Skip to content
Merged
Show file tree
Hide file tree
Changes from 6 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 5 additions & 1 deletion deeplabcut/core/metrics/bbox.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
"""
from __future__ import annotations

from unittest.mock import Mock, patch

import numpy as np

try:
Expand All @@ -26,6 +28,8 @@
with_pycocotools = False


@patch("pycocotools.coco.print", Mock())
@patch("pycocotools.cocoeval.print", Mock())
def compute_bbox_metrics(
ground_truth: dict[str, dict],
detections: dict[str, dict],
Expand Down Expand Up @@ -150,6 +154,6 @@ def _get_metric(
if len(s[s > -1]) == 0:
mean_s = -1
else:
mean_s = np.mean(s[s > -1])
mean_s = 100 * np.mean(s[s > -1])

return f"{metric_name}@{thresh}", mean_s
2 changes: 2 additions & 0 deletions deeplabcut/pose_estimation_pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,4 +25,6 @@
PoseDatasetParameters,
)
from deeplabcut.pose_estimation_pytorch.data.dlcloader import DLCLoader
from deeplabcut.pose_estimation_pytorch.runners.snapshots import TorchSnapshotManager
from deeplabcut.pose_estimation_pytorch.task import Task
from deeplabcut.pose_estimation_pytorch.utils import fix_seeds
9 changes: 6 additions & 3 deletions deeplabcut/pose_estimation_pytorch/apis/train.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ def train(
device: str | None = "cpu",
gpus: list[int] | None = None,
logger_config: dict | None = None,
snapshot_path: str | None = None,
snapshot_path: str | Path | None = None,
transform: A.BaseCompose | None = None,
inference_transform: A.BaseCompose | None = None,
max_snapshots_to_keep: int | None = None,
Expand Down Expand Up @@ -111,6 +111,9 @@ def train(
if device == "mps" and task == Task.DETECT:
device = "cpu" # FIXME: Cannot train detectors on MPS

if snapshot_path is None:
snapshot_path = run_config.get("resume_training_from")

model.to(device) # Move model before giving its parameters to the optimizer
runner = build_training_runner(
runner_config=run_config["runner"],
Expand Down Expand Up @@ -197,8 +200,8 @@ def train_network(
trainingsetindex: int = 0,
modelprefix: str = "",
device: str | None = None,
snapshot_path: str | None = None,
detector_path: str | None = None,
snapshot_path: str | Path | None = None,
detector_path: str | Path | None = None,
batch_size: int | None = None,
epochs: int | None = None,
save_epochs: int | None = None,
Expand Down
4 changes: 2 additions & 2 deletions deeplabcut/pose_estimation_pytorch/apis/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,9 +163,9 @@ def get_model_snapshots(
raise ValueError(f"No best snapshot found in {model_folder}")
snapshots = [best_snapshot]
elif isinstance(index, str) and index.lower() == "all":
snapshots = snapshot_manager.snapshots(include_best=True)
snapshots = snapshot_manager.snapshots()
elif isinstance(index, int):
all_snapshots = snapshot_manager.snapshots(include_best=True)
all_snapshots = snapshot_manager.snapshots()
if (
len(all_snapshots) == 0
or len(all_snapshots) <= index
Expand Down
45 changes: 45 additions & 0 deletions deeplabcut/pose_estimation_pytorch/config/base/base_detector.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,45 @@
data:
colormode: RGB
inference:
normalize_images: true
train:
affine:
p: 0.5
rotation: 30
scaling: [ 1.0, 1.0 ]
translation: 40
collate:
type: ResizeFromDataSizeCollate
min_scale: 0.4
max_scale: 1.0
min_short_side: 128
max_short_side: 1152
multiple_of: 32
to_square: false
hflip: true
normalize_images: true
device: auto
runner:
type: DetectorTrainingRunner
key_metric: "test.mAP@50:95"
key_metric_asc: true
eval_interval: 10
optimizer:
type: AdamW
params:
lr: 1e-4
scheduler:
type: LRListScheduler
params:
milestones: [ 160 ]
lr_list: [ [ 1e-5 ] ]
snapshots:
max_snapshots: 5
save_epochs: 25
save_optimizer_state: false
train_settings:
batch_size: 1
dataloader_workers: 0
dataloader_pin_memory: true
display_iters: 500
epochs: 250
49 changes: 0 additions & 49 deletions deeplabcut/pose_estimation_pytorch/config/base/detector.yaml

This file was deleted.

Original file line number Diff line number Diff line change
@@ -1,48 +1,5 @@
data:
colormode: RGB
inference:
normalize_images: true
train:
affine:
p: 0.5
rotation: 30
scaling: [ 1.0, 1.0 ]
translation: 40
collate:
type: ResizeFromDataSizeCollate
min_scale: 0.4
max_scale: 1.0
min_short_side: 128
max_short_side: 1152
multiple_of: 32
to_square: false
hflip: true
normalize_images: true
device: auto
model:
type: FasterRCNN
freeze_bn_stats: true
freeze_bn_weights: false
variant: fasterrcnn_mobilenet_v3_large_fpn
runner:
type: DetectorTrainingRunner
eval_interval: 1
optimizer:
type: AdamW
params:
lr: 1e-4
scheduler:
type: LRListScheduler
params:
milestones: [ 160 ]
lr_list: [ [ 1e-5 ] ]
snapshots:
max_snapshots: 5
save_epochs: 25
save_optimizer_state: false
train_settings:
batch_size: 1
dataloader_workers: 0
dataloader_pin_memory: false
display_iters: 500
epochs: 250
Original file line number Diff line number Diff line change
@@ -1,48 +1,5 @@
data:
colormode: RGB
inference:
normalize_images: true
train:
affine:
p: 0.5
rotation: 30
scaling: [ 1.0, 1.0 ]
translation: 40
collate:
type: ResizeFromDataSizeCollate
min_scale: 0.4
max_scale: 1.0
min_short_side: 128
max_short_side: 1152
multiple_of: 32
to_square: false
hflip: true
normalize_images: true
device: auto
model:
type: FasterRCNN
freeze_bn_stats: true
freeze_bn_weights: false
variant: fasterrcnn_resnet50_fpn_v2
runner:
type: DetectorTrainingRunner
eval_interval: 1
optimizer:
type: AdamW
params:
lr: 1e-4
scheduler:
type: LRListScheduler
params:
milestones: [ 160 ]
lr_list: [ [ 1e-5 ] ]
snapshots:
max_snapshots: 5
save_epochs: 25
save_optimizer_state: false
train_settings:
batch_size: 1
dataloader_workers: 0
dataloader_pin_memory: false
display_iters: 500
epochs: 250
41 changes: 0 additions & 41 deletions deeplabcut/pose_estimation_pytorch/config/detectors/ssdlite.yaml
Original file line number Diff line number Diff line change
@@ -1,47 +1,6 @@
data:
colormode: RGB
inference:
normalize_images: true
train:
affine:
p: 0.5
rotation: 30
scaling: [ 1.0, 1.0 ]
translation: 40
collate:
type: ResizeFromDataSizeCollate
min_scale: 0.4
max_scale: 1.0
min_short_side: 128
max_short_side: 1152
multiple_of: 32
to_square: false
hflip: true
normalize_images: true
device: auto
model:
type: SSDLite
freeze_bn_stats: true
freeze_bn_weights: false
runner:
type: DetectorTrainingRunner
eval_interval: 1
optimizer:
type: AdamW
params:
lr: 1e-4
scheduler:
type: LRListScheduler
params:
milestones: [ 160 ]
lr_list: [ [ 1e-5 ] ]
snapshots:
max_snapshots: 5
save_epochs: 25
save_optimizer_state: false
train_settings:
batch_size: 16
dataloader_workers: 0
dataloader_pin_memory: false
display_iters: 500
epochs: 250
10 changes: 5 additions & 5 deletions deeplabcut/pose_estimation_pytorch/config/make_pose_config.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,14 +363,14 @@ def add_detector(

detector_type = detector_type.lower()
config = copy.deepcopy(config)
detector_config = read_config_as_dict(
configs_dir / "detectors" / f"{detector_type}.yaml"
detector_config = update_config(
read_config_as_dict(configs_dir / "base" / "base_detector.yaml"),
read_config_as_dict(configs_dir / "detectors" / f"{detector_type}.yaml"),
)
detector_config = replace_default_values(
detector_config,
num_individuals=num_individuals,
detector_config, num_individuals=num_individuals,
)
config["detector"] = detector_config
config["detector"] = dict(sorted(detector_config.items()))
return config


Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -212,14 +212,13 @@ def calc_peak_locations(
locrefs: torch.Tensor,
peak_inds_in_batch: torch.Tensor,
strides: tuple[float, float],
n_decimals: int = 3,
) -> torch.Tensor:
s, b, r, c = peak_inds_in_batch.T
stride_y, stride_x = strides
strides = torch.Tensor((stride_x, stride_y)).to(locrefs.device)
off = locrefs[s, b, :, r, c]
loc = strides * peak_inds_in_batch[:, [3, 2]] + strides // 2 + off
return torch.round(loc, decimals=n_decimals)
return loc

@staticmethod
def compute_edge_costs(
Expand Down Expand Up @@ -331,8 +330,8 @@ def compute_peaks_and_costs(
) -> list[dict[str, NDArray]]:
n_samples, n_channels = heatmaps.shape[:2]
n_bodyparts = n_channels - n_id_channels
pos = self.calc_peak_locations(locrefs, peak_inds_in_batch, strides, n_decimals)
pos = pos.detach().cpu().numpy()
pos = self.calc_peak_locations(locrefs, peak_inds_in_batch, strides)
pos = np.round(pos.detach().cpu().numpy(), decimals=n_decimals)
heatmaps = heatmaps.detach().cpu().numpy()
pafs = pafs.detach().cpu().numpy()
peak_inds_in_batch = peak_inds_in_batch.detach().cpu().numpy()
Expand Down
1 change: 1 addition & 0 deletions deeplabcut/pose_estimation_pytorch/runners/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
InferenceRunner,
PoseInferenceRunner,
)
from deeplabcut.pose_estimation_pytorch.runners.snapshots import TorchSnapshotManager
from deeplabcut.pose_estimation_pytorch.runners.train import (
build_training_runner,
DetectorTrainingRunner,
Expand Down
14 changes: 13 additions & 1 deletion deeplabcut/pose_estimation_pytorch/runners/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,22 @@ def __init__(
gpus: the list of GPU indices to use for multi-GPU training
snapshot_path: the path of a snapshot from which to load model weights
"""
if gpus is None:
gpus = []

if len(gpus) == 1:
if device != "cuda":
raise ValueError(
"When specifying a GPU index to train on, the device must be set "
f"to 'cuda'. Found {device}"
)
device = f"cuda:{gpus[0]}"

self.model = model
self.device = device
self.gpus = gpus
self.snapshot_path = snapshot_path
self._gpus = gpus
self._data_parallel = len(gpus) > 1

@staticmethod
def load_snapshot(
Expand Down
Loading