【AI學習筆記9】基于pytorch實現CNN或MLP識別mnist, Mnist recognition using CNN & MLP based on pytorch
基于pytorch實現CNN或MLP識別mnist, Mnist recognition using CNN & MLP based on pytorch
一、CNN識別mnist

如圖,CNN網絡由2層卷積層(Convolutional layer)、2層池化層(Pooling layer)、1層全連接層(FCN layer)組成。【1】
二、用CNN識別mnist的代碼 【2】【3】【4】【5】
# 加載必要庫 load lib
import torch
from torchvision import transforms
from torchvision import datasets
from torch.utils.data import DataLoader
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
import matplotlib.pyplot as plt
# 定義超參數 hyperparameter
Batch_Size = 64 #每批處理的數據
Epochs = 100 #訓練數據集的輪次
Learning_Rate = 0.01 #學習率
# prepare dataset
transform = transforms.Compose([transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))])
# read mnist dataset
train_dataset = datasets.MNIST(root='./dataset/mnist/', train=True, download=True, transform=transform)
train_loader = DataLoader(train_dataset, shuffle=True, batch_size=Batch_Size) # 下載訓練集 MNIST 手寫數字訓練集
test_dataset = datasets.MNIST(root='./dataset/mnist/', train=False, download=True, transform=transform)
test_loader = DataLoader(test_dataset, shuffle=False, batch_size=Batch_Size) #下載測試集 MNIST 手寫數字測試集
# design model using class
class CNN(torch.nn.Module):
def __init__(self):
super(CNN, self).__init__()
self.conv1 = torch.nn.Sequential( # 原始圖片為灰度圖(1,28,28)
# 卷積: 輸入通道數1,輸出通道數10,卷積核3×3,步長1,不填充 (10,24,24)
torch.nn.Conv2d(in_channels=1, out_channels=10, kernel_size=5, stride=1, padding=0),
torch.nn.ReLU(), # ReLU激活函數
torch.nn.MaxPool2d(2), # 最大池化,池化核2×2,步長2, (10,12,12)
)
self.conv2 = torch.nn.Sequential(
torch.nn.Conv2d(in_channels=10, out_channels=20, kernel_size=5, stride=1, padding=0), # (20,8,8)
torch.nn.ReLU(),
torch.nn.MaxPool2d(2), # (20,4,4)
)
self.fc = torch.nn.Linear(20*4*4, 10) # 全連接層將它展平flatten, 分成10類(0-9) classification
# 前向傳播
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = x.view(x.size(0), -1)
output = self.fc(x)
return output
model = CNN()
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu") # 用CPU還是GPU
print(device)
model.to(device)
# construct loss and optimizer
criterion = torch.nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=Learning_Rate, momentum=0.5)
# training cycle forward, backward, update
def train(epoch):
running_loss = 0.0
for batch_idx, data in enumerate(train_loader, 0):
inputs, target = data
inputs, target = inputs.to(device), target.to(device)
optimizer.zero_grad()
# 調用前向傳播
outputs = model(inputs)
loss = criterion(outputs, target)
# 反向傳播
loss.backward()
optimizer.step()
running_loss += loss.item()
if batch_idx % 300 == 299:
print('[%d, %5d] loss: %.3f' % (epoch + 1, batch_idx + 1, running_loss / 300))
running_loss = 0.0
def test():
correct = 0
total = 0
with torch.no_grad():
for data in test_loader:
images, labels = data
images, labels = images.to(device), labels.to(device)
outputs = model(images)
predicted = torch.max(outputs.data, dim=1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
print('accuracy on test set: %d %% ' % (100 * correct / total))
return correct / total
if __name__ == '__main__':
epoch_list = []
acc_list = []
for epoch in range(Epochs):
train(epoch)
acc = test()
epoch_list.append(epoch)
acc_list.append(acc)
plt.plot(epoch_list, acc_list)
plt.ylabel('accuracy')
plt.xlabel('epoch')
plt.show()
運行結果:
cuda:0
[1, 300] loss: 0.588
[1, 600] loss: 0.182
[1, 900] loss: 0.135
accuracy on test set: 97 %
...
...
[100, 300] loss: 0.000
[100, 600] loss: 0.001
[100, 900] loss: 0.001
accuracy on test set: 99 %

三、用MLP識別mnist的代碼 【6】
用如下代碼替換 # design model using class 代碼塊即可:
# 定義網絡結構
class FC(torch.nn.Module):
def __init__(self):
super(FC, self).__init__()
self.l1 = torch.nn.Linear(784, 15)
self.l2 = torch.nn.Linear(15, 10)
self.relu = torch.nn.ReLU()
#前向傳播
def forward(self, x):
x = x.view(-1, 784)
x = self.l1(x)
x = self.relu(x)
x = self.l2(x)
return x
運行結果:
cuda:0
[1, 300] loss: 0.783
[1, 600] loss: 0.385
[1, 900] loss: 0.326
accuracy on test set: 91 %
...
...
[100, 300] loss: 0.087
[100, 600] loss: 0.088
[100, 900] loss: 0.095
accuracy on test set: 94 %

參考文獻(References):
【1】 架構師帶你玩轉AI 《大模型開發 - 一文搞懂CNNs工作原理(卷積與池化)》
https://www.53ai.com/news/qianyanjishu/594.html
【2】 山山而川 《Pytorch實現手寫數字識別 | MNIST數據集(CNN卷積神經網絡)》
http://www.rzrgm.cn/xinyangblog/p/16326476.html
【3】 月球背面 《深度學習入門——卷積神經網絡CNN基本原理+實戰》
https://juejin.cn/post/7238627611265253434
【4】 Tom2Code 《手撕CNN的MNIST手寫數字識別》
https://cloud.tencent.com/developer/article/2216568
【5】 全棧程序員站長 《詳解 Pytorch 實現 MNIST[通俗易懂]》
https://cloud.tencent.com/developer/article/2055189
【6】 martin-wmx 《MNIST-pytorch》
https://github.com/martin-wmx/MNIST-pytorch/blob/master/FC/model.py
posted on 2025-02-23 18:45 JasonQiuStar 閱讀(203) 評論(0) 收藏 舉報
浙公網安備 33010602011771號