Skip to content

Commit a4212ff

Browse files
committed
black
1 parent 7645fe0 commit a4212ff

File tree

19 files changed

+153
-105
lines changed

19 files changed

+153
-105
lines changed

deeplabcut/compat.py

Lines changed: 30 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -52,7 +52,7 @@ def get_available_aug_methods(engine: Engine) -> tuple[str, ...]:
5252
if engine == Engine.TF:
5353
return "imgaug", "default", "deterministic", "scalecrop", "tensorpack"
5454
elif engine == Engine.PYTORCH:
55-
return ("albumentations", )
55+
return ("albumentations",)
5656

5757
raise RuntimeError(f"Unknown augmentation for engine: {engine}")
5858

@@ -217,6 +217,7 @@ def train_network(
217217

218218
if engine == Engine.TF:
219219
from deeplabcut.pose_estimation_tensorflow import train_network
220+
220221
if max_snapshots_to_keep is None:
221222
max_snapshots_to_keep = 5
222223

@@ -238,6 +239,7 @@ def train_network(
238239
)
239240
elif engine == Engine.PYTORCH:
240241
from deeplabcut.pose_estimation_pytorch.apis import train_network
242+
241243
_update_device(gputouse, torch_kwargs)
242244
if "display_iters" not in torch_kwargs:
243245
torch_kwargs["display_iters"] = displayiters
@@ -298,14 +300,18 @@ def return_train_network_path(
298300

299301
if engine == Engine.TF:
300302
from deeplabcut.pose_estimation_tensorflow import return_train_network_path
303+
301304
return return_train_network_path(
302305
config,
303306
shuffle=shuffle,
304307
trainingsetindex=trainingsetindex,
305308
modelprefix=modelprefix,
306309
)
307310
elif engine == Engine.PYTORCH:
308-
from deeplabcut.pose_estimation_pytorch.apis.utils import return_train_network_path
311+
from deeplabcut.pose_estimation_pytorch.apis.utils import (
312+
return_train_network_path,
313+
)
314+
309315
return return_train_network_path(
310316
config,
311317
shuffle=shuffle,
@@ -458,6 +464,7 @@ def evaluate_network(
458464

459465
if engine == Engine.TF:
460466
from deeplabcut.pose_estimation_tensorflow import evaluate_network
467+
461468
return evaluate_network(
462469
str(config),
463470
Shuffles=Shuffles,
@@ -473,6 +480,7 @@ def evaluate_network(
473480
)
474481
elif engine == Engine.PYTORCH:
475482
from deeplabcut.pose_estimation_pytorch.apis import evaluate_network
483+
476484
_update_device(gputouse, torch_kwargs)
477485
return evaluate_network(
478486
config,
@@ -553,6 +561,7 @@ def return_evaluate_network_data(
553561

554562
if engine == Engine.TF:
555563
from deeplabcut.pose_estimation_tensorflow import return_evaluate_network_data
564+
556565
return return_evaluate_network_data(
557566
config,
558567
shuffle=shuffle,
@@ -817,6 +826,7 @@ def analyze_videos(
817826

818827
if engine == Engine.TF:
819828
from deeplabcut.pose_estimation_tensorflow import analyze_videos
829+
820830
kwargs = {}
821831
if use_openvino is not None: # otherwise default comes from tensorflow API
822832
kwargs["use_openvino"] = use_openvino
@@ -847,6 +857,7 @@ def analyze_videos(
847857
)
848858
elif engine == Engine.PYTORCH:
849859
from deeplabcut.pose_estimation_pytorch.apis import analyze_videos
860+
850861
_update_device(gputouse, torch_kwargs)
851862

852863
if use_shelve:
@@ -859,7 +870,8 @@ def analyze_videos(
859870
print(
860871
f"You called analyze_videos with parameters ``batchsize={batchsize}"
861872
f"`` and batch_size={torch_kwargs['batch_size']}. Only one is "
862-
f"needed/used. Using batch size {torch_kwargs['batch_size']}")
873+
f"needed/used. Using batch size {torch_kwargs['batch_size']}"
874+
)
863875
else:
864876
torch_kwargs["batch_size"] = batchsize
865877

@@ -910,6 +922,7 @@ def create_tracking_dataset(
910922

911923
if engine == Engine.TF:
912924
from deeplabcut.pose_estimation_tensorflow import create_tracking_dataset
925+
913926
return create_tracking_dataset(
914927
config,
915928
videos,
@@ -997,6 +1010,7 @@ def analyze_time_lapse_frames(
9971010

9981011
if engine == Engine.TF:
9991012
from deeplabcut.pose_estimation_tensorflow import analyze_time_lapse_frames
1013+
10001014
return analyze_time_lapse_frames(
10011015
config,
10021016
directory,
@@ -1112,6 +1126,7 @@ def convert_detections2tracklets(
11121126

11131127
if engine == Engine.TF:
11141128
from deeplabcut.pose_estimation_tensorflow import convert_detections2tracklets
1129+
11151130
return convert_detections2tracklets(
11161131
config,
11171132
videos,
@@ -1214,6 +1229,7 @@ def extract_maps(
12141229

12151230
if engine == Engine.TF:
12161231
from deeplabcut.pose_estimation_tensorflow import extract_maps
1232+
12171233
return extract_maps(
12181234
config,
12191235
shuffle=shuffle,
@@ -1228,7 +1244,9 @@ def extract_maps(
12281244

12291245

12301246
def visualize_scoremaps(
1231-
image: np.ndarray, scmap: np.ndarray, engine: Engine = DEFAULT_ENGINE,
1247+
image: np.ndarray,
1248+
scmap: np.ndarray,
1249+
engine: Engine = DEFAULT_ENGINE,
12321250
):
12331251
"""
12341252
This function is only implemented for tensorflow models/shuffles, and will throw
@@ -1237,6 +1255,7 @@ def visualize_scoremaps(
12371255
if engine == Engine.TF:
12381256
# TODO: also works for Pytorch, but should not import as then requires TF
12391257
from deeplabcut.pose_estimation_tensorflow import visualize_scoremaps
1258+
12401259
return visualize_scoremaps(image, scmap)
12411260

12421261
raise NotImplementedError(f"This function is not implemented for {engine}")
@@ -1257,7 +1276,10 @@ def visualize_locrefs(
12571276
"""
12581277
if engine == Engine.TF:
12591278
from deeplabcut.pose_estimation_tensorflow import visualize_locrefs
1260-
return visualize_locrefs(image, scmap, locref_x, locref_y, step=step, zoom_width=zoom_width)
1279+
1280+
return visualize_locrefs(
1281+
image, scmap, locref_x, locref_y, step=step, zoom_width=zoom_width
1282+
)
12611283

12621284
raise NotImplementedError(f"This function is not implemented for {engine}")
12631285

@@ -1275,6 +1297,7 @@ def visualize_paf(
12751297
"""
12761298
if engine == Engine.TF:
12771299
from deeplabcut.pose_estimation_tensorflow import visualize_paf
1300+
12781301
return visualize_paf(image, paf, step=step, colors=colors)
12791302

12801303
raise NotImplementedError(f"This function is not implemented for {engine}")
@@ -1350,6 +1373,7 @@ def extract_save_all_maps(
13501373

13511374
if engine == Engine.TF:
13521375
from deeplabcut.pose_estimation_tensorflow import extract_save_all_maps
1376+
13531377
return extract_save_all_maps(
13541378
config,
13551379
shuffle=shuffle,
@@ -1443,6 +1467,7 @@ def export_model(
14431467

14441468
if engine == Engine.TF:
14451469
from deeplabcut.pose_estimation_tensorflow import export_model
1470+
14461471
return export_model(
14471472
cfg_path=cfg_path,
14481473
shuffle=shuffle,

deeplabcut/pose_estimation_pytorch/apis/analyze_videos.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -45,7 +45,10 @@ class VideoIterator(VideoReader):
4545
"""A class to iterate over videos, with possible added context"""
4646

4747
def __init__(
48-
self, video_path: str | Path, context: list[dict[str, Any]] | None = None, cropping: list[int] | None = None
48+
self,
49+
video_path: str | Path,
50+
context: list[dict[str, Any]] | None = None,
51+
cropping: list[int] | None = None,
4952
) -> None:
5053
super().__init__(str(video_path))
5154
self._context = context
@@ -267,7 +270,10 @@ def analyze_videos(
267270
pose_cfg = auxiliaryfunctions.read_plainconfig(pose_cfg_path)
268271

269272
snapshot_index, detector_snapshot_index = parse_snapshot_index_for_analysis(
270-
cfg, model_cfg, snapshot_index, detector_snapshot_index,
273+
cfg,
274+
model_cfg,
275+
snapshot_index,
276+
detector_snapshot_index,
271277
)
272278

273279
if cropping is None and cfg.get("cropping", False):
@@ -366,7 +372,7 @@ def analyze_videos(
366372
data=output_data,
367373
metadata=metadata,
368374
dir_path=output_path,
369-
name_basis=output_name_basis
375+
name_basis=output_name_basis,
370376
)
371377

372378
pred_bodyparts = np.stack([p["bodyparts"][..., :3] for p in predictions])
@@ -473,7 +479,7 @@ def create_df_from_prediction(
473479
index=range(len(pred_bodyparts)),
474480
)
475481
if pred_unique_bodyparts is not None:
476-
unique_columns = [dlc_scorer], ['single'], unique_bodyparts, coords
482+
unique_columns = [dlc_scorer], ["single"], unique_bodyparts, coords
477483
df_u = pd.DataFrame(
478484
pred_unique_bodyparts.reshape((len(pred_unique_bodyparts), -1)),
479485
columns=pd.MultiIndex.from_product(unique_columns, names=cols_names),
@@ -627,7 +633,7 @@ def _generate_output_data(
627633
}
628634

629635
if "bboxes" in frame_predictions:
630-
output[key]["bboxes"] = frame_predictions["bboxes"]
636+
output[key]["bboxes"] = frame_predictions["bboxes"]
631637
output[key]["bbox_scores"] = frame_predictions["bbox_scores"]
632638

633639
if "identity_scores" in frame_predictions:

deeplabcut/pose_estimation_pytorch/apis/convert_detections_to_tracklets.py

Lines changed: 8 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,11 @@ def convert_detections2tracklets(
7070
# print("These are used for all videos, but won't be save to the cfg file.")
7171

7272
rel_model_dir = auxiliaryfunctions.get_model_folder(
73-
train_fraction, shuffle, cfg, modelprefix=modelprefix, engine=Engine.PYTORCH,
73+
train_fraction,
74+
shuffle,
75+
cfg,
76+
modelprefix=modelprefix,
77+
engine=Engine.PYTORCH,
7478
)
7579
model_dir = Path(cfg["project_path"]) / rel_model_dir
7680
path_test_config = model_dir / "test" / "pose_cfg.yaml"
@@ -119,7 +123,9 @@ def convert_detections2tracklets(
119123

120124
output_name_basis = video_name + dlc_scorer
121125
print(f"Loading From {output_path / output_name_basis} full and meta pickles")
122-
data, metadata = auxfun_multianimal.load_multianimal_full_meta_data(output_path, output_name_basis)
126+
data, metadata = auxfun_multianimal.load_multianimal_full_meta_data(
127+
output_path, output_name_basis
128+
)
123129
if track_method == "ellipse":
124130
method = "el"
125131
elif track_method == "box":

deeplabcut/pose_estimation_pytorch/apis/evaluate.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -222,7 +222,7 @@ def evaluate_snapshot(
222222
parameters = PoseDatasetParameters(
223223
bodyparts=project_bodyparts,
224224
unique_bpts=parameters.unique_bpts,
225-
individuals=parameters.individuals
225+
individuals=parameters.individuals,
226226
)
227227

228228
predictions = {}
@@ -311,7 +311,7 @@ def evaluate_snapshot(
311311
dot_size=cfg["dotsize"],
312312
alpha_value=cfg["alphavalue"],
313313
p_cutoff=cfg["pcutoff"],
314-
bboxes_cutoff=bboxes_cutoff
314+
bboxes_cutoff=bboxes_cutoff,
315315
)
316316

317317
return df_predictions

deeplabcut/pose_estimation_pytorch/apis/utils.py

Lines changed: 14 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -352,22 +352,27 @@ def build_predictions_dataframe(
352352
if image_name_to_index is not None:
353353
index_data.append(image_name_to_index(image))
354354
if "bboxes" in image_predictions:
355-
bboxes_data.append((image_predictions["bboxes"], image_predictions["bbox_scores"]))
355+
bboxes_data.append(
356+
(image_predictions["bboxes"], image_predictions["bbox_scores"])
357+
)
356358

357359
if len(index_data) > 0:
358360
index = pd.MultiIndex.from_tuples(index_data)
359361
else:
360362
index = list(predictions.keys())
361363

362-
return (pd.DataFrame(
363-
prediction_data,
364-
index=index,
365-
columns=build_dlc_dataframe_columns(
366-
scorer=scorer,
367-
parameters=parameters,
368-
with_likelihood=True,
364+
return (
365+
pd.DataFrame(
366+
prediction_data,
367+
index=index,
368+
columns=build_dlc_dataframe_columns(
369+
scorer=scorer,
370+
parameters=parameters,
371+
with_likelihood=True,
372+
),
369373
),
370-
), dict(zip(index, bboxes_data)))
374+
dict(zip(index, bboxes_data)),
375+
)
371376

372377

373378
def get_inference_runners(

deeplabcut/pose_estimation_pytorch/modelzoo/inference.py

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -147,10 +147,10 @@ def _video_inference_superanimal(
147147
pred_unique_bodyparts = None
148148

149149
bbox_keys_in_predictions = {"bboxes", "bbox_scores"}
150-
bboxes_list = [{key: value
151-
for key, value in p.items()
152-
if key in bbox_keys_in_predictions}
153-
for i, p in enumerate(predictions)]
150+
bboxes_list = [
151+
{key: value for key, value in p.items() if key in bbox_keys_in_predictions}
152+
for i, p in enumerate(predictions)
153+
]
154154

155155
bbox = cropping
156156
if cropping is None:

deeplabcut/pose_estimation_tensorflow/core/evaluate.py

Lines changed: 4 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -276,10 +276,8 @@ def return_evaluate_network_data(
276276
# Data=pd.read_hdf(os.path.join(cfg["project_path"],str(trainingsetfolder),'CollectedData_' + cfg["scorer"] + '.h5'),'df_with_missing')
277277

278278
# Get list of body parts to evaluate network for
279-
comparisonbodyparts = (
280-
auxiliaryfunctions.filter_bodyparts_from_config(
281-
cfg, comparisonbodyparts
282-
)
279+
comparisonbodyparts = auxiliaryfunctions.filter_bodyparts_from_config(
280+
cfg, comparisonbodyparts
283281
)
284282
##################################################
285283
# Load data...
@@ -707,10 +705,8 @@ def evaluate_network(
707705
)
708706

709707
# Get list of body parts to evaluate network for
710-
comparisonbodyparts = (
711-
auxiliaryfunctions.filter_bodyparts_from_config(
712-
cfg, comparisonbodyparts
713-
)
708+
comparisonbodyparts = auxiliaryfunctions.filter_bodyparts_from_config(
709+
cfg, comparisonbodyparts
714710
)
715711
# Make folder for evaluation
716712
auxiliaryfunctions.attempt_to_make_folder(

deeplabcut/pose_estimation_tensorflow/core/evaluate_multianimal.py

Lines changed: 3 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -145,10 +145,8 @@ def evaluate_multianimal_full(
145145
conversioncode.guarantee_multiindex_rows(Data)
146146

147147
# Get list of body parts to evaluate network for
148-
comparisonbodyparts = (
149-
auxiliaryfunctions.filter_bodyparts_from_config(
150-
cfg, comparisonbodyparts
151-
)
148+
comparisonbodyparts = auxiliaryfunctions.filter_bodyparts_from_config(
149+
cfg, comparisonbodyparts
152150
)
153151
all_bpts = np.asarray(
154152
len(cfg["individuals"]) * cfg["multianimalbodyparts"] + cfg["uniquebodyparts"]
@@ -555,7 +553,7 @@ def evaluate_multianimal_full(
555553
data=predicted_data,
556554
metadata=metadata,
557555
dir_path=output_path,
558-
name_basis=output_name_basis
556+
name_basis=output_name_basis,
559557
)
560558
tf.compat.v1.reset_default_graph()
561559

deeplabcut/pose_estimation_tensorflow/predict_multianimal.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -226,7 +226,7 @@ def AnalyzeMultiAnimalVideo(
226226
data=predicted_data,
227227
metadata=metadata,
228228
dir_path=output_path,
229-
name_basis=output_name_basis
229+
name_basis=output_name_basis,
230230
)
231231

232232

0 commit comments

Comments
 (0)