论文解读:SCAFFOLD: Stochastic Controlled Averaging for Federated Learning
作者:Sai Praneeth Karimireddy, Satyen Kale, Mehryar Mohri,Sashank J. Reddi, Sebastian U. Stich, Ananda Theertha Suresh
论文地址:https://arxiv.org/abs/1910.06378
一、背景介绍
联邦学习是一种特殊的分布式学习,它有以下几个特点
-
客户端不稳定
-
数据在客户端上是Non-IID的
-
每轮只有部分客户端参加
-
负载不平衡
-
高昂的通信代价
通信代价是联邦学习一个主要的研究方向
影响通信代价的几个因素:
- 参数 X = ( x 1 , x 2 , . . , x m ) X = (x_1,x_2,..,x_m) X=(x1,x2,..,xm)的维度m的大小
- 梯度 g = ( t 1 , t 2 , . . . , t m ) g = (t_1,t_2,...,t_m) g=(t1,t2,...,tm)的维度m的大小
- 网络延迟
- 带宽
- 每次通信的客户端个数
减少通信代价的两种方式:
- 提高server全局模型的收敛速度(减少通信次数)
- 对参数或梯度进行数据压缩(减少单次传输代价)
目前研究联邦学习通信代价的论文有很多只是侧重于研究减少其中一个方面,如只减少通信次数,而不在乎单次传输的代价(本文也是这样设计的),然后博主以为要判断通信代价的高低,更加可信的指标是你达到收敛时的总通信数据量
总通信数据量 = 通信次数 * 单次传输的数据量
FedAvg联邦平均算法是FL领域一个开山鼻祖的算法,后续很多论文中提出的算法都是对它的一些小的改进。
FedAvg实际上就是通过增加本地更新的次数,来减少通信次数(每个客户端在本地多更新几次再进行一次和server的通信。FedAvg算法在IID数据和Non-IID数据下都是收敛的,然后FedAvg的收敛速度受限于数据集的分布,在Non-IID数据集中,FedAvg的收敛速度缓慢。
以下两篇论文就展示了这一点
Federated Learning with Non-IID Data
Measuring the Effects of Non-Identical Data Distribution for Federated Visual Classification
那有没有一种算法能够改进FedAvg的这个缺点,在Non-IID数据集上仍能有良好的效果呢?
二、本文工作
-
本文提出了一种叫做SCAFFOLD的算法,简单来说就是增加了一个额外的参数control variate来修正FedAvg出现的client-drift。
-
作者用严格的数学理论证明了该算法是收敛的。
-
实验部分证明了这种算法是不受客户端数据异质性(数据分布)的影响,并且可以加快收敛速度,从而显著减少通信次数。
三、SCAFFOLD算法核心
FedAvg更新公式:
本地更新k次
y
i
=
y
i
−
η
g
i
(
y
i
)
y_i = y_i - \eta g_i(y_i)
yi=yi−ηgi(yi)
与server进行一次通信
x
=
1
S
∑
i
∈
S
y
i
x = \frac{1}{S} \sum_{i \in S} y_i
x=S1i∈S∑yi
SCAFFOLD更新公式:
本地更新k次
y
i
=
y
i
−
η
(
g
i
(
y
i
)
+
c
−
c
i
)
y_i = y_i - \eta (g_i(y_i) + c - c_i)
yi=yi−η(gi(yi)+c−ci)
与server进行一次通信
x
=
1
S
∑
i
∈
S
y
i
x = \frac{1}{S} \sum_{i \in S} y_i
x=S1i∈S∑yi
SCAFFOLD算法相比FedAvg只是加了一个修正项 c − c i c-c_i c−ci,有了这个修正项可以有效的解决本地更新时的client-drift,让每个客户端在本地更新的时候“看到”其他客户端更新的信息,这里不是指真的看到,而是一种guess。
也就说我不再是每次只在自己的数据上训练,我还去猜测别人这一步可能在做什么。 有了这个猜测我就好像用所有的数据在本地训练模型一样。
对于这个算法,我不仅要更新参数,我还要更新我的这种猜测。
下面用两个Non-IID客户端损失函数的等高线图来模拟这个过程:
蓝色是客户端1的损失函数,红色是客户端2的损失函数,绿色是他们的平均损失函数
Optimum表示全局模型的最优点(收敛点),我们的目标是达到这个全局最优点
SGD更新过程:在本地更新一次就进行一次通信(路径与server理想的更新路径相近,但是通信代价大)
FedAvg更新过程:在本地更新k次才进行一次通信(由于Non-IID,路径与server理想的更新路径发生了偏移,即client-drift,导致收敛速度缓慢)
SCAFFOLD 偏移项
SCAFFOLD更新过程:在本地更新k次才进行一次通信(由于修正项的存在,每次本地更新都会被拉回到理想的更新路径附近,收敛速度较快)
更新参数
C
C
C和
C
i
C_i
Ci
SCAFFOLD算法伪代码
对比FedAvg算法伪代码
实验
Conclusion
创新点:
-
本文研究了数据异质性对联邦优化算法的影响。证明了FedAvg可能会因不同client的梯度差异收敛速度受到严重影响,甚至可能比SGD慢。
-
本文提出了一种针对Non-IID数据集的新算法SCAFFOLD,设计了一种更好的本地更新策略,利用control variates(guess)来克服client drift
-
实验表明SCAFFOLD有着较快的收敛速度,且不受client数据分布的影响,能有效减少通信次数
不足之处:
-
每次传给Server不仅要传模型的参数,还要传control variates(梯度和参数向量维度一样),这增加了单次通信的代价(2倍)
-
要求每个client是stateful的,因此只适用于少量的客户端(thousands级别)
-
本文用的是上一轮的梯度状态去猜测下一轮梯度应该往这个方向,这要求所有的loss function时smooth的
思考:
-
更新 C i C_i Ci时,能不能只用最后一个梯度去更新?
-
用总的通信量来评估你模型的优劣可能更有可信度。
-
C i C_i Ci在第3轮被选中,但下次可能在第100轮才被选中,这时的 C i C_i Ci存的还是第3轮被选中时参数所在位置的梯度,那会有误差吧?