Skip to content

Commit

Permalink
add retry decorator to api methods
Browse files Browse the repository at this point in the history
  • Loading branch information
mbthornton-lbl committed Dec 20, 2024
1 parent 9db3bdf commit ccd00d2
Show file tree
Hide file tree
Showing 4 changed files with 74 additions and 36 deletions.
56 changes: 42 additions & 14 deletions nmdc_automation/api/nmdcapi.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,10 @@
import logging
from tenacity import retry, wait_exponential, stop_after_attempt

logging.basicConfig(level=logging.INFO)
logging_level = os.getenv("NMDC_LOG_LEVEL", logging.DEBUG)
logging.basicConfig(
level=logging_level, format="%(asctime)s %(levelname)s: %(message)s"
)
logger = logging.getLogger(__name__)

SECONDS_IN_DAY = 86400
Expand Down Expand Up @@ -74,11 +77,7 @@ def _get_token(self, *args, **kwargs):

return _get_token

@retry(
wait=wait_exponential(multiplier=4, min=8, max=120),
stop=stop_after_attempt(6),
reraise=True,
)
@retry(wait=wait_exponential(multiplier=4, min=8, max=120), stop=stop_after_attempt(6), reraise=True)
def get_token(self):
"""
Get a token using a client id/secret.
Expand Down Expand Up @@ -120,12 +119,13 @@ def get_token(self):
"Content-Type": "application/json",
"Authorization": "Bearer %s" % (self.token),
}
logging.info(f"New token expires at {self.expires_at}")
logging.debug(f"New token expires at {self.expires_at}")
return response_body

def get_header(self):
return self.header

@retry(wait=wait_exponential(multiplier=4, min=8, max=120), stop=stop_after_attempt(6), reraise=True)
@refresh_token
def minter(self, id_type, informed_by=None):
url = f"{self._base_url}pids/mint"
Expand All @@ -143,6 +143,7 @@ def minter(self, id_type, informed_by=None):
raise ValueError("Failed to bind metadata to pid")
return id

@retry(wait=wait_exponential(multiplier=4, min=8, max=120), stop=stop_after_attempt(6), reraise=True)
@refresh_token
def mint(self, ns, typ, ct):
"""
Expand All @@ -155,15 +156,20 @@ def mint(self, ns, typ, ct):
url = self._base_url + "ids/mint"
d = {"populator": "", "naa": ns, "shoulder": typ, "number": ct}
resp = requests.post(url, headers=self.header, data=json.dumps(d))
if not resp.ok:
resp.raise_for_status()
return resp.json()

@retry(wait=wait_exponential(multiplier=4, min=8, max=120), stop=stop_after_attempt(6), reraise=True)
@refresh_token
def get_object(self, obj, decode=False):
"""
Helper function to get object info
"""
url = "%sobjects/%s" % (self._base_url, obj)
resp = requests.get(url, headers=self.header)
if not resp.ok:
resp.raise_for_status()
data = resp.json()
if decode and "description" in data:
try:
Expand All @@ -173,6 +179,8 @@ def get_object(self, obj, decode=False):

return data


