Skip to content

Commit fae3965

Browse files
authored
Merge pull request #65 from nadare881/multispeaker
multispeaker対応+名前の自動化+sampleの保存
2 parents bf7c10c + a56e23c commit fae3965

File tree

3 files changed

+114
-73
lines changed

3 files changed

+114
-73
lines changed

lib/rvc/checkpoints.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,7 @@ def create_trained_model(
2121
emb_ch: int,
2222
emb_output_layer: int,
2323
epoch: int,
24+
speaker_info: Optional[dict[str, int]]
2425
):
2526
state_dict = OrderedDict()
2627
state_dict["weight"] = {}
@@ -47,7 +48,7 @@ def create_trained_model(
4748
"upsample_rates": [10, 10, 2, 2],
4849
"upsample_initial_channel": 512,
4950
"upsample_kernel_sizes": [16, 16, 4, 4],
50-
"spk_embed_dim": 109,
51+
"spk_embed_dim": 109 if speaker_info is None else len(speaker_info),
5152
"gin_channels": 256,
5253
"emb_channels": emb_ch,
5354
"sr": 40000,
@@ -72,7 +73,7 @@ def create_trained_model(
7273
"upsample_rates": [10, 6, 2, 2, 2],
7374
"upsample_initial_channel": 512,
7475
"upsample_kernel_sizes": [16, 16, 4, 4, 4],
75-
"spk_embed_dim": 109,
76+
"spk_embed_dim": 109 if speaker_info is None else len(speaker_info),
7677
"gin_channels": 256,
7778
"emb_channels": emb_ch,
7879
"sr": 48000,
@@ -97,7 +98,7 @@ def create_trained_model(
9798
"upsample_rates": [10, 4, 2, 2, 2],
9899
"upsample_initial_channel": 512,
99100
"upsample_kernel_sizes": [16, 16, 4, 4, 4],
100-
"spk_embed_dim": 109,
101+
"spk_embed_dim": 109 if speaker_info is None else len(speaker_info),
101102
"gin_channels": 256,
102103
"emb_channels": emb_ch,
103104
"sr": 32000,
@@ -109,6 +110,8 @@ def create_trained_model(
109110
state_dict["f0"] = 1 if f0 else 0
110111
state_dict["embedder_name"] = emb_name
111112
state_dict["embedder_output_layer"] = emb_output_layer
113+
if not speaker_info is None:
114+
state_dict["speaker_info"] = {str(v): str(k) for k, v in speaker_info.items()}
112115
return state_dict
113116

114117

@@ -122,6 +125,7 @@ def save(
122125
emb_output_layer: int,
123126
filepath: str,
124127
epoch: int,
128+
speaker_info: Optional[dict[str, int]]
125129
):
126130
if hasattr(model, "module"):
127131
state_dict = model.module.state_dict()
@@ -139,6 +143,7 @@ def save(
139143
emb_ch,
140144
emb_output_layer,
141145
epoch,
146+
speaker_info
142147
)
143148
os.makedirs(os.path.dirname(filepath), exist_ok=True)
144149
torch.save(state_dict, filepath)

0 commit comments

Comments
 (0)