forked from koyeongmin/PINet
-
Notifications
You must be signed in to change notification settings - Fork 0
/
evaluation.py
102 lines (93 loc) · 3.93 KB
/
evaluation.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
############################################################
##
## This is cloned from official tuSimple datat evaluation code
## https://github.com/TuSimple/tusimple-benchmark/blob/master/evaluate/lane.py
##
############################################################
import numpy as np
from sklearn.linear_model import LinearRegression
import ujson as json
class LaneEval(object):
lr = LinearRegression()
pixel_thresh = 20
pt_thresh = 0.85
@staticmethod
def get_angle(xs, y_samples):
xs, ys = xs[xs >= 0], y_samples[xs >= 0]
if len(xs) > 1:
LaneEval.lr.fit(ys[:, None], xs)
k = LaneEval.lr.coef_[0]
theta = np.arctan(k)
else:
theta = 0
return theta
@staticmethod
def line_accuracy(pred, gt, thresh):
pred = np.array([p if p >= 0 else -100 for p in pred])
gt = np.array([g if g >= 0 else -100 for g in gt])
return np.sum(np.where(np.abs(pred - gt) < thresh, 1., 0.)) / len(gt)
@staticmethod
def bench(pred, gt, y_samples, running_time):
if any(len(p) != len(y_samples) for p in pred):
raise Exception('Format of lanes error.')
if running_time > 200 or len(gt) + 2 < len(pred):
return 0., 0., 1.
angles = [LaneEval.get_angle(np.array(x_gts), np.array(y_samples)) for x_gts in gt]
threshs = [LaneEval.pixel_thresh / np.cos(angle) for angle in angles]
line_accs = []
fp, fn = 0., 0.
matched = 0.
for x_gts, thresh in zip(gt, threshs):
accs = [LaneEval.line_accuracy(np.array(x_preds), np.array(x_gts), thresh) for x_preds in pred]
max_acc = np.max(accs) if len(accs) > 0 else 0.
if max_acc < LaneEval.pt_thresh:
fn += 1
else:
matched += 1
line_accs.append(max_acc)
fp = len(pred) - matched
if len(gt) > 4 and fn > 0:
fn -= 1
s = sum(line_accs)
if len(gt) > 4:
s -= min(line_accs)
return s / max(min(4.0, len(gt)), 1.), fp / len(pred) if len(pred) > 0 else 0., fn / max(min(len(gt), 4.) , 1.)
@staticmethod
def bench_one_submit(pred_file, gt_file):
try:
json_pred = [json.loads(line) for line in open(pred_file).readlines()]
except BaseException as e:
raise Exception('Fail to load json file of the prediction.')
json_gt = [json.loads(line) for line in open(gt_file).readlines()]
if len(json_gt) != len(json_pred):
raise Exception('We do not get the predictions of all the test tasks')
gts = {l['raw_file']: l for l in json_gt}
accuracy, fp, fn = 0., 0., 0.
for pred in json_pred:
if 'raw_file' not in pred or 'lanes' not in pred or 'run_time' not in pred:
raise Exception('raw_file or lanes or run_time not in some predictions.')
raw_file = pred['raw_file']
pred_lanes = pred['lanes']
run_time = pred['run_time']
if raw_file not in gts:
raise Exception('Some raw_file from your predictions do not exist in the test tasks.')
gt = gts[raw_file]
gt_lanes = gt['lanes']
y_samples = gt['h_samples']
try:
a, p, n = LaneEval.bench(pred_lanes, gt_lanes, y_samples, run_time)
except BaseException as e:
raise Exception('Format of lanes error.')
accuracy += a
fp += p
fn += n
num = len(gts)
# the first return parameter is the default ranking parameter
return json.dumps([
{'name': 'Accuracy', 'value': accuracy / num, 'order': 'desc'},
{'name': 'FP', 'value': fp / num, 'order': 'asc'},
{'name': 'FN', 'value': fn / num, 'order': 'asc'}
])
if __name__ == '__main__':
import sys
print(LaneEval.bench_one_submit("test_result.json", "test_label.json"))