Skip to content

Commit

Permalink
chore: add minor changes and fixes
Browse files Browse the repository at this point in the history
  • Loading branch information
David Wallace authored and MyPyDavid committed Mar 10, 2024
1 parent b770780 commit 8400ef4
Show file tree
Hide file tree
Showing 7 changed files with 33 additions and 18 deletions.
16 changes: 11 additions & 5 deletions src/raman_fitting/config/base_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,13 +11,21 @@
from raman_fitting.models.deconvolution.base_model import (
get_models_and_peaks_from_definitions,
)
from raman_fitting.models.deconvolution.spectrum_regions import get_default_regions_from_toml_files
from raman_fitting.models.deconvolution.spectrum_regions import (
get_default_regions_from_toml_files,
)
from .default_models import load_config_from_toml_files
from .path_settings import create_default_package_dir_or_ask, InternalPathSettings


def get_default_models_and_peaks_from_definitions():
models_and_peaks_definitions = load_config_from_toml_files()
return get_models_and_peaks_from_definitions(models_and_peaks_definitions)


class Settings(BaseSettings):
default_models: Dict[str, Dict[str, BaseLMFitModel]] = Field(
default_factory=get_models_and_peaks_from_definitions,
default_factory=get_default_models_and_peaks_from_definitions,
alias="my_default_models",
init_var=False,
validate_default=False,
Expand All @@ -28,7 +36,5 @@ class Settings(BaseSettings):
init_var=False,
validate_default=False,
)
destination_dir: Path = Field(
default_factory=create_default_package_dir_or_ask
)
destination_dir: Path = Field(default_factory=create_default_package_dir_or_ask)
internal_paths: InternalPathSettings = Field(default_factory=InternalPathSettings)
2 changes: 2 additions & 0 deletions src/raman_fitting/config/path_settings.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,8 @@ class RunModes(StrEnum):
def get_run_mode_paths(run_mode: RunModes, user_package_home: Path = None):
if user_package_home is None:
user_package_home = USER_HOME_PACKAGE
if isinstance(run_mode, str):
run_mode = RunModes(run_mode)

RUN_MODE_PATHS = {
RunModes.PYTEST.name: {
Expand Down
12 changes: 4 additions & 8 deletions src/raman_fitting/delegating/main_delegator.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,16 @@
from dataclasses import dataclass, field
from typing import Dict, List, Sequence, Any


from raman_fitting.config.path_settings import (
RunModes,
ERROR_MSG_TEMPLATE,
initialize_run_mode_paths,
)
from raman_fitting.config import settings

from raman_fitting.imports.models import RamanFileInfo

from raman_fitting.models.deconvolution.base_model import BaseLMFitModel
from raman_fitting.models.deconvolution.base_model import (
get_models_and_peaks_from_definitions,
)
from raman_fitting.models.splitter import RegionNames
from raman_fitting.exports.exporter import ExportManager
from raman_fitting.imports.files.file_indexer import (
Expand All @@ -32,7 +29,7 @@
prepare_aggregated_spectrum_from_files,
)
from raman_fitting.types import LMFitModelCollection
from .run_fit_spectrum import run_fit_over_selected_models
from raman_fitting.delegating.run_fit_spectrum import run_fit_over_selected_models


from loguru import logger
Expand All @@ -51,7 +48,7 @@ class MainDelegator:
run_mode: RunModes
use_multiprocessing: bool = False
lmfit_models: LMFitModelCollection = field(
default_factory=get_models_and_peaks_from_definitions
default_factory=lambda: settings.default_models
)
fit_model_region_names: Sequence[RegionNames] = field(
default=(RegionNames.first_order, RegionNames.second_order)
Expand All @@ -60,7 +57,6 @@ class MainDelegator:
sample_ids: Sequence[str] = field(default_factory=list)
sample_groups: Sequence[str] = field(default_factory=list)
index: RamanFileIndex = None

selection: Sequence[RamanFileInfo] = field(init=False)
selected_models: Sequence[RamanFileInfo] = field(init=False)

Expand Down Expand Up @@ -191,7 +187,7 @@ def get_results_over_selected_models(
def make_examples():
# breakpoint()
_main_run = MainDelegator(
run_mode="pytest", fit_model_specific_names=["2peaks", "4peaks"]
run_mode="pytest", fit_model_specific_names=["2peaks", "3peaks", "2nd_4peaks"]
)
_main_run.main_run()
return _main_run
Expand Down
14 changes: 13 additions & 1 deletion src/raman_fitting/imports/files/file_indexer.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@
from raman_fitting.imports.models import RamanFileInfo
from tablib import Dataset

from raman_fitting.imports.spectrum import SPECTRUM_FILETYPE_PARSERS

RamanFileInfoSet: TypeAlias = Sequence[RamanFileInfo]


Expand Down Expand Up @@ -180,7 +182,16 @@ def collect_raman_file_index_info(
raman_files: Sequence[Path] | None = None, **kwargs
) -> RamanFileInfoSet:
"""loops over the files and scrapes the index data from each file"""
index, files = collect_raman_file_infos(raman_files, **kwargs)
raman_files = list(raman_files)
total_files = []
dirs = [i for i in raman_files if i.is_dir()]
files = [i for i in raman_files if i.is_file()]
total_files += files
suffixes = [i.lstrip(".") for i in SPECTRUM_FILETYPE_PARSERS.keys()]
for d1 in dirs:
paths = [path for i in suffixes for path in d1.glob(f"*.{i}")]
total_files += paths
index, files = collect_raman_file_infos(total_files, **kwargs)
logger.info(f"successfully made index {len(index)} from {len(files)} files")
return index

Expand All @@ -191,6 +202,7 @@ def initialize_index_from_source_files(
force_reindex: bool = False,
) -> RamanFileIndex:
raman_files = collect_raman_file_index_info(raman_files=files)
# breakpoint()
raman_index = RamanFileIndex(
index_file=index_file, raman_files=raman_files, force_reindex=force_reindex
)
Expand Down
3 changes: 2 additions & 1 deletion src/raman_fitting/imports/samples/sample_id_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

def parse_string_to_sample_id_and_position(
string: str, seps=("_", " ", "-")
) -> Tuple[str, str]:
) -> Tuple[str, int]:
"""
Parser for the filenames -> finds SampleID and sample position
Expand Down Expand Up @@ -48,6 +48,7 @@ def parse_string_to_sample_id_and_position(
elif len(split) >= 3:
sample_id = "_".join(split[0:-1])
position = int("".join(filter(str.isdigit, split[-1])))
position = position or 0
return (sample_id, position)


Expand Down
2 changes: 0 additions & 2 deletions src/raman_fitting/imports/spectrum/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,5 +15,3 @@
"method": read_file_with_tablib,
},
}

supported_filetypes = [".txt"]
2 changes: 1 addition & 1 deletion src/raman_fitting/utils/string_operations.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,7 +4,7 @@
def join_prefix_suffix(prefix: str, suffix: str) -> str:
prefix_ = prefix.rstrip("_")
suffix_ = suffix.lstrip("_")
if prefix.endswith(suffix_):
if suffix_ in prefix:
return prefix_
return f"{prefix_}_{suffix_}"

Expand Down

0 comments on commit 8400ef4

Please sign in to comment.