From ae29087e58f4eca0bd992a48ffc87b95b6ee8075 Mon Sep 17 00:00:00 2001 From: David Wallace Date: Mon, 19 Feb 2024 22:52:47 +0100 Subject: [PATCH] chore: fix load from index --- .../imports/files/file_indexer.py | 43 +++++++++-------- .../imports/files/index_funcs.py | 46 ++----------------- src/raman_fitting/imports/files/utils.py | 12 ++++- src/raman_fitting/imports/models.py | 23 +++++++++- .../imports/spectrum/datafile_parsers.py | 10 ++-- tests/indexing/test_indexer.py | 15 ++++-- 6 files changed, 74 insertions(+), 75 deletions(-) diff --git a/src/raman_fitting/imports/files/file_indexer.py b/src/raman_fitting/imports/files/file_indexer.py index e7de46f..a6c0509 100644 --- a/src/raman_fitting/imports/files/file_indexer.py +++ b/src/raman_fitting/imports/files/file_indexer.py @@ -14,20 +14,24 @@ ) from raman_fitting.config import settings from raman_fitting.imports.collector import collect_raman_file_infos -from raman_fitting.imports.files.utils import load_dataset_from_file +from raman_fitting.imports.files.utils import ( + load_dataset_from_file, + write_dataset_to_file, +) from raman_fitting.imports.models import RamanFileInfo from tablib import Dataset -RamanFileInfoSet: TypeAlias = Sequence[RamanFileInfo] +RamanFileInfoSet: TypeAlias = List[RamanFileInfo] class RamanFileIndex(BaseModel): model_config = ConfigDict(arbitrary_types_allowed=True) index_file: NewPath | FilePath = Field(None, validate_default=False) - raman_files: RamanFileInfoSet = Field(None) - dataset: Dataset = Field(None) - force_reload: bool = Field(True, validate_default=False) + raman_files: RamanFileInfoSet | None = Field(None) + dataset: Dataset | None = Field(None) + force_reindex: bool = Field(False, validate_default=False) + persist_to_file: bool = Field(True, validate_default=False) @model_validator(mode="after") def read_or_load_data(self) -> "RamanFileIndex": @@ -35,7 +39,7 @@ def read_or_load_data(self) -> "RamanFileIndex": raise ValueError("Not all fields should be empty.") reload_from_file = validate_reload_from_index_file( - self.index_file, self.force_reload + self.index_file, self.force_reindex ) if reload_from_file: self.dataset = load_dataset_from_file(self.index_file) @@ -57,20 +61,24 @@ def read_or_load_data(self) -> "RamanFileIndex": raise ValueError( "Index error, both raman_files and dataset are not provided." ) + + if self.persist_to_file and self.index_file is not None: + write_dataset_to_file(self.index_file, self.dataset) + return self def validate_reload_from_index_file( - index_file: Path | None, force_reload: bool + index_file: Path | None, force_reindex: bool ) -> bool: if index_file is None: logger.debug( "Index file not provided, index will not be reloaded or persisted." ) return False - if index_file.exists() and not force_reload: + if index_file.exists() and not force_reindex: return True - elif force_reload: + elif force_reindex: logger.warning( f"Index index_file file {index_file} exists and will be overwritten." ) @@ -90,11 +98,11 @@ def cast_raman_files_to_dataset(raman_files: List[RamanFileInfo]) -> Dataset: def parse_dataset_to_index(dataset: Dataset) -> List[RamanFileInfo]: - raman_file = [] + raman_files = [] for row in dataset: row_data = dict(zip(dataset.headers, row)) - raman_file.append(RamanFileInfo(**row_data)) - return raman_file + raman_files.append(RamanFileInfo(**row_data)) + return raman_files class IndexSelector(BaseModel): @@ -178,16 +186,13 @@ def collect_raman_file_index_info( def initialize_index_from_source_files( - files: Sequence[Path] | None = None, force_reload: bool = False + files: Sequence[Path] | None = None, force_reindex: bool = False ) -> RamanFileIndex: index_file = settings.destination_dir.joinpath("index.csv") raman_files = collect_raman_file_index_info(raman_files=files) - index_data = { - "file": index_file, - "raman_files": raman_files, - "force_reload": force_reload, - } - raman_index = RamanFileIndex(**index_data) + raman_index = RamanFileIndex( + index_file=index_file, raman_files=raman_files, force_reindex=force_reindex + ) logger.info( f"index_delegator index prepared with len {len(raman_index.raman_files)}" ) diff --git a/src/raman_fitting/imports/files/index_funcs.py b/src/raman_fitting/imports/files/index_funcs.py index 6b3af71..0d536a5 100644 --- a/src/raman_fitting/imports/files/index_funcs.py +++ b/src/raman_fitting/imports/files/index_funcs.py @@ -1,11 +1,10 @@ -import logging import sys from pathlib import Path from raman_fitting.imports.spectrum.datafile_parsers import load_dataset_from_file -logger = logging.getLogger(__name__) +from loguru import logger def get_dtypes_filepath(index_file): @@ -35,40 +34,16 @@ def export_index(index, index_file): ) -def load_index(index_file, reload=False): +def load_index(index_file): """loads the index from from defined Index file""" - if not index_file.exists() or reload: + if not index_file.exists(): logger.error( f"Error in load_index: {index_file} does not exists, starting reload index ... " ) return - # index = make_index() - # export_index(index, index_file) - # return index - # breakpoint() try: - _dtypes = load_dataset_from_file( - get_dtypes_filepath(index_file), index_col=[0] - ).to_dict()["dtypes"] - - _dtypes_datetime = { - k: val - for k, val in _dtypes.items() - if "datetime" in val or k.endswith("Date") - } - - _dtypes_no_datetime = { - k: val for k, val in _dtypes.items() if k not in _dtypes_datetime.keys() - } - - index = load_dataset_from_file( - index_file, - index_col=[0], - dtype=_dtypes_no_datetime, - parse_dates=list(_dtypes_datetime.keys()), - ) - # index = _extra_assign_destdir_and_set_paths(index) + index = load_dataset_from_file(index_file) logger.info( f"Succesfully imported Raman Index file from {index_file}, with len({len(index)})" @@ -79,8 +54,6 @@ def load_index(index_file, reload=False): \nlength of loaded index not same as number of raman files \n starting reload index ... """ ) - # index = make_index() - # return index except Exception as e: logger.error( @@ -120,7 +93,7 @@ def index_selection(index, **kwargs): default_selection = kwargs.get("default_selection", "all") if "normal" not in kwargs.get("run_mode", default_selection): default_selection = "all" - index_selection = None # pd.DataFrame() + index_selection = None logger.info( f"starting index selection from index({len(index)}) with:\n default selection: {default_selection}\n and {kwargs}" ) @@ -138,14 +111,8 @@ def index_selection(index, **kwargs): index = list( filter(lambda x: x.sample.group in kwargs["samplegroups"], index) ) - # index_selection = index.loc[ - # index.SampleGroup.str.contains("|".join(kwargs["samplegroups"])) - # ] if "sampleIDs" in kwargs: index = list(filter(lambda x: x.sample.id in kwargs["sampleIDs"], index)) - # index_selection = index.loc[ - # index.SampleID.str.contains("|".join(kwargs["sampleIDs"])) - # ] if "extra" in kwargs: runq = kwargs.get("run") @@ -164,9 +131,6 @@ def index_selection(index, **kwargs): } ) - # if "make_examples" in run_mode: - # index_selection = index.loc[~index.SampleID.str.startswith("Si")] - logger.debug( f"finished index selection from index({len(index)}) with:\n {default_selection}\n and {kwargs}\n selection len({len(index_selection )})" ) diff --git a/src/raman_fitting/imports/files/utils.py b/src/raman_fitting/imports/files/utils.py index 21b66d7..8eb53d8 100644 --- a/src/raman_fitting/imports/files/utils.py +++ b/src/raman_fitting/imports/files/utils.py @@ -2,13 +2,21 @@ from tablib import Dataset +from loguru import logger + def write_dataset_to_file(file: Path, dataset: Dataset) -> None: - with open(file, "wb", encoding="utf-8") as f: - f.write(dataset.export(file.suffix)) + if file.suffix == ".csv": + with open(file, "w", newline="") as f: + f.write(dataset.export("csv")) + else: + with open(file, "wb", encoding="utf-8") as f: + f.write(dataset.export(file.suffix)) + logger.debug(f"Wrote dataset {len(dataset)} to {file}") def load_dataset_from_file(file) -> Dataset: with open(file, "r", encoding="utf-8") as fh: imported_data = Dataset().load(fh) + logger.debug(f"Read dataset {len(imported_data)} from {file}") return imported_data diff --git a/src/raman_fitting/imports/models.py b/src/raman_fitting/imports/models.py index 58e8aac..76ba638 100644 --- a/src/raman_fitting/imports/models.py +++ b/src/raman_fitting/imports/models.py @@ -1,3 +1,4 @@ +import json from pydantic import ( BaseModel, FilePath, @@ -18,8 +19,10 @@ class RamanFileInfo(BaseModel): file: FilePath filename_id: str = Field(None, init_var=False, validate_default=False) - sample: SampleMetaData = Field(None, init_var=False, validate_default=False) - file_metadata: FileMetaData = Field(None, init_var=False, validate_default=False) + sample: SampleMetaData | str = Field(None, init_var=False, validate_default=False) + file_metadata: FileMetaData | str = Field( + None, init_var=False, validate_default=False + ) @model_validator(mode="after") def set_filename_id(self) -> "RamanFileInfo": @@ -38,3 +41,19 @@ def parse_and_set_metadata_from_filepath(self) -> "RamanFileInfo": file_metadata = get_file_metadata(self.file) self.file_metadata = FileMetaData(**file_metadata) return self + + @model_validator(mode="after") + def initialize_sample_and_file_from_dict(self) -> "RamanFileInfo": + if isinstance(self.sample, dict): + self.sample = SampleMetaData(**self.sample) + elif isinstance(self.sample, str): + _sample = json.loads(self.sample.replace("'", '"')) + self.sample = SampleMetaData(**_sample) + + if isinstance(self.file_metadata, dict): + self.file_metadata = FileMetaData(**self.file_metadata) + elif isinstance(self.file_metadata, str): + _file_metadata = json.loads(self.file_metadata.replace("'", '"')) + self.file_metadata = SampleMetaData(**_file_metadata) + + return self diff --git a/src/raman_fitting/imports/spectrum/datafile_parsers.py b/src/raman_fitting/imports/spectrum/datafile_parsers.py index 4af7186..76fbe20 100644 --- a/src/raman_fitting/imports/spectrum/datafile_parsers.py +++ b/src/raman_fitting/imports/spectrum/datafile_parsers.py @@ -1,5 +1,3 @@ -import logging - from typing import List, Sequence from pathlib import Path @@ -7,7 +5,7 @@ import pandas as pd from tablib import Dataset -logger = logging.getLogger(__name__) +from loguru import logger def filter_data_for_numeric(data: Dataset): @@ -16,7 +14,7 @@ def filter_data_for_numeric(data: Dataset): for row in data: try: - digits_row = tuple(map(lambda x: float(x), row)) + digits_row = tuple(map(float, row)) except ValueError: continue except TypeError: @@ -84,7 +82,7 @@ def use_np_loadtxt(filepath, usecols=(0, 1), **kwargs) -> np.array: except Exception as exc: _msg = f"Can not load data from txt file: {filepath}\n{exc}" logger.error(_msg) - raise ValueError(_msg) + raise ValueError(_msg) from exc return array @@ -101,7 +99,7 @@ def cast_array_into_spectrum_frame(array, keys: List[str] = None) -> pd.DataFram except Exception as exc: _msg = f"Can not create DataFrame from array object: {array}\n{exc}" logger.error(_msg) - raise ValueError(_msg) + raise ValueError(_msg) from exc def load_spectrum_from_txt(filepath, **kwargs) -> pd.DataFrame: diff --git a/tests/indexing/test_indexer.py b/tests/indexing/test_indexer.py index fc7b100..b767064 100644 --- a/tests/indexing/test_indexer.py +++ b/tests/indexing/test_indexer.py @@ -9,7 +9,6 @@ RamanFileIndex, initialize_index_from_source_files, ) -from raman_fitting.imports.files.index_funcs import load_index from raman_fitting.imports.models import RamanFileInfo run_mode = RunModes.PYTEST @@ -26,7 +25,7 @@ def setUp(self): self.all_test_files = self.example_files + self.pytest_fixtures_files index = initialize_index_from_source_files( - files=self.all_test_files, force_reload=True + files=self.all_test_files, force_reindex=True ) self.index = index @@ -37,10 +36,16 @@ def test_RamanFileIndex_make_examples(self): self.assertEqual(len(self.index.dataset), len(self.example_files)) - @unittest.skip("export_index not yet implemented") + # @unittest.skip("export_index not yet implemented") def test_load_index(self): - _loaded_index = load_index() - self.assertTrue(isinstance(_loaded_index, RamanFileIndex)) + self.index.index_file.exists() + try: + new_index = RamanFileIndex( + index_file=self.index.index_file, force_reindex=False + ) + except Exception as e: + raise e from e + self.assertTrue(isinstance(new_index, RamanFileIndex)) if __name__ == "__main__":