File tree Expand file tree Collapse file tree 1 file changed +5
-7
lines changed
pytorch_optimizer/optimizer Expand file tree Collapse file tree 1 file changed +5
-7
lines changed Original file line number Diff line number Diff line change @@ -121,14 +121,12 @@ def step(self, closure: Closure = None) -> Loss:
121121 exp_avg .mul_ (beta1 ).add_ (grad , alpha = 1.0 - beta1 )
122122
123123 square_grad = grad .square ()
124- sign_term = torch .sign (square_grad - exp_avg_sq )
125- exp_avg_sq .mul_ (beta2 ).add_ (sign_term * square_grad , alpha = 1.0 - beta2 )
126-
127- v_hat = exp_avg_sq / bias_correction2
128- adjusted_lr = group ['lr' ] / v_hat .sqrt_ ().add_ (group ['eps' ])
124+ exp_avg_sq .mul_ (beta2 ).addcmul_ (
125+ torch .sign (square_grad - exp_avg_sq ), square_grad , value = 1.0 - beta2
126+ )
129127
130- update = adjusted_lr * grad . abs () * exp_avg . sign ( )
128+ de_nom = square_grad . copy_ ( exp_avg_sq ). div_ ( bias_correction2 ). sqrt_ (). add_ ( group [ 'eps' ] )
131129
132- p .add_ ( update , alpha = - 1.0 )
130+ p .addcdiv_ ( grad . abs (). mul_ ( exp_avg . sign ()), de_nom , value = - group [ 'lr' ] )
133131
134132 return loss
You can’t perform that action at this time.
0 commit comments