Skip to content

Commit a67974b

Browse files
committed
Run black
1 parent 0fcabd3 commit a67974b

File tree

3 files changed

+58
-25
lines changed

3 files changed

+58
-25
lines changed

deeplabcut/create_project/modelzoo.py

Lines changed: 51 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -24,10 +24,19 @@
2424
import deeplabcut
2525
from deeplabcut import Engine
2626
from deeplabcut.core.config import read_config_as_dict, write_config
27-
from deeplabcut.generate_training_dataset.metadata import TrainingDatasetMetadata, ShuffleMetadata, DataSplit
28-
from deeplabcut.generate_training_dataset.trainingsetmanipulation import MakeInference_yaml
27+
from deeplabcut.generate_training_dataset.metadata import (
28+
TrainingDatasetMetadata,
29+
ShuffleMetadata,
30+
DataSplit,
31+
)
32+
from deeplabcut.generate_training_dataset.trainingsetmanipulation import (
33+
MakeInference_yaml,
34+
)
2935
from deeplabcut.modelzoo.utils import get_super_animal_project_cfg
30-
from deeplabcut.pose_estimation_pytorch.config.make_pose_config import add_metadata, make_pytorch_test_config
36+
from deeplabcut.pose_estimation_pytorch.config.make_pose_config import (
37+
add_metadata,
38+
make_pytorch_test_config,
39+
)
3140
from deeplabcut.pose_estimation_pytorch.modelzoo.utils import load_super_animal_config
3241
from deeplabcut.utils import auxiliaryfunctions
3342

