[源码解析] PyTorch 如何实现后向传播 (4)---- 具体算法

[源码解析] PyTorch 如何实现后向传播 (4)---- 具体算法

目录

0x00 摘要

前文中我们介绍了反向传播引擎的动态逻辑,因为具体反向传播算法是在设备线程中完成的,所以我们单独用一章来讲解。

[源码解析] PyTorch 如何实现后向传播 (4)---- 具体算法

本系列其他文章如下:

深度学习利器之自动微分(1)

深度学习利器之自动微分(2)

[源码解析]深度学习利器之自动微分(3) --- 示例解读

[源码解析]PyTorch如何实现前向传播(1) --- 基础类(上)

[源码解析]PyTorch如何实现前向传播(2) --- 基础类(下)

[源码解析] PyTorch如何实现前向传播(3) --- 具体实现

[源码解析] Pytorch 如何实现后向传播 (1)---- 调用引擎

[源码解析] Pytorch 如何实现后向传播 (2)---- 引擎静态结构

[源码解析] Pytorch 如何实现后向传播 (3)---- 引擎动态逻辑

0x01 工作线程主体

thread_main是工作线程的主体函数,主要逻辑就是围绕着 ReadyQueue 执行一个 while 循环,工作线程阻塞在 ReadyQueue -> pop 这里,如果主线程或者其他线程插入了一个 NodeTask,则 pop 会返回取出一个 NodeTask,工作线程处理这个 NodeTask,完成后向计算的一个环节,如果有需要就继续往某一ReadyQueue插入新的 NodeTask,驱动引擎继续执行后向计算其他环节。

thread_main 从如下途径被调用:

  1. CUDA, XLA 设备的 autograd threads 会调用。
  2. CPU 之上的反向传播主线程会调用。
  3. 前两个case 进行可重入反向传播,也会调用。

1.1 线程主体代码

工作线程的计算始于动态图的GraphRoot函数,反向传播就以 Node 的edge为纽带,层层从前向后计算,直到来到了leaf节点,最终完成了反向计算,具体如下:

  • local_graph_task表示我们从队列中检索的graph_task。外部graph_ 任务表示我们需要执行的可重入执行的总体 graph_任务。
  • 从自己的ReadyQueue之中取出NodeTask实例,使用 local_graph_task 为参数来执行evaluate_function(反向传播函数)。
  • outstanding_tasks 自减 1。
  • 如果本 local_graph_task 已经结束(可重入反向传播会运行多个 GraphTask),即:
    • 执行后续操作 exec_post_processing,然后使用 future_result_->markCompleted。
    • 如果这个task是来自其它worker thread,即 worker_device != base_owner,则向那个worker thread的queue发送一个dummy function task,让那个工作线程也执行起来。

具体代码如下:

// thread_main is used by:
// 1). autograd threads for devices (i.e. CUDA, XLA)
// 2). the caller/owning thread of the backward call on CPU (sync mode)
// 3). Renetrant backward that invoked by either 1) or 2)
// The exit conditions are different for the above three cases.
// For 1), we are spinning on running the thread_main on device autograd
//         threads throughout the Engine lifetime, thread_main will get
//         terminated during Engine destruction by pushing shutdown tasks
// For 2), the owning thread of the backward call drives the thread_main
//         synchronously until the graph_task of that owning thread is
//         completed and exit the thread_main to continue executing the
//         result of caller's code.
// For 3), the reentrant backward that invokes
//         thread_main, either from 1) or 2), will not spin and will exit as
//         long as graph_task is completed and notify the owning thread as
//         needed.
auto Engine::thread_main(const std::shared_ptr<GraphTask>& graph_task) -> void {
  // When graph_task is nullptr, this is a long running thread that processes
  // tasks (ex: device threads). When graph_task is non-null (ex: reentrant
  // backwards, user thread), this function is expected to exit once that
  // graph_task complete.

  // local_ready_queue should already been initialized when we get into thread_main
  while (graph_task == nullptr || !graph_task->future_result_->completed()) {
    // local_graph_task represents the graph_task we retrieve from the queue.
    // The outer graph_task represents the overall graph_task we need to execute
    // for reentrant execution.
    std::shared_ptr<GraphTask> local_graph_task;
    {
      // Scope this block of execution since NodeTask is not needed after this
      // block and can be deallocated (release any references to grad tensors
      // as part of inputs_).
      NodeTask task = local_ready_queue->pop(); // 阻塞等待
      // This will only work if the worker is running a non backward task
      // TODO Needs to be fixed this to work in all cases
      if (task.isShutdownTask_) {
        break;
      }

      if (!(local_graph_task = task.base_.lock())) {
        // GraphTask for function is no longer valid, skipping further
        // execution.
        continue;
      }

      if (task.fn_ && !local_graph_task->has_error_.load()) {
       // 利用grad_mode_来配置AutoGradMode,整个反向计算期间的代码都靠GradMode::is_enabled()来判断当前是否是要计算grad  
        AutoGradMode grad_mode(local_graph_task->grad_mode_);
        try {
          // The guard sets the thread_local current_graph_task on construction
          // and restores it on exit. The current_graph_task variable helps
          // queue_callback() to find the target GraphTask to append final
          // callbacks.
          GraphTaskGuard guard(local_graph_task);
          NodeGuard ndguard(task.fn_);
          // 执行后向计算
          evaluate_function(local_graph_task, task.fn_.get(), task.inputs_, local_graph_task->cpu_ready_queue_);
        } catch (std::exception& e) {
          thread_on_exception(local_graph_task, task.fn_, e);
        }
      }
    }

    // Decrement the outstanding tasks.
    --local_graph_task->outstanding_tasks_;

    // Check if we've completed execution.
    if (local_graph_task->completed()) { // 已经结束了,进行后续处理
      local_graph_task->mark_as_completed_and_run_post_processing();

      auto base_owner = local_graph_task->owner_; // 后续是需要在 GraphTask 的 owner_ 处理
      // The current worker thread finish the graph_task, but the owning thread
      // of the graph_task might be sleeping on pop() if it does not have work.
      // So we need to send a dummy function task to the owning thread just to
      // ensure that it's not sleeping, so that we can exit the thread_main.
      // If it has work, it might see that graph_task->outstanding_tasks_ == 0
      // before it gets to the task, but it's a no-op anyway.
      //
      // NB: This is not necessary if the current thread is the owning thread.
      if (worker_device != base_owner) {
        // Synchronize outstanding_tasks_ with queue mutex
        std::atomic_thread_fence(std::memory_order_release);
        // 获取后续工作的queue
        ready_queue_by_index(local_graph_task->cpu_ready_queue_, base_owner)
            ->push(NodeTask(local_graph_task, nullptr, InputBuffer(0)));
      }
    }
  }
}

