From f927083203404e1a6b48a49c6f0d52a88e30ca4a Mon Sep 17 00:00:00 2001 From: mishamsk Date: Wed, 3 Jul 2024 20:38:59 -0400 Subject: [PATCH] poc: array_like (de)serialization for dataclasses --- mashumaro/config.py | 1 + mashumaro/core/meta/code/builder.py | 48 +++++++++++++++++++++++------ tests/test_config.py | 29 +++++++++++++++++ 3 files changed, 68 insertions(+), 10 deletions(-) diff --git a/mashumaro/config.py b/mashumaro/config.py index f69cf880..dc786f95 100644 --- a/mashumaro/config.py +++ b/mashumaro/config.py @@ -70,3 +70,4 @@ class BaseConfig: sort_keys: bool = False allow_deserialization_not_by_alias: bool = False forbid_extra_keys: bool = False + array_like: bool = False diff --git a/mashumaro/core/meta/code/builder.py b/mashumaro/core/meta/code/builder.py index 41b028fb..7abf8b40 100644 --- a/mashumaro/core/meta/code/builder.py +++ b/mashumaro/core/meta/code/builder.py @@ -8,8 +8,12 @@ from contextlib import contextmanager # noinspection PyProtectedMember -from dataclasses import _FIELDS # type: ignore -from dataclasses import MISSING, Field, is_dataclass +from dataclasses import ( + _FIELDS, # type: ignore + MISSING, + Field, + is_dataclass, +) from functools import lru_cache try: @@ -483,7 +487,7 @@ def _add_unpack_method_lines(self, method_name: str) -> None: ) with self.indent("try:"): - for fname, alias, ftype in filtered_fields: + for i, (fname, alias, ftype) in enumerate(filtered_fields): self.add_type_modules(ftype) metadata = self.metadatas.get(fname, {}) field_block = FieldUnpackerCodeBlockBuilder( @@ -491,6 +495,7 @@ def _add_unpack_method_lines(self, method_name: str) -> None: ).build( fname=fname, ftype=ftype, + forder=i + int(discr is not None), metadata=metadata, alias=alias, ) @@ -862,6 +867,7 @@ def _add_pack_method_lines(self, method_name: str) -> None: serialize_by_alias = self.get_dialect_or_config_option( "serialize_by_alias", False ) + array_like = self.get_dialect_or_config_option("array_like", False) omit_none = self.get_dialect_or_config_option("omit_none", False) omit_default = self.get_dialect_or_config_option( "omit_default", False @@ -983,8 +989,17 @@ def _add_pack_method_lines(self, method_name: str) -> None: packer if packer != "value" else f"self.{fname}", ) ) - kwargs = ", ".join(f"'{k}': {v}" for k, v in kwargs_parts) - kwargs = f"{{{kwargs}}}" + + if not array_like: + kwargs = ", ".join(f"'{k}': {v}" for k, v in kwargs_parts) + kwargs = f"{{{kwargs}}}" + else: + kwargs = ", ".join(f"{v}" for _, v in kwargs_parts) + + if len(kwargs_parts) == 1: + kwargs = f"{kwargs}," + + kwargs = f"({kwargs})" post_serialize = self.get_declared_hook(__POST_SERIALIZE__) if self.encoder is not None: if self.encoder_kwargs: @@ -1309,6 +1324,7 @@ def build( self, fname: str, ftype: typing.Type, + forder: int, metadata: typing.Mapping, *, alias: typing.Optional[str] = None, @@ -1341,6 +1357,9 @@ def build( could_be_none=False if could_be_none else True, ) ) + + array_like = self.parent.get_config().array_like + if self.parent.get_config().allow_deserialization_not_by_alias: if unpacked_value != "value": self.add_line(f"value = d.get('{alias}', MISSING)") @@ -1360,15 +1379,24 @@ def build( unpacked_value = packed_value else: if unpacked_value != "value": - self.add_line(f"value = d.get('{alias or fname}', MISSING)") + if array_like: + self.add_line(f"value = d[{forder}]") + else: + self.add_line(f"value = d.get('{alias or fname}', MISSING)") packed_value = "value" elif has_default: - self.add_line(f"value = d.get('{alias or fname}', MISSING)") + if array_like: + self.add_line(f"value = d[{forder}]") + else: + self.add_line(f"value = d.get('{alias or fname}', MISSING)") packed_value = "value" else: - self.add_line( - f"__{fname} = d.get('{alias or fname}', MISSING)" - ) + if array_like: + self.add_line(f"__{fname} = d[{forder}]") + else: + self.add_line( + f"__{fname} = d.get('{alias or fname}', MISSING)" + ) packed_value = f"__{fname}" unpacked_value = packed_value if not has_default: diff --git a/tests/test_config.py b/tests/test_config.py index 4a20cb06..03696581 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -386,6 +386,35 @@ class Config(BaseConfig): assert exc_info.value.extra_keys == {"baz"} assert exc_info.value.target_type == ForbidKeysModel +def test_array_like(): + @dataclass + class FooModel(DataClassDictMixin): + foo: int + + class Config(BaseConfig): + array_like = True + + # Test packing works + assert FooModel(1).to_dict() == (1,) + + # Test unpacking works + assert FooModel.from_dict((1,)) == FooModel(1) + + # Nested + @dataclass + class BarModel(DataClassDictMixin): + bar: str + inner: FooModel + + class Config(BaseConfig): + array_like = True + + # Test packing works + assert BarModel("bar", FooModel(1)).to_dict() == ("bar", (1,)) + + # Test unpacking works + assert BarModel.from_dict(("bar", (1,))) == BarModel("bar", FooModel(1)) + @dataclass class _VariantByBase(DataClassDictMixin):