Skip to content

Commit

Permalink
Merge branch 'main' into feature/ValueFiller
Browse files Browse the repository at this point in the history
  • Loading branch information
Louis Poulain--Auzéau committed Oct 18, 2024
2 parents a8ce0db + 41e5c37 commit 4774838
Show file tree
Hide file tree
Showing 19 changed files with 1,681 additions and 578 deletions.
6 changes: 3 additions & 3 deletions .github/workflows/publish-package.yml
Original file line number Diff line number Diff line change
Expand Up @@ -11,10 +11,10 @@ jobs:
runs-on: ubuntu-latest

steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
python-version: "3.10"
- name: Setup poetry
uses: abatilo/actions-poetry@v2
- name: Install the project dependencies
Expand Down
4 changes: 2 additions & 2 deletions .github/workflows/run-tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,8 @@ jobs:
python-version: ["3.9", "3.10"]

steps:
- uses: actions/checkout@v3
- uses: actions/setup-python@v4
- uses: actions/checkout@v4
- uses: actions/setup-python@v5
with:
python-version: ${{ matrix.python-version }}
- name: Setup poetry
Expand Down
22 changes: 12 additions & 10 deletions mlpp_lib/datasets.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@
from typing_extensions import Self

from .model_selection import DataSplitter
from .standardizers import DataTransformer
from .normalizers import DataTransformer

LOGGER = logging.getLogger(__name__)

Expand Down Expand Up @@ -153,16 +153,18 @@ def apply_filter(self):
self.x, self.y = self.filter.apply(self.x, self.y)

def normalize(self, stage=None):
LOGGER.info("Standardizing data.")
LOGGER.info("Normalizing data.")

if self.normalizer is None:
if stage == "test":
raise ValueError("Must provide standardizer for `test` stage.")
raise ValueError("Must provide normalizer for `test` stage.")
else:
self.normalizer = DataTransformer({"Identity": (list(self.train[0].data_vars), {})})

LOGGER.warning("No normalizer found, data are standardized by default.")
self.normalizer = DataTransformer(
{"Standardizer": list(self.train[0].data_vars)}
)

