From a0924bcd978e7b784441beb05342c5f0420fd993 Mon Sep 17 00:00:00 2001 From: Kyle Schwab Date: Wed, 27 Nov 2024 02:00:04 -0700 Subject: [PATCH] Fix alias resolution to use preferred key. (#481) --- pydantic_settings/sources.py | 18 +++++++++++++----- tests/test_source_cli.py | 33 ++++++++++++++++++++++++++++++++- 2 files changed, 45 insertions(+), 6 deletions(-) diff --git a/pydantic_settings/sources.py b/pydantic_settings/sources.py index 4b0e984..66966e6 100644 --- a/pydantic_settings/sources.py +++ b/pydantic_settings/sources.py @@ -661,7 +661,9 @@ def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, a flag to determine whether value is complex. """ - for field_key, env_name, value_is_complex in self._extract_field_info(field, field_name): + field_infos = self._extract_field_info(field, field_name) + preferred_key, *_ = field_infos[0] + for field_key, env_name, value_is_complex in field_infos: # paths reversed to match the last-wins behaviour of `env_file` for secrets_path in reversed(self.secrets_paths): path = self.find_case_path(secrets_path, env_name, self.case_sensitive) @@ -670,14 +672,16 @@ def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, continue if path.is_file(): - return path.read_text().strip(), field_key, value_is_complex + if value_is_complex or (self.config.get('populate_by_name', False) and (field_key == field_name)): + preferred_key = field_key + return path.read_text().strip(), preferred_key, value_is_complex else: warnings.warn( f'attempted to load secret file "{path}" but found a {path_type_label(path)} instead.', stacklevel=4, ) - return None, field_key, value_is_complex + return None, preferred_key, value_is_complex def __repr__(self) -> str: return f'{self.__class__.__name__}(secrets_dir={self.secrets_dir!r})' @@ -725,12 +729,16 @@ def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, """ env_val: str | None = None - for field_key, env_name, value_is_complex in self._extract_field_info(field, field_name): + field_infos = self._extract_field_info(field, field_name) + preferred_key, *_ = field_infos[0] + for field_key, env_name, value_is_complex in field_infos: env_val = self.env_vars.get(env_name) if env_val is not None: + if value_is_complex or (self.config.get('populate_by_name', False) and (field_key == field_name)): + preferred_key = field_key break - return env_val, field_key, value_is_complex + return env_val, preferred_key, value_is_complex def prepare_field_value(self, field_name: str, field: FieldInfo, value: Any, value_is_complex: bool) -> Any: """ diff --git a/tests/test_source_cli.py b/tests/test_source_cli.py index d9288c4..fa623fc 100644 --- a/tests/test_source_cli.py +++ b/tests/test_source_cli.py @@ -10,6 +10,7 @@ import typing_extensions from pydantic import ( AliasChoices, + AliasGenerator, AliasPath, BaseModel, ConfigDict, @@ -107,7 +108,7 @@ def parse_args(self, *args: Any, **kwargs: Any) -> argparse.Namespace: return self.parser.parse_args(*args, **kwargs) -def test_validation_alias_with_cli_prefix(): +def test_cli_validation_alias_with_cli_prefix(): class Settings(BaseSettings, cli_exit_on_error=False): foobar: str = Field(validation_alias='foo') @@ -119,6 +120,36 @@ class Settings(BaseSettings, cli_exit_on_error=False): assert CliApp.run(Settings, cli_args=['--p.foo', 'bar']).foobar == 'bar' +@pytest.mark.parametrize( + 'alias_generator', + [ + AliasGenerator(validation_alias=lambda s: AliasChoices(s, s.replace('_', '-'))), + AliasGenerator(validation_alias=lambda s: AliasChoices(s.replace('_', '-'), s)), + ], +) +def test_cli_alias_resolution_consistency_with_env(env, alias_generator): + class SubModel(BaseModel): + v1: str = 'model default' + + class Settings(BaseSettings): + model_config = SettingsConfigDict( + env_nested_delimiter='__', + nested_model_default_partial_update=True, + alias_generator=alias_generator, + ) + + sub_model: SubModel = SubModel(v1='top default') + + assert CliApp.run(Settings, cli_args=[]).model_dump() == {'sub_model': {'v1': 'top default'}} + + env.set('SUB_MODEL__V1', 'env default') + assert CliApp.run(Settings, cli_args=[]).model_dump() == {'sub_model': {'v1': 'env default'}} + + assert CliApp.run(Settings, cli_args=['--sub-model.v1=cli default']).model_dump() == { + 'sub_model': {'v1': 'cli default'} + } + + def test_cli_nested_arg(): class SubSubValue(BaseModel): v6: str