Skip to content

Commit b71d595

Browse files
committed
Merge branch 'dev'
2 parents c7aa393 + c5c6b8c commit b71d595

20 files changed

+406
-394
lines changed

launch.py

Lines changed: 3 additions & 23 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,8 @@
1-
import subprocess
1+
import importlib.util
22
import os
3-
import sys
43
import shlex
5-
import importlib.util
4+
import subprocess
5+
import sys
66

77
commandline_args = os.environ.get("COMMANDLINE_ARGS", "")
88
sys.argv += shlex.split(commandline_args)
@@ -88,24 +88,6 @@ def extract_arg(args, name):
8888
return [x for x in args if x != name], name in args
8989

9090

91-
def fix_faiss():
92-
spec = importlib.util.find_spec("faiss")
93-
if (
94-
spec.submodule_search_locations is None
95-
or len(spec.submodule_search_locations) == 0
96-
):
97-
return
98-
dir = spec.submodule_search_locations[0]
99-
if os.path.exists(os.path.join(dir, "swigfaiss_avx2.py")):
100-
return
101-
try:
102-
os.symlink(
103-
os.path.join(dir, "swigfaiss.py"), os.path.join(dir, "swigfaiss_avx2.py")
104-
)
105-
except:
106-
pass
107-
108-
10991
def prepare_environment():
11092
commit = commit_hash()
11193

@@ -140,8 +122,6 @@ def prepare_environment():
140122
errdesc=f"Couldn't install requirements",
141123
)
142124

143-
fix_faiss()
144-
145125

146126
def start():
147127
os.environ["PATH"] = (

lib/rvc/checkpoints.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,12 @@ def write_config(state_dict: Dict[str, Any], cfg: Dict[str, Any]):
1414

1515
def create_trained_model(
1616
weights: Dict[str, Any],
17+
version: Literal["v1", "v2"],
1718
sr: str,
18-
f0: int,
19+
f0: bool,
1920
emb_name: str,
2021
emb_ch: int,
22+
emb_output_layer: int,
2123
epoch: int,
2224
):
2325
state_dict = OrderedDict()
@@ -101,19 +103,23 @@ def create_trained_model(
101103
"sr": 32000,
102104
},
103105
)
106+
state_dict["version"] = version
104107
state_dict["info"] = f"{epoch}epoch"
105108
state_dict["sr"] = sr
106-
state_dict["f0"] = int(f0)
109+
state_dict["f0"] = 1 if f0 else 0
107110
state_dict["embedder_name"] = emb_name
111+
state_dict["embedder_output_layer"] = emb_output_layer
108112
return state_dict
109113

110114

111115
def save(
112116
model,
117+
version: Literal["v1", "v2"],
113118
sr: str,
114-
f0: int,
119+
f0: bool,
115120
emb_name: str,
116121
emb_ch: int,
122+
emb_output_layer: int,
117123
filepath: str,
118124
epoch: int,
119125
):
@@ -126,10 +132,12 @@ def save(
126132

127133
state_dict = create_trained_model(
128134
state_dict,
135+
version,
129136
sr,
130137
f0,
131138
emb_name,
132139
emb_ch,
140+
emb_output_layer,
133141
epoch,
134142
)
135143
os.makedirs(os.path.dirname(filepath), exist_ok=True)

lib/rvc/config.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -52,6 +52,7 @@ class TrainConfigModel(BaseModel):
5252

5353

5454
class TrainConfig(BaseModel):
55+
version: Literal["v1", "v2"] = "v2"
5556
train: TrainConfigTrain
5657
data: TrainConfigData
5758
model: TrainConfigModel

lib/rvc/mel_processing.py

Lines changed: 35 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -24,23 +24,21 @@ def dynamic_range_decompression_torch(x, C=1):
2424

2525

2626
def spectral_normalize_torch(magnitudes):
27-
output = dynamic_range_compression_torch(magnitudes)
28-
return output
27+
return dynamic_range_compression_torch(magnitudes)
2928

3029

3130
def spectral_de_normalize_torch(magnitudes):
32-
output = dynamic_range_decompression_torch(magnitudes)
33-
return output
31+
return dynamic_range_decompression_torch(magnitudes)
3432

3533

3634
mel_basis = {}
3735
hann_window = {}
3836

3937

4038
def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False):
41-
if torch.min(y) < -1.0:
39+
if torch.min(y) < -1.07:
4240
print("min value is ", torch.min(y))
43-
if torch.max(y) > 1.0:
41+
if torch.max(y) > 1.07:
4442
print("max value is ", torch.max(y))
4543