1.2 使用 Ready Queue

上述代码之中,最后使用 ready_queue_by_index 获取到后续工作对应的queue。

ready_queue_by_index(local_graph_task->cpu_ready_queue_, base_owner)
    ->push(NodeTask(local_graph_task, nullptr, InputBuffer(0)));

如何获取Ready Queue?具体策略是:

  • 如果下一个 需要执行的设备是 CPU,则选用cpu_ready_queue。
  • 否则从device_ready_queues_选取一个GPU对应的 ReadyQueue。

代码如下:

auto Engine::ready_queue_by_index(std::shared_ptr<ReadyQueue> cpu_ready_queue, int device_index) -> std::shared_ptr<ReadyQueue> {
  if (device_index == CPU_DEVICE) {
    // return the cpu ready queue passed in
    TORCH_INTERNAL_ASSERT(cpu_ready_queue);
    return cpu_ready_queue;
  } else {
    // Static cast is ok here as the number of device should never overflow an int.
    TORCH_INTERNAL_ASSERT(0 <= device_index && device_index < static_cast<int>(device_ready_queues_.size()));
    // See Note [Allocating GPUs to autograd threads]
    // NB: This function would become obsolete if we truly allocated a CPU thread
    // per device, rather than colocate.
    return device_ready_queues_.at(device_index);
  }
}

逻辑如下:

+---------------------------------------------------------------------+
|  Main Thread                                                        |
|                                                                     |
|            push(NodeTask)+--------------+                           |
|                                         |                           |
+---------------------------------------------------------------------+
                                          |
                                          |
                                          v
                                   +------+-----+
                                   |            |
                                   | ReadyQueue |
                                   |            |
                                   +------+-----+
                                          |
                                          |
                                          |
+---------------------------------------------------------------------+
| Worker Thread 1                         |                           |
|                                         |                           |
|  thread_main{                           |                           |
|                                         v                           |
|     NodeTask task = local_ready_queue->pop()                        |
|                                                                     |
|     evaluate_function(task.fn_.get(),task.inputs_)                  |
|  }                                                                  |
+---------------------------------------------------------------------+

0x02 反向计算总体逻辑

evaluate_function 方法完成了反向计算的逻辑,总体逻辑如下:

  • 准备工作:如果exec_info需要处理,则处理 captured_vars_。
  • 反向计算:调用 call_function(graph_task, func, inputs),这是反向传播中计算相关的核心逻辑:
    • 调用pre hooks。
    • 调用fn进行计算。
    • 调用post hooks。
  • 扫尾工作:
    • 如果不需要keep graph,则fn.release_variables();
    • 依据 call_function的输出 outputs,进行计算 num_outputs = outputs.size(),得到 num_outputs的元素数量(该数量等同于当前fn的next_edge()返回的list中的元素数量)。
  • 准备下一步工作,具体就是查找后续需要计算的NodeTask,num_outputs 就是在这里被用到。这部分比较复杂。

总体代码如下:

