[源码解析] PyTorch分布式(6) -------- DistributedDataParallel -- 初始化&store

[源码解析] PyTorch分布式(6) ---DistributedDataParallel -- 初始化&store

目录

0x00 摘要

本文是 PyTorch 分布式系列的第六篇, 介绍 DistributedDataParallel 所依赖的初始化方法和Store这两个概念。

本系列其他文章如下:

深度学习利器之自动微分(1)

深度学习利器之自动微分(2)

[源码解析]深度学习利器之自动微分(3) --- 示例解读

[源码解析]PyTorch如何实现前向传播(1) --- 基础类(上)

[源码解析]PyTorch如何实现前向传播(2) --- 基础类(下)

[源码解析] PyTorch如何实现前向传播(3) --- 具体实现

[源码解析] Pytorch 如何实现后向传播 (1)---- 调用引擎

[源码解析] Pytorch 如何实现后向传播 (2)---- 引擎静态结构

[源码解析] Pytorch 如何实现后向传播 (3)---- 引擎动态逻辑

[源码解析] PyTorch 如何实现后向传播 (4)---- 具体算法

[源码解析] PyTorch 分布式(1)------历史和概述

[源码解析] PyTorch 分布式(2) ----- DataParallel(上)

[源码解析] PyTorch 分布式(3) ----- DataParallel(下)

[源码解析] PyTorch 分布式(4)------分布式应用基础概念

[源码解析] PyTorch分布式(5) ------ DistributedDataParallel 总述&如何使用

0x01 回顾

1.1 基本概念

关于分布式通信,PyTorch 提供的几个概念是:进程组,后端,初始化,Store。

  • 进程组 :DDP是真正的分布式训练,可以使用多台机器来组成一次并行运算的任务。为了能够让 DDP 的各个worker之间通信,PyTorch 设置了进程组这个概念。
  • 后端 :后端这个概念是一个逻辑上的概念。本质上后端是一种IPC通信机制。
  • 初始化 : 虽然有了后端和进程组的概念,但是如何让 worker 在建立进程组之前发现彼此? 这就需要一种初始化方法来告诉大家传递一个信息:如何联系到其它机器上的进程。
  • Store : 可以认为是分布式键值存储,利用这个存储就可以在组中的进程之间共享信息以及初始化分布式包 (通过显式创建存储来作为init_method的替代)。

1.2 初始化进程组

在调用任何 DDP 其他方法之前,需要使用torch.distributed.init_process_group()进行初始化。该方法会初始化默认分布式进程组和分布式包。此方法会阻塞,直到所有进程都加入,函数定义如下:

init_process_group ( backend , 
                       init_method = None , 
                       timeout = default_pg_timeout , 
                       world_size =- 1 , 
                       rank =- 1 , 
                       store = None , 
                       group_name = '' , 
                       pg_options = None )

初始化进程组有两种主要方法:

  1. 明确指定 store,rank 和 world_size。
  2. 指定 init_method(一个 URL 字符串),它指示在哪里/如何发现对等点。

如果两者都没有指定,init_method则假定为“env://”。因此大家可以看到,store 和 init_method 是互斥的

init_process_group 的参数具体如下:

  • 后端 – 要使用的后端。有效值包括mpigloo,和nccl。该字段应作为小写字符串(例如"gloo")给出,也可以通过Backend属性(例如Backend.GLOO)访问 。如果在nccl后端每台机器上使用多个进程,则每个进程必须对其使用的每个 GPU 具有独占访问权限,因为在进程之间共享 GPU 可能会导致死锁。
  • init_method – 指定如何初始化进程组的 URL。如果未指定init_methodstore指定,则默认为“env://” 。与 store互斥。
  • world_size – 参与作业的进程数。如果store指定,则 world_size 为必需。
  • rank – 当前进程的等级(它应该是一个介于 0 和world_size-1之间的数字)。如果store指定,则 rank 为必需。
  • store – 所有 worker 都可以访问的键/值存储,用于交换连接/地址信息。与init_method 互斥。
  • timeout – 针对进程组执行的操作超时。默认值等于 30 分钟。这适用于gloo后端。对于nccl,这仅在环境变量NCCL_BLOCKING_WAITNCCL_ASYNC_ERROR_HANDLING设置为 1 时 适用。
  • group_name – 组名。
  • pg_options ( Process Group Options , optional ) – 进程组选项,指定在构建特定进程组期间需要传入哪些附加选项。

0x02 初始化

2.1 初始化方法

目前DDP模块支持三种初始化方式:

  • Environment variable initialization
  • Shared file-system initialization:init_method='file:///mnt/nfs/sharedfile'
  • TCP initialization :init_method='tcp://10.1.1.20:23456'

环境变量

此方法将从环境变量中读取配置,是允许完全自定义获取信息的方式。通过在所有机器上设置以下四个环境变量,所有进程都可以正常连接到master(就是 rank 0 进程)以获取其他进程的信息,并最终与它们握手。

  • MASTER_PORT:rank 0 进程的机器上的端口。
  • MASTER_ADDR:rank 0 进程的机器上的 IP 地址。
  • WORLD_SIZE: 进程总数,因此master知道要等待多少worker。
  • RANK: 每个进程的rank,所以进程会知道自己是否是master。

共享文件系统

共享文件系统要求所有进程都可以访问共享文件系统,并将通过共享文件协调它们。这意味着每个进程都将打开文件,写入其信息,并等待每个进程都这样做。之后,所有所需的信息都将可供所有流程使用。为了避免竞争条件,文件系统必须通过fcntl支持锁定 。

dist.init_process_group(
    init_method='file:///mnt/nfs/sharedfile',
    rank=args.rank,
    world_size=4)

TCP

TCP 初始化方式是通过提供rank 0进程的IP和端口来实现的,在这里,所有worker都可以连接到等级为 0 的进程并交换有关如何相互联系的信息。

dist.init_process_group(
    init_method='tcp://10.1.1.20:23456',
    rank=args.rank,
    world_size=4)

2.2 init_method VS store

我们很好奇,为什么要有 init_method 和 store 这两个参数?

通过看 init_process_group 代码我们可以发现以下规律。

  • 当 MPI 时候, init_method 没有用处。

  • 在非 MPI 后端时候,如果没有 store 参数,则使用 init_method 构建一个store。

所以最终还是落到了 store 之上,store才是其作用的实体

        if store is None:
            rendezvous_iterator = rendezvous(
                init_method, rank, world_size, timeout=timeout
            )
            store, rank, world_size = next(rendezvous_iterator)
            store.set_timeout(timeout)

init_process_group 代码如下:

def init_process_group(backend,
                       init_method=None,
                       timeout=default_pg_timeout,
                       world_size=-1,
                       rank=-1,
                       store=None,
                       group_name='',
                       pg_options=None):

    global _pg_group_ranks
    global _backend
    global _default_pg_init_method

    if store is not None:
        assert world_size > 0, 'world_size must be positive if using store'
        assert rank >= 0, 'rank must be non-negative if using store'
    elif init_method is None:
        init_method = "env://"

    backend = Backend(backend)

    if backend == Backend.MPI:
          default_pg = _new_process_group_helper(
            -1,
            -1,
            [],
            Backend.MPI,
            None,
            group_name=group_name,
            timeout=timeout)
        _update_default_pg(default_pg)
    else:
        # backward compatible API
        if store is None:
            # 如果没有store,还是要用init_method构建一个store。
            rendezvous_iterator = rendezvous(
                init_method, rank, world_size, timeout=timeout
            )
            store, rank, world_size = next(rendezvous_iterator)
            store.set_timeout(timeout)

        default_pg = _new_process_group_helper(
            world_size,
            rank,
            [],
            backend,
            store,
            pg_options=pg_options,
            group_name=group_name,
            timeout=timeout)
        _update_default_pg(default_pg)

    _pg_group_ranks[GroupMember.WORLD] = {i: i for i in range(GroupMember.WORLD.size())}  # type: ignore[attr-defined, index]
    _backend = _pg_map[GroupMember.WORLD][0]  # type: ignore[index]
    _default_pg_init_method = init_method

    # 省略

