diff --git a/commit/core.pyx b/commit/core.pyx index a0cb193..e212966 100644 --- a/commit/core.pyx +++ b/commit/core.pyx @@ -764,6 +764,63 @@ cdef class Evaluation : return y + def set_wLasso_profiles(self, prof_weights, lambda_perc_IC): + """ + Compute array of weights for all the streamlines' profiles and set weighted lasso regularisation. + + Parameters + ---------- + prof_weights - np.array(np.float64) : + array of weights for the profiles. + NB: this array must have the same size as the number of profiles set before. + NB: the weights must be greater or equal to 1. + + lambda_perc_IC - float : + percentage of the maximum value of the regularisation parameter for the IC compartment. + NB: lambda_perc_IC must be a float greater than 0. + """ + + # check if all the necessary functions have been called + if self.niiDWI is None : + logger.error( 'Data not loaded; call "load_data()" first' ) + if self.DICTIONARY is None : + logger.error( 'Dictionary not loaded; call "load_dictionary()" first' ) + if self.KERNELS is None : + logger.error( 'Response functions not generated; call "generate_kernels()" and "load_kernels()" first' ) + if self.THREADS is None : + logger.error( 'Threads not set; call "set_threads()" first' ) + if self.A is None : + logger.error( 'Operator not built; call "build_operator()" first' ) + + if self.DICTIONARY['IC']['nF'] <= 0 : + logger.error( 'No streamline found in the dictionary; check your data' ) + + # regularisation['sizeIC'] = int( self.DICTIONARY['IC']['nF'] * self.KERNELS['wmr'].shape[0] * self.KERNELS['wmc'].shape[0]) + # self.DICTIONARY['IC']['nF'] = number of kept streamlines + # self.KERNELS['wmr'] = array with shape (1, ndirs, nS), where nS is the number of samples in the scheme + # self.KERNELS['wmc'] = array with shape (num_prof, num_samples) + + # check prof_weights + if type(prof_weights) not in [list, np.ndarray]: + logger.error('"prof_weights" must be a list or a numpy array') + prof_weights = np.array(prof_weights, dtype=np.float64) + if np.any(prof_weights < 1): + logger.error('All group weights must be greater or equal to 1') + if prof_weights.size != self.KERNELS['wmc'].shape[0]: + logger.error( f'The number of profiles\' weights in the provided array does not match the number of profiles set in "load_kernels()" (got {prof_weights.size} but {self.KERNELS["wmc"].shape[0]} expected)' ) + + # compute complete array of weights + prof_weights_all = np.repeat(prof_weights, len(self.DICTIONARY["TRK"]["kept"])) + dict_ic = {} + dict_ic['coeff_weights'] = prof_weights_all + + # set reg + self.set_regularisation( + regularisers = ('lasso', None, None), + lambdas = (lambda_perc_IC, None, None), + is_nonnegative = (False, True, True), + params = (dict_ic, None, None)) + def set_regularisation(self, regularisers=(None, None, None), lambdas=(None, None, None), is_nonnegative=(True, True, True), params=(None, None, None)): """ @@ -960,9 +1017,10 @@ cdef class Evaluation : if dictIC_params is not None and 'coeff_weights' in dictIC_params: if np.any(dictIC_params['coeff_weights'] < 0): logger.error('All coefficients weights must be non-negative') - if dictIC_params['coeff_weights'].size != len(self.DICTIONARY['TRK']['kept']): + dict_kept = np.tile(self.DICTIONARY['TRK']['kept'], self.KERNELS['wmc'].shape[0]) + if dictIC_params['coeff_weights'].size != dict_kept.size: logger.error(f'"coeff_weights" must have the same size as the number of elements in the IC compartment (got {dictIC_params["coeff_weights"].size} but {len(self.DICTIONARY["TRK"]["kept"])} expected)') - dictIC_params['coeff_weights_kept'] = dictIC_params['coeff_weights'][self.DICTIONARY['TRK']['kept']==1] + dictIC_params['coeff_weights_kept'] = dictIC_params['coeff_weights'][dict_kept==1] # check if group parameters are consistent with the regularisation if regularisation['regIC'] not in ['group_lasso', 'sparse_group_lasso'] and dictIC_params is not None: