<output id="qn6qe"></output>

    1. <output id="qn6qe"><tt id="qn6qe"></tt></output>
    2. <strike id="qn6qe"></strike>

      亚洲 日本 欧洲 欧美 视频,日韩中文字幕有码av,一本一道av中文字幕无码,国产线播放免费人成视频播放,人妻少妇偷人无码视频,日夜啪啪一区二区三区,国产尤物精品自在拍视频首页,久热这里只有精品12

      [MoE] Tutel源碼解讀

      [MoE] Tutel源碼解讀

      前言

      最近MoE變得火了起來。但我在和別人討論MoE時,總有一些說不清楚地方,就算讀了論文也一知半解。于是我決定還是要看一看具體的代碼,看看每個細節究竟都是怎么實現的。

      作為實現參考,Tutel這篇工作就很不錯。最近的工作基本都拿Tutel作為Baseline比較,于是我決定讀一讀Tutel的源代碼,學習一下MoE編程。

      在讀這篇博客之前,希望你已經大致讀過Tutel的論文。如果對下面的代碼有不清楚的地方,建議多參考下面的幾張圖。

      Tutel論文

      Tutel源碼

      論文理解

      Tutel的主要貢獻就是提供了若干種可切換的并行策略。那么首先,我們就來理解一下Tutel的并行策略。

      Tutel考慮了所有DP(數據并行),MP(模型并行)和EP(專家并行)的組合,并分析它們的通信復雜度,最終選擇了最優的兩種:DP和EP+DP+MP。

      其中\(C_g\)是每個expert的容量(處理token的數量),\(P\)是所有expert的總參數量,\(W\)是GPU數量,\(E\)是expert數量,\(r\)為MP的度數。

      我們再看圖理解一下這兩種并行策略

      DP

      Tutel的DP參考了ZeRO-DP的stage 3,即對模型參數的劃分。關于ZeRO-DP的理解可以參考我之前的博客:ZeRO-DP技術簡析。

      ZeRO-DP的主要好處是避免了在每個GPU上都保存完整的模型(以及優化器狀態,主要是優化器狀態占顯存較高),造成冗余的保存,導致顯存浪費。簡單來說,在(修改版)ZeRO-DP中,每個GPU都只保存一部分模型。在前向對模型參數進行all-gather,使每個GPU獲取到全部參數;在反向對梯度進行reduce-scatter,每個GPU都只更新自己的那部分模型參數。

      (修改版)ZeRO-DP的通信量是\(2P\)?,與傳統DP(需要一次all-reduce,通信量也是\(2P\))相同。

      EP+DP+MP

      相對的,EP+DP+MP就比較復雜了。由于原論文的圖有點過于簡略,我這里手動畫了一個圖。圖中\(W=8,E=2,r=2\),圖里只畫了前5個GPU。

      其中\(X\)表示輸入數據。\(X_i\)是第\(i\)個GPU上的輸入數據。\(X_i^{j}\)為原本在第\(i\)個GPU上,要發往第\(j\)個expert的處理的數據。為了實現MP,Tutel把數據復制了\(r\)份。因為這里\(r=2\),所以將\(X_i^j\)復制為\(X_i^{ja}\)\(X_i^{jb}\)??。

      \(E\)表示expert,其中\(E_i\)為第\(i\)個expert。圖中\(W/E=4\),所以Tutel將一個模型切成4塊,即將\(E_0\)切成\(E_0^{\alpha0}, E_0^{\alpha1}, E_0^{\beta0},E_0^{\beta1}\)。又由于\(r=2\),所以要把這4塊分成2組做MP,即\([E_0^{\alpha0}, E_0^{\alpha1}]\)一組,\([E_0^{\beta0},E_0^{\beta1}]\)一組。在每個MP組內部,還要做ZeRO-DP,即在前向傳播時,\(E_0^{\alpha0}, E_0^{\alpha1}\)進行all-gather,得到\(E_0^{\alpha}\)?。

      \(Y\)表示輸出,\(Y_{i}^{0a, \alpha}\)只這個輸出來自于\(i\)上的輸入,是第\(a\)個復制,發往expert 0,被第\(\alpha\)塊處理得到的結果。

      這里解釋一下圖中的EP,DP,MP究竟是怎么運作的。

      • 首先,GPU0-3是一個EP Group,它們每個GPU都擁有expert 0的1/4切片。
      • 不過雖然GPU0-3拿到的都是1/4切片,不過取決于并行策略不同,這些1/4切片的同步方式也不同。比如在上圖中,\(r=2\),所以有2個MP Group:GPU0-1,GPU2-3。
      • 在每個MP Group內部,做的是ZeRO-DP。例如GPU0和1,它們在前向時需要對expert參數進行All-gather,GPU0和1都拿到expert 0的前1/2參數,然后使用不同的數據進行計算。
      • 而在不同的MP Group間做的是MP。例如GPU0-1和GPU2-3,它們各自具有expert 0的1/2參數。在前向時,這兩個MP Group使用相同的數據進行計算,最后再將它們的結果相加。
      • 最后看整體的EP。
        • 在前向,首先是Gate,決定每個token發往哪個expert。
        • 接著進行把數據復制\(r=2\)?份,然后準備進行第一次all-to-all(dispatch)。
          • 在all-to-all(dispatch)時,要先考慮EP:將發往相同expert的token放到一起。
          • 然后考慮ZeRO-DP,例如這里ZeRO-DP的度數(即MP Group的大?。?,因此將expert 0的前1/2個token發給GPU0和2,后1/2個token發給GPU1和3。
        • 在expert計算完成后,再進行all-to-all(combine),將token發回原來的位置。

      復雜度:在前向和反向各需要兩次all-to-all,又由于數據被復制了\(r\)倍,所以開銷是\(4C_gr\)。另外,由于引入了EP和MP,all-gather的規??s小了\(E\times r\)倍,所以開銷是\(P/E/r\)。因此總通信量是\(4C_gr+2P/E/r\)??。

      注意到在上面兩張圖中,每個GPU保存的模型參數是相同的,因此Tutel可以實現無縫切換并行策略。只需修改參數\(r\)即可:

      PP

      關于Pipeline Paralism,Tutel的處理方法如下

      MoE的PP與傳統PP的一個主要區別在于:傳統的PP是以層為粒度的;而MoE的PP要比層更細,是在層內的Dispatch-FFN-Combine之間做PP。因此體現在代碼上面,傳統的PP直接調度不同層就可以了,不用改每層的邏輯;但MoE的PP必須要寫在層內的處理邏輯里,看起來會稍微麻煩一些。

      代碼閱讀

      API

      首先看一下Tutel的API,看看一下大致有哪些參數

      * Usage of MOELayer Args:
              gate_type        : 對gate的描述
              gate_type        : dict-type gate description, e.g. {'type': 'top', 'k': 2, 'capacity_factor': -1.5, ..},
                                    or a list of dict-type gate descriptions, e.g. [{'type': 'top', 'k', 2}, {'type': 'top', 'k', 2}],
                                    the value of k in top-gating can be also negative, like -2, which indicates one GPU will hold 1/(-k) parameters of an expert
                                    capacity_factor X can be positive (factor = X), zero (factor = max(needed_volumes)) or negative (factor = min(-X, max(needed_volumes))).
              model_dim        : MoE輸入的維度
              model_dim        : the number of channels for MOE's input tensor
              experts          : 對expert的描述,具體選項在下面
              experts          : a dict-type config for builtin expert network
              scan_expert_func : 在初始化時,對expert的每個參數執行此函數
              scan_expert_func : allow users to specify a lambda function to iterate each experts param, e.g. `scan_expert_func = lambda name, param: setattr(param, 'expert', True)`
              result_func      : 在forward末尾,對輸出執行此函數
              result_func      : allow users to specify a lambda function to format the MoE output and aux_loss, e.g. `result_func = lambda output: (output, output.l_aux)`
              group            : all-to-all的world
              group            : specify the explicit communication group of all_to_all
              seeds            : 種子,一個三元組
              seeds            : a tuple containing a tripple of int to specify manual seed of (shared params, local params, others params after MoE's)
              a2a_ffn_overlap_degree : 對應上圖中PP的度數
              a2a_ffn_overlap_degree : the value to control a2a overlap depth, 1 by default for no overlap, 2 for overlap a2a with half gemm, ..
              parallel_type    : 并行策略,可以是'data', 'model', 'adaptive:x', 或者'auto'
              parallel_type    : the parallel method to compute MoE, valid types: 'auto', 'data', 'model'
              pad_samples      : deprecated
      
      * Usage of dict-type Experts Config:
      
              這些都比較好理解,就不解釋了
              num_experts_per_device : the number of local experts per device (by default, the value is 1 if not specified)
              hidden_size_per_expert : the hidden size between two linear layers for each expert (used for type == 'ffn' only)
              type             : available built-in experts implementation, e.g: ffn
              activation_fn    : the custom-defined activation function between two linear layers (used for type == 'ffn' only)
              has_fc1_bias     : If set to False, the expert bias parameters `batched_fc1_bias` is disabled. Default: True
              has_fc2_bias     : If set to False, the expert bias parameters `batched_fc2_bias` is disabled. Default: True
      

      接著我們直接看主體部分,即MOELayer,代碼位于 tutel/impls/moe_layer.py

      Gate

      首先一上來是Gate

      def __init__(self, ...):
          # 初始化experts
          # 對于ffn,expert_module為fused_experts.ExpertModule
          expert_modules = expert_module(**experts)
          # 初始化gate
          # 對于gate_type為top-k的情況,single_gate具體用的是LinearTopKGate
          gate_module = single_gate(model_dim=self.model_dim, num_global_experts=self.num_global_experts, **single_gate_type)
          
      def forward(self, input: Tensor, ...):
          # 記住最開始輸入的形狀
          original_shape, original_dtype  = input.shape, input.dtype
          assert len(original_shape) >= 2, "Input data must be at least 2D tensor: (s)amples, .., (m)odel_dim"
          # reserve_dims默認為1,用途為將輸入數據x轉換為2d,保持最后一個維度不變
          # 如(10, 20, 300) -> (200, 300)
          # 為了方便,我們記轉換后的x形狀為(s, h)
          x = input.reshape(-1, original_shape[-reserve_dims:].numel())
          
          # 定義路由函數
          def routing():
              # 經過gate
              # 我們記expert的數量為e,GPU數量為w
              # 則num_global_experts=e
              # 則logits形狀為(s, e)
              logits = gctx(x)
              
              # 對logits加noise的結果求softmax,得到scores
              scores = F.softmax(logits_w_noise, dim=1)
              
              # 省略一些對logits的處理
              
              # 默認self.sharded_count = self.world_size // self.num_global_experts
              # 代表每個expert被切塊的數目
              # 這個切塊要么用于ZeRO-DP,要么用于MP
              # 所以sharded_count*a2a_ffn_overlap_degree為每個expert的副本數量
              mega_up = max(megablocks_size, 1)
              alignment = (self.sharded_count * a2a_ffn_overlap_degree + mega_up - 1) // mega_up * mega_up
              if alignment > 256:
                  alignment = (alignment + 127) // 128 * 128
      
              # extract_critical會計算出每一個token要發往的expert,以及在expert內的編號
              return logits.dtype, extract_critical(scores,
                  top_k = top_k,
                  loss_fn = _loss_fn,
                  capacity_factor = capacity_factor or gctx.capacity_factor,
                  batch_prioritized_routing = self.batch_prioritized_routing,
                  normalize_gate = self.normalize_gate,
                  group = self.group,
                  alignment = alignment,
                  inequivalent_tokens = inequivalent_tokens,
              )
      

      這里有必要詳細解釋一下extract_critical函數

      def extract_critical(scores, top_k, ...):
          # 對scores求topk的索引
          # topk_indices的形狀為(s, k)
          topk_indices = torch.topk(scores, top_k, dim=1).indices
          
          # 將topk_indices轉換為列表
          # indices_s長度為k的列表,每個元素是形狀為s的tensor
          indices_s = [x.view(-1) for x in topk_indices.chunk(top_k, dim=1)]
          
          # 計算one-hot編碼
          # masks_se中的每個元素是一個形狀為(s, e)的tensor,
          # 若第i個token的t第1個expert是j,則對應masks_se的第1個元素的第(i, j)位是1;否則是0
          masks_se = [losses._one_hot_with_dtype(x, num_classes=num_global_experts, dtype=x.dtype) for x in indices_s]
          
          # gates_s的每個元素形狀為(s)
          gates_s = [(scores * x).sum(dim=1) for x in masks_se]
          
          # top-k的loss
          l_loss = loss_fn(scores, topk_indices) if loss_fn is not None else None
          
          # 計算location,其中compute_location = fast_cumsum_sub_one,即對維度0求前綴和再減1
          # locations_s的元素形狀為(s, e),其中(i,j)的值>=0表示token i是發往expert j的第幾個token
          locations1 = compute_location(masks_se[0])
          locations_s = [torch.sum(locations1 * masks_se[0], dim=1).to(torch.int32)]
          
          # 為top 1..k都計算locations_s,將結果求和
          if top_k > 1:
              acc_base = None
              for k in range(1, top_k):
                  # acc_base是這個expert的top0..k-1的token數量,形狀為(1, e)
                  acc_base = torch.sum(masks_se[k - 1], dim=0, keepdim=True) if acc_base is None else acc_base + torch.sum(masks_se[k - 1], dim=0, keepdim=True)
                  locations2 = compute_location(masks_se[k])
                  # locations_s的元素表示當前token的top-k是expert發往expert j的第幾個token(考慮所有的top_k)
                  locations2 += acc_base
                  locations_s.append(torch.sum(locations2 * masks_se[k], dim=1).to(torch.int32))
          locations2 = locations2[-1] + 1
          
          # num_samples = s
          num_samples = int(scores.size(0))
          
          samples_per_expert = (num_samples + num_global_experts - 1) // num_global_experts
          if capacity_factor > 0:
              # 若capacity_factor>0,則根據capacity_factor計算每個expert的capicity
              capacity = top_k * int(capacity_factor * samples_per_expert)
          else:
              # 若capacity_factor=0,expert的capicity是所有expert的capacity的最大值
              capacity = locations2.max()
              capacity = int(simple_all_reduce(capacity, group=group, op=torch.distributed.ReduceOp.MAX))
              if capacity_factor < 0:
                  # 若capacity_factor>0,則capacity_factor是capicity的upper_bound
                  capacity = min(capacity, top_k * int(-capacity_factor * samples_per_expert))
                  
          # 對齊到所有expert副本的數量
          remainder = capacity % alignment
          if remainder > 0:
              capacity = capacity + alignment - remainder
              
          return (num_global_experts, indices_s, locations_s, gates_s, capacity, locations2), l_loss
      

      Encode

      繼續看forward

      def forward(self, input: Tensor, ...):
          # 在routing之后
          logits_dtype, (crit, l_aux) = routing()
          
          # fast_encode內部使用kernel進行encode操作
          # 用c代表expert的capacity。則encode后y的形狀為(e, c, h)
          y = fast_encode(x.to(logits_dtype), crit, self.is_postscore).to(x.dtype)
      

      接下來就和并行策略有關了

      def forward(self, input: Tensor, ...):     
          # 在encode之后
          if self.adaptive_degree == 0:
              # 只有DP,在前向不需要傳任何東西
              y = self.expert_local(y, original_shape[-reserve_dims:])
          else:
              if self.auto_parallel:
                  # 比較數據量和模型參數量,
                  # 因為DP+EP通信量是4*數據量+2*參數量,若MP度數維r,則DP+EP+MP的通信量為
                  # 4*數據量*r+2*參數量/r,所以如果2*數據量<參數量,而可以使用MP
                  self.use_model_parallel = (y.numel() * (self.sharded_count - 1) * 2 < sum([x.numel() for x in self.experts.parameters()]))
      
              if self.num_global_experts < self.world_size:
                  if self.use_model_parallel:
                      # 記adaptive_degree=r(參照論文)。把數據復制r份
                      # 即上面圖中的[X_0^{0}]復制為[X_0^{0a},X_0^{0b},X_0^{0a},X_0^{0b}]
                      # y的形狀為(w, c*e*r/w, h)
                      y = y.repeat(1, self.adaptive_degree, 1).view(self.world_size, -1, y.size(2))
                  else:
                      # 記world_size為w
                      # 將y的形狀改為(w, c*e/w, h)
                      y = y.view(self.world_size, -1, y.size(2))
      
              if a2a_ffn_overlap_degree > 1 and y.is_cuda:
                  def expert_fn(expert_input):
                      return self.expert_local(expert_input, original_shape[-reserve_dims:])
                  # 在all-to-all和FNN之間做overlap
                  y = a2a_ffn_overlap_forward(y, expert_fn=expert_fn, a2a_ffn_overlap_degree=a2a_ffn_overlap_degree, use_2dh=self.use_2dh, group=self.group)
              else:
                  # 不做overlap
                  y = C.all_to_all(y, 1, 0, use_2dh=self.use_2dh, group=self.group)
                  y = self.expert_local(y, original_shape[-reserve_dims:])
                  y = C.all_to_all(y, 0, 1, use_2dh=self.use_2dh, group=self.group)
      
              if self.num_global_experts < self.world_size:
                  if self.use_model_parallel:
                      y = torch.sum(y.view(self.num_global_experts, self.adaptive_degree, -1, y.size(2)), dim=1)
                  else:
                      y = y.view(self.num_global_experts, -1, y.size(2))
      

      接下來詳細講一講兩個all-to-all和expert_local的部分,為了方便,我們先從不做overlap的開始看,也不考慮2dh。

      第一次all-to-all

      在不做overlap的分支里,第一次all-to-all的內部如下

      # 根據前文所講,reshaped_input的形狀是(max(e, w), -1, h)
      reshaped_input = input
      # 異步的調用pytorch的all_to_all_single函數,執行all-to-all
      output, f_wait = simple_all_to_all(reshaped_input, group, background=True)
      # all_to_all_single的接口如下
      #   def all_to_all_single(
      #       output,
      #       input,
      #       output_split_sizes=None,
      #       input_split_sizes=None,
      #       group=group.WORLD,
      #   ):
          """
          Each process splits input tensor and then scatters the split list to all processes in a group.
      
          Then concatenate the received tensors from all the processes in the group and return single output tensor.
      
          Arguments:
              output (Tensor): Gathered concatenated output tensor.
              input (Tensor): Input tensor to scatter.
              output_split_sizes: (list[Int], optional): Output split sizes for dim 0
                  if specified None or empty, dim 0 of ``output`` tensor must divide
                  equally by ``world_size``.
              input_split_sizes: (list[Int], optional): Input split sizes for dim 0
                  if specified None or empty, dim 0 of ``input`` tensor must divide
                  equally by ``world_size``.
      
          Returns:
              Tensor: Output of the collective.
      
          """
      # 在Tutel調用all_to_all_single時,并沒有指定output_split_sizes和input_split_sizes
      # 也就是說all-to-all是直接按照第0個維度進行的
      # 在all_to_all_single之后,output的形狀還是(max(e, w), -1, h),
      # 只不過第0維的意義不再是``發往哪個expert/GPU'',而是``來自哪個expert/GPU''
      
      def f_async():
          # 等待all-to-all結束
          f_wait()
          # local_input = output
          local_input = RestoreBackward.apply(output, reshaped_input)
          # 將local_input的形狀變為(w, num_local_experts, -1, h)
          # 其中num_local_experts=max(e/w, 1)
          local_input = local_input.view([world_size, -1] + list(local_input.shape[1:]))
          # 這里input_dim=1
          # 作用是將local_input的前兩個維度對調,形狀變為(num_local_experts, w, -1, h)
          local_input = local_input.permute(list(range(1, input_dim + 1)) + [0] + list(range(input_dim + 1, local_input.dim())))
          # 將local_input的形狀變為(num_local_experts, -1, h)
          local_input = local_input.contiguous().view(list(local_input.shape[:input_dim]) + [-1] + list(local_input.shape[input_dim + 2:]))
          return local_input
      
      # 異步執行上述過程
      return f_async
      

      Expert

      expert_local內容如下

      def expert_local(self, x, reserve_shape):
          # 輸入形狀為(num_local_experts, -1, h)
          y = self.experts(x.view(x.size(0), x.size(1), *reserve_shape), self)
          self.protected_shape = y.shape
          # 輸出形狀為(num_local_experts, -1, h)
          return y.reshape(y.size(0), y.size(1), -1)
      

      我們一會再看expert內的細節,先繼續往下。

      第二次all-to-all

      緊接著是第二個all-to-all

      # 輸入形狀為(num_local_experts, -1, h)
      # 這里output_dim=1
      # reshaped_input形狀為(num_local_experts, w, -1, h)
      reshaped_input = input.view(list(input.shape[:output_dim]) + [world_size, -1] + list(input.shape[output_dim + 1:]))
      # 將reshaped_input前兩個維度對調,形狀變為(w, num_local_experts, -1, h)
      reshaped_input = reshaped_input.permute([output_dim] + list(range(output_dim)) + list(range(output_dim + 1, reshaped_input.dim())))
      # 進行simple_all_to_all
      output, f_wait = simple_all_to_all(reshaped_input, group, background=True)
      
      def f_async():
          f_wait()
          local_input = RestoreBackward.apply(output, reshaped_input)
          # 將local_input形狀變為(max(w, e), -1, h)
          local_input = local_input.view([-1] + list(local_input.shape[2:]))
          return local_input
      
      return f_async
      

      Decode

      再回到forward函數,最后是decode

      def forward(self, input: Tensor, ...):
          # 在[all-to-all, FFN, all-to-all]之后
          
          # 使用kernel進行decode
          # decode之后y的形狀為(-1, h)
          y = fast_decode(y.to(logits_dtype), crit, self.is_postscore)
      
          # 將輸出的形狀變得與最開始輸入的形狀相同
          y = y.view(list(original_shape[:-reserve_dims]) + list(self.protected_shape[-reserve_dims:])).to(original_dtype)
          self.l_aux = y.l_aux = l_aux
          return self.result_func(y) if self.result_func is not None else y
      

      到目前為止,MOELayer的流程我們已經走下來了,其中每個tensor的形狀我們也都大致了解了。不過,我們還沒有講兩個東西:

      • expert_local中,expert究竟是如何執行的。因為Tutel使用了ZeRO-DP,所以在前向傳播時,要對expert參數進行all-gather。這部分的邏輯不在MOELayer中,而是在expert內部。
      • a2a_ffn_overlap_forward是如何實現PP的。

      Expert內部邏輯

      Tutel默認的expert是FusedExpertsNetwork,我們簡單看一下它的結構

      class FusedExpertsNetwork(torch.nn.Module):
          def __init__(self, model_dim, hidden_size_per_expert, num_experts_per_device, sharded_count, activation_fn=None, activation_fn_with_self=None, output_dim=None, has_fc1_bias=True, has_fc2_bias=True):
              # 模型為兩層FFN,其中沿著中間的隱藏層切成了sharded_count塊
              self.hidden_size = hidden_size_per_expert // sharded_count
              self.batched_fc1_w = torch.nn.Parameter(torch.empty(num_experts_per_device, self.hidden_size, model_dim))
              self.batched_fc2_w = torch.nn.Parameter(torch.empty(num_experts_per_device, self.hidden_size, self.output_dim))
              
              for i in range(self.batched_fc1_w.size(0)):
                  fc1 = torch.nn.Linear(self.model_dim, self.hidden_size)
                  fc2 = torch.nn.Linear(self.hidden_size, self.output_dim)
                  self.batched_fc1_w[i] = fc1.weight
                  self.batched_fc1_bias[i] = fc1.bias
                  self.batched_fc2_w[i] = fc2.weight.t()
              	# 這里注意batched_fc2_bias作用于最終的輸出,它的切塊是沿著輸出維度切的
                  self.batched_fc2_bias[i] = fc2.bias[:((self.output_dim + sharded_count - 1) // sharded_count)]
      

      然后來看它的forward函數

      def forward(self, x, ctx):
          # 輸入的x形狀為(num_local_experts, -1, h)
      
          if ctx.adaptive_degree == 0:
              # 如果只有DP,那就是對所有的GPU進行all_gather
              # num_local_experts就是num_global_experts
              # zero_gather內部調用pytorch的all_gather,獲取得到完整的fc1_w
              batched_fc1_w = net.zero_gather(batched_fc1_w, group=ctx.group).view(ctx.num_global_experts, -1, batched_fc1_w.size(2))
              batched_fc2_w = net.zero_gather(batched_fc2_w, group=ctx.group).view(ctx.num_global_experts, -1, batched_fc2_w.size(2))
              if self.batched_fc1_bias is not None:
                  batched_fc1_bias = net.zero_gather(batched_fc1_bias, group=ctx.group).view(ctx.num_global_experts, 1, -1)
              if self.batched_fc2_bias is not None:
                  batched_fc2_bias = net.zero_gather(batched_fc2_bias, group=ctx.group).view(ctx.num_global_experts, 1, -1)
          else:
              # 否則,DP+EP+MP
              if ctx.sharded_count > 1:
                  # 如果expert被切塊了(因為ZeRO-DP或者MP)
                  mesh_size = net.get_world_size(ctx.group)
                  if mesh_size > 1 and mesh_size < net.get_world_size():
                      ctx.adaptive_degree = ctx.sharded_count
                  group_size = ctx.sharded_count // ctx.adaptive_degree
      
                  if group_size > 1:
                      # expert因為DP而被切塊,則這些塊需要進行all-gather
                      # 在前面圖中的MP Group內部進行all-gather
                      # 即合并E^{alpha0}和E^{alpha1}得到E^{alpha}
                      ffn_zero_group = net.create_groups_from_world(group_count=-group_size, parent_group=ctx.group).model_group
                      batched_fc1_w = net.zero_gather(batched_fc1_w, group=ffn_zero_group).view(1, -1, ctx.model_dim)
                      batched_fc2_w = net.zero_gather(batched_fc2_w, group=ffn_zero_group).view(1, -1, self.output_dim)
                      if self.batched_fc1_bias is not None:
                          batched_fc1_bias = net.zero_gather(batched_fc1_bias, group=ffn_zero_group).view(1, 1, -1)
      
                  if self.batched_fc2_bias is not None:
                      # fc_bias2也要all-gather
                      # 在前面圖中的EP Group內部進行all-gather
                      # 得到的是完整的fc_bias2
                      batched_fc2_bias = net.zero_gather(batched_fc2_bias, group=net.create_groups_from_world(group_count=ctx.num_global_experts, parent_group=ctx.group).model_group)
                      batched_fc2_bias = batched_fc2_bias.view(1, 1, -1)
      
                      # 因為數據被復制了r份,所以fc2_bias也被加了r次
                      # 為了使最后累加的y不變,這里把fc2_bias除以r
                      if ctx.adaptive_degree > 1:
                          batched_fc2_bias = torch.mul(batched_fc2_bias, 1.0 / ctx.adaptive_degree)
      
          # 進行計算
          y = torch.matmul(x, batched_fc1_w.permute(0, 2, 1))
          if self.batched_fc1_bias is not None:
              y = torch.add(y, batched_fc1_bias)
          y = self.activation_fn(y)
          y = torch.matmul(y, batched_fc2_w)
          if self.batched_fc2_bias is not None:
              y = torch.add(y, batched_fc2_bias)
      

      All-gather的部分最好參照著前面的圖理解一下。

      PP實現

      最后我們看a2a_ffn_overlap_forward函數。為了方便我們依然不考慮2dh。

      def a2a_ffn_overlap_forward(input, expert_fn, a2a_ffn_overlap_degree, use_2dh, group):
          # input的形狀為(max(e,w), -1, h)
          split_dim = 1
          # 設置AllToAllStatus.num_split = a2a_ffn_overlap_degree
          # 為了方便,我們記a2a_ffn_overlap_degree=p
          # 即PP沿著input的維度1切分,切成p塊
          # init會初始化nccl環境
          C.AllToAllStatus.init(group, a2a_ffn_overlap_degree, split_dim)
          
          # 首先等待之前的計算任務完成
          # num_slices_per_split是input的第一維大小,即max(e, w)
          # num_slices_per_split = input.shape[:split_dim].numel()
          # length = input.nbytes();
          # num_slices = num_slices_per_split * num_split;
          # slice_size = length / num_slices;
          input_ready = C.CurrentStreamRelease.apply(input, 0)
          # 對于每一個PP塊進行all-to-all操作
          # 具體來說,all-to-all是對每個expert都進行send-recv實現的
          # 在每一個PP塊的操作后,都向cuda流中插入一個事件,用來檢測這個塊是否完成
          input_scattered_after_a2a = C.AllToAllScatterAsync.apply(input_ready)
          
          # 這段要從下往上看
          expert_output_scattered = [
              # 再插入一個事件i
              C.CurrentStreamRelease.apply(
                  # 跟pre_expert_permute相反
                  # 最終x的形狀是(max(e, w), -1/p, h)
                  C.post_expert_permute(
                      # expert函數,前邊已經講過了
                      expert_fn(
                          # 跟前面的permute作用相同
                          # 先把x拆成(w, max(e/w, 1), -1/p, h)
                          # 再把前兩位對調(max(e/w), w, -1/p, h)
                          # 再把x變成(max(e/w), w*-1/p, h)
                          C.pre_expert_permute(
                              # 等待第i個事件完成
                              C.CurrentStreamAcquire.apply(
                                  x,
                              i),
                          group=group)
                      ),
                  group=group),
              i)
              # 枚舉每一個PP塊
              # 其中x是第i個塊的輸出,形狀為(max(e,w), -1/p, h),即把input的維度1切成/p塊
              for i, x in enumerate(input_scattered_after_a2a)
          ]
          
          # 對于每個PP塊,等待第i個事件,然后使用send-recv進行all-to-all
          expert_output_gathered_after_a2a = C.AllToAllGatherAsync.apply(*expert_output_scattered)
          # 等待所有all-to-all完成
          input = C.CurrentStreamAcquire.apply(expert_output_gathered_after_a2a, 0)
          
          return input
      

      backward跟forward流程差不多,這里就不講了。

      posted @ 2025-02-14 00:12  CQzhangyu  閱讀(582)  評論(0)    收藏  舉報
      主站蜘蛛池模板: 国产盗摄xxxx视频xxxx| 亚洲一区中文字幕人妻| 亚洲精品欧美综合二区| 国产在线精品国偷产拍| 耒阳市| 亚洲av鲁丝一区二区三区黄| 99午夜精品亚洲一区二区| 亚洲熟妇少妇任你躁在线观看无码| 亚洲老熟女一区二区三区| 久久一级黄色大片免费观看| 中文字幕日韩视频欧美一区| 奉节县| 亚洲欧美国产精品久久久久久久| 青青草无码免费一二三区| 男人扒开添女人下部免费视频| 日韩精品 在线 国产 丝袜| 亚洲国产美女精品久久久| 免费人成自慰网站| 久久综合免费一区二区三区| 国产日韩一区二区四季| 中文字幕亚洲无线码A| 九九热在线免费视频精品| 在线观看美女网站大全免费| 亚洲成av人最新无码不卡短片| 和顺县| 男女性杂交内射女bbwxz| 蜜桃av无码免费看永久| 91香蕉国产亚洲一二三区| 韩国三级+mp4| 国产自国产自愉自愉免费24区| 精品国产成人国产在线视| 成人一区二区不卡国产| 又色又爽又黄的视频网站| 视频一区二区三区刚刚碰| 护士的小嫩嫩好紧好爽| 精品人妻少妇嫩草av专区| 最近中文字幕日韩有码| 欧美精品一区二区在线观看播放| 日韩乱码卡一卡2卡三卡四| 亚洲av一本二本三本| 97一期涩涩97片久久久久久久 |