diff --git a/train.py b/train.py index 703d30cf..c217473d 100644 --- a/train.py +++ b/train.py @@ -72,11 +72,11 @@ def run(rank, n_gpus, hps): rank=rank, shuffle=True) collate_fn = TextAudioCollate() - train_loader = DataLoader(train_dataset, num_workers=8, shuffle=False, pin_memory=True, + train_loader = DataLoader(train_dataset, num_workers=n_gpus if n_gpus > 1 else 0, shuffle=False, pin_memory=True, collate_fn=collate_fn, batch_sampler=train_sampler) if rank == 0: eval_dataset = TextAudioLoader(hps.data.validation_files, hps.data) - eval_loader = DataLoader(eval_dataset, num_workers=8, shuffle=False, + eval_loader = DataLoader(eval_dataset, num_workers=n_gpus if n_gpus > 1 else 0, shuffle=False, batch_size=hps.train.batch_size, pin_memory=True, drop_last=False, collate_fn=collate_fn) diff --git a/train_ms.py b/train_ms.py index 34870c62..4e1cbe97 100644 --- a/train_ms.py +++ b/train_ms.py @@ -72,11 +72,11 @@ def run(rank, n_gpus, hps): rank=rank, shuffle=True) collate_fn = TextAudioSpeakerCollate() - train_loader = DataLoader(train_dataset, num_workers=8, shuffle=False, pin_memory=True, + train_loader = DataLoader(train_dataset, num_workers=n_gpus if n_gpus > 1 else 0, shuffle=False, pin_memory=True, collate_fn=collate_fn, batch_sampler=train_sampler) if rank == 0: eval_dataset = TextAudioSpeakerLoader(hps.data.validation_files, hps.data) - eval_loader = DataLoader(eval_dataset, num_workers=8, shuffle=False, + eval_loader = DataLoader(eval_dataset, num_workers=n_gpus if n_gpus > 1 else 0, shuffle=False, batch_size=hps.train.batch_size, pin_memory=True, drop_last=False, collate_fn=collate_fn)