<output id="qn6qe"></output>

    1. <output id="qn6qe"><tt id="qn6qe"></tt></output>
    2. <strike id="qn6qe"></strike>

      亚洲 日本 欧洲 欧美 视频,日韩中文字幕有码av,一本一道av中文字幕无码,国产线播放免费人成视频播放,人妻少妇偷人无码视频,日夜啪啪一区二区三区,国产尤物精品自在拍视频首页,久热这里只有精品12

      [源碼解析] TensorFlow 分布式環境(7) --- Worker 動態邏輯

      [源碼解析] TensorFlow 分布式環境(7) --- Worker 動態邏輯

      前文中,Master 在流程之中先后調用了 gRPC 給遠端 worker 發送命令,即,GrpcRemoteWorker 類中的每一個函數都通過調用 IssueRequest() 發起一個異步的 gRPC 調用。GrpcRemoteWorker 一共發了兩個請求:RegisterGraphAsync,RunGraphAsync,我們看看 GrpcWorkerService 如何處理。

      本文依舊深度借鑒了兩位大神:

      本系列其他文章是:

      [翻譯] TensorFlow 分布式之論文篇 "TensorFlow : Large-Scale Machine Learning on Heterogeneous Distributed Systems"

      [翻譯] TensorFlow 分布式之論文篇 "Implementation of Control Flow in TensorFlow"

      [源碼解析] TensorFlow 分布式環境(1) --- 總體架構

      [源碼解析] TensorFlow 分布式環境(2)---Master 靜態邏輯

      [源碼解析] TensorFlow 分布式環境(3)--- Worker 靜態邏輯

      [源碼解析] TensorFlow 分布式環境(4) --- WorkerCache

      [源碼解析] TensorFlow 分布式環境(5) --- Session

      1. 概述

      1.1 溫故

      我們首先回顧一下目前為止各種概念之間的關系。

      • Client會構建完整的計算圖(FullGraph),但是這個完整計算圖無法并行執行,所以需要切分優化。
      • Master會對完整計算圖進行處理,比如剪枝等操作,生成ClientGraph(可以執行的最小依賴子圖)。然后根據Worker信息把ClientGraph繼續切分成多個PartitionGraph。把這些PartitionGraph注冊給每個Worker。
      • Worker接收到注冊請求之后,會把收到的PartitionGraph根據本地計算設備集繼續做切分成多個PartitionGraph,并且在每個設備上啟動一個Executor來執行本設備收到的PartitionGraph。

      1.2 知新

      我們接下來看看Worker的流程概要。當流程來到某個特點 Worker 節點,如果 worker 節點收到了 RegisterGraphRequest,消息會攜帶 MasterSession 分配的 session_handle 和子圖 graph_def(GraphDef形式)。GraphDef是TensorFlow把Client創建的計算圖使用Protocol Buffer序列化之后的結果。GraphDef包括了計算圖所有的元數據。它可以被ConvertGraphDefToGraph方法轉換成Graph。Graph不但有計算圖的元數據,還有其他運行時候所需要的信息。

      Worker 把計算圖按照本地設備集繼續切分成多個 PartitionGraph,把PartitionGraph 分配給每個設備,然后在每個計算設備之上啟動一個 Executor,等待后續執行命令。Executor類是TensorFlow之中會話執行器的抽象,其提供異步執行局部圖的RunAsync虛方法及其同步封裝版本Run方法。

      當 Worker 節點收到 RunGraphAsync 之后,各個設備開始執行。WorkerSession 會調用 session->graph_mgr()->ExecuteAsync 執行,其又調用到 StartParallelExecutors,這里會啟動一個 ExecutorBarrier。當某一個計算設備執行完所分配的 PartitionGraph 后,ExecutorBarrier 計數器將會增加 1,如果所有設備都完成 PartitionGraph 列表的執行,barrier.wait() 阻塞操作將退出。

      我們接下來逐步分析一下上述流程。

      2. 注冊子圖

      當 worker 節點收到了 RegisterGraphRequest 之后,首先來到了 GrpcWorkerService,所以實際調用的是 "/tensorflow.WorkerService/RegisterGraph",對應代碼如下,其實展開了就是 RegisterGraphHandler:

      #define HANDLE_CALL(method, may_block_on_compute_pool)                        \
        void method##Handler(WorkerCall<method##Request, method##Response>* call) { \
          auto closure = [this, call]() {                                           \
            Status s = worker_->method(&call->request, &call->response);            \
            if (!s.ok()) {                                                          \
              VLOG(3) << "Bad response from " << #method << ": " << s;              \
            }                                                                       \
            call->SendResponse(ToGrpcStatus(s));                                    \
          };                                                                        \
          if ((may_block_on_compute_pool)) {                                        \
            worker_->env()->env->SchedClosure(std::move(closure));                  \
          } else {                                                                  \
            worker_->env()->compute_pool->Schedule(std::move(closure));             \
          }                                                                         \
          ENQUEUE_REQUEST(method, false);                                           \
        }
      
      HANDLE_CALL(RegisterGraph, false);
      

      2.1 GrpcWorker

      RegisterGraph 實際調用的是 WorkerInterface 的方法,其內部會轉到 RegisterGraphAsync 方法。

      Status WorkerInterface::RegisterGraph(const RegisterGraphRequest* request,
                           RegisterGraphResponse* response) {
        return CallAndWait(&ME::RegisterGraphAsync, request, response);
      }
      

      RegisterGraphAsync 最后來到 Worker 的實現,其首先依據 session_handle 查找到 WokerSession,然后調用 GraphMgr。

      GraphMgr* SessionMgr::graph_mgr() const { return graph_mgr_.get(); }
      

      RegisterGraphAsync 具體如下:

      void Worker::RegisterGraphAsync(const RegisterGraphRequest* request,
                                      RegisterGraphResponse* response,
                                      StatusCallback done) {
        std::shared_ptr<WorkerSession> session;
        Status s;
        if (request->create_worker_session_called()) {
          s = env_->session_mgr->WorkerSessionForSession(request->session_handle(),
                                                         &session);
        } else {
          session = env_->session_mgr->LegacySession();
        }
        if (s.ok()) {
          s = session->graph_mgr()->Register(
              request->session_handle(), request->graph_def(), session.get(),
              request->graph_options(), request->debug_options(),
              request->config_proto(), request->collective_graph_key(),
              session->cluster_flr(), response->mutable_graph_handle());
        }
        done(s);
      }
      

      2.2 GraphMgr

      GraphMgr 負責跟蹤一組在 TensorFlow 工作者那里注冊的計算圖。每個注冊的圖都由 GraphMgr 生成的句柄 graph_handle 來識別,并返回給調用者。在成功注冊后,調用者使用圖句柄執行一個圖。每個執行都通過調用者生成的全局唯一ID "step_id"與其他執行區分開來。只要使用的 "step_id"不同,多個執行可以同時獨立使用同一個圖,多個線程可以并發地調用 GraphMgr 方法。

      2.2.1 定義

      GraphMgr 具體定義如下:

      class GraphMgr {
       private:
        typedef GraphMgr ME;
      
        struct ExecutionUnit {
          std::unique_ptr<Graph> graph = nullptr;
          Device* device = nullptr;               // not owned.
          Executor* root = nullptr;               // not owned.
          FunctionLibraryRuntime* lib = nullptr;  // not owned.
          // Build the cost model if this value is strictly positive.
          int64_t build_cost_model = 0;
        };
      
        struct Item : public core::RefCounted {
          ~Item() override;
      
          // Session handle.
          string session;
      
          // Graph handle.
          string handle;
      
          std::unique_ptr<FunctionLibraryDefinition> lib_def;
          // Owns the FunctionLibraryRuntime objects needed to execute functions, one
          // per device.
          std::unique_ptr<ProcessFunctionLibraryRuntime> proc_flr;
          // A graph is partitioned over multiple devices.  Each partition
          // has a root executor which may call into the runtime library.
          std::vector<ExecutionUnit> units;
      
          // Used to deregister a cost model when cost model is required in graph
          // manager.
          GraphMgr* graph_mgr;
      
          int64_t collective_graph_key;
        };
      
        const WorkerEnv* worker_env_;  // Not owned.
        const DeviceMgr* device_mgr_;
      
        CostModelManager cost_model_manager_;
      
        // Owned.
        mutex mu_;
        int64_t next_id_ TF_GUARDED_BY(mu_) = 0;
      
        // If true, blocks until device has finished all queued operations in a step.
        bool sync_on_finish_ = true;
      
        // Table mapping graph handles to registered graphs.
        //
        // TODO(zhifengc): If the client does not call Deregister, we'll
        // lose memory over time. We should implement a timeout-based
        // mechanism to gc these graphs.
        std::unordered_map<string, Item*> table_;
      
        TF_DISALLOW_COPY_AND_ASSIGN(GraphMgr);
      };
      

      具體各個類之間關系和功能如下,注冊圖就是往GraphMgr的table_變量之中進行注冊新Item,而執行圖就是執行具體的Item。

      2.2.2 注冊圖

      注冊圖代碼如下,其實就是轉交給 InitItem,所以我們接下去看看 InitItem。

      Status GraphMgr::Register(
          const string& handle, const GraphDef& gdef, WorkerSession* session,
          const GraphOptions& graph_options, const DebugOptions& debug_options,
          const ConfigProto& config_proto, int64_t collective_graph_key,
          DistributedFunctionLibraryRuntime* cluster_flr, string* graph_handle) {
        Item* item = new Item;
        Status s = InitItem(handle, gdef, session, graph_options, debug_options,
                            config_proto, collective_graph_key, cluster_flr, item);
        if (!s.ok()) {
          item->Unref();
          return s;
        }
      
        // Inserts one item into table_.
        {
          mutex_lock l(mu_);
          *graph_handle =
              strings::Printf("%016llx", static_cast<long long>(++next_id_));
          item->handle = *graph_handle;
          CHECK(table_.insert({*graph_handle, item}).second);
        }
        return Status::OK();
      }
      

      InitItem 主要功能是:

      • 在給定 session 的一個圖定義 "gdef" 之后,創建 executors。

      • 如果 "gdef"中的一個節點被 "session "中的其他圖所共享,則相同的 op kernel 被重復使用。例如,通常一個params節點被一個會話中的多個圖所共享。

      • 如果 "gdef"被分配給多個設備,可能會添加額外的節點(例如,發送/接收節點)。額外節點的名字是通過調用 "new_name(old_name) "生成的。

      • 如果成功的話,"executors"將被分配,每個設備填入一個執行器,調用者將擁有返回的 executors 的所有權。

      // Creates executors given a graph definition "gdef" of a "session".
      // If a node in "gdef" is shared by other graphs in "session", the
      // same op kernel is reused. E.g., typically a params node is shared
      // by multiple graphs in a session.
      //
      // If "gdef" is assigned to multiple devices, extra nodes (e.g.,
      // send/recv nodes) maybe added. The extra nodes' name are generated
      // by calling "new_name(old_name)".
      //
      // "executors" are filled with one executor per device if success and
      // the caller takes the ownership of returned executors.
      Status GraphMgr::InitItem(
          const string& handle, const GraphDef& gdef, WorkerSession* session,
          const GraphOptions& graph_options, const DebugOptions& debug_options,
          const ConfigProto& config_proto, int64_t collective_graph_key,
          DistributedFunctionLibraryRuntime* cluster_flr, Item* item) {
        item->session = handle;
        item->collective_graph_key = collective_graph_key;
        item->lib_def.reset(
            new FunctionLibraryDefinition(OpRegistry::Global(), gdef.library()));
      
        TF_RETURN_IF_ERROR(ValidateGraphDefForDevices(gdef));
      
        // We don't explicitly Validate the graph def because ConvertGraphDefToGraph
        // does that below.
        item->proc_flr.reset(new ProcessFunctionLibraryRuntime(
            device_mgr_, worker_env_->env, /*config=*/&config_proto,
            gdef.versions().producer(), item->lib_def.get(),
            graph_options.optimizer_options(), worker_env_->compute_pool, cluster_flr,
            /*session_metadata=*/nullptr,
            Rendezvous::Factory{
                [this, session](const int64_t step_id, const DeviceMgr*,
                                Rendezvous** r) -> Status {
                  auto* remote_r = this->worker_env_->rendezvous_mgr->Find(step_id);
                  TF_RETURN_IF_ERROR(remote_r->Initialize(session));
                  *r = remote_r;
                  return Status::OK();
                },
                [this](const int64_t step_id) {
                  this->worker_env_->rendezvous_mgr->Cleanup(step_id);
                  return Status::OK();
                }}));
      
        // Constructs the graph out of "gdef".
        Graph graph(OpRegistry::Global());
        GraphConstructorOptions opts;
        opts.allow_internal_ops = true;
        opts.expect_device_spec = true;
        opts.validate_nodes = true;
        TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(opts, gdef, &graph));
      
        // Splits "graph" into multiple subgraphs by device names.
        std::unordered_map<string, GraphDef> partitions;
        PartitionOptions popts;
        popts.node_to_loc = SplitByDevice; // 這里調用了
        popts.new_name = [this](const string& prefix) {
          mutex_lock l(mu_);
          return strings::StrCat(prefix, "_G", next_id_++);
        };
        popts.get_incarnation = [this](const string& name) -> int64 {
          Device* device = nullptr;
          Status s = device_mgr_->LookupDevice(name, &device);
          if (s.ok()) {
            return device->attributes().incarnation();
          } else {
            return PartitionOptions::kIllegalIncarnation;
          }
        };
        popts.flib_def = item->lib_def.get();
        popts.control_flow_added = true;
        popts.scheduling_for_recvs = graph_options.enable_recv_scheduling();
        TF_RETURN_IF_ERROR(Partition(popts, &graph, &partitions));
        if (popts.scheduling_for_recvs) {
          TF_RETURN_IF_ERROR(AddControlEdges(popts, &partitions));
        }
      
        std::unordered_map<string, std::unique_ptr<Graph>> partition_graphs;
        // 對每個分區進行圖轉換
        for (auto& partition : partitions) {
          std::unique_ptr<Graph> device_graph(new Graph(OpRegistry::Global()));
          GraphConstructorOptions device_opts;
          // There are internal operations (e.g., send/recv) that we now allow.
          device_opts.allow_internal_ops = true;
          device_opts.expect_device_spec = true;
          TF_RETURN_IF_ERROR(ConvertGraphDefToGraph(
              device_opts, std::move(partition.second), device_graph.get()));
          partition_graphs.emplace(partition.first, std::move(device_graph));
        }
      
        GraphOptimizationPassOptions optimization_options;
        optimization_options.flib_def = item->lib_def.get();
        optimization_options.partition_graphs = &partition_graphs;
        TF_RETURN_IF_ERROR(OptimizationPassRegistry::Global()->RunGrouping(
            OptimizationPassRegistry::POST_PARTITIONING, optimization_options));
      
        LocalExecutorParams params;
      
        item->units.reserve(partitions.size());
        item->graph_mgr = this;
        const auto& optimizer_opts = graph_options.optimizer_options();
        GraphOptimizer optimizer(optimizer_opts);
        for (auto& p : partition_graphs) {
          const string& device_name = p.first;
          std::unique_ptr<Graph>& subgraph = p.second;
          item->units.resize(item->units.size() + 1);
          ExecutionUnit* unit = &(item->units.back());
      
          // Find the device.
          Status s = device_mgr_->LookupDevice(device_name, &unit->device);
          if (!s.ok()) {
            // Remove the empty unit from the item as the item destructor wants all
            // units to have valid devices.
            item->units.pop_back();
            return s;
          }
      
          // 看看是否需要重寫圖
          // Give the device an opportunity to rewrite its subgraph.
          TF_RETURN_IF_ERROR(unit->device->MaybeRewriteGraph(&subgraph));
      
          // Top-level nodes in the graph uses the op segment to cache
          // kernels. Therefore, as long as the executor is alive, we need
          // to ensure the kernels cached for the session are alive.
          auto opseg = unit->device->op_segment();
          opseg->AddHold(handle);
      
          // Function library runtime.
          FunctionLibraryRuntime* lib = item->proc_flr->GetFLR(unit->device->name());
      
          // 建立 executor
          // Construct the root executor for the subgraph.
          params.device = unit->device;
          params.function_library = lib;
          params.create_kernel =
              [handle, lib, opseg](const std::shared_ptr<const NodeProperties>& props,
                                   OpKernel** kernel) {
                // NOTE(mrry): We must not share function kernels (implemented
                // using `CallOp`) between subgraphs, because `CallOp::handle_`
                // is tied to a particular subgraph. Even if the function itself
                // is stateful, the `CallOp` that invokes it is not.
                if (!OpSegment::ShouldOwnKernel(lib, props->node_def.op())) {
                  return lib->CreateKernel(props, kernel);
                }
                auto create_fn = [lib, &props](OpKernel** kernel) {
                  return lib->CreateKernel(props, kernel);
                };
                // Kernels created for subgraph nodes need to be cached.  On
                // cache miss, create_fn() is invoked to create a kernel based
                // on the function library here + global op registry.
                return opseg->FindOrCreate(handle, props->node_def.name(), kernel,
                                           create_fn);
              };
          params.delete_kernel = [lib](OpKernel* kernel) {
            if (kernel && !OpSegment::ShouldOwnKernel(lib, kernel->type_string())) {
              delete kernel;
            }
          };
      
          // 優化圖
          optimizer.Optimize(lib, worker_env_->env, params.device, &subgraph,
                             GraphOptimizer::Options());
      
          TF_RETURN_IF_ERROR(
              EnsureMemoryTypes(DeviceType(unit->device->device_type()),
                                unit->device->name(), subgraph.get()));
          unit->graph = std::move(subgraph);
          unit->build_cost_model = graph_options.build_cost_model();
          if (unit->build_cost_model > 0) {
            skip_cost_models_ = false;
          }
          TF_RETURN_IF_ERROR(NewLocalExecutor(params, *unit->graph, &unit->root));
        }
        return Status::OK();
      }
      

      上面需要注意的一點是使用了 SplitByDevice 進行圖的二次切分,這次是按照設備來切分。

      // NOTE: node->device_name() is not set by GraphConstructor.  We
      // expects that NodeDef in GraphDef given to workers fully specifies
      // device names.
      static string SplitByDevice(const Node* node) {
        return node->assigned_device_name();
      }
      
      inline const std::string& Node::assigned_device_name() const {
        return graph_->get_assigned_device_name(*this);
      }
      

      注冊圖的結果大致如下,就是使用Master傳來的各種信息來生成一個Item,注冊在GraphMgr之中,同時也為Item生成ExecutionUnit,其中graph_handle是根據handle生成的。

      注冊完子圖之后,后續就可以運行子圖。

      3. 運行子圖

      Master 用 RunGraphRequest 來執行在 graph_handle下注冊的所有子圖。Master 會生成一個全局唯一的 step_id 來區分圖計算的不同運行 step。子圖之間可以使用 step_id 進行彼此通信(例如,發送/轉發操作),以區分不同運行產生的張量。

      RunGraphRequest 消息的 send 表示子圖輸入的張量,recv_key 指明子圖輸出的張量。RunGraphResponse 會返回 recv_key 對應的 Tensor 列表。

      3.1 Service

      首先來到了 GrpcWorkerService,調用到的是 "/tensorflow.WorkerService/RunGraph",對應的代碼是:

      void RunGraphHandler(WorkerCall<RunGraphRequest, RunGraphResponse>* call) {
        // 利用Schedule把計算任務放進線程池隊列中
        Schedule([this, call]() {
          CallOptions* call_opts = new CallOptions;
          ProtoRunGraphRequest* wrapped_request =
              new ProtoRunGraphRequest(&call->request);
          NonOwnedProtoRunGraphResponse* wrapped_response =
              new NonOwnedProtoRunGraphResponse(&call->response);
          call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
          worker_->RunGraphAsync(call_opts, wrapped_request, wrapped_response,
                                 [call, call_opts, wrapped_request,
                                  wrapped_response](const Status& s) {
                                   call->ClearCancelCallback();
                                   delete call_opts;
                                   delete wrapped_request;
                                   delete wrapped_response;
                                   call->SendResponse(ToGrpcStatus(s));
                                 });
        });
        ENQUEUE_REQUEST(RunGraph, true);
      }
      

      這里是把計算任務放進線程池隊列中,具體業務邏輯在 Worker::RunGraphAsync 函數中。

      void Schedule(std::function<void()> f) {
        worker_->env()->compute_pool->Schedule(std::move(f));
      }
      

      3.2 GrpcWorker

      在 RunGraphAsync 之中,有兩種執行方式,我們選擇 DoRunGraph 來分析。

      void Worker::RunGraphAsync(CallOptions* opts, RunGraphRequestWrapper* request,
                                 MutableRunGraphResponseWrapper* response,
                                 StatusCallback done) {
        if (request->store_errors_in_response_body()) {
          done = [response, done](const Status& status) {
            response->set_status(status);
            done(Status::OK());
          };
        }
        if (request->is_partial()) {
          DoPartialRunGraph(opts, request, response, std::move(done)); // 有興趣讀者可以深入研究
        } else {
          DoRunGraph(opts, request, response, std::move(done)); // 分析這里
        }
      }
      

      DoRunGraph 主要是調用了 session->graph_mgr()->ExecuteAsync 來執行計算圖。

      void Worker::DoRunGraph(CallOptions* opts, RunGraphRequestWrapper* request,
                              MutableRunGraphResponseWrapper* response,
                              StatusCallback done) {
        const int64_t step_id = request->step_id();
        Status s = recent_request_ids_.TrackUnique(request->request_id(),
                                                   "RunGraph (Worker)", request);
        if (!s.ok()) {
          done(s);
          return;
        }
      
        std::shared_ptr<WorkerSession> session;
        if (request->create_worker_session_called()) {
          s = env_->session_mgr->WorkerSessionForSession(request->session_handle(),
                                                         &session);
        } else {
          session = env_->session_mgr->LegacySession();
        }
        if (!s.ok()) {
          done(s);
          return;
        }
        GraphMgr::NamedTensors in;
        GraphMgr::NamedTensors* out = new GraphMgr::NamedTensors;
        s = PrepareRunGraph(request, &in, out);
        if (!s.ok()) {
          delete out;
          done(s);
          return;
        }
        StepStatsCollector* collector = nullptr;
        if (request->exec_opts().report_tensor_allocations_upon_oom() ||
            request->exec_opts().record_timeline() ||
            request->exec_opts().record_costs()) {
          collector = new StepStatsCollector(response->mutable_step_stats());
        }
        DeviceProfilerSession* device_profiler_session = nullptr;
        if (collector && request->exec_opts().record_timeline()) {
          // If timeline was requested, assume we want hardware level tracing.
          device_profiler_session = DeviceProfilerSession::Create().release();
        }
        CancellationManager* cm = new CancellationManager;
        opts->SetCancelCallback([this, cm, step_id]() {
          cm->StartCancel();
          AbortStep(step_id);
        });
        CancellationToken token;
        token = cancellation_manager_.get_cancellation_token();
        bool already_cancelled = !cancellation_manager_.RegisterCallback(
            token, [cm]() { cm->StartCancel(); });
        if (already_cancelled) {
          opts->ClearCancelCallback();
          delete cm;
          delete collector;
          delete device_profiler_session;
          delete out;
          done(errors::Aborted("Call was aborted"));
          return;
        }
        session->graph_mgr()->ExecuteAsync(
            request->graph_handle(), step_id, session.get(), request->exec_opts(),
            collector, response, cm, in,
            [this, step_id, response, session, cm, out, token, collector,
             device_profiler_session, opts, done](const Status& status) {
              Status s = status;
              if (s.ok()) {
                // 接受張量
                s = session->graph_mgr()->RecvOutputs(step_id, out);
              }
      
              opts->ClearCancelCallback();
              cancellation_manager_.DeregisterCallback(token);
              delete cm;
      
              if (device_profiler_session) {
                device_profiler_session->CollectData(response->mutable_step_stats())
                    .IgnoreError();
              }
      
              if (s.ok()) {
                for (const auto& p : *out) {
                  const string& key = p.first;
                  const Tensor& val = p.second;
                  response->AddRecv(key, val);
                }
              }
      
              if (collector) collector->Finalize();
              delete collector;
              delete device_profiler_session;
              delete out;
              done(s);
            });
      }
      

      3.3 GraphMgr

      ExecuteAsync 調用了 StartParallelExecutors 完成并行計算,具體邏輯大致為:

      • 找到一個子圖;
      • 計算子圖 cost;
      • 生成一個 rendezvous,使用本 session 初始化 rendezvous,后續就是用這個 rendezvous 來通信,rendezvous 利用 session 進行通信;
      • 發送張量到 Rendezvous;
      • 調用 StartParallelExecutors 執行子計算圖;
      void GraphMgr::ExecuteAsync(const string& handle, const int64_t step_id,
                                  WorkerSession* session, const ExecutorOpts& opts,
                                  StepStatsCollector* collector,
                                  MutableRunGraphResponseWrapper* response,
                                  CancellationManager* cancellation_manager,
                                  const NamedTensors& in, StatusCallback done) {
        const uint64 start_time_usecs = Env::Default()->NowMicros();
        profiler::TraceMeProducer activity(
            // To TraceMeConsumers in ExecutorState::Process/Finish or RunGraphDone.
            [step_id] {
              return profiler::TraceMeEncode(
                  "RunGraph", {{"id", step_id}, {"_r", 1} /*root_event*/});
            },
            profiler::ContextType::kTfExecutor, step_id,
            profiler::TraceMeLevel::kInfo);
        
        // Lookup an item. Holds one ref while executing.
        // 找到一個子圖
        Item* item = nullptr;
        {
          mutex_lock l(mu_);
          auto iter = table_.find(handle);
          if (iter != table_.end()) {
            item = iter->second;
            item->Ref();
          }
        }
       
        // 計算cost
        CostGraphDef* cost_graph = nullptr;
        if (response != nullptr) {
          cost_graph = response->mutable_cost_graph();
          if (opts.record_partition_graphs()) {
            for (const ExecutionUnit& unit : item->units) {
              GraphDef graph_def;
              unit.graph->ToGraphDef(&graph_def);
              response->AddPartitionGraph(graph_def);
            }
          }
        }
      
        // 生成一個rendezvous
        RemoteRendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
        // 使用本session初始化rendezvous,后續就是用這個rendezvous來通信,rendezvous 利用session進行通信
        Status s = rendezvous->Initialize(session); 
        CollectiveExecutor::Handle* ce_handle =
            item->collective_graph_key != BuildGraphOptions::kNoCollectiveGraphKey
                ? new CollectiveExecutor::Handle(
                      worker_env_->collective_executor_mgr->FindOrCreate(step_id),
                      true)
                : nullptr;
        // Sends values specified by the caller.
        // 發送張量到Rendezvous
        size_t input_size = 0;
        if (s.ok()) {
          std::vector<string> keys;
          std::vector<Tensor> tensors_to_send;
          keys.reserve(in.size());
          tensors_to_send.reserve(in.size());
          for (auto& p : in) {
            keys.push_back(p.first);
            tensors_to_send.push_back(p.second);
            input_size += p.second.AllocatedBytes();
          }
          // 發送張量
          s = SendTensorsToRendezvous(rendezvous, nullptr, {}, keys, tensors_to_send);
        }
      
        if (!s.ok()) {
          done(s);
          delete ce_handle;
          item->Unref();
          rendezvous->Unref();
          return;
        }
      
        // 執行子計算圖  
        StartParallelExecutors(
            handle, step_id, item, rendezvous, ce_handle, collector, cost_graph,
            cancellation_manager, session, start_time_usecs,
            [item, rendezvous, ce_handle, done, start_time_usecs, input_size,
             step_id](const Status& s) {
              profiler::TraceMeConsumer activity(
                  // From TraceMeProducer in GraphMgr::ExecuteAsync.
                  [step_id] {
                    return profiler::TraceMeEncode("RunGraphDone", {{"id", step_id}});
                  },
                  profiler::ContextType::kTfExecutor, step_id,
                  profiler::TraceMeLevel::kInfo);
              done(s);
              metrics::RecordGraphInputTensors(input_size);
              metrics::UpdateGraphExecTime(Env::Default()->NowMicros() -
                                           start_time_usecs);
              rendezvous->Unref();
              item->Unref();
              delete ce_handle;
            });
      }
      

      具體大致如下,ExecuteAsync使用handle來查找Item,進而找到計算圖。其中session用來通信和執行,step_id與通信相關,具體可以參見上面代碼。

      StartParallelExecutors 會啟動一個 ExecutorBarrier。當某一個計算設備執行完所分配的 PartitionGraph 后,ExecutorBarrier 計數器將會增加 1,如果所有設備都完成 PartitionGraph 列表的執行,barrier.wait() 阻塞操作將退出。

      void GraphMgr::StartParallelExecutors(
          const string& handle, int64_t step_id, Item* item, Rendezvous* rendezvous,
          CollectiveExecutor::Handle* ce_handle, StepStatsCollector* collector,
          CostGraphDef* cost_graph, CancellationManager* cancellation_manager,
          WorkerSession* session, int64_t start_time_usecs, StatusCallback done) {
        const int num_units = item->units.size();
        ScopedStepContainer* step_container = new ScopedStepContainer(
            step_id,
            [this](const string& name) { device_mgr_->ClearContainers({name}); });
      
        ExecutorBarrier* barrier =
            new ExecutorBarrier(num_units, rendezvous,
                                [this, item, collector, cost_graph, step_container,
                                 done](const Status& s) {
                                  BuildCostModel(item, collector, cost_graph);
                                  done(s);
                                  delete step_container;
                                });
        Executor::Args args;
        args.step_id = step_id;
        args.rendezvous = rendezvous;
        args.collective_executor = ce_handle ? ce_handle->get() : nullptr;
        args.cancellation_manager = cancellation_manager;
        args.stats_collector = collector;
        args.step_container = step_container;
        args.sync_on_finish = sync_on_finish_;
        args.start_time_usecs = start_time_usecs;
        if (LogMemory::IsEnabled()) {
          LogMemory::RecordStep(args.step_id, handle);
        }
        thread::ThreadPool* pool = worker_env_->compute_pool;
        using std::placeholders::_1;
        // Line below is equivalent to this code, but does one less indirect call:
        //  args.runner = [pool](std::function<void()> fn) { pool->Schedule(fn); };
        auto default_runner = std::bind(&thread::ThreadPool::Schedule, pool, _1);
        for (const auto& unit : item->units) {
          thread::ThreadPool* device_thread_pool =
              unit.device->tensorflow_device_thread_pool();
          if (!device_thread_pool) {
            args.runner = default_runner;
          } else {
            args.runner =
                std::bind(&thread::ThreadPool::Schedule, device_thread_pool, _1);
          }
          unit.root->RunAsync(args, barrier->Get());
        }
      }
      

      3.4 小結

      對于注冊/運行子圖,我們用一幅圖來小結一下。

      img

      圖 1 注冊/運行子圖

      4. 總結

      我們用一幅圖來把整個分布式計算流程總結如下:

      img

      圖 2 分布式計算流程

      0xFF 參考

      posted @ 2022-04-01 16:47  羅西的思考  閱讀(730)  評論(0)    收藏  舉報
      主站蜘蛛池模板: 久久69国产精品久久69软件| 亚洲国产成人无码av在线播放 | 国产午夜亚洲精品国产成人 | 狠狠亚洲色一日本高清色| 亚洲精品国产电影| 国产一区二区在线影院| 91国内精品久久精品一本| 中文字幕无码专区一VA亚洲V专| 极品无码国模国产在线观看| 精品国产综合一区二区三区| 久久一日本道色综合久久| 影视先锋av资源噜噜| 亚洲国产精品第一区二区| 少妇性l交大片| 亚洲av色一区二区三区| 噜噜综合亚洲av中文无码| 本道久久综合无码中文字幕| 国产亚洲精品aaaa片app| 九九热视频在线观看一区| 九九热精彩视频在线免费| 色欲AV无码一区二区人妻| 国产嫩草精品网亚洲av| 亚洲三区在线观看内射后入 | 久久99精品久久久大学生| 在线观看成人永久免费网站| 延边| 国产精品aⅴ免费视频| 亚洲真人无码永久在线| 天天摸天天碰天天添| 欧美精欧美乱码一二三四区| 国产精品一区二区在线欢| 免费萌白酱国产一区二区三区| 久久精品国产一区二区三区| 午夜男女爽爽影院免费视频下载| 熟妇人妻一区二区三区四区| 精品人妻伦一二二区久久| 小污女小欲女导航| 乱子伦视频在线看| 18禁极品一区二区三区| 日韩高清国产中文字幕| 国产草草影院ccyycom|