2.3 rendezvous

上面代码之中提到了 rendezvous,我们就来看看这个概念。

在我们可以运行集合算法之前,参与的进程需要找到彼此并交换信息才能够进行通信。我们称这个过程为rendezvous。rendezvous过程的结果是一个三元组,其中包含一个共享键/值存储(store),进程的等级(rank)和参与进程的总数。如果内置的rendezvous方法都不适用于您的执行环境,那么您可以选择注册自己的rendezvous处理程序。在调用rendezvous函数时,选择一个唯一的名称并使用URL方案来标识它。

rendezvous 方法就是依据参数,选择不同的handler来处理。

def rendezvous(url: str, rank: int = -1, world_size: int = -1, **kwargs):

    # Append node-specific arguments.
    result = urlparse(url)
    if rank != -1 or world_size != -1:
        query_dict: Dict[str, Union[int, str]] = dict(
            # mypy doesn't allow dict() to accept List of values (#257)
            pair.split("=") for pair in filter(None, result.query.split("&"))  # type: ignore[arg-type, misc]
        )
        if rank != -1:
            query_dict["rank"] = rank
        if world_size != -1:
            query_dict["world_size"] = world_size

        result = result._replace(
            query="{}".format("&".join(["{}={}".format(k, v) for k, v in query_dict.items()]))
        )
        url = urlunparse(result)

    return _rendezvous_handlers[result.scheme](url, **kwargs)

handler 如下,你会发现,其实 handler 就是对应了初始化的三种方法

register_rendezvous_handler("tcp", _tcp_rendezvous_handler)
register_rendezvous_handler("env", _env_rendezvous_handler)
register_rendezvous_handler("file", _file_rendezvous_handler)

2.4 小结

从目前分析结果来看,我们得到了如下结论:

  • init_method 最终还是落到了 store 之上,store才是起作用的实体。
  • 参与的进程需要找到彼此并交换信息才能够进行通信。这个过程被称为rendezvous。

0x03 Store

我们给出一个正式的概念。Store 是分布式包(distributed package)所提供的分布式键值存储,所有的 workers 都会访问这个存储以共享信息以及初始化分布式包 。用户可以通过显式创建存储来作为init_method的替代。目前有 3 种键值存储:TCPStoreFileStore,和HashStore

我们接着上节继续看 handler 概念。

3.1 _rendezvous_handlers

在 PyTorch 定义了一个全局变量 _rendezvous_handlers,用来保存如何返回 store 的方法,可以认为是工厂方法。

_rendezvous_handlers = {}

具体注册方式是:

register_rendezvous_handler("tcp", _tcp_rendezvous_handler)
register_rendezvous_handler("env", _env_rendezvous_handler)
register_rendezvous_handler("file", _file_rendezvous_handler)

注册代码如下,就是往全局变量之中插入handler。

def register_rendezvous_handler(scheme, handler):
    """Registers a new rendezvous handler.
    Args:
        scheme (str): URL scheme to identify your rendezvous handler.
        handler (function): Handler that is invoked when the
            `rendezvous()` function is called with a URL that uses
            the corresponding scheme. It must be a generator function
            that yields the triplet.
    """
    global _rendezvous_handlers
    if scheme in _rendezvous_handlers:
        raise RuntimeError(
            "Rendezvous handler for {}:// already registered".format(scheme)
        )
    _rendezvous_handlers[scheme] = handler

3.2 handlers

如果仔细看 handlers 的代码,就会发现其就是返回了不同的 store,比如 _tcp_rendezvous_handler具体就是使用各种信息建立 TCPStore,然后返回。

以下代码均删除非关键代码。

3.2.1 _file_rendezvous_handler

这里返回了FileStore。

def _file_rendezvous_handler(url: str, **kwargs):

    result = urlparse(url)
    path = result.path
    query: Dict[str, str]
    # mypy doesn't allow dict() to accept List of values (#257)
    query = dict(pair.split("=") for pair in filter(None, result.query.split("&")))  # type: ignore[misc, arg-type]

    rank = int(query["rank"])
    world_size = int(query["world_size"])
    store = FileStore(path, world_size)
    yield (store, rank, world_size)

    # If this configuration is invalidated, there is nothing we can do about it
    raise RuntimeError("Unable to perform rerendezvous using file:// method")

3.2.2 _tcp_rendezvous_handler

这里返回了 TCPStore。

def _tcp_rendezvous_handler(url: str, timeout: timedelta = default_pg_timeout, **kwargs):
    result = urlparse(url)
    query: Dict[str, Union[int, str]]
    # mypy doesn't allow dict() to accept List of values (#257)
    query = dict(pair.split("=") for pair in filter(None, result.query.split("&")))  # type: ignore[misc, arg-type]

    rank = int(query["rank"])
    world_size = int(query["world_size"])
    start_daemon = rank == 0
    assert result.hostname is not None
    store = TCPStore(result.hostname, result.port, world_size, start_daemon, timeout)
    yield (store, rank, world_size)

    # If this configuration is invalidated, there is nothing we can do about it
    raise RuntimeError("Unable to perform rerendezvous using tcp:// method")

3.2.3 _env_rendezvous_handler

居然也返回了 TCPStore,但是其会从环境变量中提取需要的信息。

def _env_rendezvous_handler(url: str, timeout: timedelta = default_pg_timeout, **kwargs):

    result = urlparse(url)
    query: Dict[str, Union[int, str]]
    query = dict(pair.split("=") for pair in filter(None, result.query.split("&"))) 
    rank: Optional[Union[str, int]]
    world_size: Optional[Union[str, int]]
    master_port: Optional[Union[str, int]]

    if "rank" in query:
        rank = int(query["rank"])
    else:
        rank = int(_get_env_or_raise("RANK"))

    if "world_size" in query:
        world_size = int(query["world_size"])
    else:
        world_size = int(_get_env_or_raise("WORLD_SIZE"))

    master_addr = _get_env_or_raise("MASTER_ADDR")
    master_port = int(_get_env_or_raise("MASTER_PORT"))

    use_torchelastic_store = os.environ.get("TORCHELASTIC_USE_AGENT_STORE", None)

    if use_torchelastic_store == str(True):
        worker_process_prefix = "/worker"
        # When TORCHELASTIC_USE_AGENT_STORE is set up, the worker process is assumed
        # to be invoked by the torchelastic agent. Torchelastic agent creates a tcp daemon thread
        # on the GROUP_RANK=0, as a result all user worker processes should create store with: daemon=False
        tcp_store = TCPStore(master_addr, master_port, world_size, False, timeout)
        yield (PrefixStore(worker_process_prefix, tcp_store), rank, world_size)
    else:
        # Start the TCP store daemon on the rank 0
        start_daemon = rank == 0
        store = TCPStore(master_addr, master_port, world_size, start_daemon, timeout)
        yield (store, rank, world_size)

    # If this configuration is invalidated, there is nothing we can do about it
    raise RuntimeError("Unable to perform rerendezvous using env:// method")

3.3 使用

3.3.1 使用 handler

如何使用 handler?在 init_process_group 之中有:

rendezvous_iterator = rendezvous(
    init_method, rank, world_size, timeout=timeout
)
store, rank, world_size = next(rendezvous_iterator)

rendezvous 具体就是依据 init_method 来选择一个 _rendezvous_handler,然后 _rendezvous_handler 返回了 store。

def rendezvous(url: str, rank: int = -1, world_size: int = -1, **kwargs):
    # Append node-specific arguments.
    result = urlparse(url)
    if rank != -1 or world_size != -1:
        query_dict: Dict[str, Union[int, str]] = dict(
            # mypy doesn't allow dict() to accept List of values (#257)
            pair.split("=") for pair in filter(None, result.query.split("&"))  # type: ignore[arg-type, misc]
        )
        if rank != -1:
            query_dict["rank"] = rank
        if world_size != -1:
            query_dict["world_size"] = world_size

        result = result._replace(
            query="{}".format("&".join(["{}={}".format(k, v) for k, v in query_dict.items()]))
        )
        url = urlunparse(result)

    return _rendezvous_handlers[result.scheme](url, **kwargs)

3.3.2 使用 Store

我们继续看如何使用 store。在 init_process_group 代码之中,接下来就使用了 store 来初始化进程组。

default_pg = _new_process_group_helper(
    world_size,
    rank,
    [],
    backend,
    store,
    pg_options=pg_options,
    group_name=group_name,
    timeout=timeout)
_update_default_pg(default_pg)
3.3.2.1 _new_process_group_helper

为了接着看 _new_process_group_helper,我们首先看看几个全局变量。以下几个变量 ProcessGroup 信息做了全局存储,比如 _pg_map[pg] = (Backend.NCCL, store)。

# Cached process groups
# For NCCL and GLOO pg, it is a map from ProcessGroup to (Backend, Store)
# For MPI pg, it is a map from ProcessGroup to (Backend, None)
_pg_map: Dict[ProcessGroup, Tuple[str, Optional[Store]]] = {}
# Process group's names, map from ProcessGroup to str
_pg_names: Dict[ProcessGroup, str] = {}
# Process group's global rank to local rank mapping
_pg_group_ranks: Dict[ProcessGroup, Dict[int, int]] = {}

_new_process_group_helper 之中得到了 store 参数之后,据此生成了一个 prefix_store,然后再根据这个 pre_store 来生成了 ProcessGroupGloo。_new_process_group_helper 代码具体如下:

def _new_process_group_helper(world_size,
                              rank,
                              group_ranks,
                              backend,
                              store,
                              pg_options=None,
                              group_name=None,
                              timeout=default_pg_timeout):
    """
    Create a new distributed process group.

    This function must be called by ALL processes in the global group, even if
    the calling process is not part of the newly created group. In that case,
    this function returns GroupMember.NON_GROUP_MEMBER.

    This function is called with ``group_ranks == []`` for the default group.
    """
    global _pg_map
    global _group_count
    global _pg_names

    if not group_name:
        group_name = str(_group_count)
        _group_count += 1

    # The list of group ranks is empty if we're creating the default group.
    is_default_group = (len(group_ranks) == 0)

    backend = Backend(backend)
    pg: Union[ProcessGroupGloo, ProcessGroupMPI, ProcessGroupNCCL]
    if backend == Backend.MPI: # 没有使用store
        pg = ProcessGroupMPI.create(group_ranks)
        if not pg:
            return GroupMember.NON_GROUP_MEMBER
        _pg_map[pg] = (Backend.MPI, None)
        _pg_names[pg] = group_name
    else:
      	# 这里会使用store
      
        # If this is a subgroup (which means group_ranks is specified),
        # we check if the current process is a member of the new group.
        if not is_default_group:
            global_rank = _get_default_group().rank()
            if global_rank not in group_ranks:
                return GroupMember.NON_GROUP_MEMBER

        # Use the group name as prefix in the default store, such that
        # a single store can be reused by multiple groups.
        
        prefix_store = PrefixStore(group_name, store) # 构建了 PrefixStore

        if backend == Backend.GLOO:
            pg = ProcessGroupGloo(
                prefix_store, # 使用PrefixStore构建进程组
                rank,
                world_size,
                timeout=timeout)
            _pg_map[pg] = (Backend.GLOO, store)
            _pg_names[pg] = group_name
        elif backend == Backend.NCCL:
            if pg_options is not None:
                assert isinstance(pg_options, ProcessGroupNCCL.Options), \
                    "Expected pg_options argument to be of type ProcessGroupNCCL.Options"
            else:
                # default pg_options for NCCL
                pg_options = ProcessGroupNCCL.Options()
                pg_options.is_high_priority_stream = False
                pg_options._timeout = timeout

            pg = ProcessGroupNCCL(
                prefix_store, # 使用PrefixStore构建进程组
                rank,
                world_size,
                pg_options)
            _pg_map[pg] = (Backend.NCCL, store)
            _pg_names[pg] = group_name
        else:
            pg = getattr(Backend, backend.upper())(
                prefix_store,
                rank,
                world_size,
                timeout)
            _pg_map[pg] = (backend, store)
            _pg_names[pg] = group_name

    return pg
3.3.2.2 ProcessGroupGloo

在 ProcessGroupGloo 之中有具体使用,比如在PrefixStore之上生成了一个GlooStore,利用 PrefixStore 建立网络等等。

ProcessGroupGloo::ProcessGroupGloo(
    const c10::intrusive_ptr<Store>& store,
    int rank,
    int size,
    c10::intrusive_ptr<Options> options)
    : ProcessGroup(rank, size),
      store_(new GlooStore(store)), // 在PrefixStore之上生成了一个GlooStore
      options_(options),
      stop_(false),
      collectiveCounter_(0) {
  auto& devices = options->devices;

  contexts_.reserve(options->devices.size());
  for (size_t i = 0; i < options->devices.size(); i++) {
    auto context = std::make_shared<::gloo::rendezvous::Context>(rank_, size_);
    // 又生成了一个PrefixStore
    auto store = ::gloo::rendezvous::PrefixStore(std::to_string(i), *store_);
    context->setTimeout(options->timeout);
    // 利用 PrefixStore 建立网络
    context->connectFullMesh(store, options->devices[i]);
    contexts_.push_back(std::move(context));
  }

  // Every worker thread stores the AsyncWork object it's currently
  // working on in the workInProgress_ vector. It must have size equal
  // to the number of workers such that they can simply index into it
  // using the worker index they are started with.
  workInProgress_.resize(options->threads);

  threads_.resize(options->threads);
  for (size_t i = 0; i < threads_.size(); i++) {
    threads_[i] = std::thread(&ProcessGroupGloo::runLoop, this, i);
  }
}

在下面代码之中,也有对store_的使用,比如等待,存取。

void ProcessGroupGloo::setSequenceNumberForGroup() {
  if (rank_ == 0) {
    // Create and broadcast sequence number
    auto seq = 1 + rand();
    sequenceNum_ = c10d::SequenceNum(seq);
    std::vector<char> values = c10d::toVec<char>(seq, kBytes);
    store_->set(kSeqNumStoreKey, values); // 存value
  } else {
    // Read rank 0's sequence number from store.
    sequenceNum_ = c10d::SequenceNum();
    store_->wait({kSeqNumStoreKey}, options_->timeout); // 等待
    std::vector<char> values = store_->get(kSeqNumStoreKey); // 取value
    uint64_t num = c10d::fromVec<char>(values);
    sequenceNum_->set(num);
  }
}  

3.4 小结

从目前分析结果来看,我们拓展结论如下:

  • init_method 最终还是落到了 store 之上,store才是起作用的实体。
  • 参与的进程需要找到彼此并交换信息才能够进行通信。这个过程被称为rendezvous。
  • rendezvous 其实就是返回了某一种store 以供后续通信使用。
  • 在进程组之中,会使用 store 来构建通信,等待,存取等。

我们接下来选择 TCPStore进行相信分析。

0x04 TCPStore

TCPStore 是基于 TCP 的分布式键值存储实现。服务器存储/保存数据,而存储客户端可以通过 TCP 连接到服务器存储并执行诸如set()插入键值对、get()检索键值对等操作。系统中应该有一个初始化完毕的TCPStore存储服务器,因为存储客户端将等待这个存储服务以建立连接。

