受立体匹配(Stereo Matching)中代价聚合(cost aggregation)和 Transformers中 self-attention 的启发, 提出一个聚合模块, 它通过考虑实例之间的相似性来聚合标记和未标记数据的初始伪标签. 为了扩大当前 mini-batch之外的聚合候选者, 利用一个队列来记忆训练期间先前 batch 样本的信息, 从而提高可扩展性. 此外, 文中还提出了一个类平衡的置信度感知队列, 该队列是通过考虑置信度和类分布而构建的, 并使用动量进行更新, 从而鼓励更好的聚合. 最后, 通过使用队列的多个不相交子集测量可能的伪标签之间的一致性, 提出了一种新的伪标签置信度度量, 这有助于提高对噪声的鲁棒性.
SSL 近年来使用的比较主流的方法包括: Pseudo-Labeling
, Consistency Regularization
, Label Propagation
, 以及后来的各种 Holistic Methods
. 不同方法的主要思想对比如下:
上诉的 SSL 方法都存在一个限制, 即它们只能当模型经过充分训练时生成自信的伪标签. 在这种情况下, 训练早期的不成熟模型可能会根据伪标签的质量而显着退化. 例如, 如果标记数据
X
\mathcal{X}
X 与未标记数据
U
\mathcal{U}
U 分布不同,
X
\mathcal{X}
X 的大小不够, 或者
X
\mathcal{X}
X 和
U
\mathcal{U}
U 受到噪声和偏差的污染, 则伪标签会过度拟合错误的伪标签, 从而导致优化模型的参数得不到预期的结果, 也称为确认偏差. 为了防止上述问题, 基于置信度的策略与伪标签相结合, 如 FixMatch, 通过设置预定义的置信度阈值来关注具有高置信度的伪标签. 然而, 在训练的早期迭代中, 用简单的阈值估计伪标签的置信度的约束过强.
在本文中, 提出了一种新的 SSL 框架: AggMatch, 它使用不同实例之间的关系来实现更自信的伪标签, 如下图所示.
AggMatch 与 FixMatch 相比, 其通过聚合初始伪标签, AggMatch 生成更自信的伪标签, 从而提高分类性能.
文中符号系统定义如下:
- X = { ( x b , y b ) , b ∈ ( 1 , … , B ) } \mathcal{X}=\{(x_b,y_b),b\in(1,\dots,B)\} X={(xb,yb),b∈(1,…,B)} 为标记数据集, 其中标签总数为 Y Y Y, B B B 为 batch_size.
- U = { ( u b ) : b ∈ ( 1 , … , μ B ) } \mathcal{U}=\{(u_b):b\in(1,\dots,\mu B)\} U={(ub):b∈(1,…,μB)} 为未标记数据集, 其中 μ \mu μ 是一个超参数, 用于确定 U \mathcal{U} U 相对于 X \mathcal{X} X 的大小.
- p m o d e l ( y ∣ r ; θ ) p_{model}(y\vert r;\theta) pmodel(y∣r;θ) 为同时利用 X \mathcal{X} X 和 U \mathcal{U} U 训练的模型, r r r 为实例. p m o d e l p_{model} pmodel 由两部分组成: 1.特征提取器, 从实例 r r r 中提取特征 v \mathbf{v} v, 2.分类器, 从 v \mathbf{v} v 中预测分类概率 P \mathrm{P} P.
1. AggMatch 算法概述
在 AggMatch 中, 主要研究如何准确测量实例之间的相似性, 以及如何通过考虑相似性来有效地聚合候选的类分布. 受立体匹配(Stereo Matching)中代价聚合(cost aggregation)和 Transformers中 self-attention 的启发, 提出了通过 query、key 和 value 来测量特征嵌入的相似性. query 可以是实例的特征 v b \mathbf{v}_b vb, keys 可以是 X \mathcal{X} X 和 U \mathcal{U} U 中的其他特征 v l \mathbf{v}_l vl, 以及类分布 P b \mathrm{P}_b Pb 和 P l \mathrm{P}_l Pl. 然后将相应的初始类分布 P l \mathrm{P}_l Pl 视为 values.
聚合分布的质量很大程度上取决于 batch_size 的大小, 但有限的 GPU 内存无法保证 big batch_size. 为了解决这个问题, 利用一个队列来记忆训练期间上一次 batch 的信息. 基于此, 有选择地使用动量模型将置信样本加入队列, 并且队列缓慢而稳定地变化, 从而实现一致的传播. 同时还通过使用队列的多个子集测量其多个假设之间的一致性来测量伪标签的置信度. 整体网络架构如下图所示:
对于通过弱增强( α ( u b ) \alpha(u_b) α(ub))和强增强( A ( u b ) \mathcal{A}(u_b) A(ub))而受到不同扰动的图像, 对模型进行正则化以在它们之间生成一致的预测. 聚合模块在一个大而一致的队列的帮助下细化了弱增强图像的类分布, 然后置信度估计器根据随机均匀划分出队列, 利用这些队列生成的假设来估计被选为不确定的不可靠样本.
2. 类分布聚合(Class Distribution Aggregation)
聚合模块旨在利用其他置信度感知样本之间的关系来改进嘈杂的类分布. 具体来说,
u
b
u_b
ub 的类分布
P
b
\mathrm{P}_b
Pb 由
X
\mathcal{X}
X 和
U
\mathcal{U}
U 的集合中的其他实例
r
l
r_l
rl 的
p
l
\mathrm{p}_l
pl 聚合而成,
P
b
\mathrm{P}_b
Pb 和
P
l
\mathrm{P}_l
Pl 之间的相似性为:
P
‾
b
=
∑
l
(
exp
(
S
(
u
b
,
r
l
)
/
τ
s
i
m
)
∑
j
(
u
b
,
r
j
)
/
τ
s
i
m
)
)
P
l
(2)
\overline{\mathrm{P}}_b=\sum_{l}(\frac{\exp(\mathcal{S}(u_b,r_l)/\tau_{\mathrm{sim}})}{\sum_j(u_b,r_j)/\tau_{\mathrm{sim}})})\mathrm{P}_l \tag{2}
Pb=l∑(∑j(ub,rj)/τsim)exp(S(ub,rl)/τsim))Pl(2)
其中
l
l
l 和
j
j
j 是
X
\mathcal{X}
X 和
U
\mathcal{U}
U 中所有样本的索引,
τ
s
i
m
\tau_{\mathrm{sim}}
τsim 是温度参数. 相似度函数
S
\mathcal{S}
S 测量实例之间的注意力权重.
一种更直接的方法是仅考虑基于 class similarity term
的类分布
P
\mathrm{P}
P, 它忽略了类分布中涉及的噪声. 另外使用 feature similarity term
来定义
u
b
u_b
ub 和
r
l
r_l
rl 之间的特征相似性:
S
(
u
b
,
r
l
)
=
(
v
b
⋅
v
l
)
/
∥
v
b
∥
∥
v
l
∥
+
λ
s
i
m
J
S
(
P
b
∥
P
l
)
(3)
\mathcal{S}(u_b,r_l)=(\mathbf{v}_b \cdot \mathbf{v}_l)/\lVert \mathbf{v}_b\rVert\lVert\mathbf{v}_l\rVert+\lambda_{\mathrm{sim}}\mathrm{JS}(\mathrm{P}_b\lVert \mathrm{P}_l) \tag{3}
S(ub,rl)=(vb⋅vl)/∥vb∥∥vl∥+λsimJS(Pb∥Pl)(3)
其中第一项表示特征相似度, 第二项表示类相似度,
λ
s
i
m
\lambda_{\mathrm{sim}}
λsim 是权重参数. 通过余弦相似度来测量特征之间的相似度, 使用 Jensen-Shannon 距离(JS 散度)来测量分布
P
b
\mathrm{P}_b
Pb 和
P
l
\mathrm{P}_l
Pl 之间的相似性. 与传统的一致性正则化方法不同, 分布聚合显式地利用不同实例之间的关系来生成更自信的伪标签.
3. 类平衡的置信度感知队列
队列定义为 Q = { ( v l , P l ) } \mathcal{Q} = \{(\mathbf{v}_l, \mathrm{P}_l)\} Q={(vl,Pl)}, 在每个训练步骤中, 当前 batches 入队, 最早的 batches 出队, 使队列能够在所有样本中隐式地强加知识集合. 为了防止噪声样本的传播, 忽略了低于阈值 τ \tau τ 的低置信度预测. 在 PAWS 中, 其也试图利用与标记样本的相似性, 但传播仅依赖于非常稀疏的标记样本. 与此不同的是, 由于提出了自信感知队列, 我们可以通过使用未标记的样本轻松地扩大候选样本. 另外, 在 MSSR 算法中也利用到相似性, 其利用孪生神经网络进行相似度度量.
- Class Balancing. 类平衡队列中, 每个样本都根据其伪标签预测进行排队. 类 y y y 的队列定义为 Q y = { ( v y , l , P y , l ) : l ∈ ( 1 , L ) } \mathcal{Q}_y = \{(\mathbf{v}_{y,l}, \mathrm{P}_{y,l}):l \in (1,L)\} Qy={(vy,l,Py,l):l∈(1,L)}, L L L 为每个类对应的样本数(对于所有类都是相同的).
- Momentum Update. 一种用于队列更新的动量技术, 它缓解了队列中的特征与当前 batch 之间的不一致问题. 参数为
θ
m
\theta_m
θm 的动量模型由动量系数
λ
m
\lambda_m
λm 和模型参数
θ
\theta
θ 控制:
θ m ← λ m θ m + ( 1 − λ m ) θ m (4) \theta_m \leftarrow \lambda_m\theta_m+(1-\lambda_m)\theta_m \tag{4} θm←λmθm+(1−λm)θm(4)
这其实也是 EMA, 在 Mean-Teacher 中利用其进行 Teacher Model 的参数更新.
4. 伪标签的置信度估计
FixMatch 使用一种简单的阈值技术作为伪标签的置信度度量, 但它需要一个手动调整的阈值参数, 并且在早期迭代中施加了过强的约束, 导致大多数可能的置信伪标签将被拒绝. 为了解决这个问题, AggMatch 提出基于多个伪标签假设之间的共识来衡量聚合伪标签 P ‾ b \overline{\mathrm{P}}_b Pb 的置信度.
为了生成这些假设, 队列
Q
\mathcal{Q}
Q 在每个类中被均匀且随机地划分为
M
M
M 个不相交的子集, 即
Q
m
\mathcal{Q}^m
Qm, 其中
m
=
1
,
…
,
M
m=1,\dots,M
m=1,…,M. 给定子集
Q
m
\mathcal{Q}^m
Qm, 像式子(1)一样, 聚合实现
P
‾
b
m
\overline{\mathrm{P}}_b^m
Pbm, 然后, 聚合的类概率会经历一个投票过程. 具体来说, 我们总结了每个预测的 one-hot 编码, 使得
a
b
=
1
M
∑
m
e
b
m
a_b = \frac{1}{M}\sum_m e^m_b
ab=M1∑mebm, 其中
e
b
m
e^m_b
ebm 表示通过在
P
‾
b
m
\overline{\mathrm{P}}_b^m
Pbm 上操作
arg max
\argmax
argmax 的 one-hot 向量. 根据这种经验概率
a
b
a_b
ab, 它表示假设中每个类别的出现, 通过熵测量伪标签的置信度
c
‾
b
\overline{c}_b
cb, 使得
c
‾
b
=
exp
(
∑
a
b
log
a
b
)
\overline{c}_b = \exp (\sum a_b \log a_b)
cb=exp(∑ablogab). 最终的伪标签
P
‾
b
\overline{\mathrm{P}}_b
Pb 可以简单地通过对
m
m
m 的
P
‾
b
m
\overline{\mathrm{P}}_b^m
Pbm 进行平均来实现. 整体置信度估计过程如下图所示:
5. 损失函数
无标记样本的损失定义为:
L
U
=
1
μ
B
∑
b
=
1
μ
B
c
‾
b
D
(
q
‾
b
,
p
m
o
d
e
l
(
y
∣
A
(
u
b
)
;
θ
)
)
(5)
\mathcal{L_U}=\frac{1}{\mu B} \sum_{b=1}^{\mu B} \overline{c}_b \mathcal{D}(\overline{\mathrm{q}}_b,p_{model}(y\vert \mathcal{A}(u_b);\theta)) \tag{5}
LU=μB1b=1∑μBcbD(qb,pmodel(y∣A(ub);θ))(5)
其中
q
‾
b
\overline{\mathrm{q}}_b
qb 为聚合伪标签,
c
‾
b
\overline{\mathrm{c}}_b
cb 自信预测结果. 伪标签
q
‾
b
\overline{\mathrm{q}}_b
qb 通过使用温度缩放
T
T
T 锐化
P
‾
b
\overline{\mathrm{P}}_b
Pb 生成.
标记样本的监督损失定义为:
L
S
=
1
B
∑
b
=
1
B
D
(
y
b
,
p
m
o
d
e
l
(
y
∣
α
(
x
b
)
;
θ
)
)
(6)
\mathcal{L_S}=\frac{1}{B} \sum_{b=1}^{B}\mathcal{D}(y_b,p_{model}(y\vert \alpha(x_b);\theta)) \tag{6}
LS=B1b=1∑BD(yb,pmodel(y∣α(xb);θ))(6)
最终的损失函数定义为:
L
=
L
S
+
λ
L
U
(7)
\mathcal{L}=\mathcal{L_S}+\lambda\mathcal{L_U} \tag{7}
L=LS+λLU(7)
完整的 AggMatch 算法如下: