Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Fix NaN bug in validation #59

Merged
merged 4 commits into from
Oct 8, 2024
Merged
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
32 changes: 30 additions & 2 deletions mlpp_lib/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -449,9 +449,9 @@ def drop_nans(self, group_size: int = 1):
x, y, w = self._get_copies()

event_axes = [self.dims.index(dim) for dim in self.dims if dim != "s"]
mask = da.any(da.isnan(da.from_array(x, name="x")), axis=event_axes)
mask = da.any(~da.isfinite(da.from_array(x, name="x")), axis=event_axes)
if y is not None:
mask = mask | da.any(da.isnan(da.from_array(y, name="y")), axis=event_axes)
mask = mask | da.any(~da.isfinite(da.from_array(y, name="y")), axis=event_axes)
mask = (~mask).compute()

# with grouped samples, nans have to be removed in blocks:
Expand Down Expand Up @@ -596,6 +596,7 @@ def __init__(
self.shuffle = shuffle
self.block_size = block_size
self.num_samples = len(self.dataset.x)
# self.num_batches = int(np.ceil(self.num_samples / batch_size))
dnerini marked this conversation as resolved.
Show resolved Hide resolved
self.num_batches = (
self.num_samples // batch_size if batch_size <= self.num_samples else 1
)
Expand All @@ -609,12 +610,39 @@ def __len__(self) -> int:
def __getitem__(self, index) -> tuple[tf.Tensor, ...]:
if index >= self.num_batches:
self._reset()
LOGGER.info("End of epoch (will raise IndexError).")
raise IndexError
start = index * self.batch_size
end = index * self.batch_size + self.batch_size
output = [self.dataset.x[start:end], self.dataset.y[start:end]]
if not self.shuffle:
count = np.count_nonzero(~np.isfinite(output[0]))
if count>0:
LOGGER.info(f"Input x (val): {count=}, {output[0].shape=}")
invalid_rows, invalid_cols = np.where(~np.isfinite(output[0]))
for col in np.unique(invalid_cols):
rows_for_col = invalid_rows[invalid_cols == col]

# Extract the invalid values in this column
invalid_values_for_col = output[0][rows_for_col, col]

# Count NaNs, +Inf, and -Inf in this column
nan_count = np.sum(np.isnan(invalid_values_for_col))
posinf_count = np.sum(np.isposinf(invalid_values_for_col))
neginf_count = np.sum(np.isneginf(invalid_values_for_col))

if len(nan_count) > 0:
LOGGER.info(f"NaNs in column {col}: {nan_count} ({len(nan_count)})")
if len(posinf_count) > 0:
LOGGER.info(f"+Inf in column {col}: {posinf_count} ({len(posinf_count)})")
if len(neginf_count) > 0:
LOGGER.info(f"-Inf in column {col}: {neginf_count} ({len(neginf_count)})")
louisPoulain marked this conversation as resolved.
Show resolved Hide resolved
if self.dataset.w is not None:
output.append(self.dataset.w[start:end])
if not self.shuffle:
count = np.count_nonzero(~np.isfinite(output[2]))
if count>0:
LOGGER.info(f"Input w (val): {count=}")
louisPoulain marked this conversation as resolved.
Show resolved Hide resolved
return tuple(output)

def on_epoch_end(self) -> None:
Expand Down
Loading