Skip to content

Implement Vision Transformers for Image Classification using MNSIT Dataset

Notifications You must be signed in to change notification settings

JohanesSetiawan/ViT-MNIST

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

2 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Vision Transformers (ViT) | MNIST Digit Dataset

Vision Transformers is a new architecture for image classification. It is based on the Transformer architecture and is designed to handle images. This notebook demonstrates how to use Vision Transformers to classify the MNIST digit dataset. The dataset is a collection of 28x28 grayscale images of handwritten digits (0-9). The dataset has 60,000 training images and 10,000 test images. The goal is to train a Vision Transformer model to classify the images into their respective digit classes.

Models lists

name_model BATCH_SIZE NUM_HEADS HIDDEN_DIM NUM_ENCODERS EPOCHS
trained_model_2.pth 16 16 1024 16 150
trained_model_1.pth 32 16 768 8 100
trained_model_0.pth 512 8 768 4 60

you can see the graph training and validation (loss & accuracy) in the following here, and you can download the trained models from here.

Installation

  • Linux:
pip install torch torchvision torchaudio pillow gradio

  • Windows (CUDA 11.8 | Pytorch 2.2.0):
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu118

after that

pip install pillow gradio

  • Windows (CUDA 12.1 | Pytorch 2.2.0):
pip3 install torch torchvision torchaudio --index-url https://download.pytorch.org/whl/cu121

after that

pip install pillow gradio

Usage

  1. download models
  2. place models in the models directory
  3. prepare the test image MNIST Dataset to test the model
  4. go to utils/params.py to adjust the parameters (batch_size, num_heads, hidden_dim, num_encoders, and epochs) in the name_model you downloaded, in the models lists table. The code is as follows:
    # change it to the name of the model you downloaded
    self.BATCH_SIZE = <batch_size in the table>
    self.NUM_HEADS = <num_heads in the table>
    self.HIDDEN_DIM = <hidden_dim in the table>
    self.NUM_ENCODERS = <num_encoders in the table>
    self.PATH_MODELS = "path/to/models/name_model.pth"
  1. run the following command:
python app.py

then open the browser and go to http:https://localhost:7860/

References

YouTube | Implement and Train ViT From Scratch for Image Recognition - PyTorch

GitHub | Original Code

About

Implement Vision Transformers for Image Classification using MNSIT Dataset

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

 
 
 

Languages