Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Next Next commit
set weights_only parameter when loading model weights
  • Loading branch information
n-poulsen committed Dec 18, 2024
commit 286d8e5827eea152ff8212e113ef70f13de238e5
32 changes: 28 additions & 4 deletions deeplabcut/pose_estimation_pytorch/runners/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@
#
from __future__ import annotations

import logging
import pickle
from abc import ABC
from pathlib import Path
from typing import Generic, TypeVar
Expand Down Expand Up @@ -63,20 +65,42 @@ def load_snapshot(
snapshot_path: str | Path,
device: str,
model: ModelType,
weights_only: bool = True,
) -> dict:
"""Loads the state dict for a model from a file

This method loads a file containing a DeepLabCut PyTorch model snapshot onto
a given device, and sets the model weights using the state_dict.

Args:
snapshot_path: the path containing the model weights to load
device: the device on which the model should be loaded
model: the model for which the weights are loaded
snapshot_path: The path containing the model weights to load
device: The device on which the model should be loaded
model: The model for which the weights are loaded
weights_only: Value for torch.load() `weights_only` parameter. If False, the
python pickle module is used implicitly, which is known to be insecure.
Only set to False if you're loading data that you trust (e.g. snapshots
that you created yourself). For more information, see:
https://pytorch.org/docs/stable/generated/torch.load.html

Returns:
The content of the snapshot file.
"""
snapshot = torch.load(snapshot_path, map_location=device)
try:
snapshot = torch.load(
snapshot_path, map_location=device, weights_only=weights_only,
)
except pickle.UnpicklingError as err:
print(
f"\nFailed to load the snapshot: {snapshot_path}.\n"
"If you trust the snapshot that you're trying to load, you can try "
"calling `Runner.load_snapshot` with `weights_only=False`. See "
"the message below for more information and warnings.\n"
"You can set the `weights_only` parameter in the model configuration ("
"the content of the pytorch_config.yaml), as:\n```\n"
"runner:\n"
" load_weights_only: False\n```\n"
)
raise err

