用張量廣播機(jī)制實(shí)現(xiàn)神經(jīng)網(wǎng)絡(luò)反向傳播
正向傳播
要想了解反向傳播,先要了解正向傳播:正向傳播的每一步是,用一個(gè)或很多輸入生成一個(gè)輸出。
反向傳播
反向傳播的作用是計(jì)算模型參數(shù)的偏導(dǎo)數(shù)。再具體一點(diǎn),反向傳播的每一個(gè)step就是:已知正向傳播的輸入本身,和輸出的偏導(dǎo)數(shù),求出每個(gè)輸入的偏導(dǎo)數(shù)的過程。
反向傳播既簡單,又復(fù)雜:
- 它的原理很簡單:鏈?zhǔn)椒▌t求偏導(dǎo)。
- 它的公式又很復(fù)雜:因?yàn)樗墓娇雌饋碚娴暮軓?fù)雜。
模型的參數(shù)
反向傳播就是計(jì)算模型的參數(shù)的偏導(dǎo)數(shù),所以介紹一下模型的參數(shù):
- 模型里有很多參數(shù),參數(shù)的本質(zhì)是張量,可以把張量看成多維數(shù)組,也可以把張量看成一顆樹。
- 張量有形狀,張量的偏導(dǎo)數(shù)是一個(gè)
同樣形狀的張量。
線性函數(shù)的反向傳播
線性函數(shù)就是 y = wx + b,我們輸入x,w,和 b 就能得到y(tǒng)。y是我們算出來的,這個(gè)算y的過程就是正向傳播。
我們規(guī)定字母后面加 .g 表示偏導(dǎo)數(shù),如 y.g 就是y的偏導(dǎo)數(shù),w.g 就是w的偏導(dǎo)數(shù)。
那么我們的目的,就是根據(jù) x, w, b 和 y.g 的值,分別算出 w,x,和b的偏導(dǎo)數(shù),而這個(gè)過程,就是反向傳播。
為了便于說明,我們假設(shè)了每個(gè)變量的形狀: x(1000, 784), w(784, 50), b(50), y(1000, 50)。
計(jì)算 x.g
y = wx + b 對 x 求偏導(dǎo) 得 w,即我們要用 w 和 y.g 計(jì)算出 x.g。
w 的形狀是 (784, 50),y.g的形狀跟y相同,是(1000, 50),如何用這兩個(gè)形狀湊出 x.g 的(1000, 784)?
emmm,很簡單,就是這樣,然后那樣,就行了。看玩笑的。。其實(shí)就是 y.g 中間加一維,變成 (1000, 1, 50) ,然后再跟 w 搞一下,得到一個(gè) (1000, 784, 50) 的形狀,再把最后一維消去,就得到 (1000, 784) 的形狀了。
即:
x.g = (y.g.unsqueeze(1) * w).sum(dim=-1)
計(jì)算 w.g
同理咯,y = wx + b 對 w 求偏導(dǎo) 得 x,即我們要用 x 和 y.g 計(jì)算出 w.g。
x 的形狀是 (1000, 784),y.g的形狀跟y相同,是(1000, 50),如何用這兩個(gè)形狀湊出 w.g` 的(784, 50)?
先將 x 最后加一維,變成 (1000, 784, 1),再將 y.g 中間加一維,變成 (1000, 1, 50),這倆搞一下,變成 (1000, 784, 50),再把開頭的那一維消去,就變成 (784, 50)了。
即:
w.g = (x.unsqueeze(-1) * y.g.unsqueeze(1)).sum(dim=0)
計(jì)算 b.g
y = wx + b 對 b 求偏導(dǎo) 得常數(shù) 1,所以直接用形狀為(1000, 50)的y.g來湊出形狀為(50)的b.g就可以了。
那么就非常簡單了,直接把(1000, 50)消去最開始的那一維就能得到(50),即:
b.g = y.g.sum(0)
線性函數(shù)的反向傳播代碼
已知線性函數(shù)的輸入是 inp,輸出是 out,計(jì)算過程用到的兩個(gè)參數(shù)是 w和b,則反向傳播的代碼如下:
def back_lin(inp, w, b, out):
inp.g = (out.g.unsqueeze(1) * w).sum(dim=-1)
w.g = (inp.unsqueeze(-1) * out.g.unsqueeze(1)).sum(dim=0)
b.g = out.g.sum(0)
relu函數(shù)的反向傳播
relu函數(shù)表示起來很簡單,就是 max(x, 0),但是在 pytorch 中這樣寫是行不通的,所以用這面這個(gè)函數(shù)表示:
def relu(x):
return x.clamp_min(0)
其反向傳播表示為:
def back_relu(inp, out):
return (inp > 0).float() * out.g
mse函數(shù)的反向傳播
mse函數(shù)用代碼表示為:
def mse(pred, target):
return (pred.squeeze(dim=-1)-target).pow(2).mean()
其反向傳播則是:
def back_mse(pred, target):
return 2. * (pred.squeeze(dim=-1) - target).unsqueeze(dim=-1) / pred.shape[0]
測試
假設(shè)我們的模型結(jié)果為:輸入一個(gè)x,進(jìn)行一次線性變換,再經(jīng)過一次relu,然后再經(jīng)過一次線性變換得到結(jié)果。
先隨機(jī)生成 輸入、輸出和各個(gè)參數(shù):
# 偽造輸入和答案
import torch
torch.manual_seed(0)
input_ = torch.randn(1000, 784).requires_grad_(True) # 輸入
target = torch.randn(1000) # 答案
# 創(chuàng)建其它參數(shù)
w1 = torch.randn(784, 50).requires_grad_(True)
b1 = torch.randn(50).requires_grad_(True)
w2 = torch.randn(50, 1).requires_grad_(True)
b2 = torch.randn(1).requires_grad_(True)
正向傳播得到模型的輸出:
l1 = input_ @ w1 + b1
l2 = relu(l1)
output = l2 @ w2 + b2
loss = mse(output, target)
反向傳播:
back_mse(output, target)
back_lin(l2, w2, b2, output)
back_relu(l1, l2)
back_lin(input_, w1, b1, l1)
此時(shí) w1.g,b1.g和 w2.g,b2.g均已求出。
然后用pytorch自帶的反向傳播求一下梯度:
# 先保存一下手動(dòng)求的梯度
w1g = w1.g.clone()
b1g = b1.g.clone()
w2g = w2.g.clone()
b2g = b2.g.clone()
input_ = input_.clone().requires_grad_(True)
w1 = w1.clone().requires_grad_(True)
b1 = b1.clone().requires_grad_(True)
w2 = w2.clone().requires_grad_(True)
b2 = b2.clone().requires_grad_(True)
l1 = input_ @ w1 + b1
l2 = relu(l1)
output = l2 @ w2 + b2
loss = mse(output, target)
loss.backward()
此時(shí)對比一下我們手動(dòng)求得的梯度和調(diào)用系統(tǒng)函數(shù)求得的梯度,發(fā)現(xiàn)二者是相等的:
def is_same(a, b):
return (a - b).max() < 1e-4
is_same(w1g, w1.grad), is_same(b2g, b2.grad), is_same(w2g, w2.grad), is_same(b2g, b2.grad)
"""輸出
(tensor(True), tensor(True), tensor(True), tensor(True))
"""
總結(jié)
借助簡單的求導(dǎo)和張量的廣播機(jī)制,就可以推導(dǎo)實(shí)現(xiàn)神經(jīng)網(wǎng)絡(luò)的反向傳播。

浙公網(wǎng)安備 33010602011771號