void Engine::evaluate_function(
    std::shared_ptr<GraphTask>& graph_task,
    Node* func, // 导数计算方法
    InputBuffer& inputs, // 当前Node的输入梯度
    const std::shared_ptr<ReadyQueue>& cpu_ready_queue) {
    
  // 进行准备工作  
  // If exec_info_ is not empty, we have to instrument the execution
  auto& exec_info_ = graph_task->exec_info_;
  if (!exec_info_.empty()) {
    auto& fn_info = exec_info_.at(func); // 取出当前的进行处理
    if (auto* capture_vec = fn_info.captures_.get()) {
      // Lock mutex for writing to graph_task->captured_vars_.
      std::lock_guard<std::mutex> lock(graph_task->mutex_);
      for (const auto& capture : *capture_vec) {
        // captured_grad 就是临时存储下,每次node计算都会更新,最终输出给调用者,相当于引用
        // 1. captured_grad 引用了captured_vars_[capture.output_idx_],
        auto& captured_grad = graph_task->captured_vars_[capture.output_idx_];
        // 2. 给 captured_vars_[capture.output_idx_] 赋值 inputs[capture.input_idx_]
        captured_grad = inputs[capture.input_idx_];
        // 遍历hooks,链式调用hook进行计算,captured_grad 不停的作为输入和输出在流水线中流淌
        // 就是针对 captured_vars_[capture.output_idx_]不停的计算,最终结果还是在 captured_vars_[capture.output_idx_] 之中。
        for (auto& hook : capture.hooks_) {
          captured_grad = (*hook)(captured_grad);
        }
      }
    }
    if (!fn_info.needed_) {
      // Skip execution if we don't need to execute the function.
      return;
    }
  }

  // Set the ThreadLocalState before calling the function.
  // NB: The ThreadLocalStateGuard doesn't set the grad_mode because GraphTask
  // always saves ThreadLocalState without grad_mode.
  at::ThreadLocalStateGuard tls_guard(graph_task->thread_locals_);

  // Switches to a function's CUDA stream (if applicable) before calling it
  const auto opt_parent_stream = (*func).stream(c10::DeviceType::CUDA);
  c10::OptionalStreamGuard parent_stream_guard{opt_parent_stream};

  // 进行反向计算
  auto outputs = call_function(graph_task, func, inputs);

  // 如果不需要保持计算图,则本节点释放变量
  auto& fn = *func;
  if (!graph_task->keep_graph_) {
    fn.release_variables();
  }

  // 得到 num_outputs的元素数量(该数量等同于当前fn的next_edge()返回的list中的元素数量),后续遍历本节点输出时候会用到
  int num_outputs = outputs.size();
  if (num_outputs == 0) { // Note: doesn't acquire the mutex
    // Records leaf stream (if applicable)
    // See note "Streaming backwards"
    if (opt_parent_stream) {
      std::lock_guard<std::mutex> lock(graph_task->mutex_);
      graph_task->leaf_streams.emplace(*opt_parent_stream);
    }
    return;
  }

  if (AnomalyMode::is_enabled()) {
    AutoGradMode grad_mode(false);
    for (int i = 0; i < num_outputs; ++i) {
      auto& output = outputs[i];
      at::OptionalDeviceGuard guard(device_of(output));
      if (output.defined() && isnan(output).any().item<uint8_t>()) {
        std::stringstream ss;
      }
    }
  }

  // 准备下一步工作
  // Lock mutex for the accesses to GraphTask dependencies_, not_ready_ and cpu_ready_queue_ below
  std::lock_guard<std::mutex> lock(graph_task->mutex_);
  for (int i = 0; i < num_outputs; ++i) {
    auto& output = outputs[i];
    const auto& next = fn.next_edge(i); // next_edge是该node在前向传播图中的输入,在反向传播时候就是本节点的输出,所以next就是下一个可能运算的节点

    if (!next.is_valid()) continue;

    // Check if the next function is ready to be computed
    bool is_ready = false;
    auto& dependencies = graph_task->dependencies_;
    auto it = dependencies.find(next.function.get()); // 找到下一个节点的依赖

    if (it == dependencies.end()) {
      auto name = next.function->name();
      throw std::runtime_error(std::string("dependency not found for ") + name);
    } else if (--it->second == 0) {
      dependencies.erase(it);
      is_ready = true; // 下一个节点没有入度了,那么说明计算该节点梯度依赖的其他节点梯度都已经计算完成
    }

    // 要去 not_ready里面看看,是否已经存储了
    auto& not_ready = graph_task->not_ready_;
    auto not_ready_it = not_ready.find(next.function.get());
    if (not_ready_it == not_ready.end()) {
      // 下一个节点的梯度还没有进行计算
      // Skip functions that aren't supposed to be executed
      // 跳过不需要计算的节点
      if (!exec_info_.empty()) {
        auto it = exec_info_.find(next.function.get());
        if (it == exec_info_.end() || !it->second.should_execute()) {
          continue;
        }
      }
      // No buffers have been allocated for the function
      InputBuffer input_buffer(next.function->num_inputs()); // 下一个节点前置梯度的buffer,就是下一个节点的输入梯度

      // Accumulates into buffer
      // 下一个节点的输入梯度就是当前节点的输出,所以要拷贝过去
      const auto opt_next_stream = next.function->stream(c10::DeviceType::CUDA);
      input_buffer.add(next.input_nr,
                       std::move(output),
                       opt_parent_stream,
                       opt_next_stream);

      if (is_ready) {
        auto queue = ready_queue(cpu_ready_queue, input_buffer.device());
        // 既然依赖全部完成,就插入到ReadyQueue 之中
        queue->push(
            NodeTask(graph_task, next.function, std::move(input_buffer)));
      } else {
        // 下一个节点的输入依赖还没有完成,就放到not_ready之中。
        not_ready.emplace(next.function.get(), std::move(input_buffer));
      }
    } else {
      // 如果下一个节点已经开始计算,但是没有完成(就是依赖梯度还有),此时应该在not_ready之中
      // The function already has a buffer
      auto &input_buffer = not_ready_it->second;

      // Accumulates into buffer
      const auto opt_next_stream = next.function->stream(c10::DeviceType::CUDA);
      input_buffer.add(next.input_nr,
                       std::move(output),
                       opt_parent_stream,
                       opt_next_stream);
        
      // Graph中每一个node(fn)的输出是下一个node(fn)的输入,下面4句代码来将前一个fn的输出转化为下一个fn的输入  
      if (is_ready) {
        // 如果此时已经没有输入依赖,就放入新的NodeTask,就是下一个需要计算梯度的NodeTask
        auto queue = ready_queue(cpu_ready_queue, input_buffer.device());
        queue->push(
            NodeTask(graph_task, next.function, std::move(input_buffer)));
        //已经完成下一个节点前置梯度计算,从not_ready中移除相应的buffer
        not_ready.erase(not_ready_it);
      }
    }
  }
}

