[2101] Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet

paper: arXiv
code: pytorch

main limitations of ViT

  1. straightforward tokenization of input images by hard split makes ViT unable to model local information, thus requiring more training samples than CNNs to achieve similar performance
  2. self-attention in ViT is not well-designed as CNNs for vision tasks, which contains redundancy and leads to limited features and more difficult training

[2101] Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet
Feature visualization of ResNet50, ViT-L/16 and T2T-ViT-24 trained on ImageNet. Green boxes highlight learned low-level structure features such as edges and lines; red boxes highlight invalid feature maps with zero or too large values. Note the feature maps visualized here for ViT and T2T-ViT are not attention maps, but image features reshaped from tokens. For better visualization, we scale the input image to size 1024x1024 or 2048x2048.

ResNet capture desired local structure (edges, lines, textures, etc.) progressively from bottom layer (conv1) to middle layer (conv25)
in ViT, structure information poorly modeled while global relations (e.g., the whole dog) captured by all attention blocks
note that ViT ignores local structure when directly splitting images to tokens with fixed length.

many channels in ViT have zero value
note that backbone of ViT is not efficient as ResNets and offers limited feature richness when training samples are not enough.

Contribution

  1. propose a progressive tokenization module to aggregate neighboring Tokens-to-Token, which can model local structure information of surrounding tokens and reduce length of tokens iteratively
  2. borrow architecture designs from CNNs to build transformer layers for improving feature richness, and find deep narrow architecture design in ViT brings much better performance

[2101] Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet
Comparison between T2T-ViT with ViT, ResNets and MobileNets when trained from scratch on ImageNet. Left: performance curve of MACs vs. top-1 accuracy. Right: performance curve of model size vs. top-1 accuracy.

Method

model architecture

  1. layer-wise T2T module: model local information of images, reduce length of tokens progressively
  2. efficient T2T-ViT backbone: draw global attention relation on tokens from T2T module

[2101] Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet
The overall network architecture of T2T-ViT. In the T2T module, the input image is first soft split as patches, and then unfolded as a sequence of tokens T 0 T_0 T0​. The length of tokens is reduced progressively in the T2T module (we use 2 iterations here and output T f T_f Tf​). Then the T2T-ViT backbone takes the fixed tokens as input and outputs the predictions. PE is Position Embedding.

token-to-token

aim to overcome limitation of simple tokenization in ViT
progressively structurize an image to tokens and model local structure information, so length of tokens reduced iteratively

[2101] Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet
Illustration of T2T process. The tokens T i T_i Ti​ are restructurized as an image I i I_i Ii​ after transformation and reshaping; then I i I_i Ii​ is split with overlapping to tokens T i + 1 T_{i+1} Ti+1​ again. Specifically, as shown in the pink panel, the four tokens (1, 2, 4, 5) of the input Ii are concatenated to form one token in T i + 1 T_{i+1} Ti+1​. The T2T transformer can be a normal Transformer layer or other efficient transformers like Performer layer at limited GPU memory.