@retry(wait=wait_exponential(multiplier=4, min=8, max=120), stop=stop_after_attempt(6), reraise=True)
@refresh_token
def create_object(self, fn, description, dataurl):
"""
Expand Down Expand Up @@ -214,8 +222,11 @@ def create_object(self, fn, description, dataurl):
"self_uri": "todo",
}
resp = requests.post(url, headers=self.header, data=json.dumps(d))
if not resp.ok:
resp.raise_for_status()
return resp.json()

@retry(wait=wait_exponential(multiplier=4, min=8, max=120), stop=stop_after_attempt(6), reraise=True)
@refresh_token
def post_objects(self, obj_data):
url = self._base_url + "workflows/workflow_executions"
Expand All @@ -225,23 +236,29 @@ def post_objects(self, obj_data):
resp.raise_for_status()
return resp.json()

@retry(wait=wait_exponential(multiplier=4, min=8, max=120), stop=stop_after_attempt(6), reraise=True)
@refresh_token
def set_type(self, obj, typ):
url = "%sobjects/%s/types" % (self._base_url, obj)
d = [typ]
resp = requests.put(url, headers=self.header, data=json.dumps(d))
if not resp.ok:
resp.raise_for_status()
return resp.json()

@retry(wait=wait_exponential(multiplier=4, min=8, max=120), stop=stop_after_attempt(6), reraise=True)
@refresh_token
def bump_time(self, obj):
url = "%sobjects/%s" % (self._base_url, obj)
now = datetime.today().isoformat()

d = {"created_time": now}
resp = requests.patch(url, headers=self.header, data=json.dumps(d))
if not resp.ok:
resp.raise_for_status()
return resp.json()

# TODO test that this concatenates multi-page results
@retry(wait=wait_exponential(multiplier=4, min=8, max=120), stop=stop_after_attempt(6), reraise=True)
@refresh_token
def list_jobs(self, filt=None, max=100) -> List[dict]:
url = "%sjobs?max_page_size=%s" % (self._base_url, max)
Expand All @@ -259,8 +276,6 @@ def list_jobs(self, filt=None, max=100) -> List[dict]:
except Exception as e:
logging.error(f"Failed to parse response: {resp.text}")
raise e


if "resources" not in response_json:
logging.warning(str(response_json))
break
Expand All @@ -270,15 +285,19 @@ def list_jobs(self, filt=None, max=100) -> List[dict]:
url = orig_url + "&page_token=%s" % (response_json["next_page_token"])
return results

@retry(wait=wait_exponential(multiplier=4, min=8, max=120), stop=stop_after_attempt(6), reraise=True)
@refresh_token
def get_job(self, job):
url = "%sjobs/%s" % (self._base_url, job)
def get_job(self, job_id: str):
url = "%sjobs/%s" % (self._base_url, job_id)
resp = requests.get(url, headers=self.header)
if not resp.ok:
resp.raise_for_status
return resp.json()

@retry(wait=wait_exponential(multiplier=4, min=8, max=120), stop=stop_after_attempt(6), reraise=True)
@refresh_token
def claim_job(self, job):
url = "%sjobs/%s:claim" % (self._base_url, job)
def claim_job(self, job_id: str):
url = "%sjobs/%s:claim" % (self._base_url, job_id)
resp = requests.post(url, headers=self.header)
if resp.status_code == 409:
claimed = True
Expand All @@ -302,6 +321,7 @@ def _page_query(self, url):
url = orig_url + "&page_token=%s" % (resp["next_page_token"])
return results

@retry(wait=wait_exponential(multiplier=4, min=8, max=120), stop=stop_after_attempt(6), reraise=True)
@refresh_token
def list_objs(self, filt=None, max_page_size=40):
url = "%sobjects?max_page_size=%d" % (self._base_url, max_page_size)
Expand All @@ -310,6 +330,7 @@ def list_objs(self, filt=None, max_page_size=40):
results = self._page_query(url)
return results

@retry(wait=wait_exponential(multiplier=4, min=8, max=120), stop=stop_after_attempt(6), reraise=True)
@refresh_token
def list_ops(self, filt=None, max_page_size=40):
url = "%soperations?max_page_size=%d" % (self._base_url, max_page_size)
Expand All @@ -329,12 +350,16 @@ def list_ops(self, filt=None, max_page_size=40):
url = orig_url + "&page_token=%s" % (resp["next_page_token"])
return results

@retry(wait=wait_exponential(multiplier=4, min=8, max=120), stop=stop_after_attempt(6), reraise=True)
@refresh_token
def get_op(self, opid):
url = "%soperations/%s" % (self._base_url, opid)
resp = requests.get(url, headers=self.header)
if not resp.ok:
resp.raise_for_status()
return resp.json()

@retry(wait=wait_exponential(multiplier=4, min=8, max=120), stop=stop_after_attempt(6), reraise=True)
@refresh_token
def update_op(self, opid, done=None, results=None, meta=None):
"""
Expand All @@ -357,8 +382,11 @@ def update_op(self, opid, done=None, results=None, meta=None):
d["metadata"] = cur["metadata"]
d["metadata"]["extra"] = meta
resp = requests.patch(url, headers=self.header, data=json.dumps(d))
if not resp.ok:
resp.raise_for_status()
return resp.json()

@retry(wait=wait_exponential(multiplier=4, min=8, max=120), stop=stop_after_attempt(6), reraise=True)
@refresh_token
def run_query(self, query):
url = "%squeries:run" % self._base_url
Expand Down
2 changes: 1 addition & 1 deletion nmdc_automation/run_process/run_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
from nmdc_automation.api import NmdcRuntimeApi


logging.basicConfig(level=logging.INFO)
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)


Expand Down
7 changes: 6 additions & 1 deletion nmdc_automation/workflow_automation/watch_nmdc.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
import importlib.resources
from functools import lru_cache
import traceback
import os

from nmdc_schema.nmdc import Database
from nmdc_automation.api import NmdcRuntimeApi
Expand All @@ -23,8 +24,12 @@
DEFAULT_STATE_DIR = Path(__file__).parent / "_state"
DEFAULT_STATE_FILE = DEFAULT_STATE_DIR / "state.json"
INITIAL_STATE = {"jobs": []}

logging_level = os.getenv("NMDC_LOG_LEVEL", logging.DEBUG)
logging.basicConfig(
level=logging_level, format="%(asctime)s %(levelname)s: %(message)s"
)
logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)


class FileHandler:
Expand Down
45 changes: 25 additions & 20 deletions nmdc_automation/workflow_automation/wfutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,11 @@

DEFAULT_MAX_RETRIES = 2

logging_level = os.getenv("NMDC_LOG_LEVEL", logging.DEBUG)
logging.basicConfig(
level=logging_level, format="%(asctime)s %(levelname)s: %(message)s"
)
logger = logging.getLogger(__name__)

