From ad641bbc9fc5091f8f1adf824433ce9b744a45bb Mon Sep 17 00:00:00 2001 From: Matt Cieslak Date: Thu, 16 Mar 2017 08:06:51 -0700 Subject: [PATCH] Fixed bpoint classifier loading/saing bug --- meap/__init__.py | 2 +- meap/io.py | 12 +++++++++--- meap/moving_ensemble.py | 20 +++++++++++++++++++- setup.py | 2 +- 4 files changed, 30 insertions(+), 6 deletions(-) diff --git a/meap/__init__.py b/meap/__init__.py index 3a02397..3536fee 100644 --- a/meap/__init__.py +++ b/meap/__init__.py @@ -22,7 +22,7 @@ SEARCH_WINDOW=30 #samples BLOOD_RESISTIVITY=135. # Ohms cm n_regions=0 -__version__= "1.1.0" +__version__= "1.1.1" # Are we bundled? if getattr( sys, 'frozen', False ) : diff --git a/meap/io.py b/meap/io.py index f922fbe..58bd0c7 100644 --- a/meap/io.py +++ b/meap/io.py @@ -732,6 +732,7 @@ def _get_subject_l(self): mea_hr = Array tpr = Array resp_corrected_tpr = Array + censored_secs_before = Array # Censored seconds between previous beat def _config_default(self): return MEAPConfig() @@ -752,7 +753,11 @@ def save(self,outfile): continue if k in ("censored_regions","event_names"): continue - v = getattr(self,k) + try: + v = getattr(self,k) + except Exception,e: + logger.info("Unable to access %s for saving", k) + continue if type(v) == np.ndarray: if v.size == 0: continue if type(v) is set: continue @@ -769,9 +774,10 @@ def save(self,outfile): except Exception, e: logger.warn("unable to save %s because of %s", k,e) tmp.close() + if not outfile.endswith(".mat"): outfile += ".mat" savemat(outfile, savedict,long_field_names=True) - if not os.path.exists(outfile+".mat"): - logger.critical("failed to save %s.mat", outfile) + if not os.path.exists(outfile): + logger.critical("failed to save %s", outfile) def load_from_disk(matfile, config=None,verbose=False): """ diff --git a/meap/moving_ensemble.py b/meap/moving_ensemble.py index 90a8f35..979e2f0 100644 --- a/meap/moving_ensemble.py +++ b/meap/moving_ensemble.py @@ -3,6 +3,7 @@ Bool, Enum, Instance, on_trait_change, Property, DelegatesTo, Int, Button, List, Set ) import os +import joblib # Needed for Tabular adapter from traitsui.api import Item,HGroup,VGroup, HSplit @@ -140,6 +141,13 @@ def __init__(self,**traits): self._init_bpoint_clf_name() + @on_trait_change("bpoint_classifier_file") + def _file_updated(self): + logger.info("Checking for new bpoint classifier %s", self.bpoint_classifier_file) + if os.path.exists(self.bpoint_classifier_file): + self.bpoint_classifier = self._bpoint_classifier_default() + + def _init_bpoint_clf_name(self): """Loops over various possible directories to find where to write the bpoint_classifier file""" @@ -302,6 +310,16 @@ def _b_apply_weighting_fired(self): # Functions involving b-point classification def _bpoint_classifier_default(self): + if os.path.exists(self.bpoint_classifier_file): + logger.info("attempting to load %s", self.bpoint_classifier_file) + try: + clf = joblib.load(self.bpoint_classifier_file) + self.bpoint_classifier = BPointClassifier( + physiodata=self.physiodata, classifier=clf) + logger.info("success") + except Exception, e: + logger.info("unable to load classifier file") + logger.info(e) logger.info("Loading new bpoint classifier (init)") return BPointClassifier(physiodata=self.physiodata) @@ -319,7 +337,7 @@ def _b_apply_clf_fired(self): progress.open() for i, beat in enumerate(self.mea_beat_train.beats): if not beat.hand_labeled: - beat.b.set_index(self.bpoint_classifier.estimate_bpoint(beat.id)) + beat.b.set_index(int(self.bpoint_classifier.estimate_bpoint(beat.id))) (cont,skip) = progress.update(i) (cont,skip) = progress.update(i+1) self.calculate_physio() diff --git a/setup.py b/setup.py index d8817b6..31dd713 100644 --- a/setup.py +++ b/setup.py @@ -15,7 +15,7 @@ "scikit-image", "traits", "traitsui", - "pyqt<5", + #"pyqt<5", "kiwisolver", "nibabel", "bioread==0.9.5",