Skip to content

Commit

Permalink
Separating out setting model type into a separate method in the job s…
Browse files Browse the repository at this point in the history
…ervice (#385)

* test input model params

* ruff

* test

* fmt

* fix job-all and pass settings from backend

* test fix
  • Loading branch information
veekaybee authored Nov 20, 2024
1 parent 5676542 commit e0dd518
Show file tree
Hide file tree
Showing 5 changed files with 79 additions and 22 deletions.
10 changes: 10 additions & 0 deletions .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -17,3 +17,13 @@ repos:
- id: end-of-file-fixer
- id: requirements-txt-fixer
exclude: requirements_lock.txt

- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.7.4
hooks:
# Run the linter.
- id: ruff
args: [ --fix ]
# Run the formatter.
- id: ruff-format
4 changes: 2 additions & 2 deletions lumigator/python/mzai/backend/backend/api/routes/health.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,8 @@ def get_job_metadata(job_id: UUID) -> JobSubmissionResponse:
)
elif resp.status_code != HTTPStatus.OK:
loguru.logger.error(
f"Unexpected status code getting job metadata text: "
"{resp.status_code}, error: {resp.text or ''}"
"Unexpected status code getting job metadata text: "
f"{resp.status_code}, error: {resp.text or ''}"
)
raise HTTPException(
status_code=HTTPStatus.INTERNAL_SERVER_ERROR,
Expand Down
14 changes: 10 additions & 4 deletions lumigator/python/mzai/backend/backend/services/jobs.py
Original file line number Diff line number Diff line change
Expand Up @@ -104,17 +104,23 @@ def _get_config_template(self, job_type: str, model_name: str) -> str:

return config_template

def _get_job_params(self, job_type: str, record, request: BaseModel) -> dict:
# get dataset S3 path from UUID
dataset_s3_path = self.data_service.get_dataset_s3_path(request.dataset)

def _set_model_type(self, request: BaseModel) -> str:
"""Sets model URL based on protocol address"""
if request.model.startswith("oai://"):
model_url = settings.OAI_API_URL
elif request.model.startswith("mistral://"):
model_url = settings.MISTRAL_API_URL
else:
model_url = request.model_url

return model_url

def _get_job_params(self, job_type: str, record, request: BaseModel) -> dict:
# get dataset S3 path from UUID
dataset_s3_path = self.data_service.get_dataset_s3_path(request.dataset)

model_url = self._set_model_type(request)

# provide a reasonable system prompt for services where none was specified
if request.system_prompt is None and not request.model.startswith("hf://"):
request.system_prompt = settings.DEFAULT_SUMMARIZER_PROMPT
Expand Down
Original file line number Diff line number Diff line change
@@ -1,24 +1,65 @@

import pytest
from lumigator_schemas.jobs import (
JobInferenceCreate,
)

from backend.services.jobs import JobService
from backend.settings import settings


def test_set_null_inference_job_params(job_record,job_service ):
request = JobInferenceCreate(name="test_run_hugging_face",
description="Test run for Huggingface model",
model="Test run for Huggingface model",
dataset="cced289c-f869-4af1-9195-1d58e32d1cc1")
params = job_service._get_job_params("INFERENCE", job_record,request)
def test_set_null_inference_job_params(job_record, job_service):
request = JobInferenceCreate(
name="test_run_hugging_face",
description="Test run for Huggingface model",
model="hf://facebook/bart-large-cnn",
dataset="cced289c-f869-4af1-9195-1d58e32d1cc1",
)
params = job_service._get_job_params("INFERENCE", job_record, request)
assert params["max_samples"] == -1

def test_set_explicit_inference_job_params(job_record,job_service ):
request = JobInferenceCreate(name="test_run_hugging_face",
description="Test run for Huggingface model",
max_samples=10,
model="Test run for Huggingface model",
dataset="cced289c-f869-4af1-9195-1d58e32d1cc1")
params = job_service._get_job_params("INFERENCE", job_record,request)

def test_set_explicit_inference_job_params(job_record, job_service):
request = JobInferenceCreate(
name="test_run_hugging_face",
description="Test run for Huggingface model",
max_samples=10,
model="hf://facebook/bart-large-cnn",
dataset="cced289c-f869-4af1-9195-1d58e32d1cc1",
)
params = job_service._get_job_params("INFERENCE", job_record, request)
assert params["max_samples"] == 10


@pytest.mark.parametrize(
["model", "input_model_url", "returned_model_url"],
[
# generic HF model loaded locally
("hf://facebook/bart-large-cnn", None, None),
# vLLM served model (with HF model name specified to be passed as "engine")
(
"hf://mistralai/Mistral-7B-Instruct-v0.3",
"http://localhost:8000/v1/chat/completions",
"http://localhost:8000/v1/chat/completions",
),
# llamafile served model (with custom model name)
(
"llamafile://mistralai/Mistral-7B-Instruct-v0.2",
"http://localhost:8000/v1/chat/completions",
"http://localhost:8000/v1/chat/completions",
),
# openai model (from API)
("oai://gpt-4-turbo", None, settings.OAI_API_URL),
# mistral model (from API)
("mistral://open-mistral-7b", None, settings.MISTRAL_API_URL),
],
)
def test_set_model(job_service, model, input_model_url, returned_model_url):
request = JobInferenceCreate(
name="test_run",
description="Test run to verify how model URL is set",
model=model,
model_url=input_model_url,
dataset="d34dd34d-d34d-d34d-d34d-d34dd34dd34d",
)
model_url = job_service._set_model_type(request)
assert model_url == returned_model_url
4 changes: 2 additions & 2 deletions lumigator/python/mzai/sdk/tests/data/job-all.json
Original file line number Diff line number Diff line change
Expand Up @@ -4,8 +4,8 @@
"model": "some_model",
"dataset": "6f6487ac-7170-4a11-af7a-0f6db1ec9a74",
"max_samples": 5,
"model_url": "s3://some_url",
"model_url": "https://huggingface.co/facebook/bart-large-cnn",
"system_prompt": "some prompt",
"config_infer_template": "some_infer_template",
"config_eval_template": "some_eval_template"
}
}

0 comments on commit e0dd518

Please sign in to comment.