从MNIST代码到真实项目:用PyTorch搭建一个简易手写数字识别Web应用(Flask + ONNX部署)

张开发
2026/4/21 15:51:36 15 分钟阅读

分享文章

从MNIST代码到真实项目:用PyTorch搭建一个简易手写数字识别Web应用(Flask + ONNX部署)
从MNIST代码到真实项目用PyTorch搭建一个简易手写数字识别Web应用Flask ONNX部署当你第一次在Jupyter Notebook里跑通MNIST手写数字识别模型时那种成就感就像解开一道数学难题。但很快你会发现这个躺在.ipynb文件里的模型就像实验室里的标本——完美却与真实世界隔绝。本文将带你跨越这道鸿沟把PyTorch模型变成任何人都能通过网页使用的智能工具。想象这样一个场景产品经理拿着手机拍下白板上的数字3秒后网页自动显示识别结果。这种端到端的实现涉及模型轻量化、接口设计和前后端协同远比单纯的准确率提升更有商业价值。下面我们就用FlaskONNX这套轻量级组合实现从实验代码到生产应用的华丽转身。1. 模型训练与ONNX导出打破框架壁垒原始MNIST训练代码往往止步于model.pth的保存但想要跨平台部署我们需要更通用的中间格式。ONNXOpen Neural Network Exchange就像神经网络界的PDF能让PyTorch模型在不同环境中保持一致性。1.1 改造训练代码在原有模型基础上我们需要确保网络结构完全支持ONNX导出。特别注意动态运算如条件分支可能导致导出失败。以下是关键改造点# 在原始Net类中添加尺寸注释 def forward(self, input): # 添加维度注释便于ONNX理解输入格式 # input: [batch_size, 1, 28, 28] output self.model(input) return output1.2 执行模型导出使用torch.onnx.export时需特别注意输入输出的命名规范这对后续Web接口开发至关重要# 创建示例输入 dummy_input torch.randn(1, 1, 28, 28).to(device) # 导出模型 torch.onnx.export( net, dummy_input, mnist.onnx, input_names[input_image], # 前端将使用此名称 output_names[prediction], # 返回结果标识 dynamic_axes{ input_image: {0: batch_size}, # 支持动态batch prediction: {0: batch_size} } )注意导出后务必使用onnxruntime验证模型可用性。安装onnxruntime后执行以下测试import onnxruntime as ort sess ort.InferenceSession(mnist.onnx) outputs sess.run( [prediction], {input_image: dummy_input.numpy()} ) print(ONNX测试输出:, outputs)2. Flask服务端开发构建AI中间件轻量级Web框架Flask在这里扮演着桥梁角色既要高效处理模型推理又要规范数据接口。我们采用经典的蓝图Blueprint结构组织代码。2.1 项目结构设计/mnist_webapp │── static/ # 前端资源 │ └── sketchpad.js # 画板交互逻辑 │── templates/ # HTML模板 │ └── index.html │── app.py # 主程序入口 │── model_loader.py # ONNX模型加载器 └── requirements.txt2.2 核心API实现在app.py中创建预测接口特别注意图像预处理必须与训练时保持一致from flask import Flask, request, jsonify import numpy as np from PIL import Image import io import model_loader # 自定义模型加载模块 app Flask(__name__) app.route(/predict, methods[POST]) def predict(): # 接收前端Base64编码图像 img_data request.json[image].split(,)[1] img Image.open(io.BytesIO(base64.b64decode(img_data))) # 执行与训练一致的预处理 img img.convert(L).resize((28, 28)) img_array np.array(img) / 255.0 img_array (img_array - 0.5) / 0.5 # 标准化 img_array img_array.reshape(1, 1, 28, 28).astype(np.float32) # 调用ONNX模型 pred model_loader.predict(img_array) return jsonify({digit: int(pred)})模型加载模块model_loader.py需要管理ONNX运行时会话import onnxruntime as ort class MNISTModel: def __init__(self, model_path): self.session ort.InferenceSession(model_path) def predict(self, input_array): outputs self.session.run( [prediction], {input_image: input_array} ) return np.argmax(outputs[0]) # 全局模型实例 model MNISTModel(mnist.onnx)3. 前端交互设计打造零门槛体验好的AI产品应该让技术隐形。我们设计两种输入方式画板绘制和图片上传覆盖不同用户场景。3.1 画板实现方案使用HTML5 Canvas构建手写区域关键点在于笔触粗细设置为15px模拟真实书写添加撤销重做功能提升体验导出时自动裁剪空白区域// static/sketchpad.js const canvas document.getElementById(drawing-board); const ctx canvas.getContext(2d); let isDrawing false; ctx.lineWidth 15; ctx.lineCap round; canvas.addEventListener(mousedown, startDrawing); canvas.addEventListener(mousemove, draw); canvas.addEventListener(mouseup, stopDrawing); function startDrawing(e) { isDrawing true; draw(e); // 立即记录第一个点 } function draw(e) { if (!isDrawing) return; ctx.beginPath(); ctx.moveTo(lastX, lastY); ctx.lineTo(e.offsetX, e.offsetY); ctx.stroke(); [lastX, lastY] [e.offsetX, e.offsetY]; } function stopDrawing() { isDrawing false; } // 导出图像供模型识别 function exportImage() { // 自动裁剪逻辑 const imageData ctx.getImageData(0, 0, canvas.width, canvas.height); const cropped autoCrop(imageData); // 转换为Base64 return cropped.toDataURL(image/png); }3.2 预测结果可视化使用Chart.js动态展示模型对各数字的置信度分布div classresults canvas idconfidence-chart/canvas div classprediction识别结果: span iddigit?/span/div /div script function showResults(confidences) { const ctx document.getElementById(confidence-chart).getContext(2d); new Chart(ctx, { type: bar, data: { labels: [0,1,2,3,4,5,6,7,8,9], datasets: [{ data: confidences, backgroundColor: #4e73df }] }, options: { scales: { y: { beginAtZero: true } } } }); } /script4. 性能优化与生产部署当访问量增加时原始实现可能遇到性能瓶颈。以下是关键优化策略4.1 模型推理加速优化手段实现方法预期提升量化压缩ONNX运行时启用int8量化推理速度提升3倍批处理前端累积多个请求一并发送吞吐量提升5倍缓存机制对相同图像哈希值缓存结果重复请求零延迟启用量化推理的改造代码# 在model_loader.py中 self.session ort.InferenceSession( model_path, providers[CUDAExecutionProvider], # 启用GPU加速 sess_optionsort.SessionOptions() ) # 设置量化参数 ort.set_default_logger_severity(3) # 减少日志输出4.2 容器化部署使用Docker确保环境一致性Dockerfile配置要点FROM python:3.8-slim WORKDIR /app COPY requirements.txt . RUN pip install --no-cache-dir -r requirements.txt COPY . . # 启用性能优化模式 ENV ONNX_THREAD_NUM2 ENV FLASK_ENVproduction EXPOSE 5000 CMD [gunicorn, -w 4, -b :5000, app:app]构建命令docker build -t mnist-webapp . docker run -d -p 5000:5000 --name mnist_app mnist-webapp5. 异常处理与监控生产系统必须考虑各种边界情况5.1 输入验证防御app.route(/predict, methods[POST]) def predict(): try: # 验证图像有效性 if image not in request.json: raise ValueError(Missing image data) img_data request.json[image] if not img_data.startswith(data:image/png;base64,): raise ValueError(Invalid image format) # 继续处理流程... except Exception as e: app.logger.error(fPrediction error: {str(e)}) return jsonify({error: str(e)}), 4005.2 健康检查端点添加/health接口供运维监控app.route(/health) def health_check(): try: # 测试模型可用性 test_input np.random.rand(1, 1, 28, 28).astype(np.float32) model.predict(test_input) return jsonify({status: healthy}) except Exception as e: return jsonify({status: unhealthy, error: str(e)}), 500在Kubernetes中配置livenessProbelivenessProbe: httpGet: path: /health port: 5000 initialDelaySeconds: 30 periodSeconds: 106. 扩展思考从Demo到产品完成基础版本后可以考虑以下增强功能用户反馈循环添加识别正确按钮收集错误样本用于模型迭代多模型切换在界面增加模型选择器如LeNet vs ResNet离线支持使用TensorFlow.js实现浏览器端直接推理历史记录本地存储用户之前的识别记录实现模型热加载的进阶代码class ModelManager: def __init__(self): self.models {} self.current_model mnist def load_model(self, name, path): self.models[name] MNISTModel(path) def switch_model(self, name): if name in self.models: self.current_model name def predict(self, input_array): return self.models[self.current_model].predict(input_array) # 初始化时加载多个模型 model_manager ModelManager() model_manager.load_model(mnist, mnist.onnx) model_manager.load_model(lenet, lenet.onnx)前端切换逻辑document.getElementById(model-selector).addEventListener(change, (e) { fetch(/switch_model, { method: POST, headers: { Content-Type: application/json }, body: JSON.stringify({ model: e.target.value }) }).then(response { console.log(Switched to ${e.target.value} model); }); });这个项目最有趣的部分是看到非技术用户对手写识别功能的反应。他们会故意写潦草的字测试系统极限或是惊讶于识别速度——这些真实反馈比验证集上的准确率数字更有价值。

更多文章