@@ -52,7 +52,7 @@ def get_available_aug_methods(engine: Engine) -> tuple[str, ...]:
5252 if engine == Engine .TF :
5353 return "imgaug" , "default" , "deterministic" , "scalecrop" , "tensorpack"
5454 elif engine == Engine .PYTORCH :
55- return ("albumentations" , )
55+ return ("albumentations" ,)
5656
5757 raise RuntimeError (f"Unknown augmentation for engine: { engine } " )
5858
@@ -217,6 +217,7 @@ def train_network(
217217
218218 if engine == Engine .TF :
219219 from deeplabcut .pose_estimation_tensorflow import train_network
220+
220221 if max_snapshots_to_keep is None :
221222 max_snapshots_to_keep = 5
222223
@@ -238,6 +239,7 @@ def train_network(
238239 )
239240 elif engine == Engine .PYTORCH :
240241 from deeplabcut .pose_estimation_pytorch .apis import train_network
242+
241243 _update_device (gputouse , torch_kwargs )
242244 if "display_iters" not in torch_kwargs :
243245 torch_kwargs ["display_iters" ] = displayiters
@@ -298,14 +300,18 @@ def return_train_network_path(
298300
299301 if engine == Engine .TF :
300302 from deeplabcut .pose_estimation_tensorflow import return_train_network_path
303+
301304 return return_train_network_path (
302305 config ,
303306 shuffle = shuffle ,
304307 trainingsetindex = trainingsetindex ,
305308 modelprefix = modelprefix ,
306309 )
307310 elif engine == Engine .PYTORCH :
308- from deeplabcut .pose_estimation_pytorch .apis .utils import return_train_network_path
311+ from deeplabcut .pose_estimation_pytorch .apis .utils import (
312+ return_train_network_path ,
313+ )
314+
309315 return return_train_network_path (
310316 config ,
311317 shuffle = shuffle ,
@@ -458,6 +464,7 @@ def evaluate_network(
458464
459465 if engine == Engine .TF :
460466 from deeplabcut .pose_estimation_tensorflow import evaluate_network
467+
461468 return evaluate_network (
462469 str (config ),
463470 Shuffles = Shuffles ,
@@ -473,6 +480,7 @@ def evaluate_network(
473480 )
474481 elif engine == Engine .PYTORCH :
475482 from deeplabcut .pose_estimation_pytorch .apis import evaluate_network
483+
476484 _update_device (gputouse , torch_kwargs )
477485 return evaluate_network (
478486 config ,
@@ -553,6 +561,7 @@ def return_evaluate_network_data(
553561
554562 if engine == Engine .TF :
555563 from deeplabcut .pose_estimation_tensorflow import return_evaluate_network_data
564+
556565 return return_evaluate_network_data (
557566 config ,
558567 shuffle = shuffle ,
@@ -817,6 +826,7 @@ def analyze_videos(
817826
818827 if engine == Engine .TF :
819828 from deeplabcut .pose_estimation_tensorflow import analyze_videos
829+
820830 kwargs = {}
821831 if use_openvino is not None : # otherwise default comes from tensorflow API
822832 kwargs ["use_openvino" ] = use_openvino
@@ -847,6 +857,7 @@ def analyze_videos(
847857 )
848858 elif engine == Engine .PYTORCH :
849859 from deeplabcut .pose_estimation_pytorch .apis import analyze_videos
860+
850861 _update_device (gputouse , torch_kwargs )
851862
852863 if use_shelve :
@@ -859,7 +870,8 @@ def analyze_videos(
859870 print (
860871 f"You called analyze_videos with parameters ``batchsize={ batchsize } "
861872 f"`` and batch_size={ torch_kwargs ['batch_size' ]} . Only one is "
862- f"needed/used. Using batch size { torch_kwargs ['batch_size' ]} " )
873+ f"needed/used. Using batch size { torch_kwargs ['batch_size' ]} "
874+ )
863875 else :
864876 torch_kwargs ["batch_size" ] = batchsize
865877
@@ -910,6 +922,7 @@ def create_tracking_dataset(
910922
911923 if engine == Engine .TF :
912924 from deeplabcut .pose_estimation_tensorflow import create_tracking_dataset
925+
913926 return create_tracking_dataset (
914927 config ,
915928 videos ,
@@ -997,6 +1010,7 @@ def analyze_time_lapse_frames(
9971010
9981011 if engine == Engine .TF :
9991012 from deeplabcut .pose_estimation_tensorflow import analyze_time_lapse_frames
1013+
10001014 return analyze_time_lapse_frames (
10011015 config ,
10021016 directory ,
@@ -1112,6 +1126,7 @@ def convert_detections2tracklets(
11121126
11131127 if engine == Engine .TF :
11141128 from deeplabcut .pose_estimation_tensorflow import convert_detections2tracklets
1129+
11151130 return convert_detections2tracklets (
11161131 config ,
11171132 videos ,
@@ -1214,6 +1229,7 @@ def extract_maps(
12141229
12151230 if engine == Engine .TF :
12161231 from deeplabcut .pose_estimation_tensorflow import extract_maps
1232+
12171233 return extract_maps (
12181234 config ,
12191235 shuffle = shuffle ,
@@ -1228,7 +1244,9 @@ def extract_maps(
12281244
12291245
12301246def visualize_scoremaps (
1231- image : np .ndarray , scmap : np .ndarray , engine : Engine = DEFAULT_ENGINE ,
1247+ image : np .ndarray ,
1248+ scmap : np .ndarray ,
1249+ engine : Engine = DEFAULT_ENGINE ,
12321250):
12331251 """
12341252 This function is only implemented for tensorflow models/shuffles, and will throw
@@ -1237,6 +1255,7 @@ def visualize_scoremaps(
12371255 if engine == Engine .TF :
12381256 # TODO: also works for Pytorch, but should not import as then requires TF
12391257 from deeplabcut .pose_estimation_tensorflow import visualize_scoremaps
1258+
12401259 return visualize_scoremaps (image , scmap )
12411260
12421261 raise NotImplementedError (f"This function is not implemented for { engine } " )
@@ -1257,7 +1276,10 @@ def visualize_locrefs(
12571276 """
12581277 if engine == Engine .TF :
12591278 from deeplabcut .pose_estimation_tensorflow import visualize_locrefs
1260- return visualize_locrefs (image , scmap , locref_x , locref_y , step = step , zoom_width = zoom_width )
1279+
1280+ return visualize_locrefs (
1281+ image , scmap , locref_x , locref_y , step = step , zoom_width = zoom_width
1282+ )
12611283
12621284 raise NotImplementedError (f"This function is not implemented for { engine } " )
12631285
@@ -1275,6 +1297,7 @@ def visualize_paf(
12751297 """
12761298 if engine == Engine .TF :
12771299 from deeplabcut .pose_estimation_tensorflow import visualize_paf
1300+
12781301 return visualize_paf (image , paf , step = step , colors = colors )
12791302
12801303 raise NotImplementedError (f"This function is not implemented for { engine } " )
@@ -1350,6 +1373,7 @@ def extract_save_all_maps(
13501373
13511374 if engine == Engine .TF :
13521375 from deeplabcut .pose_estimation_tensorflow import extract_save_all_maps
1376+
13531377 return extract_save_all_maps (
13541378 config ,
13551379 shuffle = shuffle ,
@@ -1443,6 +1467,7 @@ def export_model(
14431467
14441468 if engine == Engine .TF :
14451469 from deeplabcut .pose_estimation_tensorflow import export_model
1470+
14461471 return export_model (
14471472 cfg_path = cfg_path ,
14481473 shuffle = shuffle ,
0 commit comments