Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
update the default PAF config
  • Loading branch information
n-poulsen committed Jul 19, 2024
commit 17ccb5fce7253d52cf028a35e473785778ff3aa1
Original file line number Diff line number Diff line change
@@ -1,5 +1,4 @@
type: DLCRNetHead
weight_init: normal
predictor:
type: PartAffinityFieldPredictor
num_animals: "num_individuals"
Expand All @@ -11,12 +10,12 @@ predictor:
min_affinity: 0.05
graph: "paf_graph"
edges_to_keep: "paf_edges_to_keep"
apply_sigmoid: false
clip_scores: true
apply_sigmoid: true
clip_scores: false
target_generator:
type: SequentialGenerator
generators:
- type: HeatmapGaussianGenerator
- type: HeatmapPlateauGenerator
num_heatmaps: "num_bodyparts"
pos_dist_thresh: 17
heatmap_mode: KEYPOINT
Expand All @@ -27,7 +26,7 @@ target_generator:
width: 20
criterion:
heatmap:
type: WeightedMSECriterion
type: WeightedBCECriterion
weight: 1.0
locref:
type: WeightedHuberCriterion
Expand Down
4 changes: 4 additions & 0 deletions deeplabcut/pose_estimation_pytorch/models/heads/dlcrnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@ def __init__(
in_refined_channels
)
locref_config["channels"][0] = locref_config["channels"][-1]

super().__init__(
predictor,
target_generator,
Expand All @@ -64,6 +65,9 @@ def __init__(
locref_config,
weight_init,
)
if num_stages > 0:
self.stride *= 2 # extra deconv layer where it's multi-stage

self.paf_head = DeconvModule(**paf_config)

self.convt1 = self._make_layer_same_padding(
Expand Down