4644
global hann_window
@@ -58,33 +56,25 @@ def spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center=False)
5856
)
5957
y = y.squeeze(1)
6058

61-
# 現在、mpsはtorch.stftをサポートしていない。
59+
# mps does not support torch.stft.
6260
if y.device.type == "mps":
63-
spec = torch.stft(
64-
y.cpu(),
65-
n_fft,
66-
hop_length=hop_size,
67-
win_length=win_size,
68-
window=hann_window[wnsize_dtype_device].cpu(),
69-
center=center,
70-
pad_mode="reflect",
71-
normalized=False,
72-
onesided=True,
73-
return_complex=False,
74-
).to(device=y.device)
61+
i = y.cpu()
62+
win = hann_window[wnsize_dtype_device].cpu()
7563
else:
76-
spec = torch.stft(
77-
y,
78-
n_fft,
79-
hop_length=hop_size,
80-
win_length=win_size,
81-
window=hann_window[wnsize_dtype_device],
82-
center=center,
83-
pad_mode="reflect",
84-
normalized=False,
85-
onesided=True,
86-
return_complex=False,
87-
)
64+
i = y
65+
win = hann_window[wnsize_dtype_device]
66+
spec = torch.stft(
67+
i,
68+
n_fft,
69+
hop_length=hop_size,
70+
win_length=win_size,
71+
window=win,
72+
center=center,
73+
pad_mode="reflect",
74+
normalized=False,
75+
onesided=True,
76+
return_complex=False,
77+
).to(device=y.device)
8878

8979
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
9080
return spec
@@ -99,71 +89,25 @@ def spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax):
9989
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
10090
dtype=spec.dtype, device=spec.device
10191
)
102-
spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
103-
spec = spectral_normalize_torch(spec)
104-
return spec
92+
melspec = torch.matmul(mel_basis[fmax_dtype_device], spec)
93+
melspec = spectral_normalize_torch(melspec)
94+
return melspec
10595

10696

10797
def mel_spectrogram_torch(
10898
y, n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax, center=False
10999
):
110-
if torch.min(y) < -1.0:
111-
print("min value is ", torch.min(y))
112-
if torch.max(y) > 1.0:
113-
print("max value is ", torch.max(y))
100+
"""Convert waveform into Mel-frequency Log-amplitude spectrogram.
114101
115-
global mel_basis, hann_window
116-
dtype_device = str(y.dtype) + "_" + str(y.device)
117-
fmax_dtype_device = str(fmax) + "_" + dtype_device
118-
wnsize_dtype_device = str(win_size) + "_" + dtype_device
119-
if fmax_dtype_device not in mel_basis:
120-
mel = librosa_mel_fn(sampling_rate, n_fft, num_mels, fmin, fmax)
121-
mel_basis[fmax_dtype_device] = torch.from_numpy(mel).to(
122-
dtype=y.dtype, device=y.device
123-
)
124-
if wnsize_dtype_device not in hann_window:
125-
hann_window[wnsize_dtype_device] = torch.hann_window(win_size).to(
126-
dtype=y.dtype, device=y.device
127-
)
128-
129-
y = torch.nn.functional.pad(
130-
y.unsqueeze(1),
131-
(int((n_fft - hop_size) / 2), int((n_fft - hop_size) / 2)),
132-
mode="reflect",
133-
)
134-
y = y.squeeze(1)
135-
136-
# 現在、mpsはtorch.stftをサポートしていない。
137-
if y.device.type == "mps":
138-
spec = torch.stft(
139-
y.cpu(),
140-
n_fft,
141-
hop_length=hop_size,
142-
win_length=win_size,
143-
window=hann_window[wnsize_dtype_device].cpu(),
144-
center=center,
145-
pad_mode="reflect",
146-
normalized=False,
147-
onesided=True,
148-
return_complex=False,
149-
).to(device=y.device)
150-
else:
151-
spec = torch.stft(
152-
y,
153-
n_fft,
154-
hop_length=hop_size,
155-
win_length=win_size,
156-
window=hann_window[wnsize_dtype_device],
157-
center=center,
158-
pad_mode="reflect",
159-
normalized=False,
160-
onesided=True,
161-
return_complex=False,
162-
)
163-
164-
spec = torch.sqrt(spec.pow(2).sum(-1) + 1e-6)
102+
Args:
103+
y :: (B, T) - Waveforms
104+
Returns:
105+
melspec :: (B, Freq, Frame) - Mel-frequency Log-amplitude spectrogram
106+
"""
107+
# Linear-frequency Linear-amplitude spectrogram :: (B, T) -> (B, Freq, Frame)
108+
spec = spectrogram_torch(y, n_fft, sampling_rate, hop_size, win_size, center)
165109

