RNN的PyTorch實現
官方實現
PyTorch已經實現了一個RNN類,就在torch.nn工具包中,通過torch.nn.RNN調用。
使用步驟:
- 實例化類;
- 將輸入層向量和隱藏層向量初始狀態值傳給實例化后的對象,獲得RNN的輸出。
在實例化該類時,需要傳入如下屬性:
- input_size:輸入層神經元個數;
- hidden_size:每層隱藏層的神經元個數;
- num_layers:隱藏層層數,默認設置為1層;
- nonlinearity:激活函數的選擇,可選是'tanh'或者'relu',默認設置為'tanh';
- bias:偏置系數,可選是'True'或者'False',默認設置為'True';
- batch_first:可選是'True'或者'False',默認設置為'False';
- dropout:默認設置為0。若為非0,將在除最后一層的每層RNN輸出上引入Dropout層,dropout概率就是該非零值;
- bidirectional:默認設置為False。若為True,即為雙向RNN。
RNN的輸入有兩個,一個是input,一個是h0。input就是輸入層向量,h0就是隱藏層初始狀態值。
若沒有采用批量輸入,則輸入層向量的形狀為(L, Hin);
若采用批量輸入,且batch_first為False,則輸入層向量的形狀為(L, N, Hin);
若采用批量輸入,且batch_first為True,則輸入層向量的形狀為(N, L, Hin);
對于(N, L, Hin),在文本輸入時,可以按順序理解為(每次輸入幾句話,每句話有幾個字,每個字由多少維的向量表示)。
若沒有采用批量輸入,則隱藏層向量的形狀為(D * num_layers, Hout);
若采用批量輸入,則隱藏層向量的形狀為(D * num_layers, N, Hout);
注意,batch_first的設置對隱藏層向量的形狀不起作用。
RNN的輸出有兩個,一個是output,一個是hn。output包含了每個時間步最后一層的隱藏層狀態,hn包含了最后一個時間步每層的隱藏層狀態。
若沒有采用批量輸入,則輸出層向量的形狀為(L, D * Hout);
若采用批量輸入,且batch_first為False,則輸出層向量的形狀為(L, N, D * Hout);
若采用批量輸入,且batch_first為True,則輸出層向量的形狀為(N, L, D * Hout)。
參數解釋:
- N代表的是批量大小;
- L代表的是輸入的序列長度;
- 若是雙向RNN,則D的值為2;若是單向RNN,則D的值為1;
- Hin在數值上是輸入層神經元個數;
- Hout在數值上是隱藏層神經元個數。
復現代碼
import torch
import torch.nn as nn
rnn = nn.RNN(10, 20, 1, batch_first=True) # 實例化一個單向單層RNN
input = torch.randn(5, 3, 10)
h0 = torch.randn(1, 5, 20)
output, hn = rnn(input, h0)
手寫復現
復現代碼
import torch
import torch.nn as nn
class MyRNN(nn.Module):
def __init__(self, input_size, hidden_size):
super().__init__()
self.input_size = input_size
self.hidden_size = hidden_size
self.weight_ih = torch.randn(self.hidden_size, self.input_size) * 0.01
self.weight_hh = torch.randn(self.hidden_size, self.hidden_size) * 0.01
self.bias_ih = torch.randn(self.hidden_size)
self.bias_hh = torch.randn(self.hidden_size)
def forward(self, input, h0):
N, L, input_size = input.shape
output = torch.zeros(N, L, self.hidden_size)
for t in range(L):
x = input[:, t, :].unsqueeze(2) # 獲得當前時刻的輸入特征,[N, input_size, 1]。unsqueeze(n),在第n維上增加一維
w_ih_batch = self.weight_ih.unsqueeze(0).tile(N, 1, 1) # [N, hidden_size, input_size]
w_hh_batch = self.weight_hh.unsqueeze(0).tile(N, 1, 1) # [N, hidden_size, hidden_size]
w_times_x = torch.bmm(w_ih_batch, x).squeeze(-1) # [N, hidden_size]。squeeze(n),在第n維上減小一維
w_times_h = torch.bmm(w_hh_batch, h0.unsqueeze(2)).squeeze(-1) # [N, hidden_size]
h0 = torch.tanh(w_times_x + self.bias_ih + w_times_h + self.bias_hh)
output[:, t, :] = h0
return output, h0.unsqueeze(0)
驗證正確性
my_rnn = MyRNN(10, 20)
input = torch.randn(5, 3, 10)
h0 = torch.randn(5, 20)
my_output, my_hn = my_rnn(input, h0)
print(output.shape == my_output.shape, hn.shape == my_hn.shape)
True True

浙公網安備 33010602011771號