Skip to content

Commit

Permalink
Improve value type checks in union unpacker
Browse files Browse the repository at this point in the history
  • Loading branch information
Fatal1ty committed Nov 12, 2024
1 parent dcbdc4a commit 92c917c
Showing 1 changed file with 16 additions and 2 deletions.
18 changes: 16 additions & 2 deletions mashumaro/core/meta/types/unpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,18 +171,29 @@ def get_method_prefix(self) -> str:
return "union"

def _add_body(self, spec: ValueSpec, lines: CodeLines) -> None:
orig_lines = lines
lines = CodeLines()
unpackers = set()
fallback_unpackers = []
type_arg_unpackers = []
type_match_statements = 0
for type_arg in self.union_args:
condition = ""
unpacker = UnpackerRegistry.get(
spec.copy(type=type_arg, expression="value")
)
type_arg_unpackers.append((type_arg, unpacker))
if isinstance(unpacker, TypeMatchEligibleExpression):
type_match_statements += 1
for type_arg, unpacker in type_arg_unpackers:
condition = ""
do_try = unpacker != "value"
unpacker_block = CodeLines()
if isinstance(unpacker, TypeMatchEligibleExpression):
do_try = False
condition = f"type(value) is {type_arg.__name__}"
if type_match_statements > 1:
condition = f"__value_type is {type_arg.__name__}"
else:
condition = f"type(value) is {type_arg.__name__}"
if (condition, unpacker) in unpackers: # pragma: no cover
# we shouldn't be here because condition is always unique
continue
Expand Down Expand Up @@ -219,6 +230,9 @@ def _add_body(self, spec: ValueSpec, lines: CodeLines) -> None:
)
else:
lines.append("raise ValueError(value)")
if type_match_statements > 1:
orig_lines.append("__value_type = type(value)")
orig_lines.extend(lines)


class TypeVarUnpackerBuilder(UnionUnpackerBuilder):
Expand Down

0 comments on commit 92c917c

Please sign in to comment.