modify: 实现视频流推理

This commit is contained in:
chakcy 2026-02-04 23:12:12 +08:00
parent 7880c9e15b
commit 2ff9a2ecb1
7 changed files with 258 additions and 74 deletions

View File

@ -1,11 +1,14 @@
# yolov8_inference_with_colors.py
import os
import time
import numpy as np
import sophon_chakcy.sail as sail
from .postprocess_numpy import PostProcess
from sqlite_vfs.core import SQLiteVFS
from dataclasses import dataclass, field
from typing import Optional
import threading
@dataclass
class YOLOv8Config:
@ -149,37 +152,17 @@ class Detector:
return img_bgr_planar
def detect_single_image(self, image_path, output_dir="./detection_results"):
def preprocess(self, bmimg: sail.BMImage):
"""
检测单张图像
预处理图像
Args:
image_path: 图像路径
output_dir: 输出目录
image: 输入图像 (HWC, BGR格式)
Returns:
检测结果列表
(预处理后的BMImage, 原始尺寸, 缩放比例, 填充偏移)
"""
# 创建输出目录
os.makedirs(output_dir, exist_ok=True)
print(f"\n处理图像: {image_path}")
# 1. 解码图像
decoder = sail.Decoder(image_path, True, self.dev_id)
bmimg = sail.BMImage()
ret = decoder.read(self.handle, bmimg)
if ret != 0:
print(f"错误: 无法解码图像 {image_path}")
return []
orig_width = bmimg.width()
orig_height = bmimg.height()
print(f"图像大小: {orig_width}x{orig_height}")
# 2. 预处理
# 转换为RGB格式
# 转换为RGB
rgb_img = sail.BMImage(self.handle, bmimg.height(), bmimg.width(),
sail.Format.FORMAT_RGB_PLANAR, sail.DATA_TYPE_EXT_1N_BYTE)
self.bmcv.convert_format(bmimg, rgb_img)
@ -217,6 +200,16 @@ class Detector:
(self.ab[2], self.ab[3]),
(self.ab[4], self.ab[5])))
return preprocessed_img, r, tx1, ty1
def detect(self, bmimg: sail.BMImage):
orig_width = bmimg.width()
orig_height = bmimg.height()
print(f"图像大小: {orig_width}x{orig_height}")
# 2. 预处理
preprocessed_img, r, tx1, ty1 = self.preprocess(bmimg)
# 3. 准备输入张量
input_tensor = sail.Tensor(self.handle, self.input_shape, self.input_dtype, False, False)
self.bmcv.bm_image_to_tensor(preprocessed_img, input_tensor)
@ -245,6 +238,38 @@ class Detector:
# 执行后处理
detections = self.postprocess([pred_with_batch], org_size_list, ratios_batch, txy_batch)
return detections
def detect_single_image(self, image_path, output_dir="./detection_results"):
"""
检测单张图像
Args:
image_path: 图像路径
output_dir: 输出目录
Returns:
检测结果列表
"""
# 创建输出目录
os.makedirs(output_dir, exist_ok=True)
print(f"\n处理图像: {image_path}")
# 1. 解码图像
decoder = sail.Decoder(image_path, True, self.dev_id)
bmimg = sail.BMImage()
ret = decoder.read(self.handle, bmimg)
if ret != 0:
print(f"错误: 无法解码图像 {image_path}")
return []
orig_width = bmimg.width()
orig_height = bmimg.height()
print(f"图像大小: {orig_width}x{orig_height}")
detections = self.detect(bmimg)
if len(detections) > 0:
detections = detections[0] # 取第一个batch的结果

32
main.py
View File

