Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
6abbfcd
keepdeconvweights
n-poulsen Dec 9, 2024
f92e733
updated user_guide
n-poulsen Dec 10, 2024
708a3cd
remove extra param
n-poulsen Dec 10, 2024
647255f
video_analysis: in_random_order, n_tracks
n-poulsen Dec 10, 2024
461a718
Merge branch 'pytorch_dlc' into niels/missing_torch_api_args
n-poulsen Dec 10, 2024
b330456
implemented dynamic cropping. needs testing.
n-poulsen Dec 11, 2024
563521b
improved docs
n-poulsen Dec 12, 2024
d003de5
improve dynamic cropper
n-poulsen Dec 12, 2024
d87b4a9
update user_guide
n-poulsen Dec 12, 2024
df76d20
implement analyze_images
n-poulsen Dec 12, 2024
4a1af1d
Merge branch 'pytorch_dlc' into niels/missing_torch_api_args
n-poulsen Dec 16, 2024
82b7110
analyze_images - plotting option
n-poulsen Dec 16, 2024
80ebf7f
Merge branch 'pytorch_dlc' into niels/missing_torch_api_args
n-poulsen Dec 18, 2024
ffddaad
update analyze_images plotting
n-poulsen Dec 18, 2024
39f6deb
Update deeplabcut/compat.py
n-poulsen Dec 19, 2024
8d06b3c
Update deeplabcut/compat.py
n-poulsen Dec 19, 2024
f4ce116
Update deeplabcut/compat.py
n-poulsen Dec 19, 2024
b5a0891
Update deeplabcut/compat.py
n-poulsen Dec 19, 2024
9089979
Update deeplabcut/pose_estimation_pytorch/apis/convert_detections_to_…
n-poulsen Dec 19, 2024
f9631d5
Update docs/pytorch/user_guide.md
n-poulsen Dec 19, 2024
45ed76a
update user guide
n-poulsen Dec 19, 2024
01e6f3a
Merge branch 'pytorch_dlc' into niels/missing_torch_api_args
n-poulsen Dec 19, 2024
0b0b52b
fixed weights_only and tests
n-poulsen Dec 19, 2024
2261aa7
fix updated config
n-poulsen Dec 19, 2024
925bf51
Merge branch 'pytorch_dlc' into niels/missing_torch_api_args
n-poulsen Dec 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions deeplabcut/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -99,6 +99,7 @@
return_evaluate_network_data,
analyze_videos,
create_tracking_dataset,
analyze_images,
analyze_time_lapse_frames,
convert_detections2tracklets,
extract_maps,
Expand Down
266 changes: 242 additions & 24 deletions deeplabcut/compat.py

Large diffs are not rendered by default.

1 change: 1 addition & 0 deletions deeplabcut/pose_estimation_pytorch/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
# Licensed under GNU Lesser General Public License v3.0
#
from deeplabcut.pose_estimation_pytorch.apis import (
analyze_images,
analyze_videos,
convert_detections2tracklets,
evaluate_network,
Expand Down
110 changes: 91 additions & 19 deletions deeplabcut/pose_estimation_pytorch/apis/analyze_images.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,10 +24,12 @@

import deeplabcut.pose_estimation_pytorch.apis.visualization as visualization
import deeplabcut.pose_estimation_pytorch.config.utils as config_utils
import deeplabcut.pose_estimation_pytorch.data as data
import deeplabcut.pose_estimation_pytorch.modelzoo as modelzoo
from deeplabcut.core.engine import Engine
from deeplabcut.modelzoo.utils import get_superanimal_colormaps
from deeplabcut.pose_estimation_pytorch.apis.utils import (
build_predictions_dataframe,
get_inference_runners,
get_model_snapshots,
get_scorer_name,
Expand Down Expand Up @@ -215,15 +217,21 @@ def superanimal_analyze_images(
def analyze_images(
config: str | Path,
images: str | Path | list[str] | list[Path],
output_dir: str | Path,
frame_type: str | None = None,
output_dir: str | Path | None = None,
shuffle: int = 1,
trainingsetindex: int = 0,
snapshot_index: int | None = None,
detector_snapshot_index: int | None = None,
modelprefix: str = "",
device: str | None = None,
max_individuals: int | None = None,
save_as_csv: bool = False,
progress_bar: bool = True,
plotting: bool | str = False,
pcutoff: float | None = None,
bbox_pcutoff: float | None = None,
plot_skeleton: bool = True,
) -> dict[str, dict]:
"""Runs analysis on images using a pose model.

Expand All @@ -232,6 +240,9 @@ def analyze_images(
images: The image(s) to run inference on. Can be the path to an image, the path
to a directory containing images, or a list of image paths or directories
containing images.
frame_type: Filters the images to analyze to only the ones with the given suffix
(e.g. setting `frame_type`=".png" will only analyze ".png" images). The
default behavior analyzes all ".jpg", ".jpeg" and ".png" images.
output_dir: The directory where the predictions will be stored.
shuffle: The shuffle for which to run image analysis.
trainingsetindex: The trainingsetindex for which to run image analysis.
Expand All @@ -243,17 +254,20 @@ def analyze_images(
device: The device to use to run image analysis.
max_individuals: The maximum number of individuals to detect in each image. Set
to the number of individuals in the project if None.
save_as_csv: Whether to also save the predictions as a CSV file.
progress_bar: Whether to display a progress bar when running inference.
plotting: Whether to plot predictions on images.
pcutoff: The cutoff score when plotting pose predictions. Must be None or in
(0, 1). If None, the pcutoff is read from the project configuration file.
bbox_pcutoff: The cutoff score when plotting bounding box predictions. Must be
None or in (0, 1). If None, it is read from the project configuration file.
plot_skeleton: If a skeleton is defined in the model configuration file, whether
to plot the skeleton connecting the predicted bodyparts on the images.

Returns:
A dictionary mapping each image filename to the different types of predictions
for it (e.g. "bodyparts", "unique_bodyparts", "bboxes", "bbox_scores")
"""
if not output_dir.is_dir():
raise ValueError(
f"The `output_dir` must be a directory - please change: {output_dir}"
)

cfg = auxiliaryfunctions.read_config(config)
train_frac = cfg["TrainingFraction"][trainingsetindex]
model_folder = Path(cfg["project_path"]) / auxiliaryfunctions.get_model_folder(
Expand Down Expand Up @@ -285,22 +299,22 @@ def analyze_images(
images=images,
snapshot_path=snapshot.path,
detector_path=None if detector_snapshot is None else detector_snapshot.path,
frame_type=frame_type,
device=device,
max_individuals=max_individuals,
progress_bar=progress_bar,
)

if len(predictions) == 0:
logging.info(f"Found no images in {images}")
print(f"Found no images in {images}")
return {}

# FIXME(niels): store as H5
pred_json = {}
for image, pred in predictions.items():
pred_json[image] = dict(bodyparts=pred["bodyparts"].tolist())
for k in ("unique_bodyparts", "bboxes", "bbox_scores"):
if k in pred:
pred_json[image][k] = pred[k].tolist()
if output_dir is None:
images = list(predictions.keys())
output_dir = Path(images[0]).parent.resolve()
print(f"Setting output directory to {output_dir}")
output_dir = Path(output_dir)
output_dir.mkdir(exist_ok=True)

scorer = get_scorer_name(
cfg,
Expand All @@ -309,10 +323,60 @@ def analyze_images(
snapshot_uid=get_scorer_uid(snapshot, detector_snapshot),
modelprefix=modelprefix,
)
output_path = output_dir / f"{scorer}_image_predictions.json"
logging.info(f"Saving predictions to {output_path}")
with open(output_path, "w") as f:
json.dump(pred_json, f)
individuals = model_cfg["metadata"]["individuals"]
if max_individuals is not None:
individuals = [f"individual{i}" for i in range(max_individuals)]

df_predictions = build_predictions_dataframe(
scorer=scorer,
predictions=predictions,
parameters=data.PoseDatasetParameters(
bodyparts=model_cfg["metadata"]["bodyparts"],
unique_bpts=model_cfg["metadata"]["unique_bodyparts"],
individuals=individuals,
),
image_name_to_index=None,
)

output_filepath = output_dir / f"image_predictions_{scorer}.h5"
print(f"Saving predictions to {output_filepath}")

df_predictions.to_hdf(output_filepath, key="predictions")
if save_as_csv:
print(f"Saving CSV as {output_filepath}")
df_predictions.to_csv(output_filepath.with_suffix(".csv"))

if plotting:
plot_dir = output_dir / f"LabeledImages_{scorer}"
plot_dir.mkdir(exist_ok=True)

mode = plotting if isinstance(plotting, str) else "bodypart"

bodyparts = model_cfg["metadata"]["bodyparts"]
skeleton = None
if plot_skeleton and len(cfg.get("skeleton", [])) > 0:
skeleton = [
(bodyparts.index(bpt_0), bodyparts.index(bpt_1))
for bpt_0, bpt_1 in cfg["skeleton"]
]

if pcutoff is None:
pcutoff = cfg.get("pcutoff", 0.6)
if bbox_pcutoff is None:
bbox_pcutoff = cfg.get("bbox_pcutoff", 0.6)

visualization.create_labeled_images(
predictions=predictions,
out_folder=plot_dir,
pcutoff=pcutoff,
bboxes_pcutoff=bbox_pcutoff,
mode=mode,
cmap=cfg.get("colormap", "rainbow"),
dot_size=cfg.get("dotsize", 12),
alpha_value=cfg.get("alphavalue", 12),
skeleton=skeleton,
skeleton_color=cfg.get("skeleton_color"),
)

return predictions

Expand All @@ -322,6 +386,7 @@ def analyze_image_folder(
images: str | Path | list[str] | list[Path],
snapshot_path: str | Path,
detector_path: str | Path | None = None,
frame_type: str | None = None,
device: str | None = None,
max_individuals: int | None = None,
progress_bar: bool = True,
Expand All @@ -335,6 +400,9 @@ def analyze_image_folder(
snapshot_path: The path of the snapshot to use to analyze the images.
detector_path: The path of the detector snapshot to use to analyze the images,
if a top-down model was used.
frame_type: Filters the images to analyze to only the ones with the given suffix
(e.g. setting `frame_type`=".png" will only analyze ".png" images). The
default behavior analyzes all ".jpg", ".jpeg" and ".png" images.
device: The device to use to run image analysis.
max_individuals: The maximum number of individuals to detect in each image. Set
to the number of individuals in the project if None.
Expand Down Expand Up @@ -379,7 +447,11 @@ def analyze_image_folder(
detector_transform=None,
)

image_paths = parse_images_and_image_folders(images)
image_suffixes = ".png", ".jpg", ".jpeg"
if frame_type is not None:
image_suffixes = (frame_type, )

image_paths = parse_images_and_image_folders(images, image_suffixes)
pose_inputs = image_paths
if detector_runner is not None:
logging.info(f"Running object detection with {detector_path}")
Expand Down
Loading