參數更新
1. loss
是一個單值
假設輸入的詞元id是[0, 1]
目標詞元id是[1, 2]
也就是根據輸入得到兩個預測輸出,
注意上面的是id,每個id實際上是一個嵌入向量,比如768維向量,
假設詞匯表是3,實際詞匯表可能是5w
通過模型矩陣計算后,對于輸入的每一個位置,都會輸出一個3維度的向量,對齊進行softmax選擇最大的概率作為預測輸出,
這里輸入序列有兩個詞元,因此會預測出兩個結果,實際上是兩個3維度的概率向量,比如[[0.320, 0.333, 0.347], [0.301, 0.332, 0.367]],
這兩個概率向量表明,預測輸出都是2的概率最大
但實際上目標值第一個是1,第二個是2,
計算損失實際上是根據目標詞元id,對預測結果中對應位置的概率求負對數
位置0: -log(0.333) ≈ 1.100
位置1: -log(0.367) ≈ 1.003
平均損失 = (1.100 + 1.003)/2 = 1.0515
其意義是,如果目標詞元位置的概率很大,說明預測的準,那么這個熵損失值就很小,概率趨于1損失就趨于0,如果預測的不準就是概率小,那么損失值就很大,概率趨于0,損失值就趨于無窮大,
2. 梯度
2.1 損失對logits的梯度
交叉熵損失梯度公式:?L/?logits = softmax(logits) - one_hot(target)
位置0(目標=1)
one_hot(1) = [0, 1, 0]
?L/?logits0 = [0.320, 0.333, 0.347] - [0, 1, 0] = [0.320, -0.667, 0.347]
位置1(目標=2):
one_hot(2) = [0, 0, 1]
?L/?logits1 = [0.301, 0.332, 0.367] - [0, 0, 1] = [0.301, 0.332, -0.633]
可以看到,目標位置的梯度為負,且預測的越準的話,這個梯度絕對值就越小,那么在進行梯度下降是動作就要“輕微”點
非目標位置的梯度是正值,且非目標位置如果概率越大,表明越不準確,那么進行梯度下降時這個地方要“劇烈”點
2.2 損失對參數的梯度
?L/?W = 嵌入^T × (?L/?logits)
假設輸入嵌入矩陣是n*d,那么其轉置是d*n,那么轉置的每一行表示了n個位置每個位置的一部分,
(?L/?logits)是n*w矩陣,表示對于輸入的n個位置,每個位置對于詞匯表每個詞匯的預測概率相應的損失,
這兩矩陣相乘,結果是d*w矩陣,
可以用第一個值舉例,這個值是由n個輸入向量取每個第一個值,同時對n個輸出概率向量每個取對詞匯表第一個詞匯的梯度值,進行相乘得到一個標量值,
那么結果的d*w矩陣,第i行第j列,包含了每個位置預測結果中第j個詞元的概率梯度綜合,以及輸入序列嵌入矩陣每個輸入的第i個值,
實際上,這個d*w矩陣就是參數矩陣
2.3
我們再回憶下參數流
嵌入矩陣:w*d
輸入:n*d --h3
#忽略QKV
#QKV矩陣:d*d
# QK得到n*n自注意力矩陣
# 再與V得到n*d矩陣
FFN網絡矩陣1:d*f --W3
得到n*f矩陣 -h2
FFN網絡矩陣2: f*d --W2
得到n*d矩陣 --h1
輸出層矩陣:d*w --W1
得到n*w矩陣
2.3.1
根據n*w矩陣,我們得到對logits的梯度,也就是n*w矩陣,
下面我們反向一步步得到每個參數矩陣的梯度,
對輸出層矩陣參數d*w,因為我們根據n*d矩陣和d*w矩陣相乘得到n*w矩陣,那么需要n*d的轉置與n*w相乘就可得到d*w參數梯度。
其中n*d在有些教程中表示為h(隱藏矩陣),是一種中間態數據,需要報存在顯存
2.3.2
第一步,計算?L/?logits, 這個通過預測結果softmax矩陣與目標序列one-hot矩陣相減得到,是一個n*w矩陣,
第二步,計算?L/?W1 = h1(T) * (?L/?logits) ,是一個d*w矩陣,這個矩陣形狀和W1一樣, -- h1*W1 = n*w
第三步,計算?L/?h1 = (?L/?logits) * W1(T), 是一個n*d矩陣,和h1形狀一樣,
第四步,計算?L/?W2 = h2(T) * ?L/?h1, 是一個f*d矩陣,形狀和W2一樣, -- h2*W2 = h1
第五步,計算?L/?h2 = ?L/?h1 * W2(T), 是一個n*f矩陣,
第六步,計算?L/?W3 = h3(T)* ?L/?h2, 是一個d*f矩陣 , -- h3*W3 = h2, h3就是inputs
第七步,計算?L/?h3 = ?L/?h2 * W3(T), 是一個n*d矩陣,
3. 參數更新
根據上面計算出的梯度,對每一個實體矩陣(嵌入矩陣以及參數矩陣)進行更新,中間態矩陣不用更新。
假設lr=0.01
3.1
更新inputs,也即是h3,也即是詞嵌入
對于輸入序列n個詞的第i個嵌入,
E[i] = E[i] - 0.01*?L/?h3 [i]
或者整體上
E = E - 0.01
3.2 更新 W3
使用與其形狀相同的梯度矩陣乘以學習率,然后將W3與結果相減,得到新的W3
4. 優化器
上面的還沒有提到優化器,現在我們加入優化器
優化器包含兩個狀態m和v,或者叫動量和方差,對于每個參數值,都有一一對應的動量和方差,也就是說,每一個參數值同時對應兩個優化器值,
所有的優化器狀態初始化為0,
另外,還有幾個值以及初始化舉例如下,
學習率 lr=0.01
beta1=0.9, beta2=0.999, 即 β1,β2
epsilon=1e-8, 即 ε
時間步 t=1 (初始)
也就是說,這些值是優化器的一部分,可以看到,學習率成為了優化器的一部分,
使用優化器對每一個參數值進行更新,以W3舉例,
首先已經得到W3的梯度矩陣,和W3的形狀一樣,設為W3Grad,現在對W3[0,0]進行更新,公式
m = β1·m + (1-β1)·grad
v = β2·v + (1-β2)·grad2
m? = m / (1 - β1?)
v? = v / (1 - β2?)
param = param - lr·m? / (√v? + ε)
其中grad就是W3Grad[0,0], m,v也是對應的優化器狀態值,目前初始值0,
這里我們更新的param 就是W3[0,0]
更新完成后,對應的m,v都變為新的值了,全部更新完后,時間t加1
也就是說,每次更新參數時,先將對應優化器值進行更新(根據給定的β1、β2以及計算出來的梯度值和時間步),然后,使用更新后的m、v,lr以及t對參數進行更新,
使用β1對m進行更新的意義,大部分還是保持m當前的值(使用β1乘以當前m,β1值比較接近1),少部分根據梯度值增加,也就是說,如果梯度越大,那么這個動量m變化也越大,梯度可以為負,所以動量可以往小變,一般來說,目標詞元處的梯度為負,其他為正,
同理,梯度越大,方差v變化也越大,這個方差始終往大變
完了以后,這倆公式
m? = m / (1 - β1?)
v? = v / (1 - β2?)
將m和v的值按比例放大,時間步越大,這個放大比例越小,
最后就是更新參數了,方差大的話,更新的幅度小,動量大的話,更新的幅度大。
5 總結
從輸入序列a矩陣到最后輸出z矩陣中間會有x個參數矩陣,總共有x-1個中間態矩陣,或者叫臨時矩陣
我們將臨時矩陣標注為h1,,,h(x-1)
將參數矩陣標注為p1,,,px
a * p1 = h1,
h1* p2 = h2,
h(x-1)* px = z
最后的 z 是一個經過softmax后是一個概率矩陣,如果輸入序列是n * d維度,那么z 就是 n * w維度,其中n是詞元數,d是嵌入維度,w是詞匯量
根據目標序列,將每個目標詞元轉為w維度的one-hot向量,組成一個n*w矩陣,與z進行減法運算,計算出梯度,實際的意義就是在目標詞元處的概率如果大,梯度就小,如果概率小,梯度就大。
然后進行反向傳播,
從z開始,依次計算px、h(x-1)、p(x-1),h(x-2),,,,,一直到p1,a的梯度,
最后根據梯度進行參數更新,注意,a實際上對應的是嵌入式表中輸入詞元對應的行,
我們用h0表示a,hx表示z,
要計算pi的梯度,前提是hi梯度已經得到,
根據公式h(i-1) * pi = hi,
得到pi = T(h(i-1)) * hi, 將hi的梯度帶入,就得到pi的梯度,
同理h(i-1) = hi * T(pi) ,將hi的梯度帶入得到h(i-1)的梯度。
posted on 2025-06-23 16:56 longbigfish 閱讀(22) 評論(0) 收藏 舉報
浙公網安備 33010602011771號