Skip to content

Commit

Permalink
[feature/PI-618-bulk_etl_e2e] load fanout and tag rejig
Browse files Browse the repository at this point in the history
  • Loading branch information
jaklinger committed Dec 2, 2024
1 parent a90825b commit 0811a36
Show file tree
Hide file tree
Showing 41 changed files with 5,990 additions and 469 deletions.
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@ repos:
rev: v1.4.0
hooks:
- id: detect-secrets
exclude: ".pre-commit-config.yaml|infrastructure/localstack/provider.tf|src/etl/sds/tests/changelog"
exclude: ".pre-commit-config.yaml|infrastructure/localstack/provider.tf|src/etl/sds/tests/changelog|src/etl/sds/worker/bulk/transform_bulk/tests|src/etl/sds/worker/bulk/tests/stage_data"

- repo: https://github.com/prettier/pre-commit
rev: 57f39166b5a5a504d6808b87ab98d41ebf095b46
Expand Down
3 changes: 1 addition & 2 deletions src/etl/sds/worker/bulk/extract_bulk/extract_bulk.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,8 +46,7 @@ def extract(
s3_client: "S3Client", s3_input_path: str, s3_output_path: str, max_records: int
) -> WorkerActionResponse:
unprocessed_records = _read(s3_client=s3_client, s3_input_path=s3_input_path)

processed_records = []
processed_records = deque([])

