Skip to content

Commit

Permalink
More testing
Browse files Browse the repository at this point in the history
  • Loading branch information
dnerini committed Aug 31, 2024
1 parent 4990161 commit 6c5bbb1
Showing 1 changed file with 16 additions and 6 deletions.
22 changes: 16 additions & 6 deletions tests/test_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
{
"features": ["coe:x1"],
"targets": ["obs:y1"],
"normalizer": {"default": "MinMaxScaler"},
"model": {
"fully_connected_network": {
"hidden_layers": [10],
Expand All @@ -34,6 +35,7 @@
{
"features": ["coe:x1"],
"targets": ["obs:y1"],
"normalizer": {"default": "MinMaxScaler"},
"model": {
"fully_connected_network": {
"hidden_layers": [10],
Expand All @@ -48,6 +50,7 @@
{
"features": ["coe:x1"],
"targets": ["obs:y1"],
"normalizer": {"default": "MinMaxScaler"},
"model": {
"fully_connected_network": {
"hidden_layers": [10],
Expand Down Expand Up @@ -76,6 +79,7 @@
{
"features": ["coe:x1"],
"targets": ["obs:y1"],
"normalizer": {"default": "MinMaxScaler"},
"model": {
"fully_connected_network": {
"hidden_layers": [10],
Expand All @@ -101,6 +105,7 @@
{
"features": ["coe:x1"],
"targets": ["obs:y1"],
"normalizer": {"default": "MinMaxScaler"},
"model": {
"fully_connected_network": {
"hidden_layers": [10],
Expand All @@ -117,6 +122,7 @@
{
"features": ["coe:x1"],
"targets": ["obs:y1"],
"normalizer": {"default": "MinMaxScaler"},
"model": {
"fully_connected_network": {
"hidden_layers": [10],
Expand Down Expand Up @@ -149,21 +155,23 @@ def test_train_fromfile(tmp_path, cfg):
cfg.update({"epochs": num_epochs})

splitter_options = ValidDataSplitterOptions(time="lists", station="lists")
splitter = DataSplitter(splitter_options.time_split, splitter_options.station_split)
datasplitter = DataSplitter(splitter_options.time_split, splitter_options.station_split)
datanormalizer = DataTransformer(cfg["normalizer"])
batch_dims = ["forecast_reference_time", "t", "station"]
datamodule = DataModule(
features=cfg["features"],
targets=cfg["targets"],
batch_dims=batch_dims,
splitter=splitter,
splitter=datasplitter,
normalizer=datanormalizer,
data_dir=tmp_path.as_posix() + "/",
)
results = train.train(cfg, datamodule)

assert len(results) == 4
assert isinstance(results[0], Functional) # model
assert isinstance(results[1], dict) # custom_objects
assert isinstance(results[2], DataTransformer) # standardizer
assert isinstance(results[2], DataTransformer) # normalizer
assert isinstance(results[3], dict) # history

assert all([np.isfinite(v).all() for v in results[3].values()])
Expand All @@ -182,21 +190,23 @@ def test_train_fromds(features_dataset, targets_dataset, cfg):
cfg.update({"epochs": num_epochs})

splitter_options = ValidDataSplitterOptions(time="lists", station="lists")
splitter = DataSplitter(splitter_options.time_split, splitter_options.station_split)
datasplitter = DataSplitter(splitter_options.time_split, splitter_options.station_split)
datanormalizer = DataTransformer(**cfg["normalizer"])
batch_dims = ["forecast_reference_time", "t", "station"]
datamodule = DataModule(
features_dataset[cfg["features"]],
targets_dataset[cfg["targets"]],
batch_dims,
splitter,
splitter=datasplitter,
normalizer=datanormalizer,
group_samples=cfg.get("group_samples"),
)
results = train.train(cfg, datamodule)

assert len(results) == 4
assert isinstance(results[0], Functional) # model
assert isinstance(results[1], dict) # custom_objects
assert isinstance(results[2], DataTransformer) # standardizer
assert isinstance(results[2], DataTransformer) # normalizer
assert isinstance(results[3], dict) # history

assert all([np.isfinite(v).all() for v in results[3].values()])
Expand Down

0 comments on commit 6c5bbb1

Please sign in to comment.