diff --git a/src/blueapi/cli/cli.py b/src/blueapi/cli/cli.py index c98800c20..db8a22d1c 100644 --- a/src/blueapi/cli/cli.py +++ b/src/blueapi/cli/cli.py @@ -18,7 +18,7 @@ from blueapi.client.client import BlueapiClient from blueapi.client.event_bus import AnyEvent, BlueskyStreamingError, EventBusClient from blueapi.client.rest import BlueskyRemoteControlError -from blueapi.config import ApplicationConfig, ConfigLoader +from blueapi.config import ApplicationConfig from blueapi.core import OTLP_EXPORT_ENABLED, DataEvent from blueapi.worker import ProgressEvent, Task, WorkerEvent @@ -33,23 +33,16 @@ ) @click.pass_context def main(ctx: click.Context, config: Path | None | tuple[Path, ...]) -> None: - # if no command is supplied, run with the options passed - - config_loader = ConfigLoader(ApplicationConfig) - if config is not None: - configs = (config,) if isinstance(config, Path) else config - for path in configs: - if path.exists(): - config_loader.use_values_from_yaml(path) - else: - raise FileNotFoundError(f"Cannot find file: {path}") - - ctx.ensure_object(dict) - loaded_config: ApplicationConfig = config_loader.load() - - ctx.obj["config"] = loaded_config + # Override default yaml_file path in the model_config if `config` is provided + ApplicationConfig.model_config["yaml_file"] = config + app_config = ApplicationConfig() # Instantiates with customized sources + ctx.obj["config"] = app_config + + # note: this is the key result of the 'main' function, it loaded the config + # and due to 'pass context' flag above + # it's left for the handler of words that are later in the stdin logging.basicConfig( - format="%(asctime)s - %(message)s", level=loaded_config.logging.level + format="%(asctime)s - %(message)s", level=app_config.logging.level ) if ctx.invoked_subcommand is None: @@ -163,18 +156,18 @@ def get_devices(obj: dict) -> None: def listen_to_events(obj: dict) -> None: """Listen to events output by blueapi""" config: ApplicationConfig = obj["config"] - if config.stomp is not None: - event_bus_client = EventBusClient( - StompClient.for_broker( - broker=Broker( - host=config.stomp.host, - port=config.stomp.port, - auth=config.stomp.auth, - ) + if config.stomp is None: + raise RuntimeError("Message bus needs to be configured") + + event_bus_client = EventBusClient( + StompClient.for_broker( + broker=Broker( + host=config.stomp.host, + port=config.stomp.port, + auth=config.stomp.auth, ) ) - else: - raise RuntimeError("Message bus needs to be configured") + ) fmt = obj["fmt"] diff --git a/src/blueapi/config.py b/src/blueapi/config.py index 16e470c0d..2dccd12fd 100644 --- a/src/blueapi/config.py +++ b/src/blueapi/config.py @@ -1,16 +1,17 @@ -from collections.abc import Mapping from enum import Enum from pathlib import Path -from typing import Any, Generic, Literal, TypeVar +from typing import Literal -import yaml from bluesky_stomp.models import BasicAuthentication -from pydantic import BaseModel, Field, TypeAdapter, ValidationError +from pydantic import BaseModel, Field +from pydantic_settings import BaseSettings, SettingsConfigDict, YamlConfigSettingsSource -from blueapi.utils import BlueapiBaseModel, InvalidConfigError +from blueapi.utils import BlueapiBaseModel LogLevel = Literal["NOTSET", "DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"] +DEFAULT_PATH = Path("config.yaml") # Default YAML file path + class SourceKind(str, Enum): PLAN_FUNCTIONS = "planFunctions" @@ -77,7 +78,7 @@ class ScratchConfig(BlueapiBaseModel): repositories: list[ScratchRepository] = Field(default_factory=list) -class ApplicationConfig(BlueapiBaseModel): +class ApplicationConfig(BaseSettings, cli_parse_args=True, cli_prog_name="blueapi"): """ Config for the worker application as a whole. Root of config tree. @@ -89,83 +90,16 @@ class ApplicationConfig(BlueapiBaseModel): api: RestConfig = Field(default_factory=RestConfig) scratch: ScratchConfig | None = None - def __eq__(self, other: object) -> bool: - if isinstance(other, ApplicationConfig): - return ( - (self.stomp == other.stomp) - & (self.env == other.env) - & (self.logging == other.logging) - & (self.api == other.api) - ) - return False - - -C = TypeVar("C", bound=BaseModel) - - -class ConfigLoader(Generic[C]): - """ - Small utility class for loading config from various sources. - You must define a config schema as a dataclass (or series of - nested dataclasses) that can then be loaded from some combination - of default values, dictionaries, YAML/JSON files etc. - """ - - def __init__(self, schema: type[C]) -> None: - self._adapter = TypeAdapter(schema) - self._values: dict[str, Any] = {} - - def use_values(self, values: Mapping[str, Any]) -> None: - """ - Use all values provided in the config, override any defaults - and values set by previous calls into this class. - - Args: - values (Mapping[str, Any]): Dictionary of override values, - does not need to be exhaustive - if defaults provided. - """ - - def recursively_update_map(old: dict[str, Any], new: Mapping[str, Any]) -> None: - for key in new: - if ( - key in old - and isinstance(old[key], dict) - and isinstance(new[key], dict) - ): - recursively_update_map(old[key], new[key]) - else: - old[key] = new[key] - - recursively_update_map(self._values, values) - - def use_values_from_yaml(self, path: Path) -> None: - """ - Use all values provided in a YAML/JSON file in the - config, override any defaults and values set by - previous calls into this class. - - Args: - path (Path): Path to YAML/JSON file - """ - - with path.open("r") as stream: - values = yaml.load(stream, yaml.Loader) - self.use_values(values) - - def load(self) -> C: - """ - Finalize and load the config as an instance of the `schema` - dataclass. - - Returns: - C: Dataclass instance holding config - """ - - try: - return self._adapter.validate_python(self._values) - except ValidationError as exc: - error_details = "\n".join(str(e) for e in exc.errors()) - raise InvalidConfigError( - f"Something is wrong with the configuration file: \n {error_details}" - ) from exc + model_config = SettingsConfigDict( + env_nested_delimiter="__", yaml_file=DEFAULT_PATH, yaml_file_encoding="utf-8" + ) + + @classmethod + def customize_sources(cls, init_settings, env_settings, file_secret_settings): + path = cls.model_config.get("yaml_file") + return ( + init_settings, + YamlConfigSettingsSource(settings_cls=cls, yaml_file=path), + env_settings, + file_secret_settings, + ) diff --git a/src/blueapi/utils/__init__.py b/src/blueapi/utils/__init__.py index b871f842a..b74cf5e08 100644 --- a/src/blueapi/utils/__init__.py +++ b/src/blueapi/utils/__init__.py @@ -1,5 +1,4 @@ from .base_model import BlueapiBaseModel, BlueapiModelConfig, BlueapiPlanModelConfig -from .invalid_config_error import InvalidConfigError from .modules import load_module_all from .serialization import serialize from .thread_exception import handle_all_exceptions @@ -7,10 +6,8 @@ __all__ = [ "handle_all_exceptions", "load_module_all", - "ConfigLoader", "serialize", "BlueapiBaseModel", "BlueapiModelConfig", "BlueapiPlanModelConfig", - "InvalidConfigError", ] diff --git a/src/blueapi/utils/invalid_config_error.py b/src/blueapi/utils/invalid_config_error.py deleted file mode 100644 index be99d5a9e..000000000 --- a/src/blueapi/utils/invalid_config_error.py +++ /dev/null @@ -1,3 +0,0 @@ -class InvalidConfigError(Exception): - def __init__(self, message="Configuration is invalid"): - super().__init__(message) diff --git a/tests/unit_tests/test_config.py b/tests/unit_tests/test_config.py index 5e2ec84f9..c92cb6aca 100644 --- a/tests/unit_tests/test_config.py +++ b/tests/unit_tests/test_config.py @@ -11,8 +11,7 @@ from bluesky_stomp.models import BasicAuthentication from pydantic import BaseModel, Field -from blueapi.config import ApplicationConfig, ConfigLoader -from blueapi.utils import InvalidConfigError +from blueapi.config import ApplicationConfig class Config(BaseModel): @@ -60,70 +59,6 @@ def default_yaml(package_root: Path) -> Path: return package_root.parent.parent / "config" / "defaults.yaml" -@pytest.mark.parametrize("schema", [ConfigWithDefaults, NestedConfigWithDefaults]) -def test_load_defaults(schema: type[Any]) -> None: - loader = ConfigLoader(schema) - assert loader.load() == schema() - - -def test_load_some_defaults() -> None: - loader = ConfigLoader(ConfigWithDefaults) - loader.use_values({"foo": 4}) - assert loader.load() == ConfigWithDefaults(foo=4) - - -def test_load_override_all() -> None: - loader = ConfigLoader(ConfigWithDefaults) - loader.use_values({"foo": 4, "bar": "hi"}) - assert loader.load() == ConfigWithDefaults(foo=4, bar="hi") - - -def test_load_override_all_nested() -> None: - loader = ConfigLoader(NestedConfig) - loader.use_values({"nested": {"foo": 4, "bar": "hi"}, "baz": True}) - assert loader.load() == NestedConfig(nested=Config(foo=4, bar="hi"), baz=True) - - -def test_load_defaultless_schema() -> None: - loader = ConfigLoader(Config) - with pytest.raises(InvalidConfigError): - loader.load() - - -def test_inject_values_into_defaultless_schema() -> None: - loader = ConfigLoader(Config) - loader.use_values({"foo": 4, "bar": "hi"}) - assert loader.load() == Config(foo=4, bar="hi") - - -def test_load_yaml(config_yaml: Path) -> None: - loader = ConfigLoader(Config) - loader.use_values_from_yaml(config_yaml) - assert loader.load() == Config(foo=5, bar="test string") - - -def test_load_yaml_nested(nested_config_yaml: Path) -> None: - loader = ConfigLoader(NestedConfig) - loader.use_values_from_yaml(nested_config_yaml) - assert loader.load() == NestedConfig( - nested=Config(foo=6, bar="other test string"), baz=True - ) - - -def test_load_yaml_override(override_config_yaml: Path) -> None: - loader = ConfigLoader(ConfigWithDefaults) - loader.use_values_from_yaml(override_config_yaml) - - assert loader.load() == ConfigWithDefaults(foo=7) - - -def test_error_thrown_if_schema_does_not_match_yaml(nested_config_yaml: Path) -> None: - loader = ConfigLoader(Config) - loader.use_values_from_yaml(nested_config_yaml) - with pytest.raises(InvalidConfigError): - loader.load() - - @mock.patch.dict(os.environ, {"FOO": "bar"}, clear=True) def test_auth_from_env(): auth = BasicAuthentication(username="${FOO}", password="baz") @@ -150,7 +85,7 @@ def test_auth_from_env_throws_when_not_available(): with pytest.raises(KeyError): BasicAuthentication(username="${BAZ}", password="baz") with pytest.raises(KeyError): - BasicAuthentication(username="${baz}", passcode="baz") + BasicAuthentication(username="${baz}", passcode="baz") # type: ignore def is_subset(subset: Mapping[str, Any], superset: Mapping[str, Any]) -> bool: @@ -231,12 +166,11 @@ def test_config_yaml_parsed(temp_yaml_config_file): temp_yaml_file_path, config_data = temp_yaml_config_file # Initialize loader and load config from the YAML file - loader = ConfigLoader(ApplicationConfig) - loader.use_values_from_yaml(temp_yaml_file_path) - loaded_config = loader.load() + ApplicationConfig.model_config["yaml_file"] = temp_yaml_file_path + app_config = ApplicationConfig() # Instantiates with customized sources # Parse the loaded config JSON into a dictionary - target_dict_json = json.loads(loaded_config.model_dump_json()) + target_dict_json = json.loads(app_config.model_dump_json()) # Assert that config_data is a subset of target_dict_json assert is_subset(config_data, target_dict_json) @@ -311,17 +245,16 @@ def test_config_yaml_parsed_complete(temp_yaml_config_file: dict): temp_yaml_file_path, config_data = temp_yaml_config_file # Initialize loader and load config from the YAML file - loader = ConfigLoader(ApplicationConfig) - loader.use_values_from_yaml(temp_yaml_file_path) - loaded_config = loader.load() + ApplicationConfig.model_config["yaml_file"] = temp_yaml_file_path + app_config = ApplicationConfig() # Instantiates with customized sources # Parse the loaded config JSON into a dictionary - target_dict_json = json.loads(loaded_config.model_dump_json()) + target_dict_json = json.loads(app_config.model_dump_json()) - assert loaded_config.stomp is not None - assert loaded_config.stomp.auth is not None + assert app_config.stomp is not None + assert app_config.stomp.auth is not None assert ( - loaded_config.stomp.auth.password.get_secret_value() + app_config.stomp.auth.password.get_secret_value() == config_data["stomp"]["auth"]["password"] # noqa: E501 ) # Remove the password field to not compare it again in the full dict comparison