From cd9e2ae3a9f5926197d20833b5853d28041dcb6b Mon Sep 17 00:00:00 2001 From: PaddlePaddle-Gardener Date: Thu, 13 Jan 2022 14:28:22 +0800 Subject: [PATCH] mirgate_38831 --- .../ir/ipu/popart_canonicalization_pass.cc | 67 +++ .../popart_canonicalization/activation_ops.cc | 102 ++++ .../canonicalization_utils.cc | 196 +++++++ .../canonicalization_utils.h | 62 +++ .../ipu/popart_canonicalization/logic_ops.cc | 50 ++ .../ipu/popart_canonicalization/math_ops.cc | 378 +++++++++++++ .../ipu/popart_canonicalization/nn_ops.cc | 312 +++++++++++ .../ipu/popart_canonicalization/op_builder.cc | 217 ++++++++ .../ipu/popart_canonicalization/op_builder.h | 86 +++ .../ipu/popart_canonicalization/other_ops.cc | 65 +++ .../ipu/popart_canonicalization/search_ops.cc | 95 ++++ .../ipu/popart_canonicalization/tensor_ops.cc | 522 ++++++++++++++++++ 12 files changed, 2152 insertions(+) create mode 100644 paddle/fluid/platform/device/ipu/popart_canonicalization/other_ops.cc diff --git a/paddle/fluid/framework/ir/ipu/popart_canonicalization_pass.cc b/paddle/fluid/framework/ir/ipu/popart_canonicalization_pass.cc index e69de29bb2..d2d76f9a9a 100644 --- a/paddle/fluid/framework/ir/ipu/popart_canonicalization_pass.cc +++ b/paddle/fluid/framework/ir/ipu/popart_canonicalization_pass.cc @@ -0,0 +1,67 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/framework/ir/ipu/popart_canonicalization_pass.h" + +#include "paddle/fluid/framework/ir/pass_tester_helper.h" +#include "paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.h" + +namespace paddle { +namespace framework { +namespace ir { + +using framework::ir::Graph; +using framework::ir::Node; +using platform::ipu::SymbolHandler; + +void PopartCanonicalizationPass::ApplyImpl(ir::Graph* graph) const { + VLOG(10) << "enter PopartCanonicalizationPass::ApplyImpl"; + VLOG(10) << "Raw Graph: "; + VLOG(10) << DebugString(graph); + + auto nodes = graph->Nodes(); + for (auto* node : nodes) { + if (!node->IsOp()) { + continue; + } + auto* op = node->Op(); + auto op_type = op->Type(); + + ir::Node* new_node = nullptr; + SymbolHandler handler = platform::ipu::GetHandler(op_type); + if (handler) { + VLOG(11) << "Raw Paddle Node:"; + VLOG(11) << node->Op()->Proto()->DebugString(); + new_node = handler(graph, node); + VLOG(11) << "Post Popart Node:"; + VLOG(11) << new_node->Op()->Proto()->DebugString(); + + platform::ipu::ClearNode(node); + graph->RemoveNode(node); + } else { + LOG(ERROR) << "Can not find OpHandler for op_type: " << op_type; + } + } + + VLOG(10) << "Post Graph: "; + VLOG(10) << DebugString(graph); + VLOG(10) << "leave PopartCanonicalizationPass::ApplyImpl"; +} + +} // namespace ir +} // namespace framework +} // namespace paddle + +REGISTER_PASS(popart_canonicalization_pass, + paddle::framework::ir::PopartCanonicalizationPass); diff --git a/paddle/fluid/platform/device/ipu/popart_canonicalization/activation_ops.cc b/paddle/fluid/platform/device/ipu/popart_canonicalization/activation_ops.cc index e69de29bb2..fc2f1e476b 100644 --- a/paddle/fluid/platform/device/ipu/popart_canonicalization/activation_ops.cc +++ b/paddle/fluid/platform/device/ipu/popart_canonicalization/activation_ops.cc @@ -0,0 +1,102 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.h" +#include "paddle/fluid/platform/device/ipu/popart_canonicalization/op_builder.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace platform { +namespace ipu { +namespace { + +Node *activation_op_handler(Graph *graph, Node *node, const std::string &type) { + auto new_node = CreateBaseOp(graph, node, type, {GetInputVarNode("X", node)}, + node->outputs); + return new_node; +} + +Node *relu_handler(Graph *graph, Node *node) { + return activation_op_handler(graph, node, "popart_relu"); +} + +Node *tanh_handler(Graph *graph, Node *node) { + return activation_op_handler(graph, node, "popart_tanh"); +} + +Node *log_handler(Graph *graph, Node *node) { + return activation_op_handler(graph, node, "popart_log"); +} + +Node *sigmoid_handler(Graph *graph, Node *node) { + return activation_op_handler(graph, node, "popart_sigmoid"); +} + +Node *sqrt_handler(Graph *graph, Node *node) { + return activation_op_handler(graph, node, "popart_sqrt"); +} + +Node *gelu_handler(Graph *graph, Node *node) { + auto *op = node->Op(); + auto approximate_ = BOOST_GET_CONST(bool, op->GetAttr("approximate")); + if (approximate_) { + return activation_op_handler(graph, node, "popart_gelu_v2"); + } else { + auto sqrt2 = CreateConst(graph, node, {}, {}, + {{"value", std::vector{1.4142135623730951}}, + {"dims", std::vector{1}}, + {"dtype", GetOutputVarDtype(node)}}); + auto zero_point_five = + CreateConst(graph, node, {}, {}, {{"value", std::vector{0.5}}, + {"dims", std::vector{1}}, + {"dtype", GetOutputVarDtype(node)}}); + auto one = + CreateConst(graph, node, {}, {}, {{"value", std::vector{1}}, + {"dims", std::vector{1}}, + {"dtype", GetOutputVarDtype(node)}}); + auto div = + CreateBaseOp(graph, node, "popart_div", + {GetInputVarNode("X", node), sqrt2->outputs[0]}, {}, {}); + auto erf = + CreateBaseOp(graph, node, "popart_erf", {div->outputs[0]}, {}, {}); + auto add = CreateBaseOp(graph, node, "popart_add", + {erf->outputs[0], one->outputs[0]}, {}, {}); + auto mul1 = + CreateBaseOp(graph, node, "popart_mul", + {GetInputVarNode("X", node), add->outputs[0]}, {}, {}); + return CreateBaseOp(graph, node, "popart_mul", + {mul1->outputs[0], zero_point_five->outputs[0]}, + {GetOutputVarNode("Out", node)}, {}); + } +} + +Node *log_softmax_handler(Graph *graph, Node *node) { + auto axis = BOOST_GET_CONST(int, node->Op()->GetAttr("axis")); + auto new_softmax = CreateSoftmaxOpset11(graph, node, node->inputs, {}, axis); + return CreateBaseOp(graph, node, "popart_log", new_softmax->outputs, + node->outputs); +} + +REGISTER_HANDLER(relu, relu_handler); +REGISTER_HANDLER(tanh, tanh_handler); +REGISTER_HANDLER(log, log_handler); +REGISTER_HANDLER(sigmoid, sigmoid_handler); +REGISTER_HANDLER(sqrt, sqrt_handler); +REGISTER_HANDLER(gelu, gelu_handler); +REGISTER_HANDLER(log_softmax, log_softmax_handler); + +} // namespace +} // namespace ipu +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.cc b/paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.cc index e69de29bb2..3d22f75d34 100644 --- a/paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.cc +++ b/paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.cc @@ -0,0 +1,196 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.h" + +namespace paddle { +namespace platform { +namespace ipu { + +// This avoids the static initialisation order fiasco, +std::unordered_map &SymbolHandlers() { + static std::unordered_map symbol_handlers; + return symbol_handlers; +} + +bool RegisterHandler(const std::string &symbol, const SymbolHandler &handler) { + if (SymbolHandlers().count(symbol) != 0) { + LOG(WARNING) << "Trying to register popart handler twice for operator: " + << symbol; + return false; + } + bool new_handler = SymbolHandlers().emplace(symbol, handler).second; + return new_handler; +} + +// Return a pointer to a handler if one is registered for this kind of node or +// an empty std::function otherwise. +SymbolHandler GetHandler(const std::string &kind) { + auto it = SymbolHandlers().find(kind); + if (it != SymbolHandlers().end()) { + return it->second; + } + return {}; +} + +void ConnectNodes(Node *first_node, Node *next_node) { + first_node->outputs.push_back(next_node); + next_node->inputs.push_back(first_node); +} + +void DisConnectNodes(Node *first_node, Node *next_node) { + auto rm_by_value = [&](std::vector &vec, Node *n) { + vec.erase(std::remove(vec.begin(), vec.end(), n), vec.end()); + }; + rm_by_value(first_node->outputs, next_node); + rm_by_value(next_node->inputs, first_node); + rm_by_value(first_node->inputs, next_node); + rm_by_value(next_node->outputs, first_node); +} + +void ClearNode(Node *node) { + auto rm_by_value = [&](std::vector &vec, Node *n) { + vec.erase(std::remove(vec.begin(), vec.end(), n), vec.end()); + }; + for (auto *node_in : node->inputs) { + rm_by_value(node_in->outputs, node); + } + for (auto *node_out : node->outputs) { + rm_by_value(node_out->inputs, node); + } +} + +void CopyOpAttr(const std::string &attr_name, OpDesc *op, OpDesc *new_op, + bool override) { + if (new_op->HasAttr(attr_name) && !override) { + return; + } + if (op->HasAttr(attr_name)) { + VLOG(10) << "Copying attr: " << attr_name << " from " << op->Type() + << " to " << new_op->Type(); + new_op->SetAttr(attr_name, op->GetAttr(attr_name)); + new_op->Flush(); + } +} + +const int VarType2OnnxDtype(const int type) { + auto dtype = static_cast(type); + switch (dtype) { + case framework::proto::VarType::BOOL: + return static_cast(ONNXDataType::BOOL); + case framework::proto::VarType::INT16: + return static_cast(ONNXDataType::INT16); + case framework::proto::VarType::INT32: + return static_cast(ONNXDataType::INT32); + case framework::proto::VarType::INT64: + return static_cast(ONNXDataType::INT64); + case framework::proto::VarType::FP16: + return static_cast(ONNXDataType::FLOAT16); + case framework::proto::VarType::FP32: + return static_cast(ONNXDataType::FLOAT); + case framework::proto::VarType::FP64: + return static_cast(ONNXDataType::DOUBLE); + case framework::proto::VarType::UINT8: + return static_cast(ONNXDataType::UINT8); + case framework::proto::VarType::INT8: + return static_cast(ONNXDataType::INT8); + case framework::proto::VarType::BF16: + return static_cast(ONNXDataType::BFLOAT16); + case framework::proto::VarType::COMPLEX64: + return static_cast(ONNXDataType::COMPLEX64); + case framework::proto::VarType::COMPLEX128: + return static_cast(ONNXDataType::COMPLEX128); + default: + PADDLE_THROW( + platform::errors::Unimplemented("Unsupported data type: %d.", dtype)); + } +} + +const std::string VarType2PopStr(const int type) { + auto dtype = static_cast(type); + switch (dtype) { + case framework::proto::VarType::UINT8: + return "UINT8"; + case framework::proto::VarType::INT8: + return "INT8"; + case framework::proto::VarType::INT16: + return "INT16"; + case framework::proto::VarType::INT32: + return "INT32"; + case framework::proto::VarType::INT64: + return "INT64"; + case framework::proto::VarType::BOOL: + return "BOOL"; + case framework::proto::VarType::FP64: + return "DOUBLE"; + case framework::proto::VarType::FP32: + return "FLOAT"; + case framework::proto::VarType::FP16: + return "FLOAT16"; + default: + PADDLE_THROW( + paddle::platform::errors::Unavailable("Unsupported data type.")); + } +} + +Node *GetInputVarNode(const std::string &input_name, const Node *op_node, + const int id) { + auto var_name = op_node->Op()->Input(input_name).at(id); + return GetInputVarNodeByVarName(var_name, op_node); +} + +Node *GetOutputVarNode(const std::string &output_name, const Node *op_node, + const int id) { + auto var_name = op_node->Op()->Output(output_name).at(id); + return GetOutputVarNodeByVarName(var_name, op_node); +} + +Node *GetInputVarNodeByVarName(const std::string &var_name, + const Node *op_node) { + for (auto *var : op_node->inputs) { + if (var->Name() == var_name) { + return var; + } + } + return nullptr; +} + +Node *GetOutputVarNodeByVarName(const std::string &var_name, + const Node *op_node) { + for (auto *var : op_node->outputs) { + if (var->Name() == var_name) { + return var; + } + } + return nullptr; +} + +const bool is_float_equal(float a, float b, float eps) { + return std::fabs(a - b) <= eps; +} + +const int GetOutputVarDtype(const Node *node, const std::string &output_name) { + auto out_node = GetOutputVarNode(output_name, node); + PADDLE_ENFORCE_NOT_NULL(out_node, platform::errors::Unavailable( + "Node's out node does not exist.")); + auto var = out_node->Var(); + PADDLE_ENFORCE_NOT_NULL( + var, platform::errors::Unavailable("Node is not a variable.")); + auto proto_var_type = var->GetDataType(); + return VarType2OnnxDtype(proto_var_type); +} + +} // namespace ipu +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.h b/paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.h index e69de29bb2..5725ec767a 100644 --- a/paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.h +++ b/paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.h @@ -0,0 +1,62 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/framework/ir/graph.h" +#include "paddle/fluid/framework/ir/node.h" +#include "paddle/fluid/framework/ir/pass.h" +#include "paddle/fluid/platform/device/ipu/ipu_utils.h" + +namespace paddle { +namespace platform { +namespace ipu { + +#define REGISTER_HANDLER(name, func) \ + static bool __UNUSED_##name = \ + paddle::platform::ipu::RegisterHandler(#name, func) + +using SymbolHandler = std::function; + +std::unordered_map &SymbolHandlers(); + +bool RegisterHandler(const std::string &, const SymbolHandler &); + +SymbolHandler GetHandler(const std::string &); + +void ConnectNodes(Node *first_node, Node *next_node); +void DisConnectNodes(Node *first_node, Node *next_node); +void ClearNode(Node *node); +void CopyOpAttr(const std::string &attr_name, OpDesc *op, OpDesc *new_op, + bool override = false); + +const int VarType2OnnxDtype(const int type); +const std::string VarType2PopStr(const int type); + +Node *GetInputVarNode(const std::string &input_name, const Node *op_node, + const int id = 0); +Node *GetOutputVarNode(const std::string &output_name, const Node *op_node, + const int id = 0); +Node *GetInputVarNodeByVarName(const std::string &var_name, + const Node *op_node); +Node *GetOutputVarNodeByVarName(const std::string &var_name, + const Node *op_node); + +const bool is_float_equal(float a, float b, float eps = 1e-8); +const int GetOutputVarDtype(const Node *node, + const std::string &output_name = "Out"); + +} // namespace ipu +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/device/ipu/popart_canonicalization/logic_ops.cc b/paddle/fluid/platform/device/ipu/popart_canonicalization/logic_ops.cc index e69de29bb2..c980bb780c 100644 --- a/paddle/fluid/platform/device/ipu/popart_canonicalization/logic_ops.cc +++ b/paddle/fluid/platform/device/ipu/popart_canonicalization/logic_ops.cc @@ -0,0 +1,50 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.h" +#include "paddle/fluid/platform/device/ipu/popart_canonicalization/op_builder.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace platform { +namespace ipu { +namespace { + +Node *equal_handler(Graph *graph, Node *node) { + auto new_node = CreateBaseOp( + graph, node, "popart_equal", + {GetInputVarNode("X", node), GetInputVarNode("Y", node)}, node->outputs); + return new_node; +} + +Node *logical_not_handler(Graph *graph, Node *node) { + return CreateBaseOp(graph, node, "popart_logical_not", + {GetInputVarNode("X", node)}, + {GetOutputVarNode("Out", node)}, {}); +} + +Node *greater_than_handler(Graph *graph, Node *node) { + return CreateBaseOp(graph, node, "popart_greater", + {GetInputVarNode("X", node), GetInputVarNode("Y", node)}, + {GetOutputVarNode("Out", node)}, {}); +} + +REGISTER_HANDLER(equal, equal_handler); +REGISTER_HANDLER(logical_not, logical_not_handler); +REGISTER_HANDLER(greater_than, greater_than_handler); + +} // namespace +} // namespace ipu +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/device/ipu/popart_canonicalization/math_ops.cc b/paddle/fluid/platform/device/ipu/popart_canonicalization/math_ops.cc index e69de29bb2..67012e8d4b 100644 --- a/paddle/fluid/platform/device/ipu/popart_canonicalization/math_ops.cc +++ b/paddle/fluid/platform/device/ipu/popart_canonicalization/math_ops.cc @@ -0,0 +1,378 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.h" +#include "paddle/fluid/platform/device/ipu/popart_canonicalization/op_builder.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace platform { +namespace ipu { +namespace { + +Node *mean_handler(Graph *graph, Node *node) { + return CreateBaseOp(graph, node, "popart_reducemean", + {GetInputVarNode("X", node)}, + {GetOutputVarNode("Out", node)}, + { + {"keepdims", int64_t{0}}, + }); +} + +Node *pow_handler(Graph *graph, Node *node) { + auto *op = node->Op(); + if (op->HasInput("FactorTensor") && !op->Input("FactorTensor").empty()) { + return CreateBaseOp( + graph, node, "popart_pow", + {GetInputVarNode("X", node), GetInputVarNode("FactorTensor", node)}, + node->outputs); + } else { + // Op(pow) -> Op(Constant)->Var(const_out)->Op(Pow) + auto value_ = BOOST_GET_CONST(float, op->GetAttr("factor")); + auto attrs = + MakeConstAttrMapFromValue(value_, {1}, GetOutputVarDtype(node)); + + auto new_node_const = CreateConst(graph, node, {}, {}, attrs); + return CreateBaseOp(graph, node, "popart_pow", {GetInputVarNode("X", node), + new_node_const->outputs[0]}, + node->outputs); + } +} + +Node *mul_handler(Graph *graph, Node *node) { + auto *op = node->Op(); + auto x_num_col_dims = BOOST_GET_CONST(int, op->GetAttr("x_num_col_dims")); + auto y_num_col_dims = BOOST_GET_CONST(int, op->GetAttr("y_num_col_dims")); + auto x_shape_ = GetInputVarNode("X", node)->Var()->GetShape(); + auto y_shape_ = GetInputVarNode("Y", node)->Var()->GetShape(); + + // build the shape for reshape + std::vector reshape_shape_{}; + for (int left = 0; left < x_num_col_dims; left++) { + reshape_shape_.push_back(int64_t(x_shape_[left])); + } + for (int right = y_num_col_dims; right < y_shape_.size(); right++) { + reshape_shape_.push_back(int64_t(y_shape_[right])); + } + auto x_flatten = + CreateBaseOp(graph, node, "popart_flatten", {GetInputVarNode("X", node)}, + {}, {{"axis", int64_t(x_num_col_dims)}}); + auto y_flatten = + CreateBaseOp(graph, node, "popart_flatten", {GetInputVarNode("Y", node)}, + {}, {{"axis", int64_t(y_num_col_dims)}}); + auto matmul = + CreateBaseOp(graph, node, "popart_matmul", + {x_flatten->outputs[0], y_flatten->outputs[0]}, {}, {}); + + auto reshape_const = CreateConst( + graph, node, {}, {}, + {{"value", reshape_shape_}, + {"dims", std::vector{int64_t(reshape_shape_.size())}}, + {"dtype", ONNXDataType::INT64}}); + return CreateBaseOp(graph, node, "popart_reshape", + {matmul->outputs[0], reshape_const->outputs[0]}, + node->outputs, {}); +} + +Node *matmul_handler(Graph *graph, Node *node) { + auto *op = node->Op(); + auto transpose_x = BOOST_GET_CONST(bool, op->GetAttr("transpose_X")); + auto transpose_y = BOOST_GET_CONST(bool, op->GetAttr("transpose_Y")); + auto alpha = BOOST_GET_CONST(float, op->GetAttr("alpha")); + auto x_shape = GetInputVarNode("X", node)->Var()->GetShape(); + auto y_shape = GetInputVarNode("Y", node)->Var()->GetShape(); + + int x_rank = x_shape.size(); + std::vector perm; + if (x_rank == 1) { + perm = std::vector{0}; + } else if (x_rank == 2) { + return CreateGemm(graph, node, + {GetInputVarNode("X", node), GetInputVarNode("Y", node)}, + node->outputs, transpose_x, transpose_y, alpha); + } else if (x_rank == 3) { + perm = std::vector{0, 2, 1}; + } else if (x_rank == 4) { + perm = std::vector{0, 1, 3, 2}; + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "op matmul with input rank == %d", x_rank)); + } + + Node *x_node = GetInputVarNode("X", node); + Node *y_node = GetInputVarNode("Y", node); + if (transpose_x) { + x_node = CreateBaseOp(graph, node, "popart_transpose", + {GetInputVarNode("X", node)}, {}, {{"perm", perm}}); + x_node = x_node->outputs[0]; + } + if (transpose_y) { + y_node = CreateBaseOp(graph, node, "popart_transpose", + {GetInputVarNode("Y", node)}, {}, {{"perm", perm}}); + y_node = y_node->outputs[0]; + } + if (is_float_equal(alpha, 1.0)) { + return CreateBaseOp(graph, node, "popart_matmul", {x_node, y_node}, + node->outputs); + } else { + auto o_node = + CreateBaseOp(graph, node, "popart_matmul", {x_node, y_node}, {}); + auto attr = MakeConstAttrMapFromValue(alpha, {1}, GetOutputVarDtype(node)); + auto const_node = CreateConst(graph, node, {}, {}, attr); + return CreateBaseOp(graph, node, "popart_mul", + {o_node->outputs[0], const_node->outputs[0]}, + node->outputs); + } +} + +Node *sum_handler(Graph *graph, Node *node) { + return CreateBaseOp(graph, node, "popart_sum", node->inputs, node->outputs); +} + +Node *softmax_handler(Graph *graph, Node *node) { + auto *op = node->Op(); + int axis = -1; + if (op->HasAttr("axis")) { + axis = BOOST_GET_CONST(int, op->GetAttr("axis")); + } + return CreateSoftmaxOpset11(graph, node, node->inputs, node->outputs, axis); +} + +Node *scale_handler(Graph *graph, Node *node) { + auto *op = node->Op(); + auto scale_ = BOOST_GET_CONST(float, op->GetAttr("scale")); + auto bias_ = BOOST_GET_CONST(float, op->GetAttr("bias")); + auto bias_after_scale_ = + BOOST_GET_CONST(bool, op->GetAttr("bias_after_scale")); + auto data_type_ = GetInputVarNode("X", node)->Var()->GetDataType(); + + auto cast = CreateCast(graph, node, {GetInputVarNode("X", node)}, {}, + static_cast(framework::proto::VarType::FP32)); + + Node *result = nullptr; + if (op->HasInput("ScaleTensor") && !op->Input("ScaleTensor").empty()) { + auto scale = GetInputVarNode("ScaleTensor", node); + if (is_float_equal(bias_, 0.0)) { + result = CreateBaseOp(graph, node, "popart_mul", + {cast->outputs[0], scale}, {}, {}); + } else { + auto bias = CreateConst(graph, node, {}, {}, + {{"value", std::vector{bias_}}, + {"dims", std::vector{1}}, + {"dtype", ONNXDataType::FLOAT}}); + bias = bias->outputs[0]; + if (bias_after_scale_) { + auto mul = CreateBaseOp(graph, node, "popart_mul", + {cast->outputs[0], scale}, {}, {}); + result = CreateBaseOp(graph, node, "popart_add", + {mul->outputs[0], bias}, {}, {}); + } else { + auto add = CreateBaseOp(graph, node, "popart_add", + {cast->outputs[0], bias}, {}, {}); + result = CreateBaseOp(graph, node, "popart_mul", + {add->outputs[0], scale}, {}, {}); + } + } + } else { + if (is_float_equal(bias_, 0.0) && is_float_equal(scale_, 1.0)) { + return CreateBaseOp(graph, node, "popart_identity", + {GetInputVarNode("X", node)}, node->outputs, {}); + } else if (is_float_equal(scale_, 1.0)) { + auto bias = CreateConst(graph, node, {}, {}, + {{"value", std::vector{bias_}}, + {"dims", std::vector{1}}, + {"dtype", ONNXDataType::FLOAT}}); + result = CreateBaseOp(graph, node, "popart_add", + {cast->outputs[0], bias->outputs[0]}, {}, {}); + } else if (is_float_equal(bias_, 0.0)) { + auto scale = CreateConst(graph, node, {}, {}, + {{"value", std::vector{scale_}}, + {"dims", std::vector{1}}, + {"dtype", ONNXDataType::FLOAT}}); + result = CreateBaseOp(graph, node, "popart_mul", + {cast->outputs[0], scale->outputs[0]}, {}, {}); + } else { + auto bias = CreateConst(graph, node, {}, {}, + {{"value", std::vector{bias_}}, + {"dims", std::vector{1}}, + {"dtype", ONNXDataType::FLOAT}}); + auto scale = CreateConst(graph, node, {}, {}, + {{"value", std::vector{scale_}}, + {"dims", std::vector{1}}, + {"dtype", ONNXDataType::FLOAT}}); + if (bias_after_scale_) { + auto mul = CreateBaseOp(graph, node, "popart_mul", + {cast->outputs[0], scale->outputs[0]}, {}, {}); + result = CreateBaseOp(graph, node, "popart_add", + {mul->outputs[0], bias->outputs[0]}, {}, {}); + } else { + auto add = CreateBaseOp(graph, node, "popart_add", + {cast->outputs[0], bias->outputs[0]}, {}, {}); + result = CreateBaseOp(graph, node, "popart_mul", + {add->outputs[0], scale->outputs[0]}, {}, {}); + } + } + } + auto result_after_cast = + CreateCast(graph, node, result->outputs, node->outputs, + static_cast(data_type_)); + return result_after_cast; +} + +Node *cross_entropy2_handler(Graph *graph, Node *node) { + auto *op = node->Op(); + auto ignoreIndex = BOOST_GET_CONST(int, op->GetAttr("ignore_index")); + Node *new_cast = nullptr; + if (GetInputVarNode("Label", node)->Var()->GetDataType() == + framework::proto::VarType::INT32) { + new_cast = GetInputVarNode("Label", node); + } else { + auto new_cast = CreateCast(graph, node, {GetInputVarNode("Label", node)}, + {}, framework::proto::VarType::INT32); + new_cast = new_cast->outputs[0]; + } + auto label_shape_ = GetInputVarNode("Label", node)->Var()->GetShape(); + if (label_shape_[label_shape_.size() - 1] != 1) { + auto log = CreateBaseOp(graph, node, "popart_log", + {GetInputVarNode("X", node)}, {}, {}); + return CreateBaseOp( + graph, node, "popart_nllloss_v2", {log->outputs[0], new_cast}, + {GetOutputVarNode("Y", node)}, + { + {"reduction", 2}, // popart::ReductionType::NoReduction + {"ignoreIndex", ignoreIndex}, + {"inputIsLogProbability", true}, + }); + } else { + std::vector new_shape_{label_shape_[0]}; + auto const_before_loss = CreateBaseOp( + graph, node, "popart_constant", {}, {}, + {{"value", new_shape_}, + {"dims", + std::vector{static_cast(new_shape_.size())}}, + {"dtype", ONNXDataType::INT64}}); + + auto reshape_before_loss = + CreateBaseOp(graph, node, "popart_reshape", + {new_cast, const_before_loss->outputs[0]}, {}, {}); + + auto log = CreateBaseOp(graph, node, "popart_log", + {GetInputVarNode("X", node)}, {}, {}); + auto nllloss = CreateBaseOp( + graph, node, "popart_nllloss_v2", + {log->outputs[0], reshape_before_loss->outputs[0]}, {}, + { + {"reduction", 2}, // popart::ReductionType::NoReduction + {"ignoreIndex", ignoreIndex}, + {"inputIsLogProbability", true}, + }); + + auto const_after_loss = CreateBaseOp( + graph, node, "popart_constant", {}, {}, + {{"value", label_shape_}, + {"dims", + std::vector{static_cast(label_shape_.size())}}, + {"dtype", ONNXDataType::INT64}}); + + auto reshape_after_loss = + CreateBaseOp(graph, node, "popart_reshape", + {nllloss->outputs[0], const_after_loss->outputs[0]}, + {GetOutputVarNode("Y", node)}, {}); + return reshape_after_loss; + } +} + +Node *cumsum_handler(Graph *graph, Node *node) { + auto *op = node->Op(); + auto exclusive = BOOST_GET_CONST(bool, op->GetAttr("exclusive")); + int64_t popart_exclusive = 1 ? exclusive : 0; + auto reverse = BOOST_GET_CONST(bool, op->GetAttr("reverse")); + int64_t popart_reverse = 1 ? reverse : 0; + auto axis = BOOST_GET_CONST(int, op->GetAttr("axis")); + auto axis_node = + CreateConst(graph, node, {}, {}, {{"value", std::vector{axis}}, + {"dims", std::vector{1}}, + {"dtype", ONNXDataType::INT64}}); + return CreateBaseOp( + graph, node, "popart_cumsum", + {GetInputVarNode("X", node), axis_node->outputs[0]}, + {GetOutputVarNode("Out", node)}, + {{"exclusive", popart_exclusive}, {"reverse", popart_reverse}}); +} + +Node *matmul_v2_handler(Graph *graph, Node *node) { + auto *op = node->Op(); + auto transpose_x = BOOST_GET_CONST(bool, op->GetAttr("trans_x")); + auto transpose_y = BOOST_GET_CONST(bool, op->GetAttr("trans_y")); + auto x_shape = GetInputVarNode("X", node)->Var()->GetShape(); + auto y_shape = GetInputVarNode("Y", node)->Var()->GetShape(); + + std::vector perm; + int x_rank = x_shape.size(); + if (x_rank == 1) { + perm = std::vector{0}; + } else if (x_rank == 2) { + perm = std::vector{1, 0}; + } else if (x_rank == 3) { + perm = std::vector{0, 2, 1}; + } else if (x_rank == 4) { + perm = std::vector{0, 1, 3, 2}; + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "op matmul with input rank == %d", x_rank)); + } + + Node *x_node = GetInputVarNode("X", node); + Node *y_node = GetInputVarNode("Y", node); + + if (transpose_x) { + x_node = CreateBaseOp(graph, node, "popart_transpose", + {GetInputVarNode("X", node)}, {}, {{"perm", perm}}); + x_node = x_node->outputs[0]; + } + if (transpose_y) { + y_node = CreateBaseOp(graph, node, "popart_transpose", + {GetInputVarNode("Y", node)}, {}, {{"perm", perm}}); + y_node = y_node->outputs[0]; + } + + return CreateBaseOp(graph, node, "popart_matmul", {x_node, y_node}, + node->outputs); +} + +Node *arg_max_handler(Graph *graph, Node *node) { + auto *op = node->Op(); + auto axis = BOOST_GET_CONST(int64_t, op->GetAttr("axis")); + return CreateBaseOp(graph, node, "popart_argmax", + {GetInputVarNode("X", node)}, + {GetOutputVarNode("Out", node)}, + {{"axis", axis}, {"keepdims", int64_t{0}}}); +} + +REGISTER_HANDLER(mean, mean_handler); +REGISTER_HANDLER(pow, pow_handler); +REGISTER_HANDLER(mul, mul_handler); +REGISTER_HANDLER(matmul, matmul_handler); +REGISTER_HANDLER(sum, sum_handler); +REGISTER_HANDLER(softmax, softmax_handler); +REGISTER_HANDLER(scale, scale_handler); +REGISTER_HANDLER(cross_entropy2, cross_entropy2_handler); +REGISTER_HANDLER(cumsum, cumsum_handler); +REGISTER_HANDLER(matmul_v2, matmul_v2_handler); +REGISTER_HANDLER(arg_max, arg_max_handler); + +} // namespace +} // namespace ipu +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/device/ipu/popart_canonicalization/nn_ops.cc b/paddle/fluid/platform/device/ipu/popart_canonicalization/nn_ops.cc index e69de29bb2..b741200010 100644 --- a/paddle/fluid/platform/device/ipu/popart_canonicalization/nn_ops.cc +++ b/paddle/fluid/platform/device/ipu/popart_canonicalization/nn_ops.cc @@ -0,0 +1,312 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.h" +#include "paddle/fluid/platform/device/ipu/popart_canonicalization/op_builder.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace platform { +namespace ipu { +namespace { + +Node *conv2d_handler(Graph *graph, Node *node) { + auto *op = node->Op(); + auto dilations_ = BOOST_GET_CONST(std::vector, op->GetAttr("dilations")); + auto dilations = std::vector{dilations_.begin(), dilations_.end()}; + auto group_ = BOOST_GET_CONST(int, op->GetAttr("groups")); + auto pads_ = BOOST_GET_CONST(std::vector, op->GetAttr("paddings")); + if (pads_.size() == 2) { + pads_.push_back(pads_[0]); + pads_.push_back(pads_[1]); + } + auto pads = std::vector{pads_.begin(), pads_.end()}; + auto stride_ = BOOST_GET_CONST(std::vector, op->GetAttr("strides")); + auto stride = std::vector{stride_.begin(), stride_.end()}; + if (op->HasInput("Bias") && !op->Input("Bias").empty()) { + return CreateConv( + graph, node, + { + GetInputVarNode("Input", node), GetInputVarNode("Filter", node), + GetInputVarNode("Bias", node), + }, + node->outputs, dilations, group_, {}, pads, stride); + } else { + return CreateConv( + graph, node, + { + GetInputVarNode("Input", node), GetInputVarNode("Filter", node), + }, + node->outputs, dilations, group_, {}, pads, stride); + } +} + +Node *batch_norm_handler(Graph *graph, Node *node) { + auto *op = node->Op(); + std::vector inputs; + inputs.push_back(GetInputVarNode("X", node)); + inputs.push_back(GetInputVarNode("Scale", node)); + inputs.push_back(GetInputVarNode("Bias", node)); + inputs.push_back(GetInputVarNode("Mean", node)); + inputs.push_back(GetInputVarNode("Variance", node)); + int64_t num_outputs = 1; + std::vector outputs; + auto is_test_type = op->GetAttrType("is_test"); + bool is_test; + if (is_test_type == 0) { + // int + is_test = BOOST_GET_CONST(int, op->GetAttr("is_test")); + } else { + // bool + is_test = BOOST_GET_CONST(bool, op->GetAttr("is_test")); + } + outputs.push_back(GetOutputVarNode("Y", node)); + if (!is_test) { + outputs.push_back(GetOutputVarNode("MeanOut", node)); + outputs.push_back(GetOutputVarNode("VarianceOut", node)); + outputs.push_back(GetOutputVarNode("SavedMean", node)); + outputs.push_back(GetOutputVarNode("SavedVariance", node)); + num_outputs = 5; + } + // outputs.push_back(GetOutputVarNode("ReserveSpace", node)); + auto momentum = BOOST_GET_CONST(float, op->GetAttr("momentum")); + auto epsilon = BOOST_GET_CONST(float, op->GetAttr("epsilon")); + // data_layout + return CreateBaseOp(graph, node, "popart_batchnormalization", inputs, outputs, + { + {"momentum", momentum}, + {"epsilon", epsilon}, + {"num_outputs", num_outputs}, + }); +} + +Node *pool2d_handler(Graph *graph, Node *node) { + auto *op = node->Op(); + auto pooling_type = BOOST_GET_CONST(std::string, op->GetAttr("pooling_type")); + auto global_pooling = BOOST_GET_CONST(bool, op->GetAttr("global_pooling")); + if (global_pooling) { + if (pooling_type == "max") { + return CreateBaseOp(graph, node, "popart_globalmaxpool", node->inputs, + node->outputs); + } else if (pooling_type == "avg") { + return CreateBaseOp(graph, node, "popart_globalaveragepool", node->inputs, + node->outputs); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "op pool2d with unkonwn pooling_type: %s", pooling_type)); + } + } + if (op->HasAttr("padding_algorithm")) { + auto padding_algorithm = + BOOST_GET_CONST(std::string, op->GetAttr("padding_algorithm")); + if (padding_algorithm != "EXPLICIT") { + PADDLE_THROW(platform::errors::InvalidArgument( + "op pool2d with unkonwn padding_algorithm: %s", padding_algorithm)); + } + } + + auto ksize = BOOST_GET_CONST(std::vector, op->GetAttr("ksize")); + auto kernel_shape = std::vector{ksize.begin(), ksize.end()}; + auto ceil_mode_ = BOOST_GET_CONST(bool, op->GetAttr("ceil_mode")); + auto ceil_mode = int64_t(ceil_mode_ ? 1 : 0); + auto paddings = BOOST_GET_CONST(std::vector, op->GetAttr("paddings")); + auto pads = std::vector{paddings.begin(), paddings.end()}; + if (pads.size() == 2) { + pads.push_back(paddings[0]); + pads.push_back(paddings[1]); + } + auto strides_ = BOOST_GET_CONST(std::vector, op->GetAttr("strides")); + auto strides = std::vector{strides_.begin(), strides_.end()}; + if (pooling_type == "max") { + int64_t num_outputs = 1; + auto dilations = std::vector{}; + int64_t storage_order = 0; + return CreateBaseOp(graph, node, "popart_maxpool", node->inputs, + node->outputs, { + {"num_outputs", num_outputs}, + {"kernel_shape", kernel_shape}, + {"ceil_mode", ceil_mode}, + {"dilations", dilations}, + {"pads", pads}, + {"storage_order", storage_order}, + {"strides", strides}, + }); + } else if (pooling_type == "avg") { + int64_t count_include_pad = 0; + return CreateBaseOp(graph, node, "popart_averagepool", node->inputs, + node->outputs, + { + {"kernel_shape", kernel_shape}, + {"ceil_mode", ceil_mode}, + {"count_include_pad", count_include_pad}, + {"pads", pads}, + {"strides", strides}, + }); + } else { + PADDLE_THROW(platform::errors::InvalidArgument( + "op pool2d with unkonwn pooling_type: %s", pooling_type)); + } +} + +Node *group_norm_handler(Graph *graph, Node *node) { + auto *op = node->Op(); + auto epsilon_ = BOOST_GET_CONST(float, op->GetAttr("epsilon")); + auto groups_ = BOOST_GET_CONST(int, op->GetAttr("groups")); + auto groups = int64_t{groups_}; + auto attrs_ = AttributeMap{{"epsilon", epsilon_}, {"num_groups", groups}}; + + std::vector inputs_ = {GetInputVarNode("X", node), + GetInputVarNode("Scale", node), + GetInputVarNode("Bias", node)}; + std::vector outputs_ = {GetOutputVarNode("Y", node), + GetOutputVarNode("Mean", node), + GetOutputVarNode("Variance", node)}; + return CreateBaseOp(graph, node, "popart_groupnormalization_v2", inputs_, + outputs_, attrs_); +} + +Node *instance_norm_handler(Graph *graph, Node *node) { + auto *op = node->Op(); + auto epsilon_ = BOOST_GET_CONST(float, op->GetAttr("epsilon")); + auto attrs_ = AttributeMap{{"epsilon", epsilon_}}; + + std::vector inputs_ = {GetInputVarNode("X", node), + GetInputVarNode("Scale", node), + GetInputVarNode("Bias", node)}; + std::vector outputs_ = {GetOutputVarNode("Y", node)}; + return CreateBaseOp(graph, node, "popart_instancenormalization", inputs_, + outputs_, attrs_); +} + +Node *layer_norm_handler(Graph *graph, Node *node) { + auto *op = node->Op(); + auto begin_norm_axis_ = BOOST_GET_CONST(int, op->GetAttr("begin_norm_axis")); + auto input_shape_ = GetInputVarNode("X", node)->Var()->GetShape(); + auto epsilon_ = BOOST_GET_CONST(float, op->GetAttr("epsilon")); + int64_t groups_ = 1; + + auto groupnorm_attrs_ = + AttributeMap{{"epsilon", epsilon_}, {"num_groups", groups_}}; + + if (input_shape_.size() == 2) { + return CreateBaseOp( + graph, node, "popart_groupnormalization_v2", + {GetInputVarNode("X", node), GetInputVarNode("Scale", node), + GetInputVarNode("Bias", node)}, + {GetOutputVarNode("Y", node), GetOutputVarNode("Mean", node), + GetOutputVarNode("Variance", node)}, + groupnorm_attrs_); + } + + std::vector norm_shape_{1, 1}; + for (int i = 0; i < input_shape_.size(); i++) { + if (i < begin_norm_axis_) { + norm_shape_[0] *= input_shape_[i]; + } else { + norm_shape_[1] *= input_shape_[i]; + } + } + + auto attrs1 = AttributeMap{ + {"value", norm_shape_}, + {"dims", std::vector{static_cast(norm_shape_.size())}}, + {"dtype", ONNXDataType::INT64}}; + auto reshape1_const = + CreateBaseOp(graph, node, "popart_constant", {}, {}, attrs1); + auto new_node_reshape1 = CreateBaseOp( + graph, node, "popart_reshape", + {GetInputVarNode("X", node), reshape1_const->outputs[0]}, {}, {}); + + auto out_Y_ = MakeVarNode(graph, node); + CreateBaseOp(graph, node, "popart_groupnormalization_v2", + {new_node_reshape1->outputs[0], GetInputVarNode("Scale", node), + GetInputVarNode("Bias", node)}, + {out_Y_, GetOutputVarNode("Mean", node), + GetOutputVarNode("Variance", node)}, + groupnorm_attrs_); + + auto attrs2 = AttributeMap{ + {"value", input_shape_}, + {"dims", std::vector{static_cast(input_shape_.size())}}, + {"dtype", ONNXDataType::INT64}}; + auto reshape2_const = + CreateBaseOp(graph, node, "popart_constant", {}, {}, attrs2); + auto new_node_reshape2 = CreateBaseOp(graph, node, "popart_reshape", + {out_Y_, reshape2_const->outputs[0]}, + {GetOutputVarNode("Y", node)}, {}); + return new_node_reshape2; +} + +Node *dropout_handler(Graph *graph, Node *node) { + auto *op = node->Op(); + auto dropout_prob_ = BOOST_GET_CONST(float, op->GetAttr("dropout_prob")); + auto dropout_implementation_ = + BOOST_GET_CONST(std::string, op->GetAttr("dropout_implementation")); + auto is_test_type_ = op->GetAttrType("is_test"); + bool is_test_; + if (is_test_type_ == 0) { + // int + is_test_ = BOOST_GET_CONST(int, op->GetAttr("is_test")); + } else { + // bool + is_test_ = BOOST_GET_CONST(bool, op->GetAttr("is_test")); + } + + if (is_test_) { + if (dropout_implementation_ == "upscale_in_train") { + return CreateBaseOp(graph, node, "popart_identity", + {GetInputVarNode("X", node)}, + {GetOutputVarNode("Out", node)}, {}); + } else if (dropout_implementation_ == "downgrade_in_infer") { + auto scale = + CreateConst(graph, node, {}, {}, + {{"value", std::vector{1 - dropout_prob_}}, + {"dims", std::vector{1}}, + {"dtype", GetOutputVarDtype(node)}}); + return CreateBaseOp(graph, node, "popart_mul", + {GetInputVarNode("X", node), scale->outputs[0]}, + {GetOutputVarNode("Out", node)}, {}); + } else { + PADDLE_THROW( + platform::errors::InvalidArgument("Invalid dropout_implementation")); + } + } else { + if (dropout_implementation_ == "upscale_in_train") { + auto attrs_ = + AttributeMap{{"num_outputs", (int64_t)1}, {"ratio", dropout_prob_}}; + return CreateBaseOp(graph, node, "popart_dropout", + {GetInputVarNode("X", node)}, + {GetOutputVarNode("Out", node)}, attrs_); + } else if (dropout_implementation_ == "downgrade_in_infer") { + PADDLE_THROW(platform::errors::InvalidArgument( + "Do not support downgrade_in_infer with training")); + } else { + PADDLE_THROW( + platform::errors::InvalidArgument("Invalid dropout_implementation")); + } + } +} + +REGISTER_HANDLER(pool2d, pool2d_handler); +REGISTER_HANDLER(batch_norm, batch_norm_handler); +REGISTER_HANDLER(group_norm, group_norm_handler); +REGISTER_HANDLER(instance_norm, instance_norm_handler); +REGISTER_HANDLER(layer_norm, layer_norm_handler); +REGISTER_HANDLER(conv2d, conv2d_handler); +REGISTER_HANDLER(dropout, dropout_handler); + +} // namespace +} // namespace ipu +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/device/ipu/popart_canonicalization/op_builder.cc b/paddle/fluid/platform/device/ipu/popart_canonicalization/op_builder.cc index e69de29bb2..3ec1999edc 100644 --- a/paddle/fluid/platform/device/ipu/popart_canonicalization/op_builder.cc +++ b/paddle/fluid/platform/device/ipu/popart_canonicalization/op_builder.cc @@ -0,0 +1,217 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/platform/device/ipu/popart_canonicalization/op_builder.h" + +namespace paddle { +namespace platform { +namespace ipu { + +// singleton +static int var_count = 0; +static int op_count = 0; + +const std::string GenerateVarName() { + return std::string("_gen_var_") + std::to_string(var_count++); +} + +const std::string GenerateOpName() { + return std::string("_gen_op_") + std::to_string(op_count++); +} + +const std::string CreateOpIdentifyId(Node *node) { + // format: + // if has custom op_namescope: + // {op_namescope}/op_type/_gen_* + // else: + // {op_type}/{out_var0}/{out_var1}/.../_gen_* + // this name will be used as op name when exporting onnx model from popart + auto op_type = node->Name(); + std::string op_namescope; + if (node->Op()->HasAttr("op_namescope")) { + op_namescope = + BOOST_GET_CONST(std::string, node->Op()->GetAttr("op_namescope")); + } else { + op_namescope = "/"; + } + + if (op_namescope != "/") { + return {op_namescope + op_type + "/" + GenerateOpName()}; + } else { + std::string op_out = ""; + for (auto *out_node : node->outputs) { + op_out += "/"; + op_out += out_node->Name(); + } + return {op_type + op_out + "/" + GenerateOpName()}; + } +} + +Node *MakeVarNode(Graph *graph, Node *node) { + auto var_name = GenerateVarName(); + auto var_desc = std::make_unique(var_name); + + auto var = graph->CreateVarNode(var_desc.get()); + return var; +} + +Node *MakeOpNode(Graph *graph, Node *node, const std::string &type, + const std::vector &inputs, + const std::vector &outputs) { + auto op_desc = std::make_unique(); + op_desc->SetType(type); + auto op = graph->CreateOpNode(op_desc.get()); + + for (auto *in : inputs) { + ConnectNodes(in, op); + } + if (outputs.empty()) { + auto var = MakeVarNode(graph, node); + ConnectNodes(op, var); + } else { + for (auto *out : outputs) { + ConnectNodes(op, out); + } + } + + // i/o + std::vector input_names; + for (auto node : op->inputs) { + input_names.push_back(node->Name()); + } + op->Op()->SetInput("__inputs__", input_names); + std::vector output_names; + for (auto node : op->outputs) { + output_names.push_back(node->Name()); + } + op->Op()->SetOutput("__outputs__", output_names); + op->Op()->Flush(); + + return op; +} + +Node *CreateBaseOp(Graph *graph, Node *node, const std::string &type, + const std::vector &inputs, + const std::vector &outputs, + const AttributeMap &attrs) { + auto new_node = MakeOpNode(graph, node, type, inputs, outputs); + if (!attrs.empty()) { + new_node->Op()->SetAttrMap(attrs); + } + // deal special attr + if (!new_node->Op()->HasAttr(sIpuIndexAttr)) { + CopyOpAttr(sIpuIndexAttr, node->Op(), new_node->Op()); + } + if (!new_node->Op()->HasAttr(sIpuStageAttr)) { + CopyOpAttr(sIpuStageAttr, node->Op(), new_node->Op()); + } + if (node->Op()->HasAttr(sMatmulSerializeFactor)) { + CopyOpAttr(sMatmulSerializeFactor, node->Op(), new_node->Op()); + } + if (node->Op()->HasAttr(sMatmulSerializeMode)) { + CopyOpAttr(sMatmulSerializeMode, node->Op(), new_node->Op()); + } + { + new_node->Op()->SetAttr(sOpIdentifyIdAttr, CreateOpIdentifyId(node)); + new_node->Op()->Flush(); + } + + return new_node; +} + +Node *CreateConst(Graph *graph, Node *node, const std::vector &inputs, + const std::vector &outputs, + const AttributeMap &attrs) { + return CreateBaseOp(graph, node, "popart_constant", inputs, outputs, attrs); +} + +Node *CreateCast(Graph *graph, Node *node, const std::vector &inputs, + const std::vector &outputs, const int otype) { + auto to = VarType2PopStr(otype); + return CreateBaseOp(graph, node, "popart_cast", inputs, outputs, + {{"to", to}}); +} + +Node *CreateGemm(Graph *graph, Node *node, const std::vector &inputs, + const std::vector &outputs, int64_t transA, + int64_t transB, float alpha, float beta) { + return CreateBaseOp(graph, node, "popart_gemm", inputs, outputs, + { + {"alpha", alpha}, + {"beta", beta}, + {"transA", transA}, + {"transB", transB}, + }); +} + +Node *CreateReshape(Graph *graph, Node *node, const std::vector &inputs, + const std::vector &outputs, + const std::vector &oshape) { + auto attr = AttributeMap{ + {"value", oshape}, + {"dims", std::vector{static_cast(oshape.size())}}, + {"dtype", ONNXDataType::INT64}}; + auto new_node_const = + CreateBaseOp(graph, node, "popart_constant", {}, {}, attr); + auto new_node_reshape = + CreateBaseOp(graph, node, "popart_reshape", + {inputs[0], new_node_const->outputs[0]}, outputs); + return new_node_reshape; +} + +Node *CreateConv(Graph *graph, Node *node, const std::vector &inputs, + const std::vector &outputs, + const std::vector &dilations, int64_t group, + const std::vector &kernel_shape, + const std::vector &pads, + const std::vector &strides) { + auto attrs = AttributeMap{ + {"dilations", dilations}, {"group", group}, + {"kernel_shape", kernel_shape}, {"pads", pads}, + {"strides", strides}, + }; + return CreateBaseOp(graph, node, "popart_conv", inputs, outputs, attrs); +} + +Node *CreateSoftmaxOpset11(Graph *graph, Node *node, + const std::vector &inputs, + const std::vector &outputs, int64_t axis) { + PADDLE_ENFORCE_EQ(inputs.size(), 1, platform::errors::InvalidArgument( + "Softmax op only support one input")); + auto x_shape = inputs[0]->Var()->GetShape(); + int x_rank = x_shape.size(); + if (axis < 0) { + axis = axis + x_rank; + } + if (axis == x_rank - 1) { + return CreateBaseOp(graph, node, "popart_softmax", inputs, outputs, + {{"axis", int64_t{-1}}}); + } else { + auto perm = std::vector(x_rank); + std::iota(perm.begin(), perm.end(), 0); + perm[x_rank - 1] = axis; + perm[axis] = x_rank - 1; + auto new_transpose_pre = CreateBaseOp(graph, node, "popart_transpose", + inputs, {}, {{"perm", perm}}); + auto new_softmax = + CreateBaseOp(graph, node, "popart_softmax", new_transpose_pre->outputs, + {}, {{"axis", int64_t{-1}}}); + return CreateBaseOp(graph, node, "popart_transpose", new_softmax->outputs, + outputs, {{"perm", perm}}); + } +} + +} // namespace ipu +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/device/ipu/popart_canonicalization/op_builder.h b/paddle/fluid/platform/device/ipu/popart_canonicalization/op_builder.h index e69de29bb2..de3788e437 100644 --- a/paddle/fluid/platform/device/ipu/popart_canonicalization/op_builder.h +++ b/paddle/fluid/platform/device/ipu/popart_canonicalization/op_builder.h @@ -0,0 +1,86 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#pragma once + +#include "paddle/fluid/platform/device/ipu/ipu_names.h" +#include "paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.h" + +using paddle::framework::AttributeMap; +using paddle::framework::Attribute; + +namespace paddle { +namespace platform { +namespace ipu { + +template +AttributeMap MakeConstAttrMap(std::vector value, std::vector dims, + int dtype) { + return AttributeMap{{"value", value}, {"dims", dims}, {"dtype", dtype}}; +} + +template +AttributeMap MakeConstAttrMapFromValue(T v, std::vector dims, + int dtype) { + size_t size = 1; + for (auto &dim : dims) { + size *= dim; + } + return MakeConstAttrMap(std::vector(size, v), dims, dtype); +} + +const std::string GenerateVarName(); +const std::string CreateOpIdentifyId(Node *node); + +Node *MakeVarNode(Graph *graph, Node *node); +Node *MakeOpNode(Graph *graph, Node *node, const std::string &type, + const std::vector &inputs, + const std::vector &outputs); + +Node *CreateBaseOp(Graph *graph, Node *node, const std::string &type, + const std::vector &inputs, + const std::vector &outputs, + const AttributeMap &attrs = {}); + +Node *CreateConst(Graph *graph, Node *node, const std::vector &inputs, + const std::vector &outputs, + const AttributeMap &attrs); + +// otype is framework::proto::VarType::Type +Node *CreateCast(Graph *graph, Node *node, const std::vector &inputs, + const std::vector &outputs, const int otype); + +Node *CreateGemm(Graph *graph, Node *node, const std::vector &inputs, + const std::vector &outputs, int64_t transA = 0, + int64_t transB = 0, float alpha = 1.0f, float beta = 1.0f); + +Node *CreateReshape(Graph *graph, Node *node, const std::vector &inputs, + const std::vector &outputs, + const std::vector &oshape); + +Node *CreateConv(Graph *graph, Node *node, const std::vector &inputs, + const std::vector &outputs, + const std::vector &dilations = {1, 1}, + int64_t group = 1, + const std::vector &kernel_shape = {}, + const std::vector &pads = {0, 0, 0, 0}, + const std::vector &strides = {1, 1}); + +Node *CreateSoftmaxOpset11(Graph *graph, Node *node, + const std::vector &inputs, + const std::vector &outputs, int64_t axis); + +} // namespace ipu +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/device/ipu/popart_canonicalization/other_ops.cc b/paddle/fluid/platform/device/ipu/popart_canonicalization/other_ops.cc new file mode 100644 index 0000000000..0919afef4d --- /dev/null +++ b/paddle/fluid/platform/device/ipu/popart_canonicalization/other_ops.cc @@ -0,0 +1,65 @@ +// Copyright (c) 2022 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.h" +#include "paddle/fluid/platform/device/ipu/popart_canonicalization/op_builder.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace platform { +namespace ipu { +namespace { + +Node *custom_op_handler(Graph *graph, Node *node) { + auto *op = node->Op(); + auto attrs = op->GetAttrMap(); + attrs.insert({"__op_type", node->Op()->Type()}); + auto new_node = CreateBaseOp(graph, node, "popart_custom_op", node->inputs, + node->outputs, attrs); + return new_node; +} + +Node *print_handler(Graph *graph, Node *node) { + auto *op = node->Op(); + auto print_phase = BOOST_GET_CONST(std::string, op->GetAttr("print_phase")); + int64_t print_gradient = 0; + if (print_phase != "forward") { + print_gradient = 1; + } + auto title = BOOST_GET_CONST(std::string, op->GetAttr("message")); + if (title.empty()) { + title = GetInputVarNode("In", node)->Var()->Name(); + } + auto attrs = + AttributeMap{{"print_gradient", print_gradient}, {"title", title}}; + return CreateBaseOp(graph, node, "popart_printtensor", node->inputs, + node->outputs, attrs); +} + +Node *popart_optimizer_handler(Graph *graph, Node *node) { return nullptr; } + +Node *checkpointoutput_handler(Graph *graph, Node *node) { + return CreateBaseOp(graph, node, "popart_checkpointoutput", node->inputs, + node->outputs); +} + +REGISTER_HANDLER(custom_op, custom_op_handler); +REGISTER_HANDLER(print, print_handler); +REGISTER_HANDLER(popart_optimizer, popart_optimizer_handler); +REGISTER_HANDLER(checkpointoutput, checkpointoutput_handler); + +} // namespace +} // namespace ipu +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/device/ipu/popart_canonicalization/search_ops.cc b/paddle/fluid/platform/device/ipu/popart_canonicalization/search_ops.cc index e69de29bb2..662660c23b 100644 --- a/paddle/fluid/platform/device/ipu/popart_canonicalization/search_ops.cc +++ b/paddle/fluid/platform/device/ipu/popart_canonicalization/search_ops.cc @@ -0,0 +1,95 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.h" +#include "paddle/fluid/platform/device/ipu/popart_canonicalization/op_builder.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace platform { +namespace ipu { +namespace { + +Node *topk_handler(Graph *graph, Node *node) { + auto *op = node->Op(); + auto attrs = AttributeMap{}; + + int axis_ = -1; + if (op->HasAttr("axis")) { + axis_ = BOOST_GET_CONST(int, op->GetAttr("axis")); + } + if (axis_ == -1) { + auto shape = GetInputVarNode("X", node)->Var()->GetShape(); + int rank = shape.size(); + if (rank < 1) { + PADDLE_THROW(platform::errors::InvalidArgument( + "The dimension of the shape of topK input should be large than 1")); + } + axis_ = rank - 1; + } + int64_t axis = int64_t{axis_}; + attrs.emplace("axis", axis); + + bool largest = true; + if (op->HasAttr("largest")) { + largest = BOOST_GET_CONST(bool, op->GetAttr("largest")); + } + if (largest) { + // defaults to 1, largest values + attrs.emplace("largest", 1); + } else { + attrs.emplace("largest", 0); + } + + bool sorted = true; + if (op->HasAttr("sorted")) { + sorted = BOOST_GET_CONST(bool, op->GetAttr("sorted")); + } + if (sorted) { + // defaults to 1, sorted results + attrs.emplace("sorted", 1); + } else { + attrs.emplace("sorted", 0); + } + + Node *var_x = GetInputVarNode("X", node); + Node *var_k = nullptr; + if (op->HasInput("K") && !op->Input("K").empty()) { + var_k = GetInputVarNode("K", node); + } else { + auto k = BOOST_GET_CONST(int, op->GetAttr("k")); + auto *op_k = + CreateConst(graph, node, {}, {}, {{"value", std::vector{k}}, + {"dims", std::vector{1}}, + {"dtype", ONNXDataType::INT64}}); + var_k = op_k->outputs[0]; + } + + auto *var_i = MakeVarNode(graph, node); + CreateBaseOp(graph, node, "popart_topk", {var_x, var_k}, + {GetOutputVarNode("Out", node), var_i}, + {{"axis", int64_t{axis}}, + {"largest", int64_t{largest}}, + {"sorted", int64_t{sorted}}}); + return CreateCast(graph, node, {var_i}, {GetOutputVarNode("Indices", node)}, + static_cast(framework::proto::VarType::INT32)); +} + +REGISTER_HANDLER(top_k, topk_handler); +REGISTER_HANDLER(top_k_v2, topk_handler); + +} // namespace +} // namespace ipu +} // namespace platform +} // namespace paddle diff --git a/paddle/fluid/platform/device/ipu/popart_canonicalization/tensor_ops.cc b/paddle/fluid/platform/device/ipu/popart_canonicalization/tensor_ops.cc index e69de29bb2..296668890e 100644 --- a/paddle/fluid/platform/device/ipu/popart_canonicalization/tensor_ops.cc +++ b/paddle/fluid/platform/device/ipu/popart_canonicalization/tensor_ops.cc @@ -0,0 +1,522 @@ +// Copyright (c) 2021 PaddlePaddle Authors. All Rights Reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#include "paddle/fluid/platform/device/ipu/popart_canonicalization/canonicalization_utils.h" +#include "paddle/fluid/platform/device/ipu/popart_canonicalization/op_builder.h" +#include "paddle/fluid/platform/enforce.h" + +namespace paddle { +namespace platform { +namespace ipu { +namespace { + +Node *fill_constant_handler(Graph *graph, Node *node) { + auto *op = node->Op(); + if (op->HasInput("ShapeTensor") && !op->Input("ShapeTensor").empty()) { + PADDLE_THROW( + platform::errors::Unimplemented("op fill_constant with ShapeTensor")); + } + auto dtype_ = BOOST_GET_CONST(int, op->GetAttr("dtype")); + auto dtype = VarType2OnnxDtype(dtype_); + auto dims = BOOST_GET_CONST(std::vector, op->GetAttr("shape")); + auto value_ = BOOST_GET_CONST(float, op->GetAttr("value")); + size_t size = 1; + for (auto &dim : dims) { + size *= dim; + } + Attribute value; + switch (dtype_) { + case framework::proto::VarType::FP32: + value = std::vector(size, value_); + break; + case framework::proto::VarType::FP64: + value = std::vector(size, value_); + break; + case framework::proto::VarType::INT32: + value = std::vector(size, value_); + break; + case framework::proto::VarType::INT64: + value = std::vector(size, value_); + break; + default: + PADDLE_THROW( + platform::errors::Unimplemented("fill_constant dtype: %d", dtype_)); + } + return CreateConst(graph, node, node->inputs, node->outputs, + AttributeMap{ + {"value", value}, {"dims", dims}, {"dtype", dtype}, + }); +} + +Node *gaussian_random_handler(Graph *graph, Node *node) { + auto *op = node->Op(); + auto shape = BOOST_GET_CONST(std::vector, op->GetAttr("shape")); + auto dtype_ = BOOST_GET_CONST(int, op->GetAttr("dtype")); + auto dtype = VarType2OnnxDtype(dtype_); + auto mean = BOOST_GET_CONST(float, op->GetAttr("mean")); + auto scale = BOOST_GET_CONST(float, op->GetAttr("std")); + // seed not work + auto seed_ = BOOST_GET_CONST(int, op->GetAttr("seed")); + auto seed = static_cast(seed_); + return CreateBaseOp(graph, node, "popart_randomnormal", node->inputs, + node->outputs, { + {"shape", shape}, + {"dtype", dtype}, + {"mean", mean}, + {"scale", scale}, + {"seed", seed}, + }); +} + +Node *uniform_random_handler(Graph *graph, Node *node) { + auto *op = node->Op(); + auto shape = BOOST_GET_CONST(std::vector, op->GetAttr("shape")); + auto dtype_ = BOOST_GET_CONST(int, op->GetAttr("dtype")); + auto dtype = VarType2OnnxDtype(dtype_); + auto high = BOOST_GET_CONST(float, op->GetAttr("max")); + auto low = BOOST_GET_CONST(float, op->GetAttr("min")); + // seed not work + auto seed_ = BOOST_GET_CONST(int, op->GetAttr("seed")); + auto seed = static_cast(seed_); + return CreateBaseOp(graph, node, "popart_randomuniform", node->inputs, + node->outputs, { + {"shape", shape}, + {"dtype", dtype}, + {"high", high}, + {"low", low}, + {"seed", seed}, + }); +} + +Node *transpose_handler(Graph *graph, Node *node) { + auto *op = node->Op(); + + auto axis_ = BOOST_GET_CONST(std::vector, op->GetAttr("axis")); + std::vector perm(axis_.begin(), axis_.end()); + auto attrs = AttributeMap{{"perm", perm}}; + + auto new_node_transpose = + CreateBaseOp(graph, node, "popart_transpose", node->inputs, + {GetOutputVarNode("Out", node)}, attrs); + return new_node_transpose; +} + +Node *reshape_handler(Graph *graph, Node *node) { + auto *op = node->Op(); + auto shape_ = BOOST_GET_CONST(std::vector, op->GetAttr("shape")); + std::vector shape(shape_.begin(), shape_.end()); + auto attrs = AttributeMap{ + {"value", shape}, + {"dims", std::vector{static_cast(shape.size())}}, + {"dtype", ONNXDataType::INT64}}; + auto new_node_const = + CreateBaseOp(graph, node, "popart_constant", {}, {}, attrs); + + auto new_node_reshape = + CreateBaseOp(graph, node, "popart_reshape", + {GetInputVarNode("X", node), new_node_const->outputs[0]}, + {GetOutputVarNode("Out", node)}, {}); + return new_node_reshape; +} + +Node *flatten2_handler(Graph *graph, Node *node) { + auto *op = node->Op(); + auto axis = BOOST_GET_CONST(int, op->GetAttr("axis")); + return CreateBaseOp( + graph, node, "popart_flatten", {GetInputVarNode("X", node)}, + {GetOutputVarNode("Out", node)}, {{"axis", int64_t(axis)}}); +} + +Node *gather_handler(Graph *graph, Node *node) { + auto new_node_gather = + CreateBaseOp(graph, node, "popart_gather", + {GetInputVarNode("X", node), GetInputVarNode("Index", node)}, + {GetOutputVarNode("Out", node)}, {}); + return new_node_gather; +} + +Node *squeeze_handler(Graph *graph, Node *node) { + auto *op = node->Op(); + auto axes_ = BOOST_GET_CONST(std::vector, op->GetAttr("axes")); + auto input_shape_ = GetInputVarNode("X", node)->Var()->GetShape(); + + std::vector axes{axes_.begin(), axes_.end()}; + if (axes_.empty()) { + for (int i = 0; i < input_shape_.size(); i++) { + if (input_shape_[i] == 1) { + axes.push_back(i); + } + } + } + auto new_node_squeeze = + CreateBaseOp(graph, node, "popart_squeeze", {GetInputVarNode("X", node)}, + {GetOutputVarNode("Out", node)}, {{"axes", axes}}); + + return new_node_squeeze; +} + +Node *cast_handler(Graph *graph, Node *node) { + auto *op = node->Op(); + auto otype = BOOST_GET_CONST(int, op->GetAttr("out_dtype")); + auto new_node_cast = + CreateCast(graph, node, node->inputs, node->outputs, otype); + return new_node_cast; +} + +Node *lookup_table_op_handler(Graph *graph, Node *node, + const std::string &type) { + auto *op = node->Op(); + auto padding_idx_ = BOOST_GET_CONST(int64_t, op->GetAttr("padding_idx")); + auto w_shape_ = GetInputVarNode("W", node)->Var()->GetShape(); + auto table_size_ = w_shape_[0]; + auto emb_size_ = w_shape_[1]; + + Node *w_node; + if (padding_idx_ >= 0 && padding_idx_ < table_size_) { + std::vector const_value_(emb_size_, 0); + std::vector const_shape_{1, emb_size_}; + auto concat_const = + CreateConst(graph, node, {}, {}, {{"value", const_value_}, + {"dims", const_shape_}, + {"dtype", GetOutputVarDtype(node)}}); + auto axes = + CreateConst(graph, node, {}, {}, {{"value", std::vector{0}}, + {"dims", std::vector{1}}, + {"dtype", ONNXDataType::INT64}}); + auto step = + CreateConst(graph, node, {}, {}, {{"value", std::vector{1}}, + {"dims", std::vector{1}}, + {"dtype", ONNXDataType::INT64}}); + + auto left_start = + CreateConst(graph, node, {}, {}, {{"value", std::vector{0}}, + {"dims", std::vector{1}}, + {"dtype", ONNXDataType::INT64}}); + auto left_end = CreateConst(graph, node, {}, {}, + {{"value", std::vector{padding_idx_}}, + {"dims", std::vector{1}}, + {"dtype", ONNXDataType::INT64}}); + + auto right_start = CreateConst( + graph, node, {}, {}, {{"value", std::vector{padding_idx_ + 1}}, + {"dims", std::vector{1}}, + {"dtype", ONNXDataType::INT64}}); + auto right_end = CreateConst(graph, node, {}, {}, + {{"value", std::vector{table_size_}}, + {"dims", std::vector{1}}, + {"dtype", ONNXDataType::INT64}}); + + auto left_slice = + CreateBaseOp(graph, node, "popart_slice", + {GetInputVarNode("W", node), left_start->outputs[0], + left_end->outputs[0], axes->outputs[0], step->outputs[0]}, + {}, {}); + auto right_slice = CreateBaseOp( + graph, node, "popart_slice", + {GetInputVarNode("W", node), right_start->outputs[0], + right_end->outputs[0], axes->outputs[0], step->outputs[0]}, + {}, {}); + + if (padding_idx_ == 0) { + w_node = CreateBaseOp(graph, node, "popart_concat", + {concat_const->outputs[0], right_slice->outputs[0]}, + {}, {{"axis", int64_t(0)}}); + ClearNode(left_start); + ClearNode(left_end); + ClearNode(left_slice); + } else if (padding_idx_ == table_size_ - 1) { + w_node = CreateBaseOp(graph, node, "popart_concat", + {left_slice->outputs[0], concat_const->outputs[0]}, + {}, {{"axis", int64_t{0}}}); + ClearNode(right_start); + ClearNode(right_end); + ClearNode(right_slice); + } else { + w_node = CreateBaseOp(graph, node, "popart_concat", + {left_slice->outputs[0], concat_const->outputs[0], + right_slice->outputs[0]}, + {}, {{"axis", int64_t{0}}}); + } + w_node = w_node->outputs[0]; + } else { + w_node = GetInputVarNode("W", node); + } + + // lookup_table and lookup_table_v2 + auto ids = GetInputVarNode("Ids", node); + if (type == "v1") { + ids = CreateBaseOp(graph, node, "popart_squeeze", + {GetInputVarNode("Ids", node)}, {}, + {{"axes", std::vector{-1}}}); + ids = ids->outputs[0]; + } + + auto gather = CreateBaseOp(graph, node, "popart_gather", {w_node, ids}, + {GetOutputVarNode("Out", node)}, {}); + return gather; +} + +Node *lookup_table_handler(Graph *graph, Node *node) { + return lookup_table_op_handler(graph, node, "v1"); +} + +Node *lookup_table_v2_handler(Graph *graph, Node *node) { + return lookup_table_op_handler(graph, node, "v2"); +} + +Node *unsqueeze_handler(Graph *graph, Node *node) { + auto *op = node->Op(); + auto axes_ = BOOST_GET_CONST(std::vector, op->GetAttr("axes")); + std::vector axes{axes_.begin(), axes_.end()}; + auto new_node_unsqueeze = CreateBaseOp( + graph, node, "popart_unsqueeze", {GetInputVarNode("X", node)}, + {GetOutputVarNode("Out", node)}, {{"axes", axes}}); + + return new_node_unsqueeze; +} + +Node *concat_handler(Graph *graph, Node *node) { + auto *op = node->Op(); + int64_t axis_{BOOST_GET_CONST(int, op->GetAttr("axis"))}; + + auto new_node_concat = + CreateBaseOp(graph, node, "popart_concat", node->inputs, node->outputs, + {{"axis", axis_}}); + return new_node_concat; +} + +Node *stack_handler(Graph *graph, Node *node) { + auto *op = node->Op(); + int64_t axis_{BOOST_GET_CONST(int, op->GetAttr("axis"))}; + std::vector axes_{axis_}; + + std::vector unsqueeze_outputs_{}; + for (auto input : node->inputs) { + auto new_unsqueeze_node = CreateBaseOp(graph, node, "popart_unsqueeze", + {input}, {}, {{"axes", axes_}}); + unsqueeze_outputs_.push_back(new_unsqueeze_node->outputs[0]); + for (size_t i = 0; i < input->outputs.size(); ++i) { + if (input->outputs[i] == node) { + input->outputs[i] = new_unsqueeze_node; + break; + } + } + } + auto new_node_concat = + CreateBaseOp(graph, node, "popart_concat", unsqueeze_outputs_, + {GetOutputVarNode("Y", node)}, {{"axis", axis_}}); + return new_node_concat; +} + +Node *shape_handler(Graph *graph, Node *node) { + auto new_node = + CreateBaseOp(graph, node, "popart_shape", node->inputs, node->outputs); + return new_node; +} + +Node *slice_handler(Graph *graph, Node *node) { + auto *op = node->Op(); + Node *starts = nullptr; + if (op->HasInput("StartsTensor") && !op->Input("StartsTensor").empty()) { + starts = GetInputVarNode("StartsTensor", node); + } else { + auto starts_ = BOOST_GET_CONST(std::vector, op->GetAttr("starts")); + auto dim = int64_t(starts_.size()); + auto attr = MakeConstAttrMap(starts_, {dim}, ONNXDataType::INT32); + starts = CreateConst(graph, node, {}, {}, attr); + starts = starts->outputs[0]; + } + Node *ends = nullptr; + if (op->HasInput("EndsTensor") && !op->Input("EndsTensor").empty()) { + ends = GetInputVarNode("EndsTensor", node); + } else { + auto ends_ = BOOST_GET_CONST(std::vector, op->GetAttr("ends")); + auto dim = int64_t(ends_.size()); + auto attr = MakeConstAttrMap(ends_, {dim}, ONNXDataType::INT32); + ends = CreateConst(graph, node, {}, {}, attr); + ends = ends->outputs[0]; + } + Node *axes = nullptr; + { + auto axes_ = BOOST_GET_CONST(std::vector, op->GetAttr("axes")); + auto dim = int64_t(axes_.size()); + auto attr = MakeConstAttrMap(axes_, {dim}, ONNXDataType::INT32); + axes = CreateConst(graph, node, {}, {}, attr); + } + + auto decrease_axis_ = + BOOST_GET_CONST(std::vector, op->GetAttr("decrease_axis")); + auto input_shape_ = GetInputVarNode("Input", node)->Var()->GetShape(); + auto output_shape_ = GetOutputVarNode("Out", node)->Var()->GetShape(); + if (decrease_axis_.size() == 0) { + return CreateBaseOp( + graph, node, "popart_slice", + {GetInputVarNode("Input", node), starts, ends, axes->outputs[0]}, + node->outputs); + } else if (output_shape_ == std::vector{0} || + input_shape_.size() > output_shape_.size()) { + auto slice = CreateBaseOp( + graph, node, "popart_slice", + {GetInputVarNode("Input", node), starts, ends, axes->outputs[0]}, {}, + {}); + return CreateBaseOp(graph, node, "popart_squeeze", {slice->outputs[0]}, + {GetOutputVarNode("Out", node)}, + {{"axes", std::vector{decrease_axis_.begin(), + decrease_axis_.end()}}}); + } else { + return CreateBaseOp( + graph, node, "popart_slice", + {GetInputVarNode("Input", node), starts, ends, axes->outputs[0]}, + node->outputs); + } +} + +Node *expand_handler(Graph *graph, Node *node) { + auto *op = node->Op(); + if (op->HasInput("expand_times_tensor") && + !op->Input("expand_times_tensor").empty()) { + PADDLE_THROW( + platform::errors::Unimplemented("Expand op with expand_times_tensor")); + } + + Node *expand_times = nullptr; + if (op->HasInput("ExpandTimes") && !op->Input("ExpandTimes").empty()) { + // cast to int64 + expand_times = + CreateCast(graph, node, {GetInputVarNode("ExpandTimes", node)}, {}, + framework::proto::VarType::INT64); + } else { + auto expand_times_i32 = + BOOST_GET_CONST(std::vector, op->GetAttr("expand_times")); + auto expand_times_ = + std::vector{expand_times_i32.begin(), expand_times_i32.end()}; + auto dim = int64_t(expand_times_.size()); + auto attr = + MakeConstAttrMap(expand_times_, {dim}, ONNXDataType::INT64); + expand_times = CreateConst(graph, node, {}, {}, attr); + } + auto new_node = CreateBaseOp( + graph, node, "popart_tile", + {GetInputVarNode("X", node), expand_times->outputs[0]}, node->outputs); + return new_node; +} + +Node *assign_handler(Graph *graph, Node *node) { + return CreateBaseOp(graph, node, "popart_identity", + {GetInputVarNode("X", node)}, + {GetOutputVarNode("Out", node)}, {}); +} + +Node *fill_any_like_handler(Graph *graph, Node *node) { + auto *op = node->Op(); + auto value = BOOST_GET_CONST(float, op->GetAttr("value")); + auto x_shape = GetInputVarNode("X", node)->Var()->GetShape(); + auto dtype = BOOST_GET_CONST(int, op->GetAttr("dtype")); + auto x_dtype = static_cast(dtype); + size_t size = 1; + for (auto &dim : x_shape) { + size *= dim; + } + + Attribute out_value; + switch (x_dtype) { + case framework::proto::VarType::FP32: + out_value = std::vector(size, value); + break; + case framework::proto::VarType::FP64: + out_value = std::vector(size, value); + break; + case framework::proto::VarType::INT32: + out_value = std::vector(size, value); + break; + case framework::proto::VarType::INT64: + out_value = std::vector(size, value); + break; + case framework::proto::VarType::BOOL: + out_value = std::vector(size, value); + break; + default: + PADDLE_THROW( + platform::errors::Unimplemented("fill_any_like dtype: %d", x_dtype)); + } + return CreateConst(graph, node, node->inputs, node->outputs, + AttributeMap{ + {"value", out_value}, + {"dims", x_shape}, + {"dtype", VarType2OnnxDtype(dtype)}, + }); +} + +Node *one_hot_handler(Graph *graph, Node *node) { + auto *op = node->Op(); + auto depth = BOOST_GET_CONST(int, op->GetAttr("depth")); + auto allow_out_of_range = + BOOST_GET_CONST(bool, op->GetAttr("allow_out_of_range")); + if (allow_out_of_range) { + PADDLE_THROW(platform::errors::Unimplemented( + "Do not support allow_out_of_range=True")); + } else { + auto depth_tensor = CreateConst(graph, node, {}, {}, + {{"value", std::vector{depth}}, + {"dims", std::vector{1}}, + {"dtype", ONNXDataType::INT64}}); + auto value_tensor = + CreateConst(graph, node, {}, {}, {{"value", std::vector{0, 1}}, + {"dims", std::vector{2}}, + {"dtype", ONNXDataType::FLOAT}}); + return CreateBaseOp(graph, node, "popart_onehot", + {GetInputVarNode("X", node), depth_tensor->outputs[0], + value_tensor->outputs[0]}, + {GetOutputVarNode("Out", node)}, + {{"axis", int64_t{-1}}}); + } +} + +Node *split_handler(Graph *graph, Node *node) { + auto *op = node->Op(); + auto axis = BOOST_GET_CONST(int, op->GetAttr("axis")); + auto sections = BOOST_GET_CONST(std::vector, op->GetAttr("sections")); + return CreateBaseOp( + graph, node, "popart_split", {GetInputVarNode("X", node)}, node->outputs, + {{"num_outputs", int64_t(sections.size())}, + {"axis", int64_t(axis)}, + {"split", std::vector{sections.begin(), sections.end()}}}); +} + +REGISTER_HANDLER(fill_constant, fill_constant_handler); +REGISTER_HANDLER(gaussian_random, gaussian_random_handler); +REGISTER_HANDLER(uniform_random, uniform_random_handler); +REGISTER_HANDLER(transpose2, transpose_handler); +REGISTER_HANDLER(reshape2, reshape_handler); +REGISTER_HANDLER(flatten2, flatten2_handler); +REGISTER_HANDLER(gather, gather_handler); +REGISTER_HANDLER(squeeze2, squeeze_handler); +REGISTER_HANDLER(cast, cast_handler); +REGISTER_HANDLER(lookup_table, lookup_table_handler); +REGISTER_HANDLER(unsqueeze2, unsqueeze_handler); +REGISTER_HANDLER(concat, concat_handler); +REGISTER_HANDLER(stack, stack_handler); +REGISTER_HANDLER(shape, shape_handler); +REGISTER_HANDLER(slice, slice_handler); +REGISTER_HANDLER(expand, expand_handler); +REGISTER_HANDLER(assign, assign_handler); +REGISTER_HANDLER(fill_any_like, fill_any_like_handler); +REGISTER_HANDLER(lookup_table_v2, lookup_table_v2_handler); +REGISTER_HANDLER(split, split_handler); +REGISTER_HANDLER(one_hot, one_hot_handler); + +} // namespace +} // namespace ipu +} // namespace platform +} // namespace paddle -- Gitee