Skip to content

Commit

Permalink
chore: refactor fit run
Browse files Browse the repository at this point in the history
  • Loading branch information
MyPyDavid committed Feb 11, 2024
1 parent 4a498b3 commit 9b6f35e
Show file tree
Hide file tree
Showing 2 changed files with 39 additions and 173 deletions.
16 changes: 16 additions & 0 deletions src/raman_fitting/models/deconvolution/prepare_parameters.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,7 @@
from raman_fitting.models.splitter import WindowNames


# TODO add params to run fit post processing
def add_D_G_ratios(a: str, t: str, peaks, result):
RatioParams = {}
if {"G_", "D_"}.issubset(peaks):
Expand Down Expand Up @@ -60,3 +64,15 @@ def add_D1D1_GD1_ratio(
RatioParams.update({f"{a}D1D1/{a}GD1": result["D1D1" + t] / result["GD1" + t]})
if extra_fit_results:
RatioParams.update(add_ratio_combined_params(window_name, a, t))


def add_ratio_combined_params_second_order(
result, model_result, extra_fit_results, a, t
):
_2nd = WindowNames.second_order
if model_result._modelname.startswith("first") and _2nd in extra_fit_results.keys():
_D1D1 = extra_fit_results[_2nd].FitParameters.loc[f"Model_{_2nd}", "D1D1" + t]
result.update({"D1D1" + t: _D1D1})
return {f"Leq_{a}": 8.8 * _D1D1 / result["D" + t]}
else:
return {}
196 changes: 23 additions & 173 deletions src/raman_fitting/models/fit_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,18 @@
logger = logging.getLogger(__name__)


FIT_RESULT_ATTR_LIST = (
"chisqr",
"redchi",
"bic",
"aic",
"method",
"message",
"success",
"nfev",
)


class SpectrumFitModel(BaseModel):
model_config = ConfigDict(arbitrary_types_allowed=True)

Expand Down Expand Up @@ -45,6 +57,17 @@ def run_fit(self) -> ModelResult:
self.elapsed_time = elapsed_seconds
self.fit_result = fit_result

def process_fit_results(self):
# TODO add parameter post processing steps
self.fit_result

fit_attrs = {
f"lmfit_{i}": getattr(self.fit_result, i) for i in FIT_RESULT_ATTR_LIST
}
self.add_ratio_params()

self.result.update(fit_attrs)


def run_fit(
model: LMFitModel, spectrum: SpectrumData, method="leastsq", **kws
Expand All @@ -56,147 +79,7 @@ def run_fit(
return out


class Fitter:
"""
Fitter class for executing the fitting functions and optimizations
IDEA: implement sensing of spectrum for Si samples
"""

fit_windows = ["first_order", "second_order"]

def __init__(self, spectra_arg, models=None, start_fit=True):
self._qcnm = self.__class__.__qualname__
logger.debug(f"{self._qcnm} is called with spectrum\n\t{spectra_arg}\n")
self.start_fit = start_fit
self.models = models

self.spectra_arg = spectra_arg
self.spectra = spectra_arg
self.fit_delegator()

@property
def spectra(self):
return self._spectra

@spectra.setter
def spectra(self, value):
"""Checks if value is dict or else takes a dict from class instance value"""

_errtxt = f"This assignment {value} does not contain valid spectra"
if isinstance(value, dict):
_data = value
elif isinstance(value, "SpectrumDataCollection"):
_data = value.mean_data
_fit_lbl = "mean"
elif isinstance(value, "SpectrumDataLoader"):
_data = value.clean_df
_fit_lbl = "int"
else:
raise ValueError(_errtxt)

_specs = {
k: val
for k, val in _data.items()
if k in self.fit_windows and isinstance(val, pd.DataFrame)
}
# assert bool(_specs), _errtxt
if not _specs:
self.start_fit = False

self._spectra = _specs
self.FitResults = {}
info = {}
if hasattr(value, "info"):
info = {**info, **value.info}
self.info = info

def fit_delegator(self):
if self.start_fit:
logger.info(
f"\n{self._qcnm} is starting to fit the models on spectrum:\n\t{self.info.get('SampleID','no name')}"
)

self.fit_models(self.models.second_order) # second order should go first
logger.info(
f"\t - second order finished, {len(self.models.second_order)} model"
)
# rum:\t{self.info.get('SampleID','no name')}\n")
self.fit_models(self.models.first_order)
logger.info(
f"\t - first order finished, {len(self.models.first_order)} models"
)

def fit_models(self, model_selection):
_fittings = {}
logger.debug(f"{self._qcnm} fit_models starting.")
for modname, model in model_selection.items():
modname, model
_windowname = [i for i in self.fit_windows if modname[0:3] in i][0]
_data = self.spectra.get(_windowname)
_int_lbl = self.get_int_label(_data)
try:
out = self.run_fit(
model.lmfit_model,
_data,
_int_lbl=_int_lbl,
_modelname=modname,
_info=self.info,
)
prep = PrepareParams(out, extra_fit_results=self.FitResults)
_fittings.update({modname: prep.FitResult})
except Exception as e:
logger.warning(
f"{self._qcnm} fit_model failed for {modname}: {model}, because:\n {e}"
)

self.FitResults.update(**_fittings)

def run_fit(self, model, _data, method="leastsq", **kws):
# ideas: improve fitting loop so that starting parameters from modelX and modelX+Si are shared, faster...
init_params = model.make_params()
x, y = _data.ramanshift, _data[kws.get("_int_lbl")]
out = model.fit(y, init_params, x=x, method=method) # 'leastsq'
for k, val in kws.items():
if not hasattr(out, k):
_attrkey = k
elif not hasattr(out, f"_{k}"):
_attrkey = f"_{k}"
else:
_attrkey = None
if _attrkey:
setattr(out, _attrkey, val)
return out

def get_int_label(self, value: pd.DataFrame):
_lbl = ""
if not isinstance(value, pd.DataFrame):
return _lbl
cols = [i for i in value.columns if "ramanshift" not in i]
if not cols:
return _lbl

if len(cols) == 1:
_lbl = cols[0]
elif len(cols) > 1:
if any("mean" in i for i in cols):
_lbl = [i for i in cols if "mean" in i][0]
elif any("int" in i for i in cols):
_lbl = [i for i in cols if "int" in i][0]
return _lbl


class PrepareParams:
fit_attr_export_lst = (
"chisqr",
"redchi",
"bic",
"aic",
"method",
"message",
"success",
"nfev",
)
fit_result_template = namedtuple(
"FitResult",
[
Expand All @@ -212,44 +95,18 @@ class PrepareParams:
_standard_2nd_order = "2nd_4peaks"

def __init__(self, model_result, extra_fit_results={}):
self._qcnm = self.__class__.__qualname__
logger.debug(f"{self._qcnm} is called with model_result\n\t{model_result}\n")
self.extra_fit_results = extra_fit_results
self.model_result = model_result

@property
def model_result(self):
return self._model_result

@model_result.setter
def model_result(self, value):
"""
Takes the ModelResult class instance from lmfit.
Optional extra functionality with a list of instances.
"""
self.result = {}

if "ModelResult" in type(value).__name__:
self.result.update(value.params.valuesdict())
self.comps = value.model.components
elif ("list" or "tuple") in type(value).__name__:
assert all("ModelResult" in type(i).__name__ for i in value)
[self.result.update(mod.params.valuesdict()) for mod in value]
self.comps = [i for mod in value for i in mod.model.components]

self.peaks = set(
[i.prefix for i in self.comps]
) # peaks is prefix from components

_mod_lbl = "Model"
if hasattr(value, "_modelname"):
_mod_lbl = f'Model_{getattr(value,"_modelname")}'
self.model_name_lbl = _mod_lbl

self.raw_data_lbl = value._int_lbl

self._model_result = value

self.make_result()

def make_result(self):
Expand Down Expand Up @@ -280,13 +137,6 @@ def prep_extra_info(self):
}

def prep_params(self):
fit_attrs = OrderedDict(
zip(
[f"lmfit_{i}" for i in self.fit_attr_export_lst],
[getattr(self.model_result, i) for i in self.fit_attr_export_lst],
)
)
self.result.update(fit_attrs)
try:
self.add_ratio_params()
except Exception as e:
Expand Down

0 comments on commit 9b6f35e

Please sign in to comment.