Skip to content

Commit

Permalink
Fix JSON Schema override for optional fields and dataclasses
Browse files Browse the repository at this point in the history
  • Loading branch information
Fatal1ty committed Aug 20, 2023
1 parent c0a4cf8 commit c886b01
Show file tree
Hide file tree
Showing 3 changed files with 183 additions and 72 deletions.
170 changes: 104 additions & 66 deletions mashumaro/jsonschema/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@
Dict,
Iterable,
List,
Mapping,
Optional,
Tuple,
Type,
Expand Down Expand Up @@ -104,36 +103,50 @@ class Instance:
name: Optional[str] = None
origin_type: Type = field(init=False)
annotations: List[Annotation] = field(init=False, default_factory=list)
__builder: Optional[CodeBuilder] = None
__metadata: Optional[Dict[str, Any]] = None
__owner_builder: Optional[CodeBuilder] = None
__self_builder: Optional[CodeBuilder] = None

@property
def _builder(self) -> CodeBuilder:
assert self.__builder
return self.__builder
def metadata(self) -> Dict[str, Any]:
if self.__metadata is None:
if not self.name:
self.__metadata = {}
elif self.__owner_builder:
self.__metadata = dict(
**self.__owner_builder.metadatas.get(self.name, {})
)
else:
self.__metadata = {}
return self.__metadata

@property
def metadata(self) -> Mapping[str, Any]:
if not self.name:
return {}
return self._builder.metadatas.get(self.name, {})
def _self_builder(self) -> CodeBuilder:
assert self.__self_builder
return self.__self_builder

@property
def alias(self) -> Optional[str]:
alias = self.metadata.get("alias")
if alias is None:
alias = self.get_config().aliases.get(self.name) # type: ignore
aliases_config = self.get_owner_config().aliases
alias = aliases_config.get(self.name) # type: ignore
if alias is None:
alias = self.name
return alias

@property
def holder_class(self) -> Optional[Type]:
if self.__builder:
return self.__builder.cls
def owner_class(self) -> Optional[Type]:
if self.__owner_builder:
return self.__owner_builder.cls
return None

def copy(self, **changes: Any) -> "Instance":
return replace(self, **changes)
def derive(self, **changes: Any) -> "Instance":
self_builder = self.__self_builder
new_instance = replace(self, **changes)
if self_builder:
new_instance.__owner_builder = self.__self_builder
return new_instance

def __post_init__(self) -> None:
self.update_type(self.type)
Expand All @@ -143,27 +156,29 @@ def __post_init__(self) -> None:
self.origin_type = get_type_origin(self.type)