class JobRunnerABC(ABC):
"""Abstract base class for job runners"""
Expand Down Expand Up @@ -132,7 +137,7 @@ def generate_submission_files(self) -> Dict[str, Any]:
"workflowInputs": open(_json_tmp(self._generate_workflow_inputs()), "rb"),
"labels": open(_json_tmp(self._generate_workflow_labels()), "rb"), }
except Exception as e:
logging.error(f"Failed to generate submission files: {e}")
logger.error(f"Failed to generate submission files: {e}")
self._cleanup_files(list(files.values()))
raise e
return files
Expand All @@ -144,7 +149,7 @@ def _cleanup_files(self, files: List[Union[tempfile.NamedTemporaryFile, tempfile
file.close()
os.unlink(file.name)
except Exception as e:
logging.error(f"Failed to cleanup file: {e}")
logger.error(f"Failed to cleanup file: {e}")

def submit_job(self, force: bool = False) -> Optional[str]:
"""
Expand All @@ -154,7 +159,7 @@ def submit_job(self, force: bool = False) -> Optional[str]:
"""
status = self.workflow.last_status
if status in self.NO_SUBMIT_STATES and not force:
logging.info(f"Job {self.job_id} in state {status}, skipping submission")
logger.info(f"Job {self.job_id} in state {status}, skipping submission")
return
cleanup_files = []
try:
Expand All @@ -165,12 +170,12 @@ def submit_job(self, force: bool = False) -> Optional[str]:
response.raise_for_status()
self.metadata = response.json()
self.job_id = self.metadata["id"]
logging.info(f"Submitted job {self.job_id}")
logger.info(f"Submitted job {self.job_id}")
else:
logging.info(f"Dry run: skipping job submission")
logger.info(f"Dry run: skipping job submission")
self.job_id = "dry_run"

logging.info(f"Job {self.job_id} submitted")
logger.info(f"Job {self.job_id} submitted")
start_time = datetime.now(pytz.utc).isoformat()
# update workflow state
self.workflow.done = False
Expand All @@ -179,7 +184,7 @@ def submit_job(self, force: bool = False) -> Optional[str]:
self.workflow.update_state({"last_status": "Submitted"})
return self.job_id
except Exception as e:
logging.error(f"Failed to submit job: {e}")
logger.error(f"Failed to submit job: {e}")
raise e
finally:
self._cleanup_files(cleanup_files)
Expand All @@ -191,7 +196,7 @@ def get_job_status(self) -> str:
status_url = f"{self.service_url}/{self.workflow.cromwell_jobid}/status"
# There can be a delay between submitting a job and it
# being available in Cromwell so handle 404 errors
logging.debug(f"Getting job status from {status_url}")
logger.debug(f"Getting job status from {status_url}")
try:
response = requests.get(status_url)
response.raise_for_status()
Expand Down Expand Up @@ -355,9 +360,9 @@ def fetch_release_file(self, filename: str, suffix: str = None) -> str:
Download a release file from the Git repository and save it as a temporary file.
Note: the temporary file is not deleted automatically.
"""
logging.debug(f"Fetching release file: {filename}")
logger.debug(f"Fetching release file: {filename}")
url = self._build_release_url(filename)
logging.debug(f"Fetching release file from URL: {url}")
logger.debug(f"Fetching release file from URL: {url}")
# download the file as a stream to handle large files
response = requests.get(url, stream=True)
try:
Expand All @@ -371,9 +376,9 @@ def fetch_release_file(self, filename: str, suffix: str = None) -> str:

def _build_release_url(self, filename: str) -> str:
"""Build the URL for a release file in the Git repository."""
logging.debug(f"Building release URL for {filename}")
logger.debug(f"Building release URL for {filename}")
release = self.config["release"]
logging.debug(f"Release: {release}")
logger.debug(f"Release: {release}")
base_url = self.config["git_repo"].rstrip("/")
url = f"{base_url}{self.GIT_RELEASES_PATH}/{release}/{filename}"
return url
Expand All @@ -388,7 +393,7 @@ def _write_stream_to_file(self, response: requests.Response, file: tempfile.Name
except Exception as e:
# clean up the temporary file
Path(file.name).unlink(missing_ok=True)
logging.error(f"Error writing stream to file: {e}")
logger.error(f"Error writing stream to file: {e}")
raise e


Expand Down Expand Up @@ -508,16 +513,16 @@ def make_data_objects(self, output_dir: Union[str, Path] = None) -> List[DataObj

for output_spec in self.workflow.data_outputs: # specs are defined in the workflow.yaml file under Outputs
output_key = f"{self.workflow.input_prefix}.{output_spec['output']}"
logging.info(f"Processing output {output_key}")
logger.info(f"Processing output {output_key}")
# get the full path to the output file from the job_runner
output_file_path = Path(self.job.outputs[output_key])
logging.info(f"Output file path: {output_file_path}")
logger.info(f"Output file path: {output_file_path}")
if output_key not in self.job.outputs:
if output_spec.get("optional"):
logging.debug(f"Optional output {output_key} not found in job outputs")
logger.debug(f"Optional output {output_key} not found in job outputs")
continue
else:
logging.warning(f"Required output {output_key} not found in job outputs")
logger.warning(f"Required output {output_key} not found in job outputs")
continue


Expand All @@ -531,7 +536,7 @@ def make_data_objects(self, output_dir: Union[str, Path] = None) -> List[DataObj
# copy the file to the output directory
shutil.copy(output_file_path, new_output_file_path)
else:
logging.warning(f"Output directory not provided, not copying {output_file_path} to output directory")
logger.warning(f"Output directory not provided, not copying {output_file_path} to output directory")

# create a DataObject object
data_object = DataObject(
Expand Down Expand Up @@ -562,7 +567,7 @@ def make_workflow_execution(self, data_objects: List[DataObject]) -> WorkflowExe
if attr_val.startswith("{outputs."):
match = re.match(pattern, attr_val)
if not match:
logging.warning(f"Invalid output reference {attr_val}")
logger.warning(f"Invalid output reference {attr_val}")
continue
logical_names.add(match.group(1))
field_names.add(match.group(2))
Expand All @@ -579,7 +584,7 @@ def make_workflow_execution(self, data_objects: List[DataObject]) -> WorkflowExe
if field_name in data:
wf_dict[field_name] = data[field_name]
else:
logging.warning(f"Field {field_name} not found in {data_path}")
logger.warning(f"Field {field_name} not found in {data_path}")

wfe = workflow_process_factory(wf_dict)
return wfe
Expand Down

0 comments on commit ccd00d2

Please sign in to comment.