From a589d63ee32bd7fb610ee7515bbbcbbc5577bd8e Mon Sep 17 00:00:00 2001 From: David Wallace Date: Sun, 18 Feb 2024 23:03:53 +0100 Subject: [PATCH] refactor fit spectrum plot --- .../exports/plotting_fit_results.py | 327 +++++++++++------- 1 file changed, 198 insertions(+), 129 deletions(-) diff --git a/src/raman_fitting/exports/plotting_fit_results.py b/src/raman_fitting/exports/plotting_fit_results.py index 321a225..9b36987 100644 --- a/src/raman_fitting/exports/plotting_fit_results.py +++ b/src/raman_fitting/exports/plotting_fit_results.py @@ -1,79 +1,196 @@ -# ruff: noqa from typing import Dict import matplotlib import matplotlib.pyplot as plt from matplotlib import gridspec -from matplotlib.ticker import AutoMinorLocator, FormatStrFormatter, MultipleLocator +from matplotlib.axes import Axes + +from matplotlib.text import Text +from matplotlib.ticker import AutoMinorLocator + +from raman_fitting.imports.samples.models import SampleMetaData +from raman_fitting.models.fit_models import SpectrumFitModel -matplotlib.rcParams.update({"font.size": 14}) from raman_fitting.config.settings import ExportPathSettings from raman_fitting.models.splitter import WindowNames from raman_fitting.delegating.models import AggregatedSampleSpectrumFitResult from loguru import logger -# TODO fix big spectrum plot + + +matplotlib.rcParams.update({"font.size": 14}) +# TODO fix big spectrum plot, reduce complexity if-statements def fit_spectrum_plot( aggregated_spectra: Dict[WindowNames, AggregatedSampleSpectrumFitResult], export_paths: ExportPathSettings | None = None, - plot_Annotation=True, - plot_Residuals=True, + plot_annotation=True, + plot_residuals=True, ): # pragma: no cover - # %% first_order = aggregated_spectra[WindowNames.first_order] second_order = aggregated_spectra[WindowNames.second_order] sources = first_order.aggregated_spectrum.sources sample = sources[0].file_info.sample + # first_model_name = "4peaks" + second_model_name = "2nd_4peaks" + # first_model = [first_model_name] + second_model = second_order.fit_model_results.get(second_model_name) + # breakpoint() + for first_model_name, first_model in first_order.fit_model_results.items(): + # for second_model_name, second_model in second_order.fit_model_results.items(): + prepare_combined_spectrum_fit_result_plot( + first_model, + second_model, + sample, + export_paths, + plot_annotation=plot_annotation, + plot_residuals=plot_residuals, + ) + + +def prepare_combined_spectrum_fit_result_plot( + first_model: SpectrumFitModel, + second_model: SpectrumFitModel, + sample: SampleMetaData, + export_paths: ExportPathSettings, + plot_annotation=True, + plot_residuals=True, +): + plt.figure(figsize=(28, 24)) + gs = gridspec.GridSpec(4, 1, height_ratios=[4, 1, 4, 1]) + ax = plt.subplot(gs[0]) + ax_res = plt.subplot(gs[1]) + ax.set_title(f"{sample.id}") + # breakpoint() + + first_model_name = first_model.model.name + + fit_plot_first(ax, ax_res, first_model, plot_residuals=plot_residuals) + + ax2nd = plt.subplot(gs[2]) + ax2nd_res = plt.subplot(gs[3]) + + if second_model is not None: + fit_plot_second(ax2nd, ax2nd_res, second_model, plot_residuals=plot_residuals) + + _bbox_artists = tuple() + if plot_annotation: + annotate_report_first = prepare_annotate_fit_report_first( + ax, first_model.fit_result + ) + _bbox_artists = (annotate_report_first,) + if second_model is not None: + annotate_report_second = prepare_annotate_fit_report_second( + ax2nd, second_model.fit_result + ) + if annotate_report_second is not None: + _bbox_artists = (annotate_report_first, annotate_report_second) - first_model = first_order.fit_model_results["4peaks"] + # set axes labels and legend + set_axes_labels_and_legend(ax) + + plot_special_si_components(ax, first_model, first_model_name) + + if export_paths is not None: + savepath = export_paths.plots.joinpath(f"Model_{first_model_name}").with_suffix( + ".png" + ) + plt.savefig( + savepath, + dpi=100, + bbox_extra_artists=_bbox_artists, + bbox_inches="tight", + ) + logger.debug(f"Plot saved to {savepath}") + plt.close() + + +def fit_plot_first( + ax, ax_res, first_model: SpectrumFitModel, plot_residuals: bool = True +) -> matplotlib.text.Text | None: first_result = first_model.fit_result first_components = first_model.fit_result.components first_eval_comps = first_model.fit_result.eval_components() first_model_name = first_model.model.name - first_pars = first_model.fit_result.best_values - - if second_order.fit_model_results: - second_model = second_order.fit_model_results["2nd_4peaks"] - second_result = second_model.fit_result - second_components = second_model.fit_result.components - second_eval_comps = second_model.fit_result.eval_components() - second_model_name = second_model.model.name - second_pars = second_model.fit_result.best_values - else: - second_order = None - second_components = [] - second_result = None - plt.figure(figsize=(28, 24)) - gs = gridspec.GridSpec(4, 1, height_ratios=[4, 1, 4, 1]) - ax = plt.subplot(gs[0]) ax.grid(True, "both") - axRes = plt.subplot(gs[1]) - axRes.grid(True, "both") - if second_order: - ax2nd = plt.subplot(gs[2]) - ax2nd.grid(True) - ax2ndRes = plt.subplot(gs[3]) - ax2ndRes.grid(True) + ax_res.grid(True, "both") ax.get_yaxis().set_tick_params(direction="in") ax.get_xaxis().set_tick_params(direction="in") - ax.set_title(f"{sample.id}") ax.xaxis.set_minor_locator(AutoMinorLocator(2)) ax.yaxis.set_minor_locator(AutoMinorLocator(2)) ax.tick_params(which="both", direction="in") ax.set_facecolor("oldlace") - axRes.set_facecolor("oldlace") - if second_order: + ax_res.set_facecolor("oldlace") + ax.plot( + first_model.spectrum.ramanshift, + first_result.best_fit, + label=first_model_name, + lw=3, + c="r", + ) + ax.plot( + first_model.spectrum.ramanshift, + first_result.data, + label="Data", + lw=3, + c="grey", + alpha=0.8, + ) + + if plot_residuals: + ax_res.plot( + first_model.spectrum.ramanshift, + first_result.residual, + label="Residual", + lw=3, + c="k", + alpha=0.8, + ) + + for _component in first_components: # automatic color cycle 'cyan' ... + peak_name = _component.prefix.rstrip("_") + ax.plot( + first_model.spectrum.ramanshift, + first_eval_comps[_component.prefix], + ls="--", + lw=4, + label=peak_name, + ) + center_col = _component.prefix + "center" + ax.annotate( + f"{peak_name}:\n {first_result.best_values[center_col]:.0f}", + xy=( + first_result.best_values[center_col] * 0.97, + 0.7 * first_result.params[_component.prefix + "height"].value, + ), + xycoords="data", + ) + + +def fit_plot_second( + ax2nd, ax2nd_res, second_model: SpectrumFitModel, plot_residuals: bool = True +) -> None: + if second_model: + second_result = second_model.fit_result + second_components = second_model.fit_result.components + second_eval_comps = second_model.fit_result.eval_components() + second_model_name = second_model.model.name + else: + second_components = [] + second_result = None + if second_model: + ax2nd.grid(True) + ax2nd_res.grid(True) ax2nd.xaxis.set_minor_locator(AutoMinorLocator(2)) ax2nd.yaxis.set_minor_locator(AutoMinorLocator(2)) ax2nd.tick_params(which="both", direction="in") ax2nd.set_facecolor("oldlace") - ax2ndRes.set_facecolor("oldlace") + ax2nd_res.set_facecolor("oldlace") ax2nd.plot( second_model.spectrum.ramanshift, second_result.best_fit, @@ -89,8 +206,8 @@ def fit_spectrum_plot( c="grey", alpha=0.5, ) - if plot_Residuals: - ax2ndRes.plot( + if plot_residuals: + ax2nd_res.plot( second_model.spectrum.ramanshift, second_result.residual, label="Residual", @@ -118,52 +235,48 @@ def fit_spectrum_plot( xycoords="data", ) ax2nd.set_ylim(-0.02, second_result.data.max() * 1.5) - ax.plot( - first_model.spectrum.ramanshift, - first_result.best_fit, - label=first_model_name, - lw=3, - c="r", - ) - ax.plot( - first_model.spectrum.ramanshift, - first_result.data, - label="Data", - lw=3, - c="grey", - alpha=0.8, + + set_axes_labels_and_legend(ax2nd) + + +def prepare_annotate_fit_report_second(ax2nd, second_result) -> Text: + props = dict(boxstyle="round", facecolor="wheat", alpha=0.5) + annotate_report_second = ax2nd.text( + 1.01, + 0.7, + second_result.fit_report(), + transform=ax2nd.transAxes, + fontsize=11, + verticalalignment="top", + bbox=props, ) - if plot_Residuals: - axRes.plot( - first_model.spectrum.ramanshift, - first_result.residual, - label="Residual", - lw=3, - c="k", - alpha=0.8, - ) + return annotate_report_second - for _component in first_components: # automatic color cycle 'cyan' ... - peak_name = _component.prefix.rstrip("_") - ax.plot( - first_model.spectrum.ramanshift, - first_eval_comps[_component.prefix], - ls="--", - lw=4, - label=peak_name, - ) - center_col = _component.prefix + "center" - ax.annotate( - f"{peak_name}:\n {first_result.best_values[center_col]:.0f}", - xy=( - first_result.best_values[center_col] * 0.97, - 0.7 * first_result.params[_component.prefix + "height"].value, - ), - xycoords="data", - ) - for si_comp in [i for i in first_components if i.prefix.startswith("Si")]: +def prepare_annotate_fit_report_first(ax, first_result): + fit_report = first_result.fit_report() + if len(fit_report) > -1: # TODO remove + fit_report = fit_report.replace("prefix='D3_'", "prefix='D3_' \n") + props = dict(boxstyle="round", facecolor="wheat", alpha=0.5) + + annotate_report_first = ax.text( + 1.01, + 1, + fit_report, + transform=ax.transAxes, + fontsize=11, + verticalalignment="top", + bbox=props, + ) + return annotate_report_first + + +def plot_special_si_components(ax, first_model, model_name: str): + first_result = first_model.fit_result + si_components = filter(lambda x: x.prefix.startswith("Si"), first_result.components) + first_eval_comps = first_model.fit_result.eval_components() + for si_comp in si_components: si_result = si_comp # TODO should be si_fit_results ax.plot( @@ -182,54 +295,10 @@ def fit_spectrum_plot( ), xycoords="data", ) - if plot_Annotation: - fit_report = first_result.fit_report() - if len(fit_report) > -1: # TODO remove - fit_report = fit_report.replace("prefix='D3_'", "prefix='D3_' \n") - props = dict(boxstyle="round", facecolor="wheat", alpha=0.5) - - Report1 = ax.text( - 1.01, - 1, - fit_report, - transform=ax.transAxes, - fontsize=11, - verticalalignment="top", - bbox=props, - ) - _bbox_artists = (Report1,) - if second_result: - Report2 = ax2nd.text( - 1.01, - 0.7, - second_result.fit_report(), - transform=ax2nd.transAxes, - fontsize=11, - verticalalignment="top", - bbox=props, - ) - _bbox_artists = (Report1, Report2) - ( - ax.legend(loc=1), - ax.set_xlabel("Raman shift (cm$^{-1}$)"), - ax.set_ylabel("normalized I / a.u."), - ) - if second_order: - ( - ax2nd.legend(loc=1), - ax2nd.set_xlabel("Raman shift (cm$^{-1}$)"), - ax2nd.set_ylabel("normalized I / a.u."), - ) - if export_paths is not None: - savepath = export_paths.plots.joinpath(f"Model_{first_model_name}").with_suffix( - ".png" - ) - plt.savefig( - savepath, - dpi=100, - bbox_extra_artists=_bbox_artists, - bbox_inches="tight", - ) - logger.debug(f"Plot saved to {savepath}") - plt.close() + +def set_axes_labels_and_legend(ax: Axes): + # set axes labels and legend + ax.legend(loc=1) + ax.set_xlabel("Raman shift (cm$^{-1}$)") + ax.set_ylabel("normalized I / a.u.")