Capturing the Attention Shift with MLLM-LAWS
对于具有 L L L 层的MLLM,我们定义Layer-level Attention Weights(MLLM-LAWS)为:
LAWS V ∗ = [ a V ∗ 1 , a V ∗ 2 , ⋯ , a V ∗ L ] , a V ∗ l = ∑ i = 0 s ∑ j ∈ V ∗ a i j l ( h l − 1 ) \text{LAWS}_{\mathcal{V}_{*}} = \left[ a_{\mathcal{V}_{*}}^{1}, a_{\mathcal{V}_{*}}^{2}, \cdots, a_{\mathcal{V}_{*}}^{L} \right], \quad a_{\mathcal{V}_{*}}^{l} = \sum_{i=0}^{s} \sum_{j \in \mathcal{V}_{*}} a_{ij}^{l} \left( h^{l-1} \right) LAWSV∗=[aV∗1,aV∗2,⋯,aV∗L],aV∗l=i=0∑sj∈V∗∑aijl(hl−1)
a i j l a_{ij}^{l} aijl 表示MLLM的第 l l l 层中第 i i i 个token对第 j j j 个token的注意力权重, h l h^{l} hl 表示第 l l l 层的隐藏状态, ∀ i , ∑ j = 0 s a i j l ( h l − 1 ) = 1 \forall i, \sum_{j=0}^{s} a_{ij}^{l} \left( \mathbf{h}^{l-1} \right) = 1 ∀i,∑j=0saijl(hl−1)=1, V ∗ \mathcal{V}_{*} V∗ 表示一个token序列,可以是 V itself \mathcal{V}_\text{itself} Vitself、 V before \mathcal{V}_\text{before} Vbefore 或 V after \mathcal{V}_\text{after} Vafter,分别表示视觉序列、视觉序列前的文本序列和视觉序列后的文本序列。
LAWS V ∗ \text{LAWS}_{\mathcal{V}_{*}} LAWSV∗ 可以表示MLLM对当前序列 V ∗ \mathcal{V}_{*} V∗ 在所有MLLM层上的注意力动态曲线。
- (a): 视觉编码器和投影器的视觉特征token被插入到文本特征序列中。
- (b): 文本层中在文本token的前图像、图像本身和后图像上的注意力权重比例。红色曲线来自在纯文本任务表现更好的MLLM,而蓝色曲线来自在纯文本任务表现更差的MLLM。
- ©: 实验表明,在100多个MLLM中,视觉token前后的皮尔逊相关系数与MLLM在纯文本上的表现存在正相关关系。
因此,注意力转移可通过视觉序列前后 LAWS \text{LAWS} LAWS 之间的皮尔逊相关系数量化为:
Attention Shift = E x [ − ρ ( LAWS V before , LAWS V after ) ] + 1 \text{Attention Shift} = \mathbb{E}_{\mathbf{x}} \left[ -\rho \left( \text{LAWS}_{\mathcal{V}_{\text{before}}}, \, \text{LAWS}_{\mathcal{V}_{\text{after}}} \right) \right] + 1 Attention Shift=Ex[−ρ(LAWSVbefore,LAWSVafter)]+1
WINGS
我们引入了与主要注意力平行的额外模块,作为增强Learner来补偿注意力转移。我们在一侧训练Visual Learner,减轻一些转移的注意力。然后,我们基于路由转移的注意力权重协同训练Visual Learner和Textual Learner。
WINGS由Low-Rank Residual Attention(LoRRA)模块构建,其中前一个隐藏状态作为Query,而视觉/文本特征作为Key和Value。训练从Visual Learner和Projector开始,接着是动态注意力Router。
Learner:
Learner ∗ ( Q = h l , K , V = x ∗ ) ∗ ∈ { V , T } = Softmax ( h l ( 1 + W Q ) ⋅ ( x ∗ ( 1 + W K ) ) ⊤ d head ) x ∗ ( 1 + W V ) W O \text{Learner}^* \left( \rm{Q} = \mathbf{h}^l, \rm{K}, \rm{V} = \mathbf{x}_* \right)_{* \in \{\mathbf{V}, \mathbf{T}\}} = \text{Softmax} \left( \frac{\mathbf{h}^l (1 + \mathbf{W}^{\rm{Q}}) \cdot (\mathbf{x}_* (1 + \mathbf{W}^{\rm{K}}))^\top}{\sqrt{d_{\text{head}}}} \right) \mathbf{x}_* (1 + \mathbf{W}^{\rm{V}}) \mathbf{W}^{\rm{O}} Learner∗(Q=hl,K,V=x∗)∗∈{V,T}=Softmax(dheadhl(1+WQ)⋅(x∗(1+WK))⊤)x∗(1+WV)WO
Router接受一个注意力权重作为输入,通过单层MLP和Softmax处理,然后将Learner的输出叠加到主注意力上。
Att WINGS = Att main + ∑ ∗ ∈ { V , T } Router ( a ) ⋅ Learner ∗ ( h l , x ∗ ) \text{Att}^{\text{WINGS}} = \text{Att}^{\text{main}} + \sum_{* \in \{V, T\}} \text{Router}(\mathbf{a}) \cdot \text{Learner}^* \left( \mathbf{h}^l, \mathbf{x}_* \right) AttWINGS=Attmain+∗∈{V,T}∑Router(a)⋅Learner∗(hl,x∗)
训练
WINGS的架构包含四个元素:Vision Encoder、Projector、LLM以及带有Router的Learner。
- Stage 1
- ❄️ Vision Encoder、LLM
- ???? Projector、Visual Learner
- Visual Learner的输出直接加到主分支上
- Stage 2
- ❄️ Vision Encoder
- ???? LLM、Projector、Router、Visual Learner、Textual Learner