modify: 实现视频流推理
This commit is contained in:
parent
7880c9e15b
commit
2ff9a2ecb1
@ -1,11 +1,14 @@
|
||||
# yolov8_inference_with_colors.py
|
||||
import os
|
||||
|
||||
import time
|
||||
import numpy as np
|
||||
import sophon_chakcy.sail as sail
|
||||
from .postprocess_numpy import PostProcess
|
||||
from sqlite_vfs.core import SQLiteVFS
|
||||
from dataclasses import dataclass, field
|
||||
from typing import Optional
|
||||
import threading
|
||||
|
||||
|
||||
@dataclass
|
||||
class YOLOv8Config:
|
||||
@ -149,37 +152,17 @@ class Detector:
|
||||
|
||||
return img_bgr_planar
|
||||
|
||||
def detect_single_image(self, image_path, output_dir="./detection_results"):
|
||||
def preprocess(self, bmimg: sail.BMImage):
|
||||
"""
|
||||
检测单张图像
|
||||
|
||||
预处理图像
|
||||
|
||||
Args:
|
||||
image_path: 图像路径
|
||||
output_dir: 输出目录
|
||||
|
||||
image: 输入图像 (HWC, BGR格式)
|
||||
|
||||
Returns:
|
||||
检测结果列表
|
||||
(预处理后的BMImage, 原始尺寸, 缩放比例, 填充偏移)
|
||||
"""
|
||||
# 创建输出目录
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
print(f"\n处理图像: {image_path}")
|
||||
|
||||
# 1. 解码图像
|
||||
decoder = sail.Decoder(image_path, True, self.dev_id)
|
||||
bmimg = sail.BMImage()
|
||||
ret = decoder.read(self.handle, bmimg)
|
||||
|
||||
if ret != 0:
|
||||
print(f"错误: 无法解码图像 {image_path}")
|
||||
return []
|
||||
|
||||
orig_width = bmimg.width()
|
||||
orig_height = bmimg.height()
|
||||
print(f"图像大小: {orig_width}x{orig_height}")
|
||||
|
||||
# 2. 预处理
|
||||
# 转换为RGB格式
|
||||
# 转换为RGB
|
||||
rgb_img = sail.BMImage(self.handle, bmimg.height(), bmimg.width(),
|
||||
sail.Format.FORMAT_RGB_PLANAR, sail.DATA_TYPE_EXT_1N_BYTE)
|
||||
self.bmcv.convert_format(bmimg, rgb_img)
|
||||
@ -217,6 +200,16 @@ class Detector:
|
||||
(self.ab[2], self.ab[3]),
|
||||
(self.ab[4], self.ab[5])))
|
||||
|
||||
return preprocessed_img, r, tx1, ty1
|
||||
|
||||
def detect(self, bmimg: sail.BMImage):
|
||||
orig_width = bmimg.width()
|
||||
orig_height = bmimg.height()
|
||||
print(f"图像大小: {orig_width}x{orig_height}")
|
||||
|
||||
# 2. 预处理
|
||||
preprocessed_img, r, tx1, ty1 = self.preprocess(bmimg)
|
||||
|
||||
# 3. 准备输入张量
|
||||
input_tensor = sail.Tensor(self.handle, self.input_shape, self.input_dtype, False, False)
|
||||
self.bmcv.bm_image_to_tensor(preprocessed_img, input_tensor)
|
||||
@ -245,6 +238,38 @@ class Detector:
|
||||
|
||||
# 执行后处理
|
||||
detections = self.postprocess([pred_with_batch], org_size_list, ratios_batch, txy_batch)
|
||||
return detections
|
||||
|
||||
def detect_single_image(self, image_path, output_dir="./detection_results"):
|
||||
"""
|
||||
检测单张图像
|
||||
|
||||
Args:
|
||||
image_path: 图像路径
|
||||
output_dir: 输出目录
|
||||
|
||||
Returns:
|
||||
检测结果列表
|
||||
"""
|
||||
# 创建输出目录
|
||||
os.makedirs(output_dir, exist_ok=True)
|
||||
|
||||
print(f"\n处理图像: {image_path}")
|
||||
|
||||
# 1. 解码图像
|
||||
decoder = sail.Decoder(image_path, True, self.dev_id)
|
||||
bmimg = sail.BMImage()
|
||||
ret = decoder.read(self.handle, bmimg)
|
||||
|
||||
if ret != 0:
|
||||
print(f"错误: 无法解码图像 {image_path}")
|
||||
return []
|
||||
|
||||
orig_width = bmimg.width()
|
||||
orig_height = bmimg.height()
|
||||
print(f"图像大小: {orig_width}x{orig_height}")
|
||||
|
||||
detections = self.detect(bmimg)
|
||||
|
||||
if len(detections) > 0:
|
||||
detections = detections[0] # 取第一个batch的结果
|
||||
|
||||
32
main.py
32
main.py
@ -1,6 +1,19 @@
|
||||
import os
|
||||
import argparse
|
||||
from sophon_chakcy.get_detector import get_detector
|
||||
from sophon_chakcy.utils.detect_video_stream import DetectVideoStream
|
||||
from threading import Thread
|
||||
|
||||
|
||||
def detector_task(args):
|
||||
detector_class, model_config = get_detector(args.name, True)
|
||||
detector = detector_class(model_config, True, dev_id=args.dev_id, conf_thresh=args.conf_thresh, nms_thresh=args.nms_thresh)
|
||||
detect_video_stream = DetectVideoStream([detector], args.input)
|
||||
t = Thread(target=detect_video_stream.run_task)
|
||||
t.start()
|
||||
import time
|
||||
time.sleep(5)
|
||||
detect_video_stream.stop_task()
|
||||
t.join()
|
||||
|
||||
|
||||
def main():
|
||||
@ -12,22 +25,9 @@ def main():
|
||||
parser.add_argument('--nms_thresh', type=float, default=0.7, help='NMS阈值')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
detector_class, model_config = get_detector(args.name)
|
||||
|
||||
# 初始化检测器
|
||||
detector = detector_class(model_config, False, dev_id=args.dev_id, conf_thresh=args.conf_thresh, nms_thresh=args.nms_thresh)
|
||||
|
||||
# 根据输入类型进行处理
|
||||
if os.path.isfile(args.input):
|
||||
# 单张图像
|
||||
detector.detect_single_image(args.input)
|
||||
detector_task(args)
|
||||
|
||||
elif os.path.isdir(args.input):
|
||||
# 图像目录
|
||||
detector.detect_images_in_directory(args.input)
|
||||
else:
|
||||
print(f"错误: {args.input} 既不是文件也不是目录")
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
|
||||
@ -2,27 +2,15 @@ import os
|
||||
import argparse
|
||||
from sophon_chakcy.get_detector import get_detector
|
||||
from module_bank import PythonToSQLite
|
||||
from multiprocessing import Process
|
||||
|
||||
# 安装导入器
|
||||
packer = PythonToSQLite("plugins.mbank")
|
||||
finder = packer.install_importer()
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="YOLOv8目标检测")
|
||||
parser.add_argument('--name', type=str, required=True, help='算法名称')
|
||||
parser.add_argument('--input', type=str, required=True, help='输入图像或目录路径')
|
||||
parser.add_argument('--dev_id', type=int, default=0, help='设备ID')
|
||||
parser.add_argument('--conf_thresh', type=float, default=0.25, help='置信度阈值')
|
||||
parser.add_argument('--nms_thresh', type=float, default=0.7, help='NMS阈值')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
def detector_task(args):
|
||||
detector_class, model_config = get_detector(args.name, True)
|
||||
|
||||
# 初始化检测器
|
||||
detector = detector_class(model_config, True, dev_id=args.dev_id, conf_thresh=args.conf_thresh, nms_thresh=args.nms_thresh)
|
||||
|
||||
# 根据输入类型进行处理
|
||||
if os.path.isfile(args.input):
|
||||
# 单张图像
|
||||
@ -34,5 +22,37 @@ def main():
|
||||
else:
|
||||
print(f"错误: {args.input} 既不是文件也不是目录")
|
||||
|
||||
|
||||
def main():
|
||||
parser = argparse.ArgumentParser(description="YOLOv8目标检测")
|
||||
parser.add_argument('--name', type=str, required=True, help='算法名称')
|
||||
parser.add_argument('--input', type=str, required=True, help='输入图像或目录路径')
|
||||
parser.add_argument('--dev_id', type=int, default=0, help='设备ID')
|
||||
parser.add_argument('--conf_thresh', type=float, default=0.25, help='置信度阈值')
|
||||
parser.add_argument('--nms_thresh', type=float, default=0.7, help='NMS阈值')
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
p = Process(target=detector_task,args=(args,))
|
||||
p.start()
|
||||
p.join()
|
||||
|
||||
# detector_class, model_config = get_detector(args.name, True)
|
||||
|
||||
# # 初始化检测器
|
||||
# detector = detector_class(model_config, True, dev_id=args.dev_id, conf_thresh=args.conf_thresh, nms_thresh=args.nms_thresh)
|
||||
|
||||
# # 根据输入类型进行处理
|
||||
# if os.path.isfile(args.input):
|
||||
# # 单张图像
|
||||
# detector.detect_single_image(args.input)
|
||||
|
||||
# elif os.path.isdir(args.input):
|
||||
# # 图像目录
|
||||
# detector.detect_images_in_directory(args.input)
|
||||
# else:
|
||||
# print(f"错误: {args.input} 既不是文件也不是目录")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
main()
|
||||
|
||||
@ -23,18 +23,3 @@ def get_config(json_path: str, is_vfs: bool = False, vfs_path: str="./applicatio
|
||||
with open(json_path, 'r') as file:
|
||||
data = json.load(file)
|
||||
return data
|
||||
|
||||
|
||||
def is_img(file_name):
|
||||
"""judge the file is available image or not
|
||||
Args:
|
||||
file_name (str): input file name
|
||||
Returns:
|
||||
(bool) : whether the file is available image
|
||||
"""
|
||||
fmt = os.path.splitext(file_name)[-1]
|
||||
if isinstance(fmt, str) and fmt.lower() in ['.jpg','.png','.jpeg','.bmp','.jpeg','.webp']:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
0
src/sophon_chakcy/utils/__init__.py
Normal file
0
src/sophon_chakcy/utils/__init__.py
Normal file
154
src/sophon_chakcy/utils/detect_video_stream.py
Normal file
154
src/sophon_chakcy/utils/detect_video_stream.py
Normal file
@ -0,0 +1,154 @@
|
||||
import sophon_chakcy.sail as sail
|
||||
import threading
|
||||
import time
|
||||
|
||||
|
||||
class DetectVideoStream:
|
||||
def __init__(self,
|
||||
detectors,
|
||||
video_path: str,
|
||||
max_frames: int = 0,
|
||||
frame_skip: int = 0,
|
||||
reconnect_attempts: int = 3,
|
||||
reconnect_delay: int = 5,
|
||||
is_save: bool = False):
|
||||
"""
|
||||
处理视频流(视频文件或RTSP流)
|
||||
|
||||
Args:
|
||||
detector: 检测器
|
||||
video_path: 视频文件路径或 RTSP URL
|
||||
max_frames: 最大处理帧数, 0表示处理所有帧
|
||||
frame_skip: 跳帧处理, 0表示处理每一帧
|
||||
reconnect_attempts: RTSP断流重连尝试次数
|
||||
reconnect_delay: 重连延迟时间 (秒)
|
||||
|
||||
Returns:
|
||||
处理的帧数, 检测到的物体总数
|
||||
"""
|
||||
self.detectors = detectors
|
||||
self.dev_id = self.detectors[0].dev_id
|
||||
self.handle = self.detectors[0].handle
|
||||
self.video_path = video_path
|
||||
self.max_frames = max_frames
|
||||
self.frame_skip = frame_skip
|
||||
if not video_path.startswith("rtsp"):
|
||||
self.reconnect_attempts = 0
|
||||
else:
|
||||
self.reconnect_attempts = reconnect_attempts
|
||||
self.reconnect_delay = reconnect_delay
|
||||
self.decoder = sail.Decoder(video_path, True, self.dev_id)
|
||||
self.stop_event = threading.Event()
|
||||
self.is_save = is_save
|
||||
if self.is_save:
|
||||
import os
|
||||
from pathlib import Path
|
||||
import shutil
|
||||
|
||||
result_dir = r"results"
|
||||
# 清理旧的目录
|
||||
if os.path.exists(result_dir):
|
||||
print(f"清理旧的构建目录: {result_dir}")
|
||||
shutil.rmtree(result_dir, ignore_errors=True)
|
||||
|
||||
# 创建目录
|
||||
result_dir_path = Path(result_dir)
|
||||
if not result_dir_path.exists():
|
||||
result_dir_path.mkdir(parents=True, exist_ok=True)
|
||||
|
||||
enc_params = f"width=1920:height=1080:bitrate=2000:gop=32:gop_preset=2:framerate=25"
|
||||
self.encoder = sail.Encoder("results/output.mp4", self.dev_id, 'h264_bm', 'NV12', enc_params, 10)
|
||||
|
||||
def run_task(self):
|
||||
if not self.decoder.is_opened():
|
||||
print(f"错误: 无法打开视频流 {self.video_path}")
|
||||
return 0, 0
|
||||
|
||||
frame_shape = self.decoder.get_frame_shape()
|
||||
print(f"视频帧形状: {frame_shape}")
|
||||
|
||||
frame_count = 0
|
||||
reconnect_count = 0
|
||||
|
||||
try:
|
||||
while not self.stop_event.is_set():
|
||||
if self.max_frames > 0 and frame_count >= self.max_frames:
|
||||
print(f"\n达到最大帧数限制: {self.max_frames}")
|
||||
break
|
||||
|
||||
# 读取一帧
|
||||
bmimg = sail.BMImage()
|
||||
ret = self.decoder.read(self.handle, bmimg)
|
||||
if ret != 0:
|
||||
print(f"读取帧失败,错误码: {ret}")
|
||||
reconnect_count += 1
|
||||
|
||||
if reconnect_count > self.reconnect_attempts:
|
||||
print("视频流无法继续读取")
|
||||
break
|
||||
|
||||
print(f"尝试重连 ({reconnect_count}/{self.reconnect_attempts})...")
|
||||
time.sleep(self.reconnect_delay)
|
||||
self.decoder.release()
|
||||
decoder = sail.Decoder(self.video_path, True, self.dev_id)
|
||||
if decoder.is_opened():
|
||||
reconnect_count = 0
|
||||
print("重连成功,继续处理...")
|
||||
continue
|
||||
else:
|
||||
print("重连失败")
|
||||
|
||||
if self.frame_skip > 0 and frame_count % (self.frame_skip + 1) != 0:
|
||||
frame_count += 1
|
||||
if self.is_save:
|
||||
self.encoder.video_write(bmimg)
|
||||
continue
|
||||
|
||||
# 开始处理当前帧
|
||||
start_time = time.time()
|
||||
frame_count += 1
|
||||
|
||||
if frame_count % 10 == 0:
|
||||
print(f"处理第 {frame_count} 帧...")
|
||||
|
||||
# 检测
|
||||
for detector in self.detectors:
|
||||
detections = detector.detect(bmimg)
|
||||
|
||||
if len(detections) > 0:
|
||||
detections = detections[0]
|
||||
|
||||
filtered_detections = []
|
||||
for det in detections:
|
||||
if len(det) >= 6:
|
||||
x1, y1, x2, y2, conf, cls_id = det[:6]
|
||||
if conf > detector.conf_thresh:
|
||||
filtered_detections.append([x1, y1, x2, y2, conf, cls_id])
|
||||
|
||||
if len(filtered_detections) > 0:
|
||||
result_img = detector.draw_detections(bmimg, filtered_detections, None)
|
||||
else:
|
||||
result_img = bmimg
|
||||
else:
|
||||
result_img = bmimg
|
||||
|
||||
if self.is_save:
|
||||
self.encoder.video_write(result_img)
|
||||
|
||||
except KeyboardInterrupt:
|
||||
print("\n检测到键盘中断,停止处理")
|
||||
except Exception as e:
|
||||
print(f"处理视频流时发生错误: {e}")
|
||||
|
||||
finally:
|
||||
# 清理资源
|
||||
self.decoder.release()
|
||||
if self.is_save:
|
||||
self.encoder.release()
|
||||
print("视频解码器已释放")
|
||||
|
||||
def stop_task(self):
|
||||
print("提前关闭")
|
||||
self.stop_event.set()
|
||||
|
||||
|
||||
0
src/sophon_chakcy/utils/task_manger.py
Normal file
0
src/sophon_chakcy/utils/task_manger.py
Normal file
Loading…
x
Reference in New Issue
Block a user