Skip to content

Commit

Permalink
chore: clean up preprocessing
Browse files Browse the repository at this point in the history
  • Loading branch information
David Wallace committed Mar 10, 2024
1 parent 0e003e4 commit 6b42cf8
Show file tree
Hide file tree
Showing 3 changed files with 54 additions and 189 deletions.
38 changes: 11 additions & 27 deletions src/raman_fitting/delegating/pre_processing.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,5 @@
from typing import List

from raman_fitting.models.spectrum import SpectrumData
from raman_fitting.models.splitter import RegionNames
from raman_fitting.imports.spectrumdata_parser import SpectrumReader
from raman_fitting.processing.post_processing import SpectrumProcessor
Expand All @@ -10,17 +9,17 @@
PreparedSampleSpectrum,
)

import numpy as np
from loguru import logger

from raman_fitting.config.path_settings import CLEAN_SPEC_REGION_NAME_PREFIX
from ..imports.spectrum.spectra_collection import SpectraDataCollection


def prepare_aggregated_spectrum_from_files(
region_name: RegionNames, raman_files: List[RamanFileInfo]
) -> AggregatedSampleSpectrum | None:
selected_processed_data = f"{CLEAN_SPEC_REGION_NAME_PREFIX}{region_name}"
clean_data_for_region = {}
select_region_key = f"{CLEAN_SPEC_REGION_NAME_PREFIX}{region_name}"
clean_data_for_region = []
data_sources = []
for i in raman_files:
read = SpectrumReader(i.file)
Expand All @@ -29,32 +28,17 @@ def prepare_aggregated_spectrum_from_files(
file_info=i, read=read, processed=processed
)
data_sources.append(prepared_spec)
selected_clean_data = processed.clean_spectrum.spec_regions[
selected_processed_data
]
clean_data_for_region[i.file] = selected_clean_data

