《記從零實現手寫數字識別——PyTorch實戰篇》
一、環境搭建
實驗采用Python3.8環境,主要依賴庫:
- PyTorch 1.12:深度學習框架
- Torchvision 0.13:提供MNIST數據集
- OpenCV 4.6:圖像預處理
安裝命令:pip install torch torchvision opencv-python
二、實戰開發步驟
- 數據加載技巧
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.1307,), (0.3081,)) # 數據集均值標準差
])
train_set = datasets.MNIST('data/', train=True, download=True, transform=transform)
test_set = datasets.MNIST('data/', train=False, transform=transform)
- 改進型網絡設計
class EnhancedCNN(nn.Module):
def __init__(self):
super().__init__()
self.features = nn.Sequential(
nn.Conv2d(1, 16, 3, padding=1),
nn.BatchNorm2d(16),
nn.ReLU(),
nn.MaxPool2d(2),
nn.Conv2d(16, 32, 3, padding=1),
nn.Dropout(0.25)
)
self.classifier = nn.Sequential(
nn.Linear(32*7*7, 128),
nn.ReLU(),
nn.Linear(128, 10))
- 訓練優化技巧
def train_model():
model = EnhancedCNN()
optimizer = torch.optim.Adam(model.parameters(), lr=0.001)
scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=3, gamma=0.1)
for epoch in range(10):
model.train()
for data, target in train_loader:
optimizer.zero_grad()
output = model(data)
loss = F.cross_entropy(output, target)
loss.backward()
optimizer.step()
scheduler.step()
三、效果驗證
在測試集上達到98.7%準確率的關鍵:
- 添加BatchNorm層加速收斂
- 使用Dropout防止過擬合
- 學習率階梯下降策略
四、模型部署示例
def predict_image(img_path):
img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
img = cv2.resize(img, (28,28))
img_tensor = transform(255 - img).unsqueeze(0)
with torch.no_grad():
pred = torch.argmax(model(img_tensor)).item()
return pred
思考延伸
嘗試使用數據增強(旋轉、平移)提升模型魯棒性,比較不同優化器的性能差異,思考如何將模型部署到移動端應用。

浙公網安備 33010602011771號