1 import redis
2 import time
3 from typing import Dict, List, Tuple, Any, Optional
4
5 from config.model import settings
6 from pydantic import BaseModel
7
8
9 class StreamMessage(BaseModel):
10 message_id: str
11 message_data: Dict[str, Any]
12
13
14 class StreamMessages(BaseModel):
15 stream_name: str
16 messages: List[StreamMessage]
17
18
19 class RedisStreamManager:
20 def __init__(self, stream_name: str, redis_url: str = "", timeout: int = 2, max_length: int = 100000):
21 self.redis_url = redis_url or settings.OMS_REDIS_URL
22 self.redis = self._connect_to_redis()
23 self.stream_name = stream_name
24 self.timeout = timeout
25 self.max_length = max_length
26
27 # 確保流存在并設置最大長度
28 self.redis.xtrim(self.stream_name, maxlen=self.max_length, approximate=False)
29
30 def _connect_to_redis(self):
31 try:
32 client = redis.StrictRedis.from_url(self.redis_url, decode_responses=True)
33 client.ping() # 測試連接是否成功
34 return client
35 except redis.ConnectionError as e:
36 raise e
37
38 def add(self, message_data: Dict[str, str]) -> str:
39 """添加消息到流中"""
40 return self.redis.xadd(self.stream_name, message_data, maxlen=self.max_length)
41
42 def ensure_group(self, group_name: str):
43 """確保消費者組存在,如果不存在則創建"""
44 try:
45 self.redis.xgroup_create(self.stream_name, group_name, id="0", mkstream=True)
46 except redis.exceptions.ResponseError as e:
47 # 消費者組可能已經存在,忽略異常
48 if "BUSYGROUP Consumer Group name already exists" not in str(e):
49 raise e
50
51 def consume(self, group_name: str, consumer_name: str, count: int = 10) -> StreamMessages:
52 """從流中消費消息"""
53 self.ensure_group(group_name) # 確保消費者組存在
54 raw_messages = self.redis.xreadgroup(
55 group_name, consumer_name, {self.stream_name: ">"}, count=count, block=5000
56 )
57
58 # 解析 Redis 流消息到 Pydantic 模型
59 stream_messages = []
60 for _, message_list in raw_messages:
61 for message_id, message_data in message_list:
62 stream_message = StreamMessage(message_id=message_id, message_data=dict(message_data))
63 stream_messages.append(stream_message)
64
65 return StreamMessages(stream_name=self.stream_name, messages=stream_messages)
66
67 def ack(self, group_name: str, message_id: str):
68 """確認消息已處理"""
69 self.redis.xack(self.stream_name, group_name, message_id)
70
71 def reassign(self, group_name: str, consumer_name: str, message_id: str):
72 """根據消息ID使消息重新可分配"""
73 self.redis.xclaim(self.stream_name, group_name, consumer_name, min_idle_time=self.timeout * 1000, id=message_id)
74
75 def query_unconfirmed(self, group_name: str) -> StreamMessages:
76 """查詢未確認的消息"""
77 # 查詢未確認的消息
78 pending = self.redis.xpending(self.stream_name, group_name)
79
80 # 處理返回的未確認消息
81 stream_messages = []
82 min_id = pending.get("min")
83 max_id = pending.get("max")
84
85 # 通過 XRANGE 查詢未確認消息的具體內容
86 if min_id and max_id:
87 message_list = self.redis.xrange(self.stream_name, min_id, max_id)
88
89 for message_id, message_data in message_list:
90 stream_message = StreamMessage(message_id=message_id, message_data=dict(message_data))
91 stream_messages.append(stream_message)
92
93 return StreamMessages(stream_name=self.stream_name, messages=stream_messages)
94
95 def check_timeout(self, message_id: str) -> bool:
96 """檢查消息是否超時"""
97 message_time = int(message_id.split("-")[0])
98 return (time.time() - message_time / 1000) > self.timeout
99
100 def get_all_timeout_messages(self, group_name: str) -> List[str]:
101 """獲取所有超時的未確認消息ID"""
102 pending_messages = self.query_unconfirmed(group_name)
103 timeout_messages = []
104 for message in pending_messages.get("messages", []):
105 message_id = message[0]
106 if self.check_timeout(message_id):
107 timeout_messages.append(message_id)
108 return timeout_messages
109
110 def handle_timeout(self, group_name: str, consumer_name: str):
111 """處理超時的未確認消息"""
112 timeout_messages = self.get_all_timeout_messages(group_name)
113 for message_id in timeout_messages:
114 # 重新分配超時的未確認消息
115 self.reassign(group_name, consumer_name, message_id)
116
117 def reassign_all_unconfirmed(self, group_name: str, consumer_name: str):
118 """將所有未確認的消息恢復到可分配狀態"""
119 unconfirmed_messages = self.query_unconfirmed(group_name)
120 for message in unconfirmed_messages.get("messages", []):
121 message_id = message[0] # 直接使用,不需要解碼
122 self.reassign(group_name, consumer_name, message_id)
123
124 def query_all(self, start: str = "-", end: str = "+"):
125 """查詢流中的所有消息"""
126 try:
127 messages = self.redis.xrange(self.stream_name, min=start, max=end)
128 return messages
129 except redis.exceptions.ResponseError as e:
130 print(f"Error querying all messages: {e}")
131 return None
132
133 def weixiaofei(self, group_name):
134 return self.redis.xpending(self.stream_name, group_name)
135
136
137 def test4():
138 stream_name = "my_stream"
139 group_name = "my_group"
140 consumer_name = "my_consumer"
141 manager = RedisStreamManager(stream_name=stream_name)
142
143 # 確保消費者組存在
144 manager.ensure_group(group_name)
145
146 # 示例:添加消息
147 print("Adding messages...")
148 for i in range(5):
149 message_id = manager.add({"key": f"value_{i}"})
150 print(f"Added message with ID: {message_id}")
151
152 data = manager.query_all()
153 print(f"\ndata:{data}")
154
155 data = manager.consume(group_name, consumer_name, count=1)
156 print(f"\n未消費的:{data.messages[0].message_id}")
157 manager.ack(group_name, data.messages[0].message_id)
158
159 data = manager.consume(group_name, consumer_name, count=1)
160 print(f"\n消費未確認:{data}")
161
162 data = manager.query_unconfirmed(group_name)
163 print(f"\n未消費列表:{data}")
164
165
166 def test3():
167 stream_name = "my_stream"
168 group_name = "my_group"
169 consumer_name = "my_consumer"
170 manager = RedisStreamManager(stream_name=stream_name)
171
172 # 確保消費者組存在
173 manager.ensure_group(group_name)
174
175 data = manager.query_unconfirmed(group_name)
176 print(f"\n未消費的:{data}")
177
178
179 def test1():
180 stream_name = "my_stream"
181 group_name = "my_group"
182 consumer_name = "my_consumer"
183 manager = RedisStreamManager(stream_name=stream_name)
184
185 # 確保消費者組存在
186 manager.ensure_group(group_name)
187
188 # 示例:添加消息
189 print("Adding messages...")
190 for i in range(5):
191 message_id = manager.add({"key": f"value_{i}"})
192 print(f"Added message with ID: {message_id}")
193
194 data = manager.query_all()
195 print(f"\ndata:{data}")
196
197 # 示例:消費消息
198 test_cus = False
199 test_timeout = False
200 print("\nConsuming messages...")
201 messages = manager.consume(group_name, consumer_name)
202 for stream, message_list in messages:
203 for message_id, message_data in message_list:
204 message_id = message_id # 直接使用,不需要解碼
205 message_data = dict(message_data)
206 print(f"Received message {message_id}: {message_data}")
207
208 if not test_cus:
209 test_cus = True
210 continue
211 elif not test_timeout:
212 test_timeout = True
213 # 使消息超時以模擬未確認
214 time.sleep(2 * manager.timeout)
215
216 # 確認消息
217 manager.ack(group_name, message_id)
218 print(f"Acknowledged message {message_id}")
219
220 data = manager.query_all()
221 print(f"\n所有數據:{data}")
222
223 data = manager.consume(group_name, consumer_name)
224 print(f"\n未消費的:{data}")
225
226 data = manager.handle_timeout(group_name, consumer_name)
227 print(f"\n超時的:{data}")
228
229 data = manager.query_unconfirmed(group_name)
230 print(f"\n待確認的:{data}")
231
232
233 def test2():
234 stream_name = "my_stream"
235 group_name = "my_group"
236 consumer_name = "my_consumer"
237 manager = RedisStreamManager(stream_name=stream_name)
238
239 # 確保消費者組存在
240 manager.ensure_group(group_name)
241
242 # 示例:消費消息
243 print("\nConsuming messages...")
244 messages = manager.consume(group_name, consumer_name)
245 for stream, message_list in messages:
246 for message_id, message_data in message_list:
247 message_id = message_id # 直接使用,不需要解碼
248 message_data = dict(message_data)
249 print(f"Received message {message_id}: {message_data}")
250
251 # 使消息超時以模擬未確認
252 time.sleep(2 * manager.timeout) # 等待超時
253
254 # 確認消息
255 manager.ack(group_name, message_id)
256 print(f"Acknowledged message {message_id}")
257
258 # 處理超時消息
259 print("\nHandling timeout messages...")
260 manager.handle_timeout(group_name, consumer_name)
261 print("Handled timeout messages")
262
263 # 獲取所有超時消息
264 print("\nGetting all timeout messages...")
265 timeout_messages = manager.get_all_timeout_messages(group_name)
266 print(f"All timeout messages: {timeout_messages}")
267
268 # 將所有未確認的消息恢復到可分配狀態
269 print("\nReassigning all unconfirmed messages...")
270 manager.reassign_all_unconfirmed(group_name, consumer_name)
271 print("Reassigned all unconfirmed messages")