<output id="qn6qe"></output>

    1. <output id="qn6qe"><tt id="qn6qe"></tt></output>
    2. <strike id="qn6qe"></strike>

      亚洲 日本 欧洲 欧美 视频,日韩中文字幕有码av,一本一道av中文字幕无码,国产线播放免费人成视频播放,人妻少妇偷人无码视频,日夜啪啪一区二区三区,国产尤物精品自在拍视频首页,久热这里只有精品12

      聯邦學習中的模型聚合

      論文[1]在聯邦學習的情景下引入了多任務學習,其采用的手段是使每個client/task節點的訓練數據分布不同,從而使各任務節點學習到不同的模型,且每個任務節點以及全局(global)的模型都由多個分量模型集成。該論文最關鍵與核心的地方在于將各任務節點學習到的模型進行聚合/通信,依據模型聚合方式的不同,可以將模型采用的算法分為client-server方法,和fully decentralized(完全去中心化)的方法(其實還有其他的聚合方法沒,如論文[3]提出的簇狀聚合方法,代碼參見[4]我們這里暫時略過),其中這兩種方法在具體實現上都可以替換為對代理損失函數的優化,不過我們這里暫時略過。

      因為有多種任務聚合器(Aggregator)要實現,論文代碼(已開源在Github上,參見[2])采取的措施是先實現Aggregator抽象基類,實現好一些通用方法,并規定好抽象方法的接口,然后具體的任務聚合類繼承抽象基類,然后做具體的實現。

      我們先來看任務聚合器(Aggregator)這一抽象基類

      class Aggregator(ABC):
          r"""Aggregator的基類. `Aggregator`規定了client之間的通信"""
          def __init__(
                  self,
                  clients,
                  global_learners_ensemble,
                  log_freq,
                  global_train_logger,
                  global_test_logger,
                  sampling_rate=1.,
                  sample_with_replacement=False,
                  test_clients=None,
                  verbose=0,
                  seed=None,
                  *args,
                  **kwargs
          ):
      
              rng_seed = (seed if (seed is not None and seed >= 0) else int(time.time()))
              self.rng = random.Random(rng_seed) # 隨機數生成器
              self.np_rng = np.random.default_rng(rng_seed) # numpy隨機數生成器
      
              if test_clients is None:
                  test_clients = []
      
              self.clients = clients #  List[Client]
              self.test_clients = test_clients #  List[Client]
      
              self.global_learners_ensemble = global_learners_ensemble # List[Learner]
              self.device = self.global_learners_ensemble.device
      
      
              self.log_freq = log_freq
              self.verbose = verbose
              # verbose: 調整輸出打印的冗余度(verbosity), 
              # `0` 表示quiet(無任何打印輸出), `1` 顯示日志, `2` 顯示所有局部日志; 默認是 `0`
              self.global_train_logger = global_train_logger
              self.global_test_logger = global_test_logger
      
              self.model_dim = self.global_learners_ensemble.model_dim # #模型特征維度
      
              self.n_clients = len(clients)
              self.n_test_clients = len(test_clients)
              self.n_learners = len(self.global_learners_ensemble)
      
              # 存儲為每個client分配的權重(權重為0-1之間的小數)
              self.clients_weights =\
                  torch.tensor(
                      [client.n_train_samples for client in self.clients],
                      dtype=torch.float32
                  )
              self.clients_weights = self.clients_weights / self.clients_weights.sum()
      
              self.sampling_rate = sampling_rate  #  clients在每一輪使用的比例,默認為`1.`
              self.sample_with_replacement = sample_with_replacement #對client進行采用是可重復還是無重復的,with_replacement=True表示可重復的,否則是不可重復的
      
              # 每輪迭代需要使用到的client個數
              self.n_clients_per_round = max(1, int(self.sampling_rate * self.n_clients))
      
              # 采樣得到的client列表
              self.sampled_clients = list()
      
              # 記載當前的迭代通信輪數
              self.c_round = 0 
              self.write_logs()
      
          @abstractmethod
          def mix(self): 
              """
              該方法用于完成各client之間的權重參數與通信操作
              """
              pass
      
          @abstractmethod
          def update_clients(self): 
              """
              該方法用于將所有全局分量模型拷貝到各個client,相當于boardcast操作
              """
              pass
      
          def update_test_clients(self):
              """
              將全局(gobal)的所有分量模型都拷貝到各個client上
              """
      
          def write_logs(self):
              """
              對全局(global)的train和test數據集的loss和acc做記錄
              需要對所有client的所有樣本做累加,然后除以所有client的樣本總數做平均。
              """
      
          def save_state(self, dir_path):
              """
              保存aggregator的模型state,。例如, `global_learners_ensemble`中每個分量模型'learner'的state字典(以`.pt`文件格式),以及`self.clients` 中每個client的 `learners_weights` (注意,這個權重不是模型內部的參數,而是進行繼承的時候對各個分量模型賦予的權重,包含train和test兩部分,以一個大小為n_clients(n_test_clients)× n_learners的numpy數組的格式,即`.npy` 文件)。
              """
      
          def load_state(self, dir_path):
              """
              加載aggregator的模型state,即save_state方法里保存的那些
              """
      
          def sample_clients(self):
              """
              對clients進行采樣,
              如果self.sample_with_replacement為True,則為可重復采樣,
              否則,則為不可重復采用。
              最終得到一個clients子集列表并賦予self.sampled_clients
              """
      

      1.client-server 算法

      這種方式的通信/聚合方法也稱中心化(centralized)方法,因為該方法在每一輪迭代最后將所有client的權重數據匯集到server節點。這種方法的優化迭代部分的偽代碼示意如下:
      CV多任務學習

      落實到具體代碼實現上,這種方法的Aggregator設計如下:

      class CentralizedAggregator(Aggregator):
          r""" 標準的中心化Aggreagator
          所有clients在每一輪迭代末和average client完全同步.
          """
          def mix(self):
              self.sample_clients()
      
              # 對self.sampled_clients中每個client的參數進行優化
              for client in self.sampled_clients:
                  # 相當于偽代碼第11行調用的LocalSolver函數
                  client.step()
      
              # 遍歷global模型(self.global_learners_ensemble) 中每一個分量模型(learner)
              # 相當于偽代碼第13行
              for learner_id, learner in enumerate(self.global_learners_ensemble):
                  # 獲取所有client中對應learner_id的分量模型
                  learners = [client.learners_ensemble[learner_id] for client in self.clients]
                  # global模型的分量模型為所有client對應分量模型取平均,相當于偽代碼第14行
                  average_learners(learners, learner, weights=self.clients_weights)
      
              # 將更新后的模型賦予所有clients,相當于偽代碼第5行的boardcast操作
              self.update_clients()
      
              # 通信輪數+1
              self.c_round += 1
      
              if self.c_round % self.log_freq == 0:
                  self.write_logs()
      
          def update_clients(self):
              """
              此函數負責將所有全局分量模型拷貝到各個client,相當于偽代碼中第5行的boardcast操作
              """
              for client in self.clients:
                  for learner_id, learner in enumerate(client.learners_ensemble):
                      copy_model(learner.model, self.global_learners_ensemble[learner_id].model)
      
                      if callable(getattr(learner.optimizer, "set_initial_params", None)):
                          learner.optimizer.set_initial_params(
                              self.global_learners_ensemble[learner_id].model.parameters()
                          )
      

      2. fully decentralized(完全去中心化)算法

      這種方法之所以被稱為去中心化的,因為該方法在每一輪迭代不需要所有client的權重數據匯集到一個特定的server節點,而只需要完成每個節點和其鄰居進行通信(參數共享)即可。這種方法的優化迭代部分的偽代碼示意如下:
      CV多任務學習
      落實到具體代碼實現上,這種方法的Aggregator設計如下:

      class DecentralizedAggregator(Aggregator):
          def __init__(
                  self,
                  clients,
                  global_learners_ensemble,
                  mixing_matrix,
                  log_freq,
                  global_train_logger,
                  global_test_logger,
                  sampling_rate=1.,
                  sample_with_replacement=True,
                  test_clients=None,
                  verbose=0,
                  seed=None):
      
              super(DecentralizedAggregator, self).__init__(
                  clients=clients,
                  global_learners_ensemble=global_learners_ensemble,
                  log_freq=log_freq,
                  global_train_logger=global_train_logger,
                  global_test_logger=global_test_logger,
                  sampling_rate=sampling_rate,
                  sample_with_replacement=sample_with_replacement,
                  test_clients=test_clients,
                  verbose=verbose,
                  seed=seed
              )
      
              self.mixing_matrix = mixing_matrix
              assert self.sampling_rate >= 1, "partial sampling is not supported with DecentralizedAggregator"
      
          def update_clients(self):
              pass
      
          def mix(self):
              
              # 對各clients的模型參數進行優化
              for client in self.clients:
                  client.step()
      
              # 存儲每個模型各參數混合的權重
              # 行對應不同的client,列對應單個模型中不同的參數
              # (注意:每個分量有獨立的mixing_matrix)
              mixing_matrix = torch.tensor(
                  self.mixing_matrix.copy(),
                  dtype=torch.float32,
                  device=self.device
              )
      
              # 遍歷global模型(self.global_learners_ensemble) 中每一個分量模型(learner)
              # 相當于偽代碼第14行
              for learner_id, global_learner in enumerate(self.global_learners_ensemble):
                  # 用于將指定learner_id的各client的模型state讀出暫存
                  state_dicts = [client.learners_ensemble[learner_id].model.state_dict() for client in self.clients]
      
                  # 遍歷global模型中的各參數, key對應模型中參數的名稱
                  for key, param in global_learner.model.state_dict().items():
                      shape_ = param.shape
                      models_params = torch.zeros(self.n_clients, int(np.prod(shape_)), device=self.device)
      
                      for ii, sd in enumerate(state_dicts):
                          # models_params的第ii個下標存儲的是第ii個client的(名為key的)參數
                          models_params[ii] = sd[key].view(1, -1) 
      
                      # models_params的每一行是一個client的參數
                      # @符號表示矩陣乘/矩陣向量乘
                      # 故這里表示每個client參數是其他所有client參數的混合
                      models_params = mixing_matrix @ models_params
      
                      for ii, sd in enumerate(state_dicts):
                          # 將第ii個client的(名為key的)參數存入state_dicts中對應位置
                          sd[key] = models_params[ii].view(shape_)
      
                  # 將更新好的參數從state_dicts存入各client節點的模型中
                  for client_id, client in enumerate(self.clients):
                      client.learners_ensemble[learner_id].model.load_state_dict(state_dicts[client_id])
      
              # 通信輪數+1
              self.c_round += 1
      
              if self.c_round % self.log_freq == 0:
                  self.write_logs()
      

      參考文獻

      posted @ 2021-12-02 22:45  orion-orion  閱讀(6365)  評論(3)    收藏  舉報
      主站蜘蛛池模板: 欧洲精品码一区二区三区| 热久在线免费观看视频| 九九热在线精品视频观看| 亚洲欧美人成人让影院| 白丝乳交内射一二三区| 国内少妇人妻丰满av| 久久综合伊人77777| 久久精品熟女亚洲av麻| 中国国产一级毛片| 亚洲日韩欧洲乱码av夜夜摸| 无码日韩av一区二区三区| 午夜福利免费视频一区二区| 国产资源精品中文字幕| 最新精品国偷自产在线| 久久精品国产福利一区二区| 精品一区精品二区制服| 99久久久无码国产精品免费| 人成午夜免费大片| 欧美高清一区三区在线专区| 久久99精品中文字幕在| 东京热人妻无码一区二区av| 国产精品 第一页第二页| 1769国内精品视频在线播放 | 国产四虎永久免费观看| 人妻内射视频麻豆| 亚洲精品麻豆一二三区| 国产内射xxxxx在线| 青草热在线观看精品视频| 欧美孕妇乳喷奶水在线观看| 国产成人精品2021欧美日韩| 欧美交a欧美精品喷水| 激情的视频一区二区三区| 国产精品福利中文字幕| 美女裸体视频永久免费| 亚洲av成人午夜福利| 精品熟女日韩中文十区| 久久久久成人片免费观看蜜芽| 北川| 久久这里只精品国产2| 国产精品国产三级国产试看| 国产亚洲一区二区三区av|