Skip to content

Commit

Permalink
Merge pull request #256 from Fatal1ty/new-union-unpacker
Browse files Browse the repository at this point in the history
Make Union deserialization algorithm more robust
  • Loading branch information
Fatal1ty authored Nov 12, 2024
2 parents 3436b13 + 92c917c commit ad41838
Show file tree
Hide file tree
Showing 8 changed files with 233 additions and 62 deletions.
6 changes: 1 addition & 5 deletions benchmark/libs/mashumaro/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,17 +6,13 @@
import pyperf

from benchmark.common import AbstractBenchmark
from mashumaro import field_options, pass_through
from mashumaro import field_options
from mashumaro.codecs import BasicDecoder, BasicEncoder
from mashumaro.dialect import Dialect


class DefaultDialect(Dialect):
serialize_by_alias = True
serialization_strategy = {
str: {"deserialize": str, "serialize": pass_through},
int: {"serialize": pass_through},
}


class IssueState(Enum):
Expand Down
2 changes: 1 addition & 1 deletion mashumaro/core/meta/code/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,7 +487,7 @@ def _add_unpack_method_lines(self, method_name: str) -> None:
self.add_type_modules(ftype)
metadata = self.metadatas.get(fname, {})
field_block = FieldUnpackerCodeBlockBuilder(
self, self.lines.branch_off()
self, CodeLines()
).build(
fname=fname,
ftype=ftype,
Expand Down
8 changes: 2 additions & 6 deletions mashumaro/core/meta/code/lines.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,8 @@ def append(self, line: str) -> None:
self._lines.append(f"{self._current_indent}{line}")

def extend(self, lines: "CodeLines") -> None:
self._lines.extend(lines._lines)
for line in lines._lines:
self._lines.append(f"{self._current_indent}{line}")

@contextmanager
def indent(
Expand All @@ -34,8 +35,3 @@ def as_text(self) -> str:
def reset(self) -> None:
self._lines = []
self._current_indent = ""

def branch_off(self) -> "CodeLines":
branch = CodeLines()
branch._current_indent = self._current_indent
return branch
7 changes: 6 additions & 1 deletion mashumaro/core/meta/types/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
Sequence,
Type,
TypeVar,
Union,
)

from typing_extensions import ParamSpec, TypeAlias
Expand All @@ -39,8 +40,12 @@
CodeBuilder = Any


class TypeMatchEligibleExpression(str):
pass


NoneType = type(None)
Expression: TypeAlias = str
Expression: TypeAlias = Union[str, TypeMatchEligibleExpression]

P = ParamSpec("P")
T = TypeVar("T")
Expand Down
74 changes: 55 additions & 19 deletions mashumaro/core/meta/types/unpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@
ExpressionWrapper,
NoneType,
Registry,
TypeMatchEligibleExpression,
ValueSpec,
clean_id,
ensure_generic_collection,
Expand Down Expand Up @@ -170,28 +171,52 @@ def get_method_prefix(self) -> str:
return "union"

def _add_body(self, spec: ValueSpec, lines: CodeLines) -> None:
ambiguous_unpacker_types = []
orig_lines = lines
lines = CodeLines()
unpackers = set()
fallback_unpackers = []
type_arg_unpackers = []
type_match_statements = 0
for type_arg in self.union_args:
unpacker = UnpackerRegistry.get(
spec.copy(type=type_arg, expression="value")
)
if type_arg in (bool, str) and unpacker == "value":
ambiguous_unpacker_types.append(type_arg)
if unpacker in unpackers:
type_arg_unpackers.append((type_arg, unpacker))
if isinstance(unpacker, TypeMatchEligibleExpression):
type_match_statements += 1
for type_arg, unpacker in type_arg_unpackers:
condition = ""
do_try = unpacker != "value"
unpacker_block = CodeLines()
if isinstance(unpacker, TypeMatchEligibleExpression):
do_try = False
if type_match_statements > 1:
condition = f"__value_type is {type_arg.__name__}"
else:
condition = f"type(value) is {type_arg.__name__}"
if (condition, unpacker) in unpackers: # pragma: no cover
# we shouldn't be here because condition is always unique
continue
with unpacker_block.indent(f"if {condition}:"):
unpacker_block.append("return value")
if (condition, unpacker) not in unpackers:
fallback_unpackers.append(unpacker)
elif (condition, unpacker) in unpackers:
continue
else:
unpacker_block.append(f"return {unpacker}")

if do_try:
with lines.indent("try:"):
lines.extend(unpacker_block)
lines.append("except Exception: pass")
else:
lines.extend(unpacker_block)
unpackers.add((condition, unpacker))
for fallback_unpacker in fallback_unpackers:
with lines.indent("try:"):
lines.append(f"return {unpacker}")
lines.append(f"return {fallback_unpacker}")
lines.append("except Exception: pass")
unpackers.add(unpacker)
# if len(ambiguous_unpacker_types) >= 2:
# warnings.warn(
# f"{type_name(spec.builder.cls)}.{spec.field_ctx.name} "
# f"({type_name(spec.type)}): "
# "In the next release, data marked with Union type "
# "containing 'str' and 'bool' will be coerced to the value "
# "of the type specified first instead of passing it as is"
# )
field_type = spec.builder.get_type_name_identifier(
typ=spec.type,
resolved_type_params=spec.builder.get_field_resolved_type_params(
Expand All @@ -205,6 +230,9 @@ def _add_body(self, spec: ValueSpec, lines: CodeLines) -> None:
)
else:
lines.append("raise ValueError(value)")
if type_match_statements > 1:
orig_lines.append("__value_type = type(value)")
orig_lines.extend(lines)


class TypeVarUnpackerBuilder(UnionUnpackerBuilder):
Expand Down Expand Up @@ -843,13 +871,21 @@ def unpack_special_typing_primitive(spec: ValueSpec) -> Optional[Expression]:
@register
def unpack_number(spec: ValueSpec) -> Optional[Expression]:
if spec.origin_type in (int, float):
return f"{type_name(spec.origin_type)}({spec.expression})"
return TypeMatchEligibleExpression(
f"{type_name(spec.origin_type)}({spec.expression})"
)


@register
def unpack_bool_and_none(spec: ValueSpec) -> Optional[Expression]:
if spec.origin_type in (bool, NoneType, None):
return spec.expression
def unpack_bool(spec: ValueSpec) -> Optional[Expression]:
if spec.origin_type is bool:
return TypeMatchEligibleExpression(f"bool({spec.expression})")


@register
def unpack_none(spec: ValueSpec) -> Optional[Expression]:
if spec.origin_type in (NoneType, None):
return TypeMatchEligibleExpression("None")


@register
Expand Down Expand Up @@ -1199,7 +1235,7 @@ def inner_expr(
spec.builder.ensure_object_imported(decodebytes)
return f"bytearray(decodebytes({spec.expression}.encode()))"
elif issubclass(spec.origin_type, str):
return spec.expression
return TypeMatchEligibleExpression(f"str({spec.expression})")
elif ensure_generic_collection_subclass(spec, List):
return f"[{inner_expr()} for value in {spec.expression}]"
elif ensure_generic_collection_subclass(spec, typing.Deque):
Expand Down
86 changes: 68 additions & 18 deletions tests/test_aliases.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,12 @@
from dataclasses import dataclass, field
from email.policy import default
from importlib.metadata import metadata
from typing import Optional

import pytest
from typing_extensions import Annotated

from mashumaro import DataClassDictMixin
from mashumaro import DataClassDictMixin, pass_through
from mashumaro.config import (
TO_DICT_ADD_BY_ALIAS_FLAG,
TO_DICT_ADD_OMIT_NONE_FLAG,
Expand Down Expand Up @@ -233,34 +235,74 @@ class DataClass(DataClassDictMixin):
b2: Annotated[Optional[int], Alias("alias_b2")]
c1: Optional[str] = field(metadata={"alias": "alias_c1"})
c2: Annotated[Optional[str], Alias("alias_c2")]
d1: int = field(metadata={"alias": "alias_d1"}, default=4)
d2: Annotated[int, Alias("alias_d2")] = 4
e1: Optional[int] = field(metadata={"alias": "alias_e1"}, default=5)
e2: Annotated[Optional[int], Alias("alias_e2")] = 5
f1: Optional[str] = field(metadata={"alias": "alias_f1"}, default="6")
f2: Annotated[Optional[str], Alias("alias_f2")] = "6"
d1: int = field(
metadata={"alias": "alias_d1", "deserialize": pass_through}
)
d2: Annotated[int, Alias("alias_d2")] = field(
metadata={"deserialize": pass_through}
)
e1: int = field(metadata={"alias": "alias_e1"}, default=5)
e2: Annotated[int, Alias("alias_e2")] = 5
f1: Optional[int] = field(metadata={"alias": "alias_f1"}, default=6)
f2: Annotated[Optional[int], Alias("alias_f2")] = 6
g1: Optional[str] = field(metadata={"alias": "alias_g1"}, default="7")
g2: Annotated[Optional[str], Alias("alias_g2")] = "7"
h1: int = field(
metadata={"alias": "alias_h1", "deserialize": pass_through},
default=8,
)
h2: Annotated[int, Alias("alias_h2")] = field(
metadata={"deserialize": pass_through}, default=8
)

class Config(BaseConfig):
serialize_by_alias = True
code_generation_options = [TO_DICT_ADD_BY_ALIAS_FLAG]
allow_deserialization_not_by_alias = True

instance = DataClass(a1=1, a2=1, b1=2, b2=2, c1="3", c2="3")
instance = DataClass(a1=1, a2=1, b1=2, b2=2, c1="3", c2="3", d1=4, d2=4)
assert (
DataClass.from_dict(
{
"a1": 1,
"a2": 1,
"alias_b1": 2,
"alias_b2": 2,
"b1": 2,
"b2": 2,
"c1": "3",
"c2": "3",
"alias_d1": 4,
"alias_d2": 4,
"d1": 4,
"d2": 4,
"e1": 5,
"e2": 5,
"alias_f1": "6",
"alias_f2": "6",
"f1": 6,
"f2": 6,
"g1": "7",
"g2": "7",
"h1": 8,
"h2": 8,
}
)
== instance
)
assert (
DataClass.from_dict(
{
"alias_a1": 1,
"alias_a2": 1,
"alias_b1": 2,
"alias_b2": 2,
"alias_c1": "3",
"alias_c2": "3",
"alias_d1": 4,
"alias_d2": 4,
"alias_e1": 5,
"alias_e2": 5,
"alias_f1": 6,
"alias_f2": 6,
"alias_g1": "7",
"alias_g2": "7",
"alias_h1": 8,
"alias_h2": 8,
}
)
== instance
Expand All @@ -276,8 +318,12 @@ class Config(BaseConfig):
"alias_d2": 4,
"alias_e1": 5,
"alias_e2": 5,
"alias_f1": "6",
"alias_f2": "6",
"alias_f1": 6,
"alias_f2": 6,
"alias_g1": "7",
"alias_g2": "7",
"alias_h1": 8,
"alias_h2": 8,
}
assert instance.to_dict(by_alias=False) == {
"a1": 1,
Expand All @@ -290,8 +336,12 @@ class Config(BaseConfig):
"d2": 4,
"e1": 5,
"e2": 5,
"f1": "6",
"f2": "6",
"f1": 6,
"f2": 6,
"g1": "7",
"g2": "7",
"h1": 8,
"h2": 8,
}


Expand Down
1 change: 1 addition & 0 deletions tests/test_data_types.py
Original file line number Diff line number Diff line change
Expand Up @@ -1304,6 +1304,7 @@ class DataClass(DataClassDictMixin):

obj = DataClass(x=None, y=None, z=[None])
assert DataClass.from_dict({"x": None, "y": None, "z": [None]}) == obj
assert DataClass.from_dict({"x": 42, "y": "foo", "z": ["bar"]}) == obj
assert obj.to_dict() == {"x": None, "y": None, "z": [None]}


Expand Down
Loading

0 comments on commit ad41838

Please sign in to comment.