diff --git a/mlpp_lib/datasets.py b/mlpp_lib/datasets.py index 539a1bb..c8a4239 100644 --- a/mlpp_lib/datasets.py +++ b/mlpp_lib/datasets.py @@ -596,9 +596,7 @@ def __init__( self.shuffle = shuffle self.block_size = block_size self.num_samples = len(self.dataset.x) - self.num_batches = ( - self.num_samples // batch_size if batch_size <= self.num_samples else 1 - ) + self.num_batches = int(np.ceil(self.num_samples / batch_size)) self._indices = tf.range(self.num_samples) self._seed = 0 self._reset()