Skip to content

Commit

Permalink
Fix job configuration
Browse files Browse the repository at this point in the history
  • Loading branch information
javiermtorres committed Dec 11, 2024
1 parent bd448a8 commit 1d25f6b
Show file tree
Hide file tree
Showing 2 changed files with 64 additions and 34 deletions.
47 changes: 32 additions & 15 deletions lumigator/python/mzai/backend/backend/tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -110,9 +110,11 @@ def db_session(db_engine: Engine):
@pytest.fixture(scope="function")
def fake_s3fs() -> S3FileSystem:
# ...and patch the s3fs name with the Fake
fsspec.register_implementation("s3", MemoryFileSystem, clobber=True, errtxt="Failed to register mock S3FS")
fsspec.register_implementation(
"s3", MemoryFileSystem, clobber=True, errtxt="Failed to register mock S3FS"
)
yield MemoryFileSystem()
print(f'final s3fs contents: {str(MemoryFileSystem.store)}')
print(f"final s3fs contents: {str(MemoryFileSystem.store)}")


@pytest.fixture(scope="function")
Expand All @@ -125,6 +127,7 @@ def fake_s3_client(fake_s3fs) -> S3Client:
os.environ["AWS_ENDPOINT_URL"] = "http://example.com:4566"
return FakeS3Client(MemoryFileSystem.store)


@pytest.fixture(scope="function")
def boto_s3_client() -> S3Client:
# Initialize S3
Expand All @@ -136,12 +139,13 @@ def boto_s3_client() -> S3Client:
os.environ["AWS_ENDPOINT_URL"] = "http://localhost:4566"
return boto3.client("s3")


@pytest.fixture(scope="function")
def boto_s3fs() -> S3FileSystem:
s3fs = S3FileSystem()
mock_s3fs = Mock(wraps=s3fs)
yield mock_s3fs
print(f'intercepted s3fs calls: {str(mock_s3fs.mock_calls)}')
print(f"intercepted s3fs calls: {str(mock_s3fs.mock_calls)}")


@pytest.fixture(scope="session")
Expand Down Expand Up @@ -177,7 +181,9 @@ def local_client(app: FastAPI):


@pytest.fixture(scope="function")
def dependency_overrides_fakes(app: FastAPI, db_session: Session, fake_s3_client: S3Client, fake_s3fs: S3FileSystem) -> None:
def dependency_overrides_fakes(
app: FastAPI, db_session: Session, fake_s3_client: S3Client, fake_s3fs: S3FileSystem
) -> None:
"""Override the FastAPI dependency injection for test DB sessions. Uses mocks/fakes for unit tests.
Reference: https://fastapi.tiangolo.com/he/advanced/testing-database/
Expand All @@ -198,7 +204,9 @@ def get_s3_filesystem_override():


@pytest.fixture(scope="function")
def dependency_overrides_services(app: FastAPI, db_session: Session, boto_s3_client: S3Client, boto_s3fs: S3FileSystem) -> None:
def dependency_overrides_services(
app: FastAPI, db_session: Session, boto_s3_client: S3Client, boto_s3fs: S3FileSystem
) -> None:
"""Override the FastAPI dependency injection for test DB sessions. Uses real clients for integration tests.
Reference: https://fastapi.tiangolo.com/he/advanced/testing-database/
Expand Down Expand Up @@ -287,7 +295,7 @@ def backend_settings():
def simple_eval_template():
return """{{
"name": "{job_name}/{job_id}",
"model": {{ "path": "{model_path}" }},
"model": {{ "path": "{model_uri}" }},
"dataset": {{ "path": "{dataset_path}" }},
"evaluation": {{
"metrics": ["meteor", "rouge"],
Expand All @@ -299,15 +307,24 @@ def simple_eval_template():
}}
}}"""


