基于Redis的分布式锁

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
# -*- coding: utf-8 -*-
from redis import Redis
import socket
import os
import threading
import time
import uuid
from contextlib import ExitStack, contextmanager

# 加锁时间片
# 对于超过一个加锁时间片的锁应该重复使用expire设置过期时间
# 为的是在程序异常结束时锁能被尽快自动释放,所以时间片不应太长,
# 也不应设置过短,不然一是会频繁的访问redis续过期时长导致资源浪费,
# 二是在很卡的时候负责续时的线程可能无法正常续时。
# 建议设置在5~20秒
LOCK_TIME_SLICE = 10


class KeepLockThread(threading.Thread):

def __init__(self, redis, key_name):
super().__init__()
self.redis = redis
self.key_name = key_name

self.terminated = False
self.lock = threading.Lock()

def run(self):
while True:
time.sleep(LOCK_TIME_SLICE//3+1)

with self.lock:
if not self.terminated:
self.redis.expire(self.key_name, LOCK_TIME_SLICE)
else:
return

def stop(self):
with self.lock:
self.terminated = True


# 互斥锁
class Mutex:
def __init__(self, name, server="127.0.0.1"):
self.name = name
self.key_name = "MUTEX_" + name
hostname = socket.gethostname()
pid = os.getpid()
tid = threading.get_ident()
# 得到唯一ID,主机名 PID TID 是为了调试用
self.id = f"{hostname}_{pid}_{tid}_{uuid.uuid4().hex}"
self.redis = Redis(host=server)
self.thread = None

# blocking: 是否阻塞
# shortlived: 是否短暂的加锁,如果是短暂的则不会创建额外的线程维护过期时间,所谓短暂就是确定加锁时间小于LOCK_TIME_SLICE
# 加锁时间是从获得锁到释放锁的时间,不包括阻塞等待的时间,可以阻塞等待很长时间仍然是短暂的加锁。
def acquire(self, blocking=True, shortlived=False):
r = self.redis.set(self.key_name, self.id, ex=LOCK_TIME_SLICE, nx=True)
if blocking:
while not r:
time.sleep(0.05)
r = self.redis.set(self.key_name, self.id, ex=LOCK_TIME_SLICE, nx=True)
if r:
if not shortlived:
# 如果加锁成功了,那么应该创建线程维护过期时间
self.thread = KeepLockThread(self.redis, self.key_name)
self.thread.daemon = True
self.thread.start()
return r

def release(self):
if self.acquired():
if self.thread:
self.thread.stop()
self.thread = None
self.redis.delete(self.key_name)

def acquired(self):
r = self.redis.get(self.key_name)
if r is not None and r.decode() == str(self.id):
return True
else:
return False

def __enter__(self):
self.acquire()

def __exit__(self, exc_type, exc_value, exc_trackback):
self.release()
if exc_value is not None:
raise exc_value


def mutex(names, server="127.0.0.1"):
def medium(func):
def wrapper(*args, **kw):
if type(names) == list:
with ExitStack() as stack:
for name in names:
stack.enter_context(Mutex(name))
return func(*args, **kw)
else:
with Mutex(names, server):
return func(*args, **kw)
wrapper.__name__ = func.__name__
wrapper.__doc__ = func.__doc__
return wrapper
return medium


class KeepReadLockThread(threading.Thread):

def __init__(self, redis, lock_id, rlock_name):
super().__init__()
self.redis = redis
self.id = lock_id
self.rlock_name = rlock_name

self.terminated = False
self.lock = threading.Lock()

def run(self):
while True:
time.sleep(LOCK_TIME_SLICE//3+1)

with self.lock:
if not self.terminated:
pipeline = self.redis.pipeline()
now = time.time()
pipeline.zadd(self.rlock_name, {self.id: int(now*1000)} )
pipeline.expire(self.rlock_name, LOCK_TIME_SLICE)
pipeline.execute()
else:
return

def stop(self):
with self.lock:
self.terminated = True


# 读写锁
class ReadWriteLock:

def __init__(self, name, server="127.0.0.1"):
self.name = name
self.server = server
self.rlock_name = "RLOCK_" + name
self.wlock_name = "WLOCK_" + name
self.meta_lock_name = "META_" + name

hostname = socket.gethostname()
pid = os.getpid()
tid = threading.get_ident()
# 得到唯一ID,主机名 PID TID 是为了调试用
self.id = f"{hostname}_{pid}_{tid}_{uuid.uuid4().hex}"

self.redis = Redis(host=server)
self.lock_type = None
self.thread = None

def acquire_read_lock(self, blocking=True, shortlived=False):
# 这个mutex用来维护读写锁内部数据本身使用,在离开ReadWriteLock函数前必须释放
mutex = Mutex(self.meta_lock_name, self.server)
try:
mutex.acquire(shortlived=True)
wlock_locked = self.redis.get(self.wlock_name)
if wlock_locked:
if blocking:
while wlock_locked:
mutex.release()
time.sleep(0.05)
mutex.acquire(shortlived=True)
wlock_locked = self.redis.get(self.wlock_name)
else:
return False

pipeline = self.redis.pipeline()
now = time.time()
# 移除一定时间前失效的读锁(主动处理异常未释放的读锁)
pipeline.zremrangebyscore(self.rlock_name, 0, int((now-LOCK_TIME_SLICE)*1000) )
# 添加新的读锁
pipeline.zadd(self.rlock_name, {self.id: int(now*1000)} )
pipeline.expire(self.rlock_name, LOCK_TIME_SLICE)
pipeline.execute()
self.lock_type = 'R'

if not shortlived:
# 创建线程维护读锁过期时间
self.thread = KeepReadLockThread(self.redis, self.id, self.rlock_name)
self.thread.daemon = True
self.thread.start()

return True
finally:
mutex.release()

def acquire_write_lock(self, blocking=True, shortlived=False):
mutex = Mutex(self.meta_lock_name, self.server)
try:
# 注意这里使用mutex时shortlived始终为True,它不是由外部使用读写锁是否短暂决定
# 而是加读写锁时内部需要用到mutex进行短暂加锁,以保证获取到读锁为0后到加写锁前读锁不会新增
mutex.acquire(shortlived=True)
# 移除一定时间前失效的读锁(主动处理异常未释放的读锁)
self.redis.zremrangebyscore(self.rlock_name, 0, int((time.time()-LOCK_TIME_SLICE)*1000) )
# 获取当前读锁数量
r = self.redis.zcard(self.rlock_name)
if r:
if blocking:
while r:
mutex.release()
time.sleep(0.05)
mutex.acquire(shortlived=True)
# 移除一定时间前失效的读锁(主动处理异常未释放的读锁)
self.redis.zremrangebyscore(self.rlock_name, 0, int((time.time()-LOCK_TIME_SLICE)*1000) )
# 获取当前读锁数量
r = self.redis.zcard(self.rlock_name)
else:
return False
# 尝试获取写锁
r = self.redis.set(self.wlock_name, self.id, ex=LOCK_TIME_SLICE, nx=True)
if blocking:
while not r:
mutex.release()
time.sleep(0.05)
# 没有获取到写锁,那么可能读锁又会被其他地方获取,所以要重新先等到读锁为0
mutex.acquire(shortlived=True)
self.redis.zremrangebyscore(self.rlock_name, 0, int((time.time()-LOCK_TIME_SLICE)*1000) )
r = self.redis.zcard(self.rlock_name)
if r:
if blocking:
while r:
mutex.release()
time.sleep(0.05)
mutex.acquire(shortlived=True)
self.redis.zremrangebyscore(self.rlock_name, 0, int((time.time()-LOCK_TIME_SLICE)*1000) )
r = self.redis.zcard(self.rlock_name)
else:
return False
r = self.redis.set(self.wlock_name, self.id, ex=LOCK_TIME_SLICE, nx=True)
if r:
if not shortlived:
self.thread = KeepLockThread(self.redis, self.wlock_name)
self.thread.daemon = True
self.thread.start()

self.lock_type = 'W'
return r
finally:
mutex.release()

def acquired(self, lock_type = "any"):
if lock_type == "any":
if self.lock_type is not None:
return True
else:
return False
else:
if self.lock_type == lock_type:
return True
else:
return False

def release(self):
if self.thread:
self.thread.stop()
self.thread = None
if self.lock_type == 'R':
self.redis.zrem(self.rlock_name, self.id)
self.lock_type = None
elif self.lock_type == 'W':
r = self.redis.get(self.wlock_name)
if r is not None and r.decode() == str(self.id):
self.redis.delete(self.wlock_name)
self.lock_type = None


@contextmanager
def read_lock(name, server="127.0.0.1"):
rwlock = ReadWriteLock(name=name)
try:
rwlock.acquire_read_lock()
yield rwlock
finally:
rwlock.release()


@contextmanager
def write_lock(name, server="127.0.0.1"):
rwlock = ReadWriteLock(name=name)
try:
rwlock.acquire_write_lock()
yield rwlock
finally:
rwlock.release()

Mutex支持阻塞和非阻塞加锁,默认阻塞,非阻塞用法。
考虑了锁住后崩溃,解决方案是利用redis设置过期时间,超时后自动删除,一次长时间的加锁会通过一个线程不断的续时,如果加锁后崩溃,锁会在一个时间片内的时间被自动释放。

又基于Mutex实现了一个ReadWriteLock,也支持阻塞和非阻塞,以及崩溃后自动释放。