PyTorch模型訓練優(yōu)化、FastAPI跨域配置與Vue響應式交互的手寫數字識別實踐
本文圍繞手寫數字識別項目展開,涵蓋前端交互(Vue)、后端接口(FastAPI)、CNN模型訓練(PyTorch)全流程,把之前學習過的知識綜合運用起來。
內容包含環(huán)境搭建、代碼實現、操作步驟及問題解決,借助該項目來掌握前后端分離項目開發(fā)、MNIST數據集應用、LeNet5模型訓練與部署,獲取可復用的圖像分類項目流程,快速復現或擴展類似項目。
1 項目基礎與環(huán)境準備
1.1 項目介紹與目標
1.1.1 項目介紹
手寫數字識別是計算機視覺入門經典任務,基于MNIST數據集(含6萬訓練樣本、1萬測試樣本,每個樣本為28×28灰度圖,對應0-9數字),采用LeNet5卷積神經網絡(CNN)實現分類,架構為前端交互+后端預測+模型支撐的前后端分離模式。
1.1.2 項目目標
- 前端:提供畫布供用戶手寫數字,完成圖像預處理(縮放、灰度轉換),發(fā)起后端請求并展示結果。
- 后端:接收前端圖像,通過預訓練LeNet5模型預測數字,返回結果。
- 整體:實現端到端識別,準確率達98%以上,掌握全流程開發(fā)邏輯。
具體的流程可以參考下圖:

