Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Run black
  • Loading branch information
maximpavliv committed Feb 25, 2025
commit a67974bcbe4a75dd67b8af1c2e95fdbdb2d1ba25
73 changes: 51 additions & 22 deletions deeplabcut/create_project/modelzoo.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,19 @@
import deeplabcut
from deeplabcut import Engine
from deeplabcut.core.config import read_config_as_dict, write_config
from deeplabcut.generate_training_dataset.metadata import TrainingDatasetMetadata, ShuffleMetadata, DataSplit
from deeplabcut.generate_training_dataset.trainingsetmanipulation import MakeInference_yaml
from deeplabcut.generate_training_dataset.metadata import (
TrainingDatasetMetadata,
ShuffleMetadata,
DataSplit,
)
from deeplabcut.generate_training_dataset.trainingsetmanipulation import (
MakeInference_yaml,
)
from deeplabcut.modelzoo.utils import get_super_animal_project_cfg
from deeplabcut.pose_estimation_pytorch.config.make_pose_config import add_metadata, make_pytorch_test_config
from deeplabcut.pose_estimation_pytorch.config.make_pose_config import (
add_metadata,
make_pytorch_test_config,
)
from deeplabcut.pose_estimation_pytorch.modelzoo.utils import load_super_animal_config
from deeplabcut.utils import auxiliaryfunctions

