modify: 解决 message_item result 字段重复序列化问题

This commit is contained in:
chakcy 2025-09-13 19:02:40 +08:00
parent 2f68ecac71
commit 80cde8aff3
3 changed files with 70 additions and 26 deletions

View File

@ -8,11 +8,11 @@ from ..core import core
from ..constant import MessageStatus from ..constant import MessageStatus
class QueueOperation(): class QueueOperation:
def __init__(self, shard_num: int = 4, queue_name: str = "default"): def __init__(self, shard_num: int = 4, queue_name: str = "default"):
self.shard_num = shard_num self.shard_num = shard_num
self.db_dir = os.path.join("cache", queue_name) self.db_dir = os.path.join("cache", queue_name)
self.shard_connections = threading.local() self.shard_connections = threading.local()
if not os.path.exists(self.db_dir): if not os.path.exists(self.db_dir):
os.makedirs(self.db_dir) os.makedirs(self.db_dir)
self.init_shards() self.init_shards()
@ -62,7 +62,7 @@ class QueueOperation():
for shard_index in shard_order: for shard_index in shard_order:
if collected >= size: if collected >= size:
break break
conn = self._get_shard_conn(shard_index) conn = self._get_shard_conn(shard_index)
shard_messages = conn.dequeue(size - collected) shard_messages = conn.dequeue(size - collected)
messages.extend(shard_messages) messages.extend(shard_messages)
@ -73,7 +73,7 @@ class QueueOperation():
# 获取队列长度 # 获取队列长度
def get_queue_length(self) -> int: def get_queue_length(self) -> int:
"""获取队列中待处理消息的数量 """获取队列中待处理消息的数量
Returns: Returns:
int: 待处理消息数量 int: 待处理消息数量
""" """
@ -82,7 +82,7 @@ class QueueOperation():
conn = self._get_shard_conn(i) conn = self._get_shard_conn(i)
total += conn.get_queue_length() total += conn.get_queue_length()
return total return total
# 获取完成/失败的消息 # 获取完成/失败的消息
def get_completed_messages(self) -> List[dict]: def get_completed_messages(self) -> List[dict]:
messages = [] messages = []
@ -96,16 +96,16 @@ class QueueOperation():
def get_result(self, message_id: str) -> Tuple[bool, Union[str, dict]]: def get_result(self, message_id: str) -> Tuple[bool, Union[str, dict]]:
conn = self._get_shard_conn(self._get_shard_index(message_id)) conn = self._get_shard_conn(self._get_shard_index(message_id))
return conn.get_result(message_id) return conn.get_result(message_id)
# 更新消息状态 # 更新消息状态
def update_status(self, message_id: str, status: MessageStatus): def update_status(self, message_id: str, status: MessageStatus):
conn = self._get_shard_conn(self._get_shard_index(message_id)) conn = self._get_shard_conn(self._get_shard_index(message_id))
conn.update_status(message_id, status) conn.update_status(message_id, status)
# 更新消息结果 # 更新消息结果
def update_result(self, message_id: str, result: dict): def update_result(self, message_id: str, result: str):
conn = self._get_shard_conn(self._get_shard_index(message_id)) conn = self._get_shard_conn(self._get_shard_index(message_id))
conn.update_result(message_id, json.dumps(result)) conn.update_result(message_id, result)
# 删除消息 # 删除消息
def delete_message(self, message_id: str): def delete_message(self, message_id: str):
@ -121,4 +121,3 @@ class QueueOperation():
def clean_expired_messages(self, shard_index: int): def clean_expired_messages(self, shard_index: int):
conn = self._get_shard_conn(shard_index) conn = self._get_shard_conn(shard_index)
conn.clean_expired_messages() conn.clean_expired_messages()

View File