TCPStore 的参数如下:

  • host_name ( str ) – 主机名或 IP 地址。存储服务器在其上运行。
  • port ( int ) – 存储服务器在这个端口上侦听传入请求。
  • world_size ( int , optional ) – 用户总数。
    • world_size = 客户端数 + 1,1 代表服务器。
    • 默认值为 -1(负值表示不固定的用户数)。
  • is_master ( bool , optional ) – 初始化存储服务器时为真,初始化存储客户端时为假。默认值为假。
  • timeout ( timedelta , optional ) – store在初始化期间,以及get()和 wait()方法使用的超时时间。默认为 timedelta(seconds=300)。
  • wait_for_worker ( bool , optional ) – 是否等待所有worker与存储服务器连接。这仅在 world_size 为固定值时适用。默认值为真。

使用例子如下:

import torch.distributed as dist
from datetime import timedelta
# Run on process 1 (server)
server_store = dist.TCPStore("127.0.0.1", 1234, 2, True, timedelta(seconds=30))
# Run on process 2 (client)
client_store = dist.TCPStore("127.0.0.1", 1234, 2, False)
# Use any of the store methods from either the client or server after initialization
server_store.set("first_key", "first_value")
client_store.get("first_key")

或者

    >>> import torch.distributed as dist
    >>> from datetime import timedelta
    >>> # Using TCPStore as an example, other store types can also be used
    >>> store = dist.TCPStore("127.0.0.1", 0, 1, True, timedelta(seconds=30))
    >>> # This will throw an exception after 10 seconds
    >>> store.wait(["bad_key"], timedelta(seconds=10))

从例子上看,就是简单的 server,client 或者说 master, worker 的关系,我们接下来仔细分析。

4.1 TCPStore in python

在 Python 世界之中,就是简单的设定了 host 和 port。

class TCPStore(Store):
    def __init__(self, host_name, port, world_size=-1, is_master=False, timeout=None, *args, **kwargs): # real signature unknown; NOTE: unreliably restored from __doc__ 
        pass

    host = property(lambda self: object(), lambda self, v: None, lambda self: None)  # default
    """Gets the hostname on which the store listens for requests."""

    port = property(lambda self: object(), lambda self, v: None, lambda self: None)  # default
    """Gets the port number on which the store listens for requests."""

我们需要深入到 C++ 世界看看。

4.2 TCPStore in CPP

4.2.1 API接口

首先,C++之中的 TCPStore 可以认为是一个API接口,其定义如下:

class TCPStore : public Store {
 public:
  explicit TCPStore(
      const std::string& masterAddr,
      PortType masterPort,
      c10::optional<int> numWorkers = c10::nullopt_t(-1),
      bool isServer = false,
      const std::chrono::milliseconds& timeout = kDefaultTimeout,
      bool waitWorkers = true);

  virtual ~TCPStore();

  void set(const std::string& key, const std::vector<uint8_t>& value) override;
  std::vector<uint8_t> compareSet(
      const std::string& key,
      const std::vector<uint8_t>& expectedValue,
      const std::vector<uint8_t>& desiredValue) override;
  std::vector<uint8_t> get(const std::string& key) override;
  int64_t add(const std::string& key, int64_t value) override;
  bool deleteKey(const std::string& key) override;

  // NOTE: calling other TCPStore APIs inside the callback is NOT threadsafe
  // watchKey() is a blocking operation. It will register the socket on
  // TCPStoreMasterDaemon and the callback on TCPStoreWorkerDaemon. It will
  // return once it has verified the callback is registered on both background
  // threads. Only one thread can call watchKey() at a time.
  void watchKey(const std::string& key, WatchKeyCallback callback) override;
  bool check(const std::vector<std::string>& keys) override;
  int64_t getNumKeys() override;
  void wait(const std::vector<std::string>& keys) override;
  void wait(
      const std::vector<std::string>& keys,
      const std::chrono::milliseconds& timeout) override;
  // Waits for all workers to join.
  void waitForWorkers();
  // Returns the hostname used by the TCPStore.
  const std::string& getHost() const noexcept;
  // Returns the port used by the TCPStore.
  PortType getPort() const noexcept;

 private:
  int64_t addHelper_(const std::string& key, int64_t value);
  std::vector<uint8_t> getHelper_(const std::string& key);
  void waitHelper_(
      const std::vector<std::string>& keys,
      const std::chrono::milliseconds& timeout);

  std::mutex watchKeyMutex_;
  bool isServer_;
  int storeSocket_ = -1; // 
  int listenSocket_ = -1; // 
  int masterListenSocket_ = -1; // master 在这里监听

  std::string tcpStoreAddr_;
  PortType tcpStorePort_;

  c10::optional<int> numWorkers_;
  const std::string initKey_;
  const std::string regularPrefix_;

  std::unique_ptr<TCPStoreMasterDaemon> tcpStoreMasterDaemon_ = nullptr;
  std::unique_ptr<TCPStoreWorkerDaemon> tcpStoreWorkerDaemon_ = nullptr;
};

4.2.2 socket用处

其成员变量之中最主要的是三个socket,或者说他们是 store 的精华(难点)所在。

  int storeSocket_ = -1; // 
  int listenSocket_ = -1; // 
  int masterListenSocket_ = -1; // master 在这里监听
4.2.2.1 业务分工

具体解释如下(后面还会结合代码继续分析):

  • masterListenSocket_ 是 listen 在 masterPort 之上。
    • tcpStoreMasterDaemon_本身是一个master,就是为整个 TCPStore提供服务的 server。
    • tcpStoreMasterDaemon_ 使用 tcputil::addPollfd(fds, storeListenSocket_, POLLIN) 来监听 masterListenSocket_
    • key-value 就是std::unordered_map<std::string, std::vector<uint8_t>> tcpStore。
  • storeSocket_ 在 tcpStoreWorkerDaemon_ 之上,其连接到 masterListenSocket_ : masterPort 之上。
    • storeSocket_ 的作用是封装面对 master port 的操作,用户只管 set,get 等操作,不用知道 master port。
    • set(key, data) 的作用就是通过 storeSocket_ 向master 发送一个设置key : value 的请求。
    • tcpStoreMasterDaemon_ 监听到socket变化,就开始相应。
    • tcpStoreMasterDaemon_ 内部把 key : value 添加到 std::unordered_map<std::string, std::vector<uint8_t>> tcpStore_ 之上。
  • listenSocket_ 在 tcpStoreWorkerDaemon_ 之上,也连接到 masterListenSocket_: masterPort 之上。下面有一个解耦,如注释所述,It will register the socket on TCPStoreMasterDaemon and the callback on TCPStoreWorkerDaemon
    • listenSocket_ 封装了对 watchKey 的处理。Store Client 使用watchKey(const std::string& key, WatchKeyCallback callback) 请求注册,即:
      • Worker 请求注册。使用 tcpStoreWorkerDaemon_->setCallback(regKey, callback) 来为 tcpStoreWorkerDaemon_std::unordered_map<std::string, WatchKeyCallback> keyToCallbacks_ 之上添加一个 callback。
      • Worker 发送请求。通过 listenSocket_ 给 master 发消息 (key, WATCH_KEY),告诉master,如果 key 的 value 有变化,就调用这个 callback。
    • Master 执行注册。Master 接到 WATCH_KEY 消息之后进行注册,调用 watchHandler,使用 watchedSockets_[key].push_back(socket) 来配置,告诉自己,如果这个 key 有变化,就给这个 socket 发消息。
    • Master通知Worker。在 TCPStoreMasterDaemon::setHandler 之中,如果设置了新 value 之后,调用 sendKeyUpdatesToClients,其会遍历 watchedSockets_[key],如果有 socket,就给 socket 发送消息变化通知。
    • Worker执行callback。所以如果 key 有变化,就在 tcpStoreWorkerDaemon_ 之中调用了这个 callback。
4.2.2.2 Set 例子

我们首先看看 Set 的例子如下,就是 Worker 通过 socket 来在 Master 之上设置 value。

                                                                          +
