-
Notifications
You must be signed in to change notification settings - Fork 3
/
kernel_reg.rs
118 lines (90 loc) · 2.99 KB
/
kernel_reg.rs
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
use pointprocesses::estimators::nadarayawatson;
use pointprocesses::estimators::kernels;
use rand::prelude::*;
use std::fs;
use plotters::prelude::*;
use ndarray::Array1;
static TITLE_FONT: &str = "Arial";
static IMG_SIZE: (u32, u32) = (640, 480);
fn main() {
use nadarayawatson::*;
use kernels::*;
use std::f64::consts::PI;
// Actual regression function
let func = |x: &f64| {
(2. * PI * x + 0.1).sin() + 1.5 * x + 1.0 * x * x
};
// Noisy regression data
use ndarray::{Axis, stack};
let x_arr_init = Array1::linspace(0., 1., 20);
let mut x_arr = x_arr_init.clone();
for _ in 0..5 {
x_arr = stack![Axis(0), x_arr, x_arr_init];
}
let ref mut rng = thread_rng();
let normal = rand_distr::StandardNormal;
let sigma = 0.4;
let mut z_arr = x_arr.map(func);
z_arr.mapv_inplace(|y| {
let eps: f64 = normal.sample(rng);
y + sigma * eps
});
// Setup chart
fs::create_dir("examples/images").unwrap_or_default();
let root = BitMapBackend::new(
"lib/examples/images/nadwat_estimator.png",
IMG_SIZE).into_drawing_area();
root.fill(&WHITE).unwrap();
let caption = "Nadaraya-Watson estimator (Gaussian kernel)";
let mut chart = ChartBuilder::on(&root)
.caption(caption, (TITLE_FONT, 20).into_font())
.margin(10)
.x_label_area_size(30)
.y_label_area_size(30)
.build_ranged(-0.05..1.05, -1.0..3.5)
.unwrap();
chart.configure_mesh().draw().unwrap();
let noisy_data = x_arr.iter().zip(z_arr.iter());
let size: u32 = 2;
chart.draw_series(
noisy_data
.map(|(x,y)| {
Circle::new((*x, *y), size, RED.filled())
})
).unwrap();
// Reference function
let x_dense_arr = Array1::linspace(0., 1., 50);
let y_arr = x_dense_arr.map(func);
let data_reference_func = x_dense_arr.iter().zip(y_arr.iter()).map(
|(x,y)| {
(*x, *y)
}
);
let line = LineSeries::new(data_reference_func, &BLUE);
chart.draw_series(line)
.unwrap()
.label("Reference function")
.legend(|(x, y)| Path::new(vec![(x, y), (x + 20, y)], &BLUE));
// Create and fit the NW estimator
let bandwidth = 0.05;
let kernel = GaussianKernel::new(bandwidth);
let estimator = NadWatEstimator::new(kernel).fit(&x_arr, &z_arr);
let x0_predict_points = x_dense_arr.clone();
let y0_predict = x0_predict_points.map(
|x0| estimator.predict(*x0)
);
let predict_data = x0_predict_points.iter().zip(y0_predict.iter())
.map(|(x,y)| (*x, *y)
);
let line_predict = LineSeries::new(
predict_data, &BLACK
);
chart.draw_series(line_predict)
.unwrap()
.label("NW estimates.")
.legend(|(x, y)| Path::new(vec![(x, y), (x + 20, y)], &BLACK));
chart.configure_series_labels()
.background_style(&WHITE.mix(0.8))
.border_style(&BLACK)
.draw().unwrap();
}