Skip to content

Commit

Permalink
Faster preprocessing
Browse files Browse the repository at this point in the history
  • Loading branch information
leng-yue authored Dec 14, 2023
1 parent 3fa3153 commit a46342f
Showing 1 changed file with 111 additions and 46 deletions.
157 changes: 111 additions & 46 deletions tools/preprocessing/extract_features.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,21 @@
import argparse
import os
import random
from concurrent.futures import ProcessPoolExecutor
import subprocess as sp
import time
from copy import deepcopy
from datetime import timedelta
from pathlib import Path
from random import Random
from typing import Optional

import librosa
import numpy as np
import torch
import torch.multiprocessing as mp
import torchcrepe
from fish_audio_preprocess.utils.file import AUDIO_EXTENSIONS, list_files
from loguru import logger
from mmengine import Config
from tqdm import tqdm

from fish_diffusion.modules.energy_extractors import ENERGY_EXTRACTORS
from fish_diffusion.modules.feature_extractors import FEATURE_EXTRACTORS
Expand All @@ -38,14 +40,10 @@ def init(
global model_caches
device = torch.device("cpu")

rank = mp.current_process()._identity
rank = rank[0] if len(rank) > 0 else 0

if torch.cuda.is_available():
gpu_id = rank % torch.cuda.device_count()
device = torch.device(f"cuda:{gpu_id}")
device = torch.device("cuda")

logger.info(f"Rank {rank} uses device {device}")
logger.info(f"{curr_worker} Uses device {device}")

text_features_extractor = None
if getattr(config.preprocessing, "text_features_extractor", None):
Expand Down Expand Up @@ -213,7 +211,7 @@ def safe_process(args, config, audio_path: Path):

return aug_count + 1
except Exception as e:
logger.error(f"Error processing {audio_path}")
logger.error(f"{curr_worker} Error processing {audio_path}")

if args.debug:
logger.exception(e)
Expand All @@ -225,64 +223,131 @@ def parse_args():
parser.add_argument("--config", type=str, required=True)
parser.add_argument("--path", type=str, required=True)
parser.add_argument("--clean", action="store_true")
parser.add_argument("--num-workers", type=int, default=1)
parser.add_argument(
"--num-workers",
type=int,
default=1,
help="Number of workers, will launch a process pool if > 1",
)
parser.add_argument("--no-augmentation", action="store_true")
parser.add_argument("--debug", action="store_true")

# For multiprocessing
parser.add_argument("--rank", type=int, default=0)
parser.add_argument("--world-size", type=int, default=1)

return parser.parse_args()


if __name__ == "__main__":
# mp.set_start_method("spawn", force=True)

args = parse_args()

logger.info(f"Using {args.num_workers} workers")
curr_worker = f"[Rank {args.rank}]" if args.world_size > 1 else "[Main]"

if torch.cuda.is_available():
logger.info(f"Found {torch.cuda.device_count()} GPUs")
logger.info(f"{curr_worker} Found {torch.cuda.device_count()} GPUs")
else:
logger.warning("No GPU found, using CPU")
logger.warning(f"{curr_worker} No GPU found, using CPU")

if args.clean:
logger.info("Cleaning *.npy files...")
# Only clean on main process
if args.clean and args.rank == 0:
logger.info(f"{curr_worker} Cleaning *.npy files...")

files = list_files(args.path, {".npy"}, recursive=True, sort=True)
for f in files:
f.unlink()

logger.info("Done!")
logger.info(f"{curr_worker} Done!")

# Multi-processing
if args.num_workers > 1:
logger.info(f"{curr_worker} Launching {args.num_workers} workers")

processes = []
for idx in range(args.num_workers):
new_args = [
"python",
__file__,
"--config",
args.config,
"--path",
args.path,
"--rank",
str(idx),
"--world-size",
str(args.num_workers),
]

if args.no_augmentation:
new_args.append("--no-augmentation")

if args.debug:
new_args.append("--debug")

env = deepcopy(os.environ)

# Respect CUDA_VISIBLE_DEVICES
if "CUDA_VISIBLE_DEVICES" in env:
devices = env["CUDA_VISIBLE_DEVICES"].split(",")
env["CUDA_VISIBLE_DEVICES"] = devices[idx % len(devices)]
else:
env["CUDA_VISIBLE_DEVICES"] = str(idx % torch.cuda.device_count())

processes.append(sp.Popen(new_args, env=env))
logger.info(f"{curr_worker} Launched worker {idx}")

for p in processes:
p.wait()

if p.returncode != 0:
logger.error(
f"{curr_worker} Worker {idx} failed with code {p.returncode}, exiting..."
)
exit(p.returncode)

logger.info(f"{curr_worker} All workers done!")
exit(0)

# Load config
config = Config.fromfile(args.config)
files = list_files(args.path, AUDIO_EXTENSIONS, recursive=True, sort=False)
logger.info(f"Found {len(files)} files, processing...")
files = list_files(args.path, AUDIO_EXTENSIONS, recursive=True, sort=True)

# Shuffle files will balance the workload of workers
random.shuffle(files)
Random(42).shuffle(files)

logger.info(f"{curr_worker} Found {len(files)} files, processing...")

# Chunk files
if args.world_size > 1:
files = files[args.rank :: args.world_size]
logger.info(f"{curr_worker} Processing subset of {len(files)} files")

# Main process
total_samples, failed = 0, 0
log_time = 0
start_time = time.time()

for idx, audio_path in enumerate(files):
i = safe_process(args, config, audio_path)
if isinstance(i, int):
total_samples += i
else:
failed += 1

if (idx + 1) % 100 == 0 and time.time() - log_time > 10:
eta = (time.time() - start_time) / (idx + 1) * (len(files) - idx - 1)

logger.info(
f"{curr_worker} "
+ f"Processed {idx + 1}/{len(files)} files, "
+ f"{total_samples} samples, {failed} failed, "
+ f"ETA: {timedelta(seconds=eta)}"
)

log_time = time.time()

if args.num_workers <= 1:
for audio_path in tqdm(files):
i = safe_process(args, config, audio_path)
if isinstance(i, int):
total_samples += i
else:
failed += 1
else:
with ProcessPoolExecutor(
max_workers=args.num_workers,
) as executor:
params = [(args, config, audio_path) for audio_path in files]

for i in tqdm(executor.map(safe_process, *zip(*params)), total=len(params)):
if isinstance(i, int):
total_samples += i
else:
failed += 1

logger.info(f"Finished!")
logger.info(f"Original samples: {len(files)}")
logger.info(
f"Augmented samples: {total_samples} (x{total_samples / len(files):.2f})"
f"{curr_worker} Done! "
+ f"Original samples: {len(files)}, "
+ f"Augmented samples: {total_samples} (x{total_samples / len(files):.2f}), "
+ f"Failed: {failed}"
)
logger.info(f"Failed: {failed}")

0 comments on commit a46342f

Please sign in to comment.