From f6e11bc25d1d97e977317d3e4f7be8d49e7e7d46 Mon Sep 17 00:00:00 2001 From: louisPoulain Date: Thu, 27 Jun 2024 11:46:15 +0200 Subject: [PATCH] Prevent DataLoader from not working in cases where the total number of samples is less than the batch number (especially useful for val and test) --- mlpp_lib/datasets.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/mlpp_lib/datasets.py b/mlpp_lib/datasets.py index c8b7bfe..cacfbe9 100644 --- a/mlpp_lib/datasets.py +++ b/mlpp_lib/datasets.py @@ -593,7 +593,7 @@ def __init__( self.shuffle = shuffle self.block_size = block_size self.num_samples = len(self.dataset.x) - self.num_batches = self.num_samples // batch_size + self.num_batches = self.num_samples // batch_size if batch_size <= self.num_samples else 1 self._indices = tf.range(self.num_samples) self._seed = 0 self._reset()