Skip to content

Commit

Permalink
first look approximation@
Browse files Browse the repository at this point in the history
  • Loading branch information
stan-dot committed Nov 5, 2024
1 parent 1c50a4b commit 9ec6d18
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 197 deletions.
47 changes: 20 additions & 27 deletions src/blueapi/cli/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
from blueapi.client.client import BlueapiClient
from blueapi.client.event_bus import AnyEvent, BlueskyStreamingError, EventBusClient
from blueapi.client.rest import BlueskyRemoteControlError
from blueapi.config import ApplicationConfig, ConfigLoader
from blueapi.config import ApplicationConfig
from blueapi.core import OTLP_EXPORT_ENABLED, DataEvent
from blueapi.worker import ProgressEvent, Task, WorkerEvent

Expand All @@ -33,23 +33,16 @@
)
@click.pass_context
def main(ctx: click.Context, config: Path | None | tuple[Path, ...]) -> None:
# if no command is supplied, run with the options passed

config_loader = ConfigLoader(ApplicationConfig)
if config is not None:
configs = (config,) if isinstance(config, Path) else config
for path in configs:
if path.exists():
config_loader.use_values_from_yaml(path)
else:
raise FileNotFoundError(f"Cannot find file: {path}")

ctx.ensure_object(dict)
loaded_config: ApplicationConfig = config_loader.load()

ctx.obj["config"] = loaded_config
# Override default yaml_file path in the model_config if `config` is provided
ApplicationConfig.model_config["yaml_file"] = config
app_config = ApplicationConfig() # Instantiates with customized sources
ctx.obj["config"] = app_config

# note: this is the key result of the 'main' function, it loaded the config
# and due to 'pass context' flag above
# it's left for the handler of words that are later in the stdin
logging.basicConfig(
format="%(asctime)s - %(message)s", level=loaded_config.logging.level
format="%(asctime)s - %(message)s", level=app_config.logging.level
)