Expand Down Expand Up @@ -344,13 +353,19 @@ def create_pretrained_project_pytorch(
detector_name = "fasterrcnn_resnet50_fpn_v2"

if dataset not in get_available_datasets():
raise ValueError(f"Invalid dataset '{dataset}'. Available datasets are: {get_available_datasets()}")
raise ValueError(
f"Invalid dataset '{dataset}'. Available datasets are: {get_available_datasets()}"
)

if net_name not in get_available_models(dataset):
raise ValueError(f"Invalid net_name '{net_name}' for dataset {dataset}. The following net types are available: {get_available_models(dataset)}")
raise ValueError(
f"Invalid net_name '{net_name}' for dataset {dataset}. The following net types are available: {get_available_models(dataset)}"
)

if detector_name not in get_available_detectors(dataset):
raise ValueError(f"Invalid detector_name '{detector_name}' for dataset {dataset}. The following detectors are available: {get_available_detectors(dataset)}")
raise ValueError(
f"Invalid detector_name '{detector_name}' for dataset {dataset}. The following detectors are available: {get_available_detectors(dataset)}"
)

# Create project
cfg_path = deeplabcut.create_new_project(
Expand Down Expand Up @@ -380,7 +395,12 @@ def create_pretrained_project_pytorch(

# Create the shuffle train and test directories
config = read_config_as_dict(cfg_path)
shuffle_dir = Path(cfg_path).parent / auxiliaryfunctions.get_model_folder(trainFraction=config["TrainingFraction"][0], shuffle=1, cfg=config, engine=Engine.PYTORCH)
shuffle_dir = Path(cfg_path).parent / auxiliaryfunctions.get_model_folder(
trainFraction=config["TrainingFraction"][0],
shuffle=1,
cfg=config,
engine=Engine.PYTORCH,
)
train_dir = shuffle_dir / "train"
test_dir = shuffle_dir / "test"
train_dir.mkdir(parents=True, exist_ok=True)
Expand All @@ -393,14 +413,14 @@ def create_pretrained_project_pytorch(
download_huggingface_model(
model_name=super_animal_detector_name,
target_dir=str(train_dir),
rename_mapping={f"{super_animal_detector_name}.pt": new_detector_name}
rename_mapping={f"{super_animal_detector_name}.pt": new_detector_name},
)
super_animal_model_name = f"{dataset}_{net_name}"
new_snapshot_name = "snapshot-000.pt"
download_huggingface_model(
model_name=super_animal_model_name,
target_dir=str(train_dir),
rename_mapping={f"{super_animal_model_name}.pt": new_snapshot_name}
rename_mapping={f"{super_animal_model_name}.pt": new_snapshot_name},
)

# Create pytorch_config.yaml
Expand All @@ -412,12 +432,16 @@ def create_pretrained_project_pytorch(
)
pytorch_config = add_metadata(config, pytorch_config, train_cfg_path)
pytorch_config["resume_training_from"] = str(train_dir / new_snapshot_name)
pytorch_config["detector"]["resume_training_from"] = str(train_dir / new_detector_name)
pytorch_config["detector"]["resume_training_from"] = str(
train_dir / new_detector_name
)
write_config(train_cfg_path, pytorch_config)

# Create test pose_cfg.yaml
test_cfg_path = test_dir / "pose_cfg.yaml"
make_pytorch_test_config(model_config=pytorch_config, test_config_path=test_cfg_path, save=True)
make_pytorch_test_config(
model_config=pytorch_config, test_config_path=test_cfg_path, save=True
)

# Create inference_cfg.yaml if needed
if multi_animal:
Expand All @@ -438,25 +462,23 @@ def create_pretrained_project_pytorch(
return cfg_path, str(train_cfg_path)


def _create_inference_config(inference_cfg_path: str|Path, project_cfg: dict):
def _create_inference_config(inference_cfg_path: str | Path, project_cfg: dict):
inf_updates = dict(
minimalnumberofconnections=int(len(project_cfg["multianimalbodyparts"]) / 2),
topktoretain=len(project_cfg["individuals"]),
withid=project_cfg.get("identity", False),
)
default_inf_path = Path(auxiliaryfunctions.get_deeplabcut_path()) / "inference_cfg.yaml"
MakeInference_yaml(
inf_updates,
inference_cfg_path,
default_inf_path
default_inf_path = (
Path(auxiliaryfunctions.get_deeplabcut_path()) / "inference_cfg.yaml"
)
MakeInference_yaml(inf_updates, inference_cfg_path, default_inf_path)


def create_pretrained_project_tensorflow(
project: str,
experimenter: str,
videos: list[str],
model: str|None = None,
model: str | None = None,
working_directory: str | None = None,
copy_videos: bool = False,
videotype: str = "",
Expand Down Expand Up @@ -683,7 +705,10 @@ def create_pretrained_project_tensorflow(
else:
return "N/A", "N/A"

def _create_training_datasets_metadata(config: dict, shuffle_dir_name: str, engine: Engine):

def _create_training_datasets_metadata(
config: dict, shuffle_dir_name: str, engine: Engine
):
# First create the metadata object
metadata = TrainingDatasetMetadata.create(config)

Expand All @@ -693,7 +718,7 @@ def _create_training_datasets_metadata(config: dict, shuffle_dir_name: str, engi
train_fraction=config["TrainingFraction"][0],
index=1,
engine=engine,
split=DataSplit(train_indices=(),test_indices=())
split=DataSplit(train_indices=(), test_indices=()),
)

# Add the shuffle to metadata
Expand All @@ -717,7 +742,9 @@ def _process_videos(

if analyze_video:
print("Analyzing video...")
deeplabcut.analyze_videos(cfg_path, [video_dir], videotype=video_type, save_as_csv=True)
deeplabcut.analyze_videos(
cfg_path, [video_dir], videotype=video_type, save_as_csv=True
)

if create_labeled_video:
if filtered:
Expand All @@ -727,4 +754,6 @@ def _process_videos(
deeplabcut.create_labeled_video(
cfg_path, [video_dir], video_type, draw_skeleton=True, filtered=filtered
)
deeplabcut.plot_trajectories(cfg_path, [video_dir], video_type, filtered=filtered)
deeplabcut.plot_trajectories(
cfg_path, [video_dir], video_type, filtered=filtered
)
8 changes: 6 additions & 2 deletions deeplabcut/create_project/new.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ def create_new_project(
copy_videos: bool = False,
videotype: str = "",
multianimal: bool = False,
individuals: list[str]|None = None,
individuals: list[str] | None = None,
):
r"""Create the necessary folders and files for a new project.

Expand Down Expand Up @@ -247,7 +247,11 @@ def create_new_project(
cfg_file, ruamelFile = auxiliaryfunctions.create_config_template(multianimal)
cfg_file["multianimalproject"] = multianimal
cfg_file["identity"] = False
cfg_file["individuals"] = individuals if individuals else ["individual1", "individual2", "individual3"]
cfg_file["individuals"] = (
individuals
if individuals
else ["individual1", "individual2", "individual3"]
)
cfg_file["multianimalbodyparts"] = ["bodypart1", "bodypart2", "bodypart3"]
cfg_file["uniquebodyparts"] = []
cfg_file["bodyparts"] = "MULTI!"
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -37,4 +37,4 @@
for directory in dirs_to_delete:
shutil.rmtree(directory)

print("Test passed!")
print("Test passed!")