1919from deeplabcut .utils .auxfun_videos import VideoReader
2020from deeplabcut .utils .auxiliaryfunctions import attempt_to_make_folder
2121from matplotlib .path import Path
22- from matplotlib .widgets import Slider , LassoSelector , Button , CheckButtons
22+ from matplotlib .widgets import Slider , LassoSelector , Button , CheckButtons , TextBox
2323from PySide6 .QtWidgets import QMessageBox
2424from PySide6 .QtCore import QMutex
2525
@@ -327,6 +327,9 @@ def __init__(self, manager, videoname, trail_len=50):
327327
328328 self .dps = []
329329
330+ self .swap_id1 = None
331+ self .swap_id2 = None
332+
330333 def _prepare_canvas (self , manager , fig ):
331334 params = {
332335 "keymap.save" : "s" ,
@@ -358,7 +361,7 @@ def _prepare_canvas(self, manager, fig):
358361
359362 img = self .video .read_frame ()
360363 self .im = self .ax1 .imshow (img )
361- self .scat = self .ax1 .scatter ([], [], s = self .dotsize ** 2 , picker = True )
364+ self .scat = self .ax1 .scatter ([], [], s = self .dotsize ** 2 , picker = True )
362365 self .scat .set_offsets (manager .xy [:, 0 ])
363366 self .scat .set_color (self .colors )
364367 self .trails = sum (
@@ -374,6 +377,7 @@ def _prepare_canvas(self, manager, fig):
374377 )
375378 self .vline_x = self .ax2 .axvline (0 , 0 , 1 , c = "k" , ls = ":" )
376379 self .vline_y = self .ax3 .axvline (0 , 0 , 1 , c = "k" , ls = ":" )
380+
377381 custom_lines = [
378382 plt .Line2D ([0 ], [0 ], color = self .cmap (i ), lw = 4 )
379383 for i in range (len (manager .individuals ))
@@ -420,10 +424,15 @@ def _prepare_canvas(self, manager, fig):
420424 self .ax_flag = self .fig .add_axes ([0.75 , 0.1 , 0.05 , 0.03 ])
421425 self .ax_save = self .fig .add_axes ([0.80 , 0.1 , 0.05 , 0.03 ])
422426 self .ax_help = self .fig .add_axes ([0.85 , 0.1 , 0.05 , 0.03 ])
427+ self .ax_swap = self .fig .add_axes ([0.90 , 0.1 , 0.05 , 0.03 ]) # New button
428+
423429 self .save_button = Button (self .ax_save , "Save" , color = "darkorange" )
424430 self .save_button .on_clicked (self .save )
425431 self .help_button = Button (self .ax_help , "Help" )
426432 self .help_button .on_clicked (self .display_help )
433+ self .swap_button = Button (self .ax_swap , "Swap" ) # New button
434+ self .swap_button .on_clicked (self .swap_tracklets ) # Placeholder action
435+
427436 self .drag_toggle = CheckButtons (self .ax_drag , ["Drag" ])
428437 self .drag_toggle .on_clicked (self .toggle_draggable_points )
429438 self .flag_button = Button (self .ax_flag , "Flag" )
@@ -441,9 +450,75 @@ def _prepare_canvas(self, manager, fig):
441450 self .ax1_background = self .fig .canvas .copy_from_bbox (self .ax1 .bbox )
442451 self .fig .show ()
443452
453+ # Create dropdowns for selecting tracklets to swap, placing them near the swap button
454+ self .ax_dropdown1 = self .fig .add_axes ([0.9 , 0.15 , 0.05 , 0.03 ])
455+ self .ax_dropdown2 = self .fig .add_axes ([0.9 , 0.20 , 0.05 , 0.03 ])
456+ self .textbox1 = TextBox (self .ax_dropdown1 , "ID 1" )
457+ self .textbox2 = TextBox (self .ax_dropdown2 , "ID 2" )
458+ self .textbox1 .on_submit (self .set_swap_id1 )
459+ self .textbox2 .on_submit (self .set_swap_id2 )
460+
444461 def show (self , fig = None ):
445462 self ._prepare_canvas (self .manager , fig )
446463
464+ def swap_tracklets (self , event ):
465+ if self .swap_id1 is not None and self .swap_id2 is not None :
466+
467+ # Get tracklet indices for each individual
468+ inds1 = [
469+ k
470+ for k in range (len (self .manager .tracklet2id ))
471+ if self .manager .tracklet2id [k ] == self .swap_id1
472+ ]
473+ inds2 = [
474+ k
475+ for k in range (len (self .manager .tracklet2id ))
476+ if self .manager .tracklet2id [k ] == self .swap_id2
477+ ]
478+
479+ print (f"Swapping tracklets { self .swap_id1 } and { self .swap_id2 } " )
480+
481+ # Frames to swap
482+ frames = []
483+ if len (self .cuts ) == 2 :
484+ frames = list (range (min (self .cuts ), max (self .cuts ) + 1 ))
485+ elif len (self .cuts ) == 1 :
486+ frames = [self .cuts [0 ]]
487+ else :
488+ frames = list (range (self .curr_frame , self .manager .nframes ))
489+
490+ # Swap the tracklets
491+ for i in range (min (len (inds1 ), len (inds2 ))):
492+ self .manager .swap_tracklets (inds1 [i ], inds2 [i ], frames )
493+ self .display_traces ()
494+ self .slider .set_val (self .curr_frame )
495+
496+ def set_swap_id1 (self , val ):
497+ # check that the input is a valid from the list of individuals
498+ if int (val ) in self .manager .tracklet2id :
499+ self .swap_id1 = int (val )
500+ print ("ID 1 set." )
501+ else :
502+ print (
503+ f"Invalid ID. Please select a valid ID from the list of individuals: { set (self .manager .tracklet2id )} "
504+ )
505+ self .swap_id1 = None
506+
507+ def set_swap_id2 (self , val ):
508+ # check that the input is a valid from the list of individuals
509+ if int (val ) in self .manager .tracklet2id :
510+ self .swap_id2 = int (val )
511+ print ("ID 2 set." )
512+ else :
513+ print (
514+ f"Invalid ID. Please select a valid ID from the list of individuals: { set (self .manager .tracklet2id )} "
515+ )
516+ self .swap_id2 = None
517+
518+ def terminate (self , event ):
519+ plt .close (self .fig )
520+ self .player .terminate ()
521+
447522 def fill_shaded_areas (self ):
448523 self .clean_collections ()
449524 if self .picked_pair :
@@ -587,9 +662,9 @@ def on_press(self, event):
587662 if len (self .cuts ) > 1 :
588663 self .cuts .sort ()
589664 if self .picked_pair :
590- self .manager .tracklet_swaps [self .picked_pair ][
591- self .cuts
592- ] = ~ self . manager . tracklet_swaps [ self . picked_pair ][ self . cuts ]
665+ self .manager .tracklet_swaps [self .picked_pair ][self . cuts ] = (
666+ ~ self .manager . tracklet_swaps [ self . picked_pair ][ self . cuts ]
667+ )
593668 self .fill_shaded_areas ()
594669 self .cuts = []
595670 for line in self .ax_slider .lines :
@@ -807,7 +882,7 @@ def on_change(self, val):
807882
808883 def update_dotsize (self , val ):
809884 self .dotsize = val
810- self .scat .set_sizes ([self .dotsize ** 2 ])
885+ self .scat .set_sizes ([self .dotsize ** 2 ])
811886
812887 @staticmethod
813888 def calc_distance (x1 , y1 , x2 , y2 ):
0 commit comments