Skip to content

Commit

Permalink
Update train_accelerate.py
Browse files Browse the repository at this point in the history
  • Loading branch information
SUC-DriverOld committed Aug 16, 2024
1 parent ba1a5ee commit 27bc31c
Showing 1 changed file with 28 additions and 41 deletions.
69 changes: 28 additions & 41 deletions train_accelerate.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,6 @@
current_dir = os.path.dirname(os.path.realpath(__file__))
sys.path.append(current_dir)
import torch
import wandb
import auraloss
import torch.nn as nn
from torch.optim import Adam, AdamW, SGD, RAdam, RMSprop
Expand All @@ -32,6 +31,11 @@

warnings.filterwarnings("ignore")

import logging
log_format = "%(asctime)s.%(msecs)03d [%(levelname)s] %(module)s - %(message)s"
date_format = "%H:%M:%S"
logging.basicConfig(level = logging.INFO, format = log_format, datefmt = date_format)
logger = logging.getLogger(__name__)

def valid(model, valid_loader, args, config, device, verbose=False):
instruments = config.training.instruments
Expand Down Expand Up @@ -82,7 +86,7 @@ def __init__(self, args):
for valid_path in args.valid_path:
part = sorted(glob.glob(valid_path + '/*/mixture.wav'))
if len(part) == 0:
print('No validation data found in: {}'.format(valid_path))
logger.info('No validation data found in: {}'.format(valid_path))
all_mixtures_path += part

self.list_of_files = all_mixtures_path
Expand Down Expand Up @@ -113,7 +117,6 @@ def train_model(args):
parser.add_argument("--use_multistft_loss", action='store_true', help="Use MultiSTFT Loss (from auraloss package)")
parser.add_argument("--use_mse_loss", action='store_true', help="Use default MSE loss")
parser.add_argument("--use_l1_loss", action='store_true', help="Use L1 loss")
parser.add_argument("--wandb_key", type=str, default='', help='wandb API Key')
if args is None:
args = parser.parse_args()
else:
Expand All @@ -126,21 +129,14 @@ def train_model(args):

model, config = get_model_from_config(args.model_type, args.config_path)
if accelerator.is_main_process:
print("Instruments: {}".format(config.training.instruments))
logger.info("Instruments: {}".format(config.training.instruments))

if not os.path.isdir(args.results_path):
os.mkdir(args.results_path)

device_ids = args.device_ids
batch_size = config.training.batch_size

# wandb
if accelerator.is_main_process and args.wandb_key is not None and args.wandb_key.strip() != '':
wandb.login(key = args.wandb_key)
wandb.init(project = 'msst-accelerate', config = { 'config': config, 'args': args, 'device_ids': device_ids, 'batch_size': batch_size })
else:
wandb.init(mode = 'disabled')

