diff --git a/commit/core.pyx b/commit/core.pyx index 94c7ae9..a7d7b30 100644 --- a/commit/core.pyx +++ b/commit/core.pyx @@ -1766,6 +1766,99 @@ cdef class Evaluation : if dictionary_info['blur_gauss_extent'] > 0 or dictionary_info['blur_core_extent'] > 0: xic[ self.DICTIONARY['TRK']['kept']==1 ] *= self.DICTIONARY['TRK']['lenTot'] / self.DICTIONARY['TRK']['len'] + if self.KERNELS['wmc'].shape[0] > 1: + if "tractogram_centr_idx" in dictionary_info.keys(): + ordered_idx = dictionary_info["tractogram_centr_idx"].astype(np.int64) + unravel_weights = np.zeros( (dictionary_info['n_count'], self.KERNELS['wmc'].shape[1]), dtype=np.float64) + unravel_weights[ordered_idx, 0] = self.DICTIONARY['TRK']['kept'].astype(np.float64) + temp_weights = unravel_weights[ordered_idx] + idx_temp_weights = np.where(temp_weights>0)[0] + + # retrieve the contribution of each profile for each streamline + num_prof = self.KERNELS['wmc'].shape[0] + fib_w = [] + for i in range(nF): + start_idx = i + subarray = [] + for j in range(num_prof): + index = start_idx + j * nF + subarray.append(xic[index]) + fib_w.append(subarray) + + streamline_profs = [] + for streamline_idx in range(nF): + streamline_prof = np.zeros(self.KERNELS['wmc'].shape[1]) + bundle_prof = np.zeros(self.KERNELS['wmc'].shape[1]) + for bf_idx in range(self.KERNELS['wmc'].shape[0]): + bf = self.KERNELS['wmc'][bf_idx] + bundle_prof += bf * fib_w[streamline_idx][bf_idx] + streamline_profs.append(bundle_prof) + + if dictionary_info['blur_gauss_extent'] > 0 or dictionary_info['blur_core_extent'] > 0: + for i in range(nF): + streamline_profs[i] *= self.DICTIONARY['TRK']['lenTot'][i] / self.DICTIONARY['TRK']['len'][i] + + st_i = 0 + for idx in idx_temp_weights: + temp_weights[idx] = streamline_profs[st_i] + st_i += 1 + unravel_weights[ordered_idx] = temp_weights + xic = unravel_weights + else: + st_i = 0 + for idx in idx_temp_weights: + temp_weights[idx] = streamline_profs[st_i] + st_i += 1 + unravel_weights[ordered_idx] = temp_weights + xic = unravel_weights + + elif "tractogram_centr_idx" not in dictionary_info.keys() and ( dictionary_info['blur_gauss_extent'] > 0 or dictionary_info['blur_core_extent'] > 0): + + xic = np.zeros(self.DICTIONARY['TRK']['kept'].size) + num_prof = self.KERNELS['wmc'].shape[0] + fib_w = [] + for i in range(nF): + start_idx = i + subarray = [] + for j in range(num_prof): + index = start_idx + j * nF + subarray.append(xic[index]) + fib_w.append(subarray) + + streamline_profs = [] + for streamline_idx in range(nF): + streamline_prof = np.zeros(self.KERNELS['wmc'].shape[1]) + bundle_prof = np.zeros(self.KERNELS['wmc'].shape[1]) + for bf_idx in range(self.KERNELS['wmc'].shape[0]): + bf = self.KERNELS['wmc'][bf_idx] + bundle_prof += bf * fib_w[streamline_idx][bf_idx] + streamline_profs.append(bundle_prof) + + xic[ self.DICTIONARY['TRK']['kept']==1 ] = streamline_profs * self.DICTIONARY['TRK']['lenTot'] / self.DICTIONARY['TRK']['len'] + + else: + num_prof = self.KERNELS['wmc'].shape[0] + fib_w = [] + for i in range(nF): + start_idx = i + subarray = [] + for j in range(num_prof): + index = start_idx + j * nF + subarray.append(xic[index]) + fib_w.append(subarray) + + streamline_profs = [] + for streamline_idx in range(nF): + streamline_prof = np.zeros(self.KERNELS['wmc'].shape[1]) + bundle_prof = np.zeros(self.KERNELS['wmc'].shape[1]) + for bf_idx in range(self.KERNELS['wmc'].shape[0]): + bf = self.KERNELS['wmc'][bf_idx] + bundle_prof += bf * fib_w[streamline_idx][bf_idx] + streamline_profs.append(bundle_prof) + + xic = np.zeros( (self.DICTIONARY['TRK']['kept'].size, self.KERNELS['wmc'].shape[1]) ) + xic[ self.DICTIONARY['TRK']['kept']==1 ] = streamline_profs + self.temp_data['DICTIONARY'] = self.DICTIONARY self.temp_data['niiIC_img'] = niiIC_img @@ -1778,7 +1871,10 @@ cdef class Evaluation : if hasattr(self.model, '_postprocess') and do_reweighting: self.model._postprocess(self.temp_data, verbose=self.verbose) - np.savetxt( pjoin(RESULTS_path,'streamline_weights.txt'), xic, fmt=coeffs_format ) + if self.KERNELS['wmc'].shape[0] > 1: + np.save( pjoin(RESULTS_path,'streamline_weights.npy'), xic ) + else: + np.savetxt( pjoin(RESULTS_path,'streamline_weights.txt'), xic, fmt=coeffs_format ) self.set_config('stat_coeffs', stat_coeffs) # Save to a pickle file the following items: