198 lines
7.1 KiB
Python
198 lines
7.1 KiB
Python
from sophon_chakcy.utils.multi_decoder_manager import MultiDecoderManager, TaskArgs, DecodeMode
|
||
import sophon_chakcy.sail as sail
|
||
import asyncio
|
||
import websockets
|
||
from sophon_chakcy.utils.bmimg_to_base64 import bmimage_to_base64
|
||
from concurrent.futures import ThreadPoolExecutor
|
||
|
||
|
||
executor = ThreadPoolExecutor(max_workers=64)
|
||
HOST = "0.0.0.0"
|
||
PORT = 8765
|
||
|
||
|
||
class ChannelBroadcaster:
|
||
"""每个通道的广播器:独立读取帧并分发给所有订阅客户端"""
|
||
def __init__(self,
|
||
channel_idx: int,
|
||
decoder_manager: MultiDecoderManager,
|
||
handle: sail.Handle):
|
||
self.channel_idx = channel_idx
|
||
self.decoder_manager = decoder_manager
|
||
self.handle = handle
|
||
self.subscribers = []
|
||
self._task = None
|
||
self._running = False
|
||
|
||
def subscribe(self):
|
||
"""客户端订阅,获得一个专属队列"""
|
||
q = asyncio.Queue(maxsize=10)
|
||
self.subscribers.append(q)
|
||
return q
|
||
|
||
def unsubscribe(self, q):
|
||
"""客户端取消订阅"""
|
||
if q in self.subscribers:
|
||
self.subscribers.remove(q)
|
||
|
||
async def start(self):
|
||
"""启动广播循环"""
|
||
if self._running:
|
||
return
|
||
self._running = True
|
||
self._task = asyncio.create_task(self._broadcast_loop())
|
||
|
||
async def stop(self):
|
||
"""停止广播"""
|
||
self._running = False
|
||
if self._task:
|
||
self._task.cancel()
|
||
try:
|
||
await self._task
|
||
except asyncio.CancelledError:
|
||
pass
|
||
self._task = None
|
||
|
||
async def _broadcast_loop(self):
|
||
"""持续从解码器取帧,处理后分发给所有订阅者"""
|
||
loop = asyncio.get_running_loop()
|
||
try:
|
||
while self._running:
|
||
bmimg = self.decoder_manager.read_frame(self.channel_idx, DecodeMode.NO_WAIT)
|
||
if bmimg is not None:
|
||
# 可选:在此统一进行算法推理和绘制,避免每个客户端重复计算
|
||
algorithms = self.decoder_manager.get_algorithms(self.channel_idx)
|
||
if algorithms:
|
||
def process_frame(bmimg):
|
||
for algo in algorithms:
|
||
detections = algo.detect(bmimg)
|
||
bmimg = algo.draw_detections(bmimg, detections, None)
|
||
return bmimg
|
||
bmimg = await loop.run_in_executor(executor, process_frame, bmimg)
|
||
# 转换为 base64(统一转换,节省资源)
|
||
image_b64 = await loop.run_in_executor(executor, bmimage_to_base64, bmimg, self.handle)
|
||
if image_b64:
|
||
# 分发到每个订阅者的队列
|
||
for q in self.subscribers[:]: # 使用副本遍历,避免修改
|
||
try:
|
||
q.put_nowait(image_b64)
|
||
except asyncio.QueueFull:
|
||
# 如果队列满,丢弃最旧的帧
|
||
await q.get()
|
||
q.put_nowait(image_b64)
|
||
pass
|
||
else:
|
||
# 无帧时短暂休眠,避免空转
|
||
await asyncio.sleep(0.001)
|
||
except asyncio.CancelledError:
|
||
pass
|
||
finally:
|
||
pass
|
||
|
||
broadcasters = {}
|
||
|
||
async def video_stream(websocket, channel_idx, decoder_manager, handle):
|
||
"""处理单个WebSocket连接:订阅对应通道,发送帧"""
|
||
# 获取或创建对应通道的广播器
|
||
if channel_idx not in broadcasters:
|
||
broadcaster = ChannelBroadcaster(channel_idx, decoder_manager, handle)
|
||
await broadcaster.start()
|
||
broadcasters[channel_idx] = broadcaster
|
||
else:
|
||
broadcaster = broadcasters[channel_idx]
|
||
|
||
# 订阅通道,获得专属队列
|
||
queue = broadcaster.subscribe()
|
||
|
||
# 发送任务:从队列取帧并发送
|
||
async def sender():
|
||
try:
|
||
while True:
|
||
frame = await queue.get()
|
||
await websocket.send(frame)
|
||
except asyncio.CancelledError:
|
||
print(f"通道 {channel_idx} 发送任务取消")
|
||
except Exception as e:
|
||
print(f"通道 {channel_idx} 发送任务异常: {e}")
|
||
import traceback
|
||
traceback.print_exc()
|
||
|
||
send_task = asyncio.create_task(sender())
|
||
|
||
try:
|
||
# 监听连接关闭(客户端不会发消息,但连接关闭时会抛出异常)
|
||
await websocket.wait_closed()
|
||
except websockets.exceptions.ConnectionClosed:
|
||
pass
|
||
finally:
|
||
# 清理:取消发送任务、取消订阅、若通道无订阅者则停止广播器
|
||
print(f"连接关闭: 通道 {channel_idx}")
|
||
send_task.cancel()
|
||
try:
|
||
await send_task
|
||
except asyncio.CancelledError:
|
||
pass
|
||
|
||
broadcaster.unsubscribe(queue)
|
||
print(f"通道 {channel_idx} 取消订阅,剩余订阅者: {len(broadcaster.subscribers)}")
|
||
if not broadcaster.subscribers:
|
||
await broadcaster.stop()
|
||
if channel_idx in broadcasters:
|
||
del broadcasters[channel_idx]
|
||
|
||
async def ws_handler(websocket):
|
||
"""WebSocket 请求路由"""
|
||
path = websocket.request.path
|
||
parts = path.strip('/').split('/')
|
||
if len(parts) != 2 or parts[0] != 'video':
|
||
await websocket.close(1008, "路径格式错误,应为 /video/<通道索引>")
|
||
return
|
||
try:
|
||
channel_idx = int(parts[1])
|
||
except ValueError:
|
||
await websocket.close(1008, "通道索引必须是整数")
|
||
return
|
||
|
||
# 调用 video_stream 处理该连接
|
||
await video_stream(websocket, channel_idx, decoder_manager, handle)
|
||
|
||
async def main():
|
||
global decoder_manager, handle
|
||
|
||
handle = sail.Handle(0)
|
||
input_output_path = [
|
||
# f"rtsp://192.168.137.1:8554/stream_{i}"
|
||
f"resources/datasets/test_car_person_1080P_{i}.mp4"
|
||
for i in range(16)
|
||
]
|
||
with MultiDecoderManager() as decoder_manager:
|
||
decoder_manager.set_local_flag(True)
|
||
for input_path in input_output_path:
|
||
decoder_manager.add_channel(
|
||
input_path,
|
||
frame_skip_num=1,
|
||
loopnum=None,
|
||
task_args=TaskArgs(
|
||
task_name=f"task_{input_path.split('/')[-1]}",
|
||
algorithms=["yolo_example", "yolo_example", "yolo_example"],
|
||
handle=handle
|
||
)
|
||
)
|
||
|
||
# 为每个通道创建异步队列
|
||
active_channels = decoder_manager.get_active_channels()
|
||
channel_count = len(active_channels)
|
||
if channel_count == 0:
|
||
print("没有可用的视频通道,程序退出。")
|
||
return
|
||
|
||
print(f"可用通道索引: 0 ~ {len(active_channels)-1}")
|
||
|
||
async with websockets.serve(ws_handler, HOST, PORT):
|
||
print(f"WebSocket 服务器启动: ws://{HOST}:{PORT}/video/{{index}}")
|
||
# 保持运行
|
||
await asyncio.Future()
|
||
|
||
if __name__ == "__main__":
|
||
asyncio.run(main())
|