Skip to content

Commit

Permalink
Merge pull request #42 from NREL/dt/issue-22
Browse files Browse the repository at this point in the history
Store associations between components
  • Loading branch information
daniel-thom authored Aug 27, 2024
2 parents 1603653 + 85d3308 commit 9dc0e3e
Show file tree
Hide file tree
Showing 6 changed files with 416 additions and 26 deletions.
38 changes: 38 additions & 0 deletions docs/explanation/system.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,3 +58,41 @@ Component.model_json_schema()
- `infrasys` includes some basic quantities in [infrasys.quantities](#quantity-api).
- Pint will automatically convert a list or list of lists of values into a `numpy.ndarray`.
infrasys will handle serialization/de-serialization of these types.


### Component Associations
The system tracks associations between components in order to optimize lookups.

For example, suppose a Generator class has a field for a Bus. It is trivial to find a generator's
bus. However, if you need to find all generators connected to specific bus, you would have to
traverse all generators in the system and check their bus values.

Every time you add a component to a system, `infrasys` inspects the component type for composed
components. It checks for directly connected components, such as `Generator.bus`, and lists of
components. (It does not inspect other composite data structures like dictionaries.)

`infrasys` stores these component associations in a SQLite table and so lookups are fast.

Here is how to complete this example:

```python
generators = system.list_parent_components(bus)
```

If you only want to find specific types, you can pass that type as well.
```python
generators = system.list_parent_components(bus, component_type=Generator)
```

**Warning**: There is one potentially problematic case.

Suppose that you have a system with generators and buses and then reassign the buses, as in
```
gen1.bus = other_bus
```

`infrasys` cannot detect such reassignments and so the component associations will be incorrect.
You must inform `infrasys` to rebuild its internal table.
```
system.rebuild_component_associations()
```
126 changes: 126 additions & 0 deletions src/infrasys/component_associations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,126 @@
import sqlite3
from typing import Optional, Type
from uuid import UUID

from loguru import logger

from infrasys.component import Component
from infrasys.utils.sqlite import execute


class ComponentAssociations:
"""Stores associations between components. Allows callers to quickly find components composed
by other components, such as the generator to which a bus is connected."""

TABLE_NAME = "component_associations"

def __init__(self) -> None:
# This uses a different database because it is not persisted when the system
# is saved to files. It will be rebuilt during de-serialization.
self._con = sqlite3.connect(":memory:")
self._create_metadata_table()

def _create_metadata_table(self):
schema = [
"id INTEGER PRIMARY KEY",
"component_uuid TEXT",
"component_type TEXT",
"attached_component_uuid TEXT",
"attached_component_type TEXT",
]
schema_text = ",".join(schema)
cur = self._con.cursor()
execute(cur, f"CREATE TABLE {self.TABLE_NAME}({schema_text})")
execute(
cur,
f"CREATE INDEX by_c_uuid ON {self.TABLE_NAME}(component_uuid, attached_component_uuid)",
)
self._con.commit()
logger.debug("Created in-memory component associations table")

def add(self, *components: Component):
"""Store an association between each component and directly attached subcomponents.
- Inspects the type of each field of each component's type. Looks for subtypes of
Component and lists of subtypes of Component.
- Does not consider component fields that are dictionaries or other data structures.
"""
rows = []
for component in components:
for field in type(component).model_fields:
val = getattr(component, field)
if isinstance(val, Component):
rows.append(self._make_row(component, val))
elif isinstance(val, list) and val and isinstance(val[0], Component):
for item in val:
rows.append(self._make_row(component, item))

if rows:
self._insert_rows(rows)

def clear(self) -> None:
"""Clear all component associations."""
execute(self._con.cursor(), f"DELETE FROM {self.TABLE_NAME}")
logger.info("Cleared all component associations.")

def list_child_components(
self, component: Component, component_type: Optional[Type[Component]] = None
) -> list[UUID]:
"""Return a list of all component UUIDS that this component composes.
For example, return the bus attached to a generator.
"""
where_clause = "WHERE component_uuid = ?"
if component_type is None:
params = [str(component.uuid)]
else:
params = [str(component.uuid), component_type.__name__]
where_clause += " AND attached_component_type = ?"
query = f"SELECT attached_component_uuid FROM {self.TABLE_NAME} {where_clause}"
cur = self._con.cursor()
return [UUID(x[0]) for x in execute(cur, query, params)]

def list_parent_components(
self, component: Component, component_type: Optional[Type[Component]] = None
) -> list[UUID]:
"""Return a list of all component UUIDS that compose this component.
For example, return all components connected to a bus.
"""
where_clause = "WHERE attached_component_uuid = ?"
if component_type is None:
params = [str(component.uuid)]
else:
params = [str(component.uuid), component_type.__name__]
where_clause += " AND component_type = ?"
query = f"SELECT component_uuid FROM {self.TABLE_NAME} {where_clause}"
cur = self._con.cursor()
return [UUID(x[0]) for x in execute(cur, query, params)]

def remove(self, component: Component) -> None:
"""Delete all rows with this component."""
query = f"""
DELETE
FROM {self.TABLE_NAME}
WHERE component_uuid = ? OR attached_component_uuid = ?
"""
params = [str(component.uuid), str(component.uuid)]
execute(self._con.cursor(), query, params)
logger.debug("Removed all associations with component {}", component.label)

def _insert_rows(self, rows: list[tuple]) -> None:
cur = self._con.cursor()
placeholder = ",".join(["?"] * len(rows[0]))
query = f"INSERT INTO {self.TABLE_NAME} VALUES({placeholder})"
try:
cur.executemany(query, rows)
finally:
self._con.commit()

@staticmethod
def _make_row(component: Component, attached_component: Component):
return (
None,
str(component.uuid),
type(component).__name__,
str(attached_component.uuid),
type(attached_component).__name__,
)
107 changes: 92 additions & 15 deletions src/infrasys/component_manager.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,19 @@
"""Manages components"""

from collections import defaultdict
import itertools
from typing import Any, Callable, Iterable, Type
from collections import defaultdict
from typing import Any, Callable, Iterable, Optional, Type
from uuid import UUID
from loguru import logger

from infrasys.component import Component
from infrasys.exceptions import ISAlreadyAttached, ISNotStored, ISOperationNotAllowed
from infrasys.component_associations import ComponentAssociations
from infrasys.exceptions import (
ISAlreadyAttached,
ISNotStored,
ISOperationNotAllowed,
ISInvalidParameter,
)
from infrasys.models import make_label, get_class_and_name_from_label


Expand All @@ -23,6 +29,7 @@ def __init__(
self._components_by_uuid: dict[UUID, Component] = {}
self._uuid = uuid
self._auto_add_composed_components = auto_add_composed_components
self._associations = ComponentAssociations()

@property
def auto_add_composed_components(self) -> bool:
Expand All @@ -34,17 +41,23 @@ def auto_add_composed_components(self, val: bool) -> None:
"""Set auto_add_composed_components."""
self._auto_add_composed_components = val

def add(self, *args: Component, deserialization_in_progress=False) -> None:
def add(self, *components: Component, deserialization_in_progress=False) -> None:
"""Add one or more components to the system.
Raises
------
ISAlreadyAttached
Raised if a component is already attached to a system.
"""
for component in args:
if not components:
msg = "add_associations requires at least one component"
raise ISInvalidParameter(msg)

for component in components:
self._add(component, deserialization_in_progress)

self._associations.add(*components)

def get(self, component_type: Type[Component], name: str) -> Any:
"""Return the component with the passed type and name.
Expand Down Expand Up @@ -115,6 +128,10 @@ def get_types(self) -> Iterable[Type[Component]]:
"""Return an iterable of all stored types."""
return self._components.keys()

def has_component(self, component) -> bool:
"""Return True if the component is attached."""
return component.uuid in self._components_by_uuid

def iter(
self, *component_types: Type[Component], filter_func: Callable | None = None
) -> Iterable[Any]:
Expand Down Expand Up @@ -167,8 +184,33 @@ def iter_all(self) -> Iterable[Any]:
"""Return an iterator over all components."""
return self._components_by_uuid.values()

def list_child_components(
self, component: Component, component_type: Optional[Type[Component]] = None
) -> list[Component]:
"""Return a list of all components that this component composes."""
return [
self.get_by_uuid(x)
for x in self._associations.list_child_components(
component, component_type=component_type
)
]

def list_parent_components(
self, component: Component, component_type: Optional[Type[Component]] = None
) -> list[Component]:
"""Return a list of all components that compose this component."""
return [
self.get_by_uuid(x)
for x in self._associations.list_parent_components(
component, component_type=component_type
)
]

def to_records(
self, component_type: Type[Component], filter_func: Callable | None = None, **kwargs
self,
component_type: Type[Component],
filter_func: Callable | None = None,
**kwargs,
) -> Iterable[dict]:
"""Return a dictionary representation of the requested components.
Expand All @@ -189,7 +231,7 @@ def to_records(
subcomponent[i] = sub_component_.label
yield data

def remove(self, component: Component) -> Any:
def remove(self, component: Component, cascade_down: bool = True, force: bool = False) -> Any:
"""Remove the component from the system and return it.
Notes
Expand All @@ -200,28 +242,54 @@ def remove(self, component: Component) -> Any:
component_type = type(component)
# The system method should have already performed the check, but for completeness in case
# someone calls it directly, check here.
if (
component_type not in self._components
or component.name not in self._components[component_type]
):
key = component.name or component.label
if component_type not in self._components or key not in self._components[component_type]:
msg = f"{component.label} is not stored"
raise ISNotStored(msg)

container = self._components[component_type][component.name]
self._check_parent_components_for_remove(component, force)
container = self._components[component_type][key]
for i, comp in enumerate(container):
if comp.uuid == component.uuid:
container.pop(i)
if not self._components[component_type][component.name]:
self._components[component_type].pop(component.name)
if not self._components[component_type][key]:
self._components[component_type].pop(key)
self._components_by_uuid.pop(component.uuid)
if not self._components[component_type]:
self._components.pop(component_type)
logger.debug("Removed component {}", component.label)
if cascade_down:
child_components = self._associations.list_child_components(component)
else:
child_components = []
self._associations.remove(component)
for child_uuid in child_components:
child = self.get_by_uuid(child_uuid)
parent_components = self.list_parent_components(child)
if not parent_components:
self.remove(child, cascade_down=cascade_down, force=force)
return

msg = f"Component {component.label} is not stored"
raise ISNotStored(msg)

def _check_parent_components_for_remove(self, component: Component, force: bool) -> None:
parent_components = self.list_parent_components(component)
if parent_components:
parent_labels = ", ".join((x.label for x in parent_components))
if force:
logger.warning(
"Remove {} even though it is attached to these components: {}",
component.label,
parent_labels,
)
else:
msg = (
f"Cannot remove {component.label} because it is attached to these components: "
f"{parent_labels}"
)
raise ISOperationNotAllowed(msg)

def copy(
self,
component: Component,
Expand Down Expand Up @@ -259,6 +327,14 @@ def change_uuid(self, component: Component) -> None:
msg = "change_component_uuid"
raise NotImplementedError(msg)

def rebuild_component_associations(self) -> None:
"""Clear the component associations and rebuild the table. This may be necessary
if a user reassigns connected components that are part of a system.
"""
self._associations.clear()
self._associations.add(*self.iter_all())
logger.info("Rebuilt all component associations.")

def update(
self,
component_type: Type[Component],
Expand Down Expand Up @@ -292,6 +368,7 @@ def _add(self, component: Component, deserialization_in_progress: bool) -> None:

self._components[cls][name].append(component)
self._components_by_uuid[component.uuid] = component

logger.debug("Added {} to the system", component.label)

def _check_component_addition(self, component: Component) -> None:
Expand All @@ -303,7 +380,7 @@ def _check_component_addition(self, component: Component) -> None:
self._handle_composed_component(val)
# Recurse.
self._check_component_addition(val)
if isinstance(val, list) and val and isinstance(val[0], Component):
elif isinstance(val, list) and val and isinstance(val[0], Component):
for item in val:
self._handle_composed_component(item)
# Recurse.
Expand Down
6 changes: 5 additions & 1 deletion src/infrasys/exceptions.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,17 @@ class ISFileExists(ISBaseException):


class ISConflictingArguments(ISBaseException):
"""Raised if the arguments are conflict."""
"""Raised if the arguments conflict."""


class ISConflictingSystem(ISBaseException):
"""Raised if the system has conflicting values."""


class ISInvalidParameter(ISBaseException):
"""Raised if a parameter is invalid."""


class ISNotStored(ISBaseException):
"""Raised if the requested object is not stored."""

Expand Down
Loading

0 comments on commit 9dc0e3e

Please sign in to comment.