11
11
import torch
12
12
import torch .distributed as dist
13
13
import torch .multiprocessing as mp
14
+ import torchaudio
14
15
import tqdm
15
16
from sklearn .cluster import MiniBatchKMeans
16
17
from torch .cuda .amp import GradScaler , autocast
22
23
from . import commons , utils
23
24
from .checkpoints import save
24
25
from .config import DatasetMetadata , TrainConfig
25
- from .data_utils import (
26
- DistributedBucketSampler ,
27
- TextAudioCollate ,
28
- TextAudioCollateMultiNSFsid ,
29
- TextAudioLoader ,
30
- TextAudioLoaderMultiNSFsid ,
31
- )
26
+ from .data_utils import (DistributedBucketSampler , TextAudioCollate ,
27
+ TextAudioCollateMultiNSFsid , TextAudioLoader ,
28
+ TextAudioLoaderMultiNSFsid )
32
29
from .losses import discriminator_loss , feature_loss , generator_loss , kl_loss
33
30
from .mel_processing import mel_spectrogram_torch , spec_to_mel_torch
34
- from .models import (
35
- MultiPeriodDiscriminator ,
36
- SynthesizerTrnMs256NSFSid ,
37
- SynthesizerTrnMs256NSFSidNono ,
38
- )
31
+ from .models import (MultiPeriodDiscriminator , SynthesizerTrnMs256NSFSid ,
32
+ SynthesizerTrnMs256NSFSidNono )
33
+ from .preprocessing .extract_feature import (MODELS_DIR , get_embedder ,
34
+ load_embedder )
39
35
40
36
41
37
def is_audio_file (file : str ):
@@ -149,6 +145,60 @@ def list_data(dir: str):
149
145
json .dump (meta , f , indent = 2 )
150
146
151
147
148
+ def change_speaker (net_g , embedder , embedding_output_layer , phone , phone_lengths , pitch , pitchf , spec_lengths , sid ):
149
+ """
150
+ random change formant
151
+ inspired by https://github.com/auspicious3000/contentvec/blob/d746688a32940f4bee410ed7c87ec9cf8ff04f74/contentvec/data/audio/audio_utils_1.py#L179
152
+ """
153
+ N = pitchf .shape [0 ]
154
+ device = pitchf .device
155
+ dtype = pitchf .dtype
156
+
157
+ f0_bin = 256
158
+ f0_max = 1100.0
159
+ f0_min = 50.0
160
+ f0_mel_min = 1127 * np .log (1 + f0_min / 700 )
161
+ f0_mel_max = 1127 * np .log (1 + f0_max / 700 )
162
+
163
+ pitch_median = torch .median (pitchf , 1 ).values
164
+ lo = 75. + 25. * (pitch_median >= 200 ).to (dtype = dtype )
165
+ hi = 250. + 150. * (pitch_median >= 200 ).to (dtype = dtype )
166
+ pitch_median = torch .clip (pitch_median , lo , hi ).unsqueeze (1 )
167
+
168
+ ratio_speaker = torch .pow (.5 , 2. * torch .rand (N )).unsqueeze (1 ).to (device , dtype ) # 変更後の話者にピッチの中央値を[.25, .1]の割合で合わせる
169
+ shift_pitch = torch .exp2 ((1. - 2. * torch .rand (N )) / 2 ).unsqueeze (1 ).to (device , dtype ) # ピッチを1オクターブの範囲でずらす
170
+
171
+ shuffle_ixs = np .arange (N )
172
+ np .random .shuffle (shuffle_ixs )
173
+ rel_pitch = pitchf / pitch_median
174
+ new_pitch_median = torch .exp2 (torch .log2 (pitch_median [shuffle_ixs ]) * ratio_speaker + torch .log2 (pitch_median ) * (1. - ratio_speaker )) * shift_pitch
175
+ new_pitchf = new_pitch_median * rel_pitch
176
+ new_sid = sid [shuffle_ixs ]
177
+
178
+ new_pitch = 1127. * torch .log (1. + new_pitchf / 700. )
179
+ new_pitch = (pitch - f0_mel_min ) * (f0_bin - 2. ) / (f0_mel_max - f0_mel_min ) + 1.
180
+ new_pitch = torch .clip (new_pitch , 1 , f0_bin - 1 ).to (dtype = torch .int )
181
+
182
+ new_wave = net_g .infer (phone , phone_lengths , new_pitch , new_pitchf , new_sid )[0 ]
183
+ new_wave_16k = torchaudio .functional .resample (new_wave , net_g .sr , 16000 , rolloff = 0.99 ).squeeze (1 )
184
+ padding_mask = torch .arange (new_wave_16k .shape [1 ]).unsqueeze (0 ).to (device ) > (spec_lengths .unsqueeze (1 ) * 160 ).to (device )
185
+
186
+ inputs = {
187
+ "source" : new_wave_16k .to (device , dtype ),
188
+ "padding_mask" : padding_mask .to (device ),
189
+ "output_layer" : embedding_output_layer
190
+ }
191
+ logits = embedder .extract_features (** inputs )
192
+ if phone .shape [- 1 ] == 768 :
193
+ feats = logits [0 ]
194
+ else :
195
+ feats = embedder .final_proj (logits [0 ])
196
+ feats = torch .repeat_interleave (feats , 2 , 1 )
197
+ new_phone = torch .zeros (phone .shape ).to (device , dtype )
198
+ new_phone [:, :feats .shape [1 ]] = feats [:, :phone .shape [1 ]]
199
+ return new_phone .to (device )
200
+
201
+
152
202
def train_index (
153
203
training_dir : str ,
154
204
model_name : str ,
@@ -225,6 +275,7 @@ def train_model(
225
275
sample_rate : int ,
226
276
f0 : bool ,
227
277
batch_size : int ,
278
+ augment : bool ,
228
279
cache_batch : bool ,
229
280
total_epoch : int ,
230
281
save_every_epoch : int ,
@@ -261,6 +312,7 @@ def train_model(
261
312
sample_rate ,
262
313
f0 ,
263
314
batch_size ,
315
+ augment ,
264
316
cache_batch ,
265
317
total_epoch ,
266
318
save_every_epoch ,
@@ -284,6 +336,7 @@ def train_model(
284
336
sample_rate ,
285
337
f0 ,
286
338
batch_size ,
339
+ augment ,
287
340
cache_batch ,
288
341
total_epoch ,
289
342
save_every_epoch ,
@@ -319,6 +372,7 @@ def training_runner(
319
372
sample_rate : int ,
320
373
f0 : bool ,
321
374
batch_size : int ,
375
+ augment : bool ,
322
376
cache_in_gpu : bool ,
323
377
total_epoch : int ,
324
378
save_every_epoch : int ,
@@ -359,6 +413,17 @@ def training_runner(
359
413
360
414
torch .manual_seed (config .train .seed )
361
415
416
+ if augment :
417
+ embedder_filepath , _ , embedder_load_from = get_embedder (embedder_name )
418
+
419
+ if embedder_load_from == "local" :
420
+ embedder_filepath = os .path .join (
421
+ MODELS_DIR , "embeddings" , embedder_filepath
422
+ )
423
+ embedder , _ = load_embedder (embedder_filepath , device )
424
+ if not config .train .fp16_run :
425
+ embedder = embedder .float ()
426
+
362
427
if f0 :
363
428
train_dataset = TextAudioLoaderMultiNSFsid (training_meta , config .data )
364
429
else :
@@ -520,6 +585,7 @@ def training_runner(
520
585
cache = []
521
586
progress_bar = tqdm .tqdm (range ((total_epoch - epoch + 1 ) * len (train_loader )))
522
587
progress_bar .set_postfix (epoch = epoch )
588
+ step = - 1
523
589
for epoch in range (epoch , total_epoch + 1 ):
524
590
train_loader .batch_sampler .set_epoch (epoch )
525
591
@@ -536,6 +602,7 @@ def training_runner(
536
602
shuffle (cache )
537
603
538
604
for batch_idx , batch in data :
605
+ step += 1
539
606
progress_bar .update (1 )
540
607
if f0 :
541
608
(
@@ -614,6 +681,12 @@ def training_runner(
614
681
)
615
682
616
683
with autocast (enabled = config .train .fp16_run ):
684
+ if f0 and augment :
685
+ with torch .no_grad ():
686
+ new_phone = change_speaker (net_g , embedder , embedding_output_layer , phone , phone_lengths , pitch , pitchf , spec_lengths , sid )
687
+ weight = np .power (.5 , step / len (train_dataset )) # 学習の初期はそのままのphone embeddingを使う
688
+ phone = phone * weight + new_phone * (1. - weight )
689
+
617
690
if f0 :
618
691
(
619
692
y_hat ,
0 commit comments