re-structurization
given a sequence of tokens T i T_i Ti​ from preceding transformer layer, transform T i ′ T_i' Ti′​ by self-attention block
T i ′ = M L P ( M S A ( T i ) ) T_i'=MLP(MSA(T_i)) Ti′​=MLP(MSA(Ti​))
where, T i ∈ R L × C T_i\in R^{L\times C} Ti​∈RL×C, T i ′ ∈ R L × C T_i'\in R^{L\times C} Ti′​∈RL×C
tokens T i ′ T_i' Ti′​ will be reshaped as an image in spatial dimension
I i = R e s h a p e ( T i ′ ) I_i=Reshape(T_i') Ii​=Reshape(Ti′​)
where, R e s h a p e ( . ) Reshape(.) Reshape(.) re-organize T i ′ ∈ R L × C T_i'\in R^{L\times C} Ti′​∈RL×C to I i ∈ H × W × C I_i\in H\times W\times C Ii​∈H×W×C, with L = H × W L=H\times W L=H×W

soft split
model local structure information and reduce length of tokens
to avoid information loss in generating tokens from re-structurizated image, split image into patches with overlapping

  1. each patch is correlated with surrounding patches to establish a prior that there should be stronger correlations between surrounding tokens.
  2. tokens in each split patch are concatenated as one (red and blue box), so regional information is aggregated from surrounding pixels and patches.

T i + 1 = S S ( I i ) T_i+1=SS(I_i) Ti​+1=SS(Ii​)
where, S S ( . ) SS(.) SS(.) is soft split operation, implemented by nn.Unfold
in nn.Unfold, given a tenser X ∈ B × C × H × W X\in B\times C\times H\times W X∈B×C×H×W, a kxk-size kernel apply on F to capture X 1 ∈ C × X_1\in C\times X1​∈C×, which then reshaped into X 1 ′ ∈ C k 2 X_1'\in Ck^2 X1′​∈Ck2
get output tensor Y ∈ B × C k 2 × H 0 × W 0 Y\in B\times Ck^2\times H_0\times W_0 Y∈B×Ck2×H0​×W0​, with H 0 = ⌊ H − k + 2 p s + 1 ⌋ H_0=\lfloor \frac {H-k+2p}s+1\rfloor H0​=⌊sH−k+2p​+1⌋, W 0 = ⌊ W − k + 2 p s + 1 ⌋ W_0=\lfloor \frac {W-k+2p}s+1\rfloor W0​=⌊sW−k+2p​+1⌋
similarily, given I i ∈ H × W × C I_i\in H\times W\times C Ii​∈H×W×C, output tensor T i + 1 ∈ L 0 × C k 2 T_{i+1}\in L_0\times Ck^2 Ti+1​∈L0​×Ck2 got, with L 0 = ⌊ H − k + 2 p s + 1 ⌋ × ⌊ W − k + 2 p s + 1 ⌋ L_0=\lfloor \frac {H-k+2p}s+1\rfloor \times \lfloor \frac {W-k+2p}s+1\rfloor L0​=⌊sH−k+2p​+1⌋×⌊sW−k+2p​+1⌋
after soft split, output tokens are fed for the next T2T process

T2T module
based on transformer block, with 2 extra components

  1. reshape tokens into H × W × C H\times W\times C H×W×C image for learning more local information later
  2. unfold image into L 0 × C k 2 L_0\times Ck^2 L0​×Ck2 tokens and capture details for more efficient modeling in transformer later

for input image I 0 I_0 I0​, only apply soft split at first to split it to tokens: T 1 = S S ( I 0 ) T_1=SS(I_0) T1​=SS(I0​)
after last T2T module, output tokens T f T_f Tf​ has fixed length, so T2T-ViT backbone can model global relation on T f T_f Tf​

T2T-ViT backbone

  1. dense connection as DenseNet
  2. deep-narrow vs. shallow-wide structure as in wide-ResNets
  3. channel attention as Squeeze-and-Excitation networks
  4. more split heads in multi-head attention layer as ResNeXt
  5. Ghost operations as GhostNet

architecture variants

[2101] Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet
Structure details of T2T-ViT. T2T-ViT-14/19/24 have comparable model size with ResNet50/101/152. T2T-ViT-7/12 have comparable model size with MobileNetV1/V2. For T2T transformer layer, we adopt Transformer layer for T2T-ViTt-14 and Performer layer for T2T-ViT-14 at limited GPU memory. For ViT, ‘S’ means Small, ‘B’ is Base and ‘L’ is Large. ‘ViT-S/16’ is a variant from original ViT-B/16 with smaller MLP size and layer depth.

Experiment

上一篇:[洛谷1119]灾后重建


下一篇:【蓝桥杯】串口通信详解附双串口代码