[源码解析] PyTorch 分布式之弹性训练(1) --- 总体思路
0x00 摘要
在前面的文章之中,我们已经学习了PyTorch 分布式的基本模块,介绍了官方的几个例子,我们接下来会介绍PyTorch的弹性训练,本文是第一篇,介绍其历史和设计理念,也会与Horovod做一下对比。
注:后续会对Horovod文章进行系统整理,届时对本文进行更新,会加入更多对比分析。
0x01 痛点
因为机器学习的模型越来越庞大,单个GPU显存早已无法容纳模型参数,所以一般都是使用大量节点或者集群进行训练,随着训练规模扩大,硬件薄弱或设计原因会导致单点故障概率随之增加,这就带来了一些问题或者痛点,比如:
-
痛点 1:缺少容灾功能。
- 问题点:单个节点故障往往会导致整个训练job结束。虽然框架提供了checkpoint功能,但是频繁调用会导致性能问题,所以依然会丢失一段时间的训练成果,并且还得继续进行任务排队。
- 理想状态:单个节点失败不会影响整体训练,在节点故障时候,自动剔除该节点,同时训练继续平滑进行。
-
痛点 2:缺少弹性算力感知和动态训练扩缩容机制。
- 问题点:用户只能在提交任务时候确定所需要的固定静态资源,无法对集群资源进行实时动态感知,导致集群资源利用率低。
- 理想状态:应该在有少量空闲机器时候就开始训练,当有更多资源时候,弹性任务同上层调度系统可以和i进行配合,从而能有效检测到这些潜在资源,在训练过程中可以自动增加worker数量。当本任务有空闲算力时候,会自动释放资源。而且在worker数量变化时,不会中断训练任务,做到平滑过渡。
-
痛点 3:集群资源配置/调度机制不灵活
- 问题点:目前不支持动态配置worker,不支持高优先级抢占实例。因此当资源不足时,无法按需为其他高优先级业务腾出资源, 只能等待任务自己主动终止或者出错终止。
- 理想状态:训练任务可以被抢占,可以主动腾出资源,可以在不同用途/配置的机器间进行漂移。
0x02 难点
我们接下来看看实现弹性训练需要面对哪些挑战和难点,这里只从工程角度来看,不考虑数据切分/学习率/batch size 调整等问题。
- 难点1 :需要一个节点/进程之间彼此发现的机制。
节点/训练进程自动进入或者退出时候,其他节点/训练进程如何感知。
- 难点2:如何处理成员变更
当发现有成员变更之后,如何处理。
- 难点3:如何捕获单个进程训练失败。
如何在单个节点上管理所有训练进程,从而当某个进程发生错误时候,可以捕获其失败,或者重试或者重启该进程。
- 难点4:如何与现有训练代码集成。
如何以最小工作量与现有代码进行集成,不引入复杂的抽象,或者对用户最大限度屏蔽底层实现。
0x03 TorchElastic
我们接下来看看 PyTorch 弹性机制总体状况。PyTorch 弹性机制是合并自 TorchElastic https://github.com/pytorch/elastic,所以我们就从 TorchElastic 看起。
TorchElastic(TE)是从 PyTorch 1.9 正式引入的,我们从两个地方看弹性训练的i历史。
3.1 历史
3.1.1 PyTorch 1.7
Release note 里面提到了 TE 已经被加入了 PyTorch Docker 之中:
[Stable] TorchElastic now bundled into PyTorch docker image
Torchelastic提供了"torch.distributed.launch" CLI的一个严格超集,并添加了容错和弹性功能。如果用户对容错不感兴趣,他们可以通过设置"max_restarts=0"来获得准确的功能/行为,并增加自动分配"RANK"和"MASTER_ADDR"端口的便利性(而不是在"torch.distributed.launch"中手动指定)。
3.1.2 PyTorch 1.9
从TE repository 可以看到,TE现在已经被合并到主版本 PyTorch 1.9 之中。
IMPORTANT: This repository is deprecated.
- TorchElastic has been upstreamed to PyTorch 1.9 under
torch.distributed.elastic
. Please refer to the PyTorch documentation here.
3.2 设计理念
PET经历了两个版本,从 https://github.com/pytorch/elastic/blob/master/design/torchelastic/0.2.0/design_doc.md 可以看到其设计理念。
3.2.1 基本功能
PyTorch Elastic Trainer (PET) 提供了一个可以用容错和弹性方式跨集群来训练模型的框架。PET 通过两种方式提供这些功能:
- 当 PyTorch worker 进程抛出某类可重试错误时,它会被 PET 捕获并重试训练过程。
- 只要worker的数量维持在开始工作时指定的范围内,新worker就可以随时离开或加入到现有训练job的进程池。当成员发生变化时,所有worker会重新集合(re-rendezvous)以建立一个新的进程组,并从以前的良好状态之中恢复训练。
为了与 PET 集成,PyTorch 用户需要对其训练逻辑进行以下更改:
- 用户需要使 PET 来控制他们的训练循环。
- 本质上,用户提供了一个“内部训练”循环,该循环被 PET 包裹在一个可重试的循环中。
- PET循环是可重试的循环,其负责建立或重新建立过程组,以及将用户的训练恢复到良好状态。
- 在新worker加入进程池时,用户需要指定状态是什么以及如何把状态施加到一个新worker之上。
3.2.2 新设计概述
PET v0.2 从 v0.1 之中获取了不少经验,下面讲讲 v0.2的设计理念。
- 动态范围
在 PET v.0.2 中,我们不再尝试恢复训练函数中的错误。相反,PET 尝试维护工作进程的数量,使它们保持在作业所需的 [ min , max ] 范围内。应用编写者负责从现有可用还原点文件加载和重新启动。与 v0.1 不同,PET v0.2 不强制指定如何管理checkpoints。应用编写者可以任意使用torch.save
和 torch.load
或更高层次的框架如PyTorch Lightening 进行处理。
- 本地代理
PET v0.2 使用了一个名为elastic-agent
的新进程,每个节点有一个独立的elastic-agent
。每个代理进程只负责管理该节点的一组本地工作进程,并与本作业其他节点上的弹性代理一起协调来确定进程组成员身份的变化。具体如下图所示:
图片来源:https://github.com/pytorch/elastic/raw/master/design/torchelastic/0.2.0/torchelastic_diagram.jpg
- 成员变更
成员变更的处理方式如下:当一个工作进程失败时,管理它的弹性代理会杀死该节点上的所有worker,然后与其他代理建立一个集合操作(rendezvous),并使用新的集合信息来重启worker。但是,当代理以非零错误代码退出时,应该由上层调度模块(例如 Kubernetes)来重新启动代理(同理,此代理将重新启动它负责的所有worker)。相同的恢复机制也适用于节点级故障。编排工具(诸如 Kubernetes )会调度作业以便job可以使用最小数目的代理副本运行,然后每个代理将依次编排用户的训练脚本。
- 兼容
要采用 PET v0.2,应用程序只需让其入口点或main函数与PyTorch distributed launcher兼容 。我们期望通过分布式启动器启动的分布式训练作业可以通过弹性代理无缝启动,无需更改或最小化代码更改。唯一的区别是在后一种情况下,应用程序将能够在出现某些故障的情况下依然取得进展。
3.2.3 bare-bones
新的PET设计是想成为一个“bare-bones”:它从简单性和健壮性两方面权衡了应用程序可恢复的粒度。
将来,TE 希望为检查点机制提供更多更方便的API,开发人员可以选择使用这些API来实现更高效的重启语义。
因为 PET 是 “bare-bones”,所以对用户如何处理也给了一些指导性意见,比如 checkpoint 的处理。
一旦发生故障或成员变更,所有幸存的worker将立即被杀掉。所以用户需要手动地处理 checkpoint,定期保存你的工作进度,来保证重启后训练能够继续下去。检查点的频率应取决于用户job对于失败的容忍度。建议用户脚本采用如下结构进行处理:
def main():
load_checkpoint(checkpoint_path)
initialize()
train()
def train():
for batch in iter(dataset):
train_step(batch)
if should_checkpoint:
save_checkpoint(checkpoint_path)
3.3 小结
不难发现,TE的设计理念主要就是回答了之前提到的4个难点。
- 难点1 :需要一个节点/进程之间彼此发现的机制。
TE的答案是:当成员发生变化时,所有worker会重新集合(re-rendezvous)以建立一个新的进程组。rendezvous就是这个发现机制。
- 难点2:如何处理成员变更
TE的答案是:当一个工作进程失败时,管理它的弹性代理会杀死该节点上的所有worker,然后与其他代理建立一个集合操作(rendezvous),并使用新的集合信息来重启worker。但是,当代理以非零错误代码退出时,应该由上层调度模块(例如 Kubernetes)来重新启动代理(同理,此代理将重新启动它负责的所有worker)。
- 难点3:如何捕获单个进程训练失败,如何在单个节点上管理所有训练进程。
TE的答案是:每个代理进程只负责管理该节点的一组本地工作进程,并与本作业其他节点上的弹性代理一起协调来确定进程组成员身份的变化。
- 难点4:如何与现有训练代码集成。
TE的答案是:应用程序只需让其入口点或main函数与PyTorch distributed launcher兼容 。
0x04 问题
4.1 VS Horovod
因为我们已经有了 Horovod 弹性训练的基础,所以我们就用 Horovod 为基准,提出一系列问题,然后去 PyTorch 探寻比对。
-
如何管理本地训练进程?
Horovod 通过后台 Driver 进程来管理本地训练进程。
TE 通过后台 Agent 进程来管理本地训练进程。
-
如何保存状态?
- Horovod 提供了内置实现,在每次训练间隙,使用 state.commit() 完成checkpoint。
- TE 需要用自己实现保存/加载 checkpoint。
-
如何发现新节点?
- Horovod 让用户自己实现节点发现的逻辑,这需要用户提供一个
discovery_hosts.sh
,其中指定了正在参与训练的节点。Horovod 会定期执行这个脚本来发现当前节点。 - TE 利用分布式一致性中间件 ETCD 或者自带的 C10D后端(基于TcpStore)来解决节点之间互相发现的问题。
- Horovod 让用户自己实现节点发现的逻辑,这需要用户提供一个
-
如何捕获异常?
-
Horovod 捕获集合通信异常/节点异常/扩缩容,转换为Horovod自己的Exception,然后会依据配置重(比如内部建立异常节点黑名单)新建立环,继续训练。
-
TE定义了一个monitor方法,定时调用来监控本地进程异常,转换为内部状态数值,进行处理,如果有一个worker出现了问题,则该node上的agent会重启本node的所有worker进行新一轮rendezvous,因为是新一轮 rendezvous,所以其他节点也会重启其worker,然后大家一起继续训练。
-
4.2 TE 问题
下面是关于一些TE内部的问题,我们后续分析会逐步解答这些问题。
- RANK 和 WORLD_SIZE 这些字段不再需要手动设置,如何做到?
- 如何在不同的节点间确定 RANK?
RANK 0
的实例会作为 master 的角色存在? - worker 失败之后,如何实现重启worker操作?
- TE 发现了新worker 之后,如何处理?
- 每个代理上有一个 rendezvous,这些 rendezvous 有master,slave 概念吗?有一个master专门记录当前集群状态嘛?
- 如何支持动态地增加或减少参与训练的 worker 数量?
0x05 PyTorch分布式系列
PyTorch分布式其他文章如下:
[源码解析]PyTorch如何实现前向传播(1) --- 基础类(上)
[源码解析]PyTorch如何实现前向传播(2) --- 基础类(下)
[源码解析] PyTorch如何实现前向传播(3) --- 具体实现
[源码解析] Pytorch 如何实现后向传播 (1)---- 调用引擎
[源码解析] Pytorch 如何实现后向传播 (2)---- 引擎静态结构
[源码解析] Pytorch 如何实现后向传播 (3)---- 引擎动态逻辑
[源码解析] PyTorch 如何实现后向传播 (4)---- 具体算法
[源码解析] PyTorch 分布式(1)------历史和概述
[源码解析] PyTorch 分布式(2) ----- DataParallel(上)
[源码解析] PyTorch 分布式(3) ----- DataParallel(下)
[源码解析] PyTorch 分布式(4)------分布式应用基础概念
[源码解析] PyTorch分布式(5) ------ DistributedDataParallel 总述&如何使用
[源码解析] PyTorch分布式(6) ---DistributedDataParallel -- 初始化&store
[源码解析] PyTorch 分布式(7) ----- DistributedDataParallel 之进程组
[源码解析] PyTorch 分布式(8) -------- DistributedDataParallel之论文篇
[源码解析] PyTorch 分布式(9) ----- DistributedDataParallel 之初始化
[源码解析] PyTorch 分布式(10)------DistributedDataParallel 之 Reducer静态架构
[源码解析] PyTorch 分布式(11) ----- DistributedDataParallel 之 构建Reducer和Join操作
[源码解析] PyTorch 分布式(12) ----- DistributedDataParallel 之 前向传播
[源码解析] PyTorch 分布式(13) ----- DistributedDataParallel 之 反向传播
[源码解析] PyTorch 分布式 Autograd (1) ---- 设计
[源码解析] PyTorch 分布式 Autograd (2) ---- RPC基础
[源码解析] PyTorch 分布式 Autograd (3) ---- 上下文相关
[源码解析] PyTorch 分布式 Autograd (4) ---- 如何切入引擎
[源码解析] PyTorch 分布式 Autograd (5) ---- 引擎(上)
[源码解析] PyTorch 分布式 Autograd (6) ---- 引擎(下)
[源码解析] PyTorch分布式优化器(1)----基石篇
[源码解析] PyTorch分布式优化器(2)----数据并行优化器
[源码解析] PyTorch分布式优化器(3)---- 模型并行
[源码解析] PyTorch 分布式(14) --使用 Distributed Autograd 和 Distributed Optimizer
[源码解析] PyTorch 分布式(15) --- 使用分布式 RPC 框架实现参数服务器
[源码解析] PyTorch 分布式(16) --- 使用异步执行实现批处理 RPC
[源码解析] PyTorch 分布式(17) --- 结合DDP和分布式 RPC 框架
[源码解析] PyTorch 分布式(18) --- 使用 RPC 的分布式管道并行