trainset = MSSDataset(
config,
args.data_path,
Expand Down Expand Up @@ -171,7 +167,7 @@ def train_model(args):

if args.start_check_point != '':
if accelerator.is_main_process:
print('Start from checkpoint: {}'.format(args.start_check_point))
logger.info('Start from checkpoint: {}'.format(args.start_check_point))
if 1:
load_not_compatible_weights(model, args.start_check_point, verbose=False)
else:
Expand All @@ -182,7 +178,7 @@ def train_model(args):
optim_params = dict()
if 'optimizer' in config:
optim_params = dict(config['optimizer'])
print('Optimizer params from config:\n{}'.format(optim_params))
logger.info('Optimizer params from config:\n{}'.format(optim_params))

if config.training.optimizer == 'adam':
optimizer = Adam(model.parameters(), lr=config.training.lr, **optim_params)
Expand All @@ -192,19 +188,14 @@ def train_model(args):
optimizer = RAdam(model.parameters(), lr=config.training.lr, **optim_params)
elif config.training.optimizer == 'rmsprop':
optimizer = RMSprop(model.parameters(), lr=config.training.lr, **optim_params)
elif config.training.optimizer == 'prodigy':
from prodigyopt import Prodigy
# you can choose weight decay value based on your problem, 0 by default
# We recommend using lr=1.0 (default) for all networks.
optimizer = Prodigy(model.parameters(), lr=config.training.lr, **optim_params)
elif config.training.optimizer == 'adamw8bit':
import bitsandbytes as bnb
optimizer = bnb.optim.AdamW8bit(model.parameters(), lr=config.training.lr, **optim_params)
elif config.training.optimizer == 'sgd':
print('Use SGD optimizer')
logger.info('Use SGD optimizer')
optimizer = SGD(model.parameters(), lr=config.training.lr, **optim_params)
else:
print('Unknown optimizer: {}'.format(config.training.optimizer))
logger.info('Unknown optimizer: {}'.format(config.training.optimizer))
exit()

gradient_accumulation_steps = 1
Expand All @@ -214,8 +205,8 @@ def train_model(args):
pass

if accelerator.is_main_process:
print('Processes GPU: {}'.format(accelerator.num_processes))
print("Patience: {} Reduce factor: {} Batch size: {} Grad accum steps: {} Effective batch size: {} Optimizer: {}".format(
logger.info('Processes GPU: {}'.format(accelerator.num_processes))
logger.info("Patience: {} Reduce factor: {} Batch size: {} Grad accum steps: {} Effective batch size: {} Optimizer: {}".format(
config.training.patience,
config.training.reduce_factor,
batch_size,
Expand All @@ -238,7 +229,7 @@ def train_model(args):
except:
loss_options = dict()
if accelerator.is_main_process:
print('Loss options: {}'.format(loss_options))
logger.info('Loss options: {}'.format(loss_options))
loss_multistft = auraloss.freq.MultiResolutionSTFTLoss(
**loss_options
)
Expand All @@ -249,37 +240,37 @@ def train_model(args):
sdr_list = accelerator.gather(sdr_list)
accelerator.wait_for_everyone()

# print(sdr_list)
# logger.info(sdr_list)

sdr_avg = 0.0
instruments = config.training.instruments
if config.training.target_instrument is not None:
instruments = [config.training.target_instrument]

for instr in instruments:
# print(sdr_list[instr])
# logger.info(sdr_list[instr])
sdr_data = torch.cat(sdr_list[instr], dim=0).cpu().numpy()
sdr_val = sdr_data.mean()
if accelerator.is_main_process:
print("Valid length: {}".format(valid_dataset_length))
print("Instr SDR {}: {:.4f} Debug: {}".format(instr, sdr_val, len(sdr_data)))
logger.info("Valid length: {}".format(valid_dataset_length))
logger.info("Instr SDR {}: {:.4f} Debug: {}".format(instr, sdr_val, len(sdr_data)))
sdr_val = sdr_data[:valid_dataset_length].mean()
if accelerator.is_main_process:
print("Instr SDR {}: {:.4f} Debug: {}".format(instr, sdr_val, len(sdr_data)))
logger.info("Instr SDR {}: {:.4f} Debug: {}".format(instr, sdr_val, len(sdr_data)))
sdr_avg += sdr_val
sdr_avg /= len(instruments)
if len(instruments) > 1:
if accelerator.is_main_process:
print('SDR Avg: {:.4f}'.format(sdr_avg))
logger.info('SDR Avg: {:.4f}'.format(sdr_avg))
sdr_list = None

if accelerator.is_main_process:
print('Train for: {}'.format(config.training.num_epochs))
logger.info('Train for: {}'.format(config.training.num_epochs))
best_sdr = -100
for epoch in range(config.training.num_epochs):
model.train().to(device)
if accelerator.is_main_process:
print('Train epoch: {} Learning rate: {}'.format(epoch, optimizer.param_groups[0]['lr']))
logger.info('Train epoch: {} Learning rate: {}'.format(epoch, optimizer.param_groups[0]['lr']))
loss_val = 0.
total = 0

Expand Down Expand Up @@ -330,12 +321,10 @@ def train_model(args):
loss_val += li
total += 1
if accelerator.is_main_process:
wandb.log({'loss': 100 * li, 'avg_loss': 100 * loss_val / (i + 1), 'total': total, 'loss_val': loss_val, 'i': i })
pbar.set_postfix({'loss': 100 * li, 'avg_loss': 100 * loss_val / (i + 1)})

if accelerator.is_main_process:
print('Training loss: {:.6f}'.format(loss_val / total))
wandb.log({'train_loss': loss_val / total, 'epoch': epoch})
logger.info('Training loss: {:.6f}'.format(loss_val / total))

# Save last
store_path = args.results_path + '/last_{}.ckpt'.format(args.model_type)
Expand All @@ -355,24 +344,22 @@ def train_model(args):

for instr in instruments:
if accelerator.is_main_process and 0:
print(sdr_list[instr])
logger.info(sdr_list[instr])
sdr_data = torch.cat(sdr_list[instr], dim=0).cpu().numpy()
# sdr_val = sdr_data.mean()
sdr_val = sdr_data[:valid_dataset_length].mean()
if accelerator.is_main_process:
print("Instr SDR {}: {:.4f} Debug: {}".format(instr, sdr_val, len(sdr_data)))
wandb.log({ f'{instr}_sdr': sdr_val })
logger.info("Instr SDR {}: {:.4f} Debug: {}".format(instr, sdr_val, len(sdr_data)))
sdr_avg += sdr_val
sdr_avg /= len(instruments)
if len(instruments) > 1:
if accelerator.is_main_process:
print('SDR Avg: {:.4f}'.format(sdr_avg))
wandb.log({'sdr_avg': sdr_avg, 'best_sdr': best_sdr})
logger.info('SDR Avg: {:.4f}'.format(sdr_avg))

if accelerator.is_main_process:
if sdr_avg > best_sdr:
store_path = args.results_path + '/model_{}_ep_{}_sdr_{:.4f}.ckpt'.format(args.model_type, epoch, sdr_avg)
print('Store weights: {}'.format(store_path))
logger.info('Store weights: {}'.format(store_path))
unwrapped_model = accelerator.unwrap_model(model)
accelerator.save(unwrapped_model.state_dict(), store_path)
best_sdr = sdr_avg
Expand All @@ -382,4 +369,4 @@ def train_model(args):


if __name__ == "__main__":
train_model(None)
train_model(None)

0 comments on commit 27bc31c

Please sign in to comment.