Skip to content

Commit

Permalink
refactor Scheduler args and get_mongo_db() to work with local db
Browse files Browse the repository at this point in the history
  • Loading branch information
mbthornton-lbl committed Nov 21, 2024
1 parent 7213e8b commit 11fe8b9
Show file tree
Hide file tree
Showing 3 changed files with 38 additions and 22 deletions.
13 changes: 13 additions & 0 deletions nmdc_automation/api/nmdcapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,6 +330,19 @@ def run_query(self, query):
resp.raise_for_status()
return resp.json()

@refresh_token
def get_workflow_executions_for_informed_by(self, informed_by, type):
"""
Get a workflow execution for a given informed_by and type.
"""
url = f"{self._base_url}nmdcschema/workflow_execution_set"
filt = {
"informed_by": informed_by,
"type": type
}
resp = requests.get(url, headers=self.header, data=json.dumps(filt))
return resp.json()


# TODO - This is deprecated and should be removed along with the re_iding code that uses it
class NmdcRuntimeUserApi:
Expand Down
35 changes: 19 additions & 16 deletions nmdc_automation/workflow_automation/sched.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,17 +29,15 @@
logger = logging.getLogger(__name__)
@lru_cache
def get_mongo_db() -> MongoDatabase:
for k in ["HOST", "USERNAME", "PASSWORD", "DBNAME"]:
if f"MONGO_{k}" not in os.environ:
raise KeyError(f"Missing MONGO_{k}")
_client = MongoClient(
host=os.getenv("MONGO_HOST"),
host=os.getenv("MONGO_HOST", "localhost"),
port=int(os.getenv("MONGO_PORT", "27017")),
username=os.getenv("MONGO_USERNAME"),
password=os.getenv("MONGO_PASSWORD"),
username=os.getenv("MONGO_USERNAME", None),
password=os.getenv("MONGO_PASSWORD", None),
directConnection=True,
)
return _client[os.getenv("MONGO_DBNAME")]
)[os.getenv("MONGO_DBNAME", "nmdc")]
return _client



def within_range(wf1: WorkflowConfig, wf2: WorkflowConfig, force=False) -> bool:
Expand Down Expand Up @@ -84,12 +82,12 @@ def __init__(self, workflow: WorkflowConfig, trigger_act: WorkflowProcessNode):

class Scheduler:

def __init__(self, db, wfn="workflows.yaml",
def __init__(self, db, workflow_yaml,
site_conf="site_configuration.toml"):
logging.info("Initializing Scheduler")
# Init
wf_file = os.environ.get(_WF_YAML_ENV, wfn)
self.workflows = load_workflow_configs(wf_file)
# wf_file = os.environ.get(_WF_YAML_ENV, wfn)
self.workflows = load_workflow_configs(workflow_yaml)
self.db = db
self.api = NmdcRuntimeApi(site_conf)
# TODO: Make force a optional parameter
Expand Down Expand Up @@ -318,17 +316,18 @@ def cycle(self, dryrun: bool = False, skiplist: set = set(),
return job_recs


def main(): # pragma: no cover
def main(site_conf, wf_file): # pragma: no cover
"""
Main function
"""
site_conf = os.environ.get("NMDC_SITE_CONF", "site_configuration.toml")
sched = Scheduler(get_mongo_db(), site_conf=site_conf)
# site_conf = os.environ.get("NMDC_SITE_CONF", "site_configuration.toml")
db = get_mongo_db()
sched = Scheduler(db, wf_file, site_conf=site_conf)
dryrun = False
if os.environ.get("DRYRUN") == "1":
dryrun = True
skiplist = set()
allowlist = None
# allowlist = None
if os.environ.get("SKIPLISTFILE"):
with open(os.environ.get("SKIPLISTFILE")) as f:
for line in f:
Expand All @@ -338,6 +337,8 @@ def main(): # pragma: no cover
with open(os.environ.get("ALLOWLISTFILE")) as f:
for line in f:
allowlist.add(line.rstrip())
# for local testing
allowlist = ["nmdc:omprc-11-cegmwy02"]
while True:
sched.cycle(dryrun=dryrun, skiplist=skiplist, allowlist=allowlist)
if dryrun:
Expand All @@ -347,4 +348,6 @@ def main(): # pragma: no cover

if __name__ == "__main__": # pragma: no cover
logging.basicConfig(level=logging.INFO)
main()
main("/Users/MBThornton/Documents/code/nmdc_automation/.local/site_conf.toml",
"/Users/MBThornton/Documents/code/nmdc_automation/nmdc_automation/config/workflows/workflows.yaml"
)
12 changes: 6 additions & 6 deletions tests/test_sched.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def test_scheduler_cycle(test_db, mock_api, workflow_file, workflows_config_dir,
# Scheduler will find one job to create
exp_num_jobs_initial = 1
exp_num_jobs_cycle_1 = 0
jm = Scheduler(test_db, wfn=workflows_config_dir / workflow_file,
jm = Scheduler(test_db, workflow_yaml=workflows_config_dir / workflow_file,
site_conf=site_config_file)
resp = jm.cycle()
assert len(resp) == exp_num_jobs_initial
Expand All @@ -50,7 +50,7 @@ def test_progress(test_db, mock_api, workflow_file, workflows_config_dir, site_c



jm = Scheduler(test_db, wfn=workflows_config_dir / workflow_file,
jm = Scheduler(test_db, workflow_yaml=workflows_config_dir / workflow_file,
site_conf= site_config_file)
workflow_by_name = dict()
for wf in jm.workflows:
Expand Down Expand Up @@ -127,7 +127,7 @@ def test_multiple_versions(test_db, mock_api, workflows_config_dir, site_config_
load_fixture(test_db, "data_object_set.json")
load_fixture(test_db, "data_generation_set.json")

jm = Scheduler(test_db, wfn=workflows_config_dir / "workflows.yaml",
jm = Scheduler(test_db, workflow_yaml=workflows_config_dir / "workflows.yaml",
site_conf=site_config_file)
workflow_by_name = dict()
for wf in jm.workflows:
Expand Down Expand Up @@ -160,7 +160,7 @@ def test_out_of_range(test_db, mock_api, workflows_config_dir, site_config_file)
test_db.jobs.delete_many({})
load_fixture(test_db, "data_object_set.json")
load_fixture(test_db, "data_generation_set.json")
jm = Scheduler(test_db, wfn=workflows_config_dir / "workflows.yaml",
jm = Scheduler(test_db, workflow_yaml=workflows_config_dir / "workflows.yaml",
site_conf=site_config_file)
# Let's create two RQC records. One will be in range
# and the other will not. We should only get new jobs
Expand All @@ -185,7 +185,7 @@ def test_type_resolving(test_db, mock_api, workflows_config_dir, site_config_fil
load_fixture(test_db, "data_generation_set.json")
load_fixture(test_db, "read_qc_analysis.json", col="workflow_execution_set")

jm = Scheduler(test_db, wfn=workflows_config_dir / "workflows.yaml",
jm = Scheduler(test_db, workflow_yaml=workflows_config_dir / "workflows.yaml",
site_conf=site_config_file)
workflow_by_name = dict()
for wf in jm.workflows:
Expand Down Expand Up @@ -213,7 +213,7 @@ def test_scheduler_add_job_rec(test_db, mock_api, workflow_file, workflows_confi
load_fixture(test_db, "data_object_set.json")
load_fixture(test_db, "data_generation_set.json")

jm = Scheduler(test_db, wfn=workflows_config_dir / workflow_file,
jm = Scheduler(test_db, workflow_yaml=workflows_config_dir / workflow_file,
site_conf=site_config_file)
# sanity check
assert jm
Expand Down

0 comments on commit 11fe8b9

Please sign in to comment.