Skip to content

Commit

Permalink
sketch in JawsJobRunner class and basic unit test
Browse files Browse the repository at this point in the history
  • Loading branch information
mbthornton-lbl committed Dec 5, 2024
1 parent 00718bf commit 41b4695
Show file tree
Hide file tree
Showing 3 changed files with 99 additions and 19 deletions.
68 changes: 66 additions & 2 deletions nmdc_automation/workflow_automation/wfutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,7 +66,7 @@ def max_retries(self) -> int:
pass


class CromwellRunner(JobRunnerABC):
class CromwellJobRunner(JobRunnerABC):
"""Job runner for Cromwell"""
LABEL_SUBMITTER_VALUE = "nmdcda"
LABEL_PARAMETERS = ["release", "wdl", "git_repo"]
Expand Down Expand Up @@ -241,6 +241,70 @@ def max_retries(self) -> int:
return self._max_retries


class JawsJobRunner(JobRunnerABC):
"""Job runner for JAWS"""
def __init__(self, site_config: SiteConfig, workflow: "WorkflowStateManager", job_metadata: Dict[str, Any] = None,
max_retries: int = DEFAULT_MAX_RETRIES, dry_run: bool = False) -> None:
"""
Create a JAWS job runner.
:param site_config: SiteConfig object
:param workflow: WorkflowStateManager object
:param job_metadata: metadata for the job
:param max_retries: maximum number of retries for a job
:param dry_run: if True, do not submit the job
"""
self.config = site_config
if not isinstance(workflow, WorkflowStateManager):
raise ValueError("workflow must be a WorkflowStateManager object")
self.workflow = workflow
self._metadata = {}
if job_metadata:
self._metadata = job_metadata
self._max_retries = max_retries
self.dry_run = dry_run

@property
def job_id(self) -> Optional[str]:
""" Get the job id from the metadata """
return self.metadata.get("id", None)

@job_id.setter
def job_id(self, job_id: str):
""" Set the job id in the metadata """
self.metadata["id"] = job_id

def generate_submission_files(self) -> Dict[str, Any]:
pass

def submit_job(self, force: bool = False) -> Optional[str]:
pass

def get_job_status(self) -> str:
pass

def get_job_metadata(self) -> Dict[str, Any]:
pass

@property
def outputs(self) -> Dict[str, str]:
""" Get the outputs from the metadata """
return self.metadata.get("outputs", {})

@property
def metadata(self) -> Dict[str, Any]:
""" Get the metadata """
return self._metadata

@metadata.setter
def metadata(self, metadata: Dict[str, Any]):
""" Set the metadata """
self._metadata = metadata

@property
def max_retries(self) -> int:
return self._max_retries