1.2 開發(fā)環(huán)境準備
1.2.1 基礎環(huán)境要求
- 編程語言:Python 3.8+(后端+模型訓練)、JavaScript(前端Vue)
- 運行環(huán)境:Node.js 16+(Vue項目依賴管理)、Python虛擬環(huán)境
1.2.2 依賴庫安裝
1.2.2.1 Python依賴(后端+模型)
通過pip安裝核心庫,命令如下:
# 后端框架與網絡請求 pip install fastapi uvicorn # PyTorch核心(含CPU版本,GPU版本需替換命令) pip install torch torchvision # 圖像處理與數據處理 pip install pillow numpy # 前端請求庫(Vue側后續(xù)安裝)
這里還有個要注意的點就是,如果電腦里有多個python環(huán)境,在這里用pip下載最好指定一下,不然會默認用全局的python環(huán)境去下載。
比如:
D:\Python\Scripts\pip3.12.exe install [安裝包]
1.2.2.2 Vue依賴(前端)
進入前端項目目錄(mnist-frontend),通過npm安裝:
# 初始化Vue項目(若未創(chuàng)建) npm create vue@latest mnist-frontend # 進入目錄并安裝axios(請求后端) cd mnist-frontend npm install axios
1.2.3 項目目錄結構
參考實際文件路徑(D:\ProjectPython\DNN_CNN),規(guī)范結構如下(便于后續(xù)復用):
DNN_CNN/ # 項目根目錄 ├─ mnist-frontend/ # 前端Vue項目 │ ├─ src/ │ │ ├─ App.vue # 前端核心文件(模板+邏輯+樣式) │ │ ├─ main.js # Vue入口文件 │ │ └─ style.css # 全局樣式(本項目用組件內聯樣式) │ └─ package.json # Vue依賴配置 ├─ CNN_Proj.py # 模型訓練腳本(生成權重文件) ├─ main.py # 后端FastAPI服務腳本 ├─ LeNet5_mnist.pth # 預訓練模型權重(訓練后生成) └─ dataset/ # MNIST數據集(訓練腳本自動下載)
2 前端實現(Vue)
2.1 前端核心功能定位
前端是用戶交互入口,需解決如何讓用戶輸入數字、如何將輸入轉為模型可識別格式和如何與后端通信三個核心問題,最終實現繪制→預處理→請求→展示的這一閉環(huán)。
2.2 模板結構設計(App.vue的<template>)
模板需包含交互組件+反饋組件,結構如下:
<template>
<div class="container">
<h1>手寫數字識別</h1>
<!-- 1. 主畫布(用戶繪制數字) -->
<canvas
ref="canvas"
width="280"
height="280"
@mousedown="startDrawing"
@mousemove="draw"
@mouseup="stopDrawing"
@mouseleave="stopDrawing"
></canvas>
<!-- 2. 調試畫布(預覽28×28預處理圖像,便于排查問題) -->
<div class="debug-section" v-show="showDebug">
<h3>預處理后圖像(28x28 放大)</h3>
<canvas ref="debugCanvas" width="280" height="280"></canvas>
<p class="debug-info">實際尺寸 28x28 | 放大 10 倍</p>
</div>
<!-- 3. 控制按鈕(功能操作) -->
<div class="buttons">
<button @click="clearCanvas" :disabled="isLoading">清除畫布</button>
<button @click="predictDigit" :disabled="isLoading">
{{ isLoading ? '識別中...' : '識別' }}
</button>
<button @click="toggleDebug">顯示/隱藏調試</button>
</div>
<!-- 4. 結果與錯誤反饋 -->
<div class="result" v-if="recognitionResult">識別結果:{{ recognitionResult }}</div>
<div class="error" v-if="errorMessage">錯誤:{{ errorMessage }}</div>
</div>
</template>
2.3 核心邏輯實現(App.vue的<script setup>)
2.3.1 響應式變量定義
通過Vue的ref定義狀態(tài)變量,確保視圖與數據同步:
import { ref, onMounted, nextTick, watch } from 'vue';
import axios from 'axios';
// 畫布DOM引用
const canvas = ref(null);
const debugCanvas = ref(null);
// 控制狀態(tài)
const showDebug = ref(false); // 調試視圖開關
const isDrawing = ref(false); // 繪制狀態(tài)
const isLoading = ref(false); // 識別加載狀態(tài)
// 結果反饋
const recognitionResult = ref(''); // 識別結果
const errorMessage = ref(''); // 錯誤信息
// 繪制輔助變量
let ctx = null; // 主畫布上下文
let debugCtx = null; // 調試畫布上下文
let lastX = 0; // 上一次繪制X坐標
let lastY = 0; // 上一次繪制Y坐標
2.3.2 畫布初始化(onMounted鉤子)
畫布需在DOM渲染完成后初始化,確保上下文獲取成功,同時配置繪制參數(匹配模型輸入要求):
onMounted(async () => {
await nextTick(); // 等待DOM完全渲染
// 主畫布初始化(280×280,后續(xù)縮放為28×28,避免繪制精度不足)
if (canvas.value) {
ctx = canvas.value.getContext('2d', { willReadFrequently: true });
if (ctx) {
ctx.fillStyle = '#ffffff'; // 純白背景(匹配MNIST數據集背景)
ctx.fillRect(0, 0, 280, 280);
ctx.lineWidth = 12; // 畫筆寬度(過細會導致預處理后線條消失)
ctx.strokeStyle = 'black'; // 黑色畫筆(與MNIST數字顏色一致)
ctx.lineCap = 'round'; // 畫筆端點圓潤(避免鋸齒)
ctx.lineJoin = 'round'; // 畫筆拐角圓潤(提升繪制體驗)
} else {
errorMessage.value = '主畫布初始化失敗,請刷新';
}
}
// 調試畫布初始化(與主畫布邏輯一致,用于預覽預處理結果)
if (debugCanvas.value) {
debugCtx = debugCanvas.value.getContext('2d', { willReadFrequently: true });
if (debugCtx) {
debugCtx.fillStyle = '#ffffff';
debugCtx.fillRect(0, 0, 280, 280);
} else {
console.warn('調試畫布初始化失?。ú挥绊懼鞴δ埽?);
}
}
});
2.3.3 繪制邏輯(鼠標事件處理)
通過mousedown/mousemove/mouseup事件實現連續(xù)繪制,需處理畫布縮放導致的坐標偏移問題:
// 開始繪制(記錄初始坐標)
function startDrawing(e) {
if (!ctx) return;
isDrawing.value = true;
const rect = canvas.value.getBoundingClientRect(); // 獲取畫布在頁面中的位置
// 計算畫布內真實坐標(解決瀏覽器縮放導致的坐標偏差)
lastX = (e.clientX - rect.left) * (canvas.value.width / rect.width);
lastY = (e.clientY - rect.top) * (canvas.value.height / rect.height);
ctx.beginPath();
ctx.moveTo(lastX, lastY);
ctx.lineTo(lastX + 0.1, lastY + 0.1); // 繪制初始點(避免點擊不拖動無痕跡)
ctx.stroke();
}
// 實時繪制
function draw(e) {
if (!ctx || !isDrawing.value) return;
const rect = canvas.value.getBoundingClientRect();
const x = (e.clientX - rect.left) * (canvas.value.width / rect.width);
const y = (e.clientY - rect.top) * (canvas.value.height / rect.height);
ctx.lineTo(x, y); // 連接上一坐標與當前坐標
ctx.stroke();
lastX = x; // 更新上一坐標
lastY = y;
}
// 結束繪制
function stopDrawing() {
isDrawing.value = false;
}
2.3.4 圖像預處理(關鍵步驟)
模型輸入要求為1×1×28×28灰度圖(batch×通道×高×寬)+ 歸一化,需通過輔助函數實現轉換:
2.3.4.1 畫布空檢測(checkCanvasEmpty)
避免前端發(fā)送空圖像請求,通過亮度閾值判斷是否有繪制內容:
async function checkCanvasEmpty() {
return new Promise((resolve) => {
if (!ctx) { resolve(true); return; }
const imageData = ctx.getImageData(0, 0, 280, 280);
const data = imageData.data; // 像素數據(RGBA,每4個值對應一個像素)
const threshold = 250; // 亮度閾值(純白亮度255,低于250視為有繪制)
for (let i = 0; i < data.length; i += 4) {
const brightness = (data[i] + data[i+1] + data[i+2]) / 3; // 計算亮度(灰度值)
if (brightness < threshold) {
resolve(false); // 有繪制內容
return;
}
}
resolve(true); // 無繪制內容
});
}
2.3.4.2 28×28灰度轉換與反轉(canvasTo28x28Gray)
MNIST數據集為黑底白字,而前端繪制是白底黑字,需反轉顏色;同時縮放為28×28:
function canvasTo28x28Gray(canvasEl) {
return new Promise((resolve) => {
// 1. 創(chuàng)建臨時畫布(28×28,模型輸入尺寸)
const tempCanvas = document.createElement('canvas');
tempCanvas.width = 28;
tempCanvas.height = 28;
const tempCtx = tempCanvas.getContext('2d');
if (!tempCtx) { resolve({ imgBlob: null, tempCanvas: null }); return; }
// 2. 縮放繪制(保持比例居中,避免拉伸)
tempCtx.fillStyle = '#ffffff';
tempCtx.fillRect(0, 0, 28, 28); // 填充純白背景
const scale = Math.min(28 / canvasEl.width, 28 / canvasEl.height); // 等比例縮放
const xOffset = (28 - canvasEl.width * scale) / 2; // X軸居中偏移
const yOffset = (28 - canvasEl.height * scale) / 2; // Y軸居中偏移
tempCtx.drawImage(
canvasEl,
0, 0, canvasEl.width, canvasEl.height, // 源圖像區(qū)域
xOffset, yOffset, canvasEl.width * scale, canvasEl.height * scale // 目標繪制區(qū)域
);
// 3. 灰度轉換與顏色反轉(匹配MNIST數據分布)
const imageData = tempCtx.getImageData(0, 0, 28, 28);
const data = imageData.data;
for (let i = 0; i < data.length; i += 4) {
const brightness = (data[i] + data[i+1] + data[i+2]) / 3; // 灰度值
const inverted = 255 - brightness; // 反轉:白底黑字→黑底白字
data[i] = data[i+1] = data[i+2] = inverted; // RGB通道統一為反轉后值
data[i+3] = 255; // 透明度保持100%
}
tempCtx.putImageData(imageData, 0, 0);
// 4. 生成Blob(用于FormData傳輸)
tempCanvas.toBlob((blob) => {
resolve({ imgBlob: blob, tempCanvas: tempCanvas });
}, 'image/png', 1.0); // 無損壓縮,避免圖像細節(jié)丟失
});
}
2.3.5 后端請求邏輯(predictDigit)
通過axios發(fā)送POST請求,傳遞圖像Blob,處理響應與錯誤:
async function predictDigit() {
if (!ctx) { errorMessage.value = '畫布未初始化,請刷新'; return; }
isLoading.value = true;
errorMessage.value = '';
try {
// 步驟1:檢查畫布是否有內容
const isEmpty = await checkCanvasEmpty();
if (isEmpty) {
errorMessage.value = '請先繪制數字';
isLoading.value = false;
return;
}
// 步驟2:預處理圖像(轉為28×28灰度Blob)
const { imgBlob, tempCanvas } = await canvasTo28x28Gray(canvas.value);
if (!imgBlob) { throw new Error('圖像轉換失敗,無法生成有效數據'); }
// 步驟3:預覽調試圖像(若開啟調試)
if (showDebug.value && debugCtx && tempCanvas) {
debugCtx.drawImage(tempCanvas, 0, 0, 280, 280); // 放大10倍顯示
}
// 步驟4:構建FormData(后端接收文件格式)
const formData = new FormData();
formData.append('file', imgBlob, 'digit.png'); // 參數名'file'需與后端一致
// 步驟5:發(fā)送請求(不手動設置Content-Type,axios自動處理邊界符)
const response = await axios.post(
'http://localhost:8000/predict', // 后端接口地址
formData
);
// 步驟6:處理響應(驗證數據格式)
if (response.data && 'predicted_digit' in response.data) {
recognitionResult.value = response.data.predicted_digit;
} else {
throw new Error('后端返回數據格式異常');
}
} catch (error) {
// 精細化錯誤提示(便于排查問題)
if (error.response) {
// 后端返回錯誤(如422參數錯誤、500服務器錯誤)
errorMessage.value = `識別失?。?{error.response.status} - ${
error.response.data?.error || error.response.data?.detail || '未知錯誤'
}`;
} else if (error.request) {
// 無響應(后端未啟動、跨域問題)
errorMessage.value = '識別失敗:無法連接后端服務,請檢查后端是否運行';
} else {
// 前端本地錯誤(如圖像轉換失敗)
errorMessage.value = `識別失?。?{error.message}`;
}
console.error('預測錯誤詳情:', error);
} finally {
isLoading.value = false; // 無論成功失敗,結束加載狀態(tài)
}
}
這里簡單說下圖像Blob,圖像Blob(Binary Large Object)簡單說就是以二進制形式存儲的圖像文件數據,比如PNG、JPG格式的圖像在計算機中實際存儲的字節(jié)流,就屬于Blob。
在項目里,前端把畫布繪制的內容(28×28灰度圖)轉成Blob,是因為:
- 后端接口接收的是“文件”類型數據(
UploadFile),Blob能模擬文件的二進制格式; - 配合
FormData(表單數據)傳遞時能保持圖像的原始編碼,避免文本格式轉換導致的數據損壞。
比如項目中canvasTo28x28Gray函數里,通過tempCanvas.toBlob(...)生成Blob,再用formData.append('file', imgBlob, 'digit.png')附加到請求里,就能讓后端像接收本地圖片文件一樣解析它。
2.4 樣式設計(App.vue的<style scoped>)
樣式保證交互友好性,沒有放過多冗雜的東西,核心代碼如下:
<style scoped>
.container {
text-align: center;
padding: 20px;
max-width: 600px;
margin: 0 auto; /* 容器居中 */
}
canvas {
border: 2px solid #ccc;
margin: 10px auto;
display: block;
background-color: #ffffff; /* 匹配畫布初始化背景 */
touch-action: none; /* 禁止瀏覽器默認觸摸行為(適配移動端) */
}
.debug-section {
margin-top: 20px;
padding: 15px;
background-color: #f9f9f9;
border-radius: 8px; /* 圓角提升美觀度 */
}
.debug-info {
color: #666;
font-size: 14px;
margin-top: 5px;
}
.buttons {
margin: 20px 0;
}
button {
padding: 10px 20px;
margin: 0 10px;
cursor: pointer;
background-color: #42b983; /* Vue默認主題色,辨識度高 */
color: white;
border: none;
border-radius: 4px;
transition: opacity 0.3s; /* hover過渡效果 */
}
button:disabled {
background-color: #ccc;
cursor: not-allowed; /* 禁用狀態(tài)光標提示 */
opacity: 0.7;
}
button:hover:not(:disabled) {
opacity: 0.8; /* hover時降低透明度,反饋交互 */
}
.result {
font-size: 20px;
margin-top: 20px;
color: #42b983; /* 成功顏色 */
}
.error {
font-size: 16px;
color: #e53e3e; /* 錯誤顏色 */
margin-top: 10px;
}
</style>
3 后端實現(FastAPI + PyTorch)
3.1 后端核心功能定位
后端需解決如何接收前端圖像、如何用模型預測和如何返回結果這三個問題,核心是提供高可用的預測接口,確保與前端數據格式兼容、與模型輸入匹配。
3.2 FastAPI服務搭建
3.2.1 初始化FastAPI實例
from fastapi import FastAPI, File, UploadFile from fastapi.middleware.cors import CORSMiddleware import torch import torch.nn as nn from PIL import Image import numpy as np # 初始化FastAPI應用 app = FastAPI()
3.2.2 跨域配置(關鍵)
前端(默認5173端口)與后端(8000端口)端口不同,會觸發(fā)瀏覽器跨域攔截,需配置CORSMiddleware:
app.add_middleware(
CORSMiddleware,
allow_origins=["*"], # 開發(fā)環(huán)境允許所有源(生產環(huán)境需指定具體域名)
allow_credentials=True, # 允許攜帶Cookie(本項目暫用不到,保留擴展性)
allow_methods=["*"], # 允許所有HTTP方法(GET/POST等)
allow_headers=["*"], # 允許所有請求頭
)
3.3 LeNet5模型定義(與訓練腳本一致)
模型結構必須與訓練時完全相同,否則權重加載失敗。LeNet5是經典CNN架構,適配MNIST數據:
class LeNet5(nn.Module):
def __init__(self):
super(LeNet5, self).__init__()
# 網絡層序列(卷積→激活→池化→卷積→激活→池化→卷積→激活→展平→全連接→激活→全連接)
self.net = nn.Sequential(
# C1層:1→6通道,5×5卷積核,padding=2(保持28×28輸出)
nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, padding=2),
nn.Tanh(), # 激活函數(LeNet5原設計,引入非線性)
nn.AvgPool2d(kernel_size=2, stride=2), # S2層:2×2平均池化,輸出14×14
# C3層:6→16通道,5×5卷積核(無padding,輸出10×10)
nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5),
nn.Tanh(),
nn.AvgPool2d(kernel_size=2, stride=2), # S4層:輸出5×5
# C5層:16→120通道,5×5卷積核(輸出1×1,等效全連接)
nn.Conv2d(in_channels=16, out_channels=120, kernel_size=5),
nn.Tanh(),
nn.Flatten(), # 展平:120×1×1→120維向量
# F6層:全連接,120→84
nn.Linear(in_features=120, out_features=84),
nn.Tanh(),
# 輸出層:全連接,84→10(對應0-9數字)
nn.Linear(in_features=84, out_features=10)
)
# 前向傳播(定義數據流動路徑)
def forward(self, x):
return self.net(x)
3.4 模型加載與圖像預處理
3.4.1 模型初始化與權重加載
加載訓練生成的LeNet5_mnist.pth權重,切換為評估模式(禁用訓練相關層):
# 初始化模型
model = LeNet5()
# 加載權重(map_location='cpu'適配無GPU環(huán)境)
state_dict = torch.load('LeNet5_mnist.pth', map_location=torch.device('cpu'))
model.load_state_dict(state_dict) # 權重參數映射到模型
model.eval() # 切換為評估模式(關鍵:禁用Dropout/BatchNorm等訓練層)
3.4.2 圖像預處理函數(preprocess_image)
前端傳入的是28×28 PNG圖像,需轉為模型要求的1×1×28×28張量+歸一化:
def preprocess_image(image):
# 1. 轉為灰度圖(即使前端已處理,后端二次確認,避免格式錯誤)
image = image.convert('L') # 'L'模式為單通道灰度圖
# 2. 確保尺寸為28×28(前端可能因異常未縮放,后端兜底)
image = image.resize((28, 28), Image.Resampling.LANCZOS) # 高質量插值縮放
# 3. 轉為numpy數組并歸一化(匹配訓練時的數據分布)
image = np.array(image, dtype=np.float32) # 轉為32位浮點數數組
mean = 0.1307 # MNIST數據集均值(訓練時計算,需固定)
std = 0.3081 # MNIST數據集標準差(訓練時計算,需固定)
image = (image / 255.0 - mean) / std # 步驟:0-255→0-1→標準化(均值0,標準差1)
# 4. 調整維度(模型輸入:batch×通道×高×寬)
image = np.expand_dims(image, axis=0) # 增加通道維度:(28,28)→(1,28,28)
image = np.expand_dims(image, axis=0) # 增加batch維度:(1,28,28)→(1,1,28,28)
# 5. 轉為PyTorch張量
return torch.tensor(image)
3.5 預測接口實現(/predict)
定義POST接口,接收前端UploadFile類型文件,處理流程為讀取圖像→預處理→預測→返回結果:
@app.post("/predict")
async def predict_digit(file: UploadFile = File(...)):
try:
# 1. 打印調試信息(便于排查文件接收問題)
print(f"收到文件: {file.filename}, 類型: {file.content_type}")
# 2. 讀取圖像(PIL.Image打開)
image = Image.open(file.file)
print(f"原始圖像 - 尺寸: {image.size}, 模式: {image.mode}")
# 3. 圖像預處理
input_tensor = preprocess_image(image)
print(f"預處理后 - 張量維度: {input_tensor.shape}, 數據類型: {input_tensor.dtype}")
# 4. 模型預測(禁用梯度計算,節(jié)省資源)
with torch.no_grad():
output = model(input_tensor) # 模型輸出:(1,10)(1個樣本,10個類別概率)
predicted_digit = torch.argmax(output, dim=1).item() # 取概率最大的類別
# 5. 返回結果(JSON格式,前端可直接解析)
return {"predicted_digit": predicted_digit}
except Exception as e:
# 異常捕獲(打印錯誤信息,返回錯誤提示)
print(f"處理請求時出錯: {str(e)}")
return {"error": str(e)}
# 啟動服務(當腳本直接運行時)
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000) # 0.0.0.0允許局域網訪問,端口8000
4 模型訓練(PyTorch + MNIST)
4.1 訓練核心目標
生成可復用的權重文件(LeNet5_mnist.pth),該模型在MNIST測試集上準確率為98.17%,準確率還算不錯,用它來為后端提供預測能力。
4.2 訓練腳本實現(CNN_Proj.py)
4.2.1 數據準備(prepare_data)
加載MNIST數據集,應用與后端一致的預處理(歸一化),用DataLoader按批次加載:
import torch
import torch.nn as nn
from torch.utils.data import DataLoader
from torchvision import transforms, datasets
import matplotlib.pyplot as plt
# 解決中文顯示問題
plt.rcParams['font.sans-serif'] = ['SimSun']
plt.rcParams['axes.unicode_minus'] = False
def prepare_data():
# 數據轉換 pipeline(與后端預處理邏輯一致)
transform = transforms.Compose([
transforms.ToTensor(), # 轉為張量:(H,W,C)→(C,H,W),值歸一化到0-1
transforms.Normalize(0.1307, 0.3081) # 標準化(均值+標準差)
])
# 加載訓練集(train=True),自動下載到./dataset/mnist/
train_dataset = datasets.MNIST(
root='./dataset/mnist/',
train=True,
download=True,
transform=transform
)
# 加載測試集(train=False)
test_dataset = datasets.MNIST(
root='./dataset/mnist/',
train=False,
download=True,
transform=transform
)
# 創(chuàng)建DataLoader(按批次加載,訓練集打亂)
train_loader = DataLoader(
train_dataset,
batch_size=256, # 批次大?。ǜ鶕却嬲{整,256兼顧速度與內存)
shuffle=True # 訓練集打亂,增強泛化能力
)
test_loader = DataLoader(
test_dataset,
batch_size=256,
shuffle=False # 測試集無需打亂
)
return train_loader, test_loader
4.2.2 模型訓練(train_model)
定義訓練循環(huán),包含“前向傳播→損失計算→反向傳播→參數更新”核心步驟:
def train_model(model, train_loader, epochs=5, lr=0.9):
# 1. 損失函數:交叉熵損失(分類任務專用,含Softmax激活)
criterion = nn.CrossEntropyLoss()
# 2. 優(yōu)化器:隨機梯度下降(SGD),lr=0.9為LeNet5經典學習率
optimizer = torch.optim.SGD(model.parameters(), lr=lr)
# 3. 記錄損失(用于繪制曲線,觀察訓練效果)
train_losses = []
# 4. 訓練循環(huán)
print("\n開始訓練...")
for epoch in range(epochs):
model.train() # 切換為訓練模式(啟用Dropout/BatchNorm)
total_loss = 0.0
# 遍歷訓練集批次
for batch_idx, (images, labels) in enumerate(train_loader):
# 前向傳播:輸入圖像,獲取模型輸出
outputs = model(images)
# 計算損失:輸出與真實標簽的差異
loss = criterion(outputs, labels)
# 反向傳播與參數更新
optimizer.zero_grad() # 清空上一輪梯度(避免累積)
loss.backward() # 反向傳播計算梯度
optimizer.step() # 根據梯度更新模型參數
# 記錄損失
train_losses.append(loss.item())
total_loss += loss.item()
# 每100個批次打印一次中間結果
if (batch_idx + 1) % 100 == 0:
print(f"輪次 [{epoch+1}/{epochs}], 批次 [{batch_idx+1}/{len(train_loader)}], "
f"當前批次損失: {loss.item():.4f}")
# 打印本輪平均損失
avg_loss = total_loss / len(train_loader)
print(f"輪次 [{epoch+1}/{epochs}] 平均損失: {avg_loss:.4f}")
# 5. 繪制損失曲線(直觀觀察訓練收斂情況)
plt.figure(figsize=(10, 4))
plt.plot(train_losses, label='訓練損失')
plt.xlabel('批次')
plt.ylabel('損失值')
plt.title('訓練損失變化曲線')
plt.legend()
plt.show()
# 6. 保存模型權重(僅保存狀態(tài)字典,節(jié)省空間)
torch.save(model.state_dict(), 'LeNet5_mnist.pth')
print(f"模型已保存為 'LeNet5_mnist.pth'")
return model, train_losses
4.2.3 模型測試(test_model)
評估模型在測試集上的準確率,驗證泛化能力:
def test_model(model, test_loader):
model.eval() # 切換為評估模式
correct = 0 # 正確預測數
total = 0 # 總樣本數
# 禁用梯度計算(測試階段無需更新參數)
with torch.no_grad():
print("\n開始測試...")
for images, labels in test_loader:
outputs = model(images)
_, predicted = torch.max(outputs.data, 1) # 取概率最大的類別
total += labels.size(0)
correct += (predicted == labels).sum().item() # 統計正確數
# 計算并打印準確率
accuracy = 100 * correct / total
print(f"測試集準確率: {accuracy:.2f}%")
return accuracy
4.2.4 主函數(串聯訓練流程)
def main():
# 步驟1:準備數據
train_loader, test_loader = prepare_data()
print("數據準備完成,訓練集樣本數:", len(train_loader.dataset),
"測試集樣本數:", len(test_loader.dataset))
# 步驟2:初始化模型(與后端LeNet5完全一致)
model = LeNet5()
print("\nLeNet-5模型初始化完成")
# 步驟3:訓練模型
trained_model, losses = train_model(model, train_loader, epochs=5)
# 步驟4:測試模型
test_model(trained_model, test_loader)
if __name__ == "__main__":
main()
5 完整項目操作流程
5.1 前置準備
安裝基礎環(huán)境:
搭建項目目錄:
- 在
D:\ProjectPython\下創(chuàng)建DNN_CNN文件夾(根目錄)。 - 在
DNN_CNN下創(chuàng)建mnist-frontend文件夾(前端目錄)。
安裝依賴:
- 打開命令提示符(CMD),執(zhí)行Python依賴安裝:
pip install fastapi uvicorn torch torchvision pillow numpy
- 進入前端目錄,執(zhí)行Vue依賴安裝:
cd D:\ProjectPython\DNN_CNN\mnist-frontend npm create vue@latest . # 初始化Vue項目,全部選“NO”(簡化配置) npm install axios
5.2 模型訓練(可選,已有權重可跳過)
在DNN_CNN根目錄創(chuàng)建CNN_Proj.py,第4章的訓練腳本程序放在該py文件里。
運行訓練腳本:
cd D:\ProjectPython\DNN_CNN python CNN_Proj.py
等待訓練完成,根目錄會生成LeNet5_mnist.pth(權重文件),這個時候可以管擦測試集準確率,一般來說滿足≥95%就可以了。
比如我這邊自己訓練的,

