-
Notifications
You must be signed in to change notification settings - Fork 9.8k
Expand file tree
/
Copy pathexample.py
More file actions
118 lines (99 loc) · 4.04 KB
/
example.py
File metadata and controls
118 lines (99 loc) · 4.04 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
import argparse
import os
import torch
from checkpoint import Checkpointer
from model import ModelArgs, Transformer
from torch.distributed.fsdp import fully_shard, MixedPrecisionPolicy
from utils import inspect_mixed_precision, inspect_model
def verify_min_gpu_count(min_gpus: int = 2) -> bool:
""" verification that we have at least 2 gpus to run dist examples """
has_gpu = torch.accelerator.is_available()
gpu_count = torch.accelerator.device_count()
return has_gpu and gpu_count >= min_gpus
def set_modules_to_forward_prefetch(model, num_to_forward_prefetch):
for i, layer in enumerate(model.layers):
if i >= len(model.layers) - num_to_forward_prefetch:
break
layers_to_prefetch = [
model.layers[i + j] for j in range(1, num_to_forward_prefetch + 1)
]
layer.set_modules_to_forward_prefetch(layers_to_prefetch)
def set_modules_to_backward_prefetch(model, num_to_backward_prefetch):
for i, layer in enumerate(model.layers):
if i < num_to_backward_prefetch:
continue
layers_to_prefetch = [
model.layers[i - j] for j in range(1, num_to_backward_prefetch + 1)
]
layer.set_modules_to_backward_prefetch(layers_to_prefetch)
def main(args):
_min_gpu_count = 2
if not verify_min_gpu_count(min_gpus=_min_gpu_count):
print(f"Unable to locate sufficient {_min_gpu_count} gpus to run this example. Exiting.")
exit()
rank = int(os.environ["LOCAL_RANK"])
if torch.accelerator.is_available():
device_type = torch.accelerator.current_accelerator()
device = torch.device(f"{device_type}:{rank}")
torch.accelerator.device_index(rank)
print(f"Running on rank {rank} on device {device}")
else:
device = torch.device("cpu")
print(f"Running on device {device}")
backend = torch.distributed.get_default_backend_for_device(device)
torch.distributed.init_process_group(backend=backend, device_id=device)
torch.manual_seed(0)
vocab_size = 1024
batch_size = 32
seq_len = 64
model_args = ModelArgs(
n_layers=10,
n_heads=4,
vocab_size=vocab_size,
max_seq_len=seq_len,
dropout_p=0,
)
with torch.device("meta"):
model = Transformer(model_args)
fsdp_kwargs = {}
if args.mixed_precision:
fsdp_kwargs["mp_policy"] = MixedPrecisionPolicy(
param_dtype=torch.bfloat16,
reduce_dtype=torch.float32,
)
for layer in model.layers:
fully_shard(layer, **fsdp_kwargs)
fully_shard(model, **fsdp_kwargs)
inspect_model(model)
if args.explicit_prefetching:
set_modules_to_forward_prefetch(model, num_to_forward_prefetch=2)
set_modules_to_backward_prefetch(model, num_to_backward_prefetch=2)
checkpointer = Checkpointer("checkpoints", dcp_api=args.dcp_api)
if checkpointer.last_training_time is None:
model.to_empty(device=device)
model.reset_parameters()
else:
checkpointer.load_model(model)
if args.mixed_precision:
inspect_mixed_precision(model)
optim = torch.optim.Adam(model.parameters(), lr=1e-2)
if checkpointer.last_training_time is not None:
checkpointer.load_optim(model, optim)
for _ in range(10):
if args.explicit_prefetching:
model.unshard()
x = torch.randint(0, vocab_size, (batch_size, seq_len), device=device)
loss = model(x).sum()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
optim.step()
optim.zero_grad()
checkpointer.save(model, optim)
torch.distributed.destroy_process_group()
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="PyTorch FSDP2 example")
parser.add_argument("--explicit-prefetching", action="store_true", default=False)
parser.add_argument("--mixed-precision", action="store_true", default=False)
parser.add_argument("--dcp-api", action="store_true", default=False)
args = parser.parse_args()
main(args)