Skip to content

Commit

Permalink
fixed and refactor code
Browse files Browse the repository at this point in the history
  • Loading branch information
fullbat committed Jun 11, 2024
1 parent d4613ec commit a3e99a9
Showing 1 changed file with 10 additions and 19 deletions.
29 changes: 10 additions & 19 deletions commit/core.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,6 @@ cdef class Evaluation :
cdef public contribution_mask
cdef public contribution_fibs
cdef public contribution_voxels
cdef public debias
cdef public verbose

def __init__( self, study_path='.', subject='.' ) :
Expand All @@ -118,7 +117,6 @@ cdef class Evaluation :
self.contribution_mask = None # set by "fit" method
self.contribution_voxels = None # set by "fit" method
self.contribution_fibs = None
self.debias = False
self.verbose = 3

# store all the parameters of an evaluation with COMMIT
Expand Down Expand Up @@ -763,8 +761,6 @@ cdef class Evaluation :
logger.error( 'Data not loaded; call "load_data()" first' )

y = self.niiDWI_img[ self.DICTIONARY['MASK_ix'], self.DICTIONARY['MASK_iy'], self.DICTIONARY['MASK_iz'], : ].flatten().astype(np.float64)
# y[y < 0] = 0
print(f"place of first non-zero voxel in input data: {np.where(y==1.07405806)}")
return y


Expand Down Expand Up @@ -1223,7 +1219,7 @@ cdef class Evaluation :
logger.info( f'[ {format_time(time.time() - tr)} ]' )


def fit( self, tol_fun=1e-3, tol_x=1e-6, max_iter=100, x0=None, confidence_map_filename=None, confidence_map_rescale=False, debias=False ) :
def fit( self, tol_fun=1e-3, tol_x=1e-6, max_iter=100, x0=None, confidence_map_filename=None, confidence_map_rescale=False, debias=True ) :
"""Fit the model to the data.
Parameters
Expand Down Expand Up @@ -1340,19 +1336,15 @@ cdef class Evaluation :
self.CONFIG['optimization']['max_iter'] = max_iter
self.CONFIG['optimization']['regularisation'] = self.regularisation_params

if debias:
self.debias = True
self.CONFIG['optimization']['x0'] = x0
self.confidence_array = confidence_array


# run solver
t = time.time()
with ProgressBar(disable=self.verbose!=3, hide_on_exit=True) as pb:
self.x, opt_details = commit.solvers.solve(self.get_y(), self.A, self.A.T, tol_fun=tol_fun, tol_x=tol_x, max_iter=max_iter, verbose=self.verbose, x0=x0, regularisation=self.regularisation_params, confidence_array=confidence_array)

if self.debias:
logger.info( 'Recomputing coefficients' )
if (self.regularisation_params['regIC']!=None or self.regularisation_params['regEC']!= None or self.regularisation_params['regISO']!= None) and debias:
temp_verb = self.verbose
logger.info( 'Running debias' )
self.set_verbose(0)
xic, _, _ = self.get_coeffs()
weights_in = pjoin( self.get_config('TRACKING_path'), 'streamline_weights.txt' )
np.savetxt(weights_in, xic)
Expand All @@ -1363,9 +1355,6 @@ cdef class Evaluation :

filter(dictionary_info['filename_tractogram'], tractogram_filtered, minweight=0.000000000000001, weights_in=weights_in, force=True, verbose=0)

# # RE-RUN COMMIT WITH THE FILTERED TRACTOGRAM
# path_COMMIT = os.path.join(local_path, "COMMIT_master_debias")

trk2dictionary.run(
filename_tractogram = tractogram_filtered,
TCK_ref_image = dictionary_info['TCK_ref_image'],
Expand Down Expand Up @@ -1396,9 +1385,11 @@ cdef class Evaluation :

self.A = operator.LinearOperator( self.DICTIONARY, self.KERNELS, self.THREADS, nolut=True if hasattr(self.model, 'nolut') else False )
self.set_regularisation()
self.set_verbose(temp_verb)

with ProgressBar(disable=self.verbose!=3, hide_on_exit=True) as pb:
self.x, opt_details = commit.solvers.solve(self.get_y(), self.A, self.A.T, tol_fun=self.CONFIG['optimization']['tol_fun'], tol_x=self.CONFIG['optimization']['tol_x'], max_iter=self.CONFIG['optimization']['max_iter'], verbose=self.verbose, x0=self.CONFIG['optimization']['x0'], regularisation=self.regularisation_params, confidence_array=self.confidence_array)
logger.subinfo('Recomputing coefficients', indent_lvl=1, indent_char='*', with_progress=True)
with ProgressBar(disable=self.verbose< 3, hide_on_exit=True, subinfo=True) as pbar:
self.x, opt_details = commit.solvers.solve(self.get_y(), self.A, self.A.T, tol_fun=tol_fun, tol_x=tol_x, max_iter=max_iter, verbose=0, x0=x0, regularisation=self.regularisation_params, confidence_array=confidence_array)

self.CONFIG['optimization']['fit_details'] = opt_details
self.CONFIG['optimization']['fit_time'] = time.time()-t
Expand Down Expand Up @@ -1446,7 +1437,7 @@ cdef class Evaluation :
return xic, xec, xiso


def save_results( self, path_suffix=None, coeffs_format='%.5e', stat_coeffs='sum', save_est_dwi=False, do_reweighting=True, debias=False ) :
def save_results( self, path_suffix=None, coeffs_format='%.5e', stat_coeffs='sum', save_est_dwi=False, do_reweighting=True ) :
"""Save the output (coefficients, errors, maps etc).
Parameters
Expand Down

0 comments on commit a3e99a9

Please sign in to comment.