Skip to content

Commit

Permalink
Minimum delta value of 1km power spectrum
Browse files Browse the repository at this point in the history
  • Loading branch information
OpheliaMiralles committed Nov 22, 2024
1 parent 162b58f commit 477634c
Showing 1 changed file with 41 additions and 149 deletions.
190 changes: 41 additions & 149 deletions src/anemoi/training/diagnostics/plots.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,11 +18,15 @@
import numpy as np
from anemoi.models.layers.mapper import GraphEdgeMixin
from matplotlib.collections import LineCollection
from matplotlib.colors import BoundaryNorm, ListedColormap, TwoSlopeNorm
from pyshtools.expand import SHGLQ, SHExpandGLQ
from matplotlib.colors import BoundaryNorm
from matplotlib.colors import ListedColormap
from matplotlib.colors import TwoSlopeNorm
from pyshtools.expand import SHGLQ
from pyshtools.expand import SHExpandGLQ
from scipy.interpolate import griddata

from anemoi.training.diagnostics.maps import Coastlines, EquirectangularProjection
from anemoi.training.diagnostics.maps import Coastlines
from anemoi.training.diagnostics.maps import EquirectangularProjection

if TYPE_CHECKING:
from matplotlib.figure import Figure
Expand Down Expand Up @@ -172,29 +176,20 @@ def plot_power_spectrum(
# Calculate delta_lon and delta_lat on the projected grid
delta_lon = abs(np.diff(pc_lon))
non_zero_delta_lon = delta_lon[delta_lon != 0]
delta_lat = abs(np.diff(pc_lat))
delta_lat = abs(np.diff(pc_lat))
non_zero_delta_lat = delta_lat[delta_lat != 0]
min_delta_lon = max(0.0003, np.min(abs(non_zero_delta_lon)))
min_delta_lat = max(0.0003, np.min(abs(non_zero_delta_lat)))

# Define a regular grid for interpolation
n_pix_lon = max(
int(
np.floor(abs(pc_lon.max() - pc_lon.min()) / abs(np.min(non_zero_delta_lon)))
),
1414,
) # around 400 for O96
n_pix_lat = max(
int(
np.floor(abs(pc_lat.max() - pc_lat.min()) / abs(np.min(non_zero_delta_lat)))
),
955,
) # around 192 for O96
n_pix_lon = int(np.floor(abs(pc_lon.max() - pc_lon.min()) / min_delta_lon)) # around 400 for O96
n_pix_lat = int(np.floor(abs(pc_lat.max() - pc_lat.min()) / min_delta_lat)) # around 192 for O96
regular_pc_lon = np.linspace(pc_lon.min(), pc_lon.max(), n_pix_lon)
regular_pc_lat = np.linspace(pc_lat.min(), pc_lat.max(), n_pix_lat)
grid_pc_lon, grid_pc_lat = np.meshgrid(regular_pc_lon, regular_pc_lat)

for plot_idx, (variable_idx, (variable_name, output_only)) in enumerate(
parameters.items()
):
for plot_idx, (variable_idx, (variable_name, output_only)) in enumerate(parameters.items()):
print(variable_name)
yt = y_true[..., variable_idx].squeeze()
yp = y_pred[..., variable_idx].squeeze()

Expand All @@ -204,35 +199,11 @@ def plot_power_spectrum(
method = "linear" if nan_flag else "cubic"
if output_only:
xt = x[..., variable_idx].squeeze()
yt_i = griddata(
(pc_lon, pc_lat),
(yt - xt),
(grid_pc_lon, grid_pc_lat),
method=method,
fill_value=0.0,
)
yp_i = griddata(
(pc_lon, pc_lat),
(yp - xt),
(grid_pc_lon, grid_pc_lat),
method=method,
fill_value=0.0,
)
yt_i = griddata((pc_lon, pc_lat), (yt - xt), (grid_pc_lon, grid_pc_lat), method=method, fill_value=0.0)
yp_i = griddata((pc_lon, pc_lat), (yp - xt), (grid_pc_lon, grid_pc_lat), method=method, fill_value=0.0)
else:
yt_i = griddata(
(pc_lon, pc_lat),
yt,
(grid_pc_lon, grid_pc_lat),
method=method,
fill_value=0.0,
)
yp_i = griddata(
(pc_lon, pc_lat),
yp,
(grid_pc_lon, grid_pc_lat),
method=method,
fill_value=0.0,
)
yt_i = griddata((pc_lon, pc_lat), yt, (grid_pc_lon, grid_pc_lat), method=method, fill_value=0.0)
yp_i = griddata((pc_lon, pc_lat), yp, (grid_pc_lon, grid_pc_lat), method=method, fill_value=0.0)

# Masking NaN values
if nan_flag:
Expand Down Expand Up @@ -282,7 +253,7 @@ def compute_spectra(field: np.ndarray) -> np.ndarray:
field = np.array(field)

# compute real and imaginary parts of power spectra of field
lmax = field.shape[0] - 1 # maximum degree of expansion
lmax = min(710*640, field.shape[0]) - 1 # maximum degree of expansion
zero_w = SHGLQ(lmax)
coeffs_field = SHExpandGLQ(field, w=zero_w[1], zero=zero_w[0])

Expand Down Expand Up @@ -331,9 +302,7 @@ def plot_histogram(
figsize = (n_plots_y * 4, n_plots_x * 3)
fig, ax = plt.subplots(n_plots_x, n_plots_y, figsize=figsize)

for plot_idx, (variable_idx, (variable_name, output_only)) in enumerate(
parameters.items()
):
for plot_idx, (variable_idx, (variable_name, output_only)) in enumerate(parameters.items()):
yt = y_true[..., variable_idx].squeeze()
yp = y_pred[..., variable_idx].squeeze()
# postprocessed outputs so we need to handle possible NaNs
Expand All @@ -345,53 +314,25 @@ def plot_histogram(
yt_xt = yt - xt
yp_xt = yp - xt
# enforce the same binning for both histograms
bin_min = min(np.nanpercentile(yt_xt, 0.05), np.nanpercentile(yp_xt, 0.05))
bin_max = max(np.nanpercentile(yt_xt, 0.95), np.percentile(yp_xt, 0.95))
hist_yt, bins_yt = np.histogram(
yt_xt[~np.isnan(yt_xt)],
bins=100,
density=True,
range=[bin_min, bin_max],
)
hist_yp, bins_yp = np.histogram(
yp_xt[~np.isnan(yp_xt)],
bins=100,
density=True,
range=[bin_min, bin_max],
)
bin_min = min(np.nanmin(yt_xt), np.nanmin(yp_xt))
bin_max = max(np.nanmax(yt_xt), np.nanmax(yp_xt))
hist_yt, bins_yt = np.histogram(yt_xt[~np.isnan(yt_xt)], bins=100, density=True, range=[bin_min, bin_max])
hist_yp, bins_yp = np.histogram(yp_xt[~np.isnan(yp_xt)], bins=100, density=True, range=[bin_min, bin_max])
else:
# enforce the same binning for both histograms
bin_min = min(np.nanpercentile(yt, 0.05), np.nanpercentile(yp, 0.05))
bin_max = max(np.nanpercentile(yt, 0.95), np.nanpercentile(yp, 0.95))
hist_yt, bins_yt = np.histogram(
yt[~np.isnan(yt)], bins=100, density=True, range=[bin_min, bin_max]
)
hist_yp, bins_yp = np.histogram(
yp[~np.isnan(yp)], bins=100, density=True, range=[bin_min, bin_max]
)
bin_min = min(np.nanmin(yt), np.nanmin(yp))
bin_max = max(np.nanmax(yt), np.nanmax(yp))
hist_yt, bins_yt = np.histogram(yt[~np.isnan(yt)], bins=100, density=True, range=[bin_min, bin_max])
hist_yp, bins_yp = np.histogram(yp[~np.isnan(yp)], bins=100, density=True, range=[bin_min, bin_max])

# Visualization trick for tp
if variable_name in precip_and_related_fields:
# in-place multiplication does not work here because variables are different numpy types
hist_yt = hist_yt * bins_yt[:-1]
hist_yp = hist_yp * bins_yp[:-1]
# Plot the modified histogram
ax[plot_idx].bar(
bins_yt[:-1],
hist_yt,
width=np.diff(bins_yt),
color="blue",
alpha=0.7,
label="Truth (data)",
)
ax[plot_idx].bar(
bins_yp[:-1],
hist_yp,
width=np.diff(bins_yp),
color="red",
alpha=0.7,
label="Predicted",
)
ax[plot_idx].bar(bins_yt[:-1], hist_yt, width=np.diff(bins_yt), color="blue", alpha=0.7, label="Truth (data)")
ax[plot_idx].bar(bins_yp[:-1], hist_yp, width=np.diff(bins_yp), color="red", alpha=0.7, label="Predicted")

ax[plot_idx].set_title(variable_name)
ax[plot_idx].set_xlabel(variable_name)
Expand Down Expand Up @@ -455,9 +396,7 @@ def plot_predicted_multilevel_flat_sample(
lat, lon = latlons[:, 0], latlons[:, 1]
pc_lon, pc_lat = pc(lon, lat)

for plot_idx, (variable_idx, (variable_name, output_only)) in enumerate(
parameters.items()
):
for plot_idx, (variable_idx, (variable_name, output_only)) in enumerate(parameters.items()):
xt = x[..., variable_idx].squeeze() * int(output_only)
yt = y_true[..., variable_idx].squeeze()
yp = y_pred[..., variable_idx].squeeze()
Expand Down Expand Up @@ -549,26 +488,8 @@ def plot_flat_sample(
# converting to mm from m
truth *= 1000.0
pred *= 1000.0
scatter_plot(
fig,
ax[1],
lon=lon,
lat=lat,
data=truth,
cmap=precip_colormap,
norm=norm,
title=f"{vname} target",
)
scatter_plot(
fig,
ax[2],
lon=lon,
lat=lat,
data=pred,
cmap=precip_colormap,
norm=norm,
title=f"{vname} pred",
)
scatter_plot(fig, ax[1], lon=lon, lat=lat, data=truth, cmap=precip_colormap, norm=norm, title=f"{vname} target")
scatter_plot(fig, ax[2], lon=lon, lat=lat, data=pred, cmap=precip_colormap, norm=norm, title=f"{vname} pred")
scatter_plot(
fig,
ax[3],
Expand All @@ -588,27 +509,9 @@ def error_plot_in_degrees(array1: np.ndarray, array2: np.ndarray) -> np.ndarray:
return np.where(tmp > 180, tmp - 360, tmp)

sample_shape = truth.shape
pred = np.maximum(
np.zeros(sample_shape), np.minimum(360 * np.ones(sample_shape), (pred))
)
scatter_plot(
fig,
ax[1],
lon=lon,
lat=lat,
data=truth,
cmap=cyclic_colormap,
title=f"{vname} target",
)
scatter_plot(
fig,
ax[2],
lon=lon,
lat=lat,
data=pred,
cmap=cyclic_colormap,
title=f"capped {vname} pred",
)
pred = np.maximum(np.zeros(sample_shape), np.minimum(360 * np.ones(sample_shape), (pred)))
scatter_plot(fig, ax[1], lon=lon, lat=lat, data=truth, cmap=cyclic_colormap, title=f"{vname} target")
scatter_plot(fig, ax[2], lon=lon, lat=lat, data=pred, cmap=cyclic_colormap, title=f"capped {vname} pred")
err_plot = error_plot_in_degrees(truth, pred)
scatter_plot(
fig,
Expand Down Expand Up @@ -711,9 +614,7 @@ def error_plot_in_degrees(array1: np.ndarray, array2: np.ndarray) -> np.ndarray:
title=f"{vname} persist err: {np.nanmean(np.abs(err_plot)):.{4}f} deg.",
)
else:
scatter_plot(
fig, ax[0], lon=lon, lat=lat, data=input_, title=f"{vname} input"
)
scatter_plot(fig, ax[0], lon=lon, lat=lat, data=input_, title=f"{vname} input")
scatter_plot(
fig,
ax[4],
Expand Down Expand Up @@ -857,16 +758,11 @@ def plot_graph_node_features(model: nn.Module) -> Figure:
Figure object handle
"""
nrows = len(nodes_name := model._graph_data.node_types)
ncols = min(
model.node_attributes.trainable_tensors[m].trainable.shape[1]
for m in nodes_name
)
ncols = min(model.node_attributes.trainable_tensors[m].trainable.shape[1] for m in nodes_name)
figsize = (ncols * 4, nrows * 3)
fig, ax = plt.subplots(nrows, ncols, figsize=figsize)

for row, (mesh, trainable_tensor) in enumerate(
model.node_attributes.trainable_tensors.items()
):
for row, (mesh, trainable_tensor) in enumerate(model.node_attributes.trainable_tensors.items()):
latlons = model.node_attributes.get_coordinates(mesh).cpu().numpy()
node_features = trainable_tensor.trainable.cpu().detach().numpy()

Expand Down Expand Up @@ -907,13 +803,9 @@ def plot_graph_edge_features(model: nn.Module, q_extreme_limit: float = 0.05) ->
}

if isinstance(model.processor, GraphEdgeMixin):
trainable_modules[model._graph_name_hidden, model._graph_name_hidden] = (
model.processor
)
trainable_modules[model._graph_name_hidden, model._graph_name_hidden] = model.processor

ncols = min(
module.trainable.trainable.shape[1] for module in trainable_modules.values()
)
ncols = min(module.trainable.trainable.shape[1] for module in trainable_modules.values())
nrows = len(trainable_modules)
figsize = (ncols * 4, nrows * 3)
fig, ax = plt.subplots(nrows, ncols, figsize=figsize)
Expand Down

0 comments on commit 477634c

Please sign in to comment.