Skip to content

Commit

Permalink
Fix alias resolution to use preferred key. (#481)
Browse files Browse the repository at this point in the history
  • Loading branch information
kschwab authored Nov 27, 2024
1 parent 6fe3bd1 commit a0924bc
Show file tree
Hide file tree
Showing 2 changed files with 45 additions and 6 deletions.
18 changes: 13 additions & 5 deletions pydantic_settings/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -661,7 +661,9 @@ def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str,
a flag to determine whether value is complex.
"""

for field_key, env_name, value_is_complex in self._extract_field_info(field, field_name):
field_infos = self._extract_field_info(field, field_name)
preferred_key, *_ = field_infos[0]
for field_key, env_name, value_is_complex in field_infos:
# paths reversed to match the last-wins behaviour of `env_file`
for secrets_path in reversed(self.secrets_paths):
path = self.find_case_path(secrets_path, env_name, self.case_sensitive)
Expand All @@ -670,14 +672,16 @@ def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str,
continue

if path.is_file():
return path.read_text().strip(), field_key, value_is_complex
if value_is_complex or (self.config.get('populate_by_name', False) and (field_key == field_name)):
preferred_key = field_key
return path.read_text().strip(), preferred_key, value_is_complex
else:
warnings.warn(
f'attempted to load secret file "{path}" but found a {path_type_label(path)} instead.',
stacklevel=4,
)

return None, field_key, value_is_complex
return None, preferred_key, value_is_complex

def __repr__(self) -> str:
return f'{self.__class__.__name__}(secrets_dir={self.secrets_dir!r})'
Expand Down Expand Up @@ -725,12 +729,16 @@ def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str,
"""

env_val: str | None = None
for field_key, env_name, value_is_complex in self._extract_field_info(field, field_name):
field_infos = self._extract_field_info(field, field_name)
preferred_key, *_ = field_infos[0]
for field_key, env_name, value_is_complex in field_infos:
env_val = self.env_vars.get(env_name)
if env_val is not None:
if value_is_complex or (self.config.get('populate_by_name', False) and (field_key == field_name)):
preferred_key = field_key
break

return env_val, field_key, value_is_complex
return env_val, preferred_key, value_is_complex

def prepare_field_value(self, field_name: str, field: FieldInfo, value: Any, value_is_complex: bool) -> Any:
"""
Expand Down
33 changes: 32 additions & 1 deletion tests/test_source_cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
import typing_extensions
from pydantic import (
AliasChoices,
AliasGenerator,
AliasPath,
BaseModel,
ConfigDict,
Expand Down Expand Up @@ -107,7 +108,7 @@ def parse_args(self, *args: Any, **kwargs: Any) -> argparse.Namespace:
return self.parser.parse_args(*args, **kwargs)


def test_validation_alias_with_cli_prefix():
def test_cli_validation_alias_with_cli_prefix():
class Settings(BaseSettings, cli_exit_on_error=False):
foobar: str = Field(validation_alias='foo')

Expand All @@ -119,6 +120,36 @@ class Settings(BaseSettings, cli_exit_on_error=False):
assert CliApp.run(Settings, cli_args=['--p.foo', 'bar']).foobar == 'bar'


@pytest.mark.parametrize(
'alias_generator',
[
AliasGenerator(validation_alias=lambda s: AliasChoices(s, s.replace('_', '-'))),
AliasGenerator(validation_alias=lambda s: AliasChoices(s.replace('_', '-'), s)),
],
)
def test_cli_alias_resolution_consistency_with_env(env, alias_generator):
class SubModel(BaseModel):
v1: str = 'model default'

class Settings(BaseSettings):
model_config = SettingsConfigDict(
env_nested_delimiter='__',
nested_model_default_partial_update=True,
alias_generator=alias_generator,
)

sub_model: SubModel = SubModel(v1='top default')

assert CliApp.run(Settings, cli_args=[]).model_dump() == {'sub_model': {'v1': 'top default'}}

env.set('SUB_MODEL__V1', 'env default')
assert CliApp.run(Settings, cli_args=[]).model_dump() == {'sub_model': {'v1': 'env default'}}

assert CliApp.run(Settings, cli_args=['--sub-model.v1=cli default']).model_dump() == {
'sub_model': {'v1': 'cli default'}
}


def test_cli_nested_arg():
class SubSubValue(BaseModel):
v6: str
Expand Down

0 comments on commit a0924bc

Please sign in to comment.