C# TorchSharp 圖像分類實戰:VGG大規模圖像識別的超深度卷積網絡
教程名稱:使用 C# 入門深度學習
作者:癡者工良
教程地址:
電子書倉庫:https://github.com/whuanle/cs_pytorch
Maomi.Torch 項目倉庫:https://github.com/whuanle/Maomi.Torch
VGG大規模圖像識別的超深度卷積網絡
本文主要講解用于大規模圖像識別的超深度卷積網絡 VGG,通過 VGG 實現自有數據集進行圖像分類訓練模型和識別,VGG 有 vgg11、vgg11_bn、vgg13、vgg13_bn、vgg16、vgg16_bn、vgg19、vgg19_bn 等變種,VGG 架構的實現可參考論文:https://arxiv.org/abs/1409.1556
論文中文版地址:
數據集
本文主要使用經典圖像分類數據集 CIFAR-10 進行訓練,CIFAR-10 數據集中有 10 個分類,每個類別均有 60000 張圖像,50000 張訓練圖像和 10000 張測試圖像,每個圖像都經過了預處理,生成 32x32 彩色圖像。
CIFAR-10 的 10 個分類分別是:
airplane
automobile
bird
cat
deer
dog
frog
horse
ship
truck
下面給出幾種數據集的本地化導入方式。
直接下載
由于 CIFAR-10 是經典數據集,因此 TorchSharp 默認支持下載該數據集,但是由于網絡問題,國內下載數據庫需要開飛機,數據集自動下載和導入:
// 加載訓練和驗證數據
var train_dataset = datasets.CIFAR10(root: "E:/datasets/CIFAR-10", train: true, download: true, target_transform: transform);
var val_dataset = datasets.CIFAR10(root: "E:/datasets/CIFAR-10", train: false, download: true, target_transform: transform);
opendatalab 數據集社區
opendatalab 是一個開源數據集社區倉庫,里面有大量免費下載的數據集,借此機會給讀者講解一下如何從 opendatalab 下載數據集,這對讀者學習非常有幫助。
CIFAR-10 數據集倉庫地址:
https://opendatalab.com/OpenDataLab/CIFAR-10/cli/main
打開 https://opendatalab.com 注冊賬號,然后在個人信息中心添加密鑰。

然后下載 openxlab 提供的 cli 工具:
pip install openxlab #安裝
安裝 openxlab 后,會要求添加路徑到環境變量,環境變量地址是 Scripts 地址,示例:
C:\Users\%USER%\AppData\Roaming\Python\Python312\Scripts
接著進行登錄,輸入命令后按照提示輸入 key 和 secret:
openxlab login # 進行登錄,輸入對應的AK/SK,可在個人中心查看AK/SK
然后打開空目錄下載數據集,數據集倉庫會被下載到 OpenDataLab___CIFAR-10 目錄中:
openxlab dataset info --dataset-repo OpenDataLab/CIFAR-10 # 數據集信息及文件列表查看
openxlab dataset get --dataset-repo OpenDataLab/CIFAR-10 #數據集下載

數據集信息及文件列表查看
openxlab dataset info --dataset-repo OpenDataLab/CIFAR-10

下載的文件比較多,但是我們只需要用到 cifar-10-binary.tar.gz,直接解壓 cifar-10-binary.tar.gz 到目錄中(也可以不解壓)。

然后導入數據:
// 加載訓練和驗證數據
var train_dataset = datasets.CIFAR10(root: "E:/datasets/OpenDataLab___CIFAR-10", train: true, download: false, target_transform: transform);
var val_dataset = datasets.CIFAR10(root: "E:/datasets/OpenDataLab___CIFAR-10", train: false, download: false, target_transform: transform);
自定義數據集
Maomi.Torch 提供了自定義數據集導入方式,降低了開發者制作數據集的難度。自定義數據集也要區分訓練數據集和測試數據集,訓練數據集用于特征識別和訓練,而測試數據集用于驗證模型訓練的準確率和損失值。
測試數據集和訓練數據集可以放到不同的目錄中,具體名稱沒有要求,然后每個分類單獨一個目錄,目錄名稱就是分類名稱,按照目錄名稱的排序從 0 生成標簽值。
├─test
│ ├─airplane
│ ├─automobile
│ ├─bird
│ ├─cat
│ ├─deer
│ ├─dog
│ ├─frog
│ ├─horse
│ ├─ship
│ └─truck
└─train
│ ├─airplane
│ ├─automobile
│ ├─bird
│ ├─cat
│ ├─deer
│ ├─dog
│ ├─frog
│ ├─horse
│ ├─ship
│ └─truck

