Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[PI-618] Bulk ETL (part 1 - Layers) #421

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ repos:
rev: v1.4.0
hooks:
- id: detect-secrets
exclude: ".pre-commit-config.yaml|infrastructure/localstack/provider.tf|src/etl/sds/tests/changelog"
exclude: ".pre-commit-config.yaml|infrastructure/localstack/provider.tf|src/etl/sds/tests/changelog|src/etl/sds/worker/bulk/transform_bulk/tests|src/etl/sds/worker/bulk/tests/stage_data"

- repo: https://github.com/prettier/pre-commit
rev: 57f39166b5a5a504d6808b87ab98d41ebf095b46
Expand Down
59 changes: 26 additions & 33 deletions src/layers/domain/api/sds/query.py
Original file line number Diff line number Diff line change
@@ -1,49 +1,39 @@
from functools import cache
from itertools import chain, combinations

from pydantic import BaseModel, Extra, root_validator


class SearchSDSQueryParams(BaseModel):
class SearchSDSDeviceQueryParams(BaseModel, extra=Extra.forbid):
nhs_id_code: str
nhs_as_svc_ia: str
nhs_mhs_manufacturer_org: str = None
nhs_mhs_party_key: str = None

@root_validator(pre=True)
def client_to_id(cls, values: dict):
nhs_as_client = values.pop("nhs_as_client", None)
nhs_id_code = values.get("nhs_id_code")
if nhs_as_client and not nhs_id_code:
values["nhs_id_code"] = nhs_as_client
return values

def get_non_null_params(self):
return self.dict(exclude_none=True)

@classmethod
@cache
def allowed_field_combinations(cls) -> list[set[str]]:
"""
This method is used to generate all allowed combinations of search fields
for the given query parameters. Down the line this also used to generate
Device.tags in the ETL
"""
mandatory_fields, optional_fields = [], []
for field_name, field in cls.__fields__.items():
if field.required:
mandatory_fields.append(field_name)
else:
optional_fields.append(field_name)

n_minimum_optional_fields = 0 if mandatory_fields else 1
n_optional_fields = len(optional_fields)
optional_field_combinations = chain.from_iterable(
combinations(optional_fields, n_fields)
for n_fields in range(n_minimum_optional_fields, n_optional_fields + 1)
)

return [
{*mandatory_fields, *_optional_field_combination}
for _optional_field_combination in optional_field_combinations
{"nhs_id_code", "nhs_as_svc_ia"},
{"nhs_id_code", "nhs_as_svc_ia", "nhs_mhs_party_key"},
]


class SearchSDSDeviceQueryParams(SearchSDSQueryParams, extra=Extra.forbid):
nhs_as_client: str
nhs_as_svc_ia: str
nhs_mhs_manufacturer_org: str = None
nhs_mhs_party_key: str = None

class SearchSDSEndpointQueryParams(BaseModel, extra=Extra.forbid):
# can effectively achieve these without tags now
# if
# 1. nhs_id_code and nhs_mhs_party_key -> only need party key
# 2. nhs_id_code and nhs_mhs_svc_ia -> can retrieve all devices for org, and filter on nhs_mhs_svc_ia
# 3. nhs_mhs_party_key and nhs_mhs_svc_ia -> can retrieve the mhs device for product, and filter on nhs_mhs_svc_ia --> possible that can read device directly by cpaid
# 4. nhs_id_code and nhs_mhs_party_key and nhs_mhs_svc_ia -> same as 3.

jaklinger marked this conversation as resolved.
Show resolved Hide resolved
class SearchSDSEndpointQueryParams(SearchSDSQueryParams, extra=Extra.forbid):
nhs_id_code: str = None
nhs_mhs_svc_ia: str = None
nhs_mhs_party_key: str = None
Expand All @@ -59,3 +49,6 @@ def check_filters(cls, values: dict):
"At least 2 query parameters should be provided of type, nhs_id_code, nhs_mhs_svc_ia and nhs_mhs_party_key"
)
return values

def get_non_null_params(self):
return self.dict(exclude_none=True)
40 changes: 0 additions & 40 deletions src/layers/domain/api/tests/test_search_query.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,7 @@
from typing import Optional

import pytest
from domain.api.sds.query import (
SearchSDSDeviceQueryParams,
SearchSDSEndpointQueryParams,
SearchSDSQueryParams,
)
from pydantic import ValidationError

Expand Down Expand Up @@ -97,40 +94,3 @@ def test_endpoint_query_accepted(params):
def test_endpoint_query_invalid(params):
with pytest.raises(ValidationError):
search = SearchSDSEndpointQueryParams(**params)


