Skip to content

Commit

Permalink
update logger
Browse files Browse the repository at this point in the history
  • Loading branch information
yxlllc committed Sep 11, 2024
1 parent f875fbf commit 5f31240
Show file tree
Hide file tree
Showing 5 changed files with 47 additions and 55 deletions.
3 changes: 2 additions & 1 deletion configs/combsub.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,5 @@ train:
interval_log: 10
interval_val: 2000
lr: 0.0005
weight_decay: 0
weight_decay: 0
save_opt: false
1 change: 1 addition & 0 deletions configs/sins.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -38,3 +38,4 @@ train:
interval_val: 2000
lr: 0.0005
weight_decay: 0
save_opt: false
5 changes: 3 additions & 2 deletions ddsp/loss.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,14 @@ def __init__(self, block_size, fft_min, fft_max, n_scale, lambda_uv, device):
self.loss_uv_func = UVLoss(block_size)
self.lambda_uv = lambda_uv

def forward(self, signal, s_h, x_true, uv_true, detach_uv=False, uv_tolerance=0.05):
def forward(self, signal, s_h, x_true, uv_true, detach_uv=False, uv_tolerance=0.05, prefix='train/'):
loss_rss = self.loss_rss_func(signal, x_true)
loss_uv = self.loss_uv_func(signal, s_h, uv_true)
if detach_uv or loss_uv < uv_tolerance:
loss_uv = loss_uv.detach()
loss = loss_rss + self.lambda_uv * loss_uv
return loss, (loss_rss, loss_uv)
loss_dict = {prefix+'loss': loss.item(), prefix + 'loss_rss': loss_rss.item(), prefix+'loss_uv': loss_uv.item()}
return loss, loss_dict

class UVLoss(nn.Module):
def __init__(self, block_size, eps = 1e-5):
Expand Down
13 changes: 9 additions & 4 deletions logger/saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,10 +106,15 @@ def save_model(
print(' [*] model checkpoint saved: {}'.format(path_pt))

# save
torch.save({
'global_step': self.global_step,
'model': model.state_dict(),
'optimizer': optimizer.state_dict()}, path_pt)
if optimizer is not None:
torch.save({
'global_step': self.global_step,
'model': model.state_dict(),
'optimizer': optimizer.state_dict()}, path_pt)
else:
torch.save({
'global_step': self.global_step,
'model': model.state_dict()}, path_pt)

# to json
if to_json:
Expand Down
80 changes: 32 additions & 48 deletions solver.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,14 +10,10 @@ def test(args, model, loss_func, loader_test, saver):
print(' [*] testing...')
model.eval()

# losses
test_loss = 0.
test_loss_rss = 0.
test_loss_uv = 0.

# intialization
num_batches = len(loader_test)
rtf_all = []
test_loss_dict = {}

# run
with torch.no_grad():
Expand Down Expand Up @@ -50,26 +46,23 @@ def test(args, model, loss_func, loader_test, saver):
rtf_all.append(rtf)

# loss
loss, (loss_rss, loss_uv) = loss_func(signal, s_h, data['audio'], data['uv'])
loss, loss_dict = loss_func(signal, s_h, data['audio'], data['uv'], prefix='validation/')

test_loss += loss.item()
test_loss_rss += loss_rss.item()
test_loss_uv += loss_uv.item()
if test_loss_dict == {}:
for key, value in loss_dict.items():
test_loss_dict[key] = value / num_batches
else:
for key, value in loss_dict.items():
test_loss_dict[key] += value / num_batches

# log
saver.log_audio({fn+'/gt.wav': data['audio'], fn+'/pred.wav': signal})

# report
test_loss /= num_batches
test_loss_rss /= num_batches
test_loss_uv /= num_batches

# check
print(' [test_loss] test_loss:', test_loss)
print(' [test_loss_rss] test_loss_rss:', test_loss_rss)
print(' [test_loss_uv] test_loss_uv:', test_loss_uv)
print(' [test_loss] test_loss:', test_loss_dict['validation/loss'])
print(' [test_loss] test_loss_rss:', test_loss_dict['validation/loss_rss'])
print(' Real Time Factor', np.mean(rtf_all))
return test_loss, test_loss_rss, test_loss_uv
return test_loss_dict


def train(args, initial_global_step, model, optimizer, loss_func, loader_train, loader_test):
Expand Down Expand Up @@ -103,7 +96,14 @@ def train(args, initial_global_step, model, optimizer, loss_func, loader_train,
detach_uv = False
if saver.global_step < args.loss.detach_uv_step:
detach_uv = True
loss, (loss_rss, loss_uv) = loss_func(signal, s_h, data['audio'], data['uv'], detach_uv = detach_uv, uv_tolerance = args.loss.uv_tolerance)
loss, loss_dict = loss_func(
signal,
s_h,
data['audio'],
data['uv'],
detach_uv = detach_uv,
uv_tolerance = args.loss.uv_tolerance,
prefix = 'train/')

# handle nan loss
if torch.isnan(loss):
Expand All @@ -116,53 +116,37 @@ def train(args, initial_global_step, model, optimizer, loss_func, loader_train,
# log loss
if saver.global_step % args.train.interval_log == 0:
saver.log_info(
'epoch: {} | {:3d}/{:3d} | {} | batch/s: {:.2f} | loss: {:.3f} | rss: {:.3f}| uv: {:.3f} | time: {} | step: {}'.format(
'epoch: {} | {:3d}/{:3d} | {} | batch/s: {:.2f} | loss: {:.3f} | rss: {:.3f} | time: {} | step: {}'.format(
epoch,
batch_idx,
num_batches,
args.env.expdir,
args.train.interval_log/saver.get_interval_time(),
loss.item(),
loss_rss.item(),
loss_uv.item(),
loss_dict['train/loss'],
loss_dict['train/loss_rss'],
saver.get_total_time(),
saver.global_step
)
)

saver.log_value({
'train/loss': loss.item(),
'train/rss': loss_rss.item(),
'train/uv': loss_uv.item()
})
saver.log_value(loss_dict)

# validation
if saver.global_step % args.train.interval_val == 0:
optimizer_save = optimizer if args.train.save_opt else None

# save latest
saver.save_model(model, optimizer, postfix=f'{saver.global_step}')
saver.save_model(model, optimizer_save, postfix=f'{saver.global_step}')

# run testing set
test_loss, test_loss_rss, test_loss_uv = test(args, model, loss_func, loader_test, saver)
test_loss_dict = test(args, model, loss_func, loader_test, saver)

saver.log_info(
' --- <validation> --- \nloss: {:.3f} | rss: {:.3f} | uv: {:.3f}. '.format(
test_loss,
test_loss_rss,
test_loss_uv
' --- <validation> --- \nloss: {:.3f} | rss: {:.3f}. '.format(
test_loss_dict['validation/loss'],
test_loss_dict['validation/loss_rss']
)
)

saver.log_value({
'validation/loss': test_loss,
'validation/rss': test_loss_rss,
'validation/uv': test_loss_uv
})
)
saver.log_value(test_loss_dict)
model.train()

# save best model
if test_loss < best_loss:
saver.log_info(' [V] best model updated.')
saver.save_model(model, optimizer, postfix='best')
best_loss = test_loss


0 comments on commit 5f31240

Please sign in to comment.