@ -1,6 +1,19 @@
import os
import argparse
from sophon_chakcy.get_detector import get_detector
from sophon_chakcy.utils.detect_video_stream import DetectVideoStream
from threading import Thread
def detector_task(args):
detector_class, model_config = get_detector(args.name, True)
detector = detector_class(model_config, True, dev_id=args.dev_id, conf_thresh=args.conf_thresh, nms_thresh=args.nms_thresh)
detect_video_stream = DetectVideoStream([detector], args.input)
t = Thread(target=detect_video_stream.run_task)
t.start()
import time
time.sleep(5)
detect_video_stream.stop_task()
t.join()
def main():
@ -12,22 +25,9 @@ def main():
parser.add_argument('--nms_thresh', type=float, default=0.7, help='NMS阈值')
args = parser.parse_args()
detector_class, model_config = get_detector(args.name)
# 初始化检测器
detector = detector_class(model_config, False, dev_id=args.dev_id, conf_thresh=args.conf_thresh, nms_thresh=args.nms_thresh)
# 根据输入类型进行处理
if os.path.isfile(args.input):
# 单张图像
detector.detect_single_image(args.input)
detector_task(args)
elif os.path.isdir(args.input):
# 图像目录
detector.detect_images_in_directory(args.input)
else:
print(f"错误: {args.input} 既不是文件也不是目录")
if __name__ == "__main__":
main()
main()

View File

@ -2,27 +2,15 @@ import os
import argparse
from sophon_chakcy.get_detector import get_detector
from module_bank import PythonToSQLite
from multiprocessing import Process
# 安装导入器
packer = PythonToSQLite("plugins.mbank")
finder = packer.install_importer()
def main():
parser = argparse.ArgumentParser(description="YOLOv8目标检测")
parser.add_argument('--name', type=str, required=True, help='算法名称')
parser.add_argument('--input', type=str, required=True, help='输入图像或目录路径')
parser.add_argument('--dev_id', type=int, default=0, help='设备ID')
parser.add_argument('--conf_thresh', type=float, default=0.25, help='置信度阈值')
parser.add_argument('--nms_thresh', type=float, default=0.7, help='NMS阈值')
args = parser.parse_args()
def detector_task(args):
detector_class, model_config = get_detector(args.name, True)
# 初始化检测器
detector = detector_class(model_config, True, dev_id=args.dev_id, conf_thresh=args.conf_thresh, nms_thresh=args.nms_thresh)
# 根据输入类型进行处理
if os.path.isfile(args.input):
# 单张图像
@ -34,5 +22,37 @@ def main():
else:
print(f"错误: {args.input} 既不是文件也不是目录")
def main():
parser = argparse.ArgumentParser(description="YOLOv8目标检测")
parser.add_argument('--name', type=str, required=True, help='算法名称')
parser.add_argument('--input', type=str, required=True, help='输入图像或目录路径')
parser.add_argument('--dev_id', type=int, default=0, help='设备ID')
parser.add_argument('--conf_thresh', type=float, default=0.25, help='置信度阈值')
parser.add_argument('--nms_thresh', type=float, default=0.7, help='NMS阈值')
args = parser.parse_args()
p = Process(target=detector_task,args=(args,))
p.start()
p.join()
# detector_class, model_config = get_detector(args.name, True)
# # 初始化检测器
# detector = detector_class(model_config, True, dev_id=args.dev_id, conf_thresh=args.conf_thresh, nms_thresh=args.nms_thresh)
# # 根据输入类型进行处理
# if os.path.isfile(args.input):
# # 单张图像
# detector.detect_single_image(args.input)
# elif os.path.isdir(args.input):
# # 图像目录
# detector.detect_images_in_directory(args.input)
# else:
# print(f"错误: {args.input} 既不是文件也不是目录")
if __name__ == "__main__":
main()
main()

View File

