Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 11 additions & 2 deletions deeplabcut/pose_estimation_pytorch/apis/analyze_videos.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,14 @@ def __init__(
if self._crop:
self.set_bbox(*cropping)

def set_crop(self, cropping: list[int] | None = None) -> None:
"""Sets the cropping parameters for the video."""
self._crop = cropping is not None
if self._crop:
self.set_bbox(*cropping)
else:
self.set_bbox(0, 1, 0, 1, relative=True)

def get_context(self) -> list[dict[str, Any]] | None:
if self._context is None:
return None
Expand Down Expand Up @@ -157,6 +165,8 @@ def video_inference(
"""
if not isinstance(video, VideoIterator):
video = VideoIterator(str(video), cropping=cropping)
elif cropping is not None:
video.set_crop(cropping)

n_frames = video.get_n_frames(robust=robust_nframes)
vid_w, vid_h = video.dimensions
Expand Down Expand Up @@ -421,7 +431,7 @@ def analyze_videos(
output_prefix = video.stem + dlc_scorer
output_pkl = output_path / f"{output_prefix}_full.pickle"

video_iterator = VideoIterator(video)
video_iterator = VideoIterator(video, cropping=cropping)

shelf_writer = None
if use_shelve:
Expand All @@ -439,7 +449,6 @@ def analyze_videos(
video=video_iterator,
pose_runner=pose_runner,
detector_runner=detector_runner,
cropping=cropping,
shelf_writer=shelf_writer,
robust_nframes=robust_nframes,
)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -143,7 +143,7 @@ def get_pose_prediction(

batch_size, num_joints = x.shape

dz = torch.zeros((batch_size, 1, num_joints, 3)).to(x.device)
dz = torch.zeros((batch_size, 1, num_joints, 3), device=heatmap.device)
for b in range(batch_size):
for j in range(num_joints):
dz[b, 0, j, 2] = heatmap[b, y[b, j], x[b, j], j]
Expand All @@ -155,7 +155,7 @@ def get_pose_prediction(
x = x * scale_factors[1] + 0.5 * scale_factors[1] + dz[:, :, :, 0]
y = y * scale_factors[0] + 0.5 * scale_factors[0] + dz[:, :, :, 1]

pose = torch.empty((batch_size, 1, num_joints, 3))
pose = torch.zeros((batch_size, 1, num_joints, 3), device=heatmap.device)
pose[:, :, :, 0] = x
pose[:, :, :, 1] = y
pose[:, :, :, 2] = dz[:, :, :, 2]
Expand Down