基于redis实现的ttl-cache

有时候我们希望函数调用的缓存能够夸进程使用或者在程序重启后仍然有效。

一个显然的想法是缓存到redis,于是就有了下面的代码

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118

import os
import functools
import redis
from redis.connection import ConnectionPool
import msgpack
import msgpack_numpy
import numpy as np
import pandas as pd
import io

class _HashedSeq(list):
__slots__ = 'hashvalue'

def __init__(self, tup, hash=hash):
self[:] = tup
self.hashvalue = hash(tup)

def __hash__(self):
return self.hashvalue


def _make_key(args, kw, typed=False):
key = args
if kw:
kwd_mark = (object(),)
key += kwd_mark
for item in kw.items():
key += item
if typed:
key += tuple(type(v) for v in args)
if kw:
key += tuple(type(v) for v in kw.values())
elif len(key) == 1 and type(key[0]) in {int, str}:
return key[0]
return _HashedSeq(key)


def dumps_to_feather(df):
columns = df.columns
df.columns = [str(e) for e in df.columns]
buffer = io.BytesIO() # 创建一个内存中的字节流缓冲区
df.to_feather(buffer) # 将 DataFrame 序列化为 Feather 格式并存储到缓冲区
buffer.seek(0) # 重新定位到字节流的开头
serialized_data = buffer.getvalue() # 获取序列化后的二进制数据
return {"data": serialized_data, "columns": list(columns)}


def loads_from_feather(data):
buffer = io.BytesIO(data["data"]) # 创建一个内存中的字节流缓冲区,并加载序列化的数据
df = pd.read_feather(buffer) # 从 Feather 格式的序列化数据中加载 DataFrame
df.columns = data["columns"]
return df


def _custom_encode(obj):
if isinstance(obj, set):
return {b'__set__': list(obj)} # 将 set 对象转换为字典形式
if isinstance(obj, pd.DataFrame):
return {b'__feather__': dumps_to_feather(obj)}
return msgpack_numpy.encode(obj)


def _custom_decode(obj):
if b'__set__' in obj: # 判断是否是我们之前转换过的 set 对象
return set(obj[b'__set__']) # 将字典形式的 set 转换回原始的 set 对象
if b'__feather__' in obj:
return loads_from_feather(obj[b'__feather__'])
return msgpack_numpy.decode(obj) # 不需要转换的情况下直接返回原始对象


class RedisTTLCache:
def __init__(self, host='localhost', port=6379):
self.pool = ConnectionPool(host=host, port=port)

def cache(self, ttl, scope=os.path.basename(__file__)):
def decorator(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
# 构建唯一的缓存键
cache_key = f"FUNC_CACHE_{(scope, func.__name__, _make_key(args, kwargs))}"

# 尝试从Redis中获取缓存
redis_client = redis.Redis(connection_pool=self.pool)
result = redis_client.get(cache_key)
if result is not None:
# 如果缓存存在,直接返回结果
result = msgpack.loads(result, object_hook=_custom_decode)
return result

# 如果缓存中不存在该结果,则重新计算函数结果
result = func(*args, **kwargs)

# 将计算结果存入Redis缓存
redis_client.set(cache_key, msgpack.dumps(result, default=_custom_encode), ex=ttl)

return result
return wrapper
return decorator


if __name__ == '__main__':
redis_ttl_cache = RedisTTLCache(host="localhost")
ttl_cached = redis_ttl_cache.cache

@ttl_cached(ttl=5)
def func(a, b):
print("Hello world")
return {
'result': a + b,
'colors': {'red', 'green', 'blue'},
'arr': np.array([[0, 1], [2,3]]),
'df': pd.DataFrame([[0, 1], [2,3]]),
}

print(func(1, 2))
print(func(1, 2))

这里的主要问题是使用什么作为redis的key来确保真正相同的调用使用到相同的缓存。

如果我们在key中拼入进程ID,那么则是非常保守的,不同进程将无法使用到相同的缓存。

我这里默认使用 脚本文件名+函数名+参数 作为key,但是或许有两个相同脚本名的脚本中定义了同名的但功能不同的函数,这时候不应该使用相同的缓存。

总之我们只能折中处理,因为我的目的就是跨进程使用缓存,所以没有特别保守,同时提供了scope参数,可以自定义一个拼入key的标识,在必要时由用户决定能否使用相同的缓存。