一種偏主觀的矩陣乘法梯度推導(dǎo)方法
很早在紙上推導(dǎo)過梯度的計算方法,但每次都忘記推導(dǎo)過程反復(fù)推導(dǎo)。于此想總結(jié)新的記憶方法。
梯度下降推導(dǎo)過程難以記憶來自于矩陣微積分,矩陣微積分中涉及標量、向量、矩陣之間兩兩求導(dǎo)操作,其定義如下圖,√ 表示存在定義,x 表示不存在定義[1]:
| 函數(shù) \ 自變量 | scalar | vector | matrix |
|---|---|---|---|
| scalar | √ | √(Nabla) | √ |
| vector | √ | √(Jacobi) | × |
| matrix | √ | × | × |
- 矩陣和矩陣之間、向量和矩陣之間不存在導(dǎo)數(shù)
- 導(dǎo)數(shù)的維度是函數(shù)、自變量維度之和,反應(yīng)定義導(dǎo)數(shù)時函數(shù)和自變量之間的維度是正交關(guān)系,導(dǎo)數(shù)中每個元素都是標量求導(dǎo)
矩陣之間、矩陣向量之間不存在導(dǎo)數(shù)的原因是我們將定義范圍限制在最高二維矩陣中,所以維度不能超過2,此類導(dǎo)數(shù)定義不存在。對于矩陣乘法 \(\mathbf{Y}_{B , O} = \mathbf{X}_{B , I} \mathbf{W}_{O , I}^T\) ,我們當然希望能夠以矩陣形式優(yōu)雅地推導(dǎo)整體公式,但矩陣-矩陣之間并不存在導(dǎo)數(shù)的定義,需要頻繁在標量-向量-矩陣之間切換,或者引入許多額外 Nabla 在矩陣作用的關(guān)系,導(dǎo)致推理過程別扭不美觀。
因此引入張量,允許變量維度超過 2 維,便可定義任意兩個張量之間的導(dǎo)數(shù)。注意知道 \(\mathbf{Y}\) 的維度 B 來自 \(\mathbf{X}\),維度 O 來自 \(\mathbf{W}\),但從函數(shù)角度來看,只是定義了一個 \(R^{B\times I} , R^{O \times I} \rightarrow R^{B \times O}\) 的函數(shù),應(yīng)將這些維度用不同的符號區(qū)分。
定義矩陣乘法和維度:
已知損失函數(shù) F 相對輸出張量梯度:
則有
這個過程中假設(shè)張量也滿足鏈式法則,且鏈式法則傳遞關(guān)系通過張量乘法也就是 Einsum 規(guī)約相同符號維度,是否嚴格成立需要補充證明。
接下來需要求解具體 \(\frac{\partial \mathbf{Y}}{\partial \mathbf{W}}\) ,不幸的是仍然需要拆分到標量求導(dǎo)。這種定義方法只包含兩個層次,多維度的張量表示,以及具體計算的標量表示。
可得:
換句話說,\(\frac{\partial \mathbf{Y}}{\partial \mathbf{W}}\) 選取任意坐標 \((B=b', I=i)\) 做切片,切出來的 \(R^{O'\times O}\) 的矩陣是單位陣的倍數(shù)。
附帶驗證程序:
import torch
import einops
B = 8
O = 32
I = 256
B_ = B
O_ = O
I_ = I
x = torch.randn(B, I_).to("cuda")
w = torch.randn(O, I).to("cuda").requires_grad_(True)
y = torch.einsum('bi,oi->bo', x, w)
df_dy = torch.randn(B_, O_).to("cuda")
with torch.no_grad():
dy_dw = torch.zeros(B_, O_, O, I).to("cuda")
for b_ in range(B_):
print(f'batch {b_}')
for o_ in range(O_):
for o in range(O):
for i in range(I):
if o_ == o:
dy_dw[b_, o_, o, i] = x[b_, i]
df_dw = torch.einsum('BO,BOoi->oi', df_dy, dy_dw)
if w.grad is not None:
w.grad.zero_()
y.backward(df_dy)
auto_dy_dw = w.grad
if auto_dy_dw is None:
print("Error: w.grad is None. Gradient was not properly calculated.")
else:
diff = torch.abs(auto_dy_dw - df_dw)
print(f'Max diff: {diff.max()}')
這邊建模好處是用統(tǒng)一張量求導(dǎo)運算替代混亂的各種 矩陣-向量-標量 求導(dǎo)公式,但該求導(dǎo)方法普適和矩陣乘法無關(guān),需要額外理解和建模維度之間的約束關(guān)系。到底哪種方便見仁見智了。
或者可以利用這種推導(dǎo)理解為什么直接用維度湊 einsum 表達式就可以直接表示梯度。
import torch
import einops
B = 8
O = 32
I = 256
x = torch.randn(B, I_).to("cuda")
w = torch.randn(O, I).to("cuda")
y = torch.einsum('bi,oi->bo', x, w)
df_dy = torch.randn(B, O).to("cuda")
df_dw = torch.einsum('bo,bi->oi', df_dy, x)

浙公網(wǎng)安備 33010602011771號