Skip to content

Commit

Permalink
Merge pull request #23 from apple/fix/default-factory
Browse files Browse the repository at this point in the history
Use default_factory for mutable fields
  • Loading branch information
luke-carlson authored Nov 7, 2024
2 parents 32d5b01 + 1a5e598 commit 17e36c6
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 7 deletions.
2 changes: 1 addition & 1 deletion ml_mdm/diffusion.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ def sv(x, f):
@dataclass
class DiffusionConfig:
sampler_config: samplers.SamplerConfig = field(
default=samplers.SamplerConfig(), metadata={"help": "Sampler configuration"}
default_factory=samplers.SamplerConfig, metadata={"help": "Sampler configuration"}
)
model_output_scale: float = field(
default=0,
Expand Down
10 changes: 5 additions & 5 deletions ml_mdm/models/nested_unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
@dataclass
class NestedUNetConfig(UNetConfig):
inner_config: UNetConfig = field(
default=UNetConfig(nesting=True),
default_factory=lambda: UNetConfig(nesting=True),
metadata={"help": "inner unet used as middle blocks"},
)
skip_mid_blocks: bool = field(default=True)
Expand Down Expand Up @@ -55,23 +55,23 @@ class NestedUNetConfig(UNetConfig):
@dataclass
class Nested2UNetConfig(NestedUNetConfig):
inner_config: NestedUNetConfig = field(
default=NestedUNetConfig(nesting=True, initialize_inner_with_pretrained=None)
default_factory=lambda: NestedUNetConfig(nesting=True, initialize_inner_with_pretrained=None)
)


@config.register_model_config("nested3_unet", "nested_unet")
@dataclass
class Nested3UNetConfig(Nested2UNetConfig):
inner_config: Nested2UNetConfig = field(
default=Nested2UNetConfig(nesting=True, initialize_inner_with_pretrained=None)
default_factory=lambda: Nested2UNetConfig(nesting=True, initialize_inner_with_pretrained=None)
)


@config.register_model_config("nested4_unet", "nested_unet")
@dataclass
class Nested4UNetConfig(Nested3UNetConfig):
inner_config: Nested3UNetConfig = field(
default=Nested3UNetConfig(nesting=True, initialize_inner_with_pretrained=None)
default_factory=lambda: Nested3UNetConfig(nesting=True, initialize_inner_with_pretrained=None)
)


Expand All @@ -96,7 +96,7 @@ def download(vision_model_path):
@config.register_model("nested_unet")
class NestedUNet(UNet):
def __init__(self, input_channels, output_channels, config: NestedUNetConfig):
super().__init__(input_channels, output_channels, config)
super().__init__(input_channels, output_channels=output_channels, config=config)
config.inner_config.conditioning_feature_dim = config.conditioning_feature_dim
if getattr(config.inner_config, "inner_config", None) is None:
self.inner_unet = UNet(input_channels, output_channels, config.inner_config)
Expand Down
2 changes: 1 addition & 1 deletion ml_mdm/models/unet.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,7 +115,7 @@ class UNetConfig:
temporal_spatial_ds: bool = field(default=False)
temporal_positional_encoding: bool = field(default=False)
resnet_config: ResNetConfig = field(
default=ResNetConfig(), metadata={"help": "Resnet configs"}
default_factory=ResNetConfig, metadata={"help": "Resnet configs"}
)

def __post_init__(self):
Expand Down

0 comments on commit 17e36c6

Please sign in to comment.