Skip to content

An easy and minimal implementation of the Visual Transformer (ViT) in PyTorch, from scratch!

Notifications You must be signed in to change notification settings

guglielmocamporese/visual-transformer-pytorch

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

24 Commits
 
 
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

PyTorch Implementation of the Visual Transformer (ViT) from Scratch

Reimplementation of the paper:

"An Image is Worth 16x16 Words: Transformers for Image Recognition at Scale", Dosovitskiy et al, 2020.

arXiv

If you use the code of this repo and you find this project useful, please consider to give a star ⭐!

alt text

Usage

# Imports
import torch
from models.vit import ViT

# Create the model
vit = ViT(
    patch_size=4, 
    num_layers=2, 
    h_dim=256, 
    num_heads=8, 
    num_classes=10, 
    d_ff=2048, 
    max_time_steps=1000, 
    use_clf_token=True,
)

# Inference
model.eval()
x = torch.randn(1, 3, 32, 32) # [B, C, H, W]
logits = vit(x) # [B, N_CL]

Model Configurations

From the paper [link]:

Model Layers Hidden Size MLP Size Heads Params
ViT-Base 12 768 3072 12 86 M
ViT-Large 24 1024 4096 16 307 M
ViT-Huge 32 1280 5120 16 632 M

Train

$ python main.py \
    --mode "train" \
    --model "vit-base" \
    --patch_size 8 \
    --lr 3e-4 \
    --epochs 100

Test

$ python main.py \
    --mode "test" \
    --model "vit-base" \
    --patch_size 8 \
    --model_checkpoint "./checkpoints/vit_base.ckpt"

About

An easy and minimal implementation of the Visual Transformer (ViT) in PyTorch, from scratch!

Topics

Resources

Stars

Watchers

Forks

Languages