@@ -344,13 +353,19 @@ def create_pretrained_project_pytorch(
344353
detector_name = "fasterrcnn_resnet50_fpn_v2"
345354

346355
if dataset not in get_available_datasets():
347-
raise ValueError(f"Invalid dataset '{dataset}'. Available datasets are: {get_available_datasets()}")
356+
raise ValueError(
357+
f"Invalid dataset '{dataset}'. Available datasets are: {get_available_datasets()}"
358+
)
348359

349360
if net_name not in get_available_models(dataset):
350-
raise ValueError(f"Invalid net_name '{net_name}' for dataset {dataset}. The following net types are available: {get_available_models(dataset)}")
361+
raise ValueError(
362+
f"Invalid net_name '{net_name}' for dataset {dataset}. The following net types are available: {get_available_models(dataset)}"
363+
)
351364

352365
if detector_name not in get_available_detectors(dataset):
353-
raise ValueError(f"Invalid detector_name '{detector_name}' for dataset {dataset}. The following detectors are available: {get_available_detectors(dataset)}")
366+
raise ValueError(
367+
f"Invalid detector_name '{detector_name}' for dataset {dataset}. The following detectors are available: {get_available_detectors(dataset)}"
368+
)
354369

355370
# Create project
356371
cfg_path = deeplabcut.create_new_project(
@@ -380,7 +395,12 @@ def create_pretrained_project_pytorch(
380395

381396
# Create the shuffle train and test directories
382397
config = read_config_as_dict(cfg_path)
383-
shuffle_dir = Path(cfg_path).parent / auxiliaryfunctions.get_model_folder(trainFraction=config["TrainingFraction"][0], shuffle=1, cfg=config, engine=Engine.PYTORCH)
398+
shuffle_dir = Path(cfg_path).parent / auxiliaryfunctions.get_model_folder(
399+
trainFraction=config["TrainingFraction"][0],
400+
shuffle=1,
401+
cfg=config,
402+
engine=Engine.PYTORCH,
403+
)
384404
train_dir = shuffle_dir / "train"
385405
test_dir = shuffle_dir / "test"
386406
train_dir.mkdir(parents=True, exist_ok=True)
@@ -393,14 +413,14 @@ def create_pretrained_project_pytorch(
393413
download_huggingface_model(
394414
model_name=super_animal_detector_name,
395415
target_dir=str(train_dir),
396-
rename_mapping={f"{super_animal_detector_name}.pt": new_detector_name}
416+
rename_mapping={f"{super_animal_detector_name}.pt": new_detector_name},
397417
)
398418
super_animal_model_name = f"{dataset}_{net_name}"
399419
new_snapshot_name = "snapshot-000.pt"
400420
download_huggingface_model(
401421
model_name=super_animal_model_name,
402422
target_dir=str(train_dir),
403-
rename_mapping={f"{super_animal_model_name}.pt": new_snapshot_name}
423+
rename_mapping={f"{super_animal_model_name}.pt": new_snapshot_name},
404424
)
405425

406426
# Create pytorch_config.yaml
@@ -412,12 +432,16 @@ def create_pretrained_project_pytorch(
412432
)
413433
pytorch_config = add_metadata(config, pytorch_config, train_cfg_path)
414434
pytorch_config["resume_training_from"] = str(train_dir / new_snapshot_name)
415-
pytorch_config["detector"]["resume_training_from"] = str(train_dir / new_detector_name)
435+
pytorch_config["detector"]["resume_training_from"] = str(
436+
train_dir / new_detector_name
437+
)
416438
write_config(train_cfg_path, pytorch_config)
417439

418440
# Create test pose_cfg.yaml
419441
test_cfg_path = test_dir / "pose_cfg.yaml"
420-
make_pytorch_test_config(model_config=pytorch_config, test_config_path=test_cfg_path, save=True)
442+
make_pytorch_test_config(
443+
model_config=pytorch_config, test_config_path=test_cfg_path, save=True
444+
)
421445

422446
# Create inference_cfg.yaml if needed
423447
if multi_animal:
@@ -438,25 +462,23 @@ def create_pretrained_project_pytorch(
438462
return cfg_path, str(train_cfg_path)
439463

440464

441-
def _create_inference_config(inference_cfg_path: str|Path, project_cfg: dict):
465+
def _create_inference_config(inference_cfg_path: str | Path, project_cfg: dict):
442466
inf_updates = dict(
443467
minimalnumberofconnections=int(len(project_cfg["multianimalbodyparts"]) / 2),
444468
topktoretain=len(project_cfg["individuals"]),
445469
withid=project_cfg.get("identity", False),
446470
)
447-
default_inf_path = Path(auxiliaryfunctions.get_deeplabcut_path()) / "inference_cfg.yaml"
448-
MakeInference_yaml(
449-
inf_updates,
450-
inference_cfg_path,
451-
default_inf_path
471+
default_inf_path = (
472+
Path(auxiliaryfunctions.get_deeplabcut_path()) / "inference_cfg.yaml"
452473
)
474+
MakeInference_yaml(inf_updates, inference_cfg_path, default_inf_path)
453475

454476

455477
def create_pretrained_project_tensorflow(
456478
project: str,
457479
experimenter: str,
458480
videos: list[str],
459-
model: str|None = None,
481+
model: str | None = None,
460482
working_directory: str | None = None,
461483
copy_videos: bool = False,
462484
videotype: str = "",
@@ -683,7 +705,10 @@ def create_pretrained_project_tensorflow(
683705
else:
684706
return "N/A", "N/A"
685707

686-
def _create_training_datasets_metadata(config: dict, shuffle_dir_name: str, engine: Engine):
708+
709+
def _create_training_datasets_metadata(
710+
config: dict, shuffle_dir_name: str, engine: Engine
711+
):
687712
# First create the metadata object
688713
metadata = TrainingDatasetMetadata.create(config)
689714

@@ -693,7 +718,7 @@ def _create_training_datasets_metadata(config: dict, shuffle_dir_name: str, engi
693718
train_fraction=config["TrainingFraction"][0],
694719
index=1,
695720
engine=engine,
696-
split=DataSplit(train_indices=(),test_indices=())
721+
split=DataSplit(train_indices=(), test_indices=()),
697722
)
698723

699724
# Add the shuffle to metadata
@@ -717,7 +742,9 @@ def _process_videos(
717742

718743
if analyze_video:
719744
print("Analyzing video...")
720-
deeplabcut.analyze_videos(cfg_path, [video_dir], videotype=video_type, save_as_csv=True)
745+
deeplabcut.analyze_videos(
746+
cfg_path, [video_dir], videotype=video_type, save_as_csv=True
747+
)
721748

722749
if create_labeled_video:
723750
if filtered:
@@ -727,4 +754,6 @@ def _process_videos(
727754
deeplabcut.create_labeled_video(
728755
cfg_path, [video_dir], video_type, draw_skeleton=True, filtered=filtered
729756
)
730-
deeplabcut.plot_trajectories(cfg_path, [video_dir], video_type, filtered=filtered)
757+
deeplabcut.plot_trajectories(
758+
cfg_path, [video_dir], video_type, filtered=filtered
759+
)

deeplabcut/create_project/new.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -26,7 +26,7 @@ def create_new_project(
2626
copy_videos: bool = False,
2727
videotype: str = "",
2828
multianimal: bool = False,
29-
individuals: list[str]|None = None,
29+
individuals: list[str] | None = None,
3030
):
3131
r"""Create the necessary folders and files for a new project.
3232
@@ -247,7 +247,11 @@ def create_new_project(
247247
cfg_file, ruamelFile = auxiliaryfunctions.create_config_template(multianimal)
248248
cfg_file["multianimalproject"] = multianimal
249249
cfg_file["identity"] = False
250-
cfg_file["individuals"] = individuals if individuals else ["individual1", "individual2", "individual3"]
250+
cfg_file["individuals"] = (
251+
individuals
252+
if individuals
253+
else ["individual1", "individual2", "individual3"]
254+
)
251255
cfg_file["multianimalbodyparts"] = ["bodypart1", "bodypart2", "bodypart3"]
252256
cfg_file["uniquebodyparts"] = []
253257
cfg_file["bodyparts"] = "MULTI!"

examples/testscript_superanimal_create_pretrained_project.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,4 +37,4 @@
3737
for directory in dirs_to_delete:
3838
shutil.rmtree(directory)
3939

40-
print("Test passed!")
40+
print("Test passed!")

0 commit comments

Comments
 (0)