[源碼解析] TensorFlow 分布式環(huán)境(3)--- Worker 靜態(tài)邏輯
[源碼解析] TensorFlow 分布式環(huán)境(3)--- Worker 靜態(tài)邏輯
在具體介紹 TensorFlow 分布式的各種 Strategy 之前,我們首先需要看看分布式的基礎(chǔ):分布式環(huán)境。只有把基礎(chǔ)打扎實了,才能在以后的分析工作之中最大程度的掃清障礙,事半功倍。本篇介紹 Worker(一系列相關(guān)概念) 的靜態(tài)架構(gòu)。
本系列其他文章是:
[翻譯] TensorFlow 分布式之論文篇 "Implementation of Control Flow in TensorFlow"
[源碼解析] TensorFlow 分布式環(huán)境(1) --- 總體架構(gòu)
[源碼解析] TensorFlow 分布式環(huán)境(2)---Master 靜態(tài)邏輯
1. 繼承關(guān)系
1.1 角色概念
TensorFlow Worker 類是執(zhí)行計算的實體,其主要功能是:
- 接收 Master的請求。
- 管理 WorkerSession。
- 處理注冊的子圖,比如按照自己節(jié)點上的設(shè)備情況來對子圖進行二次分裂。
- 在每個設(shè)備上運行注冊的子圖。
- 支持 worker-to-worker 的張量傳輸?shù)鹊取>唧w如何處理依據(jù) worker 和 worker 的位置關(guān)系來決定,比如 CPU 和 GPU 之間使用 cudaMemcpyAsync,本地 GPU 之間通過 DMA,遠端 worker 通過 gRPC 或者 RDMA。
- 執(zhí)行完畢之后,從計算圖的終止節(jié)點 sink 中取出結(jié)果。
可以參見 protobuf/worker_service.proto 以了解關(guān)于每個方法的更多細節(jié)。
1.2 接口
對于 WorkerService 的訪問是通過 WorkerInterface 完成的。WorkerInterface 是 worker 的接口類,其是與 TensorFlow Worker service 交互的接口,主要是:
- 定義了一些異步虛函數(shù),比如 CreateWorkerSessionAsync,派生類將實現(xiàn)它們,這些虛函數(shù)和 GrpcWorkerService 支持的 GrpcWorkerMethod 一一對應(yīng),也和 Protobuf 的配置一一對應(yīng)。
- 定義了一些同步函數(shù),比如 CreateWorkerSession,其會通過類似 CallAndWait(&ME::CreateWorkerSessionAsync, request, response) 來調(diào)用到具體異步虛函數(shù)。
1.3 WorkerInterface 派生類
如下圖所示,WorkerInterface 有三種實現(xiàn)。
- Worker : 這個類可以被子類化,以便為不同的傳輸機制提供特定方法的專門實現(xiàn)。例如,GrpcWorker 專門實現(xiàn)了 RecvTensorAsync() 方法,以支持更有效的 gRPC 數(shù)據(jù)結(jié)構(gòu)來處理大型二進制數(shù)據(jù)。
- GrpcWorker : 從 Worker 再次派生,是本地模式下的 Worker 角色。如果 Master/Worker 都是在本地,則可以直接調(diào)用,不需要 RPC 的網(wǎng)絡(luò)傳輸。
- GrpcRemoteWorker :分布式模式下,Worker 位于遠端,本地需要使用 GrpcRemoteWorker 來訪問遠端 Worker。
- GrpcRemoteWorker 是 gRPC 客戶端,其通過 stub 來訪問遠端 Worker 之上的 GrpcWorkerService 服務(wù)。
- GrpcWorkerService 實現(xiàn)了 WorkerService 定義的所有接口,但是實際業(yè)務(wù)是轉(zhuǎn)發(fā)給本地 GrpcWorker 完成。
具體示例如下:

