Git Product home page Git Product logo

vit-pytorch's Introduction

Table of Contents

Vision Transformer - Pytorch

Implementation of Vision Transformer, a simple way to achieve SOTA in vision classification with only a single transformer encoder, in Pytorch. Significance is further explained in Yannic Kilcher's video. There's really not much to code here, but may as well lay it out for everyone so we expedite the attention revolution.

For a Pytorch implementation with pretrained models, please see Ross Wightman's repository here.

The official Jax repository is here.

A tensorflow2 translation also exists here, created by research scientist Junho Kim! 🙏

Flax translation by Enrico Shippole!

Install

$ pip install vit-pytorch

Usage

import torch
from vit_pytorch import ViT

v = ViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

img = torch.randn(1, 3, 256, 256)

preds = v(img) # (1, 1000)

Parameters

  • image_size: int.
    Image size. If you have rectangular images, make sure your image size is the maximum of the width and height
  • patch_size: int.
    Size of patches. image_size must be divisible by patch_size.
    The number of patches is: n = (image_size // patch_size) ** 2 and n must be greater than 16.
  • num_classes: int.
    Number of classes to classify.
  • dim: int.
    Last dimension of output tensor after linear transformation nn.Linear(..., dim).
  • depth: int.
    Number of Transformer blocks.
  • heads: int.
    Number of heads in Multi-head Attention layer.
  • mlp_dim: int.
    Dimension of the MLP (FeedForward) layer.
  • channels: int, default 3.
    Number of image's channels.
  • dropout: float between [0, 1], default 0..
    Dropout rate.
  • emb_dropout: float between [0, 1], default 0.
    Embedding dropout rate.
  • pool: string, either cls token pooling or mean pooling

Simple ViT

An update from some of the same authors of the original paper proposes simplifications to ViT that allows it to train faster and better.

Among these simplifications include 2d sinusoidal positional embedding, global average pooling (no CLS token), no dropout, batch sizes of 1024 rather than 4096, and use of RandAugment and MixUp augmentations. They also show that a simple linear at the end is not significantly worse than the original MLP head

You can use it by importing the SimpleViT as shown below

import torch
from vit_pytorch import SimpleViT

v = SimpleViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048
)

img = torch.randn(1, 3, 256, 256)

preds = v(img) # (1, 1000)

NaViT

This paper proposes to leverage the flexibility of attention and masking for variable lengthed sequences to train images of multiple resolution, packed into a single batch. They demonstrate much faster training and improved accuracies, with the only cost being extra complexity in the architecture and dataloading. They use factorized 2d positional encodings, token dropping, as well as query-key normalization.

You can use it as follows

import torch
from vit_pytorch.na_vit import NaViT

v = NaViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1,
    token_dropout_prob = 0.1  # token dropout of 10% (keep 90% of tokens)
)

# 5 images of different resolutions - List[List[Tensor]]

# for now, you'll have to correctly place images in same batch element as to not exceed maximum allowed sequence length for self-attention w/ masking

images = [
    [torch.randn(3, 256, 256), torch.randn(3, 128, 128)],
    [torch.randn(3, 128, 256), torch.randn(3, 256, 128)],
    [torch.randn(3, 64, 256)]
]

preds = v(images) # (5, 1000) - 5, because 5 images of different resolution above

Or if you would rather that the framework auto group the images into variable lengthed sequences that do not exceed a certain max length

images = [
    torch.randn(3, 256, 256),
    torch.randn(3, 128, 128),
    torch.randn(3, 128, 256),
    torch.randn(3, 256, 128),
    torch.randn(3, 64, 256)
]

preds = v(
    images,
    group_images = True,
    group_max_seq_len = 64
) # (5, 1000)

Distillation

A recent paper has shown that use of a distillation token for distilling knowledge from convolutional nets to vision transformer can yield small and efficient vision transformers. This repository offers the means to do distillation easily.

ex. distilling from Resnet50 (or any teacher) to a vision transformer

import torch
from torchvision.models import resnet50

from vit_pytorch.distill import DistillableViT, DistillWrapper

teacher = resnet50(pretrained = True)

v = DistillableViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 8,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

distiller = DistillWrapper(
    student = v,
    teacher = teacher,
    temperature = 3,           # temperature of distillation
    alpha = 0.5,               # trade between main loss and distillation loss
    hard = False               # whether to use soft or hard distillation
)

img = torch.randn(2, 3, 256, 256)
labels = torch.randint(0, 1000, (2,))

loss = distiller(img, labels)
loss.backward()

# after lots of training above ...

pred = v(img) # (2, 1000)

The DistillableViT class is identical to ViT except for how the forward pass is handled, so you should be able to load the parameters back to ViT after you have completed distillation training.

You can also use the handy .to_vit method on the DistillableViT instance to get back a ViT instance.

v = v.to_vit()
type(v) # <class 'vit_pytorch.vit_pytorch.ViT'>

Deep ViT

This paper notes that ViT struggles to attend at greater depths (past 12 layers), and suggests mixing the attention of each head post-softmax as a solution, dubbed Re-attention. The results line up with the Talking Heads paper from NLP.

You can use it as follows

import torch
from vit_pytorch.deepvit import DeepViT

v = DeepViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

img = torch.randn(1, 3, 256, 256)

preds = v(img) # (1, 1000)

CaiT

This paper also notes difficulty in training vision transformers at greater depths and proposes two solutions. First it proposes to do per-channel multiplication of the output of the residual block. Second, it proposes to have the patches attend to one another, and only allow the CLS token to attend to the patches in the last few layers.

They also add Talking Heads, noting improvements

You can use this scheme as follows

import torch
from vit_pytorch.cait import CaiT

v = CaiT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 12,             # depth of transformer for patch to patch attention only
    cls_depth = 2,          # depth of cross attention of CLS tokens to patch
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1,
    layer_dropout = 0.05    # randomly dropout 5% of the layers
)

img = torch.randn(1, 3, 256, 256)

preds = v(img) # (1, 1000)

Token-to-Token ViT

This paper proposes that the first couple layers should downsample the image sequence by unfolding, leading to overlapping image data in each token as shown in the figure above. You can use this variant of the ViT as follows.

import torch
from vit_pytorch.t2t import T2TViT

v = T2TViT(
    dim = 512,
    image_size = 224,
    depth = 5,
    heads = 8,
    mlp_dim = 512,
    num_classes = 1000,
    t2t_layers = ((7, 4), (3, 2), (3, 2)) # tuples of the kernel size and stride of each consecutive layers of the initial token to token module
)

img = torch.randn(1, 3, 224, 224)

preds = v(img) # (1, 1000)

CCT

CCT proposes compact transformers by using convolutions instead of patching and performing sequence pooling. This allows for CCT to have high accuracy and a low number of parameters.

You can use this with two methods

import torch
from vit_pytorch.cct import CCT

cct = CCT(
    img_size = (224, 448),
    embedding_dim = 384,
    n_conv_layers = 2,
    kernel_size = 7,
    stride = 2,
    padding = 3,
    pooling_kernel_size = 3,
    pooling_stride = 2,
    pooling_padding = 1,
    num_layers = 14,
    num_heads = 6,
    mlp_ratio = 3.,
    num_classes = 1000,
    positional_embedding = 'learnable', # ['sine', 'learnable', 'none']
)

img = torch.randn(1, 3, 224, 448)
pred = cct(img) # (1, 1000)

Alternatively you can use one of several pre-defined models [2,4,6,7,8,14,16] which pre-define the number of layers, number of attention heads, the mlp ratio, and the embedding dimension.

import torch
from vit_pytorch.cct import cct_14

cct = cct_14(
    img_size = 224,
    n_conv_layers = 1,
    kernel_size = 7,
    stride = 2,
    padding = 3,
    pooling_kernel_size = 3,
    pooling_stride = 2,
    pooling_padding = 1,
    num_classes = 1000,
    positional_embedding = 'learnable', # ['sine', 'learnable', 'none']
)

Official Repository includes links to pretrained model checkpoints.

Cross ViT

This paper proposes to have two vision transformers processing the image at different scales, cross attending to one every so often. They show improvements on top of the base vision transformer.

import torch
from vit_pytorch.cross_vit import CrossViT

v = CrossViT(
    image_size = 256,
    num_classes = 1000,
    depth = 4,               # number of multi-scale encoding blocks
    sm_dim = 192,            # high res dimension
    sm_patch_size = 16,      # high res patch size (should be smaller than lg_patch_size)
    sm_enc_depth = 2,        # high res depth
    sm_enc_heads = 8,        # high res heads
    sm_enc_mlp_dim = 2048,   # high res feedforward dimension
    lg_dim = 384,            # low res dimension
    lg_patch_size = 64,      # low res patch size
    lg_enc_depth = 3,        # low res depth
    lg_enc_heads = 8,        # low res heads
    lg_enc_mlp_dim = 2048,   # low res feedforward dimensions
    cross_attn_depth = 2,    # cross attention rounds
    cross_attn_heads = 8,    # cross attention heads
    dropout = 0.1,
    emb_dropout = 0.1
)

img = torch.randn(1, 3, 256, 256)

pred = v(img) # (1, 1000)

PiT

This paper proposes to downsample the tokens through a pooling procedure using depth-wise convolutions.

import torch
from vit_pytorch.pit import PiT

