class CenterLoss(nn.Module):
def __init__(self, num_class=10, num_feature=2):
super(CenterLoss, self).__init__()
self.num_class = num_class
self.num_feature = num_feature
self.centers = nn.Parameter(torch.randn(self.num_class, self.num_feature))
def forward(self, x, labels):
center = self.centers[labels]
dist = (x-center).pow(2).sum(dim=-1)
loss = torch.clamp(dist, min=1e-12, max=1e+12).mean(dim=-1)
return loss