class WorkflowStateManager:
CHUNK_SIZE = 1000000 # 1 MB
GIT_RELEASES_PATH = "/releases/download"
Expand Down Expand Up @@ -412,7 +476,7 @@ def __init__(self, site_config: SiteConfig, workflow_state: Dict[str, Any] = Non
self.workflow = WorkflowStateManager(workflow_state, opid)
# default to CromwellRunner if no job_runner is provided
if job_runner is None:
job_runner = CromwellRunner(site_config, self.workflow, job_metadata)
job_runner = CromwellJobRunner(site_config, self.workflow, job_metadata)
self.job = job_runner

# Properties to access the site config, job state, and job runner attributes
Expand Down
6 changes: 3 additions & 3 deletions tests/test_watch_nmdc.py
Original file line number Diff line number Diff line change
Expand Up @@ -313,7 +313,7 @@ def test_job_manager_get_finished_jobs(site_config, initial_state_file_1_failure
def test_job_manager_process_successful_job(site_config, initial_state_file_1_failure, fixtures_dir):
# mock job.job.get_job_metadata - use fixture cromwell/succeded_metadata.json
job_metadata = json.load(open(fixtures_dir / "mags_job_metadata.json"))
with patch("nmdc_automation.workflow_automation.wfutils.CromwellRunner.get_job_metadata") as mock_get_metadata:
with patch("nmdc_automation.workflow_automation.wfutils.CromwellJobRunner.get_job_metadata") as mock_get_metadata:
mock_get_metadata.return_value = job_metadata

# Arrange
Expand Down Expand Up @@ -355,7 +355,7 @@ def test_job_manager_get_finished_jobs_1_failure(site_config, initial_state_file
failed_job = failed_jobs[0]
assert failed_job.job_status == "Failed"

@mock.patch("nmdc_automation.workflow_automation.wfutils.CromwellRunner.generate_submission_files")
@mock.patch("nmdc_automation.workflow_automation.wfutils.CromwellJobRunner.generate_submission_files")
def test_job_manager_process_failed_job_1_failure(
mock_generate_submission_files, site_config, initial_state_file_1_failure, mock_cromwell_api):
# Arrange
Expand Down Expand Up @@ -393,7 +393,7 @@ def test_job_manager_process_failed_job_2_failures(site_config, initial_state_fi
def mock_runtime_api_handler(site_config, mock_api):
pass

@mock.patch("nmdc_automation.workflow_automation.wfutils.CromwellRunner.submit_job")
@mock.patch("nmdc_automation.workflow_automation.wfutils.CromwellJobRunner.submit_job")
def test_claim_jobs(mock_submit, site_config_file, site_config, fixtures_dir):
# Arrange
mock_submit.return_value = {"id": "nmdc:1234", "detail": {"id": "nmdc:1234"}}
Expand Down
44 changes: 30 additions & 14 deletions tests/test_wfutils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
from nmdc_automation.workflow_automation.wfutils import (
CromwellRunner,
CromwellJobRunner,
WorkflowJob,
WorkflowStateManager,
_json_tmp,
JawsJobRunner,
)
from nmdc_automation.models.nmdc import DataObject, workflow_process_factory
from nmdc_schema.nmdc import MagsAnalysis, EukEval
Expand Down Expand Up @@ -36,12 +36,29 @@ def test_workflow_job(site_config, fixtures_dir):


def test_cromwell_job_runner(site_config, fixtures_dir):
# load cromwell metadata
""" Test basic initialization of CromwellRunner """
job_metadata = json.load(open(fixtures_dir / "mags_job_metadata.json"))
job_state = json.load(open(fixtures_dir / "mags_workflow_state.json"))
state_manager = WorkflowStateManager(job_state)
job_runner = CromwellJobRunner(site_config, state_manager, job_metadata)
assert job_runner
assert hasattr(job_runner, "generate_submission_files")
assert hasattr(job_runner, "submit_job")
assert hasattr(job_runner, "get_job_status")
assert hasattr(job_runner, "get_job_metadata")


def test_jaws_job_runner(site_config, fixtures_dir):
""" Test basic initialization of JawsJobRunner """
job_metadata = json.load(open(fixtures_dir / "mags_job_metadata.json"))
job_state = json.load(open(fixtures_dir / "mags_workflow_state.json"))
state_manager = WorkflowStateManager(job_state)
job_runner = CromwellRunner(site_config, state_manager, job_metadata)
job_runner = JawsJobRunner(site_config, state_manager, job_metadata)
assert job_runner
assert hasattr(job_runner, "generate_submission_files")
assert hasattr(job_runner, "submit_job")
assert hasattr(job_runner, "get_job_status")
assert hasattr(job_runner, "get_job_metadata")


def test_cromwell_job_runner_get_job_status(site_config, fixtures_dir, mock_cromwell_api):
Expand All @@ -53,7 +70,7 @@ def test_cromwell_job_runner_get_job_status(site_config, fixtures_dir, mock_crom
job_metadata['id'] = "cromwell-job-id-12345"

state_manager = WorkflowStateManager(job_state)
job_runner = CromwellRunner(site_config, state_manager, job_metadata)
job_runner = CromwellJobRunner(site_config, state_manager, job_metadata)
status = job_runner.get_job_status()
assert status
assert status == "Succeeded"
Expand All @@ -62,7 +79,7 @@ def test_cromwell_job_runner_get_job_status(site_config, fixtures_dir, mock_crom
job_state['cromwell_jobid'] = "cromwell-job-id-54321"
job_metadata['id'] = "cromwell-job-id-54321"
state_manager = WorkflowStateManager(job_state)
job_runner = CromwellRunner(site_config, state_manager, job_metadata)
job_runner = CromwellJobRunner(site_config, state_manager, job_metadata)
status = job_runner.get_job_status()
assert status
assert status == "Failed"
Expand All @@ -77,7 +94,7 @@ def test_cromwell_job_runner_get_job_metadata(site_config, fixtures_dir, mock_cr
job_metadata['id'] = "cromwell-job-id-12345"

state_manager = WorkflowStateManager(job_state)
job_runner = CromwellRunner(site_config, state_manager, job_metadata)
job_runner = CromwellJobRunner(site_config, state_manager, job_metadata)
metadata = job_runner.get_job_metadata()
assert metadata
assert metadata['id'] == "cromwell-job-id-12345"
Expand Down Expand Up @@ -179,7 +196,7 @@ def test_workflow_manager_fetch_release_file_failed_write(mock_get, fixtures_dir
def test_cromwell_runner_setup_inputs_and_labels(site_config, fixtures_dir):
job_state = json.load(open(fixtures_dir / "mags_workflow_state.json"))
workflow = WorkflowStateManager(job_state)
runner = CromwellRunner(site_config, workflow)
runner = CromwellJobRunner(site_config, workflow)
inputs = runner._generate_workflow_inputs()
assert inputs
# we expect the inputs to be a key-value dict with URLs as values
Expand Down Expand Up @@ -212,7 +229,7 @@ def test_cromwell_runner_generate_submission_files( mock_fetch_release_file, sit
io.BytesIO(b"mock workflow inputs"), # workflowInputs file
io.BytesIO(b"mock labels") # labels file
]
runner = CromwellRunner(site_config, workflow)
runner = CromwellJobRunner(site_config, workflow)
submission_files = runner.generate_submission_files()
assert submission_files
assert "workflowSource" in submission_files
Expand All @@ -227,7 +244,7 @@ def test_cromwell_runner_generate_submission_files( mock_fetch_release_file, sit


@mock.patch("nmdc_automation.workflow_automation.wfutils.WorkflowStateManager.fetch_release_file")
@mock.patch("nmdc_automation.workflow_automation.wfutils.CromwellRunner._cleanup_files")
@mock.patch("nmdc_automation.workflow_automation.wfutils.CromwellJobRunner._cleanup_files")
def test_cromwell_runner_generate_submission_files_exception(mock_cleanup_files, mock_fetch_release_file,
site_config, fixtures_dir):
# Mock file fetching
Expand All @@ -247,14 +264,14 @@ def test_cromwell_runner_generate_submission_files_exception(mock_cleanup_files,
OSError("Failed to open file"), # workflowInputs file
io.BytesIO(b"mock labels") # labels file
]
runner = CromwellRunner(site_config, workflow)
runner = CromwellJobRunner(site_config, workflow)
with pytest.raises(OSError):
runner.generate_submission_files()
# Check that the cleanup function was called
mock_cleanup_files.assert_called_once()


@mock.patch("nmdc_automation.workflow_automation.wfutils.CromwellRunner.generate_submission_files")
@mock.patch("nmdc_automation.workflow_automation.wfutils.CromwellJobRunner.generate_submission_files")
def test_cromwell_job_runner_submit_job_new_job(mock_generate_submission_files, site_config, fixtures_dir, mock_cromwell_api):
mock_generate_submission_files.return_value = {
"workflowSource": "workflowSource",
Expand All @@ -270,7 +287,7 @@ def test_cromwell_job_runner_submit_job_new_job(mock_generate_submission_files,
wf_state['done'] = False # simulate a job that has not been submitted

wf_state_manager = WorkflowStateManager(wf_state)
job_runner = CromwellRunner(site_config, wf_state_manager)
job_runner = CromwellJobRunner(site_config, wf_state_manager)
jobid = job_runner.submit_job()
assert jobid

Expand Down Expand Up @@ -320,7 +337,6 @@ def test_workflow_execution_record_from_workflow_job(site_config, fixtures_dir,
assert wfe.ended_at_time



def test_workflow_job_from_database_job_record(site_config, fixtures_dir):
job_rec = json.load(open(fixtures_dir / "nmdc_api/unsubmitted_job.json"))
assert job_rec
Expand Down

0 comments on commit 41b4695

Please sign in to comment.