Skip to content

Commit

Permalink
working version with dot prod, need to fix RMSE and NRMSE
Browse files Browse the repository at this point in the history
  • Loading branch information
fullbat committed Nov 3, 2024
1 parent 43010c7 commit 1ccd11a
Show file tree
Hide file tree
Showing 5 changed files with 104 additions and 69 deletions.
99 changes: 54 additions & 45 deletions commit/core.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -90,9 +90,7 @@ cdef class Evaluation :
cdef public temp_data
cdef public confidence_array
cdef public confidence_map_img
cdef public contribution_mask
cdef public contribution_fibs
cdef public contribution_voxels
cdef public debias_mask
cdef public verbose

def __init__( self, study_path='.', subject='.' ) :
Expand All @@ -116,10 +114,8 @@ cdef class Evaluation :
self.x = None # set by "fit" method
self.confidence_array = None # set by "fit" method
self.confidence_map_img = None # set by "fit" method
self.contribution_mask = None # set by "fit" method
self.contribution_voxels = None # set by "fit" method
self.debias_mask = None # set by "fit" method
self.x_nnls = None # set by "fit" method (coefficients of IC compartment estimated without regularization)
self.contribution_fibs = None
self.verbose = 3

# store all the parameters of an evaluation with COMMIT
Expand Down Expand Up @@ -1374,54 +1370,58 @@ cdef class Evaluation :
self.x, opt_details = commit.solvers.solve(self.get_y(), self.A, self.A.T, tol_fun=tol_fun, tol_x=tol_x, max_iter=max_iter, verbose=self.verbose, x0=x0, regularisation=self.regularisation_params, confidence_array=confidence_array)

if (self.regularisation_params['regIC']!=None or self.regularisation_params['regEC']!= None or self.regularisation_params['regISO']!= None) and debias:

from commit.operator import operator
temp_verb = self.verbose
logger.info( 'Running debias' )
self.set_verbose(0)
xic, _, _ = self.get_coeffs()
weights_in = pjoin( self.get_config('TRACKING_path'), 'streamline_weights.txt' )
np.savetxt(weights_in, xic)

dictionary_info = load_dictionary_info( pjoin(self.get_config('TRACKING_path'), 'dictionary_info.pickle') )
tractogram = dictionary_info['filename_tractogram']
tractogram_filtered = tractogram.replace('.tck', '_filtered.tck')

filter(dictionary_info['filename_tractogram'], tractogram_filtered, minweight=0.000000000000001, weights_in=weights_in, force=True, verbose=0)

trk2dictionary.run(
filename_tractogram = tractogram_filtered,
TCK_ref_image = dictionary_info['TCK_ref_image'],
path_out = dictionary_info['path_out'],
filename_peaks = dictionary_info['filename_peaks'],
filename_mask = dictionary_info['filename_mask'],
do_intersect = dictionary_info['do_intersect'],
fiber_shift = dictionary_info['fiber_shift'],
min_seg_len = dictionary_info['min_seg_len'],
min_fiber_len = dictionary_info['min_fiber_len'],
max_fiber_len = dictionary_info['max_fiber_len'],
vf_THR = dictionary_info['vf_THR'],
peaks_use_affine = dictionary_info['peaks_use_affine'],
flip_peaks = dictionary_info['flip_peaks'],
blur_core_extent = dictionary_info['blur_core_extent'],
blur_gauss_extent = dictionary_info['blur_gauss_extent'],
blur_gauss_min = dictionary_info['blur_gauss_min'],
blur_spacing = dictionary_info['blur_spacing'],
ndirs = dictionary_info['ndirs'],
n_threads = dictionary_info['n_threads'],
verbose = 0
)

self.load_dictionary(dictionary_info['path_out'])

self.set_threads()
self.build_operator()
nF = self.DICTIONARY['IC']['nF']
nE = self.DICTIONARY['EC']['nE']
nV = self.DICTIONARY['nV']

offset1 = nF * self.KERNELS['wmr'].shape[0]
xic = self.x[:offset1]

mask = np.ones(nF, dtype=np.uint32)
mask[xic<0.000000000000001] = 0

self.DICTIONARY["IC"]["eval"] = mask

self.A = operator.LinearOperator( self.DICTIONARY, self.KERNELS, self.THREADS, nolut=True if hasattr(self.model, 'nolut') else False )

self.set_regularisation()

self.set_verbose(temp_verb)

logger.subinfo('Recomputing coefficients', indent_lvl=1, indent_char='*', with_progress=True)
with ProgressBar(disable=self.verbose< 3, hide_on_exit=True, subinfo=True) as pbar:
self.x, opt_details = commit.solvers.solve(self.get_y(), self.A, self.A.T, tol_fun=tol_fun, tol_x=tol_x, max_iter=max_iter, verbose=0, x0=x0, regularisation=self.regularisation_params, confidence_array=confidence_array)

x_debias = self.x.copy()

