359 lines
13 KiB
Python
359 lines
13 KiB
Python
import os
|
||
import numpy as np
|
||
import sophon_chakcy.sail as sail
|
||
from .postprocess_numpy import PostProcess
|
||
from sqlite_vfs.core import SQLiteVFS
|
||
from dataclasses import dataclass, field
|
||
from logging import Logger, getLogger
|
||
from typing import Optional
|
||
|
||
|
||
@dataclass
|
||
class YOLOv8Config:
|
||
bmodel_path: str = field(
|
||
metadata={"description": "模型路径"}
|
||
)
|
||
|
||
classes: list = field(
|
||
metadata={"description": "检测类别"}
|
||
)
|
||
|
||
colors: list = field(
|
||
metadata={"description": "边框颜色"}
|
||
)
|
||
|
||
@classmethod
|
||
def from_dict(cls, data: dict) -> "YOLOv8Config":
|
||
return cls(**data)
|
||
|
||
|
||
class Detector:
|
||
"""完整的YOLOv8检测器,包含可视化"""
|
||
|
||
def __init__(self,
|
||
handle:sail.Handle,
|
||
config_dict: dict,
|
||
is_vfs=False,
|
||
vfs_path: str="./application.svfs",
|
||
conf_thresh = 0.25,
|
||
nms_thresh = 0.7,
|
||
logger: Optional[Logger] = None):
|
||
if logger is None:
|
||
self.logger = getLogger()
|
||
else:
|
||
self.logger = logger
|
||
self.handle = handle
|
||
yolov8_config = YOLOv8Config.from_dict(config_dict)
|
||
# 加载bmodel
|
||
self.net = sail.Engine(self.handle)
|
||
if is_vfs:
|
||
vfs = SQLiteVFS(vfs_path)
|
||
bmodel_bytes = vfs.read_file(yolov8_config.bmodel_path)
|
||
bmodel_size = vfs.get_file_info(yolov8_config.bmodel_path)["file_size"] # type: ignore
|
||
self.net.load(bmodel_bytes, bmodel_size)
|
||
# self.net = sail.Engine(bmodel_bytes, bmodel_size, dev_id, sail.IOMode.SYSO)
|
||
else:
|
||
self.net.load(yolov8_config.bmodel_path)
|
||
# self.net = sail.Engine(yolov8_config.bmodel_path, dev_id, sail.IOMode.SYSO)
|
||
graph_name = self.net.get_graph_names()[0]
|
||
self.net.set_io_mode(graph_name, sail.IOMode.DEVIO)
|
||
self.logger.info(f"加载模型: {yolov8_config.bmodel_path}")
|
||
|
||
self.bmcv = sail.Bmcv(self.handle)
|
||
self.graph_name = self.net.get_graph_names()[0]
|
||
|
||
# 获取输入信息
|
||
self.input_name = self.net.get_input_names(self.graph_name)[0]
|
||
self.input_shape = self.net.get_input_shape(self.graph_name, self.input_name)
|
||
self.input_dtype = self.net.get_input_dtype(self.graph_name, self.input_name)
|
||
self.input_scale = self.net.get_input_scale(self.graph_name, self.input_name)
|
||
|
||
# 获取img_dtype
|
||
self.img_dtype = self.bmcv.get_bm_image_data_format(self.input_dtype)
|
||
|
||
# 获取输出信息
|
||
self.output_names = self.net.get_output_names(self.graph_name)
|
||
self.output_tensors = {}
|
||
for output_name in self.output_names:
|
||
output_shape = self.net.get_output_shape(self.graph_name, output_name)
|
||
output_dtype = self.net.get_output_dtype(self.graph_name, output_name)
|
||
output = sail.Tensor(self.handle, output_shape, output_dtype, True, True)
|
||
self.output_tensors[output_name] = output
|
||
|
||
self.batch_size = self.input_shape[0]
|
||
self.net_h = self.input_shape[2]
|
||
self.net_w = self.input_shape[3]
|
||
|
||
# 预处理参数
|
||
self.ab = [x * self.input_scale / 255. for x in [1, 0, 1, 0, 1, 0]]
|
||
|
||
# 后处理参数
|
||
self.conf_thresh = conf_thresh
|
||
self.nms_thresh = nms_thresh
|
||
self.postprocess = PostProcess(
|
||
conf_thresh=self.conf_thresh,
|
||
nms_thresh=self.nms_thresh,
|
||
agnostic=False,
|
||
multi_label=False,
|
||
max_det=300
|
||
)
|
||
|
||
self.colors = yolov8_config.colors
|
||
self.classes = yolov8_config.classes
|
||
|
||
self.logger.info(f"模型信息:")
|
||
self.logger.info(f" 输入形状: {self.input_shape}")
|
||
self.logger.info(f" 输入类型: {self.input_dtype}")
|
||
self.logger.info(f" colors数量: {len(self.colors)}")
|
||
self.logger.info(f" classes类别数: {len(self.classes)}")
|
||
|
||
def draw_detections(self, image, detections, save_path=None):
|
||
"""
|
||
绘制检测结果到图像
|
||
|
||
Args:
|
||
image: sail.BMImage 对象
|
||
detections: 检测结果列表,每个元素为 [x1, y1, x2, y2, confidence, class_id]
|
||
save_path: 保存路径,如果不为None则保存图像
|
||
|
||
Returns:
|
||
绘制后的图像
|
||
"""
|
||
if len(detections) == 0:
|
||
self.logger.info("没有检测到任何物体")
|
||
return image
|
||
|
||
# 转换为BGR planar格式用于绘制
|
||
img_bgr_planar = self.bmcv.convert_format(image)
|
||
|
||
thickness = 2
|
||
|
||
self.logger.info(f"\n绘制 {len(detections)} 个检测结果:")
|
||
self.logger.info("-" * 80)
|
||
|
||
for i, det in enumerate(detections):
|
||
if len(det) >= 6:
|
||
x1, y1, x2, y2, conf, cls_id = det[:6]
|
||
|
||
# 转换为整数坐标
|
||
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
|
||
|
||
# 使用COLORS为不同类别分配颜色
|
||
# 注意:class_id是0-79,但COLORS列表有160个颜色
|
||
# 我们使用 (class_id + 1) % len(COLORS) 来确保不越界
|
||
color_idx = (int(cls_id) + 1) % len(self.colors)
|
||
color = self.colors[color_idx]
|
||
|
||
# 获取类别名称
|
||
class_name = self.classes[int(cls_id) + 1] if int(cls_id) + 1 < len(self.classes) else f"未知类别{int(cls_id)}"
|
||
|
||
# 打印检测信息
|
||
self.logger.info(f"检测 {i+1:2d}: {class_name:15s} | 置信度: {conf:.4f} | 位置: [{x1:4d}, {y1:4d}, {x2:4d}, {y2:4d}] | 大小: {x2-x1:4d}x{y2-y1:4d} | 颜色: {color}")
|
||
|
||
# 绘制边界框
|
||
if (x2 - x1) > thickness * 2 and (y2 - y1) > thickness * 2:
|
||
self.bmcv.rectangle(img_bgr_planar, x1, y1, # type: ignore
|
||
(x2 - x1), (y2 - y1),
|
||
color, thickness)
|
||
|
||
self.logger.info("-" * 80)
|
||
|
||
# 保存图像
|
||
if save_path:
|
||
self.bmcv.imwrite(save_path, img_bgr_planar) # type: ignore
|
||
self.logger.info(f"\n检测结果已保存到: {save_path}")
|
||
|
||
return img_bgr_planar
|
||
|
||
def preprocess(self, bmimg: sail.BMImage):
|
||
"""
|
||
预处理图像
|
||
|
||
Args:
|
||
image: 输入图像 (HWC, BGR格式)
|
||
|
||
Returns:
|
||
(预处理后的BMImage, 原始尺寸, 缩放比例, 填充偏移)
|
||
"""
|
||
# 转换为RGB
|
||
rgb_img = sail.BMImage(self.handle, bmimg.height(), bmimg.width(), # type: ignore
|
||
sail.Format.FORMAT_RGB_PLANAR, sail.DATA_TYPE_EXT_1N_BYTE) # type: ignore
|
||
self.bmcv.convert_format(bmimg, rgb_img)
|
||
|
||
# 调整大小并填充
|
||
img_w = rgb_img.width()
|
||
img_h = rgb_img.height()
|
||
r = min(self.net_w / img_w, self.net_h / img_h)
|
||
tw = int(round(r * img_w))
|
||
th = int(round(r * img_h))
|
||
tx1 = int(round((self.net_w - tw) / 2))
|
||
ty1 = int(round((self.net_h - th) / 2))
|
||
|
||
# 创建Padding属性
|
||
attr = sail.PaddingAtrr()
|
||
attr.set_stx(tx1)
|
||
attr.set_sty(ty1)
|
||
attr.set_w(tw)
|
||
attr.set_h(th)
|
||
attr.set_r(114)
|
||
attr.set_g(114)
|
||
attr.set_b(114)
|
||
|
||
# 调整大小
|
||
resized_img = self.bmcv.crop_and_resize_padding(rgb_img, 0, 0, img_w, img_h,
|
||
self.net_w, self.net_h, attr)
|
||
|
||
# 归一化
|
||
preprocessed_img = sail.BMImage(self.handle, self.net_h, self.net_w, # type: ignore
|
||
sail.Format.FORMAT_RGB_PLANAR, self.img_dtype) # type: ignore
|
||
|
||
# 转换并归一化
|
||
self.bmcv.convert_to(resized_img, preprocessed_img,
|
||
((self.ab[0], self.ab[1]),
|
||
(self.ab[2], self.ab[3]),
|
||
(self.ab[4], self.ab[5])))
|
||
|
||
return preprocessed_img, r, tx1, ty1
|
||
|
||
def detect(self, bmimg: sail.BMImage):
|
||
orig_width = bmimg.width()
|
||
orig_height = bmimg.height()
|
||
self.logger.info(f"图像大小: {orig_width}x{orig_height}")
|
||
|
||
# 2. 预处理
|
||
preprocessed_img, r, tx1, ty1 = self.preprocess(bmimg)
|
||
|
||
# 3. 准备输入张量
|
||
input_tensor = sail.Tensor(self.handle, self.input_shape, self.input_dtype, False, False)
|
||
self.bmcv.bm_image_to_tensor(preprocessed_img, input_tensor)
|
||
|
||
# 4. 执行推理
|
||
input_tensors = {self.input_name: input_tensor}
|
||
input_shapes = {self.input_name: self.input_shape}
|
||
self.net.process(self.graph_name, input_tensors, input_shapes, self.output_tensors)
|
||
|
||
# 5. 获取输出
|
||
outputs = {}
|
||
for name in self.output_names:
|
||
outputs[name] = self.output_tensors[name].asnumpy()[0]
|
||
|
||
# 6. 后处理
|
||
output_key = list(outputs.keys())[0]
|
||
pred = outputs[output_key]
|
||
|
||
# 添加batch维度
|
||
pred_with_batch = pred[np.newaxis, :, :] # (1, 8400, 84)
|
||
|
||
# 准备后处理参数
|
||
org_size_list = [(orig_width, orig_height)]
|
||
ratios_batch = [(r, r)]
|
||
txy_batch = [(tx1, ty1)]
|
||
|
||
# 执行后处理
|
||
detections = self.postprocess([pred_with_batch], org_size_list, ratios_batch, txy_batch)
|
||
|
||
filtered_detections = []
|
||
|
||
if len(detections) > 0:
|
||
detections = detections[0] # 取第一个batch的结果
|
||
|
||
# 7. 过滤低置信度的检测
|
||
for det in detections:
|
||
if len(det) >= 6:
|
||
x1, y1, x2, y2, conf, cls_id = det[:6]
|
||
if conf > self.conf_thresh:
|
||
filtered_detections.append([x1, y1, x2, y2, conf, cls_id])
|
||
|
||
self.logger.info(f"原始检测数量: {len(detections)}")
|
||
self.logger.info(f"过滤后数量 (置信度 > {self.conf_thresh}): {len(filtered_detections)}")
|
||
|
||
return filtered_detections
|
||
|
||
def detect_single_image(self, image_path, output_dir="./detection_results"):
|
||
"""
|
||
检测单张图像
|
||
|
||
Args:
|
||
image_path: 图像路径
|
||
output_dir: 输出目录
|
||
|
||
Returns:
|
||
检测结果列表
|
||
"""
|
||
# 创建输出目录
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
|
||
self.logger.info(f"\n处理图像: {image_path}")
|
||
|
||
# 1. 解码图像
|
||
decoder = sail.Decoder(image_path, True)
|
||
bmimg = sail.BMImage()
|
||
ret = decoder.read(self.handle, bmimg)
|
||
|
||
if ret != 0:
|
||
self.logger.error(f"错误: 无法解码图像 {image_path}")
|
||
return []
|
||
|
||
orig_width = bmimg.width()
|
||
orig_height = bmimg.height()
|
||
self.logger.info(f"图像大小: {orig_width}x{orig_height}")
|
||
|
||
filtered_detections = self.detect(bmimg)
|
||
|
||
# 8. 绘制检测结果
|
||
if len(filtered_detections) > 0:
|
||
# 准备输出文件名
|
||
base_name = os.path.basename(image_path)
|
||
# name_without_ext = os.path.splitext(base_name)[0]
|
||
output_path = os.path.join(output_dir, f"detected_{base_name}")
|
||
|
||
# 绘制检测框
|
||
self.draw_detections(bmimg, filtered_detections, output_path)
|
||
|
||
return filtered_detections
|
||
else:
|
||
self.logger.info("没有检测到任何物体")
|
||
return []
|
||
|
||
def detect_images_in_directory(self, input_dir, output_dir="./detection_results"):
|
||
"""
|
||
检测目录中的所有图像
|
||
|
||
Args:
|
||
input_dir: 输入目录
|
||
output_dir: 输出目录
|
||
"""
|
||
if not os.path.exists(input_dir):
|
||
self.logger.error(f"错误: 目录不存在 {input_dir}")
|
||
return
|
||
|
||
# 创建输出目录
|
||
os.makedirs(output_dir, exist_ok=True)
|
||
|
||
self.logger.info(f"开始检测目录: {input_dir}")
|
||
self.logger.info(f"结果保存到: {output_dir}")
|
||
|
||
# 统计信息
|
||
total_images = 0
|
||
total_detections = 0
|
||
|
||
# 遍历目录
|
||
for root, dirs, files in os.walk(input_dir):
|
||
for file in files:
|
||
|
||
image_path = os.path.join(root, file)
|
||
total_images += 1
|
||
|
||
self.logger.info(f"\n{'='*80}")
|
||
self.logger.info(f"处理图像 {total_images}: {file}")
|
||
|
||
# 检测单张图像
|
||
detections = self.detect_single_image(image_path, output_dir)
|
||
total_detections += len(detections)
|
||
|
||
self.logger.info(f"\n{'='*80}")
|
||
self.logger.info(f"检测完成!")
|
||
self.logger.info(f"总处理图像数: {total_images}")
|
||
self.logger.info(f"总检测物体数: {total_detections}")
|
||
self.logger.info(f"平均每张图像检测数: {total_detections/max(total_images, 1):.2f}")
|