if ctx.invoked_subcommand is None:
Expand Down Expand Up @@ -163,18 +156,18 @@ def get_devices(obj: dict) -> None:
def listen_to_events(obj: dict) -> None:
"""Listen to events output by blueapi"""
config: ApplicationConfig = obj["config"]
if config.stomp is not None:
event_bus_client = EventBusClient(
StompClient.for_broker(
broker=Broker(
host=config.stomp.host,
port=config.stomp.port,
auth=config.stomp.auth,
)
if config.stomp is None:
raise RuntimeError("Message bus needs to be configured")

event_bus_client = EventBusClient(
StompClient.for_broker(
broker=Broker(
host=config.stomp.host,
port=config.stomp.port,
auth=config.stomp.auth,
)
)
else:
raise RuntimeError("Message bus needs to be configured")
)

fmt = obj["fmt"]

Expand Down
106 changes: 20 additions & 86 deletions src/blueapi/config.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
from collections.abc import Mapping
from enum import Enum
from pathlib import Path
from typing import Any, Generic, Literal, TypeVar
from typing import Literal

import yaml
from bluesky_stomp.models import BasicAuthentication
from pydantic import BaseModel, Field, TypeAdapter, ValidationError
from pydantic import BaseModel, Field
from pydantic_settings import BaseSettings, SettingsConfigDict, YamlConfigSettingsSource

from blueapi.utils import BlueapiBaseModel, InvalidConfigError
from blueapi.utils import BlueapiBaseModel

LogLevel = Literal["NOTSET", "DEBUG", "INFO", "WARNING", "ERROR", "CRITICAL"]

DEFAULT_PATH = Path("config.yaml") # Default YAML file path


class SourceKind(str, Enum):
PLAN_FUNCTIONS = "planFunctions"
Expand Down Expand Up @@ -77,7 +78,7 @@ class ScratchConfig(BlueapiBaseModel):
repositories: list[ScratchRepository] = Field(default_factory=list)


class ApplicationConfig(BlueapiBaseModel):
class ApplicationConfig(BaseSettings, cli_parse_args=True, cli_prog_name="blueapi"):
"""
Config for the worker application as a whole. Root of
config tree.
Expand All @@ -89,83 +90,16 @@ class ApplicationConfig(BlueapiBaseModel):
api: RestConfig = Field(default_factory=RestConfig)
scratch: ScratchConfig | None = None

def __eq__(self, other: object) -> bool:
if isinstance(other, ApplicationConfig):
return (
(self.stomp == other.stomp)
& (self.env == other.env)
& (self.logging == other.logging)
& (self.api == other.api)
)
return False


C = TypeVar("C", bound=BaseModel)


class ConfigLoader(Generic[C]):
"""
Small utility class for loading config from various sources.
You must define a config schema as a dataclass (or series of
nested dataclasses) that can then be loaded from some combination
of default values, dictionaries, YAML/JSON files etc.
"""

def __init__(self, schema: type[C]) -> None:
self._adapter = TypeAdapter(schema)
self._values: dict[str, Any] = {}

def use_values(self, values: Mapping[str, Any]) -> None:
"""
Use all values provided in the config, override any defaults
and values set by previous calls into this class.
Args:
values (Mapping[str, Any]): Dictionary of override values,
does not need to be exhaustive
if defaults provided.
"""

def recursively_update_map(old: dict[str, Any], new: Mapping[str, Any]) -> None:
for key in new:
if (
key in old
and isinstance(old[key], dict)
and isinstance(new[key], dict)
):
recursively_update_map(old[key], new[key])
else:
old[key] = new[key]

recursively_update_map(self._values, values)

def use_values_from_yaml(self, path: Path) -> None:
"""
Use all values provided in a YAML/JSON file in the
config, override any defaults and values set by
previous calls into this class.
Args:
path (Path): Path to YAML/JSON file
"""

with path.open("r") as stream:
values = yaml.load(stream, yaml.Loader)
self.use_values(values)

def load(self) -> C:
"""
Finalize and load the config as an instance of the `schema`
dataclass.
Returns:
C: Dataclass instance holding config
"""

try:
return self._adapter.validate_python(self._values)
except ValidationError as exc:
error_details = "\n".join(str(e) for e in exc.errors())
raise InvalidConfigError(
f"Something is wrong with the configuration file: \n {error_details}"
) from exc
model_config = SettingsConfigDict(
env_nested_delimiter="__", yaml_file=DEFAULT_PATH, yaml_file_encoding="utf-8"
)

@classmethod
def customize_sources(cls, init_settings, env_settings, file_secret_settings):
path = cls.model_config.get("yaml_file")
return (
init_settings,
YamlConfigSettingsSource(settings_cls=cls, yaml_file=path),
env_settings,
file_secret_settings,
)
3 changes: 0 additions & 3 deletions src/blueapi/utils/__init__.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,13 @@
from .base_model import BlueapiBaseModel, BlueapiModelConfig, BlueapiPlanModelConfig
from .invalid_config_error import InvalidConfigError
from .modules import load_module_all
from .serialization import serialize
from .thread_exception import handle_all_exceptions

__all__ = [
"handle_all_exceptions",
"load_module_all",
"ConfigLoader",
"serialize",
"BlueapiBaseModel",
"BlueapiModelConfig",
"BlueapiPlanModelConfig",
"InvalidConfigError",
]
3 changes: 0 additions & 3 deletions src/blueapi/utils/invalid_config_error.py

This file was deleted.

89 changes: 11 additions & 78 deletions tests/unit_tests/test_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,8 +11,7 @@
from bluesky_stomp.models import BasicAuthentication
from pydantic import BaseModel, Field

from blueapi.config import ApplicationConfig, ConfigLoader
from blueapi.utils import InvalidConfigError
from blueapi.config import ApplicationConfig


class Config(BaseModel):
Expand Down Expand Up @@ -60,70 +59,6 @@ def default_yaml(package_root: Path) -> Path:
return package_root.parent.parent / "config" / "defaults.yaml"


@pytest.mark.parametrize("schema", [ConfigWithDefaults, NestedConfigWithDefaults])
def test_load_defaults(schema: type[Any]) -> None:
loader = ConfigLoader(schema)
assert loader.load() == schema()


def test_load_some_defaults() -> None:
loader = ConfigLoader(ConfigWithDefaults)
loader.use_values({"foo": 4})
assert loader.load() == ConfigWithDefaults(foo=4)


def test_load_override_all() -> None:
loader = ConfigLoader(ConfigWithDefaults)
loader.use_values({"foo": 4, "bar": "hi"})
assert loader.load() == ConfigWithDefaults(foo=4, bar="hi")


def test_load_override_all_nested() -> None:
loader = ConfigLoader(NestedConfig)
loader.use_values({"nested": {"foo": 4, "bar": "hi"}, "baz": True})
assert loader.load() == NestedConfig(nested=Config(foo=4, bar="hi"), baz=True)


def test_load_defaultless_schema() -> None:
loader = ConfigLoader(Config)
with pytest.raises(InvalidConfigError):
loader.load()


def test_inject_values_into_defaultless_schema() -> None:
loader = ConfigLoader(Config)
loader.use_values({"foo": 4, "bar": "hi"})
assert loader.load() == Config(foo=4, bar="hi")


def test_load_yaml(config_yaml: Path) -> None:
loader = ConfigLoader(Config)
loader.use_values_from_yaml(config_yaml)
assert loader.load() == Config(foo=5, bar="test string")


def test_load_yaml_nested(nested_config_yaml: Path) -> None:
loader = ConfigLoader(NestedConfig)
loader.use_values_from_yaml(nested_config_yaml)
assert loader.load() == NestedConfig(
nested=Config(foo=6, bar="other test string"), baz=True
)


def test_load_yaml_override(override_config_yaml: Path) -> None:
loader = ConfigLoader(ConfigWithDefaults)
loader.use_values_from_yaml(override_config_yaml)

assert loader.load() == ConfigWithDefaults(foo=7)


def test_error_thrown_if_schema_does_not_match_yaml(nested_config_yaml: Path) -> None:
loader = ConfigLoader(Config)
loader.use_values_from_yaml(nested_config_yaml)
with pytest.raises(InvalidConfigError):
loader.load()


@mock.patch.dict(os.environ, {"FOO": "bar"}, clear=True)
def test_auth_from_env():
auth = BasicAuthentication(username="${FOO}", password="baz")
Expand All @@ -150,7 +85,7 @@ def test_auth_from_env_throws_when_not_available():
with pytest.raises(KeyError):
BasicAuthentication(username="${BAZ}", password="baz")
with pytest.raises(KeyError):
BasicAuthentication(username="${baz}", passcode="baz")
BasicAuthentication(username="${baz}", passcode="baz") # type: ignore


def is_subset(subset: Mapping[str, Any], superset: Mapping[str, Any]) -> bool:
Expand Down Expand Up @@ -231,12 +166,11 @@ def test_config_yaml_parsed(temp_yaml_config_file):
temp_yaml_file_path, config_data = temp_yaml_config_file

# Initialize loader and load config from the YAML file
loader = ConfigLoader(ApplicationConfig)
loader.use_values_from_yaml(temp_yaml_file_path)
loaded_config = loader.load()
ApplicationConfig.model_config["yaml_file"] = temp_yaml_file_path
app_config = ApplicationConfig() # Instantiates with customized sources

# Parse the loaded config JSON into a dictionary
target_dict_json = json.loads(loaded_config.model_dump_json())
target_dict_json = json.loads(app_config.model_dump_json())

# Assert that config_data is a subset of target_dict_json
assert is_subset(config_data, target_dict_json)
Expand Down Expand Up @@ -311,17 +245,16 @@ def test_config_yaml_parsed_complete(temp_yaml_config_file: dict):
temp_yaml_file_path, config_data = temp_yaml_config_file

# Initialize loader and load config from the YAML file
loader = ConfigLoader(ApplicationConfig)
loader.use_values_from_yaml(temp_yaml_file_path)
loaded_config = loader.load()
ApplicationConfig.model_config["yaml_file"] = temp_yaml_file_path
app_config = ApplicationConfig() # Instantiates with customized sources

# Parse the loaded config JSON into a dictionary
target_dict_json = json.loads(loaded_config.model_dump_json())
target_dict_json = json.loads(app_config.model_dump_json())

assert loaded_config.stomp is not None
assert loaded_config.stomp.auth is not None
assert app_config.stomp is not None
assert app_config.stomp.auth is not None
assert (
loaded_config.stomp.auth.password.get_secret_value()
app_config.stomp.auth.password.get_secret_value()
== config_data["stomp"]["auth"]["password"] # noqa: E501
)
# Remove the password field to not compare it again in the full dict comparison
Expand Down

0 comments on commit 9ec6d18

Please sign in to comment.