logger.debug( f'positive values of x before debias: {np.sum(x_debias>0)}' )
logger.debug( f'positive values of mask: {np.sum(mask>0)}' )

x_debias[:nF] *= mask
x_debias[offset1:] = 0

logger.debug( f"positive values of masked x: {np.sum(x_debias[:nF]>0)}" )

logger.debug( f'Shape of y: {self.get_y().size} Number of non zero values in y before: {np.sum(self.get_y()>0)}' )

y_mask = np.asarray(self.A.dot(x_debias))
print(f"number of non zero values in y_mask before bin: {np.sum(y_mask>0)}")
# binarize y_debias
y_mask[y_mask<0] = 0
y_mask[y_mask>0] = 1
print(f"number of non zero values in y_mask after bin: {np.sum(y_mask>0)}")

self.debias_mask = y_mask
logger.debug( f'Shape of y: {self.get_y().size} Number of non zero values in y after: {np.sum(self.get_y()>0)}' )

# print the first 10 non zero values of y_debias
logger.debug( f'First 10 non zero values of y_debias: {self.get_y()[:10]}' )
with ProgressBar(disable=self.verbose!=3, hide_on_exit=True, subinfo=True) as pbar:
self.x, opt_details = commit.solvers.solve(self.get_y(), self.A, self.A.T, tol_fun=tol_fun, tol_x=tol_x, max_iter=max_iter, verbose=self.verbose, x0=x0, regularisation=self.regularisation_params, confidence_array=confidence_array)

self.CONFIG['optimization']['fit_details'] = opt_details
self.CONFIG['optimization']['fit_time'] = time.time()-t
Expand Down Expand Up @@ -1533,8 +1533,17 @@ cdef class Evaluation :

y_mea = np.reshape( self.niiDWI_img[ self.DICTIONARY['MASK_ix'], self.DICTIONARY['MASK_iy'], self.DICTIONARY['MASK_iz'], : ].flatten().astype(np.float32), (nV,-1) )
y_est = np.reshape( self.A.dot(self.x), (nV,-1) ).astype(np.float32)

tmp = np.sqrt( np.mean((y_mea-y_est)**2,axis=1) )

if self.debias_mask is not None:
y_mask = np.reshape(self.debias_mask, (nV,-1))
# compute tmp only for the voxels of y_mea and y_est that are non zero in y_mask
idx = np.where(y_mask.flatten()>0)
tmp = np.sqrt( np.mean((y_mea.flatten()[idx]-y_est.flatten()[idx])**2) )




