-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
add first unit tests for hcnn and sensitivity analysis
- Loading branch information
Julia Schemm
committed
Aug 16, 2024
1 parent
247c20b
commit 1b6debc
Showing
4 changed files
with
293 additions
and
0 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
import torch | ||
from prosper_nn.models.hcnn import HCNN | ||
import pytest | ||
|
||
|
||
class TestHcnn: | ||
@pytest.mark.parametrize( | ||
"n_state_neurons, n_features_Y, past_horizon, forecast_horizon, batchsize, sparsity, teacher_forcing, backward_full_Y, ptf_in_backward", | ||
[ | ||
(10, 4, 5, 2, 1, 0, 1, True, True), | ||
(1, 1, 1, 1, 1, 0, 1, True, True), | ||
(10, 4, 5, 2, 1, 0.5, 1, True, True), | ||
(10, 4, 5, 2, 1, 0, 0.5, False, True), | ||
(10, 4, 5, 2, 1, 0, 0.5, False, False), | ||
(10, 4, 5, 2, 1, 0, 0.5, True, False), | ||
(10, 4, 5, 2, 1, 0, 1, True, True), | ||
], | ||
) | ||
def test_forward(self, n_state_neurons, n_features_Y, past_horizon, forecast_horizon, batchsize, sparsity, teacher_forcing, backward_full_Y, ptf_in_backward): | ||
hcnn = HCNN( | ||
n_state_neurons=n_state_neurons, | ||
n_features_Y=n_features_Y, | ||
past_horizon=past_horizon, | ||
forecast_horizon=forecast_horizon, | ||
sparsity=sparsity, | ||
teacher_forcing=teacher_forcing, | ||
backward_full_Y=backward_full_Y, | ||
ptf_in_backward=ptf_in_backward, | ||
) | ||
observation = torch.zeros(past_horizon, batchsize, n_features_Y) | ||
output_ = hcnn(observation) | ||
|
||
assert output_.shape == torch.Size((past_horizon + forecast_horizon, batchsize, n_features_Y)) | ||
assert isinstance(output_, torch.Tensor) | ||
assert not (output_.isnan()).any() | ||
|
||
@pytest.mark.parametrize( | ||
"n_state_neurons, past_horizon, forecast_horizon, batchsize, sparsity, teacher_forcing, backward_full_Y, ptf_in_backward", | ||
[ | ||
(5, 50, 5, 1, 0, 1, True, True), | ||
], | ||
) | ||
def test_train(self, n_state_neurons, past_horizon, forecast_horizon, batchsize, sparsity, teacher_forcing, backward_full_Y, ptf_in_backward): | ||
n_features_Y = 1 | ||
n_epochs = 10000 | ||
|
||
hcnn = HCNN( | ||
n_state_neurons=n_state_neurons, | ||
n_features_Y=n_features_Y, | ||
past_horizon=past_horizon, | ||
forecast_horizon=forecast_horizon, | ||
sparsity=sparsity, | ||
teacher_forcing=teacher_forcing, | ||
backward_full_Y=backward_full_Y, | ||
ptf_in_backward=ptf_in_backward, | ||
) | ||
observation = torch.zeros(past_horizon, batchsize, n_features_Y) | ||
observation = torch.sin(torch.linspace(0.5, 10 * torch.pi, past_horizon + forecast_horizon)) | ||
observation = observation.unsqueeze(1).unsqueeze(1) | ||
|
||
optimizer = torch.optim.Adam(hcnn.parameters(), lr=0.001) | ||
target = torch.zeros_like(observation[:past_horizon]) | ||
loss_fct = torch.nn.MSELoss() | ||
|
||
start_weight = hcnn.HCNNCell.A.weight.clone() | ||
|
||
for epoch in range(n_epochs): | ||
output_ = hcnn(observation[:past_horizon]) | ||
loss = loss_fct(output_[:past_horizon], target) | ||
loss.backward() | ||
assert hcnn.HCNNCell.A.weight.grad is not None | ||
optimizer.step() | ||
if epoch == 1: | ||
start_loss = loss.detach() | ||
assert (hcnn.HCNNCell.A.weight != start_weight).all() | ||
hcnn.zero_grad() | ||
|
||
forecast = hcnn(observation[:past_horizon])[past_horizon:] | ||
assert loss < start_loss | ||
assert torch.isclose(observation[past_horizon:], forecast, atol=1).all() | ||
|
||
@pytest.mark.parametrize("teacher_forcing, decrease_teacher_forcing, result", [(1, 0, 1), (1, 0.2, 0.8), (0, 0.1, 0)],) | ||
def test_adjust_teacher_forcing(self, teacher_forcing, decrease_teacher_forcing, result): | ||
hcnn = HCNN( | ||
n_state_neurons=10, | ||
n_features_Y=2, | ||
past_horizon=10, | ||
forecast_horizon=5, | ||
teacher_forcing=teacher_forcing, | ||
decrease_teacher_forcing=decrease_teacher_forcing) | ||
hcnn.adjust_teacher_forcing() | ||
assert hcnn.HCNNCell.teacher_forcing == result | ||
assert hcnn.teacher_forcing == result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
import torch | ||
from prosper_nn.models.hcnn.hcnn_cell import HCNNCell, PartialTeacherForcing | ||
|
||
|
||
class TestPartialTeacherForcing: | ||
ptf = PartialTeacherForcing(p=0.5) | ||
|
||
def test_evaluation(self): | ||
self.ptf.eval() | ||
input = torch.randn((20, 1, 100)) | ||
|
||
output = self.ptf(input) | ||
# fill dropped nodes | ||
output = torch.where(output == 0, input, output) | ||
assert (output == input).all() | ||
|
||
def test_train(self): | ||
self.ptf.train() | ||
input = torch.randn((20, 1, 100)) | ||
|
||
output = self.ptf(input) | ||
# fill dropped nodes | ||
output = torch.where(output == 0, input, output) | ||
assert (output == input).all() | ||
|
||
|
||
class TestHcnnCell: | ||
n_state_neurons = 10 | ||
n_features_Y = 5 | ||
batchsize = 7 | ||
|
||
hcnn_cell = HCNNCell( | ||
n_state_neurons=n_state_neurons, | ||
n_features_Y=n_features_Y, | ||
) | ||
hcnn_cell.A.weight = torch.nn.Parameter(torch.ones_like(hcnn_cell.A.weight)) | ||
state = 0.5 * torch.ones((batchsize, n_state_neurons)) | ||
expectation = state[..., :n_features_Y] | ||
observation = torch.ones(batchsize, n_features_Y) | ||
|
||
def test_get_teacher_forcing_full_Y(self): | ||
self.hcnn_cell.ptf_dropout.p = 0 | ||
output_, teacher_forcing_ = self.hcnn_cell.get_teacher_forcing_full_Y( | ||
self.observation, self.expectation | ||
) | ||
self.checks_get_teacher_forcing(output_, teacher_forcing_) | ||
|
||
### with partial teacher forcing | ||
self.hcnn_cell.ptf_dropout.p = 0.5 | ||
output_, teacher_forcing_ = self.hcnn_cell.get_teacher_forcing_full_Y( | ||
self.observation, self.expectation | ||
) | ||
|
||
# fill dropped nodes | ||
teacher_forcing_[..., : self.n_features_Y] = torch.where( | ||
teacher_forcing_[..., : self.n_features_Y] == 0, | ||
-0.5, | ||
teacher_forcing_[..., : self.n_features_Y], | ||
) | ||
|
||
self.checks_get_teacher_forcing(output_, teacher_forcing_) | ||
|
||
def test_get_teacher_forcing_partial_Y(self): | ||
self.hcnn_cell.ptf_dropout.p = 0 | ||
output_, teacher_forcing_ = self.hcnn_cell.get_teacher_forcing_partial_Y( | ||
self.observation, self.expectation | ||
) | ||
self.checks_get_teacher_forcing(output_, teacher_forcing_) | ||
|
||
### with partial teacher forcing | ||
self.hcnn_cell.ptf_dropout.p = 0.5 | ||
output_, teacher_forcing_ = self.hcnn_cell.get_teacher_forcing_partial_Y( | ||
self.observation, self.expectation | ||
) | ||
# fill dropped nodes | ||
teacher_forcing_[..., : self.n_features_Y] = torch.where( | ||
teacher_forcing_[..., : self.n_features_Y] == 0, | ||
-0.5, | ||
teacher_forcing_[..., : self.n_features_Y], | ||
) | ||
output_ = torch.where(output_ == 0, -0.5, output_) | ||
self.checks_get_teacher_forcing(output_, teacher_forcing_) | ||
|
||
def checks_get_teacher_forcing(self, output_, teacher_forcing_): | ||
assert (output_ == -0.5 * torch.ones(self.batchsize, self.n_features_Y)).all() | ||
assert (teacher_forcing_[..., : self.n_features_Y] == -self.expectation).all() | ||
assert (teacher_forcing_[..., self.n_features_Y :] == 0).all() | ||
assert ( | ||
(self.expectation - teacher_forcing_[..., : self.n_features_Y]) | ||
== self.observation | ||
).all() | ||
|
||
def test_forward(self): | ||
state_, output_ = self.hcnn_cell.forward(self.state) | ||
self.checks_forward(state_, output_) | ||
|
||
state_, output_ = self.hcnn_cell.forward(self.state, self.observation) | ||
self.checks_forward(state_, output_) | ||
|
||
def test_forward_past_horizon(self): | ||
state_, output_ = self.hcnn_cell.forward_past_horizon( | ||
self.state, self.observation, self.expectation | ||
) | ||
self.checks_forward(state_, output_) | ||
|
||
def test_forward_forecast_horizon(self): | ||
state_, output_ = self.hcnn_cell.forward_forecast_horizon( | ||
self.state, self.expectation | ||
) | ||
self.checks_forward(state_, output_) | ||
|
||
def checks_forward(self, state_, output_): | ||
assert state_.shape == torch.Size((self.batchsize, self.n_state_neurons)) | ||
assert output_.shape == torch.Size((self.batchsize, self.n_features_Y)) | ||
assert not (state_.isnan()).any() | ||
assert not (output_.isnan()).any() |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,83 @@ | ||
import torch | ||
from prosper_nn.utils import sensitivity_analysis | ||
from prosper_nn.models.hcnn import HCNN | ||
|
||
|
||
def test_sensitivity_analysis(): | ||
in_features = 10 | ||
out_features = 5 | ||
batchsize = 3 | ||
n_batches = 4 | ||
model = torch.nn.Linear(in_features=in_features, out_features=out_features) | ||
data = torch.randn(n_batches, batchsize, in_features) | ||
sensi = sensitivity_analysis.sensitivity_analysis( | ||
model, data=data, output_neuron=(slice(0, batchsize), 0), batchsize=batchsize | ||
) | ||
assert isinstance(sensi, torch.Tensor) | ||
assert sensi.shape == torch.Size((batchsize * n_batches, in_features)) | ||
|
||
|
||
def test_calculate_sensitivity_analysis(): | ||
in_features = 10 | ||
out_features = 5 | ||
batchsize = 3 | ||
n_batches = 4 | ||
model = torch.nn.Linear(in_features=in_features, out_features=out_features) | ||
data = torch.randn(n_batches, batchsize, in_features) | ||
sensi = sensitivity_analysis.calculate_sensitivity_analysis( | ||
model, data, output_neuron=(slice(0, batchsize), 0), batchsize=batchsize | ||
) | ||
sensi = sensi.reshape((sensi.shape[0], -1)) | ||
assert isinstance(sensi, torch.Tensor) | ||
assert sensi.shape == torch.Size((batchsize * n_batches, in_features)) | ||
|
||
|
||
def test_plot_sensitivity_curve(): | ||
in_features = 5 | ||
samples = 10 | ||
sensi = torch.randn(samples, in_features) | ||
sensitivity_analysis.plot_sensitivity_curve(sensi, output_neuron=1) | ||
|
||
|
||
def test_analyse_temporal_sensitivity(): | ||
n_features_Y = 3 | ||
n_state_neurons = 5 | ||
batchsize = 2 | ||
past_horizon = 4 | ||
forecast_horizon = 3 | ||
task_nodes = [0, 1] | ||
|
||
model = HCNN( | ||
n_features_Y=n_features_Y, | ||
n_state_neurons=n_state_neurons, | ||
past_horizon=past_horizon, | ||
forecast_horizon=forecast_horizon, | ||
) | ||
data = torch.randn(past_horizon, batchsize, n_features_Y) | ||
|
||
sensi = sensitivity_analysis.analyse_temporal_sensitivity( | ||
model, | ||
data=data, | ||
task_nodes=task_nodes, | ||
n_future_steps=forecast_horizon, | ||
past_horizon=past_horizon, | ||
n_features=input_size, | ||
) | ||
assert isinstance(sensi, torch.Tensor) | ||
assert sensi.shape == torch.Size((len(task_nodes), forecast_horizon, n_features_Y)) | ||
|
||
|
||
def test_plot_analyse_temporal_sensitivity(): | ||
n_features_Y = 3 | ||
n_target_vars = 2 | ||
forecast_horizon = 3 | ||
target_var = [f"target_var_{i}" for i in range(n_target_vars)] | ||
features = [f"feat_{i}" for i in range(n_features_Y)] | ||
|
||
sensis = torch.randn(len(target_var), forecast_horizon, n_features_Y) | ||
sensitivity_analysis.plot_analyse_temporal_sensitivity( | ||
sensis, | ||
target_var, | ||
features, | ||
n_future_steps=forecast_horizon, | ||
) |