Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Partition epoch as a multi-GPU dataset distribution method #712

Open
wants to merge 6 commits into
base: master
Choose a base branch
from
Open
28 changes: 28 additions & 0 deletions returnn/datasets/basic.py
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,7 @@ def __init__(self, name=None,
self.random_seed_offset = random_seed_offset
self.partition_epoch = partition_epoch or 1
self.repeat_epoch = repeat_epoch or 1
self.disable_horovod_partition = False # can be set by meta-dataset to handle multi-gpu partitioning on meta-level
self.seq_tags_filter = set(self._load_seq_list_file(seq_list_filter_file)) if seq_list_filter_file else None
self.unique_seq_tags = unique_seq_tags
self._seq_order_seq_lens_file = seq_order_seq_lens_file
Expand Down Expand Up @@ -483,6 +484,8 @@ def get_seq_order_for_epoch(self, epoch, num_seqs, get_seq_len=None):
seq_index = self._apply_partition_epoch(seq_index, partition_epoch, epoch)
if repeat_epoch > 1:
seq_index = seq_index * repeat_epoch
if not self.disable_horovod_partition:
seq_index = self._apply_multi_gpu_partition(seq_index)
if self.seq_tags_filter is not None:
# Note: This is as generic as possible, but requires that get_all_tags is implemented.
assert seq_index
Expand Down Expand Up @@ -517,6 +520,31 @@ def _apply_partition_epoch(cls, seq_index, partition_epoch, epoch):

return seq_index

@classmethod
def _apply_multi_gpu_partition(cls, seq_index):
"""
Via horovod_dataset_distribution = "partition", does nothing if not set.

:param list[int] seq_index:
:return: partition of seq_index for the current processes, i.e. we split onto the different GPUs
:rtype: list[int]
"""
from returnn.config import get_global_config
config = get_global_config(raise_exception=False)
if not config or not config.is_true("use_horovod"):
return seq_index

import returnn.tf.horovod
if not returnn.tf.horovod.get_ctx().get_dataset_distribution_type() == "partition":
return seq_index

rank = returnn.tf.horovod.get_ctx().rank() + 1 # one-based to make work as "epoch"
num_gpus = returnn.tf.horovod.get_ctx().size()

# Reuse the partition epoch logic to split current sub-epoch between different GPUs.
seq_index = cls._apply_partition_epoch(seq_index, partition_epoch=num_gpus, epoch=rank)
return seq_index

def _get_random_seed_for_epoch(self, epoch):
"""
:param int|None epoch:
Expand Down
5 changes: 4 additions & 1 deletion returnn/datasets/map.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,8 +102,11 @@ def init_seq_order(self, epoch=None, seq_list=None, seq_order=None):
"""
super(MapDatasetWrapper, self).init_seq_order(epoch=epoch, seq_list=seq_list, seq_order=seq_order)

if seq_list or seq_order:
if seq_list:
raise NotImplementedError
if seq_order:
self._seq_order = seq_order
return True

try:
self._seq_order = self._dataset.get_seq_order(epoch=epoch)
Expand Down
8 changes: 7 additions & 1 deletion returnn/datasets/meta.py
Original file line number Diff line number Diff line change
Expand Up @@ -863,8 +863,9 @@ def __init__(self,
# This will only initialize datasets needed for features occurring in data_map
self.datasets = {key: init_dataset(datasets[key]) for key in self.dataset_keys}

self._estimated_num_seqs = sum([self.datasets[k].estimated_num_seqs for k in sorted(self.datasets.keys())])
self.estimated_num_seq_per_subset = [self.datasets[k].estimated_num_seqs for k in sorted(self.datasets.keys())]
if all(num_seq is not None for num_seq in self.estimated_num_seq_per_subset):
self._estimated_num_seqs = sum(self.estimated_num_seq_per_subset)

if data_dims:
data_dims = convert_data_dims(data_dims)
Expand Down Expand Up @@ -913,6 +914,9 @@ def init_seq_order(self, epoch=None, seq_list=None, seq_order=None):
# partition epoch of the individual sub-datasets is still supported. Later we will call init_seq_order again with a
# sequence list to e.g. apply joint sorting or partition epoch of all sequences.
for dataset in self.datasets.values():
if self.sampling_sizes:
# Partitioning does not make sense if we sample a fixed number of sequences anyway.
dataset.disable_horovod_partition = True
dataset.init_seq_order(epoch=epoch)

# noinspection PyBroadException
Expand Down Expand Up @@ -1076,6 +1080,8 @@ def _get_sampling_seq_order(self):
# We want to additionally sort the sequences in the current sample. For this, create a sequence order on a
# range of length of the number of sequences in the sample. Note that we have to map the indices to make use
# of self._get_seq_length here.
# This get_seq_order_for_epoch call now also handles horovod_dataset_distribution = 'partition', which we
# disabled on sub-dataset level via 'disable_horovod_partition' above.
seq_order_remapping = self.get_seq_order_for_epoch(
epoch=epoch, num_seqs=len(seq_order), get_seq_len=lambda i: self._get_seq_length(seq_order[i]))

Expand Down
2 changes: 1 addition & 1 deletion returnn/tf/horovod.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,7 @@ def get_dataset_distribution_type(self):
:rtype: str
"""
dataset_distribution = self._config.value("horovod_dataset_distribution", "shard")
assert dataset_distribution in {"shard", "random_seed_offset"}
assert dataset_distribution in {"shard", "random_seed_offset", "partition"}
return dataset_distribution

def is_dataset_distribution_shard(self):
Expand Down
80 changes: 80 additions & 0 deletions tests/test_Dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,8 +6,10 @@
import sys
import _setup_test_env # noqa
import unittest
import numpy
from nose.tools import assert_equal, assert_is_instance, assert_in, assert_not_in, assert_true, assert_false
from returnn.datasets.generating import GeneratingDataset, DummyDataset, DummyDatasetMultipleSequenceLength
from returnn.datasets.map import FromListDataset, MapDatasetWrapper
from returnn.engine.batch import Batch
from returnn.datasets.basic import DatasetSeq
from returnn.util.basic import NumbersDict
Expand Down Expand Up @@ -320,6 +322,84 @@ def test_task12ax_window():
assert_equal(list(data2a[-1, 2]), [0] * input_dim) # zero-padded right


def test_horovod_partition():
num_seqs = 10
dummy_data = [{"data": numpy.array([i])} for i in range(num_seqs)]
# FromListDataset because DummyDataset does not support sequence ordering and thus no partitioning.
dataset = MapDatasetWrapper(
FromListDataset(data_list=dummy_data, data_types=None), seq_ordering="random")
from returnn.config import get_global_config
global_config = get_global_config(auto_create=True)
global_config.set("use_horovod", True)
global_config.set("horovod_dataset_distribution", "partition")
from returnn.tf import horovod

horovod_size = 3
data_out = []
for rank in range(horovod_size):
# Simulating a multi-gpu setup.
def get_dummy_ctx(config=None):
class DummyHorovodContext(horovod.HorovodContext):
def __init__(self, config):
self._rank = rank
self._size = horovod_size
self._config = config
return DummyHorovodContext(config or global_config)
horovod.get_ctx = get_dummy_ctx
dataset.init_seq_order(epoch=1)
seq_idx = 0
while dataset.is_less_than_num_seqs(seq_idx):
dataset.load_seqs(seq_idx, seq_idx + 1)
data = dataset.get_data(seq_idx, "data")
data_out.extend(data.tolist())
seq_idx += 1
assert len(data_out) == num_seqs
assert set(data_out) == set(range(num_seqs))


def test_horovod_partition_combined_dataset_sampling():
num_seqs = 10
sampling_size = 12
dummy_data = [{"data": numpy.array([i])} for i in range(num_seqs)]
from returnn.datasets.meta import CombinedDataset
dataset = MapDatasetWrapper(FromListDataset(data_list=dummy_data))
combined_dataset = CombinedDataset(
datasets={"dataset": dataset}, data_map={("dataset", "data"): "data"}, sampling_sizes={"dataset": sampling_size},
data_dims={"data": (1, 1)}, seq_ordering="random")
from returnn.config import get_global_config
global_config = get_global_config(auto_create=True)
global_config.set("use_horovod", True)
global_config.set("horovod_dataset_distribution", "partition")
from returnn.tf import horovod

horovod_size = 3
data_out = []
for rank in range(horovod_size):
# Simulating a multi-gpu setup.
def get_dummy_ctx(config=None):
class DummyHorovodContext(horovod.HorovodContext):
def __init__(self, config):
self._rank = rank
self._size = horovod_size
self._config = config
return DummyHorovodContext(config or global_config)
horovod.get_ctx = get_dummy_ctx
combined_dataset.init_seq_order(epoch=None)
seq_idx = 0
while combined_dataset.is_less_than_num_seqs(seq_idx):
combined_dataset.load_seqs(seq_idx, seq_idx + 1)
data = combined_dataset.get_data(seq_idx, "data")
data_out.extend(data.tolist())
seq_idx += 1
# We sample 12 values from range(10) "in order", so 0 and 1 should appear twice, all other values once. This e.g.
# would not be the case if the sub-dataset is partitioned before sampling,
# see Dataset.disable_horovod_partition.
assert len(data_out) == sampling_size
assert set(data_out) == set(range(num_seqs))
assert data_out.count(0) == 2
assert data_out.count(1) == 2


if __name__ == "__main__":
better_exchook.install()
if len(sys.argv) <= 1:
Expand Down