從訓練結果來看,這個 LeNet-5 模型在 MNIST 測試集上達到了98.17% 的準確率,對于基礎的手寫數字識別任務來說,這個性能算是比較理想的,直接用于簡單的手寫數字識別這個實際場景是足夠的。
5.3 后端部署
在DNN_CNN根目錄創(chuàng)建main.py,程序詳見第3章的后端腳本程序。
確保LeNet5_mnist.pth在根目錄下,啟動后端服務:
python main.py
看到“Uvicorn running on http://0.0.0.0:8000”表示啟動成功,不要關閉CMD窗口。
這里有兩個點要說清楚,
第一,如果直接在 Python 里運行 main.py(比如點擊 IDE 的“運行”按鈕),程序會加載模型 → 定義 FastAPI 實例 → 定義路由,但不會但不會啟動 Web 服務!代碼里的 API 接口(/predict )根本沒法被外部訪問, Postman 也連不上。
第二,uvicorn main:app --reload 是干啥的? uvicorn 是一個 ASGI 服務器,作用是:
- 找到你的
main.py文件,加載里面的app = FastAPI()實例 - 啟動一個 Web 服務,讓你的 API(
/predict)能被外部訪問(比如 Postman、前端頁面 ) --reload:文件改動時自動重啟服務(開發(fā)時超方便,不用手動重啟 )
main.py完整程序如下:
# 后端 main.py(PyTorch 版本)
from fastapi import FastAPI, File, UploadFile
from fastapi.middleware.cors import CORSMiddleware
import torch
import torch.nn as nn
from PIL import Image
import numpy as np
app = FastAPI()
# 允許跨域
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_credentials=True,
allow_methods=["*"],
allow_headers=["*"],
)
# 定義與CNN_Proj.py中一致的LeNet5模型結構
class LeNet5(nn.Module):
def __init__(self):
super(LeNet5, self).__init__()
self.net = nn.Sequential(
nn.Conv2d(in_channels=1, out_channels=6, kernel_size=5, padding=2),
nn.Tanh(),
nn.AvgPool2d(kernel_size=2, stride=2),
nn.Conv2d(in_channels=6, out_channels=16, kernel_size=5),
nn.Tanh(),
nn.AvgPool2d(kernel_size=2, stride=2),
nn.Conv2d(in_channels=16, out_channels=120, kernel_size=5),
nn.Tanh(),
nn.Flatten(),
nn.Linear(in_features=120, out_features=84),
nn.Tanh(),
nn.Linear(in_features=84, out_features=10)
)
def forward(self, x):
return self.net(x)
# 初始化模型
model = LeNet5()
# 加載權重(無需修改鍵名,直接匹配)
state_dict = torch.load('LeNet5_mnist.pth', map_location=torch.device('cpu'))
model.load_state_dict(state_dict)
model.eval() # 切換為評估模式
# 圖像預處理(適配MNIST數據集的預處理方式)
def preprocess_image(image):
# 確保圖像轉為灰度圖(即使前端已處理,后端再次確認)
image = image.convert('L') # 轉為灰度圖
# 確保圖像尺寸為28x28(即使前端已處理,后端再次確認)
image = image.resize((28, 28), Image.Resampling.LANCZOS) # 使用高質量插值方法
# 轉換為numpy數組并歸一化
image = np.array(image, dtype=np.float32) # 轉為數組
# 按照訓練時的方式歸一化(MNIST的均值和標準差)
mean = 0.1307
std = 0.3081
image = (image / 255.0 - mean) / std # 先歸一化到0-1再標準化
# 確保輸入維度正確
image = np.expand_dims(image, axis=0) # 增加通道維度 (1,28,28)
image = np.expand_dims(image, axis=0) # 增加batch維度 (1,1,28,28)
return torch.tensor(image)
# 預測接口
@app.post("/predict")
async def predict_digit(file: UploadFile = File(...)):
try:
# 打印文件基本信息用于調試
print(f"收到文件: {file.filename}, 類型: {file.content_type}")
# 讀取圖像
image = Image.open(file.file)
print(f"原始圖像 - 尺寸: {image.size}, 模式: {image.mode}") # 檢查圖像初始狀態(tài)
# 預處理
input_tensor = preprocess_image(image)
print(f"預處理后 - 張量維度: {input_tensor.shape}, 數據類型: {input_tensor.dtype}") # 檢查處理后狀態(tài)
# 預測
with torch.no_grad():
output = model(input_tensor)
predicted_digit = torch.argmax(output, dim=1).item()
return {"predicted_digit": predicted_digit}
except Exception as e:
# 打印異常信息用于調試
print(f"處理請求時出錯: {str(e)}")
return {"error": str(e)}
if __name__ == "__main__":
import uvicorn
uvicorn.run(app, host="0.0.0.0", port=8000)```
接下來簡單展示一下啟動步驟:
#### 1. 打開終端,進入 `main.py` 所在目錄
以我的文件結構來舉例:
D:\ProjectPython\DNN_CNN
├── main.py
├── CNN_Proj.py
└── LeNet5_mnist.pth
在 **VS Code** 里: - 點擊左側“資源管理器”,找到 `DNN_CNN` 文件夾 - 點擊頂部菜單 **終端 → 新建終端**(會自動進入當前目錄 ) - 也可以直接用cd + 文件路徑 #### 2. 運行 `uvicorn` 命令 在終端里輸入: ```bash uvicorn main:app --reload
main:app:告訴 uvicorn:
- 找
main.py文件(main) - 加載里面的
app = FastAPI()實例(app)
--reload:開發(fā)模式,改代碼后自動重啟
3. 看啟動結果
如果成功,終端會顯示:
INFO: Uvicorn running on http://127.0.0.1:8000 (Press CTRL+C to quit) INFO: Started reloader process [12345] INFO: Started server process [12346] INFO: Waiting for application startup. INFO: Application startup complete.
這說明:
- 你的 API 服務啟動了,地址是
http://127.0.0.1:8000 - 現在可以用 Postman 訪問
http://127.0.0.1:8000/predict測試

