Skip to content

Commit

Permalink
[feature/PI-618-bulk_etl] Layer updates
Browse files Browse the repository at this point in the history
  • Loading branch information
jaklinger committed Dec 2, 2024
1 parent bce830f commit 44d1bf7
Show file tree
Hide file tree
Showing 17 changed files with 220 additions and 357 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ repos:
rev: v1.4.0
hooks:
- id: detect-secrets
exclude: ".pre-commit-config.yaml|infrastructure/localstack/provider.tf|src/etl/sds/tests/changelog"
exclude: ".pre-commit-config.yaml|infrastructure/localstack/provider.tf|src/etl/sds/tests/changelog|src/etl/sds/worker/bulk/transform_bulk/tests|src/etl/sds/worker/bulk/tests/stage_data"

- repo: https://github.com/prettier/pre-commit
rev: 57f39166b5a5a504d6808b87ab98d41ebf095b46
Expand Down
59 changes: 26 additions & 33 deletions src/layers/domain/api/sds/query.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,39 @@
from functools import cache
from itertools import chain, combinations

from pydantic import BaseModel, Extra, root_validator


class SearchSDSQueryParams(BaseModel):
class SearchSDSDeviceQueryParams(BaseModel, extra=Extra.forbid):
nhs_id_code: str
nhs_as_svc_ia: str
nhs_mhs_manufacturer_org: str = None
nhs_mhs_party_key: str = None

@root_validator
def client_to_id(cls, values):
nhs_as_client = values.get("nhs_as_client")
nhs_id_code = values.get("nhs_id_code")
if nhs_as_client and not nhs_id_code:
values["nhs_id_code"] = nhs_as_client
return values

def get_non_null_params(self):
return self.dict(exclude_none=True)

@classmethod
@cache
def allowed_field_combinations(cls) -> list[set[str]]:
"""
This method is used to generate all allowed combinations of search fields
for the given query parameters. Down the line this also used to generate
Device.tags in the ETL
"""
mandatory_fields, optional_fields = [], []
for field_name, field in cls.__fields__.items():
if field.required:
mandatory_fields.append(field_name)
else:
optional_fields.append(field_name)

n_minimum_optional_fields = 0 if mandatory_fields else 1
n_optional_fields = len(optional_fields)
optional_field_combinations = chain.from_iterable(
combinations(optional_fields, n_fields)
for n_fields in range(n_minimum_optional_fields, n_optional_fields + 1)
)

return [
{*mandatory_fields, *_optional_field_combination}
for _optional_field_combination in optional_field_combinations
{"nhs_id_code", "nhs_as_svc_ia"},
{"nhs_id_code", "nhs_as_svc_ia", "nhs_mhs_party_key"},
]


class SearchSDSDeviceQueryParams(SearchSDSQueryParams, extra=Extra.forbid):
nhs_as_client: str
nhs_as_svc_ia: str
nhs_mhs_manufacturer_org: str = None
nhs_mhs_party_key: str = None

class SearchSDSEndpointQueryParams(BaseModel, extra=Extra.forbid):
# can effectively achieve these without tags now
# if
# 1. nhs_id_code and nhs_mhs_party_key -> only need party key
# 2. nhs_id_code and nhs_mhs_svc_ia -> can retrieve all devices for org, and filter on nhs_mhs_svc_ia
# 3. nhs_mhs_party_key and nhs_mhs_svc_ia -> can retrieve the mhs device for product, and filter on nhs_mhs_svc_ia --> possible that can read device directly by cpaid
# 4. nhs_id_code and nhs_mhs_party_key and nhs_mhs_svc_ia -> same as 3.

class SearchSDSEndpointQueryParams(SearchSDSQueryParams, extra=Extra.forbid):
nhs_id_code: str = None
nhs_mhs_svc_ia: str = None
nhs_mhs_party_key: str = None
Expand All @@ -59,3 +49,6 @@ def check_filters(cls, values: dict):
"At least 2 query parameters should be provided of type, nhs_id_code, nhs_mhs_svc_ia and nhs_mhs_party_key"
)
return values

def get_non_null_params(self):
return self.dict(exclude_none=True)
40 changes: 0 additions & 40 deletions src/layers/domain/api/tests/test_search_query.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
from typing import Optional

