<output id="qn6qe"></output>

    1. <output id="qn6qe"><tt id="qn6qe"></tt></output>
    2. <strike id="qn6qe"></strike>

      亚洲 日本 欧洲 欧美 视频,日韩中文字幕有码av,一本一道av中文字幕无码,国产线播放免费人成视频播放,人妻少妇偷人无码视频,日夜啪啪一区二区三区,国产尤物精品自在拍视频首页,久热这里只有精品12

      [源碼解析] TensorFlow 分布式環(huán)境(3)--- Worker 靜態(tài)邏輯

      [源碼解析] TensorFlow 分布式環(huán)境(3)--- Worker 靜態(tài)邏輯

      在具體介紹 TensorFlow 分布式的各種 Strategy 之前,我們首先需要看看分布式的基礎(chǔ):分布式環(huán)境。只有把基礎(chǔ)打扎實了,才能在以后的分析工作之中最大程度的掃清障礙,事半功倍。本篇介紹 Worker(一系列相關(guān)概念) 的靜態(tài)架構(gòu)。

      本系列其他文章是:

      [翻譯] TensorFlow 分布式之論文篇 "TensorFlow : Large-Scale Machine Learning on Heterogeneous Distributed Systems"

      [翻譯] TensorFlow 分布式之論文篇 "Implementation of Control Flow in TensorFlow"

      [源碼解析] TensorFlow 分布式環(huán)境(1) --- 總體架構(gòu)

      [源碼解析] TensorFlow 分布式環(huán)境(2)---Master 靜態(tài)邏輯

      1. 繼承關(guān)系

      1.1 角色概念

      TensorFlow Worker 類是執(zhí)行計算的實體,其主要功能是:

      • 接收 Master的請求。
      • 管理 WorkerSession。
      • 處理注冊的子圖,比如按照自己節(jié)點上的設(shè)備情況來對子圖進行二次分裂。
      • 在每個設(shè)備上運行注冊的子圖。
      • 支持 worker-to-worker 的張量傳輸?shù)鹊取>唧w如何處理依據(jù) worker 和 worker 的位置關(guān)系來決定,比如 CPU 和 GPU 之間使用 cudaMemcpyAsync,本地 GPU 之間通過 DMA,遠端 worker 通過 gRPC 或者 RDMA。
      • 執(zhí)行完畢之后,從計算圖的終止節(jié)點 sink 中取出結(jié)果。

      可以參見 protobuf/worker_service.proto 以了解關(guān)于每個方法的更多細節(jié)。

      1.2 接口

      對于 WorkerService 的訪問是通過 WorkerInterface 完成的。WorkerInterface 是 worker 的接口類,其是與 TensorFlow Worker service 交互的接口,主要是:

      • 定義了一些異步虛函數(shù),比如 CreateWorkerSessionAsync,派生類將實現(xiàn)它們,這些虛函數(shù)和 GrpcWorkerService 支持的 GrpcWorkerMethod 一一對應(yīng),也和 Protobuf 的配置一一對應(yīng)。
      • 定義了一些同步函數(shù),比如 CreateWorkerSession,其會通過類似 CallAndWait(&ME::CreateWorkerSessionAsync, request, response) 來調(diào)用到具體異步虛函數(shù)。

      1.3 WorkerInterface 派生類

      如下圖所示,WorkerInterface 有三種實現(xiàn)。

      • Worker : 這個類可以被子類化,以便為不同的傳輸機制提供特定方法的專門實現(xiàn)。例如,GrpcWorker 專門實現(xiàn)了 RecvTensorAsync() 方法,以支持更有效的 gRPC 數(shù)據(jù)結(jié)構(gòu)來處理大型二進制數(shù)據(jù)。
      • GrpcWorker : 從 Worker 再次派生,是本地模式下的 Worker 角色。如果 Master/Worker 都是在本地,則可以直接調(diào)用,不需要 RPC 的網(wǎng)絡(luò)傳輸。
      • GrpcRemoteWorker :分布式模式下,Worker 位于遠端,本地需要使用 GrpcRemoteWorker 來訪問遠端 Worker。
        • GrpcRemoteWorker 是 gRPC 客戶端,其通過 stub 來訪問遠端 Worker 之上的 GrpcWorkerService 服務(wù)。
        • GrpcWorkerService 實現(xiàn)了 WorkerService 定義的所有接口,但是實際業(yè)務(wù)是轉(zhuǎn)發(fā)給本地 GrpcWorker 完成。

      具體示例如下:

      圖 1 Worker 邏輯關(guān)系

      2. GrpcRemoteWorker

      GrpcRemoteWorker 相當(dāng)于是遠端 Worker 的一個本地代理。

      • 本地 Master 將計算圖進行分區(qū),然后依據(jù)分區(qū)是不在本地還是遠端,分別調(diào)用本地 Worker 或者 GrpcRemoteWorker 來執(zhí)行分區(qū)的子計算圖。
      • 本地 GrpcRemoteWorker 生成是在 tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc 的GetOrCreateWorker 之中。
      • GrpcRemoteWorker 會通過 IssueRequest 向遠端發(fā)送 grpc 請求。
      • 遠程 GrpcWorkerService 守護進程收到請求后,調(diào)用本地 Worker 處理請求,完成后返回結(jié)果。

      2.1 定義

      具體 GrpcRemoteWorker 代碼如下,我們省略了部分代碼,比如 DeleteWorkerSessionAsync 方法的實現(xiàn)等。

      class GrpcRemoteWorker : public WorkerInterface {
       public:
        explicit GrpcRemoteWorker(SharedGrpcChannelPtr channel,
                                  ::grpc::CompletionQueue* completion_queue,
                                  thread::ThreadPool* callback_threadpool,
                                  WorkerCacheLogger* logger, const string& target)
            : channel_(std::move(channel)),
              stub_(channel_),
              cq_(completion_queue),
              callback_threadpool_(callback_threadpool),
              getstatus_(Method(GrpcWorkerMethod::kGetStatus)),
              createworkersession_(Method(GrpcWorkerMethod::kCreateWorkerSession)),
              deleteworkersession_(Method(GrpcWorkerMethod::kDeleteWorkerSession)),
              registergraph_(Method(GrpcWorkerMethod::kRegisterGraph)),
              deregistergraph_(Method(GrpcWorkerMethod::kDeregisterGraph)),
              rungraph_(Method(GrpcWorkerMethod::kRunGraph)),
              cleanupgraph_(Method(GrpcWorkerMethod::kCleanupGraph)),
              cleanupall_(Method(GrpcWorkerMethod::kCleanupAll)),
              recvtensor_(Method(GrpcWorkerMethod::kRecvTensor)),
              recvbuf_(Method(GrpcWorkerMethod::kRecvBuf)),
              logging_(Method(GrpcWorkerMethod::kLogging)),
              tracing_(Method(GrpcWorkerMethod::kTracing)),
              completegroup_(Method(GrpcWorkerMethod::kCompleteGroup)),
              instancesource_(Method(GrpcWorkerMethod::kCompleteInstance)),
              getstepsequence_(Method(GrpcWorkerMethod::kGetStepSequence)),
              markrecvfinished_(Method(GrpcWorkerMethod::kMarkRecvFinished)),
              logger_(logger),
              target_(target) {}
      
        ~GrpcRemoteWorker() override {}
      
        void CreateWorkerSessionAsync(const CreateWorkerSessionRequest* request,
                                      CreateWorkerSessionResponse* response,
                                      StatusCallback done) override {
          IssueRequest(request, response, createworkersession_, std::move(done));
        }
      
        void RegisterGraphAsync(const RegisterGraphRequest* request,
                                RegisterGraphResponse* response,
                                StatusCallback done) override {
          IssueRequest(request, response, registergraph_, std::move(done));
        }
      
        void RunGraphAsync(CallOptions* call_opts, const RunGraphRequest* request,
                           RunGraphResponse* response, StatusCallback done) override {
          IssueRequest(request, response, rungraph_, std::move(done), call_opts);
        }
        void RunGraphAsync(CallOptions* call_opts, RunGraphRequestWrapper* request,
                           MutableRunGraphResponseWrapper* response,
                           StatusCallback done) override {
          IssueRequest(&request->ToProto(), get_proto_from_wrapper(response),
                       rungraph_, std::move(done), call_opts);
        }
      
       private:
        // Utility method for issuing a generic asynchronous request. The
        // given callback, done, will be called when the RPC completes.
        void IssueRequest(const protobuf::Message* request,
                          protobuf::Message* response, const ::grpc::string& method,
                          StatusCallback done, CallOptions* call_opts = nullptr,
                          bool fail_fast = true) {
          new RPCState<protobuf::Message>(
              &stub_, cq_, method, *request, response, std::move(done), call_opts,
              callback_threadpool_, MaxRetries(), fail_fast, &target_);
        }
      
        void IssueRequest(const protobuf::Message* request, TensorResponse* response,
                          const ::grpc::string& method, StatusCallback done,
                          CallOptions* call_opts = nullptr) {
          new RPCState<TensorResponse>(&stub_, cq_, method, *request, response,
                                       std::move(done), call_opts,
                                       callback_threadpool_, MaxRetries(),
                                       /*fail_fast=*/true, &target_);
        }
      
        // Helper function for initializing the RpcMethod objects below.
        const char* Method(GrpcWorkerMethod id) { return GrpcWorkerMethodName(id); }
      
        // Helper function for configuring max GRPC retries. Defaults to 0 (no
        // retries).
        const int64_t MaxRetries() {
          int64_t max_retries = -1;
          TF_CHECK_OK(ReadInt64FromEnvVar("GRPC_MAX_RETRIES", 0, &max_retries));
          return max_retries;
        }
      
        SharedGrpcChannelPtr channel_;
        ::grpc::GenericStub stub_;
        ::grpc::CompletionQueue* cq_;
        thread::ThreadPool* callback_threadpool_;
      
        const ::grpc::string getstatus_;
        const ::grpc::string createworkersession_;
        const ::grpc::string deleteworkersession_;
        const ::grpc::string registergraph_;
        const ::grpc::string deregistergraph_;
        const ::grpc::string rungraph_;
        const ::grpc::string cleanupgraph_;
        const ::grpc::string cleanupall_;
        const ::grpc::string recvtensor_;
        const ::grpc::string recvbuf_;
        const ::grpc::string logging_;
        const ::grpc::string tracing_;
        const ::grpc::string completegroup_;
        const ::grpc::string instancesource_;
        const ::grpc::string getstepsequence_;
        const ::grpc::string markrecvfinished_;
      
        // Support for logging.
        WorkerCacheLogger* logger_;
        const string target_;
      
        TF_DISALLOW_COPY_AND_ASSIGN(GrpcRemoteWorker);
      };
      

      2.2 生成

      生成代碼如下:

      WorkerInterface* NewGrpcRemoteWorker(SharedGrpcChannelPtr channel,
                                           ::grpc::CompletionQueue* completion_queue,
                                           thread::ThreadPool* callback_threadpool,
                                           WorkerCacheLogger* logger,
                                           const string& target) {
        return new GrpcRemoteWorker(std::move(channel), completion_queue,
                                    callback_threadpool, logger, target);
      }
      

      具體調(diào)用是在緩存之中,代碼位于:tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc,其會依據(jù)參數(shù)決定生成何種 Worker。

      WorkerInterface* GetOrCreateWorker(const string& target) override {
        if (target == local_target_) {
          return local_worker_;
        } else {
          SharedGrpcChannelPtr channel = channel_cache_->FindWorkerChannel(target);
          if (!channel) {
            return nullptr;
          }
          size_t index = AssignWorkerToThread(target);
          return NewGrpcRemoteWorker(
              channel, worker_env_->GetCompletionQueue(index),
              worker_env_->GetThreadPool(), &logger_, target);
        }
      }
      

      2.3 發(fā)送請求

      我們接下看看如何發(fā)送請求。CreateWorkerSessionAsync 實際發(fā)送的就是 createworkersession_ 這個字符串對應(yīng)的請求。

        void CreateWorkerSessionAsync(const CreateWorkerSessionRequest* request,
                                      CreateWorkerSessionResponse* response,
                                      StatusCallback done) override {
          IssueRequest(request, response, createworkersession_, std::move(done));
        }
      

      IssueRequest 在上面定義之中有, 重新列出如下,可以看到調(diào)用的是 method 這個遠端方法,對于我們這里就是 createworkersession_。

      void IssueRequest(const protobuf::Message* request,
                        protobuf::Message* response, const ::grpc::string& method,
                        StatusCallback done, CallOptions* call_opts = nullptr,
                        bool fail_fast = true) {
        new RPCState<protobuf::Message>(
            &stub_, cq_, method, *request, response, std::move(done), call_opts,
            callback_threadpool_, MaxRetries(), fail_fast, &target_);
      }
      

      createworkersession_ 是在構(gòu)建函數(shù)之中配置。

      explicit GrpcRemoteWorker(SharedGrpcChannelPtr channel,
                                ::grpc::CompletionQueue* completion_queue,
                                thread::ThreadPool* callback_threadpool,
                                WorkerCacheLogger* logger, const string& target)
          : channel_(std::move(channel)),
            createworkersession_(Method(GrpcWorkerMethod::kCreateWorkerSession)), // 配置
      

      GrpcWorkerMethodName 定義在 tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.cc 之中,這里是具體的字符串,也就是遠端 GrpcWorker 的方法名字,可以看到,CreateWorkerSessionAsync 實際上調(diào)用的是 "/tensorflow.WorkerService/CreateWorkerSession"

      // Names of worker methods.
      enum class GrpcWorkerMethod {
        kGetStatus,
        kCreateWorkerSession,
        kDeleteWorkerSession,
        kRegisterGraph,
        kDeregisterGraph,
        kRunGraph,
        kCleanupGraph,
        kCleanupAll,
        kRecvTensor,
        kRecvBuf,
        kLogging,
        kTracing,
        kCompleteGroup,
        kCompleteInstance,
        kGetStepSequence,
        kMarkRecvFinished,
      };
      
      const char* GrpcWorkerMethodName(GrpcWorkerMethod id) {
        switch (id) {
          case GrpcWorkerMethod::kGetStatus:
            return "/tensorflow.WorkerService/GetStatus";
          case GrpcWorkerMethod::kCreateWorkerSession:
            return "/tensorflow.WorkerService/CreateWorkerSession";
          case GrpcWorkerMethod::kDeleteWorkerSession:
            return "/tensorflow.WorkerService/DeleteWorkerSession";
          case GrpcWorkerMethod::kRegisterGraph:
            return "/tensorflow.WorkerService/RegisterGraph";
          case GrpcWorkerMethod::kDeregisterGraph:
            return "/tensorflow.WorkerService/DeregisterGraph";
          case GrpcWorkerMethod::kRunGraph:
            return "/tensorflow.WorkerService/RunGraph";
          case GrpcWorkerMethod::kCleanupGraph:
            return "/tensorflow.WorkerService/CleanupGraph";
          case GrpcWorkerMethod::kCleanupAll:
            return "/tensorflow.WorkerService/CleanupAll";
          case GrpcWorkerMethod::kRecvTensor:
            return "/tensorflow.WorkerService/RecvTensor";
          case GrpcWorkerMethod::kRecvBuf:
            return "/tensorflow.WorkerService/RecvBuf";
          case GrpcWorkerMethod::kLogging:
            return "/tensorflow.WorkerService/Logging";
          case GrpcWorkerMethod::kTracing:
            return "/tensorflow.WorkerService/Tracing";
          case GrpcWorkerMethod::kCompleteGroup:
            return "/tensorflow.WorkerService/CompleteGroup";
          case GrpcWorkerMethod::kCompleteInstance:
            return "/tensorflow.WorkerService/CompleteInstance";
          case GrpcWorkerMethod::kGetStepSequence:
            return "/tensorflow.WorkerService/GetStepSequence";
          case GrpcWorkerMethod::kMarkRecvFinished:
            return "/tensorflow.WorkerService/MarkRecvFinished";
        }
        // Shouldn't be reached.
        LOG(FATAL) << "Invalid id: this line shouldn't be reached.";
        return "invalid id";
      }
      

      3. Worker Service

      WorkerService是一個 gRPC 服務(wù),其定義了一個 TensorFlow 服務(wù)。WorkerService 代表MasterService在一組本地設(shè)備上執(zhí)行數(shù)據(jù)流圖。 一個 WorkerService 會跟蹤多個 "注冊的計算圖"。每個注冊圖是客戶計算圖的一個子圖,只對應(yīng)那些應(yīng)該在這個工作者上執(zhí)行的節(jié)點(以及使用 RecvTensor 方法進行進程間通信之中所需的任何額外節(jié)點)。

      Master 會依據(jù) ClusterSpec 內(nèi)容在集群之中尋找其他的 Server 實例,找到之后把這些 Server 實例作為 Worker 角色。Master 接著把子圖分發(fā)給這些 Worker 節(jié)點,然后安排這些 Worker 完成具體子圖的計算過程。Worker 之間如果存在數(shù)據(jù)依賴,則通過進程間通信進行交互。無論是 Master 調(diào)用 Worker,還是 Worker 之間互相訪問,都要遵循 WorkerService 定義的接口規(guī)范。WorkerService 的所有接口定義在 worker_service.proto 文件中。

      service WorkerService {
        // See worker.proto for details.
        rpc GetStatus(GetStatusRequest) returns (GetStatusResponse);
      
        // See worker.proto for details.
        rpc CreateWorkerSession(CreateWorkerSessionRequest)
            returns (CreateWorkerSessionResponse);
      
        // See worker.proto for details.
        rpc DeleteWorkerSession(DeleteWorkerSessionRequest)
            returns (DeleteWorkerSessionResponse);
      
        // See worker.proto for details.
        rpc RegisterGraph(RegisterGraphRequest) returns (RegisterGraphResponse);
      
        // See worker.proto for details.
        rpc DeregisterGraph(DeregisterGraphRequest) returns (DeregisterGraphResponse);
      
        // See worker.proto for details.
        rpc RunGraph(RunGraphRequest) returns (RunGraphResponse);
      
        // See worker.proto for details.
        rpc CleanupGraph(CleanupGraphRequest) returns (CleanupGraphResponse);
      
        // See worker.proto for details.
        rpc CleanupAll(CleanupAllRequest) returns (CleanupAllResponse);
      
        // See worker.proto for details.
        rpc RecvTensor(RecvTensorRequest) returns (RecvTensorResponse) {
          // RecvTensor Method
        }
      
        // See worker.proto for details.
        rpc Logging(LoggingRequest) returns (LoggingResponse);
      
        // See worker.proto for details.
        rpc Tracing(TracingRequest) returns (TracingResponse);
      
        // See worker.proto for details.
        rpc RecvBuf(RecvBufRequest) returns (RecvBufResponse) {}
      
        // See worker.proto for details.
        rpc GetStepSequence(GetStepSequenceRequest) returns (GetStepSequenceResponse);
      
        // See worker.proto for details.
        rpc CompleteGroup(CompleteGroupRequest) returns (CompleteGroupResponse);
      
        // See worker.proto for details.
        rpc CompleteInstance(CompleteInstanceRequest)
            returns (CompleteInstanceResponse);
      }
      

      3.3.1 WorkerInterface

      與 MasterService 類似,對于 WorkerService 的訪問是通過 WorkerInterface 完成的。WorkerInterface 是 worker 的接口類,其是與 TensorFlow Worker service 交互的接口,主要是:

      • 定義了一些異步虛函數(shù),比如 CreateWorkerSessionAsync,派生類將實現(xiàn)它們,這些虛函數(shù)和 GrpcWorkerService 支持的 GrpcWorkerMethod 一一對應(yīng),也和 Protobuf 的配置一一對應(yīng)。
      • 定義了一些同步函數(shù),比如 CreateWorkerSession,其會通過類似 CallAndWait(&ME::CreateWorkerSessionAsync, request, response) 的方法來調(diào)用到具體異步虛函數(shù)。

      我們首先列出其異步接口如下。

      // Interface for talking with the TensorFlow Worker service.
      class WorkerInterface {
       public:
        virtual void GetStatusAsync(CallOptions* opts,
                                    const GetStatusRequest* request,
                                    GetStatusResponse* response, bool fail_fast,
                                    StatusCallback done) = 0;
      
        virtual void CreateWorkerSessionAsync(
            const CreateWorkerSessionRequest* request,
            CreateWorkerSessionResponse* response, StatusCallback done) = 0;
      
        virtual void DeleteWorkerSessionAsync(
            CallOptions* opts, const DeleteWorkerSessionRequest* request,
            DeleteWorkerSessionResponse* response, StatusCallback done) = 0;
      
        virtual void RegisterGraphAsync(const RegisterGraphRequest* request,
                                        RegisterGraphResponse* response,
                                        StatusCallback done) = 0;
      
        virtual void DeregisterGraphAsync(const DeregisterGraphRequest* request,
                                          DeregisterGraphResponse* response,
                                          StatusCallback done) = 0;
      
        virtual void RunGraphAsync(CallOptions* opts, RunGraphRequestWrapper* request,
                                   MutableRunGraphResponseWrapper* response,
                                   StatusCallback done) = 0;
      
        virtual void RunGraphAsync(CallOptions* opts, const RunGraphRequest* request,
                                   RunGraphResponse* response, StatusCallback done) {
          RunGraphRequestWrapper* wrapped_request = new ProtoRunGraphRequest(request);
          MutableRunGraphResponseWrapper* wrapped_response =
              new NonOwnedProtoRunGraphResponse(response);
          RunGraphAsync(opts, wrapped_request, wrapped_response,
                        [wrapped_request, wrapped_response,
                         done = std::move(done)](const Status& s) {
                          done(s);
                          delete wrapped_request;
                          delete wrapped_response;
                        });
        }
      
        virtual void CleanupGraphAsync(const CleanupGraphRequest* request,
                                       CleanupGraphResponse* response,
                                       StatusCallback done) = 0;
      
        virtual void CleanupAllAsync(const CleanupAllRequest* request,
                                     CleanupAllResponse* response,
                                     StatusCallback done) = 0;
      
        virtual void RecvTensorAsync(CallOptions* opts,
                                     const RecvTensorRequest* request,
                                     TensorResponse* response,
                                     StatusCallback done) = 0;
      
        virtual void LoggingAsync(const LoggingRequest* request,
                                  LoggingResponse* response, StatusCallback done) = 0;
      
        virtual void TracingAsync(const TracingRequest* request,
                                  TracingResponse* response, StatusCallback done) = 0;
      
        virtual void RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
                                  RecvBufResponse* response, StatusCallback done) = 0;
      
        virtual void CompleteGroupAsync(CallOptions* opts,
                                        const CompleteGroupRequest* request,
                                        CompleteGroupResponse* response,
                                        StatusCallback done) = 0;
      
        virtual void CompleteInstanceAsync(CallOptions* ops,
                                           const CompleteInstanceRequest* request,
                                           CompleteInstanceResponse* response,
                                           StatusCallback done) = 0;
      
        virtual void GetStepSequenceAsync(const GetStepSequenceRequest* request,
                                          GetStepSequenceResponse* response,
                                          StatusCallback done) = 0;
      }
      

      WorkerInterface 也提供給了同步接口,這樣 Master 或者 Worker 就可以像調(diào)用本地函數(shù)一樣調(diào)用遠端 WorkerService 的方法。同步接口是在異步接口之上實現(xiàn)的,通過使用 CallAndWait 適配器來完成對異步的封裝。 另外,為了避免外部代碼非法刪除 WorkerInterface 實例,也做了一些限制,比如其析構(gòu)函數(shù)是 protected,讓 WorkerCacheInterface 成為友元,并且由 WorkerCacheInterface::ReleaseWorker 負責(zé)刪除 WorkerInterface 實例。下面是同步接口和一些基礎(chǔ)函數(shù),成員變量。

      // Interface for talking with the TensorFlow Worker service.
      class WorkerInterface {
       public:
      
        virtual MutableRunGraphRequestWrapper* CreateRunGraphRequest() {
          return new MutableProtoRunGraphRequest;
        }
      
        virtual MutableRunGraphResponseWrapper* CreateRunGraphResponse() {
          return new OwnedProtoRunGraphResponse;
        }
      
        Status GetStatus(const GetStatusRequest* request,
                         GetStatusResponse* response) {
          Status ret;
          Notification n;
          GetStatusAsync(/*opts=*/nullptr, request, response, /*fail_fast=*/true,
                         [&ret, &n](const Status& s) {
                           ret = s;
                           n.Notify();
                         });
          n.WaitForNotification();
          return ret;
        }
      
        Status CreateWorkerSession(const CreateWorkerSessionRequest* request,
                                   CreateWorkerSessionResponse* response) {
          return CallAndWait(&ME::CreateWorkerSessionAsync, request, response);
        }
      
        Status DeleteWorkerSession(const DeleteWorkerSessionRequest* request,
                                   DeleteWorkerSessionResponse* response) {
          return CallAndWaitWithOptions(&ME::DeleteWorkerSessionAsync, request,
                                        response);
        }
      
        Status RegisterGraph(const RegisterGraphRequest* request,
                             RegisterGraphResponse* response) {
          return CallAndWait(&ME::RegisterGraphAsync, request, response);
        }
      
        Status DeregisterGraph(const DeregisterGraphRequest* request,
                               DeregisterGraphResponse* response) {
          return CallAndWait(&ME::DeregisterGraphAsync, request, response);
        }
      
        Status CleanupGraph(const CleanupGraphRequest* request,
                            CleanupGraphResponse* response) {
          return CallAndWait(&ME::CleanupGraphAsync, request, response);
        }
      
        Status CleanupAll(const CleanupAllRequest* request,
                          CleanupAllResponse* response) {
          return CallAndWait(&ME::CleanupAllAsync, request, response);
        }
      
        Status Logging(const LoggingRequest* request, LoggingResponse* response) {
          return CallAndWait(&ME::LoggingAsync, request, response);
        }
      
        Status Tracing(const TracingRequest* request, TracingResponse* response) {
          return CallAndWait(&ME::TracingAsync, request, response);
        }
      
        Status GetStepSequence(const GetStepSequenceRequest* request,
                               GetStepSequenceResponse* response) {
          return CallAndWait(&ME::GetStepSequenceAsync, request, response);
        }
      
       protected:
        // Instances of WorkerInterface must be deleted by a call to
        // WorkerCacheInterface::ReleaseWorker().
        virtual ~WorkerInterface() {}
        friend class WorkerCacheInterface;
      
        // NOTE: This should only be called by implementations of this
        // interface whose CreateRunGraphResponse() method returns a
        // proto-based wrappers for the RunGraphResponse message.
        RunGraphResponse* get_proto_from_wrapper(
            MutableRunGraphResponseWrapper* wrapper) {
          return wrapper->get_proto();
        }
      
       private:
        typedef WorkerInterface ME;
      
        template <typename Method, typename Req, typename Resp>
        Status CallAndWait(Method func, const Req* req, Resp* resp) {
          Status ret;
          Notification n;
          (this->*func)(req, resp, [&ret, &n](const Status& s) {
            ret = s;
            n.Notify();
          });
          n.WaitForNotification();
          return ret;
        }
      
        template <typename Method, typename Req, typename Resp>
        Status CallAndWaitWithOptions(Method func, const Req* req, Resp* resp) {
          CallOptions call_opts;
          Status ret;
          Notification n;
          (this->*func)(&call_opts, req, resp, [&ret, &n](const Status& s) {
            ret = s;
            n.Notify();
          });
          n.WaitForNotification();
          return ret;
        }
      };
      
      

      3.3.2 概念梳理

      WorkerService 接口之中牽扯到眾多概念,我們需要仔細梳理一下。

      前面提到了,Client 和 Master 之間是通過 session_handle / MasterSession 對 來進行合作,Master 和 Worker 之間就是通過 MasterSession 和 WorkerSession 來完成合作的,MasterSession 會統(tǒng)一管理多個隸屬的 WorkerSession。這里需要理清楚幾個概念之間的關(guān)系:

      • session_handle :目的是為了讓 MasterSession 統(tǒng)一管理其下面的多個 WorkerSession。與 MasterSession 一一對應(yīng),在創(chuàng)建 MasterSession 時候生成。通過 CreateSessionResponse 返回給 Client,通過 CreateWorkerSessionRequest 發(fā)送給 Worker,這樣從 Client 到 Master,再到 Worker 這一條鏈路就是由 session_handle 唯一標(biāo)示。
      • graph_handle :注冊子圖時候,由 GraphMgr::Register 生成,通過 RegisterGraphResponse 返回給 Master。子圖就被該 graph_handle 所標(biāo)識。在集群內(nèi)部則是 (session_handle, graph_handle) 二元組來唯一標(biāo)識某一個子圖。
      • step_id :因為 Master 會讓多個 Worker 并發(fā)執(zhí)行計算,所以會廣播通知大家執(zhí)行 RunGraph,為了區(qū)別不同的 Step,Master 為每次 RunStep 生成全局唯一的標(biāo)識 step_id,通過 RunGraphRequest 消息把 step_id 攜帶給 Worker。

      我們梳理一下 graph_handle。GraphMgr::Register 之中會生成 graph_handle。

      Status GraphMgr::Register(
          const string& handle, const GraphDef& gdef, WorkerSession* session,
          const GraphOptions& graph_options, const DebugOptions& debug_options,
          const ConfigProto& config_proto, int64_t collective_graph_key,
          DistributedFunctionLibraryRuntime* cluster_flr, string* graph_handle) {
        Item* item = new Item;
        Status s = InitItem(handle, gdef, session, graph_options, debug_options,
                            config_proto, collective_graph_key, cluster_flr, item);
        // Inserts one item into table_.
        {
          mutex_lock l(mu_);
          *graph_handle =
              strings::Printf("%016llx", static_cast<long long>(++next_id_));
          item->handle = *graph_handle;
          CHECK(table_.insert({*graph_handle, item}).second);
        }
        return Status::OK();
      }
      
      

      RegisterGraphResponse 之中會返回 graph_handle 給 Master。

      message RegisterGraphResponse {
        // If the registration succeeds, returns an opaque graph_handle to
        // the master. The master calls RunGraph with graph_handle to
        // compute different steps.
        string graph_handle = 1;
      }
      
      

      分割的子圖里有 graph_handle。

      // Graph partitioned into per-location subgraphs.
      struct Part {
        // Worker name.
        string name;
      
        // Maps feed names to rendezvous keys. Empty most of the time.
        std::unordered_map<string, string> feed_key;
      
        // Maps rendezvous keys to fetch names. Empty most of the time.
        std::unordered_map<string, string> key_fetch;
      
        // The interface to the worker. Owned.
        WorkerInterface* worker = nullptr;
      
        // After registration with the worker, graph_handle identifies
        // this partition on the worker.
        string graph_handle;
      
        Part() : feed_key(3), key_fetch(3) {}
      };
      
      

      注冊返回時候會給子圖設(shè)定 graph_handle。

      Status MasterSession::ReffedClientGraph::DoRegisterPartitions(
          const PartitionOptions& popts,
          std::unordered_map<string, GraphDef> graph_partitions) {
        partitions_.reserve(graph_partitions.size());
        Status s;
        for (auto& name_def : graph_partitions) {
          partitions_.emplace_back();
          Part* part = &partitions_.back();
          part->name = name_def.first;
          TrackFeedsAndFetches(part, name_def.second, popts);
          part->worker = worker_cache_->GetOrCreateWorker(part->name);
          if (part->worker == nullptr) {
            s = errors::NotFound("worker ", part->name);
            break;
          }
        }
        if (!s.ok()) {
          for (Part& part : partitions_) {
            worker_cache_->ReleaseWorker(part.name, part.worker);
            part.worker = nullptr;
          }
          return s;
        }
        struct Call {
          RegisterGraphRequest req;
          RegisterGraphResponse resp;
          Status status;
        };
        const int num = partitions_.size();
        gtl::InlinedVector<Call, 4> calls(num);
        BlockingCounter done(num);
        for (int i = 0; i < num; ++i) {
          const Part& part = partitions_[i];
          Call* c = &calls[i];
          c->req.set_session_handle(session_handle_);
          c->req.set_create_worker_session_called(!should_deregister_);
          c->req.mutable_graph_def()->Swap(&graph_partitions[part.name]);
          StripDefaultAttributes(*OpRegistry::Global(),
                                 c->req.mutable_graph_def()->mutable_node());
          *c->req.mutable_config_proto() = session_opts_.config;
          *c->req.mutable_graph_options() = session_opts_.config.graph_options();
          *c->req.mutable_debug_options() =
              callable_opts_.run_options().debug_options();
          c->req.set_collective_graph_key(collective_graph_key_);
      
          auto cb = [c, &done](const Status& s) {
            c->status = s;
            done.DecrementCount();
          };
          part.worker->RegisterGraphAsync(&c->req, &c->resp, cb);
        }
        done.Wait();
        for (int i = 0; i < num; ++i) {
          Call* c = &calls[i];
          s.Update(c->status);
          partitions_[i].graph_handle = c->resp.graph_handle();
        }
        return s;
      }
      
      

      使用時候會用 graph_handle 來唯一確定一個子圖。

      // Asynchronously deregisters subgraphs on the workers, without waiting for the
      // result.
      void MasterSession::ReffedClientGraph::DeregisterPartitions() {
        struct Call {
          DeregisterGraphRequest req;
          DeregisterGraphResponse resp;
        };
        for (Part& part : partitions_) {
          // The graph handle may be empty if we failed during partition registration.
          if (!part.graph_handle.empty()) {
            Call* c = new Call;
            c->req.set_session_handle(session_handle_);
            c->req.set_create_worker_session_called(!should_deregister_);
            c->req.set_graph_handle(part.graph_handle);
            // NOTE(mrry): We must capture worker_cache_ since this
            // could be deleted before the callback is called.
            WorkerCacheInterface* worker_cache = worker_cache_;
            const string name = part.name;
            WorkerInterface* w = part.worker;
            CHECK_NOTNULL(w);
            auto cb = [worker_cache, c, name, w](const Status& s) {
               delete c;
              worker_cache->ReleaseWorker(name, w);
            };
            w->DeregisterGraphAsync(&c->req, &c->resp, cb);
          }
        }
      }
      
      

      3.3.4 WorkerInterface 派生類

      如下圖所示,WorkerInterface 有兩種實現(xiàn)。

      • GrpcWorker : 本地模式下的Worker 角色,如果 Master/Worker都是在本地,則可以直接調(diào)用,不需要 RPC 的網(wǎng)絡(luò)傳輸。
      • GrpcRemoteWorker :分布式模式下,Worker 位于遠端,本地需要使用 GrpcRemoteWorker 來訪問遠端 Worker。
        • GrpcRemoteWorker 是 gRPC 客戶端,其通過 stub 來訪問遠端 Worker 之上的 GrpcWorkerService 服務(wù)。
        • GrpcWorkerService 實現(xiàn)了 WorkerService 定義的所有接口,但是實際業(yè)務(wù)是轉(zhuǎn)發(fā)給本地 GrpcWorker 完成。

      具體示例如下:

      圖 1 WorkerInterface 派生類

      3.3.5 使用

      Server 初始化時候,用如下代碼建立Worker Service。

        // 創(chuàng)建 GrpcWorker 以及對應(yīng)的 GrpcWorkerService
        worker_impl_ = opts.worker_func ? opts.worker_func(&worker_env_, config)
                                        : NewGrpcWorker(&worker_env_, config);
        worker_service_ = NewGrpcWorkerService(worker_impl_.get(), &builder,
                                               opts.worker_service_options)
      
      
      

      具體就是返回 GrpcWorkerService。

      // Returns an implementation of WorkerService rpc service.
      std::unique_ptr<AsyncServiceInterface> NewGrpcWorkerService(
          GrpcWorker* worker, ::grpc::ServerBuilder* builder,
          GrpcWorkerServiceOptions options) {
        return std::unique_ptr<AsyncServiceInterface>(
            new GrpcWorkerService(worker, builder, options));
      }
      
      
      

      GrpcServer 之中,使用 worker_thread_ 線程來執(zhí)行 GrpcWorkerService 的 HandleRPCsLoop 方法。

      worker_thread_.reset(
          env_->StartThread(ThreadOptions(), "TF_worker_service",
                            [this] { worker_service_->HandleRPCsLoop(); }));
      
      
      

      3.3.6 定義

      GrpcWorkerService 定義如下,因為其需要作為守護進程處理傳入的 gRPC 請求,所以在構(gòu)造函數(shù)之中會建立若干線程,用來響應(yīng)請求,然后在 HandleRPCsLoop 之中會啟動這些線程,然后做 Join。

      class GrpcWorkerService : public AsyncServiceInterface {
       public:
        GrpcWorkerService(GrpcWorker* worker, ::grpc::ServerBuilder* builder,
                          GrpcWorkerServiceOptions options)
            : is_shutdown_(false) {
          builder->RegisterService(&worker_service_);
      
          for (int i = 0; i < options.num_serving_threads; i++) {
            threads_.emplace_back(
                new GrpcWorkerServiceThread(worker, builder, options.queue_depth,
                                            cache_.get(), &worker_service_));
          }
        }
      
        // This method blocks forever handling requests from the completion queue.
        void HandleRPCsLoop() override {
          for (auto& worker_thread : threads_) {
            worker_thread->Start();
          }
          for (auto& worker_thread : threads_) {
            worker_thread->Join();
          }
        }
      
       private:
        grpc::WorkerService::AsyncService worker_service_;
        std::vector<std::unique_ptr<GrpcWorkerServiceThread>> threads_;
      
        std::unique_ptr<GrpcResponseCache> cache_;
        mutex service_shutdown_mu_;
        bool is_shutdown_ TF_GUARDED_BY(service_shutdown_mu_);
      
        TF_DISALLOW_COPY_AND_ASSIGN(GrpcWorkerService);
      };
      
      

      3.3.7 線程

      具體循環(huán)和響應(yīng)請求其實是在線程之中完成的,cq_ 則是 grpc 的完成隊列。

      // GrpcWorkerService spawns one or more GrpcWorkerServiceThreads to service
      // requests.  Each thread operates on an independent completion queue.
      class GrpcWorkerServiceThread {
       public:
        explicit GrpcWorkerServiceThread(
            GrpcWorker* worker, ::grpc::ServerBuilder* builder,
            std::unordered_map<int, int> queue_depth, GrpcResponseCache* cache,
            grpc::WorkerService::AsyncService* worker_service)
            : worker_(worker),
              queue_depth_(queue_depth),
              cache_(cache),
              worker_service_(worker_service),
              is_shutdown_(false) {
          cq_ = builder->AddCompletionQueue();
        }
      
        void Start() {
          thread_.reset(
              worker_->env()->env->StartThread(ThreadOptions(), "grpc_worker_service",
                                               [this]() { HandleRPCsLoop(); }));
        }
      }
      
      
      主循環(huán)

      GrpcWorkerServiceThread::HandleRPCsLoop 是線程主循環(huán),和 master service 類似。這里先準(zhǔn)備好一些 gRPC 調(diào)用的等待隊列,這些調(diào)用請求與后面的 GrpcWorkerMethod 一一對應(yīng),每個方法對應(yīng)的處理過程的代碼會在后面提到。

      // Add one or more completion queue entries for each worker method, then
      // begin servicing requests from the completion queue.
      void GrpcWorkerServiceThread::HandleRPCsLoop() {
        // TODO(ncteisen): This may require performance engineering. We can
        // change the number of threads, the number of handlers per thread,
        // or even decide to specialize certain threads to certain methods.
        SETUP_FOR_REQUEST(GetStatus, 1, false);
        SETUP_FOR_REQUEST(CreateWorkerSession, 1, false);
        SETUP_FOR_REQUEST(DeleteWorkerSession, 1, false);
        SETUP_FOR_REQUEST(CleanupAll, 1, false);
        SETUP_FOR_REQUEST(RegisterGraph, 1, false);
        SETUP_FOR_REQUEST(DeregisterGraph, 1, false);
        SETUP_FOR_REQUEST(Logging, 1, false);
        SETUP_FOR_REQUEST(Tracing, 1, false);
        SETUP_FOR_REQUEST(CompleteGroup, 10, true);
        SETUP_FOR_REQUEST(CompleteInstance, 10, true);
        SETUP_FOR_REQUEST(GetStepSequence, 10, true);
        SETUP_FOR_REQUEST(RecvBuf, 500, true);
        SETUP_FOR_REQUEST(RunGraph, 100, true);
        SETUP_FOR_REQUEST(CleanupGraph, 100, false);
        SETUP_FOR_REQUEST(MarkRecvFinished, 10, false);
      
        // TODO(ncteisen): Determine a better policy for enqueuing the
        // appropriate number of each request type.
        for (int i = 0;
             i < gtl::FindWithDefault(
                     queue_depth_, static_cast<int>(GrpcWorkerMethod::kRecvTensor),
                     1000);
             ++i) {
          EnqueueRecvTensorRequestRaw();
        }
      
        void* tag;
        bool ok;
      
        while (cq_->Next(&tag, &ok)) {
          UntypedCall<GrpcWorkerServiceThread>::Tag* callback_tag =
              static_cast<UntypedCall<GrpcWorkerServiceThread>::Tag*>(tag);
          CHECK(callback_tag);
          callback_tag->OnCompleted(this, ok);
        }
      }
      
      grpc request

      對于 request 的處理與 master 類似。每個 request 會調(diào)用到一個業(yè)務(wù) handler,如下面宏定義的 GrpcWorkerServiceThread::method##Handler。

      #define ENQUEUE_REQUEST(method, supports_cancel)                             \
        do {                                                                       \
          mutex_lock l(shutdown_mu_);                                              \
          if (!is_shutdown_) {                                                     \
            Call<GrpcWorkerServiceThread, grpc::WorkerService::AsyncService,       \
                 method##Request, method##Response>::                              \
                EnqueueRequestForMethod(                                           \
                    worker_service_, cq_.get(),                                    \
                    static_cast<int>(GrpcWorkerMethod::k##method),                 \
                    &GrpcWorkerServiceThread::method##Handler, (supports_cancel)); \
          }                                                                        \
        } while (0)
      
      #define SETUP_FOR_REQUEST(method, default_depth, supports_cancel)              \
        for (int i = 0;                                                              \
             i < gtl::FindWithDefault(queue_depth_,                                  \
                                      static_cast<int>(GrpcWorkerMethod::k##method), \
                                      default_depth);                                \
             ++i) {                                                                  \
          ENQUEUE_REQUEST(method, supports_cancel);                                  \
        }
      
      

      這里需要把每個 RPC 服務(wù)注冊為異步服務(wù),這使用 gRPC 自帶的 AddMethod 接口和 MarkMethodAsync 接口來完成。

      WorkerService::AsyncService::AsyncService() {
        for (int i = 0; i < kGrpcNumWorkerMethods; ++i) {
          AddMethod(new ::grpc::internal::RpcServiceMethod(
              GrpcWorkerMethodName(static_cast<GrpcWorkerMethod>(i)),
              ::grpc::internal::RpcMethod::NORMAL_RPC, nullptr));
          ::grpc::Service::MarkMethodAsync(i);
        }
      }
      
      Handler & 線程池

      具體 Handler 是通過宏來配置的,具體如下,這里調(diào)用了 Call,其會依據(jù)配置來決定是否使用線程池 compute_pool->Schedule 來進行計算。這里就用到了 worker env 里面集成的模塊。

        // Handle all non-cancellable simple methods with a standard wrapper.
        // The boolean may_block_on_compute_pool indicates whether or not the
        // operation may block on activities (such as op execution) that run on the
        // compute pool.
      #define HANDLE_CALL(method, may_block_on_compute_pool)                        \
        void method##Handler(WorkerCall<method##Request, method##Response>* call) { \
          auto closure = [this, call]() {                                           \
            Status s = worker_->method(&call->request, &call->response);            \
            if (!s.ok()) {                                                          \
              VLOG(3) << "Bad response from " << #method << ": " << s;              \
            }                                                                       \
            call->SendResponse(ToGrpcStatus(s));                                    \
          };                                                                        \
          if ((may_block_on_compute_pool)) {                                        \
            worker_->env()->env->SchedClosure(std::move(closure));                  \
          } else {                                                                  \
            worker_->env()->compute_pool->Schedule(std::move(closure));             \
          }                                                                         \
          ENQUEUE_REQUEST(method, false);                                           \
        }
      
        HANDLE_CALL(GetStatus, false);
        HANDLE_CALL(CreateWorkerSession, false);
        HANDLE_CALL(DeleteWorkerSession, true);
        HANDLE_CALL(CleanupAll, false);
        HANDLE_CALL(RegisterGraph, false);
        HANDLE_CALL(DeregisterGraph, false);
        HANDLE_CALL(CleanupGraph, false);
        HANDLE_CALL(Logging, false);
        HANDLE_CALL(Tracing, false);
      
      #undef HANDLE_CALL
      
      消息&方法

      GrpcWorkerMethod 定義了 worker 具體有哪些方法。

      // Names of worker methods.
      enum class GrpcWorkerMethod {
        kGetStatus,
        kCreateWorkerSession,
        kDeleteWorkerSession,
        kRegisterGraph,
        kDeregisterGraph,
        kRunGraph,
        kCleanupGraph,
        kCleanupAll,
        kRecvTensor,
        kRecvBuf,
        kLogging,
        kTracing,
        kCompleteGroup,
        kCompleteInstance,
        kGetStepSequence,
        kMarkRecvFinished,
      };
      
      

      具體這些消息名字對應(yīng)哪些方法,就是由 GrpcWorkerMethodName 完成。

      const char* GrpcWorkerMethodName(GrpcWorkerMethod id) {
        switch (id) {
          case GrpcWorkerMethod::kGetStatus:
            return "/tensorflow.WorkerService/GetStatus";
          case GrpcWorkerMethod::kCreateWorkerSession:
            return "/tensorflow.WorkerService/CreateWorkerSession";
          case GrpcWorkerMethod::kDeleteWorkerSession:
            return "/tensorflow.WorkerService/DeleteWorkerSession";
          case GrpcWorkerMethod::kRegisterGraph:
            return "/tensorflow.WorkerService/RegisterGraph";
          case GrpcWorkerMethod::kDeregisterGraph:
            return "/tensorflow.WorkerService/DeregisterGraph";
          case GrpcWorkerMethod::kRunGraph:
            return "/tensorflow.WorkerService/RunGraph";
          case GrpcWorkerMethod::kCleanupGraph:
            return "/tensorflow.WorkerService/CleanupGraph";
          case GrpcWorkerMethod::kCleanupAll:
            return "/tensorflow.WorkerService/CleanupAll";
          case GrpcWorkerMethod::kRecvTensor:
            return "/tensorflow.WorkerService/RecvTensor";
          case GrpcWorkerMethod::kRecvBuf:
            return "/tensorflow.WorkerService/RecvBuf";
          case GrpcWorkerMethod::kLogging:
            return "/tensorflow.WorkerService/Logging";
          case GrpcWorkerMethod::kTracing:
            return "/tensorflow.WorkerService/Tracing";
          case GrpcWorkerMethod::kCompleteGroup:
            return "/tensorflow.WorkerService/CompleteGroup";
          case GrpcWorkerMethod::kCompleteInstance:
            return "/tensorflow.WorkerService/CompleteInstance";
          case GrpcWorkerMethod::kGetStepSequence:
            return "/tensorflow.WorkerService/GetStepSequence";
          case GrpcWorkerMethod::kMarkRecvFinished:
            return "/tensorflow.WorkerService/MarkRecvFinished";
        }
        // Shouldn't be reached.
        return "invalid id";
      }
      
      

      在 AsyncService 之中會調(diào)用 GrpcWorkerMethodName 完成給 grpc 注冊。

      WorkerService::AsyncService::AsyncService() {
        for (int i = 0; i < kGrpcNumWorkerMethods; ++i) {
          AddMethod(new ::grpc::internal::RpcServiceMethod(
              GrpcWorkerMethodName(static_cast<GrpcWorkerMethod>(i)),
              ::grpc::internal::RpcMethod::NORMAL_RPC, nullptr));
          ::grpc::Service::MarkMethodAsync(i);
        }
      }
      
      
      業(yè)務(wù)處理

      具體業(yè)務(wù)處理則是調(diào)用了 Worker 完成的。

      void GetStepSequenceHandler(
          WorkerCall<GetStepSequenceRequest, GetStepSequenceResponse>* call) {
        Schedule([this, call]() {
          worker_->GetStepSequenceAsync(
              &call->request, &call->response, [call](const Status& s) {
                call->SendResponse(ToGrpcStatus(s));
              });
        });
        ENQUEUE_REQUEST(GetStepSequence, true);
      }
      
      

      目前從線程角度看,邏輯如下,這里假定有三個線程。Server 的 worker_thread_ 啟動了 GrpcWorkerService::HandleRPCsLoop(),其作用就是啟動兩個 GrpcWorkerServiceThread,每個 GrpcWorkerServiceThread 在 GrpcWorkerServiceThread::HandleRPCsLoop 之中會響應(yīng) gRPC 請求,進行業(yè)務(wù)處理。這里需要注意,GrpcWorkerService 和 GrpcWorkerServiceThread 都有 HandleRPCsLoop 這個方法。

      圖 2 線程角度

      3.3.8 業(yè)務(wù)邏輯

      CreateWorkerSession

      CreateWorkerSessionRequest 消息之中會傳遞 MasterSession對應(yīng)的 session_handle,Worker 接收消息之后,生成一個 WorkerSession。在一個集群之內(nèi),當(dāng) MasterSession 建立 WorkerSession 時候,都會把自己對應(yīng)的 session_handle 傳過去,這樣,WorkerSession 就可以通過 session_handle 知道自己屬于哪個 MasterSession。MasterSession 實例也可以統(tǒng)一管理隸屬于它的所有 WorkerSession。

      GrpcWorker 通過 SessionMgr 來具體完成對 WorkerSession 的管理,既可以通過 master task name 來確定 WorkerSession,也可以通過 session_handle 來確定。

      class SessionMgr {
      
        WorkerEnv* const worker_env_;  // Not owned.
        std::unique_ptr<WorkerCacheInterface> default_worker_cache_;
        std::shared_ptr<WorkerSession> legacy_session_;
        const WorkerCacheFactory worker_cache_factory_;
      
        // A map from session identifier to internal session structure.
        std::map<string, std::shared_ptr<WorkerSession>> sessions_ TF_GUARDED_BY(mu_);
      
        // Incarnation and WorkerSession handle associated with a master task.
        struct MasterAssociatedSession {
          const int64_t master_incarnation;
          const string session_handle;
        };
        // A map from master task name to its associated worker sessions.
        std::unordered_multimap<string, MasterAssociatedSession>
            master_to_associated_sessions_ TF_GUARDED_BY(mu_);
      };
      
      

      具體消息如下,注意,CreateWorkerSessionResponse 沒有返回任何東西:

      message CreateWorkerSessionRequest {
        // Sessions are identified by a given handle.
        string session_handle = 1;
      
        // Defines the configuration of a TensorFlow worker.
        ServerDef server_def = 2;
      
        // If true, any resources such as Variables used in the session will not be
        // shared with other sessions.
        bool isolate_session_state = 3;
      
        // The device attributes of all the devices in the cluster.
        repeated DeviceAttributes cluster_device_attributes = 4;
      
        // The master task name from which the request is sent.
        string master_task = 5;
      
        // The incarnation ID of the master task local CPU device.
        // If the target worker already has a WorkerSession created previously with
        // the same master task name but a different incarnation, it usually indicates
        // that the previous master failed before deleting the WorkerSession on the
        // worker. To prevent memory leaks, the worker should garbage collect the old
        // WorkerSessions.
        int64 master_incarnation = 6;
      }
      
      message CreateWorkerSessionResponse {}
      
      

      圖 3 CreateWorkerSession

      如前所述,GrpcWorker 這些消息都是用宏來生成的。

      #define HANDLE_CALL(method, may_block_on_compute_pool)                        \
        void method##Handler(WorkerCall<method##Request, method##Response>* call) { \
          auto closure = [this, call]() {                                           \
            Status s = worker_->method(&call->request, &call->response);            \
            if (!s.ok()) {                                                          \
              VLOG(3) << "Bad response from " << #method << ": " << s;              \
            }                                                                       \
            call->SendResponse(ToGrpcStatus(s));                                    \
          };                                                                        \
          if ((may_block_on_compute_pool)) {                                        \
            worker_->env()->env->SchedClosure(std::move(closure));                  \
          } else {                                                                  \
            worker_->env()->compute_pool->Schedule(std::move(closure));             \
          }                                                                         \
          ENQUEUE_REQUEST(method, false);                                           \
        }
      
        HANDLE_CALL(GetStatus, false);
        HANDLE_CALL(CreateWorkerSession, false);
        HANDLE_CALL(DeleteWorkerSession, true);
        HANDLE_CALL(CleanupAll, false);
        HANDLE_CALL(RegisterGraph, false);
        HANDLE_CALL(DeregisterGraph, false);
        HANDLE_CALL(CleanupGraph, false);
        HANDLE_CALL(Logging, false);
        HANDLE_CALL(Tracing, false);
      
      
      RegisterGraph

      RegisterGraphRequest 消息會發(fā)送 MasterSession 對應(yīng)的 session_handle,子圖 graph_def。當(dāng) Worker 接收消息,完成子圖注冊/初始化后,會返回該子圖的 graph_handle 給 Master。

      對于每個會話,在 master 將每個節(jié)點放在一個設(shè)備上之后,它將整個圖分割成許多子圖。一個子圖中的所有節(jié)點都在同一個 worker 中,但可能在該 worker 擁有的許多設(shè)備上(例如cpu0,加上gpu0、gpu1、...、gpu7)。在運行任何step之前,master 為 worker 注冊了子圖。成功的注冊會返回一個圖的句柄,以便在以后的 RunGraph請求中使用。

      ////////////////////////////////////////////////////////////////////////////////
      //
      // RegisterGraph method request/response messages
      //
      // For each session, after the master placed every node on a device,
      // it partitions the whole graph into many subgraphs. All the nodes in
      // a subgraph were in the same worker, but potentially on many devices
      // owned by that worker (e.g. cpu0, plus gpu0, gpu1, ..., gpu7). The
      // master registers subgraphs for a worker before running any steps. A
      // successful registration returns a graph handle to be used in latter
      // RunGraph requests.
      //
      ////////////////////////////////////////////////////////////////////////////////
      
      message RegisterGraphRequest {
        // Subgraphs are scoped within one session.
        string session_handle = 1;
      
        // Set to true if CreateWorkerSession was called for session_handle.
        bool create_worker_session_called = 6;
      
        // "graph_def" has the subgraph of nodes for this worker, with each node
        // having its device_name filled in.
        GraphDef graph_def = 2;
      
        // True iff the graph (before partitioning) contains control flow nodes.
        //
        // As of 01/11/2015, this is no longer set by clients.
        bool has_control_flow = 3 [deprecated = true];
      
        // Configuration options for the session in which this graph was created.
        GraphOptions graph_options = 4;
      
        // Field(s) used by TensorFlow Debugger (tfdbg).
        DebugOptions debug_options = 5;
      
        // If graph_def contains any collective ops this must be a positive
        // integer used to coordinate execution with other graphs.  All
        // graphs in a distributed execution with the same
        // collective_graph_key will coordinate to use the same step_id
        // concurrently so that BufRendezvous entries will make the correct
        // values accessible.
        int64 collective_graph_key = 7;
      
        // ConfigProto from the session in which this graph was created.
        // Contains additional parameters beyond graph_options, including
        // the name of the requested executor.
        ConfigProto config_proto = 8;
      }
      
      message RegisterGraphResponse {
        // If the registration succeeds, returns an opaque graph_handle to
        // the master. The master calls RunGraph with graph_handle to
        // compute different steps.
        string graph_handle = 1;
      }
      
      

      圖 4 RegisterGraph

      DeregisterGraph

      當(dāng)不再需要計算圖時(例如,整個計算圖圖被重新調(diào)度,圖內(nèi)節(jié)點被重新編排),Master 會利用該圖對應(yīng)的 graph_handle來取消注冊。在 Master 重啟情況下,Worker 根據(jù)以 TTL 為基礎(chǔ)的策略自動取消對應(yīng) graph_handle 的注冊。

      ////////////////////////////////////////////////////////////////////////////////
      //
      // DeregisterGraph method request/response messages
      //
      // The master deregisters the given graph_handle when the graph is no
      // longer needed (e.g., the overall graph is re-scheduled and nodes
      // are re-placed).
      //
      // The worker deregisters a graph_handle automatically according to on
      // a TTL-base policy in case of master restarts.
      //
      ////////////////////////////////////////////////////////////////////////////////
      
      message DeregisterGraphRequest {
        // The session_handle used when registering the graph. If session_handle is
        // empty, a single global namespace is used.
        string session_handle = 2;
      
        // Set to true if CreateWorkerSession was called for session_handle.
        bool create_worker_session_called = 3;
      
        // REQUIRED: graph_handle must be returned by a RegisterGraph call
        // to the same WorkerService.
        string graph_handle = 1;
      }
      
      message DeregisterGraphResponse {
        // TODO(mrry): Optionally add summary stats for the graph.
      }
      
      

      圖 5 DeregisterGraph

      RunGraph

      Master 用 RunGraphRequest 來執(zhí)行在 graph_handle下注冊的所有子圖。

      Master 會生成一個全局唯一的 step_id 來區(qū)分圖計算的不同運行 step。子圖之間可以使用 step_id 進行彼此通信(例如,發(fā)送/轉(zhuǎn)發(fā)操作),以區(qū)分不同運行產(chǎn)生的張量。

      RunGraphRequest 消息的 send 表示子圖輸入的張量,recv_key 指明子圖輸出的張量。RunGraphResponse 會返回 recv_key 對應(yīng)的 Tensor 列表。

      圖 6 RunGraph

      ////////////////////////////////////////////////////////////////////////////////
      //
      // RunGraph request / response messages
      //
      // The worker executes all subgraphs registered under graph_handle.
      // RunGraph returns after the execution finishes or an error is
      // encountered.
      // A sequence of RunGraphRequests with is_partial may be sent to RunGraph for
      // partial graph execution.
      //
      ////////////////////////////////////////////////////////////////////////////////
      
      // Options specific to the execution of a single step.
      message ExecutorOpts {
        bool record_costs = 1;
        bool record_timeline = 3;
        bool record_partition_graphs = 4;
        bool report_tensor_allocations_upon_oom = 5;
      }
      
      message RunGraphRequest {
        // session_handle is the master-generated unique id for this session.
        // If session_handle is non-empty, it must be the same as used when
        // registering the graph. If it is empty, a single global namespace is used to
        // search for the graph_handle.
        string session_handle = 8;
      
        // Set to true if CreateWorkerSession was called for session_handle.
        bool create_worker_session_called = 10;
      
        // REQUIRED: graph_handle must be returned by a RegisterGraph call
        // to the same WorkerService.
        string graph_handle = 1;
      
        // A unique ID to distinguish different runs of the same graph.
        //
        // The master generates a global unique step_id to distinguish
        // different runs of the graph computation. Subgraphs communicate
        // (e.g., send/recv ops) with each other using step_id to
        // distinguish tensors generated by different runs.
        int64 step_id = 2;
      
        // Options for this step.
        ExecutorOpts exec_opts = 5;
      
        // Runs the graph.
        //
        // Sends the tensors in "send" into the graph before the run and
        // fetches the keys into RunGraphResponse.recv after the run.
        repeated NamedTensorProto send = 3;
        repeated string recv_key = 4;
      
        // True if the RunGraphRequest is a partial run request.
        bool is_partial = 6;
        // True if this is the last partial run request in a sequence of requests.
        bool is_last_partial_run = 7;
      
        // If true then some errors, e.g., execution errors that have long
        // error messages, may return an OK RunGraphResponse with the actual
        // error saved in the status_code/status_error_message fields of the
        // response body. This is a workaround since the RPC subsystem may
        // truncate long metadata messages.
        bool store_errors_in_response_body = 9;
      
        // Unique identifier for this request. Every RunGraphRequest must have a
        // unique request_id, and retried RunGraphRequests must have the same
        // request_id. If request_id is zero, retry detection is disabled.
        //
        // Retried RunGraphRequests are problematic because they may issue a
        // RecvTensor that will have no corresponding sender and will wait forever.
        // Workers use request_ids to reject retried RunGraph requests instead of
        // waiting forever.
        int64 request_id = 11;
      
        // Next: 12
      }
      
      message RunGraphResponse {
        // A list of tensors corresponding to those requested by
        // RunGraphRequest.recv_key.
        repeated NamedTensorProto recv = 1;
      
        // If the request asked for execution stats, the cost graph, or the partition
        // graphs, these are returned here.
        // TODO(suharshs): Package these in a RunMetadata instead.
        StepStats step_stats = 2;
        CostGraphDef cost_graph = 3;
        repeated GraphDef partition_graph = 4;
      
        // If store_errors_in_response_body is true in the request, then
        // optionally the server may return an OK status for the RPC and
        // fill the true status into the fields below, to allow for messages
        // that are too long to fit in metadata.
        error.Code status_code = 5;
        string status_error_message = 6;
      }
      
      
      RecvTensor

      在具體運行之中,兩個 Worker 之間可能會交換數(shù)據(jù),此時生產(chǎn)者只是把準(zhǔn)備好的張量放入 rendezvous,消費者會主動發(fā)起 RecvTensorRequest 請求,RecvTensorRequest 里面 step_id 標(biāo)識是哪次 step,rendezvous_key 標(biāo)識要接收張量的通道(channel)。

      一個 RecvTensor 請求從通道中獲取一個張量,也可以通過多個 RecvTensor 請求在同一個通道中發(fā)送和接收多個張量。最終生產(chǎn)者的張量會通過 RecvTensorResponse 返回給消費者。

      圖 7 RecvTensor

      ////////////////////////////////////////////////////////////////////////////////
      //
      // RecvTensor method request/response messages
      //
      ////////////////////////////////////////////////////////////////////////////////
      
      message RecvTensorRequest {
        // The step in which the tensor will be produced.
        //
        // REQUIRED: This must eventually correspond to the step_id passed
        // into a RunGraph call on the same WorkerService.
        int64 step_id = 1;
      
        // A key identifying the channel to receive tensors from. A RecvTensor request
        // retrieves one tensor from the channel, but multiple tensors can be sent and
        // received over the same channel with multiple RecvTensor requests. See
        // rendezvous.h for details.
        string rendezvous_key = 2;
      
        // If true, use an out-of-band DMA mechanism to transfer the
        // received tensor.
        bool dma_ok = 3;
      
        // Optional information on client-side device locality.
        DeviceLocality client_locality = 4;
      
        // Optional information on server-side device locality.
        DeviceLocality server_locality = 5;
      
        // Optional information needed by the RPC subsystem.
        google.protobuf.Any transport_options = 6;
      
        // Unique identifier for this request. Every RecvTensorRequest must have a
        // unique request_id, and retried RecvTensorRequests must have the same
        // request_id. If request_id is zero, retry detection and response cache
        // are disabled.
        //
        // Retried RecvTensorRequests are problematic because a RecvTensor with no
        // corresponding sender will wait forever, and the tensor may have been
        // delivered to a previous retry. Workers use request_ids to reject retried
        // RecvTensor requests instead of waiting forever.
        int64 request_id = 7;
      }
      
      message RecvTensorResponse {
        // The tensor as a proto.
        TensorProto tensor = 1;
      
        // If true, this tensor was the output of a dead node, and the
        // content is invalid.
        bool is_dead = 2;
      
        // The time at which tensor was available and started to be returned.
        int64 send_start_micros = 3;
      
        // Optional additional information about how to receive the tensor,
        // e.g. in the event that RecvTensorRequest.dma_ok was true.
        google.protobuf.Any transport_options = 4;
      
        // Whether the receiver should send a MarkRecvFinishedRequest to the sender
        // to ack the message.
        bool require_ack = 5;
      }
      
      

      4. Worker

      Worker 類主要是提供了 WorkerEnv 和 PartialRunMgr,其可以被子類化,以便為不同的傳輸機制提供特定方法的專門實現(xiàn)。例如,GrpcWorker 專門實現(xiàn)了 RecvTensorAsync 方法,以支持更有效的 gRPC 數(shù)據(jù)結(jié)構(gòu)來處理大型二進制數(shù)據(jù)。

      class Worker : public WorkerInterface {
       protected:
        WorkerEnv* const env_;  // Not owned.
        RecentRequestIds recent_request_ids_;
      
       private:
        PartialRunMgr partial_run_mgr_;
      
        CancellationManager cancellation_manager_;
      
        TF_DISALLOW_COPY_AND_ASSIGN(Worker);
      };
      
      

      我們舉出一個方法看看,具體其他方法我們后面遇到了會說。

      void Worker::CleanupAllAsync(const CleanupAllRequest* request,
                                   CleanupAllResponse* response,
                                   StatusCallback done) {
        std::vector<string> containers;
        for (const auto& c : request->container()) containers.push_back(c);
        env_->device_mgr->ClearContainers(containers);
        done(Status::OK());
      }
      
      

      5. GrpcWorker

      GrpcWorker 是 GrpcRemoteWorker 對應(yīng)的遠端 Worker。也是 GrpcWorkerService 調(diào)用的對象,其實現(xiàn)了業(yè)務(wù)邏輯。其定義如下,我們可以看到其實現(xiàn)了幾個方法。

      class GrpcWorker : public Worker {
       public:
        GrpcWorker(WorkerEnv* env, const ConfigProto& config);
      
        // Specialized version of RecvTensor for gRPC, which avoids a copy.
        virtual void GrpcRecvTensorAsync(CallOptions* opts,
                                         const RecvTensorRequest* request,
                                         ::grpc::ByteBuffer* response,
                                         StatusCallback done);
      
        void LoggingAsync(const LoggingRequest* request, LoggingResponse* response,
                          StatusCallback done) override;
      
        void RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
                          RecvBufResponse* response, StatusCallback done) override;
      
        void CleanupGraphAsync(const CleanupGraphRequest* request,
                               CleanupGraphResponse* response,
                               StatusCallback done) override;
      
        WorkerEnv* env();
      
        void EnableResponseCache();
      
        void RemoveCacheEntryForId(int64 request_id);
      
       private:
        std::unique_ptr<GrpcResponseCache> response_cache_;
        const int32 recv_buf_max_chunk_;
      };
      
      

      至此,Worker 的靜態(tài)結(jié)構(gòu)我們已經(jīng)介紹完畢,具體 Worker 功能我們將在后文 Session 部分進行具體介紹。

      0xFF 參考

      TensorFlow Internals

      TensorFlow架構(gòu)與設(shè)計:概述

      TensorFlow內(nèi)核剖析

      TensorFlow架構(gòu)與設(shè)計:OP本質(zhì)論

      [譯] TensorFlow 白皮書

      2017TensorFlow開發(fā)者峰會

      https://jcf94.com/2018/02/28/2018-02-28-tfunpacking3/

      TensorFlow 拆包(五):Distributed

      TensorFlow Architecture

      『深度長文』Tensorflow代碼解析(五)

      什么是in-graph replication和between-graph replication?

      [騰訊機智] TensorFlow源碼解析(1): 創(chuàng)建會話

      05tensorflow分布式會話

      第八節(jié),配置分布式TensorFlow

      TensorFlow 分布式(Distributed TensorFlow)

      tensorflow源碼解析之distributed_runtime

      Distributed TensorFlow: A Gentle Introduction

      一文說清楚Tensorflow分布式訓(xùn)練必備知識

      TensorFlow中的Placement啟發(fā)式算法模塊——Placer

      TensorFlow的圖切割模塊——Graph Partitioner

      TensorFlow中的通信機制——Rendezvous(一)本地傳輸

      TensorFlow分布式采坑記

      TensorFlow技術(shù)內(nèi)幕(九):模型優(yōu)化之分布式執(zhí)行

      Tensorflow架構(gòu)流程]

      posted @ 2022-03-21 19:29  羅西的思考  閱讀(895)  評論(0)    收藏  舉報
      主站蜘蛛池模板: 一区二区三区午夜福利院| 亚洲人成网线在线播放VA| 国产精品国产精品国产专区不卡| 蜜桃av亚洲第一区二区| 激情的视频一区二区三区| 夜爽8888视频在线观看| 日韩精品一区二区三区中文无码| 无码中文av波多野结衣一区| 亚洲熟女乱色综合一区| 亚洲成在人线在线播放无码| 国内精品人妻一区二区三区| 怡红院一区二区三区在线| 国产精品久久久久鬼色| 国产最新AV在线播放不卡| 精品人妻伦一二三区久久| 97精品人妻系列无码人妻| 国产高清自产拍av在线| 精品国产AV无码一区二区三区| 免费现黄频在线观看国产 | 日韩人妻无码精品无码中文字幕| 欧美在线观看www| 国内熟妇人妻色在线三级| 最近中文字幕免费手机版| 亚洲精品久荜中文字幕| 二区中文字幕在线观看| 久久国产一区二区三区| 自拍偷亚洲产在线观看| 久久精品人妻无码专区| 保德县| 无码av人片在线观看天堂| 国产午夜影视大全免费观看| 色综合久久一区二区三区| 亚洲av永久一区二区| 天堂亚洲免费视频| 国产仑乱无码内谢| 国产成人精品性色av麻豆| 欧美色丁香| 国产精品久久久久无码av色戒| 久久香蕉国产线看观看怡红院妓院| 安宁市| 色欲狠狠躁天天躁无码中文字幕 |