Skip to content

Commit 13d9565

Browse files
committed
update: optimize the code
1 parent c6148ad commit 13d9565

File tree

1 file changed

+5
-7
lines changed
  • pytorch_optimizer/optimizer

1 file changed

+5
-7
lines changed

pytorch_optimizer/optimizer/ano.py

Lines changed: 5 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -121,14 +121,12 @@ def step(self, closure: Closure = None) -> Loss:
121121
exp_avg.mul_(beta1).add_(grad, alpha=1.0 - beta1)
122122

123123
square_grad = grad.square()
124-
sign_term = torch.sign(square_grad - exp_avg_sq)
125-
exp_avg_sq.mul_(beta2).add_(sign_term * square_grad, alpha=1.0 - beta2)
126-
127-
v_hat = exp_avg_sq / bias_correction2
128-
adjusted_lr = group['lr'] / v_hat.sqrt_().add_(group['eps'])
124+
exp_avg_sq.mul_(beta2).addcmul_(
125+
torch.sign(square_grad - exp_avg_sq), square_grad, value=1.0 - beta2
126+
)
129127

130-
update = adjusted_lr * grad.abs() * exp_avg.sign()
128+
de_nom = square_grad.copy_(exp_avg_sq).div_(bias_correction2).sqrt_().add_(group['eps'])
131129

132-
p.add_(update, alpha=-1.0)
130+
p.addcdiv_(grad.abs().mul_(exp_avg.sign()), de_nom, value=-group['lr'])
133131

134132
return loss

0 commit comments

Comments
 (0)