logger.subinfo(f'RMSE: {tmp.mean():.3f} +/- {tmp.std():.3f}', indent_lvl=2, indent_char='-')
niiMAP_img[ self.DICTIONARY['MASK_ix'], self.DICTIONARY['MASK_iy'], self.DICTIONARY['MASK_iz'] ] = tmp
niiMAP_hdr['cal_min'] = 0
Expand Down
27 changes: 18 additions & 9 deletions commit/operator/operator.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ cdef extern void COMMIT_A(
cdef extern void COMMIT_At(
int _nF, int _n, int _nE, int _nV, int _nS, int _ndirs,
double *_v_in, double *_v_out,
unsigned int *_ICf, unsigned int *_ICv, unsigned short *_ICo, float *_ICl,
unsigned int *_ICf, unsigned int *_ICeval, unsigned int *_ICv, unsigned short *_ICo, float *_ICl,
unsigned int *_ECv, unsigned short *_ECo,
unsigned int *_ISOv,
float *_wmrSFP, float *_wmhSFP, float *_isoSFP,
Expand All @@ -31,7 +31,7 @@ cdef extern void COMMIT_At(
cdef extern void COMMIT_A_nolut(
int _nF,
double *_v_in, double *_v_out,
unsigned int *_ICf, unsigned int *_ICv, float *_ICl,
unsigned int *_ICf, unsigned int *_ICeval, unsigned int *_ICv, float *_ICl,
unsigned int *_ISOv,
unsigned int* _ICthreads, unsigned int* _ISOthreads,
unsigned int _nISO, unsigned int _nThreads
Expand All @@ -40,7 +40,7 @@ cdef extern void COMMIT_A_nolut(
cdef extern void COMMIT_At_nolut(
int _nF, int _n,
double *_v_in, double *_v_out,
unsigned int *_ICf, unsigned int *_ICv, float *_ICl,
unsigned int *_ICf, unsigned int *_ICeval, unsigned int *_ICv, float *_ICl,
unsigned int *_ISOv,
unsigned char* _ICthreadsT, unsigned int* _ISOthreadsT,
unsigned int _nISO, unsigned int _nThreads
Expand All @@ -62,6 +62,7 @@ cdef class LinearOperator :
cdef nolut

cdef unsigned int* ICf
cdef unsigned int* ICeval
cdef float* ICl
cdef unsigned int* ICv
cdef unsigned short* ICo
Expand All @@ -84,6 +85,7 @@ cdef class LinearOperator :

def __init__( self, DICTIONARY, KERNELS, THREADS, nolut=False ) :
"""Set the pointers to the data structures used by the C code."""

self.DICTIONARY = DICTIONARY
self.KERNELS = KERNELS
self.THREADS = THREADS
Expand All @@ -98,6 +100,7 @@ cdef class LinearOperator :
self.n = DICTIONARY['IC']['n'] # numbner of IC segments
self.ndirs = KERNELS['wmr'].shape[1] # number of directions


if KERNELS['wmr'].size > 0 :
self.nS = KERNELS['wmr'].shape[2] # number of SAMPLES
elif KERNELS['wmh'].size > 0 :
Expand All @@ -108,11 +111,17 @@ cdef class LinearOperator :
self.adjoint = 0 # direct of inverse product

self.n1 = self.nV*self.nS
self.n2 = self.nR*self.nF + self.nT*self.nE + self.nI*self.nV
self.n2 = self.nR*self.nF + self.nT*self.nE + self.nI*self.nV

# get C pointers to arrays in DICTIONARY
cdef unsigned int [::1] ICf = DICTIONARY['IC']['fiber']
self.ICf = &ICf[0]
cdef unsigned int [::1] ICeval = DICTIONARY["IC"]["eval"]
self.ICeval = &ICeval[0]

# for i in range(self.n2):
# print(f"ICeval after assignment: {self.ICeval[i]}")

cdef float [::1] ICl = DICTIONARY['IC']['len']
self.ICl = &ICl[0]
cdef unsigned int [::1] ICv = DICTIONARY['IC']['v']
Expand Down Expand Up @@ -153,7 +162,7 @@ cdef class LinearOperator :
@property
def T( self ) :
"""Transpose of the explicit matrix."""
C = LinearOperator( self.DICTIONARY, self.KERNELS, self.THREADS, self.nolut )
C = LinearOperator( self.DICTIONARY, self.KERNELS, self.THREADS, nolut=self.nolut )
C.adjoint = 1 - C.adjoint
return C

Expand Down Expand Up @@ -201,7 +210,7 @@ cdef class LinearOperator :
COMMIT_A_nolut(
self.nF,
&v_in[0], &v_out[0],
self.ICf, self.ICv, self.ICl,
self.ICf, self.ICeval, self.ICv, self.ICl,
self.ISOv,
self.ICthreads, self.ISOthreads,
nISO, nthreads
Expand All @@ -211,7 +220,7 @@ cdef class LinearOperator :
COMMIT_A(
self.nF, self.nE, self.nV, self.nS, self.ndirs,
&v_in[0], &v_out[0],
self.ICf, self.ICv, self.ICo, self.ICl,
self.ICf, self.ICeval, self.ICv, self.ICo, self.ICl,
self.ECv, self.ECo,
self.ISOv,
self.LUT_IC, self.LUT_EC, self.LUT_ISO,
Expand All @@ -225,7 +234,7 @@ cdef class LinearOperator :
COMMIT_At_nolut(
self.nF, self.n,
&v_in[0], &v_out[0],
self.ICf, self.ICv, self.ICl,
self.ICf, self.ICeval, self.ICv, self.ICl,
self.ISOv,
self.ICthreadsT, self.ISOthreadsT,
nISO, nthreads
Expand All @@ -235,7 +244,7 @@ cdef class LinearOperator :
COMMIT_At(
self.nF, self.n, self.nE, self.nV, self.nS, self.ndirs,
&v_in[0], &v_out[0],
self.ICf, self.ICv, self.ICo, self.ICl,
self.ICf, self.ICeval, self.ICv, self.ICo, self.ICl,
self.ECv, self.ECo,
self.ISOv,
self.LUT_IC, self.LUT_EC, self.LUT_ISO,
Expand Down
3 changes: 3 additions & 0 deletions commit/solvers.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,9 @@ def fista(y, A, At, omega, prox, sqrt_W=None, tol_fun=1e-4, tol_x=1e-6, max_iter
res = A.dot(xhat) - y
grad = np.asarray(At.dot(res))

print("grad", grad[grad!=0][:50])
print("res", res[res!=0][:50])

prox( xhat, 1.0 )
reg_term = omega( xhat )
prev_obj = 0.5 * np.linalg.norm(res)**2 + reg_term
Expand Down
8 changes: 4 additions & 4 deletions setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -65,10 +65,10 @@ def run(self):
build_ext.finalize_options(self)
build_ext.run(self)

# generate the operator_c.c file
sys.path.insert(0, os.path.dirname(__file__))
from setup_operator import write_operator_c_file
write_operator_c_file()
# # generate the operator_c.c file
# sys.path.insert(0, os.path.dirname(__file__))
# from setup_operator import write_operator_c_file
# write_operator_c_file()

# create the 'build' directory
if not os.path.exists('build'):
Expand Down
Loading

0 comments on commit 1ccd11a

Please sign in to comment.