sophon_ai_project/YOLOv8/yolov8_detector.py
2026-02-09 20:47:46 +08:00

359 lines
13 KiB
Python
Raw Permalink Blame History

This file contains ambiguous Unicode characters

This file contains Unicode characters that might be confused with other characters. If you think that this is intentional, you can safely ignore this warning. Use the Escape button to reveal them.

import os
import numpy as np
import sophon_chakcy.sail as sail
from .postprocess_numpy import PostProcess
from sqlite_vfs.core import SQLiteVFS
from dataclasses import dataclass, field
from logging import Logger, getLogger
from typing import Optional
@dataclass
class YOLOv8Config:
bmodel_path: str = field(
metadata={"description": "模型路径"}
)
classes: list = field(
metadata={"description": "检测类别"}
)
colors: list = field(
metadata={"description": "边框颜色"}
)
@classmethod
def from_dict(cls, data: dict) -> "YOLOv8Config":
return cls(**data)
class Detector:
"""完整的YOLOv8检测器包含可视化"""
def __init__(self,
handle:sail.Handle,
config_dict: dict,
is_vfs=False,
vfs_path: str="./application.svfs",
conf_thresh = 0.25,
nms_thresh = 0.7,
logger: Optional[Logger] = None):
if logger is None:
self.logger = getLogger()
else:
self.logger = logger
self.handle = handle
yolov8_config = YOLOv8Config.from_dict(config_dict)
# 加载bmodel
self.net = sail.Engine(self.handle)
if is_vfs:
vfs = SQLiteVFS(vfs_path)
bmodel_bytes = vfs.read_file(yolov8_config.bmodel_path)
bmodel_size = vfs.get_file_info(yolov8_config.bmodel_path)["file_size"] # type: ignore
self.net.load(bmodel_bytes, bmodel_size)
# self.net = sail.Engine(bmodel_bytes, bmodel_size, dev_id, sail.IOMode.SYSO)
else:
self.net.load(yolov8_config.bmodel_path)
# self.net = sail.Engine(yolov8_config.bmodel_path, dev_id, sail.IOMode.SYSO)
graph_name = self.net.get_graph_names()[0]
self.net.set_io_mode(graph_name, sail.IOMode.DEVIO)
self.logger.info(f"加载模型: {yolov8_config.bmodel_path}")
self.bmcv = sail.Bmcv(self.handle)
self.graph_name = self.net.get_graph_names()[0]
# 获取输入信息
self.input_name = self.net.get_input_names(self.graph_name)[0]
self.input_shape = self.net.get_input_shape(self.graph_name, self.input_name)
self.input_dtype = self.net.get_input_dtype(self.graph_name, self.input_name)
self.input_scale = self.net.get_input_scale(self.graph_name, self.input_name)
# 获取img_dtype
self.img_dtype = self.bmcv.get_bm_image_data_format(self.input_dtype)
# 获取输出信息
self.output_names = self.net.get_output_names(self.graph_name)
self.output_tensors = {}
for output_name in self.output_names:
output_shape = self.net.get_output_shape(self.graph_name, output_name)
output_dtype = self.net.get_output_dtype(self.graph_name, output_name)
output = sail.Tensor(self.handle, output_shape, output_dtype, True, True)
self.output_tensors[output_name] = output
self.batch_size = self.input_shape[0]
self.net_h = self.input_shape[2]
self.net_w = self.input_shape[3]
# 预处理参数
self.ab = [x * self.input_scale / 255. for x in [1, 0, 1, 0, 1, 0]]
# 后处理参数
self.conf_thresh = conf_thresh
self.nms_thresh = nms_thresh
self.postprocess = PostProcess(
conf_thresh=self.conf_thresh,
nms_thresh=self.nms_thresh,
agnostic=False,
multi_label=False,
max_det=300
)
self.colors = yolov8_config.colors
self.classes = yolov8_config.classes
self.logger.info(f"模型信息:")
self.logger.info(f" 输入形状: {self.input_shape}")
self.logger.info(f" 输入类型: {self.input_dtype}")
self.logger.info(f" colors数量: {len(self.colors)}")
self.logger.info(f" classes类别数: {len(self.classes)}")
def draw_detections(self, image, detections, save_path=None):
"""
绘制检测结果到图像
Args:
image: sail.BMImage 对象
detections: 检测结果列表,每个元素为 [x1, y1, x2, y2, confidence, class_id]
save_path: 保存路径如果不为None则保存图像
Returns:
绘制后的图像
"""
if len(detections) == 0:
self.logger.info("没有检测到任何物体")
return image
# 转换为BGR planar格式用于绘制
img_bgr_planar = self.bmcv.convert_format(image)
thickness = 2
self.logger.info(f"\n绘制 {len(detections)} 个检测结果:")
self.logger.info("-" * 80)
for i, det in enumerate(detections):
if len(det) >= 6:
x1, y1, x2, y2, conf, cls_id = det[:6]
# 转换为整数坐标
x1, y1, x2, y2 = int(x1), int(y1), int(x2), int(y2)
# 使用COLORS为不同类别分配颜色
# 注意class_id是0-79但COLORS列表有160个颜色
# 我们使用 (class_id + 1) % len(COLORS) 来确保不越界
color_idx = (int(cls_id) + 1) % len(self.colors)
color = self.colors[color_idx]
# 获取类别名称
class_name = self.classes[int(cls_id) + 1] if int(cls_id) + 1 < len(self.classes) else f"未知类别{int(cls_id)}"
# 打印检测信息
self.logger.info(f"检测 {i+1:2d}: {class_name:15s} | 置信度: {conf:.4f} | 位置: [{x1:4d}, {y1:4d}, {x2:4d}, {y2:4d}] | 大小: {x2-x1:4d}x{y2-y1:4d} | 颜色: {color}")
# 绘制边界框
if (x2 - x1) > thickness * 2 and (y2 - y1) > thickness * 2:
self.bmcv.rectangle(img_bgr_planar, x1, y1, # type: ignore
(x2 - x1), (y2 - y1),
color, thickness)
self.logger.info("-" * 80)
# 保存图像
if save_path:
self.bmcv.imwrite(save_path, img_bgr_planar) # type: ignore
self.logger.info(f"\n检测结果已保存到: {save_path}")
return img_bgr_planar
def preprocess(self, bmimg: sail.BMImage):
"""
预处理图像
Args:
image: 输入图像 (HWC, BGR格式)
Returns:
(预处理后的BMImage, 原始尺寸, 缩放比例, 填充偏移)
"""
# 转换为RGB
rgb_img = sail.BMImage(self.handle, bmimg.height(), bmimg.width(), # type: ignore
sail.Format.FORMAT_RGB_PLANAR, sail.DATA_TYPE_EXT_1N_BYTE) # type: ignore
self.bmcv.convert_format(bmimg, rgb_img)
# 调整大小并填充
img_w = rgb_img.width()
img_h = rgb_img.height()
r = min(self.net_w / img_w, self.net_h / img_h)
tw = int(round(r * img_w))
th = int(round(r * img_h))
tx1 = int(round((self.net_w - tw) / 2))
ty1 = int(round((self.net_h - th) / 2))
# 创建Padding属性
attr = sail.PaddingAtrr()
attr.set_stx(tx1)
attr.set_sty(ty1)
attr.set_w(tw)
attr.set_h(th)
attr.set_r(114)
attr.set_g(114)
attr.set_b(114)
# 调整大小
resized_img = self.bmcv.crop_and_resize_padding(rgb_img, 0, 0, img_w, img_h,
self.net_w, self.net_h, attr)
# 归一化
preprocessed_img = sail.BMImage(self.handle, self.net_h, self.net_w, # type: ignore
sail.Format.FORMAT_RGB_PLANAR, self.img_dtype) # type: ignore
# 转换并归一化
self.bmcv.convert_to(resized_img, preprocessed_img,
((self.ab[0], self.ab[1]),
(self.ab[2], self.ab[3]),
(self.ab[4], self.ab[5])))
return preprocessed_img, r, tx1, ty1
def detect(self, bmimg: sail.BMImage):
orig_width = bmimg.width()
orig_height = bmimg.height()
self.logger.info(f"图像大小: {orig_width}x{orig_height}")
# 2. 预处理
preprocessed_img, r, tx1, ty1 = self.preprocess(bmimg)
# 3. 准备输入张量
input_tensor = sail.Tensor(self.handle, self.input_shape, self.input_dtype, False, False)
self.bmcv.bm_image_to_tensor(preprocessed_img, input_tensor)
# 4. 执行推理
input_tensors = {self.input_name: input_tensor}
input_shapes = {self.input_name: self.input_shape}
self.net.process(self.graph_name, input_tensors, input_shapes, self.output_tensors)
# 5. 获取输出
outputs = {}
for name in self.output_names:
outputs[name] = self.output_tensors[name].asnumpy()[0]
# 6. 后处理
output_key = list(outputs.keys())[0]
pred = outputs[output_key]
# 添加batch维度
pred_with_batch = pred[np.newaxis, :, :] # (1, 8400, 84)
# 准备后处理参数
org_size_list = [(orig_width, orig_height)]
ratios_batch = [(r, r)]
txy_batch = [(tx1, ty1)]
# 执行后处理
detections = self.postprocess([pred_with_batch], org_size_list, ratios_batch, txy_batch)
filtered_detections = []
if len(detections) > 0:
detections = detections[0] # 取第一个batch的结果
# 7. 过滤低置信度的检测
for det in detections:
if len(det) >= 6:
x1, y1, x2, y2, conf, cls_id = det[:6]
if conf > self.conf_thresh:
filtered_detections.append([x1, y1, x2, y2, conf, cls_id])
self.logger.info(f"原始检测数量: {len(detections)}")
self.logger.info(f"过滤后数量 (置信度 > {self.conf_thresh}): {len(filtered_detections)}")
return filtered_detections
def detect_single_image(self, image_path, output_dir="./detection_results"):
"""
检测单张图像
Args:
image_path: 图像路径
output_dir: 输出目录
Returns:
检测结果列表
"""
# 创建输出目录
os.makedirs(output_dir, exist_ok=True)
self.logger.info(f"\n处理图像: {image_path}")
# 1. 解码图像
decoder = sail.Decoder(image_path, True)
bmimg = sail.BMImage()
ret = decoder.read(self.handle, bmimg)
if ret != 0:
self.logger.error(f"错误: 无法解码图像 {image_path}")
return []
orig_width = bmimg.width()
orig_height = bmimg.height()
self.logger.info(f"图像大小: {orig_width}x{orig_height}")
filtered_detections = self.detect(bmimg)
# 8. 绘制检测结果
if len(filtered_detections) > 0:
# 准备输出文件名
base_name = os.path.basename(image_path)
# name_without_ext = os.path.splitext(base_name)[0]
output_path = os.path.join(output_dir, f"detected_{base_name}")
# 绘制检测框
self.draw_detections(bmimg, filtered_detections, output_path)
return filtered_detections
else:
self.logger.info("没有检测到任何物体")
return []
def detect_images_in_directory(self, input_dir, output_dir="./detection_results"):
"""
检测目录中的所有图像
Args:
input_dir: 输入目录
output_dir: 输出目录
"""
if not os.path.exists(input_dir):
self.logger.error(f"错误: 目录不存在 {input_dir}")
return
# 创建输出目录
os.makedirs(output_dir, exist_ok=True)
self.logger.info(f"开始检测目录: {input_dir}")
self.logger.info(f"结果保存到: {output_dir}")
# 统计信息
total_images = 0
total_detections = 0
# 遍历目录
for root, dirs, files in os.walk(input_dir):
for file in files:
image_path = os.path.join(root, file)
total_images += 1
self.logger.info(f"\n{'='*80}")
self.logger.info(f"处理图像 {total_images}: {file}")
# 检测单张图像
detections = self.detect_single_image(image_path, output_dir)
total_detections += len(detections)
self.logger.info(f"\n{'='*80}")
self.logger.info(f"检测完成!")
self.logger.info(f"总处理图像数: {total_images}")
self.logger.info(f"总检测物体数: {total_detections}")
self.logger.info(f"平均每张图像检测数: {total_detections/max(total_images, 1):.2f}")