selected_clean_data = processed.clean_spectrum.spec_regions[select_region_key]
clean_data_for_region.append(selected_clean_data)
if not clean_data_for_region:
logger.info("prepare_mean_data_for_fitting received no files.")
logger.warning(
f"prepare_mean_data_for_fitting received no files. {region_name}"
)
return
# wrap this in a ProcessedSpectraCollection model
mean_int = np.mean(
np.vstack([i.intensity for i in clean_data_for_region.values()]), axis=0
)
mean_ramanshift = np.mean(
np.vstack([i.ramanshift for i in clean_data_for_region.values()]), axis=0
)
source_files = list(map(str, clean_data_for_region.keys()))
mean_spec = SpectrumData(
**{
"ramanshift": mean_ramanshift,
"intensity": mean_int,
"label": f"clean_{region_name}_mean",
"region_name": region_name,
"source": source_files,
}
spectra_collection = SpectraDataCollection(
spectra=clean_data_for_region, region_name=region_name
)
aggregated_spectrum = AggregatedSampleSpectrum(
sources=data_sources, spectrum=mean_spec
sources=data_sources, spectrum=spectra_collection.mean_spectrum
)
return aggregated_spectrum
120 changes: 43 additions & 77 deletions src/raman_fitting/imports/spectrum/spectra_collection.py
Original file line number Diff line number Diff line change
@@ -1,97 +1,63 @@
import logging
from operator import itemgetter
from typing import Dict, List
from typing import List

import numpy as np

from pydantic import BaseModel, ValidationError, model_validator, ConfigDict
from pydantic import BaseModel, ValidationError, model_validator

from .spectrum_constructor import SpectrumDataLoader
from raman_fitting.models.deconvolution.spectrum_regions import RegionNames
from raman_fitting.models.spectrum import SpectrumData

logger = logging.getLogger(__name__)
SPECTRUM_KEYS = ("ramanshift", "intensity")


class PostProcessedSpectrum(BaseModel):
pass


class SpectraDataCollection(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)
spectra: List[SpectrumData]
region_name: RegionNames
mean_spectrum: SpectrumData | None = None

spectra: List[SpectrumDataLoader]
@model_validator(mode="after")
def check_spectra_have_same_label(self) -> "SpectraDataCollection":
"""checks member of lists"""
labels = set(i.label for i in self.spectra)
if len(labels) > 1:
raise ValidationError(f"Spectra have different labels {labels}")
return self

@model_validator(mode="after")
def check_spectra_have_clean_spectrum(self) -> "SpectraDataCollection":
def check_spectra_have_same_region(self) -> "SpectraDataCollection":
"""checks member of lists"""
if not all(hasattr(spec, "clean_spectrum") for spec in self.spectra):
raise ValidationError("missing clean_data attribute")
region_names = set(i.region_name for i in self.spectra)
if len(region_names) > 1:
raise ValidationError(f"Spectra have different region_names {region_names}")
return self

@model_validator(mode="after")
def check_spectra_lengths(self) -> "SpectraDataCollection":
unique_lengths = set([i.spectrum_length for i in self.spectra])
if len(unique_lengths) > 1:
unique_lengths_rs = set(len(i.ramanshift) for i in self.spectra)
unique_lengths_int = set(len(i.intensity) for i in self.spectra)
if len(unique_lengths_rs) > 1:
raise ValidationError(
f"The spectra have different lenghts where they should be the same.\n\t{unique_lengths}"
f"The spectra have different ramanshift lengths where they should be the same.\n\t{unique_lengths_rs}"
)
if len(unique_lengths_int) > 1:
raise ValidationError(
f"The spectra have different intensity lengths where they should be the same.\n\t{unique_lengths_int}"
)
return self


def get_mean_spectra_info(spectra: List[SpectrumDataLoader]) -> Dict:
"""retrieves the info dict from spec instances and merges dict in keys that have 1 common value"""

all_spec_info = [spec.info for spec in spectra]
_all_spec_info_merged = {k: val for i in all_spec_info for k, val in i.items()}
_all_spec_info_sets = [
(k, set([i.get(k, None) for i in all_spec_info])) for k in _all_spec_info_merged
]
mean_spec_info = {
k: list(val)[0] for k, val in _all_spec_info_sets if len(val) == 1
}
mean_spec_info.update({"mean_spectrum": True})
return mean_spec_info


def calculate_mean_spectrum_from_spectra(
spectra: List[SpectrumDataLoader],
) -> Dict[str, SpectrumData]:
"""retrieves the clean data from spec instances and makes lists of tuples"""

try:
spectra_regions = [i.clean_spectrum.spec_regions for i in spectra]
mean_spec_regions = {}
for region_name in spectra_regions[0].keys():
regions_specs = [i[region_name] for i in spectra_regions]
ramanshift_mean = np.mean([i.ramanshift for i in regions_specs], axis=0)
intensity_mean = np.mean([i.intensity for i in regions_specs], axis=0)

_data = {
"ramanshift": ramanshift_mean,
"intensity": intensity_mean,
"label": regions_specs[0].label + "_mean",
"region_name": region_name + "_mean",
}
mean_spec = SpectrumData(**_data)
mean_spec_regions[region_name] = mean_spec

except Exception:
logger.warning(f"get_mean_spectra_prep_data failed for spectra {spectra}")
mean_spec_regions = {}

return mean_spec_regions


def get_best_guess_spectra_length(spectra: List[SpectrumDataLoader]) -> List:
lengths = [i.spectrum_length for i in spectra]
set_lengths = set(lengths)
if len(set_lengths) == 1:
# print(f'Spectra all same length {set_lengths}')
return spectra
return self

length_counts = [(i, lengths.count(i)) for i in set_lengths]
best_guess_length = max(length_counts, key=itemgetter(1))[0]
print(f"Spectra not same length {length_counts} took {best_guess_length}")
spectra = [spec for spec in spectra if spec.spectrum_length == best_guess_length]
return spectra
@model_validator(mode="after")
def set_mean_spectrum(self) -> "SpectraDataCollection":
# wrap this in a ProcessedSpectraCollection model
mean_int = np.mean(np.vstack([i.intensity for i in self.spectra]), axis=0)
mean_ramanshift = np.mean(
np.vstack([i.ramanshift for i in self.spectra]), axis=0
)
source_files = list(set(i.source for i in self.spectra))
_label = "".join(map(str, set(i.label for i in self.spectra)))
mean_spec = SpectrumData(
ramanshift=mean_ramanshift,
intensity=mean_int,
label=f"clean_{self.region_name}_mean",
region_name=self.region_name,
source=source_files,
)
self.mean_spectrum = mean_spec
85 changes: 0 additions & 85 deletions src/raman_fitting/imports/spectrum/spectrum_constructor.py

This file was deleted.

0 comments on commit 6b42cf8

Please sign in to comment.