From 5a1f4b7afb8a822f66c0ddc75bc959a44a57d035 Mon Sep 17 00:00:00 2001 From: boltzmann-Li Date: Thu, 28 Oct 2021 12:40:38 +0000 Subject: [PATCH] work around for pytorch stft backward bug. --- train.py | 31 ++++++++++++++++--------------- train_ms.py | 29 +++++++++++++++-------------- 2 files changed, 31 insertions(+), 29 deletions(-) diff --git a/train.py b/train.py index 703d30cf..517cd789 100644 --- a/train.py +++ b/train.py @@ -151,24 +151,25 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade hps.data.mel_fmin, hps.data.mel_fmax) y_mel = commons.slice_segments(mel, ids_slice, hps.train.segment_size // hps.data.hop_length) - y_hat_mel = mel_spectrogram_torch( - y_hat.squeeze(1), - hps.data.filter_length, - hps.data.n_mel_channels, - hps.data.sampling_rate, - hps.data.hop_length, - hps.data.win_length, - hps.data.mel_fmin, - hps.data.mel_fmax - ) + y_hat = y_hat.float() + y_hat_mel = mel_spectrogram_torch( + y_hat.squeeze(1), + hps.data.filter_length, + hps.data.n_mel_channels, + hps.data.sampling_rate, + hps.data.hop_length, + hps.data.win_length, + hps.data.mel_fmin, + hps.data.mel_fmax + ) - y = commons.slice_segments(y, ids_slice * hps.data.hop_length, hps.train.segment_size) # slice + y = commons.slice_segments(y, ids_slice * hps.data.hop_length, hps.train.segment_size) # slice # Discriminator - y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach()) - with autocast(enabled=False): - loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(y_d_hat_r, y_d_hat_g) - loss_disc_all = loss_disc + y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach()) + with autocast(enabled=False): + loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(y_d_hat_r, y_d_hat_g) + loss_disc_all = loss_disc optim_d.zero_grad() scaler.scale(loss_disc_all).backward() scaler.unscale_(optim_d) diff --git a/train_ms.py b/train_ms.py index 34870c62..d87c0639 100644 --- a/train_ms.py +++ b/train_ms.py @@ -153,24 +153,25 @@ def train_and_evaluate(rank, epoch, hps, nets, optims, schedulers, scaler, loade hps.data.mel_fmin, hps.data.mel_fmax) y_mel = commons.slice_segments(mel, ids_slice, hps.train.segment_size // hps.data.hop_length) - y_hat_mel = mel_spectrogram_torch( - y_hat.squeeze(1), - hps.data.filter_length, - hps.data.n_mel_channels, - hps.data.sampling_rate, - hps.data.hop_length, - hps.data.win_length, - hps.data.mel_fmin, - hps.data.mel_fmax + y_hat = y_hat.float() + y_hat_mel = mel_spectrogram_torch( + y_hat.squeeze(1), + hps.data.filter_length, + hps.data.n_mel_channels, + hps.data.sampling_rate, + hps.data.hop_length, + hps.data.win_length, + hps.data.mel_fmin, + hps.data.mel_fmax ) - y = commons.slice_segments(y, ids_slice * hps.data.hop_length, hps.train.segment_size) # slice + y = commons.slice_segments(y, ids_slice * hps.data.hop_length, hps.train.segment_size) # slice # Discriminator - y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach()) - with autocast(enabled=False): - loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(y_d_hat_r, y_d_hat_g) - loss_disc_all = loss_disc + y_d_hat_r, y_d_hat_g, _, _ = net_d(y, y_hat.detach()) + with autocast(enabled=False): + loss_disc, losses_disc_r, losses_disc_g = discriminator_loss(y_d_hat_r, y_d_hat_g) + loss_disc_all = loss_disc optim_d.zero_grad() scaler.scale(loss_disc_all).backward() scaler.unscale_(optim_d)