保姆级教程:手把手教你修改YOLOv8源码,集成DeepSORT并输出带类别和置信度的跟踪结果

张开发
2026/4/17 18:12:10 15 分钟阅读

分享文章

保姆级教程:手把手教你修改YOLOv8源码,集成DeepSORT并输出带类别和置信度的跟踪结果
从零实现YOLOv8与DeepSORT深度整合输出带类别和置信度的多目标跟踪系统在计算机视觉领域目标跟踪技术正从单纯的边界框追踪向更丰富的语义信息表达演进。本文将带您深入YOLOv8框架内部通过源码级改造实现与DeepSORT算法的无缝集成最终输出包含目标类别标签和检测置信度的完整跟踪结果。不同于简单的API调用我们将从底层修改检测与跟踪的数据流确保每个跟踪目标都携带完整的语义信息。1. 环境准备与架构设计1.1 基础环境配置确保已安装以下组件Python 3.8PyTorch 1.12 (建议使用CUDA 11.3版本)Ultralytics YOLOv8最新版OpenCV 4.5# 创建conda环境可选 conda create -n yolo_deepsort python3.8 conda activate yolo_deepsort # 安装核心依赖 pip install torch torchvision --extra-index-url https://download.pytorch.org/whl/cu113 pip install ultralytics opencv-python1.2 系统架构设计我们的改造方案采用分层设计思想YOLOv8检测层 → 检测结果格式化 → DeepSORT跟踪层 → 结果增强输出 ↑ (类别、置信度注入)关键修改点集中在三个层面检测结果格式化将YOLOv8输出的原始检测信息转换为DeepSORT可处理的格式跟踪数据增强修改DeepSORT内部数据结构以保存类别和置信度结果可视化定制输出界面同时显示跟踪ID、类别和置信度2. DeepSORT核心模块改造2.1 Detection类改造首先修改detection.py增加类别标签存储能力class Detection(object): def __init__(self, tlwh, confidence, feature, label): self.tlwh np.asarray(tlwh, dtypenp.float32) # 格式转换 self.confidence float(confidence) # 原始置信度 self.feature np.asarray(feature, dtypenp.float32) # 特征向量 self.label int(label) # 新增类别标签 self.confs confidence # 新增置信度副本2.2 Tracker类改造在tracker.py中扩展跟踪目标属性class Track: def __init__(self, mean, covariance, track_id, label, confs): # ...原有卡尔曼滤波初始化代码... self.label label # 新增类别标签 self.confs confs # 新增置信度记录 def update(self, kf, detection): # ...原有更新逻辑... self.label detection.label # 同步更新类别 self.confs detection.confs # 同步更新置信度2.3 输出接口改造修改deep_sort.py中的输出格式def update(self, detections): # ...原有跟踪逻辑... outputs [] for track in tracks: bbox track.to_tlbr() outputs.append(np.array([ bbox[0], bbox[1], bbox[2], bbox[3], # 边界框坐标 track.label, # 类别标签 track.track_id, # 跟踪ID int(track.confs * 100) # 置信度(放大100倍存储) ], dtypenp.int32)) return outputs3. YOLOv8检测集成方案3.1 多模型协同检测通过创建VideoTracker类整合多个YOLOv8模型class VideoTracker: def __init__(self, track_cfg, predictors): self.predictors predictors # 多个检测器实例 self.deepsort build_tracker(deepsort_cfg) # 初始化DeepSORT def image_track(self, img): # 并行执行多个检测 det_results [predictor(img)[0] for predictor in self.predictors] # 合并检测结果 bbox_xywh torch.cat([res.boxes.xywh for res in det_results]) confs torch.cat([res.boxes.conf for res in det_results]) cls torch.cat([res.boxes.cls for res in det_results]) # 送入DeepSORT return self.deepsort.update(bbox_xywh.cpu(), confs.cpu(), img, cls.cpu())3.2 配置文件设计创建track.yaml配置文件input_path: test.mp4 save_option: save: True root: runs/track txt: True img: True class_name: 0: person 1: car 2: bicycle4. 可视化与输出增强4.1 带语义信息的可视化改进可视化函数同时显示三类信息def plot_track(self, img, outputs): for box in outputs: x1, y1, x2, y2, label, track_id, confidence box color self.colors[label] # 绘制边界框 cv2.rectangle(img, (x1,y1), (x2,y2), color, 2) # 左上角显示类别和置信度 label_text f{self.class_names[label]}:{confidence/100:.2f} cv2.putText(img, label_text, (x1, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2) # 右上角显示跟踪ID id_text fID:{track_id} cv2.putText(img, id_text, (x2-50, y1-10), cv2.FONT_HERSHEY_SIMPLEX, 0.6, color, 2) return img4.2 结构化数据输出提供多种格式的结果保存def save_results(self, frame_id, outputs): # 保存原始跟踪数据 np.savetxt(ftrack_{frame_id:05d}.txt, outputs, fmt%d) # 保存归一化坐标(适用于不同分辨率) normalized outputs.copy() h, w self.img_size normalized[:,[0,2]] / w # 归一化x坐标 normalized[:,[1,3]] / h # 归一化y坐标 np.savetxt(fnorm_{frame_id:05d}.txt, normalized, fmt%.6f)5. 性能优化技巧5.1 推理加速方案通过以下手段提升系统实时性优化手段实现方式预期收益半精度推理model.half()提升40%推理速度TensorRT部署转换ONNX后优化提升2-3倍速度多线程预处理使用Queue预加载减少20%延迟5.2 跟踪参数调优DeepSORT关键参数经验值DEEPSORT: MAX_DIST: 0.2 # 特征匹配阈值 MIN_CONFIDENCE: 0.3 # 检测置信度过滤 MAX_IOU_DISTANCE: 0.7 # IoU关联阈值 MAX_AGE: 30 # 目标丢失保持帧数 N_INIT: 3 # 新目标确认帧数6. 实际应用案例6.1 智能零售场景在货架监控中系统可同时跟踪顾客(人)和商品(物)并记录它们的交互过程。通过分析跟踪数据我们可以得到顾客停留热点区域商品被拿取的频率顾客-商品关联关系6.2 交通监控系统改造后的系统在交通场景中表现出色准确区分车辆类型轿车、卡车、自行车记录每辆车的行驶轨迹和速度统计不同类别车辆的通过数量# 交通场景特殊处理 if label truck: roi expand_roi(bbox) # 扩大卡车检测区域7. 常见问题解决方案7.1 ID切换问题当发生严重遮挡时可能出现ID切换。改进方案运动一致性检查通过卡尔曼滤波预测位置与实际检测偏差过大时暂不更新ID特征相似度验证比较历史特征与当前特征的余弦距离时空约束限制同一ID在相邻帧中的最大移动距离7.2 类别混淆处理对于易混淆类别如猫/狗建议提高分类模型的置信度阈值添加后处理规则如尺寸过滤使用时序一致性校验连续n帧确认8. 扩展与进阶8.1 自定义特征提取器替换默认的ReID模型from torchvision.models import resnet50 class CustomExtractor: def __init__(self): self.model resnet50(pretrainedTrue) self.model.fc nn.Identity() # 移除分类层 def __call__(self, img): with torch.no_grad(): return self.model(img)8.2 多模态跟踪融合其他传感器数据def update(self, detections, radar_dataNone): if radar_data: # 雷达辅助验证 detections filter_by_radar(detections, radar_data) return super().update(detections)经过上述改造您的YOLOv8DeepSORT系统将具备完整的语义跟踪能力。在实际部署中建议先在小规模数据上验证各模块的稳定性再逐步扩大应用范围。这种深度整合方案相比直接使用官方API在灵活性和信息丰富度上都有显著提升。

更多文章