def update_type(self, new_type: Type) -> None:
if self.__builder:
self.type = self.__builder._get_real_type(
if self.__owner_builder:
self.type = self.__owner_builder._get_real_type(
field_name=self.name, # type: ignore
field_type=new_type,
)
self.origin_type = get_type_origin(self.type)
if is_dataclass(self.origin_type):
type_args = get_args(self.type)
self.__builder = CodeBuilder(self.origin_type, type_args)
self.__builder.reset()
self.__self_builder = CodeBuilder(self.origin_type, type_args)
self.__self_builder.reset()
else:
self.__self_builder = None

def fields(self) -> Iterable[Tuple[str, Type, bool, Any]]:
for f_name, f_type in self._builder.get_field_types(
for f_name, f_type in self._self_builder.get_field_types(
include_extras=True
).items():
f = self._builder.dataclass_fields.get(f_name) # type: ignore
f = self._self_builder.dataclass_fields.get(f_name) # type: ignore
if f and not f.init:
continue
f_default = f.default
if f_default is MISSING:
f_default = self._builder.namespace.get(f_name, MISSING)
f_default = self._self_builder.namespace.get(f_name, MISSING)
if f_default is not MISSING:
f_default = _default(f_type, f_default)

Expand All @@ -176,12 +191,14 @@ def fields(self) -> Iterable[Tuple[str, Type, bool, Any]]:
def get_overridden_serialization_method(
self,
) -> Optional[Union[Callable, str]]:
if not self.__builder:
if not self.__owner_builder:
return None
serialize_option = self.metadata.get("serialize")
if serialize_option is not None:
if callable(serialize_option):
self.metadata.pop("serialize", None) # prevent recursion
return serialize_option
for strategy in self.__builder.iter_serialization_strategies(
for strategy in self.__owner_builder.iter_serialization_strategies(
self.metadata, self.type
):
if strategy is pass_through:
Expand All @@ -194,9 +211,15 @@ def get_overridden_serialization_method(
return serialize_option
return None

def get_config(self) -> Type[BaseConfig]:
if self.__builder:
return self.__builder.get_config()
def get_owner_config(self) -> Type[BaseConfig]:
if self.__owner_builder:
return self.__owner_builder.get_config()
else:
return BaseConfig

def get_self_config(self) -> Type[BaseConfig]:
if self.__self_builder:
return self.__self_builder.get_config()
else:
return BaseConfig

Expand Down Expand Up @@ -235,7 +258,7 @@ def get_schema(
raise NotImplementedError(
(
f'Type {type_name(instance.type)} of field "{instance.name}" '
f"in {type_name(instance.holder_class)} isn't supported"
f"in {type_name(instance.owner_class)} isn't supported"
)
)

Expand All @@ -261,31 +284,41 @@ class CC(DataClassJSONMixin):
register = Registry.register


def override_field_instance_type_if_needed(
root_instance: Instance, field_instance: Instance
) -> None:
overridden_method = field_instance.get_overridden_serialization_method()
@register
def on_type_with_overridden_deserialization(
instance: Instance, ctx: Context
) -> Optional[JSONSchema]:
def override_with_any(reason: Any) -> None:
if instance.owner_class is not None:
name = f"{type_name(instance.owner_class)}.{instance.name}"
else:
name = type_name(instance.type)
warnings.warn(
f"Type Any will be used for {name} with "
f"overridden serialization method: {reason}"
)
instance.update_type(Any) # type: ignore[arg-type]

overridden_method = instance.get_overridden_serialization_method()
if overridden_method is pass_through:
return
return None
elif callable(overridden_method):
try:
field_instance.update_type(
get_function_return_annotation(overridden_method)
)
new_type = get_function_return_annotation(overridden_method)
if new_type is instance.type:
return None
else:
instance.update_type(new_type)
except Exception as e:
warnings.warn(
f"Type Any will be used for "
f"{type_name(root_instance.type)}.{field_instance.name} with "
f"overridden serialization method: {e}"
)
field_instance.update_type(Any) # type: ignore[arg-type]
override_with_any(e)
return get_schema(instance, ctx)


@register
def on_dataclass(instance: Instance, ctx: Context) -> Optional[JSONSchema]:
# TODO: Self references might not work
if is_dataclass(instance.origin_type):
jsonschema_config = instance.get_config().json_schema
jsonschema_config = instance.get_self_config().json_schema
schema = JSONObjectSchema(
title=instance.origin_type.__name__,
additionalProperties=jsonschema_config.get(
Expand All @@ -297,11 +330,10 @@ def on_dataclass(instance: Instance, ctx: Context) -> Optional[JSONSchema]:
field_schema_overrides = jsonschema_config.get("properties", {})
for f_name, f_type, has_default, f_default in instance.fields():
override = field_schema_overrides.get(f_name)
f_instance = instance.copy(type=f_type, name=f_name)
f_instance = instance.derive(type=f_type, name=f_name)
if override:
f_schema = JSONSchema.from_dict(override)
else:
override_field_instance_type_if_needed(instance, f_instance)
f_schema = get_schema(f_instance, ctx)
if f_instance.alias:
f_name = f_instance.alias
Expand Down Expand Up @@ -361,7 +393,7 @@ def on_special_typing_primitive(

if is_union(instance.type):
return JSONSchema(
anyOf=[get_schema(instance.copy(type=arg), ctx) for arg in args]
anyOf=[get_schema(instance.derive(type=arg), ctx) for arg in args]
)
elif is_type_var_any(instance.type):
return EmptyJSONSchema()
Expand All @@ -370,25 +402,29 @@ def on_special_typing_primitive(
if constraints:
return JSONSchema(
anyOf=[
get_schema(instance.copy(type=arg), ctx)
get_schema(instance.derive(type=arg), ctx)
for arg in constraints
]
)
else:
bound = getattr(instance.type, "__bound__")
return get_schema(instance.copy(type=bound), ctx)
return get_schema(instance.derive(type=bound), ctx)
elif is_new_type(instance.type):
return get_schema(instance.copy(type=instance.type.__supertype__), ctx)
return get_schema(
instance.derive(type=instance.type.__supertype__), ctx
)
elif is_literal(instance.type):
return on_literal(instance, ctx)
# elif is_self(instance.type):
# raise NotImplementedError
elif is_required(instance.type) or is_not_required(instance.type):
return get_schema(instance.copy(type=args[0]), ctx)
return get_schema(instance.derive(type=args[0]), ctx)
elif is_unpack(instance.type):
return get_schema(instance.copy(type=get_args(instance.type)[0]), ctx)
return get_schema(
instance.derive(type=get_args(instance.type)[0]), ctx
)
elif is_type_var_tuple(instance.type):
return get_schema(instance.copy(type=Tuple[Any, ...]), ctx)
return get_schema(instance.derive(type=Tuple[Any, ...]), ctx)


@register
Expand Down Expand Up @@ -518,7 +554,7 @@ def on_tuple(instance: Instance, ctx: Context) -> JSONArraySchema:
if not PY_311_MIN:
return JSONArraySchema(maxItems=0)
if len(args) == 2 and args[1] is Ellipsis:
items_schema = _get_schema_or_none(instance.copy(type=args[0]), ctx)
items_schema = _get_schema_or_none(instance.derive(type=args[0]), ctx)
return JSONArraySchema(items=items_schema)
else:
min_items: Optional[int] = 0
Expand All @@ -532,10 +568,10 @@ def on_tuple(instance: Instance, ctx: Context) -> JSONArraySchema:
min_items += 1 # type: ignore
if not unpack_schema:
prefix_items.append(
get_schema(instance.copy(type=arg), ctx)
get_schema(instance.derive(type=arg), ctx)
)
else:
unpack_schema = get_schema(instance.copy(type=arg), ctx)
unpack_schema = get_schema(instance.derive(type=arg), ctx)
unpack_idx = arg_idx
if unpack_schema:
prefix_items.extend(unpack_schema.prefixItems or [])
Expand Down Expand Up @@ -566,7 +602,7 @@ def on_named_tuple(instance: Instance, ctx: Context) -> JSONSchema:
}
fields = getattr(instance.type, "_fields", ())
defaults = getattr(instance.type, "_field_defaults", {})
as_dict = instance.get_config().namedtuple_as_dict
as_dict = instance.get_owner_config().namedtuple_as_dict
serialize_option = instance.get_overridden_serialization_method()
if serialize_option == "as_dict":
as_dict = True
Expand All @@ -575,7 +611,7 @@ def on_named_tuple(instance: Instance, ctx: Context) -> JSONSchema:
properties = {}
for f_name in fields:
f_type = annotations.get(f_name, typing.Any)
f_schema = get_schema(instance.copy(type=f_type), ctx)
f_schema = get_schema(instance.derive(type=f_type), ctx)
f_default = defaults.get(f_name, MISSING)
if f_default is not MISSING:
if isinstance(f_schema, EmptyJSONSchema):
Expand Down Expand Up @@ -608,7 +644,7 @@ def on_typed_dict(instance: Instance, ctx: Context) -> JSONObjectSchema:
required_keys = getattr(instance.type, "__required_keys__", all_keys)
return JSONObjectSchema(
properties={
key: get_schema(instance.copy(type=annotations[key]), ctx)
key: get_schema(instance.derive(type=annotations[key]), ctx)
for key in all_keys
}
or None,
Expand Down Expand Up @@ -689,7 +725,7 @@ def on_collection(instance: Instance, ctx: Context) -> Optional[JSONSchema]:
return apply_array_constraints(
instance,
JSONArraySchema(
items=_get_schema_or_none(instance.copy(type=args[0]), ctx)
items=_get_schema_or_none(instance.derive(type=args[0]), ctx)
if args
else None
),
Expand All @@ -707,7 +743,7 @@ def on_collection(instance: Instance, ctx: Context) -> Optional[JSONSchema]:
return apply_array_constraints(
instance,
JSONArraySchema(
items=_get_schema_or_none(instance.copy(type=args[0]), ctx)
items=_get_schema_or_none(instance.derive(type=args[0]), ctx)
if args
else None,
uniqueItems=True,
Expand All @@ -720,7 +756,7 @@ def on_collection(instance: Instance, ctx: Context) -> Optional[JSONSchema]:
instance,
JSONArraySchema(
items=get_schema(
instance=instance.copy(
instance=instance.derive(
type=(
Dict[args[0], args[1]] # type: ignore
if args
Expand All @@ -735,11 +771,11 @@ def on_collection(instance: Instance, ctx: Context) -> Optional[JSONSchema]:
instance.origin_type, typing.Counter
):
schema = JSONObjectSchema(
additionalProperties=get_schema(instance.copy(type=int), ctx),
additionalProperties=get_schema(instance.derive(type=int), ctx),
)
if args:
schema.propertyNames = _get_schema_or_none(
instance.copy(type=args[0]), ctx
instance.derive(type=args[0]), ctx
)
return apply_object_constraints(instance, schema)
elif is_typed_dict(instance.origin_type):
Expand All @@ -749,11 +785,13 @@ def on_collection(instance: Instance, ctx: Context) -> Optional[JSONSchema]:
):
schema = JSONObjectSchema(
additionalProperties=_get_schema_or_none(
instance.copy(type=args[1]), ctx
instance.derive(type=args[1]), ctx
)
if args
else None,
propertyNames=_get_schema_or_none(instance.copy(type=args[0]), ctx)
propertyNames=_get_schema_or_none(
instance.derive(type=args[0]), ctx
)
if args
else None,
)
Expand All @@ -764,7 +802,7 @@ def on_collection(instance: Instance, ctx: Context) -> Optional[JSONSchema]:
return apply_array_constraints(
instance,
JSONArraySchema(
items=_get_schema_or_none(instance.copy(type=args[0]), ctx)
items=_get_schema_or_none(instance.derive(type=args[0]), ctx)
if args
else None
),
Expand Down
4 changes: 2 additions & 2 deletions tests/test_jsonschema/test_jsonschema_builder.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,10 @@ class DataClass:

instance = Instance(int)
assert instance.metadata == {}
assert instance.holder_class is None
assert instance.owner_class is None

instance = Instance(DataClass)
assert instance.holder_class is DataClass
assert instance.owner_class is None
assert instance.metadata == {}


Expand Down
Loading

0 comments on commit c886b01

Please sign in to comment.