@@ -150,6 +150,7 @@ def train_all(
150
150
save_every_epoch ,
151
151
save_wav_with_checkpoint ,
152
152
fp16 ,
153
+ save_only_last ,
153
154
pre_trained_bottom_model_g ,
154
155
pre_trained_bottom_model_d ,
155
156
run_train_index ,
@@ -264,7 +265,7 @@ def train_all(
264
265
pre_trained_bottom_model_d ,
265
266
embedder_name ,
266
267
int (embedding_output_layer ),
267
- False ,
268
+ save_only_last ,
268
269
None if len (gpu_ids ) > 1 else device ,
269
270
)
270
271
@@ -380,6 +381,9 @@ def train_all(
380
381
fp16 = gr .Checkbox (
381
382
label = "FP16" , value = half_support , disabled = not half_support
382
383
)
384
+ save_only_last = gr .Checkbox (
385
+ label = "Save only the latest G and D files" , value = False
386
+ )
383
387
with gr .Row (equal_height = False ):
384
388
augment = gr .Checkbox (label = "Augment" , value = False )
385
389
augment_from_pretrain = gr .Checkbox (
@@ -477,6 +481,7 @@ def train_all(
477
481
save_every_epoch ,
478
482
save_wav_with_checkpoint ,
479
483
fp16 ,
484
+ save_only_last ,
480
485
pre_trained_generator ,
481
486
pre_trained_discriminator ,
482
487
run_train_index ,
0 commit comments