Skip to content

Commit

Permalink
Added set_wLasso_profiles() method for weighted lasso regularization …
Browse files Browse the repository at this point in the history
…using profiles
  • Loading branch information
ilariagabusi committed Oct 16, 2024
1 parent 1e58ced commit f15699a
Showing 1 changed file with 60 additions and 2 deletions.
62 changes: 60 additions & 2 deletions commit/core.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -764,6 +764,63 @@ cdef class Evaluation :

return y

def set_wLasso_profiles(self, prof_weights, lambda_perc_IC):
"""
Compute array of weights for all the streamlines' profiles and set weighted lasso regularisation.
Parameters
----------
prof_weights - np.array(np.float64) :
array of weights for the profiles.
NB: this array must have the same size as the number of profiles set before.
NB: the weights must be greater or equal to 1.
lambda_perc_IC - float :
percentage of the maximum value of the regularisation parameter for the IC compartment.
NB: lambda_perc_IC must be a float greater than 0.
"""

# check if all the necessary functions have been called
if self.niiDWI is None :
logger.error( 'Data not loaded; call "load_data()" first' )
if self.DICTIONARY is None :
logger.error( 'Dictionary not loaded; call "load_dictionary()" first' )
if self.KERNELS is None :
logger.error( 'Response functions not generated; call "generate_kernels()" and "load_kernels()" first' )
if self.THREADS is None :
logger.error( 'Threads not set; call "set_threads()" first' )
if self.A is None :
logger.error( 'Operator not built; call "build_operator()" first' )

if self.DICTIONARY['IC']['nF'] <= 0 :
logger.error( 'No streamline found in the dictionary; check your data' )

# regularisation['sizeIC'] = int( self.DICTIONARY['IC']['nF'] * self.KERNELS['wmr'].shape[0] * self.KERNELS['wmc'].shape[0])
# self.DICTIONARY['IC']['nF'] = number of kept streamlines
# self.KERNELS['wmr'] = array with shape (1, ndirs, nS), where nS is the number of samples in the scheme
# self.KERNELS['wmc'] = array with shape (num_prof, num_samples)

# check prof_weights
if type(prof_weights) not in [list, np.ndarray]:
logger.error('"prof_weights" must be a list or a numpy array')
prof_weights = np.array(prof_weights, dtype=np.float64)
if np.any(prof_weights < 1):
logger.error('All group weights must be greater or equal to 1')
if prof_weights.size != self.KERNELS['wmc'].shape[0]:
logger.error( f'The number of profiles\' weights in the provided array does not match the number of profiles set in "load_kernels()" (got {prof_weights.size} but {self.KERNELS["wmc"].shape[0]} expected)' )

# compute complete array of weights
prof_weights_all = np.repeat(prof_weights, len(self.DICTIONARY["TRK"]["kept"]))
dict_ic = {}
dict_ic['coeff_weights'] = prof_weights_all

# set reg
self.set_regularisation(
regularisers = ('lasso', None, None),
lambdas = (lambda_perc_IC, None, None),
is_nonnegative = (False, True, True),
params = (dict_ic, None, None))


def set_regularisation(self, regularisers=(None, None, None), lambdas=(None, None, None), is_nonnegative=(True, True, True), params=(None, None, None)):
"""
Expand Down Expand Up @@ -960,9 +1017,10 @@ cdef class Evaluation :
if dictIC_params is not None and 'coeff_weights' in dictIC_params:
if np.any(dictIC_params['coeff_weights'] < 0):
logger.error('All coefficients weights must be non-negative')
if dictIC_params['coeff_weights'].size != len(self.DICTIONARY['TRK']['kept']):
dict_kept = np.tile(self.DICTIONARY['TRK']['kept'], self.KERNELS['wmc'].shape[0])
if dictIC_params['coeff_weights'].size != dict_kept.size:
logger.error(f'"coeff_weights" must have the same size as the number of elements in the IC compartment (got {dictIC_params["coeff_weights"].size} but {len(self.DICTIONARY["TRK"]["kept"])} expected)')
dictIC_params['coeff_weights_kept'] = dictIC_params['coeff_weights'][self.DICTIONARY['TRK']['kept']==1]
dictIC_params['coeff_weights_kept'] = dictIC_params['coeff_weights'][dict_kept==1]

# check if group parameters are consistent with the regularisation
if regularisation['regIC'] not in ['group_lasso', 'sparse_group_lasso'] and dictIC_params is not None:
Expand Down

0 comments on commit f15699a

Please sign in to comment.