forked from KamitaniLab/GenericObjectDecoding
-
Notifications
You must be signed in to change notification settings - Fork 0
/
analysis_CategoryIdentification.py
124 lines (86 loc) · 3.81 KB
/
analysis_CategoryIdentification.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
'''
Object category identification
This file is a part of GenericDecoding_demo.
'''
from __future__ import print_function
import os
import pickle
import numpy as np
import pandas as pd
import bdpy
from bdpy.stats import corrmat
import god_config as config
# Main #################################################################
def main():
results_dir = config.results_dir
output_file = config.results_file
image_feature_file = config.image_feature_file
# Load results -----------------------------------------------------
print('Loading %s' % output_file)
with open(output_file, 'rb') as f:
results = pickle.load(f)
data_feature = bdpy.BData(image_feature_file)
# Category identification ------------------------------------------
print('Running pair-wise category identification')
feature_list = results['feature']
pred_percept = results['predicted_feature_averaged_percept']
pred_imagery = results['predicted_feature_averaged_imagery']
cat_label_percept = results['category_label_set_percept']
cat_label_imagery = results['category_label_set_imagery']
cat_feature_percept = results['category_feature_averaged_percept']
cat_feature_imagery = results['category_feature_averaged_imagery']
ind_cat_other = (data_feature.select('FeatureType') == 4).flatten()
pwident_cr_pt = [] # Prop correct in pair-wise identification (perception)
pwident_cr_im = [] # Prop correct in pair-wise identification (imagery)
for f, fpt, fim, pred_pt, pred_im in zip(feature_list, cat_feature_percept, cat_feature_imagery,
pred_percept, pred_imagery):
feat_other = data_feature.select(f)[ind_cat_other, :]
n_unit = fpt.shape[1]
feat_other = feat_other[:, :n_unit]
feat_candidate_pt = np.vstack([fpt, feat_other])
feat_candidate_im = np.vstack([fim, feat_other])
simmat_pt = corrmat(pred_pt, feat_candidate_pt)
simmat_im = corrmat(pred_im, feat_candidate_im)
cr_pt = get_pwident_correctrate(simmat_pt)
cr_im = get_pwident_correctrate(simmat_im)
pwident_cr_pt.append(np.mean(cr_pt))
pwident_cr_im.append(np.mean(cr_im))
results['catident_correct_rate_percept'] = pwident_cr_pt
results['catident_correct_rate_imagery'] = pwident_cr_im
# Save the merged dataframe ----------------------------------------
with open(output_file, 'wb') as f:
pickle.dump(results, f)
print('Saved %s' % output_file)
# Show results -----------------------------------------------------
tb_pt = pd.pivot_table(results, index=['roi'], columns=['feature'],
values=['catident_correct_rate_percept'], aggfunc=np.mean)
tb_im = pd.pivot_table(results, index=['roi'], columns=['feature'],
values=['catident_correct_rate_imagery'], aggfunc=np.mean)
print(tb_pt)
print(tb_im)
# Functions ############################################################
def get_pwident_correctrate(simmat):
'''
Returns correct rate in pairwise identification
Parameters
----------
simmat : numpy array [num_prediction * num_category]
Similarity matrix
Returns
-------
correct_rate : correct rate of pair-wise identification
'''
num_pred = simmat.shape[0]
labels = range(num_pred)
correct_rate = []
for i in xrange(num_pred):
pred_feat = simmat[i, :]
correct_feat = pred_feat[labels[i]]
pred_num = len(pred_feat) - 1
correct_rate.append((pred_num - np.sum(pred_feat > correct_feat)) / float(pred_num))
return correct_rate
# Run as a scirpt ######################################################
if __name__ == '__main__':
# To avoid any use of global variables,
# do nothing except calling main() here
main()