166-
spec = torch.matmul(mel_basis[fmax_dtype_device], spec)
167-
spec = spectral_normalize_torch(spec)
110+
# Mel-frequency Log-amplitude spectrogram :: (B, Freq, Frame) -> (B, Freq=num_mels, Frame)
111+
melspec = spec_to_mel_torch(spec, n_fft, num_mels, sampling_rate, fmin, fmax)
168112

169-
return spec
113+
return melspec

lib/rvc/models.py

Lines changed: 28 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -714,36 +714,6 @@ def infer(self, phone, phone_lengths, sid, max_len=None):
714714
return o, x_mask, (z, z_p, m_p, logs_p)
715715

716716

717-
class MultiPeriodDiscriminator(torch.nn.Module):
718-
def __init__(self, use_spectral_norm=False):
719-
super(MultiPeriodDiscriminator, self).__init__()
720-
periods = [2, 3, 5, 7, 11, 17]
721-
# periods = [3, 5, 7, 11, 17, 23, 37]
722-
723-
discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
724-
discs = discs + [
725-
DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
726-
]
727-
self.discriminators = nn.ModuleList(discs)
728-
729-
def forward(self, y, y_hat):
730-
y_d_rs = [] #
731-
y_d_gs = []
732-
fmap_rs = []
733-
fmap_gs = []
734-
for i, d in enumerate(self.discriminators):
735-
y_d_r, fmap_r = d(y)
736-
y_d_g, fmap_g = d(y_hat)
737-
# for j in range(len(fmap_r)):
738-
# print(i,j,y.shape,y_hat.shape,fmap_r[j].shape,fmap_g[j].shape)
739-
y_d_rs.append(y_d_r)
740-
y_d_gs.append(y_d_g)
741-
fmap_rs.append(fmap_r)
742-
fmap_gs.append(fmap_g)
743-
744-
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
745-
746-
747717
class DiscriminatorS(torch.nn.Module):
748718
def __init__(self, use_spectral_norm=False):
749719
super(DiscriminatorS, self).__init__()
@@ -851,3 +821,31 @@ def forward(self, x):
851821
x = torch.flatten(x, 1, -1)
852822

853823
return x, fmap
824+
825+
826+
class MultiPeriodDiscriminator(torch.nn.Module):
827+
def __init__(self, use_spectral_norm=False, periods=[2, 3, 5, 7, 11, 17]):
828+
super(MultiPeriodDiscriminator, self).__init__()
829+
830+
discs = [DiscriminatorS(use_spectral_norm=use_spectral_norm)]
831+
discs = discs + [
832+
DiscriminatorP(i, use_spectral_norm=use_spectral_norm) for i in periods
833+
]
834+
self.discriminators = nn.ModuleList(discs)
835+
836+
def forward(self, y, y_hat):
837+
y_d_rs = [] #
838+
y_d_gs = []
839+
fmap_rs = []
840+
fmap_gs = []
841+
for i, d in enumerate(self.discriminators):
842+
y_d_r, fmap_r = d(y)
843+
y_d_g, fmap_g = d(y_hat)
844+
# for j in range(len(fmap_r)):
845+
# print(i,j,y.shape,y_hat.shape,fmap_r[j].shape,fmap_g[j].shape)
846+
y_d_rs.append(y_d_r)
847+
y_d_gs.append(y_d_g)
848+
fmap_rs.append(fmap_r)
849+
fmap_gs.append(fmap_g)
850+
851+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs

0 commit comments

Comments
 (0)