https://arxiv.org/pdf/2010.11929.pdf
---------------------------------------------------------
2021-08-30
transformer缺少cnn的平移不变性,局部性:大规模数据集预训练可解决
class PatchEmbeddin(nn.Module): def __init__(self,in_channel:int = 3,patch_size:int = 16,emb_size:int=768,img_size:int = 224): super(PatchEmbeddin, self).__init__() self.patch_size=patch_size self.projection=nn.Sequential( nn.Conv2d(in_channel,emb_size,kernel_size=patch_size,stride=patch_size), Rearrange("b e (h) (w) -> b (h w) e"), ) self.cls_token=nn.Parameter(torch.randn(1,1,emb_size)) self.position=nn.Parameter(torch.randn((img_size//patch_size)**2+1,emb_size)) def forward(self,x:torch.Tensor)->torch.Tensor: b=x.size()[0] x=self.projection(x) cls_tokens=einops.repeat(self.cls_token,"() n e -> b n e",b=b) x=torch.cat([cls_tokens,x],dim=1) x+=self.position return x class MultiHeadAttention(nn.Module): def __init__(self,emb_size:int=768,num_headas:int=8,dropout:float=0): super(MultiHeadAttention, self).__init__() self.emb_size=emb_size self.num_heads=num_headas self.qkv=nn.Linear(emb_size,emb_size*3) self.ett_drop=nn.Dropout(dropout) self.projection=nn.Linear(emb_size,emb_size) def forward(self,x:torch.Tensor,mask:torch.Tensor=None)->torch.Tensor: qkvs=einops.rearrange(self.qkv(x),"b n (h d qkv) -> (qkv) b h n d",h=self.num_heads,qkv=3) queries,keys,values=qkvs[0],qkvs[1],qkvs[2] energy=torch.einsum("bhqd,bhkd -> bhqk",queries,keys) if mask is not None: fill_value=torch.finfo(torch.float32).min energy.mask_fill(~mask,fill_value) scaling=self.emb_size**(1/2) att=F.softmax(energy,dim=-1)/scaling att=self.ett_drop(att) out=torch.einsum("bhal,bhl -> bhav",att,values) out=einops.rearrange(out,"b h n d -> b n (h d)") out=self.projection(out) return out