Skip to content

Commit

Permalink
implemented debias newest version
Browse files Browse the repository at this point in the history
  • Loading branch information
fullbat committed Jul 18, 2024
1 parent 3fd5168 commit f47792c
Show file tree
Hide file tree
Showing 3 changed files with 122 additions and 37 deletions.
154 changes: 120 additions & 34 deletions commit/core.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ from dicelib.utils import format_time

import commit.models
import commit.solvers
from commit.operator import operator


logger = setup_logger('core')
Expand Down Expand Up @@ -81,6 +82,7 @@ cdef class Evaluation :
cdef public CONFIG
cdef public temp_data
cdef public confidence_map_img
cdef public debias_mask
cdef public verbose

def __init__( self, study_path='.', subject='.' ) :
Expand All @@ -103,6 +105,7 @@ cdef class Evaluation :
self.regularisation_params = None # set by "set_regularisation" method
self.x = None # set by "fit" method
self.confidence_map_img = None # set by "fit" method
self.debias_mask = None # set by "fit" method
self.verbose = 3

# store all the parameters of an evaluation with COMMIT
Expand Down Expand Up @@ -734,7 +737,10 @@ cdef class Evaluation :
logger.subinfo('')
logger.info( 'Building linear operator A' )

from commit.operator import operator
nF = self.DICTIONARY['IC']['nF'] # number of FIBERS
n2 = nF * self.KERNELS['wmr'].shape[0] * self.KERNELS['wmc'].shape[0]
self.DICTIONARY["IC"]["eval"] = np.ones( int(n2), dtype=np.uint32)

self.A = operator.LinearOperator( self.DICTIONARY, self.KERNELS, self.THREADS, True if hasattr(self.model, 'nolut') else False )

logger.info( f'[ {format_time(time.time() - tic)} ]' )
Expand All @@ -751,7 +757,9 @@ cdef class Evaluation :
logger.error( 'Data not loaded; call "load_data()" first' )

y = self.niiDWI_img[ self.DICTIONARY['MASK_ix'], self.DICTIONARY['MASK_iy'], self.DICTIONARY['MASK_iz'], : ].flatten().astype(np.float64)
# y[y < 0] = 0
if self.debias_mask is not None :
y *= self.debias_mask

return y


Expand Down Expand Up @@ -1332,19 +1340,53 @@ cdef class Evaluation :

self.CONFIG['optimization']['fit_details'] = opt_details
self.CONFIG['optimization']['fit_time'] = time.time()-t
if debias:


if (self.regularisation_params['regIC']!=None or self.regularisation_params['regEC']!= None or self.regularisation_params['regISO']!= None) and debias:

from commit.operator import operator
mask = np.zeros(self.DICTIONARY['IC']['nF']*self.KERNELS['wmc'].shape[0], dtype=np.uint32)
mask[self.x>0] = 1
mask[self.x<0] = 1
self.DICTIONARY['IC']['idx'] = np.ascontiguousarray(mask, dtype=np.uint32)
temp_verb = self.verbose
logger.info( 'Running debias' )
self.set_verbose(0)

nF = self.DICTIONARY['IC']['nF']

offset1 = nF * self.KERNELS['wmr'].shape[0] * self.KERNELS['wmc'].shape[0]
xic = self.x[:offset1]

mask = np.ones(offset1, dtype=np.uint32)

mask[self.x==0] = 0

self.DICTIONARY['IC']['eval'] = mask

self.A = operator.LinearOperator( self.DICTIONARY, self.KERNELS, self.THREADS, True if hasattr(self.model, 'nolut') else False )

if self.KERNELS['wmc'].shape[0] > 1:
self.set_regularisation(is_nonnegative = (False, True, True))
else:
self.set_regularisation()

self.x, opt_details = commit.solvers.solve(self.get_y(), self.A, self.A.T, tol_fun=tol_fun, tol_x=tol_x, max_iter=max_iter, verbose=self.verbose, x0=self.x, regularisation=self.regularisation_params, confidence_array=confidence_array)

self.set_verbose(temp_verb)

logger.subinfo('Recomputing coefficients', indent_lvl=1, indent_char='*', with_progress=True)

x_debias = self.x.copy()
x_debias[:offset1] *= mask
x_debias[offset1:] = 0

y_mask = np.asarray(self.A.dot(x_debias))
# binarize y_debias
y_mask[y_mask<0] = 0
y_mask[y_mask>0] = 1

self.debias_mask = y_mask

