From 642faf311fdb6dae01a8f3ae5b0ec316d7676d29 Mon Sep 17 00:00:00 2001 From: Isabella Deanhardt <139717054+bdeanhardt@users.noreply.github.com> Date: Sun, 10 Nov 2024 19:04:24 -0500 Subject: [PATCH 1/3] test_configs.py docstrings --- tests/test_configs.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/test_configs.py b/tests/test_configs.py index 2d3fe83..299d0b3 100644 --- a/tests/test_configs.py +++ b/tests/test_configs.py @@ -7,16 +7,20 @@ def test_unet_in_registry(): + """Check that 'nested_unet' and 'unet' models are correctly registered in the Model Registry.""" assert config.get_model("nested_unet") is not None assert config.get_model("unet") is not None def test_unet_in_pipeline(): + """Check that 'nested_unet' and 'unet' models have corresponding pipelines defined.""" assert config.get_pipeline("unet") is not None assert config.get_pipeline("nested_unet") is not None def test_config_cc12m_64x64(): + """Check that the 'cc12m_64x64' configuration file loads successfully for all pipeline modes (trainer, sampler, demo).""" + f = "configs/models/cc12m_64x64.yaml" f = "configs/models/cc12m_64x64.yaml" args = config.get_arguments( mode="trainer", @@ -44,6 +48,7 @@ def test_config_cc12m_64x64(): def test_config_cc12m_256x256(): + """Check that the 'cc12m_256x256' configuration loads with 'nested_unet' as model in all modes (trainer, sampler, demo).""" f = "configs/models/cc12m_256x256.yaml" args = config.get_arguments( args=["--model=nested_unet"], @@ -75,6 +80,7 @@ def test_config_cc12m_256x256(): def test_config_cc12m_1024x1024(): + """Check that the 'cc12m_1024x1024' configuration loads with 'nested2_unet' model in all modes (trainer, sampler, demo).""" f = "configs/models/cc12m_1024x1024.yaml" args = config.get_arguments( args=["--model=nested2_unet"], From 1a82a4bf32876a465c38f378d20dc68ab6a553e9 Mon Sep 17 00:00:00 2001 From: Bella Deanhardt Date: Tue, 12 Nov 2024 09:25:09 -0500 Subject: [PATCH 2/3] added docstrings to remaining test cases --- tests/test_generate_batch.py | 10 ++++++++++ tests/test_generate_sample.py | 4 ++++ tests/test_imports.py | 4 ++++ tests/test_models.py | 3 +++ tests/test_reader.py | 3 +++ tests/test_train.py | 1 + 6 files changed, 25 insertions(+) diff --git a/tests/test_generate_batch.py b/tests/test_generate_batch.py index 2334ac2..4edb759 100644 --- a/tests/test_generate_batch.py +++ b/tests/test_generate_batch.py @@ -8,6 +8,10 @@ def test_small_batch(): + """ + Test small batch generation with T5 model. + Check that basic data generation pipeline works with minimal settings. + """ args = Namespace( batch_size=10, test_file_list="tests/test_files/sample_training_0.tsv", @@ -33,6 +37,12 @@ def test_small_batch(): def test_generate_batch(): + """ + Test batch generation with default config settings. + + Note: This test currently only sets up the configuration but doesn't execute + the generation (ends with pass statement). + """ args = config.get_arguments(mode="sampler") args.batch_size = 10 args.test_file_list = "tests/test_files/sample_training_0.tsv" diff --git a/tests/test_generate_sample.py b/tests/test_generate_sample.py index d36c889..02c70e5 100644 --- a/tests/test_generate_sample.py +++ b/tests/test_generate_sample.py @@ -4,6 +4,10 @@ def test_load_flick_config(): + """ + Test loading of cc12m_64x64.yaml config file. + Checks image dimensions are correctly loaded in reader config. + """ args = config.get_arguments( "", mode="demo", diff --git a/tests/test_imports.py b/tests/test_imports.py index 19138c2..04474a8 100644 --- a/tests/test_imports.py +++ b/tests/test_imports.py @@ -1,6 +1,7 @@ # For licensing see accompanying LICENSE file. # Copyright (C) 2024 Apple Inc. All rights reserved. def test_top_level_imports_work(): + """Checks that all top-level ml_mdm module imports are accessible.""" from ml_mdm import ( config, diffusion, @@ -16,6 +17,7 @@ def test_top_level_imports_work(): def test_cli_imports_work(): + """Checks that all CLI module imports are accessible.""" from ml_mdm.clis import ( download_tar_from_index, generate_batch, @@ -25,8 +27,10 @@ def test_cli_imports_work(): def test_model_imports_work(): + """Checks that all model module imports are accessible.""" from ml_mdm.models import model_ema, nested_unet, unet def test_lm_imports_work(): + """Checks that all language model module imports are accessible.""" from ml_mdm.language_models import factory, tokenizer diff --git a/tests/test_models.py b/tests/test_models.py index b945fb5..2a96dad 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -13,6 +13,7 @@ def test_initialize_unet(): + """Test UNet model and EMA initialization with default configs.""" unet_config = models.unet.UNetConfig() diffusion_config = diffusion.DiffusionConfig( use_vdm_loss_weights=True, model_output_scale=0.1 @@ -30,6 +31,7 @@ def test_initialize_unet(): def test_all_registered_models(): + """Test instantiation of all models in the registry with default configs.""" for config_name, additional_info in config.MODEL_CONFIG_REGISTRY.items(): model_name = additional_info["model"] config_cls = additional_info["config"] @@ -44,6 +46,7 @@ def test_all_registered_models(): @pytest.mark.gpu def test_initialize_pretrained(): + """Test loading pretrained 64x64 model on GPU if available.""" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") args = config.get_arguments( diff --git a/tests/test_reader.py b/tests/test_reader.py index bcc3ef7..dd6d983 100644 --- a/tests/test_reader.py +++ b/tests/test_reader.py @@ -10,6 +10,7 @@ def test_get_dataset(): + """Test dataset loading and verify sample format and dimensions.""" tokenizer = factory.create_tokenizer("data/t5.vocab") dataset = reader.get_dataset( tokenizer=tokenizer, @@ -31,6 +32,7 @@ def test_get_dataset(): def test_get_dataset_partition(): + """Test dataset partitioning and iteration.""" tokenizer = factory.create_tokenizer("data/t5.vocab") train_loader = reader.get_dataset_partition( partition_num=0, @@ -46,6 +48,7 @@ def test_get_dataset_partition(): def test_process_text(): + """Test text tokenization with default reader config.""" line = "A bicycle on top of a boat." tokenizer = factory.create_tokenizer("data/t5.vocab") tokens = reader.process_text( diff --git a/tests/test_train.py b/tests/test_train.py index c1f2f52..073fcb4 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -23,6 +23,7 @@ reason="more effective to test this with torchrun, just here for documentation" ) def test_small(): + """Test minimal training run with single process setup.""" os.environ["RANK"] = "0" os.environ["WORLD_SIZE"] = "1" os.environ["LOCAL_RANK"] = "0" From 5cad65072b2680f2cc581cd75e6deebfa1f4e6d5 Mon Sep 17 00:00:00 2001 From: Bella Deanhardt Date: Tue, 12 Nov 2024 09:34:51 -0500 Subject: [PATCH 3/3] removed duplicate line --- tests/test_configs.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_configs.py b/tests/test_configs.py index 299d0b3..ed622ae 100644 --- a/tests/test_configs.py +++ b/tests/test_configs.py @@ -21,7 +21,6 @@ def test_unet_in_pipeline(): def test_config_cc12m_64x64(): """Check that the 'cc12m_64x64' configuration file loads successfully for all pipeline modes (trainer, sampler, demo).""" f = "configs/models/cc12m_64x64.yaml" - f = "configs/models/cc12m_64x64.yaml" args = config.get_arguments( mode="trainer", additional_config_paths=[f],