PyTorch implementation of DDPM Demo and Classifier-Free DDPM Demo. Experiments were conducted on MNIST and Cifar10 datasets.
- Red Hat Enterprise Linux Server release 7.6 (Maipo)
- pytorch 1.6.0
- torchvision 0.7.0
- python 3.7.16
-
Download dataset and pth file from DDPM-Demo file:
├── classifier_free_ddpm
│ ├── models
│ └── photos
│ └── ...
├── datasets
├── ddpm
│ ├── models
│ └── photos
│ └── ...
└── README.md
MNIST:
python train.py -b 64 -d 0 -e 20 -t 500
Cifar10:
python train.py -b 64 -d 1 -e 100 -t 1000
Parameter meaning:
abbreviation | full name | meaning |
---|---|---|
-b | --batch_size | batch size |
-d | --datasets_type | datasets type,0:MNISI,1:Cifar-10 |
-e | --epochs | epochs |
-t | --timesteps | timesteps |
-dp | --datasets_path | path of the Datasets |
MNIST:
python inference.py -b 64 -d 0 -t 500 -p models/mnist-500-20-0.0005.pth
Cifar10:
python inference.py -b 64 -d 1 -t 1000 -p models/cifar10-1000-100-0.0002.pth