MLP 的局限:從 MNIST 到 Cats vs Dogs
!!!本次模擬訓練的時長在沒有下載的基礎上且使用cuda加速的情況下是4min多。 (需要保證體驗的話,需要使用cuda或者mps進行加速,且提前下載數據集)
在上一篇實驗中,我們看到 MLP 在 MNIST 手寫數字識別上可以達到接近 97.5% 的準確率。這說明 MLP 具備了較強的擬合能力,只要樣本量足夠并且合理調節超參數,就能在像手寫體這樣結構相對簡單的數據集上取得非常不錯的表現。
但這里也埋下了一個重要的問題:這樣的高準確率是否意味著 MLP 在其他任務上也能同樣泛化?
MNIST 數據集本身有幾個特點:
1:圖片分辨率低(28×28 灰度),輸入維度相對較小;
2:樣本居中、背景干凈、噪聲少,模式相對統一;
3:任務目標簡單:10 類數字分類,類間差異明顯。
這就解釋了為什么 MLP 在 MNIST 上可以輕松達到很高的準確率——因為它不需要建模復雜的局部結構,只要把像素展平后學習全局模式,就足以區分數字。然而,如果我們把手寫圖片做一些簡單的擾動,比如:
a:微小平移(數字偏移幾像素);b:輕微旋轉(±15°);c:添加噪聲(模糊或隨機點);d:換成其他的數據源,接下來我們通過設計實驗來觀察MLP在其他物體識別泛化能力怎樣。
import torch
import torch.nn as nn
import torch.optim as optim
import torchvision
import torchvision.transforms as transforms
from torchvision.transforms import functional as TF
import numpy as np
import matplotlib.pyplot as plt
from matplotlib import rcParams
import matplotlib.font_manager as fm
import pandas as pd
# 嘗試多個中文字體(按優先級)
chinese_fonts = [
'PingFang SC', # macOS 默認
'Heiti TC', # macOS 黑體
'STHeiti', # 華文黑體
'Arial Unicode MS', # 支持中文的Arial
'SimHei', # 黑體
'Microsoft YaHei', # 微軟雅黑
]
# 查找可用的中文字體
available_fonts = [f.name for f in fm.fontManager.ttflist]
font_found = None
for font in chinese_fonts:
if font in available_fonts:
font_found = font
break
if font_found:
rcParams['font.sans-serif'] = [font_found]
rcParams['axes.unicode_minus'] = False
print(f"Using font: {font_found}")
else:
print("Warning: No Chinese font found, using English labels")
# 設置設備
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
print(f"Using device: {device}\n")
# ==================== MLP模型定義 ====================
class MLP(nn.Module):
def __init__(self, input_size=784, hidden_sizes=[512, 256], num_classes=10):
super().__init__()
layers = []
# 構建網絡
prev_size = input_size
for hidden_size in hidden_sizes:
layers.append(nn.Linear(prev_size, hidden_size))
layers.append(nn.ReLU())
layers.append(nn.Dropout(0.2))
prev_size = hidden_size
layers.append(nn.Linear(prev_size, num_classes))
self.network = nn.Sequential(*layers)
def forward(self, x):
x = x.view(x.size(0), -1) # 展平
return self.network(x)
# ==================== 訓練和評估函數 ====================
def train_model(model, train_loader, criterion, optimizer, epochs=10, device=device):
"""訓練模型"""
model = model.to(device)
model.train()
for epoch in range(epochs):
running_loss = 0.0
for images, labels in train_loader:
images, labels = images.to(device), labels.to(device)
optimizer.zero_grad()
outputs = model(images)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
running_loss += loss.item()
if (epoch + 1) % 2 == 0:
print(f" Epoch [{epoch+1}/{epochs}], Loss: {running_loss/len(train_loader):.4f}")
def evaluate_model(model, test_loader, device=device):
"""評估模型準確率"""
model.eval()
correct, total = 0, 0
with torch.no_grad():
for images, labels in test_loader:
images, labels = images.to(device), labels.to(device)
outputs = model(images)
_, predicted = torch.max(outputs, 1)
total += labels.size(0)
correct += (predicted == labels).sum().item()
accuracy = 100 * correct / total
return accuracy
# ==================== 數據增強和擾動 ====================
class AddGaussianNoise:
"""添加高斯噪聲"""
def __init__(self, mean=0., std=0.1):
self.mean = mean
self.std = std
def __call__(self, tensor):
return tensor + torch.randn(tensor.size()) * self.std + self.mean
class RandomShift:
"""隨機平移"""
def __init__(self, shift_range=4):
self.shift_range = shift_range
def __call__(self, img):
shift_x = np.random.randint(-self.shift_range, self.shift_range + 1)
shift_y = np.random.randint(-self.shift_range, self.shift_range + 1)
return TF.affine(img, angle=0, translate=(shift_x, shift_y), scale=1.0, shear=0)
# ==================== 實驗1: 基準測試 ====================
def experiment_baseline():
"""基準實驗:原始MNIST"""
print("="*60)
print("實驗1: 基準測試 - 原始MNIST")
print("="*60)
# 準備數據
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = torchvision.datasets.MNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=0)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1000, shuffle=False, num_workers=0)
# 訓練模型
model = MLP(input_size=784, hidden_sizes=[512, 256], num_classes=10)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
print("開始訓練...")
train_model(model, train_loader, criterion, optimizer, epochs=10)
# 評估
accuracy = evaluate_model(model, test_loader)
print(f"? 基準準確率: {accuracy:.2f}%\n")
return model, accuracy
# ==================== 實驗2: 平移擾動 ====================
def experiment_translation(trained_model):
"""實驗2: 測試平移不變性"""
print("="*60)
print("實驗2: 平移擾動測試")
print("="*60)
shift_ranges = [0, 2, 4, 6, 8]
accuracies = []
for shift in shift_ranges:
if shift == 0:
transform = transforms.Compose([transforms.ToTensor()])
else:
transform = transforms.Compose([
RandomShift(shift_range=shift),
transforms.ToTensor()
])
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1000, shuffle=False, num_workers=0)
accuracy = evaluate_model(trained_model, test_loader)
accuracies.append(accuracy)
print(f" 平移范圍 ±{shift}px: {accuracy:.2f}%")
print()
return shift_ranges, accuracies
# ==================== 實驗3: 旋轉擾動 ====================
def experiment_rotation(trained_model):
"""實驗3: 測試旋轉不變性"""
print("="*60)
print("實驗3: 旋轉擾動測試")
print("="*60)
rotation_angles = [0, 5, 10, 15, 20, 30]
accuracies = []
for angle in rotation_angles:
if angle == 0:
transform = transforms.Compose([transforms.ToTensor()])
else:
transform = transforms.Compose([
transforms.RandomRotation(degrees=(angle, angle)),
transforms.ToTensor()
])
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1000, shuffle=False, num_workers=0)
accuracy = evaluate_model(trained_model, test_loader)
accuracies.append(accuracy)
print(f" 旋轉角度 {angle}°: {accuracy:.2f}%")
print()
return rotation_angles, accuracies
# ==================== 實驗4: 噪聲擾動 ====================
def experiment_noise(trained_model):
"""實驗4: 測試噪聲魯棒性"""
print("="*60)
print("實驗4: 噪聲擾動測試")
print("="*60)
noise_levels = [0.0, 0.1, 0.2, 0.3, 0.4, 0.5]
accuracies = []
for noise_std in noise_levels:
if noise_std == 0.0:
transform = transforms.Compose([transforms.ToTensor()])
else:
transform = transforms.Compose([
transforms.ToTensor(),
AddGaussianNoise(mean=0., std=noise_std)
])
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1000, shuffle=False, num_workers=0)
accuracy = evaluate_model(trained_model, test_loader)
accuracies.append(accuracy)
print(f" 噪聲標準差 {noise_std:.1f}: {accuracy:.2f}%")
print()
return noise_levels, accuracies
# ==================== 實驗5: 組合擾動 ====================
def experiment_combined(trained_model):
"""實驗5: 組合擾動測試"""
print("="*60)
print("實驗5: 組合擾動測試")
print("="*60)
test_cases = [
("原始", transforms.Compose([transforms.ToTensor()])),
("平移+旋轉", transforms.Compose([
RandomShift(shift_range=4),
transforms.RandomRotation(degrees=10),
transforms.ToTensor()
])),
("平移+噪聲", transforms.Compose([
RandomShift(shift_range=4),
transforms.ToTensor(),
AddGaussianNoise(std=0.2)
])),
("旋轉+噪聲", transforms.Compose([
transforms.RandomRotation(degrees=10),
transforms.ToTensor(),
AddGaussianNoise(std=0.2)
])),
("全部擾動", transforms.Compose([
RandomShift(shift_range=4),
transforms.RandomRotation(degrees=10),
transforms.ToTensor(),
AddGaussianNoise(std=0.2)
]))
]
case_names = []
accuracies = []
for name, transform in test_cases:
test_dataset = torchvision.datasets.MNIST(root='./data', train=False, transform=transform)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1000, shuffle=False, num_workers=0)
accuracy = evaluate_model(trained_model, test_loader)
case_names.append(name)
accuracies.append(accuracy)
print(f" {name}: {accuracy:.2f}%")
print()
return case_names, accuracies
# ==================== 實驗6: Fashion-MNIST ====================
def experiment_fashion_mnist():
"""實驗6: Fashion-MNIST數據集"""
print("="*60)
print("實驗6: Fashion-MNIST 泛化測試")
print("="*60)
# 準備數據
transform = transforms.Compose([transforms.ToTensor()])
train_dataset = torchvision.datasets.FashionMNIST(root='./data', train=True, transform=transform, download=True)
test_dataset = torchvision.datasets.FashionMNIST(root='./data', train=False, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=0)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1000, shuffle=False, num_workers=0)
# 訓練模型
model = MLP(input_size=784, hidden_sizes=[512, 256], num_classes=10)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
print("開始訓練...")
train_model(model, train_loader, criterion, optimizer, epochs=10)
# 評估
accuracy = evaluate_model(model, test_loader)
print(f"? Fashion-MNIST準確率: {accuracy:.2f}%\n")
return accuracy
# ==================== 實驗7: CIFAR-10 ====================
def experiment_cifar10():
"""實驗7: CIFAR-10數據集(彩色圖像)"""
print("="*60)
print("實驗7: CIFAR-10 泛化測試")
print("="*60)
# 準備數據
transform = transforms.Compose([
transforms.ToTensor(),
transforms.Normalize((0.5, 0.5, 0.5), (0.5, 0.5, 0.5))
])
train_dataset = torchvision.datasets.CIFAR10(root='./data', train=True, transform=transform, download=True)
test_dataset = torchvision.datasets.CIFAR10(root='./data', train=False, transform=transform)
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=0)
test_loader = torch.utils.data.DataLoader(test_dataset, batch_size=1000, shuffle=False, num_workers=0)
# 訓練模型(輸入是32x32x3=3072維)
model = MLP(input_size=3072, hidden_sizes=[1024, 512, 256], num_classes=10)
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
print("開始訓練...")
train_model(model, train_loader, criterion, optimizer, epochs=15)
# 評估
accuracy = evaluate_model(model, test_loader)
print(f"? CIFAR-10準確率: {accuracy:.2f}%\n")
return accuracy
# ==================== 運行所有實驗 ====================
print("?? 開始MLP泛化能力綜合測試\n")
# 實驗1: 基準
trained_model, baseline_acc = experiment_baseline()
# 實驗2-4: 擾動測試
shift_x, shift_acc = experiment_translation(trained_model)
rotation_x, rotation_acc = experiment_rotation(trained_model)
noise_x, noise_acc = experiment_noise(trained_model)
# 實驗5: 組合擾動
combined_names, combined_acc = experiment_combined(trained_model)
# 實驗6-7: 其他數據集
fashion_acc = experiment_fashion_mnist()
cifar_acc = experiment_cifar10()
# ==================== 可視化結果 ====================
fig = plt.figure(figsize=(18, 12))
# 子圖1: 平移擾動
ax1 = plt.subplot(3, 3, 1)
ax1.plot(shift_x, shift_acc, marker='o', linewidth=2, markersize=10, color='#FF6B6B')
ax1.axhline(y=baseline_acc, color='gray', linestyle='--', alpha=0.5, label='基準')
ax1.set_xlabel('平移范圍 (±pixels)', fontsize=11)
ax1.set_ylabel('準確率 (%)', fontsize=11)
ax1.set_title('平移擾動對準確率的影響', fontsize=12, fontweight='bold')
ax1.grid(True, alpha=0.3)
ax1.legend()
ax1.fill_between(shift_x, shift_acc, baseline_acc, alpha=0.2, color='#FF6B6B')
# 子圖2: 旋轉擾動
ax2 = plt.subplot(3, 3, 2)
ax2.plot(rotation_x, rotation_acc, marker='s', linewidth=2, markersize=10, color='#4ECDC4')
ax2.axhline(y=baseline_acc, color='gray', linestyle='--', alpha=0.5, label='基準')
ax2.set_xlabel('旋轉角度 (度)', fontsize=11)
ax2.set_ylabel('準確率 (%)', fontsize=11)
ax2.set_title('旋轉擾動對準確率的影響', fontsize=12, fontweight='bold')
ax2.grid(True, alpha=0.3)
ax2.legend()
ax2.fill_between(rotation_x, rotation_acc, baseline_acc, alpha=0.2, color='#4ECDC4')
# 子圖3: 噪聲擾動
ax3 = plt.subplot(3, 3, 3)
ax3.plot(noise_x, noise_acc, marker='D', linewidth=2, markersize=10, color='#95E1D3')
ax3.axhline(y=baseline_acc, color='gray', linestyle='--', alpha=0.5, label='基準')
ax3.set_xlabel('噪聲標準差', fontsize=11)
ax3.set_ylabel('準確率 (%)', fontsize=11)
ax3.set_title('噪聲擾動對準確率的影響', fontsize=12, fontweight='bold')
ax3.grid(True, alpha=0.3)
ax3.legend()
ax3.fill_between(noise_x, noise_acc, baseline_acc, alpha=0.2, color='#95E1D3')
# 子圖4: 組合擾動
ax4 = plt.subplot(3, 3, 4)
colors = ['#74B9FF', '#FDA7DF', '#F8C471', '#A29BFE', '#E17055']
bars = ax4.bar(range(len(combined_names)), combined_acc, color=colors, alpha=0.8)
ax4.axhline(y=baseline_acc, color='gray', linestyle='--', alpha=0.5)
ax4.set_xlabel('擾動類型', fontsize=11)
ax4.set_ylabel('準確率 (%)', fontsize=11)
ax4.set_title('組合擾動測試', fontsize=12, fontweight='bold')
ax4.set_xticks(range(len(combined_names)))
ax4.set_xticklabels(combined_names, rotation=45, ha='right', fontsize=9)
ax4.grid(True, alpha=0.3, axis='y')
# 在柱子上標注數值
for i, (bar, acc) in enumerate(zip(bars, combined_acc)):
height = bar.get_height()
ax4.text(bar.get_x() + bar.get_width()/2., height,
f'{acc:.1f}%', ha='center', va='bottom', fontsize=9)
# 子圖5: 準確率下降對比
ax5 = plt.subplot(3, 3, 5)
perturbations = ['平移±8px', '旋轉30°', '噪聲0.5', '全部組合']
acc_drops = [
baseline_acc - shift_acc[-1],
baseline_acc - rotation_acc[-1],
baseline_acc - noise_acc[-1],
baseline_acc - combined_acc[-1]
]
colors_drop = ['#FF6B6B', '#4ECDC4', '#95E1D3', '#E17055']
bars = ax5.barh(range(len(perturbations)), acc_drops, color=colors_drop, alpha=0.8)
ax5.set_yticks(range(len(perturbations)))
ax5.set_yticklabels(perturbations)
ax5.set_xlabel('準確率下降 (%)', fontsize=11)
ax5.set_title('不同擾動的影響程度', fontsize=12, fontweight='bold')
ax5.grid(True, alpha=0.3, axis='x')
# 標注數值
for i, (bar, drop) in enumerate(zip(bars, acc_drops)):
width = bar.get_width()
ax5.text(width, bar.get_y() + bar.get_height()/2.,
f' {drop:.1f}%', ha='left', va='center', fontsize=10, fontweight='bold')
# 子圖6: 跨數據集性能對比
ax6 = plt.subplot(3, 3, 6)
datasets = ['MNIST\n(手寫數字)', 'Fashion-MNIST\n(服裝)', 'CIFAR-10\n(彩色物體)']
dataset_accs = [baseline_acc, fashion_acc, cifar_acc]
colors_dataset = ['#6C5CE7', '#00B894', '#FD79A8']
bars = ax6.bar(range(len(datasets)), dataset_accs, color=colors_dataset, alpha=0.8, width=0.6)
ax6.set_ylabel('準確率 (%)', fontsize=11)
ax6.set_title('不同數據集性能對比', fontsize=12, fontweight='bold')
ax6.set_xticks(range(len(datasets)))
ax6.set_xticklabels(datasets, fontsize=10)
ax6.set_ylim([0, 100])
ax6.grid(True, alpha=0.3, axis='y')
# 標注數值
for bar, acc in zip(bars, dataset_accs):
height = bar.get_height()
ax6.text(bar.get_x() + bar.get_width()/2., height + 2,
f'{acc:.1f}%', ha='center', va='bottom', fontsize=11, fontweight='bold')
# 子圖7: 綜合魯棒性雷達圖
ax7 = plt.subplot(3, 3, 7, projection='polar')
categories = ['平移不變性', '旋轉不變性', '噪聲魯棒性', '跨域泛化']
# 計算各項得分(基于準確率保持率)
scores = [
(shift_acc[-1] / baseline_acc) * 100,
(rotation_acc[-1] / baseline_acc) * 100,
(noise_acc[-1] / baseline_acc) * 100,
(fashion_acc / baseline_acc) * 100
]
scores += scores[:1] # 閉合
angles = np.linspace(0, 2 * np.pi, len(categories), endpoint=False).tolist()
angles += angles[:1]
ax7.plot(angles, scores, 'o-', linewidth=2, color='#6C5CE7')
ax7.fill(angles, scores, alpha=0.25, color='#6C5CE7')
ax7.set_xticks(angles[:-1])
ax7.set_xticklabels(categories, fontsize=10)
ax7.set_ylim(0, 100)
ax7.set_title('MLP魯棒性評分', fontsize=12, fontweight='bold', pad=20)
ax7.grid(True)
# 子圖8: 數據復雜度分析
ax8 = plt.subplot(3, 3, 8)
complexity_data = {
'MNIST': {'分辨率': 28*28, '通道數': 1, '復雜度': 1, '準確率': baseline_acc, 'color': '#6C5CE7'},
'Fashion-MNIST': {'分辨率': 28*28, '通道數': 1, '復雜度': 2, '準確率': fashion_acc, 'color': '#00B894'},
'CIFAR-10': {'分辨率': 32*32*3, '通道數': 3, '復雜度': 3, '準確率': cifar_acc, 'color': '#FD79A8'}
}
for i, (name, data) in enumerate(complexity_data.items()):
ax8.scatter(data['復雜度'], data['準確率'], s=500, alpha=0.6,
color=data['color'], label=name, edgecolors='black', linewidth=2)
ax8.text(data['復雜度'], data['準確率'] - 3, name,
ha='center', fontsize=9, fontweight='bold')
ax8.set_xlabel('數據復雜度', fontsize=11)
ax8.set_ylabel('準確率 (%)', fontsize=11)
ax8.set_title('數據復雜度與性能關系', fontsize=12, fontweight='bold')
ax8.set_xticks([1, 2, 3])
ax8.set_xticklabels(['簡單', '中等', '復雜'])
ax8.grid(True, alpha=0.3)
ax8.set_ylim([30, 100])
# 子圖9: 總結表格
ax9 = plt.subplot(3, 3, 9)
ax9.axis('off')
summary_data = [
['測試項目', '準確率', '性能變化'],
['基準 (MNIST)', f'{baseline_acc:.1f}%', '-'],
['平移 ±8px', f'{shift_acc[-1]:.1f}%', f'↓{baseline_acc-shift_acc[-1]:.1f}%'],
['旋轉 30°', f'{rotation_acc[-1]:.1f}%', f'↓{baseline_acc-rotation_acc[-1]:.1f}%'],
['噪聲 0.5', f'{noise_acc[-1]:.1f}%', f'↓{baseline_acc-noise_acc[-1]:.1f}%'],
['組合擾動', f'{combined_acc[-1]:.1f}%', f'↓{baseline_acc-combined_acc[-1]:.1f}%'],
['Fashion-MNIST', f'{fashion_acc:.1f}%', f'↓{baseline_acc-fashion_acc:.1f}%'],
['CIFAR-10', f'{cifar_acc:.1f}%', f'↓{baseline_acc-cifar_acc:.1f}%']
]
table = ax9.table(cellText=summary_data, cellLoc='center', loc='center',
colWidths=[0.4, 0.3, 0.3])
table.auto_set_font_size(False)
table.set_fontsize(10)
table.scale(1, 2)
# 設置表頭樣式
for i in range(3):
table[(0, i)].set_facecolor('#6C5CE7')
table[(0, i)].set_text_props(weight='bold', color='white')
# 設置行顏色
for i in range(1, len(summary_data)):
for j in range(3):
if i % 2 == 0:
table[(i, j)].set_facecolor('#F0F0F0')
else:
table[(i, j)].set_facecolor('white')
ax9.set_title('實驗結果匯總', fontsize=12, fontweight='bold', pad=20)
# 總標題
fig.suptitle('MLP泛化能力綜合評估報告', fontsize=16, fontweight='bold', y=0.995)
plt.tight_layout()
plt.savefig('mlp_generalization_analysis.png', dpi=300, bbox_inches='tight')
print("="*60)
print("?? 可視化報告已保存: mlp_generalization_analysis.png")
print("="*60)
plt.show()
# ==================== 文字總結 ====================
print("\n" + "="*80)
print("?? MLP泛化能力分析總結")
print("="*80)
print("\n【核心發現】")
print("1. ?? 空間不變性弱")
print(f" - 平移±8px導致準確率下降: {baseline_acc - shift_acc[-1]:.1f}%")
print(f" - 旋轉30°導致準確率下降: {baseline_acc - rotation_acc[-1]:.1f}%")
print(" ?? 原因: MLP缺乏局部特征提取能力,對空間位置敏感")
print("\n2. ?? 噪聲魯棒性差")
print(f" - 噪聲標準差0.5導致準確率下降: {baseline_acc - noise_acc[-1]:.1f}%")
print(" ?? 原因: 全連接層將所有像素等權對待,無法區分信號和噪聲")
print("\n3. ?? 組合擾動影響嚴重")
print(f" - 多重擾動準確率: {combined_acc[-1]:.1f}%")
print(f" - 相比基準下降: {baseline_acc - combined_acc[-1]:.1f}%")
print(" ?? 原因: 缺乏抗干擾機制,擾動效應疊加")
print("\n4. ?? 跨域泛化能力有限")
print(f" - MNIST (簡單): {baseline_acc:.1f}%")
print(f" - Fashion-MNIST (中等): {fashion_acc:.1f}%")
print(f" - CIFAR-10 (復雜): {cifar_acc:.1f}%")
print(" ?? 結論: 數據復雜度越高,MLP性能下降越明顯")
print("\n【MLP的局限性】")
print("? 無法學習平移不變性 - 同一物體在不同位置被視為不同模式")
print("? 無法學習旋轉不變性 - 對角度變化極度敏感")
print("? 無法提取局部特征 - 丟失了圖像的空間結構信息")
print("? 參數量大易過擬合 - 全連接導致參數爆炸")
print("? 泛化能力弱 - 在復雜、真實場景下性能大幅下降")
print("\n【為什么需要CNN?】")
print("? 卷積操作天然具有平移不變性")
print("? 局部感受野保留空間結構")
print("? 權重共享大幅減少參數")
print("? 層級特征提取(邊緣→紋理→物體)")
print("? 更強的泛化能力和魯棒性")
print("\n" + "="*80)
print("?? 結論: MLP在MNIST上表現優秀,但面對真實世界的復雜場景")
print(" (位置變化、視角變化、光照變化、遮擋等)時能力有限。")
print(" 這就是為什么計算機視覺領域需要卷積神經網絡(CNN)!")
print("="*80)


通過分析的表格可以明顯看出MLP的局限性,下一章我們將介紹MLP-CNN

浙公網安備 33010602011771號