<output id="qn6qe"></output>

    1. <output id="qn6qe"><tt id="qn6qe"></tt></output>
    2. <strike id="qn6qe"></strike>

      亚洲 日本 欧洲 欧美 视频,日韩中文字幕有码av,一本一道av中文字幕无码,国产线播放免费人成视频播放,人妻少妇偷人无码视频,日夜啪啪一区二区三区,国产尤物精品自在拍视频首页,久热这里只有精品12

      [源碼解析] TensorFlow 分布式環境(6) --- Master 動態邏輯

      [源碼解析] TensorFlow 分布式環境(6) --- Master 動態邏輯

      在具體介紹 TensorFlow 分布式的各種 Strategy 之前,我們首先需要看看分布式的基礎:分布式環境。只有把基礎打扎實了,才能在以后的分析工作之中最大程度的掃清障礙,事半功倍。本文會從 Client 開始,看看 Master 如何對計算圖進行處理。

      本文依舊深度借鑒了兩位大神:

      本系列其他文章是:

      [翻譯] TensorFlow 分布式之論文篇 "TensorFlow : Large-Scale Machine Learning on Heterogeneous Distributed Systems"

      [翻譯] TensorFlow 分布式之論文篇 "Implementation of Control Flow in TensorFlow"

      [源碼解析] TensorFlow 分布式環境(1) --- 總體架構

      [源碼解析] TensorFlow 分布式環境(2)---Master 靜態邏輯

      [源碼解析] TensorFlow 分布式環境(3)--- Worker 靜態邏輯

      [源碼解析] TensorFlow 分布式環境(4) --- WorkerCache

      [源碼解析] TensorFlow 分布式環境(5) --- Session

      1. GrpcSession

      1.1 運行

      首先,客戶會調用 GrpcSession 來開始運行,而 Run 方法會調用 RunHelper。

      Status GrpcSession::Run(const RunOptions& run_options,
                              const std::vector<std::pair<string, Tensor>>& inputs,
                              const std::vector<string>& output_tensor_names,
                              const std::vector<string>& target_node_names,
                              std::vector<Tensor>* outputs,
                              RunMetadata* run_metadata) {
        return RunHelper(run_options, inputs, output_tensor_names, target_node_names,
                         outputs, run_metadata, /* prun_handle */ "");
      }
      

      RunHelper 方法如下,這里重要的是添加 feed 和 fetch,然后調用 RunProto 運行 session。

      Status GrpcSession::RunHelper(
          const RunOptions& run_options,
          const std::vector<std::pair<string, Tensor>>& inputs,
          const std::vector<string>& output_tensor_names,
          const std::vector<string>& target_node_names, std::vector<Tensor>* outputs,
          RunMetadata* run_metadata, const string& prun_handle) {
        // Convert to proto
        std::unique_ptr<MutableRunStepRequestWrapper> req(
            master_->CreateRunStepRequest());
        std::unique_ptr<MutableRunStepResponseWrapper> resp(
            master_->CreateRunStepResponse());
      
        *req->mutable_options() = run_options;
      
        if (run_options.timeout_in_ms() == 0) {
          req->mutable_options()->set_timeout_in_ms(
              options_.config.operation_timeout_in_ms());
        }
      
        if (!prun_handle.empty()) {
          req->set_partial_run_handle(prun_handle);
        }
      
        for (const auto& it : inputs) {
          req->add_feed(it.first, it.second);
        }
      
        // Support long error messages by storing the error code in the response body.
        req->set_store_errors_in_response_body(true);
      
        // Build an index from fetch tensor name to first index in
        // output_tensor_names.
        std::unordered_map<string, int> output_name_to_offset;
        for (int i = 0, end = output_tensor_names.size(); i < end; ++i) {
          const string& name = output_tensor_names[i];
          if (output_name_to_offset.insert(std::make_pair(name, i)).second) {
            req->add_fetch(name);
          }
        }
        for (const string& target : target_node_names) {
          req->add_target(target);
        }
      
        CallOptions call_options;
        call_options.SetTimeout(req->options().timeout_in_ms());
        
        // 調用 RunProto 運行session
        TF_RETURN_IF_ERROR(RunProto(&call_options, req.get(), resp.get()));
      
        // Look for an extended error returned in the response body.
        if (resp->status_code() != error::Code::OK) {
          return resp->status();
        }
      
        if (!output_tensor_names.empty()) {
          outputs->resize(output_tensor_names.size());
        }
      
        // Convert response back to Tensors in the correct order.
        for (size_t i = 0; i < resp->num_tensors(); ++i) {
          auto fetch_it = output_name_to_offset.find(resp->tensor_name(i));
          if (fetch_it == output_name_to_offset.end()) {
            return errors::Internal("Received response for unrequested fetch: ",
                                    resp->tensor_name(i));
          }
      
          Tensor output;
          TF_RETURN_IF_ERROR(resp->TensorValue(i, &output));
          (*outputs)[fetch_it->second] = output;
        }
        // In the unlikely event that output_tensor_names contains duplicates, fill in
        // the duplicate values.
        if (output_name_to_offset.size() != output_tensor_names.size()) {
          for (int i = 0, end = output_tensor_names.size(); i < end; ++i) {
            const string& name = output_tensor_names[i];
            int offset = output_name_to_offset[name];
            if (offset != i) {
              (*outputs)[i] = (*outputs)[offset];
            }
          }
        }
      
        if (run_metadata) {
          run_metadata->Swap(resp->mutable_metadata());
        }
      
        return Status::OK();
      }
      

      最終 RunProto 還是調用到 master_->RunStep 完成業務功能。

      Status GrpcSession::RunProto(CallOptions* call_options,
                                   MutableRunStepRequestWrapper* req,
                                   MutableRunStepResponseWrapper* resp) {
        string handle;
        TF_RETURN_IF_ERROR(Handle(&handle));
        req->set_session_handle(handle);
        return master_->RunStep(call_options, req, resp);
      }
      

      master_ 就是 GrpcRemoteMaster,所以我們接著看下去。

      1.2 GrpcRemoteMaster

      GrpcRemoteMaster 是位于 Client 的 gRPC 客戶端實現,它的 RunStep 方法只是通過 gRPC stub 來調用 遠端服務 MasterService 的 RunStep 接口,其實就是發送一個 RunStepRequest 請求。

      Status RunStep(CallOptions* call_options, RunStepRequestWrapper* request,
                     MutableRunStepResponseWrapper* response) override {
        return CallWithRetry(call_options, &request->ToProto(),
                             get_proto_from_wrapper(response),
                             &MasterServiceStub::RunStep, "RunStep/Client");
      }
      

      于是,此時 Client 的邏輯拓展如下:

      圖 1 Master 動態邏輯 1

      2. Master

      從現在開始,我們進入到了 Master 角色對應的服務器。GrpcMasterService 運行的是 gRPC 服務,當收到 RunStepRequest 時候,系統會調用到 RunStepHandler。代碼位于:tensorflow/core/distributed_runtime/rpc/grpc_master_service.cc。

      // RPC handler for running one step in a session.
      void RunStepHandler(MasterCall<RunStepRequest, RunStepResponse>* call) {
        auto* trace = TraceRpc("RunStep/Server", call->client_metadata());
        CallOptions* call_opts = new CallOptions;
        if (call->request.options().timeout_in_ms() > 0) {
          call_opts->SetTimeout(call->request.options().timeout_in_ms());
        } else {
          call_opts->SetTimeout(default_session_config_.operation_timeout_in_ms());
        }
        RunStepRequestWrapper* wrapped_request =
            new ProtoRunStepRequest(&call->request);
        MutableRunStepResponseWrapper* wrapped_response =
            new NonOwnedProtoRunStepResponse(&call->response);
        call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
        master_impl_->RunStep(
            call_opts, wrapped_request, wrapped_response,
            [call, call_opts, wrapped_request, trace](const Status& status) {
              call->ClearCancelCallback();
              delete call_opts;
              delete wrapped_request;
              delete trace;
              if (call->request.store_errors_in_response_body() && !status.ok()) {
                call->response.set_status_code(status.code());
                call->response.set_status_error_message(status.error_message());
                call->SendResponse(ToGrpcStatus(Status::OK()));
              } else {
                call->SendResponse(ToGrpcStatus(status));
              }
            });
        ENQUEUE_REQUEST(RunStep, true);
      }
      

      master_impl_ 是 Master 實例,RunStep 會調用master session進行計算。

      void Master::RunStep(CallOptions* opts, const RunStepRequestWrapper* req,
                           MutableRunStepResponseWrapper* resp, MyClosure done) {
      
        // 獲取session
        auto session = FindMasterSession(req->session_handle());
      
        // 運行session
        SchedClosure([this, start_time, session, opts, req, resp, done]() {
          Status status = session->Run(opts, *req, resp); 
        });
      }
      

      現在我們正式進入到 Master 的業務邏輯,接下來就看看如何進一步處理。

      2.1 總體概述

      我們先來做一下總體概述。在 Master 上:

      • 首先完成對 FullGraph 的剪枝,生成 ClientGraph。
      • 然后,按照 Worker 維度將 ClientGraph 切分為多個 PartitionGraph。
      • 最后,將 PartitionGraph 列表注冊給各個 Worker(這里有一個 RPC 操作),并啟動各個 Worker 對 PartitionGraph 列表進行并發執行(這里有一個 RPC 操作)。

      結合代碼來看如下。首先,Master 會調用 FindMasterSession 找到 session_handle 對應的 MasterSession,這之后,邏輯就由 MasterSession 來接管。

      MasterSession* Master::FindMasterSession(const string& handle) {
        MasterSession* session = nullptr;
        {
          mutex_lock l(mu_);
          session = gtl::FindPtrOrNull(sessions_, handle);
          if (session != nullptr) {
            session->Ref();
          }
        }
        return session;
      }
      

      其次,MasterSession::Run 有兩種調用可能,我們這里選擇 DoRunWithLocalExecution 來分析。

      Status MasterSession::Run(CallOptions* opts, const RunStepRequestWrapper& req,
                                MutableRunStepResponseWrapper* resp) {
        UpdateLastAccessTime();
        {
          mutex_lock l(mu_);
          if (closed_) {
            return errors::FailedPrecondition("Session is closed.");
          }
          ++num_running_;
          // Note: all code paths must eventually call MarkRunCompletion()
          // in order to appropriate decrement the num_running_ counter.
        }
        Status status;
        if (!req.partial_run_handle().empty()) {
          status = DoPartialRun(opts, req, resp);
        } else {
          status = DoRunWithLocalExecution(opts, req, resp);
        }
        return status;
      }
      

      DoRunWithLocalExecution 會做三個主要操作:

      • StartStep 將調用 BuildGraph 來生成 ClientGraph,這里會進行剪枝。
      • BuildAndRegisterPartitions 將 計算圖按 location 不同切分為多個子圖。
      • RunPartitions 執行子圖。這里的一個子圖就對應一個 worker,就是對應一個 worker service。
      Status MasterSession::DoRunWithLocalExecution(
          CallOptions* opts, const RunStepRequestWrapper& req,
          MutableRunStepResponseWrapper* resp) {
      
        PerStepState pss;
        pss.start_micros = Env::Default()->NowMicros();
        auto cleanup = gtl::MakeCleanup([this] { MarkRunCompletion(); });
      
        // Prepare.
        BuildGraphOptions bgopts;
        BuildBuildGraphOptions(req, session_opts_.config, &bgopts);
        ReffedClientGraph* rcg = nullptr;
        int64 count;
        // StartStep 將調用 BuildGraph 來生成 ClientGraph,這里會進行剪枝
        TF_RETURN_IF_ERROR(StartStep(bgopts, false, &rcg, &count));
      
        // Unref "rcg" when out of scope.
        core::ScopedUnref unref(rcg);
      
        // 對計算圖進行切分
        TF_RETURN_IF_ERROR(BuildAndRegisterPartitions(rcg));
      
        // Keeps the highest 8 bits 0x01: we reserve some bits of the
        // step_id for future use.
        uint64 step_id = NewStepId(rcg->collective_graph_key());
      
        std::unique_ptr<ProfileHandler> ph;
        FillPerStepState(rcg, req.options(), step_id, count, &pss, &ph);
      
        if (pss.collect_partition_graphs &&
            session_opts_.config.experimental().disable_output_partition_graphs()) {
          return errors::InvalidArgument(
              "RunOptions.output_partition_graphs() is not supported when "
              "disable_output_partition_graphs is true.");
        }
      
        // 執行計算圖
        Status s = rcg->RunPartitions(env_, step_id, count, &pss, opts, req, resp,
                                      &cancellation_manager_, false);
      
        cleanup.release();  // MarkRunCompletion called in PostRunCleanup().
        return PostRunCleanup(rcg, step_id, req.options(), &pss, ph, s,
                              resp->mutable_metadata());
      }
      

      我們接下來對 DoRunWithLocalExecution 三個主要操作一一分析。

      2.2 建立 & 剪枝

      2.2.1 建立計算圖

      StartStep 關鍵是建立計算圖并且做剪枝。

      Status MasterSession::StartStep(const BuildGraphOptions& opts, bool is_partial,
                                      ReffedClientGraph** out_rcg,
                                      int64_t* out_count) {
        const uint64 hash = HashBuildGraphOptions(opts);
        {
          mutex_lock l(mu_);
          RCGMap* m = is_partial ? &partial_run_graphs_ : &run_graphs_;
          auto iter = m->find(hash);
          if (iter == m->end()) {
            // We have not seen this subgraph before. Build the subgraph and
            // cache it.
            std::unique_ptr<ClientGraph> client_graph;
            // 建立計算圖
            TF_RETURN_IF_ERROR(execution_state_->BuildGraph(opts, &client_graph));
            WorkerCacheInterface* worker_cache = get_worker_cache();
            auto entry = new ReffedClientGraph(
                handle_, opts, std::move(client_graph), session_opts_,
                stats_publisher_factory_, is_partial, worker_cache,
                !should_delete_worker_sessions_);
            iter = m->insert({hash, entry}).first;
          }
          *out_rcg = iter->second;
          (*out_rcg)->Ref();
          *out_count = (*out_rcg)->get_and_increment_execution_count();
        }
        return Status::OK();
      }
      

      2.2.2 剪枝

      BuildGraph 之中最關鍵的是調用 PruneGraph 進行剪枝。

      Status GraphExecutionState::BuildGraph(const BuildGraphOptions& options,
                                             std::unique_ptr<ClientGraph>* out) {
        // Grappler optimization might change the structure of a graph itself, and
        // also it can add/prune functions to/from the library.
        std::unique_ptr<Graph> optimized_graph;
        std::unique_ptr<FunctionLibraryDefinition> optimized_flib;
      
        Status s = OptimizeGraph(options, *graph_, flib_def_.get(), &optimized_graph,
                                 &optimized_flib);
        if (!s.ok()) {
          // Simply copy the original graph and the function library if we couldn't
          // optimize it.
          optimized_graph.reset(new Graph(flib_def_.get()));
          CopyGraph(*graph_, optimized_graph.get());
          optimized_flib.reset(new FunctionLibraryDefinition(*flib_def_));
        }
      
        subgraph::RewriteGraphMetadata rewrite_metadata;
        if (session_options_ == nullptr ||
            !session_options_->config.graph_options().place_pruned_graph()) {
          TF_RETURN_IF_ERROR( // PruneGraph 會進行剪枝
              PruneGraph(options, optimized_graph.get(), &rewrite_metadata));
        } else {
          // This GraphExecutionState represents a graph that was
          // pruned when this was constructed, so we copy the metadata from
          // a member variable.
          CHECK(rewrite_metadata_);
          rewrite_metadata = *rewrite_metadata_;
        }
      
        GraphOptimizationPassOptions optimization_options;
        optimization_options.session_options = session_options_;
        optimization_options.graph = &optimized_graph;
        optimization_options.flib_def = optimized_flib.get();
        optimization_options.device_set = device_set_;
      
        TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
            OptimizationPassRegistry::POST_REWRITE_FOR_EXEC, optimization_options));
      
        int64_t collective_graph_key = options.collective_graph_key;
        if (collective_graph_key == BuildGraphOptions::kNoCollectiveGraphKey) {
          // BuildGraphOptions does not specify a collective_graph_key.  Check all
          // nodes in the Graph and FunctionLibraryDefinition for collective ops and
          // if found, initialize a collective_graph_key as a hash of the ordered set
          // of instance keys.
          std::set<int32> instance_key_set;
          bool has_collective_v2 = false;
          for (Node* node : optimized_graph->nodes()) {
            if (node->IsCollective()) {
              int32_t instance_key;
              TF_RETURN_IF_ERROR(
                  GetNodeAttr(node->attrs(), "instance_key", &instance_key));
              instance_key_set.emplace(instance_key);
            } else if (IsCollectiveV2(node->type_string())) {
              has_collective_v2 = true;
            } else {
              const FunctionDef* fdef = optimized_flib->Find(node->def().op());
              if (fdef != nullptr) {
                for (const NodeDef& ndef : fdef->node_def()) {
                  if (ndef.op() == "CollectiveReduce" ||
                      ndef.op() == "CollectiveBcastSend" ||
                      ndef.op() == "CollectiveBcastRecv" ||
                      ndef.op() == "CollectiveGather") {
                    int32_t instance_key;
                    TF_RETURN_IF_ERROR(
                        GetNodeAttr(ndef, "instance_key", &instance_key));
                    instance_key_set.emplace(instance_key);
                  } else if (IsCollectiveV2(ndef.op())) {
                    has_collective_v2 = true;
                  }
                }
              }
            }
          }
          if (!instance_key_set.empty()) {
            uint64 hash = 0x8774aa605c729c72ULL;
            for (int32_t instance_key : instance_key_set) {
              hash = Hash64Combine(instance_key, hash);
            }
            collective_graph_key = hash;
          } else if (has_collective_v2) {
            collective_graph_key = 0x8774aa605c729c72ULL;
          }
        }
      
        // Make collective execution order deterministic if needed.
        if (options.collective_order != GraphCollectiveOrder::kNone) {
          TF_RETURN_IF_ERROR(
              OrderCollectives(optimized_graph.get(), options.collective_order));
        }
      
        // Copy the extracted graph in order to make its node ids dense,
        // since the local CostModel used to record its stats is sized by
        // the largest node id.
        std::unique_ptr<ClientGraph> dense_copy(
            new ClientGraph(std::move(optimized_flib), rewrite_metadata.feed_types,
                            rewrite_metadata.fetch_types, collective_graph_key));
        CopyGraph(*optimized_graph, &dense_copy->graph);
      
        metrics::UpdateGraphBuildTime(Env::Default()->NowMicros() - start_time_usecs);
        *out = std::move(dense_copy);
        return Status::OK();
      }
      

      2.3 切分注冊

      2.2.1 原理

      因為單個設備的計算能力和存儲都不足,所以需要對大型模型進行模型分片,其本質就是把模型和相關計算進行切分之后分配到不同的設備之上。

      TensorFlow的 Placement 機制就是解決模型分片問題,其作用就是標明哪個 operation 放置在哪個設備之上。Placement 這個名詞或者說機制最早應該是 Google Spanner 提出來的,其提供跨區數據遷移時管理功能,也有一定的負載均衡意義。TF 的 Placement 借鑒了 Google 的思想,其原則是:盡量滿足用戶需求;盡量使用計算更快的設備;優先考慮近鄰性,避免拷貝;確保分配之后的程序可以運行。

      Placement 機制完成之后,每個節點就擁有了Placement信息,而 Partition 方法就可以根據這些節點的信息對計算圖進行切分。

      2.2.2 配置

      BuildAndRegisterPartitions 之中會調用 RegisterPartitions 切分注冊,我們首先關注的是這里如何配置切分??梢钥吹剑涫褂?SplitByWorker 做了切分標準。

      Status MasterSession::BuildAndRegisterPartitions(ReffedClientGraph* rcg) {
        // 為切分做配置
        PartitionOptions popts;
        popts.node_to_loc = SplitByWorker; // 被worker切分
        popts.new_name = [this](const string& prefix) {
          mutex_lock l(mu_);
          return strings::StrCat(prefix, "_S", next_node_id_++);
        };
        popts.get_incarnation = [this](const string& name) -> int64 {
          Device* d = devices_->FindDeviceByName(name);
          if (d == nullptr) {
            return PartitionOptions::kIllegalIncarnation;
          } else {
            return d->attributes().incarnation();
          }
        };
        popts.control_flow_added = false; // 控制流
        const bool enable_bfloat16_sendrecv =
            session_opts_.config.graph_options().enable_bfloat16_sendrecv();
        // 是否cast
        popts.should_cast = [enable_bfloat16_sendrecv](const Edge* e) {
          if (e->IsControlEdge()) {
            return DT_FLOAT;
          }
          DataType dtype = BaseType(e->src()->output_type(e->src_output()));
          if (enable_bfloat16_sendrecv && dtype == DT_FLOAT) {
            return DT_BFLOAT16;
          } else {
            return dtype;
          }
        };
        if (session_opts_.config.graph_options().enable_recv_scheduling()) {
          popts.scheduling_for_recvs = true;
          popts.need_to_record_start_times = true;
        }
      
        // 切分注冊子圖
        TF_RETURN_IF_ERROR(rcg->RegisterPartitions(std::move(popts)));
      
        return Status::OK();
      }
      

      SplitByWorker 方法如下。

      static string SplitByWorker(const Node* node) {
        string task;
        string device;
        CHECK(DeviceNameUtils::SplitDeviceName(node->assigned_device_name(), &task,
                                               &device))
            << "node: " << node->name() << " dev: " << node->assigned_device_name();
        return task;
      }
      

      BuildAndRegisterPartitions 然后調用了 RegisterPartitions,RegisterPartitions 會調用 DoBuildPartitions 進行分區,調用 DoRegisterPartitions 注冊分區。

      Status MasterSession::ReffedClientGraph::RegisterPartitions(
          PartitionOptions popts) {
        {  // Ensure register once.
          mu_.lock();
          if (client_graph_before_register_) {
            // The `ClientGraph` is no longer needed after partitions are registered.
            // Since it can account for a large amount of memory, we consume it here,
            // and it will be freed after concluding with registration.
      
            std::unique_ptr<ClientGraph> client_graph;
            std::swap(client_graph_before_register_, client_graph);
            mu_.unlock();
            std::unordered_map<string, GraphDef> graph_defs;
            popts.flib_def = client_graph->flib_def.get();
            
            // 進行分區
            Status s = DoBuildPartitions(popts, client_graph.get(), &graph_defs);
            if (s.ok()) {
              // NOTE(mrry): The pointers in `graph_defs_for_publishing` do not remain
              // valid after the call to DoRegisterPartitions begins, so
              // `stats_publisher_` must make a copy if it wants to retain the
              // GraphDef objects.
              std::vector<const GraphDef*> graph_defs_for_publishing;
              graph_defs_for_publishing.reserve(partitions_.size());
              for (const auto& name_def : graph_defs) {
                graph_defs_for_publishing.push_back(&name_def.second);
              }
              
              stats_publisher_->PublishGraphProto(graph_defs_for_publishing);
              
              // 注冊分區
              s = DoRegisterPartitions(popts, std::move(graph_defs));
            }
            mu_.lock();
            init_result_ = s;
            init_done_.Notify();
          } else {
            mu_.unlock();
            init_done_.WaitForNotification();
            mu_.lock();
          }
          const Status result = init_result_;
          mu_.unlock();
          return result;
        }
      }
      

      2.2.3 切分

      DoBuildPartitions 會調用 Partition 正式進入切分。

      #include "tensorflow/core/graph/graph_partition.h"
      
      Status MasterSession::ReffedClientGraph::DoBuildPartitions(
          PartitionOptions popts, ClientGraph* client_graph,
          std::unordered_map<string, GraphDef>* out_partitions) {
        if (popts.need_to_record_start_times) {
          CostModel cost_model(true);
          cost_model.InitFromGraph(client_graph->graph);
          // TODO(yuanbyu): Use the real cost model.
          // execution_state_->MergeFromGlobal(&cost_model);
          SlackAnalysis sa(&client_graph->graph, &cost_model);
          sa.ComputeAsap(&popts.start_times);
        }
      
        // Partition the graph.
        return Partition(popts, &client_graph->graph, out_partitions);
      }
      
      2.2.3.1 Partition

      Partition 的主要邏輯如下:

      • 切分原計算圖,產生多個子圖。
      • 如果跨設備的節點互相有依賴,則插入 Send 和 Recv 節點對。
      • 如果需要,插入 Control Flow 邊。

      具體來說是:

      • 分析原計算圖。補齊控制流邊。
        • 為控制流的分布式執行添加 "代碼"。只為放在多個設備上的框架(frames)添加代碼。新圖是原圖的等價變換,并且具有這樣的特性:它可以隨后被任意分割(低至單個設備的水平),以便分布式執行。
      • 為每個 operator 的節點/邊構建 Memory/Device 信息,也是為了切分做準備。
        • TF 希望參與計算的張量被分配到設備上,參與控制的張量被分配到 Host 之上,所以需要對每個 op 進行分析,確定其在 CPU 或者 GPU 上的版本,也需要確定其輸入和輸出張量的內存信息,比如某些 op 雖然位于 GPU 之上但是依然需要從 CPU 讀取數據,又比如有些數據需要強制放到 CPU 之上因為其對 GPU 不友好。
      • 遍歷圖的節點進行分析和切分,插入 Send/Recv 節點和控制邊,最終得到多個子圖。
        • 從原圖取出一個節點 dst,拿到 dst 的 location 信息,依據 location 信息拿到其在 partitions 之中的GraphDef,添加 Node,設置設備。
        • 將 dst 在原來圖之中的輸入邊分析出來,連同控制邊一起,插入到 inputs 數組之中。
        • 取出 dst 的一個輸入邊,得到邊的 src 節點,得到 src 節點的圖。
          • 如果 src/dst 在同一個圖之中,則說明是同樣分區和可以兼容的內存類型,則在這個圖里面把 src,dst 連接起來,遍歷到 dst 下一個邊。
          • 如果 src/dst 不在同一個圖里面,所以需要通信,這樣就需要依據 edge, src 等信息構建通信 key,依據 key 在 cache 之中查找 Recv 節點,如果找到了,就把 Recv 節點和 dst 節點連起來,遍歷到 dst 下一個邊。
          • 如果存在控制邊,因為是跨設備,需要把這種依賴關系跨設備等價表示出來。所以雖然控制邊不真正傳輸張量,也需要發一個消息給接受方,這樣接收方才知道有一個依賴關系。所以在src設備上插入一個 dummy const node,在接收方插入一個 identity 節點來讀取這個 shape 是 0 的 dummy const,還需要把 identity 確定為接收方的控制依賴。
          • 添加 Send 節點和 Recv 節點。
          • 針對控制/數據關系做進一步修復。
            • 對于同一設備上的發送/接收節點,它們之間是有數據拷貝操作的,所以添加一個從發送到接收的控制邊。這樣可以防止異步 recv kernel 在數據可用之前就被調度出去,從而保證了執行順序。
            • 否則是跨設備,需要根據數據流來重定向控制邊到真實的 recv 節點。
      • 收尾工作,比如完善子圖的版本信息,函數庫,和send/recv節點的 Incarnation

      比如分割之后,如下:

      圖 2 分割計算圖,來自 TensorFlow

      插入 Send/Recv 節點之后如下:

      圖 3 插入節點,來自 TensorFlow

      Partition 代碼具體如下,進行大幅精簡。

      Status Partition(const PartitionOptions& opts, Graph* g,
                       std::unordered_map<string, GraphDef>* partitions) {
        Status status;
        partitions->clear();
      
        GraphInfo g_info;
        if (!opts.control_flow_added) {
          // 分析原計算圖。補齊控制流邊。
          // 為控制流的分布式執行添加 "代碼"。只為放在多個設備上的框架(frames)添加代碼。新圖是原圖的等價變換,并且具有這樣的特性:它可以隨后被任意分割(低至單個設備的水平),以便分布式執行。
          status = AddControlFlow(opts, g, &g_info);
          if (!status.ok()) return status;
        }
      
        // At this point, all the graph mutations have been done. Build memory
        // and device type info for every node and edge in the graph.
        // 為每個operator的節點/邊構建Memory/Device信息,也是為了切分做準備。
        // TF希望參與計算的張量被分配到設備上,參與控制的張量被分配到Host之上,所以需要對每個op進行分析,確定其在CPU或者GPU上的版本,也需要確定其輸入和輸出張量的內存信息,比如某些op雖然位于GPU之上但是依然需要從CPU讀取數據,而有些數據需要強制放到CPU之上因為其對GPU不友好。
        status = BuildMemoryDeviceInfo(*g, &g_info);
        if (!status.ok()) return status;
      
        string dstp;
        std::vector<const Edge*> inputs;
        DupRecvTable dup_recv(3);
        //  對于一個節點dst,'ref_recvs'是由ref邊引入到dst的recvs。ref_control_inputs'是由非ref到dst的輸入。
        // 對于(ref_recvs x ref_control_inputs)之中每一個pair,我們增加一個控制邊
        std::vector<NodeDef*> ref_recvs;
        std::vector<string> ref_control_inputs;
      
        int32_t num_data = 0;
        int32_t num_control = 0;
        for (const Node* dst : g->op_nodes()) { // 遍歷圖的節點進行分析和切分,插入Send/Recv節點和控制邊
          // 從原圖取出一個節點dst
          dstp = opts.node_to_loc(dst); // 拿到dst的location信息
          GraphDef* dst_graph = &(*partitions)[dstp]; // 依據location信息拿到其在partitions之中的GraphDef
          NodeDef* dst_def = dst_graph->add_node(); // 添加Node
          *dst_def = dst->def();
          dst_def->set_device(dst->assigned_device_name()); // 設置設備   
          dst_def->clear_input();  // Inputs are filled below
      
          // Arrange the incoming edges to dst so that input[i] holds the
          // input flowing into slot numbered i. Trailing entries in input[]
          // hold control edges.
          // 將dst在原來圖之中的輸入邊分析出來,連同控制邊一起,插入到inputs數組之中。
          inputs.clear();
          inputs.resize(dst->num_inputs(), nullptr);
          ref_recvs.clear();
          ref_control_inputs.clear();
          const Edge* control_flow_edge = nullptr;
          int32_t num_control_flow_edges = 0;
          int32_t num_input_edges = 0;
          for (const Edge* edge : dst->in_edges()) {
            if (edge->IsControlEdge()) {
              if (IsMerge(edge->src()) && IsControlLoop(edge->src())) {
                // This is one of the control edges added for control flow. There
                // can be multiple such edges as the dest node may have multiple
                // remote inputs. We keep track of the number of such edges.
                control_flow_edge = edge;
                ++num_control_flow_edges;
              } else {
                inputs.push_back(edge);
              }
            } else {
              DCHECK(inputs[edge->dst_input()] == nullptr);
              inputs[edge->dst_input()] = edge;
              ++num_input_edges;
            }
          }
      
          // Process in order so that all data edges are added as inputs to
          // dst in Edge::dst_input() order.
          for (const Edge* edge : inputs) { // 取出dst的一個邊
            const Node* src = edge->src(); // 得到邊的src節點
            if (!src->IsOp()) continue;  // Skip Sink/Source nodes.
      
            GraphDef* src_graph = &(*partitions)[opts.node_to_loc(src)]; // 調用配置的 SplitByWorker 或者 SplitByDevice 進行分區,得到src節點的圖
            if (src_graph == dst_graph && !NeedSameDeviceSendRecv(edge, g_info)) {
              // 在同一個圖之中,則說明是同樣分區和可以兼容的內存類型,則在這個圖里面把src,dst連接起來
              // Same partition and compatible memory types:
              AddInput(dst_def, src->name(), edge->src_output());
              if (edge->IsControlEdge() ||
                  !IsRefType(src->output_type(edge->src_output()))) {
                ref_control_inputs.push_back(src->name());
              }
              continue; // 遍歷到dst下一個邊
            }
      
            // Check whether there is already a send/recv pair transferring
            // the same tensor/control from the src to dst partition.
            const bool on_host = IsDstInputOnHost(edge, g_info);
            // 因為不在同一個圖里面,所以需要通信,這樣就需要依據edge, src等信息構建通信key
            DupRecvKey key{src->id(), edge->src_output(), dst_graph, on_host};
            auto iter = dup_recv.find(key); // 依據key在cache之中查找Recv節點
            if (iter != dup_recv.end()) { // 如果找到了,就把Recv節點和dst節點連起來
              // We found one. Reuse the data/control transferred already.
              const string& recv_node_name = iter->second.recv->name();
              if (edge->IsControlEdge()) {
                AddInput(dst_def, recv_node_name, Graph::kControlSlot);
              } else {
                AddInput(dst_def, recv_node_name, 0);
              }
              ref_control_inputs.push_back(recv_node_name);
              continue; // 遍歷到dst下一個邊
            }
      
            // 添加Send節點和Recv節點
            NodeDefBuilder::NodeOut send_from; // 設定發送節點信息
            if (edge->IsControlEdge()) {
              // Insert a dummy const node that will generate a tiny
              // data element to be sent from send to recv.
              // 如果存在控制邊,因為是跨設備,需要把這種依賴關系跨設備等價表示出來。
              // 所以雖然控制邊不真正傳輸張量,也需要發一個消息給接受方,這樣接收方才知道有一個依賴關系。所以在src設備上插入一個dummy const node,在接收方插入一個identity節點來讀取這個shape是0的dummy const,還需要把identity確定為接收方的控制依賴
              NodeDef* dummy = AddDummyConst(opts, src_graph, edge, &status);
              if (!status.ok()) return status;
              AddInput(dummy, src->name(), Graph::kControlSlot);
              send_from.Reset(dummy->name(), 0, DT_FLOAT);
            } else {
              send_from.Reset(src->name(), edge->src_output(), EdgeType(edge));
            }
      
            // Need to split edge by placing matching send/recv nodes on
            // the src/dst sides of the edge.
            NodeDef* send = AddSend(opts, g_info, src_graph, edge, send_from,
                                    send_start_time, &status);
            if (!status.ok()) return status;
      
            NodeDef* real_recv = nullptr;
            NodeDef* recv =
                AddRecv(opts, g_info, dst_graph, edge, &real_recv, &status);
            if (!status.ok()) return status;
      
             if (src_graph == dst_graph) {
              // For same device send/recv, add a control edge from send to recv.
              // This prevents the asynchronous recv kernel from being scheduled
              // before the data is available.
              // 對于同一設備上的發送/接收節點,它們之間是有數據拷貝操作的,所以添加一個從發送到接收的控制邊。這樣可以防止異步recv kernel在數據可用之前就被調度出去,從而保證了執行順序。
              AddInput(real_recv, send->name(), Graph::kControlSlot);
            } else if (control_flow_edge != nullptr) {
              // Redirect control edge to the real recv since this is not the same
              // device send/recv.
              // 否則是跨設備,需要根據數據流來重定向控制邊到真實的recv節點
              --num_control_flow_edges;
              AddInput(real_recv, control_flow_edge->src()->name(),
                       Graph::kControlSlot);
            }
      
            if (!edge->IsControlEdge() &&
                IsRefType(src->output_type(edge->src_output()))) {
              // If src is of ref type and the edge is not a control edge, dst has
              // read semantics and therefore we must control the recv.
              ref_recvs.push_back(real_recv);
            } else {
              // Memorize the send/recv pair, only if this is not a "ref" edge.
              // NOTE(yuanbyu): Collapsing ref edges requires extreme care so
              // for now we don't do it.
              dup_recv[key] = {recv, real_recv, recv_start_time};
              ref_control_inputs.push_back(recv->name());
            }
      
            if (edge->IsControlEdge()) {
              ++num_control;
              AddInput(dst_def, recv->name(), Graph::kControlSlot);
            } else {
              ++num_data;
              AddInput(dst_def, recv->name(), 0);
            }
          }
      
          // Add control edges from 'ref_control_inputs' to 'ref_recvs'.
          // NOTE(yuanbyu): Adding these control edges should not introduce
          // deadlocks. 'dst' has implicit "read" nodes that, when we split
          // across devices, are made explicit; Retargeting the dependencies
          // to 'dst' to those nodes would not introduce cycles if there isn't
          // one before the transformation.
          // NOTE(yuanbyu): This may impact performance because it defers the
          // execution of recvs until all the other inputs become available.
          AddReadControl(ref_recvs, ref_control_inputs);
      
          // Add back the control edges for control flow that are not used.
          if (control_flow_edge != nullptr) {
            for (int i = 0; i < num_control_flow_edges; ++i) {
              AddInput(dst_def, control_flow_edge->src()->name(),
                       Graph::kControlSlot);
            }
          }
        }
      
        // 收尾工作,比如完善子圖的版本信息,函數庫,和send/recv節點的Incarnation
        const FunctionLibraryDefinition* flib_def = opts.flib_def;
        if (flib_def == nullptr) {
          flib_def = &g->flib_def();
        }
      
        // Set versions, function library and send/recv incarnation.
        for (auto& it : *partitions) {
          GraphDef* gdef = &it.second;
          *gdef->mutable_versions() = g->versions();
          // Prune unreachable functions from `flib_def` before adding them to `gdef`.
          *gdef->mutable_library() = flib_def->ReachableDefinitions(*gdef).ToProto();
      
          // Traverse the graph to fill every send/recv op's incarnation
          // information.
          SetIncarnation(opts, gdef);
        }
      
        return Status::OK();
      }
      

      Partition 用到的部分函數具體如下。

      2.2.3.2 AddDummyConst

      如果存在控制邊,因為是跨設備,需要把這種依賴關系跨設備等價表示出來。所以雖然控制邊不真正傳輸張量,也需要發一個消息給接受方,這樣接收方才知道有一個依賴關系。

      所以在src設備上插入一個 dummy const node 用來表達這種對下游的控制依賴關系,在接收方插入一個 identity節點來讀取這個 shape 是 0 的 dummy const,還需要把identity確定為接收方的控制依賴。這樣,dummy const node 是生產者,Identity 是消費者角色。就滿足了跨設備間的通信需求。

      NodeDef* AddDummyConst(const PartitionOptions& opts, GraphDef* gdef,
                             const Edge* edge, Status* status) {
        const Node* src = edge->src();
        Tensor tensor(DT_FLOAT, TensorShape({0}));
        NodeDef* result = gdef->add_node();
        *status = NodeDefBuilder(opts.new_name(src->name()), "Const")
                      .Device(src->assigned_device_name())
                      .Attr("dtype", DT_FLOAT)
                      .Attr("value", tensor)
                      .Finalize(result, /*consume=*/true);
        return result;
      }
      
      2.2.3.3 AddSend

      如果 src 和 dst 分別屬于兩個 Partition,則需要把原來兩者之間的普通邊切分開,在它們中間增加 Send 與 Recv 節點,這樣就可以將其劃歸在兩個不同 Partition 之內。

      NodeDef* AddSend(const PartitionOptions& opts, const GraphInfo& g_info,
                       GraphDef* gdef, const Edge* edge,
                       NodeDefBuilder::NodeOut send_from, int64_t start_time,
                       Status* status) {
        const DataType dtype = send_from.data_type;
        const DataType cast_dtype = opts.should_cast ? opts.should_cast(edge) : dtype;
        const Node* src = edge->src();
        const int src_port = edge->src_output();
      
        // host_memory = true iff we need to use HostSend/HostCast.
        bool host_memory = false;
        if (!edge->IsControlEdge()) {
          auto src_it = g_info.output_types.find({src->id(), src_port});
          host_memory = (src_it->second == HOST_MEMORY);
        }
      
        // Add a cast node that casts dtype to cast_dtype.
        // NOTE(yuanbyu): Only cast for cross-device send/recv.
        if (dtype != cast_dtype && !NeedSameDeviceSendRecv(edge, g_info)) {
          const string cast_op = (host_memory) ? "_HostCast" : "Cast";
          NodeDefBuilder cast_builder(opts.new_name(src->name()), cast_op,
                                      NodeDebugInfo(*src));
          cast_builder.Device(src->assigned_device_name()).Input(send_from);
          cast_builder.Attr("DstT", cast_dtype);
      
          if (cast_dtype == DT_BFLOAT16) {
            // the below attribute specifies that the cast to bfloat16 should use
            // truncation. This is needed to retain legacy behavior when we change
            // the default bfloat16 casts to use rounding instead of truncation
            cast_builder.Attr("Truncate", true);
          }
      
          NodeDef* cast = gdef->add_node();
          *status = cast_builder.Finalize(cast, /*consume=*/true);
          if (!status->ok()) return nullptr;
      
          // Connect the Send op to the cast.
          send_from.Reset(cast->name(), 0, cast_dtype);
        }
      
        // Add the send node.
        const string send_op = (host_memory) ? "_HostSend" : "_Send";
        NodeDefBuilder send_builder(opts.new_name(src->name()), send_op,
                                    NodeDebugInfo(*src));
        SetSendRecvAttrs(opts, edge, &send_builder);
        send_builder.Device(src->assigned_device_name()).Input(send_from);
      
        NodeDef* send = gdef->add_node();
        *status = send_builder.Finalize(send, /*consume=*/true);
        return send;
      }
      
      2.2.3.4 AddRecv

      前面提到的在接收方插入一個 identity 節點來讀取這個 shape 是 0 的 dummy const,還需要把 identity 確定為接收方的控制依賴,這部分代碼在此實現。Identity 是恒等變化,可以直接輸出張量,這樣既去除了變量的引用標識,也避免了內存拷貝。

      NodeDef* AddRecv(const PartitionOptions& opts, const GraphInfo& g_info,
                       GraphDef* gdef, const Edge* edge, NodeDef** real_recv,
                       Status* status) {
        const DataType dtype = EdgeType(edge);
        const Node* src = edge->src();
        const Node* dst = edge->dst();
        const int dst_port = edge->dst_input();
        DataType cast_dtype = dtype;
      
        // NOTE(yuanbyu): Only cast for cross-device send/recv.
        if (opts.should_cast && !NeedSameDeviceSendRecv(edge, g_info)) {
          cast_dtype = opts.should_cast(edge);
        }
      
        // host_memory = true iff we need to use HostRecv/HostCast.
        // Also log the introduction of the send-recv pair, for performance debugging.
        bool host_memory = false;
        if (!edge->IsControlEdge()) {
          auto dst_it = g_info.input_types.find({dst->id(), dst_port});
          DCHECK(dst_it != g_info.input_types.end());
          host_memory = (dst_it->second == HOST_MEMORY);
          bool src_host_memory = false;
        } else {
          // Log control-edge transfers too, but don't mention memory space since it's
          // irrelevant.
      		// 省略log
        }
      
        // Add the recv node.
        const string recv_op = (host_memory) ? "_HostRecv" : "_Recv";
        NodeDefBuilder recv_builder(opts.new_name(src->name()), recv_op,
                                    NodeDebugInfo(*src));
        SetSendRecvAttrs(opts, edge, &recv_builder);
        recv_builder.Device(dst->assigned_device_name())
            .Attr("tensor_type", cast_dtype);
        NodeDef* recv = gdef->add_node();
        *status = recv_builder.Finalize(recv, /*consume=*/true);
        if (!status->ok()) return nullptr;
        *real_recv = recv;
      
        // Add the cast node (from cast_dtype to dtype) or an Identity node.
        if (dtype != cast_dtype) {
          const string cast_op = (host_memory) ? "_HostCast" : "Cast";
          NodeDefBuilder cast_builder(opts.new_name(src->name()), cast_op,
                                      NodeDebugInfo(*src));
          cast_builder.Attr("DstT", dtype);
          cast_builder.Device(dst->assigned_device_name())
              .Input(recv->name(), 0, cast_dtype);
          NodeDef* cast = gdef->add_node();
          *status = cast_builder.Finalize(cast, /*consume=*/true);
          if (!status->ok()) return nullptr;
          return cast;
        } else if (edge->IsControlEdge()) {
          // An Identity is only needed for control edges.
          // 這里加入了"Identity"。
          NodeDefBuilder id_builder(opts.new_name(src->name()), "Identity",
                                    NodeDebugInfo(*src));
          id_builder.Device(dst->assigned_device_name())
              .Input(recv->name(), 0, cast_dtype);
          NodeDef* id = gdef->add_node();
          *status = id_builder.Finalize(id, /*consume=*/true);
          if (!status->ok()) return nullptr;
          return id;
        } else {
          return recv;
        }
      }
      
      2.2.3.5 AddInput

      AddInput 為下游節點增加輸入。

      // Add an input to dst that comes from the "src_slot" output of the
      // node named by "src_name".
      void AddInput(NodeDef* dst, StringPiece src_name, int src_slot) {
        if (src_slot == Graph::kControlSlot) {
          dst->add_input(strings::StrCat("^", src_name));
        } else if (src_slot == 0) {
          dst->add_input(src_name.data(), src_name.size());
        } else {
          dst->add_input(strings::StrCat(src_name, ":", src_slot));
        }
      }
      
      2.2.3.6 AddReadControl

      AddReadControl 其實是通過 add_input 完成控制。

      // Add a control edge from each input to each recv.
      void AddReadControl(const std::vector<NodeDef*>& recvs,
                          const std::vector<string>& inputs) {
        for (NodeDef* recv : recvs) {
          for (const string& input : inputs) {
            recv->add_input(strings::StrCat("^", input));
          }
        }
      }
      

      2.2.4 注冊

      現在分區完畢,我們來到了注冊階段。

      2.2.4.1 DoRegisterPartitions

      DoRegisterPartitions 會設置哪個 worker 負責哪個分區,關鍵代碼是:

      • 調用 part->worker = worker_cache_->GetOrCreateWorker(part->name) 來設置每個 part 的 worker。

      • 調用 part.worker->RegisterGraphAsync(&c->req, &c->resp, cb) 來注冊圖。

      Status MasterSession::ReffedClientGraph::DoRegisterPartitions(
          const PartitionOptions& popts,
          std::unordered_map<string, GraphDef> graph_partitions) {
        partitions_.reserve(graph_partitions.size());
        Status s;
        for (auto& name_def : graph_partitions) {
          partitions_.emplace_back();
          Part* part = &partitions_.back();
          part->name = name_def.first;
          TrackFeedsAndFetches(part, name_def.second, popts);
          part->worker = worker_cache_->GetOrCreateWorker(part->name);
          if (part->worker == nullptr) {
            s = errors::NotFound("worker ", part->name);
            break;
          }
        }
        if (!s.ok()) {
          for (Part& part : partitions_) {
            worker_cache_->ReleaseWorker(part.name, part.worker);
            part.worker = nullptr;
          }
          return s;
        }
        struct Call {
          RegisterGraphRequest req;
          RegisterGraphResponse resp;
          Status status;
        };
        const int num = partitions_.size();
        gtl::InlinedVector<Call, 4> calls(num);
        BlockingCounter done(num);
        for (int i = 0; i < num; ++i) {
          const Part& part = partitions_[i];
          Call* c = &calls[i];
          c->req.set_session_handle(session_handle_);
          c->req.set_create_worker_session_called(!should_deregister_);
          c->req.mutable_graph_def()->Swap(&graph_partitions[part.name]);
          StripDefaultAttributes(*OpRegistry::Global(),
                                 c->req.mutable_graph_def()->mutable_node());
          *c->req.mutable_config_proto() = session_opts_.config;
          *c->req.mutable_graph_options() = session_opts_.config.graph_options();
          *c->req.mutable_debug_options() =
              callable_opts_.run_options().debug_options();
          c->req.set_collective_graph_key(collective_graph_key_);
      
          auto cb = [c, &done](const Status& s) {
            c->status = s;
            done.DecrementCount();
          };
          part.worker->RegisterGraphAsync(&c->req, &c->resp, cb);
        }
        done.Wait();
        for (int i = 0; i < num; ++i) {
          Call* c = &calls[i];
          s.Update(c->status);
          partitions_[i].graph_handle = c->resp.graph_handle();
        }
        return s;
      }
      
      2.2.4.2 GrpcRemoteWorker

      上面的 part.worker->RegisterGraphAsync 會調用到 GrpcRemoteWorker,最終發送 RegisterGraphRequest 給下游 Worker。

      tensorflow/core/distributed_runtime/rpc/grpc_remote_worker.cc 之中,RegisterGraphAsync 會調用 rpc。

      void RegisterGraphAsync(const RegisterGraphRequest* request,
                              RegisterGraphResponse* response,
                              StatusCallback done) override {
        IssueRequest(request, response, registergraph_, std::move(done));
      }
      

      注意是,除非計算圖節點被重新編排,或者 Master 進程被重啟,否則Master 只會執行一次 RegisterGraph。概念上具體示意如下:

      圖 4 注冊圖,來自 TensorFlow

      2.4 執行計算圖

      既然已經分區結束,也注冊到了遠端 Worker 之上,每個worker都擁有自己的子圖,接下來就是運行子圖。

      Master 通過調用 RunGraph 來在 Worker 之上觸發子圖運算,Worker 會使用 GPU/CPU 運算設備執行TensorFlow Kernel 運算。在 Worker/設備之間會依據情況不同采用不同傳輸方式:

      • 本節點 GPU 和 CPU 之間采用 cudaMemcpyAsync。
      • 本節點 GPU 和 GPU 之間采用 peer-to-peer DMA。
      • 在 Worker 之間采用 gRPC(TCP) 和 RDMA (Converged Ethernet)。

      圖 5 運行子圖

      2.4.1 RunPartitions

      RunPartitions 調用了 RunPartitionsHelper 執行subgraph。

      Status MasterSession::ReffedClientGraph::RunPartitions(
          const MasterEnv* env, int64_t step_id, int64_t execution_count,
          PerStepState* pss, CallOptions* call_opts, const RunCallableRequest& req,
          RunCallableResponse* resp, CancellationManager* cm) {
      
        // Maps the names of fed tensors to their index in `req`.
        std::unordered_map<StringPiece, size_t, StringPieceHasher> feeds(3);
        for (size_t i = 0, end = callable_opts_.feed_size(); i < end; ++i) {
          if (!feeds.insert({callable_opts_.feed(i), i}).second) {
            // MakeCallable will fail if there are two feeds with the same name.
            return errors::Internal("Duplicated feeds in callable: ",
                                    callable_opts_.feed(i));
          }
        }
      
        // Create a wrapped response object to collect the fetched values and
        // rearrange them for the RunCallableResponse.
        RunCallableResponseWrapper wrapped_resp;
        wrapped_resp.resp = resp;
      
        // 在這里調用執行
        TF_RETURN_IF_ERROR(RunPartitionsHelper(
            feeds, callable_opts_.fetch(), env, step_id, execution_count, pss,
            call_opts, req, &wrapped_resp, cm, false /* is_last_partial_run */));
      
        // Collects fetches.
        for (const string& fetch : callable_opts_.fetch()) {
          TensorProto* fetch_proto = resp->mutable_fetch()->Add();
          auto iter = wrapped_resp.fetch_key_to_protos.find(fetch);
          if (iter == wrapped_resp.fetch_key_to_protos.end()) {
            return errors::Internal("Worker did not return a value for fetch: ",
                                    fetch);
          }
          fetch_proto->Swap(&iter->second);
        }
        return Status::OK();
      }
      

      2.4.2 RunPartitionsHelper

      RunPartitionsHelper執行子圖,具體邏輯是:

      • 為每一個分區配置一個 RunManyGraphs::Call,給這個 call 配置 request,response,session handle,graph handle,request id,配置 recv key。
      • 每個 worker 發送 RunGraphAsync。
        • 一個子圖分配給一個 worker,對應一個 worker service。
        • part.worker 是每個分區對應的 WorkerInterface 對象,如果在遠程是 GrpcRemoteWorker 實例,否則是 Worker 實例。
      • 注冊各種 callback,等待 RunGraphAsync 運行結果。
      • 處理運行結果。
      template <class FetchListType, class ClientRequestType,
                class ClientResponseType>
      Status MasterSession::ReffedClientGraph::RunPartitionsHelper(
          const std::unordered_map<StringPiece, size_t, StringPieceHasher>& feeds,
          const FetchListType& fetches, const MasterEnv* env, int64_t step_id,
          int64_t execution_count, PerStepState* pss, CallOptions* call_opts,
          const ClientRequestType& req, ClientResponseType* resp,
          CancellationManager* cm, bool is_last_partial_run) {
        // Collect execution cost stats on a smoothly decreasing frequency.
        ExecutorOpts exec_opts;
        // 省略統計代碼
      
        const int num = partitions_.size();
        RunManyGraphs calls(num);
      
        for (int i = 0; i < num; ++i) {
          // 為每一個分區配置一個RunManyGraphs::Call
          const Part& part = partitions_[i];
          RunManyGraphs::Call* c = calls.get(i);
          c->worker_name = &part.name;
          c->req.reset(part.worker->CreateRunGraphRequest()); // 配置request
          c->resp.reset(part.worker->CreateRunGraphResponse()); // 配置response
          if (is_partial_) {
            c->req->set_is_partial(is_partial_);
            c->req->set_is_last_partial_run(is_last_partial_run);
          }
          c->req->set_session_handle(session_handle_); // 配置session handle
          c->req->set_create_worker_session_called(!should_deregister_);
          c->req->set_graph_handle(part.graph_handle); // 配置graph handle
          c->req->set_step_id(step_id);
          *c->req->mutable_exec_opts() = exec_opts;
          c->req->set_store_errors_in_response_body(true);
          c->req->set_request_id(GetUniqueRequestId()); // 配置request id
          // If any feeds are provided, send the feed values together
          // in the RunGraph request.
          // In the partial case, we only want to include feeds provided in the req.
          // In the non-partial case, all feeds in the request are in the part.
          // We keep these as separate paths for now, to ensure we aren't
          // inadvertently slowing down the normal run path.
          if (is_partial_) {
            for (const auto& name_index : feeds) {
              const auto iter = part.feed_key.find(string(name_index.first));
              if (iter == part.feed_key.end()) {
                // The provided feed must be for a different partition.
                continue;
              }
              const string& key = iter->second;
              TF_RETURN_IF_ERROR(AddSendFromClientRequest(req, c->req.get(),
                                                          name_index.second, key));
            }
            // TODO(suharshs): Make a map from feed to fetch_key to make this faster.
            // For now, we just iterate through partitions to find the matching key.
            for (const string& req_fetch : fetches) {
              for (const auto& key_fetch : part.key_fetch) {
                if (key_fetch.second == req_fetch) {
                  c->req->add_recv_key(key_fetch.first); // 配置 recv key
                  break;
                }
              }
            }
          } else {
            for (const auto& feed_key : part.feed_key) {
              const string& feed = feed_key.first;
              const string& key = feed_key.second;
              auto iter = feeds.find(feed);
              if (iter == feeds.end()) {
                return errors::Internal("No feed index found for feed: ", feed);
              }
              const int64_t feed_index = iter->second;
              TF_RETURN_IF_ERROR(
                  AddSendFromClientRequest(req, c->req.get(), feed_index, key));
            }
            for (const auto& key_fetch : part.key_fetch) {
              const string& key = key_fetch.first;
              c->req->add_recv_key(key); // 配置 recv key
            }
          }
        }
      
        // Issues RunGraph calls.
        for (int i = 0; i < num; ++i) {
          const Part& part = partitions_[i];
          RunManyGraphs::Call* call = calls.get(i);
          part.worker->RunGraphAsync( // 每個 worker 發送 RunGraphAsync
              &call->opts, call->req.get(), call->resp.get(),
              std::bind(&RunManyGraphs::WhenDone, &calls, i, std::placeholders::_1));
        }
      
        // Waits for the RunGraph calls.
        // 注冊各種callback,等待運行結果
        call_opts->SetCancelCallback([&calls]() {
          calls.StartCancel();
        });
        auto token = cm->get_cancellation_token();
        const bool success =
            cm->RegisterCallback(token, [&calls]() { calls.StartCancel(); });
        if (!success) {
          calls.StartCancel();
        }
        calls.Wait();
        call_opts->ClearCancelCallback();
        if (success) {
          cm->DeregisterCallback(token);
        } else {
          return errors::Cancelled("Step was cancelled");
        }
      
        // Collects fetches and metadata.
        // 處理運行結果          
        Status status;
        for (int i = 0; i < num; ++i) {
          const Part& part = partitions_[i];
          MutableRunGraphResponseWrapper* run_graph_resp = calls.get(i)->resp.get();
          for (size_t j = 0; j < run_graph_resp->num_recvs(); ++j) {
            auto iter = part.key_fetch.find(run_graph_resp->recv_key(j));
            if (iter == part.key_fetch.end()) {
              status.Update(errors::Internal("Unexpected fetch key: ",
                                             run_graph_resp->recv_key(j)));
              break;
            }
            const string& fetch = iter->second;
            status.Update(
                resp->AddTensorFromRunGraphResponse(fetch, run_graph_resp, j));
            if (!status.ok()) {
              break;
            }
          }
          if (pss->collect_timeline) {
            pss->step_stats[i].Swap(run_graph_resp->mutable_step_stats());
          }
          if (pss->collect_costs) {
            CostGraphDef* cost_graph = run_graph_resp->mutable_cost_graph();
            for (int j = 0; j < cost_graph->node_size(); ++j) {
              resp->mutable_metadata()->mutable_cost_graph()->add_node()->Swap(
                  cost_graph->mutable_node(j));
            }
          }
          if (pss->collect_partition_graphs) {
            protobuf::RepeatedPtrField<GraphDef>* partition_graph_defs =
                resp->mutable_metadata()->mutable_partition_graphs();
            for (size_t i = 0; i < run_graph_resp->num_partition_graphs(); i++) {
              partition_graph_defs->Add()->Swap(
                  run_graph_resp->mutable_partition_graph(i));
            }
          }
        }
        return status;
      }
      

      2.4.3 GrpcRemoteWorker

      上面調用到了如下代碼通知遠端 Worker 運行子圖。

      part.worker->RunGraphAsync(
          &call->opts, call->req.get(), call->resp.get(),
          std::bind(&RunManyGraphs::WhenDone, &calls, i, std::placeholders::_1));
      

      RunGraphAsync 具體定義就是 GrpcRemoteWorker 之中。GrpcRemoteWorker 的每個函數調用 IssueRequest() 發起一個異步 gRPC 調用。

      void RunGraphAsync(CallOptions* call_opts, const RunGraphRequest* request,
                         RunGraphResponse* response, StatusCallback done) override {
        IssueRequest(request, response, rungraph_, std::move(done), call_opts);
      }
      

      遠端運行的 GrpcWorkerService 作為守護進程,將會處理傳入的 gRPC 請求。

      我們總結 DoRunWithLocalExecution 總體邏輯如下:

      圖 6 DoRunWithLocalExecution 總體邏輯

      2.5 小結

      運行邏輯小結如下,注意這里有兩個grpc 調用,一個是 register,一個是 run。首先調用 register 把子圖注冊到遠端 Worker 之上,其次調用 run 來讓遠端 Worker 完成子圖計算。

      圖 7 Master 動態邏輯 2

      我們馬上會去 Worker 來一探究竟。

      0xFF 參考

      [1]. Abadi M, Agarwal A, Barham P, et al. Tensorflow: Large-scale machine learning on heterogeneous distributed systems[J]. arXiv preprint arXiv:1603.04467, 2016.

      [2] TensorFlow的圖切割模塊——Graph Partitioner

      [3] TensorFlow中的Placement啟發式算法模塊——Placer

      [4] TensorFlow中的設備管理——Device的創建與注冊機制

      posted @ 2022-03-29 16:34  羅西的思考  閱讀(693)  評論(1)    收藏  舉報
      主站蜘蛛池模板: 亚洲一区成人av在线| 人妻少妇久久久久久97人妻| 亚洲精品人妻中文字幕| 久久66热人妻偷产精品| 无套内谢少妇毛片aaaa片免费| 欧美日韩一线| 国产精品亚洲аv无码播放| Y111111国产精品久久久| 国产精品无码久久久久| 成人免费av色资源日日| 正定县| 中文字幕精品人妻丝袜| 日本高清在线观看WWW色| 欧美怡春院一区二区三区| 蜜桃亚洲一区二区三区四| 国产亚洲欧美日韩在线一区| 人成午夜大片免费视频77777| 霞浦县| 欧美亚洲熟妇一区二区三区| 男人狂桶女人出白浆免费视频| 龙门县| 中文字幕久久人妻熟人妻| 少妇精品视频一码二码三| 一区二区三区四区在线不卡高清| 中文字幕日韩有码国产| 亚洲天堂网色图伦理经典| 日韩成人精品一区二区三区| av深夜免费在线观看| 国产精品白浆免费视频| 欧美肥老太wbwbwbb | 高清破外女出血AV毛片| 久艹视频免费看| 色诱视频在线观看| 起碰免费公开97在线视频| 野花社区在线观看视频| 天堂v亚洲国产v第一次| 亚洲午夜成人精品电影在线观看| 天堂国产一区二区三区四区不卡 | 免费无码成人AV片在线| 亚洲一区成人在线视频| 久久亚洲女同第一区综合|