Skip to content

Commit 1ccca4a

Browse files
authored
Merge pull request #49 from nadare881/augment
experimental: augmentationの追加
2 parents a488c7e + 29897a4 commit 1ccca4a

File tree

4 files changed

+206
-14
lines changed

4 files changed

+206
-14
lines changed

lib/rvc/models.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,7 @@ def __init__(
525525
self.segment_size = segment_size
526526
self.gin_channels = gin_channels
527527
self.emb_channels = emb_channels
528+
self.sr = sr
528529
# self.hop_length = hop_length#
529530
self.spk_embed_dim = spk_embed_dim
530531
self.enc_p = TextEncoder(
@@ -644,6 +645,7 @@ def __init__(
644645
self.segment_size = segment_size
645646
self.gin_channels = gin_channels
646647
self.emb_channels = emb_channels
648+
self.sr = sr
647649
# self.hop_length = hop_length#
648650
self.spk_embed_dim = spk_embed_dim
649651
self.enc_p = TextEncoder(

lib/rvc/preprocessing/extract_feature.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,22 @@
1111
from fairseq import checkpoint_utils
1212
from tqdm import tqdm
1313

14+
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
15+
MODELS_DIR = os.path.join(ROOT_DIR, "models")
16+
EMBEDDINGS_LIST = {
17+
"hubert-base-japanese": (
18+
"rinna_hubert_base_jp.pt",
19+
"hubert-base-japanese",
20+
"local",
21+
),
22+
"contentvec": ("checkpoint_best_legacy_500.pt", "contentvec", "local"),
23+
}
24+
25+
def get_embedder(embedder_name):
26+
if embedder_name in EMBEDDINGS_LIST:
27+
return EMBEDDINGS_LIST[embedder_name]
28+
return None
29+
1430

1531
def load_embedder(embedder_path: str, device):
1632
try:

lib/rvc/train.py

Lines changed: 160 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import torch
1212
import torch.distributed as dist
1313
import torch.multiprocessing as mp
14+
import torchaudio
1415
import tqdm
1516
from sklearn.cluster import MiniBatchKMeans
1617
from torch.cuda.amp import GradScaler, autocast
@@ -22,20 +23,15 @@
2223
from . import commons, utils
2324
from .checkpoints import save
2425
from .config import DatasetMetadata, TrainConfig
25-
from .data_utils import (
26-
DistributedBucketSampler,
27-
TextAudioCollate,
28-
TextAudioCollateMultiNSFsid,
29-
TextAudioLoader,
30-
TextAudioLoaderMultiNSFsid,
31-
)
26+
from .data_utils import (DistributedBucketSampler, TextAudioCollate,
27+
TextAudioCollateMultiNSFsid, TextAudioLoader,
28+
TextAudioLoaderMultiNSFsid)
3229
from .losses import discriminator_loss, feature_loss, generator_loss, kl_loss
3330
from .mel_processing import mel_spectrogram_torch, spec_to_mel_torch
34-
from .models import (
35-
MultiPeriodDiscriminator,
36-
SynthesizerTrnMs256NSFSid,
37-
SynthesizerTrnMs256NSFSidNono,
38-
)
31+
from .models import (MultiPeriodDiscriminator, SynthesizerTrnMs256NSFSid,
32+
SynthesizerTrnMs256NSFSidNono)
33+
from .preprocessing.extract_feature import (MODELS_DIR, get_embedder,
34+
load_embedder)
3935

4036

4137
def is_audio_file(file: str):
@@ -149,6 +145,91 @@ def list_data(dir: str):
149145
json.dump(meta, f, indent=2)
150146

151147

148+
def change_speaker(net_g, speaker_info, embedder, embedding_output_layer, phone, phone_lengths, pitch, pitchf, spec_lengths):
149+
"""
150+
random change formant
151+
inspired by https://github.com/auspicious3000/contentvec/blob/d746688a32940f4bee410ed7c87ec9cf8ff04f74/contentvec/data/audio/audio_utils_1.py#L179
152+
"""
153+
N = phone.shape[0]
154+
device = phone.device
155+
dtype = phone.dtype
156+
157+
f0_bin = 256
158+
f0_max = 1100.0
159+
f0_min = 50.0
160+
f0_mel_min = 1127 * np.log(1 + f0_min / 700)
161+
f0_mel_max = 1127 * np.log(1 + f0_max / 700)
162+
163+
pitch_median = torch.median(pitchf, 1).values
164+
lo = 75. + 25. * (pitch_median >= 200).to(dtype=dtype)
165+
hi = 250. + 150. * (pitch_median >= 200).to(dtype=dtype)
166+
pitch_median = torch.clip(pitch_median, lo, hi).unsqueeze(1)
167+
168+
shift_pitch = torch.exp2((1. - 2. * torch.rand(N)) / 2).unsqueeze(1).to(device, dtype) # ピッチを1オクターブの範囲でずらす
169+
170+
new_sid = np.random.choice(np.arange(len(speaker_info))[speaker_info > 0], size=N)
171+
rel_pitch = pitchf / pitch_median
172+
new_pitch_median = torch.from_numpy(speaker_info[new_sid]).to(device, dtype).unsqueeze(1) * shift_pitch
173+
new_pitchf = new_pitch_median * rel_pitch
174+
new_sid = torch.from_numpy(new_sid).to(device)
175+
176+
new_pitch = 1127. * torch.log(1. + new_pitchf / 700.)
177+
new_pitch = (pitch - f0_mel_min) * (f0_bin - 2.) / (f0_mel_max - f0_mel_min) + 1.
178+
new_pitch = torch.clip(new_pitch, 1, f0_bin - 1).to(dtype=torch.int)
179+
180+
new_wave = net_g.infer(phone, phone_lengths, new_pitch, new_pitchf, new_sid)[0]
181+
new_wave_16k = torchaudio.functional.resample(new_wave, net_g.sr, 16000, rolloff=0.99).squeeze(1)
182+
padding_mask = torch.arange(new_wave_16k.shape[1]).unsqueeze(0).to(device) > (spec_lengths.unsqueeze(1) * 160).to(device)
183+
184+
inputs = {
185+
"source": new_wave_16k.to(device, dtype),
186+
"padding_mask": padding_mask.to(device),
187+
"output_layer": embedding_output_layer
188+
}
189+
logits = embedder.extract_features(**inputs)
190+
if phone.shape[-1] == 768:
191+
feats = logits[0]
192+
else:
193+
feats = embedder.final_proj(logits[0])
194+
feats = torch.repeat_interleave(feats, 2, 1)
195+
new_phone = torch.zeros(phone.shape).to(device, dtype)
196+
new_phone[:, :feats.shape[1]] = feats[:, :phone.shape[1]]
197+
return new_phone.to(device)
198+
199+
200+
def change_speaker_nono(net_g, embedder, embedding_output_layer, phone, phone_lengths, spec_lengths):
201+
"""
202+
random change formant
203+
inspired by https://github.com/auspicious3000/contentvec/blob/d746688a32940f4bee410ed7c87ec9cf8ff04f74/contentvec/data/audio/audio_utils_1.py#L179
204+
"""
205+
N = phone.shape[0]
206+
device = phone.device
207+
dtype = phone.dtype
208+
209+
new_sid = np.random.randint(net_g.spk_embed_dim, size=N)
210+
new_sid = torch.from_numpy(new_sid).to(device)
211+
212+
new_wave = net_g.infer(phone, phone_lengths, new_sid)[0]
213+
new_wave_16k = torchaudio.functional.resample(new_wave, net_g.sr, 16000, rolloff=0.99).squeeze(1)
214+
padding_mask = torch.arange(new_wave_16k.shape[1]).unsqueeze(0).to(device) > (spec_lengths.unsqueeze(1) * 160).to(device)
215+
216+
inputs = {
217+
"source": new_wave_16k.to(device, dtype),
218+
"padding_mask": padding_mask.to(device),
219+
"output_layer": embedding_output_layer
220+
}
221+
222+
logits = embedder.extract_features(**inputs)
223+
if phone.shape[-1] == 768:
224+
feats = logits[0]
225+
else:
226+
feats = embedder.final_proj(logits[0])
227+
feats = torch.repeat_interleave(feats, 2, 1)
228+
new_phone = torch.zeros(phone.shape).to(device, dtype)
229+
new_phone[:, :feats.shape[1]] = feats[:, :phone.shape[1]]
230+
return new_phone.to(device)
231+
232+
152233
def train_index(
153234
training_dir: str,
154235
model_name: str,
@@ -225,6 +306,9 @@ def train_model(
225306
sample_rate: int,
226307
f0: bool,
227308
batch_size: int,
309+
augment: bool,
310+
augment_path: Optional[str],
311+
speaker_info_path: Optional[str],
228312
cache_batch: bool,
229313
total_epoch: int,
230314
save_every_epoch: int,
@@ -261,6 +345,9 @@ def train_model(
261345
sample_rate,
262346
f0,
263347
batch_size,
348+
augment,
349+
augment_path,
350+
speaker_info_path,
264351
cache_batch,
265352
total_epoch,
266353
save_every_epoch,
@@ -284,6 +371,9 @@ def train_model(
284371
sample_rate,
285372
f0,
286373
batch_size,
374+
augment,
375+
augment_path,
376+
speaker_info_path,
287377
cache_batch,
288378
total_epoch,
289379
save_every_epoch,
@@ -319,6 +409,9 @@ def training_runner(
319409
sample_rate: int,
320410
f0: bool,
321411
batch_size: int,
412+
augment: bool,
413+
augment_path: Optional[str],
414+
speaker_info_path: Optional[str],
322415
cache_in_gpu: bool,
323416
total_epoch: int,
324417
save_every_epoch: int,
@@ -395,14 +488,15 @@ def training_runner(
395488
config.train.segment_size // config.data.hop_length,
396489
**config.model.dict(),
397490
is_half=config.train.fp16_run,
398-
sr=sample_rate,
491+
sr=int(sample_rate[:-1] + "000"),
399492
)
400493
else:
401494
net_g = SynthesizerTrnMs256NSFSidNono(
402495
config.data.filter_length // 2 + 1,
403496
config.train.segment_size // config.data.hop_length,
404497
**config.model.dict(),
405498
is_half=config.train.fp16_run,
499+
sr=int(sample_rate[:-1] + "000"),
406500
)
407501

408502
if is_multi_process:
@@ -440,6 +534,48 @@ def training_runner(
440534
last_d_state = utils.latest_checkpoint_path(state_dir, "D_*.pth")
441535
last_g_state = utils.latest_checkpoint_path(state_dir, "G_*.pth")
442536

537+
if augment:
538+
# load embedder
539+
embedder_filepath, _, embedder_load_from = get_embedder(embedder_name)
540+
541+
if embedder_load_from == "local":
542+
embedder_filepath = os.path.join(
543+
MODELS_DIR, "embeddings", embedder_filepath
544+
)
545+
embedder, _ = load_embedder(embedder_filepath, device)
546+
if not config.train.fp16_run:
547+
embedder = embedder.float()
548+
549+
if (augment_path is not None):
550+
state_dict = torch.load(augment_path, map_location="cpu")
551+
if state_dict["f0"] == 1:
552+
augment_net_g = SynthesizerTrnMs256NSFSid(
553+
**state_dict["params"], is_half=config.train.fp16_run
554+
)
555+
augment_speaker_info = np.load(speaker_info_path)
556+
else:
557+
augment_net_g = SynthesizerTrnMs256NSFSidNono(
558+
**state_dict["params"], is_half=config.train.fp16_run
559+
)
560+
561+
augment_net_g.load_state_dict(state_dict["weight"], strict=False)
562+
augment_net_g.eval().to(device)
563+
564+
if config.train.fp16_run:
565+
augment_net_g = augment_net_g.half()
566+
else:
567+
augment_net_g= augment_net_g.float()
568+
else:
569+
augment_net_g = net_g
570+
if f0:
571+
medians = [[] for _ in range(augment_net_g.spk_embed_dim)]
572+
for file in training_meta.files.values():
573+
f0f = np.load(file.f0nsf)
574+
if np.any(f0f > 0):
575+
medians[file.speaker_id].append(np.median(f0f[f0f > 0]))
576+
augment_speaker_info = np.array([np.median(x) if len(x) else 0. for x in medians])
577+
np.save(os.path.join(training_dir, "speaker_info.npy"), augment_speaker_info)
578+
443579
if last_d_state is None or last_g_state is None:
444580
epoch = 1
445581
global_step = 0
@@ -520,6 +656,7 @@ def training_runner(
520656
cache = []
521657
progress_bar = tqdm.tqdm(range((total_epoch - epoch + 1) * len(train_loader)))
522658
progress_bar.set_postfix(epoch=epoch)
659+
step = -1
523660
for epoch in range(epoch, total_epoch + 1):
524661
train_loader.batch_sampler.set_epoch(epoch)
525662

@@ -536,6 +673,7 @@ def training_runner(
536673
shuffle(cache)
537674

538675
for batch_idx, batch in data:
676+
step += 1
539677
progress_bar.update(1)
540678
if f0:
541679
(
@@ -614,6 +752,15 @@ def training_runner(
614752
)
615753

616754
with autocast(enabled=config.train.fp16_run):
755+
if augment:
756+
with torch.no_grad():
757+
if type(augment_net_g) == SynthesizerTrnMs256NSFSid:
758+
new_phone = change_speaker(augment_net_g, augment_speaker_info, embedder, embedding_output_layer, phone, phone_lengths, pitch, pitchf, spec_lengths)
759+
else:
760+
new_phone = change_speaker_nono(augment_net_g, embedder, embedding_output_layer, phone, phone_lengths, spec_lengths)
761+
weight = np.power(.5, step / len(train_dataset)) # 学習の初期はそのままのphone embeddingを使う
762+
phone = phone * weight + new_phone * (1. - weight)
763+
617764
if f0:
618765
(
619766
y_hat,

modules/tabs/training.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
import gradio as gr
77

88
from lib.rvc.preprocessing import extract_f0, extract_feature, split
9-
from lib.rvc.train import create_dataset_meta, glob_dataset, train_index, train_model
9+
from lib.rvc.train import (create_dataset_meta, glob_dataset, train_index,
10+
train_model)
1011
from modules import models, utils
1112
from modules.shared import MODELS_DIR, device, half_support
1213
from modules.ui import Tab
@@ -141,6 +142,10 @@ def train_all(
141142
norm_audio_when_preprocess,
142143
pitch_extraction_algo,
143144
batch_size,
145+
augment,
146+
augment_from_pretrain,
147+
augment_path,
148+
speaker_info_path,
144149
cache_batch,
145150
num_epochs,
146151
save_every_epoch,
@@ -234,6 +239,10 @@ def train_all(
234239
)
235240
out_dir = os.path.join(MODELS_DIR, "checkpoints")
236241

242+
if not augment_from_pretrain:
243+
augment_path = None
244+
speaker_info_path = None
245+
237246
train_model(
238247
gpu_ids,
239248
config,
@@ -243,6 +252,9 @@ def train_all(
243252
sampling_rate_str,
244253
f0,
245254
batch_size,
255+
augment,
256+
augment_path,
257+
speaker_info_path,
246258
cache_batch,
247259
num_epochs,
248260
save_every_epoch,
@@ -363,6 +375,17 @@ def train_all(
363375
fp16 = gr.Checkbox(
364376
label="FP16", value=half_support, disabled=not half_support
365377
)
378+
with gr.Row().style(equal_height=False):
379+
augment = gr.Checkbox(label="Augment", value=False)
380+
augment_from_pretrain = gr.Checkbox(label="Augment From Pretrain", value=False)
381+
augment_path = gr.Textbox(
382+
label="Pre trained generator path (pth)",
383+
value="file is not prepared"
384+
)
385+
speaker_info_path = gr.Textbox(
386+
label="speaker info path (npy)",
387+
value="file is not prepared"
388+
)
366389
with gr.Row().style(equal_height=False):
367390
pre_trained_generator = gr.Textbox(
368391
label="Pre trained generator path",
@@ -438,6 +461,10 @@ def train_all(
438461
norm_audio_when_preprocess,
439462
pitch_extraction_algo,
440463
batch_size,
464+
augment,
465+
augment_from_pretrain,
466+
augment_path,
467+
speaker_info_path,
441468
cache_batch,
442469
num_epochs,
443470
save_every_epoch,

0 commit comments

Comments
 (0)