圖 1 Worker 邏輯關(guān)系
2. GrpcRemoteWorker
GrpcRemoteWorker 相當(dāng)于是遠端 Worker 的一個本地代理。
- 本地 Master 將計算圖進行分區(qū),然后依據(jù)分區(qū)是不在本地還是遠端,分別調(diào)用本地 Worker 或者 GrpcRemoteWorker 來執(zhí)行分區(qū)的子計算圖。
- 本地 GrpcRemoteWorker 生成是在 tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc 的GetOrCreateWorker 之中。
- GrpcRemoteWorker 會通過 IssueRequest 向遠端發(fā)送 grpc 請求。
- 遠程 GrpcWorkerService 守護進程收到請求后,調(diào)用本地 Worker 處理請求,完成后返回結(jié)果。
2.1 定義
具體 GrpcRemoteWorker 代碼如下,我們省略了部分代碼,比如 DeleteWorkerSessionAsync 方法的實現(xiàn)等。
class GrpcRemoteWorker : public WorkerInterface {
public:
explicit GrpcRemoteWorker(SharedGrpcChannelPtr channel,
::grpc::CompletionQueue* completion_queue,
thread::ThreadPool* callback_threadpool,
WorkerCacheLogger* logger, const string& target)
: channel_(std::move(channel)),
stub_(channel_),
cq_(completion_queue),
callback_threadpool_(callback_threadpool),
getstatus_(Method(GrpcWorkerMethod::kGetStatus)),
createworkersession_(Method(GrpcWorkerMethod::kCreateWorkerSession)),
deleteworkersession_(Method(GrpcWorkerMethod::kDeleteWorkerSession)),
registergraph_(Method(GrpcWorkerMethod::kRegisterGraph)),
deregistergraph_(Method(GrpcWorkerMethod::kDeregisterGraph)),
rungraph_(Method(GrpcWorkerMethod::kRunGraph)),
cleanupgraph_(Method(GrpcWorkerMethod::kCleanupGraph)),
cleanupall_(Method(GrpcWorkerMethod::kCleanupAll)),
recvtensor_(Method(GrpcWorkerMethod::kRecvTensor)),
recvbuf_(Method(GrpcWorkerMethod::kRecvBuf)),
logging_(Method(GrpcWorkerMethod::kLogging)),
tracing_(Method(GrpcWorkerMethod::kTracing)),
completegroup_(Method(GrpcWorkerMethod::kCompleteGroup)),
instancesource_(Method(GrpcWorkerMethod::kCompleteInstance)),
getstepsequence_(Method(GrpcWorkerMethod::kGetStepSequence)),
markrecvfinished_(Method(GrpcWorkerMethod::kMarkRecvFinished)),
logger_(logger),
target_(target) {}
~GrpcRemoteWorker() override {}
void CreateWorkerSessionAsync(const CreateWorkerSessionRequest* request,
CreateWorkerSessionResponse* response,
StatusCallback done) override {
IssueRequest(request, response, createworkersession_, std::move(done));
}
void RegisterGraphAsync(const RegisterGraphRequest* request,
RegisterGraphResponse* response,
StatusCallback done) override {
IssueRequest(request, response, registergraph_, std::move(done));
}
void RunGraphAsync(CallOptions* call_opts, const RunGraphRequest* request,
RunGraphResponse* response, StatusCallback done) override {
IssueRequest(request, response, rungraph_, std::move(done), call_opts);
}
void RunGraphAsync(CallOptions* call_opts, RunGraphRequestWrapper* request,
MutableRunGraphResponseWrapper* response,
StatusCallback done) override {
IssueRequest(&request->ToProto(), get_proto_from_wrapper(response),
rungraph_, std::move(done), call_opts);
}
private:
// Utility method for issuing a generic asynchronous request. The
// given callback, done, will be called when the RPC completes.
void IssueRequest(const protobuf::Message* request,
protobuf::Message* response, const ::grpc::string& method,
StatusCallback done, CallOptions* call_opts = nullptr,
bool fail_fast = true) {
new RPCState<protobuf::Message>(
&stub_, cq_, method, *request, response, std::move(done), call_opts,
callback_threadpool_, MaxRetries(), fail_fast, &target_);
}
void IssueRequest(const protobuf::Message* request, TensorResponse* response,
const ::grpc::string& method, StatusCallback done,
CallOptions* call_opts = nullptr) {
new RPCState<TensorResponse>(&stub_, cq_, method, *request, response,
std::move(done), call_opts,
callback_threadpool_, MaxRetries(),
/*fail_fast=*/true, &target_);
}
// Helper function for initializing the RpcMethod objects below.
const char* Method(GrpcWorkerMethod id) { return GrpcWorkerMethodName(id); }
// Helper function for configuring max GRPC retries. Defaults to 0 (no
// retries).
const int64_t MaxRetries() {
int64_t max_retries = -1;
TF_CHECK_OK(ReadInt64FromEnvVar("GRPC_MAX_RETRIES", 0, &max_retries));
return max_retries;
}
SharedGrpcChannelPtr channel_;
::grpc::GenericStub stub_;
::grpc::CompletionQueue* cq_;
thread::ThreadPool* callback_threadpool_;
const ::grpc::string getstatus_;
const ::grpc::string createworkersession_;
const ::grpc::string deleteworkersession_;
const ::grpc::string registergraph_;
const ::grpc::string deregistergraph_;
const ::grpc::string rungraph_;
const ::grpc::string cleanupgraph_;
const ::grpc::string cleanupall_;
const ::grpc::string recvtensor_;
const ::grpc::string recvbuf_;
const ::grpc::string logging_;
const ::grpc::string tracing_;
const ::grpc::string completegroup_;
const ::grpc::string instancesource_;
const ::grpc::string getstepsequence_;
const ::grpc::string markrecvfinished_;
// Support for logging.
WorkerCacheLogger* logger_;
const string target_;
TF_DISALLOW_COPY_AND_ASSIGN(GrpcRemoteWorker);
};
2.2 生成
生成代碼如下:
WorkerInterface* NewGrpcRemoteWorker(SharedGrpcChannelPtr channel,
::grpc::CompletionQueue* completion_queue,
thread::ThreadPool* callback_threadpool,
WorkerCacheLogger* logger,
const string& target) {
return new GrpcRemoteWorker(std::move(channel), completion_queue,
callback_threadpool, logger, target);
}
具體調(diào)用是在緩存之中,代碼位于:tensorflow/core/distributed_runtime/rpc/grpc_worker_cache.cc,其會依據(jù)參數(shù)決定生成何種 Worker。
WorkerInterface* GetOrCreateWorker(const string& target) override {
if (target == local_target_) {
return local_worker_;
} else {
SharedGrpcChannelPtr channel = channel_cache_->FindWorkerChannel(target);
if (!channel) {
return nullptr;
}
size_t index = AssignWorkerToThread(target);
return NewGrpcRemoteWorker(
channel, worker_env_->GetCompletionQueue(index),
worker_env_->GetThreadPool(), &logger_, target);
}
}
2.3 發(fā)送請求
我們接下看看如何發(fā)送請求。CreateWorkerSessionAsync 實際發(fā)送的就是 createworkersession_ 這個字符串對應(yīng)的請求。
void CreateWorkerSessionAsync(const CreateWorkerSessionRequest* request,
CreateWorkerSessionResponse* response,
StatusCallback done) override {
IssueRequest(request, response, createworkersession_, std::move(done));
}
IssueRequest 在上面定義之中有, 重新列出如下,可以看到調(diào)用的是 method 這個遠端方法,對于我們這里就是 createworkersession_。
void IssueRequest(const protobuf::Message* request,
protobuf::Message* response, const ::grpc::string& method,
StatusCallback done, CallOptions* call_opts = nullptr,
bool fail_fast = true) {
new RPCState<protobuf::Message>(
&stub_, cq_, method, *request, response, std::move(done), call_opts,
callback_threadpool_, MaxRetries(), fail_fast, &target_);
}
createworkersession_ 是在構(gòu)建函數(shù)之中配置。
explicit GrpcRemoteWorker(SharedGrpcChannelPtr channel,
::grpc::CompletionQueue* completion_queue,
thread::ThreadPool* callback_threadpool,
WorkerCacheLogger* logger, const string& target)
: channel_(std::move(channel)),
createworkersession_(Method(GrpcWorkerMethod::kCreateWorkerSession)), // 配置
GrpcWorkerMethodName 定義在 tensorflow/core/distributed_runtime/rpc/grpc_worker_service_impl.cc 之中,這里是具體的字符串,也就是遠端 GrpcWorker 的方法名字,可以看到,CreateWorkerSessionAsync 實際上調(diào)用的是 "/tensorflow.WorkerService/CreateWorkerSession"。
// Names of worker methods.
enum class GrpcWorkerMethod {
kGetStatus,
kCreateWorkerSession,
kDeleteWorkerSession,
kRegisterGraph,
kDeregisterGraph,
kRunGraph,
kCleanupGraph,
kCleanupAll,
kRecvTensor,
kRecvBuf,
kLogging,
kTracing,
kCompleteGroup,
kCompleteInstance,
kGetStepSequence,
kMarkRecvFinished,
};
const char* GrpcWorkerMethodName(GrpcWorkerMethod id) {
switch (id) {
case GrpcWorkerMethod::kGetStatus:
return "/tensorflow.WorkerService/GetStatus";
case GrpcWorkerMethod::kCreateWorkerSession:
return "/tensorflow.WorkerService/CreateWorkerSession";
case GrpcWorkerMethod::kDeleteWorkerSession:
return "/tensorflow.WorkerService/DeleteWorkerSession";
case GrpcWorkerMethod::kRegisterGraph:
return "/tensorflow.WorkerService/RegisterGraph";
case GrpcWorkerMethod::kDeregisterGraph:
return "/tensorflow.WorkerService/DeregisterGraph";
case GrpcWorkerMethod::kRunGraph:
return "/tensorflow.WorkerService/RunGraph";
case GrpcWorkerMethod::kCleanupGraph:
return "/tensorflow.WorkerService/CleanupGraph";
case GrpcWorkerMethod::kCleanupAll:
return "/tensorflow.WorkerService/CleanupAll";
case GrpcWorkerMethod::kRecvTensor:
return "/tensorflow.WorkerService/RecvTensor";
case GrpcWorkerMethod::kRecvBuf:
return "/tensorflow.WorkerService/RecvBuf";
case GrpcWorkerMethod::kLogging:
return "/tensorflow.WorkerService/Logging";
case GrpcWorkerMethod::kTracing:
return "/tensorflow.WorkerService/Tracing";
case GrpcWorkerMethod::kCompleteGroup:
return "/tensorflow.WorkerService/CompleteGroup";
case GrpcWorkerMethod::kCompleteInstance:
return "/tensorflow.WorkerService/CompleteInstance";
case GrpcWorkerMethod::kGetStepSequence:
return "/tensorflow.WorkerService/GetStepSequence";
case GrpcWorkerMethod::kMarkRecvFinished:
return "/tensorflow.WorkerService/MarkRecvFinished";
}
// Shouldn't be reached.
LOG(FATAL) << "Invalid id: this line shouldn't be reached.";
return "invalid id";
}
3. Worker Service
WorkerService是一個 gRPC 服務(wù),其定義了一個 TensorFlow 服務(wù)。WorkerService 代表MasterService在一組本地設(shè)備上執(zhí)行數(shù)據(jù)流圖。 一個 WorkerService 會跟蹤多個 "注冊的計算圖"。每個注冊圖是客戶計算圖的一個子圖,只對應(yīng)那些應(yīng)該在這個工作者上執(zhí)行的節(jié)點(以及使用 RecvTensor 方法進行進程間通信之中所需的任何額外節(jié)點)。
Master 會依據(jù) ClusterSpec 內(nèi)容在集群之中尋找其他的 Server 實例,找到之后把這些 Server 實例作為 Worker 角色。Master 接著把子圖分發(fā)給這些 Worker 節(jié)點,然后安排這些 Worker 完成具體子圖的計算過程。Worker 之間如果存在數(shù)據(jù)依賴,則通過進程間通信進行交互。無論是 Master 調(diào)用 Worker,還是 Worker 之間互相訪問,都要遵循 WorkerService 定義的接口規(guī)范。WorkerService 的所有接口定義在 worker_service.proto 文件中。
service WorkerService {
// See worker.proto for details.
rpc GetStatus(GetStatusRequest) returns (GetStatusResponse);
// See worker.proto for details.
rpc CreateWorkerSession(CreateWorkerSessionRequest)
returns (CreateWorkerSessionResponse);
// See worker.proto for details.
rpc DeleteWorkerSession(DeleteWorkerSessionRequest)
returns (DeleteWorkerSessionResponse);
// See worker.proto for details.
rpc RegisterGraph(RegisterGraphRequest) returns (RegisterGraphResponse);
// See worker.proto for details.
rpc DeregisterGraph(DeregisterGraphRequest) returns (DeregisterGraphResponse);
// See worker.proto for details.
rpc RunGraph(RunGraphRequest) returns (RunGraphResponse);
// See worker.proto for details.
rpc CleanupGraph(CleanupGraphRequest) returns (CleanupGraphResponse);
// See worker.proto for details.
rpc CleanupAll(CleanupAllRequest) returns (CleanupAllResponse);
// See worker.proto for details.
rpc RecvTensor(RecvTensorRequest) returns (RecvTensorResponse) {
// RecvTensor Method
}
// See worker.proto for details.
rpc Logging(LoggingRequest) returns (LoggingResponse);
// See worker.proto for details.
rpc Tracing(TracingRequest) returns (TracingResponse);
// See worker.proto for details.
rpc RecvBuf(RecvBufRequest) returns (RecvBufResponse) {}
// See worker.proto for details.
rpc GetStepSequence(GetStepSequenceRequest) returns (GetStepSequenceResponse);
// See worker.proto for details.
rpc CompleteGroup(CompleteGroupRequest) returns (CompleteGroupResponse);
// See worker.proto for details.
rpc CompleteInstance(CompleteInstanceRequest)
returns (CompleteInstanceResponse);
}
3.3.1 WorkerInterface
與 MasterService 類似,對于 WorkerService 的訪問是通過 WorkerInterface 完成的。WorkerInterface 是 worker 的接口類,其是與 TensorFlow Worker service 交互的接口,主要是:
- 定義了一些異步虛函數(shù),比如 CreateWorkerSessionAsync,派生類將實現(xiàn)它們,這些虛函數(shù)和 GrpcWorkerService 支持的 GrpcWorkerMethod 一一對應(yīng),也和 Protobuf 的配置一一對應(yīng)。
- 定義了一些同步函數(shù),比如 CreateWorkerSession,其會通過類似 CallAndWait(&ME::CreateWorkerSessionAsync, request, response) 的方法來調(diào)用到具體異步虛函數(shù)。
我們首先列出其異步接口如下。
// Interface for talking with the TensorFlow Worker service.
class WorkerInterface {
public:
virtual void GetStatusAsync(CallOptions* opts,
const GetStatusRequest* request,
GetStatusResponse* response, bool fail_fast,
StatusCallback done) = 0;
virtual void CreateWorkerSessionAsync(
const CreateWorkerSessionRequest* request,
CreateWorkerSessionResponse* response, StatusCallback done) = 0;
virtual void DeleteWorkerSessionAsync(
CallOptions* opts, const DeleteWorkerSessionRequest* request,
DeleteWorkerSessionResponse* response, StatusCallback done) = 0;
virtual void RegisterGraphAsync(const RegisterGraphRequest* request,
RegisterGraphResponse* response,
StatusCallback done) = 0;
virtual void DeregisterGraphAsync(const DeregisterGraphRequest* request,
DeregisterGraphResponse* response,
StatusCallback done) = 0;
virtual void RunGraphAsync(CallOptions* opts, RunGraphRequestWrapper* request,
MutableRunGraphResponseWrapper* response,
StatusCallback done) = 0;
virtual void RunGraphAsync(CallOptions* opts, const RunGraphRequest* request,
RunGraphResponse* response, StatusCallback done) {
RunGraphRequestWrapper* wrapped_request = new ProtoRunGraphRequest(request);
MutableRunGraphResponseWrapper* wrapped_response =
new NonOwnedProtoRunGraphResponse(response);
RunGraphAsync(opts, wrapped_request, wrapped_response,
[wrapped_request, wrapped_response,
done = std::move(done)](const Status& s) {
done(s);
delete wrapped_request;
delete wrapped_response;
});
}
virtual void CleanupGraphAsync(const CleanupGraphRequest* request,
CleanupGraphResponse* response,
StatusCallback done) = 0;
virtual void CleanupAllAsync(const CleanupAllRequest* request,
CleanupAllResponse* response,
StatusCallback done) = 0;
virtual void RecvTensorAsync(CallOptions* opts,
const RecvTensorRequest* request,
TensorResponse* response,
StatusCallback done) = 0;
virtual void LoggingAsync(const LoggingRequest* request,
LoggingResponse* response, StatusCallback done) = 0;
virtual void TracingAsync(const TracingRequest* request,
TracingResponse* response, StatusCallback done) = 0;
virtual void RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
RecvBufResponse* response, StatusCallback done) = 0;
virtual void CompleteGroupAsync(CallOptions* opts,
const CompleteGroupRequest* request,
CompleteGroupResponse* response,
StatusCallback done) = 0;
virtual void CompleteInstanceAsync(CallOptions* ops,
const CompleteInstanceRequest* request,
CompleteInstanceResponse* response,
StatusCallback done) = 0;
virtual void GetStepSequenceAsync(const GetStepSequenceRequest* request,
GetStepSequenceResponse* response,
StatusCallback done) = 0;
}
WorkerInterface 也提供給了同步接口,這樣 Master 或者 Worker 就可以像調(diào)用本地函數(shù)一樣調(diào)用遠端 WorkerService 的方法。同步接口是在異步接口之上實現(xiàn)的,通過使用 CallAndWait 適配器來完成對異步的封裝。 另外,為了避免外部代碼非法刪除 WorkerInterface 實例,也做了一些限制,比如其析構(gòu)函數(shù)是 protected,讓 WorkerCacheInterface 成為友元,并且由 WorkerCacheInterface::ReleaseWorker 負責(zé)刪除 WorkerInterface 實例。下面是同步接口和一些基礎(chǔ)函數(shù),成員變量。
// Interface for talking with the TensorFlow Worker service.
class WorkerInterface {
public:
virtual MutableRunGraphRequestWrapper* CreateRunGraphRequest() {
return new MutableProtoRunGraphRequest;
}
virtual MutableRunGraphResponseWrapper* CreateRunGraphResponse() {
return new OwnedProtoRunGraphResponse;
}
Status GetStatus(const GetStatusRequest* request,
GetStatusResponse* response) {
Status ret;
Notification n;
GetStatusAsync(/*opts=*/nullptr, request, response, /*fail_fast=*/true,
[&ret, &n](const Status& s) {
ret = s;
n.Notify();
});
n.WaitForNotification();
return ret;
}
Status CreateWorkerSession(const CreateWorkerSessionRequest* request,
CreateWorkerSessionResponse* response) {
return CallAndWait(&ME::CreateWorkerSessionAsync, request, response);
}
Status DeleteWorkerSession(const DeleteWorkerSessionRequest* request,
DeleteWorkerSessionResponse* response) {
return CallAndWaitWithOptions(&ME::DeleteWorkerSessionAsync, request,
response);
}
Status RegisterGraph(const RegisterGraphRequest* request,
RegisterGraphResponse* response) {
return CallAndWait(&ME::RegisterGraphAsync, request, response);
}
Status DeregisterGraph(const DeregisterGraphRequest* request,
DeregisterGraphResponse* response) {
return CallAndWait(&ME::DeregisterGraphAsync, request, response);
}
Status CleanupGraph(const CleanupGraphRequest* request,
CleanupGraphResponse* response) {
return CallAndWait(&ME::CleanupGraphAsync, request, response);
}
Status CleanupAll(const CleanupAllRequest* request,
CleanupAllResponse* response) {
return CallAndWait(&ME::CleanupAllAsync, request, response);
}
Status Logging(const LoggingRequest* request, LoggingResponse* response) {
return CallAndWait(&ME::LoggingAsync, request, response);
}
Status Tracing(const TracingRequest* request, TracingResponse* response) {
return CallAndWait(&ME::TracingAsync, request, response);
}
Status GetStepSequence(const GetStepSequenceRequest* request,
GetStepSequenceResponse* response) {
return CallAndWait(&ME::GetStepSequenceAsync, request, response);
}
protected:
// Instances of WorkerInterface must be deleted by a call to
// WorkerCacheInterface::ReleaseWorker().
virtual ~WorkerInterface() {}
friend class WorkerCacheInterface;
// NOTE: This should only be called by implementations of this
// interface whose CreateRunGraphResponse() method returns a
// proto-based wrappers for the RunGraphResponse message.
RunGraphResponse* get_proto_from_wrapper(
MutableRunGraphResponseWrapper* wrapper) {
return wrapper->get_proto();
}
private:
typedef WorkerInterface ME;
template <typename Method, typename Req, typename Resp>
Status CallAndWait(Method func, const Req* req, Resp* resp) {
Status ret;
Notification n;
(this->*func)(req, resp, [&ret, &n](const Status& s) {
ret = s;
n.Notify();
});
n.WaitForNotification();
return ret;
}
template <typename Method, typename Req, typename Resp>
Status CallAndWaitWithOptions(Method func, const Req* req, Resp* resp) {
CallOptions call_opts;
Status ret;
Notification n;
(this->*func)(&call_opts, req, resp, [&ret, &n](const Status& s) {
ret = s;
n.Notify();
});
n.WaitForNotification();
return ret;
}
};
3.3.2 概念梳理
WorkerService 接口之中牽扯到眾多概念,我們需要仔細梳理一下。
前面提到了,Client 和 Master 之間是通過 session_handle / MasterSession 對 來進行合作,Master 和 Worker 之間就是通過 MasterSession 和 WorkerSession 來完成合作的,MasterSession 會統(tǒng)一管理多個隸屬的 WorkerSession。這里需要理清楚幾個概念之間的關(guān)系:
- session_handle :目的是為了讓 MasterSession 統(tǒng)一管理其下面的多個 WorkerSession。與 MasterSession 一一對應(yīng),在創(chuàng)建 MasterSession 時候生成。通過 CreateSessionResponse 返回給 Client,通過 CreateWorkerSessionRequest 發(fā)送給 Worker,這樣從 Client 到 Master,再到 Worker 這一條鏈路就是由 session_handle 唯一標(biāo)示。
- graph_handle :注冊子圖時候,由 GraphMgr::Register 生成,通過 RegisterGraphResponse 返回給 Master。子圖就被該 graph_handle 所標(biāo)識。在集群內(nèi)部則是 (session_handle, graph_handle) 二元組來唯一標(biāo)識某一個子圖。
- step_id :因為 Master 會讓多個 Worker 并發(fā)執(zhí)行計算,所以會廣播通知大家執(zhí)行 RunGraph,為了區(qū)別不同的 Step,Master 為每次 RunStep 生成全局唯一的標(biāo)識 step_id,通過 RunGraphRequest 消息把 step_id 攜帶給 Worker。
我們梳理一下 graph_handle。GraphMgr::Register 之中會生成 graph_handle。
Status GraphMgr::Register(
const string& handle, const GraphDef& gdef, WorkerSession* session,
const GraphOptions& graph_options, const DebugOptions& debug_options,
const ConfigProto& config_proto, int64_t collective_graph_key,
DistributedFunctionLibraryRuntime* cluster_flr, string* graph_handle) {
Item* item = new Item;
Status s = InitItem(handle, gdef, session, graph_options, debug_options,
config_proto, collective_graph_key, cluster_flr, item);
// Inserts one item into table_.
{
mutex_lock l(mu_);
*graph_handle =
strings::Printf("%016llx", static_cast<long long>(++next_id_));
item->handle = *graph_handle;
CHECK(table_.insert({*graph_handle, item}).second);
}
return Status::OK();
}
RegisterGraphResponse 之中會返回 graph_handle 給 Master。
message RegisterGraphResponse {
// If the registration succeeds, returns an opaque graph_handle to
// the master. The master calls RunGraph with graph_handle to
// compute different steps.
string graph_handle = 1;
}
分割的子圖里有 graph_handle。
// Graph partitioned into per-location subgraphs.
struct Part {
// Worker name.
string name;
// Maps feed names to rendezvous keys. Empty most of the time.
std::unordered_map<string, string> feed_key;
// Maps rendezvous keys to fetch names. Empty most of the time.
std::unordered_map<string, string> key_fetch;
// The interface to the worker. Owned.
WorkerInterface* worker = nullptr;
// After registration with the worker, graph_handle identifies
// this partition on the worker.
string graph_handle;
Part() : feed_key(3), key_fetch(3) {}
};
注冊返回時候會給子圖設(shè)定 graph_handle。
Status MasterSession::ReffedClientGraph::DoRegisterPartitions(
const PartitionOptions& popts,
std::unordered_map<string, GraphDef> graph_partitions) {
partitions_.reserve(graph_partitions.size());
Status s;
for (auto& name_def : graph_partitions) {
partitions_.emplace_back();
Part* part = &partitions_.back();
part->name = name_def.first;
TrackFeedsAndFetches(part, name_def.second, popts);
part->worker = worker_cache_->GetOrCreateWorker(part->name);
if (part->worker == nullptr) {
s = errors::NotFound("worker ", part->name);
break;
}
}
if (!s.ok()) {
for (Part& part : partitions_) {
worker_cache_->ReleaseWorker(part.name, part.worker);
part.worker = nullptr;
}
return s;
}
struct Call {
RegisterGraphRequest req;
RegisterGraphResponse resp;
Status status;
};
const int num = partitions_.size();
gtl::InlinedVector<Call, 4> calls(num);
BlockingCounter done(num);
for (int i = 0; i < num; ++i) {
const Part& part = partitions_[i];
Call* c = &calls[i];
c->req.set_session_handle(session_handle_);
c->req.set_create_worker_session_called(!should_deregister_);
c->req.mutable_graph_def()->Swap(&graph_partitions[part.name]);
StripDefaultAttributes(*OpRegistry::Global(),
c->req.mutable_graph_def()->mutable_node());
*c->req.mutable_config_proto() = session_opts_.config;
*c->req.mutable_graph_options() = session_opts_.config.graph_options();
*c->req.mutable_debug_options() =
callable_opts_.run_options().debug_options();
c->req.set_collective_graph_key(collective_graph_key_);
auto cb = [c, &done](const Status& s) {
c->status = s;
done.DecrementCount();
};
part.worker->RegisterGraphAsync(&c->req, &c->resp, cb);
}
done.Wait();
for (int i = 0; i < num; ++i) {
Call* c = &calls[i];
s.Update(c->status);
partitions_[i].graph_handle = c->resp.graph_handle();
}
return s;
}
使用時候會用 graph_handle 來唯一確定一個子圖。
// Asynchronously deregisters subgraphs on the workers, without waiting for the
// result.
void MasterSession::ReffedClientGraph::DeregisterPartitions() {
struct Call {
DeregisterGraphRequest req;
DeregisterGraphResponse resp;
};
for (Part& part : partitions_) {
// The graph handle may be empty if we failed during partition registration.
if (!part.graph_handle.empty()) {
Call* c = new Call;
c->req.set_session_handle(session_handle_);
c->req.set_create_worker_session_called(!should_deregister_);
c->req.set_graph_handle(part.graph_handle);
// NOTE(mrry): We must capture worker_cache_ since this
// could be deleted before the callback is called.
WorkerCacheInterface* worker_cache = worker_cache_;
const string name = part.name;
WorkerInterface* w = part.worker;
CHECK_NOTNULL(w);
auto cb = [worker_cache, c, name, w](const Status& s) {
delete c;
worker_cache->ReleaseWorker(name, w);
};
w->DeregisterGraphAsync(&c->req, &c->resp, cb);
}
}
}
3.3.4 WorkerInterface 派生類
如下圖所示,WorkerInterface 有兩種實現(xiàn)。
- GrpcWorker : 本地模式下的Worker 角色,如果 Master/Worker都是在本地,則可以直接調(diào)用,不需要 RPC 的網(wǎng)絡(luò)傳輸。
- GrpcRemoteWorker :分布式模式下,Worker 位于遠端,本地需要使用 GrpcRemoteWorker 來訪問遠端 Worker。
- GrpcRemoteWorker 是 gRPC 客戶端,其通過 stub 來訪問遠端 Worker 之上的 GrpcWorkerService 服務(wù)。
- GrpcWorkerService 實現(xiàn)了 WorkerService 定義的所有接口,但是實際業(yè)務(wù)是轉(zhuǎn)發(fā)給本地 GrpcWorker 完成。
具體示例如下:

圖 1 WorkerInterface 派生類
3.3.5 使用
Server 初始化時候,用如下代碼建立Worker Service。
// 創(chuàng)建 GrpcWorker 以及對應(yīng)的 GrpcWorkerService
worker_impl_ = opts.worker_func ? opts.worker_func(&worker_env_, config)
: NewGrpcWorker(&worker_env_, config);
worker_service_ = NewGrpcWorkerService(worker_impl_.get(), &builder,
opts.worker_service_options)
具體就是返回 GrpcWorkerService。
// Returns an implementation of WorkerService rpc service.
std::unique_ptr<AsyncServiceInterface> NewGrpcWorkerService(
GrpcWorker* worker, ::grpc::ServerBuilder* builder,
GrpcWorkerServiceOptions options) {
return std::unique_ptr<AsyncServiceInterface>(
new GrpcWorkerService(worker, builder, options));
}
GrpcServer 之中,使用 worker_thread_ 線程來執(zhí)行 GrpcWorkerService 的 HandleRPCsLoop 方法。
worker_thread_.reset(
env_->StartThread(ThreadOptions(), "TF_worker_service",
[this] { worker_service_->HandleRPCsLoop(); }));
3.3.6 定義
GrpcWorkerService 定義如下,因為其需要作為守護進程處理傳入的 gRPC 請求,所以在構(gòu)造函數(shù)之中會建立若干線程,用來響應(yīng)請求,然后在 HandleRPCsLoop 之中會啟動這些線程,然后做 Join。
class GrpcWorkerService : public AsyncServiceInterface {
public:
GrpcWorkerService(GrpcWorker* worker, ::grpc::ServerBuilder* builder,
GrpcWorkerServiceOptions options)
: is_shutdown_(false) {
builder->RegisterService(&worker_service_);
for (int i = 0; i < options.num_serving_threads; i++) {
threads_.emplace_back(
new GrpcWorkerServiceThread(worker, builder, options.queue_depth,
cache_.get(), &worker_service_));
}
}
// This method blocks forever handling requests from the completion queue.
void HandleRPCsLoop() override {
for (auto& worker_thread : threads_) {
worker_thread->Start();
}
for (auto& worker_thread : threads_) {
worker_thread->Join();
}
}
private:
grpc::WorkerService::AsyncService worker_service_;
std::vector<std::unique_ptr<GrpcWorkerServiceThread>> threads_;
std::unique_ptr<GrpcResponseCache> cache_;
mutex service_shutdown_mu_;
bool is_shutdown_ TF_GUARDED_BY(service_shutdown_mu_);
TF_DISALLOW_COPY_AND_ASSIGN(GrpcWorkerService);
};
3.3.7 線程
具體循環(huán)和響應(yīng)請求其實是在線程之中完成的,cq_ 則是 grpc 的完成隊列。
// GrpcWorkerService spawns one or more GrpcWorkerServiceThreads to service
// requests. Each thread operates on an independent completion queue.
class GrpcWorkerServiceThread {
public:
explicit GrpcWorkerServiceThread(
GrpcWorker* worker, ::grpc::ServerBuilder* builder,
std::unordered_map<int, int> queue_depth, GrpcResponseCache* cache,
grpc::WorkerService::AsyncService* worker_service)
: worker_(worker),
queue_depth_(queue_depth),
cache_(cache),
worker_service_(worker_service),
is_shutdown_(false) {
cq_ = builder->AddCompletionQueue();
}
void Start() {
thread_.reset(
worker_->env()->env->StartThread(ThreadOptions(), "grpc_worker_service",
[this]() { HandleRPCsLoop(); }));
}
}
主循環(huán)
GrpcWorkerServiceThread::HandleRPCsLoop 是線程主循環(huán),和 master service 類似。這里先準(zhǔn)備好一些 gRPC 調(diào)用的等待隊列,這些調(diào)用請求與后面的 GrpcWorkerMethod 一一對應(yīng),每個方法對應(yīng)的處理過程的代碼會在后面提到。
// Add one or more completion queue entries for each worker method, then
// begin servicing requests from the completion queue.
void GrpcWorkerServiceThread::HandleRPCsLoop() {
// TODO(ncteisen): This may require performance engineering. We can
// change the number of threads, the number of handlers per thread,
// or even decide to specialize certain threads to certain methods.
SETUP_FOR_REQUEST(GetStatus, 1, false);
SETUP_FOR_REQUEST(CreateWorkerSession, 1, false);
SETUP_FOR_REQUEST(DeleteWorkerSession, 1, false);
SETUP_FOR_REQUEST(CleanupAll, 1, false);
SETUP_FOR_REQUEST(RegisterGraph, 1, false);
SETUP_FOR_REQUEST(DeregisterGraph, 1, false);
SETUP_FOR_REQUEST(Logging, 1, false);
SETUP_FOR_REQUEST(Tracing, 1, false);
SETUP_FOR_REQUEST(CompleteGroup, 10, true);
SETUP_FOR_REQUEST(CompleteInstance, 10, true);
SETUP_FOR_REQUEST(GetStepSequence, 10, true);
SETUP_FOR_REQUEST(RecvBuf, 500, true);
SETUP_FOR_REQUEST(RunGraph, 100, true);
SETUP_FOR_REQUEST(CleanupGraph, 100, false);
SETUP_FOR_REQUEST(MarkRecvFinished, 10, false);
// TODO(ncteisen): Determine a better policy for enqueuing the
// appropriate number of each request type.
for (int i = 0;
i < gtl::FindWithDefault(
queue_depth_, static_cast<int>(GrpcWorkerMethod::kRecvTensor),
1000);
++i) {
EnqueueRecvTensorRequestRaw();
}
void* tag;
bool ok;
while (cq_->Next(&tag, &ok)) {
UntypedCall<GrpcWorkerServiceThread>::Tag* callback_tag =
static_cast<UntypedCall<GrpcWorkerServiceThread>::Tag*>(tag);
CHECK(callback_tag);
callback_tag->OnCompleted(this, ok);
}
}
grpc request
對于 request 的處理與 master 類似。每個 request 會調(diào)用到一個業(yè)務(wù) handler,如下面宏定義的 GrpcWorkerServiceThread::method##Handler。
#define ENQUEUE_REQUEST(method, supports_cancel) \
do { \
mutex_lock l(shutdown_mu_); \
if (!is_shutdown_) { \
Call<GrpcWorkerServiceThread, grpc::WorkerService::AsyncService, \
method##Request, method##Response>:: \
EnqueueRequestForMethod( \
worker_service_, cq_.get(), \
static_cast<int>(GrpcWorkerMethod::k##method), \
&GrpcWorkerServiceThread::method##Handler, (supports_cancel)); \
} \
} while (0)
#define SETUP_FOR_REQUEST(method, default_depth, supports_cancel) \
for (int i = 0; \
i < gtl::FindWithDefault(queue_depth_, \
static_cast<int>(GrpcWorkerMethod::k##method), \
default_depth); \
++i) { \
ENQUEUE_REQUEST(method, supports_cancel); \
}
這里需要把每個 RPC 服務(wù)注冊為異步服務(wù),這使用 gRPC 自帶的 AddMethod 接口和 MarkMethodAsync 接口來完成。
WorkerService::AsyncService::AsyncService() {
for (int i = 0; i < kGrpcNumWorkerMethods; ++i) {
AddMethod(new ::grpc::internal::RpcServiceMethod(
GrpcWorkerMethodName(static_cast<GrpcWorkerMethod>(i)),
::grpc::internal::RpcMethod::NORMAL_RPC, nullptr));
::grpc::Service::MarkMethodAsync(i);
}
}
Handler & 線程池
具體 Handler 是通過宏來配置的,具體如下,這里調(diào)用了 Call,其會依據(jù)配置來決定是否使用線程池 compute_pool->Schedule 來進行計算。這里就用到了 worker env 里面集成的模塊。
// Handle all non-cancellable simple methods with a standard wrapper.
// The boolean may_block_on_compute_pool indicates whether or not the
// operation may block on activities (such as op execution) that run on the
// compute pool.
#define HANDLE_CALL(method, may_block_on_compute_pool) \
void method##Handler(WorkerCall<method##Request, method##Response>* call) { \
auto closure = [this, call]() { \
Status s = worker_->method(&call->request, &call->response); \
if (!s.ok()) { \
VLOG(3) << "Bad response from " << #method << ": " << s; \
} \
call->SendResponse(ToGrpcStatus(s)); \
}; \
if ((may_block_on_compute_pool)) { \
worker_->env()->env->SchedClosure(std::move(closure)); \
} else { \
worker_->env()->compute_pool->Schedule(std::move(closure)); \
} \
ENQUEUE_REQUEST(method, false); \
}
HANDLE_CALL(GetStatus, false);
HANDLE_CALL(CreateWorkerSession, false);
HANDLE_CALL(DeleteWorkerSession, true);
HANDLE_CALL(CleanupAll, false);
HANDLE_CALL(RegisterGraph, false);
HANDLE_CALL(DeregisterGraph, false);
HANDLE_CALL(CleanupGraph, false);
HANDLE_CALL(Logging, false);
HANDLE_CALL(Tracing, false);
#undef HANDLE_CALL
消息&方法
GrpcWorkerMethod 定義了 worker 具體有哪些方法。
// Names of worker methods.
enum class GrpcWorkerMethod {
kGetStatus,
kCreateWorkerSession,
kDeleteWorkerSession,
kRegisterGraph,
kDeregisterGraph,
kRunGraph,
kCleanupGraph,
kCleanupAll,
kRecvTensor,
kRecvBuf,
kLogging,
kTracing,
kCompleteGroup,
kCompleteInstance,
kGetStepSequence,
kMarkRecvFinished,
};
具體這些消息名字對應(yīng)哪些方法,就是由 GrpcWorkerMethodName 完成。
const char* GrpcWorkerMethodName(GrpcWorkerMethod id) {
switch (id) {
case GrpcWorkerMethod::kGetStatus:
return "/tensorflow.WorkerService/GetStatus";
case GrpcWorkerMethod::kCreateWorkerSession:
return "/tensorflow.WorkerService/CreateWorkerSession";
case GrpcWorkerMethod::kDeleteWorkerSession:
return "/tensorflow.WorkerService/DeleteWorkerSession";
case GrpcWorkerMethod::kRegisterGraph:
return "/tensorflow.WorkerService/RegisterGraph";
case GrpcWorkerMethod::kDeregisterGraph:
return "/tensorflow.WorkerService/DeregisterGraph";
case GrpcWorkerMethod::kRunGraph:
return "/tensorflow.WorkerService/RunGraph";
case GrpcWorkerMethod::kCleanupGraph:
return "/tensorflow.WorkerService/CleanupGraph";
case GrpcWorkerMethod::kCleanupAll:
return "/tensorflow.WorkerService/CleanupAll";
case GrpcWorkerMethod::kRecvTensor:
return "/tensorflow.WorkerService/RecvTensor";
case GrpcWorkerMethod::kRecvBuf:
return "/tensorflow.WorkerService/RecvBuf";
case GrpcWorkerMethod::kLogging:
return "/tensorflow.WorkerService/Logging";
case GrpcWorkerMethod::kTracing:
return "/tensorflow.WorkerService/Tracing";
case GrpcWorkerMethod::kCompleteGroup:
return "/tensorflow.WorkerService/CompleteGroup";
case GrpcWorkerMethod::kCompleteInstance:
return "/tensorflow.WorkerService/CompleteInstance";
case GrpcWorkerMethod::kGetStepSequence:
return "/tensorflow.WorkerService/GetStepSequence";
case GrpcWorkerMethod::kMarkRecvFinished:
return "/tensorflow.WorkerService/MarkRecvFinished";
}
// Shouldn't be reached.
return "invalid id";
}
在 AsyncService 之中會調(diào)用 GrpcWorkerMethodName 完成給 grpc 注冊。
WorkerService::AsyncService::AsyncService() {
for (int i = 0; i < kGrpcNumWorkerMethods; ++i) {
AddMethod(new ::grpc::internal::RpcServiceMethod(
GrpcWorkerMethodName(static_cast<GrpcWorkerMethod>(i)),
::grpc::internal::RpcMethod::NORMAL_RPC, nullptr));
::grpc::Service::MarkMethodAsync(i);
}
}
業(yè)務(wù)處理
具體業(yè)務(wù)處理則是調(diào)用了 Worker 完成的。
void GetStepSequenceHandler(
WorkerCall<GetStepSequenceRequest, GetStepSequenceResponse>* call) {
Schedule([this, call]() {
worker_->GetStepSequenceAsync(
&call->request, &call->response, [call](const Status& s) {
call->SendResponse(ToGrpcStatus(s));
});
});
ENQUEUE_REQUEST(GetStepSequence, true);
}
目前從線程角度看,邏輯如下,這里假定有三個線程。Server 的 worker_thread_ 啟動了 GrpcWorkerService::HandleRPCsLoop(),其作用就是啟動兩個 GrpcWorkerServiceThread,每個 GrpcWorkerServiceThread 在 GrpcWorkerServiceThread::HandleRPCsLoop 之中會響應(yīng) gRPC 請求,進行業(yè)務(wù)處理。這里需要注意,GrpcWorkerService 和 GrpcWorkerServiceThread 都有 HandleRPCsLoop 這個方法。

圖 2 線程角度
3.3.8 業(yè)務(wù)邏輯
CreateWorkerSession
CreateWorkerSessionRequest 消息之中會傳遞 MasterSession對應(yīng)的 session_handle,Worker 接收消息之后,生成一個 WorkerSession。在一個集群之內(nèi),當(dāng) MasterSession 建立 WorkerSession 時候,都會把自己對應(yīng)的 session_handle 傳過去,這樣,WorkerSession 就可以通過 session_handle 知道自己屬于哪個 MasterSession。MasterSession 實例也可以統(tǒng)一管理隸屬于它的所有 WorkerSession。
GrpcWorker 通過 SessionMgr 來具體完成對 WorkerSession 的管理,既可以通過 master task name 來確定 WorkerSession,也可以通過 session_handle 來確定。
class SessionMgr {
WorkerEnv* const worker_env_; // Not owned.
std::unique_ptr<WorkerCacheInterface> default_worker_cache_;
std::shared_ptr<WorkerSession> legacy_session_;
const WorkerCacheFactory worker_cache_factory_;
// A map from session identifier to internal session structure.
std::map<string, std::shared_ptr<WorkerSession>> sessions_ TF_GUARDED_BY(mu_);
// Incarnation and WorkerSession handle associated with a master task.
struct MasterAssociatedSession {
const int64_t master_incarnation;
const string session_handle;
};
// A map from master task name to its associated worker sessions.
std::unordered_multimap<string, MasterAssociatedSession>
master_to_associated_sessions_ TF_GUARDED_BY(mu_);
};
具體消息如下,注意,CreateWorkerSessionResponse 沒有返回任何東西:
message CreateWorkerSessionRequest {
// Sessions are identified by a given handle.
string session_handle = 1;
// Defines the configuration of a TensorFlow worker.
ServerDef server_def = 2;
// If true, any resources such as Variables used in the session will not be
// shared with other sessions.
bool isolate_session_state = 3;
// The device attributes of all the devices in the cluster.
repeated DeviceAttributes cluster_device_attributes = 4;
// The master task name from which the request is sent.
string master_task = 5;
// The incarnation ID of the master task local CPU device.
// If the target worker already has a WorkerSession created previously with
// the same master task name but a different incarnation, it usually indicates
// that the previous master failed before deleting the WorkerSession on the
// worker. To prevent memory leaks, the worker should garbage collect the old
// WorkerSessions.
int64 master_incarnation = 6;
}
message CreateWorkerSessionResponse {}

圖 3 CreateWorkerSession
如前所述,GrpcWorker 這些消息都是用宏來生成的。
#define HANDLE_CALL(method, may_block_on_compute_pool) \
void method##Handler(WorkerCall<method##Request, method##Response>* call) { \
auto closure = [this, call]() { \
Status s = worker_->method(&call->request, &call->response); \
if (!s.ok()) { \
VLOG(3) << "Bad response from " << #method << ": " << s; \
} \
call->SendResponse(ToGrpcStatus(s)); \
}; \
if ((may_block_on_compute_pool)) { \
worker_->env()->env->SchedClosure(std::move(closure)); \
} else { \
worker_->env()->compute_pool->Schedule(std::move(closure)); \
} \
ENQUEUE_REQUEST(method, false); \
}
HANDLE_CALL(GetStatus, false);
HANDLE_CALL(CreateWorkerSession, false);
HANDLE_CALL(DeleteWorkerSession, true);
HANDLE_CALL(CleanupAll, false);
HANDLE_CALL(RegisterGraph, false);
HANDLE_CALL(DeregisterGraph, false);
HANDLE_CALL(CleanupGraph, false);
HANDLE_CALL(Logging, false);
HANDLE_CALL(Tracing, false);
RegisterGraph
RegisterGraphRequest 消息會發(fā)送 MasterSession 對應(yīng)的 session_handle,子圖 graph_def。當(dāng) Worker 接收消息,完成子圖注冊/初始化后,會返回該子圖的 graph_handle 給 Master。
對于每個會話,在 master 將每個節(jié)點放在一個設(shè)備上之后,它將整個圖分割成許多子圖。一個子圖中的所有節(jié)點都在同一個 worker 中,但可能在該 worker 擁有的許多設(shè)備上(例如cpu0,加上gpu0、gpu1、...、gpu7)。在運行任何step之前,master 為 worker 注冊了子圖。成功的注冊會返回一個圖的句柄,以便在以后的 RunGraph請求中使用。
////////////////////////////////////////////////////////////////////////////////
//
// RegisterGraph method request/response messages
//
// For each session, after the master placed every node on a device,
// it partitions the whole graph into many subgraphs. All the nodes in
// a subgraph were in the same worker, but potentially on many devices
// owned by that worker (e.g. cpu0, plus gpu0, gpu1, ..., gpu7). The
// master registers subgraphs for a worker before running any steps. A
// successful registration returns a graph handle to be used in latter
// RunGraph requests.
//
////////////////////////////////////////////////////////////////////////////////
message RegisterGraphRequest {
// Subgraphs are scoped within one session.
string session_handle = 1;
// Set to true if CreateWorkerSession was called for session_handle.
bool create_worker_session_called = 6;
// "graph_def" has the subgraph of nodes for this worker, with each node
// having its device_name filled in.
GraphDef graph_def = 2;
// True iff the graph (before partitioning) contains control flow nodes.
//
// As of 01/11/2015, this is no longer set by clients.
bool has_control_flow = 3 [deprecated = true];
// Configuration options for the session in which this graph was created.
GraphOptions graph_options = 4;
// Field(s) used by TensorFlow Debugger (tfdbg).
DebugOptions debug_options = 5;
// If graph_def contains any collective ops this must be a positive
// integer used to coordinate execution with other graphs. All
// graphs in a distributed execution with the same
// collective_graph_key will coordinate to use the same step_id
// concurrently so that BufRendezvous entries will make the correct
// values accessible.
int64 collective_graph_key = 7;
// ConfigProto from the session in which this graph was created.
// Contains additional parameters beyond graph_options, including
// the name of the requested executor.
ConfigProto config_proto = 8;
}
message RegisterGraphResponse {
// If the registration succeeds, returns an opaque graph_handle to
// the master. The master calls RunGraph with graph_handle to
// compute different steps.
string graph_handle = 1;
}

圖 4 RegisterGraph
DeregisterGraph
當(dāng)不再需要計算圖時(例如,整個計算圖圖被重新調(diào)度,圖內(nèi)節(jié)點被重新編排),Master 會利用該圖對應(yīng)的 graph_handle來取消注冊。在 Master 重啟情況下,Worker 根據(jù)以 TTL 為基礎(chǔ)的策略自動取消對應(yīng) graph_handle 的注冊。
////////////////////////////////////////////////////////////////////////////////
//
// DeregisterGraph method request/response messages
//
// The master deregisters the given graph_handle when the graph is no
// longer needed (e.g., the overall graph is re-scheduled and nodes
// are re-placed).
//
// The worker deregisters a graph_handle automatically according to on
// a TTL-base policy in case of master restarts.
//
////////////////////////////////////////////////////////////////////////////////
message DeregisterGraphRequest {
// The session_handle used when registering the graph. If session_handle is
// empty, a single global namespace is used.
string session_handle = 2;
// Set to true if CreateWorkerSession was called for session_handle.
bool create_worker_session_called = 3;
// REQUIRED: graph_handle must be returned by a RegisterGraph call
// to the same WorkerService.
string graph_handle = 1;
}
message DeregisterGraphResponse {
// TODO(mrry): Optionally add summary stats for the graph.
}

圖 5 DeregisterGraph
RunGraph
Master 用 RunGraphRequest 來執(zhí)行在 graph_handle下注冊的所有子圖。
Master 會生成一個全局唯一的 step_id 來區(qū)分圖計算的不同運行 step。子圖之間可以使用 step_id 進行彼此通信(例如,發(fā)送/轉(zhuǎn)發(fā)操作),以區(qū)分不同運行產(chǎn)生的張量。
RunGraphRequest 消息的 send 表示子圖輸入的張量,recv_key 指明子圖輸出的張量。RunGraphResponse 會返回 recv_key 對應(yīng)的 Tensor 列表。

圖 6 RunGraph
////////////////////////////////////////////////////////////////////////////////
//
// RunGraph request / response messages
//
// The worker executes all subgraphs registered under graph_handle.
// RunGraph returns after the execution finishes or an error is
// encountered.
// A sequence of RunGraphRequests with is_partial may be sent to RunGraph for
// partial graph execution.
//
////////////////////////////////////////////////////////////////////////////////
// Options specific to the execution of a single step.
message ExecutorOpts {
bool record_costs = 1;
bool record_timeline = 3;
bool record_partition_graphs = 4;
bool report_tensor_allocations_upon_oom = 5;
}
message RunGraphRequest {
// session_handle is the master-generated unique id for this session.
// If session_handle is non-empty, it must be the same as used when
// registering the graph. If it is empty, a single global namespace is used to
// search for the graph_handle.
string session_handle = 8;
// Set to true if CreateWorkerSession was called for session_handle.
bool create_worker_session_called = 10;
// REQUIRED: graph_handle must be returned by a RegisterGraph call
// to the same WorkerService.
string graph_handle = 1;
// A unique ID to distinguish different runs of the same graph.
//
// The master generates a global unique step_id to distinguish
// different runs of the graph computation. Subgraphs communicate
// (e.g., send/recv ops) with each other using step_id to
// distinguish tensors generated by different runs.
int64 step_id = 2;
// Options for this step.
ExecutorOpts exec_opts = 5;
// Runs the graph.
//
// Sends the tensors in "send" into the graph before the run and
// fetches the keys into RunGraphResponse.recv after the run.
repeated NamedTensorProto send = 3;
repeated string recv_key = 4;
// True if the RunGraphRequest is a partial run request.
bool is_partial = 6;
// True if this is the last partial run request in a sequence of requests.
bool is_last_partial_run = 7;
// If true then some errors, e.g., execution errors that have long
// error messages, may return an OK RunGraphResponse with the actual
// error saved in the status_code/status_error_message fields of the
// response body. This is a workaround since the RPC subsystem may
// truncate long metadata messages.
bool store_errors_in_response_body = 9;
// Unique identifier for this request. Every RunGraphRequest must have a
// unique request_id, and retried RunGraphRequests must have the same
// request_id. If request_id is zero, retry detection is disabled.
//
// Retried RunGraphRequests are problematic because they may issue a
// RecvTensor that will have no corresponding sender and will wait forever.
// Workers use request_ids to reject retried RunGraph requests instead of
// waiting forever.
int64 request_id = 11;
// Next: 12
}
message RunGraphResponse {
// A list of tensors corresponding to those requested by
// RunGraphRequest.recv_key.
repeated NamedTensorProto recv = 1;
// If the request asked for execution stats, the cost graph, or the partition
// graphs, these are returned here.
// TODO(suharshs): Package these in a RunMetadata instead.
StepStats step_stats = 2;
CostGraphDef cost_graph = 3;
repeated GraphDef partition_graph = 4;
// If store_errors_in_response_body is true in the request, then
// optionally the server may return an OK status for the RPC and
// fill the true status into the fields below, to allow for messages
// that are too long to fit in metadata.
error.Code status_code = 5;
string status_error_message = 6;
}
RecvTensor
在具體運行之中,兩個 Worker 之間可能會交換數(shù)據(jù),此時生產(chǎn)者只是把準(zhǔn)備好的張量放入 rendezvous,消費者會主動發(fā)起 RecvTensorRequest 請求,RecvTensorRequest 里面 step_id 標(biāo)識是哪次 step,rendezvous_key 標(biāo)識要接收張量的通道(channel)。
一個 RecvTensor 請求從通道中獲取一個張量,也可以通過多個 RecvTensor 請求在同一個通道中發(fā)送和接收多個張量。最終生產(chǎn)者的張量會通過 RecvTensorResponse 返回給消費者。

圖 7 RecvTensor
////////////////////////////////////////////////////////////////////////////////
//
// RecvTensor method request/response messages
//
////////////////////////////////////////////////////////////////////////////////
message RecvTensorRequest {
// The step in which the tensor will be produced.
//
// REQUIRED: This must eventually correspond to the step_id passed
// into a RunGraph call on the same WorkerService.
int64 step_id = 1;
// A key identifying the channel to receive tensors from. A RecvTensor request
// retrieves one tensor from the channel, but multiple tensors can be sent and
// received over the same channel with multiple RecvTensor requests. See
// rendezvous.h for details.
string rendezvous_key = 2;
// If true, use an out-of-band DMA mechanism to transfer the
// received tensor.
bool dma_ok = 3;
// Optional information on client-side device locality.
DeviceLocality client_locality = 4;
// Optional information on server-side device locality.
DeviceLocality server_locality = 5;
// Optional information needed by the RPC subsystem.
google.protobuf.Any transport_options = 6;
// Unique identifier for this request. Every RecvTensorRequest must have a
// unique request_id, and retried RecvTensorRequests must have the same
// request_id. If request_id is zero, retry detection and response cache
// are disabled.
//
// Retried RecvTensorRequests are problematic because a RecvTensor with no
// corresponding sender will wait forever, and the tensor may have been
// delivered to a previous retry. Workers use request_ids to reject retried
// RecvTensor requests instead of waiting forever.
int64 request_id = 7;
}
message RecvTensorResponse {
// The tensor as a proto.
TensorProto tensor = 1;
// If true, this tensor was the output of a dead node, and the
// content is invalid.
bool is_dead = 2;
// The time at which tensor was available and started to be returned.
int64 send_start_micros = 3;
// Optional additional information about how to receive the tensor,
// e.g. in the event that RecvTensorRequest.dma_ok was true.
google.protobuf.Any transport_options = 4;
// Whether the receiver should send a MarkRecvFinishedRequest to the sender
// to ack the message.
bool require_ack = 5;
}
4. Worker
Worker 類主要是提供了 WorkerEnv 和 PartialRunMgr,其可以被子類化,以便為不同的傳輸機制提供特定方法的專門實現(xiàn)。例如,GrpcWorker 專門實現(xiàn)了 RecvTensorAsync 方法,以支持更有效的 gRPC 數(shù)據(jù)結(jié)構(gòu)來處理大型二進制數(shù)據(jù)。
class Worker : public WorkerInterface {
protected:
WorkerEnv* const env_; // Not owned.
RecentRequestIds recent_request_ids_;
private:
PartialRunMgr partial_run_mgr_;
CancellationManager cancellation_manager_;
TF_DISALLOW_COPY_AND_ASSIGN(Worker);
};
我們舉出一個方法看看,具體其他方法我們后面遇到了會說。
void Worker::CleanupAllAsync(const CleanupAllRequest* request,
CleanupAllResponse* response,
StatusCallback done) {
std::vector<string> containers;
for (const auto& c : request->container()) containers.push_back(c);
env_->device_mgr->ClearContainers(containers);
done(Status::OK());
}
5. GrpcWorker
GrpcWorker 是 GrpcRemoteWorker 對應(yīng)的遠端 Worker。也是 GrpcWorkerService 調(diào)用的對象,其實現(xiàn)了業(yè)務(wù)邏輯。其定義如下,我們可以看到其實現(xiàn)了幾個方法。
class GrpcWorker : public Worker {
public:
GrpcWorker(WorkerEnv* env, const ConfigProto& config);
// Specialized version of RecvTensor for gRPC, which avoids a copy.
virtual void GrpcRecvTensorAsync(CallOptions* opts,
const RecvTensorRequest* request,
::grpc::ByteBuffer* response,
StatusCallback done);
void LoggingAsync(const LoggingRequest* request, LoggingResponse* response,
StatusCallback done) override;
void RecvBufAsync(CallOptions* opts, const RecvBufRequest* request,
RecvBufResponse* response, StatusCallback done) override;
void CleanupGraphAsync(const CleanupGraphRequest* request,
CleanupGraphResponse* response,
StatusCallback done) override;
WorkerEnv* env();
void EnableResponseCache();
void RemoveCacheEntryForId(int64 request_id);
private:
std::unique_ptr<GrpcResponseCache> response_cache_;
const int32 recv_buf_max_chunk_;
};
至此,Worker 的靜態(tài)結(jié)構(gòu)我們已經(jīng)介紹完畢,具體 Worker 功能我們將在后文 Session 部分進行具體介紹。
0xFF 參考
TensorFlow架構(gòu)與設(shè)計:OP本質(zhì)論
https://jcf94.com/2018/02/28/2018-02-28-tfunpacking3/
什么是in-graph replication和between-graph replication?
[騰訊機智] TensorFlow源碼解析(1): 創(chuàng)建會話
TensorFlow 分布式(Distributed TensorFlow)
tensorflow源碼解析之distributed_runtime
Distributed TensorFlow: A Gentle Introduction
TensorFlow中的Placement啟發(fā)式算法模塊——Placer
TensorFlow的圖切割模塊——Graph Partitioner
TensorFlow中的通信機制——Rendezvous(一)本地傳輸
浙公網(wǎng)安備 33010602011771號