Skip to content

Commit

Permalink
[feature/PI-605-bulk_etl_transform] bulk etl transform
Browse files Browse the repository at this point in the history
  • Loading branch information
jaklinger committed Nov 22, 2024
1 parent e5a9818 commit 1ddacf9
Show file tree
Hide file tree
Showing 19 changed files with 723 additions and 381 deletions.
2 changes: 1 addition & 1 deletion src/layers/domain/repository/cpm_product_repository/v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ def read(self, product_team_id: str, id: str):
return super()._read(parent_ids=(product_team_id,), id=id)

def search(self, product_team_id: str):
return super()._query(parent_ids=(product_team_id,))
return super()._search(parent_ids=(product_team_id,))

def handle_CpmProductCreatedEvent(self, event: CpmProductCreatedEvent):
return self.create_index(
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ def read(self, product_team_id: str, product_id: str, id: str):
return super()._read(parent_ids=(product_team_id, product_id), id=id)

def search(self, product_team_id: str, product_id: str):
return super()._query(parent_ids=(product_team_id, product_id))
return super()._search(parent_ids=(product_team_id, product_id))

def handle_DeviceReferenceDataCreatedEvent(
self, event: DeviceReferenceDataCreatedEvent
Expand Down
78 changes: 15 additions & 63 deletions src/layers/domain/repository/device_repository/v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,27 +94,6 @@ def create_tag_index(
)


def create_tag_index_batch(device_id: str, tag_value: str, data: dict):
"""
Difference between `create_tag_index` and `create_tag_index_batch`:
`create_index` is intended for the event-based
`handle_TagAddedEvent` which is called by the base
`write` method, which expects `TransactItem`s for use with `client.transact_write_items`
`create_tag_index_batch` is intended for the entity-based handler
`handle_bulk` which is called by the base method `write_bulk`, which expects
`BatchWriteItem`s which we render as a `dict` for use with `client.batch_write_items`
"""
pk = TableKey.DEVICE_TAG.key(tag_value)
sk = TableKey.DEVICE.key(device_id)
return {
"PutRequest": {
"Item": marshall(pk=pk, sk=sk, pk_read=pk, sk_read=sk, root=False, **data)
}
}


def delete_tag_index(table_name: str, device_id: str, tag_value: str) -> TransactItem:
pk = TableKey.DEVICE_TAG.key(tag_value)
sk = TableKey.DEVICE.key(device_id)
Expand Down Expand Up @@ -158,11 +137,16 @@ def __init__(self, table_name, dynamodb_client):
table_key=TableKey.DEVICE,
)

def _query(self, parent_ids: tuple[str], id: str = None):
return map(
decompress_device_fields, super()._query(parent_ids=parent_ids, id=id)
)

def read(self, product_team_id: str, product_id: str, id: str):
return super()._read(parent_ids=(product_team_id, product_id), id=id)

def search(self, product_team_id: str, product_id: str):
return super()._query(parent_ids=(product_team_id, product_id))
return super()._search(parent_ids=(product_team_id, product_id))

def handle_DeviceCreatedEvent(self, event: DeviceCreatedEvent) -> TransactItem:
return self.create_index(
Expand Down Expand Up @@ -310,7 +294,9 @@ def handle_DeviceKeyDeletedEvent(
def handle_DeviceTagAddedEvent(
self, event: DeviceTagAddedEvent
) -> list[TransactItem]:
data = {"tags": event.tags, "updated_on": event.updated_on}
data = compress_device_fields(
{"tags": event.tags, "updated_on": event.updated_on}
)

# Create a copy of the Device indexed against the new tag
create_tag_transaction = create_tag_index(
Expand Down Expand Up @@ -343,7 +329,9 @@ def handle_DeviceTagAddedEvent(
)

def handle_DeviceTagsAddedEvent(self, event: DeviceTagsAddedEvent):
data = {"tags": event.tags, "updated_on": event.updated_on}
data = compress_device_fields(
{"tags": event.tags, "updated_on": event.updated_on}
)
_data = compress_device_fields(
event, fields_to_compress=NON_ROOT_FIELDS_TO_COMPRESS
)
Expand Down Expand Up @@ -390,9 +378,8 @@ def handle_DeviceTagsClearedEvent(self, event: DeviceTagsClearedEvent):
]

keys = {DeviceKey(**key).key_value for key in event.keys}
update_transactions = self.update_indexes(
id=event.id, keys=keys, data={"tags": []}
)
data = compress_device_fields({"tags": []})
update_transactions = self.update_indexes(id=event.id, keys=keys, data=data)
return delete_tags_transactions + update_transactions

def handle_DeviceReferenceDataIdAddedEvent(
Expand All @@ -408,42 +395,7 @@ def handle_DeviceReferenceDataIdAddedEvent(
def handle_QuestionnaireResponseUpdatedEvent(
self, event: QuestionnaireResponseUpdatedEvent
) -> TransactItem:
data = asdict(event)
data.pop("id")
return self.update_indexes(id=event.id, keys=[], data=data)

def handle_bulk(self, item: dict) -> list[dict]:
parent_key = (item["product_team_id"], item["product_id"])

root_data = compress_device_fields(item)
create_device_transaction = self.create_index_batch(
id=item["id"], parent_key_parts=parent_key, data=root_data, root=True
)

non_root_data = compress_device_fields(
item, fields_to_compress=NON_ROOT_FIELDS_TO_COMPRESS
)
create_keys_transactions = [
self.create_index_batch(
id=key["key_value"],
parent_key_parts=parent_key,
data=non_root_data,
root=False,
)
for key in item["keys"]
]

create_tags_transactions = [
create_tag_index_batch(
device_id=item["id"], tag_value=tag, data=non_root_data
)
for tag in item["tags"]
]
return (
[create_device_transaction]
+ create_keys_transactions
+ create_tags_transactions
)
return self.handle_DeviceUpdatedEvent(event=event)

def query_by_tag(
self,
Expand Down
11 changes: 7 additions & 4 deletions src/layers/domain/repository/product_team_repository/v1.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,20 +16,23 @@ def __init__(self, table_name: str, dynamodb_client):
)

def read(self, id: str) -> ProductTeam:
return super()._read(parent_ids=(), id=id)
return super()._read(parent_ids=("",), id=id)

def search(self) -> list[ProductTeam]:
return super()._search(parent_ids=("",))

def handle_ProductTeamCreatedEvent(self, event: ProductTeamCreatedEvent):
create_root_transaction = self.create_index(
id=event.id, parent_key_parts=(event.id,), data=asdict(event), root=True
id=event.id, parent_key_parts=("",), data=asdict(event), root=True
)

keys = {ProductTeamKey(**key) for key in event.keys}
create_key_transactions = [
self.create_index(
id=key.key_value,
parent_key_parts=(key.key_value,),
parent_key_parts=("",),
data=asdict(event),
root=True,
root=False,
)
for key in keys
]
Expand Down
145 changes: 1 addition & 144 deletions src/layers/domain/repository/repository/tests/test_repository_v1.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,7 @@
import pytest
from domain.repository.errors import AlreadyExistsError, ItemNotFound
from domain.repository.repository import (
exponential_backoff_with_jitter,
retry_with_jitter,
)

from .model_v1 import (
MyEventAdd,
MyEventDelete,
MyModel,
MyOtherEventAdd,
MyRepository,
MyTableKey,
)
from .model_v1 import MyEventAdd, MyEventDelete, MyModel, MyOtherEventAdd, MyRepository


@pytest.fixture
Expand Down Expand Up @@ -125,135 +114,3 @@ def test_repository_add_and_delete_separate_transactions(repository: MyRepositor

with pytest.raises(ItemNotFound):
repository.read(id=value)


@pytest.mark.integration
def test_repository_write_bulk(repository: MyRepository):
responses = repository.write_bulk(
[
{
"pk": str(i),
"sk": str(i),
"pk_read": MyTableKey.FOO.key(str(i)),
"sk_read": MyTableKey.FOO.key(str(i)),
"field": f"boo-{i}",
}
for i in range(51)
],
batch_size=25,
)
assert len(responses) >= 3 # 51/25

for i in range(51):
assert repository.read(id=str(i)).field == f"boo-{i}"


def test_exponential_backoff_with_jitter():
base_delay = 0.1
max_delay = 5
min_delay = 0.05
n_samples = 1000

delays = []
for retry in range(n_samples):
delay = exponential_backoff_with_jitter(
n_retries=retry,
base_delay=base_delay,
min_delay=min_delay,
max_delay=max_delay,
)
assert max_delay >= delay >= min_delay
delays.append(delay)
assert len(set(delays)) == n_samples # all delays should be unique
assert sum(delays[n_samples:]) < sum(
delays[:n_samples]
) # final delays should be larger than first delays


@pytest.mark.parametrize(
"error_code",
[
"ProvisionedThroughputExceededException",
"ThrottlingException",
"InternalServerError",
],
)
def test_retry_with_jitter_all_fail(error_code: str):
class MockException(Exception):
def __init__(self, error_code):
self.response = {"Error": {"Code": error_code}}

max_retries = 3

@retry_with_jitter(max_retries=max_retries, error=MockException)
def throw(error_code):
raise MockException(error_code=error_code)

with pytest.raises(ExceptionGroup) as exception_info:
throw(error_code=error_code)

assert (
exception_info.value.message
== f"Failed to put item after {max_retries} retries"
)
assert len(exception_info.value.exceptions) == max_retries
assert all(
isinstance(exc, MockException) for exc in exception_info.value.exceptions
)


@pytest.mark.parametrize(
"error_code",
[
"ProvisionedThroughputExceededException",
"ThrottlingException",
"InternalServerError",
],
)
def test_retry_with_jitter_third_passes(error_code: str):
class MockException(Exception):
retries = 0

def __init__(self, error_code):
self.response = {"Error": {"Code": error_code}}

max_retries = 3

@retry_with_jitter(max_retries=max_retries, error=MockException)
def throw(error_code):
if MockException.retries == max_retries - 1:
return "foo"
MockException.retries += 1
raise MockException(error_code=error_code)

assert throw(error_code=error_code) == "foo"


@pytest.mark.parametrize(
"error_code",
[
"SomeOtherError",
],
)
def test_retry_with_jitter_other_code(error_code: str):
class MockException(Exception):
def __init__(self, error_code):
self.response = {"Error": {"Code": error_code}}

@retry_with_jitter(max_retries=3, error=MockException)
def throw(error_code):
raise MockException(error_code=error_code)

with pytest.raises(MockException) as exception_info:
throw(error_code=error_code)

assert exception_info.value.response == {"Error": {"Code": error_code}}


def test_retry_with_jitter_other_exception():
@retry_with_jitter(max_retries=3, error=ValueError)
def throw():
raise TypeError()

with pytest.raises(TypeError):
throw()
Loading

0 comments on commit 1ddacf9

Please sign in to comment.