Skip to content

Commit 0c45b98

Browse files
authored
Merge pull request #2629 from yeshaokai/shaokai/educational_superanimal_notebook
New PyTorch DeepLabCut SuperAnimal demo notebook
2 parents 83f1acb + 27a58af commit 0c45b98

File tree

9 files changed

+2506
-234
lines changed

9 files changed

+2506
-234
lines changed

deeplabcut/modelzoo/generalized_data_converter/datasets/base.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -153,20 +153,28 @@ def populate_generic(self):
153153
raise NotImplementedError("Must implement this function")
154154

155155
def materialize(
156-
self, proj_root, framework="coco", deepcopy=False, append_image_id=True
156+
self,
157+
proj_root,
158+
framework="coco",
159+
deepcopy=False,
160+
append_image_id=True,
161+
no_image_copy=False,
157162
):
158163
mat_func = mat_func_factory(framework)
159164
self.meta["mat_datasets"] = {self.meta["dataset_name"]: self}
160165
self.meta["imageid2datasetname"] = self.imageid2datasetname
166+
kwargs = dict(deepcopy=deepcopy, append_image_id=append_image_id)
167+
if framework == "coco":
168+
kwargs["no_image_copy"] = no_image_copy
169+
161170
mat_func(
162171
proj_root,
163172
self.generic_train_images,
164173
self.generic_test_images,
165174
self.generic_train_annotations,
166175
self.generic_test_annotations,
167176
self.meta,
168-
deepcopy=deepcopy,
169-
append_image_id=append_image_id,
177+
**kwargs,
170178
)
171179

172180
def whether_anno_image_match(self, images, annotations):

deeplabcut/modelzoo/generalized_data_converter/datasets/materialize.py

Lines changed: 44 additions & 39 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
import os
1313
import pickle
1414
import shutil
15+
from pathlib import Path
1516

1617
import numpy as np
1718
import pandas as pd
@@ -660,9 +661,10 @@ def _generic2coco(
660661
train_annotations,
661662
test_annotations,
662663
meta,
663-
deepcopy=False,
664-
full_image_path=True,
665-
append_image_id=True,
664+
deepcopy: bool = False,
665+
full_image_path: bool = True,
666+
append_image_id: bool = True,
667+
no_image_copy: bool = False,
666668
):
667669
"""
668670
Take generic data and create coco structure
@@ -672,6 +674,17 @@ def _generic2coco(
672674
annotations
673675
- train.json
674676
- test.json
677+
678+
Args:
679+
deepcopy: Only when no_image_copy=False. If False, images are not copied from
680+
their original location and symlinks are created instead.
681+
full_image_path: Only when no_image_copy=False. If True, the ``file_name`` for
682+
the images in the annotation files contain the resolved path to the images.
683+
Otherwise, a relative path is used.
684+
append_image_id: Only when no_image_copy=False. Appends the image IDs in the
685+
dataset to the image names.
686+
no_image_copy: Instead of copying images to the COCO dataset, the full paths to
687+
the images in the original dataset are used in the annotations.
675688
"""
676689

