diff --git a/mlpp_lib/datasets.py b/mlpp_lib/datasets.py index 539a1bb..7d449d2 100644 --- a/mlpp_lib/datasets.py +++ b/mlpp_lib/datasets.py @@ -449,9 +449,9 @@ def drop_nans(self, group_size: int = 1): x, y, w = self._get_copies() event_axes = [self.dims.index(dim) for dim in self.dims if dim != "s"] - mask = da.any(da.isnan(da.from_array(x, name="x")), axis=event_axes) + mask = da.any(~da.isfinite(da.from_array(x, name="x")), axis=event_axes) if y is not None: - mask = mask | da.any(da.isnan(da.from_array(y, name="y")), axis=event_axes) + mask = mask | da.any(~da.isfinite(da.from_array(y, name="y")), axis=event_axes) mask = (~mask).compute() # with grouped samples, nans have to be removed in blocks: