modify: 代码优化

This commit is contained in:
chakcy 2025-09-25 21:53:58 +08:00
parent bcb3304338
commit cfc4028a51
8 changed files with 173 additions and 40 deletions

View File

@ -5,7 +5,7 @@ use pyo3::prelude::*;
use queue_operation::QueueOperation;
use task_mounter::TaskMounter; // 导入结构体 // 导入结构体
#[pymodule]
#[pymodule(gil_used = false)]
fn core(_py: Python<'_>, m: &Bound<'_, PyModule>) -> PyResult<()> {
m.add_class::<TaskMounter>()?;
m.add_class::<QueueOperation>()?;

View File

@ -35,18 +35,133 @@ class core:
"""
...
def dequeue(self, size: int = 1) -> List[dict]: ...
def get_queue_length(self) -> int: ...
def get_completed_messages(self) -> List[dict]: ...
def get_result(self, message_id: str): ...
def update_status(self, message_id: str, status: MessageStatus): ...
def update_result(self, message_id: str, result: str): ...
def delete_message(self, message_id: str): ...
def clean_old_messages(self, days: int): ...
def clean_expired_messages(self): ...
def remove_expired_messages(self, days: int): ...
def dequeue(self, size: int = 1) -> List[dict]:
"""
Dequeue messages.
Args:
size (int, optional): The number of messages to dequeue. Defaults to 1.
Returns:
List[dict]: The dequeued messages.
"""
...
def get_queue_length(self) -> int:
"""
Get the length of the queue.
Returns:
int: The length of the queue.
"""
...
def get_completed_messages(self) -> List[dict]:
"""
Get the completed messages.
Returns:
List[dict]: The completed messages.
"""
...
def get_result(self, message_id: str):
"""
Get the result of a message.
Args:
message_id (str): The message id.
Returns:
str: The result of the message.
"""
...
def update_status(self, message_id: str, status: MessageStatus):
"""
Update the status of a message.
Args:
message_id (str): The message id.
status (MessageStatus): The status of the message.
Raises:
ValueError: If the status is not valid.
"""
...
def update_result(self, message_id: str, result: str):
"""
Update the result of a message.
Args:
message_id (str): The message id.
result (str): The result of the message.
"""
...
def delete_message(self, message_id: str):
"""
Delete a message.
Args:
message_id (str): The message id.
Raises:
ValueError: If the message id is not valid.
"""
...
def clean_old_messages(self, days: int):
"""
Clean old messages.
Args:
days (int): The number of days to keep messages.
"""
...
def clean_expired_messages(self):
"""
Clean expired messages.
Raises:
Exception: If an unknown error occurs.
"""
...
def remove_expired_messages(self, days: int):
"""
Remove expired messages.
Args:
days (int): The number of days to keep messages.
"""
...
class TaskMounter:
def __init__(self, task_class: type[TaskMounter]): ...
def get_task_list(self) -> List[str]: ...
def get_task_function(self, name: str) -> Callable: ...
def __init__(self, task_class: type[TaskMounter]):
"""
Initialize the task mounter.
Args:
task_class (type[TaskMounter]): The task class.
"""
...
def get_task_list(self) -> List[str]:
"""
Get the task list.
Returns:
List[str]: The task list.
"""
...
def get_task_function(self, name: str) -> Callable:
"""
Get the task function.
Args:
name (str): The name of the task.
Returns:
Callable: The task function.
Raises:
ValueError: If the task name is not valid.
"""
...

View File

@ -1,4 +1,4 @@
from .task_mounter import TaskMounter
from .listen_mounter import ListenMounter
from .task_mounter import task
from .listen_mounter import listener
__all__ = ["TaskMounter", "ListenMounter"]
__all__ = ["task", "listener"]

View File

@ -5,16 +5,18 @@ class ListenMounter:
@classmethod
def mount_Listener(cls, function: Callable):
setattr(cls, function.__name__, function)
@staticmethod
def listener():
"""带参数的装饰器"""
def decorator(function: Callable):
# 使用自定义名称或函数原名
setattr(ListenMounter, function.__name__, function)
return function
return decorator
@classmethod
def get_Listener_function(cls, name: str):
return getattr(cls, name, None)
@ -23,7 +25,7 @@ class ListenMounter:
def get_Listener_list(cls) -> List[str]:
"""获取所有挂载的监听函数名称列表"""
listener_list = []
# 遍历类属性字典
for attr_name, attr_value in vars(cls).items():
# 过滤条件:
@ -31,11 +33,20 @@ class ListenMounter:
# 2. 不是类自带的特殊方法(非双下划线开头)
# 3. 不是类方法本身(如 mount_Listener, listener 等)
if (
callable(attr_value)
and not attr_name.startswith("__")
and attr_name not in ["mount_Listener", "listener", "get_Listener_function", "get_Listener_list"]
callable(attr_value)
and not attr_name.startswith("__")
and attr_name
not in [
"mount_Listener",
"listener",
"get_Listener_function",
"get_Listener_list",
]
):
listener_list.append(attr_name)
return listener_list
def listener():
return ListenMounter.listener()

View File

@ -5,17 +5,19 @@ class TaskMounter:
@classmethod
def mount_task(cls, function: Callable):
setattr(cls, function.__name__, function)
@staticmethod
def task(meta: dict = {}):
"""带参数的装饰器"""
def decorator(function: Callable):
# 使用自定义名称或函数原名
function.meta = meta # type: ignore
function.meta = meta # type: ignore
setattr(TaskMounter, function.__name__, function)
return function
return decorator
@classmethod
def get_task_function(cls, name: str):
return getattr(cls, name, None)
@ -24,7 +26,7 @@ class TaskMounter:
def get_task_list(cls) -> List[str]:
"""获取所有挂载的任务函数名称列表"""
task_list = []
# 遍历类属性字典
for attr_name, attr_value in vars(cls).items():
# 过滤条件:
@ -32,11 +34,16 @@ class TaskMounter:
# 2. 不是类自带的特殊方法(非双下划线开头)
# 3. 不是类方法本身(如 mount_task, get_task_list 等)
if (
callable(attr_value)
and not attr_name.startswith("__")
and attr_name not in ["mount_task", "task", "get_task_function", "get_task_list"]
callable(attr_value)
and not attr_name.startswith("__")
and attr_name
not in ["mount_task", "task", "get_task_function", "get_task_list"]
):
task_list.append(attr_name)
return task_list
# task 装饰器
def task(meta: dict = {}):
return TaskMounter.task(meta)

View File

@ -1,8 +1,8 @@
from queue_sqlite.mounter.task_mounter import TaskMounter
from queue_sqlite.mounter import task
from queue_sqlite.model import MessageItem
@TaskMounter.task(meta={"task_name": "test"})
@task(meta={"task_name": "test"})
async def async_example(message_item: MessageItem):
def fibonacci_generator():
a, b = 0, 1

View File

@ -1,8 +1,8 @@
from queue_sqlite.mounter.task_mounter import TaskMounter
from queue_sqlite.mounter import task
from queue_sqlite.model import MessageItem
@TaskMounter.task(meta={"task_name": "test"})
@task(meta={"task_name": "test"})
def example(message_item: MessageItem):
def fibonacci_generator():
a, b = 0, 1
@ -16,4 +16,4 @@ def example(message_item: MessageItem):
return message_item.to_json()
# 输出:[0, 1, 1, 2, 3, 5, 8, 13, 21, 34]
# 输出:[0, 1, 1, 2, 3, 5, 8, 13, 21, 34]

View File

@ -1,5 +1,5 @@
from queue_sqlite.core import core
from queue_sqlite.mounter import TaskMounter
from queue_sqlite.mounter.task_mounter import TaskMounter
from queue_sqlite.model import MessageItem