677690
os.makedirs(os.path.join(proj_root, "images"), exist_ok=True)
@@ -693,54 +706,46 @@ def _generic2coco(
693706
broken_links = []
694707
# copying images via symbolic link
695708
for image in train_images + test_images:
696-
src = image["file_name"]
709+
# important to resolve the filepath! Otherwise, errors can occur when running
710+
# this code from Jupyter Notebooks
711+
src = Path(image["file_name"]).resolve()
697712
image_id = image["id"]
698713

699-
if not os.path.exists(src):
714+
if not src.exists():
700715
print("problem comes from", image["source_dataset"])
701716
print(src)
702717
broken_links.append(image_id)
703718
continue
704-
else:
705-
pass
706-
# print ('success comes from', image['source_dataset'])
707-
# print (src)
708-
709-
# in dlc, some images have same name but under different folder
710-
# we used to use a parent folder to distinguish them, but it's only applicable to DLC
711-
# so here it's easier to just append a id into the filename
712719

713-
image_name = src.split(os.sep)[-1]
720+
file_name = str(src)
721+
dest = src
722+
if not no_image_copy:
723+
# in dlc, some images have same name but under different folder
724+
# we used to use a parent folder to distinguish them, but it's only
725+
# applicable to DLC so here it's easier to append an id into the filename
714726

715-
if image_name.count(".") > 1:
716-
sep = image_name.rfind(".")
717-
pre, suffix = image_name[:sep], image_name[sep + 1 :]
718-
else:
719-
# this does not work for image file that looks like image9.5.jpg..
720-
pre, suffix = image_name.split(".")
721-
722-
# not to repeatedly add image id in memory replay training
723-
if append_image_id:
724-
dest_image_name = f"{pre}_{image_id}.{suffix}"
725-
else:
726-
dest_image_name = image_name
727-
dest = os.path.join(proj_root, "images", dest_image_name)
727+
# not to repeatedly add image id in memory replay training
728+
dest_image_name = src.name
729+
if append_image_id:
730+
dest_image_name = f"{src.stem}_{image_id}{src.suffix}"
728731

729-
# now, we will also need to update the path in the config files
732+
dest = Path(proj_root) / "images" / dest_image_name
733+
dest = dest.resolve()
730734

731-
if full_image_path:
732-
image["file_name"] = dest
733-
else:
734-
image["file_name"] = os.path.join("images", dest_image_name)
735+
file_name = str(Path(*dest.parts[-2:]))
736+
if full_image_path:
737+
file_name = str(dest)
735738

736-
if deepcopy:
737-
shutil.copy(src, dest)
738-
else:
739-
try:
740-
os.symlink(src, dest)
741-
except:
742-
pass
739+
if deepcopy:
740+
shutil.copy(src, dest)
741+
else:
742+
try:
743+
os.symlink(src, dest)
744+
except Exception as err:
745+
print(f"Could not create a symlink from {src} to {dest}: {err}")
746+
pass
743747

748+
image["file_name"] = file_name
744749
lookuptable[dest] = src
745750

746751
train_annotations = [

deeplabcut/pose_estimation_pytorch/apis/__init__.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,10 @@
99
# Licensed under GNU Lesser General Public License v3.0
1010
#
1111

12-
from deeplabcut.pose_estimation_pytorch.apis.analyze_images import analyze_images
12+
from deeplabcut.pose_estimation_pytorch.apis.analyze_images import (
13+
analyze_images,
14+
superanimal_analyze_images,
15+
)
1316
from deeplabcut.pose_estimation_pytorch.apis.analyze_videos import analyze_videos
1417
from deeplabcut.pose_estimation_pytorch.apis.convert_detections_to_tracklets import (
1518
convert_detections2tracklets,

deeplabcut/pose_estimation_pytorch/apis/analyze_images.py

Lines changed: 28 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -47,33 +47,37 @@ def superanimal_analyze_images(
4747
images: str | Path | list[str] | list[Path],
4848
max_individuals: int,
4949
out_folder: str,
50+
bbox_threshold: float = 0.6,
5051
progress_bar: bool = True,
5152
device: str | None = None,
5253
customized_pose_checkpoint: str | None = None,
5354
customized_detector_checkpoint: str | None = None,
5455
customized_model_config: str | None = None,
5556
):
5657
"""
57-
This funciton inferences a superanimal model on a set of images and saves the results as labeled images.
58+
This funciton inferences a superanimal model on a set of images and saves the
59+
results as labeled images.
5860
5961
Parameters
6062
----------
6163
superanimal_name: str
62-
The name of the superanimal to analyze.
63-
supported list:
64-
superanimal_topviewmouse
65-
superanimal_quadruped
64+
The name of the superanimal to analyze. Supported list:
65+
- "superanimal_topviewmouse"
66+
- "superanimal_quadruped"
6667
model_name: str
67-
The name of the model to use for inference.
68-
supported list:
69-
hrnetw32
68+
The name of the model to use for inference. Supported list:
69+
- "hrnetw32"
7070
images: str | Path | list[str] | list[Path]
7171
The images to analyze. Can either be a directory containing images, or
7272
a list of paths of images.
7373
max_individuals: int
7474
The maximum number of individuals to detect in each image.
7575
out_folder: str
7676
The directory where the labeled images will be saved.
77+
bbox_threshold: float, default=0.1
78+
The minimum confidence score to keep bounding box detections. Must be in (0, 1).
79+
Only used when `customized_model_config=None` (otherwise, edit your
80+
`customized_model_config` with the desired bbox_threshold).
7781
progress_bar: bool
7882
Whether to display a progress bar when running inference.
7983
device: str | None
@@ -95,17 +99,19 @@ def superanimal_analyze_images(
9599
--------
96100
>>> import deeplabcut
97101
>>> from deeplabcut.pose_estimation_pytorch.apis.analyze_images import superanimal_analyze_images
98-
>>> superanimal_name = 'superanimal_quadruped'
99-
>>> model_name = 'hrnetw32'
100-
>>> device = 'cuda'
102+
>>> superanimal_name = "superanimal_quadruped"
103+
>>> model_name = "hrnetw32"
104+
>>> device = "cuda"
101105
>>> max_individuals = 3
102-
>>> test_images_folder = 'test_rodent_images'
103-
>>> out_images_folder = 'vis_test_rodent_images'
104-
>>> ret = superanimal_analyze_images(superanimal_name,
105-
model_name,
106-
test_images_folder,
107-
max_individuals,
108-
out_images_folder)
106+
>>> test_images_folder = "test_rodent_images"
107+
>>> out_images_folder = "vis_test_rodent_images"
108+
>>> ret = superanimal_analyze_images(
109+
>>> superanimal_name,
110+
>>> model_name,
111+
>>> test_images_folder,
112+
>>> max_individuals,
113+
>>> out_images_folder
114+
>>> )
109115
"""
110116

111117
os.makedirs(out_folder, exist_ok=True)
@@ -119,6 +125,10 @@ def superanimal_analyze_images(
119125
snapshot_path,
120126
detector_path,
121127
) = get_config_model_paths(superanimal_name, model_name)
128+
129+
if "detector" in model_cfg:
130+
model_cfg["detector"]["model"]["box_score_thresh"] = bbox_threshold
131+
122132
config = {**project_config, **model_cfg}
123133
config = update_config(config, max_individuals, device)
124134
else:
@@ -146,9 +156,7 @@ def superanimal_analyze_images(
146156

147157
superanimal_colormaps = get_superanimal_colormaps()
148158
colormap = superanimal_colormaps[superanimal_name]
149-
150159
create_labeled_images_from_predictions(predictions, out_folder, colormap)
151-
152160
return predictions
153161

154162

@@ -164,8 +172,6 @@ def analyze_images(
164172
device: str | None = None,
165173
max_individuals: int | None = None,
166174
progress_bar: bool = True,
167-
superanimal_name=None,
168-
model_name=None,
169175
) -> dict[str, dict]:
170176
"""Runs analysis on images using a pose model.
171177

deeplabcut/pose_estimation_pytorch/apis/train.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -257,8 +257,10 @@ def train_network(
257257
dataset_params = loader.get_dataset_parameters()
258258
backbone_name = loader.model_cfg["model"]["backbone"]["model_name"]
259259
model_name = modelzoo_utils.get_pose_model_type(backbone_name)
260-
# at some point train_network should support a different train_file passing so memory replay can also take the same train file
260+
# at some point train_network should support a different train_file passing
261+
# so memory replay can also take the same train file
261262

263+
print("Preparing data for memory replay (this can take some time)")
262264
prepare_memory_replay(
263265
loader.project_path,
264266
shuffle,
@@ -271,6 +273,7 @@ def train_network(
271273
customized_pose_checkpoint=weight_init.customized_pose_checkpoint,
272274
)
273275

276+
print("Loading memory replay data")
274277
loader = COCOLoader(
275278
project_root=Path(loader.model_folder).parent / "memory_replay",
276279
model_config_path=loader.model_config_path,

deeplabcut/pose_estimation_pytorch/config/base/base.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -5,7 +5,7 @@ runner:
55
gpus: null
66
key_metric: "test.mAP"
77
key_metric_asc: true
8-
eval_interval: 1
8+
eval_interval: 10
99
optimizer:
1010
type: AdamW
1111
params:

0 commit comments

Comments
 (0)