v = PiT(
    image_size = 224,
    patch_size = 14,
    dim = 256,
    num_classes = 1000,
    depth = (3, 3, 3),     # list of depths, indicating the number of rounds of each stage before a downsample
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

# forward pass now returns predictions and the attention maps

img = torch.randn(1, 3, 224, 224)

preds = v(img) # (1, 1000)

LeViT

This paper proposes a number of changes, including (1) convolutional embedding instead of patch-wise projection (2) downsampling in stages (3) extra non-linearity in attention (4) 2d relative positional biases instead of initial absolute positional bias (5) batchnorm in place of layernorm.

Official repository

import torch
from vit_pytorch.levit import LeViT

levit = LeViT(
    image_size = 224,
    num_classes = 1000,
    stages = 3,             # number of stages
    dim = (256, 384, 512),  # dimensions at each stage
    depth = 4,              # transformer of depth 4 at each stage
    heads = (4, 6, 8),      # heads at each stage
    mlp_mult = 2,
    dropout = 0.1
)

img = torch.randn(1, 3, 224, 224)

levit(img) # (1, 1000)

CvT

This paper proposes mixing convolutions and attention. Specifically, convolutions are used to embed and downsample the image / feature map in three stages. Depthwise-convoltion is also used to project the queries, keys, and values for attention.

import torch
from vit_pytorch.cvt import CvT

v = CvT(
    num_classes = 1000,
    s1_emb_dim = 64,        # stage 1 - dimension
    s1_emb_kernel = 7,      # stage 1 - conv kernel
    s1_emb_stride = 4,      # stage 1 - conv stride
    s1_proj_kernel = 3,     # stage 1 - attention ds-conv kernel size
    s1_kv_proj_stride = 2,  # stage 1 - attention key / value projection stride
    s1_heads = 1,           # stage 1 - heads
    s1_depth = 1,           # stage 1 - depth
    s1_mlp_mult = 4,        # stage 1 - feedforward expansion factor
    s2_emb_dim = 192,       # stage 2 - (same as above)
    s2_emb_kernel = 3,
    s2_emb_stride = 2,
    s2_proj_kernel = 3,
    s2_kv_proj_stride = 2,
    s2_heads = 3,
    s2_depth = 2,
    s2_mlp_mult = 4,
    s3_emb_dim = 384,       # stage 3 - (same as above)
    s3_emb_kernel = 3,
    s3_emb_stride = 2,
    s3_proj_kernel = 3,
    s3_kv_proj_stride = 2,
    s3_heads = 4,
    s3_depth = 10,
    s3_mlp_mult = 4,
    dropout = 0.
)

img = torch.randn(1, 3, 224, 224)

pred = v(img) # (1, 1000)

Twins SVT

This paper proposes mixing local and global attention, along with position encoding generator (proposed in CPVT) and global average pooling, to achieve the same results as Swin, without the extra complexity of shifted windows, CLS tokens, nor positional embeddings.

import torch
from vit_pytorch.twins_svt import TwinsSVT

model = TwinsSVT(
    num_classes = 1000,       # number of output classes
    s1_emb_dim = 64,          # stage 1 - patch embedding projected dimension
    s1_patch_size = 4,        # stage 1 - patch size for patch embedding
    s1_local_patch_size = 7,  # stage 1 - patch size for local attention
    s1_global_k = 7,          # stage 1 - global attention key / value reduction factor, defaults to 7 as specified in paper
    s1_depth = 1,             # stage 1 - number of transformer blocks (local attn -> ff -> global attn -> ff)
    s2_emb_dim = 128,         # stage 2 (same as above)
    s2_patch_size = 2,
    s2_local_patch_size = 7,
    s2_global_k = 7,
    s2_depth = 1,
    s3_emb_dim = 256,         # stage 3 (same as above)
    s3_patch_size = 2,
    s3_local_patch_size = 7,
    s3_global_k = 7,
    s3_depth = 5,
    s4_emb_dim = 512,         # stage 4 (same as above)
    s4_patch_size = 2,
    s4_local_patch_size = 7,
    s4_global_k = 7,
    s4_depth = 4,
    peg_kernel_size = 3,      # positional encoding generator kernel size
    dropout = 0.              # dropout
)

img = torch.randn(1, 3, 224, 224)

pred = model(img) # (1, 1000)

RegionViT

This paper proposes to divide up the feature map into local regions, whereby the local tokens attend to each other. Each local region has its own regional token which then attends to all its local tokens, as well as other regional tokens.

You can use it as follows

import torch
from vit_pytorch.regionvit import RegionViT

model = RegionViT(
    dim = (64, 128, 256, 512),      # tuple of size 4, indicating dimension at each stage
    depth = (2, 2, 8, 2),           # depth of the region to local transformer at each stage
    window_size = 7,                # window size, which should be either 7 or 14
    num_classes = 1000,             # number of output classes
    tokenize_local_3_conv = False,  # whether to use a 3 layer convolution to encode the local tokens from the image. the paper uses this for the smaller models, but uses only 1 conv (set to False) for the larger models
    use_peg = False,                # whether to use positional generating module. they used this for object detection for a boost in performance
)

img = torch.randn(1, 3, 224, 224)

pred = model(img) # (1, 1000)

CrossFormer

This paper beats PVT and Swin using alternating local and global attention. The global attention is done across the windowing dimension for reduced complexity, much like the scheme used for axial attention.

They also have cross-scale embedding layer, which they shown to be a generic layer that can improve all vision transformers. Dynamic relative positional bias was also formulated to allow the net to generalize to images of greater resolution.

import torch
from vit_pytorch.crossformer import CrossFormer

model = CrossFormer(
    num_classes = 1000,                # number of output classes
    dim = (64, 128, 256, 512),         # dimension at each stage
    depth = (2, 2, 8, 2),              # depth of transformer at each stage
    global_window_size = (8, 4, 2, 1), # global window sizes at each stage
    local_window_size = 7,             # local window size (can be customized for each stage, but in paper, held constant at 7 for all stages)
)

img = torch.randn(1, 3, 224, 224)

pred = model(img) # (1, 1000)

ScalableViT

This Bytedance AI paper proposes the Scalable Self Attention (SSA) and the Interactive Windowed Self Attention (IWSA) modules. The SSA alleviates the computation needed at earlier stages by reducing the key / value feature map by some factor (reduction_factor), while modulating the dimension of the queries and keys (ssa_dim_key). The IWSA performs self attention within local windows, similar to other vision transformer papers. However, they add a residual of the values, passed through a convolution of kernel size 3, which they named Local Interactive Module (LIM).

They make the claim in this paper that this scheme outperforms Swin Transformer, and also demonstrate competitive performance against Crossformer.

You can use it as follows (ex. ScalableViT-S)

import torch
from vit_pytorch.scalable_vit import ScalableViT

model = ScalableViT(
    num_classes = 1000,
    dim = 64,                               # starting model dimension. at every stage, dimension is doubled
    heads = (2, 4, 8, 16),                  # number of attention heads at each stage
    depth = (2, 2, 20, 2),                  # number of transformer blocks at each stage
    ssa_dim_key = (40, 40, 40, 32),         # the dimension of the attention keys (and queries) for SSA. in the paper, they represented this as a scale factor on the base dimension per key (ssa_dim_key / dim_key)
    reduction_factor = (8, 4, 2, 1),        # downsampling of the key / values in SSA. in the paper, this was represented as (reduction_factor ** -2)
    window_size = (64, 32, None, None),     # window size of the IWSA at each stage. None means no windowing needed
    dropout = 0.1,                          # attention and feedforward dropout
)

img = torch.randn(1, 3, 256, 256)

preds = model(img) # (1, 1000)

SepViT

Another Bytedance AI paper, it proposes a depthwise-pointwise self-attention layer that seems largely inspired by mobilenet's depthwise-separable convolution. The most interesting aspect is the reuse of the feature map from the depthwise self-attention stage as the values for the pointwise self-attention, as shown in the diagram above.

I have decided to include only the version of SepViT with this specific self-attention layer, as the grouped attention layers are not remarkable nor novel, and the authors were not clear on how they treated the window tokens for the group self-attention layer. Besides, it seems like with DSSA layer alone, they were able to beat Swin.

ex. SepViT-Lite

import torch
from vit_pytorch.sep_vit import SepViT

v = SepViT(
    num_classes = 1000,
    dim = 32,               # dimensions of first stage, which doubles every stage (32, 64, 128, 256) for SepViT-Lite
    dim_head = 32,          # attention head dimension
    heads = (1, 2, 4, 8),   # number of heads per stage
    depth = (1, 2, 6, 2),   # number of transformer blocks per stage
    window_size = 7,        # window size of DSS Attention block
    dropout = 0.1           # dropout
)

img = torch.randn(1, 3, 224, 224)

preds = v(img) # (1, 1000)

MaxViT

This paper proposes a hybrid convolutional / attention network, using MBConv from the convolution side, and then block / grid axial sparse attention.

They also claim this specific vision transformer is good for generative models (GANs).

ex. MaxViT-S

import torch
from vit_pytorch.max_vit import MaxViT

v = MaxViT(
    num_classes = 1000,
    dim_conv_stem = 64,               # dimension of the convolutional stem, would default to dimension of first layer if not specified
    dim = 96,                         # dimension of first layer, doubles every layer
    dim_head = 32,                    # dimension of attention heads, kept at 32 in paper
    depth = (2, 2, 5, 2),             # number of MaxViT blocks per stage, which consists of MBConv, block-like attention, grid-like attention
    window_size = 7,                  # window size for block and grids
    mbconv_expansion_rate = 4,        # expansion rate of MBConv
    mbconv_shrinkage_rate = 0.25,     # shrinkage rate of squeeze-excitation in MBConv
    dropout = 0.1                     # dropout
)

img = torch.randn(2, 3, 224, 224)

preds = v(img) # (2, 1000)

NesT

This paper decided to process the image in hierarchical stages, with attention only within tokens of local blocks, which aggregate as it moves up the hierarchy. The aggregation is done in the image plane, and contains a convolution and subsequent maxpool to allow it to pass information across the boundary.

You can use it with the following code (ex. NesT-T)

import torch
from vit_pytorch.nest import NesT

nest = NesT(
    image_size = 224,
    patch_size = 4,
    dim = 96,
    heads = 3,
    num_hierarchies = 3,        # number of hierarchies
    block_repeats = (2, 2, 8),  # the number of transformer blocks at each hierarchy, starting from the bottom
    num_classes = 1000
)

img = torch.randn(1, 3, 224, 224)

pred = nest(img) # (1, 1000)

MobileViT

This paper introduce MobileViT, a light-weight and general purpose vision transformer for mobile devices. MobileViT presents a different perspective for the global processing of information with transformers.

You can use it with the following code (ex. mobilevit_xs)

import torch
from vit_pytorch.mobile_vit import MobileViT

mbvit_xs = MobileViT(
    image_size = (256, 256),
    dims = [96, 120, 144],
    channels = [16, 32, 48, 48, 64, 64, 80, 80, 96, 96, 384],
    num_classes = 1000
)

img = torch.randn(1, 3, 256, 256)

pred = mbvit_xs(img) # (1, 1000)

XCiT

This paper introduces the cross covariance attention (abbreviated XCA). One can think of it as doing attention across the features dimension rather than the spatial one (another perspective would be a dynamic 1x1 convolution, the kernel being attention map defined by spatial correlations).

Technically, this amounts to simply transposing the query, key, values before executing cosine similarity attention with learned temperature.

import torch
from vit_pytorch.xcit import XCiT

v = XCiT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 12,                     # depth of xcit transformer
    cls_depth = 2,                  # depth of cross attention of CLS tokens to patch, attention pool at end
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1,
    layer_dropout = 0.05,           # randomly dropout 5% of the layers
    local_patch_kernel_size = 3     # kernel size of the local patch interaction module (depthwise convs)
)

