<output id="qn6qe"></output>

    1. <output id="qn6qe"><tt id="qn6qe"></tt></output>
    2. <strike id="qn6qe"></strike>

      亚洲 日本 欧洲 欧美 视频,日韩中文字幕有码av,一本一道av中文字幕无码,国产线播放免费人成视频播放,人妻少妇偷人无码视频,日夜啪啪一区二区三区,国产尤物精品自在拍视频首页,久热这里只有精品12

      [源碼解析] 深度學(xué)習(xí)流水線并行 PipeDream(5)--- 通信模塊

      [源碼解析] 深度學(xué)習(xí)流水線并行 PipeDream(5)--- 通信模塊

      0x00 摘要

      在前文中,我們介紹了PipeDream的總體架構(gòu),Profile階段,計(jì)算分區(qū)階段,模型轉(zhuǎn)換階段和運(yùn)行時(shí)引擎,本文我們介紹PipeDream 的通信模塊,通信模塊是引擎的基礎(chǔ),同時(shí)也是PyTorch DDP,P2P 如何使用的一個(gè)萬(wàn)花筒和完美示例。

      流水線并行其他文章鏈接如下:

      [源碼解析] 深度學(xué)習(xí)流水線并行Gpipe(1)---流水線基本實(shí)現(xiàn)

      [源碼解析] 深度學(xué)習(xí)流水線并行GPipe (2) ----- 梯度累積

      [源碼解析] 深度學(xué)習(xí)流水線并行 GPipe(3) ----重計(jì)算

      [源碼解析] 深度學(xué)習(xí)流水線并行之PipeDream(1)--- Profile階段

      [源碼解析] 深度學(xué)習(xí)流水線并行 PipeDream(2)--- 計(jì)算分區(qū)

      [源碼解析] 深度學(xué)習(xí)流水線并行 PipeDream(3)--- 轉(zhuǎn)換模型

      [源碼解析] 深度學(xué)習(xí)流水線并行 PipeDream(4)--- 運(yùn)行時(shí)引擎

      0x01 前言

      通訊模塊代碼位于:runtime/communication.py。我們首先思考一下,通信模塊需要哪些功能?

      • 階段(Stage)之間的通信,如果階段在不同機(jī)器上如何處理?在同一個(gè)機(jī)器上如何處理?

      • 因?yàn)槭钱惒酵ㄐ艦橹鳎煌?jié)點(diǎn)的性能可能不同,是否需要一個(gè)緩存機(jī)制來(lái)協(xié)調(diào)不同節(jié)點(diǎn),類似背壓功能?

      • 深度學(xué)習(xí)參數(shù)眾多,涉及的張量和梯度眾多,層數(shù)眾多,每層的數(shù)據(jù)并行數(shù)目也不同,所以前向傳播和反向傳播如何保證按照確定次序運(yùn)行?

      • 因?yàn)楣?jié)點(diǎn)上需要進(jìn)行前向,后向傳播,所以需要建立多個(gè)線程進(jìn)行分別傳輸。

      因此我們下面分析時(shí)候,就結(jié)合這些問(wèn)題進(jìn)行思考。

      0x02 類定義

      CommunicationHandler 負(fù)責(zé)在階段(Stage)之間的通信。

      • 如果階段位于不同機(jī)器上,就使用 PyTorch p2p 的 send/recv。
      • 如果階段位于同一個(gè)機(jī)器上,則使用 PyTorch p2p 的 broadcast。

      下面代碼中,主要就是初始化各種成員變量,我們目前最熟悉的是和DDP相關(guān)的,比如init_process_group。

      class CommunicationHandler(object):
          """ Handles communication between stages.
      
          For stages on different machines, use send/recv.
          For stages on same machine, use broadcast.
          """
          def __init__(self, master_addr, master_port, rank,
                       local_rank, num_ranks_in_server,
                       world_size, fp16, backend):
              """ Set up process groups.
      
              Note: To turn off broadcasting, set num_ranks_in_server = 1.
              """
              self.rank = rank
              self.local_rank = local_rank
              self.backend = backend
              self.num_ranks_in_server = num_ranks_in_server
              self.world_size = world_size
              self.fp16 = fp16
              assert num_ranks_in_server > 0
      
              # Initialize the distributed environment.
              # 以下是為了 DDP
              os.environ['MASTER_ADDR'] = master_addr
              os.environ['MASTER_PORT'] = str(master_port)
              dist.init_process_group(backend, rank=rank, world_size=world_size)
              assert dist.get_world_size() == self.world_size
      
              # Stores list of ranks of GPUs on the same server.
              self.ranks_in_server = []
      
              if num_ranks_in_server == 1:
                  return
      
              # Stores information about tensors sent directly GPU-to-GPU.
              self.connection_list = []
      
              # Stores process groups (for broadcast() connections).
              self.process_groups = {}
      
              # Populate ranks_in_server.
              rank_of_first_gpu_in_server = rank - rank % num_ranks_in_server
              for connected_rank in range(
                  rank_of_first_gpu_in_server,
                  rank_of_first_gpu_in_server + num_ranks_in_server):
                  if connected_rank == rank:
                      continue
                  self.ranks_in_server.append(connected_rank)
              assert len(self.ranks_in_server) == num_ranks_in_server - 1, \
                  self.ranks_in_server
      
      

      0x03 構(gòu)建

      3.1 初始化

      前面章節(jié)中提到,當(dāng)生成了CommunicationHandler之后,會(huì)調(diào)用initialize進(jìn)行初始化。

              if self.comm_handler is not None:
                  self.comm_handler.initialize(
                      self.receive_ranks,
                      self.send_ranks,
                      self.tensor_tags,
                      self.target_tensor_names,
                      self.training_tensor_dtypes,
                      self.rank_in_stage,
                      self.num_ranks_in_stage,
                      self.ranks_in_previous_stage,
                      self.ranks_in_next_stage)
      
      

      在初始化代碼之中,完成如下操作,主要是:

      • 構(gòu)建通信需要的queue。
      • 構(gòu)建發(fā)送消息的次序。
      • 構(gòu)建進(jìn)程組。
          def initialize(self, receive_ranks, send_ranks,
                         tensor_tags, target_tensor_names,
                         training_tensor_dtypes,
                         rank_in_stage,
                         num_ranks_in_stage,
                         ranks_in_previous_stage,
                         ranks_in_next_stage):
              """
              Initialize state needed for CommunicationHandler.
              """
              self.receive_ranks = receive_ranks
              self.send_ranks = send_ranks
              self.tensor_tags = tensor_tags
              self.target_tensor_names = target_tensor_names
              self.training_tensor_dtypes = training_tensor_dtypes
              self.rank_in_stage = rank_in_stage
              self.num_ranks_in_stage = num_ranks_in_stage
              self.ranks_in_previous_stage = ranks_in_previous_stage
              self.num_ranks_in_previous_stage = len(ranks_in_previous_stage)
              self.ranks_in_next_stage = ranks_in_next_stage
              self.num_ranks_in_next_stage = len(ranks_in_next_stage)
      
              self.setup_queues() # 構(gòu)建通信需要的queue
              self.setup_messaging_schedule() # 構(gòu)建發(fā)送消息的次序
              self.create_process_groups() # 構(gòu)建進(jìn)程組
      

      我們具體分析如下。

      3.2 創(chuàng)建queue

      Queue 的作用是作為 send,receive 的基礎(chǔ),系統(tǒng)通過(guò)index找到哪一個(gè)queue,然后進(jìn)行相應(yīng)操作。

      initialize 函數(shù)傳入了兩個(gè)ranks列表。

      • receive_ranks 就是本節(jié)點(diǎn)的輸入rank。
      • send_ranks 就是本節(jié)點(diǎn)的輸出rank。

      ranks 列表舉例如下:

      receive_ranks = {dict: 3}  # 這里就是每個(gè)tensor對(duì)應(yīng)的接收目標(biāo)rank
       'out8' = {list: 1} [2] # out8 是tensor name, {list: 1} [2] 是 out8 對(duì)應(yīng)的 ranks
       'out9' = {list: 1} [2] # 就是這幾個(gè)張量都要從 rank 2 接收
       'out10' = {list: 1} [2]
       __len__ = {int} 3
      

      setup_queues 相應(yīng)一共建立了4個(gè)queue列表:

      • forward_receive_queues :前向傳播過(guò)程中,接受張量的queue。對(duì)應(yīng)了 receive_ranks
      • backward_send_queues : 后向傳播過(guò)程中,發(fā)送張量的queue。對(duì)應(yīng)了 receive_ranks。因?yàn)榍跋騻鞑ブ薪邮艿膶?duì)象,就是后向傳播中發(fā)送的目標(biāo)。
      • forward_send_queues : 前向傳播過(guò)程中,發(fā)送張量的queue。對(duì)應(yīng)了 send_ranks
      • backward_receive_queues :后向傳播過(guò)程中,接受張量的queue。對(duì)應(yīng)了 send_ranks。因?yàn)榍跋騻鞑ブ邪l(fā)送的目標(biāo)就是后向傳播中接受的對(duì)象。

      大致邏輯如下:

      forward_receive_queues <-----> receive_ranks <------->  backward_send_queues
      forward_send_queues  <------>  send_ranks    <------->  backward_receive_queues
      

      以 forward_receive_queues 為例。

      • forward_receive_queues 這個(gè)列表包括多個(gè)queue。
      • receive_ranks 列表中包括多個(gè) rank,每個(gè)rank在通信過(guò)程之中,對(duì)應(yīng)了一個(gè)張量,可以認(rèn)為 receive_ranks 包括多個(gè)張量,由一個(gè)張量名字來(lái)對(duì)應(yīng)。張量名字類似于:target_tensor_names = {"target", "target_length"}。
      • forward_receive_queues 列表之中,每一個(gè)queue對(duì)應(yīng)了receive_ranks 之中的一個(gè) 張量。
      • 每個(gè)張量,對(duì)應(yīng)一個(gè)唯一的tag,PipeDream的目的是讓每一個(gè)tag都有自己的process group,因?yàn)槿魏我粋€(gè)stage都有可能并行。
      • 針對(duì)這個(gè)張量和這個(gè)唯一的tag,注冊(cè) [tag, rank] 到 connection_list。

      具體如下:

          def setup_queues(self):
              """
              Setup queues for communication between main compute thread
              and helper communication threads. One queue per tensor
              in forward / backward direction.
              """
              self.forward_receive_queues = {}
              self.backward_receive_queues = {}
              self.forward_send_queues = {}
              self.backward_send_queues = {}
              self.num_forward_threads = 0
              self.num_backward_threads = 0
      
              self.target_receive_rank_counts = {}
              self.target_send_rank_counts = {}
              # Setup queues for each tensor to be received and sent.
              for input_name in self.receive_ranks: # 遍歷張量
                  # 與 input_name 張量對(duì)應(yīng)的queue,input_name 是張量名字
                  self.forward_receive_queues[input_name] = []
                  self.backward_send_queues[input_name] = []
                  # 遍歷該張量對(duì)應(yīng)的每個(gè) ranks
                  for i in range(len(self.receive_ranks[input_name])):
                      self.forward_receive_queues[input_name].append(
                          threadsafe_queue.Queue())
                      self.backward_send_queues[input_name].append(
                          threadsafe_queue.Queue())
                      # 得到 rank
                      target_receive_rank = self.receive_ranks[input_name][i]
                      # 針對(duì) rank,注冊(cè)張量
                      self.register_tensor(
                          connected_rank=target_receive_rank,
                          tag=self.tensor_tags[input_name])
                      if target_receive_rank not in self.target_receive_rank_counts:
                          self.target_receive_rank_counts[target_receive_rank] = 0
                      self.target_receive_rank_counts[target_receive_rank] += 1
                      self.num_forward_threads += 1
                      self.num_backward_threads += 1
                      
              for output_name in self.send_ranks: # 遍歷張量
                  # 與 output_name 張量對(duì)應(yīng)的queue
                  self.backward_receive_queues[output_name] = []
                  self.forward_send_queues[output_name] = []
                  # 遍歷該張量對(duì)應(yīng)的每個(gè) ranks
                  for i in range(len(self.send_ranks[output_name])):
                      self.backward_receive_queues[output_name].append(
                          threadsafe_queue.Queue())
                      self.forward_send_queues[output_name].append(
                          threadsafe_queue.Queue())
                      # 得到 rank
                      target_send_rank = self.send_ranks[output_name][i]
                      # 針對(duì) rank,注冊(cè)張量
                      self.register_tensor(
                          connected_rank=target_send_rank,
                          tag=self.tensor_tags[output_name])
                      if target_send_rank not in self.target_send_rank_counts:
                          self.target_send_rank_counts[target_send_rank] = 0
                      self.target_send_rank_counts[target_send_rank] += 1
                      self.num_forward_threads += 1
                      self.num_backward_threads += 1
      
              # 單獨(dú)處理目標(biāo)tensor
              for target_tensor_name in self.target_tensor_names:
                  # Queues for target in forward pass.
                  self.forward_receive_queues[target_tensor_name] = []
                  self.forward_send_queues[target_tensor_name] = []
      
                  if self.num_ranks_in_previous_stage > 0:
                      self.receive_ranks[target_tensor_name] = self.ranks_in_previous_stage
                      for i in range(len(self.receive_ranks[target_tensor_name])):
                          # 針對(duì) rank,注冊(cè)張量
                          self.register_tensor(
                              connected_rank=self.receive_ranks[target_tensor_name][i],
                              tag=self.tensor_tags[target_tensor_name])
                          self.forward_receive_queues[target_tensor_name].append(
                              threadsafe_queue.Queue())
                          self.num_forward_threads += 1
      
                  if self.num_ranks_in_next_stage > 0:
                      self.send_ranks[target_tensor_name] = self.ranks_in_next_stage
                      for i in range(len(self.send_ranks[target_tensor_name])):
                          self.register_tensor(
                              connected_rank=self.send_ranks[target_tensor_name][i],
                              tag=self.tensor_tags[target_tensor_name])
                          self.forward_send_queues[target_tensor_name].append(
                              threadsafe_queue.Queue())
                          self.num_forward_threads += 1
      
              print ("Send ranks: ", self.send_ranks)
              print ("Receive ranks: ", self.receive_ranks)
      
              # Queues for ack for forward pass-only runs as a clocking mechanism.
              # 單獨(dú)處理 ack 情況
              self.num_ack_threads = 0
              if "ack" in self.tensor_tags:
                  self.backward_receive_queues["ack"] = []
                  self.backward_send_queues["ack"] = []
                  for i in range(self.num_ranks_in_previous_stage):
                      # 針對(duì) rank,注冊(cè)張量
                      self.register_tensor(
                          connected_rank=self.ranks_in_previous_stage[i],
                          tag=self.tensor_tags["ack"])
                      self.backward_send_queues["ack"].append(
                          threadsafe_queue.Queue())
                      self.num_ack_threads += 1
                  for i in range(self.num_ranks_in_next_stage):
                      # 針對(duì) rank,注冊(cè)張量
                      self.register_tensor(
                          connected_rank=self.ranks_in_next_stage[i],
                          tag=self.tensor_tags["ack"])
                      self.backward_receive_queues["ack"].append(
                          threadsafe_queue.Queue())
                      self.num_ack_threads += 1
      

      注意,每個(gè)張量有唯一一個(gè)tag,針對(duì)這個(gè)張量和這個(gè)唯一的tag,注冊(cè) [tag, rank] 到 connection_list。

          def register_tensor(self, connected_rank, tag):
              """
              Builds connections list of tensors that are communicated GPU to GPU.
      
              For tensors that are sent GPU-to-GPU (intra-server for GLOO backend),
              make a list of destination/source ranks and the corresponding tag.
              This information is then used to crate process groups.
              """
              if not self.is_gpu_to_gpu_comm(connected_rank=connected_rank):
                  return
              connection_info = [tag, connected_rank]
              self.connection_list.append(connection_info)
      
      

      于是,此時(shí)邏輯如下,我們僅僅以部分 ranks,queue等舉例,forward_receive_queues 之中的這幾個(gè)queue 就是用來(lái)作為對(duì)應(yīng)張量的buffer。

      +------------------------+         'out8' = {list: 1} [2]
      |                        |
      |     receive_ranks +----------->  'out9' = {list: 1} [2]
      |                        |
      +------------------------+         'out10' = {list: 1} [2]
      
      
      
      +--------------------------+
      |                          |         'out8' : Queue
      | forward_receive_queues+-------->
      |                          |         'out9' : Queue
      +--------------------------+
                                           'out10' : Queue
      
      
      
      
      +--------------------------+       'out8' : rank 2
      |                          |
      |    connection_list  +--------->  'out9' : rank 2
      |                          |
      +--------------------------+       'out10' : rank 2
      
      

      3.3 前向后向順序

      接下來(lái)建立消息傳遞的前后向順序,其目的是為了讓每個(gè) worker 記錄如何處理由前向?qū)?后向?qū)觽鱽?lái)的rank。

      3.3.1 建立順序

      setup_messaging_schedule 方法就是用來(lái)建立:

      • 前向傳播時(shí)接受的順序。
      • 后向傳播時(shí)發(fā)送的順序。

      這里的重點(diǎn)是:如果前一層數(shù)目比本層數(shù)目多,就把 i對(duì)應(yīng)的前一層ranki + (本層rank數(shù)目) * n 對(duì)應(yīng)的前一層rank 都加入到本層 i 的計(jì)劃(self.message_schedule)。n 等于 num_ranks_in_stage。

      最終把順序放入 self.messaging_schedule 成員變量。假如本stage是擁有 3 個(gè)rank,則 self.messaging_schedule 就是這三個(gè)rank 分別的 message_schedule,每個(gè) message_schedule 里面都是對(duì)應(yīng)的上一層 某些 ranks。

      再細(xì)化一下:

      • self.messaging_schedule 是一個(gè)列表。
      • self.messaging_schedule 其中每一個(gè)item又是一個(gè)列表。self.messaging_schedule[ i ] 就表示比如 本層 第 i 個(gè) rank 對(duì)應(yīng)的 schedule(message_schedule)。
      • schedule(message_schedule)是上一層 或者 下一層 的某些ranks。
      • message_schedule包括的ranks是本stage所包括ranks的一個(gè)index。因?yàn)槭莾?nèi)部使用,所以不需要是真正的 rank 數(shù)值,只要能和內(nèi)部的queue等其他內(nèi)部數(shù)據(jù)結(jié)構(gòu)映射上即可。

      代碼如下:

          def setup_messaging_schedule(self):
              """ Order in which to receive forward and send backwards.
      
              Separate indexes of ranks in previous stage based on their
              corresponding offset in this stage. Then each worker will go
              in increasing order within a subset, and process subsets in
              a decreasing order.
      
              This is done so that messages are processed in the order
              that they are sent. Backwards send is done so that that it
              matches up with forward receive.
              """
              self.messaging_schedule = []
              for i in range(self.num_ranks_in_stage): # 本stage的并行數(shù)目
                  idx = i
                  message_schedule = []
                  while idx < self.num_ranks_in_previous_stage: # 上一個(gè)stage的并行數(shù)目
                      message_schedule.append(idx)
                      # 如果前一層比本層多,就把 i, i + (本層rank) * n 對(duì)應(yīng)的前一層rank都加入到本層 i 的計(jì)劃
                      idx += self.num_ranks_in_stage
                  if len(message_schedule) > 0:
                      self.messaging_schedule.append(message_schedule)
      
              self.fwd_messaging_scheduling_row = self.rank_in_stage # 自己的rank index
              self.fwd_messaging_scheduling_col = 0 # receive forward
              self.bwd_messaging_scheduling_row = self.rank_in_stage # 自己的rank index
              self.bwd_messaging_scheduling_col = 0 # send backwards
      
              # For cases where previous stage has less workers than current stage.
              while self.fwd_messaging_scheduling_row >= \
                  len(self.messaging_schedule):
                  self.fwd_messaging_scheduling_row -= 1
                  self.bwd_messaging_scheduling_row -= 1
      

      具體邏輯如下:

      +-------------------+                 +--------------------------------------------------+
      | Stage 0           |                 | Stage 1                                          |
      |                   |                 |                                                  |
      |                   |                 |                                                  |
      |                   |                 |     +----------------------------------------+   |
      |                   |   send_ranks    |     | messaging_schedule                     |   |
      |  ranks:           |                 |     |                                        |   |
      |                   +---------------> |     |                                        |   |
      |  [0,1,2,3,4,5,    |                 |     |   message_schedule +---> [0,1,2,9]     |   |
      |  6,7,8,9,10,11,12]|                 |     |                                        |   |
      |                   |                 |     |   message_schedule +---> [3,4,5,6,10]  |   |
      |                   |                 |     |                                        |   |
      |                   |                 |     |   message_schedule +---> [6,7,8,11]    |   |
      |                   |                 |     |                                        |   |
      |                   |                 |     +----------------------------------------+   |
      |                   |                 |                                                  |
      +-------------------+                 +--------------------------------------------------+
      
      

      3.3.2 獲取消息序列

      get_messaging_index 方法是用來(lái)獲取本次傳遞的對(duì)象,就是應(yīng)該和哪個(gè)rank進(jìn)行交互。

          def get_messaging_index(self, sending):
              if sending:
                  connection_rank = self.messaging_schedule[
                      self.bwd_messaging_scheduling_row][
                          self.bwd_messaging_scheduling_col]
              else:
                  connection_rank = self.messaging_schedule[
                      self.fwd_messaging_scheduling_row][
                          self.fwd_messaging_scheduling_col]
      
              return connection_rank
      
      

      哪里用到了 get_messaging_index?原來(lái)是send, recv 函數(shù),就是和前一層打交道時(shí)候會(huì)用到。

      比如:

          def recv(self, tensor_name, forward_minibatch_id,
                   backward_minibatch_id, backward=False):
              if backward:
                  index = (backward_minibatch_id + self.rank_in_stage) % \
                      len(self.backward_receive_queues[tensor_name])
                  tensor = self.backward_receive_queues[tensor_name][
                      index].remove()
                  return tensor
              else:
                  # 這里會(huì)使用到,獲取與哪一個(gè)rank進(jìn)行交互
                  index = self.get_messaging_index(sending=False)
                  # 然后得到使用哪個(gè)張量,從queue之中提取對(duì)應(yīng)的最新張量
                  tensor = self.forward_receive_queues[tensor_name][
                      index].remove()
                  if tensor.dtype == torch.float32:
                      tensor = tensor.requires_grad_()
                  return tensor
      

      3.3.3 增加消息序列

      increment_messaging_index 方法用來(lái)增加消息序列,就是得到下一次應(yīng)該使用哪個(gè)消息。

      其中,兩個(gè)參數(shù)需要說(shuō)明:

      • bwd_messaging_scheduling_col 表示上游具體哪一個(gè) rank index。

      • bwd_messaging_scheduling_row 表示自己的 rank index。

      方法如下:

          def increment_messaging_index(self, sending):
              if sending:
                  self.bwd_messaging_scheduling_col += 1 # send backwards 對(duì)應(yīng)的下一個(gè) rank
                  if self.bwd_messaging_scheduling_col == len(
                          self.messaging_schedule[
                              self.bwd_messaging_scheduling_row]):
                      self.bwd_messaging_scheduling_col = 0
                      self.bwd_messaging_scheduling_row -= 1 # 自己的rank index
                      if self.bwd_messaging_scheduling_row == -1:
                          self.bwd_messaging_scheduling_row = \ # 重置回self.messaging_schedule,繼續(xù)新的一輪本地 rank通訊
                              len(self.messaging_schedule) - 1
              else:
                  self.fwd_messaging_scheduling_col += 1 # receive forward 對(duì)應(yīng)的下一個(gè) rank
                  if self.fwd_messaging_scheduling_col == len(
                          self.messaging_schedule[
                              self.fwd_messaging_scheduling_row]): 
                      self.fwd_messaging_scheduling_col = 0
                      self.fwd_messaging_scheduling_row -= 1 # 自己的rank index
                      if self.fwd_messaging_scheduling_row == -1:
                          self.fwd_messaging_scheduling_row = \ # 重置回self.messaging_schedule,繼續(xù)新的一輪本地 rank通訊
                              len(self.messaging_schedule) - 1
      
      

      哪里會(huì)用到?在以下幾個(gè)函數(shù)中會(huì)用到:

          def receive_tensors_forward(self):
              if self.loader_iter is not None:
      			# ......
              else:
                  # Receive all required tensors from upstream machines.
      			# ......
                  # Used to track where to receive forward from.
                  self.comm_handler.increment_messaging_index(
                      sending=False)
      
          def send_tensors_backward(self):
              # Send all required gradients upstream.
      
              if self.num_ranks_in_previous_stage > 0:
                  # Used to track where to send tensors in the
                  # backward pass.
                  self.comm_handler.increment_messaging_index(
                      sending=True)    
                  
          def run_ack(self):
              if self.stage > 0:
                  self.comm_handler.send(
                      "ack",
                      torch.zeros(self.tensor_shapes["ack"],
                                  dtype=torch.int64).cuda(),
                      forward_minibatch_id=self.forward_minibatch_id,
                      backward_minibatch_id=self.backward_minibatch_id,
                      backward=True)
      
                  # Used to track where to receive forward from.
                  self.comm_handler.increment_messaging_index(sending=True)        
      

      3.4 建立進(jìn)程組

      目的是:針對(duì)每個(gè)張量,設(shè)置兩個(gè)進(jìn)程組,一個(gè)用于前向,一個(gè)用于后向。每一個(gè)張量有一個(gè)自己的tag。每一個(gè)tag都有自己的兩個(gè)process group,因?yàn)槿魏我粋€(gè)stage都有可能并行。

      3.4.1 設(shè)計(jì)

      首先,我們看看注釋,學(xué)習(xí)一下為何這么設(shè)計(jì)。

      create_process_groups 方法在所有rank之中以同樣順序建立進(jìn)程組。為了以同樣順序建立進(jìn)程組,每個(gè)worker都會(huì)收集其他所有workers的connection_list(GPU to GPU)。為了做到這一點(diǎn),每個(gè)worker收集所有其他workers的連接列表connection_list(L)的最大大小。然后每個(gè)worker創(chuàng)建一個(gè)大小為L(zhǎng)x2的張量,其中每行表示一個(gè)連接,并根據(jù)“它本身連接列表大小”來(lái)填充此張量。擁有最大連接列表的worker將填充整個(gè)張量。

      構(gòu)建此列表后,將執(zhí)行all_gather操作,然后每個(gè)worker都擁有一個(gè)相同的 NxLx2 輸出,其中N是worker 數(shù)量(world_size),輸出的每個(gè)index代表一個(gè)worker的連接列表。對(duì)于 i=self.rank,輸出將與本worker的本地連接列表相同。

      每個(gè)worker以相同的順序在連接列表上進(jìn)行迭代,檢查是否已創(chuàng)建每個(gè)連接(每個(gè)連接都將在輸出中出現(xiàn)兩次),如果連接不存在,則對(duì)于前向和后向都創(chuàng)建一個(gè)新的進(jìn)程組。既然在進(jìn)程組中rank永遠(yuǎn)是一致的,所以小rank排在前面,大的rank排在后面。

      3.4.2 代碼

      回到代碼上,我們仔細(xì)分析下。

      +--------------------------+       'out8' : rank 2
      |                          |
      |    connection_list  +--------->  'out9' : rank 2
      |                          |
      +--------------------------+       'out10' : rank 2
      

      這里就用到了 connection_list。具體邏輯是:

      • 找到 workers 之中最大的 connection_list
      • 獲取到 connection_list 的大小,即 connection_list_size
      • 用集合通信來(lái)對(duì) connection_list_size 進(jìn)行聚合,最后得到的gathered_connection_list_sizes就是所有節(jié)點(diǎn)上的 connection_list_size 集合
      • 得到connection_list的最大數(shù)值
      • 利用最大數(shù)值來(lái)構(gòu)建張量列表 connection_list_tensor
      • 把張量移動(dòng)到GPU之上
      • 用集合通信來(lái)對(duì) connection_list_tensor進(jìn)行聚合,得到aggregated_connection_list
      • 在每個(gè)worker之上,利用 dist.new_group 建立同樣的進(jìn)程組
      • 遍歷aggregated_connection_list中的每一個(gè)connection
        • 得到張量對(duì)應(yīng)的tag
        • 針對(duì)每個(gè)張量,設(shè)置兩個(gè)進(jìn)程組,一個(gè)前向,一個(gè)后向

      因此,目的就是在每個(gè) worker 之中建立同樣的進(jìn)程組,針對(duì)每個(gè)張量,設(shè)置兩個(gè)進(jìn)程組,一個(gè)前向,一個(gè)后向。

      具體代碼如下:

          def create_process_groups(self):
              """ Create process groups in the same order across all ranks.
      
              To create process groups in the same order, each worker collects
              the connection_list of all other workers. To do this, every worker
              gathers the largest size of all other worker's connection_lists (L).
              Then every worker creates a tensor of size Lx2, where each row
              represents a connection, and fills up this tensor depending on how
              large its own connection list is. The worker(s) w/ the largest
              connection list will fill up the entire tensor.
      
              After constructing this list, an all_gather is performed, after which
              each worker has an identical NxLx2 output, where N is the number of
              workers (world_size), and each index of output represents a worker's
              connection list. For i=self.rank, the output will be identical to the
              workers local connection list.
      
              Each worker then iterates in the same order over the connections list,
              checking if each connection has been created yet (every connection will
              appear twice in the output), and creating a new process group if one
              doesn't exist for that connection, for both the forward and backward
              direction. Since ranks within process groups must always be identical,
              the smaller rank always goes first, followed by the larger rank.
              """
              if self.num_ranks_in_server == 1:
                  return
      
              print("Setting up process groups for broadcasts...")
      
              # Figure out the size of the largest connection list that any worker
              # has (L).
              # 找到最大的 connection_list
              # 獲取到 connection_list 的大小,即 connection_list_size
              connection_list_size = torch.tensor(
                  len(self.connection_list), dtype=torch.int)
              if self.backend == NCCL:
                  connection_list_size = connection_list_size.cuda()
              gathered_connection_list_sizes = [
                  torch.ones_like(connection_list_size)
                  for _ in range(self.world_size)]
              
              # 用集合通信來(lái)對(duì) connection_list_size 進(jìn)行聚合,最后得到的gathered_connection_list_sizes就是所有節(jié)點(diǎn)上的 connection_list_size 集合
              dist.all_gather(gathered_connection_list_sizes,
                              connection_list_size)
              # 得到最大數(shù)值
              max_connection_list_size = max(
                  gathered_connection_list_sizes)
      
              if max_connection_list_size == 0:
                  return 
      
              # 利用最大數(shù)值來(lái)構(gòu)建張量列表 connection_list_tensor
              # Build tensor to send local connection list to all other workers.
              connection_list_tensor = torch.ones([max_connection_list_size, 2],
                                                  dtype=torch.int) * -1
              # 把張量移動(dòng)到GPU之上
              if self.backend == NCCL:
                  connection_list_tensor = connection_list_tensor.cuda()
              if len(self.connection_list) > 0:
                  connection_list_tensor[0:len(self.connection_list)] = \
                      torch.IntTensor(self.connection_list)
      
              # 用集合通信來(lái)對(duì) connection_list_tensor進(jìn)行聚合       
              # Gather connection lists of all workers.
              aggregated_connection_list = [
                  torch.ones_like(connection_list_tensor)
                  for _ in range(self.world_size)]
              dist.all_gather(aggregated_connection_list,
                              connection_list_tensor)
      
              # 在每個(gè)worker之上,利用 dist.new_group 建立同樣的進(jìn)程組
              # Construct identical process groups on each worker.
              local_rank_connections = 0
              for src_rank in range(len(aggregated_connection_list)):
                  for connection in aggregated_connection_list[src_rank]:
                      # 得到張量對(duì)應(yīng)的tag
                      tag = int(connection[0])
                      dst_rank = int(connection[1])
      
                      if tag == -1:
                          assert dst_rank == -1
                          continue
      
                      min_rank = min(src_rank, dst_rank)
                      max_rank = max(src_rank, dst_rank)
                      assert min_rank != max_rank
      
                      if min_rank not in self.process_groups:
                          self.process_groups[min_rank] = {}
      
                      if max_rank not in self.process_groups[min_rank]:
                          self.process_groups[min_rank][max_rank] = {}
      
                      if tag not in self.process_groups[min_rank][max_rank]:
                          # 用到了pytorch p2p 的api
                          sub_process_group_fwd = dist.new_group(
                              ranks=[min_rank, max_rank])
                          sub_process_group_bwd = dist.new_group(
                              ranks=[min_rank, max_rank])
      
                          # 針對(duì)每個(gè)張量,設(shè)置進(jìn)程組
                          self.process_groups[min_rank][max_rank][tag] = {
                              'forward': sub_process_group_fwd,
                              'backward': sub_process_group_bwd
                          }
      
                          if min_rank == self.rank or max_rank == self.rank:
                              local_rank_connections += 1
              assert local_rank_connections == len(self.connection_list)
      

      具體 如何使用進(jìn)程組?在 recv_helper_thread_args 等函數(shù)會(huì)使用,比如:

          def recv_helper_thread_args(self, tensor_name, index, dtype,
                                      backward, num_iterations):
              if backward:
                  src_rank = self.send_ranks[tensor_name][index]
              else:
                  src_rank = self.receive_ranks[tensor_name][index]
      
              sub_process_group = None
              # 獲取張量 tensor_name 對(duì)應(yīng)的 tag
              tag = self.tensor_tags[tensor_name]
              if self.is_gpu_to_gpu_comm(connected_rank=src_rank) and tensor_name != "ack":
                  min_rank = min(self.rank, src_rank)
                  max_rank = max(self.rank, src_rank)
                  
                  if src_rank > self.rank:
                      # 獲取 tag 對(duì)應(yīng)的進(jìn)程組,調(diào)用者后續(xù)會(huì)使用
                      sub_process_group = \
                          self.process_groups[min_rank][max_rank][tag]['backward']
                  else:
                      # 獲取 tag 對(duì)應(yīng)的進(jìn)程組,調(diào)用者后續(xù)會(huì)使用
                      sub_process_group = \
                          self.process_groups[min_rank][max_rank][tag]['forward']
                  assert sub_process_group
      
              if backward:
                  queue = self.backward_receive_queues[tensor_name][index]
              else:
                  queue = self.forward_receive_queues[tensor_name][index]
              tensor_shape = self.tensor_shapes[tensor_name]
      
              return (queue, self.counter, self.local_rank, tensor_name,
                      src_rank, tag, tensor_shape, dtype, sub_process_group,
                      num_iterations)
      

      3.5 啟動(dòng)助手線程

      使用 start_helper_threads 來(lái)進(jìn)行啟動(dòng)助手線程。這些助手線程是為了 P2P 使用。

      首先,ranks舉例,可以看出來(lái),key 是張量名字,value 是ranks列表。

      receive_ranks = {dict: 3}  # 這里就是每個(gè)tensor對(duì)應(yīng)的接收目標(biāo)rank
       'out8' = {list: 1} [2]
       'out9' = {list: 1} [2]
       'out10' = {list: 1} [2]
       __len__ = {int} 3
      

      3.5.1 建立線程

      回憶一下之前建立的 4 個(gè)queues:

      • forward_receive_queues :前向傳播過(guò)程中,接受張量的queue。對(duì)應(yīng)了 receive_ranks
      • backward_send_queues : 后向傳播過(guò)程中,發(fā)送張量的queue。對(duì)應(yīng)了 receive_ranks。因?yàn)榍跋騻鞑ブ薪邮艿膶?duì)象,就是后向傳播中發(fā)送的目標(biāo)。
      • forward_send_queues : 前向傳播過(guò)程中,發(fā)送張量的queue。對(duì)應(yīng)了 send_ranks
      • backward_receive_queues :后向傳播過(guò)程中,接受張量的queue。對(duì)應(yīng)了 send_ranks。因?yàn)榍跋騻鞑ブ邪l(fā)送的目標(biāo)就是后向傳播中接受的對(duì)象。

      這 4 個(gè)queue 其實(shí)就對(duì)應(yīng)了 4 個(gè)不同的助手線程。

      思路是:

      • 針對(duì)接受ranks進(jìn)行處理,即遍歷 receive_ranks 中的張量
        • 遍歷張量對(duì)應(yīng)的ranks,對(duì)于每一個(gè)rank
          • 需要后向處理,所以建立后向發(fā)送線程
          • 建立接受助手線程
      • 針對(duì)發(fā)送ranks進(jìn)行處理,即遍歷 send_ranks 中的張量
        • 遍歷張量對(duì)應(yīng)的ranks,對(duì)于每一個(gè)rank
          • 需要后向處理,所以建立后向接受線程
          • 建立發(fā)送助手線程
      • 針對(duì)target進(jìn)行處理
      • 如果只有前向,則需要補(bǔ)齊ack

      具體代碼是:

          def start_helper_threads(self, num_iterations, forward_only):
              """
              Start helper communication threads, one for each queue.
              """
              if forward_only:
                  self.set_counter(self.num_forward_threads +
                                   self.num_ack_threads)
                  # For validation, receive acks in backward pass from next stage, send
                  # acks in backward pass to next stage.
                  self.receive_ranks["ack"] = self.ranks_in_previous_stage
                  self.send_ranks["ack"] = self.ranks_in_next_stage
              else:
                  self.set_counter(self.num_forward_threads +
                                   self.num_backward_threads)
                  if "ack" in self.receive_ranks:
                      del self.receive_ranks["ack"]
                  if "ack" in self.send_ranks:
                      del self.send_ranks["ack"]
      
              (num_iterations_for_forward_threads,
               num_iterations_for_backward_threads) = \
                  self.num_iterations_for_helper_threads(
                      num_iterations=num_iterations)
              dtype = torch.float16 if self.fp16 else torch.float32
      
              # Setup queues for each tensor to be received and sent.
              # 針對(duì)接受rank進(jìn)行處理
              for input_name in self.receive_ranks:
                  if input_name in self.target_tensor_names or input_name == "ack":
                      continue
      
                  # 遍歷張量對(duì)應(yīng)的ranks
                  for i in range(len(self.receive_ranks[input_name])):
                      if not forward_only:
                          # 需要后向處理,所以建立后向發(fā)送線程
                          self.start_helper_thread(
                              self.send_helper_thread_args,
                              send_helper_thread,
                              [input_name, i, True],
                              num_iterations_for_backward_threads)
                      # 建立接受助手線程    
                      self.start_helper_thread(
                          self.recv_helper_thread_args,
                          recv_helper_thread,
                          [input_name,
                           i,
                           self.training_tensor_dtypes[input_name],
                           False],
                          num_iterations_for_backward_threads)
                   
              # 針對(duì)發(fā)送ranks進(jìn)行處理
              for output_name in self.send_ranks:
                  if output_name in self.target_tensor_names or output_name == "ack":
                      continue
      
                  # 遍歷張量對(duì)應(yīng)的ranks
                  for i in range(len(self.send_ranks[output_name])):
                      if not forward_only:
                          # 需要后向處理,所以建立后向接受線程
                          self.start_helper_thread(
                              self.recv_helper_thread_args,
                              recv_helper_thread,
                              [output_name, i,
                               self.training_tensor_dtypes[output_name],
                               True],
                              num_iterations_for_forward_threads)
                      # 發(fā)送助手線程
                      self.start_helper_thread(
                          self.send_helper_thread_args,
                          send_helper_thread,
                          [output_name, i, False],
                          num_iterations_for_forward_threads)
      
              # 針對(duì)target進(jìn)行處理
              for target_tensor_name in self.target_tensor_names:
                  if self.num_ranks_in_previous_stage > 0:
                      for i in range(len(self.receive_ranks[target_tensor_name])):
                          self.start_helper_thread(
                              self.recv_helper_thread_args,
                              recv_helper_thread,
                              [target_tensor_name, i, torch.int64,
                               False],
                              num_iterations_for_backward_threads)
      
                  if self.num_ranks_in_next_stage > 0:
                      for i in range(len(self.send_ranks[target_tensor_name])):
                          self.start_helper_thread(
                              self.send_helper_thread_args,
                              send_helper_thread,
                              [target_tensor_name, i, False],
                              num_iterations_for_forward_threads)
      
              # Start helper threads for ack for forward pass-only run as a clocking
              # mechanism.
              # 如果只有前向,則需要補(bǔ)齊ack
              if forward_only:
                  # 有前向就補(bǔ)齊 ack
                  if "ack" in self.receive_ranks:
                      for i in range(len(self.receive_ranks["ack"])):
                          self.start_helper_thread(self.send_helper_thread_args,
                                                   send_helper_thread,
                                                   ["ack", i, True],
                                                   num_iterations_for_backward_threads)
                  if "ack" in self.send_ranks:
                      for i in range(len(self.send_ranks["ack"])):
                          self.start_helper_thread(self.recv_helper_thread_args,
                                                   recv_helper_thread,
                                                   ["ack", i, torch.int64, True],
                                                   num_iterations_for_forward_threads)
      
      
      

      具體線程建立函數(shù)為:

          def start_helper_thread(self, args_func, func, args_func_args, num_iterations):
              """
              Start passed-in func on a helper thread.
              """
              args_func_args += [num_iterations]
              args = args_func(*args_func_args) # 需要注意的是使用函數(shù)來(lái)獲取對(duì)應(yīng)的參數(shù)
              helper_thread = threading.Thread(target=func, # 用線程主函數(shù)來(lái)執(zhí)行線程
                                               args=args)
              helper_thread.start()
      

      3.5.2 線程主函數(shù)

      recv_helper_thread 和 send_helper_thread 分別是 接受助手線程 和 發(fā)送助手線程。分別調(diào)用 _recv 和 _send 來(lái)完成具體業(yè)務(wù)工作。

      需要注意的是使用函數(shù)來(lái)獲取對(duì)應(yīng)的參數(shù)。就是使用 recv_helper_thread_args 和 send_helper_thread_args 來(lái)獲取參數(shù)。

      def recv_helper_thread(queue, counter, local_rank, tensor_name,
                             src_rank, tag, tensor_shape, dtype,
                             sub_process_group, num_iterations):
          torch.cuda.set_device(local_rank)
          # This method is to be executed from a helper daemon thread.
          for i in range(num_iterations):
              tensor = _recv(
                  tensor_name, src_rank, tensor_shape=tensor_shape,
                  dtype=dtype, tag=tag,
                  sub_process_group=sub_process_group)
              queue.add(tensor)
          counter.decrement()
      
      def send_helper_thread(queue, counter, local_rank, tensor_name,
                             src_rank, dst_rank, tag,
                             sub_process_group, num_iterations):
          torch.cuda.set_device(local_rank)
          # This method is to be executed from a helper daemon thread.
          for i in range(num_iterations):
              tensor = queue.remove()
              _send(tensor, tensor_name, src_rank, dst_rank,
                    tag=tag,
                    sub_process_group=sub_process_group)
          counter.decrement()
      

      3.5.3 構(gòu)建參數(shù)

      回憶一下,在 create_process_groups 方法中,有如下代碼,這里就給每一個(gè) tag 設(shè)定了 進(jìn)程組,在助手線程之中,就要利用這些進(jìn)程組來(lái)完成邏輯:

      if tag not in self.process_groups[min_rank][max_rank]:
      	sub_process_group_fwd = dist.new_group(ranks=[min_rank, max_rank])
          sub_process_group_bwd = dist.new_group(ranks=[min_rank, max_rank])
      
      	self.process_groups[min_rank][max_rank][tag] = {
          	'forward': sub_process_group_fwd,
              'backward': sub_process_group_bwd
      	}
      

      使用如下函數(shù)來(lái)完成對(duì)線程主函數(shù)參數(shù)的獲取。基本邏輯就是:

      • 利用張量名字,獲取到對(duì)應(yīng)的rank
      • 利用張量名字,獲取到對(duì)應(yīng)的tag
      • 使用tag來(lái)獲取到對(duì)應(yīng)的進(jìn)程組
      • 利用張量名字和index得到對(duì)應(yīng)的queue
      • 返回參數(shù)
          def recv_helper_thread_args(self, tensor_name, index, dtype,
                                      backward, num_iterations):
              # 利用張量名字,獲取到對(duì)應(yīng)的rank
              if backward:
                  src_rank = self.send_ranks[tensor_name][index]
              else:
                  src_rank = self.receive_ranks[tensor_name][index]
      
              # 利用張量名字,獲取到對(duì)應(yīng)的tag
              sub_process_group = None
              tag = self.tensor_tags[tensor_name]
              
              # 使用tag來(lái)獲取到對(duì)應(yīng)的進(jìn)程組
              if self.is_gpu_to_gpu_comm(connected_rank=src_rank) and tensor_name != "ack":
                  min_rank = min(self.rank, src_rank)
                  max_rank = max(self.rank, src_rank)
                  if src_rank > self.rank:
                      sub_process_group = \
                          self.process_groups[min_rank][max_rank][tag]['backward']
                  else:
                      sub_process_group = \
                          self.process_groups[min_rank][max_rank][tag]['forward']
                  assert sub_process_group
      
              # 得到對(duì)應(yīng)的queue
              if backward:
                  queue = self.backward_receive_queues[tensor_name][index]
              else:
                  queue = self.forward_receive_queues[tensor_name][index]
              tensor_shape = self.tensor_shapes[tensor_name]
      
              # 返回參數(shù)
              return (queue, self.counter, self.local_rank, tensor_name,
                      src_rank, tag, tensor_shape, dtype, sub_process_group,
                      num_iterations)
      
          def send_helper_thread_args(self, tensor_name, index,
                                      backward, num_iterations):
              # 利用張量名字得到對(duì)應(yīng)的rank
              if backward:
                  dst_rank = self.receive_ranks[tensor_name][index]
                  num_ranks_in_connected_stage = self.num_ranks_in_previous_stage
              else:
                  dst_rank = self.send_ranks[tensor_name][index]
                  num_ranks_in_connected_stage = self.num_ranks_in_next_stage
      
              # 使用tag來(lái)獲取到對(duì)應(yīng)的進(jìn)程組
              sub_process_group = None
              tag = self.tensor_tags[tensor_name]
              if self.is_gpu_to_gpu_comm(connected_rank=dst_rank) and tensor_name != "ack":
                  min_rank = min(self.rank, dst_rank)
                  max_rank = max(self.rank, dst_rank)
                  if dst_rank > self.rank:
                      sub_process_group = \
                           self.process_groups[min_rank][max_rank][tag]['forward']
                  else:
                      sub_process_group = \
                          self.process_groups[min_rank][max_rank][tag]['backward']
                  assert sub_process_group
      
              # 得到對(duì)應(yīng)的queue
              if backward:
                  queue = self.backward_send_queues[tensor_name][index]
              else:
                  queue = self.forward_send_queues[tensor_name][index]
      
              # 返回參數(shù)
              return (queue, self.counter, self.local_rank, tensor_name, self.rank,
                      dst_rank, tag, sub_process_group, num_iterations)
      

      0x04 功能函數(shù)

      以下功能函數(shù)就是最終被使用完成 流水線 RPC 邏輯的函數(shù)。

      這里有一個(gè)通過(guò)queue完成的解耦合:

      • recv 和 send 就會(huì)對(duì)于 queue 進(jìn)行操作,往queue里面添加或者提取張量。
      • 助手線程會(huì)調(diào)用 _recv 和 _send 對(duì) queue 進(jìn)行操作。

      所以我們要先看看這個(gè)Queue的實(shí)現(xiàn),可以看到,無(wú)論是 add 還是 remove,都使用了 threading.Condition,就說(shuō)明幾個(gè)線程可以在 Queue 上通過(guò) add / remove 實(shí)現(xiàn)等待,阻塞,即生產(chǎn)者和消費(fèi)者。

      class Queue:
          def __init__(self):
              self.queue = []
              self.cv = threading.Condition()
      
          def add(self, tensor):
              self.cv.acquire()
              self.queue.append(tensor)
              self.cv.notify()
              self.cv.release()
      
          def remove(self):
              self.cv.acquire()
              while len(self.queue) == 0:
                  self.cv.wait()
              tensor = self.queue.pop(0)
              self.cv.release()
              return tensor
      

      4.1 發(fā)送邏輯

      發(fā)送的邏輯如下:

      1. 訓(xùn)練代碼會(huì)調(diào)用StageRuntime.run_backward。
      2. StageRuntime.run_backward 方法會(huì)調(diào)用 StageRuntime.send_tensors_backward 來(lái)發(fā)送張量 tensor_name。
      3. send_tensors_backward 會(huì)調(diào)用 CommunicationHandler.send 來(lái)向 CommunicationHandler 的成員變量backward_send_queues[tensor_name] [index] 添加這個(gè)張量。每個(gè)張量對(duì)應(yīng)了若干個(gè)queue。這里就是個(gè)解耦合。
      4. send 函數(shù) 會(huì)調(diào)用 backward_send_queues.add,這里會(huì)通知阻塞在queue上的 send_helper_thread 進(jìn)行工作。
      5. 在 CommunicationHandler 的線程 send_helper_thread 中,之前就阻塞在queue這里,此時(shí)會(huì)從 backward_send_queues[tensor_name] [index] 之中提取張量。
      6. send_helper_thread 會(huì)調(diào)用 _send 來(lái)發(fā)送張量。
      7. 而最終調(diào)用的是 dist.send,就是PyTorch P2P。

      具體如下圖:

       StageRuntime            CommunicationHandler              send_helper_thread
      
            +                           +                                 +
            |                           |                                 |
            | 1                         |                                 |
            v                           |                                 |
       run_backward                     |                                 |
            |                           |                                 |
            | 2                         |                                 |
            |                           |                    wait on backward_send_queues
            v                  3        v                                 |
      send_tensors_backward +--------> send                               |
                                        |                                 |
                                        |                                 |
                                        |  4                              |
                                        v               5                 v
                     backward_send_queues.add(tensor) +----> tensor = queue.remove()
                                                      notify              |
                                                                          |
                                                                          | 6
                                                                          v
                                                                        _send
                                                                          |
                                                                          | 7
                                                                          |
                                                                          v
                                                                       dist.send
      
      

      4.2 接受邏輯

      接受邏輯如下:

      1. StageRuntime 訓(xùn)練代碼中調(diào)用 run_backward。
      2. run_backward 調(diào)用 receive_tensors_backward。
      3. receive_tensors_backward 調(diào)用 self.gradients[output_name] = self.comm_handler.recv 獲取梯度。CommunicationHandler 的 recv 函數(shù)會(huì)阻塞在 backward_receive_queues[tensor_name] [index] 之上
      4. 同時(shí),CommunicationHandler 的 recv_helper_thread 線程調(diào)用 _recv 接受其他stage點(diǎn)傳來(lái)的張量。
      5. _recv調(diào)用 dist.recv 或者 dist.broadcast 接受張量。
      6. _recv 向 backward_receive_queues[tensor_name] [index] 添加張量。這樣就通知阻塞的 CommunicationHandler 的 recv 函數(shù)進(jìn)行工作
      7. CommunicationHandler 的 recv 函數(shù)會(huì)從backward_receive_queues[tensor_name] [index] 提取梯度,然后返回給 StageRuntime。就是 3 的返回。

      具體如下圖:

          StageRuntime             CommunicationHandler           recv_helper_thread
                +                            +                            +
                |                            |                            |
                | 1                          |                            |
                |                            |                            | 4
                v                            |                            v
          run_backward                       |                         _recv
                |                            |                            |
                |                            |                            |
                |                            |                            | 5
                |                            |                            |
                | 2                          |                            v
                |                            |                  dist.recv / dist.broadcast
                |                            |                            |
                v                  3         v                            |
      receive_tensors_backward +--------->  recv                          |
                +                            |                            |
                |                            |                            |
                |                            |                            |
                |                            |                            |
                |                            v                            |
                |                 backward_receive_queues.remove()        |
                |                            |                            |
                |                            |                            |
                |                            |                            |
                |                            |                            |
                |               wait on backward_receive_queues           |
                |                            |                            |
                |                            |                            |
                |                            |                            |
                |                            |                 6          v
                |                  backward_receive_queues <-------+ queue.add(tensor)
                |                            |               notify
                |                            |  7
                v                  3 return  |
      gradients[output_name] <---------------+
      
      

      4.3 recv

      這里其實(shí)就是從對(duì)應(yīng)的queue之中,依據(jù)張量名字來(lái)獲取對(duì)應(yīng)的張量。

          def recv(self, tensor_name, forward_minibatch_id,
                   backward_minibatch_id, backward=False):
              if backward:
                  index = (backward_minibatch_id + self.rank_in_stage) % \
                      len(self.backward_receive_queues[tensor_name])
                  tensor = self.backward_receive_queues[tensor_name][
                      index].remove()
                  return tensor
              else:
                  # 前向時(shí)候,需要知道從前一層的哪一個(gè)index獲取
                  index = self.get_messaging_index(sending=False)
                  tensor = self.forward_receive_queues[tensor_name][
                      index].remove()
                  if tensor.dtype == torch.float32:
                      tensor = tensor.requires_grad_()
                  return tensor
      

      在運(yùn)行時(shí) receive_tensors_forward,receive_tensors_backward 函數(shù)中,會(huì)調(diào)用到 recv 函數(shù),從對(duì)應(yīng)的queue 拿到已經(jīng)存的張量。比如:

          def receive_tensors_backward(self):
              # Receive all required gradients from downstream
              # machines.
              for output_name in self.send_ranks:
                   if output_name in self.target_tensor_names:
                      continue
      
                   self.gradients[output_name] = \
                      self.comm_handler.recv( # 這里使用了
                          output_name,
                          forward_minibatch_id=self.forward_minibatch_id,
                          backward_minibatch_id=self.backward_minibatch_id,
                          backward=True)
      
                   self.backward_stats.stats['receive_tensors_size'] += \
                       (self.gradients[output_name].element_size() *
                        self.gradients[output_name].nelement())
      

      4.4 send

      這里是把張量放置在對(duì)應(yīng)的queue之中。

          def send(self, tensor_name, tensor, forward_minibatch_id,
                   backward_minibatch_id, backward=False):
              if backward:
                  # 后向時(shí)候,需要知道發(fā)送給前一層的哪一個(gè)index
                  index = self.get_messaging_index(sending=True)
                  dst_rank = self.receive_ranks[tensor_name][index]
                  self.backward_send_queues[tensor_name][index].add(tensor)
              else:
                  index = (forward_minibatch_id + self.rank_in_stage) % \
                      len(self.send_ranks[tensor_name])
                  self.forward_send_queues[tensor_name][index].add(tensor)
      
      
      

      send_tensors_backward,send_tensors_forward 之中會(huì)使用,比如:

          def send_tensors_backward(self):
              # Send all required gradients upstream.
              for input_name in self.receive_ranks:
                  if input_name in self.target_tensor_names:
                      continue
      
                  self.comm_handler.send(
                      input_name,
                      self.gradients[input_name],
                      forward_minibatch_id=self.forward_minibatch_id,
                      backward_minibatch_id=self.backward_minibatch_id,
                      backward=True)
      
                  self.backward_stats.stats['send_tensors_size'] += \
                      (self.gradients[input_name].element_size() *
                       self.gradients[input_name].nelement())
      
              if self.num_ranks_in_previous_stage > 0:
                  # Used to track where to send tensors in the
                  # backward pass.
                  self.comm_handler.increment_messaging_index(
                      sending=True)
      

      4.5 _recv

      _recv 參數(shù)中,sub_process_group 就是上面代碼中構(gòu)建的。

      如果在同一個(gè)節(jié)點(diǎn)上,就使用dist.broadcast,否則使用dist.recv。

      def _recv(tensor_name, src_rank, tensor_shape=None, dtype=torch.float32,
                tensor=None, tag=None, sub_process_group=None):
          """
          Receives tensor by calling PyTorch's recv() call.
      
          Tensor will be copied to GPU prior to return.
          """
          assert tag is not None
          if tensor is None:
              assert tensor_shape is not None
              assert dtype is not None
              assert dtype != torch.float16
      
          if sub_process_group is not None:
              # Receive tensor shape.
              received_tensor_shape = torch.zeros(len(tensor_shape),
                                                  dtype=torch.int)
              dist.broadcast(tensor=received_tensor_shape,
                             src=src_rank,
                             group=sub_process_group)
              received_tensor_shape = list(map(lambda x: int(x),
                                               received_tensor_shape))
      
              # Receive tensor.
              tensor = torch.zeros(received_tensor_shape, dtype=dtype).cuda()
              dist.broadcast(tensor=tensor,
                             src=src_rank,
                             group=sub_process_group)
          else:
              # Receive tensor shape.
              received_tensor_shape = torch.zeros(len(tensor_shape),
                                                  dtype=torch.int)
              dist.recv(tensor=received_tensor_shape,
                        src=src_rank,
                        tag=tag)
              received_tensor_shape = list(map(lambda x: int(x),
                                               received_tensor_shape))
      
              # Receive tensor.
              tensor = torch.zeros(received_tensor_shape, dtype=dtype)
              dist.recv(tensor=tensor,
                        src=src_rank,
                        tag=tag)
              tensor = tensor.cuda()
      
          assert tensor.is_cuda
          return tensor
      

      在 recv_helper_thread 之中會(huì)調(diào)用 _recv。

      def recv_helper_thread(queue, counter, local_rank, tensor_name,
                             src_rank, tag, tensor_shape, dtype,
                             sub_process_group, num_iterations):
          torch.cuda.set_device(local_rank)
          # This method is to be executed from a helper daemon thread.
          for i in range(num_iterations):
              tensor = _recv(
                  tensor_name, src_rank, tensor_shape=tensor_shape,
                  dtype=dtype, tag=tag,
                  sub_process_group=sub_process_group)
              queue.add(tensor) # 獲取到張量之后,放入queue
          counter.decrement()
      

      4.6 _send

      如果在同一個(gè)節(jié)點(diǎn)上,就使用dist.broadcast,否則使用dist.send。

      def _send(tensor, tensor_name, src_rank, dst_rank, tag, sub_process_group=None):
          """
          Sends tensor by calling PyTorch's send() call.
      
          If tensor is being sent not via broadcast(), it will
          be first copied to the CPU.
          """
          if sub_process_group is not None:
              assert tensor.is_cuda
      
              # Send tensor shape.
              tensor_shape = torch.tensor(tensor.shape, dtype=torch.int)
              dist.broadcast(tensor=tensor_shape, src=src_rank,
                            group=sub_process_group)
      
              # Send tensor.
              contiguous_tensor = tensor.detach().clone()
              dist.broadcast(tensor=contiguous_tensor.contiguous(),
                             src=src_rank,
                             group=sub_process_group)
          else:
              assert tensor.is_cuda
              tensor = tensor.cpu()
      
              # Send tensor shape.
              tensor_shape = torch.tensor(tensor.shape, dtype=torch.int)
              dist.send(tensor=tensor_shape, dst=dst_rank, tag=tag)
      
              # Send tensor.
              dist.send(tensor=tensor, dst=dst_rank, tag=tag)
      

      recv_helper_thread 使用 _send獲取張量。

      def send_helper_thread(queue, counter, local_rank, tensor_name,
                             src_rank, dst_rank, tag,
                             sub_process_group, num_iterations):
          torch.cuda.set_device(local_rank)
          # This method is to be executed from a helper daemon thread.
          for i in range(num_iterations):
              tensor = queue.remove()
              # 從queue提取張量,發(fā)送出去。
              _send(tensor, tensor_name, src_rank, dst_rank,
                    tag=tag,
                    sub_process_group=sub_process_group)
          counter.decrement()
      

      至此,通信模塊已經(jīng)分析完畢,下一篇終于要介紹 1F1B 了。

      0xFF 參考

      posted @ 2021-09-13 10:13  羅西的思考  閱讀(1444)  評(píng)論(0)    收藏  舉報(bào)
      主站蜘蛛池模板: 人妻av中文字幕无码专区| 好吊妞| 婷婷六月天在线| 亚洲熟妇自偷自拍另欧美| 欧美激情在线播放| 亚洲精品综合久久国产二区| 午夜射精日本三级| 周宁县| 国产亚洲综合区成人国产| 国产午夜福利片在线观看| 少妇无码av无码专区| 在线视频中文字幕二区| 军人粗大的内捧猛烈进出视频 | 国产无遮挡又黄又爽免费网站| 久久精品国产91精品亚洲| 国产明星精品无码AV换脸 | 四虎亚洲国产成人久久精品| 国产内射xxxxx在线| 国产精品高潮无码毛片| 四虎库影成人在线播放| 中文字幕日韩精品国产| 老司机免费的精品视频| 国产乱码精品一区二区三| 国产地址二永久伊甸园| 亚洲自拍偷拍中文字幕色| 伊人久久精品无码麻豆一区| 乱人伦人妻中文字幕不卡| 亚洲午夜成人精品电影在线观看| 亚洲av永久无码精品天堂久久| 国产精品黄色一区二区三区| 色偷偷亚洲男人的天堂| japanese无码中文字幕| 亚洲欧美电影在线一区二区| 欧美乱码卡一卡二卡四卡免费| 亚洲中文精品一区二区| 乱子伦视频在线看| 少妇人妻挤奶水中文视频毛片| 在线中文字幕国产精品| 和政县| 99精品日本二区留学生| 日韩激情成人|