with ProgressBar(disable=self.verbose!=3, hide_on_exit=True, subinfo=True) as pbar:
self.x, opt_details = commit.solvers.solve(self.get_y(), self.A, self.A.T, tol_fun=tol_fun, tol_x=tol_x, max_iter=max_iter, verbose=self.verbose, x0=self.x, regularisation=self.regularisation_params, confidence_array=confidence_array)

self.CONFIG['optimization']['fit_details'] = opt_details
self.CONFIG['optimization']['fit_time'] = time.time()-t
logger.info( f'[ {format_time(self.CONFIG["optimization"]["fit_time"])} ]' )


Expand Down Expand Up @@ -1388,10 +1430,6 @@ cdef class Evaluation :

return xic, xec, xiso

def compute_chunks(lst, n):
"""Yield successive n-sized chunks from lst."""
for i in range(0, len(lst), n):
yield lst[i:i + n]


cpdef compute_contribution(self, x, norm_fib, metric):
Expand Down Expand Up @@ -1511,26 +1549,74 @@ cdef class Evaluation :
niiMAP_hdr['descrip'] = 'Created with COMMIT %s'%self.get_config('version')
niiMAP_hdr['db_name'] = ''

y_mea = np.reshape( self.niiDWI_img[ self.DICTIONARY['MASK_ix'], self.DICTIONARY['MASK_iy'], self.DICTIONARY['MASK_iz'], : ].flatten().astype(np.float32), (nV,-1) )
y_est = np.reshape( self.A.dot(self.x), (nV,-1) ).astype(np.float32)

tmp = np.sqrt( np.mean((y_mea-y_est)**2,axis=1) )
logger.subinfo(f'RMSE: {tmp.mean():.3f} +/- {tmp.std():.3f}', indent_lvl=2, indent_char='-')
niiMAP_img[ self.DICTIONARY['MASK_ix'], self.DICTIONARY['MASK_iy'], self.DICTIONARY['MASK_iz'] ] = tmp
niiMAP_hdr['cal_min'] = 0
niiMAP_hdr['cal_max'] = tmp.max()
nibabel.save( niiMAP, pjoin(RESULTS_path,'fit_RMSE.nii.gz') )

tmp = np.sum(y_mea**2,axis=1)
idx = np.where( tmp < 1E-12 )
tmp[ idx ] = 1
tmp = np.sqrt( np.sum((y_mea-y_est)**2,axis=1) / tmp )
tmp[ idx ] = 0
logger.subinfo(f'NRMSE: {tmp.mean():.3f} +/- {tmp.std():.3f}', indent_lvl=2, indent_char='-')
niiMAP_img[ self.DICTIONARY['MASK_ix'], self.DICTIONARY['MASK_iy'], self.DICTIONARY['MASK_iz'] ] = tmp
niiMAP_hdr['cal_min'] = 0
niiMAP_hdr['cal_max'] = 1
nibabel.save( niiMAP, pjoin(RESULTS_path,'fit_NRMSE.nii.gz') )
if self.debias_mask is not None:
nV = int(np.sum(self.debias_mask)/self.niiDWI_img.shape[3])

if nV == 0:
logger.warning("Streamlines contributions are all zero.")
return 0
ind_mask = np.where(self.debias_mask>0)[0]
vox_mask = np.reshape( self.debias_mask[ind_mask], (nV,-1) )

y_mea = np.reshape( self.get_y()[ind_mask], (nV,-1) )

y_est_ = np.asarray(self.A.dot(self.x))
y_est = np.reshape( y_est_[ind_mask], (nV,-1) )

tmp = np.sqrt( np.mean((y_mea-y_est)**2,axis=1) )
rmse = tmp.mean()
logger.subinfo(f'RMSE: {tmp.mean():.3f} +/- {tmp.std():.3f}', indent_lvl=2, indent_char='-')

tmp = np.sum(y_mea**2,axis=1)
idx = np.where( tmp < 1E-12 )
tmp[ idx ] = 1
tmp = np.sqrt( np.sum((y_mea-y_est)**2,axis=1) / tmp )
tmp[ idx ] = 0
logger.subinfo(f'NRMSE: {tmp.mean():.3f} +/- {tmp.std():.3f}', indent_lvl=2, indent_char='-')

y_mea = np.reshape( self.get_y(), (self.DICTIONARY['nV'],-1) )

y_est_ = np.asarray(self.A.dot(self.x))
y_est = np.reshape( y_est_, (self.DICTIONARY['nV'],-1) )
tmp = np.sqrt( np.mean((y_mea-y_est)**2,axis=1) )