@pytest.fixture(scope="session")
def simple_infer_template():
return """{{
"name": "{job_name}/{job_id}",
"model": {{ "path": "{model_path}" }},
"dataset": {{ "path": "{dataset_path}" }},
"job": {{
"max_samples": {max_samples},
"storage_path": "{storage_path}",
"output_field": "{output_field}"
}}
}}"""
"name": "{job_name}/{job_id}",
"dataset": {{ "path": "{dataset_path}" }},
"hf_pipeline": {{
"model_uri": "{model_uri}",
"task": "{task}",
"accelerator": "{accelerator}",
"revision": "{revision}",
"use_fast": "{use_fast}",
"trust_remote_code": "{trust_remote_code}",
"torch_dtype": "{torch_dtype}",
"max_length": 200
}},
"job": {{
"max_samples": {max_samples},
"storage_path": "{storage_path}"
}}
}}"""
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
import time
from time import sleep

import pytest
import requests
import time
from fastapi.testclient import TestClient
from lumigator_schemas.datasets import DatasetFormat, DatasetResponse
from lumigator_schemas.experiments import ExperimentResponse
Expand All @@ -18,7 +18,13 @@ def test_health_ok(local_client: TestClient):
assert response.status_code == 200


def test_upload_data_launch_job(local_client: TestClient, dialog_dataset, simple_eval_template, simple_infer_template, dependency_overrides_services):
def test_upload_data_launch_job(
local_client: TestClient,
dialog_dataset,
simple_eval_template,
simple_infer_template,
dependency_overrides_services,
):
response = local_client.get("/health")
assert response.status_code == 200

Expand All @@ -28,7 +34,7 @@ def test_upload_data_launch_job(local_client: TestClient, dialog_dataset, simple
files={"dataset": dialog_dataset, "format": (None, DatasetFormat.JOB.value)},
)

print(f'response: {create_response.text}')
print(f"response: {create_response.text}")
assert create_response.status_code == 201

created_dataset = DatasetResponse.model_validate(create_response.json())
Expand All @@ -40,6 +46,8 @@ def test_upload_data_launch_job(local_client: TestClient, dialog_dataset, simple
eval_payload = {
"name": "test_run_hugging_face",
"description": "Test run for Huggingface model",
# "model": "hf-internal-testing/tiny-random-BartForCausalLM",
# "model": "mlabonne/dummy-CodeLlama-7b-hf",
"model": "hf://hf-internal-testing/tiny-random-LlamaForCausalLM",
"dataset": str(created_dataset.id),
"config_template": simple_eval_template,
Expand All @@ -55,14 +63,16 @@ def test_upload_data_launch_job(local_client: TestClient, dialog_dataset, simple
)
assert create_evaluation_job_response.status_code == 201

create_evaluation_job_response_model = JobResponse.model_validate(create_evaluation_job_response.json())
create_evaluation_job_response_model = JobResponse.model_validate(
create_evaluation_job_response.json()
)

succeeded = False
for i in range (1, 200):
get_job_response = local_client.get(f'/jobs/{create_evaluation_job_response_model.id}')
for i in range(1, 200):
get_job_response = local_client.get(f"/jobs/{create_evaluation_job_response_model.id}")
assert get_job_response.status_code == 200
get_job_response_model = JobResponse.model_validate(get_job_response.json())
print(f'--> try {i}: {get_job_response_model}')
print(f"--> try {i}: {get_job_response_model}")
if get_job_response_model.status == JobStatus.SUCCEEDED.value:
succeeded = True
break
Expand All @@ -77,46 +87,49 @@ def test_upload_data_launch_job(local_client: TestClient, dialog_dataset, simple
"description": "Test run for Huggingface model",
"model": "hf://hf-internal-testing/tiny-random-LlamaForCausalLM",
"dataset": str(created_dataset.id),
"config_template": simple_infer_template,
# "config_template": simple_infer_template,
"max_samples": 10,
# Investigate!
# "model_url": "string",
# "system_prompt": "string",
# "config_template": "string",
}

create_inference_job_response = local_client.post(
"/jobs/inference/", headers=headers, json=infer_payload
)
assert create_inference_job_response.status_code == 201

create_inference_job_response_model = JobResponse.model_validate(create_inference_job_response.json())
create_inference_job_response_model = JobResponse.model_validate(
create_inference_job_response.json()
)

succeeded = False
for i in range (1, 200):
get_job_response = local_client.get(f'/jobs/{create_inference_job_response_model.id}')
for i in range(1, 200):
get_job_response = local_client.get(f"/jobs/{create_inference_job_response_model.id}")
assert get_job_response.status_code == 200
get_job_response_model = JobResponse.model_validate(get_job_response.json())
print(f'--> try {i}: {get_job_response_model}')
print(f"--> try {i}: {get_job_response_model}")
if get_job_response_model.status == JobStatus.SUCCEEDED.value:
succeeded = True
break
if get_job_response_model.status == JobStatus.FAILED.value:
succeeded = False
break
time.sleep(1)
time.sleep(20)
assert succeeded


def test_full_experiment_launch(local_client: TestClient, dialog_dataset, dependency_overrides_services):
def test_full_experiment_launch(
local_client: TestClient, dialog_dataset, dependency_overrides_services
):
response = local_client.get("/health")
assert response.status_code == 200
create_response = local_client.post(
"/datasets/",
data={},
files={"dataset": dialog_dataset, "format": (None, DatasetFormat.JOB.value)},
)
print(f'response: {create_response.text}')
print(f"response: {create_response.text}")
assert create_response.status_code == 201
created_dataset = DatasetResponse.model_validate(create_response.json())
headers = {
Expand Down Expand Up @@ -145,11 +158,11 @@ def test_full_experiment_launch(local_client: TestClient, dialog_dataset, depend
assert get_experiment_response.status_code == 200

succeeded = False
for i in range (1, 200):
get_job_response = local_client.get(f'/jobs/{get_experiments.items[0].id}')
for i in range(1, 200):
get_job_response = local_client.get(f"/jobs/{get_experiments.items[0].id}")
assert get_job_response.status_code == 200
get_job_response_model = JobResponse.model_validate(get_job_response.json())
print(f'--> try {i}: {get_job_response_model}')
print(f"--> try {i}: {get_job_response_model}")
if get_job_response_model.status == JobStatus.SUCCEEDED.value:
succeeded = True
break
Expand Down

0 comments on commit 1d25f6b

Please sign in to comment.