@@ -228,7 +228,7 @@ def cool_down(self, dataset, epochs=1):
228228 for c in range (self .task .get_n_classes ()[0 ]):
229229 if c == cls :
230230 ds = dataset .get_k_image_of_class (cl = cls , k = K ) # get K images of class c
231- wc = get_prototype (self . model , ds , cls , self .device , return_all = True )
231+ wc = get_prototype (model , ds , cls , self .device , return_all = True )
232232 if wc is None :
233233 # print("WC is None!!")
234234 weight [c ] = F .normalize (model .cls .cls [0 ].weight [c ], dim = 0 )
@@ -337,7 +337,7 @@ def cool_down(self, dataset, epochs=1):
337337 state [k [7 :]] = v
338338 model .load_state_dict (state , strict = True )
339339 model = model .to (self .device )
340- model .body . eval ()
340+ model .eval ()
341341 for p in model .body .parameters ():
342342 p .requires_grad = False
343343
@@ -358,13 +358,14 @@ def cool_down(self, dataset, epochs=1):
358358 optimizer .zero_grad ()
359359
360360 for e in range (self .EPISODE ):
361+ model .eval ()
361362 weight = torch .zeros_like (model .cls .cls [0 ].weight )
362363 K = random .choice ([1 , 2 , 5 ])
363364 cls = random .choice (classes ) # sample N classes
364365 for c in range (self .task .get_n_classes ()[0 ]):
365366 if c == cls :
366367 ds = dataset .get_k_image_of_class (cl = cls , k = K ) # get K images of class c
367- wc = get_prototype (self . model , ds , cls , self .device , return_all = True )
368+ wc = get_prototype (model , ds , cls , self .device , return_all = True )
368369 if wc is None :
369370 # print("WC is None!!")
370371 weight [c ] = F .normalize (model .cls .cls [0 ].weight [c ], dim = 0 )
@@ -374,6 +375,7 @@ def cool_down(self, dataset, epochs=1):
374375 else :
375376 weight [c ] = F .normalize (model .cls .cls [0 ].weight [c ], dim = 0 )
376377
378+ model .train ()
377379 # get a batch of images from dataloader
378380 it , batch = get_batch (it , dataloader )
379381 ds = dataset .get_k_image_of_class (cl = cls , k = self .BATCH_SIZE ) # get K images of class c
0 commit comments