Skip to content

Commit 30151f8

Browse files
committed
augmentationの追加
1 parent a488c7e commit 30151f8

File tree

4 files changed

+109
-13
lines changed

4 files changed

+109
-13
lines changed

lib/rvc/models.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -525,6 +525,7 @@ def __init__(
525525
self.segment_size = segment_size
526526
self.gin_channels = gin_channels
527527
self.emb_channels = emb_channels
528+
self.sr = sr
528529
# self.hop_length = hop_length#
529530
self.spk_embed_dim = spk_embed_dim
530531
self.enc_p = TextEncoder(
@@ -644,6 +645,7 @@ def __init__(
644645
self.segment_size = segment_size
645646
self.gin_channels = gin_channels
646647
self.emb_channels = emb_channels
648+
self.sr = sr
647649
# self.hop_length = hop_length#
648650
self.spk_embed_dim = spk_embed_dim
649651
self.enc_p = TextEncoder(

lib/rvc/preprocessing/extract_feature.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,22 @@
1111
from fairseq import checkpoint_utils
1212
from tqdm import tqdm
1313

14+
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
15+
MODELS_DIR = os.path.join(ROOT_DIR, "models")
16+
EMBEDDINGS_LIST = {
17+
"hubert-base-japanese": (
18+
"rinna_hubert_base_jp.pt",
19+
"hubert-base-japanese",
20+
"local",
21+
),
22+
"contentvec": ("checkpoint_best_legacy_500.pt", "contentvec", "local"),
23+
}
24+
25+
def get_embedder(embedder_name):
26+
if embedder_name in EMBEDDINGS_LIST:
27+
return EMBEDDINGS_LIST[embedder_name]
28+
return None
29+
1430

1531
def load_embedder(embedder_path: str, device):
1632
try:

lib/rvc/train.py

Lines changed: 85 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@
1111
import torch
1212
import torch.distributed as dist
1313
import torch.multiprocessing as mp
14+
import torchaudio
1415
import tqdm
1516
from sklearn.cluster import MiniBatchKMeans
1617
from torch.cuda.amp import GradScaler, autocast
@@ -22,20 +23,15 @@
2223
from . import commons, utils
2324
from .checkpoints import save
2425
from .config import DatasetMetadata, TrainConfig
25-
from .data_utils import (
26-
DistributedBucketSampler,
27-
TextAudioCollate,
28-
TextAudioCollateMultiNSFsid,
29-
TextAudioLoader,
30-
TextAudioLoaderMultiNSFsid,
31-
)
26+
from .data_utils import (DistributedBucketSampler, TextAudioCollate,
27+
TextAudioCollateMultiNSFsid, TextAudioLoader,
28+
TextAudioLoaderMultiNSFsid)
3229
from .losses import discriminator_loss, feature_loss, generator_loss, kl_loss
3330
from .mel_processing import mel_spectrogram_torch, spec_to_mel_torch
34-
from .models import (
35-
MultiPeriodDiscriminator,
36-
SynthesizerTrnMs256NSFSid,
37-
SynthesizerTrnMs256NSFSidNono,
38-
)
31+
from .models import (MultiPeriodDiscriminator, SynthesizerTrnMs256NSFSid,
32+
SynthesizerTrnMs256NSFSidNono)
33+
from .preprocessing.extract_feature import (MODELS_DIR, get_embedder,
34+
load_embedder)
3935

4036

4137
def is_audio_file(file: str):
@@ -149,6 +145,60 @@ def list_data(dir: str):
149145
json.dump(meta, f, indent=2)
150146

151147

148+
def change_speaker(net_g, embedder, embedding_output_layer, phone, phone_lengths, pitch, pitchf, spec_lengths, sid):
149+
"""
150+
random change formant
151+
inspired by https://github.com/auspicious3000/contentvec/blob/d746688a32940f4bee410ed7c87ec9cf8ff04f74/contentvec/data/audio/audio_utils_1.py#L179
152+
"""
153+
N = pitchf.shape[0]
154+
device = pitchf.device
155+
dtype = pitchf.dtype
156+
157+
f0_bin = 256
158+
f0_max = 1100.0
159+
f0_min = 50.0
160+
f0_mel_min = 1127 * np.log(1 + f0_min / 700)
161+
f0_mel_max = 1127 * np.log(1 + f0_max / 700)
162+
163+
pitch_median = torch.median(pitchf, 1).values
164+
lo = 75. + 25. * (pitch_median >= 200).to(dtype=dtype)
165+
hi = 250. + 150. * (pitch_median >= 200).to(dtype=dtype)
166+
pitch_median = torch.clip(pitch_median, lo, hi).unsqueeze(1)
167+
168+
ratio_speaker = torch.pow(.5, 2. * torch.rand(N)).unsqueeze(1).to(device, dtype) # 変更後の話者にピッチの中央値を[.25, .1]の割合で合わせる
169+
shift_pitch = torch.exp2((1. - 2. * torch.rand(N)) / 2).unsqueeze(1).to(device, dtype) # ピッチを1オクターブの範囲でずらす
170+
171+
shuffle_ixs = np.arange(N)
172+
np.random.shuffle(shuffle_ixs)
173+
rel_pitch = pitchf / pitch_median
174+
new_pitch_median = torch.exp2(torch.log2(pitch_median[shuffle_ixs]) * ratio_speaker + torch.log2(pitch_median) *(1. - ratio_speaker)) * shift_pitch
175+
new_pitchf = new_pitch_median * rel_pitch
176+
new_sid = sid[shuffle_ixs]
177+
178+
new_pitch = 1127. * torch.log(1. + new_pitchf / 700.)
179+
new_pitch = (pitch - f0_mel_min) * (f0_bin - 2.) / (f0_mel_max - f0_mel_min) + 1.
180+
new_pitch = torch.clip(new_pitch, 1, f0_bin - 1).to(dtype=torch.int)
181+
182+
new_wave = net_g.infer(phone, phone_lengths, new_pitch, new_pitchf, new_sid)[0]
183+
new_wave_16k = torchaudio.functional.resample(new_wave, net_g.sr, 16000, rolloff=0.99).squeeze(1)
184+
padding_mask = torch.arange(new_wave_16k.shape[1]).unsqueeze(0).to(device) > (spec_lengths.unsqueeze(1) * 160).to(device)
185+
186+
inputs = {
187+
"source": new_wave_16k.to(device, dtype),
188+
"padding_mask": padding_mask.to(device),
189+
"output_layer": embedding_output_layer
190+
}
191+
logits = embedder.extract_features(**inputs)
192+
if phone.shape[-1] == 768:
193+
feats = logits[0]
194+
else:
195+
feats = embedder.final_proj(logits[0])
196+
feats = torch.repeat_interleave(feats, 2, 1)
197+
new_phone = torch.zeros(phone.shape).to(device, dtype)
198+
new_phone[:, :feats.shape[1]] = feats[:, :phone.shape[1]]
199+
return new_phone.to(device)
200+
201+
152202
def train_index(
153203
training_dir: str,
154204
model_name: str,
@@ -225,6 +275,7 @@ def train_model(
225275
sample_rate: int,
226276
f0: bool,
227277
batch_size: int,
278+
augment: bool,
228279
cache_batch: bool,
229280
total_epoch: int,
230281
save_every_epoch: int,
@@ -261,6 +312,7 @@ def train_model(
261312
sample_rate,
262313
f0,
263314
batch_size,
315+
augment,
264316
cache_batch,
265317
total_epoch,
266318
save_every_epoch,
@@ -284,6 +336,7 @@ def train_model(
284336
sample_rate,
285337
f0,
286338
batch_size,
339+
augment,
287340
cache_batch,
288341
total_epoch,
289342
save_every_epoch,
@@ -319,6 +372,7 @@ def training_runner(
319372
sample_rate: int,
320373
f0: bool,
321374
batch_size: int,
375+
augment: bool,
322376
cache_in_gpu: bool,
323377
total_epoch: int,
324378
save_every_epoch: int,
@@ -359,6 +413,17 @@ def training_runner(
359413

360414
torch.manual_seed(config.train.seed)
361415

416+
if augment:
417+
embedder_filepath, _, embedder_load_from = get_embedder(embedder_name)
418+
419+
if embedder_load_from == "local":
420+
embedder_filepath = os.path.join(
421+
MODELS_DIR, "embeddings", embedder_filepath
422+
)
423+
embedder, _ = load_embedder(embedder_filepath, device)
424+
if not config.train.fp16_run:
425+
embedder = embedder.float()
426+
362427
if f0:
363428
train_dataset = TextAudioLoaderMultiNSFsid(training_meta, config.data)
364429
else:
@@ -520,6 +585,7 @@ def training_runner(
520585
cache = []
521586
progress_bar = tqdm.tqdm(range((total_epoch - epoch + 1) * len(train_loader)))
522587
progress_bar.set_postfix(epoch=epoch)
588+
step = -1
523589
for epoch in range(epoch, total_epoch + 1):
524590
train_loader.batch_sampler.set_epoch(epoch)
525591

@@ -536,6 +602,7 @@ def training_runner(
536602
shuffle(cache)
537603

538604
for batch_idx, batch in data:
605+
step += 1
539606
progress_bar.update(1)
540607
if f0:
541608
(
@@ -614,6 +681,12 @@ def training_runner(
614681
)
615682

616683
with autocast(enabled=config.train.fp16_run):
684+
if f0 and augment:
685+
with torch.no_grad():
686+
new_phone = change_speaker(net_g, embedder, embedding_output_layer, phone, phone_lengths, pitch, pitchf, spec_lengths, sid)
687+
weight = np.power(.5, step / len(train_dataset)) # 学習の初期はそのままのphone embeddingを使う
688+
phone = phone * weight + new_phone * (1. - weight)
689+
617690
if f0:
618691
(
619692
y_hat,

modules/tabs/training.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,7 +6,8 @@
66
import gradio as gr
77

88
from lib.rvc.preprocessing import extract_f0, extract_feature, split
9-
from lib.rvc.train import create_dataset_meta, glob_dataset, train_index, train_model
9+
from lib.rvc.train import (create_dataset_meta, glob_dataset, train_index,
10+
train_model)
1011
from modules import models, utils
1112
from modules.shared import MODELS_DIR, device, half_support
1213
from modules.ui import Tab
@@ -141,6 +142,7 @@ def train_all(
141142
norm_audio_when_preprocess,
142143
pitch_extraction_algo,
143144
batch_size,
145+
augment,
144146
cache_batch,
145147
num_epochs,
146148
save_every_epoch,
@@ -243,6 +245,7 @@ def train_all(
243245
sampling_rate_str,
244246
f0,
245247
batch_size,
248+
augment,
246249
cache_batch,
247250
num_epochs,
248251
save_every_epoch,
@@ -359,6 +362,7 @@ def train_all(
359362
step=1,
360363
label="Save every epoch",
361364
)
365+
augment = gr.Checkbox(label="Augment", value=False)
362366
cache_batch = gr.Checkbox(label="Cache batch", value=True)
363367
fp16 = gr.Checkbox(
364368
label="FP16", value=half_support, disabled=not half_support
@@ -438,6 +442,7 @@ def train_all(
438442
norm_audio_when_preprocess,
439443
pitch_extraction_algo,
440444
batch_size,
445+
augment,
441446
cache_batch,
442447
num_epochs,
443448
save_every_epoch,

0 commit comments

Comments
 (0)