張量鏈式法則(下篇):揭秘Transpose、Summation等復雜算子反向傳播,徹底掌握深度學習求導精髓!
本文首發于本人的微信公眾號,鏈接:https://mp.weixin.qq.com/s/eEDo6WF0oJtRvLYTeUYxYg
摘要
本文緊接系列的上篇,介紹了 transpose,summation,broadcast_to 等更為復雜的深度學習算子的反向傳播公式推導。
寫在前面
本系列文章的上篇介紹了張量函數鏈式法則公式,并以幾個簡單的算子為例子展示了公式的使用方法。本文將繼續以更復雜的算子為例子演示公式的使用方法,求解這些算子的反向傳播公式也是我研究張量函數鏈式法則的目的:因為對于 transpose,broadcast_to 這類會根據傳入的參數改變輸出張量維度數量的算子,常規的矩陣鏈式法則公式已無法解決。
常見算子的反向傳播推導(下半部分)
復習一下
張量函數鏈式法則的公式為:
求解步驟為:我們首先需要確定各個張量的形狀,然后把注意力集中到自變量里的某個元素,寫出這個元素的導數表達式,然后再推廣到整個導數張量。
接下來我們繼續常見算子的反向傳播公式推導。
Summation
這個算子是對輸入張量沿著某些軸求和,這個算子有一個參數 axes,表示求和的規約軸,例如,對于一個四維張量 \(X \in \mathbb{R}^{d_1 \times d_2 \times d_3 \times d_4}\),如果 axes=(2, 3),\(F(X) \in \mathbb{R}^{d_1 \times d_4}\),是一個二維張量,且 \(f_{ij} = \sum_{k=1}^{d_2} \sum_{l=1}^{d_3} X_{iklj}\)。
由此可見,對于多個軸的 summation 操作其實可以拆分為多次的對于一個軸的 summation,所以我們僅討論 axes 只有一個軸的公式,對于有多個軸的場景可以將其視為復合函數,通過反復使用該公式來進行擴展。
單軸 Summation 問題描述
所以我們要解決的問題就變成了:函數 \(F\) 會對張量 \(X\) 的第 \(a\) 個維度進行求和,求該函數的反向傳播公式。
(注:本文統一以 1 為起始下標,實際編程時 axes 是以 0 為起始下標,這個差異需要注意)
首先確定各個張量的形狀,如果自變量 \(X \in \mathbb{R}^{d_1 \times d_2 \times \cdots \times d_n}\) 是一個 \(n\) 維張量,那么 \(F(X) \in \mathbb{R}^{d_1 \times d_2 \times \cdots \times d_{a-1} \times d_{a+1} \times \cdots \times d_n}\) 為 \(n-1\) 維張量。
單軸 Summation 問題求解
接下來可以寫出每個自變量的導數的表達式:
注意到,當且僅當 \(\mu_1 = \lambda_1, \mu_2 = \lambda_2, \ldots, \mu_n = \lambda_n\) 時,這個表達式值不為 0,且滿足上述條件時,只有當 \(i = \lambda_a\) 時,求和表達式值為 1,\(i\) 為其他值時都為 0,所以這一項的最終結果是 \(g_{\lambda_1 \lambda_2 \cdots \lambda_n}\)。
所以最終的 \(\nabla = \text{broadcast}(G, a)\),即把張量 \(G\) 在第 \(a\) 個軸做 broadcast_to(broadcast_to 操作的定義見下文)。
當然,這里實際操作時首先要對 \(G\) 做 reshape,把因為求和丟掉的軸 unsqueeze 回來,然后再通過 broadcast_to 操作廣播到 \(X\) 的形狀,具體可以參考下面的具體代碼實現:
a = node.inputs[0]
target_dim_num = len(a.shape)
grad_new_shape = []
for i in range(target_dim_num):
if i in self.axes:
grad_new_shape.append(1)
else:
grad_new_shape.append(a.shape[i])
return broadcast_to(reshape(out_grad, grad_new_shape), a.shape)
多軸 Summation 問題求解
接下來討論 axes 有多個的情形,通過上面的討論,容易想到:只需要把求和規約掉的多個軸通過 reshape 進行 unsqueeze,然后再進行 broadcast 就行了。
實際情況正是如此,以兩個軸為例,這種情況可以認為是兩個單軸 summation 操作的復合,在實際進行反向傳播時,會先傳播到第一個單軸 summation,此時會進行一次 broadcast_to,然后這個結果會作為 \(G\) 繼續傳播到第二個單軸 summation,此時又會進行一次 broadcast_to,最終結果等價于把這兩次 broadcast_to 放到一起完成。
嚴格的數學推導這里就不展開了,留作習題自證不難。
所以對于 Summation,最終的導數結果為:
BroadcastTo
這個算子是對一個張量進行廣播操作,也就是把張量的元素在若干個軸上進行“復制”的操作,形成一個更“充實”的張量。numpy,pytorch 等框架在處理形狀不同的張量時會自動進行廣播操作。例如,\(A\) 的形狀是 \((6, 6, 5, 4)\),\(B\) 的形狀是 \((6, 5, 4)\),在執行 \(A \odot B\) 時,框架會自動在 \(B\) 的左邊補上維度 1,變成 \((1, 6, 5, 4)\),然后再執行廣播變成 \((6, 6, 5, 4)\),然后再做哈達馬積。
這里我們同樣先討論只針對一個軸進行 broadcast_to 的情形,多軸的情形同樣可以視為多個單軸 broadcast_to 的嵌套。
(注:以下討論涉及到的參數和實際編程中的參數有差異,實際編程中是直接傳入 broadcast_to 之后的形狀作為參數)
單軸 BroadcastTo 算子定義
單軸 broadcast_to 算子有兩個參數:
- 參數
a,表示在哪一個軸進行廣播,該算子要求自變量在這一維度的大小為 1 - 參數
b,表示要將這一維度廣播到多大
這一算子的形式化的定義為:
- \(X \in \mathbb{R}^{d_1 \times d_2 \times \cdots \times d_n}\),是 \(n\) 維張量,其中 \(d_a = 1\),\(F(X) = \text{broadcast\_to}(X, a)\)
- 則 \(F(X) \in \mathbb{R}^{d_1 \times d_2 \times \cdots \times d_{a-1} \times b \times d_{a+1} \times \cdots \times d_n}\),其中 \(f_{\lambda_1 \lambda_2 \cdots \lambda_n} = x_{\lambda_1 \lambda_2 \cdots \lambda_{a-1} 1 \lambda_{a+1} \cdots \lambda_n}\)。
直觀上來看就是把 \(X\) 在第 \(a\) 維的元素復制了 \(b\) 份。
單軸 BroadcastTo 問題求解
首先可以確認,\(X\) 和 \(\nabla\) 形狀相同,為 \(\mathbb{R}^{d_1 \times d_2 \times \cdots \times d_{a-1} \times 1 \times d_{a+1} \times \cdots \times d_n}\),\(G\) 和 \(F(X)\) 的形狀相同,為 \(\mathbb{R}^{d_1 \times d_2 \times \cdots \times d_{a-1} \times b \times d_{a+1} \times \cdots \times d_n}\)。
寫出 \(\nabla\) 的表達式可得:
把 \(F\) 的定義式代入,原式子可寫作:
注意到,只有當 \(\mu_1 = \lambda_1, \mu_2 = \lambda_2, \ldots, \mu_{a-1} = \lambda_{a-1}, \mu_{a+1} = \lambda_{a+1}, \ldots, \mu_n = \lambda_n\) 時,求和式不為 0,所以這個式子可以進一步化簡為:
這個表達式的值恰好就等于張量 \(G\) 在 \(a\) 軸做 Summation,所以有:
多軸 BroadcastTo 問題求解
和 Summation 類似,多軸情形下只需要對所有廣播過的軸做 Summation 即可,由此可得,多軸情形下:
其中 \(a_1, a_2, \ldots, a_m\) 是所有經過廣播的軸的編號,具體可以參考以下代碼實現:
old_shape = node.inputs[0].shape
new_shape = self.shape
sum_axes = []
for i in range(len(new_shape)):
if i >= len(old_shape) or (old_shape[i] == 1 and new_shape[i] != 1):
sum_axes.append(i)
return reshape(summation(out_grad, tuple(sum_axes)), old_shape)
Reshape
顧名思義,這個算子的作用就是改變張量的形狀。numpy 對于這個操作的描述是:在不改變數組內容的情況下為數組賦予新的形狀。可以認為 numpy 存儲的多維張量本質上是一個連續的一維數組,形狀只是我們看這個數組的一個視角,以二維張量為例,假設這個一維數組是 \([1,2,3,4,5,6]\),如果以 \(2 \times 3\) 矩陣的視角去看,那就會是:
如果以 \(6 \times 1\) 的矩陣視角去看,那就會是:
Reshape 問題求解
這里我們可以猜一下,以三維張量為例,\(\nabla, X \in \mathbb{R}^{e_1 \times e_2 \times e_3}\),\(G, F(X) \in \mathbb{R}^{d_1 \times d_2 \times d_3}\),其中 \(e_1 \times e_2 \times e_3 = d_1 \times d_2 \times d_3\)。
注意到 \(\nabla\) 和 \(G\) 的元素數量相同,只是形狀不同,那只需要進行一次 reshape 即可。
事實正是如此,對于 \(F(X) = \text{reshape}(X, \text{new\_shape})\),其反向傳播導數:
這里具體的數學推導就不再贅述了,留作習題供讀者練習。
(提示:可以考慮定義一個輔助函數,將原來軸的參數映射到新的軸上的參數)
Transpose
這一算子的定義是做轉置,二維矩陣的轉置很顯然,就是行列互換。推廣到 \(n\) 維張量,就是選擇兩個軸,然后在這兩個軸上做互換。
(注:這里的 transpose 是 CMU Homework1 里面定義的,而非 numpy 里的定義,這里只會轉置兩個軸,但是這里推導得到的結果可以輕易推廣到多軸的情形)
Transpose 形式化定義
- 這一算子有 2 個參數
a和b,表示需要轉置的兩個軸 - 若 \(X \in \mathbb{R}^{d_1 \times d_2 \times \cdots \times d_n}\),是 \(n\) 維張量
- 則 \(F(X) \in \mathbb{R}^{d_1 \times d_2 \times \cdots \times d_{a-1} \times d_b \times d_{a+1} \times \cdots \times d_{b-1} \times d_a \times d_{b+1} \times \cdots \times d_n}\) 也是 \(n\) 維張量,只是第 \(a\) 維和第 \(b\) 維的大小互換了
- 且其中:
Transpose 問題求解
這里也很容易才到,對 \(G\) 做同樣的轉置即可得到,這里同樣不展開贅述了,留作習題供讀者練習。
(提示:同樣可以考慮定義映射軸的輔助函數來解決)
MatMul
這一算子是矩陣乘法,二維矩陣的公式已經在上一篇文章里給出,這里主要補充一下 batch 模式下的矩陣乘法。根據 numpy 里的定義,進行 MatMul 的兩個張量 \(X\),\(Y\) 可以是兩個高維的張量,例如,當 \(X\) 的形狀為 \((6, 6, 5, 3)\),\(Y\) 的形狀為 \((6, 6, 3, 4)\) 時,會把 \(X\) 視為是 36 個 \(5 \times 3\) 矩陣按照 \(6 \times 6\) 的格式排列,然后把 \(Y\) 視為 36 個 \(3 \times 4\) 的矩陣按照 \(6 \times 6\) 排列,最后將兩個大矩陣中對應位置的兩個小矩陣做矩陣乘法,最終會得到 36 個 \(5 \times 4\) 的小矩陣,組成一個形狀為 \((6, 6, 5, 4)\) 的張量。
這一操作同樣支持廣播,即:如果 \(X\) 形狀為 \((6, 6, 5, 3)\),\(Y\) 的形狀為 \((3,4)\),那么最終結果會是形狀為 \((6, 6, 5, 4)\) 的張量,即 \(X\) 的 36 個小矩陣每一個都和 \(Y\) 做矩陣乘法。
這種情形下,如果記單個矩陣乘法的函數為 MatMul,批量矩陣乘法函數為 MatMul_Batch,那么此時 MatMul_Batch 實際上是 MatMul(X, broadcast_to(Y, X.shape)),所以在處理 MatMul_Batch 對 \(Y\) 求導時,需要考慮到這里實際上是嵌套了一層廣播的,而廣播的反向傳播是做 Summation,所以在套用單矩陣 MatMul 的反向傳播公式之后還需要做一個 Summation 將形狀變回和 \(Y\) 相同的形狀,具體過程可以參考如下的代碼實現:
(注:理論上是需要先做 Summation 再做 Matmul 的反向傳播,但是先做 Summation 和后做是等價的,為了代碼實現方便就統一放到后面來做了)
a, b = node.inputs
a_grad, b_grad = matmul(out_grad, transpose(b)), matmul(transpose(a), out_grad)
if len(a_grad.shape) > len(a.shape):
sum_axes = tuple((i for i in range(len(a_grad.shape) - len(a.shape))))
a_grad = summation(a_grad, sum_axes)
if len(b_grad.shape) > len(b.shape):
sum_axes = tuple((i for i in range(len(b_grad.shape) - len(b.shape))))
b_grad = summation(b_grad, sum_axes)
return a_grad, b_grad
一些剩下的簡單算子
接下來放一些簡單算子的反向傳播公式,這里就只給出結果而省略推導過程了。
Negate
這個算子是把張量中所有元素取相反數,很顯然:
Log
這個算子是對張量中所有元素取自然對數,很顯然:
Exp
這個算子是對張量中所有元素過一次指數函數 \(y = e^x\),很顯然:
EWisePow
這個算子接收 2 個相同形狀的自變量 \(X\) 和 \(Y\)(如果形狀不同會進行廣播到相同形狀),對于 \(X\) 里的每一個 \(x\),取 \(Y\) 對應位置上的元素 \(y\),做 \(x^y\)。
很顯然:

浙公網安備 33010602011771號