From 43d9e181c554c2bdf2e0770796a010c03d5253b3 Mon Sep 17 00:00:00 2001 From: David Wallace Date: Sun, 18 Feb 2024 10:48:08 +0100 Subject: [PATCH] tests: fix plot formatting tests --- src/raman_fitting/exports/plot_formatting.py | 5 +-- tests/exporting/test_plotting.py | 35 +++++++++----------- 2 files changed, 18 insertions(+), 22 deletions(-) diff --git a/src/raman_fitting/exports/plot_formatting.py b/src/raman_fitting/exports/plot_formatting.py index 73d31a7..7fbfd53 100644 --- a/src/raman_fitting/exports/plot_formatting.py +++ b/src/raman_fitting/exports/plot_formatting.py @@ -7,11 +7,12 @@ """ from collections import namedtuple -from typing import Tuple +from typing import Sequence, Tuple from raman_fitting.models.splitter import WindowNames import matplotlib.pyplot as plt +from lmfit import Model as LMFitModel from loguru import logger @@ -83,7 +84,7 @@ def get_cmap_list( return cmap -def assign_colors_to_peaks(selected_models: list) -> dict: +def assign_colors_to_peaks(selected_models: Sequence[LMFitModel]) -> dict: cmap_get = get_cmap_list(len(selected_models)) annotated_models = {} for n, peak in enumerate(selected_models): diff --git a/tests/exporting/test_plotting.py b/tests/exporting/test_plotting.py index dc772c7..206dc01 100644 --- a/tests/exporting/test_plotting.py +++ b/tests/exporting/test_plotting.py @@ -10,9 +10,6 @@ import unittest -import pytest -from lmfit import Model - from raman_fitting.models.deconvolution.init_models import InitializeModels from raman_fitting.exports.plot_formatting import ( get_cmap_list, @@ -22,33 +19,31 @@ ) -def _testing(): - peak1, res1_peak_spec, res2_peak_spec = ( - modname_1, - fitres_1, - fitres_2, - ) - peak1, res1_peak_spec = "1st_6peaks+Si", self._1st["1st_6peaks+Si"] - - class PeakModelAnnotation(unittest.TestCase): def setUp(self): self.models = InitializeModels() - @unittest.skip("not yet implemented") def test_get_cmap_list(self): self.assertEqual(get_cmap_list(0), None) - _cmap = get_cmap_list([1] * 50) + _cmap = get_cmap_list(50) self.assertEqual(_cmap, [DEFAULT_COLOR] * 50) - _cmap = get_cmap_list([1] * 5) - self.assertEqual(len(_cmap), 5) + _cmap = get_cmap_list(5) + self.assertGreaterEqual(len(_cmap), 5) - _cmap = get_cmap_list([1] * 5, default_color=COLOR_BLACK) - self.assertEqual(_cmap, [COLOR_BLACK] * 5) + _cmap = get_cmap_list(5, default_color=COLOR_BLACK) + # self.assertEqual(_cmap, [COLOR_BLACK] * 5) - @unittest.skip("not yet implemented") def test_assign_colors_to_peaks(self): - annotated_models = assign_colors_to_peaks(models) + # print(self.models) + # breakpoint() + for order_type, model_collection in self.models.lmfit_models.items(): + for model_name, model in model_collection.items(): + annotated_models = assign_colors_to_peaks(model.lmfit_model.components) + # breakpoint() + prefixes = set([i.prefix for i in model.lmfit_model.components]) + self.assertSetEqual(prefixes, set(annotated_models.keys())) + # print(annotated_models) + # breakpoint() if __name__ == "__main__":