Skip to content

Commit

Permalink
Add no_copy dialect option
Browse files Browse the repository at this point in the history
  • Loading branch information
Fatal1ty committed Sep 8, 2023
1 parent 711705f commit 13ba405
Show file tree
Hide file tree
Showing 8 changed files with 86 additions and 17 deletions.
1 change: 1 addition & 0 deletions mashumaro/core/meta/code/builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -1161,6 +1161,7 @@ def _get_field_packer(
metadata=metadata,
),
could_be_none=False,
no_copy=self._get_dialect_or_config_option("no_copy", False),
)
)
return packer, alias, could_be_none
Expand Down
1 change: 1 addition & 0 deletions mashumaro/core/meta/types/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -78,6 +78,7 @@ class ValueSpec:
could_be_none: bool = True
annotated_type: Optional[Type] = None
owner: Optional[Type] = None
no_copy: bool = False

def __setattr__(self, key: str, value: Any) -> None:
if key == "type":
Expand Down
56 changes: 40 additions & 16 deletions mashumaro/core/meta/types/pack.py
Original file line number Diff line number Diff line change
Expand Up @@ -673,6 +673,18 @@ def inner_expr(
)
)

def _check_sequence_pass_through(ie: Expression) -> bool:
return spec.no_copy and ie == "value"

def _make_sequence_expression(ie: Expression) -> Expression:
return f"[{ie} for value in {spec.expression}]"

def _check_mapping_pass_through(ke: Expression, ve: Expression) -> bool:
return spec.no_copy and ke == "key" and ve == "value"

def _make_mapping_expression(ke: Expression, ve: Expression) -> Expression:
return f"{{{ke}: {ve} for key, value in {spec.expression}.items()}}"

if issubclass(spec.origin_type, typing.ByteString): # type: ignore
spec.builder.ensure_object_imported(encodebytes)
return f"encodebytes({spec.expression}).decode()"
Expand All @@ -686,32 +698,44 @@ def inner_expr(
elif ensure_generic_collection_subclass(
spec, typing.List, typing.Deque, typing.AbstractSet
):
return f"[{inner_expr()} for value in {spec.expression}]"
ie = inner_expr()
if _check_sequence_pass_through(ie):
return spec.expression
return _make_sequence_expression(ie)
elif ensure_generic_mapping(spec, args, typing.ChainMap):
ke = inner_expr(0, "key")
ve = inner_expr(1)
if _check_mapping_pass_through(ke, ve):
return spec.expression
return (
f'[{{{inner_expr(0, "key")}: {inner_expr(1)} '
f"for key, value in m.items()}} "
f"[{{{ke}: {ve} for key, value in m.items()}} "
f"for m in {spec.expression}.maps]"
)
elif ensure_generic_mapping(spec, args, typing.OrderedDict):
return (
f'{{{inner_expr(0, "key")}: {inner_expr(1)} '
f"for key, value in {spec.expression}.items()}}"
)
ke = inner_expr(0, "key")
ve = inner_expr(1)
if _check_mapping_pass_through(ke, ve):
return spec.expression
return _make_mapping_expression(ke, ve)
elif ensure_generic_mapping(spec, args, typing.Counter):
return (
f'{{{inner_expr(0, "key")}: {inner_expr(1, v_type=int)} '
f"for key, value in {spec.expression}.items()}}"
)
ke = inner_expr(0, "key")
ve = inner_expr(1, v_type=int)
if _check_mapping_pass_through(ke, ve):
return spec.expression
return _make_mapping_expression(ke, ve)
elif is_typed_dict(spec.origin_type):
return pack_typed_dict(spec)
elif ensure_generic_mapping(spec, args, typing.Mapping):
return (
f'{{{inner_expr(0, "key")}: {inner_expr(1)} '
f"for key, value in {spec.expression}.items()}}"
)
ke = inner_expr(0, "key")
ve = inner_expr(1)
if _check_mapping_pass_through(ke, ve):
return spec.expression
return _make_mapping_expression(ke, ve)
elif ensure_generic_collection_subclass(spec, typing.Sequence):
return f"[{inner_expr()} for value in {spec.expression}]"
ie = inner_expr()
if _check_sequence_pass_through(ie):
return spec.expression
return _make_sequence_expression(ie)


@register
Expand Down
1 change: 1 addition & 0 deletions mashumaro/dialect.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,4 @@ class Dialect:
serialization_strategy: Dict[Any, SerializationStrategyValueType] = {}
omit_none: Union[bool, Literal[Sentinel.MISSING]] = Sentinel.MISSING
omit_default: Union[bool, Literal[Sentinel.MISSING]] = Sentinel.MISSING
no_copy: bool = False
1 change: 1 addition & 0 deletions mashumaro/mixins/msgpack.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@


class MessagePackDialect(Dialect):
no_copy = True
serialization_strategy = {
bytes: pass_through, # type: ignore
bytearray: {
Expand Down
1 change: 1 addition & 0 deletions mashumaro/mixins/orjson.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@


class OrjsonDialect(Dialect):
no_copy = True
serialization_strategy = {
datetime: {"serialize": pass_through},
date: {"serialize": pass_through},
Expand Down
1 change: 1 addition & 0 deletions mashumaro/mixins/toml.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@


class TOMLDialect(Dialect):
no_copy = True
omit_none = True
serialization_strategy = {
datetime: pass_through,
Expand Down
41 changes: 40 additions & 1 deletion tests/test_dialect.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
import collections
import typing
from dataclasses import dataclass, field
from datetime import date, datetime
from typing import (
Expand All @@ -15,7 +17,7 @@
import pytest
from typing_extensions import TypedDict

from mashumaro import DataClassDictMixin
from mashumaro import DataClassDictMixin, pass_through
from mashumaro.config import ADD_DIALECT_SUPPORT, BaseConfig
from mashumaro.dialect import Dialect
from mashumaro.exceptions import BadDialect
Expand Down Expand Up @@ -1161,3 +1163,40 @@ def test_dataclass_omit_default_dialects():
DataClassWithDefaultAndDialectSupport().to_dict(dialect=EmptyDialect)
== complete_dict
)


def test_dialect_no_copy():
class NoCopyDialect(Dialect):
no_copy = True
serialization_strategy = {int: {"serialize": pass_through}}

@dataclass
class DataClass(DataClassDictMixin):
a: List[str]
b: Set[str]
c: typing.ChainMap[str, str]
d: typing.OrderedDict[str, str]
e: typing.Counter[str]
f: typing.Dict[str, str]
g: typing.Sequence[str]

class Config(BaseConfig):
dialect = NoCopyDialect

obj = DataClass(
a=["foo"],
b={"foo"},
c=collections.ChainMap({"foo": "bar"}),
d=collections.OrderedDict({"foo": "bar"}),
e=collections.Counter({"foo": 1}),
f={"foo": "bar"},
g=["foo"],
)
data = obj.to_dict()
assert data["a"] is obj.a
assert data["b"] is obj.b
assert data["c"] is obj.c
assert data["d"] is obj.d
assert data["e"] is obj.e
assert data["f"] is obj.f
assert data["g"] is obj.g

0 comments on commit 13ba405

Please sign in to comment.