Pytorch训练及导出部署全流程

张开发
2026/4/17 22:57:24 15 分钟阅读

分享文章

Pytorch训练及导出部署全流程
本文主要描述训练全流程跟各部分的作用同时对训练好的模型进行导出部署如想深究原理部分推荐观看吴恩达老师的深度学习课程。本文只列出关键代码相关yolov8的训练部署全流程在姊妹篇可看完整的模型代码在GitHub可观看。本篇内容比较多大家可配合目录选择性观看。一、模型训练1.数据集的划分与清洗为了方便代码调用避免出现路径报错等问题直接将数据集文件夹与代码文件配置在同一根目录下后续调用时只需填写相对路径即可无需繁琐的绝对路径配置。数据集的标注工作我是使用labelImg工具独立完成的具体标注流程如下首先需要配置类别名称文件明确模型需要识别的目标类别同时配置label文件夹用于存放标注生成的xml文件之后打开labelImg工具并导入需要标注的数据集文件夹使用键盘W键即可框选目标区域框选完成后从下拉菜单中选择对应的类别名称确认无误后按CtrlS键保存标注结果按D键可快速切换至下一张图片继续标注。若标注过程中出现错误选中错误的框选区域按delete键即可删除若有遗漏的目标重新按W键框选目标、选择类别并保存即可整个标注过程简单易操作适合新手上手。注意对应好标签跟图片分别放在两个文件夹中且label文件需要转换为txt方能使用。标注完成后需要配置数据类以适配模型训练具体需继承PyTorch中的Dataset类自定义三个核心方法__init__方法用于初始化主要实现图片路径与对应标注文件路径的拼接确保每张图片都能精准匹配到其标注信息__getitem__方法用于按索引获取单张图片及其对应的标注数据方便后续批量加载__len__方法用于返回数据集的总样本数供DataLoader计算批次等参数使用。配置 transforms用于对数据集进行清洗和数据增强操作可根据数据集的实际情况添加对应配置比如图片尺寸统一、归一化、随机翻转、随机裁剪、亮度/对比度调整等既能统一输入模型的数据格式又能增加数据集的多样性提升模型的泛化能力避免过拟合。transform transforms.Compose([ transforms.Resize((640, 640)), # 维度匹配尺寸对齐 transforms.ToTensor(), # HWC → CHW 转张量 transforms.Normalize(mean, std) # 数据缩放归一化 ])如果在标注时没有分割数据集可使用torch.utils.data.random_split方法对处理好的数据集进行训练集和验证集的划分划分比例可根据数据集规模灵活调整通常采用8:2或7:3的比例其中训练集用于模型的训练拟合验证集用于实时监控模型的训练效果及时发现过拟合或欠拟合问题。train_dataset, val_dataset torch.utils.data.random_split( dataset, # 要划分的完整数据集 lengths, # 划分长度列表/比例如[8000,2000]或[0.8,0.2] generatortorch.Generator().manual_seed(42) # 随机种子保证可复现 )配置DataLoader对划分好的训练集和验证集分别进行打包处理主要配置参数包括批次大小batch_size、是否打乱数据shuffle、是否使用多线程加载num_workers等。通过DataLoader可实现数据的批量加载减少内存占用同时打乱数据能避免模型学习到数据的顺序规律提升模型的泛化能力多线程加载则能加快数据读取速度提升训练效率。dataloader torch.utils.data.DataLoader( dataset, # 加载的数据集 batch_size1, # 批次大小 shuffleFalse, # 是否打乱数据 num_workers0, # 加载线程数 collate_fnNone, # 数据合并函数 pin_memoryFalse, # 锁页内存加速GPU传输 drop_lastFalse # 是否丢弃最后一个不完整批次 )2.神经网络的搭建神经网络的搭建有两种常用方式传统的搭建方法是创建model类并继承nn.Model在类的__init__方法内部搭建网络结构根据任务需求依次配置各层组件再定义forward方法实现网络的前向传播明确数据在网络中的流动路径完成特征提取和输出预测。这里不再详细展开每一层的具体代码配置重点介绍各个常用网络层的核心作用方便大家理解网络结构的设计逻辑。Flatten拉直层核心作用是将卷积层、池化层输出的多维特征图通常为四维格式为[batch_size, channels, height, width]拉直为一维向量便于后续输入全连接层进行分类或预测是连接卷积特征提取与全连接分类的关键层。Dense全连接层也叫线性层nn.Linear核心作用是对输入的一维特征向量进行加权求和与偏置调整将提取到的高维特征映射到目标类别空间常用于模型的最终分类或回归预测常与sigmoid、softmax等激活函数混用实现非线性分类提升分类精度。现在一般可用全局平均池化进行代替。Conv2D卷积层是图像识别模型中核心的特征提取层通过设置不同大小的卷积核、步长stride和填充padding对输入图片进行卷积运算逐步提取图片的底层特征如边缘、纹理卷积核的数量决定了提取特征的丰富度步长决定了卷积运算的步幅填充则用于避免卷积后特征图尺寸缩小过多。卷积层还有别的几种这只是最通用的可根据需求切换。如空洞卷积在不丢失的情况下扩大感受野。Activation激活层核心作用是为神经网络加入非线性因素打破线性映射的局限性让模型能够学习到数据中的复杂规律和非线性关系常用的激活函数有ReLU、sigmoid、tanh等其中ReLU函数因计算高效、能缓解梯度消失问题被广泛应用于卷积神经网络中。Maxpool池化层也叫最大池化层核心作用是对卷积层输出的特征图进行下采样压缩数据尺寸通常使图像高宽减半减少模型的计算量和参数数量同时保留特征图中的关键特征避免冗余信息干扰进一步提升模型的泛化能力。另有平均池化层一般在图像模型中不用。Dropout随机失活层训练过程中会随机关掉部分神经元按设定的概率核心作用是防止模型过拟合避免模型过度依赖某几个神经元的输出迫使模型学习到更具通用性的特征提升模型在未见过的数据上的预测能力测试时会恢复所有神经元的工作。需注意在导出模型参数和进行模型评估时需关闭否则参数不完整。BatchNorm批归一化层核心作用是对每一批次的输入数据进行标准化处理将数据调整到合适的分布范围通常均值为0、方差为1减少梯度消失或梯度爆炸的风险让模型的训练过程更稳定、收敛速度更快同时还能缓解过拟合问题。Add残差连接主要用于深层神经网络中核心作用是将输入的特征与网络某一层的输出特征进行叠加形成残差结构能够有效缓解深层网络中的梯度消失问题让模型能够训练更深的网络层数提升特征提取能力和模型性能。在transform架构中常见。在实际操作中我没有选择从零搭建网络而是直接加载官方预训练模型和预训练权重进行微调fine-tune这种方式既能节省训练时间又能利用预训练模型学到的通用特征提升模型的训练效果。并且这种方法官方封装功能比较完善根据个人的任务需求需要重点对模型的输出类别通过yaml.py进行重新配置与自己标注的数据集类别数量一致同时调整输入分辨率与自己的数据集图片尺寸匹配若自己的数据集图像大小与原本模型的默认输入大小相差过大还需要调整锚框anchor尺寸确保模型能够精准识别目标在封装好的yolov8中会自动调用K均值聚类。3.训练方法配置损失函数Loss Function损失函数是衡量模型预测结果与真实标签之间误差的核心指标其选择需根据具体任务类型确定若为分类任务常用交叉熵损失函数若为回归任务常用均方误差损失函数若为目标检测任务常用置信度损失、坐标损失和类别损失的组合损失函数。损失函数的输出值越小说明模型的预测结果越接近真实值训练效果越好。配置优化器Optimizer优化器的核心作用是根据损失函数计算的梯度更新模型的参数最小化损失函数常用的优化器有SGD、Adam、RMSprop等其中Adam优化器因收敛速度快、适应性强被广泛应用于深度学习模型训练中。配置优化器时需要设置学习率learning rate、权重衰减weight decay等超参数学习率决定了参数更新的步幅过大容易导致训练不收敛过小则会导致训练速度过慢权重衰减则用于抑制过拟合。配置学习调度器Learning Rate Scheduler该配置为可选项核心作用是根据训练轮数动态调整学习率比如训练前期使用较大的学习率加快收敛训练后期使用较小的学习率精细调整参数避免模型在最优解附近震荡。如果识别目标较小不推荐开启。# 损失函数 criterion torch.nn.CrossEntropyLoss() # 分类任务 criterion torch.nn.MSELoss() # 回归任务 # 优化器 optimizer torch.optim.Adam( model.parameters(), # 模型参数 lr0.001, # 学习率 weight_decay0.0001 # 权重衰减 ) # 学习率调度器 scheduler torch.optim.lr_scheduler.StepLR( optimizer, # 优化器 step_size10, # 每多少轮衰减 gamma0.1 # 衰减系数 )4.训练过程首先定义训练轮数epoch和批次大小batch_size训练轮数是指整个数据集被模型训练的次数批次大小则是指每次输入模型的样本数量两者需根据数据集规模、硬件配置GPU内存大小灵活调整数据集规模较大时可适当增加训练轮数减小批次大小硬件配置较强时可增大批次大小提升训练速度。注可将参数等送至GPU进行数据处理增加效率。训练过程的核心流程分为四个步骤循环执行直至完成所有训练轮数第一步是前向传播将训练集的批量数据输入模型通过模型的网络结构进行特征提取和预测得到模型的预测结果第二步是计算损失使用配置好的损失函数对比模型的预测结果与真实标签计算出当前批次的损失值第三步是反向传播根据损失值计算模型各参数的梯度明确参数需要调整的方向和幅度第四步是参数更新通过配置好的优化器根据梯度信息更新模型的所有可训练参数逐步减小损失值。训练过程中需要注意模式切换训练时需开启模型的训练模式model.train()此时模型中的Dropout层、BatchNorm层会按照训练逻辑工作验证和导出模型参数时需切换为模型的评估模式model.eval()此时Dropout层会关闭BatchNorm层会使用训练过程中积累的均值和方差避免影响验证结果的准确性。训练过程中建议实时打印关键指标包括每一批次的训练损失train_loss、每一轮的验证损失val_loss、验证集的准确率accuracy、精确率precision、召回率recall等通过这些指标的变化可以直观判断模型的训练状态若训练损失和验证损失均持续下降说明模型在正常收敛若训练损失持续下降但验证损失开始上升说明模型出现过拟合若两者均持续居高不下说明模型可能欠拟合或超参数设置不合理需及时调整。打印图像类信息我推荐使用tensorboard进行配置可通过命令行生成网站地址在网站上进行查看不仅能看训练过程也能查看数据清洗后的结果。tensorboard --logdir./runs --port6006 --host0.0.0.0from torch.utils.tensorboard import SummaryWriter # 初始化 writer SummaryWriter(log_dir./runs/exp) # 记录标量loss/acc/lr writer.add_scalar(train/loss, loss, epoch) writer.add_scalar(train/acc, acc, epoch) # 记录模型结构图 writer.add_graph(model, input_tensor) # 关闭 writer.close()5.模型评估模型评估的核心目的是检验模型的泛化能力判断模型是否达到预期的训练效果评估过程需使用独立的验证集未参与模型训练的数据避免使用训练集评估导致结果失真。评估时可通过with语句首先将模型切换为评估模式固定参数防止Dropout随机丢失影响模型精度然后将验证集数据通过DataLoader批量输入模型得到模型的预测结果之后计算各类评估指标对于分类任务主要计算精确率预测为正类且实际为正类的样本占所有预测为正类样本的比例、召回率实际为正类且预测为正类的样本占所有实际为正类样本的比例、准确率预测正确的样本占总样本的比例、F1分数精确率和召回率的调和平均数等对于目标检测任务主要计算mAP平均精度均值、IOU交并比等指标其中mAP是衡量目标检测模型性能的核心指标数值越高模型的检测效果越好。为了更直观地观察模型的训练状态建议绘制损失曲线和准确率曲线以训练轮数为横轴损失值、准确率为纵轴通过曲线可以清晰看到训练过程中损失和准确率的变化趋势便于及时发现模型训练中的问题如过拟合、欠拟合。根据评估结果对模型进行优化调整若模型过拟合可通过增加Dropout概率、增大权重衰减、增加数据增强力度或扩大数据集若模型欠拟合可增加训练轮数、增大学习率、增加网络层数或神经元数量若某类目标的识别精度较低可针对性增加该类目标的样本数量优化标注质量。解决的前提一定要先定位好位置不要盲目进行调整。6.保存模型模型训练完成并通过评估后需要将模型保存下来方便后续的加载、测试和部署保存时需注意区分模型文件的类型避免后续使用时出现兼容问题。导出时建议配置自动生成导出路径。可用模型名加数据集命名清晰明了。常用的模型保存方式有两种第一种是保存完整的模型结构权重参数这种方式保存的文件包含了模型的网络结构和所有训练好的参数加载时无需重新定义网络结构直接加载即可使用缺点是文件体积较大一般不推荐第二种是只保存模型的权重参数state_dict这种方式保存的文件体积较小通用性更强后续进行onnx导出也更方便不受网络结构定义方式的影响加载时需要先重新定义好网络结构再将权重参数加载到网络中。# 方法1保存完整模型不推荐 torch.save(model, model_full.pth) # 方法2仅保存模型参数推荐 torch.save(model.state_dict(), model_weights.pth)# 加载模型参数推荐用法 model.load_state_dict(torch.load(model_weights.pth, map_locationdevice))7.性能测试模型保存完成后需要进行全面的性能测试检验模型的最终能力确保模型能够满足实际部署需求性能测试主要围绕推理速度、识别精度、稳定性等方面展开。首先使用独立的测试集未参与训练和评估的数据对模型进行测试统计模型的推理速度包括单张图片的平均推理时间、批量图片的推理速度推理速度直接影响模型的部署体验尤其是实时检测场景需要确保推理速度满足要求。记录性能测试的所有结果包括推理速度、识别精度、各场景下的识别表现等形成性能报告为后续的模型导出和部署提供参考若性能不满足需求需返回模型训练环节进行优化调整。这项需要根据模型的实际用途进行选择性测试看重准确率还是实时性等等。二、模型导出1. 添加logger进行迅速定位模型导出过程中可能会出现各种异常问题如环境配置错误、模型加载失败、导出格式异常等为了快速定位问题所在需要添加logger日志记录工具配置日志的输出级别如INFO、ERROR、输出格式、输出路径等。在最开始可添加屏蔽警告类信息减少日志生成内容以便更好地定位。import warnings warnings.filterwarnings(ignore)通过logger可以实时记录导出过程中的每一步操作包括环境检查结果、模型加载状态、导出进度、异常信息等当导出出现错误时无需逐行排查代码只需查看日志文件即可快速找到错误原因如某条依赖库版本不兼容、GPU内存不足等大幅提升问题排查效率节省时间。2. 添加模型注册表并对库的版本进行检查添加模型注册表Model Registry主要用于统一管理模型的注册、加载和导出避免因模型名称、结构不统一导致导出失败同时便于后续扩展多种模型的导出功能实现一键导出不同类型的模型。注意导入新的模型时要在模型注册表中进行添加若为模型神经网络需进行包的导入。模型导出对相关依赖库的版本要求较高不同版本的库可能存在兼容性问题导致导出失败或导出后的模型无法正常使用因此需要对核心依赖库的版本进行检查包括PyTorch、ONNX、ONNXRuntime、OpenCV等确保各库的版本符合导出要求若版本不兼容需及时升级或降级同时将版本检查结果记录到日志中便于后续追溯。环境要求需谨慎配置这里推荐Anaconda可对环境进行分割。详情可看我的另一篇作品。3. 定义基础命令行传参为了提升模型导出的灵活性和通用性避免每次导出都需要修改代码需要定义基础的命令行传参通过命令行输入参数的方式配置模型导出的相关设置无需修改代码即可完成不同参数的导出操作。python train.py --batch_size 64 --lr 0.001 --epochs 100 --device cuda:0如果不知道要传什么可用以下命令进行显示python train.py -h常用的命令行参数包括模型权重文件路径、导出模型的保存路径、模型的输入尺寸、是否使用GPU导出、动态维度配置开关、精度对齐检验开关等每个参数都设置默认值同时添加参数说明方便用户根据自身需求灵活调整提升导出操作的便捷性。4. 检查导出环境模型导出前需要全面检查导出环境确保环境满足导出要求避免因环境问题导致导出失败主要检查内容包括操作系统是否兼容如Windows、Linux、GPU驱动是否安装正常、CUDA版本与PyTorch版本是否匹配、相关依赖库是否完整安装等。device torch.device(cuda if torch.cuda.is_available() else cpu)若使用GPU导出需检查GPU是否可用是否能被PyTorch识别同时检查GPU内存是否充足若使用CPU导出需确保CPU性能满足导出需求避免因CPU性能不足导致导出过程卡顿或失败。检查完成后将环境检查结果通过logger记录若存在环境问题及时提示用户进行修复。5. 查看GPU内存占用模型导出过程中会占用一定的GPU内存尤其是大型模型若GPU内存不足会导致导出失败因此需要在导出前查看GPU的内存占用情况确保有足够的空闲内存用于模型导出。可通过torch.cuda.memory_allocated()、torch.cuda.memory_reserved()等方法查看当前GPU的内存占用情况若空闲内存不足可关闭其他占用GPU内存的程序释放内存或调整模型的输入尺寸、批量大小等参数减少内存占用确保导出过程顺利进行。同时将GPU内存占用情况记录到日志中便于后续排查内存相关问题。每次查看后可顺手清理缓存释放内存。6. 加载模型及其权重模型导出前需要先加载训练好的模型及其权重参数加载过程需与模型保存的方式对应若保存的是完整模型直接通过torch.load()方法加载即可若保存的是权重参数需先重新定义好网络结构再通过model.load_state_dict(torch.load())方法将权重参数加载到模型中。这里调用的模型须在模型注册表内声明一般推荐第二种方式将网络结构拉过来比较清晰明了通用性强。7. 虚拟输入配置模型导出尤其是导出为ONNX格式时需要构造与真实输入尺寸一致的虚拟张量dummy input用于模型的tracing或scripting明确模型的输入维度和数据类型避免导出后的模型输入格式不匹配。虚拟输入的尺寸需与模型的实际输入层尺寸一致包括批次大小、通道数、高度、宽度等数据类型需与训练时的输入数据类型一致通常为float32同时将虚拟输入转移到对应的设备CPU/GPU上。构造完成后可将虚拟输入输入模型验证模型的输出维度是否符合预期确保虚拟输入配置正确。8. 动态维度配置实际部署过程中输入模型的图片尺寸可能会不固定因此需要配置动态输入维度让导出后的模型能够适配不同分辨率的图片推理提升模型的通用性避免因输入尺寸固定导致部署时无法处理不同尺寸的图片。配置动态维度时需指定动态变化的维度通常为批次大小batch_size、图片高度height、图片宽度width设置动态维度后导出的模型可以接受不同批次、不同尺寸的输入数据。配置完成后需验证动态维度是否生效可输入不同尺寸的虚拟输入查看模型的输出是否正常。一般是输入输出的batch都要一致开启高宽一般只有输入开启如果不开启动态维度匹配每次只能读一张照片。配置后在推理引擎配置阶段开启使用即模型部署。9. 模型简化和ONNX模型导出ONNXOpen Neural Network Exchange是一种跨平台、跨框架的模型格式能够实现不同深度学习框架如PyTorch、TensorFlow之间的模型互转便于模型的部署和推理因此这里主要导出为ONNX格式模型。导出前建议使用onnx-simplifier工具对模型进行简化去除模型中的冗余节点、无用操作减少模型文件体积提升模型的推理速度同时避免因冗余节点导致导出后的模型出现兼容性问题。模型简化完成后通过torch.onnx.export()方法进行ONNX模型导出导出时需指定模型、虚拟输入、导出路径、动态维度配置、输入输出节点名称等参数确保导出的ONNX模型格式正确导出过程中logger会实时记录导出进度若导出成功会提示模型保存路径若导出失败会记录错误信息便于用户排查问题。导出完成后可查看模型文件的大小和格式确认导出无误。如果模型比较复杂需注意导出后进行计算图检查看关键层是否缺失自定义算子是否按对应结构导出。# 导出 ONNX 模型 torch.onnx.export( model, # 待导出的模型 dummy_input, # 示例输入张量 model.onnx, # 保存路径 input_names[input], # 输入节点名称 output_names[output],# 输出节点名称 # 动态维度设置 dynamic_axes{input: {0: batch}, output: {0: batch}}, opset_version12, # ONNX opset 版本 do_constant_foldingTrue# 开启常量折叠优化 ) # 加载 ONNX 模型 onnx_model onnx.load(model.onnx) # 简化模型 simplified_model, success simplify(onnx_model) # 保存简化后的模型 onnx.save(simplified_model, model_simplified.onnx)10. 精度对齐检验模型导出为ONNX格式后可能会出现精度丢失的问题导致模型的预测结果与原始PyTorch模型的预测结果偏差较大因此需要进行精度对齐检验确保导出后的ONNX模型精度符合要求。精度对齐检验的核心方法是分别使用原始PyTorch模型和导出的ONNX模型输入相同的测试数据得到两组预测结果然后计算两组结果的误差如均方误差、绝对误差若误差在可接受范围通常为1e-5以内说明模型精度没有丢失导出合格若误差过大说明导出过程中出现问题需检查导出参数、虚拟输入配置、模型简化等环节重新导出。检验过程中需记录两组预测结果和误差值形成精度检验报告便于后续追溯同时确保导出的模型能够保持原始模型的识别精度。11. 性能测试ONNX模型导出完成后需要进行性能测试检验导出后模型的推理速度、内存占用等性能指标确保模型能够满足部署需求同时对比原始PyTorch模型的性能确认导出过程没有导致性能损耗。这个跟训练阶段按需求可配置一样的主要是验证在导出后模型性能方面的改变。性能测试主要包括测试ONNX模型的推理速度单张图片平均推理时间、批量图片推理速度、GPU/CPU内存占用情况同时测试模型在不同输入尺寸、不同批次大小下的性能表现记录各项性能指标。将ONNX模型的性能指标与原始PyTorch模型进行对比若推理速度提升、内存占用降低说明模型导出和简化效果良好若推理速度下降、内存占用增加需检查模型简化过程或导出参数进行优化调整。同时测试模型的稳定性连续输入多组数据确保模型能够稳定运行无异常报错。12. 主函数主函数的核心作用是整合模型导出的全流程将上述所有步骤日志配置、版本检查、命令行传参、环境检查、模型加载、虚拟输入配置、动态维度配置、模型简化、ONNX导出、精度对齐、性能测试串联起来实现一键导出功能用户只需运行主函数输入对应的命令行参数即可完成整个模型导出过程无需逐行执行代码提升操作便捷性。推荐配置__main__ 魔术方法限制运行。主函数中需要加入异常捕获机制try-except当导出过程中出现异常如模型加载失败、内存不足、导出格式错误等时能够及时捕获异常信息通过logger记录错误详情并给出相应的错误提示避免程序崩溃同时便于用户排查问题。导出完成后主函数会输出导出成功的提示信息包括模型保存路径、精度检验结果、性能测试结果等方便用户快速了解导出情况同时将所有导出相关的日志信息保存到日志文件中便于后续追溯和问题排查。三、模型部署1. 基础参数模型部署有通用的推理引擎ONNX Runtime 和专门适配英伟达显卡的推理引擎Tensor RT其中TRT比较复杂一些但性能更加强悍缺点就是不够通用。另有在移动端部署的TensorFlow Lite。模型部署前需要配置一系列基础参数确保部署过程顺利进行同时适配实际的部署场景基础参数的配置需结合模型的特点和部署需求主要包括以下几类一是检测相关参数配置置信度阈值confidence threshold用于过滤置信度过低的预测框避免误检测通常设置为0.5左右可根据实际需求调整配置NMS非极大值抑制阈值用于去除重复、冗余的检测框保留置信度最高、最准确的目标框通常设置为0.3-0.5之间配置模型的输入尺寸需与模型训练、导出时的输入尺寸一致确保输入数据格式匹配配置类别名称列表与数据集标注的类别名称一致用于后续目标类别标注。特别注意模型内部的输入输出名称二者需一致推荐在导出时顺便打印记录。需注意输出头的信息一般是[batch,class4,j检测框数]这一步是为了后续后处理画框做准备。如果输出头信息有误定位的框会杂乱识别极差甚至找不到框。二是部署设备参数定义部署所使用的设备CPU/GPU若部署环境有GPU可选择GPU部署提升推理速度若没有GPU可选择CPU部署确保模型能够正常运行同时设置线程数根据部署设备的性能调整合理的线程数能够提升数据读取和推理速度。参数配置完成后建议将参数保存到配置文件中便于后续修改和维护同时确保参数的一致性避免因参数配置错误导致部署失败。2. ONNX模型加载模型部署时需要使用推理引擎加载导出好的ONNX模型ONNXRuntime是微软推出的开源推理引擎支持跨平台、高速度推理兼容性强适合大多数部署场景这里主要使用ONNXRuntime加载模型。加载模型前需要安装ONNXRuntime依赖库根据部署设备CPU/GPU选择对应的版本安装完成后通过onnxruntime.InferenceSession()方法加载ONNX模型加载时需指定模型文件路径同时配置推理设备CPU/GPU。加载完成后需要检查模型的输入输出节点名称和维度确保与模型导出时的输入输出配置一致避免因节点名称或维度不匹配导致推理失败。可通过session.get_inputs()、session.get_outputs()方法查看输入输出节点信息记录节点名称和维度便于后续输入数据和解析输出结果。# 加载ONNX模型创建推理会话 session ort.InferenceSession( model.onnx, providers[CPUExecutionProvider, CUDAExecutionProvider] ) # 获取输入输出名称 input_name session.get_inputs()[0].name output_name session.get_outputs()[0].name # 执行推理 outputs session.run( [output_name], {input_name: input_data} )3. 预处理预处理是模型部署的关键环节其目的是将原始图片数据转换为模型能够接受的格式预处理方式必须与模型训练时的预处理方式保持一致否则会导致模型预测结果不准确预处理主要包括以下步骤第一步是图片读取使用OpenCV或PIL工具读取原始图片获取图片的原始尺寸高度、宽度便于后续调整检测框坐标第二步是通道转换模型训练时通常使用RGB通道而OpenCV读取的图片是BGR通道因此需要将BGR通道转换为RGB通道第三步是尺寸调整将图片调整为模型的输入尺寸调整过程中需注意保持图片的宽高比避免图片拉伸导致目标变形可通过padding方式进行灰度填充补充空白区域确保图片尺寸符合模型要求第四步是归一化将图片的像素值归一化到指定范围如[0,1]或[-1,1]与训练时的归一化方式一致消除像素值差异对模型预测的影响第五步是维度转换将图片的维度从[height, width, channels]转换为[batch_size, channels, height, width]符合模型的输入维度要求。预处理完成后将处理后的图片数据转换为模型能够接受的数据类型通常为float32并转移到对应的部署设备CPU/GPU上为后续推理做好准备。# 数据转移到指定设备 images images.to(device) labels labels.to(device) # 模型转移到指定设备 model model.to(device)4. 后处理后处理的核心作用是解析模型的输出结果将模型输出的原始数据转换为直观、可用的检测结果如目标坐标、类别、置信度便于后续绘制检测框和展示结果后处理步骤需根据模型的输出格式灵活调整主要包括以下内容首先获取模型的输出结果根据加载模型时获取的输出节点名称通过session.run()方法获取模型的原始输出数据原始输出数据通常包含预测框坐标、置信度、类别索引等信息不同模型的输出格式可能不同需根据模型类型进行解析。其次分离预测框、置信度、类别索引将原始输出数据中的预测框坐标通常为归一化后的坐标、每个预测框的置信度、类别索引分离出来便于后续处理然后将归一化的预测框坐标转换为原始图片尺寸下的坐标根据预处理时记录的原始图片尺寸将归一化坐标乘以对应的宽高得到目标在原始图片中的实际坐标x1, y1, x2, y2其中x1, y1为预测框左上角坐标x2, y2为预测框右下角坐标。最后根据置信度阈值过滤低置信度预测框将置信度低于设定阈值的预测框剔除减少误检测保留置信度较高的预测框为后续NMS过滤做准备。5. NMS过滤框经过后处理和置信度过滤后可能还会存在一些重复、冗余的检测框同一目标被多次检测影响检测结果的准确性和美观度因此需要通过NMS非极大值抑制算法过滤这些冗余检测框。NMS算法的核心逻辑是对于同一类别的所有预测框按照置信度从高到低排序选取置信度最高的预测框作为基准框计算其他预测框与基准框的IOU交并比若IOU大于设定的NMS阈值则认为该预测框与基准框是同一目标将其剔除若IOU小于阈值则保留该预测框重复此过程直到所有预测框都被处理完毕最终保留每个目标的最优检测框。NMS阈值的设置需要结合实际检测场景阈值过大可能会导致重复检测框无法被完全剔除阈值过小可能会导致漏检通常设置为0.3-0.5之间可根据实际检测效果灵活调整。6. 图片检测图片检测是模型部署的核心功能主要实现将预处理后的图片输入模型完成推理、后处理、NMS过滤等一系列操作最终得到准确的检测结果支持单张图片检测、批量图片检测、视频帧检测、摄像头实时检测等多种场景。单张图片检测将单张图片经过预处理后输入加载好的ONNX模型通过推理引擎得到模型输出结果再经过后处理和NMS过滤得到最终的检测结果包括目标的坐标、类别、置信度等信息。批量图片检测将多张图片批量进行预处理统一输入模型进行推理批量处理能够提升推理效率适合大量图片的检测场景处理完成后分别对每张图片的输出结果进行后处理和NMS过滤得到每张图片的检测结果。视频帧/摄像头实时检测将视频的每一帧或摄像头捕获的实时画面按照单张图片的预处理方式进行处理然后输入模型推理实时输出检测结果实现动态检测适合实时监控、实时识别等场景需注意优化推理速度确保检测的实时性。7. 绘制检测框为了直观展示检测结果方便用户查看需要在原始图片上绘制检测框并标注目标的类别名称和置信度绘制过程需结合后处理得到的目标坐标、类别、置信度等信息具体步骤如下首先获取原始图片和检测结果根据检测结果中的目标坐标在原始图片上使用矩形框绘制目标区域矩形框的颜色可根据不同类别进行区分如不同类别使用不同颜色的框便于区分不同目标然后在矩形框的上方或旁边标注目标的类别名称和置信度置信度可保留两位小数让用户清晰了解目标的识别精度最后将绘制好检测框的图片保存到指定路径或实时显示出来适用于视频、摄像头实时检测场景。绘制过程中需注意矩形框的位置要准确对应目标区域标注的字体大小、颜色要清晰可见避免遮挡目标同时保持图片的清晰度确保检测结果直观、易读。8. 主函数主函数是模型部署的入口核心作用是整合部署全流程将基础参数配置、ONNX模型加载、图片预处理、模型推理、后处理、NMS过滤、检测框绘制等步骤串联起来实现一键部署功能支持多种检测场景单张图片、批量图片、视频、摄像头。主函数中需要添加参数解析功能用户可通过命令行输入检测类型、图片/视频路径、模型路径、置信度阈值、NMS阈值等参数灵活调整部署配置无需修改代码。同时加入异常捕获机制当部署过程中出现异常如模型加载失败、图片读取失败、推理报错等时能够及时捕获异常信息给出错误提示避免程序崩溃同时记录日志便于问题排查。部署完成后主函数会输出部署成功的提示信息同时将检测结果绘制好检测框的图片、视频保存到指定路径或实时显示检测画面方便用户查看检测效果。此外主函数的代码设计需简洁、可移植能够轻松适配不同的部署环境便于后续的二次开发和优化。

更多文章