import pytest
from domain.api.sds.query import (
SearchSDSDeviceQueryParams,
SearchSDSEndpointQueryParams,
SearchSDSQueryParams,
)
from pydantic import ValidationError

Expand Down Expand Up @@ -97,40 +94,3 @@ def test_endpoint_query_accepted(params):
def test_endpoint_query_invalid(params):
with pytest.raises(ValidationError):
search = SearchSDSEndpointQueryParams(**params)


def test_allowed_field_combinations():
class MyModel(SearchSDSQueryParams):
foo: str
bar: Optional[str]
bob: Optional[str]

assert MyModel.allowed_field_combinations() == [
{"foo"},
{"foo", "bar"},
{"foo", "bob"},
{"foo", "bar", "bob"},
]


def test_allowed_field_combinations_all_optional():
class MyModel(SearchSDSQueryParams):
foo: Optional[str]
bar: Optional[str]
bob: Optional[str]

assert MyModel.allowed_field_combinations() == [
{
"foo",
},
{
"bar",
},
{
"bob",
},
{"foo", "bar"},
{"foo", "bob"},
{"bar", "bob"},
{"foo", "bar", "bob"},
]
4 changes: 0 additions & 4 deletions src/layers/domain/core/device/v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,6 @@ class DeviceTagsClearedEvent(Event):

@dataclass(kw_only=True, slots=True)
class QuestionnaireResponseUpdatedEvent(Event):
"""
This is adding the initial questionnaire response from the event body request.
"""

id: str
questionnaire_responses: dict[str, list[QuestionnaireResponse]]
keys: list[DeviceKey]
Expand Down
2 changes: 1 addition & 1 deletion src/layers/domain/core/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class AccreditedSystem:
ID_PATTERN = re.compile(rf"^[a-zA-Z-0-9]+$")

class PartyKey:
PARTY_KEY_REGEX = rf"^{_ODS_CODE_REGEX}-[0-9]{{6}}$"
PARTY_KEY_REGEX = rf"^{_ODS_CODE_REGEX}-[0-9]{{6,9}}$"
ID_PATTERN = re.compile(PARTY_KEY_REGEX)

class CpaId:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,5 @@
"nhs_as_acf": "AS ACF",
"nhs_temp_uid": "Temp UID",
"description": "Description",
"nhs_as_category_bag": "AS Category Bag"
"nhs_as_category_bag": "Category Bag"
}
2 changes: 1 addition & 1 deletion src/layers/etl_utils/ldif/_ldif.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def __next_key_and_value(self):
# All values should be valid ascii; we support UTF-8 as a
# non-official, backwards compatibility layer.
attr_value = unfolded_line[colon_pos + 1 :].encode("utf-8")
return attr_type.lower(), attr_value
return attr_type.lower(), attr_value.strip()

def _consume_empty_lines(self):
"""
Expand Down
13 changes: 12 additions & 1 deletion src/layers/etl_utils/ldif/ldif.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import re
from base64 import b64decode
from collections import defaultdict
from io import BytesIO
from types import FunctionType
Expand Down Expand Up @@ -109,11 +110,16 @@ def __init__(self, group_field: str, filter_terms: list[tuple[str, str]]):
for key, value in filter_terms
]
self.group_filter = re.compile(rf"(?i)^({group_field}): (.*)\n$".encode()).match
self.group_filter_base_64 = re.compile(
rf"(?i)^({group_field}):: (.*)\n$".encode()
).match
self.reset()

def flush(self) -> str:
if self.group is None:
raise Exception
raise ValueError(
f"No group name assigned to the following group:\n{self.buffer}"
)
self.data[self.group].write(self.buffer)
self.reset()

Expand All @@ -126,6 +132,11 @@ def parse(self, line: bytes):
group_match = self.group_filter(line)
if group_match:
(_, self.group) = group_match.groups()
else:
b64_group_match = self.group_filter_base_64(line)
if b64_group_match:
(_, b64_group) = b64_group_match.groups()
self.group = b64decode(b64_group).strip()

if not self.keep and any(filter(line) for filter in self.filters):
self.keep = True
Expand Down
67 changes: 67 additions & 0 deletions src/layers/etl_utils/ldif/tests/test_ldif.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,54 @@
myOtherField: 123
"""