+----------------------------------------------------------------------+  |  +----------------------------------------------+
| TCPStore                                                      Master |  |  | TCPStore                              Worker |
|                                                                      |  |  |                                              |
|                                                                      |  |  |                                              |
|                                                                      |  |  |                                              |
|   +------------------------------------------------------------+     |  |  |                                              |
|   | TcpStoreMasterDaemon_                            MasterPort|     |  |  |                                              |
|   |                                                            |     |  |  |                                              |
|   |    TCPStore.masterListenSocket_                            |     |  |  |      +---------------------------------+     |
|   |                                                            |     |  |  |      | set(key, value)                 |     |
|   |                                                            |     |  |  |      |                                 |     |
|   |    tcpStore_[key] = value  <------------------------------------------------+ |    storeSocket_                 |     |
|   |                                                            |     |  |  |      |                                 |     |
|   |                                                            |     |  |  |      +---------------------------------+     |
|   |                                                            |     |  |  |                                              |
|   +------------------------------------------------------------+     |  |  |                                              |
|                                                                      |  |  |                                              |
+----------------------------------------------------------------------+  |  +----------------------------------------------+
                                                                          +

手机如下:

[源码解析] PyTorch分布式(6) -------- DistributedDataParallel -- 初始化&store

4.2.2.3 Set 和 watchKey 结合

Set 和 watchKey 结合起来的示意图如下(worker请求注册,具体执行回调;master执行注册,通知worker执行回调):

  1. Worker 请求注册。Store Client 使用watchKey(const std::string& key, WatchKeyCallback callback) 就是使用 tcpStoreWorkerDaemon_->setCallback(regKey, callback) 来为 tcpStoreWorkerDaemon_std::unordered_map<std::string, WatchKeyCallback> keyToCallbacks_ 之上添加一个callback。
  2. Worker 发送请求。Worker 通过 listenSocket_ 给 master 发消息 (key, WATCH_KEY),告诉master,如果 key 的 value 有变化,就调用这个 callback。
  3. Master 执行注册。Master 接到 WATCH_KEY 消息之后,调用 watchHandler,使用 watchedSockets_[key].push_back(socket) 来配置,告诉自己,如果这个 key 有变化,就给这个 socket 发消息。
  4. 下面我们假设 Store Client(这里假设是同一个worker设置,实际上可能是不同worker)设置了一个 value。
  5. Master通知Worker。Master 在 TCPStoreMasterDaemon::setHandler 之中,如果设置了新 value 之后,调用 sendKeyUpdatesToClients,其会遍历 watchedSockets_[key],如果有 socket,就给 socket 发送消息变化通知。
  6. Worker执行callback。如果 key 有变化,就在 tcpStoreWorkerDaemon_ 之中调用了这个 callback。
+----------------------------------------------------------------------+  +  +------------------------------------------------------------------------+
| TCPStore                                                      Master |  |  | TCPStore                                                        Worker |
|                                                                      |  |  |                                                                        |
|   +------------------------------------------------------------+     |  |  |                                                                        |
|   | TcpStoreMasterDaemon_                            MasterPort|     |  |  |      +---------------------------------+                               |
|   |                                                            |     |  |  |      |                                 |                               |
|   |                                                  2         |     |  |  |      | watchKey(key, callback) +----------------------+                |
|   |           TCPStore.masterListenSocket_   <----------------------------------+ |                                 |              |                |
|   |                       +                                    |     |  |  |      |    listenSocket_                |              |                |
|   |                       | 3                                  |     |  |  |      |                                 |            1 |                |
|   |                       v                                    |     |  |  |      |                                 |              |                |
|   |           watchedSockets_[key] = socket                    |     |  |  |      +---------------------------------+              |                |
|   |                                                            |     |  |  |                                                       |                |
|   |  +-------------------------------------------------+       |     |  |  |                                                       |                |
|   |  |                                                 |       |     |  |  |                                                       |                |
|   |  |    setHandler                                   |       |     |  |  |   +----------------------------------------------------------------+   |
|   |  |                                                 |       |     |  |  |   | TCPStoreWorkerDaemon                              |            |   |
|   |  |                                                 |       |     |  |  |   |                                                   v            |   |
|   |  |       tcpStore_[key] = newData                  |       |     |  |  |   |   unordered_map<string, WatchKeyCallback> keyToCallbacks_      |   |
|   |  |                   +                             |       |     |  |  |   |                                                                |   |
|   |  |                   |                             |       |     |  |  |   |   TCPStore.listenSocket_                                       |   |
|   |  |                   |                             |       |     |  |  |   |                                                                |   |
|   |  |                   v                             |       |     |  |  |   |  +----------------------------------------------------------+  |   |
|   |  |       sendKeyUpdatesToClients                   |       |     |  |  |   |  | run                                                      |  |   |
|   |  |                   +                             |       |  5  |  |  |   |  |                                                          |  |   |
|   |  |                   |                             |  +---------------------->+                                        6                 |  |   |
|   |  |                   |                             |  |    |     |  |  |   |  |       callbackHandler +-----> keyToCallbacks_(callback)  |  |   |
|   |  |                   v                             |  |    |     |  |  |   |  |                                                          |  |   |
|   |  |                                                 |  |    |     |  |  |   |  +----------------------------------------------------------+  |   |
|   |  |    for (int socket : watchedSockets_[key]){     |  |    |     |  |  |   +----------------------------------------------------------------+   |
|   |  |       tcputil::sendString(socket, key, true) +-----+    |     |  |  |                                                                        |
|   |  |    }                                            |       |     |  |  |                                                                        |
|   |  |                                                 |       |     |  |  |       +------------------------+                                       |
|   |  |                                                 |       |  4  |  |  |       | set(key, newData)      |                                       |
|   |  |                                                 | <-----------------------+ |                        |                                       |
|   |  +-------------------------------------------------+       |     |  |  |       |                        |                                       |
|   |                                                            |     |  |  |       +------------------------+                                       |
|   +------------------------------------------------------------+     |  |  |                                                                        |
|                                                                      |  |  |                                                                        |
+----------------------------------------------------------------------+  +  +------------------------------------------------------------------------+

手机如下:

[源码解析] PyTorch分布式(6) -------- DistributedDataParallel -- 初始化&store

4.2.3 功能函数

TCPStore 提供了若干功能函数。

void TCPStore::set(const std::string& key, const std::vector<uint8_t>& data) {
  std::string regKey = regularPrefix_ + key;
  tcputil::sendValue<QueryType>(storeSocket_, QueryType::SET);
  tcputil::sendString(storeSocket_, regKey, true);
  tcputil::sendVector<uint8_t>(storeSocket_, data);
}

std::vector<uint8_t> TCPStore::get(const std::string& key) {
  std::string regKey = regularPrefix_ + key;
  return getHelper_(regKey);
}

int64_t TCPStore::add(const std::string& key, int64_t value) {
  std::string regKey = regularPrefix_ + key;
  return addHelper_(regKey, value);
}

int64_t TCPStore::addHelper_(const std::string& key, int64_t value) {
  tcputil::sendValue<QueryType>(storeSocket_, QueryType::ADD);
  tcputil::sendString(storeSocket_, key, true);
  tcputil::sendValue<int64_t>(storeSocket_, value);
  return tcputil::recvValue<int64_t>(storeSocket_);
}

这些功能函数是调用如下基础函数来发送接收。

// this is only for convenience when sending rvalues
template <typename T>
void sendValue(int socket, const T& value, bool moreData = false) {
  sendBytes<T>(socket, &value, 1, moreData);
}

template <typename T>
T recvValue(int socket) {
  T value;
  recvBytes<T>(socket, &value, 1);
  return value;
}

4.2.4 构建函数

我们从构建函数可以看到:

  • 对于存储服务器角色,主要就是启动了 tcpStoreMasterDaemon_,注意在启动了 daemon 之后,server 就进入了等待worker状态,不会启动接下来代码中的 tcpStoreWorkerDaemon_
  • 对于存储客户端,则启动了 tcpStoreWorkerDaemon_。
// TCPStore class methods
TCPStore::TCPStore(
    const std::string& masterAddr,
    PortType masterPort,
    c10::optional<int> numWorkers,
    bool isServer,
    const std::chrono::milliseconds& timeout,
    bool waitWorkers)
    : Store(timeout),
      isServer_(isServer),
      tcpStoreAddr_(masterAddr),
      tcpStorePort_(masterPort),
      numWorkers_(numWorkers),
      initKey_("init/"),
      regularPrefix_("/") {
  tcputil::socketInitialize();
  if (isServer_) { // 如果设置了是server,就在masterPort上监听
    // Opening up the listening socket
    std::tie(masterListenSocket_, tcpStorePort_) = tcputil::listen(masterPort);
  }
  try {
    if (isServer_) { // 如果设置了是server,就启动 tcpStoreMasterDaemon_
      // Now start the daemon
      tcpStoreMasterDaemon_ =
          std::make_unique<TCPStoreMasterDaemon>(masterListenSocket_);
    }
    // Connect to the daemon
    // worker 会与 master port 建立联系
    storeSocket_ = tcputil::connect(
        tcpStoreAddr_, tcpStorePort_, /* wait= */ true, timeout_);
    if (numWorkers.value_or(-1) >= 0 && waitWorkers) {
      waitForWorkers(); // server 等待 worker
    }

    // socket to handle requests from server,因为 master 也会给 worker 发消息
    listenSocket_ = tcputil::connect(
        tcpStoreAddr_, tcpStorePort_, /* wait= */ true, timeout_);
    // 启动 worker daemon
    tcpStoreWorkerDaemon_ =
        std::make_unique<TCPStoreWorkerDaemon>(listenSocket_);
  } catch (const std::exception&) {
    if (isServer_) {
      tcpStoreMasterDaemon_ = nullptr;
      tcputil::closeSocket(masterListenSocket_);
    }
    tcpStoreWorkerDaemon_ = nullptr;
    if (listenSocket_ != -1) {
      tcputil::closeSocket(listenSocket_);
    }
    if (storeSocket_ != -1) {
      tcputil::closeSocket(storeSocket_);
    }
    throw;
  }
}

server 会使用如下函数来等待 worker.

void TCPStore::waitForWorkers() {
  addHelper_(initKey_, 1);
  // Let server block until all workers have completed, this ensures that
  // the server daemon thread is always running until the very end
  if (isServer_) {
    const auto start = std::chrono::steady_clock::now();
    while (true) {
      std::vector<uint8_t> value = getHelper_(initKey_);
      auto buf = reinterpret_cast<const char*>(value.data());
      auto len = value.size();
      int numWorkersCompleted = std::stoi(std::string(buf, len));
      if (numWorkersCompleted >= numWorkers_.value_or(-1)) {
        break;
      }
      const auto elapsed = std::chrono::duration_cast<std::chrono::seconds>(
          std::chrono::steady_clock::now() - start);
      if (timeout_ != kNoTimeout && elapsed > timeout_) {
        break;
      }
      /* sleep override */
      std::this_thread::sleep_for(std::chrono::milliseconds(10));
    }
  }
}

4.2.5 TCPStoreWorkerDaemon

这个 daemon 进程只是用来处理 watchKey。

// Separate thread that is launched on all instances (including master)
// Right now only handles callbacks registered from watchKey()
class TCPStoreWorkerDaemon : public BackgroundThread {
 public:
  explicit TCPStoreWorkerDaemon(int listenSocket);
  // Set the callback to run key change
  void setCallback(std::string key, WatchKeyCallback cb);
  void waitForCallbackRegistration() {
    // Block until callback has been registered successfully
    std::unique_lock<std::mutex> callbackRegistrationLock(
        callbackRegistrationMutex_);
    callbackRegisteredCV_.wait(
        callbackRegistrationLock, [&] { return callbackRegisteredData_; });

    // Reset payload for next callback
    callbackRegisteredData_ = false;
  }
  void setCallbackRegistered() {
    callbackRegisteredData_ = true;
    callbackRegisteredCV_.notify_one();
  }

 private:
  void run();
  void callbackHandler(int socket);
  // List of callbacks map each watched key
  std::unordered_map<std::string, WatchKeyCallback> keyToCallbacks_;
  std::mutex keyToCallbacksMutex_;
  std::mutex callbackRegistrationMutex_;
  std::condition_variable callbackRegisteredCV_;
  bool callbackRegisteredData_ = false;
};


其构建函数只是建立一个线程。

// TCPStoreListener class methods
TCPStoreWorkerDaemon::TCPStoreWorkerDaemon(int listenSocket)
    : BackgroundThread(listenSocket) {
  daemonThread_ = std::thread(&TCPStoreWorkerDaemon::run, this);
}
4.2.5.1 watchKey

Client Store 使用watchKey(const std::string& key, WatchKeyCallback callback) 的作用是往master注册监听key:

  • Worker 请求注册。使用 tcpStoreWorkerDaemon_->setCallback(regKey, callback) 来为 tcpStoreWorkerDaemon_std::unordered_map<std::string, WatchKeyCallback> keyToCallbacks_ 之上添加一个 callback。
  • Worker 发送请求。通过 listenSocket_ 给 master 发消息 (key, WATCH_KEY),告诉master,如果 key 的 value 有变化,就调用这个 callback。
  • 然后使用 waitForCallbackRegistration 等待注册完成。
void TCPStore::watchKey(const std::string& key, WatchKeyCallback callback) {
  // Only allow one thread to perform watchKey() at a time
  const std::lock_guard<std::mutex> watchKeyLock(watchKeyMutex_);

  // Register callback with TCPStoreMasterDaemon to call TCPStoreWorkerDaemon on
  // key change
  std::string regKey = regularPrefix_ + key;
  tcpStoreWorkerDaemon_->setCallback(regKey, callback);
  tcputil::sendValue<QueryType>(listenSocket_, QueryType::WATCH_KEY);
  tcputil::sendString(listenSocket_, regKey);

  // Block until callback has been registered successfully
  tcpStoreWorkerDaemon_->waitForCallbackRegistration();
}
4.2.5.2 运行

其运行分为 windows 和 其他系统,但是主要就是收到了业务key,然后进行相关业务处理。

  • Master 执行注册。Master 接到 WATCH_KEY 消息之后,调用 watchHandler,使用 watchedSockets_[key].push_back(socket) 来配置,告诉自己,如果这个 key 有变化,就给这个 socket 发消息。
  • Master通知Worker。在 TCPStoreMasterDaemon::setHandler 之中,如果设置了新 value 之后,调用 sendKeyUpdatesToClients,其会遍历 watchedSockets_[key],如果有 socket,就给 socket 发送消息变化通知。
  • Worker执行callback。所以如果 key 有变化,就在 tcpStoreWorkerDaemon_ 之中调用了这个 callback。
#ifdef _WIN32 
void TCPStoreWorkerDaemon::run() { // 这里是windows系统
  std::vector<struct pollfd> fds;
  tcputil::addPollfd(fds, storeListenSocket_, POLLIN);

  while (true) {
    // Check control and exit early if triggered
    int res;
    SYSCHECK_ERR_RETURN_NEG1(
        res = WSAPoll(fds.data(), fds.size(), checkTimeout_.count()))
    if (res == 0) {
      auto rvPoll = WaitForSingleObject(ghStopEvent_, 0);
      if (rvPoll != WAIT_TIMEOUT) {
        break;
      }
      continue;
    }

    // if connection is closed gracefully by master, peeked data will return 0
    char data;
    int ret = recv(fds[0].fd, &data, 1, MSG_PEEK);
    if (ret == 0) {
      auto rvData = WaitForSingleObject(ghStopEvent_, 0);
      if (rvData != WAIT_TIMEOUT) {
        break;
      }
      continue;
    }

    // valid request, perform callback logic
    callbackHandler(fds[0].fd); // 业务处理
  }
}
#else
void TCPStoreWorkerDaemon::run() {
  std::vector<struct pollfd> fds;
  tcputil::addPollfd(fds, controlPipeFd_[0], POLLHUP);
  tcputil::addPollfd(fds, storeListenSocket_, POLLIN);

  while (true) {
    SYSCHECK_ERR_RETURN_NEG1(::poll(fds.data(), fds.size(), -1));

    // Check control and exit early if triggered
    // The pipe receives an event which tells us to shutdown the listener thread
    if (fds[0].revents != 0) {
      // Will be POLLUP when the pipe is closed
      if (fds[0].revents ^ POLLHUP) {
        throw std::system_error(
            ECONNABORTED,
            std::system_category(),
            "Unexpected poll revent on the control pipe's reading fd: " +
                std::to_string(fds[0].revents));
      }
      break;
    }

    // if connection is closed gracefully by master, peeked data will return 0
    char data;
    int ret = recv(fds[1].fd, &data, 1, MSG_PEEK);
    if (ret == 0) {
      continue;
    }

    // valid request, perform callback logic
    callbackHandler(fds[1].fd); // 业务处理
  }
}
#endif

