Skip to content

Commit

Permalink
use generator for project files and free up memory
Browse files Browse the repository at this point in the history
  • Loading branch information
mbthornton-lbl committed Dec 10, 2024
1 parent 1dad830 commit e91c770
Show file tree
Hide file tree
Showing 2 changed files with 31 additions and 21 deletions.
18 changes: 9 additions & 9 deletions nmdc_automation/import_automation/activity_mapper.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import pytz
import yaml

from typing import List, Dict, Union, Tuple
from typing import List, Dict, Union, Tuple, Iterable
from nmdc_schema import nmdc

from nmdc_automation.api import NmdcRuntimeApi
Expand All @@ -20,7 +20,7 @@ class GoldMapper:
def __init__(
self,
iteration,
file_list: List[Union[str, Path]],
project_files: Iterable[Union[str, Path]],
nucelotide_sequencing_id: str,
yaml_file: Union[str, Path],
project_directory: Union[str, Path],
Expand All @@ -30,7 +30,7 @@ def __init__(
Initialize the GoldMapper object.
Args:
file_list: List of file paths to be processed.
project_files: Iterable with file paths to be processed.
nucelotide_sequencing_id: Identifier for the omics data.
yaml_file: File path of the yaml file containing import data.
root_directory: Root directory path.
Expand All @@ -40,7 +40,7 @@ def __init__(
self.import_data = self.load_yaml_file(yaml_file)
self.nmdc_db = nmdc.Database()
self.iteration = iteration
self.file_list = file_list
self.project_files = project_files
self.nucelotide_sequencing_id = nucelotide_sequencing_id
self.root_dir = os.path.join(
self.import_data["Workflow Metadata"]["Root Directory"], nucelotide_sequencing_id
Expand Down Expand Up @@ -79,7 +79,7 @@ def map_sequencing_data(self) -> Tuple[nmdc.Database, Dict]:
has_output = []
for data_object_dict in sequencing_import_data:
# get the file(s) that match the import suffix
for import_file in self.file_list:
for import_file in self.project_files:
import_file = str(import_file)
if re.search(data_object_dict["import_suffix"], import_file):
logging.debug(f"Processing {data_object_dict['data_object_type']}")
Expand Down Expand Up @@ -134,7 +134,7 @@ def map_sequencing_data(self) -> Tuple[nmdc.Database, Dict]:
return db, update


def map_data(self,db: nmdc.Database, unique: bool = True) -> Tuple[nmdc.Database, Dict]:
def map_data(self,db: nmdc.Database, unique: bool = True) -> nmdc.Database:
"""
Map data objects to the NMDC database.
"""
Expand Down Expand Up @@ -194,22 +194,22 @@ def process_files(files: Union[str, List[str]], data_object_dict: Dict, workflow

# Process unique data objects
if unique:
for file in map(str, self.file_list):
for file in map(str, self.project_files):
if re.search(data_object_spec["import_suffix"], file):
workflow_execution_id = self.get_workflow_execution_id(data_object_spec["output_of"])
db.data_object_set.append(process_files(file, data_object_spec, workflow_execution_id))

# Process multiple data data files into a single data object
else:
multiple_files = []
for file in map(str, self.file_list):
for file in map(str, self.project_files):
if re.search(data_object_spec["import_suffix"], file):
multiple_files.append(file)
if multiple_files:
workflow_execution_id = self.get_workflow_execution_id(data_object_spec["output_of"])
db.data_object_set.append(process_files(multiple_files, data_object_spec, workflow_execution_id, multiple=True))

return db, self.data_object_map
return db


def map_workflow_executions(self, db) -> nmdc.Database:
Expand Down
34 changes: 22 additions & 12 deletions nmdc_automation/run_process/run_import.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,16 +37,18 @@ def import_projects(import_file, import_yaml, site_configuration, iteration):
for data_import in data_imports:
project_path = data_import["project_path"]
nucleotide_sequencing_id = data_import["nucleotide_sequencing_id"]
files_list = [
os.path.join(project_path, f)
for f in os.listdir(os.path.abspath(project_path))
if os.path.isfile(os.path.join(project_path, f))
]

logger.info(f"Importing {nucleotide_sequencing_id} from {project_path}: {len(files_list)} files")
# files_list = [
# os.path.join(project_path, f)
# for f in os.listdir(os.path.abspath(project_path))
# if os.path.isfile(os.path.join(project_path, f))
# ]
# Replace the above with a generator
project_files = get_project_files(project_path)

logger.info(f"Importing {nucleotide_sequencing_id} from {project_path}")
mapper = GoldMapper(
iteration,
files_list,
project_files,
nucleotide_sequencing_id,
import_yaml,
project_path,
Expand All @@ -58,10 +60,10 @@ def import_projects(import_file, import_yaml, site_configuration, iteration):
db, data_generation_update = mapper.map_sequencing_data()
# Map the rest of the data files - single files
logger.info("Mapping single data files")
db, do_mapping = mapper.map_data(db)
db = mapper.map_data(db)
# Map the rest of the data files - multiple files
logger.info("Mapping multiple data files")
db, do_mapping = mapper.map_data(db, unique=False)
db = mapper.map_data(db, unique=False)

# map the workflow executions
logger.info("Mapping workflow executions")
Expand All @@ -70,6 +72,7 @@ def import_projects(import_file, import_yaml, site_configuration, iteration):
# validate the database
logger.info("Validating imported data")
db_dict = yaml.safe_load(yaml_dumper.dumps(db))
del db # free up memory
validation_report = linkml.validator.validate(db_dict, nmdc_materialized)
if validation_report.results:
logger.error(f"Validation Failed")
Expand All @@ -88,17 +91,24 @@ def import_projects(import_file, import_yaml, site_configuration, iteration):
logger.error(data_generation_update)
raise e


# Post the data to the API
logger.info("Posting data to the API")
try:
runtime.post_objects(db_dict)
del db_dict # free up memory
except Exception as e:
logger.error(f"Error posting data to the API: {e}")
raise e



def get_project_files(project_path):
"""
A generator that returns all the files in a project directory
"""
abs_project_path = os.path.abspath(project_path)
for f in os.scandir(abs_project_path):
if f.is_file():
yield f.path

@lru_cache(maxsize=None)
def _get_nmdc_materialized():
Expand Down

0 comments on commit e91c770

Please sign in to comment.