LDIF_TO_FILTER_AND_GROUP_EXAMPLE_BASE_64 = """
dn: uniqueIdentifier=AAA1
myField:: QUFB
myOtherField: 123
dn: uniqueIdentifier=BBB1
myfield: BBB
myOtherField: 123
dn: uniqueIdentifier=BBB2
myfield: BBB
myOtherField: 123
dn: uniqueIdentifier=AAA2
myfield: AAA
myOtherField: 123
dn: uniqueIdentifier=AAA3
myField: AAA
myOtherField: 234
dn: uniqueIdentifier=BBB3
myfield:: QkJC
myOtherField: 123
"""

FILTERED_AND_GROUPED_LDIF_TO_FILTER_AND_GROUP_EXAMPLE = """
dn: uniqueIdentifier=AAA1
myField:: QUFB
myOtherField: 123
dn: uniqueIdentifier=AAA2
myfield: AAA
myOtherField: 123
dn: uniqueIdentifier=BBB1
myfield: BBB
myOtherField: 123
dn: uniqueIdentifier=BBB2
myfield: BBB
myOtherField: 123
dn: uniqueIdentifier=BBB3
myfield:: QkJC
myOtherField: 123
"""


@pytest.mark.parametrize(
("raw_distinguished_name", "parsed_distinguished_name"),
Expand Down Expand Up @@ -390,6 +438,25 @@ def test_filter_and_group_ldif_from_s3_by_property(mocked_open):
)


@mock.patch(
"etl_utils.ldif.ldif._smart_open",
return_value=BytesIO(LDIF_TO_FILTER_AND_GROUP_EXAMPLE_BASE_64.encode()),
)
def test_filter_and_group_ldif_from_s3_by_property_with_b64encoded_group(mocked_open):
with mock_aws():
s3_client = boto3.client("s3")
filtered_ldif = filter_and_group_ldif_from_s3_by_property(
s3_client=s3_client,
s3_path="s3://dummy_bucket/dummy_key",
group_field="myField",
filter_terms=[("myOtherField", "123")],
)
assert (
"".join(data.tobytes().decode() for data in filtered_ldif)
== FILTERED_AND_GROUPED_LDIF_TO_FILTER_AND_GROUP_EXAMPLE
)


@pytest.mark.parametrize(
["raw_ldif", "parsed_ldif"],
[
Expand Down
14 changes: 12 additions & 2 deletions src/layers/etl_utils/worker/exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,14 @@
TRUNCATED = "[TRUNCATED]\n"


def truncate_message(
item: str, truncation_depth=TRUNCATION_DEPTH, truncated_note=TRUNCATED
):
if len(item) > truncation_depth:
return item[:truncation_depth] + truncated_note
return item


def _render_exception(exception: Exception) -> str:
"""Concatenates an exception with its notes"""
_notes = exception.__dict__.get("__notes__", [])
Expand Down Expand Up @@ -92,9 +100,11 @@ def render_exception(

_traceback = "".join(TracebackException.from_exception(exception).format())
if truncation_depth is not None and len(formatted_exception) > truncation_depth:
formatted_exception = formatted_exception[:truncation_depth] + TRUNCATED
formatted_exception = truncate_message(
formatted_exception, truncation_depth=truncation_depth
)

if truncation_depth is not None and len(_traceback) > truncation_depth:
_traceback = _traceback[:truncation_depth] + TRUNCATED
_traceback = truncate_message(_traceback, truncation_depth=truncation_depth)

return f"{indentation}{formatted_exception}\n{_traceback}\n"
2 changes: 1 addition & 1 deletion src/layers/sds/domain/nhs_accredited_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class NhsAccreditedSystem(SdsBaseModel):
nhs_product_name: Optional[str] = Field(alias="nhsproductname")
nhs_product_version: Optional[str] = Field(alias="nhsproductversion")
nhs_as_acf: Optional[set[str]] = Field(alias="nhsasacf")
nhs_as_client: Optional[set[str]] = Field(alias="nhsasclient")
nhs_as_client: Optional[set[str]] = Field(alias="nhsasclient", default_factory=set)
nhs_as_svc_ia: set[str] = Field(alias="nhsassvcia")
nhs_temp_uid: Optional[str] = Field(alias="nhstempuid")
description: Optional[str] = Field(alias="description")
Expand Down
Loading

0 comments on commit 44d1bf7

Please sign in to comment.