From 1a5e598d17b1dd61e7c79ba47d1173b66312c4b7 Mon Sep 17 00:00:00 2001 From: Luke Carlson Date: Fri, 1 Nov 2024 10:18:38 -0400 Subject: [PATCH 1/4] use default_factory for mutable fields --- ml_mdm/diffusion.py | 2 +- ml_mdm/models/nested_unet.py | 10 +++++----- ml_mdm/models/unet.py | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/ml_mdm/diffusion.py b/ml_mdm/diffusion.py index 22d55f2..c11b034 100644 --- a/ml_mdm/diffusion.py +++ b/ml_mdm/diffusion.py @@ -29,7 +29,7 @@ def sv(x, f): @dataclass class DiffusionConfig: sampler_config: samplers.SamplerConfig = field( - default=samplers.SamplerConfig(), metadata={"help": "Sampler configuration"} + default_factory=samplers.SamplerConfig, metadata={"help": "Sampler configuration"} ) model_output_scale: float = field( default=0, diff --git a/ml_mdm/models/nested_unet.py b/ml_mdm/models/nested_unet.py index 43a4abb..b87c20c 100644 --- a/ml_mdm/models/nested_unet.py +++ b/ml_mdm/models/nested_unet.py @@ -22,7 +22,7 @@ @dataclass class NestedUNetConfig(UNetConfig): inner_config: UNetConfig = field( - default=UNetConfig(nesting=True), + default_factory=lambda: UNetConfig(nesting=True), metadata={"help": "inner unet used as middle blocks"}, ) skip_mid_blocks: bool = field(default=True) @@ -55,7 +55,7 @@ class NestedUNetConfig(UNetConfig): @dataclass class Nested2UNetConfig(NestedUNetConfig): inner_config: NestedUNetConfig = field( - default=NestedUNetConfig(nesting=True, initialize_inner_with_pretrained=None) + default_factory=lambda: NestedUNetConfig(nesting=True, initialize_inner_with_pretrained=None) ) @@ -63,7 +63,7 @@ class Nested2UNetConfig(NestedUNetConfig): @dataclass class Nested3UNetConfig(Nested2UNetConfig): inner_config: Nested2UNetConfig = field( - default=Nested2UNetConfig(nesting=True, initialize_inner_with_pretrained=None) + default_factory=lambda: Nested2UNetConfig(nesting=True, initialize_inner_with_pretrained=None) ) @@ -71,7 +71,7 @@ class Nested3UNetConfig(Nested2UNetConfig): @dataclass class Nested4UNetConfig(Nested3UNetConfig): inner_config: Nested3UNetConfig = field( - default=Nested3UNetConfig(nesting=True, initialize_inner_with_pretrained=None) + default_factory=lambda: Nested3UNetConfig(nesting=True, initialize_inner_with_pretrained=None) ) @@ -96,7 +96,7 @@ def download(vision_model_path): @config.register_model("nested_unet") class NestedUNet(UNet): def __init__(self, input_channels, output_channels, config: NestedUNetConfig): - super().__init__(input_channels, output_channels, config) + super().__init__(input_channels, output_channels=output_channels, config=config) config.inner_config.conditioning_feature_dim = config.conditioning_feature_dim if getattr(config.inner_config, "inner_config", None) is None: self.inner_unet = UNet(input_channels, output_channels, config.inner_config) diff --git a/ml_mdm/models/unet.py b/ml_mdm/models/unet.py index 94b7d60..43a8506 100644 --- a/ml_mdm/models/unet.py +++ b/ml_mdm/models/unet.py @@ -115,7 +115,7 @@ class UNetConfig: temporal_spatial_ds: bool = field(default=False) temporal_positional_encoding: bool = field(default=False) resnet_config: ResNetConfig = field( - default=ResNetConfig(), metadata={"help": "Resnet configs"} + default_factory=ResNetConfig, metadata={"help": "Resnet configs"} ) def __post_init__(self): From 642faf311fdb6dae01a8f3ae5b0ec316d7676d29 Mon Sep 17 00:00:00 2001 From: Isabella Deanhardt <139717054+bdeanhardt@users.noreply.github.com> Date: Sun, 10 Nov 2024 19:04:24 -0500 Subject: [PATCH 2/4] test_configs.py docstrings --- tests/test_configs.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/tests/test_configs.py b/tests/test_configs.py index 2d3fe83..299d0b3 100644 --- a/tests/test_configs.py +++ b/tests/test_configs.py @@ -7,16 +7,20 @@ def test_unet_in_registry(): + """Check that 'nested_unet' and 'unet' models are correctly registered in the Model Registry.""" assert config.get_model("nested_unet") is not None assert config.get_model("unet") is not None def test_unet_in_pipeline(): + """Check that 'nested_unet' and 'unet' models have corresponding pipelines defined.""" assert config.get_pipeline("unet") is not None assert config.get_pipeline("nested_unet") is not None def test_config_cc12m_64x64(): + """Check that the 'cc12m_64x64' configuration file loads successfully for all pipeline modes (trainer, sampler, demo).""" + f = "configs/models/cc12m_64x64.yaml" f = "configs/models/cc12m_64x64.yaml" args = config.get_arguments( mode="trainer", @@ -44,6 +48,7 @@ def test_config_cc12m_64x64(): def test_config_cc12m_256x256(): + """Check that the 'cc12m_256x256' configuration loads with 'nested_unet' as model in all modes (trainer, sampler, demo).""" f = "configs/models/cc12m_256x256.yaml" args = config.get_arguments( args=["--model=nested_unet"], @@ -75,6 +80,7 @@ def test_config_cc12m_256x256(): def test_config_cc12m_1024x1024(): + """Check that the 'cc12m_1024x1024' configuration loads with 'nested2_unet' model in all modes (trainer, sampler, demo).""" f = "configs/models/cc12m_1024x1024.yaml" args = config.get_arguments( args=["--model=nested2_unet"], From 1a82a4bf32876a465c38f378d20dc68ab6a553e9 Mon Sep 17 00:00:00 2001 From: Bella Deanhardt Date: Tue, 12 Nov 2024 09:25:09 -0500 Subject: [PATCH 3/4] added docstrings to remaining test cases --- tests/test_generate_batch.py | 10 ++++++++++ tests/test_generate_sample.py | 4 ++++ tests/test_imports.py | 4 ++++ tests/test_models.py | 3 +++ tests/test_reader.py | 3 +++ tests/test_train.py | 1 + 6 files changed, 25 insertions(+) diff --git a/tests/test_generate_batch.py b/tests/test_generate_batch.py index 2334ac2..4edb759 100644 --- a/tests/test_generate_batch.py +++ b/tests/test_generate_batch.py @@ -8,6 +8,10 @@ def test_small_batch(): + """ + Test small batch generation with T5 model. + Check that basic data generation pipeline works with minimal settings. + """ args = Namespace( batch_size=10, test_file_list="tests/test_files/sample_training_0.tsv", @@ -33,6 +37,12 @@ def test_small_batch(): def test_generate_batch(): + """ + Test batch generation with default config settings. + + Note: This test currently only sets up the configuration but doesn't execute + the generation (ends with pass statement). + """ args = config.get_arguments(mode="sampler") args.batch_size = 10 args.test_file_list = "tests/test_files/sample_training_0.tsv" diff --git a/tests/test_generate_sample.py b/tests/test_generate_sample.py index d36c889..02c70e5 100644 --- a/tests/test_generate_sample.py +++ b/tests/test_generate_sample.py @@ -4,6 +4,10 @@ def test_load_flick_config(): + """ + Test loading of cc12m_64x64.yaml config file. + Checks image dimensions are correctly loaded in reader config. + """ args = config.get_arguments( "", mode="demo", diff --git a/tests/test_imports.py b/tests/test_imports.py index 19138c2..04474a8 100644 --- a/tests/test_imports.py +++ b/tests/test_imports.py @@ -1,6 +1,7 @@ # For licensing see accompanying LICENSE file. # Copyright (C) 2024 Apple Inc. All rights reserved. def test_top_level_imports_work(): + """Checks that all top-level ml_mdm module imports are accessible.""" from ml_mdm import ( config, diffusion, @@ -16,6 +17,7 @@ def test_top_level_imports_work(): def test_cli_imports_work(): + """Checks that all CLI module imports are accessible.""" from ml_mdm.clis import ( download_tar_from_index, generate_batch, @@ -25,8 +27,10 @@ def test_cli_imports_work(): def test_model_imports_work(): + """Checks that all model module imports are accessible.""" from ml_mdm.models import model_ema, nested_unet, unet def test_lm_imports_work(): + """Checks that all language model module imports are accessible.""" from ml_mdm.language_models import factory, tokenizer diff --git a/tests/test_models.py b/tests/test_models.py index b945fb5..2a96dad 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -13,6 +13,7 @@ def test_initialize_unet(): + """Test UNet model and EMA initialization with default configs.""" unet_config = models.unet.UNetConfig() diffusion_config = diffusion.DiffusionConfig( use_vdm_loss_weights=True, model_output_scale=0.1 @@ -30,6 +31,7 @@ def test_initialize_unet(): def test_all_registered_models(): + """Test instantiation of all models in the registry with default configs.""" for config_name, additional_info in config.MODEL_CONFIG_REGISTRY.items(): model_name = additional_info["model"] config_cls = additional_info["config"] @@ -44,6 +46,7 @@ def test_all_registered_models(): @pytest.mark.gpu def test_initialize_pretrained(): + """Test loading pretrained 64x64 model on GPU if available.""" device = torch.device("cuda" if torch.cuda.is_available() else "cpu") args = config.get_arguments( diff --git a/tests/test_reader.py b/tests/test_reader.py index bcc3ef7..dd6d983 100644 --- a/tests/test_reader.py +++ b/tests/test_reader.py @@ -10,6 +10,7 @@ def test_get_dataset(): + """Test dataset loading and verify sample format and dimensions.""" tokenizer = factory.create_tokenizer("data/t5.vocab") dataset = reader.get_dataset( tokenizer=tokenizer, @@ -31,6 +32,7 @@ def test_get_dataset(): def test_get_dataset_partition(): + """Test dataset partitioning and iteration.""" tokenizer = factory.create_tokenizer("data/t5.vocab") train_loader = reader.get_dataset_partition( partition_num=0, @@ -46,6 +48,7 @@ def test_get_dataset_partition(): def test_process_text(): + """Test text tokenization with default reader config.""" line = "A bicycle on top of a boat." tokenizer = factory.create_tokenizer("data/t5.vocab") tokens = reader.process_text( diff --git a/tests/test_train.py b/tests/test_train.py index c1f2f52..073fcb4 100644 --- a/tests/test_train.py +++ b/tests/test_train.py @@ -23,6 +23,7 @@ reason="more effective to test this with torchrun, just here for documentation" ) def test_small(): + """Test minimal training run with single process setup.""" os.environ["RANK"] = "0" os.environ["WORLD_SIZE"] = "1" os.environ["LOCAL_RANK"] = "0" From 5cad65072b2680f2cc581cd75e6deebfa1f4e6d5 Mon Sep 17 00:00:00 2001 From: Bella Deanhardt Date: Tue, 12 Nov 2024 09:34:51 -0500 Subject: [PATCH 4/4] removed duplicate line --- tests/test_configs.py | 1 - 1 file changed, 1 deletion(-) diff --git a/tests/test_configs.py b/tests/test_configs.py index 299d0b3..ed622ae 100644 --- a/tests/test_configs.py +++ b/tests/test_configs.py @@ -21,7 +21,6 @@ def test_unet_in_pipeline(): def test_config_cc12m_64x64(): """Check that the 'cc12m_64x64' configuration file loads successfully for all pipeline modes (trainer, sampler, demo).""" f = "configs/models/cc12m_64x64.yaml" - f = "configs/models/cc12m_64x64.yaml" args = config.get_arguments( mode="trainer", additional_config_paths=[f],