diff --git a/paddle/fluid/framework/distributed_strategy.proto b/paddle/fluid/framework/distributed_strategy.proto index e6a7d74cc43433318ea825927e72c779b14ab43c..3627a8cf71c1e418dababee234deb55c1c99213e 100644 --- a/paddle/fluid/framework/distributed_strategy.proto +++ b/paddle/fluid/framework/distributed_strategy.proto @@ -1,4 +1,5 @@ // Copyright (c) 2020 PaddlePaddle Authors. All Rights Reserved. +// Copyright (c) 2021 NVIDIA Corporation. All rights reserved. // // Licensed under the Apache License, Version 2.0 (the "License"); // you may not use this file except in compliance with the License. @@ -41,12 +42,14 @@ message ShardingConfig { optional bool optimize_offload = 9 [ default = false ]; optional bool pp_allreduce_in_optimize = 10 [ default = false ]; optional int32 pp_degree = 11 [ default = 1 ]; + optional bool optimize_cast = 12 [ default = false ]; } message HybridConfig { optional int32 dp_degree = 1 [ default = -1 ]; optional int32 mp_degree = 2 [ default = 1 ]; optional int32 pp_degree = 3 [ default = 1 ]; + optional int32 sharding_degree = 4 [ default = 1 ]; } message AMPConfig { @@ -109,6 +112,7 @@ message BuildStrategy { optional bool fuse_bn_add_act_ops = 10 [ default = true ]; optional bool enable_auto_fusion = 11 [ default = false ]; optional bool enable_addto = 12 [ default = false ]; + optional bool fix_op_run_order = 13 [ default = false ]; } message ExecutionStrategy { @@ -118,6 +122,16 @@ message ExecutionStrategy { optional bool use_thread_barrier = 4 [ default = false ]; } +message GradientScaleConfig { + // Optional value ['avg', 'sum', 'customized'] + // If avg, loss@grad will be divided by the number of devices, + // that is, the gradient will be accumulated and averaged among + // multiple devices. + // Else if sum, the gradient will accumulated among multiple + // devices. + optional string scale_strategy = 1 [ default = 'avg' ]; +} + message AsyncConfig { optional int32 k_steps = 1 [ default = -1 ]; optional int32 max_merge_var_num = 2 [ default = 1 ]; @@ -133,10 +147,23 @@ message AsyncConfig { optional int32 use_ps_gpu = 12 [ default = 0 ]; } +message TrainerDescConfig { + optional string dump_fields_path = 1; + repeated string dump_fields = 2; + repeated string dump_param = 3; + repeated string stat_var_names = 4; +} + message PipelineConfig { optional int32 micro_batch_size = 1 [ default = 1 ]; optional int32 accumulate_steps = 2 [ default = 1 ]; optional string schedule_mode = 3 [ default = '1F1B' ]; + optional bool p2p_cache_shape = 4 [ default = true ]; +} + +message TensorParallelConfig { + optional int32 tensor_parallel_degree = 1 [ default = 1 ]; + optional int32 tensor_init_seed = 2 [ default = -1 ]; } message DistributedStrategy { @@ -168,7 +195,13 @@ message DistributedStrategy { optional bool fp16_allreduce = 25 [ default = false ]; optional bool sharding = 26 [ default = false ]; optional float last_comm_group_size_MB = 27 [ default = 1 ]; - optional bool find_unused_parameters = 28 [ default = true ]; + optional bool find_unused_parameters = 28 [ default = false ]; + optional bool tensor_parallel = 29 [ default = false ]; + optional bool without_graph_optimization = 30 [ default = false ]; + optional int32 fuse_grad_size_in_num = 31 [ default = 8 ]; + optional bool calc_comm_same_stream = 32 [ default = false ]; + optional bool asp = 33 [ default = false ]; + optional bool fuse_grad_merge = 34 [ default = false ]; optional RecomputeConfig recompute_configs = 101; optional AMPConfig amp_configs = 102; @@ -182,8 +215,11 @@ message DistributedStrategy { optional AdaptiveLocalSGDConfig adaptive_localsgd_configs = 110; optional ShardingConfig sharding_configs = 111; optional HybridConfig hybrid_configs = 112; + optional TensorParallelConfig tensor_parallel_configs = 113; + optional TrainerDescConfig trainer_desc_configs = 114; optional BuildStrategy build_strategy = 201; optional ExecutionStrategy execution_strategy = 202; + optional GradientScaleConfig gradient_scale_configs = 203; } message DistributedJobInfo { diff --git a/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py new file mode 100644 index 0000000000000000000000000000000000000000..706d64d8d35b6112f3feb270c4952a7c4276b00f --- /dev/null +++ b/python/paddle/distributed/fleet/meta_parallel/pipeline_parallel.py @@ -0,0 +1,327 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and + +import paddle +import paddle.fluid as fluid +from .meta_parallel_base import MetaParallelBase +from .pp_utils.utils import is_float_tensor, _initialize_recompute_hcg +from .parallel_layers.pp_layers import PipelineLayer + +from ..utils.hybrid_parallel_util import broadcast_mp_parameters +from ..utils.hybrid_parallel_util import broadcast_dp_parameters +from ..utils.log_util import logger +from ..meta_optimizers.dygraph_optimizer import HybridParallelOptimizer, HybridParallelGradScaler +from .pp_utils import p2p_communication as p2p + +__all__ = [] + + +class PipelineParallel(MetaParallelBase): + def __init__(self, layers, hcg, strategy): + if not isinstance(layers, PipelineLayer): + raise TypeError( + "The Layer should be a derived class of PipelineLayer.") + super(PipelineParallel, self).__init__(layers, hcg, strategy) + self.use_data_parallel = self._hcg.get_data_parallel_world_size() > 1 + self.use_model_parallel = self._hcg.get_model_parallel_world_size() > 1 + + self.total_loss = None + + self.micro_batch_size = self._strategy.pipeline_configs[ + 'micro_batch_size'] + self.accumulate_steps = self._strategy.pipeline_configs[ + 'accumulate_steps'] + + self._using_cache = self._strategy.pipeline_configs['p2p_cache_shape'] + + self.num_stages = self._hcg.get_pipe_parallel_world_size() + self.stage_id = self._hcg.get_stage_id() + self.pp_group = self._hcg.get_pipe_parallel_group() + + p2p.initialize_p2p_groups(hcg, self._using_cache) + + _initialize_recompute_hcg(hcg) + + self.is_first_stage = self.stage_id == 0 + self.is_last_stage = (self.stage_id == (self.num_stages - 1)) + self.global_rank = self._hcg.get_global_rank() + self.micro_batch_id = 0 + + self._compute_loss = True + + logger.info("Pipeline Info -- num_stages: {}, stage_id: {}".format( + self.num_stages, self.stage_id)) + + if self.use_model_parallel: + logger.info("start broadcast mp parameters") + broadcast_mp_parameters(self._layers, self._hcg) + + if self.use_data_parallel: + logger.info("start broadcast dp parameters") + broadcast_dp_parameters(self._layers, self._hcg) + + def train_batch(self, data, optimizer, lr_scheduler=None, scaler=None): + assert isinstance(optimizer, HybridParallelOptimizer), ( + 'optimizer should be HybridParallelOptimizer subclass.') + if scaler is not None: + assert isinstance(scaler, HybridParallelGradScaler), ( + 'scaler should be HybridParallelGradScaler subclass or None.') + assert fluid.framework._dygraph_tracer()._has_grad, ( + 'Please enable the generation of gradients.') + + if self.is_first_stage or self.is_last_stage: + assert data is not None, ( + "For the first and the last stage, the data must be set.") + else: + data = None + + self.optimizer = optimizer + self.lr_scheduler = lr_scheduler + self.scaler = scaler + self.data = data + self._compute_loss = True + + self._layers.train() + + # store total loss of entire batch + self.total_loss = None + + # store data id for micro_batch + self.micro_batch_id = 0 + + # Next, use the 1f1b scheduling strategy. + # this strategy is inspired by: + # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/schedules.py + + startup_steps = (self.num_stages - self.stage_id - 1) + startup_steps = min(startup_steps, self.accumulate_steps) + steady_steps = self.accumulate_steps - startup_steps + + input_buffers = [] + output_buffers = [] + + for step_id in range(startup_steps): + input_tensor = p2p.recv_forward() + + output_tensor = self._forward_step(input_tensor) + p2p.send_forward(output_tensor) + + input_buffers.append(input_tensor) + output_buffers.append(output_tensor) + + if steady_steps > 0: + input_tensor = p2p.recv_forward() + + for i in range(steady_steps): + last_iter = (i == (steady_steps - 1)) + + output_tensor = self._forward_step(input_tensor) + + output_tensor_grad = p2p.send_forward_recv_backward(output_tensor) + + input_buffers.append(input_tensor) + output_buffers.append(output_tensor) + + input_tensor, output_tensor = input_buffers.pop( + 0), output_buffers.pop(0) + + input_tensor_grad = self._backward_step(input_tensor, output_tensor, + output_tensor_grad) + + if last_iter: + input_tensor = None + p2p.send_backward(input_tensor_grad) + else: + input_tensor = p2p.send_backward_recv_forward(input_tensor_grad) + + for i in range(startup_steps): + input_tensor = input_buffers.pop(0) + output_tensor = output_buffers.pop(0) + + output_tensor_grad = p2p.recv_backward() + + input_tensor_grad = self._backward_step(input_tensor, output_tensor, + output_tensor_grad) + p2p.send_backward(input_tensor_grad) + + self._layers.allreduce_shared_weight_gradients() + + self.train_loss = self._broadcast_final_loss() + + # optimizer + self._optimizer_step() + return self.train_loss + + def eval_batch(self, data, compute_loss=False): + self._layers.eval() + self._compute_loss = compute_loss + + # save data for eval + self.data = data + # store data id for micro_batch + self.micro_batch_id = 0 + + # store total loss of entire batch + self.total_loss = None + + startup_steps = (self.num_stages - self.stage_id - 1) + startup_steps = min(startup_steps, self.accumulate_steps) + steady_steps = self.accumulate_steps - startup_steps + + input_buffers = [] + output_buffers = [] + + for step_id in range(startup_steps): + input_tensor = p2p.recv_forward() + + output_tensor = self._forward_step(input_tensor) + p2p.send_forward(output_tensor) + + input_buffers.append(input_tensor) + output_buffers.append(output_tensor) + + if steady_steps > 0: + input_tensor = p2p.recv_forward() + + for i in range(steady_steps): + last_iter = (i == (steady_steps - 1)) + + output_tensor = self._forward_step(input_tensor) + p2p.send_forward(output_tensor) + + input_buffers.append(input_tensor) + output_buffers.append(output_tensor) + + if not last_iter: + input_tensor = p2p.recv_forward() + + return self.total_loss if self._compute_loss else output_buffers + + def _forward_step(self, input_tensor): + if self.stage_id == 0: + input_tensor = self._load_micro_batch(self.micro_batch_id) + + output_tensor = self._layers.forward(input_tensor) + + if self.is_last_stage: + # train calculate loss for train + if self._compute_loss: + assert self._layers._loss_fn is not None, "loss function should exist to compute loss" + labels = self._load_micro_batch(self.micro_batch_id) + output_tensor = self._layers._loss_fn(output_tensor, labels) + assert isinstance( + output_tensor, paddle.Tensor + ), "Currently, loss_fn should obtain Paddle.Tensor dtype" + + if self.accumulate_steps > 1: + output_tensor = output_tensor / self.accumulate_steps + + if self.total_loss is None: + self.total_loss = paddle.zeros_like(output_tensor) + self.total_loss += output_tensor.detach() + + self.micro_batch_id += 1 + return output_tensor + + def _backward_step(self, input_tensor, output_tensor, output_tensor_grad): + if self.is_last_stage: + assert output_tensor_grad is None + if self.scaler: + paddle.autograd.backward(self.scaler.scale(output_tensor)) + else: + paddle.autograd.backward(output_tensor) + else: + if isinstance(output_tensor, tuple): + outputs = [t for t in output_tensor if not t.stop_gradient] + assert len(outputs) == len(output_tensor_grad) + paddle.autograd.backward( + tensors=outputs, + grad_tensors=[t for t in output_tensor_grad]) + else: + paddle.autograd.backward( + tensors=[output_tensor], grad_tensors=[output_tensor_grad]) + + input_tensor_grad = None + if input_tensor is not None: + if isinstance(input_tensor, tuple): + input_tensor_grad = tuple( + [t.grad for t in input_tensor if not t.stop_gradient]) + else: + input_tensor_grad = input_tensor.grad + return input_tensor_grad + + def _load_micro_batch(self, cache_id): + inputs = self.data + begin = cache_id * self.micro_batch_size + end = begin + self.micro_batch_size + + if self.is_first_stage: + assert len(inputs) == 2, "length of input should be 2" + if isinstance(inputs[0], tuple): + assert len( + inputs[0] + ) > 1, "If you use tuple for input data, it should have at least two inputs." + batch_size = inputs[0][0].shape[0] + assert self.micro_batch_size * self.accumulate_steps == batch_size, ( + "batch_size needs to be divisible by micro_batch_size. Currently, " + "batch_size = %d, micro_batch_size = %d, accumulate_steps = %d." + % + (batch_size, self.micro_batch_size, self.accumulate_steps)) + data = [input[begin:end, :].detach() for input in inputs[0]] + return tuple(data) + else: + batch_size = inputs[0].shape[0] + assert self.micro_batch_size * self.accumulate_steps == batch_size + return inputs[0][begin:end, :].detach() + elif self.is_last_stage: + assert len(inputs) == 2, "length of input should be 2" + if isinstance(inputs[1], tuple): + batch_size = inputs[1][0].shape[0] + assert self.micro_batch_size * self.accumulate_steps == batch_size + data = [input[begin:end, :].detach() for input in inputs[1]] + return tuple(data) + else: + batch_size = inputs[1].shape[0] + assert self.micro_batch_size * self.accumulate_steps == batch_size + return inputs[1][begin:end, :].detach() + else: + # No data input is required for other stages + inputs = None + + def _broadcast_final_loss(self): + if self.is_last_stage: + assert self.total_loss is not None, "train_batch() in last stage should obtain vaild loss" + loss = self.total_loss.detach() + paddle.distributed.broadcast( + loss, + src=self.global_rank, + use_calc_stream=True, + group=self.pp_group) + else: + loss = paddle.zeros(shape=[1], dtype="float32") + paddle.distributed.broadcast( + loss, + src=self._hcg.get_rank_from_stage(self.num_stages - 1), + use_calc_stream=True, + group=self.pp_group) + return loss + + def _optimizer_step(self): + if self.scaler: + self.scaler.minimize(self.optimizer, self.train_loss) + else: + self.optimizer.step() + + self.optimizer.clear_grad() + if self.lr_scheduler: + self.lr_scheduler.step() diff --git a/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_transformer.py b/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..c4c1e565068b2bc9254db3a9c23a6be5b98a5adc --- /dev/null +++ b/python/paddle/fluid/tests/unittests/hybrid_parallel_pp_transformer.py @@ -0,0 +1,190 @@ +# Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. + +from __future__ import division +from __future__ import print_function + +import unittest +import paddle +import numpy as np +import random +import paddle.distributed as dist +import paddle.distributed.fleet as fleet +from paddle.fluid import layers +import paddle.nn.functional as F +from paddle.distributed.fleet.meta_parallel import PipelineLayer, LayerDesc +from paddle.fluid.dygraph.layers import Layer +import paddle.nn as nn + + +def set_random_seed(seed, dp_id, rank_id): + """Set random seed for reproducability.""" + random.seed(seed) + np.random.seed(seed + dp_id) + paddle.seed(seed + dp_id) + + +batch_size = 8 +length = 8 +micro_batch_size = 2 +vocab_size = 128 +hidden_size = 16 +d_model = hidden_size +dim_feedforward = 4 * d_model + + +class EmbeddingNet(Layer): + def __init__(self): + super(EmbeddingNet, self).__init__() + self.word_embeddings = nn.Embedding(vocab_size, hidden_size) + self.position_embeddings = nn.Embedding(vocab_size, hidden_size) + + def forward(self, x): + attention_mask = paddle.tensor.triu( + (paddle.ones( + (length, length), dtype="float32") * -1e9), 1) + + no_used = paddle.ones((3, 3), dtype="int32") + + w_emb = self.word_embeddings(x) + p_emb = self.position_embeddings(x) + w_emb = w_emb + p_emb + + attention_mask.stop_gradient = True + no_used.stop_gradient = True + # need to fix bug of backward() + return w_emb, attention_mask, no_used, p_emb + + +class TransformerNet(Layer): + def __init__(self): + super(TransformerNet, self).__init__() + self.linear1 = nn.Linear(d_model, dim_feedforward) + self.linear2 = nn.Linear(dim_feedforward, d_model) + + self.q_proj = nn.Linear(d_model, d_model) + self.k_proj = nn.Linear(d_model, d_model) + self.v_proj = nn.Linear(d_model, d_model) + + self.norm1 = nn.LayerNorm(d_model, epsilon=1e-5) + + def forward(self, x, mask): + q = self.q_proj(x) + k = self.k_proj(x) + v = self.v_proj(x) + product = layers.matmul(x=q, y=k, transpose_y=True, alpha=d_model**-0.5) + + weights = F.softmax(product + mask) + # TODO(shenliang03) For save/load in PipeLineParallel, can’t support dropout temporarily. + # weights = F.dropout(weights, 0.2) + tgt = layers.matmul(weights, v) + residual = tgt + tgt = self.norm1(tgt) + tgt = residual + tgt + + out = self.linear2(F.gelu(self.linear1(tgt), approximate=True)) + return out + + +class EmbeddingPipe(EmbeddingNet): + def forward(self, x): + return super().forward(x) + + +class TransformerNetPipe(TransformerNet): + def forward(self, args): + x, mask, no_used, p_emb = args[0], args[1], args[2], args[3] + + output = super().forward(x, mask) + output = output + p_emb + mask.stop_gradient = True + return output, mask, no_used, p_emb + + +class CriterionPipe(Layer): + def __init__(self): + super(CriterionPipe, self).__init__() + + def forward(self, out, label): + loss = out.mean() + return loss + + +class ModelPipe(PipelineLayer): + def __init__(self, topology): + self.descs = [] + self.descs.append(LayerDesc(EmbeddingPipe)) + + for x in range(6): + self.descs.append(LayerDesc(TransformerNetPipe)) + + self.descs.append(lambda x: x[0]) + + super().__init__( + layers=self.descs, + loss_fn=CriterionPipe(), + topology=topology, + seg_method="layer:TransformerNetPipe") + + +class TestDistPPTraning(unittest.TestCase): + def setUp(self): + strategy = fleet.DistributedStrategy() + self.model_parallel_size = 1 + self.data_parallel_size = 1 + self.pipeline_parallel_size = 2 + strategy.hybrid_configs = { + "dp_degree": self.data_parallel_size, + "mp_degree": self.model_parallel_size, + "pp_degree": self.pipeline_parallel_size, + } + strategy.pipeline_configs = { + "accumulate_steps": batch_size // micro_batch_size, + "micro_batch_size": micro_batch_size + } + fleet.init(is_collective=True, strategy=strategy) + + def test_pp_model(self): + hcg = fleet.get_hybrid_communicate_group() + word_size = hcg.get_model_parallel_world_size() + dp_id = hcg.get_data_parallel_rank() + pp_id = hcg.get_stage_id() + rank_id = dist.get_rank() + topology = hcg.topology() + set_random_seed(1024, dp_id, rank_id) + + model = ModelPipe(topology) + scheduler = paddle.optimizer.lr.PiecewiseDecay( + boundaries=[2], values=[0.001, 0.002], verbose=True) + optimizer = paddle.optimizer.SGD(learning_rate=scheduler, + parameters=model.parameters()) + + model = fleet.distributed_model(model) + optimizer = fleet.distributed_optimizer(optimizer) + + for step_id in range(5): + x_data = np.random.randint(0, vocab_size, size=[batch_size, length]) + x = paddle.to_tensor(x_data) + x.stop_gradient = True + + e_loss = model.eval_batch([x, x], True) + loss = model.train_batch([x, x], optimizer, scheduler) + + # TODO(shenliang03) add utest for loss + if pp_id != 0: + np.testing.assert_allclose(loss.numpy(), e_loss.numpy()) + + +if __name__ == "__main__": + unittest.main()