Skip to content

Commit

Permalink
Fixed bpoint classifier loading/saing bug
Browse files Browse the repository at this point in the history
  • Loading branch information
mattcieslak committed Mar 16, 2017
1 parent 8cc8177 commit ad641bb
Show file tree
Hide file tree
Showing 4 changed files with 30 additions and 6 deletions.
2 changes: 1 addition & 1 deletion meap/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@
SEARCH_WINDOW=30 #samples
BLOOD_RESISTIVITY=135. # Ohms cm
n_regions=0
__version__= "1.1.0"
__version__= "1.1.1"

# Are we bundled?
if getattr( sys, 'frozen', False ) :
Expand Down
12 changes: 9 additions & 3 deletions meap/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -732,6 +732,7 @@ def _get_subject_l(self):
mea_hr = Array
tpr = Array
resp_corrected_tpr = Array
censored_secs_before = Array # Censored seconds between previous beat

def _config_default(self):
return MEAPConfig()
Expand All @@ -752,7 +753,11 @@ def save(self,outfile):
continue
if k in ("censored_regions","event_names"):
continue
v = getattr(self,k)
try:
v = getattr(self,k)
except Exception,e:
logger.info("Unable to access %s for saving", k)
continue
if type(v) == np.ndarray:
if v.size == 0: continue
if type(v) is set: continue
Expand All @@ -769,9 +774,10 @@ def save(self,outfile):
except Exception, e:
logger.warn("unable to save %s because of %s", k,e)
tmp.close()
if not outfile.endswith(".mat"): outfile += ".mat"
savemat(outfile, savedict,long_field_names=True)
if not os.path.exists(outfile+".mat"):
logger.critical("failed to save %s.mat", outfile)
if not os.path.exists(outfile):
logger.critical("failed to save %s", outfile)

def load_from_disk(matfile, config=None,verbose=False):
"""
Expand Down
20 changes: 19 additions & 1 deletion meap/moving_ensemble.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@
Bool, Enum, Instance, on_trait_change, Property,
DelegatesTo, Int, Button, List, Set )
import os
import joblib

# Needed for Tabular adapter
from traitsui.api import Item,HGroup,VGroup, HSplit
Expand Down Expand Up @@ -140,6 +141,13 @@ def __init__(self,**traits):

self._init_bpoint_clf_name()

@on_trait_change("bpoint_classifier_file")
def _file_updated(self):
logger.info("Checking for new bpoint classifier %s", self.bpoint_classifier_file)
if os.path.exists(self.bpoint_classifier_file):
self.bpoint_classifier = self._bpoint_classifier_default()


def _init_bpoint_clf_name(self):
"""Loops over various possible directories to find where to write
the bpoint_classifier file"""
Expand Down Expand Up @@ -302,6 +310,16 @@ def _b_apply_weighting_fired(self):

# Functions involving b-point classification
def _bpoint_classifier_default(self):
if os.path.exists(self.bpoint_classifier_file):
logger.info("attempting to load %s", self.bpoint_classifier_file)
try:
clf = joblib.load(self.bpoint_classifier_file)
self.bpoint_classifier = BPointClassifier(
physiodata=self.physiodata, classifier=clf)
logger.info("success")
except Exception, e:
logger.info("unable to load classifier file")
logger.info(e)
logger.info("Loading new bpoint classifier (init)")
return BPointClassifier(physiodata=self.physiodata)

Expand All @@ -319,7 +337,7 @@ def _b_apply_clf_fired(self):
progress.open()
for i, beat in enumerate(self.mea_beat_train.beats):
if not beat.hand_labeled:
beat.b.set_index(self.bpoint_classifier.estimate_bpoint(beat.id))
beat.b.set_index(int(self.bpoint_classifier.estimate_bpoint(beat.id)))
(cont,skip) = progress.update(i)
(cont,skip) = progress.update(i+1)
self.calculate_physio()
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@
"scikit-image",
"traits",
"traitsui",
"pyqt<5",
#"pyqt<5",
"kiwisolver",
"nibabel",
"bioread==0.9.5",
Expand Down

0 comments on commit ad641bb

Please sign in to comment.