五、神經網絡的基本框架-nn.module的使用、卷積
前置知識
super().init() 是用于調用父類的構造函數(初始化方法)
搭建簡單的神經網絡框架
import torch
from torch import nn
class MM(nn.module):
def __init__(self):
super(MM,self).__init__() # 調用父類的初始化函數
def forward(self,input):
output=input+1
return output
mm=MM()
x=torch.tensor(1.0)
output=mm(x) # 這里可以直接調用類,是因為所有nn.module類及繼承自它的子類都含有特殊的call函數,而call函數中又會自動調用它內部的一些函數(例如這里的forward函數)
print(output)
卷積操作
以Conv2d為例,2D 卷積操作(Convolutional Operation),主要目的是 通過卷積核提取輸入圖像的局部特征。這種操作廣泛應用于 圖像處理、特征提取和深度學習中的 CNN(卷積神經網絡)。
以下是Conv2d所需要的參數,這里是torch.nn.functional中的conv2d(區別一下torch.nn和torch.nn.functional:前者是對后者功能的一個封裝)這里先介紹torch.nn.functional中的conv2d

事實上,卷積操作 是在輸入圖像上滑動卷積核,并計算 加權和 以生成特征圖。

以下將做一個簡單的示范:
import torch
import torch.nn.functional as F
# 輸入的圖像為以下數據
input=torch.tensor([[1,2,0,3,1],
[0,1,2,3,1],
[1,2,1,0,0],
[5,2,3,1,1],
[2,1,0,1,1]])
# 以下為卷積核
kernel=torch.tensor([[1,2,1],
[0,1,0],
[2,1,0]])
input=torch.reshape(input,(1,1,5,5))
kernel=torch.reshape(kernel,(1,1,3,3))
print(input.shape)
print(kernel.shape)
output=F.conv2d(input,kernel,stride=1)
print(output)
output2=F.conv2d(input,kernel,stride=2)
print(output2)
output3=F.conv2d(input,kernel,stride=1,padding=1)
print(output3)
得到的輸出結果如下:

卷積層
這里便是直接對torch.nn中的conv2d進行使用。
以下展示了官方文檔中的卷積操作動圖:

特殊:空洞卷積

簡單示例:
dataset=torchvision.datasets.CIFAR10(root='./dataset', train=False, transform=torchvision.transforms.ToTensor(), download=True)
dataloader=DataLoader(dataset,batch_size=64,shuffle=True)
class MM(nn.Module):
def __init__(self):
super(MM,self).__init__()
self.conv1=Conv2d(3,6,3,1,0)
def forward(self,x):
x=self.conv1(x)
return x
mm=MM()
writer=SummaryWriter("logs")
step=0
for data in dataloader:
imgs,targets=data
output=mm(imgs)
print(imgs.shape)
print(output.shape)
# torch.Size([64, 3, 32, 32])
writer.add_images("input",imgs,step)
# torch.Size([64, 6, 30, 30])
output=torch.reshape(output,(-1,3,30,30)) # 由于tensorboard無法展示6 channels的圖像,
# 所以在不太嚴謹的情況下需要將6切成3,
# 但因此會導致batch_size增多,在不知道它的大小之前,可以設置為-1
writer.add_images("output",output,step)
step+=1
在tensorboard中的輸出結果如下圖所示:

補充:3維卷積示意圖


浙公網安備 33010602011771號