11
11
import torch
12
12
import torch .distributed as dist
13
13
import torch .multiprocessing as mp
14
+ import torchaudio
14
15
import tqdm
15
16
from sklearn .cluster import MiniBatchKMeans
16
17
from torch .cuda .amp import GradScaler , autocast
22
23
from . import commons , utils
23
24
from .checkpoints import save
24
25
from .config import DatasetMetadata , TrainConfig
25
- from .data_utils import (
26
- DistributedBucketSampler ,
27
- TextAudioCollate ,
28
- TextAudioCollateMultiNSFsid ,
29
- TextAudioLoader ,
30
- TextAudioLoaderMultiNSFsid ,
31
- )
26
+ from .data_utils import (DistributedBucketSampler , TextAudioCollate ,
27
+ TextAudioCollateMultiNSFsid , TextAudioLoader ,
28
+ TextAudioLoaderMultiNSFsid )
32
29
from .losses import discriminator_loss , feature_loss , generator_loss , kl_loss
33
30
from .mel_processing import mel_spectrogram_torch , spec_to_mel_torch
34
- from .models import (
35
- MultiPeriodDiscriminator ,
36
- SynthesizerTrnMs256NSFSid ,
37
- SynthesizerTrnMs256NSFSidNono ,
38
- )
31
+ from .models import (MultiPeriodDiscriminator , SynthesizerTrnMs256NSFSid ,
32
+ SynthesizerTrnMs256NSFSidNono )
33
+ from .preprocessing .extract_feature import (MODELS_DIR , get_embedder ,
34
+ load_embedder )
39
35
40
36
41
37
def is_audio_file (file : str ):
@@ -149,6 +145,91 @@ def list_data(dir: str):
149
145
json .dump (meta , f , indent = 2 )
150
146
151
147
148
+ def change_speaker (net_g , speaker_info , embedder , embedding_output_layer , phone , phone_lengths , pitch , pitchf , spec_lengths ):
149
+ """
150
+ random change formant
151
+ inspired by https://github.com/auspicious3000/contentvec/blob/d746688a32940f4bee410ed7c87ec9cf8ff04f74/contentvec/data/audio/audio_utils_1.py#L179
152
+ """
153
+ N = phone .shape [0 ]
154
+ device = phone .device
155
+ dtype = phone .dtype
156
+
157
+ f0_bin = 256
158
+ f0_max = 1100.0
159
+ f0_min = 50.0
160
+ f0_mel_min = 1127 * np .log (1 + f0_min / 700 )
161
+ f0_mel_max = 1127 * np .log (1 + f0_max / 700 )
162
+
163
+ pitch_median = torch .median (pitchf , 1 ).values
164
+ lo = 75. + 25. * (pitch_median >= 200 ).to (dtype = dtype )
165
+ hi = 250. + 150. * (pitch_median >= 200 ).to (dtype = dtype )
166
+ pitch_median = torch .clip (pitch_median , lo , hi ).unsqueeze (1 )
167
+
168
+ shift_pitch = torch .exp2 ((1. - 2. * torch .rand (N )) / 2 ).unsqueeze (1 ).to (device , dtype ) # ピッチを1オクターブの範囲でずらす
169
+
170
+ new_sid = np .random .choice (np .arange (len (speaker_info ))[speaker_info > 0 ], size = N )
171
+ rel_pitch = pitchf / pitch_median
172
+ new_pitch_median = torch .from_numpy (speaker_info [new_sid ]).to (device , dtype ).unsqueeze (1 ) * shift_pitch
173
+ new_pitchf = new_pitch_median * rel_pitch
174
+ new_sid = torch .from_numpy (new_sid ).to (device )
175
+
176
+ new_pitch = 1127. * torch .log (1. + new_pitchf / 700. )
177
+ new_pitch = (pitch - f0_mel_min ) * (f0_bin - 2. ) / (f0_mel_max - f0_mel_min ) + 1.
178
+ new_pitch = torch .clip (new_pitch , 1 , f0_bin - 1 ).to (dtype = torch .int )
179
+
180
+ new_wave = net_g .infer (phone , phone_lengths , new_pitch , new_pitchf , new_sid )[0 ]
181
+ new_wave_16k = torchaudio .functional .resample (new_wave , net_g .sr , 16000 , rolloff = 0.99 ).squeeze (1 )
182
+ padding_mask = torch .arange (new_wave_16k .shape [1 ]).unsqueeze (0 ).to (device ) > (spec_lengths .unsqueeze (1 ) * 160 ).to (device )
183
+
184
+ inputs = {
185
+ "source" : new_wave_16k .to (device , dtype ),
186
+ "padding_mask" : padding_mask .to (device ),
187
+ "output_layer" : embedding_output_layer
188
+ }
189
+ logits = embedder .extract_features (** inputs )
190
+ if phone .shape [- 1 ] == 768 :
191
+ feats = logits [0 ]
192
+ else :
193
+ feats = embedder .final_proj (logits [0 ])
194
+ feats = torch .repeat_interleave (feats , 2 , 1 )
195
+ new_phone = torch .zeros (phone .shape ).to (device , dtype )
196
+ new_phone [:, :feats .shape [1 ]] = feats [:, :phone .shape [1 ]]
197
+ return new_phone .to (device )
198
+
199
+
200
+ def change_speaker_nono (net_g , embedder , embedding_output_layer , phone , phone_lengths , spec_lengths ):
201
+ """
202
+ random change formant
203
+ inspired by https://github.com/auspicious3000/contentvec/blob/d746688a32940f4bee410ed7c87ec9cf8ff04f74/contentvec/data/audio/audio_utils_1.py#L179
204
+ """
205
+ N = phone .shape [0 ]
206
+ device = phone .device
207
+ dtype = phone .dtype
208
+
209
+ new_sid = np .random .randint (net_g .spk_embed_dim , size = N )
210
+ new_sid = torch .from_numpy (new_sid ).to (device )
211
+
212
+ new_wave = net_g .infer (phone , phone_lengths , new_sid )[0 ]
213
+ new_wave_16k = torchaudio .functional .resample (new_wave , net_g .sr , 16000 , rolloff = 0.99 ).squeeze (1 )
214
+ padding_mask = torch .arange (new_wave_16k .shape [1 ]).unsqueeze (0 ).to (device ) > (spec_lengths .unsqueeze (1 ) * 160 ).to (device )
215
+
216
+ inputs = {
217
+ "source" : new_wave_16k .to (device , dtype ),
218
+ "padding_mask" : padding_mask .to (device ),
219
+ "output_layer" : embedding_output_layer
220
+ }
221
+
222
+ logits = embedder .extract_features (** inputs )
223
+ if phone .shape [- 1 ] == 768 :
224
+ feats = logits [0 ]
225
+ else :
226
+ feats = embedder .final_proj (logits [0 ])
227
+ feats = torch .repeat_interleave (feats , 2 , 1 )
228
+ new_phone = torch .zeros (phone .shape ).to (device , dtype )
229
+ new_phone [:, :feats .shape [1 ]] = feats [:, :phone .shape [1 ]]
230
+ return new_phone .to (device )
231
+
232
+
152
233
def train_index (
153
234
training_dir : str ,
154
235
model_name : str ,
@@ -225,6 +306,9 @@ def train_model(
225
306
sample_rate : int ,
226
307
f0 : bool ,
227
308
batch_size : int ,
309
+ augment : bool ,
310
+ augment_path : Optional [str ],
311
+ speaker_info_path : Optional [str ],
228
312
cache_batch : bool ,
229
313
total_epoch : int ,
230
314
save_every_epoch : int ,
@@ -261,6 +345,9 @@ def train_model(
261
345
sample_rate ,
262
346
f0 ,
263
347
batch_size ,
348
+ augment ,
349
+ augment_path ,
350
+ speaker_info_path ,
264
351
cache_batch ,
265
352
total_epoch ,
266
353
save_every_epoch ,
@@ -284,6 +371,9 @@ def train_model(
284
371
sample_rate ,
285
372
f0 ,
286
373
batch_size ,
374
+ augment ,
375
+ augment_path ,
376
+ speaker_info_path ,
287
377
cache_batch ,
288
378
total_epoch ,
289
379
save_every_epoch ,
@@ -319,6 +409,9 @@ def training_runner(
319
409
sample_rate : int ,
320
410
f0 : bool ,
321
411
batch_size : int ,
412
+ augment : bool ,
413
+ augment_path : Optional [str ],
414
+ speaker_info_path : Optional [str ],
322
415
cache_in_gpu : bool ,
323
416
total_epoch : int ,
324
417
save_every_epoch : int ,
@@ -395,14 +488,15 @@ def training_runner(
395
488
config .train .segment_size // config .data .hop_length ,
396
489
** config .model .dict (),
397
490
is_half = config .train .fp16_run ,
398
- sr = sample_rate ,
491
+ sr = int ( sample_rate [: - 1 ] + "000" ) ,
399
492
)
400
493
else :
401
494
net_g = SynthesizerTrnMs256NSFSidNono (
402
495
config .data .filter_length // 2 + 1 ,
403
496
config .train .segment_size // config .data .hop_length ,
404
497
** config .model .dict (),
405
498
is_half = config .train .fp16_run ,
499
+ sr = int (sample_rate [:- 1 ] + "000" ),
406
500
)
407
501
408
502
if is_multi_process :
@@ -440,6 +534,48 @@ def training_runner(
440
534
last_d_state = utils .latest_checkpoint_path (state_dir , "D_*.pth" )
441
535
last_g_state = utils .latest_checkpoint_path (state_dir , "G_*.pth" )
442
536
537
+ if augment :
538
+ # load embedder
539
+ embedder_filepath , _ , embedder_load_from = get_embedder (embedder_name )
540
+
541
+ if embedder_load_from == "local" :
542
+ embedder_filepath = os .path .join (
543
+ MODELS_DIR , "embeddings" , embedder_filepath
544
+ )
545
+ embedder , _ = load_embedder (embedder_filepath , device )
546
+ if not config .train .fp16_run :
547
+ embedder = embedder .float ()
548
+
549
+ if (augment_path is not None ):
550
+ state_dict = torch .load (augment_path , map_location = "cpu" )
551
+ if state_dict ["f0" ] == 1 :
552
+ augment_net_g = SynthesizerTrnMs256NSFSid (
553
+ ** state_dict ["params" ], is_half = config .train .fp16_run
554
+ )
555
+ augment_speaker_info = np .load (speaker_info_path )
556
+ else :
557
+ augment_net_g = SynthesizerTrnMs256NSFSidNono (
558
+ ** state_dict ["params" ], is_half = config .train .fp16_run
559
+ )
560
+
561
+ augment_net_g .load_state_dict (state_dict ["weight" ], strict = False )
562
+ augment_net_g .eval ().to (device )
563
+
564
+ if config .train .fp16_run :
565
+ augment_net_g = augment_net_g .half ()
566
+ else :
567
+ augment_net_g = augment_net_g .float ()
568
+ else :
569
+ augment_net_g = net_g
570
+ if f0 :
571
+ medians = [[] for _ in range (augment_net_g .spk_embed_dim )]
572
+ for file in training_meta .files .values ():
573
+ f0f = np .load (file .f0nsf )
574
+ if np .any (f0f > 0 ):
575
+ medians [file .speaker_id ].append (np .median (f0f [f0f > 0 ]))
576
+ augment_speaker_info = np .array ([np .median (x ) if len (x ) else 0. for x in medians ])
577
+ np .save (os .path .join (training_dir , "speaker_info.npy" ), augment_speaker_info )
578
+
443
579
if last_d_state is None or last_g_state is None :
444
580
epoch = 1
445
581
global_step = 0
@@ -520,6 +656,7 @@ def training_runner(
520
656
cache = []
521
657
progress_bar = tqdm .tqdm (range ((total_epoch - epoch + 1 ) * len (train_loader )))
522
658
progress_bar .set_postfix (epoch = epoch )
659
+ step = - 1
523
660
for epoch in range (epoch , total_epoch + 1 ):
524
661
train_loader .batch_sampler .set_epoch (epoch )
525
662
@@ -536,6 +673,7 @@ def training_runner(
536
673
shuffle (cache )
537
674
538
675
for batch_idx , batch in data :
676
+ step += 1
539
677
progress_bar .update (1 )
540
678
if f0 :
541
679
(
@@ -614,6 +752,15 @@ def training_runner(
614
752
)
615
753
616
754
with autocast (enabled = config .train .fp16_run ):
755
+ if augment :
756
+ with torch .no_grad ():
757
+ if type (augment_net_g ) == SynthesizerTrnMs256NSFSid :
758
+ new_phone = change_speaker (augment_net_g , augment_speaker_info , embedder , embedding_output_layer , phone , phone_lengths , pitch , pitchf , spec_lengths )
759
+ else :
760
+ new_phone = change_speaker_nono (augment_net_g , embedder , embedding_output_layer , phone , phone_lengths , spec_lengths )
761
+ weight = np .power (.5 , step / len (train_dataset )) # 学習の初期はそのままのphone embeddingを使う
762
+ phone = phone * weight + new_phone * (1. - weight )
763
+
617
764
if f0 :
618
765
(
619
766
y_hat ,
0 commit comments