img = torch.randn(1, 3, 256, 256)

preds = v(img) # (1, 1000)

Simple Masked Image Modeling

This paper proposes a simple masked image modeling (SimMIM) scheme, using only a linear projection off the masked tokens into pixel space followed by an L1 loss with the pixel values of the masked patches. Results are competitive with other more complicated approaches.

You can use this as follows

import torch
from vit_pytorch import ViT
from vit_pytorch.simmim import SimMIM

v = ViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 8,
    mlp_dim = 2048
)

mim = SimMIM(
    encoder = v,
    masking_ratio = 0.5  # they found 50% to yield the best results
)

images = torch.randn(8, 3, 256, 256)

loss = mim(images)
loss.backward()

# that's all!
# do the above in a for loop many times with a lot of images and your vision transformer will learn

torch.save(v.state_dict(), './trained-vit.pt')

Masked Autoencoder

A new Kaiming He paper proposes a simple autoencoder scheme where the vision transformer attends to a set of unmasked patches, and a smaller decoder tries to reconstruct the masked pixel values.

DeepReader quick paper review

AI Coffeebreak with Letitia

You can use it with the following code

import torch
from vit_pytorch import ViT, MAE

v = ViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 8,
    mlp_dim = 2048
)

mae = MAE(
    encoder = v,
    masking_ratio = 0.75,   # the paper recommended 75% masked patches
    decoder_dim = 512,      # paper showed good results with just 512
    decoder_depth = 6       # anywhere from 1 to 8
)

images = torch.randn(8, 3, 256, 256)

loss = mae(images)
loss.backward()

# that's all!
# do the above in a for loop many times with a lot of images and your vision transformer will learn

# save your improved vision transformer
torch.save(v.state_dict(), './trained-vit.pt')

Masked Patch Prediction

Thanks to Zach, you can train using the original masked patch prediction task presented in the paper, with the following code.

import torch
from vit_pytorch import ViT
from vit_pytorch.mpp import MPP

model = ViT(
    image_size=256,
    patch_size=32,
    num_classes=1000,
    dim=1024,
    depth=6,
    heads=8,
    mlp_dim=2048,
    dropout=0.1,
    emb_dropout=0.1
)

mpp_trainer = MPP(
    transformer=model,
    patch_size=32,
    dim=1024,
    mask_prob=0.15,          # probability of using token in masked prediction task
    random_patch_prob=0.30,  # probability of randomly replacing a token being used for mpp
    replace_prob=0.50,       # probability of replacing a token being used for mpp with the mask token
)

opt = torch.optim.Adam(mpp_trainer.parameters(), lr=3e-4)

def sample_unlabelled_images():
    return torch.FloatTensor(20, 3, 256, 256).uniform_(0., 1.)

for _ in range(100):
    images = sample_unlabelled_images()
    loss = mpp_trainer(images)
    opt.zero_grad()
    loss.backward()
    opt.step()

# save your improved network
torch.save(model.state_dict(), './pretrained-net.pt')

Masked Position Prediction

New paper that introduces masked position prediction pre-training criteria. This strategy is more efficient than the Masked Autoencoder strategy and has comparable performance.

import torch
from vit_pytorch.mp3 import ViT, MP3

v = ViT(
    num_classes = 1000,
    image_size = 256,
    patch_size = 8,
    dim = 1024,
    depth = 6,
    heads = 8,
    mlp_dim = 2048,
    dropout = 0.1,
)

mp3 = MP3(
    vit = v,
    masking_ratio = 0.75
)

images = torch.randn(8, 3, 256, 256)

loss = mp3(images)
loss.backward()

# that's all!
# do the above in a for loop many times with a lot of images and your vision transformer will learn

# save your improved vision transformer
torch.save(v.state_dict(), './trained-vit.pt')

Adaptive Token Sampling

This paper proposes to use the CLS attention scores, re-weighed by the norms of the value heads, as means to discard unimportant tokens at different layers.

import torch
from vit_pytorch.ats_vit import ViT