因为这部分代码十分复杂,我们逐一进行分析。

0x03 准备工作

首先我们看看准备工作,具体如下:

  • 取出当前 Node 的 ExecInfo。
  • 取出其 captures_,遍历其中每一个 Capture。
  • 遍历Capture 的 hooks,链式调用hook进行计算。
    • captured_grad 不停的作为输入和输出在流水线中流淌,针对 captured_vars_[capture.output_idx_]陆续计算。
    • 最终结果保存在 captured_vars_[capture.output_idx_] 之中。

代码中有一个细节,就是captured_grad 只是临时存储,每次node计算都会更新,最终输出给调用者,相当于引用

void Engine::evaluate_function(
    std::shared_ptr<GraphTask>& graph_task,
    Node* func, // 导数计算方法
    InputBuffer& inputs, // 当前Node的输入梯度
    const std::shared_ptr<ReadyQueue>& cpu_ready_queue) {
    
  // 进行准备工作  
  // If exec_info_ is not empty, we have to instrument the execution
  auto& exec_info_ = graph_task->exec_info_;
  if (!exec_info_.empty()) {
    auto& fn_info = exec_info_.at(func); // 取出当前的进行处理
    if (auto* capture_vec = fn_info.captures_.get()) {
      // Lock mutex for writing to graph_task->captured_vars_.
      std::lock_guard<std::mutex> lock(graph_task->mutex_);
      for (const auto& capture : *capture_vec) {
        // captured_grad 就是临时存储下,每次node计算都会更新,最终输出给调用者,相当于引用
        // 1. captured_grad 引用了captured_vars_[capture.output_idx_],
        auto& captured_grad = graph_task->captured_vars_[capture.output_idx_];
        // 2. 给 captured_vars_[capture.output_idx_] 赋值 inputs[capture.input_idx_]
        captured_grad = inputs[capture.input_idx_];
        // 遍历hooks,链式调用hook进行计算,captured_grad 不停的作为输入和输出在流水线中流淌
        // 就是针对 captured_vars_[capture.output_idx_]不停的计算,最终结果还是在 captured_vars_[capture.output_idx_] 之中。
        for (auto& hook : capture.hooks_) {
          captured_grad = (*hook)(captured_grad);
        }
      }
    }
    if (!fn_info.needed_) {
      // Skip execution if we don't need to execute the function.
      return;
    }
  }

0x04 核心逻辑

call_function是反向传播中计算相关的核心逻辑。

  • 调用注册在本 node上的pre_hooks;
  • 调用node本身,比如MeanBackward0、MulBackward0等。
    • 输入是InputBuffer::variables(std::move(inputBuffer)),一组Variable的实例。当动态图刚开始进行反向计算时,引擎首先执行的是图的根节点——graph_root,它的输入是task.inputs——InputBuffer(0)。
    • 调用的是fn的apply(),apply是多态实现,针对不同的operation会dispatch到operation对应的apply实现上。
    • 输出也是一组Variable的实例 outputs = fn(std::move(inputs_copy)),outputs 要作为下一个fn的输入。
  • 调用注册在node上的post hooks。
  • 返回当前节点对应的导数,这是一个variable_list。

具体代码如下:

static variable_list call_function(
    std::shared_ptr<GraphTask>& graph_task,
    Node* func,
    InputBuffer& inputBuffer) {
  CheckpointValidGuard cpvguard(graph_task);
  auto& fn = *func;
  auto inputs =
      call_pre_hooks(fn, InputBuffer::variables(std::move(inputBuffer)));

  if (!graph_task->keep_graph_) {
    fn.will_release_variables();
  }

  const auto has_post_hooks = !fn.post_hooks().empty();
  variable_list outputs;

  if (has_post_hooks) {
    // In functions/accumulate_grad.cpp, there is some logic to check the
    // conditions under which the incoming gradient can be stolen directly
    // (which elides a deep copy) instead of cloned. One of these conditions
    // is that the incoming gradient's refcount must be 1 (nothing else is
    // referencing the same data).  Stashing inputs_copy here bumps the
    // refcount, so if post hooks are employed, it's actually still ok for
    // accumulate_grad.cpp to steal the gradient if the refcount is 2.
    //
    // "new_grad.use_count() <= 1 + !post_hooks().empty()" in
    // accumulate_grad.cpp accounts for this, but also creates a silent
    // dependency between engine.cpp (ie, this particular engine
    // implementation) and accumulate_grad.cpp.
    //
    // If you change the logic here, make sure it's compatible with
    // accumulate_grad.cpp.
    auto inputs_copy = inputs;
    outputs = fn(std::move(inputs_copy));
  } else {
    outputs = fn(std::move(inputs));
  }

  validate_outputs(fn.next_edges(), outputs, [&](const std::string& msg) {
    std::ostringstream ss;
    return ss.str();
  });

  if(has_post_hooks){
    return call_post_hooks(fn, std::move(outputs), inputs);
  }
  return outputs;
}

0x05 准备下一步工作

这部分是反向传播的复杂之处。

现在调用 call_function,得到了后向传播的输出,记录到了 outputs 之中。

auto outputs = call_function(graph_task, func, inputs);

所以,后半部分就是从 outputs 之中寻找后续可以计算的Node

总体思路就是:遍历后向传播的输出节点(就是该节点在前向计算图中的入边连接的节点),逐一衡量输出节点。遍历循环中分为两段代码,对于每一个输出节点做如下操作:

  • 第一段是依据依赖排查这个节点,得到这个节点是否就绪。核心就是看看这个输出节点在GraphTask的dependencies的计数是否降为0
    • 如果是0,就说明这个节点就绪了,说明这个node不会被未来的计算所依赖了。
    • 如果非0,就说明这个节点有多个输入,即,被多个node连接,而且有的输入还没有计算完成梯度。
  • 第二段是依据是否就绪来处理这个节点,比如放入哪一个queue

5.1 依据依赖排查节点

第一段代码功能是依据依赖关系来 排查节点,得到这个节点是否就绪,具体如下:

  • 假定某一个节点是 output,我们得到对应的边,遍历输出边。

    • 每次把一个输出边记录为 next,func 是 NodeTask 之中的函数。

    • 利用 dependencies_ 的信息,next 是否可以计算。dependencies_ 里面记录的是图中所有节点的依赖。

    • 从 dependencies_ 之中找到 next 对应的依赖数目,把依赖数目减一(通常因为有多个 input)。

      • 如果--it->second == 0,说明该前置节点计算梯度所依赖的其他节点梯度都已经完成计算。则
        • 把该前置节点对应的信息GraphTask中移除,即从GraphTask的dependencies中移除(后续也会从GraphTask的 not_ready 成员变量之中移除)。
        • 将is_ready 置为true,后续会依据这个 is_ready 的数值进行操作。
    • 从 not_ready_ 之中得到 next 对应的输入buffer(后续代码就是对此进行操作);

      • std::unordered_map<Node*, InputBuffer> not_ready_;
        

    代码如下:

  for (int i = 0; i < num_outputs; ++i) { // 遍历输出节点,逐一衡量
    auto& output = outputs[i];
    const auto& next = fn.next_edge(i); // 获得一个输出节点
      
    if (!next.is_valid()) continue;

    // Check if the next function is ready to be computed
    bool is_ready = false;
    auto& dependencies = graph_task->dependencies_; // 拿到GraphTask的依赖关系
    auto it = dependencies.find(next.function.get()); // 找到输出节点的依赖项

    if (it == dependencies.end()) {
      auto name = next.function->name(); // 没找到
      throw std::runtime_error(std::string("dependency not found for ") + name);
    } else if (--it->second == 0) {
      dependencies.erase(it);  // 找到了,并且已经计算完毕
      is_ready = true;
    }

    auto& not_ready = graph_task->not_ready_; 
    auto not_ready_it = not_ready.find(next.function.get()); // 找到输入buffer     

现在已经找到了某一个输出节点,也知道其是否计算完毕(依据有没有依赖项),也拿到了其存在"未就绪队列"的输入buffer(如果存在的话)。

5.2 处理这个节点

第二段是依据是否就绪来处理这个节点,比如放入哪一个queue,是就绪队列?还是未就绪队列?核心是:

  • 如果就绪,就放到该节点对应的 ReadyQueue 去处理。
  • 如果没有就绪,就新建立一个NodeTask放到 GraphTask的 not_ready 等待后续处理。需要注意的是,这个新的NodeTask 是在 worker thread 之中创建的。
  • 如何找到 ReadyQueue?需要看这个 Node 节点的 input_buffer.device() ,即,这个新 NodeTask 应该发送到 input_buffer.device() 那个 device 对应的 ReadyQueue。

我们具体看看如何依据 is_ready 的数值来对 not_ready 进行操作。

  • 如果在 未就绪队列 not_ready 之中 没有找到 next_edge 对应的元素,则:
    • 如果 exec_info_ 不为空,则在 exec_info_ 之中查找 next_edge 对应的元素,如果有元素且注明了不需要执行,就跳到for循环的下一个。
    • 用 next_edge 的流,inut_nr 等信息构建一个 input_buffer。
    • 如果 is_ready 是 True,就用 本 GraphTask,next.function,input_buffer构建一个NodeTask,放入 ReadyQueue(利用 input_buffer.device() 来得到对应的 queue)。这就要唤醒下一个 worker 线程
    • 如果 is_ready 是 False,这通常表明这个node有多个输入(被更多的node连接,使用num_inputs()可以获得数量),也说明此次处理的是这个node的第一个输入,后续还需要使用这个 next_edge,所以这个 next_edge 需要被放到 not_ready 之中。则把 next.function,input_buffer 放入到 not_ready 之中,这个input_buffer 就是 next_edge 后续执行时候需要的各种输入。
  • 如果在 未就绪队列 not_ready 之中找到了 next_edge 对应的元素,则:
    • 拿出来该元素对应的 input_buffer,把信息累积到 input_buffer 之中。此次累积的是该节点的其他输入。 input_buffer.add(next.input_nr, std::move(output), opt_parent_stream, opt_next_stream) 完成了累积操作,next.input_nr 就表明当前的node是反向传播中要流向的node(next)的第几个输入。
    • 如果is_ready 是 True,就用 本 GraphTask,next.function,input_buffer构建一个NodeTask,放入 ReadyQueue。这就要唤醒下一个 worker 线程
    • 从 not_ready 之中移除此元素,就是从 GraphTask 的依赖关系之中去除。

代码如下:

    if (not_ready_it == not_ready.end()) {
      // Skip functions that aren't supposed to be executed
      if (!exec_info_.empty()) {
        auto it = exec_info_.find(next.function.get());
        if (it == exec_info_.end() || !it->second.should_execute()) {
          continue;
        }
      }
      // No buffers have been allocated for the function
      InputBuffer input_buffer(next.function->num_inputs());

      // Accumulates into buffer
      const auto opt_next_stream = next.function->stream(c10::DeviceType::CUDA);
      input_buffer.add(next.input_nr,
                       std::move(output),
                       opt_parent_stream,
                       opt_next_stream);

      if (is_ready) {
        // 找出了下一个Node的queue
        auto queue = ready_queue(cpu_ready_queue, input_buffer.device());
        queue->push( //
            NodeTask(graph_task, next.function, std::move(input_buffer)));
      } else {
        not_ready.emplace(next.function.get(), std::move(input_buffer));
      }
    } else {
      // The function already has a buffer
      auto &input_buffer = not_ready_it->second;

      // Accumulates into buffer
      const auto opt_next_stream = next.function->stream(c10::DeviceType::CUDA);
      input_buffer.add(next.input_nr,
                       std::move(output),
                       opt_parent_stream,
                       opt_next_stream);
      if (is_ready) {
        // 找出了下一个Node的queue
        auto queue = ready_queue(cpu_ready_queue, input_buffer.device());
        queue->push(
            NodeTask(graph_task, next.function, std::move(input_buffer)));
        not_ready.erase(not_ready_it);
      }
    }

具体逻辑图如下:

  1. func 指向了目前正在进行反向计算的 Node。
  2. func 调用自己的 apply 方法进行计算,得出了 outputs,假设有3个输出,遍历,我们选择第三个为 output。
  3. func 的边是 next_edges_ 成员变量,遍历,我们选择第三个边为next。
  4. 用 next 和 GraphTask 的 dependencies_ 来判断 next 是不是就绪。
  5. 如果就绪,把 output 构建一个 input_buffer,然后生成一个 NodeTask,插入到对应的 ReadyQuieue。
  6. 如果没就绪,把 output 构建一个 input_buffer,和 next 一起放入 GraphTask 的 not_ready_,后续会使用。
       1  +---------------+
func +--> | Node          |              +---> ...
          |               |              |
          |               |              |
          |  apply() +------> outputs +------> ...  2
          |               |              |
          |               |              |
          |               |              |                 +--------------+
          |               |              +---> output +--> | input_buffer +--+
          |               |                                +--------------+  |
          |               |                                                  |
          |               |                                                  |
          |               |                                                  | 5
          |               |                                                  |
          |               |                                                  |
          |               |   +----> ...                                     |
          |               |   |                                              +---------+
          |               |   |                                              |         |
          |  next_edges_+---> +----> ...  3                                  |         |
          |               |   |                                              |         |
          |               |   |                                              |         |
          |               |   |                                         5    v         |
          |               |   +----> next +------>+              YES                   |     +------------+
          +---------------+                       |             +---> push(NodeTask) +-----> | ReadyQueue |
                                                  |      4      |                      |     +------------+
                                                  |             |                      |
          +---------------+                       +--> Ready? +-+                      |
          | GraphTask     |                       |             |       6              |
          |               |                       |             | NO                   | 6
          |               |                       |             +----> next.function   |
          | dependencies_+--> map<Node*, int> +-->+                          +         |
          |               |                                                  |         |
          |               |                                                  |         |
          |               |                              6                   v         v
          | not_ready_ +--------------------------------------------->  map<Node*, InputBuffer>
          |               |
          +---------------+

手机如下:

[源码解析] PyTorch 如何实现后向传播 (4)---- 具体算法

0x06 扫尾操作

在 thread_main 之中,如果本task已经结束,即做后续操作,具体代码如下。

auto Engine::thread_main(const std::shared_ptr<GraphTask>& graph_task) -> void {
  
    // 忽略前面代码
  
    // Check if we've completed execution.
	  if (local_graph_task->completed()) { // 判断是否结束
      // 如果结束了,就进行后续操作
      local_graph_task->mark_as_completed_and_run_post_processing();

      auto base_owner = local_graph_task->owner_;
      // The current worker thread finish the graph_task, but the owning thread
      // of the graph_task might be sleeping on pop() if it does not have work.
      // So we need to send a dummy function task to the owning thread just to
      // ensure that it's not sleeping, so that we can exit the thread_main.
      // If it has work, it might see that graph_task->outstanding_tasks_ == 0
      // before it gets to the task, but it's a no-op anyway.
      //
      // NB: This is not necessary if the current thread is the owning thread.
      if (worker_device != base_owner) {
        // Synchronize outstanding_tasks_ with queue mutex
        std::atomic_thread_fence(std::memory_order_release);
        ready_queue_by_index(local_graph_task->cpu_ready_queue_, base_owner)
            ->push(NodeTask(local_graph_task, nullptr, InputBuffer(0)));
      }
    }

我们接下来分析这些扫尾工作。注意,这里是 thread_main 之中的扫尾工作

6.1 判断结束

以下代码用来判断本 GraphTask是否结束,其实就是 ReadyQueue 之中是否还有待运行的 NodeTask。

outstanding_tasks_ 是待处理 NodeTask的数量,用来判断该GrapTask是否还需要执行,其数值总是先加再减,如果数目为0,则说明任务结束了。

  • 当 GraphTask 被创建出来时候,此数值为0。
  • 如果有一个NodeTask被送入到 ReadyQueue,则outstanding_tasks_ 增加 1。
  • 如果在工作线程作执行一次 evaluate_function(task)后,outstanding_tasks的值减 1。
  • 如果这个数量不为0,则此GraphTask依然需要运行。
bool GraphTask::completed() {
  // outstanding_tasks在evaluate_function中可能会被改变
  return outstanding_tasks_.load() == 0 ||
      (exit_on_error_ && has_error_.load());
}

6.2 后续&通知

mark_as_completed_and_run_post_processing 就是进行后续处理。

执行后续操作 exec_post_processing,然后使用 future_result_->markCompleted 通知主线程。

void GraphTask::mark_as_completed_and_run_post_processing() {
  // Allow only one thread one attempt to process this logic.
  if (future_completed_.exchange(true)) {
    // Future is already marked complete, or being marked as such.
    // In case the marking complete is only in progress, we add a
    // wait() to guarantee the future is marked complete on exit.
    future_result_->wait();
    return;
  }

  try {
    // Run post processing, before marking the future as complete.
    // Drop lock prior to completing, to avoid holding across callbacks.
    std::unique_lock<std::mutex> lock(mutex_);

    exec_post_processing(); // 进行后续操作
    std::vector<Variable> vars = std::move(captured_vars_);

    // Need to unlock before we call markCompleted to avoid holding locks
    // when the callbacks are called.
    lock.unlock();
    future_result_->markCompleted(std::move(vars));  // 通知主线程
  } catch (std::exception& e) {
    future_result_->setErrorIfNeeded(std::current_exception());
  }
}

6.2.1 后续操作

后续操作,如果之前有注册了 callback,则进行调用。也会进行流同步。

void GraphTask::exec_post_processing() {
  if (!not_ready_.empty()) {
    throw std::runtime_error("could not compute gradients for some functions");
  }

  // set the thread_local current_graph_task_ as more callbacks can be installed
  // by existing final callbacks.
  GraphTaskGuard guard(shared_from_this());
  // Lock mutex during each iteration for accessing final_callbacks.size()
  // Unlocking is necessary, because the callback can register
  // more callbacks (or they can be registered from other threads
  // while it's waiting.
  std::unique_lock<std::mutex> cb_lock(final_callbacks_lock_);
  // WARNING: Don't use a range-for loop here because more callbacks may be
  // added in between callback calls, so iterators may become invalidated.
  for (size_t i = 0; i < final_callbacks_.size(); ++i) {
    cb_lock.unlock();
    final_callbacks_[i]();
    cb_lock.lock();
  }

  // Syncs leaf streams with default streams (if necessary)
  // See note "Streaming backwards"
  for (const auto& leaf_stream : leaf_streams) {
    const auto guard = c10::impl::VirtualGuardImpl{c10::DeviceType::CUDA};
    const auto default_stream = guard.getDefaultStream(leaf_stream.device());
    if (leaf_stream != default_stream) {
      auto event = c10::Event{c10::DeviceType::CUDA};
      event.record(leaf_stream);
      default_stream.wait(event);
    }
  }
}

6.2.2 通知主线程

之前在 execute 之中会用 fut->wait() 来等待任务完成。下面我们省略了部分代码。

auto Engine::execute(const edge_list& roots,
                     const variable_list& inputs,
                     bool keep_graph,
                     bool create_graph,
                     bool accumulate_grad,
                     const edge_list& outputs) -> variable_list {

  
  // Queue the root
  if (skip_dummy_node) {
    execute_with_graph_task(graph_task, graph_root, std::move(input_buffer));
  } else {
    execute_with_graph_task(graph_task, graph_root, InputBuffer(variable_list()));
  }
  auto& fut = graph_task->future_result_;
  fut->wait();
  return fut->value().toTensorVector();
}

在 mark_as_completed_and_run_post_processing 会用如下代码来通知主线程。

future_result_->markCompleted(std::move(vars));  // 通知主线程

6.3 通知其他线程

如果这个task是来自其它work thread,即 worker_device != base_owner,则向那个worker thread的queue发送一个dummy function task,让那个工作线程也执行起来。

local_graph_task 表示我们从队列中检索的 graph_task。外部graph_ 任务表示我们需要执行的可重入执行的总体graph_任务。

在 thread_main 之中,有一个 work around。就是:当前工作线程完成 graph_task,但此时,拥有graph_task的线程可能正在pop()上等待休眠。因此,我们需要向所属线程发送一个仿造的函数任务,以唤醒它,这样我们可以退出thread_main。

这种情况发生在可重入反向传播的情形。

// If worker_device is any devices (i.e. CPU, CUDA): this is a re-entrant
//    backward call from that device.
graph_task->owner_ = worker_device;

具体代码如下:

    // Check if we've completed execution.
    if (local_graph_task->completed()) {
      local_graph_task->mark_as_completed_and_run_post_processing();
      auto base_owner = local_graph_task->owner_; // 当前设备
        
      if (worker_device != base_owner) {
          
        // 不是同一个设备
          
        // Synchronize outstanding_tasks_ with queue mutex
        std::atomic_thread_fence(std::memory_order_release);
        ready_queue_by_index(local_graph_task->cpu_ready_queue_, base_owner)
            ->push(NodeTask(local_graph_task, nullptr, InputBuffer(0))); // dummy task
      }
    }

其他线程当收到了 dummy task 之后,不会处理,因为 function 是 nullptr,然后就调用 local_ready_queue->pop() 继续从自己的queue 中读取下一个 task

具体如下:

  1. 主线程等待。
  2. 如果工作线程发现GraphTask 已经结束,就通知主线程。
  3. 如果需要唤醒其他线程,就向该线程对应的 queue 插入 NodeTask。
  4. 对应线程取出 NodeTask 进行执行。
                                         +------------------------------------------------+
                                         | Worker Thread 1                                |
                                         |                                                |
                                         |  thread_main{                                  |
                                         |                                                |
                                         |     mark_as_completed_and_run_post_processing  |
                       2 markCompleted() |     {                                          |
                                 +-------------------+                                    |
                                 |       |     }                                          |
                                 |       |                                                |
+---------------+                |       |     push(NodeTask) +-----+                     |
| Main Thread   |                |       |                          |                     |
|               |                |       |   }                      |                     |
|               |                |       |                          |                     |
|               |                |       +------------------------------------------------+
|               |                |                                  |
|               |                |                                3 |
|               |                v                                  v
|               |                                           +-------+-------+
|               |   1      +----------------+               |               |
|               | wait()   |                |               |  ReadyQueue   |
|           +------------> | future_result_ |               |               |
|               |          |                |               +-------+-------+
|               |          +----------------+                       |
|               |                                                   |
|               |                                                 4 | pop(NodeTask)
|               |                                                   |
|               |                                                   v
|               |                                          +--------+---------------------+
|               |                                          | Worker Thread 2              |
|               |                                          |                              |
|               |                                          |                              |
+---------------+                                          |                              |
                                                           |                              |
                                                           |                              |
                                                           +------------------------------+

至此,后向传播已经分析完毕,从下一篇开始,我们正式进入 PyTorch 分布式训练。

0xFF 参考

https://www.zhihu.com/column/gemfield

【PyTorch】聊聊 backward 背后的代码

pytorch笔记(计算图+autograd)-Node(1)

详解Pytorch中的网络构造

PyTorch的优化器

PyTorch的分布式

PyTorch的Tensor(下)

PyTorch的Tensor(中)

PyTorch的Tensor(上)

PyTorch的动态图(下)

PyTorch的动态图(上)

PyTorch Internals 5:Autograd的实现

A GENTLE INTRODUCTION TO TORCH.AUTOGRAD

PyTorch学习笔记(12)——PyTorch中的Autograd机制介绍

PyTorch 的 Autograd

上一篇:图(Graph)-图的遍历(DFS&BFS)


下一篇:MySQL的insert ignore与replace into不同