Skip to content

YHRen/DDPM_tutorial

Folders and files

NameName
Last commit message
Last commit date

Latest commit

 

History

7 Commits
 
 
 
 
 
 
 
 
 
 
 
 

Repository files navigation

Barebone implementation of DDPM

A Barebone Implementation of DDPM.

CIFAR10:

CIFAR10

FashionMNIST:

FASHION

Usage

Setup Environment

DDPM is very computationally expensive. This repo requires CUDA. If GPU does not have enough memory, try to reduce batch size.

conda env create -f environment.yml
conda activate ddpm_torch
python main.py -h

Train model

Current code supports two datasets: CIFAR10 and FashionMNIST.

python main.py train -h
python main.py train --dataset=cifar10
python main.py train --dataset=fashion

Additional flags such as batch_size, epochs, timesteps and checkpoint intervals ckpt_interval.

The model will start training and saving model weights to ./checkpoints/.

To Sample

python main.py infer -h
python main.py infer <epoch> --sample_n=16 --dataset=cifar10

Using the last checkpoint (cifar10_epc_999.pt) to sample some images. Images will be saved in ./images/.

One can combine all 16 sample trajectories using imagemagic.

montage -density 300 -tile 16x0 -geometry +1+1 -border 2 images/*.png out.png

More complete implementation of DDPM

About

No description, website, or topics provided.

Resources

Stars

Watchers

Forks

Releases

No releases published

Packages

No packages published

Languages