Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
improved tests; implemented iteration for export function.
  • Loading branch information
n-poulsen committed Dec 2, 2024
commit 874fc4d07d1ff01318aa6b28c954beb1bcdd2797
8 changes: 5 additions & 3 deletions deeplabcut/pose_estimation_pytorch/apis/export.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import deeplabcut.pose_estimation_pytorch.apis.utils as utils
import deeplabcut.pose_estimation_pytorch.data as dlc3_data
import deeplabcut.utils.auxiliaryfunctions as af
from deeplabcut.pose_estimation_pytorch.runners.snapshots import Snapshot
from deeplabcut.pose_estimation_pytorch.task import Task

Expand Down Expand Up @@ -70,11 +71,12 @@ def export_model(
>>> snapshotindex=-1,
>>> )
"""
cfg = af.read_config(str(config))
if iteration is not None:
raise ValueError(f"TODO(niels)")
cfg["iteration"] = iteration

loader = dlc3_data.DLCLoader(
config=Path(config),
config=cfg,
trainset_index=trainingsetindex,
shuffle=shuffle,
modelprefix=modelprefix,
Expand Down Expand Up @@ -145,7 +147,7 @@ def get_export_folder_name(loader: dlc3_data.DLCLoader) -> str:
The name of the folder in which exported models should be placed for a shuffle.
"""
return (
f"DLC_{loader.project_cfg['task']}_{loader.model_cfg['net_type']}_"
f"DLC_{loader.project_cfg['Task']}_{loader.model_cfg['net_type']}_"
f"iteration-{loader.project_cfg['iteration']}_shuffle-{loader.shuffle}"
)

Expand Down
13 changes: 9 additions & 4 deletions deeplabcut/pose_estimation_pytorch/data/dlcloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,20 +31,25 @@ class DLCLoader(Loader):

def __init__(
self,
config: str | Path,
config: str | Path | dict,
trainset_index: int = 0,
shuffle: int = 0,
modelprefix: str = "",
):
"""
Args:
config: path to the DeepLabCut project config
config: Path to the DeepLabCut project config, or the project config itself
trainset_index: the index of the TrainingsetFraction for which to load data
shuffle: the index of the shuffle for which to load data
modelprefix: the modelprefix for the shuffle
"""
self._project_root = Path(config).parent
self._project_config = af.read_config(str(config))
if isinstance(config, (str, Path)):
self._project_root = Path(config).parent
self._project_config = af.read_config(str(config))
else:
self._project_root = Path(config["project_path"])
self._project_config = config

self._shuffle = shuffle
self._trainset_index = trainset_index
self._train_frac = self._project_config["TrainingFraction"][trainset_index]
Expand Down
6 changes: 5 additions & 1 deletion examples/openfield-Pranav-2018-10-30/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ identity:


# Project path (change when moving around)
project_path: WILL BE AUTOMATICALLY UPDATED BY DEMO CODE
project_path:
/Users/niels/Documents/upamathis/repos/DeepLabCut/examples/openfield-Pranav-2018-10-30


# Default DeepLabCut engine to use for shuffle creation (either pytorch or tensorflow)
Expand All @@ -25,6 +26,9 @@ bodyparts:
- tailbase


# Fraction of video to start/stop when extracting frames for labeling/refinement


# Fraction of video to start/stop when extracting frames for labeling/refinement
start: 0
stop: 1
Expand Down
96 changes: 84 additions & 12 deletions tests/pose_estimation_pytorch/apis/test_apis_export.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
# Licensed under GNU Lesser General Public License v3.0
#
"""Tests exporting models"""
import copy
import shutil
from pathlib import Path
from unittest.mock import Mock, patch

Expand All @@ -21,9 +23,19 @@
from deeplabcut.pose_estimation_pytorch.runners.snapshots import Snapshot


@pytest.fixture()
def project_dir(tmp_path_factory) -> Path:
project_dir = tmp_path_factory.mktemp("tmp-project")
print(f"\nTemporary project directory:")
print(str(project_dir))
print("---")
yield project_dir
shutil.rmtree(str(project_dir))


def _mock_multianimal_project(project_dir: Path):
video_dir = project_dir / "videos"
video_dir.mkdir()
video_dir.mkdir(exist_ok=True)

cfg_file, yaml_file = af.create_config_template(multianimal=True)
cfg_file["Task"] = "mock"
Expand Down Expand Up @@ -56,7 +68,10 @@ def _make_mock_loader(
loader.shuffle = 0

loader.project_cfg = dict(
task=project_task,
project_path=str(project_path),
Task=project_task,
date="Jan12",
TrainingFraction=[0.95],
snapshotindex=default_snapshot_index,
detector_snapshotindex=default_detector_snapshot_index,
iteration=project_iteration,
Expand All @@ -76,11 +91,16 @@ def _make_mock_loader(
return loader


def _get_export_model_data(project_dir, num_snapshots, task):
def _get_export_model_data(
project_dir: Path,
num_snapshots: int,
task: Task,
project_iteration: int = 0,
):
_mock_multianimal_project(project_dir)

model_dir = Path(project_dir) / "fake-shuffle-0"
model_dir.mkdir(exist_ok=True)
model_dir = Path(project_dir) / f"iteration-{project_iteration}" / "fake-shuffle-0"
model_dir.mkdir(exist_ok=True, parents=True)
snapshots = []
snapshot_data = []
for i in range(num_snapshots):
Expand All @@ -105,7 +125,7 @@ def _get_export_model_data(project_dir, num_snapshots, task):
mock_loader = _make_mock_loader(
project_path=project_dir,
project_task="mock",
project_iteration=0,
project_iteration=project_iteration,
model_folder=model_dir,
net_type="fake-net",
pose_task=task,
Expand All @@ -128,13 +148,12 @@ def _get_export_model_data(project_dir, num_snapshots, task):
]
)
def test_export_model(
tmp_path_factory,
project_dir,
task: Task,
num_snapshots: int,
idx: int,
detector_idx: int | None,
):
project_dir = tmp_path_factory.mktemp("tmp-project")
test_data = _get_export_model_data(project_dir, num_snapshots, task)
mock_loader, snapshots, snapshot_data, detector_snapshots, detector_data = test_data

Expand Down Expand Up @@ -179,8 +198,7 @@ def get_mock_loader(*args, **kwargs):

@patch("deeplabcut.pose_estimation_pytorch.apis.export.wipe_paths_from_model_config")
@pytest.mark.parametrize("task", [Task.BOTTOM_UP, Task.TOP_DOWN])
def test_export_model_clear_paths(mock_wipe: Mock, tmp_path_factory, task: Task):
project_dir = tmp_path_factory.mktemp("tmp-project")
def test_export_model_clear_paths(mock_wipe: Mock, project_dir, task: Task):
test_data = _get_export_model_data(project_dir, 1, task)
mock_loader, snapshots, snapshot_data, detector_snapshots, detector_data = test_data

Expand All @@ -199,8 +217,7 @@ def get_mock_loader(*args, **kwargs):

@pytest.mark.parametrize("task", [Task.BOTTOM_UP, Task.TOP_DOWN])
@pytest.mark.parametrize("overwrite", [True, False])
def test_export_overwrite(tmp_path_factory, task: Task, overwrite: bool):
project_dir = tmp_path_factory.mktemp("tmp-project")
def test_export_overwrite(project_dir, task: Task, overwrite: bool):
test_data = _get_export_model_data(project_dir, 1, task)
mock_loader, snapshots, snapshot_data, detector_snapshots, detector_data = test_data
snapshot = snapshots[0]
Expand Down Expand Up @@ -232,3 +249,58 @@ def get_mock_loader(*args, **kwargs):
assert existing_data != exported_data
else:
assert existing_data == exported_data


@pytest.mark.parametrize("task", [Task.BOTTOM_UP, Task.TOP_DOWN])
@pytest.mark.parametrize("iteration", [5, 12])
def test_export_change_iteration(project_dir, task: Task, iteration: int):
test_data = _get_export_model_data(
project_dir, 1, task, project_iteration=0,
)
mock_loader, snapshots, snapshot_data, detector_snapshots, detector_data = test_data
snapshot = snapshots[0]
detector = None if task == Task.BOTTOM_UP else detector_snapshots[0]

loader_diff_iter = _get_export_model_data(
project_dir, 1, task, project_iteration=iteration
)[0]

def get_mock_loader(config, *args, **kwargs):
_loader = copy.deepcopy(mock_loader)
if isinstance(config, dict):
_loader = copy.deepcopy(mock_loader)
_loader.project_cfg = config
return _loader

def read_mock_config(*args, **kwargs):
return copy.deepcopy(mock_loader.project_cfg)

# patch the DLCLoader but also read_config
with patch(
"deeplabcut.pose_estimation_pytorch.apis.export.dlc3_data.DLCLoader",
get_mock_loader,
):
with patch(
"deeplabcut.pose_estimation_pytorch.apis.export.af.read_config",
read_mock_config,
):
# check no exports exist yet
for loader in [mock_loader, loader_diff_iter]:
dir_name = export.get_export_folder_name(loader)
filename = export.get_export_filename(loader, snapshot, detector)
assert not (
project_dir / "exported-models-pytorch" / dir_name / filename
).exists()

# export data
export.export_model(project_dir / "config.yaml", iteration=iteration)

# check the export exists for the correct iteration
for loader, file_should_exist in [
(mock_loader, False), (loader_diff_iter, True)
]:
dir_name = export.get_export_folder_name(loader)
filename = export.get_export_filename(loader, snapshot, detector)
expected = project_dir / "exported-models-pytorch" / dir_name / filename
expected_exists = expected.exists()
assert expected_exists == file_should_exist