diff --git a/docs/index.md b/docs/index.md index d1627c9d..a05a2bce 100644 --- a/docs/index.md +++ b/docs/index.md @@ -324,9 +324,6 @@ print(Settings().model_dump()) `env_nested_delimiter` can be configured via the `model_config` as shown above, or via the `_env_nested_delimiter` keyword argument on instantiation. -JSON is only parsed in top-level fields, if you need to parse JSON in sub-models, you will need to implement -validators on those models. - Nested environment variables take precedence over the top-level environment variable JSON (e.g. in the example above, `SUB_MODEL__V2` trumps `SUB_MODEL`). diff --git a/pydantic_settings/sources.py b/pydantic_settings/sources.py index 335ba9dd..c9a32430 100644 --- a/pydantic_settings/sources.py +++ b/pydantic_settings/sources.py @@ -11,8 +11,8 @@ from dotenv import dotenv_values from pydantic import AliasChoices, AliasPath, BaseModel, Json, TypeAdapter -from pydantic._internal._typing_extra import origin_is_union -from pydantic._internal._utils import deep_update, lenient_issubclass +from pydantic._internal._typing_extra import WithArgsTypes, origin_is_union +from pydantic._internal._utils import deep_update, is_model_class, lenient_issubclass from pydantic.fields import FieldInfo from typing_extensions import get_args, get_origin @@ -188,6 +188,8 @@ def _extract_field_info(self, field: FieldInfo, field_name: str) -> list[tuple[s ) else: # string validation alias field_info.append((v_alias, self._apply_case_sensitive(v_alias), False)) + elif origin_is_union(get_origin(field.annotation)) and _union_is_complex(field.annotation, field.metadata): + field_info.append((field_name, self._apply_case_sensitive(self.env_prefix + field_name), True)) else: field_info.append((field_name, self._apply_case_sensitive(self.env_prefix + field_name), False)) @@ -478,16 +480,13 @@ def prepare_field_value(self, field_name: str, field: FieldInfo, value: Any, val # simplest case, field is not complex, we only need to add the value if it was found return value - def _union_is_complex(self, annotation: type[Any] | None, metadata: list[Any]) -> bool: - return any(_annotation_is_complex(arg, metadata) for arg in get_args(annotation)) - def _field_is_complex(self, field: FieldInfo) -> tuple[bool, bool]: """ Find out if a field is complex, and if so whether JSON errors should be ignored """ if self.field_is_complex(field): allow_parse_failure = False - elif origin_is_union(get_origin(field.annotation)) and self._union_is_complex(field.annotation, field.metadata): + elif origin_is_union(get_origin(field.annotation)) and _union_is_complex(field.annotation, field.metadata): allow_parse_failure = True else: return False, False @@ -495,7 +494,7 @@ def _field_is_complex(self, field: FieldInfo) -> tuple[bool, bool]: return True, allow_parse_failure @staticmethod - def next_field(field: FieldInfo | None, key: str) -> FieldInfo | None: + def next_field(field: FieldInfo | Any | None, key: str) -> FieldInfo | None: """ Find the field in a sub model by key(env name) @@ -524,11 +523,17 @@ class Cfg(BaseSettings): Returns: Field if it finds the next field otherwise `None`. """ - if not field or origin_is_union(get_origin(field.annotation)): - # no support for Unions of complex BaseSettings fields + if not field: return None - elif field.annotation and hasattr(field.annotation, 'model_fields') and field.annotation.model_fields.get(key): - return field.annotation.model_fields[key] + + annotation = field.annotation if isinstance(field, FieldInfo) else field + if origin_is_union(get_origin(annotation)) or isinstance(annotation, WithArgsTypes): + for type_ in get_args(annotation): + type_has_key = EnvSettingsSource.next_field(type_, key) + if type_has_key: + return type_has_key + elif is_model_class(annotation) and annotation.model_fields.get(key): + return annotation.model_fields[key] return None @@ -716,3 +721,7 @@ def _annotation_is_complex_inner(annotation: type[Any] | None) -> bool: return lenient_issubclass(annotation, (BaseModel, Mapping, Sequence, tuple, set, frozenset, deque)) or is_dataclass( annotation ) + + +def _union_is_complex(annotation: type[Any] | None, metadata: list[Any]) -> bool: + return any(_annotation_is_complex(arg, metadata) for arg in get_args(annotation)) diff --git a/tests/test_settings.py b/tests/test_settings.py index b2371953..a7d6912b 100644 --- a/tests/test_settings.py +++ b/tests/test_settings.py @@ -198,6 +198,22 @@ class Cfg(BaseSettings): } +def test_nested_env_optional_json(env): + class Child(BaseModel): + num_list: Optional[List[int]] = None + + class Cfg(BaseSettings, env_nested_delimiter='__'): + child: Optional[Child] = None + + env.set('CHILD__NUM_LIST', '[1,2,3]') + cfg = Cfg() + assert cfg.model_dump() == { + 'child': { + 'num_list': [1, 2, 3], + }, + } + + def test_nested_env_delimiter_with_prefix(env): class Subsettings(BaseSettings): banana: str @@ -1212,6 +1228,21 @@ class Settings(BaseSettings): assert Settings().model_dump() == {'foo': {'a': 'b'}} +def test_secrets_nested_optional_json(tmp_path): + p = tmp_path / 'foo' + p.write_text('{"a": 10}') + + class Foo(BaseModel): + a: int + + class Settings(BaseSettings): + foo: Optional[Foo] = None + + model_config = SettingsConfigDict(secrets_dir=tmp_path) + + assert Settings().model_dump() == {'foo': {'a': 10}} + + def test_secrets_path_invalid_json(tmp_path): p = tmp_path / 'foo' p.write_text('{"a": "b"')