<output id="qn6qe"></output>

    1. <output id="qn6qe"><tt id="qn6qe"></tt></output>
    2. <strike id="qn6qe"></strike>

      亚洲 日本 欧洲 欧美 视频,日韩中文字幕有码av,一本一道av中文字幕无码,国产线播放免费人成视频播放,人妻少妇偷人无码视频,日夜啪啪一区二区三区,国产尤物精品自在拍视频首页,久热这里只有精品12

      [源碼解析] TensorFlow 分布式環境(8) --- 通信機制

      [源碼解析] TensorFlow 分布式環境(8) --- 通信機制

      當計算圖在設備之間劃分之后,跨設備的 PartitionGraph 之間可能存在著數據依賴關系,因此 TF 在它們之間插入 Send/Recv 節點,這樣就完成數據交互。而在分布式模式之中,Send/Recv 通過 RpcRemoteRendezvous 完成數據交換,所以我們需要先看看 TF 之中的數據交換機制 Rendezvous。

      迄今為止,在分布式機器學習之中,我們看到了太多的 Rendezvous,其大多出現在彈性和通信相關部分,雖然具體意義各有細微不同,但是基本意義都差不多,就是來自其法語單詞的原意:會合,聚會,集會,約會等。TensorFlow的Rendezvous是消息傳輸的通信組件和交換機制。

      本文依舊深度借鑒了兩位大神:

      [TensorFlow Internals] (https://github.com/horance-liu/tensorflow-internals),雖然其分析的不是最新代碼,但是建議對 TF 內部實現機制有興趣的朋友都去閱讀一下,絕對大有收獲。
      https://home.cnblogs.com/u/deep-learning-stacks/ 西門宇少,不僅僅是 TensorFlow,其公共號還有更多其他領域,業界前沿。

      本系列其他文章是:

      [翻譯] TensorFlow 分布式之論文篇 "TensorFlow : Large-Scale Machine Learning on Heterogeneous Distributed Systems"

      [翻譯] TensorFlow 分布式之論文篇 "Implementation of Control Flow in TensorFlow"

      [源碼解析] TensorFlow 分布式環境(1) --- 總體架構

      [源碼解析] TensorFlow 分布式環境(2)---Master 靜態邏輯

      [源碼解析] TensorFlow 分布式環境(3)--- Worker 靜態邏輯

      [源碼解析] TensorFlow 分布式環境(4) --- WorkerCache

      [源碼解析] TensorFlow 分布式環境(5) --- Session

      [源碼解析] TensorFlow 分布式環境(7) --- Worker 動態邏輯

      1. 機制

      在分布式模式之中,對跨設備的邊會進行分裂,在邊的發送端和接收端會分別插入 Send 節點和 Recv 節點。

      • 進程內的 Send 和 Recv 節點通過 IntraProcessRendezvous 實現數據交換。
      • 進程間的 Send 和 Recv 節點通過 GrpcRemoteRendezvous 實現數據交換。

      我們假設 Worker 0 有兩個 GPU,當插入Send 節點和 Recv 節點,效果如下,其中 Worker 1 發送給 Worker 之間的代表進程間通過 GrpcRemoteRendezvous 實現數據交換,Worker 0 內部兩個 GPU 之間的虛線箭頭代表進程內部通過 IntraProcessRendezvous 實現數據交換,Worker 之間的實線箭頭表示使用 RPC 進行數據交換。

      當執行某次 step,如果兩個 Worker 需要交互數據,則:

      • 生產者 Sender 會先生成張量,放入本地 Table。
      • 消費者 Receiver 向生產者發送 RecvTensorRequest 消息,消息之中攜帶二元組 (step_id, rendezvous_key)
      • 生產者端 Worker 會從本地 Table 獲取相應的 Tensor 數據,并通過 RecvTensorResponse 返回。

      其中send/recv 的數據傳輸是通過 WorkerInterface 的派生類作為接口完成的,WorkerInterface 則基于底層的 gRPC 通信庫。

      圖 1 發送/接受

      1.1 消息標識符

      我們在學習 PyTorch 分布式時候,就知道每次分布式通信都需要有一個全局唯一的標識符,比如:

      • 使用 autogradMessageId 來表示一對 send/recv autograd 函數。每 send-recv 對被分配一個全局唯一的autograd_message_id 以唯一地標識該send-recv對。這對于在向后傳播期間查找遠程節點上的相應函數很有用。
      • 此容器還負責維護全局唯一的消息 id,用來關聯發送/接收自動微分函數對。格式是一個 64 位整數,前 16 位是工作者 id,后 48 位是 worker 內部自動遞增的整數。

      類似的,TF 也需要為每一個Send/Recv Pair 確定一個唯一的標識符,這樣在多組消息并行發送時候,才不會發生消息錯位。這個標識符就是 ParsedKey。

      1.1.1 定義

      其定義如下:

      • src_device:發送設備。
      • src:和 src_device 信息相同,只不過是表示為結構體。
      • src_incarnation:用于 debug,某個 worker 重啟后,該值會發生變化,這樣就可以區分之前掛掉的worker。
      • dst_device:接收方設備。
      • dst:和 dst_device 信息相同,只不過表示為結構體。
      • edge_name:邊名字,可以是張量名字,也可以是某種特殊意義的字符串。
      // Parses the key constructed by CreateKey and parse src/dst device
      // names into structures respectively.
      struct ParsedKey {
        StringPiece src_device;
        DeviceNameUtils::ParsedName src;
        uint64 src_incarnation = 0;
        StringPiece dst_device;
        DeviceNameUtils::ParsedName dst;
        StringPiece edge_name;
      
        ParsedKey() {}
        ParsedKey(const ParsedKey& b) { *this = b; }
      
        ParsedKey& operator=(const ParsedKey& b);
        StringPiece FullKey() const { return buf_; }
      
       private:
        friend class Rendezvous;
        friend class SendOp;
        friend class RecvOp;
        std::string buf_;
      };
      

      1.1.2 創建

      具體生成字符串 key 結果如下:

      src_device ; HexString(src_incarnation) ; dst_device ; name ; frame_iter.frame_id : frame_iter.iter_id
      

      具體代碼如下:

      /*  static */
      string Rendezvous::CreateKey(const string& src_device, uint64 src_incarnation,
                                   const string& dst_device, const string& name,
                                   const FrameAndIter& frame_iter) {
        // NOTE: ';' is not used in the device name's job name.
        //
        // We include both sender and receiver in the key to facilitate
        // debugging. For correctness, we only need to encode the receiver.
        //
        // "src_incarnation" is used to distinguish a worker when it
        // restarts.
        char buf[strings::kFastToBufferSize];
        return strings::StrCat(
            src_device, ";", strings::Uint64ToHexString(src_incarnation, buf), ";",
            dst_device, ";", name, ";", frame_iter.frame_id, ":", frame_iter.iter_id);
      }
      

      然后系統會使用 ParseKey 方法來解析key,生成 ParsedKey。ParseKey 對輸入 key 的前四個域做了映射,拋棄第五個域 frame_iter.frame_id : frame_iter.iter_id。其他都直接對應字面意思,只是 edge_name 對應了 name。

      /* static */
      Status Rendezvous::ParseKey(StringPiece key, ParsedKey* out) {
        if (key.data() == out->buf_.data()) {
          // Caller used our buf_ string directly, so we don't need to copy.  (The
          // SendOp and RecvOp implementations do this, for example).
          DCHECK_EQ(key.size(), out->buf_.size());
        } else {
          // Make a copy that our StringPieces can point at a copy that will persist
          // for the lifetime of the ParsedKey object.
          out->buf_.assign(key.data(), key.size());
        }
        StringPiece s(out->buf_);
        StringPiece parts[5];
        for (int i = 0; i < 5; i++) {
          parts[i] = ConsumeNextPart(&s, ';');
        }
        if (s.empty() &&          // Consumed the whole string
            !parts[4].empty() &&  // Exactly five parts
            DeviceNameUtils::ParseFullName(parts[0], &out->src) &&
            strings::HexStringToUint64(parts[1], &out->src_incarnation) &&
            DeviceNameUtils::ParseFullName(parts[2], &out->dst) &&
            !parts[3].empty()) {
          out->src_device = StringPiece(parts[0].data(), parts[0].size());
          out->dst_device = StringPiece(parts[2].data(), parts[2].size());
          out->edge_name = StringPiece(parts[3].data(), parts[3].size());
          return Status::OK();
        }
        return errors::InvalidArgument("Invalid  rendezvous key: ", key);
      }
      

      1.2 Rendezvous

      Rendezvous 是一個抽象,用于從生產者向消費者傳遞張量。一個 rendezvous 是一個通道(channels)的表(table)。每個通道都由一個 rendezvous 鍵來標記。該鍵編碼為<生產者,消費者>對,其中生產者和消費者是 tensorflow 設備。

      生產者調用 Send() 方法在一個命名的通道上發送一個張量。消費者調用 Recv() 方法從一個指定的通道接收一個張量。一個張量的序列可以從生產者傳遞給消費者。 消費者按照生產者發送的順序接收它們。

      消費者可以在張量產生之前或之后安全地請求張量。 消費者可以選擇進行阻塞式調用或提供回調:無論哪種情況,消費者都會在張量可用時收到它。 生產者永遠不會阻塞。

      1.2.1 接口類

      RendezvousInterface 是接口類,定義了虛函數。ParsedKey 也是定義在這里(我們省略了這部分代碼)。

      class RendezvousInterface {
       public:
        struct Args {
          DeviceContext* device_context = nullptr;
          AllocatorAttributes alloc_attrs;
          CancellationManager* cancellation_manager = nullptr;  // not owned.
        };
      
        // The caller is a tensor producer and it sends a message (a tensor
        // "val" and a bool "is_dead") under the given "key".
        //
        // {val, is_dead} is bundled as a message sent and received.
        // Typically, is_dead is set by some control flow nodes
        // (e.g., a not-taken branch).  args is passed by Send to the
        // Recv function to communicate any information that the Recv
        // function might need.  This is typically only necessary for
        // Send/Recv on the same worker.
        //
        // Send() never blocks.
        virtual Status Send(const ParsedKey& key, const Args& args, const Tensor& val,
                            const bool is_dead) = 0;
      
        // Callback provided by a tensor consumer waiting on the rendezvous.
        // It will be invoked when the tensor is available, or when a non-OK
        // status arises in the production of that tensor.  It also gets
        // two Rendezvous::Args, one provided by the sender, the other by the
        // receiver, which may be needed when a non-CPU device is in use
        // by either side.
        typedef std::function<void(const Status&, const Args&, const Args&,
                                   const Tensor&, const bool)>
            DoneCallback;
      
        virtual void RecvAsync(const ParsedKey& key, const Args& args,
                               DoneCallback done) = 0;
      
        // Synchronous wrapper for RecvAsync.
        Status Recv(const ParsedKey& key, const Args& args, Tensor* val,
                    bool* is_dead, int64_t timeout_ms);
        Status Recv(const ParsedKey& key, const Args& args, Tensor* val,
                    bool* is_dead);
      
        // Aborts all pending and future Send/Recv with the given "status".
        // StartAbort() does not wait for ongoing calls to finish.
        // REQUIRES: !status.ok()
        virtual void StartAbort(const Status& status) = 0;
      
       protected:
        virtual ~RendezvousInterface();
      
        virtual bool is_cross_process() { return false; }
        friend class ProcessFunctionLibraryRuntime;
      };
      

      1.2.2 基礎實現 Rendezvous

      Rendezvous 類提供了最基本的 Send、Recv 和 RecvAsync 的實現,也提供了 ParseKey 功能。

      // A reference-counted implementation of RendezvousInterface.
      //
      // This class is used in cases where a rendezvous may be shared between multiple
      // threads with no clear owner.
      class Rendezvous : public RendezvousInterface, public core::RefCounted {
       public:
        class Factory {
         public:
          // Default to a factory that evaluates to false.
          Factory() : valid_(false) {}
      
          Factory(std::function<Status(const int64_t, const DeviceMgr*, Rendezvous**)>
                      create_fn,
                  std::function<Status(const int64_t)> cleanup_fn)
              : valid_(true),
                create_fn_(std::move(create_fn)),
                cleanup_fn_(std::move(cleanup_fn)) {}
      
          // If no clean up fn is provided, just put in a dummy.
          // For backwards compatibility.
          explicit Factory(
              std::function<Status(const int64_t, const DeviceMgr*, Rendezvous**)>
                  create_fn)
              : valid_(true),
                create_fn_(std::move(create_fn)),
                cleanup_fn_([](const int64_t step_id) { return Status::OK(); }) {}
      
          explicit operator bool() const { return valid_; }
      
          Status operator()(const int64_t step_id, const DeviceMgr* device_mgr,
                            Rendezvous** rendez) const {
            return create_fn_(step_id, device_mgr, rendez);
          }
      
          Status CleanUp(const int64_t step_id) const { return cleanup_fn_(step_id); }
      
         private:
          bool valid_;
          std::function<Status(const int64_t, const DeviceMgr*, Rendezvous**)>
              create_fn_;
          std::function<Status(const int64_t)> cleanup_fn_;
        };
      
        // Constructs a rendezvous key for the tensor of "name" sent from
        // "src_device" to "dst_device". The tensor is generated in the frame
        // and iteration specified by "frame_iter".
        static std::string CreateKey(const std::string& src_device,
                                     uint64 src_incarnation,
                                     const std::string& dst_device,
                                     const std::string& name,
                                     const FrameAndIter& frame_iter);
      
        static Status ParseKey(StringPiece key, ParsedKey* out);
      };
      

      1.2.3 跨進程 RemoteRendezvous

      RemoteRendezvous 繼承了 Rendezvous,其只增加了一個純虛函數 Initialize 方法。所有跨進程通信的派生類都需要重寫此函數,因為需要借助 Session 成初始化工作。

      RemoteRendezvous 可以處理兩個遠端進程之中生產者或消費者的情況,增加了與遠程工作者協調的功能。RemoteRendezvous 遵循兩階段初始化策略:首先,對象被構建。最終,它們將被初始化。RendezvousMgrInterface 的客戶端必須保證最終對返回的 RemoteRendezvous 調用了 nitialize 方法。

      // RemoteRendezvous follow a 2-part initialization. First the objects are
      // constructed. Eventually, they will be initialized. Clients of the
      // RendezvousMgrInterface must guarantee to call Initialize on the returned
      // RemoteRendezvous eventually.
      //
      // Partially initialized RemoteRendezvous must respect the Rendezvous interface
      // (i.e. Send() must never block), however implementations are not expected to
      // actually perform the underlying operations until after the RemoteRendezvous
      // has been Initialize'd.
      class RemoteRendezvous : public Rendezvous {
       public:
        // Fully construct the RemoteRendezvous.
        virtual Status Initialize(WorkerSession* session) = 0;
      
       protected:
        bool is_cross_process() override { return true; }
      };
      

      1.2.4 BaseRemoteRendezvous

      因為跨進程通信存在不同協議,所以跨進程通信的各種 Rendezvous 都需要依據自己不同的協議來實現。所以 TF 在 RemoteRendezvous 和真正特化的各種 Rendezvous 中間加入了一個中間層 BaseRemoteRendezvous,這個類起到了承上啟下的作用,提供了公共的 Send 和 Recv 方法,可以做到盡可能代碼復用。

      BaseRemoteRendezvous 主要成員變量是 Rendezvous* local_,代碼之中大量使用了 BaseRecvTensorCall 作為參數,BaseRecvTensorCall 是通信的實體抽象。

      // RemoteRendezvous is a Rendezvous which can handle either
      // the producer or consumer being in a remote process.
      //
      // Buffering of Tensor values is delegated to a "local" Rendezvous
      // obtained from NewLocalRendezvous().  This class just adds
      // functionality to coordinate with remote workers.
      class BaseRemoteRendezvous : public RemoteRendezvous {
       public:
        BaseRemoteRendezvous(const WorkerEnv* env, int64_t step_id);
      
        // Upgrades the BaseRemoteRendezvous to full initialization.
        Status Initialize(WorkerSession* session) override;
      
        // Forwards to local_, where the Tensor "val" will be buffered and
        // any waiting callback stored.
        Status Send(const ParsedKey& key, const Rendezvous::Args& args,
                    const Tensor& val, const bool is_dead) override;
      
        // This method is called only by the RecvOp.  It tests to see
        // whether the value will be produced by a local or remote device
        // and handles accordingly.  In the local case it forwards to
        // local_, in the remote case it initiates an RPC request.
        void RecvAsync(const ParsedKey& key, const Rendezvous::Args& args,
                       DoneCallback done) override;
      
        void StartAbort(const Status& status) override;
      
        // This method is called only by the local Worker, forwarded through
        // the same method on RendezvousMgr.  This occurs when the Worker
        // has received a RecvTensor request, either locally or over the
        // network.  In either case it needs to retrieve a locally buffered
        // value from local_, and give it to its caller.
        //
        // Runs "done" as soon as the tensor for "parsed" is available or an error
        // is detected.
        //
        // REQUIRES: "parsed" is one that will be Saved into the local rendezvous.
        void RecvLocalAsync(const ParsedKey& parsed, DoneCallback done);
      
       protected:
        virtual void RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed,
                                         const Rendezvous::Args& args,
                                         DoneCallback done) = 0;
      
        // Returns true if "src" and "dst" are located in the same worker,
        // and hence may use a local rendezvous.
        virtual bool IsSameWorker(DeviceNameUtils::ParsedName src,
                                  DeviceNameUtils::ParsedName dst);
      
        // If aborted, aborts "call". Otherwise, adds "call" into active_.
        void RegisterCall(BaseRecvTensorCall* call, const Rendezvous::Args& args);
      
        // Removes "call" from active_ if "call" is in active_.
        void DeregisterCall(BaseRecvTensorCall* call);
      
        WorkerSession* session();
      
        bool is_initialized();
      
        ~BaseRemoteRendezvous() override;
      
        const WorkerEnv* const env_;  // Not owned.
        const int64_t step_id_;
      
       private:
        Rendezvous* local_;  // Owns a Ref on this object.
      
        mutable mutex mu_;
      
        // Status given by StartAbort() if any.
        Status status_ TF_GUARDED_BY(mu_);
      
        WorkerSession* session_ TF_GUARDED_BY(mu_);  // Not owned.
      
        // Data structures to handle calls when partially initialized.
        struct DeferredCall {
          const ParsedKey parsed;
          DoneCallback done;
      
          DeferredCall(const ParsedKey& parsed, DoneCallback done);
        };
        std::vector<DeferredCall> deferred_calls_ TF_GUARDED_BY(mu_);
      
        typedef std::function<void()> InactiveCallback;
      
        std::unordered_map<BaseRecvTensorCall*, InactiveCallback> active_
            TF_GUARDED_BY(mu_);
      
        bool is_initialized_locked() TF_SHARED_LOCKS_REQUIRED(mu_) {
          return session_ != nullptr;
        }
      
        // If "is_src" is true, checks that the rendezvous key "parsed"'s
        // source is in this process. If "is_src" is false, checks that the
        // rendezvous key "parsed"'s destination is in this process.
        Status ValidateDevices(const Rendezvous::ParsedKey& parsed, bool is_src);
      
        // Callback handling the case when a rendezvous has been
        // accomplished in local_ and the consumer is local to this process.
        // Tensor "in" will be copied into "out". The key "parsed" encodes
        // the src and dst devices.
        void SameWorkerRecvDone(const Rendezvous::ParsedKey& parsed,
                                const Rendezvous::Args& in_args,
                                const Rendezvous::Args& out_args, const Tensor& in,
                                Tensor* out, StatusCallback done);
      
        // Must be called only if fully initialized.
        void RecvLocalAsyncInternal(const ParsedKey& parsed, DoneCallback done);
      
        TF_DISALLOW_COPY_AND_ASSIGN(BaseRemoteRendezvous);
      };
      
      class BaseRecvTensorCall {
       public:
        BaseRecvTensorCall() {}
        virtual ~BaseRecvTensorCall() {}
        virtual void Start(std::function<void()> recv_done) = 0;
        virtual void StartAbort(const Status& s) = 0;
        virtual Status status() const = 0;
       private:
        TF_DISALLOW_COPY_AND_ASSIGN(BaseRecvTensorCall);
      };
      

      在創建時候構建了一個 local Rendezvous,這個 local Rendezvous用來完成基本業務。

      BaseRemoteRendezvous::BaseRemoteRendezvous(const WorkerEnv* env,
                                                 int64_t step_id)
          : env_(env),
            step_id_(step_id),
            local_(NewLocalRendezvous()),
            session_(nullptr) {}
      
      Rendezvous* NewLocalRendezvous() { return new LocalRendezvousWrapper; }
      

      LocalRendezvousWrapper 定義如下:

      class LocalRendezvousWrapper : public Rendezvous {
       public:
        LocalRendezvousWrapper() : impl_(this) {}
      
        Status Send(const ParsedKey& key, const Args& send_args, const Tensor& val,
                    const bool is_dead) override {
          return impl_.Send(key, send_args, val, is_dead);
        }
      
        void RecvAsync(const ParsedKey& key, const Args& recv_args,
                       DoneCallback done) override {
          impl_.RecvAsync(key, recv_args, std::move(done));
        }
      
        void StartAbort(const Status& status) override { impl_.StartAbort(status); }
      
       private:
        LocalRendezvous impl_;
      
        TF_DISALLOW_COPY_AND_ASSIGN(LocalRendezvousWrapper);
      };
      

      我們接下來看看 BaseRemoteRendezvous 初始化方法,其中做了基礎配置,比如設置session。

      Status BaseRemoteRendezvous::Initialize(WorkerSession* session) {
        std::vector<DeferredCall> deferred_calls;
        {
          mutex_lock l(mu_);
          if (session_ != nullptr) {
            if (session_->worker_name() == session->worker_name()) {
              return Status::OK();
            }
            Status s = errors::Internal(
                "Double init! Worker names would have changed from: ",
                session_->worker_name(), " -> ", session->worker_name());
            return s;
          }
          session_ = session;
          std::swap(deferred_calls, deferred_calls_);
        }
        for (auto& call : deferred_calls) {
          RecvLocalAsyncInternal(call.parsed, std::move(call.done));
        }
        return Status::OK();
      }
      

      1.2.5 RpcRemoteRendezvous

      RpcRemoteRendezvous 是 RemoteRendezvous 的 gRPC 協議實現。

      class RpcRemoteRendezvous : public BaseRemoteRendezvous {
       public:
        RpcRemoteRendezvous(const WorkerEnv* env, int64_t step_id)
            : BaseRemoteRendezvous(env, step_id) {}
      
       protected:
        void RecvFromRemoteAsync(const Rendezvous::ParsedKey& parsed,
                                 const Rendezvous::Args& args,
                                 DoneCallback done) override;
      
       private:
        ~RpcRemoteRendezvous() override {}
      
        TF_DISALLOW_COPY_AND_ASSIGN(RpcRemoteRendezvous);
      };
      

      BaseRecvTensorCall 對應的派生類是 RpcRecvTensorCall。

      // Used only to retrieve tensors from remote processes.
      class RpcRecvTensorCall : public BaseRecvTensorCall {
       public:
        RpcRecvTensorCall() : wi_(nullptr), dst_device_(nullptr) {}
      
        void Init(WorkerInterface* wi, int64_t step_id, StringPiece key,
                  AllocatorAttributes alloc_attrs, Device* dst_device,
                  const Rendezvous::Args& recv_args, Rendezvous::DoneCallback done) {
          wi_ = wi;
          alloc_attrs_ = alloc_attrs;
          dst_device_ = dst_device;
          recv_args_ = recv_args;
          done_ = std::move(done);
          req_.set_step_id(step_id);
          req_.set_rendezvous_key(key.data(), key.size());
          req_.set_request_id(GetUniqueRequestId());
        }
      
        void Reset() {
          // The RpcRemoteRendezvous using this object is responsible for calling
          // ReleaseWorker() before Reset().
      
          alloc_attrs_ = AllocatorAttributes();
          dst_device_ = nullptr;
          // We don't clear opts_ and assume that Init will set up the state for
          // opts_ appropriately.
          req_.Clear();
          resp_.Clear();
          {
            mutex_lock l(mu_);
            status_ = Status::OK();
          }
          done_ = nullptr;
        }
      
        ~RpcRecvTensorCall() override {
          // Since only the RpcRecvTensorFreeList will delete an
          // RpcRecvTensorCall, we require that ReleaseWorker() has been called before
          // the user releases a Call object to the free list.
          CHECK_EQ(static_cast<WorkerInterface*>(nullptr), wi_)
              << "Leaking WorkerInterface in RpcRecvTensorCall destructor.";
        }
      
        void Start(std::function<void()> recv_done) override {
          StartRTCall(std::move(recv_done));
        }
      
        void StartAbort(const Status& s) override {
          {
            mutex_lock l(mu_);
            status_.Update(s);
          }
          opts_.StartCancel();
        }
      
        Status status() const override {
          mutex_lock l(mu_);
          return status_;
        }
      
        void ReleaseWorker(WorkerCacheInterface* worker_cache) {
          DCHECK_NE(static_cast<WorkerInterface*>(nullptr), wi_)
              << "RpcRecvTensorCall::ReleaseWorker() called twice.";
          worker_cache->ReleaseWorker(src_worker_, wi_);
          wi_ = nullptr;
        }
      
        const Tensor& tensor() const { return resp_.tensor(); }
      
        bool is_dead() const { return resp_.metadata().is_dead(); }
      
        Device* dst_device() const { return dst_device_; }
        const Rendezvous::Args& recv_args() const { return recv_args_; }
        const Rendezvous::DoneCallback& done() const { return done_; }
      
       private:
        friend class RpcRemoteRendezvous;
      
        // Start the main RecvTensor call, checking for an async abort.
        void StartRTCall(std::function<void()> recv_done) {
          resp_.InitAlloc(dst_device_, alloc_attrs_);
          auto abort_checked = std::make_shared<Notification>();
          auto cb = [this, abort_checked,
                     recv_done = std::move(recv_done)](const Status& s) {
            // Make sure the Rendezvous abort checking is finished before running the
            // callback, which might destroy the current call object.
            abort_checked->WaitForNotification();
            if (!s.ok()) {
              mutex_lock l(mu_);
              status_.Update(s);
            }
            recv_done();
          };
          wi_->RecvTensorAsync(&opts_, &req_, &resp_, std::move(cb));
      
          // NOTE: Check if the rendezvous was aborted after sending out the RPC. The
          // ordering is important because StartAbort could be called right before
          // the RecvTensorAsync request registers its RPC cancellation to opts_.
          // In that case, the previous StartAbort would not trigger the
          // cancellation of this call.
          Status s;
          {
            mutex_lock l(mu_);
            s = status_;
          }
          if (!s.ok()) {
            opts_.StartCancel();
          }
          // Notify that the abort check has finished.
          abort_checked->Notify();
        }
      
        string src_worker_;
        string src_rel_device_;
        WorkerInterface* wi_;  // Not owned.
        AllocatorAttributes alloc_attrs_;
        Device* dst_device_;
        CallOptions opts_;
        RecvTensorRequest req_;
        TensorResponse resp_;
        Rendezvous::Args recv_args_;
        Rendezvous::DoneCallback done_;
      
        mutable mutex mu_;
        Status status_ TF_GUARDED_BY(mu_);
      
        TF_DISALLOW_COPY_AND_ASSIGN(RpcRecvTensorCall);
      };
      

      目前的邏輯關系具體如下:

      圖 2 Rendezvous 邏輯關系

      1.3 管理類

      RendezvousMgr 主要負責創建和銷毀 RemoteRendezvous,其會跟蹤一組本地的 rendezvous 實例,本工作者發送的所有張量都在 RendezvousMgr 中緩沖,直到張量被接收。 每個全局唯一的 "step_id" 對應于一個由 RendezvousMgr 管理的本地 rendezvous實例。

      1.3.1 接口

      RendezvousMgrInterface 是接口類。

      // RendezvousMgr keeps track of a set of local rendezvous instances.
      // All tensors sent by this worker are buffered in a RendezvousMgr
      // until the tensor is received.  Each global unique "step_id"
      // corresponds to one local rendezvous instance managed by a
      // RendezvousMgr.
      //
      // E.g.,
      //   Rendezvous* rendez = worker_env->rendezvous_mgr->Find(0x8935);
      //   fork execution of an graph executor using "rendez"  on thread 1;
      //   fork execution of another graph executor using "rendez" on thread 2;
      //   ...
      //   join threads 1 and 2;
      //
      // In the example above, execution in thread 1 and 2 communicates with
      // each other by send/recv operations through the "rend".
      //
      // Tensors sent and recved through rendezvous managed by this
      // RendezvousMgr must have keys generated by Rendezvous::CreateKey.
      class RendezvousMgrInterface {
       public:
        RendezvousMgrInterface() {}
        virtual ~RendezvousMgrInterface() {}
      
        // Returns Rendezvous supporting send and recv among workers in the
        // "step_id".  The caller takes ownership of one reference on the
        // returned Rendezvous instance.
        //
        // Note: the caller must guarantee to eventually call Initialize on the
        // returned RemoteRendezvous
        virtual RemoteRendezvous* Find(int64_t step_id) = 0;
      
        // Finds the local rendezvous instance for the "step_id".  Runs
        // "done" when the tensor for "key" is produced or an error occurs.
        //
        // This method is used by the rpc handler of RecvTensor.
        virtual void RecvLocalAsync(int64_t step_id,
                                    const Rendezvous::ParsedKey& parsed,
                                    Rendezvous::DoneCallback done) = 0;
      
        // Synchronous wrapper for RecvLocalAsync.
        virtual Status RecvLocal(int64_t step_id, const Rendezvous::ParsedKey& parsed,
                                 Tensor* val, bool* is_dead) = 0;
      
        // Removes rendezvous for "step_id".
        //
        // TODO(zhifengc): Have a background thread in worker that
        // periodically calls CleanupAll().
        virtual void Cleanup(int64_t step_id) = 0;
      };
      

      1.3.2 BaseRendezvousMgr

      BaseRendezvousMgr 實現了基本功能,比如依據step_id查找Rendezvous。

      class BaseRendezvousMgr : public RendezvousMgrInterface {
       public:
        explicit BaseRendezvousMgr(const WorkerEnv* worker_env);
      
        ~BaseRendezvousMgr() override;
      
        // Returns Rendezvous supporting send and recv among workers in the
        // "step_id".  The caller takes ownership of one reference on the
        // returned Rendezvous instance.
        //
        // Note: the caller must guarantee to eventually call Initialize on the
        // returned RemoteRendezvous
        RemoteRendezvous* Find(int64_t step_id) override;
      
        // Finds the local rendezvous instance for the "step_id".  Runs
        // "done" when the tensor for "key" is produced or an error occurs.
        //
        // This method is used by the rpc handler of RecvTensor.
        void RecvLocalAsync(int64_t step_id, const Rendezvous::ParsedKey& parsed,
                            Rendezvous::DoneCallback done) override;
      
        // Synchronous wrapper for RecvLocalAsync.
        Status RecvLocal(int64_t step_id, const Rendezvous::ParsedKey& parsed,
                         Tensor* val, bool* is_dead) override;
      
        // Removes rendezvous for "step_id".
        void Cleanup(int64_t step_id) override;
      
       protected:
        virtual BaseRemoteRendezvous* Create(int64_t step_id,
                                             const WorkerEnv* worker_env) = 0;
      
       private:
        // Maps step_id to rendezvous.
        typedef absl::flat_hash_map<int64_t, BaseRemoteRendezvous*> Table;
      
        // Not owned.
        const WorkerEnv* const worker_env_;
      
        mutex mu_;
        Table table_ TF_GUARDED_BY(mu_);
      
        BaseRemoteRendezvous* FindOrCreate(int64_t step_id);
      
        TF_DISALLOW_COPY_AND_ASSIGN(BaseRendezvousMgr);
      };
      

      2. 使用

      在前面執行計算時候,我們看到了一些關于 Rendezvous 的使用,接下來我們就找幾個情景來分析一下。

      2.1 Worker 接受

      我們首先看看接受方的 worker。

      2.1.1 DoRunGraph

      Worker 在 DoRunGraph 方法之中會接受張量。

      void Worker::DoRunGraph(CallOptions* opts, RunGraphRequestWrapper* request,
                              MutableRunGraphResponseWrapper* response,
                              StatusCallback done) {
      
        session->graph_mgr()->ExecuteAsync(
            request->graph_handle(), step_id, session.get(), request->exec_opts(),
            collector, response, cm, in,
            [this, step_id, response, session, cm, out, token, collector,
             device_profiler_session, opts, done](const Status& status) {
              Status s = status;
              if (s.ok()) {
                // 接受張量
                s = session->graph_mgr()->RecvOutputs(step_id, out);
              }
            });
      }
      

      RecvOutputs 方法如下,就是依據step_id獲取一個Rendezvous,然后接受消息。

      Status GraphMgr::RecvOutputs(const int64_t step_id, NamedTensors* out) {
        Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
        Status s = RecvOutputsFromRendezvous(rendezvous, out, Rendezvous::Args());
        rendezvous->Unref();
        size_t output_size = 0;
        for (auto& p : *out) {
          output_size += p.second.AllocatedBytes();
        }
        return s;
      }
      

      具體如下圖所示,流程順序如圖上數字,其中第3步返回了一個Rendezvous,RecvOutputsFromRendezvous 是一個全局方法。

      2.1.2 DoPartialRunGraph

      DoPartialRunGraph 會調用 RecvOutputsAsync 完成接受任務。

      void Worker::DoPartialRunGraph(CallOptions* opts,
                                     RunGraphRequestWrapper* request,
                                     MutableRunGraphResponseWrapper* response,
                                     StatusCallback done) {
        const int64_t step_id = request->step_id();
        const string& graph_handle = request->graph_handle();
      
        Status s = recent_request_ids_.TrackUnique(
            request->request_id(), "PartialRunGraph (Worker)", request);
      
        std::shared_ptr<WorkerSession> session;
        if (request->create_worker_session_called()) {
          s = env_->session_mgr->WorkerSessionForSession(request->session_handle(),
                                                         &session);
        } else {
          session = env_->session_mgr->LegacySession();
        }
      
        GraphMgr::NamedTensors in;
        GraphMgr::NamedTensors* out = new GraphMgr::NamedTensors;
        s = PrepareRunGraph(request, &in, out);
        auto finish = [done, out, opts](const Status& s) {
          opts->ClearCancelCallback();
          delete out;
          done(s);
        };
      
        CancellationManager* cm = nullptr;
        bool is_new_partial_run = partial_run_mgr_.FindOrCreate(step_id, &cm);
      
        // Before we start doing anything, we set the RPC cancellation.
        opts->SetCancelCallback([this, cm, step_id]() {
          cm->StartCancel();
          AbortStep(step_id);
        });
      
        // If this is a new partial run request, the request will need to start the
        // executors.
        if (is_new_partial_run) {
          CancellationToken token;
          token = cancellation_manager_.get_cancellation_token();
          cancellation_manager_.RegisterCallback(token,
                                                 [cm]() { cm->StartCancel(); });
          session->graph_mgr()->ExecuteAsync(
              graph_handle, step_id, session.get(), request->exec_opts(),
              nullptr /* collector */, nullptr /* response */, cm, in,
              [this, token, step_id, session](Status s) {
                cancellation_manager_.DeregisterCallback(token);
                partial_run_mgr_.ExecutorDone(step_id, s);
              });
        } else {
          // Send the partial run's new inputs.
          s = session->graph_mgr()->SendInputs(step_id, in);
        }
      
        // 這里會調用到 RecvOutputsAsync 來接受張量
        session->graph_mgr()->RecvOutputsAsync(
            step_id, out, [this, out, request, response, step_id, finish](Status s) {
              if (s.ok()) {
                // Construct and return the resp.
                for (const auto& p : *out) {
                  const string& key = p.first;
                  const Tensor& val = p.second;
                  response->AddRecv(key, val);
                }
              }
              if (request->is_last_partial_run()) {
                partial_run_mgr_.PartialRunDone(step_id, finish, s);
              } else {
                finish(s);
              }
            });
      }
      

      RecvOutputsAsync 這里調用了 RecvOutputsFromRendezvousAsync。

      void GraphMgr::RecvOutputsAsync(const int64_t step_id, NamedTensors* out,
                                      StatusCallback done) {
        Rendezvous* rendezvous = worker_env_->rendezvous_mgr->Find(step_id);
        std::vector<string> keys;
        std::vector<Tensor>* received_keys = new std::vector<Tensor>;
        keys.reserve(out->size());
        received_keys->reserve(out->size());
        for (const auto& p : *out) {
          keys.push_back(p.first);
          received_keys->push_back(p.second);
        }
        RecvOutputsFromRendezvousAsync(
            rendezvous, nullptr, {}, keys, received_keys,
            [done, rendezvous, received_keys, out, keys](const Status s) {
              rendezvous->Unref();
              size_t output_size = 0;
              for (int i = 0, end = keys.size(); i < end; ++i) {
                (*out)[keys[i]] = (*received_keys)[i];
                output_size += (*out)[keys[i]].AllocatedBytes();
              }
              metrics::RecordGraphOutputTensors(output_size);
              delete received_keys;
              done(s);
            });
      }
      

      具體如下圖,流程順序如圖上數字,其中第3步返回了一個Rendezvous,RecvOutputsFromRendezvousAsync是一個全局方法。

      2.2 GraphMgr 發送

      在 ExecuteAsync 之中會發送張量。

      void GraphMgr::ExecuteAsync(const string& handle, const int64_t step_id,
                                  WorkerSession* session, const ExecutorOpts& opts,
                                  StepStatsCollector* collector,
                                  MutableRunGraphResponseWrapper* response,
                                  CancellationManager* cancellation_manager,
                                  const NamedTensors& in, StatusCallback done) {
      
        if (s.ok()) {
          // 發送張量
          s = SendTensorsToRendezvous(rendezvous, nullptr, {}, keys, tensors_to_send);
        }
      
        // 執行子計算圖  
        StartParallelExecutors(
            handle, step_id, item, rendezvous, ce_handle, collector, cost_graph,
            cancellation_manager, session, start_time_usecs,
            [item, rendezvous, ce_handle, done, start_time_usecs, input_size,
             step_id](const Status& s) {
            });
      }
      

      SendTensorsToRendezvous 如下:

      Status SendTensorsToRendezvous(
          RendezvousInterface* rendezvous, DeviceContext* device_context,
          const std::vector<AllocatorAttributes>& alloc_attrs,
          const std::vector<string>& keys, gtl::ArraySlice<Tensor> tensors_to_send) {
      
        Rendezvous::ParsedKey parsed;
        for (int i = 0; i < keys.size(); ++i) {
          Rendezvous::Args rendez_args;
          rendez_args.device_context = device_context;
          if (!alloc_attrs.empty()) {
            rendez_args.alloc_attrs = alloc_attrs[i];
          }
          TF_RETURN_IF_ERROR(Rendezvous::ParseKey(keys[i], &parsed));
          TF_RETURN_IF_ERROR(
              rendezvous->Send(parsed, rendez_args, tensors_to_send[i], false));
        }
        return Status::OK();
      }
      

      我們接下來就仔細分析一下如何接受和發送。

      3. 發送

      我們首先看看發送流程。Send 過程并不涉及跨進程傳輸,所以和本地場景下的 Send 傳輸過程相同,這里只是把張量放到 Worker 的本地 Table 之中,完全不涉及跨網絡傳輸,是非阻塞的。

      3.1 BaseRemoteRendezvous

      Send 方法調用了 local_->Send 完成功能。

      Status BaseRemoteRendezvous::Send(const Rendezvous::ParsedKey& parsed,
                                        const Rendezvous::Args& args,
                                        const Tensor& val, const bool is_dead) {
      
        WorkerSession* sess = nullptr;
        {
          tf_shared_lock l(mu_);
          if (!status_.ok()) return status_;
          sess = session_;
        }
      
        if (!IsLocalDevice(sess->worker_name(), parsed.src_device)) {
          return errors::InvalidArgument(
              "Invalid rendezvous key (src): ", parsed.FullKey(), " @ ",
              sess->worker_name());
        }
      
        // Buffers "val" and "device_context" in local_.
        return local_->Send(parsed, args, val, is_dead);
      }
      

      3.2 LocalRendezvous

      LocalRendezvous::Send 會把張量插入到本地表。

      Status LocalRendezvous::Send(const Rendezvous::ParsedKey& key,
                                   const Rendezvous::Args& send_args,
                                   const Tensor& val, const bool is_dead) {
        uint64 key_hash = KeyHash(key.FullKey());
      
        if (is_dead) {
          static auto* rendezvous_dead_values_sent = monitoring::Counter<2>::New(
              "/tensorflow/core/rendezvous_dead_values_sent",
              "The number of dead values sent between a pair of devices.",
              "send_device", "recv_device");
          rendezvous_dead_values_sent
              ->GetCell(string(key.src_device), string(key.dst_device))
              ->IncrementBy(1);
        }
      
        mu_.lock();
        if (!status_.ok()) {
          // Rendezvous has been aborted.
          Status s = status_;
          mu_.unlock();
          return s;
        }
      
        ItemQueue* queue = &table_[key_hash];
        if (queue->head == nullptr || queue->head->type == Item::kSend) {
          // There is no waiter for this message. Append the message
          // into the queue. The waiter will pick it up when arrives.
          // Only send-related fields need to be filled.
          queue->push_back(new Item(send_args, val, is_dead));
          mu_.unlock();
          return Status::OK();
        }
      
        // There is an earliest waiter to consume this message.
        Item* item = queue->head;
      
        // Delete the queue when the last element has been consumed.
        if (item->next == nullptr) {
          table_.erase(key_hash);
        } else {
          queue->head = item->next;
        }
        mu_.unlock();
      
        // Notify the waiter by invoking its done closure, outside the
        // lock.
        DCHECK_EQ(item->type, Item::kRecv);
        (*item->recv_state.waiter)(Status::OK(), send_args, item->args, val, is_dead);
        delete item;
        return Status::OK();
      }
      

      此時邏輯如下,這里 Worker 0 指代的是一個工作者角色,并非是 Worker 類。

      圖 3 發送邏輯

      4. 接受

      發送端現在已經把準備好的張量放入本地 table。接收端需要從發送端的 table 取出張量,這里就涉及了跨進程傳輸。接受的處理過程是:

      • Recv方 是 Client,Recv 方將所需要的 Tensor 對應的 ParsedKey 拼接出來,然后向 Send 方發出 Request,ParsedKey 攜帶于 Request 之中。
      • Send方 是 Server,接收到 Request 后,Send 方立即在本地 Table 中查找 Client 所需要的Tensor,找到后將 Tensor 封裝成 Response 發送回 Recv 方。

      這里重點是:數據傳輸由 recv 部分發起,向 Send 方主動發出請求來觸發通信過程。這與我們常見的模式不同。我們知道,Worker 之中既有同步調用,也有異步調用,我們選擇異步調用來看看。先提前給出一個發送接受流程讓大家有個整體認識。下圖之中虛線表示返回張量。

      圖 4 發送接受整體邏輯

      4.1 Client

      客戶端邏輯如下:

      4.1.1 RecvOutputsFromRendezvousAsync

      全局函數 RecvOutputsFromRendezvousAsync 調用到了 rendezvous->RecvAsync。

      void RecvOutputsFromRendezvousAsync(
          RendezvousInterface* rendezvous, DeviceContext* device_context,
          const std::vector<AllocatorAttributes>& alloc_attrs,
          const std::vector<string>& keys, std::vector<Tensor>* received_tensors,
          StatusCallback done) {
        if (keys.empty()) {
          done(Status::OK());
          return;
        }
      
        received_tensors->reserve(keys.size());
        std::vector<
            std::tuple<string, Tensor*, Rendezvous::ParsedKey, AllocatorAttributes>>
            arguments;
        for (int i = 0; i < keys.size(); ++i) {
          Rendezvous::ParsedKey parsed;
          Status s = Rendezvous::ParseKey(keys[i], &parsed);
          received_tensors->push_back(Tensor());
          if (!s.ok()) {
            done(s);
            return;
          }
          AllocatorAttributes alloc_attr;
          if (!alloc_attrs.empty()) {
            alloc_attr = alloc_attrs[i];
          }
          arguments.emplace_back(keys[i], &((*received_tensors)[i]), parsed,
                                 alloc_attr);
        }
      
        auto status_cb = new ReffedStatusCallback(std::move(done));
        for (auto& p : arguments) {
          const string& key = std::get<0>(p);
          Tensor* val = std::get<1>(p);
          Rendezvous::ParsedKey parsed = std::get<2>(p);
          Rendezvous::Args rendez_args;
          rendez_args.device_context = device_context;
          rendez_args.alloc_attrs = std::get<3>(p);
          status_cb->Ref();
          rendezvous->RecvAsync(
              parsed, rendez_args,
              [val, key, status_cb](const Status& s,
                                    const Rendezvous::Args& send_args,
                                    const Rendezvous::Args& recv_args,
                                    const Tensor& v, const bool is_dead) {
                Status status = s;
                if (status.ok()) {
                  *val = v;
                  if (is_dead) {
                    status = errors::InvalidArgument("The tensor returned for ", key,
                                                     " was not valid.");
                  }
                }
                status_cb->UpdateStatus(status);
                status_cb->Unref();
              });
        }
        status_cb->Unref();
      }
      

      4.1.2 BaseRemoteRendezvous

      因為不在一個進程之內,所以調用到了 RecvFromRemoteAsync。

      void BaseRemoteRendezvous::RecvAsync(const ParsedKey& parsed,
                                           const Rendezvous::Args& recv_args,
                                           DoneCallback done) {
        Status s = ValidateDevices(parsed, false /*!is_src*/);
      
        profiler::ScopedMemoryDebugAnnotation op_annotation("RecvAsync", step_id_);
        // Are src and dst in the same worker?
        if (IsSameWorker(parsed.src, parsed.dst)) { // 在同一個worker里面
          // Recv the tensor from local_.
          local_->RecvAsync(
              parsed, recv_args,
              [this, parsed, done](
                  const Status& status, const Rendezvous::Args& send_args,
                  const Rendezvous::Args& recv_args, const Tensor& in, bool is_dead) {
      
                Tensor* out = new Tensor;
                StatusCallback final_callback = [done, send_args, recv_args, out,
                                                 is_dead](const Status& s) {
                  done(s, send_args, recv_args, *out, is_dead);
                  delete out;
                };
      
                if (status.ok()) {
                  SameWorkerRecvDone(parsed, send_args, recv_args, in, out,
                                     std::move(final_callback));
                } else {
                  final_callback(status);
                }
              });
          return;
        } else { // 不在同一個worker里面
          RecvFromRemoteAsync(parsed, recv_args, std::move(done));
        }
      }
      

      4.1.3 RpcRemoteRendezvous

      RpcRemoteRendezvous 檢查各項參數,準備 RpcRecvTensorCall,之后啟動 call->Start(),Start() 里面調的是 StartRTCall()。RpcRecvTensorCall 繼承了 BaseRecvTensorCall 這個抽象基類,是一次 gRPC 調用的抽象,其封裝了復雜的后續調用鏈。這里關鍵點是如下兩句,就是如何使用對應的 Worker 設置 RpcRecvTensorCall:

      WorkerInterface* rwi = worker_cache->GetOrCreateWorker(call->src_worker_);
      
      call->Init(rwi, step_id_, parsed.FullKey(), recv_args.alloc_attrs, dst_device,
                   recv_args, std::move(done));
      

      完整代碼如下:

      void RpcRemoteRendezvous::RecvFromRemoteAsync(
          const Rendezvous::ParsedKey& parsed, const Rendezvous::Args& recv_args,
          DoneCallback done) {
        CHECK(is_initialized());
        Status s;
      
        // Prepare a RecvTensor call that can handle being aborted.
        // 生成一個 Call
        RpcRecvTensorCall* call = get_call_freelist()->New();
      
        // key.src_device identifies a remote device.
        if (!DeviceNameUtils::SplitDeviceName(parsed.src_device, &call->src_worker_,
                                              &call->src_rel_device_)) {
          s = errors::Internal(parsed.src_device,
                               " is invalid remote source device.");
        }
        WorkerSession* sess = session();
        std::shared_ptr<WorkerCacheInterface> worker_cache =
            sess->GetSharedWorkerCache();
        // The worker will be released in a subsequent call to
        // sess->worker_cache()->ReleaseWorker() (if the call has not yet been
        // initialized) or call->ReleaseWorker() (if it has been initialized).
        
        // 拿到對應的 Worker
        WorkerInterface* rwi = worker_cache->GetOrCreateWorker(call->src_worker_);
      
        Device* dst_device;
        if (s.ok()) {
          s = sess->device_mgr()->LookupDevice(parsed.dst_device, &dst_device);
        }
        if (!s.ok()) {
          if (rwi != nullptr) {
            sess->worker_cache()->ReleaseWorker(call->src_worker_, rwi);
          }
          get_call_freelist()->Release(call);
          done(s, Args(), recv_args, Tensor{}, false);
          return;
        }
      
        // 用 Worker 來初始化
        call->Init(rwi, step_id_, parsed.FullKey(), recv_args.alloc_attrs, dst_device,
                   recv_args, std::move(done));
      
        // Record "call" in active_ so that it can be aborted cleanly.
        RegisterCall(call, recv_args);
      
        // Start "call".
        Ref();
        call->Start([this, call, worker_cache]() {
          // Removes "call" from active_. Prevent StartAbort().
          DeregisterCall(call);
          // If StartAbort was called prior to DeregisterCall, then the
          // current status should be bad.
          Status s = call->status();
          // NOTE: *session() can potentially be deleted before we return from
          // call->done()(...), so we must release the worker before calling the
          // callback.
          call->ReleaseWorker(session()->worker_cache());
          call->done()(s, Args(), call->recv_args(), call->tensor(), call->is_dead());
          get_call_freelist()->Release(call);
          Unref();
        });
      }
      

      4.1.4 RpcRecvTensorCall

      RpcRecvTensorCall 的 Start 方法如下,結果又來到了 StartRTCall。

      void RpcRecvTensorCall::Start(std::function<void()> recv_done) override {
        StartRTCall(std::move(recv_done));
      }
      

      RpcRecvTensorCall::StartRTCall 之中,會調用 Worker 的 RecvTensorAsync 來完成傳輸,其實就是 GrpcRemoteWorker 的 RecvTensorAsync。

      // Start the main RecvTensor call, checking for an async abort.
      void RpcRecvTensorCall::StartRTCall(std::function<void()> recv_done) {
        resp_.InitAlloc(dst_device_, alloc_attrs_);
        auto abort_checked = std::make_shared<Notification>();
        auto cb = [this, abort_checked,
                   recv_done = std::move(recv_done)](const Status& s) {
          // Make sure the Rendezvous abort checking is finished before running the
          // callback, which might destroy the current call object.
          abort_checked->WaitForNotification();
          if (!s.ok()) {
            mutex_lock l(mu_);
            status_.Update(s);
          }
          recv_done();
        };
        wi_->RecvTensorAsync(&opts_, &req_, &resp_, std::move(cb));
      
        // NOTE: Check if the rendezvous was aborted after sending out the RPC. The
        // ordering is important because StartAbort could be called right before
        // the RecvTensorAsync request registers its RPC cancellation to opts_.
        // In that case, the previous StartAbort would not trigger the
        // cancellation of this call.
        Status s;
        {
          mutex_lock l(mu_);
          s = status_;
        }
        if (!s.ok()) {
          opts_.StartCancel();
        }
        // Notify that the abort check has finished.
        abort_checked->Notify();
      }
      

      4.1.5 GrpcRemoteWorker

      RecvTensorAsync 方法的縮減版本如下,于是我們回到了熟悉的 Worker 流程。

      void GrpcRemoteWorker::RecvTensorAsync(CallOptions* call_opts, const RecvTensorRequest* request, TensorResponse* response, StatusCallback done) override {
        IssueRequest(request, response, recvtensor_, callback, call_opts);
      }
      

      目前我們完成了下圖的右半部分,如圖上圓圈所示。

      4.2 Server

      現在我們來到了 Server 端,其實就是張量發送方。接收到 RecvTensorRequest 之后的邏輯如下:

      4.2.1 GrpcWorkerService

      GrpcWorkerServiceThread::HandleRPCsLoop 之中有一個 for 循環,插入了 1000 個處理機制,設定了 GrpcWorkerMethod::kRecvTensor 由 EnqueueRecvTensorRequestRaw() 處理。這是事先緩存,為了加速處理,而且 EnqueueRecvTensorRequestRaw 之中在處理一個消息之后,會調用 EnqueueRequestForMethod 再次插入一個處理機制。

      void GrpcWorkerServiceThread::HandleRPCsLoop() {
        // TODO(ncteisen): This may require performance engineering. We can
        // change the number of threads, the number of handlers per thread,
        // or even decide to specialize certain threads to certain methods.
        SETUP_FOR_REQUEST(GetStatus, 1, false);
        SETUP_FOR_REQUEST(CreateWorkerSession, 1, false);
        SETUP_FOR_REQUEST(DeleteWorkerSession, 1, false);
        SETUP_FOR_REQUEST(CleanupAll, 1, false);
        SETUP_FOR_REQUEST(RegisterGraph, 1, false);
        SETUP_FOR_REQUEST(DeregisterGraph, 1, false);
        SETUP_FOR_REQUEST(Logging, 1, false);
        SETUP_FOR_REQUEST(Tracing, 1, false);
        SETUP_FOR_REQUEST(CompleteGroup, 10, true);
        SETUP_FOR_REQUEST(CompleteInstance, 10, true);
        SETUP_FOR_REQUEST(GetStepSequence, 10, true);
        SETUP_FOR_REQUEST(RecvBuf, 500, true);
        SETUP_FOR_REQUEST(RunGraph, 100, true);
        SETUP_FOR_REQUEST(CleanupGraph, 100, false);
        SETUP_FOR_REQUEST(MarkRecvFinished, 10, false);
      
        // TODO(ncteisen): Determine a better policy for enqueuing the
        // appropriate number of each request type.
        for (int i = 0;
             i < gtl::FindWithDefault(
                     queue_depth_, static_cast<int>(GrpcWorkerMethod::kRecvTensor),
                     1000);
             ++i) {
          EnqueueRecvTensorRequestRaw(); // 設置
        }
      
        void* tag;
        bool ok;
      
        while (cq_->Next(&tag, &ok)) {
          UntypedCall<GrpcWorkerServiceThread>::Tag* callback_tag =
              static_cast<UntypedCall<GrpcWorkerServiceThread>::Tag*>(tag);
          CHECK(callback_tag);
          callback_tag->OnCompleted(this, ok);
        }
      }
      

      這里會再次插入,會設定由 GrpcWorkerServiceThread::RecvTensorHandlerRaw 繼續處理 GrpcWorkerMethod::kRecvTensor。

      void EnqueueRecvTensorRequestRaw() {
        mutex_lock l(shutdown_mu_);
        if (!is_shutdown_) {
          Call<GrpcWorkerServiceThread, grpc::WorkerService::AsyncService,
               RecvTensorRequest, ::grpc::ByteBuffer>::
              EnqueueRequestForMethod(
                  worker_service_, cq_.get(),
                  static_cast<int>(GrpcWorkerMethod::kRecvTensor),
                  &GrpcWorkerServiceThread::RecvTensorHandlerRaw,
                  true /* supports cancel*/);
        }
      }
      

      4.2.2 GrpcWorkerServiceThread

      GrpcWorkerServiceThread 是服務端處理請求的線程類。這里就是調用 GrpcWorker 來繼續處理。這里使用了 WorkerCall 來作為參數。WorkerCall 是服務端處理一次 gRPC 請求和響應的類,是個別名。

      using WorkerCall =
          Call<GrpcWorkerServiceThread, grpc::WorkerService::AsyncService,
               RequestMessage, ResponseMessage>;
      

      代碼具體如下:

      void GrpcWorkerServiceThread::RecvTensorHandlerRaw(
          WorkerCall<RecvTensorRequest, ::grpc::ByteBuffer>* call) {
        Schedule([this, call]() {
          CallOptions* call_opts = new CallOptions;
          call->SetCancelCallback([call_opts]() { call_opts->StartCancel(); });
      
          worker_->GrpcRecvTensorAsync(
              call_opts, &call->request, &call->response,
              [call, call_opts](const Status& s) {
                call->ClearCancelCallback();
                delete call_opts;
                if (!s.ok()) {
                  VLOG(3) << "Bad response from RecvTensor:" << s;
                }
                call->SendResponse(ToGrpcStatus(s));
              });
        });
        EnqueueRecvTensorRequestRaw();
      }
      

      4.2.3 GrpcWorker

      GrpcWorker 是真正負責處理請求邏輯的 Worker,是 GrpcRemoteWorker 的服務端版本。GrpcWorker::GrpcRecvTensorAsync 邏輯是:

      • 會獲取 rendezvous。使用 rendezvous_mgr->RecvLocalAsync 將客戶端所需要的 Tensor 從本地 Table 查找出來。
      • 調用 grpc::EncodeTensorToByteBuffer(is_dead, tensor, cache_enabled, response) 把張量編碼。
      • 然后在 callback 之中調用 CopyDeviceToHost 把張量從 GPU 拷貝到 CPU。
      • 最后利用 gRPC 發送回客戶端。
      // GrpcRecvTensorAsync: unlike the other Worker methods, which use protocol
      // buffers for a response object, to avoid extra protocol buffer serialization
      // overhead we generate our response directly into a ::grpc::ByteBuffer object
      void GrpcWorker::GrpcRecvTensorAsync(CallOptions* opts,
                                           const RecvTensorRequest* request,
                                           ::grpc::ByteBuffer* response,
                                           StatusCallback done) {
      
        const int64_t request_id = request->request_id();
        const int64_t step_id = request->step_id();
      
        bool cache_enabled = (response_cache_ != nullptr && request_id != 0);
      
        auto do_response = [response, done, cache_enabled](const Tensor& tensor,
                                                           bool is_dead,
                                                           const Status& status) {
          if (status.ok()) {
            grpc::EncodeTensorToByteBuffer(is_dead, tensor, cache_enabled, response);
          }
          done(status);
        };
      
        // If response cache is enabled and the response cache already contains the
        // request, we delegate this retry request to the response cache. Otherwise,
        // we add the request to the response cache and start the computation to
        // retrieve the requested data.
        if (cache_enabled &&
            response_cache_->QueueRequest(request_id, step_id, do_response)) {
          return;
        }
      
        auto rendezvous_done = [this, request_id, do_response, cache_enabled](
                                   const Tensor& tensor, bool is_dead,
                                   const Status& status) {
          if (cache_enabled) {
            // Data is ready. Process all pending requests in the response cache.
            response_cache_->OnRequestFinished(request_id, tensor, is_dead, status);
          } else {
            do_response(tensor, is_dead, status);
          }
        };
      
        auto fail = [&rendezvous_done](const Status& status) {
          rendezvous_done(Tensor(), false, status);
        };
      
        Status s = recent_request_ids_.TrackUnique(
            request_id, "RecvTensor (GrpcWorker)", *request);
      
        const string& key = request->rendezvous_key();
        Rendezvous::ParsedKey parsed;
        s = Rendezvous::ParseKey(key, &parsed);
        Device* src_dev = nullptr;
        if (s.ok()) {
          s = PrepareRecvTensor(parsed, &src_dev);
        }
      
        // Request the tensor associated with the rendezvous key.
        // Note that we log the cancellation here but do not abort the current step.
        // gRPC can generate cancellations in response to transient network failures,
        // and aborting the step eliminates the opportunity for client side retries.
        // Repeated client failures will eventually cause the step to be aborted by
        // the client.
        opts->SetCancelCallback(
            [step_id]() { LOG(WARNING) << "RecvTensor cancelled for " << step_id; });
        env_->rendezvous_mgr->RecvLocalAsync(
            step_id, parsed,
            [opts, rendezvous_done, src_dev, request](
                const Status& status, const Rendezvous::Args& send_args,
                const Rendezvous::Args& recv_args, const Tensor& val,
                const bool is_dead) {
              opts->ClearCancelCallback();
              if (status.ok()) {
                // DMA can only be used for Tensors that do not fall into
                // the following three odd edge cases: 1) a zero-size
                // buffer, 2) a dead tensor which has an uninit value, and
                // 3) the tensor has the on_host allocation attribute,
                // i.e. it's in CPU RAM *independent of its assigned
                // device type*.
                const bool on_host = send_args.alloc_attrs.on_host();
                {
                  // Non-DMA cases.
                  if (src_dev->tensorflow_gpu_device_info() && (!on_host)) {
                    DeviceContext* send_dev_context = send_args.device_context;
                    AllocatorAttributes alloc_attrs;
                    alloc_attrs.set_gpu_compatible(true);
                    alloc_attrs.set_on_host(true);
                    Allocator* alloc = src_dev->GetAllocator(alloc_attrs);
                    Tensor* copy = new Tensor(alloc, val.dtype(), val.shape());
                    // "val" is on an accelerator device. Uses the device_context to
                    // fill the copy on host.
                    StatusCallback copy_ready = [rendezvous_done, copy,
                                                 is_dead](const Status& s) {
                      // The value is now ready to be returned on the wire.
                      rendezvous_done(*copy, is_dead, s);
                      delete copy;
                    };
      
                    CopyDeviceToHost(&val, alloc, alloc, request->rendezvous_key(),
                                     src_dev, copy, send_dev_context, copy_ready);
                    return;
                  }
                }
              }
      
              rendezvous_done(val, is_dead, status);
            });
      }
      

      4.2.4 BaseRendezvousMgr

      BaseRendezvousMgr::RecvLocalAsync 會從本地 Table 查找張量。

      void BaseRendezvousMgr::RecvLocalAsync(int64_t step_id,
                                             const Rendezvous::ParsedKey& parsed,
                                             Rendezvous::DoneCallback done) {
        auto rendez = FindOrCreate(step_id);
        auto done_cb = [rendez, done = std::move(done)](
                           const Status& s, const Rendezvous::Args& send_args,
                           const Rendezvous::Args& recv_args, const Tensor& v,
                           bool dead) {
          rendez->Unref();
          done(s, send_args, recv_args, v, dead);
        };
        rendez->RecvLocalAsync(parsed, std::move(done_cb));
      }
      

      4.2.5 BaseRemoteRendezvous

      其實,最終調用到了 RecvLocalAsyncInternal,其關鍵代碼是 local_->RecvAsync。

      void BaseRemoteRendezvous::RecvLocalAsync(const ParsedKey& parsed,
                                                DoneCallback done) {
        // Test whether the rendezvous is initialized using a shared lock, to avoid
        // the need for exclusive access in the common case.
        if (TF_PREDICT_FALSE(!is_initialized())) {
          mutex_lock l(mu_);
          if (!is_initialized_locked()) {
            // RecvLocalAsync can be called (due to an incoming RecvTensor RPC from a
            // remote worker) before the RunStep (or PartialRunStep) RPC from the
            // master arrives. RecvLocalAsync thus buffers the arguments until after
            // the RemoteRendezvous is Initialize()'d, when it completes the
            // rendezvous logic. At some point after Initialize() is called, a Tensor
            // is produced locally that will then be sent in response to the incoming
            // RPC.
            DeferredCall call(parsed, std::move(done));
            deferred_calls_.push_back(call);
            return;
          }
        }
        RecvLocalAsyncInternal(parsed, std::move(done));
      }
      
      void BaseRemoteRendezvous::RecvLocalAsyncInternal(const ParsedKey& parsed,
                                                        DoneCallback done) {
        Status s = ValidateDevices(parsed, true /* is_src */);
        if (!s.ok()) {
          done(s, Args(), Args(), Tensor(), false);
          return;
        }
        local_->RecvAsync(parsed, Args(), std::move(done));
      }
      

      4.2.6 LocalRendezvous

      LocalRendezvous::RecvAsync 完成了從本地 table 讀取張量的操作。

      void LocalRendezvous::RecvAsync(const Rendezvous::ParsedKey& key,
                                      const Rendezvous::Args& recv_args,
                                      Rendezvous::DoneCallback done) {
        uint64 key_hash = KeyHash(key.FullKey());
      
        mu_.lock();
        if (!status_.ok()) {
          // Rendezvous has been aborted.
          Status s = status_;
          mu_.unlock();
          done(s, Rendezvous::Args(), recv_args, Tensor(), false);
          return;
        }
      
        ItemQueue* queue = &table_[key_hash];
        if (queue->head == nullptr || queue->head->type == Item::kRecv) {
          // There is no message to pick up.
          // Only recv-related fields need to be filled.
          CancellationManager* cm = recv_args.cancellation_manager;
          CancellationToken token = CancellationManager::kInvalidToken;
          bool already_cancelled = false;
          if (cm != nullptr) {
            // Increment the refcount when cancellation manager is present, to make
            // sure the rendezvous outlives the recv and its cancel callbacks.
            // This refcount is dropped in exactly one of the following cases:
            // (1) Recv registers cancellation callback to cm, and then cm is
            //     cancelled, unref in the cancellation callback;
            // (2) Recv registers cancellation callback to cm, but cm is already
            //     cancelled, unref in the already_cancelled check;
            // (3) Recv is successful, and item done callback finishes deregistering
            //     the cancellation callback, unref in the item done callback;
            // (4) Recv is successful, but the item done callback fails to deregister
            //     the cancellation callback because cm already StartCancel, in this
            //     case the cancellation callback will be invoked by the cm anyway,
            //     unref in the cancellation callback.
            if (rc_owner_) rc_owner_->Ref();
            token = cm->get_cancellation_token();
            already_cancelled = !cm->RegisterCallback(token, [this, token, key_hash] {
              Item* item = nullptr;
              {
                mutex_lock l(mu_);
                ItemQueue* queue = &table_[key_hash];
                // Find an item in the queue with a cancellation token that matches
                // token, and remove it.
                if (queue->head != nullptr && queue->head->type == Item::kRecv) {
                  for (Item *prev = nullptr, *curr = queue->head; curr != nullptr;
                       prev = curr, curr = curr->next) {
                    if (curr->recv_state.cancellation_token == token) {
                      item = curr;
                      if (queue->head->next == nullptr) {
                        // We have a single-element queue, so we can erase it from
                        // the table.
                        table_.erase(key_hash);
                      } else {
                        // Remove the current item from the queue.
                        if (curr == queue->head) {
                          DCHECK_EQ(prev, nullptr);
                          queue->head = curr->next;
                        } else {
                          DCHECK_NE(prev, nullptr);
                          prev->next = curr->next;
                        }
                        if (queue->tail == curr) {
                          queue->tail = prev;
                        }
                      }
                      break;
                    }
                  }
                }
              }
      
              if (item != nullptr) {
                (*item->recv_state.waiter)(
                    StatusGroup::MakeDerived(
                        errors::Cancelled("RecvAsync is cancelled.")),
                    Rendezvous::Args(), item->args, Tensor(), /*is_dead=*/false);
                delete item;
              }
              // Unref case (1) and (4)
              if (rc_owner_) rc_owner_->Unref();
            });
          }
          if (already_cancelled) {
            mu_.unlock();
            // Unref case (2)
            if (rc_owner_) rc_owner_->Unref();
            done(StatusGroup::MakeDerived(
                     errors::Cancelled("RecvAsync is cancelled.")),
                 Rendezvous::Args(), recv_args, Tensor(), /*is_dead=*/false);
            return;
          }
      
          // TODO(b/143786186): Investigate moving the allocation of Item outside
          // the lock.
          if (cm != nullptr) {
            // NOTE(mrry): We must wrap done with code that deregisters the
            // cancellation callback before calling the done callback, because the
            // cancellation manager may no longer be live after done is called.
            queue->push_back(new Item(
                recv_args,
                [this, cm, token, done = std::move(done)](
                    const Status& s, const Rendezvous::Args& send_args,
                    const Rendezvous::Args& recv_args, const Tensor& v, bool dead) {
                  // TryDeregisterCallback returns true when the cancellation callback
                  // is successfully deregistered. If it fails because the CM already
                  // StartAbort, Unref will happen inside the cancellation callback
                  // when called by the CM.
                  if (cm->TryDeregisterCallback(token)) {
                    // Unref case (3)
                    if (this->rc_owner_) this->rc_owner_->Unref();
                  }
                  done(s, send_args, recv_args, v, dead);
                },
                token));
          } else {
            queue->push_back(new Item(recv_args, std::move(done), token));
          }
      
          mu_.unlock();
          return;
        }
      
        // A message has already arrived and is queued in the table under
        // this key.  Consumes the message and invokes the done closure.
        Item* item = queue->head;
      
        // Delete the queue when the last element has been consumed.
        if (item->next == nullptr) {
          table_.erase(key_hash);
        } else {
          queue->head = item->next;
        }
        mu_.unlock();
      
        // Invoke done() without holding the table lock.
        DCHECK_EQ(item->type, Item::kSend);
        done(Status::OK(), item->args, recv_args, *item->send_state.value,
             item->send_state.is_dead);
        delete item;
      }
      

      最終補齊了之前圖的所有邏輯。或者我們也可以從另一種角度來看,如下圖所示:

      0xFF 參考

      TensorFlow架構與設計:概述

      TensorFlow內核剖析

      TensorFlow架構與設計:OP本質論

      [譯] TensorFlow 白皮書

      2017TensorFlow開發者峰會

      https://jcf94.com/2018/02/28/2018-02-28-tfunpacking3/

      TensorFlow 拆包(五):Distributed

      TensorFlow Architecture

      『深度長文』Tensorflow代碼解析(五)

      什么是in-graph replication和between-graph replication?

      [騰訊機智] TensorFlow源碼解析(1): 創建會話

      05tensorflow分布式會話

      第八節,配置分布式TensorFlow

      TensorFlow 分布式(Distributed TensorFlow)

      tensorflow源碼解析之distributed_runtime

      Distributed TensorFlow: A Gentle Introduction

      一文說清楚Tensorflow分布式訓練必備知識

      TensorFlow中的Placement啟發式算法模塊——Placer

      TensorFlow的圖切割模塊——Graph Partitioner

      TensorFlow中的通信機制——Rendezvous(一)本地傳輸

      TensorFlow分布式采坑記

      TensorFlow技術內幕(九):模型優化之分布式執行

      Tensorflow架構流程]

      posted @ 2022-04-06 15:52  羅西的思考  閱讀(1234)  評論(1)    收藏  舉報
      主站蜘蛛池模板: 亚洲国产美国产综合一区| 国产一区二区三区色噜噜| 99久久精品国产熟女拳交| 在线视频中文字幕二区| 亚洲欧洲日韩国内精品| 成人亚洲一级午夜激情网| 樱花草在线社区www| 欧美大bbbb流白水| 亚洲区精品区日韩区综合区| 国产影片AV级毛片特别刺激| 亚洲av成人无码天堂| 中文精品无码中文字幕无码专区| 少妇高清一区二区免费看| 大厂| 成人亚洲狠狠一二三四区| 色狠狠色噜噜AV一区| 人妻少妇精品视频专区| 久久青草国产精品一区| 国产成人8x视频一区二区| 国厂精品114福利电影免费| 久久精品国产高潮国产夫妻| aⅴ精品无码无卡在线观看| 国产成人一区二区三区视频免费| 蜜臀av一区二区三区日韩| 性欧美三级在线观看| 成人午夜福利精品一区二区| 亚洲精品中文av在线| 兴隆县| 成人免费A级毛片无码片2022| 爱如潮水日本免费观看视频| 久久亚洲国产精品久久| 视频一区视频二区卡通动漫| 久久久av男人的天堂| 日本xxxx色视频在线播放| 东方四虎av在线观看| 灵山县| 天天爽夜夜爱| 久久精品伊人狠狠大香网| 国产成人高清精品免费软件| 视频一区视频二区卡通动漫| 欧美性xxxxx极品|