Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion examples/gpt/single/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -132,7 +132,7 @@ GPT训练默认使用AdamW优化器以及cosine 学习率衰减,这里通过
| output_dir | 指定输出文件 |
| ckpt_dir | checkpoint的加载目录 |
| fused_linear | 是否使用fused_linear代替传统Linear加速训练。注:该功能需要cuda 11.6及以上编译的paddle支持。 |

| tensor_fusion | 是否使用tensor_fustion功能加速训练 |

## 运行方式

Expand Down
1 change: 1 addition & 0 deletions examples/gpt/single/configs_1.3B_single_card.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ PreTraining:
output_dir: ./output
ckpt_dir:
fused_linear: True
tensor_fusion: True

Model:
vocab_size: 50304
Expand Down
1 change: 1 addition & 0 deletions examples/gpt/single/configs_345m_single_card.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ PreTraining:
output_dir: ./output
ckpt_dir:
fused_linear: True
tensor_fusion: True

Model:
vocab_size: 50304
Expand Down
16 changes: 11 additions & 5 deletions examples/gpt/single/gpt_module.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from fleetx.utils import logger
from fleetx.optim import lr_scheduler as lr
from fleetx.core.module.basic_module import BasicModule
from fleetx.utils.tensor_fusion_helper import fused_parameters


