Skip to content

Commit

Permalink
Fix alias resolution for default settings source.
Browse files Browse the repository at this point in the history
  • Loading branch information
kschwab committed Nov 5, 2024
1 parent 0922bc1 commit d1ab119
Show file tree
Hide file tree
Showing 2 changed files with 67 additions and 41 deletions.
86 changes: 45 additions & 41 deletions pydantic_settings/sources.py
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@
from dotenv import dotenv_values
from pydantic import AliasChoices, AliasPath, BaseModel, Json, RootModel, TypeAdapter
from pydantic._internal._repr import Representation
from pydantic._internal._signature import _field_name_for_signature
from pydantic._internal._typing_extra import WithArgsTypes, origin_is_union, typing_base
from pydantic._internal._utils import deep_update, is_model_class, lenient_issubclass
from pydantic.dataclasses import is_pydantic_dataclass
Expand Down Expand Up @@ -336,10 +335,12 @@ def __init__(self, settings_cls: type[BaseSettings], nested_model_default_partia
)
if self.nested_model_default_partial_update:
for field_name, field_info in settings_cls.model_fields.items():
alias_names, *_ = _get_alias_names(field_name, field_info)
preferred_alias = alias_names[0]
if is_dataclass(type(field_info.default)):
self.defaults[_field_name_for_signature(field_name, field_info)] = asdict(field_info.default)
self.defaults[preferred_alias] = asdict(field_info.default)
elif is_model_class(type(field_info.default)):
self.defaults[_field_name_for_signature(field_name, field_info)] = field_info.default.model_dump()
self.defaults[preferred_alias] = field_info.default.model_dump()

