1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118
| import os import functools import redis from redis.connection import ConnectionPool import msgpack import msgpack_numpy import numpy as np import pandas as pd import io
class _HashedSeq(list): __slots__ = 'hashvalue'
def __init__(self, tup, hash=hash): self[:] = tup self.hashvalue = hash(tup)
def __hash__(self): return self.hashvalue
def _make_key(args, kw, typed=False): key = args if kw: kwd_mark = (object(),) key += kwd_mark for item in kw.items(): key += item if typed: key += tuple(type(v) for v in args) if kw: key += tuple(type(v) for v in kw.values()) elif len(key) == 1 and type(key[0]) in {int, str}: return key[0] return _HashedSeq(key)
def dumps_to_feather(df): columns = df.columns df.columns = [str(e) for e in df.columns] buffer = io.BytesIO() df.to_feather(buffer) buffer.seek(0) serialized_data = buffer.getvalue() return {"data": serialized_data, "columns": list(columns)}
def loads_from_feather(data): buffer = io.BytesIO(data["data"]) df = pd.read_feather(buffer) df.columns = data["columns"] return df
def _custom_encode(obj): if isinstance(obj, set): return {b'__set__': list(obj)} if isinstance(obj, pd.DataFrame): return {b'__feather__': dumps_to_feather(obj)} return msgpack_numpy.encode(obj)
def _custom_decode(obj): if b'__set__' in obj: return set(obj[b'__set__']) if b'__feather__' in obj: return loads_from_feather(obj[b'__feather__']) return msgpack_numpy.decode(obj)
class RedisTTLCache: def __init__(self, host='localhost', port=6379): self.pool = ConnectionPool(host=host, port=port)
def cache(self, ttl, scope=os.path.basename(__file__)): def decorator(func): @functools.wraps(func) def wrapper(*args, **kwargs): cache_key = f"FUNC_CACHE_{(scope, func.__name__, _make_key(args, kwargs))}" redis_client = redis.Redis(connection_pool=self.pool) result = redis_client.get(cache_key) if result is not None: result = msgpack.loads(result, object_hook=_custom_decode) return result result = func(*args, **kwargs) redis_client.set(cache_key, msgpack.dumps(result, default=_custom_encode), ex=ttl) return result return wrapper return decorator
if __name__ == '__main__': redis_ttl_cache = RedisTTLCache(host="localhost") ttl_cached = redis_ttl_cache.cache
@ttl_cached(ttl=5) def func(a, b): print("Hello world") return { 'result': a + b, 'colors': {'red', 'green', 'blue'}, 'arr': np.array([[0, 1], [2,3]]), 'df': pd.DataFrame([[0, 1], [2,3]]), }
print(func(1, 2)) print(func(1, 2))
|