Skip to content

Commit

Permalink
Merge pull request #28 from bdeanhardt/patch-1
Browse files Browse the repository at this point in the history
Add doc strings to testcases (25)
  • Loading branch information
luke-carlson authored Nov 13, 2024
2 parents 17e36c6 + 5cad650 commit 534bea9
Show file tree
Hide file tree
Showing 7 changed files with 30 additions and 0 deletions.
5 changes: 5 additions & 0 deletions tests/test_configs.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,16 +7,19 @@


def test_unet_in_registry():
"""Check that 'nested_unet' and 'unet' models are correctly registered in the Model Registry."""
assert config.get_model("nested_unet") is not None
assert config.get_model("unet") is not None


def test_unet_in_pipeline():
"""Check that 'nested_unet' and 'unet' models have corresponding pipelines defined."""
assert config.get_pipeline("unet") is not None
assert config.get_pipeline("nested_unet") is not None


def test_config_cc12m_64x64():
"""Check that the 'cc12m_64x64' configuration file loads successfully for all pipeline modes (trainer, sampler, demo)."""
f = "configs/models/cc12m_64x64.yaml"
args = config.get_arguments(
mode="trainer",
Expand Down Expand Up @@ -44,6 +47,7 @@ def test_config_cc12m_64x64():


def test_config_cc12m_256x256():
"""Check that the 'cc12m_256x256' configuration loads with 'nested_unet' as model in all modes (trainer, sampler, demo)."""
f = "configs/models/cc12m_256x256.yaml"
args = config.get_arguments(
args=["--model=nested_unet"],
Expand Down Expand Up @@ -75,6 +79,7 @@ def test_config_cc12m_256x256():


def test_config_cc12m_1024x1024():
"""Check that the 'cc12m_1024x1024' configuration loads with 'nested2_unet' model in all modes (trainer, sampler, demo)."""
f = "configs/models/cc12m_1024x1024.yaml"
args = config.get_arguments(
args=["--model=nested2_unet"],
Expand Down
10 changes: 10 additions & 0 deletions tests/test_generate_batch.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,6 +8,10 @@


def test_small_batch():
"""
Test small batch generation with T5 model.
Check that basic data generation pipeline works with minimal settings.
"""
args = Namespace(
batch_size=10,
test_file_list="tests/test_files/sample_training_0.tsv",
Expand All @@ -33,6 +37,12 @@ def test_small_batch():


def test_generate_batch():
"""
Test batch generation with default config settings.
Note: This test currently only sets up the configuration but doesn't execute
the generation (ends with pass statement).
"""
args = config.get_arguments(mode="sampler")
args.batch_size = 10
args.test_file_list = "tests/test_files/sample_training_0.tsv"
Expand Down
4 changes: 4 additions & 0 deletions tests/test_generate_sample.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,10 @@


def test_load_flick_config():
"""
Test loading of cc12m_64x64.yaml config file.
Checks image dimensions are correctly loaded in reader config.
"""
args = config.get_arguments(
"",
mode="demo",
Expand Down
4 changes: 4 additions & 0 deletions tests/test_imports.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
# For licensing see accompanying LICENSE file.
# Copyright (C) 2024 Apple Inc. All rights reserved.
def test_top_level_imports_work():
"""Checks that all top-level ml_mdm module imports are accessible."""
from ml_mdm import (
config,
diffusion,
Expand All @@ -16,6 +17,7 @@ def test_top_level_imports_work():


def test_cli_imports_work():
"""Checks that all CLI module imports are accessible."""
from ml_mdm.clis import (
download_tar_from_index,
generate_batch,
Expand All @@ -25,8 +27,10 @@ def test_cli_imports_work():


def test_model_imports_work():
"""Checks that all model module imports are accessible."""
from ml_mdm.models import model_ema, nested_unet, unet


def test_lm_imports_work():
"""Checks that all language model module imports are accessible."""
from ml_mdm.language_models import factory, tokenizer
3 changes: 3 additions & 0 deletions tests/test_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@


def test_initialize_unet():
"""Test UNet model and EMA initialization with default configs."""
unet_config = models.unet.UNetConfig()
diffusion_config = diffusion.DiffusionConfig(
use_vdm_loss_weights=True, model_output_scale=0.1
Expand All @@ -30,6 +31,7 @@ def test_initialize_unet():


def test_all_registered_models():
"""Test instantiation of all models in the registry with default configs."""
for config_name, additional_info in config.MODEL_CONFIG_REGISTRY.items():
model_name = additional_info["model"]
config_cls = additional_info["config"]
Expand All @@ -44,6 +46,7 @@ def test_all_registered_models():

@pytest.mark.gpu
def test_initialize_pretrained():
"""Test loading pretrained 64x64 model on GPU if available."""
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")

args = config.get_arguments(
Expand Down
3 changes: 3 additions & 0 deletions tests/test_reader.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@


def test_get_dataset():
"""Test dataset loading and verify sample format and dimensions."""
tokenizer = factory.create_tokenizer("data/t5.vocab")
dataset = reader.get_dataset(
tokenizer=tokenizer,
Expand All @@ -31,6 +32,7 @@ def test_get_dataset():


def test_get_dataset_partition():
"""Test dataset partitioning and iteration."""
tokenizer = factory.create_tokenizer("data/t5.vocab")
train_loader = reader.get_dataset_partition(
partition_num=0,
Expand All @@ -46,6 +48,7 @@ def test_get_dataset_partition():


def test_process_text():
"""Test text tokenization with default reader config."""
line = "A bicycle on top of a boat."
tokenizer = factory.create_tokenizer("data/t5.vocab")
tokens = reader.process_text(
Expand Down
1 change: 1 addition & 0 deletions tests/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
reason="more effective to test this with torchrun, just here for documentation"
)
def test_small():
"""Test minimal training run with single process setup."""
os.environ["RANK"] = "0"
os.environ["WORLD_SIZE"] = "1"
os.environ["LOCAL_RANK"] = "0"
Expand Down

0 comments on commit 534bea9

Please sign in to comment.