Skip to content

Commit

Permalink
add missing plot function
Browse files Browse the repository at this point in the history
  • Loading branch information
Julia Schemm committed Aug 16, 2024
1 parent 75330ea commit 84c6115
Showing 1 changed file with 103 additions and 0 deletions.
103 changes: 103 additions & 0 deletions prosper_nn/utils/sensitivity_analysis.py
Original file line number Diff line number Diff line change
Expand Up @@ -254,6 +254,109 @@ def analyse_temporal_sensitivity(

return torch.stack(total_heat)

def plot_analyse_temporal_sensitivity(
sensis: torch.Tensor,
target_var: List[str],
features: List[str],
n_future_steps: int,
path: Optional[str] = None,
title: Optional[Union[dict, str]] = None,
xticks: Optional[Union[dict, str]] = None,
yticks: Optional[Union[dict, str]] = None,
xlabel: Optional[Union[dict, str]] = None,
ylabel: Optional[Union[dict, str]] = None,
figsize: List[float] = [12.4, 5.8],
) -> None:
"""
Plots a sensitivity analysis and creates a table with monotonie and total heat on the right side
for each task variable.
"""
# Calculate total heat and monotony
total_heat = torch.sum(torch.abs(sensis), dim=2)
total_heat = (total_heat * 100).round() / 100
monotonie = torch.sum(sensis, dim=2) / total_heat
monotonie = (monotonie * 100).round() / 100

plt.rcParams["figure.figsize"] = figsize
### Temporal Sensitivity Heatmap ###
# plot a sensitivity matrix for every feature/target variable to be investigated
for i, node in enumerate(target_var):
# Set description
if not title:
title = "Influence of auxiliary variables on {}"
if not xlabel:
xlabel = "Weeks into future"
if not ylabel:
ylabel = "Auxiliary variables"
if not xticks:
xticks = {
"ticks": range(1, n_future_steps + 1),
"labels": [
str(i) if i % 2 == 1 else None for i in range(1, n_future_steps + 1)
],
"horizontalalignment": "right",
}
if not yticks:
yticks = {
"ticks": range(len(features)),
"labels": [feature.replace("_", " ") for feature in features],
"rotation": 0,
"va": "top",
"size": "large",
}

sns.heatmap(sensis[i],
center=0,
cmap='coolwarm',
robust=True,
cbar_kws={'location':'right', 'pad': 0.22},
)
plt.ylabel(ylabel)
plt.xlabel(xlabel)
plt.xticks(**xticks)
plt.yticks(**yticks),
plt.title(title.format(node.replace("_", " ")), pad=25)

# Fade out row name if total heat is not that strong
for j, ticklabel in enumerate(plt.gca().get_yticklabels()):
if j >= len(target_var):
alpha = float(0.5 + (total_heat[i][j] / torch.max(total_heat)) / 2)
ticklabel.set_color(color=[0, 0, 0, alpha])
else:
ticklabel.set_color(color="C0")
plt.tight_layout()

### Table with total heat and monotonie ###
table_values = torch.stack((total_heat[i], monotonie[i])).T

# Colour of cells
cell_colours = [
["#E1E3E3" for _ in range(table_values.shape[1])]
for _ in range(table_values.shape[0])
]
cell_colours[torch.argmax(table_values, dim=0)[0]][0] = "#179C7D"
cell_colours[torch.argmax(torch.abs(table_values), dim=0)[1]][1] = "#179C7D"

# Plot table
plt.table(
table_values.numpy(),
loc='right',
colLabels=['Absolute', 'Monotony'],
colWidths=[0.2,0.2],
bbox=[1, 0, 0.3, 1.042], #[1, 0, 0.4, 1.042],
cellColours=cell_colours,
edges='BRT',
)
plt.subplots_adjust(left=0.05, right=1.0) # creates space for table

# Save and close
if path:
plt.savefig(
path + "sensi_analysis_{}.png".format(node), bbox_inches="tight"
)
else:
plt.show()
plt.close()

# %% Sensitivity for feed-forward models and other not-recurrent models

Expand Down

0 comments on commit 84c6115

Please sign in to comment.