[LLM] ZeRO-DP技術簡析
[LLM] ZeRO-DP技術簡析
本文對ZeRO: Memory Optimizations Toward Training Trillion Parameter Models中提出的ZeRO-DP進行簡要總結。相關的講解其實網上也有很多了,不過只看網上的終究還是有點走馬觀花,所以我還是決定自己寫一篇博客,記錄一下我自己的理解。這篇博客講的不會太細,但是希望能用更易于理解的方式,講明白文中的重要內容。
為什么需要ZeRO-DP?
-
數據并行(DP)是分布式訓練中最基本的并行方式,它通過把數據分發到不同的GPU上從而提升效率。但數據并行不會降低每個GPU的顯存開銷。在一個數據并行組中,不同的GPU保存的模型參數、優化器狀態、和梯度其實都是同一份。每次迭代時,需要對模型參數進行All-Reduce來同步狀態。
-
為了避免存儲冗余狀態,降低顯存開銷,ZeRO-DP選擇把這些狀態也分割到不同的GPU上(注意:這不同于模型并行MP。ZeRO-DP本質上還是DP,它是把狀態在DP組內進行分割,它可以于MP同時存在。)在前向傳播的時候,每個GPU從其他GPU那里獲取到全部狀態并進行計算;在反向傳播的時候,只把劃分后的狀態發給每個GPU。
概述
-
圖中,\(\Psi\)代表模型參數量,圖中使用fp16參數,所以模型參數占用內存為\(2\Psi\);\(N_d\)表示DP度數(DP組的大小);\(K\)表示優化器狀態的參數量是模型參數量的多少倍,圖中使用Adam優化器中\(K=12\)。
-
ZeRO-DP一共分為三個階段:
- \(P_{os}\)對優化器狀態進行劃分。
- \(P_{os+g}\)對優化器狀態和梯度進行劃分。
- \(P_{os+g+p}\)對優化器狀態,梯度和模型參數進行劃分。
-
圖中可以明顯的看出每個階段的劃分所帶來的顯存降低收益。
通信量分析
- 很明顯的,ZeRO-DP將狀態劃分到不同的GPU上,從而降低了顯存開銷。但是在這個過程中,拉取和分發狀態是否會導致額外的通信開銷呢?所以我們來分析一下ZeRO-DP的通信開銷。
前置知識
-
為了方便,我們這里先不考慮模型并行MP,只考慮數據并行DP。這里的通信開銷指的是每臺GPU所需的通信量。
-
All-reduce的通信開銷是\(2\Psi\)。Reduce-scatter和All-gather的通信開銷都是\(\Psi\)。
傳統DP的通信開銷
在下面圖中,\(D\)表示數據,\(P\)表示模型參數,\(G\)表示梯度,\(O\)表示優化器參數。下標表示數據劃分的第\(i\)塊,上標表示模型劃分的第\(j\)塊。這里只考慮2個GPU。
在傳統DP中,正向傳播不需要任何通信。但是在反向傳播中,由于所有GPU上的模型參數是副本關系,所以它們要進行All-reduce完成同步,所需通信量是\(2\Psi\)?。
\(P_{os+g}\)?的通信開銷
\(P_{os}\)和\(P_{os+g}\)的通信量相同。在前向,每個GPU都能計算完整的梯度。在反向,需要對梯度進行reduce-scatter,每個GPU對自己的部分梯度進行聚合,使用自己的優化器得到參數。最后再對參數進行all-gather發給每個GPU。總的通信量為\(\Psi+\Psi=2\Psi\)?,和傳統DP是一樣的。
\(P_{os+g+p}\)的通信開銷
接著考慮對模型參數進行劃分。在前向,在一開始額外對參數進行一次all-gather,使每個GPU獲取到全部的參數。在反向,依然對梯度進行reduce-scatter。
注意到模型有很多層。在前向,在我們使用了一層的全部參數計算完成后,我們可以直接釋放掉這些參數的顯存,接著算后面的層,防止這些參數一直占用著顯存。但這樣的話,在反向,我們需要再進行一次all-gather重新獲得這一層的參數才行。因此,總的通信量是\(2\Psi+\Psi=3\Psi\)。
| 歡迎來原網站坐坐! >原文鏈接<

浙公網安備 33010602011771號