def test_allowed_field_combinations():
class MyModel(SearchSDSQueryParams):
foo: str
bar: Optional[str]
bob: Optional[str]

assert MyModel.allowed_field_combinations() == [
{"foo"},
{"foo", "bar"},
{"foo", "bob"},
{"foo", "bar", "bob"},
]


def test_allowed_field_combinations_all_optional():
class MyModel(SearchSDSQueryParams):
foo: Optional[str]
bar: Optional[str]
bob: Optional[str]

assert MyModel.allowed_field_combinations() == [
{
"foo",
},
{
"bar",
},
{
"bob",
},
{"foo", "bar"},
{"foo", "bob"},
{"bar", "bob"},
{"foo", "bar", "bob"},
]
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ def test_party_key_generator_validate_key_valid():
[
"ABC000124", # Missing hyphen
"ABC-1234", # Number part too short
"ABC-1234567", # Number part too long
"ABC-123456789101112", # Number part too long
"ABC-0001A4", # Number part contains a non-digit character
"", # Empty string
],
Expand Down
4 changes: 0 additions & 4 deletions src/layers/domain/core/device/v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -159,10 +159,6 @@ class DeviceTagsClearedEvent(Event):

@dataclass(kw_only=True, slots=True)
class QuestionnaireResponseUpdatedEvent(Event):
"""
This is adding the initial questionnaire response from the event body request.
"""

id: str
questionnaire_responses: dict[str, list[QuestionnaireResponse]]
keys: list[DeviceKey]
Expand Down
2 changes: 1 addition & 1 deletion src/layers/domain/core/validation.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,7 @@ class AccreditedSystem:
ID_PATTERN = re.compile(rf"^[a-zA-Z-0-9]+$")

class PartyKey:
PARTY_KEY_REGEX = rf"^{_ODS_CODE_REGEX}-[0-9]{{6}}$"
PARTY_KEY_REGEX = rf"^{_ODS_CODE_REGEX}-[0-9]{{6,9}}$"
jaklinger marked this conversation as resolved.
Show resolved Hide resolved
ID_PATTERN = re.compile(PARTY_KEY_REGEX)

class CpaId:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,5 +14,5 @@
"nhs_as_acf": "AS ACF",
"nhs_temp_uid": "Temp UID",
"description": "Description",
"nhs_as_category_bag": "AS Category Bag"
"nhs_as_category_bag": "Category Bag"
}
2 changes: 1 addition & 1 deletion src/layers/etl_utils/ldif/_ldif.py
Original file line number Diff line number Diff line change
Expand Up @@ -203,7 +203,7 @@ def __next_key_and_value(self):
# All values should be valid ascii; we support UTF-8 as a
# non-official, backwards compatibility layer.
attr_value = unfolded_line[colon_pos + 1 :].encode("utf-8")
return attr_type.lower(), attr_value
return attr_type.lower(), attr_value.strip()

def _consume_empty_lines(self):
"""
Expand Down
13 changes: 12 additions & 1 deletion src/layers/etl_utils/ldif/ldif.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
import re
from base64 import b64decode
from collections import defaultdict
from io import BytesIO
from types import FunctionType
Expand Down Expand Up @@ -109,11 +110,16 @@ def __init__(self, group_field: str, filter_terms: list[tuple[str, str]]):
for key, value in filter_terms
]
self.group_filter = re.compile(rf"(?i)^({group_field}): (.*)\n$".encode()).match
self.group_filter_base_64 = re.compile(
rf"(?i)^({group_field}):: (.*)\n$".encode()
).match
self.reset()

def flush(self) -> str:
if self.group is None:
raise Exception
raise ValueError(
f"No group name assigned to the following group:\n{self.buffer}"
)
self.data[self.group].write(self.buffer)
self.reset()

Expand All @@ -126,6 +132,11 @@ def parse(self, line: bytes):
group_match = self.group_filter(line)
if group_match:
(_, self.group) = group_match.groups()
else:
b64_group_match = self.group_filter_base_64(line)
if b64_group_match:
(_, b64_group) = b64_group_match.groups()
self.group = b64decode(b64_group).strip()

if not self.keep and any(filter(line) for filter in self.filters):
self.keep = True
Expand Down
67 changes: 67 additions & 0 deletions src/layers/etl_utils/ldif/tests/test_ldif.py
Original file line number Diff line number Diff line change
Expand Up @@ -296,6 +296,54 @@
myOtherField: 123
"""