v = ViT(
    image_size = 256,
    patch_size = 16,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    max_tokens_per_depth = (256, 128, 64, 32, 16, 8), # a tuple that denotes the maximum number of tokens that any given layer should have. if the layer has greater than this amount, it will undergo adaptive token sampling
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

img = torch.randn(4, 3, 256, 256)

preds = v(img) # (4, 1000)

# you can also get a list of the final sampled patch ids
# a value of -1 denotes padding

preds, token_ids = v(img, return_sampled_token_ids = True) # (4, 1000), (4, <=8)

Patch Merger

This paper proposes a simple module (Patch Merger) for reducing the number of tokens at any layer of a vision transformer without sacrificing performance.

import torch
from vit_pytorch.vit_with_patch_merger import ViT

v = ViT(
    image_size = 256,
    patch_size = 16,
    num_classes = 1000,
    dim = 1024,
    depth = 12,
    heads = 8,
    patch_merge_layer = 6,        # at which transformer layer to do patch merging
    patch_merge_num_tokens = 8,   # the output number of tokens from the patch merge
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

img = torch.randn(4, 3, 256, 256)

preds = v(img) # (4, 1000)

One can also use the PatchMerger module by itself

import torch
from vit_pytorch.vit_with_patch_merger import PatchMerger

merger = PatchMerger(
    dim = 1024,
    num_tokens_out = 8   # output number of tokens
)

features = torch.randn(4, 256, 1024) # (batch, num tokens, dimension)

out = merger(features) # (4, 8, 1024)

Vision Transformer for Small Datasets

This paper proposes a new image to patch function that incorporates shifts of the image, before normalizing and dividing the image into patches. I have found shifting to be extremely helpful in some other transformers work, so decided to include this for further explorations. It also includes the LSA with the learned temperature and masking out of a token's attention to itself.

You can use as follows:

import torch
from vit_pytorch.vit_for_small_dataset import ViT

v = ViT(
    image_size = 256,
    patch_size = 16,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

img = torch.randn(4, 3, 256, 256)

preds = v(img) # (1, 1000)

You can also use the SPT from this paper as a standalone module

import torch
from vit_pytorch.vit_for_small_dataset import SPT

spt = SPT(
    dim = 1024,
    patch_size = 16,
    channels = 3
)

img = torch.randn(4, 3, 256, 256)

tokens = spt(img) # (4, 256, 1024)

3D ViT

By popular request, I will start extending a few of the architectures in this repository to 3D ViTs, for use with video, medical imaging, etc.

You will need to pass in two additional hyperparameters: (1) the number of frames frames and (2) patch size along the frame dimension frame_patch_size

For starters, 3D ViT

import torch
from vit_pytorch.vit_3d import ViT

v = ViT(
    image_size = 128,          # image size
    frames = 16,               # number of frames
    image_patch_size = 16,     # image patch size
    frame_patch_size = 2,      # frame patch size
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 8,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

video = torch.randn(4, 3, 16, 128, 128) # (batch, channels, frames, height, width)

preds = v(video) # (4, 1000)

3D Simple ViT

import torch
from vit_pytorch.simple_vit_3d import SimpleViT

v = SimpleViT(
    image_size = 128,          # image size
    frames = 16,               # number of frames
    image_patch_size = 16,     # image patch size
    frame_patch_size = 2,      # frame patch size
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 8,
    mlp_dim = 2048
)

video = torch.randn(4, 3, 16, 128, 128) # (batch, channels, frames, height, width)

preds = v(video) # (4, 1000)

3D version of CCT

import torch
from vit_pytorch.cct_3d import CCT

cct = CCT(
    img_size = 224,
    num_frames = 8,
    embedding_dim = 384,
    n_conv_layers = 2,
    frame_kernel_size = 3,
    kernel_size = 7,
    stride = 2,
    padding = 3,
    pooling_kernel_size = 3,
    pooling_stride = 2,
    pooling_padding = 1,
    num_layers = 14,
    num_heads = 6,
    mlp_ratio = 3.,
    num_classes = 1000,
    positional_embedding = 'learnable'
)

video = torch.randn(1, 3, 8, 224, 224) # (batch, channels, frames, height, width)
pred = cct(video)

ViViT

This paper offers 3 different types of architectures for efficient attention of videos, with the main theme being factorizing the attention across space and time. This repository will offer the first variant, which is a spatial transformer followed by a temporal one.

import torch
from vit_pytorch.vivit import ViT

v = ViT(
    image_size = 128,          # image size
    frames = 16,               # number of frames
    image_patch_size = 16,     # image patch size
    frame_patch_size = 2,      # frame patch size
    num_classes = 1000,
    dim = 1024,
    spatial_depth = 6,         # depth of the spatial transformer
    temporal_depth = 6,        # depth of the temporal transformer
    heads = 8,
    mlp_dim = 2048
)

video = torch.randn(4, 3, 16, 128, 128) # (batch, channels, frames, height, width)

preds = v(video) # (4, 1000)

Parallel ViT

This paper propose parallelizing multiple attention and feedforward blocks per layer (2 blocks), claiming that it is easier to train without loss of performance.

You can try this variant as follows

import torch
from vit_pytorch.parallel_vit import ViT

v = ViT(
    image_size = 256,
    patch_size = 16,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 8,
    mlp_dim = 2048,
    num_parallel_branches = 2,  # in paper, they claimed 2 was optimal
    dropout = 0.1,
    emb_dropout = 0.1
)

img = torch.randn(4, 3, 256, 256)

preds = v(img) # (4, 1000)

Learnable Memory ViT

This paper shows that adding learnable memory tokens at each layer of a vision transformer can greatly enhance fine-tuning results (in addition to learnable task specific CLS token and adapter head).

You can use this with a specially modified ViT as follows

import torch
from vit_pytorch.learnable_memory_vit import ViT, Adapter

# normal base ViT

v = ViT(
    image_size = 256,
    patch_size = 16,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 8,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

img = torch.randn(4, 3, 256, 256)
logits = v(img) # (4, 1000)

# do your usual training with ViT
# ...


# then, to finetune, just pass the ViT into the Adapter class
# you can do this for multiple Adapters, as shown below

adapter1 = Adapter(
    vit = v,
    num_classes = 2,               # number of output classes for this specific task
    num_memories_per_layer = 5     # number of learnable memories per layer, 10 was sufficient in paper
)

logits1 = adapter1(img) # (4, 2) - predict 2 classes off frozen ViT backbone with learnable memories and task specific head

# yet another task to finetune on, this time with 4 classes

adapter2 = Adapter(
    vit = v,
    num_classes = 4,
    num_memories_per_layer = 10
)

logits2 = adapter2(img) # (4, 4) - predict 4 classes off frozen ViT backbone with learnable memories and task specific head

Dino

You can train ViT with the recent SOTA self-supervised learning technique, Dino, with the following code.

Yannic Kilcher video

import torch
from vit_pytorch import ViT, Dino

model = ViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 8,
    mlp_dim = 2048
)

learner = Dino(
    model,
    image_size = 256,
    hidden_layer = 'to_latent',        # hidden layer name or index, from which to extract the embedding
    projection_hidden_size = 256,      # projector network hidden dimension
    projection_layers = 4,             # number of layers in projection network
    num_classes_K = 65336,             # output logits dimensions (referenced as K in paper)
    student_temp = 0.9,                # student temperature
    teacher_temp = 0.04,               # teacher temperature, needs to be annealed from 0.04 to 0.07 over 30 epochs
    local_upper_crop_scale = 0.4,      # upper bound for local crop - 0.4 was recommended in the paper 
    global_lower_crop_scale = 0.5,     # lower bound for global crop - 0.5 was recommended in the paper
    moving_average_decay = 0.9,        # moving average of encoder - paper showed anywhere from 0.9 to 0.999 was ok
    center_moving_average_decay = 0.9, # moving average of teacher centers - paper showed anywhere from 0.9 to 0.999 was ok
)

opt = torch.optim.Adam(learner.parameters(), lr = 3e-4)

def sample_unlabelled_images():
    return torch.randn(20, 3, 256, 256)

for _ in range(100):
    images = sample_unlabelled_images()
    loss = learner(images)
    opt.zero_grad()
    loss.backward()
    opt.step()
    learner.update_moving_average() # update moving average of teacher encoder and teacher centers

# save your improved network
torch.save(model.state_dict(), './pretrained-net.pt')

EsViT

EsViT is a variant of Dino (from above) re-engineered to support efficient ViTs with patch merging / downsampling by taking into an account an extra regional loss between the augmented views. To quote the abstract, it outperforms its supervised counterpart on 17 out of 18 datasets at 3 times higher throughput.

Even though it is named as though it were a new ViT variant, it actually is just a strategy for training any multistage ViT (in the paper, they focused on Swin). The example below will show how to use it with CvT. You'll need to set the hidden_layer to the name of the layer within your efficient ViT that outputs the non-average pooled visual representations, just before the global pooling and projection to logits.

import torch
from vit_pytorch.cvt import CvT
from vit_pytorch.es_vit import EsViTTrainer

cvt = CvT(
    num_classes = 1000,
    s1_emb_dim = 64,
    s1_emb_kernel = 7,
    s1_emb_stride = 4,
    s1_proj_kernel = 3,
    s1_kv_proj_stride = 2,
    s1_heads = 1,
    s1_depth = 1,
    s1_mlp_mult = 4,
    s2_emb_dim = 192,
    s2_emb_kernel = 3,
    s2_emb_stride = 2,
    s2_proj_kernel = 3,
    s2_kv_proj_stride = 2,
    s2_heads = 3,
    s2_depth = 2,
    s2_mlp_mult = 4,
    s3_emb_dim = 384,
    s3_emb_kernel = 3,
    s3_emb_stride = 2,
    s3_proj_kernel = 3,
    s3_kv_proj_stride = 2,
    s3_heads = 4,
    s3_depth = 10,
    s3_mlp_mult = 4,
    dropout = 0.
)

learner = EsViTTrainer(
    cvt,
    image_size = 256,
    hidden_layer = 'layers',           # hidden layer name or index, from which to extract the embedding
    projection_hidden_size = 256,      # projector network hidden dimension
    projection_layers = 4,             # number of layers in projection network
    num_classes_K = 65336,             # output logits dimensions (referenced as K in paper)
    student_temp = 0.9,                # student temperature
    teacher_temp = 0.04,               # teacher temperature, needs to be annealed from 0.04 to 0.07 over 30 epochs
    local_upper_crop_scale = 0.4,      # upper bound for local crop - 0.4 was recommended in the paper
    global_lower_crop_scale = 0.5,     # lower bound for global crop - 0.5 was recommended in the paper
    moving_average_decay = 0.9,        # moving average of encoder - paper showed anywhere from 0.9 to 0.999 was ok
    center_moving_average_decay = 0.9, # moving average of teacher centers - paper showed anywhere from 0.9 to 0.999 was ok
)

opt = torch.optim.AdamW(learner.parameters(), lr = 3e-4)

def sample_unlabelled_images():
    return torch.randn(8, 3, 256, 256)

for _ in range(1000):
    images = sample_unlabelled_images()
    loss = learner(images)
    opt.zero_grad()
    loss.backward()
    opt.step()
    learner.update_moving_average() # update moving average of teacher encoder and teacher centers

# save your improved network
torch.save(cvt.state_dict(), './pretrained-net.pt')

Accessing Attention

If you would like to visualize the attention weights (post-softmax) for your research, just follow the procedure below

import torch
from vit_pytorch.vit import ViT

v = ViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

# import Recorder and wrap the ViT

from vit_pytorch.recorder import Recorder
v = Recorder(v)

# forward pass now returns predictions and the attention maps

img = torch.randn(1, 3, 256, 256)
preds, attns = v(img)

# there is one extra patch due to the CLS token

attns # (1, 6, 16, 65, 65) - (batch x layers x heads x patch x patch)

to cleanup the class and the hooks once you have collected enough data

v = v.eject()  # wrapper is discarded and original ViT instance is returned

Accessing Embeddings

You can similarly access the embeddings with the Extractor wrapper

import torch
from vit_pytorch.vit import ViT

v = ViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

# import Recorder and wrap the ViT

from vit_pytorch.extractor import Extractor
v = Extractor(v)

# forward pass now returns predictions and the attention maps

img = torch.randn(1, 3, 256, 256)
logits, embeddings = v(img)

# there is one extra token due to the CLS token

embeddings # (1, 65, 1024) - (batch x patches x model dim)

Or say for CrossViT, which has a multi-scale encoder that outputs two sets of embeddings for 'large' and 'small' scales

import torch
from vit_pytorch.cross_vit import CrossViT

v = CrossViT(
    image_size = 256,
    num_classes = 1000,
    depth = 4,
    sm_dim = 192,
    sm_patch_size = 16,
    sm_enc_depth = 2,
    sm_enc_heads = 8,
    sm_enc_mlp_dim = 2048,
    lg_dim = 384,
    lg_patch_size = 64,
    lg_enc_depth = 3,
    lg_enc_heads = 8,
    lg_enc_mlp_dim = 2048,
    cross_attn_depth = 2,
    cross_attn_heads = 8,
    dropout = 0.1,
    emb_dropout = 0.1
)

# wrap the CrossViT

from vit_pytorch.extractor import Extractor
v = Extractor(v, layer_name = 'multi_scale_encoder') # take embedding coming from the output of multi-scale-encoder

# forward pass now returns predictions and the attention maps

img = torch.randn(1, 3, 256, 256)
logits, embeddings = v(img)

# there is one extra token due to the CLS token

embeddings # ((1, 257, 192), (1, 17, 384)) - (batch x patches x dimension) <- large and small scales respectively

Research Ideas

Efficient Attention

There may be some coming from computer vision who think attention still suffers from quadratic costs. Fortunately, we have a lot of new techniques that may help. This repository offers a way for you to plugin your own sparse attention transformer.

An example with Nystromformer

$ pip install nystrom-attention
import torch
from vit_pytorch.efficient import ViT
from nystrom_attention import Nystromformer

efficient_transformer = Nystromformer(
    dim = 512,
    depth = 12,
    heads = 8,
    num_landmarks = 256
)

v = ViT(
    dim = 512,
    image_size = 2048,
    patch_size = 32,
    num_classes = 1000,
    transformer = efficient_transformer
)

img = torch.randn(1, 3, 2048, 2048) # your high resolution picture
v(img) # (1, 1000)

Other sparse attention frameworks I would highly recommend is Routing Transformer or Sinkhorn Transformer

Combining with other Transformer improvements

This paper purposely used the most vanilla of attention networks to make a statement. If you would like to use some of the latest improvements for attention nets, please use the Encoder from this repository.

ex.

$ pip install x-transformers
import torch
from vit_pytorch.efficient import ViT
from x_transformers import Encoder

v = ViT(
    dim = 512,
    image_size = 224,
    patch_size = 16,
    num_classes = 1000,
    transformer = Encoder(
        dim = 512,                  # set to be the same as the wrapper
        depth = 12,
        heads = 8,
        ff_glu = True,              # ex. feed forward GLU variant https://arxiv.org/abs/2002.05202
        residual_attn = True        # ex. residual attention https://arxiv.org/abs/2012.11747
    )
)

img = torch.randn(1, 3, 224, 224)
v(img) # (1, 1000)

FAQ

  • How do I pass in non-square images?

You can already pass in non-square images - you just have to make sure your height and width is less than or equal to the image_size, and both divisible by the patch_size

ex.

import torch
from vit_pytorch import ViT

v = ViT(
    image_size = 256,
    patch_size = 32,
    num_classes = 1000,
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

img = torch.randn(1, 3, 256, 128) # <-- not a square

preds = v(img) # (1, 1000)
  • How do I pass in non-square patches?
import torch
from vit_pytorch import ViT

v = ViT(
    num_classes = 1000,
    image_size = (256, 128),  # image size is a tuple of (height, width)
    patch_size = (32, 16),    # patch size is a tuple of (height, width)
    dim = 1024,
    depth = 6,
    heads = 16,
    mlp_dim = 2048,
    dropout = 0.1,
    emb_dropout = 0.1
)

img = torch.randn(1, 3, 256, 128)

preds = v(img)

Resources

Coming from computer vision and new to transformers? Here are some resources that greatly accelerated my learning.

  1. Illustrated Transformer - Jay Alammar

  2. Transformers from Scratch - Peter Bloem

  3. The Annotated Transformer - Harvard NLP

Citations

@article{hassani2021escaping,
    title   = {Escaping the Big Data Paradigm with Compact Transformers},
    author  = {Ali Hassani and Steven Walton and Nikhil Shah and Abulikemu Abuduweili and Jiachen Li and Humphrey Shi},
    year    = 2021,
    url     = {https://arxiv.org/abs/2104.05704},
    eprint  = {2104.05704},
    archiveprefix = {arXiv},
    primaryclass = {cs.CV}
}
@misc{dosovitskiy2020image,
    title   = {An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale},
    author  = {Alexey Dosovitskiy and Lucas Beyer and Alexander Kolesnikov and Dirk Weissenborn and Xiaohua Zhai and Thomas Unterthiner and Mostafa Dehghani and Matthias Minderer and Georg Heigold and Sylvain Gelly and Jakob Uszkoreit and Neil Houlsby},
    year    = {2020},
    eprint  = {2010.11929},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{touvron2020training,
    title   = {Training data-efficient image transformers & distillation through attention}, 
    author  = {Hugo Touvron and Matthieu Cord and Matthijs Douze and Francisco Massa and Alexandre Sablayrolles and Hervé Jégou},
    year    = {2020},
    eprint  = {2012.12877},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{yuan2021tokenstotoken,
    title   = {Tokens-to-Token ViT: Training Vision Transformers from Scratch on ImageNet},
    author  = {Li Yuan and Yunpeng Chen and Tao Wang and Weihao Yu and Yujun Shi and Francis EH Tay and Jiashi Feng and Shuicheng Yan},
    year    = {2021},
    eprint  = {2101.11986},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{zhou2021deepvit,
    title   = {DeepViT: Towards Deeper Vision Transformer},
    author  = {Daquan Zhou and Bingyi Kang and Xiaojie Jin and Linjie Yang and Xiaochen Lian and Qibin Hou and Jiashi Feng},
    year    = {2021},
    eprint  = {2103.11886},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{touvron2021going,
    title   = {Going deeper with Image Transformers}, 
    author  = {Hugo Touvron and Matthieu Cord and Alexandre Sablayrolles and Gabriel Synnaeve and Hervé Jégou},
    year    = {2021},
    eprint  = {2103.17239},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{chen2021crossvit,
    title   = {CrossViT: Cross-Attention Multi-Scale Vision Transformer for Image Classification},
    author  = {Chun-Fu Chen and Quanfu Fan and Rameswar Panda},
    year    = {2021},
    eprint  = {2103.14899},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{wu2021cvt,
    title   = {CvT: Introducing Convolutions to Vision Transformers},
    author  = {Haiping Wu and Bin Xiao and Noel Codella and Mengchen Liu and Xiyang Dai and Lu Yuan and Lei Zhang},
    year    = {2021},
    eprint  = {2103.15808},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{heo2021rethinking,
    title   = {Rethinking Spatial Dimensions of Vision Transformers}, 
    author  = {Byeongho Heo and Sangdoo Yun and Dongyoon Han and Sanghyuk Chun and Junsuk Choe and Seong Joon Oh},
    year    = {2021},
    eprint  = {2103.16302},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{graham2021levit,
    title   = {LeViT: a Vision Transformer in ConvNet's Clothing for Faster Inference},
    author  = {Ben Graham and Alaaeldin El-Nouby and Hugo Touvron and Pierre Stock and Armand Joulin and Hervé Jégou and Matthijs Douze},
    year    = {2021},
    eprint  = {2104.01136},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{li2021localvit,
    title   = {LocalViT: Bringing Locality to Vision Transformers},
    author  = {Yawei Li and Kai Zhang and Jiezhang Cao and Radu Timofte and Luc Van Gool},
    year    = {2021},
    eprint  = {2104.05707},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{chu2021twins,
    title   = {Twins: Revisiting Spatial Attention Design in Vision Transformers},
    author  = {Xiangxiang Chu and Zhi Tian and Yuqing Wang and Bo Zhang and Haibing Ren and Xiaolin Wei and Huaxia Xia and Chunhua Shen},
    year    = {2021},
    eprint  = {2104.13840},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{su2021roformer,
    title   = {RoFormer: Enhanced Transformer with Rotary Position Embedding}, 
    author  = {Jianlin Su and Yu Lu and Shengfeng Pan and Bo Wen and Yunfeng Liu},
    year    = {2021},
    eprint  = {2104.09864},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}
@misc{zhang2021aggregating,
    title   = {Aggregating Nested Transformers},
    author  = {Zizhao Zhang and Han Zhang and Long Zhao and Ting Chen and Tomas Pfister},
    year    = {2021},
    eprint  = {2105.12723},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{chen2021regionvit,
    title   = {RegionViT: Regional-to-Local Attention for Vision Transformers}, 
    author  = {Chun-Fu Chen and Rameswar Panda and Quanfu Fan},
    year    = {2021},
    eprint  = {2106.02689},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{wang2021crossformer,
    title   = {CrossFormer: A Versatile Vision Transformer Hinging on Cross-scale Attention}, 
    author  = {Wenxiao Wang and Lu Yao and Long Chen and Binbin Lin and Deng Cai and Xiaofei He and Wei Liu},
    year    = {2021},
    eprint  = {2108.00154},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{caron2021emerging,
    title   = {Emerging Properties in Self-Supervised Vision Transformers},
    author  = {Mathilde Caron and Hugo Touvron and Ishan Misra and Hervé Jégou and Julien Mairal and Piotr Bojanowski and Armand Joulin},
    year    = {2021},
    eprint  = {2104.14294},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{he2021masked,
    title   = {Masked Autoencoders Are Scalable Vision Learners}, 
    author  = {Kaiming He and Xinlei Chen and Saining Xie and Yanghao Li and Piotr Dollár and Ross Girshick},
    year    = {2021},
    eprint  = {2111.06377},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{xie2021simmim,
    title   = {SimMIM: A Simple Framework for Masked Image Modeling}, 
    author  = {Zhenda Xie and Zheng Zhang and Yue Cao and Yutong Lin and Jianmin Bao and Zhuliang Yao and Qi Dai and Han Hu},
    year    = {2021},
    eprint  = {2111.09886},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{fayyaz2021ats,
    title   = {ATS: Adaptive Token Sampling For Efficient Vision Transformers},
    author  = {Mohsen Fayyaz and Soroush Abbasi Kouhpayegani and Farnoush Rezaei Jafari and Eric Sommerlade and Hamid Reza Vaezi Joze and Hamed Pirsiavash and Juergen Gall},
    year    = {2021},
    eprint  = {2111.15667},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{mehta2021mobilevit,
    title   = {MobileViT: Light-weight, General-purpose, and Mobile-friendly Vision Transformer},
    author  = {Sachin Mehta and Mohammad Rastegari},
    year    = {2021},
    eprint  = {2110.02178},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{lee2021vision,
    title   = {Vision Transformer for Small-Size Datasets}, 
    author  = {Seung Hoon Lee and Seunghyun Lee and Byung Cheol Song},
    year    = {2021},
    eprint  = {2112.13492},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{renggli2022learning,
    title   = {Learning to Merge Tokens in Vision Transformers},
    author  = {Cedric Renggli and André Susano Pinto and Neil Houlsby and Basil Mustafa and Joan Puigcerver and Carlos Riquelme},
    year    = {2022},
    eprint  = {2202.12015},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@misc{yang2022scalablevit,
    title   = {ScalableViT: Rethinking the Context-oriented Generalization of Vision Transformer}, 
    author  = {Rui Yang and Hailong Ma and Jie Wu and Yansong Tang and Xuefeng Xiao and Min Zheng and Xiu Li},
    year    = {2022},
    eprint  = {2203.10790},
    archivePrefix = {arXiv},
    primaryClass = {cs.CV}
}
@inproceedings{Touvron2022ThreeTE,
    title   = {Three things everyone should know about Vision Transformers},
    author  = {Hugo Touvron and Matthieu Cord and Alaaeldin El-Nouby and Jakob Verbeek and Herv'e J'egou},
    year    = {2022}
}
@inproceedings{Sandler2022FinetuningIT,
    title   = {Fine-tuning Image Transformers using Learnable Memory},
    author  = {Mark Sandler and Andrey Zhmoginov and Max Vladymyrov and Andrew Jackson},
    year    = {2022}
}
@inproceedings{Li2022SepViTSV,
    title   = {SepViT: Separable Vision Transformer},
    author  = {Wei Li and Xing Wang and Xin Xia and Jie Wu and Xuefeng Xiao and Minghang Zheng and Shiping Wen},
    year    = {2022}
}
@inproceedings{Tu2022MaxViTMV,
    title   = {MaxViT: Multi-Axis Vision Transformer},
    author  = {Zhengzhong Tu and Hossein Talebi and Han Zhang and Feng Yang and Peyman Milanfar and Alan Conrad Bovik and Yinxiao Li},
    year    = {2022}
}
@article{Li2021EfficientSV,
    title   = {Efficient Self-supervised Vision Transformers for Representation Learning},
    author  = {Chunyuan Li and Jianwei Yang and Pengchuan Zhang and Mei Gao and Bin Xiao and Xiyang Dai and Lu Yuan and Jianfeng Gao},
    journal = {ArXiv},
    year    = {2021},
    volume  = {abs/2106.09785}
}
@misc{Beyer2022BetterPlainViT
    title     = {Better plain ViT baselines for ImageNet-1k},
    author    = {Beyer, Lucas and Zhai, Xiaohua and Kolesnikov, Alexander},
    publisher = {arXiv},
    year      = {2022}
}
@article{Arnab2021ViViTAV,
    title   = {ViViT: A Video Vision Transformer},
    author  = {Anurag Arnab and Mostafa Dehghani and Georg Heigold and Chen Sun and Mario Lucic and Cordelia Schmid},
    journal = {2021 IEEE/CVF International Conference on Computer Vision (ICCV)},
    year    = {2021},
    pages   = {6816-6826}
}
@article{Liu2022PatchDropoutEV,
    title   = {PatchDropout: Economizing Vision Transformers Using Patch Dropout},
    author  = {Yue Liu and Christos Matsoukas and Fredrik Strand and Hossein Azizpour and Kevin Smith},
    journal = {ArXiv},
    year    = {2022},
    volume  = {abs/2208.07220}
}
@misc{https://doi.org/10.48550/arxiv.2302.01327,
    doi     = {10.48550/ARXIV.2302.01327},
    url     = {https://arxiv.org/abs/2302.01327},
    author  = {Kumar, Manoj and Dehghani, Mostafa and Houlsby, Neil},
    title   = {Dual PatchNorm},
    publisher = {arXiv},
    year    = {2023},
    copyright = {Creative Commons Attribution 4.0 International}
}
@inproceedings{Dehghani2023PatchNP,
    title   = {Patch n' Pack: NaViT, a Vision Transformer for any Aspect Ratio and Resolution},
    author  = {Mostafa Dehghani and Basil Mustafa and Josip Djolonga and Jonathan Heek and Matthias Minderer and Mathilde Caron and Andreas Steiner and Joan Puigcerver and Robert Geirhos and Ibrahim M. Alabdulmohsin and Avital Oliver and Piotr Padlewski and Alexey A. Gritsenko and Mario Luvci'c and Neil Houlsby},
    year    = {2023}
}
@misc{vaswani2017attention,
    title   = {Attention Is All You Need},
    author  = {Ashish Vaswani and Noam Shazeer and Niki Parmar and Jakob Uszkoreit and Llion Jones and Aidan N. Gomez and Lukasz Kaiser and Illia Polosukhin},
    year    = {2017},
    eprint  = {1706.03762},
    archivePrefix = {arXiv},
    primaryClass = {cs.CL}
}
@inproceedings{dao2022flashattention,
    title   = {Flash{A}ttention: Fast and Memory-Efficient Exact Attention with {IO}-Awareness},
    author  = {Dao, Tri and Fu, Daniel Y. and Ermon, Stefano and Rudra, Atri and R{\'e}, Christopher},
    booktitle = {Advances in Neural Information Processing Systems},
    year    = {2022}
}
@inproceedings{Darcet2023VisionTN,
    title   = {Vision Transformers Need Registers},
    author  = {Timoth'ee Darcet and Maxime Oquab and Julien Mairal and Piotr Bojanowski},
    year    = {2023},
    url     = {https://api.semanticscholar.org/CorpusID:263134283}
}
@inproceedings{ElNouby2021XCiTCI,
    title   = {XCiT: Cross-Covariance Image Transformers},
    author  = {Alaaeldin El-Nouby and Hugo Touvron and Mathilde Caron and Piotr Bojanowski and Matthijs Douze and Armand Joulin and Ivan Laptev and Natalia Neverova and Gabriel Synnaeve and Jakob Verbeek and Herv{\'e} J{\'e}gou},
    booktitle = {Neural Information Processing Systems},
    year    = {2021},
    url     = {https://api.semanticscholar.org/CorpusID:235458262}
}

I visualise a time when we will be to robots what dogs are to humans, and I’m rooting for the machines. — Claude Shannon

vit-pytorch's People

Contributors

adimyth avatar alihassanijr avatar ankandrew avatar chinhsuanwu avatar developer0hye avatar eify avatar jon-tow avatar l0wgear avatar loctruong96 avatar lucidrains avatar minhlong94 avatar murufeng avatar roydenwa avatar ryanrussell avatar shabie avatar soumya1729 avatar stevenwalton avatar umbertov avatar vishu26 avatar vztu avatar zankner avatar

Stargazers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

Watchers

 avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar  avatar

vit-pytorch's Issues

Image size constraint

Is there any image size constraint on the performance? Also, how one can use VIT for asymmetric images e.g. 32*64?

Trained on small dataset with pre-trained weight, don't have good result.

pretrained_v = timm.create_model('vit_base_patch16_224', pretrained=True)
pretrained_v.head = nn.Linear(768,2)

I tried Kaggle Cats vs Dogs Dataset for binary classification. Didn't work, output is all cat or all dog.

Any idea how to make it work at small dataset? (less than 10000 or even less than 1000)

PS: Adam, lr = 1e-2

Why only use the first patch? Thanks

I don't understand in line 124 of vit_pytorch.py:
x = self.to_cls_token(x[:, 0])
If the first dimension of x is batch, then the 2nd dimension 0 should be patch, as the dimension of x should be [batch, patch, feature]. Does it mean only the first patch is used? Could anybody help me on this? Thanks a lot.

How to handle variant image sizes? Thanks

I have a question about variant image sizes.

If we have images with different sizes (actually happens often, if no resizing is used). Let's say imge1 has 256 patches, and image2 has 512 patches. For this question, I would guess self.pos_embedding is defined as a sufficient big, e.g.,
self.pos_embedding = nn.Parameter(torch.randn(1, 10000, dim)), and then when using it, we may use
num_patches = x.shape[1] x += self.pos_embedding(:, num_patches + 1, :).
But I am not quite sure if this approach works. Could you please advise?

Test on various image size

Hi,
great coding job. When I read this paper, I always have a question bout testing on the image whose size is different from the training data. Suppose we train on 224x224x3 images and the patch size is 16x16x3, which means the sequence length of image would be 196. However, if I want to test the model on 220x220x3 image (size is not dividable by 16), how can we handle this? Does it mean that we need to randomly crop into the size that is dividable by 16, e.g. 208x208x3? If we do so, we might miss some information of the image, e.g. the whole image only contains a face of a bear. CNN does not have this problem.

how to use this model for image generation?

Thanks for the great work. I removed the classification head and trying to use this repo for image generation but I get really bad results. All images have patchy looks and very low quality. I played with number of heads, number of layers, LR etc, but didnt really matter.

What would be the most sensible approach to generate images with the encoder part?

Patch To Embedding correct?

In line 95 of ViT [[self.patch_to_embedding = nn.Linear(patch_dim, dim)]]
Is it supposed to be a nn.Linear layer? I believe its a learnable tensor. The paper says "E is a trainable linear projection that maps each vectorized patch to the model dimension D". Yannic also referred E as a linear projection matrix. Could you please share your thoughts?

NB: I have modified and run E as a nn.parameter type tensor and produces similar results

About flatening the patch

Hi,

Thanks for sharing this implementation. I got one question when flattening the HxW patch to D dimension, you use a FC layer to map it to a fixed dimension. But in the original paper, they use 224x224 to train and 384x384 to test, which can not be achieved if the flatten layer is fixed. Also, in another repo you shared (https://github.com/rwightman/pytorch-image-models/blob/6f43aeb2526f9fc35cde5262df939ea23a18006c/timm/models/vision_transformer.py#L146), they use 1D conv to avoid resolution mismatch problem. Do you know which one is correct? Thanks!

Increase Performance

Hello @lucidrains ,
I use vit-transform for spesific data.Image size is 320x320 and number of classes equal to 2. I set parameters for my dataset and it reached %64.5 test accuracy.Have you any suggestion for parameters?Because I get average %83 test accuracy with other models.

efficient_transformer = Linformer( dim=256, seq_len=1024+1, # 7x7 patches + 1 cls-token depth=12, heads=8, k=64)
model = ViT( dim=256, image_size=320, patch_size=10, num_classes=2, transformer=efficient_transformer, channels=3, ).to(device)

cls_token

All samples in a batch share the same cls_token(because in the code, the cls_token is repeated for batch_size), but how they change to be different during loss backward? As the cls_token was used as the classifier input, then all samples in a batch will be classified as the same label?

The model doesn't converge

Hi,

Thank you for your work.

I have tested to train your implementation for action classification on Kinetics400, but find the training not convergent.
image
Note that the learning rate is calculated based on the paper: 4096 ~ 6e-4 by Linear Scaling Rule. I also applied warmup but the loss plateau while warming up.

I have also tested the pretrained model in timm. Also not convergent.
image

Do you have any suggestion for better training such stand-alone transformers? Thanks.

Attention maps

Hi! First, thanks for the great resource.
I was wondering how difficult would be to implement the attention results they show in the Fig. 6 and Fig 13 of the paper.
I am not quite familiar with transformers. This is similar to GradCam o some different approach?

Why number of params dost not consistent?

In Table 1 of the paper, the configuration of the ViT-Base Model is provided. However, the number of parameters in the paper model and this model is inconsistent: 86M (in the paper) vs. 91M (in this model).

Is it because the paper does not take the MLP into account?

That's my script:

import torch
from vit_pytorch import ViT

model = ViT(
    image_size = 224,
    patch_size = 16,
    num_classes = 1000,
    dim = 768,
    depth = 12,
    heads = 12,
    mlp_dim = 3072
)

n_parameters = sum(p.numel() for p in model.parameters() if p.requires_grad)
print('number of params:', n_parameters)

for name, p in model.named_parameters():
    print(name, p.numel())

and that's the output:

number of params: 91206376
pos_embedding 151296
cls_token 768
patch_to_embedding.weight 589824
patch_to_embedding.bias 768
transformer.layers.0.0.fn.norm.weight 768
transformer.layers.0.0.fn.norm.bias 768
transformer.layers.0.0.fn.fn.to_qkv.weight 1769472
transformer.layers.0.0.fn.fn.to_out.0.weight 589824
transformer.layers.0.0.fn.fn.to_out.0.bias 768
transformer.layers.0.1.fn.norm.weight 768
transformer.layers.0.1.fn.norm.bias 768
transformer.layers.0.1.fn.fn.net.0.weight 2359296
transformer.layers.0.1.fn.fn.net.0.bias 3072
transformer.layers.0.1.fn.fn.net.3.weight 2359296
transformer.layers.0.1.fn.fn.net.3.bias 768
transformer.layers.1.0.fn.norm.weight 768
...
transformer.layers.11.1.fn.fn.net.3.bias 768
mlp_head.0.weight 768
mlp_head.0.bias 768
mlp_head.1.weight 2359296
mlp_head.1.bias 3072
mlp_head.4.weight 3072000
mlp_head.4.bias 1000

Using masks as preprocessing for classification [FR]

Maybe it's a little bit too early to ask for this but could it be possible to specify regions within an image for ViT to perfom the prediction? I was thinking on a binary mask, for example, which could be used for the tiling step in order to obtain different images sequences.

I am thinking on a pipeline where, in order to increase resolution, you could specify the regions to perform the training based on whatever reason you find it suitable (previous attention maps for example 😄).

Train result on my own dataset. A Big gap between Train and Valid dataset

Hey guys.
First of all. This is a great job and thanks to the authors.
Then my question is...
Recently I use this code on my own dataset. A simple binary-classification problem.
The performance on the training dataset is good, but not as well as the validation dataset.
The Loss curve is...
image

My model is
model = ViT(
dim=128,
image_size=224,
patch_size=32,
num_classes=2,
depth=12,
heads=8,
mlp_dim=512,
channels=3,
)

Training Dataset has 1200+ images, Validation Dataset has 300+ images.

Can someone give me some suggestions, how to solve this problem?

I think there are several possibilities. Maybe I need a pretrained model? Or I did the wrong way in the training of transformer model?

Anyone tried to train this code with Imagenet from scratch ?

Thanks for the amazing work !!

I follow the hyperparameter described in the original paper, with Adam optimizer, batch size = 4096, lr=3x10−3, weight_decay = 0.3, dropout = 0.1, but it seems that the regularization is too strong and the model can not converge well.

git clone not working with jupyter notebook

Very happy to see transformers moving further into vision and big thanks for this repo!

I wanted to point out that git clone isn't working properly for this repo due to some issue with the new .ipynb notebook:

C:\Users\lessw>git clone https://github.com/lucidrains/vit-pytorch.git vit2
Cloning into 'vit2'...
remote: Enumerating objects: 94, done.
remote: Counting objects: 100% (94/94), done.
remote: Compressing objects: 100% (73/73), done.
remote: Total 94 (delta 53), reused 48 (delta 18), pack-reused 0
Unpacking objects: 100% (94/94), done.
error: unable to create file examples/VisualTransformer | Cats&Dogs Edition.ipynb: Invalid argument
fatal: unable to checkout working tree
warning: Clone succeeded, but checkout failed.
You can inspect what was checked out with 'git status'
and retry with 'git restore --source=HEAD :/'

I hit this yesterday as well and it further complicated any git pulls to sync to latest updates. (hence trying now with vit2)

How to set appropriate learning rate ?

vit = ViT( image_size=448, patch_size=32, num_classes=180, dim=1024, depth=8, heads=8, mlp_dim=2048, dropout=0.5, emb_dropout=0.5 ).cuda()
optimizer = torch.optim.Adam(vit.parameters(), lr=5e-3, weight_decay=0.1)
I tried to train ViT on a 180-class dataset and used the shown config but loss doesn't descend during training.
Any suggestion to solve ?

Multi-label classification

My dataset contains 5 unique labels for instance:

1 0 NaN 1 0

Where 1 means it has that feature. 0 means it doesn’t and NaN means that we have no observation about that feature and it shouldn’t included in loss.

Is it possible to make ViT multi-label for a dataset like this?

Isn't the softmax of the attention matrix calculated incorrectly?

The softmax is currently calculated as follows:

attn = dots.softmax(dim=-1)

The softmax is only applied on the last dimension instead of the whole matrix. Instead, I think it should be implemented like that:

attn = rearrange(dots, 'b h i j -> b h (i j)').softmax(dim=-1).view_as(dots)

Is this correct or am I missing anything?

Eq. 4 in the paper.

Thank you for creating this repo! It's so beneficial for us!
I'm wondering about Eq.4( i.e. MLP Head) in their paper and your implementation.

In their paper, Eq.4 is written in:

Screen Shot 2020-10-11 at 3 29 03

This equation is supposed to indicate MLP Head, but, clearly this is not MLP but Layer Norm. Is this typo? Or, am I missing something?

And, in your implementation of MLP Head, it seems just FFN(except output shape). So, based on what part of the paper did you implement it(apparently not Eq.4)?

In short, my questions are:

  1. Is Eq.4 typo? Or, am I missing something?
  2. Based on what part of the paper did you implement MLP Head?

Again, thank you for such a quick work. Looking forward to your reply!

Regression

Hi, I am doing a regression task on images, predicting 6 numbers form images. Does it makes sense to use the CLS token or can I just pool the last layer of the transformer and connect it to the mlp head?

Masking Patches Question

Hi! Hope you're doing well and thanks for the great resource.

Many of my images have padding-patches, so I am trying to use the mask input to have the model ignore these padding patches. It seems I'm encountering an error that when I use these masks, the output of my model becomes [nan, nan] (for a 2-class output problem).

I believe this may be because in the masked_fill_ (line 59 in vit_pytorch.py),the patch indices in the mask I want to ignore (which I've set to False in the mask) gets set to float(-inf) in dots. This then gets softmaxed, resulting in softmax(-inf) -> nan.

As a fix, I've added this after line 62 in vit_pytorch.py: attn[torch.isnan(attn)] = float(0)

I'm relatively new to attention, and was wondering if this is a correct way to approach this.

Thanks!

Image retrieval / similarity

Thanks for this great resource.
Planning to play with your cats and dog example.
I am wondering whether this approach could be used for image retrieval / image similarity?
Have you looked into it?

fastai compatibility

would it be possible to make the distill vit compatible with fastai? both the vanilla vit and efficient vit work fine.

Pair images prediction

Hi,
Is it possible to input two images (a pair), like Image1<sep>Image2 into the network?

Pre-training weights

Will you provide pre-training weights? I have very little data. Pre-training weights may help me a lot

Loss doesn't drop when training on ImageNet

Hi, Great Thanks for sharing the code!
I found that the loss was always stable at around 7 when I training it with ImageNet on one 3090. Have you tried it on imageNet successfully with vit-pytorch?

THIS IS THE HYPERPARAMETERS I HAVA.

batch_size = 192
image_size = 256
patch_size = 16
num_layers = 8
num_head = 8
mlp_dim = 512
dim_model = 512
num_class = 1000
channel = 3
dropout = 0.4
learning_rate = 3e-4
beta1 = 0.9
beta2 = 0.999
weight_decay = 0.01
epoches = 20
num_workers = 4

Here are some logs when i training:

2020-12-01 20:12:48,410 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6519/6672 - Iter loss : 27.8335 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:12:50,900 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6520/6672 - Iter loss : 33.7667 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:12:52,467 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6521/6672 - Iter loss : 16.5766 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:12:53,450 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6522/6672 - Iter loss : 9.5950 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:12:54,919 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6523/6672 - Iter loss : 12.1596 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:12:57,193 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6524/6672 - Iter loss : 8.5739 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:12:58,304 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6525/6672 - Iter loss : 8.2490 - Iter acc: 0.0052 - Num correct: 1
2020-12-01 20:12:59,285 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6526/6672 - Iter loss : 6.9780 - Iter acc: 0.0052 - Num correct: 1
2020-12-01 20:13:05,256 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6527/6672 - Iter loss : 6.9443 - Iter acc: 0.0052 - Num correct: 1
2020-12-01 20:13:07,644 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6528/6672 - Iter loss : 7.0152 - Iter acc: 0.0052 - Num correct: 1
2020-12-01 20:13:09,009 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6529/6672 - Iter loss : 7.0805 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:13:09,993 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6530/6672 - Iter loss : 7.1097 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:13:13,525 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6531/6672 - Iter loss : 7.1950 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:13:20,461 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6532/6672 - Iter loss : 10.4532 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:13:23,267 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6533/6672 - Iter loss : 11.6160 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:13:24,247 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6534/6672 - Iter loss : 14.5333 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:13:25,253 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6535/6672 - Iter loss : 49.1107 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:13:27,575 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6536/6672 - Iter loss : 36.7893 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:13:29,475 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6537/6672 - Iter loss : 22.1718 - Iter acc: 0.0104 - Num correct: 2
2020-12-01 20:13:30,457 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6538/6672 - Iter loss : 7.3235 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:13:31,441 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6539/6672 - Iter loss : 7.0744 - Iter acc: 0.0052 - Num correct: 1
2020-12-01 20:13:38,798 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6540/6672 - Iter loss : 8.2893 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:13:45,064 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6541/6672 - Iter loss : 7.2673 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:13:46,039 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6542/6672 - Iter loss : 8.7027 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:13:49,190 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6543/6672 - Iter loss : 7.0044 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:13:59,167 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6544/6672 - Iter loss : 6.9146 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:14:05,451 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6545/6672 - Iter loss : 7.1504 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:14:06,431 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6546/6672 - Iter loss : 6.9319 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:14:07,437 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6547/6672 - Iter loss : 6.9291 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:14:09,486 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6548/6672 - Iter loss : 6.8974 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:14:13,987 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6549/6672 - Iter loss : 6.9187 - Iter acc: 0.0104 - Num correct: 2
2020-12-01 20:14:14,967 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6550/6672 - Iter loss : 7.0641 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:14:15,952 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6551/6672 - Iter loss : 7.0853 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:14:18,097 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6552/6672 - Iter loss : 7.2872 - Iter acc: 0.0052 - Num correct: 1
2020-12-01 20:14:23,045 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6553/6672 - Iter loss : 7.0384 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:14:24,026 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6554/6672 - Iter loss : 6.9754 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:14:25,033 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6555/6672 - Iter loss : 7.0169 - Iter acc: 0.0052 - Num correct: 1
2020-12-01 20:14:26,194 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6556/6672 - Iter loss : 6.9243 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:14:29,584 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6557/6672 - Iter loss : 6.9077 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:14:30,563 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6558/6672 - Iter loss : 6.9486 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:14:31,734 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6559/6672 - Iter loss : 6.9292 - Iter acc: 0.0052 - Num correct: 1
2020-12-01 20:14:41,763 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6560/6672 - Iter loss : 6.9972 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:14:47,781 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6561/6672 - Iter loss : 7.0302 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:14:48,764 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6562/6672 - Iter loss : 6.9987 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:14:50,913 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6563/6672 - Iter loss : 6.9165 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:14:57,334 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6564/6672 - Iter loss : 6.9228 - Iter acc: 0.0052 - Num correct: 1
2020-12-01 20:15:03,513 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6565/6672 - Iter loss : 7.0496 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:15:04,494 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6566/6672 - Iter loss : 6.9134 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:15:05,499 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6567/6672 - Iter loss : 7.1323 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:15:06,459 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6568/6672 - Iter loss : 6.9309 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:15:10,015 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6569/6672 - Iter loss : 6.9452 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:15:10,995 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6570/6672 - Iter loss : 6.9216 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:15:11,978 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6571/6672 - Iter loss : 7.0044 - Iter acc: 0.0000 - Num correct: 0
2020-12-01 20:15:12,938 - train.py[line:103] - INFO: Epoch : 1/20 - Iter : 6572/6672 - Iter loss : 6.9816 - Iter acc: 0.0000 - Num correct: 0

issue in one line..urgent ..please resolve

train_data = CatsDogsDataset(train_list, transform=train_transforms)
** valid_data = CatsDogsDataset(valid_list, transform=test_transforms) ** #this one
test_data = CatsDogsDataset(test_list, transform=test_transforms)

in this line transform=val_transforms shoulb be there for valid_data , Am I correct1???

code for Image classification pytorch

Hi!! Awsome work!!
the code that you used (vit_pytorch for classifying catsanddogs) is for image classification....i referred to the code...it runs model on train and validation set...if We have a separate test set...how to run model on it and generate confusion matrix??Also,there is no provision of checkpt and callbacks...

Also, in my case i am classifying cancer as benign and malignant..and when I run code similar to the one u did on validation set...it shows all predicted labels and targets as the SAME..Pls help!!!!

Finetune pretrained model for classfication

To finetune the already pre-trained self-supervised model, do I use the snippet below?

import torch.nn as nn
checkpoint = torch.load('./pretrained-net.pt')
model2 = model.load_state_dict(checkpoint)
model.output =  nn.Linear(2048, 1000) #Do I need to add this? since the output layer was already given the vit model

Is there pre-trained weight?

Hi,
ViT seems to perform better with large dataset pre-trained.
Will pre-trained weights provided in the future?

Thanks.

Train result on cifar10 classification dataset without pretrained weight

Just share result.

Transformer structure

net = ViT(
image_size = 32,
patch_size = 16,
num_classes = 10,
dim = 1024,
depth = 12,
heads = 8,
mlp_dim = 2048
)

Result :
Epoch: 197
[============================ 391/391 ===========================>] Step: 90ms | Tot: 44s860ms | Loss: 1.394 | Acc: 49.418% (24709/50000)
[============================ 100/100 ===========================>] Step: 22ms | Tot: 2s680ms | Loss: 1.387 | Acc: 51.130% (5113/10000)

Epoch: 198
[============================ 391/391 ===========================>] Step: 91ms | Tot: 44s277ms | Loss: 1.398 | Acc: 49.366% (24683/50000)
[============================ 100/100 ===========================>] Step: 29ms | Tot: 2s703ms | Loss: 1.349 | Acc: 51.000% (5100/10000)

Epoch: 199
[============================ 391/391 ===========================>] Step: 91ms | Tot: 44s584ms | Loss: 1.396 | Acc: 49.508% (24754/50000)
[============================ 100/100 ===========================>] Step: 29ms | Tot: 2s735ms | Loss: 1.410 | Acc: 49.110% (4911/10000)

Use on 1-Dimensional data

Thank you very much for your code.
I was wondering if we could use the model on data of shape 1x512x2 (HxWxC)
What changes can we make in the code to make it compatible with such data dimensions?

Model doesn't converge

We are trying to apply this method on a medical dataset, and have about 70K images (224 res) for 5 classes. However, our training doesn't converge (we tried a range of learning rates e.g. 3e-3, 3e-4 etc.) however doesn't seem to work. Currently our model outputs 45% accuracy where the average accuracy for this dataset is around 85-90% (we trained for 100 epochs). Is there anything else we should tune?

Also, here is our configuration:

batch_size = 64
epochs = 400
lr = 3e-4
gamma = 0.7
seed = 42

efficient_transformer = Linformer(
    dim=128,
    seq_len=49 + 1,  # 7x7 patches + 1 cls-token
    depth=4,
    heads=8,
    k=64
)

# Visual Transformer

model = ViT(
    dim=128,
    image_size=224,
    patch_size=32,
    num_classes=5,
    transformer=efficient_transformer,  # nn.Transformer(d_model=128, nhead=8),
    channels=1,
).to(device)

Thank you very much!

Problem with ResNet

Hi! I would like to train Vit using distiller in a dataset with grayscale images, but I am having problems with the ResNet since it is expecting inputs with 3 channels and my images have only 1. Do you have any suggestions? Thanks!

This is the error:

RuntimeError: Given groups=1, weight of size [64, 3, 7, 7], expected input[64, 1, 224, 224] to have 3 channels, but got 1 channels instead

Purpose of nn.Identity() in the ViT class

Hey!

Awesome project :) Not really an issue, more of a question. Referring to this snippet:

self.to_cls_token = nn.Identity()

...

x = self.transformer(x, mask)
x = self.to_cls_token(x[:, 0])
return self.mlp_head(x)

Can I please know what's the purpose of the nn.Identity() in the ViT class? As an identity transformation/matrix, it doesn't really change the incoming x value from the Transformer(), right?

What effect does it have in terms of changing the inputs passed into the MLP Head?

Any help appreciated!

Thanks :D

Recommend Projects

  • React photo React

    A declarative, efficient, and flexible JavaScript library for building user interfaces.

  • Vue.js photo Vue.js

    🖖 Vue.js is a progressive, incrementally-adoptable JavaScript framework for building UI on the web.

  • Typescript photo Typescript

    TypeScript is a superset of JavaScript that compiles to clean JavaScript output.

  • TensorFlow photo TensorFlow

    An Open Source Machine Learning Framework for Everyone

  • Django photo Django

    The Web framework for perfectionists with deadlines.

  • D3 photo D3

    Bring data to life with SVG, Canvas and HTML. 📊📈🎉

Recommend Topics

  • javascript

    JavaScript (JS) is a lightweight interpreted programming language with first-class functions.

  • web

    Some thing interesting about web. New door for the world.

  • server

    A server is a program made to process requests and deliver data to clients.

  • Machine learning

    Machine learning is a way of modeling and interpreting data that allows a piece of software to respond intelligently.

  • Game

    Some thing interesting about game, make everyone happy.

Recommend Org

  • Facebook photo Facebook

    We are working to build community through open source technology. NB: members must have two-factor auth.

  • Microsoft photo Microsoft

    Open source projects and samples from Microsoft.

  • Google photo Google

    Google ❤️ Open Source for everyone.

  • D3 photo D3

    Data-Driven Documents codes.