@ -23,18 +23,3 @@ def get_config(json_path: str, is_vfs: bool = False, vfs_path: str="./applicatio
with open(json_path, 'r') as file:
data = json.load(file)
return data
def is_img(file_name):
"""judge the file is available image or not
Args:
file_name (str): input file name
Returns:
(bool) : whether the file is available image
"""
fmt = os.path.splitext(file_name)[-1]
if isinstance(fmt, str) and fmt.lower() in ['.jpg','.png','.jpeg','.bmp','.jpeg','.webp']:
return True
else:
return False

View File

View File

@ -0,0 +1,154 @@
import sophon_chakcy.sail as sail
import threading
import time
class DetectVideoStream:
def __init__(self,
detectors,
video_path: str,
max_frames: int = 0,
frame_skip: int = 0,
reconnect_attempts: int = 3,
reconnect_delay: int = 5,
is_save: bool = False):
"""
处理视频流视频文件或RTSP流
Args:
detector: 检测器
video_path: 视频文件路径或 RTSP URL
max_frames: 最大处理帧数, 0表示处理所有帧
frame_skip: 跳帧处理, 0表示处理每一帧
reconnect_attempts: RTSP断流重连尝试次数
reconnect_delay: 重连延迟时间 ()
Returns:
处理的帧数, 检测到的物体总数
"""
self.detectors = detectors
self.dev_id = self.detectors[0].dev_id
self.handle = self.detectors[0].handle
self.video_path = video_path
self.max_frames = max_frames
self.frame_skip = frame_skip
if not video_path.startswith("rtsp"):
self.reconnect_attempts = 0
else:
self.reconnect_attempts = reconnect_attempts
self.reconnect_delay = reconnect_delay
self.decoder = sail.Decoder(video_path, True, self.dev_id)
self.stop_event = threading.Event()
self.is_save = is_save
if self.is_save:
import os
from pathlib import Path
import shutil
result_dir = r"results"
# 清理旧的目录
if os.path.exists(result_dir):
print(f"清理旧的构建目录: {result_dir}")
shutil.rmtree(result_dir, ignore_errors=True)
# 创建目录
result_dir_path = Path(result_dir)
if not result_dir_path.exists():
result_dir_path.mkdir(parents=True, exist_ok=True)
enc_params = f"width=1920:height=1080:bitrate=2000:gop=32:gop_preset=2:framerate=25"
self.encoder = sail.Encoder("results/output.mp4", self.dev_id, 'h264_bm', 'NV12', enc_params, 10)
def run_task(self):
if not self.decoder.is_opened():
print(f"错误: 无法打开视频流 {self.video_path}")
return 0, 0
frame_shape = self.decoder.get_frame_shape()
print(f"视频帧形状: {frame_shape}")
frame_count = 0
reconnect_count = 0
try:
while not self.stop_event.is_set():
if self.max_frames > 0 and frame_count >= self.max_frames:
print(f"\n达到最大帧数限制: {self.max_frames}")
break
# 读取一帧
bmimg = sail.BMImage()
ret = self.decoder.read(self.handle, bmimg)
if ret != 0:
print(f"读取帧失败,错误码: {ret}")
reconnect_count += 1
if reconnect_count > self.reconnect_attempts:
print("视频流无法继续读取")
break
print(f"尝试重连 ({reconnect_count}/{self.reconnect_attempts})...")
time.sleep(self.reconnect_delay)
self.decoder.release()
decoder = sail.Decoder(self.video_path, True, self.dev_id)
if decoder.is_opened():
reconnect_count = 0
print("重连成功,继续处理...")
continue
else:
print("重连失败")
if self.frame_skip > 0 and frame_count % (self.frame_skip + 1) != 0:
frame_count += 1
if self.is_save:
self.encoder.video_write(bmimg)
continue
# 开始处理当前帧
start_time = time.time()
frame_count += 1
if frame_count % 10 == 0:
print(f"处理第 {frame_count} 帧...")
# 检测
for detector in self.detectors:
detections = detector.detect(bmimg)
if len(detections) > 0:
detections = detections[0]
filtered_detections = []
for det in detections:
if len(det) >= 6:
x1, y1, x2, y2, conf, cls_id = det[:6]
if conf > detector.conf_thresh:
filtered_detections.append([x1, y1, x2, y2, conf, cls_id])
if len(filtered_detections) > 0:
result_img = detector.draw_detections(bmimg, filtered_detections, None)
else:
result_img = bmimg
else:
result_img = bmimg
if self.is_save:
self.encoder.video_write(result_img)
except KeyboardInterrupt:
print("\n检测到键盘中断,停止处理")
except Exception as e:
print(f"处理视频流时发生错误: {e}")
finally:
# 清理资源
self.decoder.release()
if self.is_save:
self.encoder.release()
print("视频解码器已释放")
def stop_task(self):
print("提前关闭")
self.stop_event.set()

View File