This repository utilizes native Pytorch Automatic Mixed Precision (AMP) package for fast training and in this document we will go through of advantage of using it.
Remark1: Note that Apex module developed by Nvidia does the same thing but we rather use the native Pytorch module to make the code more compact.
The advantage of using the Automatic Mixed Precision module is that the speed of training gets boosted twice, at least.
Remark2: Be mindful that the Automatic Mixed Precision is only used in the training process.
We mainly follow Typical Mixed Precision Training documentation for Pytorch.
Which modules are the ones containing classes and code snippets that implement Pytorch Automatic Mixed Precision?
In our code there are two modules dealing with the Automatic Mixed Precision as follows:
main.pyutils/train_validation.py.
- To get started with the Automatic Mixed Precision one needs to create
GradScalerinstance inmain.pyfile before the beginning of the training.
scaler = torch.cuda.amp.GradScaler()
- Again, in
main.pypass thescalerinstance toTraValclass to create an instance of training and validation.
traval = TraVal(model, train_loader, optimizer,
criterion, scaler,
args, validation_loader,
writer = writer if args.local_rank == 0 else None,
curr_scen_name = scenario.curr_scen_name if args.local_rank == 0 else None)
- In
train_validation.pyimportautocastclass fromtorch.cuda.ampusing the following:
from torch.cuda.amp import autocast
- In the training process only forward pass with autocasting are recommended. Do not use backward passes. Therefore,
autocastonly wraps the forward pass(es) of the network, including the loss computation(s) as the following:
with autocast():
output = self.model(input)
loss = self.criterion(output, target)
- Scale the loss because training with the Automatic Mixed Precision requires that the loss is scaled in order to prevent the gradients from underflow.
self.scaler.scale(loss).backward()
-
Finally,
self.scaler.step()unscales the gradients of the optimizer's assigned parameters. If these gradients do not containinfs orNaNs,self.optimizer.step()is then called. Otherwise,self.optimizer.step()is skipped. -
Updates the scale for the next iteration.
self.scaler.update()