Skip to content

Commit

Permalink
Merge branch 'main' into fix-all
Browse files Browse the repository at this point in the history
  • Loading branch information
tolgacangoz authored Oct 16, 2024
2 parents 412fbe1 + f9e87fc commit ea93b76
Show file tree
Hide file tree
Showing 6 changed files with 65 additions and 51 deletions.
28 changes: 14 additions & 14 deletions README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# ml_mdm - Matryoshka Diffusion Models

`ml_mdm` is a python package for efficiently training high quality text-to-image diffusion models — brought to the public by [Luke Carlson](https://github.com/jlukecarlson), [Jiatao Gu](https://github.com/MultiPath), [Shuangfei Zhai](https://github.com/Shuangfei), and [Navdeep Jaitly](https://github.com/ndjaitly).
`ml_mdm` is a python package for efficiently training high quality text-to-image diffusion models — brought to the public by [Luke Carlson](https://github.com/luke-carlson), [Jiatao Gu](https://github.com/MultiPath), [Shuangfei Zhai](https://github.com/Shuangfei), and [Navdeep Jaitly](https://github.com/ndjaitly).


---
Expand Down Expand Up @@ -79,7 +79,7 @@ We've uploaded model checkpoints to:
Feel free to download the models or skip further down to train your own. Once a pretrained model is downloaded locally, you can use it in our web demo, pass it as an argument to training, sampling, and more.

```bash
```console
export ASSET_PATH=https://docs-assets.developer.apple.com/ml-research/models/mdm

curl $ASSET_PATH/flickr64/vis_model.pth --output vis_model_64x64.pth
Expand All @@ -91,7 +91,7 @@ curl $ASSET_PATH/flickr1024/vis_model.pth --output vis_model_1024x1024.pth
### Web Demo
You can run your own instance of the web demo (after downloading the checkpoints) with this command:

```bash
```console
torchrun --standalone --nproc_per_node=1 ml_mdm/clis/generate_sample.py --port $YOUR_PORT
```

Expand Down Expand Up @@ -133,7 +133,7 @@ In the `ml_mdm.models` submodule, we've open sourced our implementations of:

Once you've installed `ml_mdm`, download these checkpoints into the repo's directory.

```
```console
curl https://docs-assets.developer.apple.com/ml-research/models/mdm/flickr64/vis_model.pth --output vis_model_64x64.pth
curl https://docs-assets.developer.apple.com/ml-research/models/mdm/flickr256/vis_model.pth --output vis_model_256x256.pth
```
Expand All @@ -145,7 +145,7 @@ The web demo will load each model with a corresponding configuration:

In the demo, you can change a variety of settings and peek into the internals of the model. Set the port you'd like to use by swapping in `$YOUR_PORT` and then run:

```
```console
torchrun --standalone --nproc_per_node=1 ml_mdm/clis/generate_sample.py --port $YOUR_PORT
```

Expand All @@ -154,7 +154,7 @@ If you just want to step through the process of training a model and running a p

> Feel free to try changing a variety of --args either directly in the cli or by editing the config yaml file
```
```console
torchrun --standalone --nproc_per_node=1 ml_mdm/clis/train_parallel.py \
--file-list=tests/test_files/sample_training_0.tsv \
--multinode=0 \
Expand All @@ -174,7 +174,7 @@ You should see a `outputs/vis_model_000100.pth` file. Now lets do something a bi

> The script is based on [img2dataset's CC12M script](https://github.com/rom1504/img2dataset/blob/main/dataset_examples/cc12m.md).
```
```console
curl https://storage.googleapis.com/conceptual_12m/cc12m.tsv | head -n 1000 > cc12m_index.tsv

# Add headers to the file
Expand All @@ -186,13 +186,13 @@ Then prepare and split into train/validation

> This script requires `img2dataset`, either run `pip install '.[data_prep]'` or just `pip install img2dataset`
```
```console
python3 -m ml_mdm.clis.scrape_cc12m \
--cc12m_index cc12m_index.tsv \
--cc12m_local_dir cc12m_download
```
After running this command you will see the following files:
```
```console
training.0.tsv # train index file
validation.tsv # validation index file
cc12m_download/
Expand All @@ -203,7 +203,7 @@ cc12m_download/
### 2. Train
Now that we have our training file, we can select a model config and pass any additional training arguments:

```
```console
# Modify torchrun arguments to fit your GPU setup
torchrun --standalone --nproc_per_node=8 ml_mdm/clis/train_parallel.py \
--file-list=training_0.tsv \
Expand All @@ -216,15 +216,15 @@ torchrun --standalone --nproc_per_node=8 ml_mdm/clis/train_parallel.py \
> If you've downloaded a pretrained model, you can set the `--pretrained-vision-file` argument to point to its location on disk
Once training completes, you'll find the model in the folder defined by the --output-dir argument:
```
```console
2024-07-22:17:58:46,649 INFO [model_ema.py:33] Saving EMA model file: /mnt/data/outputs/vis_model_000100.pth
2024-07-22:17:58:47,448 INFO [unet.py:794] Saving model file: /mnt/data/outputs/vis_model_noema_000100.pth
```


### 3. Sample from the model
Now that we have a trained model, we can generate samples from the diffusion model:
```
```console
torchrun --standalone --nproc_per_node=1 ml_mdm/clis/generate_batch.py \
--config_path configs/models/cc12m_64x64.yaml \
--min-examples 3 --test-file-list validation.tsv \
Expand Down Expand Up @@ -261,7 +261,7 @@ reader_config:
Then you can use our dataset download helper:
```
```console
python -m ml_mdm.clis.download_tar_from_index \
--dataset-config-file configs/datasets/cc12m.yaml \
--subset train --download_tar
Expand Down Expand Up @@ -304,7 +304,7 @@ eval:
### Dataset Structure
The S3 Bucket contains a series of files in this format, take a look at `ml_mdm/clis/scrape_cc12m.py` to generate your own.
```bash
```console
2023-04-01 01:31:30 36147200 images_00000.tar
2023-05-10 11:34:49 1108424 images_00000.tsv
2023-04-01 01:31:26 36454400 images_00001.tar
Expand Down
30 changes: 16 additions & 14 deletions configs/models/cc12m_1024x1024.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,24 +9,26 @@ test_file_list: validation.tsv
# reader-config-file: configs/datasets/reader_config_eval.yaml
# shared_arguments:
output_dir: /mnt/data/outputs
num_diffusion_steps: 1000
reproject_signal: false
model_output_scale: 0
prediction_type: V_PREDICTION
loss_target_type: DDPM
schedule_type: DEEPFLOYD
diffusion_config:
sampler_config:
num_diffusion_steps: 1000
reproject_signal: false
prediction_type: V_PREDICTION
loss_target_type: DDPM
schedule_type: DEEPFLOYD
rescale_signal: 1
schedule_shifted: true
schedule_shifted_power: 2
model_output_scale: 0
use_vdm_loss_weights: false
use_double_loss: true
no_use_residual: true
prediction_length: 129
use_vdm_loss_weights: false
use_double_loss: true
no_use_residual: true
num_training_steps: 1000000
avg_lm_steps: 0
categorical_conditioning: 0
rescale_signal: 1
schedule_shifted: true
schedule_shifted_power: 2
skip_normalization: true
random_low_noise: true
# skip_normalization: true
# random_low_noise: true
vocab_file: data/t5.vocab
text_model: google/flan-t5-xl
model: nested2_unet
Expand Down
27 changes: 14 additions & 13 deletions configs/models/cc12m_256x256.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -9,23 +9,24 @@ test_file_list: validation.tsv
#reader-config-file: configs/datasets/reader_config_eval.yaml
# shared_arguments
output_dir: /mnt/data/outputs
num_diffusion_steps: 1000
reproject_signal: false
model_output_scale: 0
prediction_type: V_PREDICTION
loss_target_type: DDPM
schedule_type: DEEPFLOYD
diffusion_config:
sampler_config:
num_diffusion_steps: 1000
reproject_signal: false
prediction_type: V_PREDICTION
loss_target_type: DDPM
schedule_type: DEEPFLOYD
rescale_signal: 1
schedule_shifted: true
model_output_scale: 0
use_vdm_loss_weights: false
use_double_loss: true
no_use_residual: true
prediction_length: 129
use_vdm_loss_weights: false
use_double_loss: true
no_use_residual: true
num_training_steps: 1000000
avg_lm_steps: 0
categorical_conditioning: 0
rescale_signal: 1
schedule_shifted: true
skip_normalization: true
random_low_noise: true

vocab_file: data/t5.vocab
text_model: google/flan-t5-xl
model: nested_unet
Expand Down
21 changes: 12 additions & 9 deletions configs/models/cc12m_64x64.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -16,17 +16,20 @@ sample_image_size: 64
test_file_list: validation.tsv
device: cuda
model: unet

output_dir: /mnt/data/outputs
num_diffusion_steps: 1000
reproject_signal: false
predict_variances: false
model_output_scale: 0
prediction_type: V_PREDICTION
loss_target_type: HA_STYLE
schedule_type: DEEPFLOYD

diffusion_config:
sampler_config:
num_diffusion_steps: 1000
reproject_signal: false
predict_variances: false
prediction_type: V_PREDICTION
loss_target_type: HA_STYLE
schedule_type: DEEPFLOYD
model_output_scale: 0
use_vdm_loss_weights: false

prediction_length: 129
use_vdm_loss_weights: false
loss_factor: 1
num_training_steps: 5000
num_epochs: 20000
Expand Down
2 changes: 1 addition & 1 deletion ml_mdm/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,7 @@ def forward(
# recompute the noise from pred_low
if not self.diffusion_config.no_use_residual:
assert (
self.diffusion_config.mixed_batch is None
self.diffusion_config.mixed_ratio is None
), "do not support mixed-batch"
x_t, x_t_low = x_t
pred, pred_low = p_t
Expand Down
8 changes: 8 additions & 0 deletions ml_mdm/samplers.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,6 +110,12 @@ class SamplerConfig:
"help": "automatically shift the noise schedule based on the resolution ratios."
},
)
schedule_shifted_power: float = field(
default=1,
metadata={
"help": "noise shifted ratio, by default using 1."
},
)


##########################################################################################
Expand Down Expand Up @@ -242,6 +248,8 @@ def get_image_rescaled(self, images, scale_factor=None):

def get_schedule_shifted(self, gammas, scale_factor=None):
if (scale_factor is not None) and (scale_factor > 1): # rescale noise schecule
p = self._config.schedule_shifted_power
scale_factor = scale_factor ** p
snr = gammas / (1 - gammas)
scaled_snr = snr / scale_factor
gammas = 1 / (1 + 1 / scaled_snr)
Expand Down

0 comments on commit ea93b76

Please sign in to comment.