@ -2,7 +2,7 @@ from ..queue_operation import QueueOperation
from ..model import MessageItem from ..model import MessageItem
from ..constant import MessageStatus from ..constant import MessageStatus
from ..cycle.task_cycle import TaskCycle from ..cycle.task_cycle import TaskCycle
from ..mounter.task_mounter import TaskMounter from ..mounter.task_mounter import TaskMounter
from concurrent.futures import ThreadPoolExecutor from concurrent.futures import ThreadPoolExecutor
from datetime import datetime from datetime import datetime
import time import time
@ -25,22 +25,22 @@ class TaskScheduler:
message.result = {"result": "success"} message.result = {"result": "success"}
message.updatetime = datetime.now() message.updatetime = datetime.now()
return message return message
task_function = TaskMounter.get_task_function(message.destination) task_function = TaskMounter.get_task_function(message.destination)
task_cycle = TaskCycle(message, task_function) task_cycle = TaskCycle(message, task_function)
task_cycle.run() task_cycle.run()
message.status = task_cycle.get_task_status() # type: ignore message.status = task_cycle.get_task_status() # type: ignore
if message.status == MessageStatus.FAILED: if message.status == MessageStatus.FAILED:
message.result = {"error": task_cycle.get_task_error()} message.result = {"error": task_cycle.get_task_error()}
else: else:
message.result = task_cycle.get_task_result() # type: ignore message.result = task_cycle.get_task_result() # type: ignore
message.updatetime = datetime.now() message.updatetime = datetime.now()
except Exception as e: except Exception as e:
message.status = MessageStatus.FAILED message.status = MessageStatus.FAILED
message.result = {"error": str(e)} message.result = {"error": str(e)}
message.updatetime = datetime.now() message.updatetime = datetime.now()
return message return message
def _update_result(self, message): def _update_result(self, message):
@ -55,24 +55,26 @@ class TaskScheduler:
"""单一轮询线程,并行执行任务""" """单一轮询线程,并行执行任务"""
while self.is_running: while self.is_running:
try: try:
message_list = self.queue_operation.dequeue(size=self.task_thread_num * 2) message_list = self.queue_operation.dequeue(
size=self.task_thread_num * 2
)
if message_list: if message_list:
# 并行处理所有获取到的消息 # 并行处理所有获取到的消息
for message in message_list: for message in message_list:
self.executor.submit( self.executor.submit(
lambda m: self._update_result(self._process_message(m)), lambda m: self._update_result(self._process_message(m)),
MessageItem.from_dict(message) MessageItem.from_dict(message),
) )
else: else:
time.sleep(0.1) # 适当休眠 time.sleep(0.1) # 适当休眠
except Exception as e: except Exception as e:
print(f"任务调度错误: {str(e)}") print(f"任务调度错误: {str(e)}")
time.sleep(1) time.sleep(1)
def start_task_thread(self): def start_task_thread(self):
if self.is_running: if self.is_running:
return return
self.is_running = True self.is_running = True
# 创建单一轮询线程 # 创建单一轮询线程
self.task_thread = threading.Thread(target=self.task_callback, daemon=True) self.task_thread = threading.Thread(target=self.task_callback, daemon=True)
@ -81,11 +83,11 @@ class TaskScheduler:
def stop_task_thread(self): def stop_task_thread(self):
if not self.is_running: if not self.is_running:
return return
self.is_running = False self.is_running = False
# 等待轮询线程结束 # 等待轮询线程结束
if self.task_thread and self.task_thread.is_alive(): if self.task_thread and self.task_thread.is_alive():
self.task_thread.join(timeout=2.0) self.task_thread.join(timeout=2.0)
# 关闭线程池 # 关闭线程池
self.executor.shutdown(wait=True) self.executor.shutdown(wait=True)

View File

@ -1,9 +1,10 @@
from queue_sqlite.model import MessageItem from queue_sqlite.model import MessageItem
import json
class TestMessageItem: class TestMessageItem:
messageItem = MessageItem(content={"num": 1}, destination="test") messageItem = MessageItem(content={"num": 1}, destination="test")
@classmethod @classmethod
def test_to_dict(cls): def test_to_dict(cls):
print(cls.messageItem.to_dict()) print(cls.messageItem.to_dict())
@ -15,4 +16,46 @@ class TestMessageItem:
@classmethod @classmethod
def test_from_dict(cls): def test_from_dict(cls):
messageItem = MessageItem.from_dict(cls.messageItem.to_dict()) messageItem = MessageItem.from_dict(cls.messageItem.to_dict())
print(messageItem.to_dict()) data = {
"id": "ddb66277-503c-4921-8e7f-5091eace49e3",
"type": "task",
"status": 2,
"content": '{"num": 1}',
"createtime": "2025-09-12T18:41:08.221531",
"updatetime": "2025-09-12T18:41:08.221531",
"result": '"{\\"message\\": \\"\\\\u6d4b\\\\u8bd5\\\\u6210\\\\u529f\\"}"',
"priority": 1,
"source": "client",
"destination": "test",
"retry_count": 0,
"expire_time": "null",
"tags": "null",
"metadata": "{}",
}
print(messageItem.to_dict())
print(MessageItem.from_dict(data).to_dict())
@classmethod
def test_from_json(cls):
json_str = """
{
"content": {
"num": 1
},
"createtime": "2025-09-12T16:42:50.663248",
"destination": "test",
"expire_time": null,
"id": "b83c3d72-0b06-4c34-ab4f-32d696aa3875",
"metadata": {},
"priority": 1,
"result": "{\"message\": \"\\u6d4b\\u8bd5\\u6210\\u529f\"}",
"retry_count": 0,
"source": "client",
"status": 2,
"tags": null,
"type": "task",
"updatetime": "2025-09-12T16:42:50.663248"
}
"""
messageItem = MessageItem.from_dict(json.loads(json_str))
print(messageItem.to_json())