exception = apply_action(
unprocessed_records=unprocessed_records,
Expand Down
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
import json
import os
from pathlib import Path
from typing import Callable
from unittest import mock

Expand All @@ -13,8 +12,9 @@
from moto import mock_aws
from mypy_boto3_s3 import S3Client

from etl.sds.worker.bulk.tests.test_bulk_e2e import PATH_TO_STAGE_DATA

BUCKET_NAME = "my-bucket"
PATH_TO_HERE = Path(__file__).parent


@pytest.fixture
Expand Down Expand Up @@ -55,7 +55,7 @@ def test_extract_worker_pass(
from etl.sds.worker.bulk.extract_bulk import extract_bulk

# Initial state
with open(PATH_TO_HERE / "extract_bulk_input.ldif") as f:
with open(PATH_TO_STAGE_DATA / "0.extract_input.ldif") as f:
input_data = f.read()

put_object(key=WorkerKey.EXTRACT, body=input_data)
Expand All @@ -75,7 +75,7 @@ def test_extract_worker_pass(
output_data = get_object(key=WorkerKey.TRANSFORM)
n_final_unprocessed = len(_split_ldif(final_unprocessed_data))

with open(PATH_TO_HERE / "extract_bulk_output.json") as f:
with open(PATH_TO_STAGE_DATA / "1.extract_output.json") as f:
expected_output_data = json_load(f)

assert (
Expand Down
100 changes: 77 additions & 23 deletions src/etl/sds/worker/bulk/load_bulk/tests/test_load_bulk_worker.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import os
from collections import deque
from itertools import chain
from pathlib import Path
from typing import Callable
from unittest import mock

Expand All @@ -21,15 +22,15 @@
from event.json import json_load
from moto import mock_aws
from mypy_boto3_s3 import S3Client
from sds.epr.bulk_create.bulk_load_fanout import FANOUT
from sds.epr.constants import AS_DEVICE_SUFFIX, MHS_DEVICE_SUFFIX

from etl.sds.worker.bulk.transform_bulk.tests.test_transform_bulk_worker import (
PATH_TO_HERE as PATH_TO_TRANSFORM_OUTPUT,
)
from etl.sds.worker.bulk.tests.test_bulk_e2e import PATH_TO_STAGE_DATA
from test_helpers.dynamodb import mock_table

BUCKET_NAME = "my-bucket"
TABLE_NAME = "my-table"
PATH_TO_HERE = Path(__file__).parent


@pytest.fixture
Expand Down Expand Up @@ -69,33 +70,48 @@ def test_load_worker_pass(
from etl.sds.worker.bulk.load_bulk import load_bulk

# Initial state
with open(PATH_TO_TRANSFORM_OUTPUT / "transform_bulk_output.json") as f:
input_data = json_load(f)
for i in range(FANOUT):
with open(PATH_TO_STAGE_DATA / f"3.load_fanout_output.{i}.json") as f:
input_data = json_load(f)

put_object(
key=WorkerKey.LOAD,
body=pkl_dumps_lz4(deque(input_data)),
)
put_object(
key=f"{WorkerKey.LOAD}.{i}",
body=pkl_dumps_lz4(deque(input_data)),
)

# Execute the load worker
with mock_table(TABLE_NAME) as dynamodb_client:
load_bulk.CACHE.REPOSITORY.client = dynamodb_client
response = load_bulk.handler(
event={"s3_input_path": f"s3://{BUCKET_NAME}/{WorkerKey.LOAD}"},
context=None,
)

assert response == {
"stage_name": "load",
"processed_records": len(input_data),
"unprocessed_records": 0,
"error_message": None,
}

# Final state
final_unprocessed_data = pkl_loads_lz4(get_object(key=WorkerKey.LOAD))
assert final_unprocessed_data == deque([])
responses = []
for i in range(FANOUT):
response = load_bulk.handler(
event={"s3_input_path": f"s3://{BUCKET_NAME}/{WorkerKey.LOAD}.{i}"},
context=None,
)
responses.append(response)

# Final state
final_unprocessed_data = pkl_loads_lz4(
get_object(key=f"{WorkerKey.LOAD}.{i}")
)
assert final_unprocessed_data == deque([])

assert responses == [
{
"stage_name": "load",
"processed_records": 10,
"unprocessed_records": 0,
"error_message": None,
}
] * (FANOUT - 1) + [
{
"stage_name": "load",
"processed_records": 7,
"unprocessed_records": 0,
"error_message": None,
}
]
product_team_repo = ProductTeamRepository(
table_name=TABLE_NAME, dynamodb_client=dynamodb_client
)
Expand Down Expand Up @@ -149,6 +165,9 @@ def test_load_worker_pass(
)
assert len(device_ref_datas) == 4

with open(PATH_TO_STAGE_DATA / "2.transform_output.json") as f:
input_data: list[dict[str, dict]] = json_load(f)

(input_product_team,) = (
ProductTeam(**data)
for item in input_data
Expand Down Expand Up @@ -193,3 +212,38 @@ def test_load_worker_pass(
assert products == input_products
assert devices == input_devices
assert device_ref_datas == input_device_ref_data


# @pytest.mark.integration
# @pytest.mark.parametrize("path", (PATH_TO_HERE / "edge_cases").iterdir())
# def test_load_worker_edge_cases(path: str):

# s3_client = boto3.client("s3")
# lambda_client = boto3.client("lambda")
# bucket = read_terraform_output("sds_etl.value.bucket")
# lambda_name = read_terraform_output("sds_etl.value.bulk_load_lambda_arn")

# with open(path) as f:
# s3_client.put_object(
# Key=f"{WorkerKey.LOAD}.0",
# Bucket=bucket,
# body=pkl_dumps_lz4(deque(json_load(f))),
# )

# response = lambda_client.invoke(
# FunctionName=lambda_name,
# Payload=json.dumps({"s3_input_path": f"s3://{BUCKET_NAME}/{WorkerKey.LOAD}.0"}),
# )
# assert json_load(response["Payload"]) == {
# "stage_name": "load",
# "processed_records": 10,
# "unprocessed_records": 0,
# "error_message": None,
# }

# # Final state
# _final_unprocessed_data = s3_client.get_object(
# Key=f"{WorkerKey.LOAD}.0", Bucket=bucket
# )["Body"]
# final_unprocessed_data = pkl_load_lz4(_final_unprocessed_data)
# assert final_unprocessed_data == deque([])
Empty file.
155 changes: 155 additions & 0 deletions src/etl/sds/worker/bulk/load_bulk_fanout/load_bulk_fanout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
from collections import deque
from dataclasses import asdict
from itertools import batched
from types import FunctionType
from typing import TYPE_CHECKING

import boto3
from etl_utils.constants import WorkerKey
from etl_utils.io import pkl_dump_lz4, pkl_load_lz4
from etl_utils.smart_open import smart_open
from etl_utils.worker.action import apply_action
from etl_utils.worker.model import WorkerActionResponse
from etl_utils.worker.steps import (
execute_action,
save_processed_records,
save_unprocessed_records,
)
from etl_utils.worker.worker_step_chain import (
_log_action_without_inputs,
_render_response,
log_exception,
)
from event.environment import BaseEnvironment
from event.step_chain import StepChain
from sds.epr.bulk_create.bulk_load_fanout import FANOUT, calculate_batch_size
from sds.epr.bulk_create.bulk_repository import BulkRepository

if TYPE_CHECKING:
from mypy_boto3_s3 import S3Client


class TransformWorkerEnvironment(BaseEnvironment):
ETL_BUCKET: str

def s3_path(self, key) -> str:
return f"s3://{self.ETL_BUCKET}/{key}"


ENVIRONMENT = TransformWorkerEnvironment.build()
S3_CLIENT = boto3.client("s3")


def execute_step_chain(
action: FunctionType,
s3_client,
s3_input_path: str,
s3_output_path: str,
unprocessed_dumper: FunctionType,
processed_dumper: FunctionType,
max_records: int = None,
**kwargs,
) -> list[dict]:
# Run the main action chain
action_chain = StepChain(
step_chain=[execute_action], step_decorators=[_log_action_without_inputs]
)
action_chain.run(
init=(action, s3_client, s3_input_path, s3_output_path, max_records, kwargs)
)
if isinstance(action_chain.result, Exception):
log_exception(action_chain.result)

# Save the action chain results if there were no unhandled (fatal) exceptions
count_unprocessed_records = None
count_processed_records = None
save_chain_response = None

worker_response = []
if isinstance(action_chain.result, WorkerActionResponse):
if isinstance(action_chain.result.exception, Exception):
log_exception(action_chain.result.exception)

count_unprocessed_records = len(action_chain.result.unprocessed_records)

batch_size = calculate_batch_size(
sequence=action_chain.result.processed_records, n_batches=FANOUT
)
for i, batch in enumerate(
batched(action_chain.result.processed_records, n=batch_size)
):
count_processed_records = len(batch)
_action_response = WorkerActionResponse(
unprocessed_records=action_chain.result.unprocessed_records,
s3_input_path=action_chain.result.s3_input_path,
processed_records=deque(batch),
s3_output_path=ENVIRONMENT.s3_path(f"{WorkerKey.LOAD}.{i}"),
exception=action_chain.result.exception,
)

save_chain = StepChain(
step_chain=[save_unprocessed_records, save_processed_records],
step_decorators=[_log_action_without_inputs],
)
save_chain.run(
init=(
_action_response,
s3_client,
unprocessed_dumper,
processed_dumper,
)
)
save_chain_response = save_chain.result

if isinstance(save_chain.result, Exception):
log_exception(save_chain.result)

# Summarise the outcome of action_chain and step_chain
worker_response_item = _render_response(
action_name=action.__name__,
action_chain_response=_action_response,
save_chain_response=save_chain_response,
count_unprocessed_records=count_unprocessed_records,
count_processed_records=count_processed_records,
)
_response = asdict(worker_response_item)
_response["s3_input_path"] = ENVIRONMENT.s3_path(f"{WorkerKey.LOAD}.{i}")
worker_response.append(_response)
return worker_response


def load_bulk_fanout(
s3_client: "S3Client", s3_input_path: str, s3_output_path: str, max_records: int
) -> WorkerActionResponse:
with smart_open(s3_path=s3_input_path, s3_client=s3_client) as f:
unprocessed_records: deque[dict] = pkl_load_lz4(f)
processed_records = deque()

exception = apply_action(
unprocessed_records=unprocessed_records,
processed_records=processed_records,
action=lambda record: BulkRepository(
None, None
).generate_transaction_statements(record),
max_records=max_records,
)

return WorkerActionResponse(
unprocessed_records=unprocessed_records,
processed_records=processed_records,
s3_input_path=s3_input_path,
s3_output_path=s3_output_path,
exception=exception,
)


def handler(event: dict, context):
return execute_step_chain(
action=load_bulk_fanout,
s3_client=S3_CLIENT,
s3_input_path=ENVIRONMENT.s3_path(WorkerKey.LOAD),
s3_output_path=None,
unprocessed_dumper=pkl_dump_lz4,
processed_dumper=pkl_dump_lz4,
max_records=None,
)
4 changes: 4 additions & 0 deletions src/etl/sds/worker/bulk/load_bulk_fanout/make/make.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
from builder.lambda_build import build

if __name__ == "__main__":
build(__file__)
Loading

0 comments on commit 0811a36

Please sign in to comment.