Skip to content
Merged
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
remove debug prints
  • Loading branch information
n-poulsen committed Jan 27, 2025
commit 295575c1df4c20fb8a5f79079d5f2d5078e56da7
Original file line number Diff line number Diff line change
Expand Up @@ -74,21 +74,13 @@ def forward(
>>> poses = predictor.forward(stride, output)
"""
heatmaps = outputs["heatmap"]
# print("---")
# print(f"HEATMAPS SHAPE: {heatmaps.shape}")
scale_factors = stride, stride

if self.apply_sigmoid:
heatmaps = self.sigmoid(heatmaps)

heatmaps = heatmaps.permute(0, 2, 3, 1)
batch_size, height, width, num_joints = heatmaps.shape
# print(f"HEATMAPS SHAPE: {heatmaps.shape}")
# print(f"STRIDE: {stride}")
# for i in range(heatmaps.shape[-1]):
# hm = heatmaps[0, :, :, i]
# y0, x0 = stride * (hm == torch.max(hm)).nonzero()[0]
# print(i, f"({x0}, {y0}):", torch.max(heatmaps[..., i]))

locrefs = None
if self.location_refinement:
Expand All @@ -99,16 +91,10 @@ def forward(
locrefs = locrefs * self.locref_std

poses = self.get_pose_prediction(heatmaps, locrefs, scale_factors)
# print(f"PREDICTION: {poses.shape}")
# for p in poses[0]:
# for kpt in p:
# print(kpt)
# print("---")

if self.clip_scores:
poses[..., 2] = torch.clip(poses[..., 2], min=0, max=1)

# raise ValueError("")
return {"poses": poses}

def get_top_values(
Expand Down