Skip to content

Commit

Permalink
chore: fix save_results weights with blur
Browse files Browse the repository at this point in the history
  • Loading branch information
fullbat committed Aug 27, 2024
1 parent fe3c736 commit 3cc13b4
Showing 1 changed file with 97 additions and 1 deletion.
98 changes: 97 additions & 1 deletion commit/core.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -1766,6 +1766,99 @@ cdef class Evaluation :
if dictionary_info['blur_gauss_extent'] > 0 or dictionary_info['blur_core_extent'] > 0:
xic[ self.DICTIONARY['TRK']['kept']==1 ] *= self.DICTIONARY['TRK']['lenTot'] / self.DICTIONARY['TRK']['len']

if self.KERNELS['wmc'].shape[0] > 1:
if "tractogram_centr_idx" in dictionary_info.keys():
ordered_idx = dictionary_info["tractogram_centr_idx"].astype(np.int64)
unravel_weights = np.zeros( (dictionary_info['n_count'], self.KERNELS['wmc'].shape[1]), dtype=np.float64)
unravel_weights[ordered_idx, 0] = self.DICTIONARY['TRK']['kept'].astype(np.float64)
temp_weights = unravel_weights[ordered_idx]
idx_temp_weights = np.where(temp_weights>0)[0]

# retrieve the contribution of each profile for each streamline
num_prof = self.KERNELS['wmc'].shape[0]
fib_w = []
for i in range(nF):
start_idx = i
subarray = []
for j in range(num_prof):
index = start_idx + j * nF
subarray.append(xic[index])
fib_w.append(subarray)

streamline_profs = []
for streamline_idx in range(nF):
streamline_prof = np.zeros(self.KERNELS['wmc'].shape[1])
bundle_prof = np.zeros(self.KERNELS['wmc'].shape[1])
for bf_idx in range(self.KERNELS['wmc'].shape[0]):
bf = self.KERNELS['wmc'][bf_idx]
bundle_prof += bf * fib_w[streamline_idx][bf_idx]
streamline_profs.append(bundle_prof)

if dictionary_info['blur_gauss_extent'] > 0 or dictionary_info['blur_core_extent'] > 0:
for i in range(nF):
streamline_profs[i] *= self.DICTIONARY['TRK']['lenTot'][i] / self.DICTIONARY['TRK']['len'][i]

st_i = 0
for idx in idx_temp_weights:
temp_weights[idx] = streamline_profs[st_i]
st_i += 1
unravel_weights[ordered_idx] = temp_weights
xic = unravel_weights
else:
st_i = 0
for idx in idx_temp_weights:
temp_weights[idx] = streamline_profs[st_i]
st_i += 1
unravel_weights[ordered_idx] = temp_weights
xic = unravel_weights

elif "tractogram_centr_idx" not in dictionary_info.keys() and ( dictionary_info['blur_gauss_extent'] > 0 or dictionary_info['blur_core_extent'] > 0):

xic = np.zeros(self.DICTIONARY['TRK']['kept'].size)
num_prof = self.KERNELS['wmc'].shape[0]
fib_w = []
for i in range(nF):
start_idx = i
subarray = []
for j in range(num_prof):
index = start_idx + j * nF
subarray.append(xic[index])
fib_w.append(subarray)

streamline_profs = []
for streamline_idx in range(nF):
streamline_prof = np.zeros(self.KERNELS['wmc'].shape[1])
bundle_prof = np.zeros(self.KERNELS['wmc'].shape[1])
for bf_idx in range(self.KERNELS['wmc'].shape[0]):
bf = self.KERNELS['wmc'][bf_idx]
bundle_prof += bf * fib_w[streamline_idx][bf_idx]
streamline_profs.append(bundle_prof)

xic[ self.DICTIONARY['TRK']['kept']==1 ] = streamline_profs * self.DICTIONARY['TRK']['lenTot'] / self.DICTIONARY['TRK']['len']

else:
num_prof = self.KERNELS['wmc'].shape[0]
fib_w = []
for i in range(nF):
start_idx = i
subarray = []
for j in range(num_prof):
index = start_idx + j * nF
subarray.append(xic[index])
fib_w.append(subarray)

streamline_profs = []
for streamline_idx in range(nF):
streamline_prof = np.zeros(self.KERNELS['wmc'].shape[1])
bundle_prof = np.zeros(self.KERNELS['wmc'].shape[1])
for bf_idx in range(self.KERNELS['wmc'].shape[0]):
bf = self.KERNELS['wmc'][bf_idx]
bundle_prof += bf * fib_w[streamline_idx][bf_idx]
streamline_profs.append(bundle_prof)

xic = np.zeros( (self.DICTIONARY['TRK']['kept'].size, self.KERNELS['wmc'].shape[1]) )
xic[ self.DICTIONARY['TRK']['kept']==1 ] = streamline_profs


self.temp_data['DICTIONARY'] = self.DICTIONARY
self.temp_data['niiIC_img'] = niiIC_img
Expand All @@ -1778,7 +1871,10 @@ cdef class Evaluation :
if hasattr(self.model, '_postprocess') and do_reweighting:
self.model._postprocess(self.temp_data, verbose=self.verbose)

np.savetxt( pjoin(RESULTS_path,'streamline_weights.txt'), xic, fmt=coeffs_format )
if self.KERNELS['wmc'].shape[0] > 1:
np.save( pjoin(RESULTS_path,'streamline_weights.npy'), xic )
else:
np.savetxt( pjoin(RESULTS_path,'streamline_weights.txt'), xic, fmt=coeffs_format )
self.set_config('stat_coeffs', stat_coeffs)

# Save to a pickle file the following items:
Expand Down

0 comments on commit 3cc13b4

Please sign in to comment.