diff --git a/mashumaro/jsonschema/schema.py b/mashumaro/jsonschema/schema.py index 00bfe696..6024846b 100644 --- a/mashumaro/jsonschema/schema.py +++ b/mashumaro/jsonschema/schema.py @@ -14,7 +14,6 @@ Dict, Iterable, List, - Mapping, Optional, Tuple, Type, @@ -104,36 +103,50 @@ class Instance: name: Optional[str] = None origin_type: Type = field(init=False) annotations: List[Annotation] = field(init=False, default_factory=list) - __builder: Optional[CodeBuilder] = None + __metadata: Optional[Dict[str, Any]] = None + __owner_builder: Optional[CodeBuilder] = None + __self_builder: Optional[CodeBuilder] = None @property - def _builder(self) -> CodeBuilder: - assert self.__builder - return self.__builder + def metadata(self) -> Dict[str, Any]: + if self.__metadata is None: + if not self.name: + self.__metadata = {} + elif self.__owner_builder: + self.__metadata = dict( + **self.__owner_builder.metadatas.get(self.name, {}) + ) + else: + self.__metadata = {} + return self.__metadata @property - def metadata(self) -> Mapping[str, Any]: - if not self.name: - return {} - return self._builder.metadatas.get(self.name, {}) + def _self_builder(self) -> CodeBuilder: + assert self.__self_builder + return self.__self_builder @property def alias(self) -> Optional[str]: alias = self.metadata.get("alias") if alias is None: - alias = self.get_config().aliases.get(self.name) # type: ignore + aliases_config = self.get_owner_config().aliases + alias = aliases_config.get(self.name) # type: ignore if alias is None: alias = self.name return alias @property - def holder_class(self) -> Optional[Type]: - if self.__builder: - return self.__builder.cls + def owner_class(self) -> Optional[Type]: + if self.__owner_builder: + return self.__owner_builder.cls return None - def copy(self, **changes: Any) -> "Instance": - return replace(self, **changes) + def derive(self, **changes: Any) -> "Instance": + self_builder = self.__self_builder + new_instance = replace(self, **changes) + if self_builder: + new_instance.__owner_builder = self.__self_builder + return new_instance def __post_init__(self) -> None: self.update_type(self.type) @@ -143,27 +156,29 @@ def __post_init__(self) -> None: self.origin_type = get_type_origin(self.type) def update_type(self, new_type: Type) -> None: - if self.__builder: - self.type = self.__builder._get_real_type( + if self.__owner_builder: + self.type = self.__owner_builder._get_real_type( field_name=self.name, # type: ignore field_type=new_type, ) self.origin_type = get_type_origin(self.type) if is_dataclass(self.origin_type): type_args = get_args(self.type) - self.__builder = CodeBuilder(self.origin_type, type_args) - self.__builder.reset() + self.__self_builder = CodeBuilder(self.origin_type, type_args) + self.__self_builder.reset() + else: + self.__self_builder = None def fields(self) -> Iterable[Tuple[str, Type, bool, Any]]: - for f_name, f_type in self._builder.get_field_types( + for f_name, f_type in self._self_builder.get_field_types( include_extras=True ).items(): - f = self._builder.dataclass_fields.get(f_name) # type: ignore + f = self._self_builder.dataclass_fields.get(f_name) # type: ignore if f and not f.init: continue f_default = f.default if f_default is MISSING: - f_default = self._builder.namespace.get(f_name, MISSING) + f_default = self._self_builder.namespace.get(f_name, MISSING) if f_default is not MISSING: f_default = _default(f_type, f_default) @@ -176,12 +191,14 @@ def fields(self) -> Iterable[Tuple[str, Type, bool, Any]]: def get_overridden_serialization_method( self, ) -> Optional[Union[Callable, str]]: - if not self.__builder: + if not self.__owner_builder: return None serialize_option = self.metadata.get("serialize") if serialize_option is not None: + if callable(serialize_option): + self.metadata.pop("serialize", None) # prevent recursion return serialize_option - for strategy in self.__builder.iter_serialization_strategies( + for strategy in self.__owner_builder.iter_serialization_strategies( self.metadata, self.type ): if strategy is pass_through: @@ -194,9 +211,15 @@ def get_overridden_serialization_method( return serialize_option return None - def get_config(self) -> Type[BaseConfig]: - if self.__builder: - return self.__builder.get_config() + def get_owner_config(self) -> Type[BaseConfig]: + if self.__owner_builder: + return self.__owner_builder.get_config() + else: + return BaseConfig + + def get_self_config(self) -> Type[BaseConfig]: + if self.__self_builder: + return self.__self_builder.get_config() else: return BaseConfig @@ -235,7 +258,7 @@ def get_schema( raise NotImplementedError( ( f'Type {type_name(instance.type)} of field "{instance.name}" ' - f"in {type_name(instance.holder_class)} isn't supported" + f"in {type_name(instance.owner_class)} isn't supported" ) ) @@ -261,31 +284,41 @@ class CC(DataClassJSONMixin): register = Registry.register -def override_field_instance_type_if_needed( - root_instance: Instance, field_instance: Instance -) -> None: - overridden_method = field_instance.get_overridden_serialization_method() +@register +def on_type_with_overridden_deserialization( + instance: Instance, ctx: Context +) -> Optional[JSONSchema]: + def override_with_any(reason: Any) -> None: + if instance.owner_class is not None: + name = f"{type_name(instance.owner_class)}.{instance.name}" + else: + name = type_name(instance.type) + warnings.warn( + f"Type Any will be used for {name} with " + f"overridden serialization method: {reason}" + ) + instance.update_type(Any) # type: ignore[arg-type] + + overridden_method = instance.get_overridden_serialization_method() if overridden_method is pass_through: - return + return None elif callable(overridden_method): try: - field_instance.update_type( - get_function_return_annotation(overridden_method) - ) + new_type = get_function_return_annotation(overridden_method) + if new_type is instance.type: + return None + else: + instance.update_type(new_type) except Exception as e: - warnings.warn( - f"Type Any will be used for " - f"{type_name(root_instance.type)}.{field_instance.name} with " - f"overridden serialization method: {e}" - ) - field_instance.update_type(Any) # type: ignore[arg-type] + override_with_any(e) + return get_schema(instance, ctx) @register def on_dataclass(instance: Instance, ctx: Context) -> Optional[JSONSchema]: # TODO: Self references might not work if is_dataclass(instance.origin_type): - jsonschema_config = instance.get_config().json_schema + jsonschema_config = instance.get_self_config().json_schema schema = JSONObjectSchema( title=instance.origin_type.__name__, additionalProperties=jsonschema_config.get( @@ -297,11 +330,10 @@ def on_dataclass(instance: Instance, ctx: Context) -> Optional[JSONSchema]: field_schema_overrides = jsonschema_config.get("properties", {}) for f_name, f_type, has_default, f_default in instance.fields(): override = field_schema_overrides.get(f_name) - f_instance = instance.copy(type=f_type, name=f_name) + f_instance = instance.derive(type=f_type, name=f_name) if override: f_schema = JSONSchema.from_dict(override) else: - override_field_instance_type_if_needed(instance, f_instance) f_schema = get_schema(f_instance, ctx) if f_instance.alias: f_name = f_instance.alias @@ -361,7 +393,7 @@ def on_special_typing_primitive( if is_union(instance.type): return JSONSchema( - anyOf=[get_schema(instance.copy(type=arg), ctx) for arg in args] + anyOf=[get_schema(instance.derive(type=arg), ctx) for arg in args] ) elif is_type_var_any(instance.type): return EmptyJSONSchema() @@ -370,25 +402,29 @@ def on_special_typing_primitive( if constraints: return JSONSchema( anyOf=[ - get_schema(instance.copy(type=arg), ctx) + get_schema(instance.derive(type=arg), ctx) for arg in constraints ] ) else: bound = getattr(instance.type, "__bound__") - return get_schema(instance.copy(type=bound), ctx) + return get_schema(instance.derive(type=bound), ctx) elif is_new_type(instance.type): - return get_schema(instance.copy(type=instance.type.__supertype__), ctx) + return get_schema( + instance.derive(type=instance.type.__supertype__), ctx + ) elif is_literal(instance.type): return on_literal(instance, ctx) # elif is_self(instance.type): # raise NotImplementedError elif is_required(instance.type) or is_not_required(instance.type): - return get_schema(instance.copy(type=args[0]), ctx) + return get_schema(instance.derive(type=args[0]), ctx) elif is_unpack(instance.type): - return get_schema(instance.copy(type=get_args(instance.type)[0]), ctx) + return get_schema( + instance.derive(type=get_args(instance.type)[0]), ctx + ) elif is_type_var_tuple(instance.type): - return get_schema(instance.copy(type=Tuple[Any, ...]), ctx) + return get_schema(instance.derive(type=Tuple[Any, ...]), ctx) @register @@ -518,7 +554,7 @@ def on_tuple(instance: Instance, ctx: Context) -> JSONArraySchema: if not PY_311_MIN: return JSONArraySchema(maxItems=0) if len(args) == 2 and args[1] is Ellipsis: - items_schema = _get_schema_or_none(instance.copy(type=args[0]), ctx) + items_schema = _get_schema_or_none(instance.derive(type=args[0]), ctx) return JSONArraySchema(items=items_schema) else: min_items: Optional[int] = 0 @@ -532,10 +568,10 @@ def on_tuple(instance: Instance, ctx: Context) -> JSONArraySchema: min_items += 1 # type: ignore if not unpack_schema: prefix_items.append( - get_schema(instance.copy(type=arg), ctx) + get_schema(instance.derive(type=arg), ctx) ) else: - unpack_schema = get_schema(instance.copy(type=arg), ctx) + unpack_schema = get_schema(instance.derive(type=arg), ctx) unpack_idx = arg_idx if unpack_schema: prefix_items.extend(unpack_schema.prefixItems or []) @@ -566,7 +602,7 @@ def on_named_tuple(instance: Instance, ctx: Context) -> JSONSchema: } fields = getattr(instance.type, "_fields", ()) defaults = getattr(instance.type, "_field_defaults", {}) - as_dict = instance.get_config().namedtuple_as_dict + as_dict = instance.get_owner_config().namedtuple_as_dict serialize_option = instance.get_overridden_serialization_method() if serialize_option == "as_dict": as_dict = True @@ -575,7 +611,7 @@ def on_named_tuple(instance: Instance, ctx: Context) -> JSONSchema: properties = {} for f_name in fields: f_type = annotations.get(f_name, typing.Any) - f_schema = get_schema(instance.copy(type=f_type), ctx) + f_schema = get_schema(instance.derive(type=f_type), ctx) f_default = defaults.get(f_name, MISSING) if f_default is not MISSING: if isinstance(f_schema, EmptyJSONSchema): @@ -608,7 +644,7 @@ def on_typed_dict(instance: Instance, ctx: Context) -> JSONObjectSchema: required_keys = getattr(instance.type, "__required_keys__", all_keys) return JSONObjectSchema( properties={ - key: get_schema(instance.copy(type=annotations[key]), ctx) + key: get_schema(instance.derive(type=annotations[key]), ctx) for key in all_keys } or None, @@ -689,7 +725,7 @@ def on_collection(instance: Instance, ctx: Context) -> Optional[JSONSchema]: return apply_array_constraints( instance, JSONArraySchema( - items=_get_schema_or_none(instance.copy(type=args[0]), ctx) + items=_get_schema_or_none(instance.derive(type=args[0]), ctx) if args else None ), @@ -707,7 +743,7 @@ def on_collection(instance: Instance, ctx: Context) -> Optional[JSONSchema]: return apply_array_constraints( instance, JSONArraySchema( - items=_get_schema_or_none(instance.copy(type=args[0]), ctx) + items=_get_schema_or_none(instance.derive(type=args[0]), ctx) if args else None, uniqueItems=True, @@ -720,7 +756,7 @@ def on_collection(instance: Instance, ctx: Context) -> Optional[JSONSchema]: instance, JSONArraySchema( items=get_schema( - instance=instance.copy( + instance=instance.derive( type=( Dict[args[0], args[1]] # type: ignore if args @@ -735,11 +771,11 @@ def on_collection(instance: Instance, ctx: Context) -> Optional[JSONSchema]: instance.origin_type, typing.Counter ): schema = JSONObjectSchema( - additionalProperties=get_schema(instance.copy(type=int), ctx), + additionalProperties=get_schema(instance.derive(type=int), ctx), ) if args: schema.propertyNames = _get_schema_or_none( - instance.copy(type=args[0]), ctx + instance.derive(type=args[0]), ctx ) return apply_object_constraints(instance, schema) elif is_typed_dict(instance.origin_type): @@ -749,11 +785,13 @@ def on_collection(instance: Instance, ctx: Context) -> Optional[JSONSchema]: ): schema = JSONObjectSchema( additionalProperties=_get_schema_or_none( - instance.copy(type=args[1]), ctx + instance.derive(type=args[1]), ctx ) if args else None, - propertyNames=_get_schema_or_none(instance.copy(type=args[0]), ctx) + propertyNames=_get_schema_or_none( + instance.derive(type=args[0]), ctx + ) if args else None, ) @@ -764,7 +802,7 @@ def on_collection(instance: Instance, ctx: Context) -> Optional[JSONSchema]: return apply_array_constraints( instance, JSONArraySchema( - items=_get_schema_or_none(instance.copy(type=args[0]), ctx) + items=_get_schema_or_none(instance.derive(type=args[0]), ctx) if args else None ), diff --git a/tests/test_jsonschema/test_jsonschema_builder.py b/tests/test_jsonschema/test_jsonschema_builder.py index 4e93b303..60c2fd23 100644 --- a/tests/test_jsonschema/test_jsonschema_builder.py +++ b/tests/test_jsonschema/test_jsonschema_builder.py @@ -27,10 +27,10 @@ class DataClass: instance = Instance(int) assert instance.metadata == {} - assert instance.holder_class is None + assert instance.owner_class is None instance = Instance(DataClass) - assert instance.holder_class is DataClass + assert instance.owner_class is None assert instance.metadata == {} diff --git a/tests/test_jsonschema/test_jsonschema_generation.py b/tests/test_jsonschema/test_jsonschema_generation.py index f8d84fdf..5e882824 100644 --- a/tests/test_jsonschema/test_jsonschema_generation.py +++ b/tests/test_jsonschema/test_jsonschema_generation.py @@ -950,22 +950,95 @@ def test_overridden_serialization_method_with_return_annotation(): def as_timestamp(dt: datetime.datetime) -> float: return dt.timestamp() # pragma no cover + def first_datetime_as_timestamp( + seq: List[datetime.datetime], + ) -> float: + return as_timestamp(seq[0]) # pragma no cover + @dataclass class DataClass: - x: datetime.datetime - y: datetime.datetime = field(metadata={"serialize": as_timestamp}) + a: datetime.datetime + b: datetime.datetime = field(metadata={"serialize": as_timestamp}) + c: List[datetime.datetime] + d: List[datetime.datetime] = field( + metadata={"serialize": first_datetime_as_timestamp} + ) + e: Optional[datetime.datetime] + f: List[Optional[datetime.datetime]] class Config(BaseConfig): serialization_strategy = { datetime.datetime: {"serialize": as_timestamp} } - assert build_json_schema(DataClass).properties["x"] == JSONSchema( + schema = build_json_schema(DataClass) + assert schema.properties["a"] == JSONSchema( type=JSONSchemaInstanceType.NUMBER ) - assert build_json_schema(DataClass).properties["y"] == JSONSchema( + assert schema.properties["b"] == JSONSchema( type=JSONSchemaInstanceType.NUMBER ) + assert schema.properties["c"] == JSONArraySchema( + items=JSONSchema(type=JSONSchemaInstanceType.NUMBER) + ) + assert schema.properties["d"] == JSONSchema( + type=JSONSchemaInstanceType.NUMBER + ) + assert schema.properties["e"] == JSONSchema( + anyOf=[ + JSONSchema(type=JSONSchemaInstanceType.NUMBER), + JSONSchema(type=JSONSchemaInstanceType.NULL), + ] + ) + assert schema.properties["f"] == JSONArraySchema( + items=JSONSchema( + anyOf=[ + JSONSchema(type=JSONSchemaInstanceType.NUMBER), + JSONSchema(type=JSONSchemaInstanceType.NULL), + ] + ) + ) + + +def test_dataclass_overridden_serialization_method(): + def serialize_as_str(value: Any) -> str: + return str(value) # pragma no cover + + @dataclass + class Inner: + x: int + + @dataclass + class DataClass: + a: Inner + b: Optional[Inner] + c: List[Inner] + d: List[Optional[Inner]] + + class Config(BaseConfig): + serialization_strategy = {Inner: {"serialize": serialize_as_str}} + + schema = build_json_schema(DataClass) + assert schema.properties["a"] == JSONSchema( + type=JSONSchemaInstanceType.STRING + ) + assert schema.properties["b"] == JSONSchema( + anyOf=[ + JSONSchema(type=JSONSchemaInstanceType.STRING), + JSONSchema(type=JSONSchemaInstanceType.NULL), + ] + ) + assert schema.properties["c"] == JSONArraySchema( + items=JSONSchema(type=JSONSchemaInstanceType.STRING) + ) + assert schema.properties["d"] == JSONArraySchema( + items=JSONSchema( + anyOf=[ + JSONSchema(type=JSONSchemaInstanceType.STRING), + JSONSchema(type=JSONSchemaInstanceType.NULL), + ] + ) + ) def test_jsonschema_with_override_for_properties():