讀者可以參考 exportdataset項目,將 CIFAR-10 數據集生成導出到目錄中。
通過自定義目錄導入數據集的代碼為:
var train_dataset = MM.Datasets.ImageFolder(root: "E:/datasets/t1/train", target_transform: transform);
var val_dataset = MM.Datasets.ImageFolder(root: "E:/datasets/t1/test", target_transform: transform);
模型訓練
定義圖像預處理轉換代碼,代碼如下所示:
Device defaultDevice = MM.GetOpTimalDevice();
torch.set_default_device(defaultDevice);
Console.WriteLine("當前正在使用 {defaultDevice}");
// 數據預處理
var transform = transforms.Compose([
transforms.Resize(32, 32),
transforms.ConvertImageDtype( ScalarType.Float32),
MM.transforms.ReshapeTransform(new long[]{ 1,3,32,32}),
transforms.Normalize(means: new double[] { 0.485, 0.456, 0.406 }, stdevs: new double[] { 0.229, 0.224, 0.225 }),
MM.transforms.ReshapeTransform(new long[]{ 3,32,32})
]);
因為 TorchSharp 對圖像維度處理的兼容性不好,沒有 Pytorch 的自動處理,因此導入的圖片維度和批處理維度、transforms 處理的維度兼容性不好,容易報錯,因此這里需要使用 Maomi.Torch 的轉換函數,以便在導入圖片和進行圖像批處理的時候,保障 shape 符合要求。
分批加載數據集:
// 加載訓練和驗證數據
var train_dataset = datasets.CIFAR10(root: "E:/datasets/CIFAR-10", train: true, download: true, target_transform: transform);
var val_dataset = datasets.CIFAR10(root: "E:/datasets/CIFAR-10", train: false, download: true, target_transform: transform);
var train_loader = new DataLoader(train_dataset, batchSize: 1024, shuffle: true, device: defaultDevice, num_worker: 10);
var val_loader = new DataLoader(val_dataset, batchSize: 1024, shuffle: false, device: defaultDevice, num_worker: 10);
初始化 vgg16 網絡:
var model = torchvision.models.vgg16(num_classes: 10);
model.to(device: defaultDevice);
設置損失函數和優化器:
var criterion = nn.CrossEntropyLoss();
var optimizer = optim.SGD(model.parameters(), learningRate: 0.001, momentum: 0.9);
訓練模型并保存:
int num_epochs = 150;
for (int epoch = 0; epoch < num_epochs; epoch++)
{
model.train();
double running_loss = 0.0;
int i = 0;
foreach (var item in train_loader)
{
var (inputs, labels) = (item["data"], item["label"]);
var inputs_device = inputs.to(defaultDevice);
var labels_device = labels.to(defaultDevice);
optimizer.zero_grad();
var outputs = model.call(inputs_device);
var loss = criterion.call(outputs, labels_device);
loss.backward();
optimizer.step();
running_loss += loss.item<float>() * inputs.size(0);
Console.WriteLine($"[{epoch}/{num_epochs}][{i % train_loader.Count}/{train_loader.Count}]");
i++;
}
double epoch_loss = running_loss / train_dataset.Count;
Console.WriteLine($"Train Loss: {epoch_loss:F4}");
model.eval();
long correct = 0;
int total = 0;
using (torch.no_grad())
{
foreach (var item in val_loader)
{
var (inputs, labels) = (item["data"], item["label"]);
var inputs_device = inputs.to(defaultDevice);
var labels_device = labels.to(defaultDevice);
var outputs = model.call(inputs_device);
var predicted = outputs.argmax(1);
total += (int)labels.size(0);
correct += (predicted == labels_device).sum().item<long>();
}
}
double val_accuracy = 100.0 * correct / total;
Console.WriteLine($"Validation Accuracy: {val_accuracy:F2}%");
}
model.save("model.dat");
啟動項目后可以直接執行訓練,訓練一百多輪后,準確率在 70% 左右,損失值在 0.0010 左右,繼續訓練已經提高不了準確率了。
導出的模型還是比較大的:
513M model.dat
下面來編寫圖像識別測試,在示例項目 vggdemo 中自帶了三張圖片,讀者可以直接導入使用。
model.load("model.dat");
model.to(device: defaultDevice);
model.eval();
var classes = new string[] {
"airplane",
"automobile",
"bird",
"cat",
"deer",
"dog",
"frog",
"horse",
"ship",
"truck"
};
List<Tensor> imgs = new();
imgs.Add(transform.call(MM.LoadImage("airplane.jpg").to(defaultDevice)).view(1, 3, 32, 32));
imgs.Add(transform.call(MM.LoadImage("cat.jpg").to(defaultDevice)).view(1, 3, 32, 32));
imgs.Add(transform.call(MM.LoadImage("dog.jpg").to(defaultDevice)).view(1, 3, 32, 32));
using (torch.no_grad())
{
foreach (var data in imgs)
{
var outputs = model.call(data);
var index = outputs[0].argmax(0).ToInt32();
// 轉換為歸一化的概率
// outputs.shape = [1,10],所以取 [dim:1]
var array = torch.nn.functional.softmax(outputs, dim: 1);
var max = array[0].ToFloat32Array();
var predicted1 = classes[index];
Console.WriteLine($"識別結果 {predicted1},準確率:{max[index] * 100}%");
}
}
識別結果:
當前正在使用 cuda:0
識別結果 airplane,準確率:99.99983%
識別結果 cat,準確率:99.83113%
識別結果 dog,準確率:100%
用到的三張圖片均從網絡上搜索而來:




浙公網安備 33010602011771號