-
Notifications
You must be signed in to change notification settings - Fork 6
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
David Wallace
committed
Mar 10, 2024
1 parent
0e003e4
commit 6b42cf8
Showing
3 changed files
with
54 additions
and
189 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
120 changes: 43 additions & 77 deletions
120
src/raman_fitting/imports/spectrum/spectra_collection.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,97 +1,63 @@ | ||
import logging | ||
from operator import itemgetter | ||
from typing import Dict, List | ||
from typing import List | ||
|
||
import numpy as np | ||
|
||
from pydantic import BaseModel, ValidationError, model_validator, ConfigDict | ||
from pydantic import BaseModel, ValidationError, model_validator | ||
|
||
from .spectrum_constructor import SpectrumDataLoader | ||
from raman_fitting.models.deconvolution.spectrum_regions import RegionNames | ||
from raman_fitting.models.spectrum import SpectrumData | ||
|
||
logger = logging.getLogger(__name__) | ||
SPECTRUM_KEYS = ("ramanshift", "intensity") | ||
|
||
|
||
class PostProcessedSpectrum(BaseModel): | ||
pass | ||
|
||
|
||
class SpectraDataCollection(BaseModel): | ||
model_config = ConfigDict(arbitrary_types_allowed=True) | ||
spectra: List[SpectrumData] | ||
region_name: RegionNames | ||
mean_spectrum: SpectrumData | None = None | ||
|
||
spectra: List[SpectrumDataLoader] | ||
@model_validator(mode="after") | ||
def check_spectra_have_same_label(self) -> "SpectraDataCollection": | ||
"""checks member of lists""" | ||
labels = set(i.label for i in self.spectra) | ||
if len(labels) > 1: | ||
raise ValidationError(f"Spectra have different labels {labels}") | ||
return self | ||
|
||
@model_validator(mode="after") | ||
def check_spectra_have_clean_spectrum(self) -> "SpectraDataCollection": | ||
def check_spectra_have_same_region(self) -> "SpectraDataCollection": | ||
"""checks member of lists""" | ||
if not all(hasattr(spec, "clean_spectrum") for spec in self.spectra): | ||
raise ValidationError("missing clean_data attribute") | ||
region_names = set(i.region_name for i in self.spectra) | ||
if len(region_names) > 1: | ||
raise ValidationError(f"Spectra have different region_names {region_names}") | ||
return self | ||
|
||
@model_validator(mode="after") | ||
def check_spectra_lengths(self) -> "SpectraDataCollection": | ||
unique_lengths = set([i.spectrum_length for i in self.spectra]) | ||
if len(unique_lengths) > 1: | ||
unique_lengths_rs = set(len(i.ramanshift) for i in self.spectra) | ||
unique_lengths_int = set(len(i.intensity) for i in self.spectra) | ||
if len(unique_lengths_rs) > 1: | ||
raise ValidationError( | ||
f"The spectra have different lenghts where they should be the same.\n\t{unique_lengths}" | ||
f"The spectra have different ramanshift lengths where they should be the same.\n\t{unique_lengths_rs}" | ||
) | ||
if len(unique_lengths_int) > 1: | ||
raise ValidationError( | ||
f"The spectra have different intensity lengths where they should be the same.\n\t{unique_lengths_int}" | ||
) | ||
return self | ||
|
||
|
||
def get_mean_spectra_info(spectra: List[SpectrumDataLoader]) -> Dict: | ||
"""retrieves the info dict from spec instances and merges dict in keys that have 1 common value""" | ||
|
||
all_spec_info = [spec.info for spec in spectra] | ||
_all_spec_info_merged = {k: val for i in all_spec_info for k, val in i.items()} | ||
_all_spec_info_sets = [ | ||
(k, set([i.get(k, None) for i in all_spec_info])) for k in _all_spec_info_merged | ||
] | ||
mean_spec_info = { | ||
k: list(val)[0] for k, val in _all_spec_info_sets if len(val) == 1 | ||
} | ||
mean_spec_info.update({"mean_spectrum": True}) | ||
return mean_spec_info | ||
|
||
|
||
def calculate_mean_spectrum_from_spectra( | ||
spectra: List[SpectrumDataLoader], | ||
) -> Dict[str, SpectrumData]: | ||
"""retrieves the clean data from spec instances and makes lists of tuples""" | ||
|
||
try: | ||
spectra_regions = [i.clean_spectrum.spec_regions for i in spectra] | ||
mean_spec_regions = {} | ||
for region_name in spectra_regions[0].keys(): | ||
regions_specs = [i[region_name] for i in spectra_regions] | ||
ramanshift_mean = np.mean([i.ramanshift for i in regions_specs], axis=0) | ||
intensity_mean = np.mean([i.intensity for i in regions_specs], axis=0) | ||
|
||
_data = { | ||
"ramanshift": ramanshift_mean, | ||
"intensity": intensity_mean, | ||
"label": regions_specs[0].label + "_mean", | ||
"region_name": region_name + "_mean", | ||
} | ||
mean_spec = SpectrumData(**_data) | ||
mean_spec_regions[region_name] = mean_spec | ||
|
||
except Exception: | ||
logger.warning(f"get_mean_spectra_prep_data failed for spectra {spectra}") | ||
mean_spec_regions = {} | ||
|
||
return mean_spec_regions | ||
|
||
|
||
def get_best_guess_spectra_length(spectra: List[SpectrumDataLoader]) -> List: | ||
lengths = [i.spectrum_length for i in spectra] | ||
set_lengths = set(lengths) | ||
if len(set_lengths) == 1: | ||
# print(f'Spectra all same length {set_lengths}') | ||
return spectra | ||
return self | ||
|
||
length_counts = [(i, lengths.count(i)) for i in set_lengths] | ||
best_guess_length = max(length_counts, key=itemgetter(1))[0] | ||
print(f"Spectra not same length {length_counts} took {best_guess_length}") | ||
spectra = [spec for spec in spectra if spec.spectrum_length == best_guess_length] | ||
return spectra | ||
@model_validator(mode="after") | ||
def set_mean_spectrum(self) -> "SpectraDataCollection": | ||
# wrap this in a ProcessedSpectraCollection model | ||
mean_int = np.mean(np.vstack([i.intensity for i in self.spectra]), axis=0) | ||
mean_ramanshift = np.mean( | ||
np.vstack([i.ramanshift for i in self.spectra]), axis=0 | ||
) | ||
source_files = list(set(i.source for i in self.spectra)) | ||
_label = "".join(map(str, set(i.label for i in self.spectra))) | ||
mean_spec = SpectrumData( | ||
ramanshift=mean_ramanshift, | ||
intensity=mean_int, | ||
label=f"clean_{self.region_name}_mean", | ||
region_name=self.region_name, | ||
source=source_files, | ||
) | ||
self.mean_spectrum = mean_spec |
85 changes: 0 additions & 85 deletions
85
src/raman_fitting/imports/spectrum/spectrum_constructor.py
This file was deleted.
Oops, something went wrong.