niiMAP_img[self.DICTIONARY['MASK_ix'], self.DICTIONARY['MASK_iy'], self.DICTIONARY['MASK_iz']] = tmp
niiMAP_hdr['cal_min'] = 0
niiMAP_hdr['cal_max'] = tmp.max()
nibabel.save( niiMAP, pjoin(RESULTS_path,'fit_RMSE.nii.gz') )

tmp = np.sum(y_mea**2,axis=1)
idx = np.where( tmp < 1E-12 )
tmp[ idx ] = 1
tmp = np.sqrt( np.sum((y_mea-y_est)**2,axis=1) / tmp )
tmp[ idx ] = 0

niiMAP_img[self.DICTIONARY['MASK_ix'], self.DICTIONARY['MASK_iy'], self.DICTIONARY['MASK_iz']] = tmp
niiMAP_hdr['cal_min'] = 0
niiMAP_hdr['cal_max'] = 1
nibabel.save( niiMAP, pjoin(RESULTS_path,'fit_NRMSE.nii.gz') )

else:
y_mea = np.reshape( self.niiDWI_img[ self.DICTIONARY['MASK_ix'], self.DICTIONARY['MASK_iy'], self.DICTIONARY['MASK_iz'], : ].flatten().astype(np.float32), (nV,-1) )
y_est = np.reshape( self.A.dot(self.x), (nV,-1) ).astype(np.float32)
tmp = np.sqrt( np.mean((y_mea-y_est)**2,axis=1) )

logger.subinfo(f'RMSE: {tmp.mean():.3f} +/- {tmp.std():.3f}', indent_lvl=2, indent_char='-')
niiMAP_img[ self.DICTIONARY['MASK_ix'], self.DICTIONARY['MASK_iy'], self.DICTIONARY['MASK_iz'] ] = tmp
niiMAP_hdr['cal_min'] = 0
niiMAP_hdr['cal_max'] = tmp.max()
nibabel.save( niiMAP, pjoin(RESULTS_path,'fit_RMSE.nii.gz') )

tmp = np.sum(y_mea**2,axis=1)
idx = np.where( tmp < 1E-12 )
tmp[ idx ] = 1
tmp = np.sqrt( np.sum((y_mea-y_est)**2,axis=1) / tmp )
tmp[ idx ] = 0
logger.subinfo(f'NRMSE: {tmp.mean():.3f} +/- {tmp.std():.3f}', indent_lvl=2, indent_char='-')
niiMAP_img[ self.DICTIONARY['MASK_ix'], self.DICTIONARY['MASK_iy'], self.DICTIONARY['MASK_iz'] ] = tmp
niiMAP_hdr['cal_min'] = 0
niiMAP_hdr['cal_max'] = 1
nibabel.save( niiMAP, pjoin(RESULTS_path,'fit_NRMSE.nii.gz') )

if self.confidence_map_img is not None:
confidence_array = np.reshape( self.confidence_map_img[ self.DICTIONARY['MASK_ix'], self.DICTIONARY['MASK_iy'], self.DICTIONARY['MASK_iz'], : ].flatten().astype(np.float32), (nV,-1) )
Expand Down Expand Up @@ -1705,7 +1791,7 @@ cdef class Evaluation :
with open( pjoin(RESULTS_path,'results.pickle'), 'wb+' ) as fid :
self.CONFIG['optimization']['regularisation'].pop('omega', None)
self.CONFIG['optimization']['regularisation'].pop('prox', None)
pickle.dump( [self.CONFIG, x, self.x], fid, protocol=2 )
pickle.dump( [self.CONFIG, x, self.x, rmse], fid, protocol=2 )

if save_est_dwi :
logger.subinfo('Estimated signal:', indent_char='-', indent_lvl=2, with_progress=True)
Expand Down
2 changes: 1 addition & 1 deletion commit/operator/operator.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -127,7 +127,7 @@ cdef class LinearOperator :
cdef unsigned int [::1] ICf = DICTIONARY['IC']['fiber']
self.ICf = &ICf[0]

cdef unsigned int [::1] ICeval = DICTIONARY['IC']['idx']
cdef unsigned int [::1] ICeval = DICTIONARY['IC']['eval']
self.ICeval = &ICeval[0]


Expand Down
3 changes: 1 addition & 2 deletions commit/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -273,8 +273,7 @@ def fista(y, A, At, omega, prox, sqrt_W=None, tol_fun=1e-4, tol_x=1e-6, max_iter
else:
res = A.dot(xhat) - y
grad = np.asarray(At.dot(res))



prox( xhat, 1.0 )
reg_term = omega( xhat )
prev_obj = 0.5 * np.linalg.norm(res)**2 + reg_term
Expand Down

0 comments on commit f47792c

Please sign in to comment.