def get_field_value(self, field: FieldInfo, field_name: str) -> tuple[Any, str, bool]:
# Nothing to do here. Only implement the return statement to make mypy happy
Expand Down Expand Up @@ -1422,41 +1423,6 @@ def _get_sub_models(self, model: type[BaseModel], field_name: str, field_info: F
sub_models.append(type_) # type: ignore
return sub_models

def _get_alias_names(
self, field_name: str, field_info: FieldInfo, alias_path_args: dict[str, str]
) -> tuple[tuple[str, ...], bool]:
alias_names: list[str] = []
is_alias_path_only: bool = True
if not any((field_info.alias, field_info.validation_alias)):
alias_names += [field_name]
is_alias_path_only = False
else:
new_alias_paths: list[AliasPath] = []
for alias in (field_info.alias, field_info.validation_alias):
if alias is None:
continue
elif isinstance(alias, str):
alias_names.append(alias)
is_alias_path_only = False
elif isinstance(alias, AliasChoices):
for name in alias.choices:
if isinstance(name, str):
alias_names.append(name)
is_alias_path_only = False
else:
new_alias_paths.append(name)
else:
new_alias_paths.append(alias)
for alias_path in new_alias_paths:
name = cast(str, alias_path.path[0])
name = name.lower() if not self.case_sensitive else name
alias_path_args[name] = 'dict' if len(alias_path.path) > 2 else 'list'
if not alias_names and is_alias_path_only:
alias_names.append(name)
if not self.case_sensitive:
alias_names = [alias_name.lower() for alias_name in alias_names]
return tuple(dict.fromkeys(alias_names)), is_alias_path_only

def _verify_cli_flag_annotations(self, model: type[BaseModel], field_name: str, field_info: FieldInfo) -> None:
if _CliImplicitFlag in field_info.metadata:
cli_flag_name = 'CliImplicitFlag'
Expand All @@ -1481,7 +1447,7 @@ def _sort_arg_fields(self, model: type[BaseModel]) -> list[tuple[str, FieldInfo]
if not field_info.is_required():
raise SettingsError(f'subcommand argument {model.__name__}.{field_name} has a default value')
else:
alias_names, *_ = self._get_alias_names(field_name, field_info, {})
alias_names, *_ = _get_alias_names(field_name, field_info)
if len(alias_names) > 1:
raise SettingsError(f'subcommand argument {model.__name__}.{field_name} has multiple aliases')
field_types = [type_ for type_ in get_args(field_info.annotation) if type_ is not type(None)]
Expand All @@ -1495,7 +1461,7 @@ def _sort_arg_fields(self, model: type[BaseModel]) -> list[tuple[str, FieldInfo]
if not field_info.is_required():
raise SettingsError(f'positional argument {model.__name__}.{field_name} has a default value')
else:
alias_names, *_ = self._get_alias_names(field_name, field_info, {})
alias_names, *_ = _get_alias_names(field_name, field_info)
if len(alias_names) > 1:
raise SettingsError(f'positional argument {model.__name__}.{field_name} has multiple aliases')
positional_args.append((field_name, field_info))
Expand Down Expand Up @@ -1597,7 +1563,9 @@ def _add_parser_args(
alias_path_args: dict[str, str] = {}
for field_name, field_info in self._sort_arg_fields(model):
sub_models: list[type[BaseModel]] = self._get_sub_models(model, field_name, field_info)
alias_names, is_alias_path_only = self._get_alias_names(field_name, field_info, alias_path_args)
alias_names, is_alias_path_only = _get_alias_names(
field_name, field_info, alias_path_args=alias_path_args, case_sensitive=self.case_sensitive
)
preferred_alias = alias_names[0]
if _CliSubCommand in field_info.metadata:
for model in sub_models:
Expand Down Expand Up @@ -2241,5 +2209,41 @@ def _get_model_fields(model_cls: type[Any]) -> dict[str, FieldInfo]:
raise SettingsError(f'Error: {model_cls.__name__} is not subclass of BaseModel or pydantic.dataclasses.dataclass')


def _get_alias_names(
field_name: str, field_info: FieldInfo, alias_path_args: dict[str, str] = {}, case_sensitive: bool = True
) -> tuple[tuple[str, ...], bool]:
alias_names: list[str] = []
is_alias_path_only: bool = True
if not any((field_info.alias, field_info.validation_alias)):
alias_names += [field_name]
is_alias_path_only = False
else:
new_alias_paths: list[AliasPath] = []
for alias in (field_info.alias, field_info.validation_alias):
if alias is None:
continue
elif isinstance(alias, str):
alias_names.append(alias)
is_alias_path_only = False
elif isinstance(alias, AliasChoices):
for name in alias.choices:
if isinstance(name, str):
alias_names.append(name)
is_alias_path_only = False
else:
new_alias_paths.append(name)
else:
new_alias_paths.append(alias)
for alias_path in new_alias_paths:
name = cast(str, alias_path.path[0])
name = name.lower() if not case_sensitive else name
alias_path_args[name] = 'dict' if len(alias_path.path) > 2 else 'list'
if not alias_names and is_alias_path_only:
alias_names.append(name)
if not case_sensitive:
alias_names = [alias_name.lower() for alias_name in alias_names]
return tuple(dict.fromkeys(alias_names)), is_alias_path_only


def _is_function(obj: Any) -> bool:
return isinstance(obj, (FunctionType, BuiltinFunctionType))
22 changes: 22 additions & 0 deletions tests/test_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from annotated_types import MinLen
from pydantic import (
AliasChoices,
AliasGenerator,
AliasPath,
BaseModel,
Discriminator,
Expand All @@ -32,6 +33,7 @@

from pydantic_settings import (
BaseSettings,
CliApp,
DotEnvSettingsSource,
EnvSettingsSource,
InitSettingsSource,
Expand Down Expand Up @@ -621,6 +623,26 @@ def settings_customise_sources(
assert s.model_dump() == s_final


def test_alias_nested_model_default_partial_update():
class SubModel(BaseModel):
v1: str = 'default'
v2: bytes = b'hello'
v3: int

class Settings(BaseSettings):
model_config = SettingsConfigDict(
nested_model_default_partial_update=True, alias_generator=AliasGenerator(lambda s: s.replace('_', '-'))
)

v0: str = 'ok'
sub_model: SubModel = SubModel(v1='top default', v3=33)

assert CliApp.run(Settings, cli_args=['--sub-model.v1=cli']).model_dump() == {
'v0': 'ok',
'sub_model': {'v1': 'cli', 'v2': b'hello', 'v3': 33},
}


def test_env_str(env):
class Settings(BaseSettings):
apple: str = Field(None, validation_alias='BOOM')
Expand Down

0 comments on commit d1ab119

Please sign in to comment.