<output id="qn6qe"></output>

    1. <output id="qn6qe"><tt id="qn6qe"></tt></output>
    2. <strike id="qn6qe"></strike>

      亚洲 日本 欧洲 欧美 视频,日韩中文字幕有码av,一本一道av中文字幕无码,国产线播放免费人成视频播放,人妻少妇偷人无码视频,日夜啪啪一区二区三区,国产尤物精品自在拍视频首页,久热这里只有精品12

      [源碼解析] Pytorch 如何實(shí)現(xiàn)后向傳播 (2)---- 引擎靜態(tài)結(jié)構(gòu)

      [源碼解析] Pytorch 如何實(shí)現(xiàn)后向傳播 (2)---- 引擎靜態(tài)結(jié)構(gòu)

      0x00 摘要

      前文最終我們提到了如下代碼就是調(diào)用引擎來(lái)進(jìn)行反向傳播,其中:

      • roots是包含有前向傳播輸出節(jié)點(diǎn)的 gradient_edge()(即輸出節(jié)點(diǎn)的(grad_fn_, 0))的 vector,也就是edge_list。
      • inputs 是前向傳播產(chǎn)生的梯度,如果沒(méi)有配置,則初始化為(tensor(1.),)。
      • outputs 是依據(jù)前向傳播輸入節(jié)點(diǎn)構(gòu)建的后向傳播輸出邊,這些邊是(Function, input number) pair。
       Engine::execute(roots, inputs, keep_graph, create_graph, accumulate_grad, outputs);
      

      結(jié)合Engine定義,我們可以一一把這些輸入與 execute 的參數(shù)對(duì)應(yīng)起來(lái)。

      auto Engine::execute(const edge_list& roots, // 反向傳播的根節(jié)點(diǎn)
                           const variable_list& inputs, // 根節(jié)點(diǎn)的梯度
                           bool keep_graph, // 計(jì)算圖是否需要保留
                           bool create_graph, // 是否需要構(gòu)建微分圖以進(jìn)行高階求導(dǎo)
                           bool accumulate_grad,
                           const edge_list& outputs // 需要輸出梯度的節(jié)點(diǎn)
                          ) 
      

      所以本文我們首先從靜態(tài)角度來(lái)看引擎,就是看看其數(shù)據(jù)結(jié)構(gòu)和靜態(tài)性質(zhì)。

      系列前幾篇鏈接如下:

      深度學(xué)習(xí)利器之自動(dòng)微分(1)

      深度學(xué)習(xí)利器之自動(dòng)微分(2)

      [源碼解析]深度學(xué)習(xí)利器之自動(dòng)微分(3) --- 示例解讀

      [源碼解析]PyTorch如何實(shí)現(xiàn)前向傳播(1) --- 基礎(chǔ)類(lèi)(上)

      [源碼解析]PyTorch如何實(shí)現(xiàn)前向傳播(2) --- 基礎(chǔ)類(lèi)(下)

      [源碼解析] PyTorch如何實(shí)現(xiàn)前向傳播(3) --- 具體實(shí)現(xiàn)

      [源碼解析] Pytorch 如何實(shí)現(xiàn)后向傳播 (1)---- 調(diào)用引擎

      0x01 Engine

      Engine 是autograd的核心,其實(shí)現(xiàn)了后向傳播。后向傳播方向是從根節(jié)點(diǎn)(就是正向傳播的輸出)到輸出(就是正向傳播的輸入),在后向傳播過(guò)程之中依據(jù)前向傳播過(guò)程中設(shè)置的依賴(lài)關(guān)系生成了動(dòng)態(tài)計(jì)算圖。

      Engine 入口 是execute函數(shù),其邏輯如下:

      • 根據(jù)根節(jié)點(diǎn) roots 構(gòu)建GraphRoot。
      • 根據(jù) roots 之中的Node實(shí)例 metadata 以及各層之間的關(guān)系來(lái)構(gòu)建計(jì)算圖。
        • 通過(guò)next_edge不斷的找到指向的下一個(gè)Edge,最終完成整個(gè)計(jì)算圖的計(jì)算。
        • 利用 Queue 來(lái)多線(xiàn)程完成反向計(jì)算的工作。

      引擎定義在:torch/csrc/autograd/engine.cpp,這里只給出成員變量,最主要的變量是:

      • device_ready_queues_ :ReadyQueue 列表 device_ready_queues_ 之中的每一個(gè)ReadyQueue都啟動(dòng)了一個(gè)工作線(xiàn)程。各個(gè)線(xiàn)程之間通過(guò) device_ready_queues_ 來(lái)進(jìn)行交互。注意,因?yàn)镃PU線(xiàn)程會(huì)處理其調(diào)用的反向傳播的CPU相關(guān)工作,所以每個(gè) GraphTask 擁有自己的 cpu_ready_queue_,用戶(hù)可以向這些 cpu_ready_queue_ 發(fā)送待處理的工作。
      • thread_pool_shared_ :線(xiàn)程池,用來(lái)多線(xiàn)程處理后向傳播。

      具體代碼是:

      // A single instance of this struct should be created through the whole process lifetime.
      // The worker thread creation logic and Engine's destructor rely on this.
      struct TORCH_API Engine {
      
        // Ensures device_ready_queues_ are initialized only once
        std::once_flag start_device_threads_flag_;
        // Safe to read device_ready_queues_ without synchronization after initialization
        std::vector<std::shared_ptr<ReadyQueue>> device_ready_queues_;
      
        std::vector<std::function<void()>> final_callbacks_;
        // To protect reads and writes to final_callbacks_
        std::mutex post_callbacks_lock_;
      
        // How many nested reentrant calls are allowed until a new thread is used
        int max_recursion_depth_;
      
        struct ThreadPoolShared {
          // Data structures used by the threads for executing reentrant backwards
          // tasks. See Note [Reentrant backwards]
          // Number of available threads for processing new GraphTasks.
          unsigned int num_workers_;
          // The threads will wait on work_ to be notified of GraphTasks
          std::condition_variable work_;
          // To protect reads and writes to graphtask_queue_ and num_workers_
          // and for synchronizing creating new threads when needed
          std::mutex mutex_;
          // Workers will process the GraphTasks added to this queue. A GraphTask is
          // allocated inside Engine::execute and lives for the duration of execute
          std::queue<std::weak_ptr<GraphTask>> graphtasks_queue_;
      
          ThreadPoolShared() : num_workers_(0) {}
       };
      
       // Temporary workaround until shutting down threads is done
       // We need shared ownership of all these objects because the threads are leaked
       // when Engine shuts down, so there may be threads waiting on work_
       // for the graphtasks_queue_ to be nonempty.
       std::shared_ptr<ThreadPoolShared> thread_pool_shared_;
      
      private:
        // Number of non-reentrant threads
        std::atomic<uint32_t> non_reentrant_device_thread_count_;
        // Destructor will wait for non-reentrant threads to finish
        std::condition_variable non_reentrant_device_thread_condvar_;
        std::mutex non_reentrant_device_thread_mutex_;
        // stop() must be called before the destruction path goes down to the base
        // class, in order to avoid a data-race-on-vptr. Use this boolean to guard
        // whether stop() has already been called, so we can call this in every
        // destructor of the class hierarchy.
        bool stopped_{false};
      };
      

      我們接下來(lái)就先介紹各種基礎(chǔ)類(lèi),每個(gè)類(lèi)我們力爭(zhēng)結(jié)合其使用代碼來(lái)分析。

      0x02 GraphRoot

      GraphRoot 是一個(gè)Node類(lèi)型,Node其實(shí)就是原來(lái)的Function類(lèi)。

      struct TORCH_API GraphRoot : public Node {
          
        GraphRoot(edge_list functions, variable_list inputs)
            : Node(std::move(functions)),
            outputs(std::move(inputs)) { // 把輸入的 input 配置給 outputs 成員變量。
          // Ensures calls to stream() on a GraphRoot instance reflect current stream(s)
          // on devices of root grad tensors at the time the instance is constructed.
          for (const auto& t : outputs) {
            add_input_metadata(t);
          }
        }
      
        variable_list apply(variable_list&& inputs) override {
          return outputs; // apply 方法僅僅返回它的輸入,就是梯度。Node 的其他派生類(lèi)會(huì)有自己不同的實(shí)現(xiàn)。
        }
      
        variable_list outputs; // 梯度。其只是通過(guò) apply() 來(lái)進(jìn)行使用,就是 apply 方法返回這個(gè)outputs。
      };
      
      struct TORCH_API Identity : public Node {
        variable_list apply(variable_list&& inputs) override;
      };
      

      2.1 構(gòu)建

      在 engine 之中,是用如下代碼構(gòu)建 GraphRoot。結(jié)合 execute 的調(diào)用方式,我們知道是使用 反向傳播的根節(jié)點(diǎn)(起始點(diǎn))和 根節(jié)點(diǎn)的梯度 inputs 來(lái)構(gòu)建 GraphRoot。

        // If we receive a single root, skip creating extra root node
        bool skip_dummy_node = roots.size() == 1;
        auto graph_root = skip_dummy_node ?
          roots.at(0).function :
          std::make_shared<GraphRoot>(roots, inputs); 
      

      我們?cè)倩貞浺幌?GraphRoot 之中的 Node這個(gè)基類(lèi)被如何構(gòu)建。可以看到 GraphRoot 就是使用邊列表構(gòu)建了基類(lèi) Node,反向傳播的根節(jié)點(diǎn) roots 就是 GraphRoot(Node)相關(guān)聯(lián)的邊,然后 GraphRoot 本身新增了成員變量 variable_list outputs(就是輸入 input 參數(shù))。

        explicit Node(edge_list&& next_edges = edge_list())
          : Node(/*sequence_nr=*/at::sequence_number::get_and_increment(),
          std::move(next_edges)) {}
      

      具體如下:

      +------------------------------------+
      | GraphRoot                          |
      |                                    |
      |   variable_list outputs +--------------->  inputs  梯度,被透?jìng)鹘o下游
      |                                    |
      |                                    |
      |   +----------------------------+   |
      |   | Node                       |   |
      |   |                            |   |
      |   |                            |   |
      |   |   edge_list next_edges_ +----------->  roots   起始點(diǎn)
      |   |                            |   |
      |   +----------------------------+   |
      |                                    |
      +------------------------------------+
      

      2.2 作用

      GraphRoot 的作用是:

      • GraphRoot 就是后向傳播的輸入,就是根節(jié)點(diǎn)。
      • 在構(gòu)造 graph_root 時(shí)候:
        • 如果只有一個(gè)root節(jié)點(diǎn),則就直接使用root作為 GraphRoot 。
        • 如果多個(gè)root,就構(gòu)造一個(gè)GraphRoot(可以認(rèn)為是虛擬根節(jié)點(diǎn)),把這些 root 作為參數(shù)構(gòu)建一個(gè)GraphRoot,這個(gè) GraphRoot 作為真正的根節(jié)點(diǎn)。root 就是 Node 的邊。
      • 從初始化函數(shù)可以看出來(lái),引擎的輸入inputs(反向傳播的輸入梯度)就是GraphRoot的輸出 outputs
      • Function 的靈魂是 apply 方法,對(duì)于 GraphRoot 來(lái)說(shuō),其apply函數(shù)僅僅返回它的輸入,這樣,原始輸入 input 就直接被 GraphRoot 透?jìng)鹘o反向傳播的下一階段
      • 后續(xù)計(jì)算 compute_dependencies 會(huì)用這個(gè) GraphRoot 來(lái)得到計(jì)算圖的依賴(lài)關(guān)系,就是利用 GraphRoot 的 next_edges_ 來(lái)得到計(jì)算圖的依賴(lài)關(guān)系。
        // If we receive a single root, skip creating extra root node
        bool skip_dummy_node = roots.size() == 1;
        auto graph_root = skip_dummy_node ?
          roots.at(0).function : // 如果只有一個(gè)root,就直接使用root作為 GraphRoot 
          std::make_shared<GraphRoot>(roots, inputs); // 如果多個(gè)root,就構(gòu)造一個(gè)GraphRoot
      
        auto min_topo_nr = compute_min_topological_nr(outputs);
        // Now compute the dependencies for all executable functions
        compute_dependencies(graph_root.get(), *graph_task, min_topo_nr);
      
      

      0x03 GraphTask

      我們先給出一個(gè)基本概念。GraphTask 實(shí)例代表一個(gè)動(dòng)態(tài)圖級(jí)別的資源管理對(duì)象,其擁有一次反向傳播執(zhí)行所需要的全部元數(shù)據(jù),比如計(jì)算圖中所有Node的依賴(lài)關(guān)系,還沒(méi)有準(zhǔn)備好Node的等待隊(duì)列等等。如果允許重入反向傳播,則會(huì)有多個(gè)GraphTask一起工作。

      3.1 定義

      GraphTask 其主要成員變量如下:

      • outstanding_tasks_ :用來(lái)記錄當(dāng)前任務(wù)數(shù)目,如果數(shù)目為0,則說(shuō)明任務(wù)結(jié)束了。 如果這個(gè)數(shù)量不為0,則此GraphTask依然需要運(yùn)行。
      • dependencies_ :用來(lái)判斷后續(xù)節(jié)點(diǎn)是否已經(jīng)可以被執(zhí)行。
      • not_ready_ :存儲(chǔ)沒(méi)有完成的function和其輸入。
      • grad_mode_ :是否需要進(jìn)行梯度計(jì)算。反向計(jì)算期間執(zhí)行的代碼邏輯依靠AutoGradMode::is_enabled() 來(lái)判斷當(dāng)前是否是要計(jì)算grad。
      • owner : GraphTask 所屬線(xiàn)程的Device 數(shù)值。GraphTask是在哪個(gè)線(xiàn)程中創(chuàng)建的,該值就是那個(gè)線(xiàn)程中的worker_device的值。
      • cpu_ready_queue_ :
        • CPU線(xiàn)程專(zhuān)用于處理反向傳播之中的CPU相關(guān)工作。因此所有Graph task都會(huì)維護(hù)自己的cpu_ready_queue_,CPU相關(guān)任務(wù)應(yīng)該將發(fā)送到該隊(duì)列。
        • 對(duì)于每個(gè)GraphTask,我們維護(hù)cpu_ready_queue_,這樣在設(shè)備線(xiàn)程(即GPU)上執(zhí)行時(shí),如果是下一個(gè)NodeTask 應(yīng)該在CPU上運(yùn)行,我們就知道應(yīng)該推送 NodeTask 到哪個(gè)就緒隊(duì)列。
      • mutex_ :保護(hù)如下變量:not_ready_, dependencies_, captured_vars,has_error_, future_result_, cpu_ready_queue_, and leaf_streams
      • keep_graph :用來(lái)指定一次反向計(jì)算后是否釋放資源。

      具體定義如下,這里只給出成員變量:

      // GraphTask holds metadata needed for a single execution of backward()
      struct GraphTask: std::enable_shared_from_this<GraphTask> {
        std::atomic<uint64_t> outstanding_tasks_{0};
        // Indicates if an error occurred while executing any task.  When this is
        // true, it signals all threads to stop executing.
        std::atomic_bool has_error_{false};
        std::atomic_bool future_completed_{false};
        // It is safe to read grad_mode_ and keep_graph_ without synchronization
        bool keep_graph_;
        bool grad_mode_;
      
        // To protect reads/writes to not_ready_, dependencies_, captured_vars_,
        // has_error_, future_result_, cpu_ready_queue_, and leaf_streams.
        std::mutex mutex_;
        std::unordered_map<Node*, InputBuffer> not_ready_;
        std::unordered_map<Node*, int> dependencies_;
      
        struct ExecInfo {
          struct Capture {
            Capture(const Capture&) = delete;
            Capture(Capture&&) = default;
      
            Capture(int input_idx, int output_idx)
                : input_idx_(input_idx), output_idx_(output_idx) {}
            int input_idx_; // within Node inputs
            int output_idx_; // within the output vector of a GraphTask
      
            // This hook will be executed after a grad is captured. The captured
            // grad will be replaced by the return value of the hook.
            struct GradCaptureHook {
              virtual ~GradCaptureHook() = default;
              virtual at::Tensor operator()(const at::Tensor& grad) = 0;
            };
            // The hooks will be called one by one in the order as they were added.
            // The input grad of a hook will be the output of its preceding hook. The
            // first hook will take the captured grad as the input. The output of the
            // last hook will replace the captured grad.
            std::vector<std::unique_ptr<GradCaptureHook>> hooks_;
          };
      
          bool should_execute() const {
            return needed_ || captures_;
          }
      
          bool needed_ = false;
          std::unique_ptr<std::vector<Capture>> captures_;
        };
        // Exec info has a bit complicated semantics. If it's empty, it means the task
        // is run in a "default" mode, which means that all next_edges we encounter
        // should get executed. If it's not empty, only functions that have an entry
        // and this entry has needed == True should be executed. exec_info is only empty
        // when the graph is executed via .backward() and the inputs parameter is not passed.
        // Otherwise, when executed through .grad(), or when inputs arg is specified for
        // .backward(), exec_info will be non-empty.
        //
        // exec_info_ is safe to read without synchronization
        std::unordered_map<Node*, ExecInfo> exec_info_;
        // Captures variables are grads captured that we return to the user. After
        // execution of the GraphTask is completed, the captured_vars_ are moved
        // out of the GraphTask and are no longer valid.
        std::vector<Variable> captured_vars_;
      
        at::ThreadLocalState thread_locals_ =
            at::ThreadLocalState(/* keep_grad_mode */ false);
      
        std::unordered_set<c10::Stream> leaf_streams;
      
        // The value of worker_device in the thread that created this task.
        // See Note [Reentrant backwards]
        // Safe to read owner_ and reentrant_depth_ without synchronizaton
        int owner_;
        // The number of parent graph tasks for this graph task
        const int reentrant_depth_;
      
        // Whether or not to stop execution for this GraphTask when an error is
        // encountered. When set to true, this would cause Engine::execute() to throw
        // an exception as soon as the autograd engine receives an exception.
        bool exit_on_error_;
      
        // CPU threads are dedicated to processing CPU work for the backward they invoked.
        // So any given graph task maintains its own cpu_ready_queue_ where you should send
        // work for it to be done. We memoize the cpu_ready_queue_ per GraphTask so that
        // we know which ready queue we should push to if we are on device thread (i.e. GPU)
        // and but next NodeTask should be run on CPU.
        std::shared_ptr<ReadyQueue> cpu_ready_queue_;
      
        // Future representing the completion of the graph task. Notified when all
        // tasks are done.
        std::shared_ptr<at::ivalue::Future> future_result_;
      
        // Final callbacks installed during execution of this GraphTask
        std::vector<std::function<void()>> final_callbacks_;
        // To protect reads and writes to final_callbacks_. Intentionally no reusing
        // mutex_ as the two are protecting different data structures.
        std::mutex final_callbacks_lock_;
      };
      
      

      我們接下來(lái)看看一些重要成員變量。

      3.2 outstanding_tasks_

      是待處理 NodeTask的數(shù)量,用來(lái)判斷該GrapTask是否還需要執(zhí)行,其數(shù)值總是先加再減,如果數(shù)目為0,則說(shuō)明任務(wù)結(jié)束了。

      • 當(dāng) GraphTask 被創(chuàng)建出來(lái)時(shí)候,此數(shù)值為0。
      • 如果有一個(gè)NodeTask被送入到 ReadyQueue,則outstanding_tasks_ 增加 1。
      • 如果在工作線(xiàn)程作執(zhí)行一次 evaluate_function(task)后,outstanding_tasks的值減1。
      • 如果這個(gè)數(shù)量不為0,則此GraphTask依然需要運(yùn)行。

      3.2.1 任務(wù)結(jié)束

      以下代碼用來(lái)判斷GraphTask是否結(jié)束。

      bool GraphTask::completed() {
        return outstanding_tasks_.load() == 0 ||
            (exit_on_error_ && has_error_.load());
      }
      

      3.2.2 增加

      NodeTask任務(wù)增加時(shí) outstanding_tasks_ 就加一。即,往某一個(gè) ReadyQueue 之中插入一個(gè) NodeTask 時(shí)候, NodeTask 對(duì)應(yīng)的GraphTask 就會(huì)把其 outstanding_tasks_ 增加一。

      auto ReadyQueue::push(NodeTask item, bool incrementOutstandingTasks) -> void {
        {
          // Lock mutex for writing to heap_
          std::lock_guard<std::mutex> lock(mutex_);
          if (incrementOutstandingTasks) {
            std::shared_ptr<GraphTask> graph_task = item.base_.lock();
            ++graph_task->outstanding_tasks_; // 增加
          }
          heap_.push(std::move(item));
        }
        not_empty_.notify_one();
      }
      

      3.2.3 遞減

      NodeTask 任務(wù)結(jié)束時(shí)候就減一,我們用簡(jiǎn)化代碼看看。

      auto Engine::thread_main(const std::shared_ptr<GraphTask>& graph_task) -> void {
      
        while (graph_task == nullptr || !graph_task->future_result_->completed()) { //運(yùn)行 GraphTask
      
          std::shared_ptr<GraphTask> local_graph_task;
          {
            NodeTask task = local_ready_queue->pop();
      
            if (task.fn_ && !local_graph_task->has_error_.load()) {
              // 運(yùn)行 NodeTask
              evaluate_function(local_graph_task, task.fn_.get(), task.inputs_, local_graph_task->cpu_ready_queue_);
            }
          }
      
          // Decrement the outstanding tasks.
          --local_graph_task->outstanding_tasks_; // 運(yùn)行 NodeTask完畢,這里減一
      
          // Check if we've completed execution.
          if (local_graph_task->completed()) { // 判斷 GraphTask是否結(jié)束。
            // 做相關(guān)處理工作
          }
        }
      }
      

      3.3 keep_graph

      keep_graph 用來(lái)指定一次反向計(jì)算后是否釋放資源。資源就是在前向過(guò)程中建立起來(lái)的資源。keep_graph如果是False的話(huà),則會(huì)在 fn 執(zhí)行完畢后調(diào)用 fn 的 will_release_variables 方法來(lái)釋放該資源。

      當(dāng)執(zhí)行反向傳播時(shí)候,在 void Engine::evaluate_function 會(huì)調(diào)用

      auto outputs = call_function(graph_task, func, inputs);
      

      在 call_function 之中,如果發(fā)現(xiàn)不需要保持圖,就釋放資源。

      static variable_list call_function(
          std::shared_ptr<GraphTask>& graph_task,
          Node* func,
          InputBuffer& inputBuffer) {
        CheckpointValidGuard cpvguard(graph_task);
        auto& fn = *func;
        auto inputs =
            call_pre_hooks(fn, InputBuffer::variables(std::move(inputBuffer)));
      
        if (!graph_task->keep_graph_) {
          fn.will_release_variables(); // 如果不需要保持圖,就調(diào)用釋放。
        }
        // 省略其他
      }
      

      3.4 dependencies_

      dependencies 用來(lái)判斷后續(xù)節(jié)點(diǎn)是否已經(jīng)可以被執(zhí)行,其類(lèi)型如下:

      std::unordered_map<Node*, int> dependencies_;
      

      dependencies成員在compute_dependencies調(diào)用中被初始化,只要一個(gè)grad_fn函數(shù)在別人的next_edges()中出現(xiàn)過(guò)一次,那么dependencies[this_grad_fn] 就自增1。如果dependencies[this_grad_fn]大于0,說(shuō)明this_grad_fn有一個(gè)后向傳播的依賴(lài),即this_grad_fn需要等被依賴(lài)者完成,才能進(jìn)行反向傳播。

      比如如下計(jì)算圖:

      # MulBackward0 被 SubBackward0 的next_edges引用 1 次,即 MulBackward0 需要等 SubBackward0 反向計(jì)算完成之后,才能進(jìn)行自己的反向傳播
      dependencies[MulBackward0] = 1
      
      #PowBackward0-1 被 MulBackward0 的next_edges用1次
      dependencies[PowBackward0-1] = 1
      
      #PowBackward0-2 被 MulBackward0 的next_edges用1次
      dependencies[PowBackward0-2] = 1
      

      我們結(jié)合具體代碼(刪除無(wú)關(guān)代碼)看看。

      void Engine::evaluate_function(
          std::shared_ptr<GraphTask>& graph_task,
          Node* func,
          InputBuffer& inputs,
          const std::shared_ptr<ReadyQueue>& cpu_ready_queue) {
      
        // 執(zhí)行后向計(jì)算
        auto outputs = call_function(graph_task, func, inputs);
      
        std::lock_guard<std::mutex> lock(graph_task->mutex_);
        for (int i = 0; i < num_outputs; ++i) { // 遍歷自己的輸出
          auto& output = outputs[i];
          const auto& next = fn.next_edge(i); // 找到第i個(gè)輸出
      
          // Check if the next function is ready to be computed
          bool is_ready = false;
            
          // 得到依賴(lài)關(guān)系  
          auto& dependencies = graph_task->dependencies_;
          auto it = dependencies.find(next.function.get()); // 找到第i個(gè)輸出的依賴(lài)關(guān)系
      
          if (it == dependencies.end()) {
            auto name = next.function->name();
            throw std::runtime_error(std::string("dependency not found for ") + name);
          } else if (--it->second == 0) { // 因?yàn)楸竟?jié)點(diǎn)的后向計(jì)算已經(jīng)完成,所以第i個(gè)輸出的依賴(lài)數(shù)目減一
            dependencies.erase(it); // 如果為0,說(shuō)明沒(méi)有依賴(lài)了,就從依賴(lài)關(guān)系之中刪除
            is_ready = true; // true 代表沒(méi)有依賴(lài)關(guān)系,可以構(gòu)建一個(gè) NodeTask 進(jìn)行下一步反向計(jì)算了
          }
        }
      }
      

      3.5 not_ready_

      用來(lái)暫存未就緒的function及其輸入,類(lèi)型如下:

      std::unordered_map<Node*, InputBuffer> not_ready_;
      

      not_ready_ 是針對(duì)未就緒節(jié)點(diǎn)和其輸入的map,假設(shè)某節(jié)點(diǎn) A 在反向傳播路徑上有兩個(gè)輸入,當(dāng)?shù)谝粋€(gè)輸入完成時(shí)候,因?yàn)榈诙€(gè)輸入沒(méi)有完成反向計(jì)算,所以需要有一個(gè)地方暫存這個(gè) A 和 其第一個(gè)輸入以備后續(xù)處理。not_ready_ 就是用來(lái)做這個(gè)的。

      not_ready_ 的 key 是未就緒節(jié)點(diǎn),value 是這個(gè)節(jié)點(diǎn)目前就緒的輸入列表。

      • 第一次遇到某節(jié)點(diǎn)的一個(gè)輸入之后,就把 (節(jié)點(diǎn) A, A 的輸入信息 ) 放入 not_ready_ 這里,得到 (節(jié)點(diǎn) A, [A 的輸入信息 1 ] )

      • 后續(xù)遇到 A 的其他輸入,就繼續(xù)調(diào)整這里,把 A 的其他輸入加入到 "A 的輸入信息" 之中,比如得到 (節(jié)點(diǎn) A, [A 的輸入信息 1,A的輸入信息 2 ] )

      • 如果 此時(shí) A 已經(jīng) ready,就把 A 和其輸入信息 放入 一個(gè) Ready Queue,然后從 not_ready_ 移除 節(jié)點(diǎn) A。

      • 如果 A 還沒(méi)有 ready(A還需要其他輸出),就繼續(xù)維持 not_ready_ 的狀態(tài),把目前 A 輸入都加入到 not_ready_ 里面。

      我們結(jié)合代碼看看。

          auto& not_ready = graph_task->not_ready_;
          auto not_ready_it = not_ready.find(next.function.get());
          if (not_ready_it == not_ready.end()) { // 如果未就緒隊(duì)列之中沒(méi)有next節(jié)點(diǎn)
            // Skip functions that aren't supposed to be executed
            if (!exec_info_.empty()) {
              auto it = exec_info_.find(next.function.get());
              if (it == exec_info_.end() || !it->second.should_execute()) {
                continue;
              }
            }
            // No buffers have been allocated for the function
            InputBuffer input_buffer(next.function->num_inputs()); // 整理 next 節(jié)點(diǎn)的輸入?yún)?shù)信息
      
            // Accumulates into buffer
            const auto opt_next_stream = next.function->stream(c10::DeviceType::CUDA);
            input_buffer.add(next.input_nr, // 插入 next 節(jié)點(diǎn)的輸入?yún)?shù)信息
                             std::move(output),
                             opt_parent_stream,
                             opt_next_stream);
      
            if (is_ready) { // is_ready 是前面小節(jié)之中,通過(guò)依賴(lài)關(guān)系計(jì)算出來(lái)的,true表示可以進(jìn)行反向計(jì)算了
              auto queue = ready_queue(cpu_ready_queue, input_buffer.device());
              queue->push(
                  NodeTask(graph_task, next.function, std::move(input_buffer)));
            } else {
              // 還有依賴(lài)關(guān)系,不能進(jìn)行反向計(jì)算,只能放入未就緒隊(duì)列 not_ready_ 
              not_ready.emplace(next.function.get(), std::move(input_buffer));
            }
          } else { // 如果未就緒隊(duì)列之中已經(jīng)有next節(jié)點(diǎn)
              
            // The function already has a buffer
            auto &input_buffer = not_ready_it->second;
      
            // Accumulates into buffer
            const auto opt_next_stream = next.function->stream(c10::DeviceType::CUDA);
            input_buffer.add(next.input_nr, // 把最新完成反向計(jì)算的輸入插入輸入buffer input_buffer
                             std::move(output),
                             opt_parent_stream,
                             opt_next_stream);
            if (is_ready) { // 如果可以計(jì)算,就放入ready 隊(duì)列
              auto queue = ready_queue(cpu_ready_queue, input_buffer.device());
              queue->push(
                  NodeTask(graph_task, next.function, std::move(input_buffer)));
              not_ready.erase(not_ready_it); // 同時(shí)從未就緒隊(duì)列之中移除
            }
          }
      

      3.6 exec_info_

      ExecInfo 主要作用就是判斷是否需要執(zhí)行,并且注冊(cè)了一個(gè)hook,用來(lái)在計(jì)算梯度時(shí)候做調(diào)用。

      3.6.1 定義

      定義如下:

      struct ExecInfo {
        struct Capture {
          Capture(const Capture&) = delete;
          Capture(Capture&&) = default;
      
          Capture(int input_idx, int output_idx)
              : input_idx_(input_idx), output_idx_(output_idx) {}
          int input_idx_; // within Node inputs
          int output_idx_; // within the output vector of a GraphTask
      
          // This hook will be executed after a grad is captured. The captured
          // grad will be replaced by the return value of the hook.
          struct GradCaptureHook {
            virtual ~GradCaptureHook() = default;
            virtual at::Tensor operator()(const at::Tensor& grad) = 0;
          };
          // The hooks will be called one by one in the order as they were added.
          // The input grad of a hook will be the output of its preceding hook. The
          // first hook will take the captured grad as the input. The output of the
          // last hook will replace the captured grad.
          std::vector<std::unique_ptr<GradCaptureHook>> hooks_;
        };
      
        bool should_execute() const {
          return needed_ || captures_;
        }
      
        bool needed_ = false;
        std::unique_ptr<std::vector<Capture>> captures_;
      };
      
      

      引擎之中生成如下成員變量。

      // Exec info has a bit complicated semantics. If it's empty, it means the task
      // is run in a "default" mode, which means that all next_edges we encounter
      // should get executed. If it's not empty, only functions that have an entry
      // and this entry has needed == True should be executed. exec_info is only empty
      // when the graph is executed via .backward() and the inputs parameter is not passed.
      // Otherwise, when executed through .grad(), or when inputs arg is specified for
      // .backward(), exec_info will be non-empty.
      //
      // exec_info_ is safe to read without synchronization
      std::unordered_map<Node*, ExecInfo> exec_info_;
      

      3.6.2 作用

      exec_info_ 的作用就是給 GraphTask 的每一個(gè)Node配置一個(gè)ExecInfo,就是執(zhí)行信息。

      • 如果exec_info_為空,說(shuō)明該task運(yùn)行在默認(rèn)模式,即,所有遇到的 next_edges 都需要執(zhí)行。

      • 如果 exec_info_ 非空,說(shuō)明只有特定 functions 才會(huì)被執(zhí)行,這些 Functions 的特點(diǎn)是:擁有 entry,并且這個(gè) entry 的 “has needed == True”。

      exec_info_ 何時(shí)為空?何時(shí)非空?

      • 當(dāng)圖被用 .backward() 執(zhí)行,并且沒(méi)有傳遞輸入?yún)?shù),則 exec_info 為空,就是全部執(zhí)行。
      • 如果只是使用用 .grad() 執(zhí)行,或者使用.backward() 執(zhí)行時(shí)候并且給定輸入?yún)?shù),那么 exec_info_ 非空

      所以,exec 和 captured_vars_ 就是針對(duì) grad() 和指定參數(shù)的 backward(),就是標(biāo)注在這種情況下需要計(jì)算哪些梯度。在這種情況下,只有某些節(jié)點(diǎn)需要執(zhí)行,從這些節(jié)點(diǎn)開(kāi)始,有一條路徑通向 outpus

      3.6.3 生成

      在 Engine::execute 之中會(huì)調(diào)用 init_to_execute 生成ExecInfo。

      if (!outputs.empty()) {
        graph_task->init_to_execute(*graph_root, outputs, accumulate_grad, min_topo_nr);
      }
      

      邏輯是:

      Populates exec_info so nodes that should be executed have `exec_info[node].needed_ = true` Only nodes that have a path to any edge in `outputs` should be executed.The code below populates exec_info using recursion, but the actual code does this iteratively. Refer to the numbering to see how the actual code corresponds.A difference to note is that in the iterative version, when you are working with the current Node, you are reponsible to update your parent's is_needed after all your children have been updated.
      

      從其注釋可知,其作用是:填充exec_info,以便應(yīng)執(zhí)行的節(jié)點(diǎn)具有exec_info[node].needed_ = true

      只具特定節(jié)點(diǎn)才應(yīng)該執(zhí)行,這些節(jié)點(diǎn)的性質(zhì)是:節(jié)點(diǎn)擁有一條路徑,這路徑可以通往outputs的任何一條邊。

      下面的代碼使用遞歸填充exec_info,但實(shí)際代碼以迭代方式執(zhí)行此操作。關(guān)鍵代碼如下,就是插入ExecInfo信息 exec_info_.emplace(stack.back().fn_, ExecInfo());。具體刪減版代碼如下:

      void GraphTask::init_to_execute(Node& graph_root, const edge_list& outputs, bool accumulate_grad, uint64_t min_topo_nr) {
        // Populates exec_info so nodes that should be executed have `exec_info[node].needed_ = true`
        // Only nodes that have a path to any edge in `outputs` should be executed.
        // The code below populates exec_info using recursion, but the actual code does this
        // iteratively. Refer to the numbering to see how the actual code corresponds.
        // A difference to note is that in the iterative version, when you are working with
        // the current Node, you are reponsible to update your parent's is_needed after all your
        // children have been updated.
        //
        // is_needed = {fn: True for fn in outputs}             # (0)
        // seen = {}
        // def compute_is_needed(fn):
        //   for next_edge in fn.next_edges:
        //     child_fn = next_edge.fn
        //     if child_fn in seen and is_needed[child_fn]:     # (1)
        //       is_needed[fn] = true
        //     else:
        //       seen.add(child_fn)
        //       if compute_is_needed(child_fn):
        //         is_needed[fn] = true                         # (2)
        //                                                      # (3) exit for-loop
        //   return is_needed[fn]
        // compute_is_needed(graph_root)
        //
        // NB: you might be wondering why we don't populate `seen` with outputs. We cannot
        // because in the case where two outputs lie on the same path, we still need to explore past
        // the first output or we would miss the nodes that are required to compute the second output.
        
        // 這一段就是針對(duì) grad() API 進(jìn)行處理,只有在所求梯度的張量路徑上的其他張量才會(huì)被計(jì)算梯度
        int output_idx = 0;
        for (auto & output_edge : outputs) { // 遍歷輸出邊
          // (0) `is_needed` above corresponds to `exec_info_[fn].needed_`
          Node *output = output_edge.function.get();
          auto & info = exec_info_[output];
          if (accumulate_grad) {
            // if called through `.backward()` we directly set `needed_` for all the outputs to true
            info.needed_ = true;
          } else {
            if (!info.captures_) {
              info.captures_ = make_unique<std::vector<ExecInfo::Capture>>();
            }
            // 第 i 個(gè)輸入對(duì)應(yīng)的輸出
            info.captures_->emplace_back(output_edge.input_nr, output_idx++);
          }
        }
        captured_vars_.resize(output_idx);
      
        auto nodeShouldExecute = [this](Node *fn) {
          auto it = exec_info_.find(fn);
          return it != exec_info_.end() && it->second.should_execute();
        };
      
        std::vector<Frame> stack;
        std::unordered_set<Node*> seen;
        stack.emplace_back(&graph_root);
        exec_info_.emplace(stack.back().fn_, ExecInfo()); // 這里會(huì)初始化 exec_info_,有多個(gè) exec_info
      
        while (!stack.empty()) {
          auto &frame = stack.back();
          const auto fn = frame.fn_;
      
          Node *child_fn = nullptr;
          while((child_fn = frame.get_next_fn()) && !seen.emplace(child_fn).second) {
            // (1) next child exists AND has already been seen
            if (nodeShouldExecute(child_fn)) {
              exec_info_[fn].needed_ = true;
            }
          }
      
          if (child_fn) {
            // (2) next child exists but has not been seen
            if (child_fn->topological_nr() < min_topo_nr) { 
              // child created before the first output means this child cannot have
              // an edge to output
              continue;
            }
            stack.emplace_back(child_fn);
          } else {
            // (3) no next child exists for `fn` means its `needed` has already been
            // finalized. pop stack and update parent
            stack.pop_back();
            if (nodeShouldExecute(fn) && !stack.empty()) {
              exec_info_[stack.back().fn_].needed_ = true;
            }
          }
        }
      }
      

      3.6.4 GradCaptureHook

      其中,ExecInfo.Capture.GradCaptureHook 是要對(duì)梯度再做后續(xù)處理。

      但是這個(gè)使用卻是主要在分布式狀態(tài)下,是因?yàn)榉植际揭嬗幸粋€(gè)累積梯度的需要,這個(gè)必須在正常梯度操作之后的后置處理中完成

      DistEngine::computeDependencies 之中有添加操作:

          // Create a dummy GraphRoot and run init_to_execute with it.
          GraphRoot dummyRoot(edges, {});
          graphTask->init_to_execute(dummyRoot, outputEdges, /*accumulate_grad=*/false, /*min_topo_nr=*/0);
          for (auto& mapEntry : graphTask->exec_info_) {
            auto& execInfo = mapEntry.second;
            if (!execInfo.captures_) {
              continue;
            }
            auto fn = mapEntry.first;
            // There may be nodes other than 'AccumulateGrad', e.g. RecvRPCBackward,
            // to be captured.
            if (auto accumulateGradFn = dynamic_cast<AccumulateGrad*>(fn)) {
              for (auto& capture : *execInfo.captures_) {
                capture.hooks_.push_back( // 在這里添加 hook
                    std::make_unique<DistAccumulateGradCaptureHook>(
                        std::dynamic_pointer_cast<AccumulateGrad>(
                            accumulateGradFn->shared_from_this()),
                        autogradContext));
              }
            }
          }
      

      在 Engine::evaluate_function 之中有使用操作。

        auto& exec_info_ = graph_task->exec_info_;
        if (!exec_info_.empty()) {
          auto& fn_info = exec_info_.at(func);
          if (auto* capture_vec = fn_info.captures_.get()) {
            // Lock mutex for writing to graph_task->captured_vars_.
            std::lock_guard<std::mutex> lock(graph_task->mutex_);
            for (const auto& capture : *capture_vec) {
              // 獲取到 captured_vars_,然后對(duì)其進(jìn)行后置操作
              auto& captured_grad = graph_task->captured_vars_[capture.output_idx_];
              // 這里是引用操作,所以 captured_grad 的賦值實(shí)際就是往 graph_task->captured_vars_ 賦值
              captured_grad = inputs[capture.input_idx_];
              for (auto& hook : capture.hooks_) {
                captured_grad = (*hook)(captured_grad); // 這里使用了 hook 進(jìn)行后置操作
              }
            }
          }
          if (!fn_info.needed_) {
            // Skip execution if we don't need to execute the function.
            return;
          }
        }
      

      3.7 captured_vars_

      上面提到了 captured_vars_,我們因此就一并分析。

      Captures variables是我們返回給用戶(hù)的捕獲梯度。GraphTask執(zhí)行完成后,Captures variables 將移出GraphTask,不再有效。

      // Captures variables are grads captured that we return to the user. After
      // execution of the GraphTask is completed, the captured_vars_ are moved
      // out of the GraphTask and are no longer valid.
      std::vector<Variable> captured_vars_;
      

      這個(gè) captured_vars_ 是可以進(jìn)行后續(xù)處理,就是使用上面提到的GradCaptureHook 在 evaluate_function 進(jìn)行處理,具體賦值也是在 evaluate_function 其中,參見(jiàn)前面代碼之中的注釋?zhuān)覀兒笪脑敿?xì)對(duì)函數(shù)也會(huì)有分析。

      // This hook will be executed after a grad is captured. The captured
      // grad will be replaced by the return value of the hook.
      

      引擎進(jìn)行后向傳播操作,最后返回給調(diào)用者(比如Python代碼)的output結(jié)果就是 captured_vars_

      void GraphTask::mark_as_completed_and_run_post_processing() {
        // Allow only one thread one attempt to process this logic.
        if (future_completed_.exchange(true)) {
          // Future is already marked complete, or being marked as such.
          // In case the marking complete is only in progress, we add a
          // wait() to guarantee the future is marked complete on exit.
          future_result_->wait();
          return;
        }
      
        try {
          // Run post processing, before marking the future as complete.
          // Drop lock prior to completing, to avoid holding across callbacks.
          std::unique_lock<std::mutex> lock(mutex_);
      
          exec_post_processing();
          std::vector<Variable> vars = std::move(captured_vars_); //最后返回的輸出
      
          // Need to unlock before we call markCompleted to avoid holding locks
          // when the callbacks are called.
          lock.unlock();
          // NOLINTNEXTLINE(performance-move-const-arg)
          future_result_->markCompleted(std::move(vars)); // 反向傳播最后的返回輸出
        } catch (std::exception& e) {
          future_result_->setErrorIfNeeded(std::current_exception());
        }
      }
      
      

      0x04 NodeTask

      4.1 緣由

      對(duì)于NodeTask,我們有一個(gè)疑問(wèn):為什么要再增加一個(gè)新類(lèi)型?而不是繼續(xù)使用 GraphTask。

      因?yàn)?GraphTask 只是包括本計(jì)算圖的總體信息,但是具體某一個(gè)節(jié)點(diǎn)如何計(jì)算梯度,GraphTask 是不知道的,所以引入了一個(gè)新類(lèi)型 NodeTask 來(lái)處理。NodeTask 這個(gè)類(lèi)的對(duì)象正是在queue中傳輸?shù)臇|西,就是一個(gè)可以被執(zhí)行的求導(dǎo)函數(shù)。從下面的定義可以看到,我們使用GraphTask、Node、InputBuffer來(lái)構(gòu)建一個(gè)NodeTask實(shí)例,可以認(rèn)為,生產(chǎn)者不停的向 ReadyQueue 插入 NodeTask,消費(fèi)者則從 ReadyQueue 之中提取 NodeTask 進(jìn)行處理。

      4.2 定義

      NodeTask 定義如下:

      struct NodeTask {
        std::weak_ptr<GraphTask> base_; // 所屬的GraphTask
        std::shared_ptr<Node> fn_; // 需要執(zhí)行的Node,比如 PowBackward0
        // This buffer serves as an implicit "addition" node for all of the
        // gradients flowing here.  Once all the dependencies are finished, we
        // use the contents of this buffer to run the function.
        InputBuffer inputs_; // fn_的輸入
        // When worker receives a task with isShutdownTask = true, it will immediately
        // exit. The engine sends a shutdown task to every queue upon its destruction.
        bool isShutdownTask_;
      
        int getReentrantDepth() const;
      
        NodeTask(
            std::weak_ptr<GraphTask> base,
            std::shared_ptr<Node> fn,
            InputBuffer inputs,
            bool isShutdownTask = false)
            : base_(base),
              fn_(std::move(fn)),
              inputs_(std::move(inputs)),
              isShutdownTask_(isShutdownTask) {}
      };
      

      在主線(xiàn)程和工作線(xiàn)程之中都可以插入 NodeTask,我們逐一分析。

      4.3 主線(xiàn)程生產(chǎn)

      主線(xiàn)程有兩種情況會(huì)產(chǎn)生 NodeTask。

      • 剛啟動(dòng)時(shí)候,在 execute_with_graph_task 之中,主線(xiàn)程就是往 index = -1 的 CPU 工作線(xiàn)程的queue 發(fā)送一個(gè) NodeTask。
      // Now that all the non-thread safe fields of the graph_task have been populated,
      // we can enqueue it.
      // 主線(xiàn)程之中
      queue->push(NodeTask(graph_task, std::move(graph_root), std::move(input_buffer)));
      
      • 在 execute_with_graph_task 之中,當(dāng)有重入的反向傳播時(shí)候,也會(huì)插入 NodeTask:
          // We set the worker_device to CPU_DEVICE only if worker_device was previously
          // NO_DEVICE. Setting it to CPU afterwards allow us to detect whether this is
          // a re-entrant call or not.
          set_device(CPU_DEVICE);
      
          // set the graph_task owner to the current device
          graph_task->owner_ = worker_device;
      
          // Now that all the non-thread safe fields of the graph_task have been populated,
          // we can enqueue it.
          queue->push(NodeTask(graph_task, std::move(graph_root), std::move(input_buffer)));
      

      graph_root 的初始化我們可以回憶一下:

        auto graph_root = skip_dummy_node ?
          roots.at(0).function : // 如果只有一個(gè)root,就直接使用root作為 GraphRoot 
          std::make_shared<GraphRoot>(roots, inputs); // 如果多個(gè)root,就構(gòu)造一個(gè)GraphRoot
      

      graph_root 由roots和inputs構(gòu)建,roots就是最終輸出節(jié)點(diǎn)的gradient_edge(),比如 [ (MulBackward0實(shí)例,0),(PowerBackward0, 0) ]。inputs 如果用戶(hù)沒(méi)有指定,就是默認(rèn)的 tensor(1.),如果指定了,就是起始梯度。

      4.4 工作線(xiàn)程生產(chǎn)

      在工作線(xiàn)程 thread_main 中,可以用如下方式構(gòu)建新NodeTask實(shí)例,添加到queue中。

      4.4.1 下一可計(jì)算節(jié)點(diǎn)

      在 evaluate_function 之中,當(dāng)完成一個(gè)節(jié)點(diǎn)的反向計(jì)算之后,會(huì)查找下一個(gè)可以計(jì)算的節(jié)點(diǎn),如果找到了,就取出當(dāng)前節(jié)點(diǎn)的下一條邊,然后依據(jù)這個(gè)邊構(gòu)建一個(gè)NodeTask,放入對(duì)應(yīng)的工作線(xiàn)程(依據(jù)下一條邊的device等等信息)的 ReadyQueue。

      for (int i = 0; i < num_outputs; ++i) { // 遍歷輸入節(jié)點(diǎn)
          
            const auto& next = fn.next_edge(i); // 查找下一個(gè)可以計(jì)算的節(jié)點(diǎn)
      
            if (not_ready_it == not_ready.end()) {
            // Skip functions that aren't supposed to be executed
      
            // No buffers have been allocated for the function
            InputBuffer input_buffer(next.function->num_inputs());
      
            // Accumulates into buffer
            const auto opt_next_stream = next.function->stream(c10::DeviceType::CUDA);
            input_buffer.add(next.input_nr,
                             std::move(output),
                             opt_parent_stream,
                             opt_next_stream);
      
            if (is_ready) {
              auto queue = ready_queue(cpu_ready_queue, input_buffer.device());
              // 插入下一個(gè)需要計(jì)算的NodeTask
              queue->push(
                  NodeTask(graph_task, next.function, std::move(input_buffer)));
            } else {
              not_ready.emplace(next.function.get(), std::move(input_buffer));
            }
          } else {
            // The function already has a buffer
            auto &input_buffer = not_ready_it->second;
      
            // Accumulates into buffer
            const auto opt_next_stream = next.function->stream(c10::DeviceType::CUDA);
            input_buffer.add(next.input_nr,
                             std::move(output),
                             opt_parent_stream,
                             opt_next_stream);
            if (is_ready) {
              auto queue = ready_queue(cpu_ready_queue, input_buffer.device());
              // 插入下一個(gè)需要計(jì)算的NodeTask  
              queue->push(
                  NodeTask(graph_task, next.function, std::move(input_buffer)));
              not_ready.erase(not_ready_it);
            }
          }
      }    
      
      

      其中,const auto& next = fn.next_edge(i); 就是用來(lái)查找下一個(gè)節(jié)點(diǎn)。

      next_edge 代碼如下:

      const Edge& next_edge(size_t index) const noexcept {
        return next_edges_[index];
      }
      

      next_edges_ 指向的是前向圖中該Node節(jié)點(diǎn)的輸入節(jié)點(diǎn),所以在反向傳播中,就是該節(jié)點(diǎn)的輸出節(jié)點(diǎn)

      4.4.2 喚醒

      在 thread_main 之中,有一個(gè) work around。就是:當(dāng)前工作線(xiàn)程完成 graph_task,但此時(shí),擁有g(shù)raph_task的線(xiàn)程可能正在pop()上等待休眠。因此,我們需要向所屬線(xiàn)程發(fā)送一個(gè)仿造的函數(shù)任務(wù),以喚醒它,這樣我們可以退出thread_main

          // Check if we've completed execution.
          if (local_graph_task->completed()) {
            local_graph_task->mark_as_completed_and_run_post_processing();
      
            auto base_owner = local_graph_task->owner_;
            // The current worker thread finish the graph_task, but the owning thread
            // of the graph_task might be sleeping on pop() if it does not have work.
            // So we need to send a dummy function task to the owning thread just to
            // ensure that it's not sleeping, so that we can exit the thread_main.
            // If it has work, it might see that graph_task->outstanding_tasks_ == 0
            // before it gets to the task, but it's a no-op anyway.
            //
            // NB: This is not necessary if the current thread is the owning thread.
            if (worker_device != base_owner) {
              // Synchronize outstanding_tasks_ with queue mutex
              std::atomic_thread_fence(std::memory_order_release);
              ready_queue_by_index(local_graph_task->cpu_ready_queue_, base_owner)
                  ->push(NodeTask(local_graph_task, nullptr, InputBuffer(0)));
            }
          }
      

      4.5 工作線(xiàn)程消費(fèi)

      首先,我們可以回憶一下graph_root 的初始化,graph_root 由roots和inputs構(gòu)建,roots就是最終輸出節(jié)點(diǎn)的gradient_edge(),比如 [ (MulBackward0實(shí)例,0),(PowerBackward0, 0) ]。inputs 如果用戶(hù)沒(méi)有指定,就是默認(rèn)的 tensor(1.)。

        auto graph_root = skip_dummy_node ?
          roots.at(0).function : // 如果只有一個(gè)root,就直接使用root作為 GraphRoot 
          std::make_shared<GraphRoot>(roots, inputs); // 如果多個(gè)root,就構(gòu)造一個(gè)GraphRoot
      

      其次,我們看看如何消費(fèi)。

      當(dāng)worker線(xiàn)程剛被創(chuàng)建出來(lái)的時(shí)候,該線(xiàn)程被阻塞在queue->pop(),就是等待生產(chǎn)者往這個(gè)queue里插入一個(gè)task。當(dāng)主線(xiàn)程往 ReadyQueue 發(fā)送了 NodeTask 實(shí)例之后,消費(fèi)端的工作線(xiàn)程在 thread_main 的 pop 結(jié)束阻塞被喚醒。

      于是worker線(xiàn)程獲取 到了NodeTask。worker線(xiàn)程 然后:

      • 通過(guò)task.base來(lái)訪(fǎng)問(wèn)到這個(gè)GraphTask實(shí)例。
      • 通過(guò) task.fn_ 訪(fǎng)問(wèn)到這個(gè)roots實(shí)例,也就是該NodeTask需要執(zhí)行的后向計(jì)算方法,比如 MulBackward0。
      • 通過(guò)task.inputs_ 來(lái)訪(fǎng)問(wèn)這個(gè)InputBuffer實(shí)例,就是 MulBackward0 的輸入。
      • 后把NodeTask 的 fn_, inputs 傳給evaluate_function。進(jìn)行反向計(jì)算。

      具體代碼如下:

        // 工作線(xiàn)程之中如何消費(fèi) NodeTask
        NodeTask task = local_ready_queue->pop();
        if (task.fn_ && !local_graph_task->has_error_.load()) {
          AutoGradMode grad_mode(local_graph_task->grad_mode_);
          try {
            GraphTaskGuard guard(local_graph_task);
            NodeGuard ndguard(task.fn_);
            // 后向計(jì)算
            evaluate_function(local_graph_task, task.fn_.get(), task.inputs_, local_graph_task->cpu_ready_queue_);
          } catch (std::exception& e) {
            thread_on_exception(local_graph_task, task.fn_, e);
          }
        }
      }
      

      下面是生產(chǎn)者和消費(fèi)者的圖例。

      • 1)主線(xiàn)程往CPU ReadyQueue放入一個(gè) NodeTask。
      • 2)工作線(xiàn)程 1 從 CPU ReadyQueue 取出 NodeTask,開(kāi)始執(zhí)行。
      • 3)工作線(xiàn)程 1 結(jié)束之后,往 device_ready_queues_ 的 某一個(gè) ReadyQueue 插入一個(gè) NodeTask。
      • 4)ReadyQueue 對(duì)應(yīng)的 工作線(xiàn)程 2 取出 NodeTask,開(kāi)始執(zhí)行。
      +--------------+                                                     +-----------------+
      | Main Thread  |                                                     | Worker Thread 1 |
      |              |       1         +-----------------+       2         |                 |
      |              | push(NodeTask)  |                 |  pop(NodeTask)  |                 |
      |          +-------------------> | CPU ReadyQueue  +----------------------->           |
      |              |                 |                 |                 |                 |
      |              |                 +-----------------+                 |                 |
      |              |              +----------------------+               |                 |
      |              |              | device_ready_queues_ |               |                 |
      |              |              |                      |               |                 |
      |              |              |                      |    3          |                 |
      |              |              |    +-------------+   | push(NodeTask)|                 |
      |              |              |    | ReadyQueue  | <------------------------           |
      |              |              |    +------+------+   |               |                 |
      |              |              |           |          |               |                 |
      +--------------+              |           |          |               +-----------------+
                                    |           +------------------+
                                    |                      |       |       +-----------------+
                                    |                      |       |       | Worker Thread 2 |
                                    |                      |       |       |                 |
                                    |                      |       |       |                 |
                                    |                      |       |       |                 |
                                    |    +-------------+   |       +------------->           |
                                    |    | ReadyQueue  |   | pop(NodeTask) |                 |
                                    |    +-------------+   |     4         |                 |
                                    |                      |               |                 |
                                    |                      |               |                 |
                                    |    +-------------+   |               |                 |
                                    |    | ReadyQueue  |   |               |                 |
                                    |    +-------------+   |               |                 |
                                    |                      |               |                 |
                                    +----------------------+               +-----------------+
      
      

      0x05 InputBuffer

      因?yàn)橛械墓?jié)點(diǎn)在反向計(jì)算時(shí)候,有多個(gè)輸入,所以在計(jì)算梯度的時(shí)候, grad_fn 的 輸入可能從 很多條路徑上累積過(guò)來(lái),InputBuffer 就是用來(lái)累積 grad_fn 的輸入

      struct InputBuffer {
          // size 表示有幾個(gè)輸入
        explicit InputBuffer(size_t size)
          : buffer(size) {}
        InputBuffer(const InputBuffer& other) = delete;
        InputBuffer(InputBuffer&& other) = default;
        explicit InputBuffer(variable_list&& inputs): buffer(std::move(inputs)) {};
        InputBuffer& operator=(InputBuffer&& other) = default;
      
        // Accumulates the variable at a specified index.
        // The optional CUDA streams determine which stream the accumulation
        // is run on and how the addition is synchronized.
        void add(size_t pos,
                 Variable&& var,
                 const c10::optional<c10::Stream>& opt_producer_stream,
                 const c10::optional<c10::Stream>& opt_consumer_stream);
      
        at::Device device() const;
      
        Variable operator[](size_t pos) { return buffer[pos]; }
      
        // Returns the inputs as a list of variables. Destroys given InputBuffer.
        static std::vector<Variable> variables(InputBuffer&& g);
      
      private:
        // Variables, pair 中的 int 代表 version  
        std::vector<Variable> buffer;
      };
      

      如何通過(guò) input_buffer.device() 來(lái)得到對(duì)應(yīng)的 device?就是遍歷 input_buffer 中的 variables,其中第一個(gè)設(shè)備非cpu的variable的device將成為input_buffer的device,否則設(shè)備就是CPU。

      auto InputBuffer::device() const -> at::Device {
        // Since we pick the first non-CPU tensor, this won't work with
        // mixed device-type operations (e.g., an op that is both CUDA
        // and XLA).  This is *incredibly* unlikely, so we don't worry
        // about it.
        // 遍歷buffer,獲取第一個(gè)非CPU張量,然后得到他的device
        for (auto& var : buffer) {
          if (var.defined()) {
            auto device = var.device();
            if (device.type() != at::kCPU) {
              return device;
            }
          }
        }
        // Only report to the CPU thread if there really were no tensors
        // from other devices.
        return at::kCPU;
      }
      

      InputBuffer 對(duì)應(yīng)的部分方法如下,有添加參數(shù),也有累積參數(shù)。

        static void accumulate(std::vector<Variable>& buffer,
                               const size_t pos,
                               Variable&& var) {
          TORCH_INTERNAL_ASSERT(pos < buffer.size());
          auto& old_var = buffer[pos];
          // ATen doesn't route sparse additions correctly...
          // do dense + sparse in-place if possible
          if (old_var.is_sparse()) {
            //storage use_count is a big hammer, but for anything lighter there's an adversarial example with unexpected inplace modification
            if (!var.is_sparse() && var.is_contiguous() && var.storage().use_count() == 1) {
                buffer[pos] = var.add_(old_var);
            } else {
                buffer[pos] = var + old_var;
            }
          } else {
            if (var.is_sparse() && !old_var.is_sparse() && old_var.is_contiguous() && old_var.storage().use_count() == 1) {
                buffer[pos] = old_var.add_(var);
            } else {
                buffer[pos] = old_var + var;
            }
          }
        }
      
        void InputBuffer::add(size_t pos,
                              Variable&& var,
                              const c10::optional<c10::Stream>& opt_producer_stream,
                              const c10::optional<c10::Stream>& opt_consumer_stream) {
        TORCH_INTERNAL_ASSERT(pos < buffer.size());
        if (!var.defined()) {
          return;
        }
      
        // Switches to accumulate device
        // The device (and stream) chosen for accumulation is:
        //  (1) var is not a CUDA variable. Accumulation happens on var's device.
        //  (2) var is a CUDA variable and it, the consumer, and the producer share the same device:
        //       (2a) Uses the consumer's stream as the accumulation stream
        //       (2b) Syncs the accumulation stream with the producer's stream (if different)
        //       (2c) Accumulates.
        //  (3) var is a CUDA variable and it shares a device with the consumer but not the producer:
        //       (3a) Uses the consumer's stream as the accumulation stream
        //       (3b) Syncs the accumulation stream with the consumer device's default stream
        //       (3c) Accumulates.
        //  (4) var is a CUDA variable and it shares a device with the producer but not the consumer:
        //       (4a) Uses the producer device's default stream as the accumulation stream
        //       (4b) Syncs the accumulation stream with the the producer's stream
        //       (4c) Accumulates.
        //  (5) var is a CUDA variable and it does not share a device with the consumer or producer.
        //      Accumulation happens on the var device's default stream.
      
        c10::optional<c10::Stream> opt_accumulate_stream = c10::nullopt;
        if (device_of(var)->is_cuda()) {
          const auto on_producer = opt_producer_stream
                              && device_of(var) == opt_producer_stream->device();
          const auto on_consumer = opt_consumer_stream
                              && device_of(var) == opt_consumer_stream->device();
          if (on_producer && on_consumer) {
            // (2a)
            opt_accumulate_stream = opt_consumer_stream;
            if (opt_accumulate_stream != opt_producer_stream) {
              // (2b)
              auto event = c10::Event{c10::DeviceType::CUDA};
              event.record(*opt_producer_stream);
              opt_accumulate_stream->wait(event);
            }
          } else {
            c10::optional<c10::Stream> opt_sync_stream = c10::nullopt;
            const auto guard = c10::impl::VirtualGuardImpl{c10::DeviceType::CUDA};
            if (on_consumer && !on_producer) {
              // (3a)
              opt_accumulate_stream = opt_consumer_stream;
              opt_sync_stream = guard.getDefaultStream(opt_consumer_stream->device());
            } else if (on_producer && !on_consumer) {
              // (4a)
              opt_accumulate_stream = guard.getDefaultStream(opt_producer_stream->device());
              opt_sync_stream = opt_producer_stream;
            } else {
              // (5)
              opt_accumulate_stream = guard.getDefaultStream(*device_of(var));
            }
            if (opt_sync_stream && (opt_accumulate_stream != opt_sync_stream)) {
              // (3b), (4b)
              c10::OptionalDeviceGuard device_guard{opt_sync_stream->device()};
              auto event = c10::Event{c10::DeviceType::CUDA};
              event.record(*opt_sync_stream);
              opt_accumulate_stream->wait(event);
            }
          }
        }
      
        auto& old_var = buffer[pos];
        if (!old_var.defined()) {
          buffer[pos] = std::move(var);
        } else {
          if (opt_accumulate_stream) {
            c10::OptionalStreamGuard stream_guard{opt_accumulate_stream};
            accumulate(buffer, pos, std::move(var));
          } else {
            // (1) non-CUDA variable
            //     Accumulation happens on variable's device
            c10::OptionalDeviceGuard device_guard{device_of(var)};
            accumulate(buffer, pos, std::move(var));
          }
        }
      }
      
      auto InputBuffer::variables(InputBuffer&& g) -> std::vector<Variable> {
        std::vector<Variable> result = std::move(g.buffer);
        return result;
      }
      

      0x06 ReadyQueue

      6.1 定義

      ReadyQueue 用來(lái)在主線(xiàn)程和worker線(xiàn)程之間、以及worker線(xiàn)程和worker線(xiàn)程之間傳輸任務(wù)(NodeTask對(duì)象)。為什么要傳遞 NodeTask?是因?yàn)?u>NodeTask 包含了求導(dǎo)函數(shù),逐一運(yùn)行NodeTask 就是在反向計(jì)算圖路徑上逐一運(yùn)行求導(dǎo)函數(shù),最后往輸出節(jié)點(diǎn)輸出最終梯度。ReadyQueue就指定了worker線(xiàn)程要執(zhí)行的工作流

      其定義如下:

      struct ReadyQueue {
       private:
        // Returns true when t2 should be (weakly) BEFORE t1 in the queue.
        // Shutdown tasks are first and then empty NodeTask are next.
        struct CompareNodeTaskTime {
          bool operator()(NodeTask const & t1, NodeTask const & t2) {
            // NOLINTNEXTLINE(bugprone-branch-clone)
            if (t2.isShutdownTask_) {
              return true;
            } else if (!t1.fn_ || t1.isShutdownTask_) {
              return false;
            } else if (!t2.fn_) {
              return true;
            } else if (t1.getReentrantDepth() == t2.getReentrantDepth()) {
              return t1.fn_->sequence_nr() < t2.fn_->sequence_nr();
            } else {
              return t1.getReentrantDepth() < t2.getReentrantDepth();
            }
          }
        };
      
        // To notify threads waiting on the ReadyQueue of available tasks on the heap_
        std::condition_variable not_empty_;
        // To protect read and writes to heap_
        mutable std::mutex mutex_;
      
        std::priority_queue<NodeTask, std::vector<NodeTask>, CompareNodeTaskTime> heap_;
      
       public:
        // incrementOutstandingTasks indicates whether or not we should increment
        // 'outstanding_tasks_' for the associated GraphTask. This should mostly
        // always be true and is only set false in certain cases (see docs for
        // DistEngine.execute_graph_task_until_ready_queue_empty)
        void push(NodeTask item, bool incrementOutstandingTasks = true);
        void pushShutdownTask();
        NodeTask pop();
        bool empty() const;
        size_t size() const;
      };
      

      ReadyQueue 主要成員函數(shù)/成員變量如下:

      • std::condition_variable not_empty_ 其作用是在線(xiàn)程之間同步。
      • Push 是生成者行為,使用 not_empty_.notify_one() 來(lái)通知消費(fèi)者,這樣就可以解鎖一個(gè)消費(fèi)者。
      • Pop 是消費(fèi)者行為,使用 not_empty_.wait(lock, [this]{ return !heap_.empty(); }) 來(lái)阻塞等待生產(chǎn)。
      • std::priority_queue heap_,使用 CompareNodeTaskTime 來(lái)做比較。
        • 每次 pop 時(shí)會(huì)取出 CompareNodeTaskTime 最小的 NodeTask。
        • CompareNodeTaskTime 依據(jù) ReentrantDepth 和 sequence_nr 做比較,哪一個(gè)小就消費(fèi)哪一個(gè)。因此消費(fèi)的順序不等同于生產(chǎn)的順序,這里生產(chǎn)的意思是往queue之中插入NodeTask
      auto ReadyQueue::push(NodeTask item, bool incrementOutstandingTasks) -> void {
        {
          // Lock mutex for writing to heap_
          std::lock_guard<std::mutex> lock(mutex_);
          if (incrementOutstandingTasks) {
            std::shared_ptr<GraphTask> graph_task = item.base_.lock();
            ++graph_task->outstanding_tasks_;
          }
          heap_.push(std::move(item));
        }
        not_empty_.notify_one();
      }
      
      auto ReadyQueue::pushShutdownTask() -> void {
        {
          std::lock_guard<std::mutex> lock(mutex_);
          heap_.push(NodeTask({}, nullptr, InputBuffer(0), true));
        }
        not_empty_.notify_one();
      }
      
      size_t ReadyQueue::size() const {
        // Lock mutex for accesses to heap_
        std::unique_lock<std::mutex> lock(mutex_);
        return heap_.size();
      }
      
      auto ReadyQueue::pop() -> NodeTask {
        // Lock mutex for accesses to heap_
        std::unique_lock<std::mutex> lock(mutex_);
        not_empty_.wait(lock, [this]{ return !heap_.empty(); });
        auto task = std::move(const_cast<NodeTask&>(heap_.top())); heap_.pop();
        return task;
      }
      
      bool ReadyQueue::empty() const {
        // Lock mutex for accesses to heap_
        std::unique_lock<std::mutex> lock(mutex_);
        return heap_.empty();
      }
      

      6.2 設(shè)備Queue 數(shù)量

      在引擎之中,線(xiàn)程數(shù)量和ReadyQueue 的數(shù)量是由據(jù)設(shè)備的數(shù)量來(lái)決定的。有多少個(gè)設(shè)備,就啟動(dòng)多少個(gè)工作線(xiàn)程,也生成與線(xiàn)程一一對(duì)應(yīng)的ReadyQueue。

      所以,引擎有如下成員變量,使用 vector 來(lái)統(tǒng)一管理 queue。

      // Safe to read device_ready_queues_ without synchronization after initialization
      std::vector<std::shared_ptr<ReadyQueue>> device_ready_queues_;
      

      生成queue具體如下面的代碼:

      auto Engine::start_device_threads() -> void {
        // See Note [Allocating GPUs to autograd threads]
        c10::DeviceIndex num_devices = 0;
        // 得到設(shè)備數(shù)量
        for (const auto& impl_atomic : c10::impl::device_guard_impl_registry) {
          auto* impl = impl_atomic.load();
          if (impl) {
            num_devices = std::max(num_devices, impl->deviceCount());
          }
        }
      
        // 確定queue數(shù)量,并且生成queue
        // allocate one thread for every GPU device (but colocate GPUs of different
        // types), and pre-allocate the device_ready_queues_ to ensure safe reading on it.
        device_ready_queues_ = std::vector<std::shared_ptr<ReadyQueue>>(num_devices);
        for (auto& queue : device_ready_queues_)    {
          // NOLINTNEXTLINE(modernize-make-shared)
          queue.reset(new ReadyQueue());
        }
          
        // 生成線(xiàn)程 
        thread_pool_shared_ = std::make_shared<ThreadPoolShared>();
        for (int i = 0; i < num_devices; ++i) {
          std::thread t(&Engine::thread_init, this, i, device_ready_queues_[i], true);
          t.detach();
        }
        // Wait for the threads to start
        {
          std::unique_lock<std::mutex> lk(non_reentrant_device_thread_mutex_);
          while(non_reentrant_device_thread_count_.load() != static_cast<uint32_t>(num_devices)) {
            non_reentrant_device_thread_condvar_.wait(lk);
          }
        }
      }
      

      因?yàn)槭鞘褂?vector 來(lái)管理queue,所以可以使用設(shè)備號(hào)(device index)去vector里得到每個(gè)device專(zhuān)屬的ReadyQueue。

      auto Engine::ready_queue_by_index(std::shared_ptr<ReadyQueue> cpu_ready_queue, int device_index) -> std::shared_ptr<ReadyQueue> {
        if (device_index == CPU_DEVICE) {
          // return the cpu ready queue passed in
          TORCH_INTERNAL_ASSERT(cpu_ready_queue);
          return cpu_ready_queue;
        } else {
          // Static cast is ok here as the number of device should never overflow an int.
          TORCH_INTERNAL_ASSERT(0 <= device_index && device_index < static_cast<int>(device_ready_queues_.size()));
          // See Note [Allocating GPUs to autograd threads]
          // NB: This function would become obsolete if we truly allocated a CPU thread
          // per device, rather than colocate.
          return device_ready_queues_.at(device_index);
        }
      }
      

      6.3 線(xiàn)程角度看ReadyQueue

      現(xiàn)在,讓我們從線(xiàn)程角度來(lái)看看ReadyQueue。

      6.3.1 工作線(xiàn)程

      每個(gè)autogard 工作線(xiàn)程都與一個(gè)就緒隊(duì)列相關(guān)聯(lián),該隊(duì)列指定該線(xiàn)程要執(zhí)行的工作流,這個(gè)隊(duì)列定義如下。

      // Every autograd worker thread is associated with a ready queue, which specifies
      // the stream of work of this thread to do. This shared_ptr is a thread_local
      // pointer to each thread's ready_queue, and it should be initialized via the
      // Engine::init_local_ready_queue() call in each corresponding thread before execution.
      //
      // The CUDA, XLA threads are shared among all invocations of backwards via
      // device_ready_queues_, while CPU threads are dedicated to processing CPU work for
      // the backward they invoked. So any given graph task maintains its own cpu_ready_queue_
      // where you should send work for it to be done
      //
      // For reentrant backward calls, if we spawn new thread from the current thread
      // because we reached the maximum depth, the new thread will just reuse the same
      // ReadyQueue with the parent thread for performance improvement.
      // see Note [Reentrant backwards] for more details.
      
      static thread_local std::shared_ptr<ReadyQueue> local_ready_queue = nullptr;
      

      這個(gè)shared_ptr是一個(gè)thread_local指針,其指向每個(gè)線(xiàn)程的ready_queue,在執(zhí)行之前,應(yīng)該通過(guò)每個(gè)對(duì)應(yīng)線(xiàn)程中的 Engine::init_local_ready_queue() 調(diào)用對(duì)其進(jìn)行初始化。

      void Engine::init_local_ready_queue(std::shared_ptr<ReadyQueue> ready_queue) {
        if (ready_queue) {
          // if ready_queue provided in the caller, use the caller's ready_queue to initialize local_ready_queue
          local_ready_queue = std::move(ready_queue);
        } else if (!local_ready_queue){
          // otherwise if local_ready_queue not allocated, allocate a new ready_queue
          local_ready_queue = std::make_shared<ReadyQueue>();
        }
      }
      

      對(duì)于可重入的向后調(diào)用,如果由于達(dá)到最大深度而從當(dāng)前線(xiàn)程生成新線(xiàn)程,則新線(xiàn)程將與父線(xiàn)程重用相同的ReadyQueue以提高性能。

      對(duì)于工作線(xiàn)程,其對(duì)應(yīng)的 ReadyQueue 是 device_ready_queues_ 之中對(duì)應(yīng)的 queue,比如下面是用 std::thread t(&Engine::thread_init, this, i, device_ready_queues_[i], true) 來(lái)初始化。

      auto Engine::start_device_threads() -> void {
        // See Note [Allocating GPUs to autograd threads]
        c10::DeviceIndex num_devices = 0;
        for (const auto& impl_atomic : c10::impl::device_guard_impl_registry) {
          auto* impl = impl_atomic.load();
          if (impl) {
            num_devices = std::max(num_devices, impl->deviceCount());
          }
        }
      
        // allocate one thread for every GPU device (but colocate GPUs of different
        // types), and pre-allocate the device_ready_queues_ to ensure safe reading on it.
        device_ready_queues_ = std::vector<std::shared_ptr<ReadyQueue>>(num_devices);
        for (auto& queue : device_ready_queues_)    {
          // NOLINTNEXTLINE(modernize-make-shared)
          queue.reset(new ReadyQueue());
        }
      
        thread_pool_shared_ = std::make_shared<ThreadPoolShared>();
      
        for (int i = 0; i < num_devices; ++i) {
          std::thread t(&Engine::thread_init, this, i, device_ready_queues_[i], true);
          t.detach();
        }
        // Wait for the threads to start
        {
          std::unique_lock<std::mutex> lk(non_reentrant_device_thread_mutex_);
          while(non_reentrant_device_thread_count_.load() != static_cast<uint32_t>(num_devices)) {
            non_reentrant_device_thread_condvar_.wait(lk);
          }
        }
      }
      

      6.3.2 主線(xiàn)程

      對(duì)于主線(xiàn)程,則調(diào)用 init_local_ready_queue() 來(lái) 初始化local ready_queue。

      因?yàn)?init_local_ready_queue 沒(méi)有傳入?yún)?shù),所以新生成一個(gè)queue。

      void Engine::init_local_ready_queue(std::shared_ptr<ReadyQueue> ready_queue) {
        if (ready_queue) {
          // if ready_queue provided in the caller, use the caller's ready_queue to initialize local_ready_queue
          local_ready_queue = std::move(ready_queue);
        } else if (!local_ready_queue){
          // otherwise if local_ready_queue not allocated, allocate a new ready_queue
          local_ready_queue = std::make_shared<ReadyQueue>();
        }
      }
      

      這就是 CPU queue。我們把 CPU queue 和工作線(xiàn)程的queue做比較。

      • 設(shè)備 ReadyQueue 的數(shù)目 與 worker線(xiàn)程數(shù)目相同,每個(gè)worker有一個(gè)對(duì)應(yīng)的 ReadyQueue。CUDA、XLA線(xiàn)程在所有反向傳播調(diào)用之間通過(guò) device_ready_queues_進(jìn)行信息共享。
      • 而CPU線(xiàn)程專(zhuān)用于處理它們調(diào)用的反向傳播相關(guān)CPU工作。因此,任何給定的graph任務(wù)都會(huì)維護(hù)自己的cpu_ready_queue_,用戶(hù)應(yīng)該向其發(fā)送要完成的工作

      CPU queue 就是GraphTask 的成員變量 cpu_ready_queue_。

        // CPU threads are dedicated to processing CPU work for the backward they invoked.
        // So any given graph task maintains its own cpu_ready_queue_ where you should send
        // work for it to be done. We memoize the cpu_ready_queue_ per GraphTask so that
        // we know which ready queue we should push to if we are on device thread (i.e. GPU)
        // and but next NodeTask should be run on CPU.
        std::shared_ptr<ReadyQueue> cpu_ready_queue_;
      

      注意,CPU就緒隊(duì)列為每個(gè)GraphTask獨(dú)有,但CUDA設(shè)備就緒隊(duì)列在所有GraphTask中共享。

      所以,引擎之中就緒隊(duì)列數(shù)目是:設(shè)備數(shù)目 + GraphTask 數(shù)目。

      我們完善一下之前的圖例,加入了GraphTask 和 Engine 信息,具體如下圖:

      • 1)主線(xiàn)程往CPU ReadyQueue放入一個(gè) NodeTask。
      • 2)工作線(xiàn)程 1 從 CPU ReadyQueue 取出 NodeTask,開(kāi)始執(zhí)行。
      • 3)工作線(xiàn)程 1 結(jié)束之后,往 device_ready_queues_ 的 某一個(gè) ReadyQueue 插入一個(gè) NodeTask。
      • 4)ReadyQueue 對(duì)應(yīng)的 工作線(xiàn)程 2 取出 NodeTask,開(kāi)始執(zhí)行。
                                  +-------------------------+
                                  | GraphTask               |
                                  |                         |
                                  |        cpu_ready_queue_ |
                                  |            +            |
                                  |            |            |
                                  +-------------------------+
                                               |
      +--------------+                         |                           +-----------------+
      | Main Thread  |                         v                           | Worker Thread 1 |
      |              |       1         +-------+---------+       2         |                 |
      |              | push(NodeTask)  |                 |  pop(NodeTask)  |                 |
      |          +-------------------> | CPU ReadyQueue  +----------------------->           |
      |              |                 |                 |                 |                 |
      |              |                 +-----------------+                 |                 |
      |              |              +----------------------+               |                 |
      |              |              | Device ReadyQueues   |               |                 |
      |              |              |                      |               |                 |
      |              |              |                      |    3          |                 |
      |              |              |    +-------------+   | push(NodeTask)|                 |
      |              |              |    | ReadyQueue 1| <-----------------------+           |
      |              |              |    +------+------+   |               |                 |
      |              |              |           |          |               |                 |
      +--------------+              |           |          |               +-----------------+
                                    |           +------------------+
                                    |                      |       |       +-----------------+
      +------------------------+    |          .           |       |       | Worker Thread 2 |
      | Engine                 |    |          .           |       |       |                 |
      |                        |    |          .           |       |       |                 |
      |                        |    |                      |       |       |                 |
      |  device_ready_queues_ +---> |    +-------------+   |       +------------->           |
      |                        |    |    | ReadyQueue 2|   | pop(NodeTask) |                 |
      |                        |    |    +-------------+   |     4         |                 |
      +------------------------+    |                      |               |                 |
                                    |                      |               |                 |
                                    |    +-------------+   |               |                 |
                                    |    | ReadyQueue 3|   |               |                 |
                                    |    +-------------+   |               |                 |
                                    |                      |               |                 |
                                    +----------------------+               +-----------------+
      
      

      至此,靜態(tài)結(jié)構(gòu)和基礎(chǔ)類(lèi)介紹完畢,下一篇我們介紹動(dòng)態(tài)邏輯。

      0xFF 參考

      https://www.zhihu.com/column/gemfield

      【PyTorch】聊聊 backward 背后的代碼

      pytorch筆記(計(jì)算圖+autograd)-Node(1)

      詳解Pytorch中的網(wǎng)絡(luò)構(gòu)造

      PyTorch的優(yōu)化器

      PyTorch的分布式

      PyTorch的Tensor(下)

      PyTorch的Tensor(中)

      PyTorch的Tensor(上)

      PyTorch的動(dòng)態(tài)圖(下)

      PyTorch的動(dòng)態(tài)圖(上)

      PyTorch Internals 5:Autograd的實(shí)現(xiàn)

      A GENTLE INTRODUCTION TO TORCH.AUTOGRAD

      PyTorch學(xué)習(xí)筆記(12)——PyTorch中的Autograd機(jī)制介紹

      PyTorch 的 Autograd

      posted @ 2021-10-27 20:02  羅西的思考  閱讀(2251)  評(píng)論(0)    收藏  舉報(bào)
      主站蜘蛛池模板: 在线观看热码亚洲av每日更新| 国产精品高清中文字幕| 亚洲 校园 欧美 国产 另类 | 久久综合色一综合色88欧美| 欧美一区二区三区啪啪| 国产精品中文字幕一区| 久久亚洲精品11p| 亚洲精品日韩在线丰满| 色橹橹欧美在线观看视频高清| 国产一二三五区不在卡| 蜜桃无码一区二区三区| 乱中年女人伦av二区| 日韩高清不卡一区二区三区| 青青青爽在线视频观看| 日本无遮挡真人祼交视频| 日本一卡2卡3卡4卡无卡免费| 色久综合色久综合色久综合| 国产精品久久久久无码网站| 久久久精品2019中文字幕之3| 日韩中文字幕人妻精品| 酒泉市| 亚洲最大天堂在线看视频| 2021国产精品视频网站| 禹城市| 久久久久无码精品国产h动漫| 亚洲一二三区精品美妇| 安塞县| 欧美饥渴熟妇高潮喷水| 日本一区二区三区在线播放| 天啦噜国产精品亚洲精品| 性色欲情网站iwww九文堂| 亚洲中文字幕在线观看| 亚洲精品电影院| 韩国福利视频一区二区三区| jlzz大jlzz大全免费| 国产精品久久久久鬼色| 日韩一区二区三区高清视频| 疯狂做受xxxx高潮欧美日本| 蜜桃av无码免费看永久| 四虎永久播放地址免费| 天堂a无码a无线孕交|