diff --git a/docs/index.md b/docs/index.md index 76c8b5a..218da6a 100644 --- a/docs/index.md +++ b/docs/index.md @@ -969,6 +969,44 @@ For `BaseModel` and `pydantic.dataclasses.dataclass` types, `CliApp.run` will in The alias generator for kebab case does not propagate to subcommands or submodels and will have to be manually set in these cases. +### Mutually Exclusive Groups + +CLI mutually exclusive groups can be created by inheriting from the `CliMutuallyExclusiveGroup` class. + +!!! note + A `CliMutuallyExclusiveGroup` cannot be used in a union or contain nested models. + +```py +from typing import Optional + +from pydantic import BaseModel + +from pydantic_settings import CliApp, CliMutuallyExclusiveGroup, SettingsError + + +class Circle(CliMutuallyExclusiveGroup): + radius: Optional[float] = None + diameter: Optional[float] = None + perimeter: Optional[float] = None + + +class Settings(BaseModel): + circle: Circle + + +try: + CliApp.run( + Settings, + cli_args=['--circle.radius=1', '--circle.diameter=2'], + cli_exit_on_error=False, + ) +except SettingsError as e: + print(e) + """ + error parsing CLI: argument --circle.diameter: not allowed with argument --circle.radius + """ +``` + ### Customizing the CLI Experience The below flags can be used to customise the CLI experience to your needs. diff --git a/pydantic_settings/__init__.py b/pydantic_settings/__init__.py index fd42d3e..5b3aa9f 100644 --- a/pydantic_settings/__init__.py +++ b/pydantic_settings/__init__.py @@ -4,6 +4,7 @@ AzureKeyVaultSettingsSource, CliExplicitFlag, CliImplicitFlag, + CliMutuallyExclusiveGroup, CliPositionalArg, CliSettingsSource, CliSubCommand, @@ -34,6 +35,7 @@ 'CliPositionalArg', 'CliExplicitFlag', 'CliImplicitFlag', + 'CliMutuallyExclusiveGroup', 'InitSettingsSource', 'JsonConfigSettingsSource', 'PyprojectTomlConfigSettingsSource', diff --git a/pydantic_settings/sources.py b/pydantic_settings/sources.py index 06fdee9..fb32095 100644 --- a/pydantic_settings/sources.py +++ b/pydantic_settings/sources.py @@ -149,6 +149,10 @@ def error(self, message: str) -> NoReturn: super().error(message) +class CliMutuallyExclusiveGroup(BaseModel): + pass + + T = TypeVar('T') CliSubCommand = Annotated[Union[T, None], _CliSubCommand] CliPositionalArg = Annotated[T, _CliPositionalArg] @@ -1483,7 +1487,7 @@ def _connect_parser_method( if ( parser_method is not None and self.case_sensitive is False - and method_name == 'parsed_args_method' + and method_name == 'parse_args_method' and isinstance(self._root_parser, _CliInternalArgParser) ): @@ -1515,6 +1519,26 @@ def none_parser_method(*args: Any, **kwargs: Any) -> Any: else: return parser_method + def _connect_group_method(self, add_argument_group_method: Callable[..., Any] | None) -> Callable[..., Any]: + add_argument_group = self._connect_parser_method(add_argument_group_method, 'add_argument_group_method') + + def add_group_method(parser: Any, **kwargs: Any) -> Any: + if not kwargs.pop('_is_cli_mutually_exclusive_group'): + kwargs.pop('required') + return add_argument_group(parser, **kwargs) + else: + main_group_kwargs = {arg: kwargs.pop(arg) for arg in ['title', 'description'] if arg in kwargs} + main_group_kwargs['title'] += ' (mutually exclusive)' + group = add_argument_group(parser, **main_group_kwargs) + if not hasattr(group, 'add_mutually_exclusive_group'): + raise SettingsError( + 'cannot connect CLI settings source root parser: ' + 'group object is missing add_mutually_exclusive_group but is needed for connecting' + ) + return group.add_mutually_exclusive_group(**kwargs) + + return add_group_method + def _connect_root_parser( self, root_parser: T, @@ -1531,9 +1555,9 @@ def _parse_known_args(*args: Any, **kwargs: Any) -> Namespace: self._root_parser = root_parser if parse_args_method is None: parse_args_method = _parse_known_args if self.cli_ignore_unknown_args else ArgumentParser.parse_args - self._parse_args = self._connect_parser_method(parse_args_method, 'parsed_args_method') + self._parse_args = self._connect_parser_method(parse_args_method, 'parse_args_method') self._add_argument = self._connect_parser_method(add_argument_method, 'add_argument_method') - self._add_argument_group = self._connect_parser_method(add_argument_group_method, 'add_argument_group_method') + self._add_group = self._connect_group_method(add_argument_group_method) self._add_parser = self._connect_parser_method(add_parser_method, 'add_parser_method') self._add_subparsers = self._connect_parser_method(add_subparsers_method, 'add_subparsers_method') self._formatter_class = formatter_class @@ -1665,6 +1689,7 @@ def _add_parser_args( if is_parser_submodel: self._add_parser_submodels( parser, + model, sub_models, added_args, arg_prefix, @@ -1680,7 +1705,7 @@ def _add_parser_args( elif not is_alias_path_only: if group is not None: if isinstance(group, dict): - group = self._add_argument_group(parser, **group) + group = self._add_group(parser, **group) added_args += list(arg_names) self._add_argument(group, *(f'{flag_prefix[:len(name)]}{name}' for name in arg_names), **kwargs) else: @@ -1724,6 +1749,7 @@ def _get_arg_names( def _add_parser_submodels( self, parser: Any, + model: type[BaseModel], sub_models: list[type[BaseModel]], added_args: list[str], arg_prefix: str, @@ -1736,10 +1762,23 @@ def _add_parser_submodels( alias_names: tuple[str, ...], model_default: Any, ) -> None: + if issubclass(model, CliMutuallyExclusiveGroup): + # Argparse has deprecated "calling add_argument_group() or add_mutually_exclusive_group() on a + # mutually exclusive group" (https://docs.python.org/3/library/argparse.html#mutual-exclusion). + # Since nested models result in a group add, raise an exception for nested models in a mutually + # exclusive group. + raise SettingsError('cannot have nested models in a CliMutuallyExclusiveGroup') + model_group: Any = None model_group_kwargs: dict[str, Any] = {} model_group_kwargs['title'] = f'{arg_names[0]} options' model_group_kwargs['description'] = field_info.description + model_group_kwargs['required'] = kwargs['required'] + model_group_kwargs['_is_cli_mutually_exclusive_group'] = any( + issubclass(model, CliMutuallyExclusiveGroup) for model in sub_models + ) + if model_group_kwargs['_is_cli_mutually_exclusive_group'] and len(sub_models) > 1: + raise SettingsError('cannot use union with CliMutuallyExclusiveGroup') if self.cli_use_class_docs_for_groups and len(sub_models) == 1: model_group_kwargs['description'] = None if sub_models[0].__doc__ is None else dedent(sub_models[0].__doc__) @@ -1762,7 +1801,7 @@ def _add_parser_submodels( if not self.cli_avoid_json: added_args.append(arg_names[0]) kwargs['help'] = f'set {arg_names[0]} from JSON string' - model_group = self._add_argument_group(parser, **model_group_kwargs) + model_group = self._add_group(parser, **model_group_kwargs) self._add_argument(model_group, *(f'{flag_prefix}{name}' for name in arg_names), **kwargs) for model in sub_models: self._add_parser_args( @@ -1788,7 +1827,7 @@ def _add_parser_alias_paths( if alias_path_args: context = parser if group is not None: - context = self._add_argument_group(parser, **group) if isinstance(group, dict) else group + context = self._add_group(parser, **group) if isinstance(group, dict) else group is_nested_alias_path = arg_prefix.endswith('.') arg_prefix = arg_prefix[:-1] if is_nested_alias_path else arg_prefix for name, metavar in alias_path_args.items(): diff --git a/tests/test_source_cli.py b/tests/test_source_cli.py index c53e777..9cc374b 100644 --- a/tests/test_source_cli.py +++ b/tests/test_source_cli.py @@ -33,6 +33,7 @@ CLI_SUPPRESS, CliExplicitFlag, CliImplicitFlag, + CliMutuallyExclusiveGroup, CliPositionalArg, CliSettingsSource, CliSubCommand, @@ -79,30 +80,30 @@ class SettingWithIgnoreEmpty(BaseSettings): class CliDummyArgGroup(BaseModel, arbitrary_types_allowed=True): group: argparse._ArgumentGroup - def add_argument(self, *args, **kwargs) -> None: + def add_argument(self, *args: Any, **kwargs: Any) -> None: self.group.add_argument(*args, **kwargs) class CliDummySubParsers(BaseModel, arbitrary_types_allowed=True): sub_parser: argparse._SubParsersAction - def add_parser(self, *args, **kwargs) -> 'CliDummyParser': + def add_parser(self, *args: Any, **kwargs: Any) -> 'CliDummyParser': return CliDummyParser(parser=self.sub_parser.add_parser(*args, **kwargs)) class CliDummyParser(BaseModel, arbitrary_types_allowed=True): parser: argparse.ArgumentParser = Field(default_factory=lambda: argparse.ArgumentParser()) - def add_argument(self, *args, **kwargs) -> None: + def add_argument(self, *args: Any, **kwargs: Any) -> None: self.parser.add_argument(*args, **kwargs) - def add_argument_group(self, *args, **kwargs) -> CliDummyArgGroup: + def add_argument_group(self, *args: Any, **kwargs: Any) -> CliDummyArgGroup: return CliDummyArgGroup(group=self.parser.add_argument_group(*args, **kwargs)) - def add_subparsers(self, *args, **kwargs) -> CliDummySubParsers: + def add_subparsers(self, *args: Any, **kwargs: Any) -> CliDummySubParsers: return CliDummySubParsers(sub_parser=self.parser.add_subparsers(*args, **kwargs)) - def parse_args(self, *args, **kwargs) -> argparse.Namespace: + def parse_args(self, *args: Any, **kwargs: Any) -> argparse.Namespace: return self.parser.parse_args(*args, **kwargs) @@ -1826,11 +1827,11 @@ class Cfg(BaseSettings): args = ['--fruit', 'pear'] parsed_args = parser.parse_args(args) - assert Cfg(_cli_settings_source=cli_cfg_settings(parsed_args=parsed_args)).model_dump() == { + assert CliApp.run(Cfg, cli_args=parsed_args, cli_settings_source=cli_cfg_settings).model_dump() == { 'pet': 'bird', 'command': None, } - assert Cfg(_cli_settings_source=cli_cfg_settings(args=args)).model_dump() == { + assert CliApp.run(Cfg, cli_args=args, cli_settings_source=cli_cfg_settings).model_dump() == { 'pet': 'bird', 'command': None, } @@ -1838,28 +1839,28 @@ class Cfg(BaseSettings): arg_prefix = f'{prefix}.' if prefix else '' args = ['--fruit', 'kiwi', f'--{arg_prefix}pet', 'dog'] parsed_args = parser.parse_args(args) - assert Cfg(_cli_settings_source=cli_cfg_settings(parsed_args=parsed_args)).model_dump() == { + assert CliApp.run(Cfg, cli_args=parsed_args, cli_settings_source=cli_cfg_settings).model_dump() == { 'pet': 'dog', 'command': None, } - assert Cfg(_cli_settings_source=cli_cfg_settings(args=args)).model_dump() == { + assert CliApp.run(Cfg, cli_args=args, cli_settings_source=cli_cfg_settings).model_dump() == { 'pet': 'dog', 'command': None, } parsed_args = parser.parse_args(['--fruit', 'kiwi', f'--{arg_prefix}pet', 'cat']) - assert Cfg(_cli_settings_source=cli_cfg_settings(parsed_args=vars(parsed_args))).model_dump() == { + assert CliApp.run(Cfg, cli_args=vars(parsed_args), cli_settings_source=cli_cfg_settings).model_dump() == { 'pet': 'cat', 'command': None, } args = ['--fruit', 'kiwi', f'--{arg_prefix}pet', 'dog', 'command', '--name', 'ralph', '--command', 'roll'] parsed_args = parser.parse_args(args) - assert Cfg(_cli_settings_source=cli_cfg_settings(parsed_args=vars(parsed_args))).model_dump() == { + assert CliApp.run(Cfg, cli_args=vars(parsed_args), cli_settings_source=cli_cfg_settings).model_dump() == { 'pet': 'dog', 'command': {'name': 'ralph', 'command': 'roll'}, } - assert Cfg(_cli_settings_source=cli_cfg_settings(args=args)).model_dump() == { + assert CliApp.run(Cfg, cli_args=args, cli_settings_source=cli_cfg_settings).model_dump() == { 'pet': 'dog', 'command': {'name': 'ralph', 'command': 'roll'}, } @@ -2085,3 +2086,140 @@ class Settings(BaseSettings, cli_parse_args=True): -h, --help show this help message and exit """ ) + + +def test_cli_mutually_exclusive_group(capsys, monkeypatch): + class Circle(CliMutuallyExclusiveGroup): + radius: Optional[float] = 21 + diameter: Optional[float] = 22 + perimeter: Optional[float] = 23 + + class Settings(BaseModel): + circle_optional: Circle = Circle(radius=None, diameter=None, perimeter=24) + circle_required: Circle + + CliApp.run(Settings, cli_args=['--circle-required.radius=1', '--circle-optional.radius=1']).model_dump() == { + 'circle_optional': {'radius': 1, 'diameter': 22, 'perimeter': 24}, + 'circle_required': {'radius': 1, 'diameter': 22, 'perimeter': 23}, + } + + with pytest.raises(SystemExit): + CliApp.run(Settings, cli_args=['--circle-required.radius=1', '--circle-required.diameter=2']) + assert ( + 'error: argument --circle-required.diameter: not allowed with argument --circle-required.radius' + in capsys.readouterr().err + ) + + with pytest.raises(SystemExit): + CliApp.run( + Settings, + cli_args=['--circle-required.radius=1', '--circle-optional.radius=1', '--circle-optional.diameter=2'], + ) + assert ( + 'error: argument --circle-optional.diameter: not allowed with argument --circle-optional.radius' + in capsys.readouterr().err + ) + + with monkeypatch.context() as m: + m.setattr(sys, 'argv', ['example.py', '--help']) + with pytest.raises(SystemExit): + CliApp.run(Settings) + usage = ( + """usage: example.py [-h] [--circle-optional.radius float | + --circle-optional.diameter float | + --circle-optional.perimeter float] + (--circle-required.radius float | + --circle-required.diameter float | + --circle-required.perimeter float)""" + if sys.version_info >= (3, 13) + else """usage: example.py [-h] + [--circle-optional.radius float | --circle-optional.diameter float | --circle-optional.perimeter float] + (--circle-required.radius float | --circle-required.diameter float | --circle-required.perimeter float)""" + ) + assert ( + capsys.readouterr().out + == f"""{usage} + +{ARGPARSE_OPTIONS_TEXT}: + -h, --help show this help message and exit + +circle-optional options (mutually exclusive): + --circle-optional.radius float + (default: None) + --circle-optional.diameter float + (default: None) + --circle-optional.perimeter float + (default: 24.0) + +circle-required options (mutually exclusive): + --circle-required.radius float + (default: 21) + --circle-required.diameter float + (default: 22) + --circle-required.perimeter float + (default: 23) +""" + ) + + +def test_cli_mutually_exclusive_group_exceptions(): + class Circle(CliMutuallyExclusiveGroup): + radius: Optional[float] = 21 + diameter: Optional[float] = 22 + perimeter: Optional[float] = 23 + + class Settings(BaseSettings): + circle: Circle + + parser = CliDummyParser() + with pytest.raises( + SettingsError, + match='cannot connect CLI settings source root parser: group object is missing add_mutually_exclusive_group but is needed for connecting', + ): + CliSettingsSource( + Settings, + root_parser=parser, + parse_args_method=CliDummyParser.parse_args, + add_argument_method=CliDummyParser.add_argument, + add_argument_group_method=CliDummyParser.add_argument_group, + add_parser_method=CliDummySubParsers.add_parser, + add_subparsers_method=CliDummyParser.add_subparsers, + ) + + class SubModel(BaseModel): + pass + + class SettingsInvalidUnion(BaseSettings): + union: Union[Circle, SubModel] + + with pytest.raises(SettingsError, match='cannot use union with CliMutuallyExclusiveGroup'): + CliApp.run(SettingsInvalidUnion) + + class CircleInvalidSubModel(Circle): + square: Optional[SubModel] = None + + class SettingsInvalidOptSubModel(BaseModel): + circle: CircleInvalidSubModel = CircleInvalidSubModel() + + class SettingsInvalidReqSubModel(BaseModel): + circle: CircleInvalidSubModel + + for settings in [SettingsInvalidOptSubModel, SettingsInvalidReqSubModel]: + with pytest.raises(SettingsError, match='cannot have nested models in a CliMutuallyExclusiveGroup'): + CliApp.run(settings) + + class CircleRequiredField(Circle): + length: float + + class SettingsOptCircleReqField(BaseModel): + circle: CircleRequiredField = CircleRequiredField(length=2) + + assert CliApp.run(SettingsOptCircleReqField, cli_args=[]).model_dump() == { + 'circle': {'diameter': 22.0, 'length': 2.0, 'perimeter': 23.0, 'radius': 21.0} + } + + class SettingsInvalidReqCircleReqField(BaseModel): + circle: CircleRequiredField + + with pytest.raises(ValueError, match='mutually exclusive arguments must be optional'): + CliApp.run(SettingsInvalidReqCircleReqField)