diff --git a/ml_mdm/diffusion.py b/ml_mdm/diffusion.py index 22d55f2..c11b034 100644 --- a/ml_mdm/diffusion.py +++ b/ml_mdm/diffusion.py @@ -29,7 +29,7 @@ def sv(x, f): @dataclass class DiffusionConfig: sampler_config: samplers.SamplerConfig = field( - default=samplers.SamplerConfig(), metadata={"help": "Sampler configuration"} + default_factory=samplers.SamplerConfig, metadata={"help": "Sampler configuration"} ) model_output_scale: float = field( default=0, diff --git a/ml_mdm/models/nested_unet.py b/ml_mdm/models/nested_unet.py index 43a4abb..b87c20c 100644 --- a/ml_mdm/models/nested_unet.py +++ b/ml_mdm/models/nested_unet.py @@ -22,7 +22,7 @@ @dataclass class NestedUNetConfig(UNetConfig): inner_config: UNetConfig = field( - default=UNetConfig(nesting=True), + default_factory=lambda: UNetConfig(nesting=True), metadata={"help": "inner unet used as middle blocks"}, ) skip_mid_blocks: bool = field(default=True) @@ -55,7 +55,7 @@ class NestedUNetConfig(UNetConfig): @dataclass class Nested2UNetConfig(NestedUNetConfig): inner_config: NestedUNetConfig = field( - default=NestedUNetConfig(nesting=True, initialize_inner_with_pretrained=None) + default_factory=lambda: NestedUNetConfig(nesting=True, initialize_inner_with_pretrained=None) ) @@ -63,7 +63,7 @@ class Nested2UNetConfig(NestedUNetConfig): @dataclass class Nested3UNetConfig(Nested2UNetConfig): inner_config: Nested2UNetConfig = field( - default=Nested2UNetConfig(nesting=True, initialize_inner_with_pretrained=None) + default_factory=lambda: Nested2UNetConfig(nesting=True, initialize_inner_with_pretrained=None) ) @@ -71,7 +71,7 @@ class Nested3UNetConfig(Nested2UNetConfig): @dataclass class Nested4UNetConfig(Nested3UNetConfig): inner_config: Nested3UNetConfig = field( - default=Nested3UNetConfig(nesting=True, initialize_inner_with_pretrained=None) + default_factory=lambda: Nested3UNetConfig(nesting=True, initialize_inner_with_pretrained=None) ) @@ -96,7 +96,7 @@ def download(vision_model_path): @config.register_model("nested_unet") class NestedUNet(UNet): def __init__(self, input_channels, output_channels, config: NestedUNetConfig): - super().__init__(input_channels, output_channels, config) + super().__init__(input_channels, output_channels=output_channels, config=config) config.inner_config.conditioning_feature_dim = config.conditioning_feature_dim if getattr(config.inner_config, "inner_config", None) is None: self.inner_unet = UNet(input_channels, output_channels, config.inner_config) diff --git a/ml_mdm/models/unet.py b/ml_mdm/models/unet.py index 94b7d60..43a8506 100644 --- a/ml_mdm/models/unet.py +++ b/ml_mdm/models/unet.py @@ -115,7 +115,7 @@ class UNetConfig: temporal_spatial_ds: bool = field(default=False) temporal_positional_encoding: bool = field(default=False) resnet_config: ResNetConfig = field( - default=ResNetConfig(), metadata={"help": "Resnet configs"} + default_factory=ResNetConfig, metadata={"help": "Resnet configs"} ) def __post_init__(self):