diff --git a/.python-version b/.python-version index 2c07333..24ee5b1 100644 --- a/.python-version +++ b/.python-version @@ -1 +1 @@ -3.11 +3.13 diff --git a/pyproject.toml b/pyproject.toml index e1d825e..49573df 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -26,9 +26,9 @@ pyqt5 = ["qtpy", "PyQt5"] pyqt6 = ["qtpy", "PyQt6"] pyside6 = ["qtpy", "PySide6"] -# [[tool.uv.index]] -# default = true -# url = "https://pypi.tuna.tsinghua.edu.cn/simple" +[[tool.uv.index]] +default = true +url = "https://pypi.tuna.tsinghua.edu.cn/simple" [build-system] build-backend = "hatchling.build" diff --git a/scripts/clean_pycache.py b/scripts/clean_pycache.py index 070343f..d402b80 100644 --- a/scripts/clean_pycache.py +++ b/scripts/clean_pycache.py @@ -1,5 +1,6 @@ import os import shutil + # 需要遍历的目录 root_dir = "./" # 遍历目录 diff --git a/src/queue_sqlite/scheduler/_async/async_receive_scheduler.py b/src/queue_sqlite/scheduler/_async/async_receive_scheduler.py index 98dd24c..04ff0c3 100644 --- a/src/queue_sqlite/scheduler/_async/async_receive_scheduler.py +++ b/src/queue_sqlite/scheduler/_async/async_receive_scheduler.py @@ -16,6 +16,9 @@ import asyncio import threading import concurrent.futures import logging +from collections import namedtuple + +SendMessageCallback = namedtuple("send_message_callback", ["message", "callback"]) class AsyncReceiveScheduler: @@ -37,6 +40,21 @@ class AsyncReceiveScheduler: with self.lock: self.callbacks[message.id] = callback + def send_message_batch(self, message_callback_list: list[SendMessageCallback]): + message_list = [ + message_callback.message.to_dict_by_core() + for message_callback in message_callback_list + ] + if message_list: + self.queue_operation.enqueue_batch(message_list) + with self.lock: + for message_callback in message_callback_list: + message = message_callback.message + callback = message_callback.callback + if callback is None: + callback = lambda m: logging.info(f"receive message: {m.id}") + self.callbacks[message.id] = callback + async def receive_message(self): """单一轮询线程,并行执行回调""" diff --git a/src/queue_sqlite/scheduler/qt/receive_scheduler.py b/src/queue_sqlite/scheduler/qt/receive_scheduler.py index 497929e..4ea5fc7 100644 --- a/src/queue_sqlite/scheduler/qt/receive_scheduler.py +++ b/src/queue_sqlite/scheduler/qt/receive_scheduler.py @@ -17,6 +17,9 @@ from ...model import MessageItem import time import logging from .callback_task import QtCallbackTask +from collections import namedtuple + +SendMessageCallback = namedtuple("send_message_callback", ["message", "callback"]) class QtReceiveScheduler: @@ -42,6 +45,21 @@ class QtReceiveScheduler: with self.lock: self.callbacks[message.id] = callback + def send_message_batch(self, message_callback_list: list[SendMessageCallback]): + message_list = [ + message_callback.message.to_dict_by_core() + for message_callback in message_callback_list + ] + if message_list: + self.queue_operation.enqueue_batch(message_list) + with self.lock: + for message_callback in message_callback_list: + message = message_callback.message + callback = message_callback.callback + if callback is None: + callback = lambda m: logging.info(f"receive message: {m.id}") + self.callbacks[message.id] = callback + def _receive_loop(self): """接收消息循环""" while self.is_running: diff --git a/src/queue_sqlite/scheduler/standard/receive_scheduler.py b/src/queue_sqlite/scheduler/standard/receive_scheduler.py index d691e9b..4acf65d 100644 --- a/src/queue_sqlite/scheduler/standard/receive_scheduler.py +++ b/src/queue_sqlite/scheduler/standard/receive_scheduler.py @@ -16,6 +16,9 @@ from concurrent.futures import ThreadPoolExecutor import threading import time import logging +from collections import namedtuple + +SendMessageCallback = namedtuple("send_message_callback", ["message", "callback"]) class ReceiveScheduler: @@ -40,6 +43,21 @@ class ReceiveScheduler: with self.lock: self.callbacks[message.id] = callback + def send_message_batch(self, message_callback_list: list[SendMessageCallback]): + message_list = [ + message_callback.message.to_dict_by_core() + for message_callback in message_callback_list + ] + if message_list: + self.queue_operation.enqueue_batch(message_list) + with self.lock: + for message_callback in message_callback_list: + message = message_callback.message + callback = message_callback.callback + if callback is None: + callback = lambda m: logging.info(f"receive message: {m.id}") + self.callbacks[message.id] = callback + def receive_message(self): """单一轮询线程,并行执行回调""" diff --git a/src/queue_sqlite_core/queue_sqlite_core.pyi b/src/queue_sqlite_core/queue_sqlite_core.pyi index 90f9d0b..ff5b749 100644 --- a/src/queue_sqlite_core/queue_sqlite_core.pyi +++ b/src/queue_sqlite_core/queue_sqlite_core.pyi @@ -4,6 +4,7 @@ class QueueOperation: def __init__(self, queue_path: str): ... def init_db(self) -> None: ... def enqueue(self, message: dict) -> str: ... + def enqueue_batch(self, messages: list[dict]) -> list[str]: ... def dequeue(self, size: int = 1) -> list[dict]: ... def get_queue_length(self) -> int: ... def get_completed_messages(self) -> list[dict]: ... @@ -23,6 +24,7 @@ class ShardedQueueOperation: def __init__(self, shard_num: int, queue_name: str): ... def _get_shard_index(self, message_id: str) -> int: ... def enqueue(self, message: dict) -> str: ... + def enqueue_batch(self, messages: list[dict]) -> list[str]: ... def dequeue(self, size: int = 1) -> list[dict]: ... def get_queue_length(self) -> int: ... def get_completed_messages(self) -> list[dict]: ... diff --git a/src/queue_sqlite_core/src/queue_operation.rs b/src/queue_sqlite_core/src/queue_operation.rs index 412d8ea..3f8dd16 100644 --- a/src/queue_sqlite_core/src/queue_operation.rs +++ b/src/queue_sqlite_core/src/queue_operation.rs @@ -256,6 +256,104 @@ impl QueueOperation { Ok(id) } + pub fn enqueue_batch(&self, messages: Vec>) -> PyResult> { + let mut conn = self.get_connection()?; + let tx = conn + .transaction() + .map_err(|e| PyErr::new::(e.to_string()))?; + + let mut inserted_ids = Vec::new(); + { + let mut stmt = tx + .prepare( + "INSERT INTO messages ( + id, type, status, content, createtime, updatetime, result, + priority, source, destination, retry_count, expire_time, tags, metadata + ) VALUES (?1, ?2, ?3, ?4, ?5, ?6, ?7, ?8, ?9, ?10, ?11, ?12, ?13, ?14)", + ) + .map_err(|e| PyErr::new::(e.to_string()))?; + for message in messages { + let id: String = message.get_item("id")?.unwrap().extract()?; + let message_type: String = message.get_item("type")?.unwrap().extract()?; + let status: i32 = message.get_item("status")?.unwrap().extract()?; + let content: String = message.get_item("content")?.unwrap().extract()?; + let createtime: String = message.get_item("createtime")?.unwrap().extract()?; + let updatetime: String = message.get_item("updatetime")?.unwrap().extract()?; + let priority: i32 = message.get_item("priority")?.unwrap().extract()?; + let source: String = message.get_item("source")?.unwrap().extract()?; + let destination: String = message.get_item("destination")?.unwrap().extract()?; + let retry_count: i32 = message.get_item("retry_count")?.unwrap().extract()?; + let result: Option = match message.get_item("result")? { + Some(item) => { + if item.is_none() { + None + } else { + Some(item.extract()?) + } + } + None => None, + }; + + let expire_time: Option = match message.get_item("expire_time")? { + Some(item) => { + if item.is_none() { + None + } else { + Some(item.extract()?) + } + } + None => None, + }; + + let tags: Option = match message.get_item("tags")? { + Some(item) => { + if item.is_none() { + None + } else { + Some(item.extract()?) + } + } + None => None, + }; + + let metadata: Option = match message.get_item("metadata")? { + Some(item) => { + if item.is_none() { + None + } else { + Some(item.extract()?) + } + } + None => None, + }; + + stmt.execute(params![ + id, + message_type, + status, + content, + createtime, + updatetime, + result, + priority, + source, + destination, + retry_count, + expire_time, + tags, + metadata, + ]) + .map_err(|e| PyErr::new::(e.to_string()))?; + + inserted_ids.push(id); + } + } + + tx.commit() + .map_err(|e| PyErr::new::(e.to_string()))?; + Ok(inserted_ids) + } + /// 出队 /// Args: /// - self: QueueOperation 实例 @@ -661,6 +759,28 @@ impl ShardedQueueOperation { self.shards[shard_index].enqueue(message) } + pub fn enqueue_batch(&self, messages: Vec>) -> PyResult> { + let mut shard_groups: HashMap>> = HashMap::new(); + for message in messages { + let message_id_bound = message + .get_item("id")? + .ok_or_else(|| PyErr::new::("Missing is field"))?; + let message_id: String = message_id_bound.extract()?; + let shard_index = self._get_shard_index(&message_id); + shard_groups + .entry(shard_index) + .or_insert_with(Vec::new) + .push(message); + } + let mut all_ids = Vec::new(); + + for (shard_index, shard_messages) in shard_groups { + let shard_ids = self.shards[shard_index].enqueue_batch(shard_messages)?; + all_ids.extend(shard_ids); + } + Ok(all_ids) + } + /// 获取消息 /// Args: /// - self: ShardedQueueOperation 实例 diff --git a/tests/test_message_batch.py b/tests/test_message_batch.py new file mode 100644 index 0000000..e69de29