-
Notifications
You must be signed in to change notification settings - Fork 3
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #31 from Fraunhofer-IIS/27-include-unit-tests
add first unit tests for hcnn and sensitivity analysis
- Loading branch information
Showing
5 changed files
with
396 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,94 @@ | ||
import torch | ||
from prosper_nn.models.hcnn import HCNN | ||
import pytest | ||
|
||
|
||
class TestHcnn: | ||
@pytest.mark.parametrize( | ||
"n_state_neurons, n_features_Y, past_horizon, forecast_horizon, batchsize, sparsity, teacher_forcing, backward_full_Y, ptf_in_backward", | ||
[ | ||
(10, 4, 5, 2, 1, 0, 1, True, True), | ||
(1, 1, 1, 1, 1, 0, 1, True, True), | ||
(10, 4, 5, 2, 1, 0.5, 1, True, True), | ||
(10, 4, 5, 2, 1, 0, 0.5, False, True), | ||
(10, 4, 5, 2, 1, 0, 0.5, False, False), | ||
(10, 4, 5, 2, 1, 0, 0.5, True, False), | ||
(10, 4, 5, 2, 1, 0, 1, True, True), | ||
], | ||
) | ||
def test_forward(self, n_state_neurons, n_features_Y, past_horizon, forecast_horizon, batchsize, sparsity, teacher_forcing, backward_full_Y, ptf_in_backward): | ||
hcnn = HCNN( | ||
n_state_neurons=n_state_neurons, | ||
n_features_Y=n_features_Y, | ||
past_horizon=past_horizon, | ||
forecast_horizon=forecast_horizon, | ||
sparsity=sparsity, | ||
teacher_forcing=teacher_forcing, | ||
backward_full_Y=backward_full_Y, | ||
ptf_in_backward=ptf_in_backward, | ||
) | ||
observation = torch.zeros(past_horizon, batchsize, n_features_Y) | ||
output_ = hcnn(observation) | ||
|
||
assert output_.shape == torch.Size((past_horizon + forecast_horizon, batchsize, n_features_Y)) | ||
assert isinstance(output_, torch.Tensor) | ||
assert not (output_.isnan()).any() | ||
|
||
@pytest.mark.parametrize( | ||
"n_state_neurons, past_horizon, forecast_horizon, batchsize, sparsity, teacher_forcing, backward_full_Y, ptf_in_backward", | ||
[ | ||
(5, 50, 5, 1, 0, 1, True, True), | ||
], | ||
) | ||
def test_train(self, n_state_neurons, past_horizon, forecast_horizon, batchsize, sparsity, teacher_forcing, backward_full_Y, ptf_in_backward): | ||
n_features_Y = 1 | ||
n_epochs = 10000 | ||
|
||
hcnn = HCNN( | ||
n_state_neurons=n_state_neurons, | ||
n_features_Y=n_features_Y, | ||
past_horizon=past_horizon, | ||
forecast_horizon=forecast_horizon, | ||
sparsity=sparsity, | ||
teacher_forcing=teacher_forcing, | ||
backward_full_Y=backward_full_Y, | ||
ptf_in_backward=ptf_in_backward, | ||
) | ||
observation = torch.zeros(past_horizon, batchsize, n_features_Y) | ||
observation = torch.sin(torch.linspace(0.5, 10 * torch.pi, past_horizon + forecast_horizon)) | ||
observation = observation.unsqueeze(1).unsqueeze(1) | ||
|
||
optimizer = torch.optim.Adam(hcnn.parameters(), lr=0.001) | ||
target = torch.zeros_like(observation[:past_horizon]) | ||
loss_fct = torch.nn.MSELoss() | ||
|
||
start_weight = hcnn.HCNNCell.A.weight.clone() | ||
|
||
for epoch in range(n_epochs): | ||
output_ = hcnn(observation[:past_horizon]) | ||
loss = loss_fct(output_[:past_horizon], target) | ||
loss.backward() | ||
assert hcnn.HCNNCell.A.weight.grad is not None | ||
optimizer.step() | ||
if epoch == 1: | ||
start_loss = loss.detach() | ||
assert (hcnn.HCNNCell.A.weight != start_weight).all() | ||
hcnn.zero_grad() | ||
|
||
forecast = hcnn(observation[:past_horizon])[past_horizon:] | ||
assert loss < start_loss | ||
assert torch.isclose(observation[past_horizon:], forecast, atol=1).all() | ||
|
||
@pytest.mark.parametrize("teacher_forcing, decrease_teacher_forcing, result", [(1, 0, 1), (1, 0.2, 0.8), (0, 0.1, 0)],) | ||
def test_adjust_teacher_forcing(self, teacher_forcing, decrease_teacher_forcing, result): | ||
hcnn = HCNN( | ||
n_state_neurons=10, | ||
n_features_Y=2, | ||
past_horizon=10, | ||
forecast_horizon=5, | ||
teacher_forcing=teacher_forcing, | ||
decrease_teacher_forcing=decrease_teacher_forcing) | ||
hcnn.adjust_teacher_forcing() | ||
assert hcnn.HCNNCell.teacher_forcing == result | ||
assert hcnn.teacher_forcing == result |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,116 @@ | ||
import torch | ||
from prosper_nn.models.hcnn.hcnn_cell import HCNNCell, PartialTeacherForcing | ||
|
||
|
||
class TestPartialTeacherForcing: | ||
ptf = PartialTeacherForcing(p=0.5) | ||
|
||
def test_evaluation(self): | ||
self.ptf.eval() | ||
input = torch.randn((20, 1, 100)) | ||
|
||
output = self.ptf(input) | ||
# fill dropped nodes | ||
output = torch.where(output == 0, input, output) | ||
assert (output == input).all() | ||
|
||
def test_train(self): | ||
self.ptf.train() | ||
input = torch.randn((20, 1, 100)) | ||
|
||
output = self.ptf(input) | ||
# fill dropped nodes | ||
output = torch.where(output == 0, input, output) | ||
assert (output == input).all() | ||
|
||
|
||
class TestHcnnCell: | ||
n_state_neurons = 10 | ||
n_features_Y = 5 | ||
batchsize = 7 | ||
|
||
hcnn_cell = HCNNCell( | ||
n_state_neurons=n_state_neurons, | ||
n_features_Y=n_features_Y, | ||
) | ||
hcnn_cell.A.weight = torch.nn.Parameter(torch.ones_like(hcnn_cell.A.weight)) | ||
state = 0.5 * torch.ones((batchsize, n_state_neurons)) | ||
expectation = state[..., :n_features_Y] | ||
observation = torch.ones(batchsize, n_features_Y) | ||
|
||
def test_get_teacher_forcing_full_Y(self): | ||
self.hcnn_cell.ptf_dropout.p = 0 | ||
output_, teacher_forcing_ = self.hcnn_cell.get_teacher_forcing_full_Y( | ||
self.observation, self.expectation | ||
) | ||
self.checks_get_teacher_forcing(output_, teacher_forcing_) | ||
|
||
### with partial teacher forcing | ||
self.hcnn_cell.ptf_dropout.p = 0.5 | ||
output_, teacher_forcing_ = self.hcnn_cell.get_teacher_forcing_full_Y( | ||
self.observation, self.expectation | ||
) | ||
|
||
# fill dropped nodes | ||
teacher_forcing_[..., : self.n_features_Y] = torch.where( | ||
teacher_forcing_[..., : self.n_features_Y] == 0, | ||
-0.5, | ||
teacher_forcing_[..., : self.n_features_Y], | ||
) | ||
|
||
self.checks_get_teacher_forcing(output_, teacher_forcing_) | ||
|
||
def test_get_teacher_forcing_partial_Y(self): | ||
self.hcnn_cell.ptf_dropout.p = 0 | ||
output_, teacher_forcing_ = self.hcnn_cell.get_teacher_forcing_partial_Y( | ||
self.observation, self.expectation | ||
) | ||
self.checks_get_teacher_forcing(output_, teacher_forcing_) | ||
|
||
### with partial teacher forcing | ||
self.hcnn_cell.ptf_dropout.p = 0.5 | ||
output_, teacher_forcing_ = self.hcnn_cell.get_teacher_forcing_partial_Y( | ||
self.observation, self.expectation | ||
) | ||
# fill dropped nodes | ||
teacher_forcing_[..., : self.n_features_Y] = torch.where( | ||
teacher_forcing_[..., : self.n_features_Y] == 0, | ||
-0.5, | ||
teacher_forcing_[..., : self.n_features_Y], | ||
) | ||
output_ = torch.where(output_ == 0, -0.5, output_) | ||
self.checks_get_teacher_forcing(output_, teacher_forcing_) | ||
|
||
def checks_get_teacher_forcing(self, output_, teacher_forcing_): | ||
assert (output_ == -0.5 * torch.ones(self.batchsize, self.n_features_Y)).all() | ||
assert (teacher_forcing_[..., : self.n_features_Y] == -self.expectation).all() | ||
assert (teacher_forcing_[..., self.n_features_Y :] == 0).all() | ||
assert ( | ||
(self.expectation - teacher_forcing_[..., : self.n_features_Y]) | ||
== self.observation | ||
).all() | ||
|
||
def test_forward(self): | ||
state_, output_ = self.hcnn_cell.forward(self.state) | ||
self.checks_forward(state_, output_) | ||
|
||
state_, output_ = self.hcnn_cell.forward(self.state, self.observation) | ||
self.checks_forward(state_, output_) | ||
|
||
def test_forward_past_horizon(self): | ||
state_, output_ = self.hcnn_cell.forward_past_horizon( | ||
self.state, self.observation, self.expectation | ||
) | ||
self.checks_forward(state_, output_) | ||
|
||
def test_forward_forecast_horizon(self): | ||
state_, output_ = self.hcnn_cell.forward_forecast_horizon( | ||
self.state, self.expectation | ||
) | ||
self.checks_forward(state_, output_) | ||
|
||
def checks_forward(self, state_, output_): | ||
assert state_.shape == torch.Size((self.batchsize, self.n_state_neurons)) | ||
assert output_.shape == torch.Size((self.batchsize, self.n_features_Y)) | ||
assert not (state_.isnan()).any() | ||
assert not (output_.isnan()).any() |
Oops, something went wrong.