如果想要保險起見,可以先用下面這一步來測試一下,測試下API的情況。
4. 測試 API(用 Postman 或瀏覽器)
打開 Postman:
- 請求方法:POST
- URL:
http://127.0.0.1:8000/predict
** Body → form-data**:
Key選file,類型選FileValue選一張手寫數字的圖片(28x28 黑白圖最佳 )
發(fā)送請求后,就能看到返回的 predicted_digit(識別結果 )
打開后配置請求信息:
Step 1:選請求方法 + 填 URL
- 選 POST(必須和你
main.py里的@app.post("/predict")對應); - 中間 URL 輸入框,填
http://127.0.0.1:8000/predict(就是你 FastAPI 服務的地址 + 接口名)。
Step 2:配置 Body(上傳圖片)
- 點擊請求下方的 “Body” 標簽 → 勾選 “Form Data”(表單上傳,和
main.py接收UploadFile對應); - 第一行“Key”輸入
file(必須和main.py里predict(file: UploadFile = File(...))的參數名一致); - 第一行“Value”右側,點擊 “File” 按鈕(默認是“Text”,要改成文件上傳),然后選擇一張你的手寫數字圖片(28x28 黑白圖最佳,手機拍的手寫數字照片也能試)。
Step 3:發(fā)送請求
- 點擊右上角的 “Send” 按鈕(藍色箭頭),發(fā)送請求。

發(fā)送后,右側會顯示服務器返回的結果:
- 成功情況:如果返回類似
{"predicted_digit": 5},說明模型識別出圖片里的數字是 5,API 調用成功!
常見問題排查:
- 若顯示“Connection refused”:檢查 FastAPI 服務是否啟動(終端里的
uvicorn命令有沒有在運行); - 若顯示“找不到文件”:檢查
main.py里torch.load("LeNet5_mnist.pth")的模型路徑是否正確,確保LeNet5_mnist.pth和main.py在同一目錄; - 若識別結果錯誤:檢查
preprocess_image函數的預處理邏輯(比如是否轉灰度、是否 resize 到 28x28),要和訓練時完全一致。
5.4 前端部署
首先確保你的 Node.js 環(huán)境已經準備好,接下來用 Vue3 + Vite 實現手寫數字識別的前端界面并和后端 API 打通:
5.4.1 創(chuàng)建 Vue3 + Vite 項目
打開終端(CMD/PowerShell/VS Code 終端都可以 );
創(chuàng)建項目(按順序執(zhí)行 ):
# 1. 創(chuàng)建 Vue3 項目(項目名 mnist-frontend,模板選 vue) npm create vite@latest mnist-frontend -- --template vue # 2. 進入項目目錄 cd mnist-frontend # 3. 安裝依賴(等待安裝完成) npm install # 4. 啟動開發(fā)環(huán)境(啟動后,瀏覽器訪問 http://127.0.0.1:5173) npm run dev
執(zhí)行完后,瀏覽器會自動打開 Vue3 初始頁面(或手動訪問 http://127.0.0.1:5173 ),看到 Vue 的歡迎界面,說明項目創(chuàng)建成功。

5.4.2 編寫前端界面
在 VS Code 中打開項目目錄 mnist-frontend,找到 src/App.vue 文件,替換成以下完整程序:
<template>
<div class="container">
<h1>手寫數字識別</h1>
<!-- 主畫布 -->
<canvas
ref="canvas"
width="280"
height="280"
@mousedown="startDrawing"
@mousemove="draw"
@mouseup="stopDrawing"
@mouseleave="stopDrawing"
></canvas>
<!-- 調試畫布(v-show 保持 DOM 存在) -->
<div class="debug-section" v-show="showDebug">
<h3>預處理后圖像(28x28 放大)</h3>
<canvas ref="debugCanvas" width="280" height="280"></canvas>
<p class="debug-info">實際尺寸 28x28 | 放大 10 倍</p>
</div>
<!-- 控制按鈕 -->
<div class="buttons">
<button @click="clearCanvas" :disabled="isLoading">清除畫布</button>
<button @click="predictDigit" :disabled="isLoading">
{{ isLoading ? '識別中...' : '識別' }}
</button>
<button @click="toggleDebug">顯示/隱藏調試</button>
</div>
<!-- 結果與錯誤提示 -->
<div class="result" v-if="recognitionResult">識別結果:{{ recognitionResult }}</div>
<div class="error" v-if="errorMessage">錯誤:{{ errorMessage }}</div>
</div>
</template>
<script setup>
import { ref, onMounted, nextTick, watch } from 'vue';
import axios from 'axios';
// 響應式變量
const canvas = ref(null);
const debugCanvas = ref(null);
const showDebug = ref(false);
const isDrawing = ref(false);
const isLoading = ref(false);
const recognitionResult = ref('');
const errorMessage = ref('');
let ctx = null;
let debugCtx = null;
let lastX = 0;
let lastY = 0;
// 初始化畫布(確保 DOM 渲染完成)
onMounted(async () => {
await nextTick(); // 等待 DOM 完全渲染
// 主畫布初始化
if (canvas.value) {
ctx = canvas.value.getContext('2d', { willReadFrequently: true });
if (ctx) {
ctx.fillStyle = '#ffffff'; // 改為純白背景,與MNIST訓練數據背景一致
ctx.fillRect(0, 0, 280, 280);
ctx.lineWidth = 12; // 調整畫筆寬度,避免預處理后線條過細
ctx.strokeStyle = 'black';
ctx.lineCap = 'round'; // 畫筆端點圓潤,避免鋸齒
ctx.lineJoin = 'round'; // 畫筆拐角圓潤,提升繪制體驗
} else {
errorMessage.value = '主畫布初始化失敗,請刷新';
}
} else {
errorMessage.value = '未找到主畫布元素,請檢查代碼';
}
// 調試畫布初始化(v-show 已確保 DOM 存在)
if (debugCanvas.value) {
debugCtx = debugCanvas.value.getContext('2d', { willReadFrequently: true });
if (debugCtx) {
debugCtx.fillStyle = '#ffffff';
debugCtx.fillRect(0, 0, 280, 280);
} else {
console.warn('調試畫布初始化失?。ú挥绊懼鞴δ埽?);
}
}
});
// 監(jiān)聽 showDebug 變化,重新初始化調試畫布
watch(showDebug, (newVal) => {
if (newVal && debugCanvas.value && !debugCtx) {
debugCtx = debugCanvas.value.getContext('2d', { willReadFrequently: true });
if (debugCtx) {
debugCtx.fillStyle = '#ffffff';
debugCtx.fillRect(0, 0, 280, 280);
}
}
});
// 繪制邏輯 - 修復坐標計算與繪制連續(xù)性問題
function startDrawing(e) {
if (!ctx) return;
isDrawing.value = true;
const rect = canvas.value.getBoundingClientRect();
// 計算畫布內真實坐標(處理畫布縮放場景)
lastX = (e.clientX - rect.left) * (canvas.value.width / rect.width);
lastY = (e.clientY - rect.top) * (canvas.value.height / rect.height);
ctx.beginPath();
ctx.moveTo(lastX, lastY);
// 繪制初始點(解決點擊畫布不拖動無痕跡問題)
ctx.lineTo(lastX + 0.1, lastY + 0.1);
ctx.stroke();
}
function draw(e) {
if (!ctx || !isDrawing.value) return;
const rect = canvas.value.getBoundingClientRect();
// 計算畫布內真實坐標
const x = (e.clientX - rect.left) * (canvas.value.width / rect.width);
const y = (e.clientY - rect.top) * (canvas.value.height / rect.height);
ctx.lineTo(x, y);
ctx.stroke();
lastX = x;
lastY = y;
}
function stopDrawing() {
isDrawing.value = false;
}
// 清除畫布
function clearCanvas() {
if (!ctx) return;
ctx.fillStyle = '#ffffff';
ctx.fillRect(0, 0, 280, 280);
// 清除調試畫布
if (debugCtx) {
debugCtx.fillStyle = '#ffffff';
debugCtx.fillRect(0, 0, 280, 280);
}
recognitionResult.value = '';
errorMessage.value = '';
}
// 切換調試視圖
function toggleDebug() {
showDebug.value = !showDebug.value;
}
// 預測邏輯 - 修復FormData構建與錯誤處理
async function predictDigit() {
if (!ctx) {
errorMessage.value = '畫布未初始化,請刷新';
return;
}
isLoading.value = true;
errorMessage.value = '';
try {
// 檢查畫布是否有內容(優(yōu)化閾值,適配純白背景)
const isEmpty = await checkCanvasEmpty();
if (isEmpty) {
errorMessage.value = '請先繪制數字';
isLoading.value = false;
return;
}
// 轉換為 28x28 灰度圖(前端預處理)
const { imgBlob, tempCanvas } = await canvasTo28x28Gray(canvas.value);
if (!imgBlob) {
throw new Error('圖像轉換失敗,無法生成有效圖像數據');
}
// 顯示調試圖像(放大)
if (showDebug.value && debugCtx && tempCanvas) {
debugCtx.drawImage(tempCanvas, 0, 0, 280, 280);
}
// 調用后端識別 - 修復FormData構建,移除手動設置Content-Type(axios自動處理)
const formData = new FormData();
formData.append('file', imgBlob, 'digit.png'); // 參數名改為'file',與后端UploadFile參數名匹配
const response = await axios.post(
'http://localhost:8000/predict',
formData
// 移除手動設置的Content-Type,避免邊界符缺失問題
);
// 驗證響應數據格式
if (response.data && 'predicted_digit' in response.data) {
recognitionResult.value = response.data.predicted_digit;
} else {
throw new Error('后端返回數據格式異常');
}
} catch (error) {
// 精細化錯誤提示
if (error.response) {
// 后端返回錯誤(如422、500)
errorMessage.value = `識別失?。?{error.response.status} - ${
error.response.data?.error || error.response.data?.detail || '未知錯誤'
}`;
} else if (error.request) {
// 無響應(如后端未啟動、跨域問題)
errorMessage.value = '識別失敗:無法連接后端服務,請檢查后端是否運行';
} else {
// 前端本地錯誤(如圖像轉換)
errorMessage.value = `識別失?。?{error.message}`;
}
console.error('預測錯誤詳情:', error);
} finally {
isLoading.value = false;
}
}
// 輔助函數:檢查畫布是否為空(優(yōu)化閾值,適配純白背景)
async function checkCanvasEmpty() {
return new Promise((resolve) => {
if (!ctx) {
resolve(true);
return;
}
const imageData = ctx.getImageData(0, 0, 280, 280);
const data = imageData.data;
const threshold = 250; // 純白背景下,低于250視為有繪制內容
for (let i = 0; i < data.length; i += 4) {
const brightness = (data[i] + data[i+1] + data[i+2]) / 3;
if (brightness < threshold) {
resolve(false);
return;
}
}
resolve(true);
});
}
// 輔助函數:Canvas 轉 28x28 灰度圖(修復圖像反轉邏輯,匹配MNIST)
function canvasTo28x28Gray(canvasEl) {
return new Promise((resolve) => {
const tempCanvas = document.createElement('canvas');
tempCanvas.width = 28;
tempCanvas.height = 28;
const tempCtx = tempCanvas.getContext('2d');
if (!tempCtx) {
resolve({ imgBlob: null, tempCanvas: null });
return;
}
// 1. 繪制時保持圖像比例,避免拉伸(居中繪制)
tempCtx.fillStyle = '#ffffff';
tempCtx.fillRect(0, 0, 28, 28); // 先填充純白背景
// 計算縮放比例(確保圖像完全放入28x28畫布,保留比例)
const scale = Math.min(28 / canvasEl.width, 28 / canvasEl.height);
const xOffset = (28 - canvasEl.width * scale) / 2;
const yOffset = (28 - canvasEl.height * scale) / 2;
tempCtx.drawImage(
canvasEl,
0, 0, canvasEl.width, canvasEl.height,
xOffset, yOffset, canvasEl.width * scale, canvasEl.height * scale
);
// 2. 轉灰度并反轉(MNIST:白底黑字 → 黑底白字,增強特征)
const imageData = tempCtx.getImageData(0, 0, 28, 28);
const data = imageData.data;
for (let i = 0; i < data.length; i += 4) {
// 計算亮度(灰度值)
const brightness = (data[i] + data[i+1] + data[i+2]) / 3;
// 反轉:白色(高亮度)→ 黑色(0),黑色(低亮度)→ 白色(255),匹配MNIST數據分布
const inverted = 255 - brightness;
data[i] = data[i+1] = data[i+2] = inverted;
data[i+3] = 255; // 保持不透明
}
tempCtx.putImageData(imageData, 0, 0);
// 3. 生成Blob(指定質量,避免數據損壞)
tempCanvas.toBlob((blob) => {
resolve({ imgBlob: blob, tempCanvas: tempCanvas });
}, 'image/png', 1.0); // 1.0表示無損壓縮,確保圖像細節(jié)不丟失
});
}
</script>
<style scoped>
.container {
text-align: center;
padding: 20px;
max-width: 600px;
margin: 0 auto;
}
canvas {
border: 2px solid #ccc;
margin: 10px auto;
display: block;
background-color: #ffffff; /* 匹配初始化的純白背景 */
touch-action: none;
}
.debug-section {
margin-top: 20px;
padding: 15px;
background-color: #f9f9f9;
border-radius: 8px;
}
.debug-info {
color: #666;
font-size: 14px;
margin-top: 5px;
}
.buttons {
margin: 20px 0;
}
button {
padding: 10px 20px;
margin: 0 10px;
cursor: pointer;
background-color: #42b983;
color: white;
border: none;
border-radius: 4px;
transition: opacity 0.3s;
}
button:disabled {
background-color: #ccc;
cursor: not-allowed;
opacity: 0.7;
}
button:hover:not(:disabled) {
opacity: 0.8;
}
.result {
font-size: 20px;
margin-top: 20px;
color: #42b983;
}
.error {
font-size: 16px;
color: #e53e3e;
margin-top: 10px;
}
</style>
已經在mnist-frontend/src目錄下創(chuàng)建好App.vue,程序詳見第2章的前端腳本程序。
啟動前端服務:
cd D:\ProjectPython\DNN_CNN\mnist-frontend\src npm run dev
看到Local: http://localhost:5173/表示啟動成功,復制鏈接在瀏覽器打開。

5.5 功能測試
- 在瀏覽器頁面的畫布上,用鼠標繪制0-9任意數字。
- 點擊
顯示/隱藏調試,查看28×28預處理圖像。 - 點擊
識別按鈕,下方會顯示識別結果。 - 點擊
清除畫布可重新繪制,測試其他數字。
結果如下,只列舉部分:




當然,你在終端上也可以看到具體的信息,如果出現錯誤也可以從中看到是什么錯誤:

在前端頁面上也可以通過Fn + 12來打開瀏覽器后臺查看具體信息。
在你創(chuàng)建好后,如果未更改前后端文件,后續(xù)你的啟動步驟就只需要兩步:
1.啟動后端API服務:
uvicorn main:app --reload
2.啟動前端開發(fā)環(huán)境:
npm run dev
6 問題復盤與解決
6.1 錯誤1:422 Unprocessable Entity(前端請求后端失?。?/h3>
這個算是一開始很常見的問題,具體來說很大概率基本都是參數名與后端不匹配。
- 原因:前端FormData參數名與后端不匹配(原前端用
image,后端需file);手動設置Content-Type: multipart/form-data導致請求邊界符缺失。 - 解決思路:前端
formData.append('file', imgBlob, 'digit.png');刪除axios的headers配置,讓axios自動處理。
6.2 錯誤2:預測結果不準確(如“3”識別為“8”)
- 原因:前端圖像未反轉(與MNIST黑底白字分布相反);畫筆過細導致預處理后線條消失。
- 解決思路:在
canvasTo28x28Gray中添加灰度反轉(255 - brightness);將ctx.lineWidth設為12-15。
7 小結
7.1 收獲
技術棧整合:切身體會Vue(前端交互)、FastAPI(后端接口)和PyTorch(CNN模型)的前后端分離開發(fā)模式,理解各模塊間的數據流轉邏輯(圖像→Blob→FormData→張量→預測結果)。
關鍵技術點:
圖像預處理:灰度轉換、尺寸縮放、顏色反轉、歸一化,核心是“匹配模型訓練時的數據分布”。模型部署:訓練權重加載、評估模式切換、無梯度預測,確保模型高效且正確運行。問題排查:通過調試信息(如后端打印的文件尺寸、張量維度)定位數據格式問題,通過精細化錯誤提示快速排查接口問題。
7.2 可擴展方向
- 功能擴展:支持手寫字母識別(替換數據集為EMNIST)、多數字識別(修改模型輸出層為多分類)。
- 性能優(yōu)化:用ResNet-18替換LeNet5提升準確率,用TensorRT加速模型推理,前端添加防抖繪制減少冗余數據。
- 場景適配:開發(fā)移動端頁面,添加歷史記錄功能,部署到云服務器實現公網訪問,但相關知識目前還沒學完,后面有時間試試。
7.3 可復用方向
本筆記的環(huán)境搭建→代碼實現→操作流程可直接復用于圖像分類類項目(如驗證碼識別、水果分類),只需替換三個部分:
- 數據集:將MNIST替換為目標數據集(如EMNIST、Fruits-360)。
- 模型結構:根據數據集復雜度調整CNN層數(簡單任務用LeNet5,復雜任務用ResNet)。
- 前端交互:根據輸入類型修改交互組件(將畫布改為圖片上傳)。
以上為個人經驗,希望能給大家一個參考,也希望大家多多支持腳本之家。
相關文章
Mac中Python 3環(huán)境下安裝scrapy的方法教程
作為一名python爬蟲愛好者,怎能不折騰下Scrapy?折騰了許久之后終于安裝到了mac中,所以下面這篇文章主要給大家介紹了關于Mac系統中Python 3環(huán)境下安裝scrapy的相關資料,文中將實現的步驟一步步介紹的非常詳細,需要的朋友可以參考下。2017-10-10

