-
Notifications
You must be signed in to change notification settings - Fork 11
/
train_uvit.py
65 lines (53 loc) · 2.64 KB
/
train_uvit.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from types import SimpleNamespace
import wandb
import torch
from cloud_diffusion.dataset import download_dataset, CloudDataset
from cloud_diffusion.utils import NoisifyDataloader, MiniTrainer, set_seed, parse_args
from cloud_diffusion.simple_diffusion import noisify_uvit, simple_diffusion_sampler
from cloud_diffusion.models import UViT, get_uvit_params
DEBUG = True
PROJECT_NAME = "ddpm_clouds"
DATASET_ARTIFACT = 'capecape/gtc/np_dataset:v1'
config = SimpleNamespace(
epochs = 100, # number of epochs
model_name="uvit_small", # model name to save
strategy="simple_diffusion", # strategy to use [ddpm, simple_diffusion]
noise_steps=1000, # number of noise steps on the diffusion process
sampler_steps=500, # number of sampler steps on the diffusion process
seed = 42, # random seed
batch_size = 6, # batch size
img_size = 512, # image size
device = "cuda", # device
num_workers=8, # number of workers for dataloader
num_frames=4, # number of frames to use as input
lr = 5e-4, # learning rate
validation_days=3, # number of days to use for validation
n_preds=8, # number of predictions to make
log_every_epoch = 5, # log every n epochs to wandb
)
def train_func(config):
config.model_params = get_uvit_params(config.model_name, config.num_frames)
set_seed(config.seed)
device = torch.device(config.device)
# downlaod the dataset from the wandb.Artifact
files = download_dataset(DATASET_ARTIFACT, PROJECT_NAME)
files = files[0:5] if DEBUG else files
train_days, valid_days = files[:-config.validation_days], files[-config.validation_days:]
train_ds = CloudDataset(files=train_days, num_frames=config.num_frames, img_size=config.img_size)
valid_ds = CloudDataset(files=valid_days, num_frames=config.num_frames, img_size=config.img_size).shuffle()
# UViT dataloaders
train_dataloader = NoisifyDataloader(train_ds, config.batch_size, shuffle=True,
noise_func=noisify_uvit, num_workers=config.num_workers)
valid_dataloader = NoisifyDataloader(valid_ds, config.batch_size, shuffle=False,
noise_func=noisify_uvit, num_workers=config.num_workers)
# model setup
model = UViT(**config.model_params)
# sampler
sampler = simple_diffusion_sampler(steps=config.sampler_steps)
# A simple training loop
trainer = MiniTrainer(train_dataloader, valid_dataloader, model, sampler, device)
trainer.fit(config)
if __name__=="__main__":
parse_args(config)
with wandb.init(project=PROJECT_NAME, config=config, tags=["sd", config.model_name]):
train_func(config)