Skip to content

Commit 5b6c8b7

Browse files
authored
DLC 3.0 - Function to export PyTorch models (#2800)
* first implementation of pytorch model export * use loader pose_task * updated userguide * added test; modularized code * improve tests * improved tests; implemented iteration for export function. * bug fix * ux improvements * fix typing * improved docs * more general model prefix * lazy-load train/test split * isort + black * isort + black
1 parent d940483 commit 5b6c8b7

File tree

6 files changed

+599
-28
lines changed

6 files changed

+599
-28
lines changed

deeplabcut/compat.py

Lines changed: 37 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -53,7 +53,7 @@ def get_available_aug_methods(engine: Engine) -> tuple[str, ...]:
5353
if engine == Engine.TF:
5454
return "imgaug", "default", "deterministic", "scalecrop", "tensorpack"
5555
elif engine == Engine.PYTORCH:
56-
return ("albumentations", )
56+
return ("albumentations",)
5757

5858
raise RuntimeError(f"Unknown augmentation for engine: {engine}")
5959

@@ -218,6 +218,7 @@ def train_network(
218218

219219
if engine == Engine.TF:
220220
from deeplabcut.pose_estimation_tensorflow import train_network
221+
221222
if max_snapshots_to_keep is None:
222223
max_snapshots_to_keep = 5
223224

@@ -239,6 +240,7 @@ def train_network(
239240
)
240241
elif engine == Engine.PYTORCH:
241242
from deeplabcut.pose_estimation_pytorch.apis import train_network
243+
242244
_update_device(gputouse, torch_kwargs)
243245
if "display_iters" not in torch_kwargs:
244246
torch_kwargs["display_iters"] = displayiters
@@ -299,14 +301,18 @@ def return_train_network_path(
299301

300302
if engine == Engine.TF:
301303
from deeplabcut.pose_estimation_tensorflow import return_train_network_path
304+
302305
return return_train_network_path(
303306
config,
304307
shuffle=shuffle,
305308
trainingsetindex=trainingsetindex,
306309
modelprefix=modelprefix,
307310
)
308311
elif engine == Engine.PYTORCH:
309-
from deeplabcut.pose_estimation_pytorch.apis.utils import return_train_network_path
312+
from deeplabcut.pose_estimation_pytorch.apis.utils import (
313+
return_train_network_path,
314+
)
315+
310316
return return_train_network_path(
311317
config,
312318
shuffle=shuffle,
@@ -458,6 +464,7 @@ def evaluate_network(
458464

459465
if engine == Engine.TF:
460466
from deeplabcut.pose_estimation_tensorflow import evaluate_network
467+
461468
return evaluate_network(
462469
str(config),
463470
Shuffles=Shuffles,
@@ -473,6 +480,7 @@ def evaluate_network(
473480
)
474481
elif engine == Engine.PYTORCH:
475482
from deeplabcut.pose_estimation_pytorch.apis import evaluate_network
483+
476484
_update_device(gputouse, torch_kwargs)
477485
return evaluate_network(
478486
config,
@@ -553,6 +561,7 @@ def return_evaluate_network_data(
553561

554562
if engine == Engine.TF:
555563
from deeplabcut.pose_estimation_tensorflow import return_evaluate_network_data
564+
556565
return return_evaluate_network_data(
557566
config,
558567
shuffle=shuffle,
@@ -816,6 +825,7 @@ def analyze_videos(
816825

817826
if engine == Engine.TF:
818827
from deeplabcut.pose_estimation_tensorflow import analyze_videos
828+
819829
kwargs = {}
820830
if use_openvino is not None: # otherwise default comes from tensorflow API
821831
kwargs["use_openvino"] = use_openvino
@@ -846,6 +856,7 @@ def analyze_videos(
846856
)
847857
elif engine == Engine.PYTORCH:
848858
from deeplabcut.pose_estimation_pytorch.apis import analyze_videos
859+
849860
_update_device(gputouse, torch_kwargs)
850861

851862
if batchsize is not None:
@@ -907,6 +918,7 @@ def create_tracking_dataset(
907918

908919
if engine == Engine.TF:
909920
from deeplabcut.pose_estimation_tensorflow import create_tracking_dataset
921+
910922
return create_tracking_dataset(
911923
config,
912924
videos,
@@ -994,6 +1006,7 @@ def analyze_time_lapse_frames(
9941006

9951007
if engine == Engine.TF:
9961008
from deeplabcut.pose_estimation_tensorflow import analyze_time_lapse_frames
1009+
9971010
return analyze_time_lapse_frames(
9981011
config,
9991012
directory,
@@ -1120,6 +1133,7 @@ def convert_detections2tracklets(
11201133

11211134
if engine == Engine.TF:
11221135
from deeplabcut.pose_estimation_tensorflow import convert_detections2tracklets
1136+
11231137
return convert_detections2tracklets(
11241138
config,
11251139
videos,
@@ -1233,6 +1247,7 @@ def extract_maps(
12331247

12341248
if engine == Engine.TF:
12351249
from deeplabcut.pose_estimation_tensorflow import extract_maps
1250+
12361251
return extract_maps(
12371252
config,
12381253
shuffle=shuffle,
@@ -1244,6 +1259,7 @@ def extract_maps(
12441259
)
12451260
elif engine == Engine.PYTORCH:
12461261
from deeplabcut.pose_estimation_pytorch import extract_maps
1262+
12471263
return extract_maps(
12481264
config,
12491265
shuffle=shuffle,
@@ -1404,6 +1420,7 @@ def extract_save_all_maps(
14041420

14051421
if engine == Engine.TF:
14061422
from deeplabcut.pose_estimation_tensorflow import extract_save_all_maps
1423+
14071424
return extract_save_all_maps(
14081425
config,
14091426
shuffle=shuffle,
@@ -1419,6 +1436,7 @@ def extract_save_all_maps(
14191436
)
14201437
elif engine == Engine.PYTORCH:
14211438
from deeplabcut.pose_estimation_pytorch import extract_save_all_maps
1439+
14221440
return extract_save_all_maps(
14231441
config,
14241442
shuffle=shuffle,
@@ -1450,14 +1468,12 @@ def export_model(
14501468
wipepaths: bool = False,
14511469
modelprefix: str = "",
14521470
engine: Engine | None = None,
1453-
):
1471+
) -> None:
14541472
"""Export DeepLabCut models for the model zoo or for live inference.
14551473
14561474
Saves the pose configuration, snapshot files, and frozen TF graph of the model to
1457-
directory named exported-models within the project directory
1458-
1459-
This function is only implemented for tensorflow models/shuffles, and will throw
1460-
an error if called with a PyTorch shuffle.
1475+
directory named exported-models within the project directory (and an
1476+
`exported-models-pytorch` directory for PyTorch models).
14611477
14621478
Parameters
14631479
-----------
@@ -1514,6 +1530,7 @@ def export_model(
15141530

15151531
if engine == Engine.TF:
15161532
from deeplabcut.pose_estimation_tensorflow import export_model
1533+
15171534
return export_model(
15181535
cfg_path=cfg_path,
15191536
shuffle=shuffle,
@@ -1526,6 +1543,19 @@ def export_model(
15261543
wipepaths=wipepaths,
15271544
modelprefix=modelprefix,
15281545
)
1546+
elif engine == Engine.PYTORCH:
1547+
from deeplabcut.pose_estimation_pytorch.apis.export import export_model
1548+
1549+
return export_model(
1550+
config=cfg_path,
1551+
shuffle=shuffle,
1552+
trainingsetindex=trainingsetindex,
1553+
snapshotindex=snapshotindex,
1554+
iteration=iteration,
1555+
overwrite=overwrite,
1556+
wipe_paths=wipepaths,
1557+
modelprefix=modelprefix,
1558+
)
15291559

15301560
raise NotImplementedError(f"This function is not implemented for {engine}")
15311561

deeplabcut/pose_estimation_pytorch/apis/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818
convert_detections2tracklets,
1919
)
2020
from deeplabcut.pose_estimation_pytorch.apis.evaluate import evaluate_network
21+
from deeplabcut.pose_estimation_pytorch.apis.export import export_model
2122
from deeplabcut.pose_estimation_pytorch.apis.train import train_network
2223
from deeplabcut.pose_estimation_pytorch.apis.visualization import (
2324
extract_maps,
Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
#
2+
# DeepLabCut Toolbox (deeplabcut.org)
3+
# © A. & M.W. Mathis Labs
4+
# https://github.com/DeepLabCut/DeepLabCut
5+
#
6+
# Please see AUTHORS for contributors.
7+
# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
8+
#
9+
# Licensed under GNU Lesser General Public License v3.0
10+
#
11+
"""Code to export DeepLabCut models for DLCLive inference"""
12+
import copy
13+
from pathlib import Path
14+
15+
import torch
16+
17+
import deeplabcut.pose_estimation_pytorch.apis.utils as utils
18+
import deeplabcut.pose_estimation_pytorch.data as dlc3_data
19+
import deeplabcut.utils.auxiliaryfunctions as af
20+
from deeplabcut.pose_estimation_pytorch.runners.snapshots import Snapshot
21+
from deeplabcut.pose_estimation_pytorch.task import Task
22+
23+
24+
def export_model(
25+
config: str | Path,
26+
shuffle: int = 1,
27+
trainingsetindex: int = 0,
28+
snapshotindex: int | None = None,
29+
detector_snapshot_index: int | None = None,
30+
iteration: int | None = None,
31+
overwrite: bool = False,
32+
wipe_paths: bool = False,
33+
modelprefix: str | None = None,
34+
) -> None:
35+
"""Export DeepLabCut models for live inference.
36+
37+
Saves the pytorch_config.yaml configuration, snapshot files, of the model to a
38+
directory named exported-models-pytorch within the project directory.
39+
40+
Args:
41+
config: Path of the project configuration file
42+
shuffle : The shuffle of the model to export.
43+
trainingsetindex: The index of the training fraction for the model you wish to
44+
export.
45+
snapshotindex: The snapshot index for the weights you wish to export. If None,
46+
uses the snapshotindex as defined in ``config.yaml``.
47+
detector_snapshot_index: Only for TD models. If defined, uses the detector with
48+
the given index for pose estimation. If None, uses the snapshotindex as
49+
defined in the project ``config.yaml``.
50+
iteration: The project iteration (active learning loop) you wish to export. If
51+
None, the iteration listed in the project config file is used.
52+
overwrite : bool, optional
53+
If the model you wish to export has already been exported, whether to
54+
overwrite. default = False
55+
wipe_paths : bool, optional
56+
Removes the actual path of your project and the init_weights from the
57+
``pytorch_config.yaml``.
58+
modelprefix: Directory containing the deeplabcut models to use when evaluating
59+
the network. By default, the models are assumed to exist in the project
60+
folder.
61+
62+
Raises:
63+
ValueError: If no snapshots could be found for the shuffle.
64+
ValueError: If a top-down model is exported but no detector snapshots are found.
65+
66+
Examples:
67+
Export the last stored snapshot for model trained with shuffle 3:
68+
>>> import deeplabcut
69+
>>> deeplabcut.export_model(
70+
>>> "/analysis/project/reaching-task/config.yaml",
71+
>>> shuffle=3,
72+
>>> snapshotindex=-1,
73+
>>> )
74+
"""
75+
cfg = af.read_config(str(config))
76+
if iteration is not None:
77+
cfg["iteration"] = iteration
78+
79+
loader = dlc3_data.DLCLoader(
80+
config=cfg,
81+
trainset_index=trainingsetindex,
82+
shuffle=shuffle,
83+
modelprefix="" if modelprefix is None else modelprefix,
84+
)
85+
86+
if snapshotindex is None:
87+
snapshotindex = loader.project_cfg["snapshotindex"]
88+
snapshots = utils.get_model_snapshots(
89+
snapshotindex, loader.model_folder, loader.pose_task
90+
)
91+
92+
if len(snapshots) == 0:
93+
raise ValueError(
94+
f"Could not find any snapshots to export in ``{loader.model_folder}`` for "
95+
f"``snapshotindex={snapshotindex}``."
96+
)
97+
98+
detector_snapshots = [None]
99+
if loader.pose_task == Task.TOP_DOWN:
100+
if detector_snapshot_index is None:
101+
detector_snapshot_index = loader.project_cfg["detector_snapshotindex"]
102+
detector_snapshots = utils.get_model_snapshots(
103+
detector_snapshot_index, loader.model_folder, Task.DETECT
104+
)
105+
106+
if len(detector_snapshots) == 0:
107+
raise ValueError(
108+
"Attempting to export a top-down pose estimation model but no detector "
109+
f"snapshots were found in ``{loader.model_folder}`` for "
110+
f"``detector_snapshot_index={detector_snapshot_index}``. You must "
111+
f"export a detector snapshot with a top-down pose estimation model."
112+
)
113+
114+
export_folder_name = get_export_folder_name(loader)
115+
export_dir = loader.project_path / "exported-models-pytorch" / export_folder_name
116+
export_dir.mkdir(exist_ok=True, parents=True)
117+
118+
load_kwargs = dict(map_location="cpu", weights_only=True)
119+
for det_snapshot in detector_snapshots:
120+
detector_weights = None
121+
if det_snapshot is not None:
122+
detector_weights = torch.load(det_snapshot.path, **load_kwargs)["model"]
123+
124+
for snapshot in snapshots:
125+
export_filename = get_export_filename(loader, snapshot, det_snapshot)
126+
export_path = export_dir / export_filename
127+
if export_path.exists() and not overwrite:
128+
continue
129+
130+
model_cfg = copy.deepcopy(loader.model_cfg)
131+
if wipe_paths:
132+
wipe_paths_from_model_config(model_cfg)
133+
134+
pose_weights = torch.load(snapshot.path, **load_kwargs)["model"]
135+
export_dict = dict(config=model_cfg, pose=pose_weights)
136+
if detector_weights is not None:
137+
export_dict["detector"] = detector_weights
138+
139+
torch.save(export_dict, export_path)
140+
141+
142+
def get_export_folder_name(loader: dlc3_data.DLCLoader) -> str:
143+
"""
144+
Args:
145+
loader: The loader for the shuffle for which we want to export models.
146+
147+
Returns:
148+
The name of the folder in which exported models should be placed for a shuffle.
149+
"""
150+
return (
151+
f"DLC_{loader.project_cfg['Task']}_{loader.model_cfg['net_type']}_"
152+
f"iteration-{loader.project_cfg['iteration']}_shuffle-{loader.shuffle}"
153+
)
154+
155+
156+
def get_export_filename(
157+
loader: dlc3_data.DLCLoader,
158+
snapshot: Snapshot,
159+
detector_snapshot: Snapshot | None = None,
160+
) -> str:
161+
"""
162+
Args:
163+
loader: The loader for the shuffle for which we want to export models.
164+
snapshot: The pose model snapshot to export.
165+
detector_snapshot: The detector snapshot to export, for top-down models.
166+
167+
Returns:
168+
The name of the file in which the exported model should be stored.
169+
"""
170+
export_filename = get_export_folder_name(loader)
171+
if detector_snapshot is not None:
172+
export_filename += "_snapshot-detector" + detector_snapshot.uid()
173+
export_filename += "_snapshot-" + snapshot.uid()
174+
return export_filename + ".pt"
175+
176+
177+
def wipe_paths_from_model_config(model_cfg: dict) -> None:
178+
"""
179+
Removes all paths from the contents of the ``pytorch_config`` file.
180+
181+
Args:
182+
model_cfg: The model configuration to wipe.
183+
"""
184+
model_cfg["metadata"]["project_path"] = ""
185+
model_cfg["metadata"]["pose_config_path"] = ""
186+
if "weight_init" in model_cfg["train_settings"]:
187+
model_cfg["train_settings"]["weight_init"] = None
188+
if "resume_training_from" in model_cfg:
189+
model_cfg["resume_training_from"] = None
190+
if "resume_training_from" in model_cfg.get("detector", {}):
191+
model_cfg["detector"]["resume_training_from"] = None

0 commit comments

Comments
 (0)