Skip to content

Commit 3d78552

Browse files
fcdl94fcdl94
authored andcommitted
Fix TrainedWI model
1 parent 208510e commit 3d78552

File tree

1 file changed

+5
-3
lines changed

1 file changed

+5
-3
lines changed

methods/imprinting.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -228,7 +228,7 @@ def cool_down(self, dataset, epochs=1):
228228
for c in range(self.task.get_n_classes()[0]):
229229
if c == cls:
230230
ds = dataset.get_k_image_of_class(cl=cls, k=K) # get K images of class c
231-
wc = get_prototype(self.model, ds, cls, self.device, return_all=True)
231+
wc = get_prototype(model, ds, cls, self.device, return_all=True)
232232
if wc is None:
233233
# print("WC is None!!")
234234
weight[c] = F.normalize(model.cls.cls[0].weight[c], dim=0)
@@ -337,7 +337,7 @@ def cool_down(self, dataset, epochs=1):
337337
state[k[7:]] = v
338338
model.load_state_dict(state, strict=True)
339339
model = model.to(self.device)
340-
model.body.eval()
340+
model.eval()
341341
for p in model.body.parameters():
342342
p.requires_grad = False
343343

@@ -358,13 +358,14 @@ def cool_down(self, dataset, epochs=1):
358358
optimizer.zero_grad()
359359

360360
for e in range(self.EPISODE):
361+
model.eval()
361362
weight = torch.zeros_like(model.cls.cls[0].weight)
362363
K = random.choice([1, 2, 5])
363364
cls = random.choice(classes) # sample N classes
364365
for c in range(self.task.get_n_classes()[0]):
365366
if c == cls:
366367
ds = dataset.get_k_image_of_class(cl=cls, k=K) # get K images of class c
367-
wc = get_prototype(self.model, ds, cls, self.device, return_all=True)
368+
wc = get_prototype(model, ds, cls, self.device, return_all=True)
368369
if wc is None:
369370
# print("WC is None!!")
370371
weight[c] = F.normalize(model.cls.cls[0].weight[c], dim=0)
@@ -374,6 +375,7 @@ def cool_down(self, dataset, epochs=1):
374375
else:
375376
weight[c] = F.normalize(model.cls.cls[0].weight[c], dim=0)
376377

378+
model.train()
377379
# get a batch of images from dataloader
378380
it, batch = get_batch(it, dataloader)
379381
ds = dataset.get_k_image_of_class(cl=cls, k=self.BATCH_SIZE) # get K images of class c

0 commit comments

Comments
 (0)