Skip to content

Commit

Permalink
refactor and rename funcs and vars
Browse files Browse the repository at this point in the history
  • Loading branch information
MyPyDavid committed Feb 17, 2024
1 parent 149ce5b commit 3d4493b
Show file tree
Hide file tree
Showing 7 changed files with 79 additions and 189 deletions.
8 changes: 4 additions & 4 deletions src/raman_fitting/delegating/main_delegator.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from raman_fitting.exports.exporter import ExportManager
from raman_fitting.imports.files.file_indexer import (
RamanFileIndex,
initialize_index,
initialize_index_from_source_files,
groupby_sample_group,
groupby_sample_id,
IndexSelector,
Expand Down Expand Up @@ -62,7 +62,7 @@ class MainDelegator:
fit_model_specific_names: Sequence[str] | None = None
sample_IDs: Sequence[str] = field(default_factory=list)
sample_groups: Sequence[str] = field(default_factory=list)
index: RamanFileIndex = field(default_factory=initialize_index)
index: RamanFileIndex = field(default_factory=initialize_index_from_source_files)

selection: Sequence[RamanFileInfo] = field(init=False)
selected_models: Sequence[RamanFileInfo] = field(init=False)
Expand Down Expand Up @@ -125,7 +125,7 @@ def select_fitting_model(
try:
return self.lmfit_models[window_name][model_name]
except KeyError as exc:
raise ValueError(f"Model {window_name} {model_name} not found.") from exc
raise KeyError(f"Model {window_name} {model_name} not found.") from exc

def main_run(self):
selection = self.select_samples_from_index()
Expand Down Expand Up @@ -214,4 +214,4 @@ def make_examples():


if __name__ == "__main__":
RamanIndex = make_examples()
example_run = make_examples()
64 changes: 36 additions & 28 deletions src/raman_fitting/imports/files/file_indexer.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,8 @@
""" Indexer for raman data files """
from itertools import groupby, filterfalse
from typing import List, Sequence
from typing import TypeAlias

from pathlib import Path

from pydantic import (
Expand All @@ -16,58 +18,62 @@
from raman_fitting.imports.collector import collect_raman_file_infos
from raman_fitting.imports.models import RamanFileInfo

from .utils import load_dataset_from_file
from raman_fitting.imports.files.utils import load_dataset_from_file


from loguru import logger
from tablib import Dataset


RamanFileInfoSet: TypeAlias = Sequence[RamanFileInfo]


class RamanFileIndex(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)

source: NewPath | FilePath = Field(None, validate_default=False)
raman_files: Sequence[RamanFileInfo] = Field(None)
index_file: NewPath | FilePath = Field(None, validate_default=False)
raman_files: RamanFileInfoSet = Field(None)
dataset: Dataset = Field(None)
force_reload: bool = Field(True, validate_default=False)

@model_validator(mode="after")
def read_or_load_data(self) -> "RamanFileIndex":
if not any([self.source, self.raman_files, self.dataset]):
if not any([self.index_file, self.raman_files, self.dataset]):
raise ValueError("Not all fields should be empty.")

if self.source is not None:
if self.source.exists() and not self.force_reload:
self.dataset = load_dataset_from_file(self.source)
if self.index_file is not None:
if self.index_file.exists() and not self.force_reload:
self.dataset = load_dataset_from_file(self.index_file)
self.raman_files = parse_dataset_to_index(self.dataset)
return self
elif self.source.exists() and self.force_reload:
logger.info(
f"Index source file {self.source} exists and will be overwritten."
elif self.index_file.exists() and self.force_reload:
logger.warning(
f"Index index_file file {self.index_file} exists and will be overwritten."
)
elif not self.source.exists() and self.force_reload:
elif not self.index_file.exists() and self.force_reload:
logger.info(
"Index source file does not exists but was asked to reload from it."
"Index index_file file does not exists but was asked to reload from it."
)
elif not self.source.exists() and not self.force_reload:
elif not self.index_file.exists() and not self.force_reload:
pass
else:
logger.debug("Index source file not provided.")
logger.debug("Index file not provided, index will not be persisted.")

if self.raman_files is not None:
if self.dataset is None:
self.dataset = cast_raman_files_to_dataset(self.raman_files)
else:
dataset_rf = cast_raman_files_to_dataset(self.raman_files)
dataset_rf = cast_raman_files_to_dataset(self.raman_files)
if self.dataset is not None:
assert (
dataset_rf == self.dataset
), "Both dataset and raman_files provider but are different."
if self.dataset is not None:
self.raman_files = parse_dataset_to_index(self.dataset)
elif self.dataset is None and self.raman_files is None:
raise ValueError(
"Index error, both raman_files and dataset are not provided."
)
), "Both dataset and raman_files provided and they are different."
self.dataset = dataset_rf
else:
if self.dataset is not None:
self.raman_files = parse_dataset_to_index(self.dataset)
else:
raise ValueError(
"Index error, both raman_files and dataset are not provided."
)
return self


def cast_raman_files_to_dataset(raman_files: List[RamanFileInfo]) -> Dataset:
Expand Down Expand Up @@ -157,7 +163,7 @@ def select_index(

def collect_raman_file_index_info(
raman_files: Sequence[Path] | None = None, **kwargs
) -> List[RamanFileInfo]:
) -> RamanFileInfoSet:
"""loops over the files and scrapes the index data from each file"""
if not raman_files:
raman_files = list(settings.internal_paths.example_fixtures.glob("*.txt"))
Expand All @@ -166,7 +172,9 @@ def collect_raman_file_index_info(
return index


def initialize_index(files: Sequence[Path] | None = None, force_reload: bool = False):
def initialize_index_from_source_files(
files: Sequence[Path] | None = None, force_reload: bool = False
) -> RamanFileIndex:
index_file = settings.destination_dir.joinpath("index.csv")
raman_files = collect_raman_file_index_info(raman_files=files)
index_data = {
Expand Down Expand Up @@ -199,4 +207,4 @@ def main():


if __name__ == "__main__":
RamanIndex = main()
main()
14 changes: 2 additions & 12 deletions src/raman_fitting/imports/models.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
ConfigDict,
)

from .samples.sample_id_helpers import parse_sample_from_filepath
from .samples.sample_id_helpers import extract_sample_metadata_from_filepath

from .files.metadata import FileMetaData, get_file_metadata
from .files.index_helpers import get_filename_id_from_path
Expand All @@ -29,7 +29,7 @@ def set_filename_id(self) -> "RamanFileInfo":

@model_validator(mode="after")
def parse_and_set_sample_from_file(self) -> "RamanFileInfo":
sample = parse_sample_from_filepath(self.file)
sample = extract_sample_metadata_from_filepath(self.file)
self.sample = sample
return self

Expand All @@ -38,13 +38,3 @@ def parse_and_set_metadata_from_filepath(self) -> "RamanFileInfo":
file_metadata = get_file_metadata(self.file)
self.file_metadata = FileMetaData(**file_metadata)
return self


# def extra_assign_export_dir_on_index(result_dir, index: List[RamanFileInfo]):
# """assign the DestDir column to index and sets column values as object type"""
# _index = []

# for rf_info in index:
# rf_info.export_dir = result_dir.joinpath(rf_info.sample.group)
# _index.append(rf_info)
# return _index
58 changes: 32 additions & 26 deletions src/raman_fitting/imports/samples/sample_id_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,66 +35,72 @@ def parse_string_to_sample_id_and_position(
_lensplit = len(split)

if _lensplit == 0:
sID, position = split[0], 0
sample_id, position = split[0], 0
elif len(split) == 1:
sID, position = split[0], 0
sample_id, position = split[0], 0
elif len(split) == 2:
sID = split[0]
sample_id = split[0]
_pos_strnum = "".join(i for i in split[1] if i.isnumeric())
if _pos_strnum:
position = int(_pos_strnum)
else:
position = split[1]
elif len(split) >= 3:
sID = "_".join(split[0:-1])
sample_id = "_".join(split[0:-1])
position = int("".join(filter(str.isdigit, split[-1])))
return (sID, position)
return (sample_id, position)


def sID_to_sgrpID(sID: str, max_len=4) -> str:
def extract_sample_group_from_sample_id(sample_id: str, max_len=4) -> str:
"""adding the extra sample Group key from sample ID"""

_len = len(sID)
_len = len(sample_id)
_maxalphakey = min(
[n for n, i in enumerate(sID) if not str(i).isalpha()], default=_len
[n for n, i in enumerate(sample_id) if not str(i).isalpha()], default=_len
)
_maxkey = min((_len, _maxalphakey, max_len))
sgrpID = "".join([i for i in sID[0:_maxkey] if i.isalpha()])
return sgrpID
sample_group_id = "".join([i for i in sample_id[0:_maxkey] if i.isalpha()])
return sample_group_id


def overwrite_sID_from_mapper(sID: str, mapper: dict) -> str:
"""Takes an sID and potentially overwrites from a mapper dict"""
_sID_map = mapper.get(sID, None)
if _sID_map:
sID = _sID_map
return sID
def overwrite_sample_id_from_mapper(sample_id: str, mapper: dict) -> str:
"""Takes an sample_id and potentially overwrites from a mapper dict"""
sample_id_map = mapper.get(sample_id)
if sample_id_map is not None:
return sample_id_map
return sample_id


def overwrite_sgrpID_from_parts(parts: List[str], sgrpID: str, mapper: dict) -> str:
def overwrite_sample_group_id_from_parts(
parts: List[str], sample_group_id: str, mapper: dict
) -> str:
for k, val in mapper.items():
if k in parts:
sgrpID = val
return sgrpID
sample_group_id = val
return sample_group_id


def parse_sample_from_filepath(
def extract_sample_metadata_from_filepath(
filepath: Path, sample_name_mapper: Optional[Dict[str, Dict[str, str]]] = None
) -> SampleMetaData:
"""parse the sID, position and sgrpID from stem"""
"""parse the sample_id, position and sgrpID from stem"""
stem = filepath.stem
parts = filepath.parts

sID, position = parse_string_to_sample_id_and_position(stem)
sample_id, position = parse_string_to_sample_id_and_position(stem)

if sample_name_mapper is not None:
sample_id_mapper = sample_name_mapper.get("sample_id", {})
sID = overwrite_sID_from_mapper(sID, sample_id_mapper)
sgrpID = sID_to_sgrpID(sID)
sample_id = overwrite_sample_id_from_mapper(sample_id, sample_id_mapper)
sample_group_id = extract_sample_group_from_sample_id(sample_id)

if sample_name_mapper is not None:
sample_grp_mapper = sample_name_mapper.get("sample_group_id", {})
sgrpID = overwrite_sgrpID_from_parts(parts, sgrpID, sample_grp_mapper)
sample_group_id = overwrite_sample_group_id_from_parts(
parts, sample_group_id, sample_grp_mapper
)

sample = SampleMetaData(**{"id": sID, "group": sgrpID, "position": position})
sample = SampleMetaData(
**{"id": sample_id, "group": sample_group_id, "position": position}
)
return sample
4 changes: 3 additions & 1 deletion src/raman_fitting/interfaces/cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@
import argparse

from raman_fitting.config.settings import RunModes
from loguru import logger


_RUN_MODES = ["normal", "testing", "debug", "make_index", "make_examples"]

Expand Down Expand Up @@ -88,6 +90,6 @@ def main():
extra_kwargs.update(
{"fit_model_specific_names": ["2peaks", "3peaks", "4peaks"]}
)
print(f"CLI args: {args}")
logger.info(f"Starting raman_fitting with CLI args:\n{args}")
kwargs = {**vars(args), **extra_kwargs}
_main_run = rf.MainDelegator(**kwargs)
1 change: 0 additions & 1 deletion src/raman_fitting/models/deconvolution/base_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -153,7 +153,6 @@ def get_models_and_peaks_from_definitions(

def main():
models = get_models_and_peaks_from_definitions()
# breakpoint()
print("Models: ", len(models))


Expand Down
Loading

0 comments on commit 3d4493b

Please sign in to comment.