class GPTModule(BasicModule):
Expand Down Expand Up @@ -61,6 +62,8 @@ def training_step_end(self, loss, epoch, step, reader_cost, train_cost):
def configure_optimizers(self):
if self.args.decay_steps is None:
self.args.decay_steps = self.args.max_steps
if self.args.tensor_fusion:
decay_fused_tensors, all_fused_tensors = fused_parameters(self.model)
warmup_step = self.args.warmup_rate * self.args.decay_steps
lr_scheduler = lr.CosineAnnealingWithWarmupDecay(
max_lr=self.args.max_lr,
Expand All @@ -74,17 +77,20 @@ def configure_optimizers(self):

# Generate parameter names needed to perform weight decay.
# All bias and LayerNorm parameters are excluded.
decay_params = [
p.name for n, p in self.model.named_parameters()
if not any(nd in n for nd in ["bias", "norm"])
]
if self.args.tensor_fusion:
decay_params = [p.name for p in decay_fused_tensors]
else:
decay_params = [
p.name for n, p in self.model.named_parameters()
if not any(nd in n for nd in ["bias", "norm"])
]
optimizer = paddle.optimizer.AdamW(
learning_rate=lr_scheduler
if lr_scheduler is not None else self.args.max_lr,
beta1=self.args.adam_beta1,
beta2=self.args.adam_beta2,
epsilon=self.args.adam_epsilon,
parameters=self.model.parameters(),
parameters=all_fused_tensors if self.args.tensor_fusion else self.model.parameters(),
weight_decay=self.args.weight_decay,
grad_clip=clip,
apply_decay_param_fun=lambda x: x in decay_params,
Expand Down
3 changes: 3 additions & 0 deletions examples/gpt/tools.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ def add_dict(config, k, v):
add_dict(yaml_dict, "PreTraining", global_config["PreTraining"])
args = argparse.Namespace(**yaml_dict)

if not hasattr(args, 'recompute_granularity'):
args.recompute_granularity = None

args.test_iters = args.eval_iters * 10

if args.fused_linear and not is_fused_matmul_bias_supported():
Expand Down
28 changes: 14 additions & 14 deletions fleetx/core/engine/eager_engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,6 +53,20 @@ def __init__(self, module, configs=None):
"'loss_fn' must be sub classes of `paddle.nn.Layer` or any callable function, but got: {module.loss_fn.__class__.__name__}."
)

self._configs = configs

# configs
for k, v in configs.items():
self.__dict__.update({'_{}'.format(k): v})

if self._use_pure_fp16:
self._scaler = paddle.amp.GradScaler(
init_loss_scaling=self._scale_loss)
self._module.model = paddle.amp.decorate(
models=self._module.model, level='O2', save_dtype='float32')
else:
self._scaler = None

optimizers = module.configure_optimizers()

if optimizers and isinstance(optimizers, (
Expand All @@ -77,20 +91,6 @@ def __init__(self, module, configs=None):
"Only support optimizer or/and lr_scheduler as outputs of `configure_optimizers`."
)

self._configs = configs

# configs
for k, v in configs.items():
self.__dict__.update({'_{}'.format(k): v})

if self._use_pure_fp16:
self._scaler = paddle.amp.GradScaler(
init_loss_scaling=self._scale_loss)
self._module.model = paddle.amp.decorate(
models=self._module.model, level='O2', save_dtype='float32')
else:
self._scaler = None

self._module.global_step = 0

def fit(self, epoch=1, train_data_loader=None, valid_data_loader=None):
Expand Down
125 changes: 125 additions & 0 deletions fleetx/utils/tensor_fusion_helper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,125 @@
# Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

import paddle
from paddle.framework import core
import numpy as np
from collections import OrderedDict

from paddle.fluid.framework import in_dygraph_mode, _in_legacy_dygraph

if in_dygraph_mode():
from paddle.distributed.fleet.meta_parallel.sharding.group_sharded_storage import ParamStorage, GradStorage
elif _in_legacy_dygraph():
from paddle.distributed.fleet.utils.internal_storage import ParamStorage, GradStorage

from paddle.distributed.fleet.meta_parallel.sharding.sharding_utils import Type

alignment = {
"gpu": 256,
}
align = {
Type.fp16.value: 2,
Type.fp32.value: 4,
}

def assign_group_by_size(parameters, group_size=256 * 1024 * 1024):
is_sparse_gradient = [False] * len(parameters)

if in_dygraph_mode():
group_indices = core.eager_assign_group_by_size(
parameters, is_sparse_gradient, [group_size, group_size])
elif _in_legacy_dygraph():
group_indices = core.assign_group_by_size(parameters,
is_sparse_gradient,
[group_size, group_size])

var_groups = OrderedDict()
for group_idx, indices in enumerate(group_indices):
for index in indices:
var_groups.setdefault(group_idx, []).append(parameters[index])
return var_groups

def flatten_dense_tensors(parameters):
_buffer_size = 0
_param2align = {}
dtype = parameters[0].dtype

for param in parameters:
assert param.trainable, "param must be trainable..."
size = np.prod(param.shape) * align[dtype]
remaining = size % alignment["gpu"]
ali = 0 if remaining == 0 else alignment["gpu"] - remaining
align_ = ali // align[dtype]
_buffer_size += np.prod(param.shape) + align_
_param2align[param.name] = align_

param_storage = ParamStorage(size=_buffer_size, dtype=dtype, device="gpu")

param_storage.add_rank_params(parameters, _param2align)

# process gradient
grad_storage = GradStorage(size=_buffer_size,
dtype=dtype,
device="gpu",
destination="0",
parm2align=_param2align)

for param in parameters:
grad_storage.add_grad(param, _param2align[param.name])

# param_storage --> grad_storage
param_storage.buffer._copy_gradient_from(grad_storage.buffer)
param_storage.buffer.stop_gradient = False
return param_storage, grad_storage

def obtain_storage(parameters):
if len(parameters) < 1:
return []

var_groups = assign_group_by_size(parameters)
storage = []
for group_idx, parameters in var_groups.items():
param_storage, grad_storage = flatten_dense_tensors(parameters)
storage.append(param_storage.buffer)
return storage

def fused_parameters(model, use_sharding=False):
decay_params = []
other_params = []

for param in model.parameters():
if not any(nd in param.name for nd in ["bias", "norm"]):
decay_params.append(param)
else:
other_params.append(param)

print("all parameters length:", len(model.parameters()))
print("decay_params len: {}, params len: {}".format(len(decay_params), len(other_params)))

decay_fused = decay_params if use_sharding else obtain_storage(decay_params)
other_fused = other_params if use_sharding else obtain_storage(other_params)
all_fused = decay_fused + other_fused

return decay_fused, all_fused

def all_reduce_parameters(params, group):
if group.nranks < 2:
return

div_factor = 1.0 / group.nranks
with paddle.framework.no_grad():
for p in params:
grad = p.grad.scale_(div_factor)
paddle.distributed.all_reduce(grad, group=group)