model.load_state_dict(snapshot["model"])
return snapshot
33 changes: 26 additions & 7 deletions deeplabcut/pose_estimation_pytorch/runners/inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,14 +41,21 @@ def __init__(
snapshot_path: str | Path | None = None,
preprocessor: Preprocessor | None = None,
postprocessor: Postprocessor | None = None,
load_weights_only: bool = True,
):
"""
Args:
model: the model to run actions on
device: the device to use (e.g. {'cpu', 'cuda:0', 'mps'})
snapshot_path: if defined, the path of a snapshot from which to load pretrained weights
preprocessor: the preprocessor to use on images before inference
postprocessor: the postprocessor to use on images after inference
model: The model to run actions on
device: The device to use (e.g. {'cpu', 'cuda:0', 'mps'})
snapshot_path: If defined, the path of a snapshot from which to load
pretrained weights
preprocessor: The preprocessor to use on images before inference
postprocessor: The postprocessor to use on images after inference
load_weights_only: Value for the torch.load() `weights_only` parameter. If
False, the python pickle module is used implicitly, which is known to be
insecure. Only set to False if you're loading data that you trust (e.g.
snapshots that you created yourself). For more information, see:
https://pytorch.org/docs/stable/generated/torch.load.html
"""
super().__init__(model=model, device=device, snapshot_path=snapshot_path)
if not isinstance(batch_size, int) or batch_size <= 0:
Expand All @@ -59,7 +66,12 @@ def __init__(
self.postprocessor = postprocessor

if self.snapshot_path is not None and self.snapshot_path != "":
self.load_snapshot(self.snapshot_path, self.device, self.model)
self.load_snapshot(
self.snapshot_path,
self.device,
self.model,
weights_only=load_weights_only,
)

self._batch: torch.Tensor | None = None
self._contexts: list[dict] = []
Expand Down Expand Up @@ -293,6 +305,7 @@ def build_inference_runner(
batch_size: int = 1,
preprocessor: Preprocessor | None = None,
postprocessor: Postprocessor | None = None,
load_weights_only: bool = True,
) -> InferenceRunner:
"""
Build a runner object according to a pytorch configuration file
Expand All @@ -305,9 +318,14 @@ def build_inference_runner(
batch_size: the batch size to use to run inference
preprocessor: the preprocessor to use on images before inference
postprocessor: the postprocessor to use on images after inference
load_weights_only: Value for the torch.load() `weights_only` parameter. If
False, the python pickle module is used implicitly, which is known to be
insecure. Only set to False if you're loading data that you trust (e.g.
snapshots that you created yourself). For more information, see:
https://pytorch.org/docs/stable/generated/torch.load.html

Returns:
the inference runner
The inference runner.
"""
kwargs = dict(
model=model,
Expand All @@ -316,6 +334,7 @@ def build_inference_runner(
batch_size=batch_size,
preprocessor=preprocessor,
postprocessor=postprocessor,
load_weights_only=load_weights_only,
)
if task == Task.DETECT:
return DetectorInferenceRunner(**kwargs)
Expand Down
19 changes: 15 additions & 4 deletions deeplabcut/pose_estimation_pytorch/runners/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,12 @@ class TrainingRunner(Runner, Generic[ModelType], metaclass=ABCMeta):
might be used.
logger: Logger to monitor training (e.g. a WandBLogger).
log_filename: Name of the file in which to store training stats.
load_weights_only: Value for the torch.load() `weights_only` parameter if
`snapshot_path` is not None. If False, the python pickle module is used
implicitly, which is known to be insecure. Only set to False if you're
loading data that you trust (e.g. snapshots that you created yourself). For
more information, see:
https://pytorch.org/docs/stable/generated/torch.load.html
"""

def __init__(
Expand All @@ -78,6 +84,7 @@ def __init__(
load_scheduler_state_dict: bool = True,
logger: BaseLogger | None = None,
log_filename: str = "learning_stats.csv",
load_weights_only: bool = True,
):
super().__init__(
model=model, device=device, gpus=gpus, snapshot_path=snapshot_path
Expand All @@ -104,7 +111,12 @@ def __init__(
self._print_valid_loss = True

if self.snapshot_path:
snapshot = self.load_snapshot(self.snapshot_path, self.device, self.model)
snapshot = self.load_snapshot(
self.snapshot_path,
self.device,
self.model,
weights_only=load_weights_only,
)
self.starting_epoch = snapshot.get("metadata", {}).get("epoch", 0)

if "optimizer" in snapshot:
Expand Down Expand Up @@ -639,9 +651,7 @@ def build_training_runner(
Returns:
the runner that was built
"""
optim_cfg = runner_config["optimizer"]
optim_cls = getattr(torch.optim, optim_cfg["type"])
optimizer = optim_cls(params=model.parameters(), **optim_cfg["params"])
optimizer = build_optimizer(model, runner_config["optimizer"])
scheduler = schedulers.build_scheduler(runner_config.get("scheduler"), optimizer)

# if no custom snapshot prefix is defined, use the default one
Expand All @@ -668,6 +678,7 @@ def build_training_runner(
scheduler=scheduler,
load_scheduler_state_dict=runner_config.get("load_scheduler_state_dict", True),
logger=logger,
load_weights_only=runner_config.get("load_weights_only", True),
)
if task == Task.DETECT:
return DetectorTrainingRunner(**kwargs)
Expand Down
6 changes: 5 additions & 1 deletion examples/openfield-Pranav-2018-10-30/config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,8 @@ identity:


# Project path (change when moving around)
project_path: WILL BE AUTOMATICALLY UPDATED BY DEMO CODE
project_path:
/Users/niels/Documents/upamathis/repos/DeepLabCut/examples/openfield-Pranav-2018-10-30


# Default DeepLabCut engine to use for shuffle creation (either pytorch or tensorflow)
Expand All @@ -25,6 +26,9 @@ bodyparts:
- tailbase


# Fraction of video to start/stop when extracting frames for labeling/refinement


# Fraction of video to start/stop when extracting frames for labeling/refinement
start: 0
stop: 1
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
#
# DeepLabCut Toolbox (deeplabcut.org)
# © A. & M.W. Mathis Labs
# https://github.com/DeepLabCut/DeepLabCut
#
# Please see AUTHORS for contributors.
# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
#
# Licensed under GNU Lesser General Public License v3.0
#
import dlclibrary
import pytest
import torch

from deeplabcut.pose_estimation_pytorch.modelzoo import get_super_animal_snapshot_path


@pytest.mark.skip(reason="require-models")
def test_load_superanimal_models_weights_only():
super_animal_names = dlclibrary.get_available_datasets()
for super_animal in super_animal_names:
print(f"\nTesting {super_animal}")
for detector in dlclibrary.get_available_detectors(super_animal):
print(super_animal, detector)
path = get_super_animal_snapshot_path(super_animal, detector)
snapshot = torch.load(path, map_location="cpu", weights_only=True)

for pose_model in dlclibrary.get_available_models(super_animal):
print(super_animal, pose_model)
path = get_super_animal_snapshot_path(super_animal, pose_model)
snapshot = torch.load(path, map_location="cpu", weights_only=True)
29 changes: 29 additions & 0 deletions tests/pose_estimation_pytorch/runners/test_runners.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
#
# DeepLabCut Toolbox (deeplabcut.org)
# © A. & M.W. Mathis Labs
# https://github.com/DeepLabCut/DeepLabCut
#
# Please see AUTHORS for contributors.
# https://github.com/DeepLabCut/DeepLabCut/blob/main/AUTHORS
#
# Licensed under GNU Lesser General Public License v3.0
#
import pickle
from unittest.mock import Mock

import numpy as np
import pytest
import torch

import deeplabcut.pose_estimation_pytorch.runners as runners


def test_load_snapshot_weights_only_error(tmpdir_factory):
snapshot_dir = tmpdir_factory.mktemp("snapshot-dir")
snapshot_path = snapshot_dir / "snapshot.pt"
torch.save(dict(content=np.zeros(10)), snapshot_path)

with pytest.raises(pickle.UnpicklingError):
runners.Runner.load_snapshot(
snapshot_path, device="cpu", model=Mock(), weights_only=True
)
21 changes: 20 additions & 1 deletion tests/pose_estimation_pytorch/runners/test_runners_inference.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,7 @@
# Licensed under GNU Lesser General Public License v3.0
#
"""Tests inference runners"""
from unittest.mock import Mock
from unittest.mock import Mock, patch

import numpy as np
import pytest
Expand All @@ -18,6 +18,25 @@
import deeplabcut.pose_estimation_pytorch.data.postprocessor as post
import deeplabcut.pose_estimation_pytorch.data.preprocessor as prep
import deeplabcut.pose_estimation_pytorch.runners.inference as inference
from deeplabcut.pose_estimation_pytorch.task import Task


@patch("deeplabcut.pose_estimation_pytorch.runners.train.build_optimizer", Mock())
@pytest.mark.parametrize("task", [Task.DETECT, Task.TOP_DOWN, Task.BOTTOM_UP])
@pytest.mark.parametrize("weights_only", [True, False])
def test_load_weights_only_with_build_training_runner(task: Task, weights_only: bool):
with patch("deeplabcut.pose_estimation_pytorch.runners.base.torch.load") as load:
snapshot = "snapshot.pt"
runner = inference.build_inference_runner(
task=task,
model=Mock(),
device="cpu",
snapshot_path=snapshot,
load_weights_only=weights_only,
)
load.assert_called_once_with(
snapshot, map_location="cpu", weights_only=weights_only
)


class MockInferenceRunner(inference.InferenceRunner):
Expand Down
25 changes: 25 additions & 0 deletions tests/pose_estimation_pytorch/runners/test_runners_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,31 @@

import deeplabcut.pose_estimation_pytorch.runners.schedulers as schedulers
import deeplabcut.pose_estimation_pytorch.runners.train as train_runners
from deeplabcut.pose_estimation_pytorch.task import Task


@patch("deeplabcut.pose_estimation_pytorch.runners.train.build_optimizer", Mock())
@patch("deeplabcut.pose_estimation_pytorch.runners.train.CSVLogger", Mock())
@pytest.mark.parametrize("task", [Task.DETECT, Task.TOP_DOWN, Task.BOTTOM_UP])
@pytest.mark.parametrize("weights_only", [True, False])
def test_load_weights_only_with_build_training_runner(task: Task, weights_only: bool):
runner_config = dict(
optimizer=dict(),
snapshots=dict(max_snapshots=1, save_epochs=5, save_optimizer_state=False),
load_weights_only=weights_only,
)
with patch("deeplabcut.pose_estimation_pytorch.runners.base.torch.load") as load:
train_runners.build_training_runner(
runner_config=runner_config,
model_folder=Mock(),
task=task,
model=Mock(),
device="cpu",
snapshot_path="snapshot.pt",
)
load.assert_called_once_with(
"snapshot.pt", map_location="cpu", weights_only=weights_only
)


@dataclass
Expand Down