if stage == "fit" or stage is None:
if stage == "fit" or stage is None:
self.normalizer.fit(self.train[0])
self.train = (
tuple(self.normalizer.transform(self.train[0])) + self.train[1:]
Expand Down Expand Up @@ -446,9 +448,9 @@ def drop_nans(self, group_size: int = 1):
x, y, w = self._get_copies()

event_axes = [self.dims.index(dim) for dim in self.dims if dim != "s"]
mask = da.any(da.isnan(da.from_array(x, name="x")), axis=event_axes)
mask = da.any(~da.isfinite(da.from_array(x, name="x")), axis=event_axes)
if y is not None:
mask = mask | da.any(da.isnan(da.from_array(y, name="y")), axis=event_axes)
mask = mask | da.any(~da.isfinite(da.from_array(y, name="y")), axis=event_axes)
mask = (~mask).compute()

# with grouped samples, nans have to be removed in blocks:
Expand Down Expand Up @@ -593,7 +595,7 @@ def __init__(
self.shuffle = shuffle
self.block_size = block_size
self.num_samples = len(self.dataset.x)
self.num_batches = self.num_samples // batch_size
self.num_batches = int(np.ceil(self.num_samples / batch_size))
self._indices = tf.range(self.num_samples)
self._seed = 0
self._reset()
Expand Down
48 changes: 34 additions & 14 deletions mlpp_lib/model_selection.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,6 @@ def __init__(
seed: Optional[int] = 10,
time_dim_name: str = "forecast_reference_time",
):

if not time_split.keys() == station_split.keys():
raise ValueError(
"Time split and station split must be defined "
Expand Down Expand Up @@ -198,14 +197,23 @@ def _time_partitioning(self) -> None:
self._time_indexers = self.time_split
else:
self._time_indexers = {}
if all(
[isinstance(v, float) for v in self.time_split.values()]
): # only fractions
if all([isinstance(v, float) for v in self.time_split.values()]):
res = self._time_partition_method(self.time_split)
self._time_indexers.update(res)
else: # mixed
else: # mixed fractions and labels
_time_split = self.time_split.copy()
self._time_indexers.update({"test": _time_split.pop("test")})
test_indexers = _time_split.pop("test")
if isinstance(test_indexers, tuple) and len(test_indexers) == 2:
# convert time slice to list of time indices
start_date = np.datetime64(test_indexers[0])
end_date = np.datetime64(test_indexers[1])
test_indexers = self.time_index[
np.logical_and(
self.time_index >= start_date, self.time_index <= end_date
)
]
test_indexers = [t for t in test_indexers if t in self.time_index]
self._time_indexers.update({"test": test_indexers})
res = self._time_partition_method(_time_split)
self._time_indexers.update(res)

Expand All @@ -223,8 +231,11 @@ def _time_partitioning(self) -> None:
self.partitions[partition].update(indexer)

def _time_partition_method(self, fractions: Mapping[str, float]):
time_index = [
t for t in self.time_index if t not in self._time_indexers.get("test", [])
]
if self.time_split_method == "sequential":
return sequential_split(self.time_index, fractions)
return sequential_split(time_index, fractions)

def _station_partitioning(self):
"""
Expand All @@ -236,9 +247,11 @@ def _station_partitioning(self):
if all([isinstance(v, float) for v in self.station_split.values()]):
res = self._station_partition_method(self.station_split)
self._station_indexers.update(res)
else:
else: # mixed fractions and labels
_station_split = self.station_split.copy()
self._station_indexers.update({"test": _station_split.pop("test")})
test_indexers = _station_split.pop("test")
test_indexers = [s for s in test_indexers if s in self.station_index]
self._station_indexers.update({"test": test_indexers})
res = self._station_partition_method(_station_split)
self._station_indexers.update(res)
else:
Expand All @@ -254,10 +267,15 @@ def _station_partition_method(
self, fractions: Mapping[str, float]
) -> Mapping[str, np.ndarray]:

station_index = [
sta
for sta in self.station_index
if sta not in self._station_indexers.get("test", [])
]
if self.station_split_method == "random":
out = random_split(self.station_index, fractions, seed=self.seed)
out = random_split(station_index, fractions, seed=self.seed)
elif self.station_split_method == "sequential":
out = sequential_split(self.station_index, fractions)
out = sequential_split(station_index, fractions)
return out

def _check_time(self, time_split: dict, time_split_method: str):
Expand Down Expand Up @@ -313,7 +331,9 @@ def from_dict(cls, splits: dict) -> Self:

def to_dict(self, sort_values=False):
if not hasattr(self, "time_index") or not hasattr(self, "station_index"):
raise ValueError("DataSplitter wasn't applied on any data yet, run `fit` first.")
raise ValueError(
"DataSplitter wasn't applied on any data yet, run `fit` first."
)
if not hasattr(self, "partitions"):
self._time_partitioning()
self._station_partitioning()
Expand Down Expand Up @@ -362,15 +382,15 @@ def random_split(
seed: int = 10,
) -> dict[str, np.ndarray]:
"""Split an input index array randomly"""
np.random.seed(seed)
rng = np.random.default_rng(np.random.PCG64(seed))

assert np.isclose(sum(split_fractions.values()), 1.0)

n_samples = len(index)
partitions = list(split_fractions.keys())
fractions = np.array(list(split_fractions.values()))

shuffled_index = np.random.permutation(index)
shuffled_index = rng.permutation(index)
indices = np.floor(np.cumsum(fractions)[:-1] * n_samples).astype(int)
sub_arrays = np.split(shuffled_index, indices)
return dict(zip(partitions, sub_arrays))
Expand Down
Loading

0 comments on commit 4774838

Please sign in to comment.