Skip to content

Train Basic Model on CIFAR10 Dataset - 🎨🖥️ Utilizes CIFAR-10 dataset with 60000 32x32 color images in 10 classes. Demonstrates loading using torchvision and training with pretrained models like ResNet18, AlexNet, VGG16, DenseNet161, and Inception. Notebook available for experimentation.

Notifications You must be signed in to change notification settings

deBUGger404/cifar-pytorch_model

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

21 Commits
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Train Basic Model on CIFAR10-Dataset

Contents

Introduction

The CIFAR-10 dataset consists of 60000 32x32 colour images in 10 classes, with 6000 images per class. There are 50000 training images and 10000 test images.

below is the 6 random images with their respective label:

There is a package of python called torchvision, that has data loaders for CIFAR10 and data transformers for images using torch.utils.data.DataLoader.

Below an example of how to load CIFAR10 dataset using torchvision:

import torch
import torchvision
## load data CIFAR10
train_dataset = torchvision.datasets.CIFAR10(root='./train_data', train=True, download=True)
train_dataloader = torch.utils.data.DataLoader(train_dataset, batch_size=128, shuffle=True, num_workers=2)

Prerequisites

  • Python>=3.6
  • PyTorch >=1.4
  • Library are mentioned in requirenments.txt

Training

I used pretrained resnet18 for model training. you can use any other pretrained model according to you problem.

import torchvision.models as models
alexnet = models.alexnet()
vgg16 = models.vgg16()
densenet = models.densenet161()
inception = models.inception_v3()

There are two things for pytorch model training:

  1. Notebook - you can just download and play with it
  2. python scripts:
    # Start training with: 
    python main.py
    
    # You can manually pass the attributes for the training: 
    python main.py --lr=0.01 --epoch 20 --model_path './cifar_model.pth'
    
    # Start infrence with:
     python3.6 prediction.py --model_path './cifar_model.pth'
    

Give a ⭐ to this Repository!

About

Train Basic Model on CIFAR10 Dataset - 🎨🖥️ Utilizes CIFAR-10 dataset with 60000 32x32 color images in 10 classes. Demonstrates loading using torchvision and training with pretrained models like ResNet18, AlexNet, VGG16, DenseNet161, and Inception. Notebook available for experimentation.

Topics

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published