LDIF_TO_FILTER_AND_GROUP_EXAMPLE_BASE_64 = """
dn: uniqueIdentifier=AAA1
myField:: QUFB
myOtherField: 123

dn: uniqueIdentifier=BBB1
myfield: BBB
myOtherField: 123

dn: uniqueIdentifier=BBB2
myfield: BBB
myOtherField: 123

dn: uniqueIdentifier=AAA2
myfield: AAA
myOtherField: 123

dn: uniqueIdentifier=AAA3
myField: AAA
myOtherField: 234

dn: uniqueIdentifier=BBB3
myfield:: QkJC
myOtherField: 123
"""

FILTERED_AND_GROUPED_LDIF_TO_FILTER_AND_GROUP_EXAMPLE_BASE_64 = """
dn: uniqueIdentifier=AAA1
myField:: QUFB
myOtherField: 123

dn: uniqueIdentifier=AAA2
myfield: AAA
myOtherField: 123

dn: uniqueIdentifier=BBB1
myfield: BBB
myOtherField: 123

dn: uniqueIdentifier=BBB2
myfield: BBB
myOtherField: 123

dn: uniqueIdentifier=BBB3
myfield:: QkJC
myOtherField: 123
"""


@pytest.mark.parametrize(
("raw_distinguished_name", "parsed_distinguished_name"),
Expand Down Expand Up @@ -390,6 +438,25 @@ def test_filter_and_group_ldif_from_s3_by_property(mocked_open):
)


@mock.patch(
"etl_utils.ldif.ldif._smart_open",
return_value=BytesIO(LDIF_TO_FILTER_AND_GROUP_EXAMPLE_BASE_64.encode()),
)
def test_filter_and_group_ldif_from_s3_by_property_with_b64encoded_group(mocked_open):
with mock_aws():
s3_client = boto3.client("s3")
filtered_ldif = filter_and_group_ldif_from_s3_by_property(
s3_client=s3_client,
s3_path="s3://dummy_bucket/dummy_key",
group_field="myField",
filter_terms=[("myOtherField", "123")],
)
assert (
"".join(data.tobytes().decode() for data in filtered_ldif)
== FILTERED_AND_GROUPED_LDIF_TO_FILTER_AND_GROUP_EXAMPLE_BASE_64
)


@pytest.mark.parametrize(
["raw_ldif", "parsed_ldif"],
[
Expand Down
14 changes: 12 additions & 2 deletions src/layers/etl_utils/worker/exception.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,14 @@
TRUNCATED = "[TRUNCATED]\n"


def truncate_message(
item: str, truncation_depth=TRUNCATION_DEPTH, truncated_note=TRUNCATED
):
if len(item) > truncation_depth:
return item[:truncation_depth] + truncated_note
return item


def _render_exception(exception: Exception) -> str:
"""Concatenates an exception with its notes"""
_notes = exception.__dict__.get("__notes__", [])
Expand Down Expand Up @@ -92,9 +100,11 @@ def render_exception(

_traceback = "".join(TracebackException.from_exception(exception).format())
if truncation_depth is not None and len(formatted_exception) > truncation_depth:
formatted_exception = formatted_exception[:truncation_depth] + TRUNCATED
formatted_exception = truncate_message(
formatted_exception, truncation_depth=truncation_depth
)

if truncation_depth is not None and len(_traceback) > truncation_depth:
_traceback = _traceback[:truncation_depth] + TRUNCATED
_traceback = truncate_message(_traceback, truncation_depth=truncation_depth)

return f"{indentation}{formatted_exception}\n{_traceback}\n"
2 changes: 1 addition & 1 deletion src/layers/sds/domain/nhs_accredited_system.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ class NhsAccreditedSystem(SdsBaseModel):
nhs_product_name: Optional[str] = Field(alias="nhsproductname")
nhs_product_version: Optional[str] = Field(alias="nhsproductversion")
nhs_as_acf: Optional[set[str]] = Field(alias="nhsasacf")
nhs_as_client: Optional[set[str]] = Field(alias="nhsasclient")
nhs_as_client: Optional[set[str]] = Field(alias="nhsasclient", default_factory=set)
jaklinger marked this conversation as resolved.
Show resolved Hide resolved
nhs_as_svc_ia: set[str] = Field(alias="nhsassvcia")
nhs_temp_uid: Optional[str] = Field(alias="nhstempuid")
description: Optional[str] = Field(alias="description")
Expand Down
Loading
Loading