Skip to content

Commit 65c37b2

Browse files
authored
Merge pull request #55 from Iamgoofball/multispeaker_adjustments
Crepe Support + Changed Multispeaker Training
2 parents b2c22cd + 6bf1c33 commit 65c37b2

File tree

6 files changed

+204
-28
lines changed

6 files changed

+204
-28
lines changed

lib/rvc/pipeline.py

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -8,7 +8,8 @@
88
import scipy.signal as signal
99
import torch
1010
import torch.nn.functional as F
11-
11+
import torchcrepe
12+
from torch import Tensor
1213
# from faiss.swigfaiss_avx2 import IndexIVFFlat # cause crash on windows' faiss-cpu installed from pip
1314
from fairseq.models.hubert import HubertModel
1415

@@ -51,6 +52,85 @@ def __init__(self, tgt_sr: int, device: Union[str, torch.device], is_half: bool)
5152
self.device = device
5253
self.is_half = is_half
5354

55+
def get_optimal_torch_device(self, index: int = 0) -> torch.device:
56+
# Get cuda device
57+
if torch.cuda.is_available():
58+
return torch.device(f"cuda:{index % torch.cuda.device_count()}") # Very fast
59+
elif torch.backends.mps.is_available():
60+
return torch.device("mps")
61+
# Insert an else here to grab "xla" devices if available. TO DO later. Requires the torch_xla.core.xla_model library
62+
# Else wise return the "cpu" as a torch device,
63+
return torch.device("cpu")
64+
65+
def get_f0_crepe_computation(
66+
self,
67+
x,
68+
f0_min,
69+
f0_max,
70+
p_len,
71+
hop_length=64, # 512 before. Hop length changes the speed that the voice jumps to a different dramatic pitch. Lower hop lengths means more pitch accuracy but longer inference time.
72+
model="full", # Either use crepe-tiny "tiny" or crepe "full". Default is full
73+
):
74+
x = x.astype(np.float32) # fixes the F.conv2D exception. We needed to convert double to float.
75+
x /= np.quantile(np.abs(x), 0.999)
76+
torch_device = self.get_optimal_torch_device()
77+
audio = torch.from_numpy(x).to(torch_device, copy=True)
78+
audio = torch.unsqueeze(audio, dim=0)
79+
if audio.ndim == 2 and audio.shape[0] > 1:
80+
audio = torch.mean(audio, dim=0, keepdim=True).detach()
81+
audio = audio.detach()
82+
print("Initiating prediction with a crepe_hop_length of: " + str(hop_length))
83+
pitch: Tensor = torchcrepe.predict(
84+
audio,
85+
self.sr,
86+
hop_length,
87+
f0_min,
88+
f0_max,
89+
model,
90+
batch_size=hop_length * 2,
91+
device=torch_device,
92+
pad=True
93+
)
94+
p_len = p_len or x.shape[0] // hop_length
95+
# Resize the pitch for final f0
96+
source = np.array(pitch.squeeze(0).cpu().float().numpy())
97+
source[source < 0.001] = np.nan
98+
target = np.interp(
99+
np.arange(0, len(source) * p_len, len(source)) / p_len,
100+
np.arange(0, len(source)),
101+
source
102+
)
103+
f0 = np.nan_to_num(target)
104+
return f0 # Resized f0
105+
106+
def get_f0_official_crepe_computation(
107+
self,
108+
x,
109+
f0_min,
110+
f0_max,
111+
model="full",
112+
):
113+
# Pick a batch size that doesn't cause memory errors on your gpu
114+
batch_size = 512
115+
# Compute pitch using first gpu
116+
audio = torch.tensor(np.copy(x))[None].float()
117+
f0, pd = torchcrepe.predict(
118+
audio,
119+
self.sr,
120+
self.window,
121+
f0_min,
122+
f0_max,
123+
model,
124+
batch_size=batch_size,
125+
device=self.device,
126+
return_periodicity=True,
127+
)
128+
pd = torchcrepe.filter.median(pd, 3)
129+
f0 = torchcrepe.filter.mean(f0, 3)
130+
f0[pd < 0.1] = 0
131+
f0 = f0[0].cpu().numpy()
132+
return f0
133+
54134
def get_f0(
55135
self,
56136
x: np.ndarray,
@@ -84,6 +164,10 @@ def get_f0(
84164
)
85165
f0 = pyworld.stonemask(x.astype(np.double), f0, t, self.sr)
86166
f0 = signal.medfilt(f0, 3)
167+
elif f0_method == "mangio-crepe":
168+
f0 = self.get_f0_crepe_computation(x, f0_min, f0_max, p_len, 160, "full")
169+
elif f0_method == "crepe":
170+
f0 = self.get_f0_official_crepe_computation(x, f0_min, f0_max, "full")
87171

88172
f0 *= pow(2, f0_up_key / 12)
89173
tf0 = self.sr // self.window # f0 points per second

lib/rvc/preprocessing/extract_f0.py

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -5,10 +5,92 @@
55

66
import numpy as np
77
import pyworld
8+
import torch
9+
import torchcrepe
10+
from torch import Tensor
811
from tqdm import tqdm
912

1013
from lib.rvc.utils import load_audio
1114

15+
def get_optimal_torch_device(index: int = 0) -> torch.device:
16+
# Get cuda device
17+
if torch.cuda.is_available():
18+
return torch.device(f"cuda:{index % torch.cuda.device_count()}") # Very fast
19+
elif torch.backends.mps.is_available():
20+
return torch.device("mps")
21+
# Insert an else here to grab "xla" devices if available. TO DO later. Requires the torch_xla.core.xla_model library
22+
# Else wise return the "cpu" as a torch device,
23+
return torch.device("cpu")
24+
25+
def get_f0_official_crepe_computation(
26+
x,
27+
sr,
28+
f0_min,
29+
f0_max,
30+
model="full",
31+
):
32+
batch_size = 512
33+
torch_device = get_optimal_torch_device()
34+
audio = torch.tensor(np.copy(x))[None].float()
35+
f0, pd = torchcrepe.predict(
36+
audio,
37+
sr,
38+
160,
39+
f0_min,
40+
f0_max,
41+
model,
42+
batch_size=batch_size,
43+
device=torch_device,
44+
return_periodicity=True,
45+
)
46+
pd = torchcrepe.filter.median(pd, 3)
47+
f0 = torchcrepe.filter.mean(f0, 3)
48+
f0[pd < 0.1] = 0
49+
f0 = f0[0].cpu().numpy()
50+
f0 = f0[1:] # Get rid of extra first frame
51+
return f0
52+
53+
def get_f0_crepe_computation(
54+
x,
55+
sr,
56+
f0_min,
57+
f0_max,
58+
hop_length=160, # 512 before. Hop length changes the speed that the voice jumps to a different dramatic pitch. Lower hop lengths means more pitch accuracy but longer inference time.
59+
model="full", # Either use crepe-tiny "tiny" or crepe "full". Default is full
60+
):
61+
x = x.astype(np.float32) # fixes the F.conv2D exception. We needed to convert double to float.
62+
x /= np.quantile(np.abs(x), 0.999)
63+
torch_device = get_optimal_torch_device()
64+
audio = torch.from_numpy(x).to(torch_device, copy=True)
65+
audio = torch.unsqueeze(audio, dim=0)
66+
if audio.ndim == 2 and audio.shape[0] > 1:
67+
audio = torch.mean(audio, dim=0, keepdim=True).detach()
68+
audio = audio.detach()
69+
print("Initiating prediction with a crepe_hop_length of: " + str(hop_length))
70+
pitch: Tensor = torchcrepe.predict(
71+
audio,
72+
sr,
73+
hop_length,
74+
f0_min,
75+
f0_max,
76+
model,
77+
batch_size=hop_length * 2,
78+
device=torch_device,
79+
pad=True
80+
)
81+
p_len = x.shape[0] // hop_length
82+
# Resize the pitch for final f0
83+
source = np.array(pitch.squeeze(0).cpu().float().numpy())
84+
source[source < 0.001] = np.nan
85+
target = np.interp(
86+
np.arange(0, len(source) * p_len, len(source)) / p_len,
87+
np.arange(0, len(source)),
88+
source
89+
)
90+
f0 = np.nan_to_num(target)
91+
f0 = f0[1:] # Get rid of extra first frame
92+
return f0 # Resized f0
93+
1294

1395
def compute_f0(
1496
path: str,
@@ -37,6 +119,10 @@ def compute_f0(
37119
frame_period=1000 * hop / fs,
38120
)
39121
f0 = pyworld.stonemask(x.astype(np.double), f0, t, fs)
122+
elif f0_method == "mangio-crepe":
123+
f0 = get_f0_crepe_computation(x, fs, f0_min, f0_max, 160, "full")
124+
elif f0_method == "crepe":
125+
f0 = get_f0_official_crepe_computation(x.astype(np.double), fs, f0_min, f0_max, "full")
40126
return f0
41127

42128

lib/rvc/train.py

Lines changed: 27 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -13,6 +13,7 @@
1313
import torch.multiprocessing as mp
1414
import torchaudio
1515
import tqdm
16+
import json
1617
from sklearn.cluster import MiniBatchKMeans
1718
from torch.cuda.amp import GradScaler, autocast
1819
from torch.nn import functional as F
@@ -56,33 +57,37 @@ def glob_dataset(
5657
recursive: bool = True,
5758
):
5859
globs = glob_str.split(",")
60+
speaker_count = 0
5961
datasets_speakers = []
62+
speaker_to_id_mapping = {}
6063
for glob_str in globs:
6164
if os.path.isdir(glob_str):
62-
files = os.listdir(glob_str)
6365
if multiple_speakers:
64-
# pattern: {glob_str}/{decimal}[_]* and isdir
65-
multi_speakers_dir = [
66-
(os.path.join(glob_str, f), int(f.split("_")[0]))
67-
for f in files
68-
if os.path.isdir(os.path.join(glob_str, f))
69-
and f.split("_")[0].isdecimal()
70-
]
71-
72-
if len(multi_speakers_dir) > 0:
73-
# multi speakers at once train
74-
datasets_speakers = [
75-
(file, dir[1])
76-
for dir in multi_speakers_dir
77-
for file in glob.iglob(
78-
os.path.join(dir[0], "*"), recursive=recursive
79-
)
80-
if is_audio_file(file)
81-
]
82-
continue
66+
# Multispeaker format:
67+
# dataset_path/
68+
# - speakername/
69+
# - {wav name here}.wav
70+
# - ...
71+
# - next_speakername/
72+
# - {wav name here}.wav
73+
# - ...
74+
# - ...
75+
print("Multispeaker dataset enabled; Processing speakers.")
76+
for dir in tqdm.tqdm(os.listdir(glob_str)):
77+
print("Speaker ID " + str(speaker_count) + ": " + dir)
78+
speaker_to_id_mapping[dir] = speaker_count
79+
speaker_path = glob_str + "/" + dir
80+
for audio in tqdm.tqdm(os.listdir(speaker_path)):
81+
if is_audio_file(glob_str + "/" + dir + "/" + audio):
82+
datasets_speakers.append((glob_str + "/" + dir + "/" + audio, speaker_count))
83+
speaker_count += 1
84+
with open("./speaker_info.json", "w") as outfile:
85+
print("Dumped speaker info to ./speaker_info.json")
86+
json.dump(speaker_to_id_mapping, outfile)
87+
continue # Skip the normal speaker extend
8388

8489
glob_str = os.path.join(glob_str, "**", "*")
85-
90+
print("Single speaker dataset enabled; Processing speaker as ID " + str(speaker_id) + ".")
8691
datasets_speakers.extend(
8792
[
8893
(file, speaker_id)
@@ -91,7 +96,7 @@ def glob_dataset(
9196
]
9297
)
9398

94-
return sorted(datasets_speakers, key=operator.itemgetter(0))
99+
return sorted(datasets_speakers)
95100

96101

97102
def create_dataset_meta(training_dir: str, f0: bool):

modules/tabs/inference.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ def inference_options_ui(show_out_dir=True):
2222
minimum=-20, maximum=20, value=0, step=1, label="Transpose"
2323
)
2424
pitch_extraction_algo = gr.Radio(
25-
choices=["dio", "harvest"],
26-
value="dio",
25+
choices=["dio", "harvest", "mangio-crepe", "crepe"],
26+
value="crepe",
2727
label="Pitch Extraction Algorithm",
2828
)
2929
embedding_model = gr.Radio(

modules/tabs/training.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -354,8 +354,8 @@ def train_all(
354354
label="Normalize audio volume when preprocess",
355355
)
356356
pitch_extraction_algo = gr.Radio(
357-
choices=["dio", "harvest"],
358-
value="harvest",
357+
choices=["dio", "harvest", "mangio-crepe", "crepe"],
358+
value="crepe",
359359
label="Pitch extraction algorithm",
360360
)
361361
with gr.Row().style(equal_height=False):

requirements/main.txt

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,13 +5,14 @@ faiss-cpu==1.7.3
55
fairseq==0.12.2
66
matplotlib==3.7.1
77
scipy==1.9.3
8-
librosa==0.9.2
8+
librosa==0.9.1
99
pyworld==0.3.2
1010
soundfile==0.12.1
1111
ffmpeg-python==0.2.0
1212
pydub==0.25.1
1313
soxr==0.3.5
1414
transformers==4.28.1
15+
torchcrepe==0.0.20
1516

1617
tensorboard
1718
tensorboardX

0 commit comments

Comments
 (0)