From a46342fea005f5e0d72f6940cfa8673a41814998 Mon Sep 17 00:00:00 2001 From: Leng Yue Date: Thu, 14 Dec 2023 05:45:47 -0800 Subject: [PATCH] Faster preprocessing --- tools/preprocessing/extract_features.py | 157 +++++++++++++++++------- 1 file changed, 111 insertions(+), 46 deletions(-) diff --git a/tools/preprocessing/extract_features.py b/tools/preprocessing/extract_features.py index c670ba0a..268f6f7a 100644 --- a/tools/preprocessing/extract_features.py +++ b/tools/preprocessing/extract_features.py @@ -1,19 +1,21 @@ import argparse +import os import random -from concurrent.futures import ProcessPoolExecutor +import subprocess as sp +import time from copy import deepcopy +from datetime import timedelta from pathlib import Path +from random import Random from typing import Optional import librosa import numpy as np import torch -import torch.multiprocessing as mp import torchcrepe from fish_audio_preprocess.utils.file import AUDIO_EXTENSIONS, list_files from loguru import logger from mmengine import Config -from tqdm import tqdm from fish_diffusion.modules.energy_extractors import ENERGY_EXTRACTORS from fish_diffusion.modules.feature_extractors import FEATURE_EXTRACTORS @@ -38,14 +40,10 @@ def init( global model_caches device = torch.device("cpu") - rank = mp.current_process()._identity - rank = rank[0] if len(rank) > 0 else 0 - if torch.cuda.is_available(): - gpu_id = rank % torch.cuda.device_count() - device = torch.device(f"cuda:{gpu_id}") + device = torch.device("cuda") - logger.info(f"Rank {rank} uses device {device}") + logger.info(f"{curr_worker} Uses device {device}") text_features_extractor = None if getattr(config.preprocessing, "text_features_extractor", None): @@ -213,7 +211,7 @@ def safe_process(args, config, audio_path: Path): return aug_count + 1 except Exception as e: - logger.error(f"Error processing {audio_path}") + logger.error(f"{curr_worker} Error processing {audio_path}") if args.debug: logger.exception(e) @@ -225,64 +223,131 @@ def parse_args(): parser.add_argument("--config", type=str, required=True) parser.add_argument("--path", type=str, required=True) parser.add_argument("--clean", action="store_true") - parser.add_argument("--num-workers", type=int, default=1) + parser.add_argument( + "--num-workers", + type=int, + default=1, + help="Number of workers, will launch a process pool if > 1", + ) parser.add_argument("--no-augmentation", action="store_true") parser.add_argument("--debug", action="store_true") + # For multiprocessing + parser.add_argument("--rank", type=int, default=0) + parser.add_argument("--world-size", type=int, default=1) + return parser.parse_args() if __name__ == "__main__": - # mp.set_start_method("spawn", force=True) - args = parse_args() - - logger.info(f"Using {args.num_workers} workers") + curr_worker = f"[Rank {args.rank}]" if args.world_size > 1 else "[Main]" if torch.cuda.is_available(): - logger.info(f"Found {torch.cuda.device_count()} GPUs") + logger.info(f"{curr_worker} Found {torch.cuda.device_count()} GPUs") else: - logger.warning("No GPU found, using CPU") + logger.warning(f"{curr_worker} No GPU found, using CPU") - if args.clean: - logger.info("Cleaning *.npy files...") + # Only clean on main process + if args.clean and args.rank == 0: + logger.info(f"{curr_worker} Cleaning *.npy files...") files = list_files(args.path, {".npy"}, recursive=True, sort=True) for f in files: f.unlink() - logger.info("Done!") + logger.info(f"{curr_worker} Done!") + + # Multi-processing + if args.num_workers > 1: + logger.info(f"{curr_worker} Launching {args.num_workers} workers") + + processes = [] + for idx in range(args.num_workers): + new_args = [ + "python", + __file__, + "--config", + args.config, + "--path", + args.path, + "--rank", + str(idx), + "--world-size", + str(args.num_workers), + ] + + if args.no_augmentation: + new_args.append("--no-augmentation") + + if args.debug: + new_args.append("--debug") + + env = deepcopy(os.environ) + + # Respect CUDA_VISIBLE_DEVICES + if "CUDA_VISIBLE_DEVICES" in env: + devices = env["CUDA_VISIBLE_DEVICES"].split(",") + env["CUDA_VISIBLE_DEVICES"] = devices[idx % len(devices)] + else: + env["CUDA_VISIBLE_DEVICES"] = str(idx % torch.cuda.device_count()) + + processes.append(sp.Popen(new_args, env=env)) + logger.info(f"{curr_worker} Launched worker {idx}") + + for p in processes: + p.wait() + + if p.returncode != 0: + logger.error( + f"{curr_worker} Worker {idx} failed with code {p.returncode}, exiting..." + ) + exit(p.returncode) + + logger.info(f"{curr_worker} All workers done!") + exit(0) + # Load config config = Config.fromfile(args.config) - files = list_files(args.path, AUDIO_EXTENSIONS, recursive=True, sort=False) - logger.info(f"Found {len(files)} files, processing...") + files = list_files(args.path, AUDIO_EXTENSIONS, recursive=True, sort=True) # Shuffle files will balance the workload of workers - random.shuffle(files) + Random(42).shuffle(files) + + logger.info(f"{curr_worker} Found {len(files)} files, processing...") + + # Chunk files + if args.world_size > 1: + files = files[args.rank :: args.world_size] + logger.info(f"{curr_worker} Processing subset of {len(files)} files") + + # Main process total_samples, failed = 0, 0 + log_time = 0 + start_time = time.time() + + for idx, audio_path in enumerate(files): + i = safe_process(args, config, audio_path) + if isinstance(i, int): + total_samples += i + else: + failed += 1 + + if (idx + 1) % 100 == 0 and time.time() - log_time > 10: + eta = (time.time() - start_time) / (idx + 1) * (len(files) - idx - 1) + + logger.info( + f"{curr_worker} " + + f"Processed {idx + 1}/{len(files)} files, " + + f"{total_samples} samples, {failed} failed, " + + f"ETA: {timedelta(seconds=eta)}" + ) + + log_time = time.time() - if args.num_workers <= 1: - for audio_path in tqdm(files): - i = safe_process(args, config, audio_path) - if isinstance(i, int): - total_samples += i - else: - failed += 1 - else: - with ProcessPoolExecutor( - max_workers=args.num_workers, - ) as executor: - params = [(args, config, audio_path) for audio_path in files] - - for i in tqdm(executor.map(safe_process, *zip(*params)), total=len(params)): - if isinstance(i, int): - total_samples += i - else: - failed += 1 - - logger.info(f"Finished!") - logger.info(f"Original samples: {len(files)}") logger.info( - f"Augmented samples: {total_samples} (x{total_samples / len(files):.2f})" + f"{curr_worker} Done! " + + f"Original samples: {len(files)}, " + + f"Augmented samples: {total_samples} (x{total_samples / len(files):.2f}), " + + f"Failed: {failed}" ) - logger.info(f"Failed: {failed}")