2424import deeplabcut
2525from deeplabcut import Engine
2626from deeplabcut .core .config import read_config_as_dict , write_config
27- from deeplabcut .generate_training_dataset .metadata import TrainingDatasetMetadata , ShuffleMetadata , DataSplit
28- from deeplabcut .generate_training_dataset .trainingsetmanipulation import MakeInference_yaml
27+ from deeplabcut .generate_training_dataset .metadata import (
28+ TrainingDatasetMetadata ,
29+ ShuffleMetadata ,
30+ DataSplit ,
31+ )
32+ from deeplabcut .generate_training_dataset .trainingsetmanipulation import (
33+ MakeInference_yaml ,
34+ )
2935from deeplabcut .modelzoo .utils import get_super_animal_project_cfg
30- from deeplabcut .pose_estimation_pytorch .config .make_pose_config import add_metadata , make_pytorch_test_config
36+ from deeplabcut .pose_estimation_pytorch .config .make_pose_config import (
37+ add_metadata ,
38+ make_pytorch_test_config ,
39+ )
3140from deeplabcut .pose_estimation_pytorch .modelzoo .utils import load_super_animal_config
3241from deeplabcut .utils import auxiliaryfunctions
3342
@@ -344,13 +353,19 @@ def create_pretrained_project_pytorch(
344353 detector_name = "fasterrcnn_resnet50_fpn_v2"
345354
346355 if dataset not in get_available_datasets ():
347- raise ValueError (f"Invalid dataset '{ dataset } '. Available datasets are: { get_available_datasets ()} " )
356+ raise ValueError (
357+ f"Invalid dataset '{ dataset } '. Available datasets are: { get_available_datasets ()} "
358+ )
348359
349360 if net_name not in get_available_models (dataset ):
350- raise ValueError (f"Invalid net_name '{ net_name } ' for dataset { dataset } . The following net types are available: { get_available_models (dataset )} " )
361+ raise ValueError (
362+ f"Invalid net_name '{ net_name } ' for dataset { dataset } . The following net types are available: { get_available_models (dataset )} "
363+ )
351364
352365 if detector_name not in get_available_detectors (dataset ):
353- raise ValueError (f"Invalid detector_name '{ detector_name } ' for dataset { dataset } . The following detectors are available: { get_available_detectors (dataset )} " )
366+ raise ValueError (
367+ f"Invalid detector_name '{ detector_name } ' for dataset { dataset } . The following detectors are available: { get_available_detectors (dataset )} "
368+ )
354369
355370 # Create project
356371 cfg_path = deeplabcut .create_new_project (
@@ -380,7 +395,12 @@ def create_pretrained_project_pytorch(
380395
381396 # Create the shuffle train and test directories
382397 config = read_config_as_dict (cfg_path )
383- shuffle_dir = Path (cfg_path ).parent / auxiliaryfunctions .get_model_folder (trainFraction = config ["TrainingFraction" ][0 ], shuffle = 1 , cfg = config , engine = Engine .PYTORCH )
398+ shuffle_dir = Path (cfg_path ).parent / auxiliaryfunctions .get_model_folder (
399+ trainFraction = config ["TrainingFraction" ][0 ],
400+ shuffle = 1 ,
401+ cfg = config ,
402+ engine = Engine .PYTORCH ,
403+ )
384404 train_dir = shuffle_dir / "train"
385405 test_dir = shuffle_dir / "test"
386406 train_dir .mkdir (parents = True , exist_ok = True )
@@ -393,14 +413,14 @@ def create_pretrained_project_pytorch(
393413 download_huggingface_model (
394414 model_name = super_animal_detector_name ,
395415 target_dir = str (train_dir ),
396- rename_mapping = {f"{ super_animal_detector_name } .pt" : new_detector_name }
416+ rename_mapping = {f"{ super_animal_detector_name } .pt" : new_detector_name },
397417 )
398418 super_animal_model_name = f"{ dataset } _{ net_name } "
399419 new_snapshot_name = "snapshot-000.pt"
400420 download_huggingface_model (
401421 model_name = super_animal_model_name ,
402422 target_dir = str (train_dir ),
403- rename_mapping = {f"{ super_animal_model_name } .pt" : new_snapshot_name }
423+ rename_mapping = {f"{ super_animal_model_name } .pt" : new_snapshot_name },
404424 )
405425
406426 # Create pytorch_config.yaml
@@ -412,12 +432,16 @@ def create_pretrained_project_pytorch(
412432 )
413433 pytorch_config = add_metadata (config , pytorch_config , train_cfg_path )
414434 pytorch_config ["resume_training_from" ] = str (train_dir / new_snapshot_name )
415- pytorch_config ["detector" ]["resume_training_from" ] = str (train_dir / new_detector_name )
435+ pytorch_config ["detector" ]["resume_training_from" ] = str (
436+ train_dir / new_detector_name
437+ )
416438 write_config (train_cfg_path , pytorch_config )
417439
418440 # Create test pose_cfg.yaml
419441 test_cfg_path = test_dir / "pose_cfg.yaml"
420- make_pytorch_test_config (model_config = pytorch_config , test_config_path = test_cfg_path , save = True )
442+ make_pytorch_test_config (
443+ model_config = pytorch_config , test_config_path = test_cfg_path , save = True
444+ )
421445
422446 # Create inference_cfg.yaml if needed
423447 if multi_animal :
@@ -438,25 +462,23 @@ def create_pretrained_project_pytorch(
438462 return cfg_path , str (train_cfg_path )
439463
440464
441- def _create_inference_config (inference_cfg_path : str | Path , project_cfg : dict ):
465+ def _create_inference_config (inference_cfg_path : str | Path , project_cfg : dict ):
442466 inf_updates = dict (
443467 minimalnumberofconnections = int (len (project_cfg ["multianimalbodyparts" ]) / 2 ),
444468 topktoretain = len (project_cfg ["individuals" ]),
445469 withid = project_cfg .get ("identity" , False ),
446470 )
447- default_inf_path = Path (auxiliaryfunctions .get_deeplabcut_path ()) / "inference_cfg.yaml"
448- MakeInference_yaml (
449- inf_updates ,
450- inference_cfg_path ,
451- default_inf_path
471+ default_inf_path = (
472+ Path (auxiliaryfunctions .get_deeplabcut_path ()) / "inference_cfg.yaml"
452473 )
474+ MakeInference_yaml (inf_updates , inference_cfg_path , default_inf_path )
453475
454476
455477def create_pretrained_project_tensorflow (
456478 project : str ,
457479 experimenter : str ,
458480 videos : list [str ],
459- model : str | None = None ,
481+ model : str | None = None ,
460482 working_directory : str | None = None ,
461483 copy_videos : bool = False ,
462484 videotype : str = "" ,
@@ -683,7 +705,10 @@ def create_pretrained_project_tensorflow(
683705 else :
684706 return "N/A" , "N/A"
685707
686- def _create_training_datasets_metadata (config : dict , shuffle_dir_name : str , engine : Engine ):
708+
709+ def _create_training_datasets_metadata (
710+ config : dict , shuffle_dir_name : str , engine : Engine
711+ ):
687712 # First create the metadata object
688713 metadata = TrainingDatasetMetadata .create (config )
689714
@@ -693,7 +718,7 @@ def _create_training_datasets_metadata(config: dict, shuffle_dir_name: str, engi
693718 train_fraction = config ["TrainingFraction" ][0 ],
694719 index = 1 ,
695720 engine = engine ,
696- split = DataSplit (train_indices = (),test_indices = ())
721+ split = DataSplit (train_indices = (), test_indices = ()),
697722 )
698723
699724 # Add the shuffle to metadata
@@ -717,7 +742,9 @@ def _process_videos(
717742
718743 if analyze_video :
719744 print ("Analyzing video..." )
720- deeplabcut .analyze_videos (cfg_path , [video_dir ], videotype = video_type , save_as_csv = True )
745+ deeplabcut .analyze_videos (
746+ cfg_path , [video_dir ], videotype = video_type , save_as_csv = True
747+ )
721748
722749 if create_labeled_video :
723750 if filtered :
@@ -727,4 +754,6 @@ def _process_videos(
727754 deeplabcut .create_labeled_video (
728755 cfg_path , [video_dir ], video_type , draw_skeleton = True , filtered = filtered
729756 )
730- deeplabcut .plot_trajectories (cfg_path , [video_dir ], video_type , filtered = filtered )
757+ deeplabcut .plot_trajectories (
758+ cfg_path , [video_dir ], video_type , filtered = filtered
759+ )
0 commit comments