Skip to content

Commit

Permalink
Merge pull request #31 from Fraunhofer-IIS/27-include-unit-tests
Browse files Browse the repository at this point in the history
add first unit tests for hcnn and sensitivity analysis
  • Loading branch information
bknico-iis authored Sep 18, 2024
2 parents c9934c1 + 84c6115 commit 38279ca
Show file tree
Hide file tree
Showing 5 changed files with 396 additions and 0 deletions.
103 changes: 103 additions & 0 deletions prosper_nn/utils/sensitivity_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,109 @@ def analyse_temporal_sensitivity(

return torch.stack(total_heat)

def plot_analyse_temporal_sensitivity(
sensis: torch.Tensor,
target_var: List[str],
features: List[str],
n_future_steps: int,
path: Optional[str] = None,
title: Optional[Union[dict, str]] = None,
xticks: Optional[Union[dict, str]] = None,
yticks: Optional[Union[dict, str]] = None,
xlabel: Optional[Union[dict, str]] = None,
ylabel: Optional[Union[dict, str]] = None,
figsize: List[float] = [12.4, 5.8],
) -> None:
"""
Plots a sensitivity analysis and creates a table with monotonie and total heat on the right side
for each task variable.
"""
# Calculate total heat and monotony
total_heat = torch.sum(torch.abs(sensis), dim=2)
total_heat = (total_heat * 100).round() / 100
monotonie = torch.sum(sensis, dim=2) / total_heat
monotonie = (monotonie * 100).round() / 100

plt.rcParams["figure.figsize"] = figsize
### Temporal Sensitivity Heatmap ###
# plot a sensitivity matrix for every feature/target variable to be investigated
for i, node in enumerate(target_var):
# Set description
if not title:
title = "Influence of auxiliary variables on {}"
if not xlabel:
xlabel = "Weeks into future"
if not ylabel:
ylabel = "Auxiliary variables"
if not xticks:
xticks = {
"ticks": range(1, n_future_steps + 1),
"labels": [
str(i) if i % 2 == 1 else None for i in range(1, n_future_steps + 1)
],
"horizontalalignment": "right",
}
if not yticks:
yticks = {
"ticks": range(len(features)),
"labels": [feature.replace("_", " ") for feature in features],
"rotation": 0,
"va": "top",
"size": "large",
}

sns.heatmap(sensis[i],
center=0,
cmap='coolwarm',
robust=True,
cbar_kws={'location':'right', 'pad': 0.22},
)
plt.ylabel(ylabel)
plt.xlabel(xlabel)
plt.xticks(**xticks)
plt.yticks(**yticks),
plt.title(title.format(node.replace("_", " ")), pad=25)

# Fade out row name if total heat is not that strong
for j, ticklabel in enumerate(plt.gca().get_yticklabels()):
if j >= len(target_var):
alpha = float(0.5 + (total_heat[i][j] / torch.max(total_heat)) / 2)
ticklabel.set_color(color=[0, 0, 0, alpha])
else:
ticklabel.set_color(color="C0")
plt.tight_layout()

### Table with total heat and monotonie ###
table_values = torch.stack((total_heat[i], monotonie[i])).T

# Colour of cells
cell_colours = [
["#E1E3E3" for _ in range(table_values.shape[1])]
for _ in range(table_values.shape[0])
]
cell_colours[torch.argmax(table_values, dim=0)[0]][0] = "#179C7D"
cell_colours[torch.argmax(torch.abs(table_values), dim=0)[1]][1] = "#179C7D"

# Plot table
plt.table(
table_values.numpy(),
loc='right',
colLabels=['Absolute', 'Monotony'],
colWidths=[0.2,0.2],
bbox=[1, 0, 0.3, 1.042], #[1, 0, 0.4, 1.042],
cellColours=cell_colours,
edges='BRT',
)
plt.subplots_adjust(left=0.05, right=1.0) # creates space for table

# Save and close
if path:
plt.savefig(
path + "sensi_analysis_{}.png".format(node), bbox_inches="tight"
)
else:
plt.show()
plt.close()

def plot_analyse_temporal_sensitivity(
sensis: torch.Tensor,
Expand Down
Empty file added tests/__init__.py
Empty file.
94 changes: 94 additions & 0 deletions tests/models/hcnn/test_hcnn.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,94 @@
import torch
from prosper_nn.models.hcnn import HCNN
import pytest


class TestHcnn:
@pytest.mark.parametrize(
"n_state_neurons, n_features_Y, past_horizon, forecast_horizon, batchsize, sparsity, teacher_forcing, backward_full_Y, ptf_in_backward",
[
(10, 4, 5, 2, 1, 0, 1, True, True),
(1, 1, 1, 1, 1, 0, 1, True, True),
(10, 4, 5, 2, 1, 0.5, 1, True, True),
(10, 4, 5, 2, 1, 0, 0.5, False, True),
(10, 4, 5, 2, 1, 0, 0.5, False, False),
(10, 4, 5, 2, 1, 0, 0.5, True, False),
(10, 4, 5, 2, 1, 0, 1, True, True),
],
)
def test_forward(self, n_state_neurons, n_features_Y, past_horizon, forecast_horizon, batchsize, sparsity, teacher_forcing, backward_full_Y, ptf_in_backward):
hcnn = HCNN(
n_state_neurons=n_state_neurons,
n_features_Y=n_features_Y,
past_horizon=past_horizon,
forecast_horizon=forecast_horizon,
sparsity=sparsity,
teacher_forcing=teacher_forcing,
backward_full_Y=backward_full_Y,
ptf_in_backward=ptf_in_backward,
)
observation = torch.zeros(past_horizon, batchsize, n_features_Y)
output_ = hcnn(observation)

assert output_.shape == torch.Size((past_horizon + forecast_horizon, batchsize, n_features_Y))
assert isinstance(output_, torch.Tensor)
assert not (output_.isnan()).any()

@pytest.mark.parametrize(
"n_state_neurons, past_horizon, forecast_horizon, batchsize, sparsity, teacher_forcing, backward_full_Y, ptf_in_backward",
[
(5, 50, 5, 1, 0, 1, True, True),
],
)
def test_train(self, n_state_neurons, past_horizon, forecast_horizon, batchsize, sparsity, teacher_forcing, backward_full_Y, ptf_in_backward):
n_features_Y = 1
n_epochs = 10000

hcnn = HCNN(
n_state_neurons=n_state_neurons,
n_features_Y=n_features_Y,
past_horizon=past_horizon,
forecast_horizon=forecast_horizon,
sparsity=sparsity,
teacher_forcing=teacher_forcing,
backward_full_Y=backward_full_Y,
ptf_in_backward=ptf_in_backward,
)
observation = torch.zeros(past_horizon, batchsize, n_features_Y)
observation = torch.sin(torch.linspace(0.5, 10 * torch.pi, past_horizon + forecast_horizon))
observation = observation.unsqueeze(1).unsqueeze(1)

optimizer = torch.optim.Adam(hcnn.parameters(), lr=0.001)
target = torch.zeros_like(observation[:past_horizon])
loss_fct = torch.nn.MSELoss()

start_weight = hcnn.HCNNCell.A.weight.clone()

for epoch in range(n_epochs):
output_ = hcnn(observation[:past_horizon])
loss = loss_fct(output_[:past_horizon], target)
loss.backward()
assert hcnn.HCNNCell.A.weight.grad is not None
optimizer.step()
if epoch == 1:
start_loss = loss.detach()
assert (hcnn.HCNNCell.A.weight != start_weight).all()
hcnn.zero_grad()

forecast = hcnn(observation[:past_horizon])[past_horizon:]
assert loss < start_loss
assert torch.isclose(observation[past_horizon:], forecast, atol=1).all()

@pytest.mark.parametrize("teacher_forcing, decrease_teacher_forcing, result", [(1, 0, 1), (1, 0.2, 0.8), (0, 0.1, 0)],)
def test_adjust_teacher_forcing(self, teacher_forcing, decrease_teacher_forcing, result):
hcnn = HCNN(
n_state_neurons=10,
n_features_Y=2,
past_horizon=10,
forecast_horizon=5,
teacher_forcing=teacher_forcing,
decrease_teacher_forcing=decrease_teacher_forcing)
hcnn.adjust_teacher_forcing()
assert hcnn.HCNNCell.teacher_forcing == result
assert hcnn.teacher_forcing == result
116 changes: 116 additions & 0 deletions tests/models/hcnn/test_hcnn_cell.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,116 @@
import torch
from prosper_nn.models.hcnn.hcnn_cell import HCNNCell, PartialTeacherForcing


class TestPartialTeacherForcing:
ptf = PartialTeacherForcing(p=0.5)

def test_evaluation(self):
self.ptf.eval()
input = torch.randn((20, 1, 100))

output = self.ptf(input)
# fill dropped nodes
output = torch.where(output == 0, input, output)
assert (output == input).all()

def test_train(self):
self.ptf.train()
input = torch.randn((20, 1, 100))

output = self.ptf(input)
# fill dropped nodes
output = torch.where(output == 0, input, output)
assert (output == input).all()


class TestHcnnCell:
n_state_neurons = 10
n_features_Y = 5
batchsize = 7

hcnn_cell = HCNNCell(
n_state_neurons=n_state_neurons,
n_features_Y=n_features_Y,
)
hcnn_cell.A.weight = torch.nn.Parameter(torch.ones_like(hcnn_cell.A.weight))
state = 0.5 * torch.ones((batchsize, n_state_neurons))
expectation = state[..., :n_features_Y]
observation = torch.ones(batchsize, n_features_Y)

def test_get_teacher_forcing_full_Y(self):
self.hcnn_cell.ptf_dropout.p = 0
output_, teacher_forcing_ = self.hcnn_cell.get_teacher_forcing_full_Y(
self.observation, self.expectation
)
self.checks_get_teacher_forcing(output_, teacher_forcing_)

### with partial teacher forcing
self.hcnn_cell.ptf_dropout.p = 0.5
output_, teacher_forcing_ = self.hcnn_cell.get_teacher_forcing_full_Y(
self.observation, self.expectation
)

# fill dropped nodes
teacher_forcing_[..., : self.n_features_Y] = torch.where(
teacher_forcing_[..., : self.n_features_Y] == 0,
-0.5,
teacher_forcing_[..., : self.n_features_Y],
)

self.checks_get_teacher_forcing(output_, teacher_forcing_)

def test_get_teacher_forcing_partial_Y(self):
self.hcnn_cell.ptf_dropout.p = 0
output_, teacher_forcing_ = self.hcnn_cell.get_teacher_forcing_partial_Y(
self.observation, self.expectation
)
self.checks_get_teacher_forcing(output_, teacher_forcing_)

### with partial teacher forcing
self.hcnn_cell.ptf_dropout.p = 0.5
output_, teacher_forcing_ = self.hcnn_cell.get_teacher_forcing_partial_Y(
self.observation, self.expectation
)
# fill dropped nodes
teacher_forcing_[..., : self.n_features_Y] = torch.where(
teacher_forcing_[..., : self.n_features_Y] == 0,
-0.5,
teacher_forcing_[..., : self.n_features_Y],
)
output_ = torch.where(output_ == 0, -0.5, output_)
self.checks_get_teacher_forcing(output_, teacher_forcing_)

def checks_get_teacher_forcing(self, output_, teacher_forcing_):
assert (output_ == -0.5 * torch.ones(self.batchsize, self.n_features_Y)).all()
assert (teacher_forcing_[..., : self.n_features_Y] == -self.expectation).all()
assert (teacher_forcing_[..., self.n_features_Y :] == 0).all()
assert (
(self.expectation - teacher_forcing_[..., : self.n_features_Y])
== self.observation
).all()

def test_forward(self):
state_, output_ = self.hcnn_cell.forward(self.state)
self.checks_forward(state_, output_)

state_, output_ = self.hcnn_cell.forward(self.state, self.observation)
self.checks_forward(state_, output_)

def test_forward_past_horizon(self):
state_, output_ = self.hcnn_cell.forward_past_horizon(
self.state, self.observation, self.expectation
)
self.checks_forward(state_, output_)

def test_forward_forecast_horizon(self):
state_, output_ = self.hcnn_cell.forward_forecast_horizon(
self.state, self.expectation
)
self.checks_forward(state_, output_)

def checks_forward(self, state_, output_):
assert state_.shape == torch.Size((self.batchsize, self.n_state_neurons))
assert output_.shape == torch.Size((self.batchsize, self.n_features_Y))
assert not (state_.isnan()).any()
assert not (output_.isnan()).any()
Loading

0 comments on commit 38279ca

Please sign in to comment.