Prompt-to-Prompt Image Editing with Cross Attention Control
TL; DR:prompt2prompt 提出通过替换 UNet 中的交叉注意力图,在图像编辑过程中根据新的 prompt 语义生图的同时,保持图像整体布局结构不变。从而实现了基于纯文本(不用 mask 等额外信息)的图像编辑。
导语
基于文本控制图像生成模型取得了飞速的进展,一个自然的想法是如何实现基于文本控制的图像编辑模型。图像编辑需要在保持原始图片构图不变的情况下,对部分元素或整体风格进行修改。然而,在文生图模型中,文本 prompt 中一个小小的修改(比如替换一个单词),通常都会使得生图结果完全不同。现有的图像编辑方法一般是通过要求用户同时提供编辑区域的 mask,并只在 mask 内进行图像编辑,从而保持图像整体构图布局不变。
本文中,作者深入分析了现有的文生图模型,发现交叉注意力层是控制图像空间布局与 prompt 中每个单词之间关系的关键。基于此,作者通过替换交叉注意力图,构建了一种仅需修改 prompt (无需额外 mask) 的图像编辑框架。并基于该框架介绍了三种图像编辑应用,包括替换单词的局部编辑、添加描述的全局编辑以及更改单词权重来修改单词在生图结果中的反映程度。
下图展示了 prompt2prompt 图像编辑的一些结果,可以看到,仅通过在 prompt 中替换、删除或添加几个单词,就能完成对应的图像编辑,关键是在这个过程过程中保持了图像整体空间布局和其他背景元素不变。
方法
记 I \mathcal{I} I 为文生图模型在 prompt 为 P \mathcal{P} P ,随机种子为 s s s 下生成的一张图像,我们的目标是仅通过修改 prompt 为 P ∗ \mathcal{P}^* P∗ ,同时固定随机种子 s s s,生成出语义符合 P ∗ \mathcal{P}^* P∗ 且空间构图与 I \mathcal{I} I 一致的编辑图像 I ∗ \mathcal{I}^* I∗。
图文交叉注意力
prompt2prompt 的方法注意力替换技术如下图所示。带噪图像
ϕ
(
z
t
)
\phi(z_t)
ϕ(zt) 映射为 query 矩阵
Q
=
l
Q
(
ϕ
(
z
t
)
Q=l_Q(\phi(z_t)
Q=lQ(ϕ(zt) ,文本 prompt 嵌入
ψ
(
P
)
\psi(\mathcal{P})
ψ(P) 映射为 key 矩阵
K
=
l
K
ψ
(
P
)
K=l_K\psi(\mathcal{P})
K=lKψ(P) 和 value 矩阵
V
=
l
V
ψ
(
P
)
V=l_V\psi(\mathcal{P})
V=lVψ(P) ,其中
l
Q
,
l
K
,
l
V
l_Q,l_K,l_V
lQ,lK,lV 是三个线性映射层。注意力图(attention maps)
M
M
M 为:
M
=
Softmax
(
Q
K
T
d
)
M=\text{Softmax}(\frac{QK^T}{\sqrt{d}})
M=Softmax(dQKT)
其中 d d d 是 key 和 value 的隐层维度,每个单元 M i j M_{ij} Mij 表示第 j j j 个文本 token 和第 i i i 个图像块 token 的相关性,该值会作为 value 矩阵的权重,从而交叉注意力层最终的输出为 ϕ ^ ( z t ) = M V \hat\phi(z_t)=MV ϕ^(zt)=MV 。
总结一下,交叉注意力层的输出 M V MV MV 是 value 矩阵 V V V 的加权平均,而各位置的权重就是注意力图 M M M , M M M 其实就是 query 矩阵 Q Q Q 和 key 矩阵 K K K 的各位置的相似度。
交叉注意力图替换
本文的关键发现是:生成图像的空间布局是由交叉注意力图决定的。该发现可由下图看出。图中展示了文本 prompt 中的各个单词对应的交叉注意力图,其中上方一行展示了交叉注意力图在所有时间步平均的结果,可以看到:像素位置与描述它的单词之间的注意力权重更高(如 bear、bird 都对应于各自的位置);下方两行展示了 bear 和 bird 两个单词在整个生图过程中每一时间步的对应注意力图,可以看到:文本单词与空间位置的对应关系在生图早期就已经确定下来了。这个可视化实验支撑了本文的关键发现:生成图像的空间布局是由交叉注意力图决定的。
既然已经知道了生图的空间构图是由交叉注意力图决定的,那要保持图像编辑过程中的空间构图的思路就很直接了:将第一次生图 I \mathcal{I} I 时的交叉注意力图 M M M 保存下来,在图像编辑生图 I ∗ \mathcal{I}^* I∗ 的过程中替换上去,让对应位置的单词还是在编辑生图的对应空间位置。这样就可以使得编辑图像 I ∗ \mathcal{I}^* I∗ 不仅语义与编辑 prompt P ∗ \mathcal{P}^* P∗ 一致,同时还能保持空间构图与原图 I \mathcal{I} I 一致。
接下来,本文提出了一种统一的图像编辑框架,并介绍了该框架下的三种图像编辑应用:分别是替换单词的局部编辑、添加描述的全局编辑以及更改单词权重来修改单词在生图结果中的反映程度。
记扩散模型的在时间步 t t t 的去噪过程为 D M ( z t , P , t , s ) DM(z_t,\mathcal{P},t,s) DM(zt,P,t,s) ,其输出为预测的上一时间步的噪声图像 z t − 1 z_{t-1} zt−1,我们同时保存该步去噪过程中的交叉注意力图,记为 M t M_t Mt。记 D M ( z t , P , t , s ) { M ← M ^ } DM(z_t,\mathcal{P},t,s)\{M\leftarrow \hat{M}\} DM(zt,P,t,s){M←M^} 为将交叉注意力图 M M M 替换为 M ^ \hat{M} M^ 的去噪步,注意其中的 value 矩阵 V V V 不做替换,仅替换交叉注意力图。我们定义 E d i t ( M t , M t ∗ , t ) Edit(M_t,M_t^*,t) Edit(Mt,Mt∗,t) 为交叉注意力图编辑函数,其中 M t ∗ M_t^* Mt∗ 为生成编辑图像时的交叉注意力图。编辑函数输入为:生成原始图片时的注意力图、生成编辑图片时的注意力图和时间步,输出为当前步的注意力图。
prompt2prompt 框架的算法伪代码如下所示:
注意 prompt2prompt 需要保持生成原始图片时的随机种子和生成编辑图片时的随机种子一致,因为在扩散模型中,即使文本 prompt 相同,如果随机种子不同,生成结果也会非常不同。
通过在上述框架下使用不同的编辑函数,就能实现不同的图像编辑应用。本文提到的三种应用图示见图 2 下侧,以下分别具体介绍三种图像编辑应用。
应用
word swap
通过单词替换来改变原图中的物体或属性。比如原 prompt 为 “a big red car”,编辑 prompt 为 “a big red bike”。由于 prompt 中 token 数目没变,这种情况可以直接通过交叉注意力替换来实现语义与编辑 prompt 一致,而构图与原图一致。但是,在去噪过程全程都进行注意力替换约束可能会过强,尤其是当编辑前后 prompt 的轮廓发生很大变化时(如 car -> bike)。考虑到生图的空间位置关系是在早期的去噪步中确定的,这里通过指定在某个时间步
τ
\tau
τ 之前进行注意力替换,来使得整体约束相对宽松,给编辑图像生图一定的*度:
E
d
i
t
(
M
t
,
M
t
∗
,
t
)
:
=
{
M
t
∗
if
t
<
τ
M
t
oherwise
Edit(M_t,M_t^*,t):= \begin{cases} M_t^*\ \ \ \text{if}\ t<\tau \\ M_t\ \ \ \text{oherwise} \end{cases}
Edit(Mt,Mt∗,t):={Mt∗ if t<τMt oherwise
如果编辑 prompt 中的新单词编码为多个 token,还可以对不同的 token 分别指定不同的
τ
\tau
τ ;如果编辑前后两个单词编码的 token 数不同,可以通过复制、平均或使用对齐函数来对齐。
adding a new phrase
用户还可以在原 prompt 中添加一些单词,得到编辑 prompt。比如原 prompt 为 “a castle next to a river”,编辑 prompt 为 “children drawing of a castle next to a river”。此时我们仅对编辑前后两个 prompt 中相同的单词进行注意力图替换。具体老说,我们使用一个对齐函数
A
A
A,将编辑 prompt 中的 token 的索引映射为原始 prompt 中相同 token 的索引,如果原始 prompt 没有就返回 None。此时编辑函数为:
(
E
d
i
t
(
M
t
,
M
t
∗
,
t
)
)
i
,
j
:
=
{
(
M
t
∗
)
i
,
j
if
A
(
j
)
=
N
o
n
e
(
M
t
)
i
,
j
oherwise
(Edit(M_t,M_t^*,t))_{i,j}:= \begin{cases} (M_t^*)_{i,j}\ \ \ \text{if}\ A(j)=None \\ (M_t)_{i,j}\ \ \ \text{oherwise} \end{cases}
(Edit(Mt,Mt∗,t))i,j:={(Mt∗)