Skip to content
Merged
8 changes: 7 additions & 1 deletion deeplabcut/pose_estimation_tensorflow/lib/inferenceutils.py
Original file line number Diff line number Diff line change
Expand Up @@ -809,7 +809,13 @@ def wrapped(i):
if unique is not None:
self.unique[i] = unique
pbar.update()


def from_pickle(self, pickle_path):
with open(pickle_path, "rb") as file:
data = pickle.load(file)
self.unique = data.pop("single", {})
self.assemblies = data

@staticmethod
def parse_metadata(data):
params = dict()
Expand Down
51 changes: 28 additions & 23 deletions deeplabcut/pose_estimation_tensorflow/predict_videos.py
Original file line number Diff line number Diff line change
Expand Up @@ -1495,7 +1495,7 @@ def _convert_detections_to_tracklets(
)
tracklets = {}

ass = inferenceutils.Assembler(
assembly_builder = inferenceutils.Assembler(
data,
max_n_individuals=inference_cfg["topktoretain"],
n_multibodyparts=len(cfg["multianimalbodyparts"]),
Expand All @@ -1512,22 +1512,22 @@ def _convert_detections_to_tracklets(
str(trainingsetfolder),
"CollectedData_" + cfg["scorer"] + ".h5",
)
ass.calibrate(train_data_file)
ass.assemble()
assembly_builder.calibrate(train_data_file)
assembly_builder.assemble()

output_path, _ = os.path.splitext(output_path)
output_path += ".pickle"
ass.to_pickle(output_path.replace(".pickle", "_assemblies.pickle"))
assembly_builder.to_pickle(output_path.replace(".pickle", "_assemblies.pickle"))

if cfg["uniquebodyparts"]:
tracklets["single"] = {}
tracklets["single"].update(ass.unique)
tracklets["single"].update(assembly_builder.unique)

for i, imname in tqdm(enumerate(ass.metadata["imnames"])):
assemblies = ass.assemblies.get(i)
for i, imname in tqdm(enumerate(assembly_builder.metadata["imnames"])):
assemblies = assembly_builder.assemblies.get(i)
if assemblies is None:
continue
animals = np.stack([ass.data[:, :3] for ass in assemblies])
animals = np.stack([assembly_builder.data[:, :3] for assembly_builder in assemblies])
if track_method == "box":
xy = trackingutils.calc_bboxes_from_keypoints(
animals, inference_cfg.get("boundingboxslack", 0)
Expand Down Expand Up @@ -1798,7 +1798,7 @@ def convert_detections2tracklets(
)
tracklets = {}
multi_bpts = cfg["multianimalbodyparts"]
ass = inferenceutils.Assembler(
assembly_builder = inferenceutils.Assembler(
data,
max_n_individuals=inferencecfg["topktoretain"],
n_multibodyparts=len(multi_bpts),
Expand All @@ -1808,16 +1808,21 @@ def convert_detections2tracklets(
window_size=window_size,
identity_only=identity_only,
)
if calibrate:
trainingsetfolder = auxiliaryfunctions.get_training_set_folder(cfg)
train_data_file = os.path.join(
cfg["project_path"],
str(trainingsetfolder),
"CollectedData_" + cfg["scorer"] + ".h5",
)
ass.calibrate(train_data_file)
ass.assemble()
ass.to_pickle(dataname.split(".h5")[0] + "_assemblies.pickle")
assemblies_filename = dataname.split(".h5")[0] + "_assemblies.pickle"
if not os.path.exists(assemblies_filename) or overwrite:
if calibrate:
trainingsetfolder = auxiliaryfunctions.get_training_set_folder(cfg)
train_data_file = os.path.join(
cfg["project_path"],
str(trainingsetfolder),
"CollectedData_" + cfg["scorer"] + ".h5",
)
assembly_builder.calibrate(train_data_file)
assembly_builder.assemble()
assembly_builder.to_pickle(assemblies_filename)
else:
assembly_builder.from_pickle(assemblies_filename)
print(f"Loading assemblies from {ass_filename}")
try:
data.close()
except AttributeError:
Expand All @@ -1829,7 +1834,7 @@ def convert_detections2tracklets(
tracklets["single"] = {}
_single = {}
for index, imname in enumerate(imnames):
single_detection = ass.unique.get(index)
single_detection = assembly_builder.unique.get(index)
if single_detection is None:
continue
imindex = int(re.findall(r"\d+", imname)[0])
Expand All @@ -1839,18 +1844,18 @@ def convert_detections2tracklets(
if inferencecfg["topktoretain"] == 1:
tracklets[0] = {}
for index, imname in tqdm(enumerate(imnames)):
assemblies = ass.assemblies.get(index)
assemblies = assembly_builder.assemblies.get(index)
if assemblies is None:
continue
tracklets[0][imname] = assemblies[0].data
else:
keep = set(multi_bpts).difference(ignore_bodyparts or [])
keep_inds = sorted(multi_bpts.index(bpt) for bpt in keep)
for index, imname in tqdm(enumerate(imnames)):
assemblies = ass.assemblies.get(index)
assemblies = assembly_builder.assemblies.get(index)
if assemblies is None:
continue
animals = np.stack([ass.data for ass in assemblies])
animals = np.stack([assembly_builder.data for assembly_builder in assemblies])
if not identity_only:
if track_method == "box":
xy = trackingutils.calc_bboxes_from_keypoints(
Expand Down