-
Notifications
You must be signed in to change notification settings - Fork 5
/
save_alignment_paths.py
131 lines (96 loc) · 4.03 KB
/
save_alignment_paths.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
"""
Save the alignment matrices (hard attention weights) for the MUSDB dataset using the specified model (tag)
"""
import os
import pickle
import json
import torch
import numpy as np
import matplotlib.pyplot as plt
import testx
import data
import model
from estimate_alignment import optimal_alignment_path
tag = 'JOINT3' # tag of alignment model
target = 'vocals'
torch.manual_seed(0)
model_path = 'trained_models/{}'.format(tag)
device = 'cpu'
print("Device:", device)
# load model
unmix = testx.load_model(target, model_path, device)
unmix.return_alphas = True
unmix.stft.center = True
try:
with open(os.path.join(model_path, target + '.json'), 'r') as stream:
config = json.load(stream)
keys = config['args'].keys()
samplerate = config['args']['samplerate']
text_units = config['args']['text_units']
nb_channels = config['args']['nb_channels']
nfft = config['args']['nfft']
nhop = config['args']['nhop']
data_set = config['args']['dataset']
space_token_only = config['args']['space_token_only'] if 'space_token_only' in keys else False
except (FileNotFoundError):
print('no config file found!')
quit()
test_set = data.MUSDBLyricsDataTest(samplerate=samplerate, text_units=text_units,
space_token_only=True)
val_set = data.MUSDBLyricsDataVal(samplerate=samplerate, text_units=text_units, space_token_only=True,
return_name=True)
train_set = data.MUSDBLyricsDataTrain(samplerate=samplerate, text_units=text_units, add_silence=False,
random_track_mixing=False,
space_token_only=True, return_name=True)
pickle_in = open('dicts/idx2cmu_phoneme.pickle', 'rb')
idx2symbol = pickle.load(pickle_in)
# go through data sets and save alignment path
# use clean vocals for training and validation set
# use mixtures for test set
# make dirs to save alignments
base_path = 'evaluation/{}/musdb_alignments/'.format(tag)
if not os.path.isdir(base_path):
os.makedirs(base_path)
os.makedirs(os.path.join(base_path, 'train'))
os.makedirs(os.path.join(base_path, 'val'))
os.makedirs(os.path.join(base_path, 'test'))
# TEST SET
for idx in range(len(test_set)):
track = test_set[idx]
mix = track['mix'].unsqueeze(dim=0)
true_vocals = track['vocals'].unsqueeze(dim=0)
true_accompaniment = track['accompaniment']
text = track['text'].unsqueeze(dim=0)
name = track['name'][2:]
with torch.no_grad():
vocals_estimate, alphas, scores = unmix((mix.to(device), text.to(device)))
optimal_path = optimal_alignment_path(scores, mode='max_numpy', init=2000)
optimal_path = torch.from_numpy(optimal_path).type(torch.float32)
torch.save(optimal_path, os.path.join(base_path, 'test', name + '.pt'))
print(idx, name)
# TRAIN SET
for idx in range(len(train_set)):
data = train_set[idx]
mix = data[0].unsqueeze(dim=0) # mix
true_vocals = data[1].unsqueeze(dim=0) # vocals
text = data[2].unsqueeze(dim=0) # text
name = data[3] # track name
with torch.no_grad():
vocals_estimate, alphas, scores = unmix((true_vocals.to(device), text.to(device)))
optimal_path = optimal_alignment_path(scores, mode='max_numpy', init=2000)
optimal_path = torch.from_numpy(optimal_path).type(torch.float32)
torch.save(optimal_path, os.path.join(base_path, 'train', name + '.pt'))
print(idx, name)
# VAL SET
for idx in range(len(val_set)):
data = val_set[idx]
mix = data[0].unsqueeze(dim=0) # mix
true_vocals = data[1].unsqueeze(dim=0) # vocals
text = data[2].unsqueeze(dim=0) # text
name = data[3] # track name
with torch.no_grad():
vocals_estimate, alphas, scores = unmix((true_vocals.to(device), text.to(device)))
optimal_path = optimal_alignment_path(scores, mode='max_numpy', init=2000)
optimal_path = torch.from_numpy(optimal_path).type(torch.float32)
torch.save(optimal_path, os.path.join(base_path, 'val', name + '.pt'))
print(idx, name)