4.2.6 TCPStoreMasterDaemon

这里的 std::unordered_map<std::string, std::vector<uint8_t>> tcpStore_; 是真实的 kv。

所以,TCPStoreMasterDaemon 就是负责对 kv 的操作,比如存取。

// Separate thread that is only launched on master
class TCPStoreMasterDaemon : public BackgroundThread {
 public:
  explicit TCPStoreMasterDaemon(int storeListenSocket);

 private:
  void run();
  void queryFds(std::vector<struct pollfd>& fds);
  void query(int socket);

  // The master runs on a single thread so only
  // one handler can be executed at a time
  void setHandler(int socket);
  void compareSetHandler(int socket);
  void addHandler(int socket);
  void getHandler(int socket) const;
  void checkHandler(int socket) const;
  void getNumKeysHandler(int socket) const;
  void deleteHandler(int socket);
  void waitHandler(int socket);
  void watchHandler(int socket);

  bool checkKeys(const std::vector<std::string>& keys) const;
  // Helper function to alerts waiting workers, used in setHandler, getHandler
  void wakeupWaitingClients(const std::string& key);
  // Helper function used when the key is changed
  // used in setHandler, addHandler, getHandler, deleteHandler
  void sendKeyUpdatesToClients(
      const std::string& key,
      const enum WatchResponseType& type,
      std::vector<uint8_t>& oldData,
      std::vector<uint8_t>& newData);
  std::unordered_map<std::string, std::vector<uint8_t>> tcpStore_;
  // From key -> the list of sockets waiting on the key
  std::unordered_map<std::string, std::vector<int>> waitingSockets_;
  // From socket -> number of keys awaited
  std::unordered_map<int, size_t> keysAwaited_;
  // From key -> the list of sockets watching the key
  std::unordered_map<std::string, std::vector<int>> watchedSockets_;
};
4.2.6.1 运行

TCPStoreMasterDaemon 就是等待在 socket 之上,即 masterListenSocket_ 是 listen 在 masterPort 之上。

  • tcpStoreMasterDaemon_ 使用 tcputil::addPollfd(fds, storeListenSocket_, POLLIN) 来监听 masterListenSocket_
  • tcpStoreMasterDaemon_本身成为一个master,就是为整个 TCPStore提供服务的 server。
  • key-value 就是std::unordered_map<std::string, std::vector<uint8_t>> tcpStore。
#ifdef _WIN32
void TCPStoreMasterDaemon::run() {
  std::vector<struct pollfd> fds;
  tcputil::addPollfd(fds, storeListenSocket_, POLLIN);

  // receive the queries
  bool finished = false;
  while (!finished) {
    for (size_t i = 0; i < sockets_.size(); i++) {
      fds[i].revents = 0;
    }

    int res;
    SYSCHECK_ERR_RETURN_NEG1(
        res = WSAPoll(fds.data(), fds.size(), checkTimeout_.count()))
    if (res == 0) {
      auto rv = WaitForSingleObject(ghStopEvent_, 0);
      if (rv != WAIT_TIMEOUT) {
        finished = true;
        break;
      }
      continue;
    }

    // TCPStore's listening socket has an event and it should now be able to
    // accept new connections.
    if (fds[0].revents != 0) { // 收到了消息
      if (!(fds[0].revents & POLLIN)) {
        throw std::system_error(
            ECONNABORTED,
            std::system_category(),
            "Unexpected poll revent on the master's listening socket: " +
                std::to_string(fds[0].revents));
      }
      int sockFd = std::get<0>(tcputil::accept(storeListenSocket_));
      sockets_.push_back(sockFd);
      tcputil::addPollfd(fds, sockFd, POLLIN);
    }
    queryFds(fds); // 业务处理
  }
}
#else

void TCPStoreMasterDaemon::run() {
  std::vector<struct pollfd> fds;
  tcputil::addPollfd(fds, storeListenSocket_, POLLIN);
  // Push the read end of the pipe to signal the stopping of the daemon run
  tcputil::addPollfd(fds, controlPipeFd_[0], POLLHUP);

  // receive the queries
  bool finished = false;
  while (!finished) {
    for (size_t i = 0; i < sockets_.size(); i++) {
      fds[i].revents = 0;
    }

    SYSCHECK_ERR_RETURN_NEG1(::poll(fds.data(), fds.size(), -1));

    // TCPStore's listening socket has an event and it should now be able to
    // accept new connections.
    if (fds[0].revents != 0) {
      if (fds[0].revents ^ POLLIN) {
        throw std::system_error(
            ECONNABORTED,
            std::system_category(),
            "Unexpected poll revent on the master's listening socket: " +
                std::to_string(fds[0].revents));
      }
      int sockFd = std::get<0>(tcputil::accept(storeListenSocket_));
      sockets_.push_back(sockFd);
      tcputil::addPollfd(fds, sockFd, POLLIN);
    }

    // The pipe receives an event which tells us to shutdown the daemon
    if (fds[1].revents != 0) { // 收到了消息
      // Will be POLLUP when the pipe is closed
      if (fds[1].revents ^ POLLHUP) {
        throw std::system_error(
            ECONNABORTED,
            std::system_category(),
            "Unexpected poll revent on the control pipe's reading fd: " +
                std::to_string(fds[1].revents));
      }
      finished = true;
      break;
    }
    queryFds(fds); // 业务处理
  }
}
#endif
4.2.6.2 调用业务

queryFds 会根据 socket 监听结果而调用不同业务。

void TCPStoreMasterDaemon::queryFds(std::vector<struct pollfd>& fds) {
  // Skipping the fds[0] and fds[1],
  // fds[0] is master's listening socket
  // fds[1] is control pipe's reading fd, it is not for Windows platform
  for (size_t fdIdx = CONNECT_SOCKET_OFFSET; fdIdx < fds.size(); ++fdIdx) {
    if (fds[fdIdx].revents == 0) {
      continue;
    }

    // Now query the socket that has the event
    try {
      query(fds[fdIdx].fd); // 处理业务
    } catch (...) {
      tcputil::closeSocket(fds[fdIdx].fd);

      // Remove all the tracking state of the close FD
      for (auto it = waitingSockets_.begin(); it != waitingSockets_.end();) {
        for (auto vecIt = it->second.begin(); vecIt != it->second.end();) {
          if (*vecIt == fds[fdIdx].fd) {
            vecIt = it->second.erase(vecIt);
          } else {
            ++vecIt;
          }
        }
        if (it->second.size() == 0) {
          it = waitingSockets_.erase(it);
        } else {
          ++it;
        }
      }
      for (auto it = keysAwaited_.begin(); it != keysAwaited_.end();) {
        if (it->first == fds[fdIdx].fd) {
          it = keysAwaited_.erase(it);
        } else {
          ++it;
        }
      }
      fds.erase(fds.begin() + fdIdx);
      sockets_.erase(sockets_.begin() + fdIdx - CONNECT_SOCKET_OFFSET);
      --fdIdx;
      continue;
    }
  }
}

4.2.6.4 处理业务

从 socket 之中读取消息,依据消息内容来进行相关业务处理。

// query communicates with the worker. The format
// of the query is as follows:
// type of query | size of arg1 | arg1 | size of arg2 | arg2 | ...
// or, in the case of wait
// type of query | number of args | size of arg1 | arg1 | ...
void TCPStoreMasterDaemon::query(int socket) {
  QueryType qt;
  tcputil::recvBytes<QueryType>(socket, &qt, 1);
  if (qt == QueryType::SET) {
    setHandler(socket);

  } else if (qt == QueryType::COMPARE_SET) {
    compareSetHandler(socket);

  } else if (qt == QueryType::ADD) {
    addHandler(socket);

  } else if (qt == QueryType::GET) {
    getHandler(socket);

  } else if (qt == QueryType::CHECK) {
    checkHandler(socket);

  } else if (qt == QueryType::WAIT) {
    waitHandler(socket);

  } else if (qt == QueryType::GETNUMKEYS) {
    getNumKeysHandler(socket);

  } else if (qt == QueryType::DELETE_KEY) {
    deleteHandler(socket);

  } else if (qt == QueryType::WATCH_KEY) {
    watchHandler(socket);

  } else {
    throw std::runtime_error("Unexpected query type");
  }
}

添加

此处是处理添加 value 的业务。

void TCPStoreMasterDaemon::setHandler(int socket) {
  std::string key = tcputil::recvString(socket);
  std::vector<uint8_t> newData = tcputil::recvVector<uint8_t>(socket);
  std::vector<uint8_t> oldData;
  bool newKey = true;
  auto it = tcpStore_.find(key);
  if (it != tcpStore_.end()) {
    oldData = it->second;
    newKey = false;
  }
  tcpStore_[key] = newData;
  // On "set", wake up all clients that have been waiting
  wakeupWaitingClients(key);
  // Send key update to all watching clients
  newKey ? sendKeyUpdatesToClients(
               key, WatchResponseType::KEY_CREATED, oldData, newData)
         : sendKeyUpdatesToClients(
               key, WatchResponseType::KEY_UPDATED, oldData, newData);
}
获取

出处处理获取 value 的业务。

void TCPStoreMasterDaemon::getHandler(int socket) const {
  std::string key = tcputil::recvString(socket);
  auto data = tcpStore_.at(key);
  tcputil::sendVector<uint8_t>(socket, data);
}

watchKey

此处添加了想要监控的 key。

对于WATCH_KEY,给对应的key添加了一个socket,作为以后发送通知的对象。

void TCPStoreMasterDaemon::watchHandler(int socket) {
  std::string key = tcputil::recvString(socket);

  // Record the socket to respond to when the key is updated
  watchedSockets_[key].push_back(socket);

  // Send update to TCPStoreWorkerDaemon on client
  tcputil::sendValue<WatchResponseType>(
      socket, WatchResponseType::KEY_CALLBACK_REGISTERED);
}
通知

如果key 有变化,就通知客户端。

void TCPStoreMasterDaemon::sendKeyUpdatesToClients(
    const std::string& key,
    const enum WatchResponseType& type,
    std::vector<uint8_t>& oldData,
    std::vector<uint8_t>& newData) {
  for (int socket : watchedSockets_[key]) {
    tcputil::sendValue<WatchResponseType>(socket, type);
    tcputil::sendString(socket, key, true);
    tcputil::sendVector<uint8_t>(socket, oldData);
    tcputil::sendVector<uint8_t>(socket, newData);
  }
}

4.2.7 总结

我们总结图例如下:

  • Master 之中使用MasterPort 进行监听请求。
  • 关于存取value。
    • Worker 之中,storeSocket_ 被用来存储/获取value,对应下图 数字 1。
    • 在 Master 之中对应了 tcpStore_。
  • 关于监控。
    • Worker 之中,listenSocket_ 被用来通知 Master 我需要监听这个 key,对应下图 数字 2。同时 worker 内部给这个 key 设置了 callback,对应了下图 数字 3。
    • 监听在 Master 之中对应了 watchedSockets_[key] = socket_
    • Master 之中,如果设置 value 时候,发现是一个被监控的 key,就通知 watchedSockets_[key],对应了下图 数字 4。
    • Worker 之中会进行相关业务调用。
                                                                          +
+----------------------------------------------------------------------+  |  +------------------------------------------------------------------------+
| TCPStore                                                      Master |  |  | TCPStore                                                        Worker |
|                                                                      |  |  |                                                                        |
|   storeSocket_                                                       |  |  |                                                                        |
|                                                                      |  |  |                                                                        |
|   +------------------------------------------------------------+     |  |  |                                                                        |
|   | TcpStoreMasterDaemon_                            MasterPort|     |  |  |  1   +---------------------------------+                               |
|   |                                                            | <--------------+ | set(key, value)                 |                               |
|   |   unordered_map<string, vector<uint8_t> > tcpStore_+---+   |     |  |  |      |                                 |                               |
|   |                                                        |   |     |  |  |      |    storeSocket_                 |                               |
|   |   TCPStore.masterListenSocket_                         |   |     |  |  |      |                                 |                               |
|   |                                                        |   |     |  |  |      +---------------------------------+                               |
|   |   +-----------------------------------------------+    |   |     |  |  |                                                                        |
|   |   |  run                                          |    |   |     |  |  |  2   +---------------------------------+                               |
|   |   |                                               |    |   | <--------------+ |                                 |                               |
|   |   |    queryFds     query                         |    |   |     |  |  |      | watchKey(key, callback) +-------------------------------+       |
|   |   |                                               |    |   |     |  |  |      |                                 |        3              |       |
|   |   |    setHandler   getHandler                    |    |   |     |  |  |      |    listenSocket_                |                       |       |
|   |   |                                               |    |   |     |  |  |      |                                 |                       |       |
|   |   +-----------------------------------------------+    |   |     |  |  |      |                                 |                       |       |
|   |                                                        |   |     |  |  |      +---------------------------------+                       |       |
|   +------------------------------------------------------------+     |  |  |                                                                |       |
|                                                            |         |  |  |                                                                |       |
|                                                            |         |  |  |                                                                |       |
|                                                            |         |  |  |   +----------------------------------------------------------------+   |
|                                                            |         |  |  |   | TCPStoreWorkerDaemon                                       |   |   |
|                                                            |         |  |  |   |                                                            |   |   |
|                                                            |         |  |  |   |   unordered_map<string, WatchKeyCallback> keyToCallbacks_  |   |   |
|                                                            |         |  |  |   |                                                            |   |   |
|                                                            |         |  |  |   |   TCPStore.listenSocket_                              +----+   |   |
|                                                            |         |  |  |   |                                                       |        |   |
|                                                            |         |  |  |   |  +----------------------------------------------------------+  |   |
|                                                            |         |  |  |   |  | run                                                |     |  |   |
|                                                            |     4   |  |  |   |  |                                                    |     |  |   |
|                                                            +--------------------->+                                                    v     |  |   |
|                                                                      |  |  |   |  |       callbackHandler +-----> keyToCallbacks_(callback)  |  |   |
|                                                                      |  |  |   |  |                                                          |  |   |
|                                                                      |  |  |   |  +----------------------------------------------------------+  |   |
|                                                                      |  |  |   +----------------------------------------------------------------+   |
+----------------------------------------------------------------------+  +  +------------------------------------------------------------------------+

手机如下:

[源码解析] PyTorch分布式(6) -------- DistributedDataParallel -- 初始化&store

至此,我们梳理了初始化方法和Store这两个概念,最终其实是Store这个概念在初始化过程中起了作用。我们也通过TCPStore 的分析知道了一个Store应该具备的功能,比如设置KV,监控某个key的变等等,正是这些功能才可以让若干进程彼此知道对方的存在。

下一篇我们介绍进程组的概念,敬请期待。

上一篇:pg_probackup


下一篇:pg 函数返回查询结果表