diff --git a/paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.h b/paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.h index 7e039d9852fc3895b23c5c96010a5f75c90577d9..81294dd568926f0c4e86c597f3f82f7b8b13cb62 100644 --- a/paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.h +++ b/paddle/fluid/framework/ir/mkldnn/fc_act_mkldnn_fuse_pass.h @@ -42,4 +42,4 @@ class FuseFCActOneDNNPass : public FusePassBase { } // namespace ir } // namespace framework -} // namespace paddlea +} // namespace paddle diff --git a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc index 1bde58f7c4edbb66c642e217e82e2fee6ffd999e..0526ae52b390305695d6537cb2c161391fc85ad0 100644 --- a/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc +++ b/paddle/fluid/operators/mkldnn/conv_mkldnn_op.cc @@ -218,13 +218,15 @@ class ConvMKLDNNHandlerT : dnnl::prop_kind::forward_training; float sum_scale = 1.0f; + float activation_scale = 1.0f; std::vector output_shift_scale; if (platform::is_int8()) - std::tie(sum_scale, output_shift_scale) = get_int8_scales(ctx); + std::tie(sum_scale, output_shift_scale, activation_scale) = + get_int8_scales(ctx); const dnnl::primitive_attr conv_attr = CreatePostOps( fuse_activation, fuse_alpha, fuse_beta, fuse_residual_conn, - output_shift_scale, sum_scale); // for INT8 only! + output_shift_scale, sum_scale, activation_scale); // for INT8 only! if (bias) { auto bias_tz = framework::vectorize(bias->dims()); @@ -432,7 +434,7 @@ class ConvMKLDNNHandlerT return bias_scale_tuple; } - std::tuple> get_int8_scales( + std::tuple, float> get_int8_scales( const framework::ExecutionContext& ctx) const { const auto* filter = ctx.Input("Filter"); const auto& weights_tz = framework::vectorize(filter->dims()); @@ -445,8 +447,14 @@ class ConvMKLDNNHandlerT const auto& scale_in_eltwise_data = ctx.Attr("Scale_in_eltwise"); auto scale_weights_data = ctx.Attr>("Scale_weights"); bool is_multi_channel = scale_weights_data.size() > 1; + bool has_activation = !ctx.Attr("fuse_activation").empty(); + float activation_scale = + force_fp32_output ? 1.0f : has_activation ? ctx.Attr("Scale_out") + : 1.0f; auto scale_out_data = - force_fp32_output ? 1.0f : ctx.Attr("Scale_out"); + force_fp32_output ? 1.0f : has_activation + ? 1.0f + : ctx.Attr("Scale_out"); float sum_scale = fuse_residual_conn ? scale_out_data / scale_in_eltwise_data : 1.0f; int count = @@ -468,13 +476,13 @@ class ConvMKLDNNHandlerT static_cast(scale_weights_data[i]))); } - return std::make_tuple(sum_scale, output_shift_scale); + return std::make_tuple(sum_scale, output_shift_scale, activation_scale); } dnnl::primitive_attr CreatePostOps( std::string fuse_activation, float fuse_alpha, float fuse_beta, bool fuse_residual_conn, const std::vector output_shift_scale = {}, - float sum_scale = 1.0f) { + float sum_scale = 1.0f, float activation_scale = 1.0f) { dnnl::primitive_attr conv_attr; dnnl::post_ops post_operations; if (output_shift_scale.size() > 0) { @@ -492,30 +500,34 @@ class ConvMKLDNNHandlerT } // Fusion with ReLU layer is executed through the PostOps feature. Create a // PostOps object and configure it to execute an eltwise relu operation. - constexpr float scale = 1.0f; if (fuse_activation == "relu" || fuse_activation == "leaky_relu") { - post_operations.append_eltwise(scale, dnnl::algorithm::eltwise_relu, - fuse_alpha, fuse_beta); + post_operations.append_eltwise(activation_scale, + dnnl::algorithm::eltwise_relu, fuse_alpha, + fuse_beta); } else if (fuse_activation == "relu6") { - post_operations.append_eltwise( - scale, dnnl::algorithm::eltwise_bounded_relu, fuse_alpha, fuse_beta); - } else if (fuse_activation == "swish") { - post_operations.append_eltwise(scale, dnnl::algorithm::eltwise_swish, + post_operations.append_eltwise(activation_scale, + dnnl::algorithm::eltwise_bounded_relu, fuse_alpha, fuse_beta); + } else if (fuse_activation == "swish") { + post_operations.append_eltwise(activation_scale, + dnnl::algorithm::eltwise_swish, fuse_alpha, + fuse_beta); } else if (fuse_activation == "hard_swish") { - post_operations.append_eltwise(scale, dnnl::algorithm::eltwise_hardswish, + post_operations.append_eltwise(activation_scale, + dnnl::algorithm::eltwise_hardswish, fuse_alpha, fuse_beta); } else if (fuse_activation == "hard_sigmoid") { - post_operations.append_eltwise(scale, dnnl::algorithm::eltwise_linear, + post_operations.append_eltwise(activation_scale, + dnnl::algorithm::eltwise_linear, fuse_alpha, fuse_beta); - post_operations.append_eltwise(scale, dnnl::algorithm::eltwise_clip, 0.0f, - 1.0f); + post_operations.append_eltwise(activation_scale, + dnnl::algorithm::eltwise_clip, 0.0f, 1.0f); } else if (fuse_activation == "gelu_tanh") { - post_operations.append_eltwise(scale, dnnl::algorithm::eltwise_gelu_tanh, - 0.0f, 0.0f); + post_operations.append_eltwise( + activation_scale, dnnl::algorithm::eltwise_gelu_tanh, 0.0f, 0.0f); } else if (fuse_activation == "gelu_erf") { - post_operations.append_eltwise(scale, dnnl::algorithm::eltwise_gelu_erf, - 0.0f, 0.0f); + post_operations.append_eltwise( + activation_scale, dnnl::algorithm::eltwise_gelu_erf, 0.0f, 0.0f); } conv_attr.set_post_ops(post_operations); return conv_attr; diff --git a/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py b/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py index 7dbd927874d1968e017b3d5059056b22610ef2a5..0251dd693f66f04ad129c0e53492a3cfac541a9e 100644 --- a/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py +++ b/python/paddle/fluid/contrib/slim/quantization/quant2_int8_mkldnn_pass.py @@ -426,6 +426,7 @@ class Quant2Int8MkldnnPass(object): graph = self._apply_pass(graph, 'conv_elementwise_add_mkldnn_fuse_pass') graph = self._apply_pass(graph, 'conv_relu_mkldnn_fuse_pass') graph = self._apply_pass(graph, 'conv_relu6_mkldnn_fuse_pass') + graph = self._apply_pass(graph, 'conv_hard_swish_mkldnn_fuse_pass') graph = self._apply_pass(graph, 'fc_fuse_pass', ['use_gpu', 'use_fc_padding'], [False, False]) graph = self._apply_pass(graph, 'repeated_fc_relu_fuse_pass') diff --git a/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_int8_mkldnn_op.py b/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_int8_mkldnn_op.py index 2cfb6146f3f55d1b939d3a5d3e6b141a517524e1..6fc01488c7ea0413ced572c044f7d16f09132983 100644 --- a/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_int8_mkldnn_op.py +++ b/python/paddle/fluid/tests/unittests/mkldnn/test_conv2d_int8_mkldnn_op.py @@ -23,13 +23,12 @@ from paddle.fluid.tests.unittests.test_conv2d_op import conv2d_forward_naive, Te def conv2d_forward_refer(input, filter, group, conv_param): - out, in_n, out_h, out_w, out_c = conv2d_forward_naive(input, filter, group, - conv_param) + out, _, _, _, _ = conv2d_forward_naive(input, filter, group, conv_param) return out -@unittest.skipIf(not core.supports_bfloat16(), - "place does not support BF16 evaluation") +@unittest.skipIf(not core.supports_int8(), + "place does not support int8 computation") class TestConv2DInt8Op(TestConv2DOp): def setUp(self): self.op_type = "conv2d" @@ -44,7 +43,7 @@ class TestConv2DInt8Op(TestConv2DOp): self.init_group() self.init_dilation() self.init_test_case() - self.init_fuse_relu() + self.init_fuse_activation() self.init_fuse_residual() self.init_data_type() @@ -53,73 +52,75 @@ class TestConv2DInt8Op(TestConv2DOp): 'pad': self.pad, 'dilation': self.dilations } - + # This implementation of convolution quantization is based on OneDNN documentation + # https://oneapi-src.github.io/oneDNN/dev_guide_int8_computations.html#doxid-dev-guide-int8-computations-1dg-i8-comp-s11 + inner_scale = 1. if self.fuse_activation != "" else self.scale_out + activation_scale = self.scale_out if self.fuse_activation != "" else 1. + scale_output_shift = (inner_scale / + (self.scale_in * self.scale_weights[0])) filter = np.random.random(self.filter_size).astype(self.weighttype) - if self.srctype == np.uint8: - input = np.random.randint(0, 10, - self.input_size).astype(self.srctype) - else: - input = np.random.randint(-5, 5, + + # When the Intel AVX2 or Intel AVX512 Instruction Set is used + # the reorder additionally scales the weights by 0.5 + # to overcome the potential overflow issue. If the processor supports VNNI instructions, + # modification of the weights is not necessary. + avx_scale = 0.5 if not core.supports_vnni( + ) and self.srctype == np.int8 else 1. + filter_int = np.round(filter * self.scale_weights[0] * + avx_scale).astype(np.int32) + scale_output_shift = scale_output_shift / avx_scale + + def conv2d_forward_refer_helper(input_): + return conv2d_forward_refer( + input_.astype(np.int32), filter_int, self.groups, + conv2d_param).astype(np.float32) * scale_output_shift + + def residual_helper(init_low, init_high, output_): + input_residual_ = np.random.randint( + init_low, init_high, + self.input_residual_size).astype(self.srctype) + return (output_ + input_residual_ * + (inner_scale / self.scale_in_eltwise)), input_residual_ + + if self.srctype == np.int8: + init_low, init_high = (-5, 5) + input = np.random.randint(init_low, init_high, self.input_size).astype(self.srctype) input_shift = (np.ones(self.input_size) * 128).astype(np.uint8) - if self.srctype == np.int8: - filter_int = np.round(filter * self.scale_weights[0] * - 0.5).astype(np.int32) - scale_output_shift = self.scale_out / (self.scale_in * - self.scale_weights[0] * 0.5) - output1 = conv2d_forward_refer( - np.round((input.astype(np.int32) + input_shift) * - self.scale_in).astype(np.int32), filter_int, - self.groups, - conv2d_param).astype(np.float32) * scale_output_shift - output2 = conv2d_forward_refer( - np.round((input_shift) * self.scale_in).astype(np.int32), - filter_int, self.groups, - conv2d_param).astype(np.float32) * scale_output_shift - if self.fuse_residual: - input_residual = np.random.randint( - -5, 5, self.input_residual_size).astype(self.srctype) - output_tmp = np.round(output1 - output2 + input_residual.astype( - self.srctype) * (self.scale_out / self.scale_in_eltwise)) - if self.fuse_activation == "relu": - output = np.maximum(output_tmp, 0).astype(self.dsttype) - else: - output = output_tmp.astype(self.dsttype) - else: - if self.fuse_activation == "relu": - output = np.maximum(np.round(output1 - output2), - 0).astype(self.dsttype) - else: - output = np.round(output1 - output2).astype(self.dsttype) + output1 = conv2d_forward_refer_helper( + np.round(input + input_shift).astype(np.int32)) + output2 = conv2d_forward_refer_helper( + np.round(input_shift).astype(np.int32)) + output = output1 - output2 + else: + init_low, init_high = (0, 10) + input = np.random.randint(init_low, init_high, + self.input_size).astype(self.srctype) + output = conv2d_forward_refer_helper(input) + if self.fuse_residual: + output, input_residual = residual_helper(init_low, init_high, + output) + + if self.fuse_activation == "": + pass + elif self.fuse_activation == "relu": + output = activation_scale * np.maximum(output, 0) + elif self.fuse_activation == "hard_swish": + output = activation_scale * output / 6. * np.minimum( + np.maximum(0, output + 3.), 6) + elif self.fuse_activation == "relu6": + output = activation_scale * np.maximum(0, np.minimum(6, output)) + elif self.fuse_activation == "swish": + output = activation_scale * output / (1. + np.exp(-1. * output)) + elif self.fuse_activation == "leaky_relu": + output = activation_scale * np.maximum(output, 0.02 * output) else: - filter_int = np.round(filter * - self.scale_weights[0]).astype(np.int32) - scale_output_shift = self.scale_out / (self.scale_in * - self.scale_weights[0]) - output1 = conv2d_forward_refer( - input.astype(np.int32), filter_int, self.groups, - conv2d_param).astype(np.float32) - output1_tmp = np.round(output1 * ( - self.scale_out / (self.scale_in * self.scale_weights[0]))) - - if self.fuse_residual: - input_residual = np.random.randint( - 0, 10, self.input_residual_size).astype(self.srctype) - output_tmp_res = np.round(output1 * (self.scale_out / ( - self.scale_in * self.scale_weights[ - 0])) + input_residual.astype(np.int32) * ( - self.scale_out / self.scale_in_eltwise)) - if self.fuse_activation == "relu": - output = np.maximum(output_tmp_res, 0).astype(self.dsttype) - else: - output = output_tmp_res.astype(self.dsttype) - else: - if self.fuse_activation == "relu": - output = np.maximum(output1_tmp, 0).astype(self.dsttype) - else: - output = output1_tmp.astype(self.dsttype) + raise NotImplementedError("test for " + self.fuse_activation + + " activation not implemented") + + output = np.round(output).astype(self.dsttype) self.inputs = { 'Input': @@ -144,6 +145,8 @@ class TestConv2DInt8Op(TestConv2DOp): 'Scale_weights': self.scale_weights, 'Scale_in_eltwise': self.scale_in_eltwise, 'fuse_activation': self.fuse_activation, + 'fuse_alpha': self.fuse_alpha, + 'fuse_beta': self.fuse_beta, 'fuse_residual_connection': self.fuse_residual, 'mkldnn_data_type': self.mkldnn_data_type } @@ -169,7 +172,7 @@ class TestConv2DInt8Op(TestConv2DOp): f_c = self.input_size[1] // self.groups self.input_residual_size = [1, 2, 3, 3] self.filter_size = [2, f_c, 3, 3] - self.scale_in = 1.0 + self.scale_in = 0.95 self.scale_out = 0.5 self.scale_weights = [10.0] self.scale_in_eltwise = 0.6 @@ -178,14 +181,16 @@ class TestConv2DInt8Op(TestConv2DOp): self.srctype = np.uint8 self.dsttype = np.int8 - def init_fuse_relu(self): + def init_fuse_activation(self): self.fuse_activation = "relu" + self.fuse_alpha = 0 + self.fuse_beta = 0 def init_fuse_residual(self): self.fuse_residual = True -#--------------------test conv2d u8 in and u8 out with residual fuse-------------------- +# --------------------test conv2d u8 in and u8 out with residual fuse-------------------- class TestConv2D(TestConv2DInt8Op): @@ -197,12 +202,40 @@ class TestConv2D(TestConv2DInt8Op): assert np.mod(self.input_size[1], self.groups) == 0 f_c = self.input_size[1] // self.groups self.filter_size = [6, f_c, 3, 3] - self.scale_in = 1.0 + self.scale_in = 0.95 self.scale_out = 0.5 self.scale_weights = [10.0] self.scale_in_eltwise = 0.6 +class TestWithHardSwish(TestConv2D): + def init_fuse_activation(self): + self.fuse_activation = "hard_swish" + self.fuse_alpha = 0 + self.fuse_beta = 0 + + +class TestWithRelu6(TestConv2D): + def init_fuse_activation(self): + self.fuse_activation = "relu6" + self.fuse_alpha = 6 + self.fuse_beta = 0 + + +class TestWithSwish(TestConv2D): + def init_fuse_activation(self): + self.fuse_activation = "swish" + self.fuse_alpha = 1 + self.fuse_beta = 0 + + +class TestWithLeakyRelu(TestConv2D): + def init_fuse_activation(self): + self.fuse_activation = "leaky_relu" + self.fuse_alpha = 0.02 + self.fuse_beta = 0 + + class TestWithPad(TestConv2D): def init_test_case(self): TestConv2D.init_test_case(self) @@ -224,7 +257,7 @@ class TestWithStride(TestConv2DInt8Op): assert np.mod(self.input_size[1], self.groups) == 0 f_c = self.input_size[1] // self.groups self.filter_size = [6, f_c, 3, 3] - self.scale_in = 1.0 + self.scale_in = 0.95 self.scale_out = 0.8 self.scale_weights = [10.0] self.scale_in_eltwise = 0.5 @@ -240,7 +273,7 @@ class TestWithDilations(TestConv2DInt8Op): assert np.mod(self.input_size[1], self.groups) == 0 f_c = self.input_size[1] // self.groups self.filter_size = [6, f_c, 3, 3] - self.scale_in = 1.0 + self.scale_in = 0.95 self.scale_out = 0.8 self.scale_weights = [10.0] self.scale_in_eltwise = 0.5 @@ -255,7 +288,7 @@ class TestWith1x1(TestConv2DInt8Op): assert np.mod(self.input_size[1], self.groups) == 0 f_c = self.input_size[1] // self.groups self.filter_size = [6, f_c, 1, 1] - self.scale_in = 1.0 + self.scale_in = 0.95 self.scale_out = 0.5 self.scale_weights = [12.0] self.scale_in_eltwise = 0.5 @@ -270,7 +303,7 @@ class TestWithInput1x1Filter1x1(TestConv2DInt8Op): assert np.mod(self.input_size[1], self.groups) == 0 f_c = self.input_size[1] // self.groups self.filter_size = [6, f_c, 1, 1] - self.scale_in = 1.0 + self.scale_in = 0.95 self.scale_out = 0.5 self.scale_weights = [10.0] self.scale_in_eltwise = 0.8 @@ -290,32 +323,32 @@ def init_data_type_with_fusion(self, input_dt, fuse_activation, fuse_residual): def create_test_int8_class(parent): - #--------------------test conv2d s8 in and u8 out-------------------- + # --------------------test conv2d s8 in and u8 out-------------------- class TestS8U8Case(parent): def init_data_type(self): init_data_type_with_fusion(self, np.int8, "relu", False) - #--------------------test conv2d s8 in and s8 out-------------------- + # --------------------test conv2d s8 in and s8 out-------------------- class TestS8S8Case(parent): def init_data_type(self): init_data_type_with_fusion(self, np.int8, "", False) - #--------------------test conv2d u8 in and s8 out-------------------- + # --------------------test conv2d u8 in and s8 out-------------------- class TestU8S8Case(parent): def init_data_type(self): init_data_type_with_fusion(self, np.uint8, "", False) - #--------------------test conv2d u8 in and u8 out without residual fuse-------------------- + # --------------------test conv2d u8 in and u8 out without residual fuse-------------------- class TestU8U8Case(parent): def init_data_type(self): init_data_type_with_fusion(self, np.uint8, "relu", False) - #--------------------test conv2d s8 in and s8 out with residual fuse-------------------- + # --------------------test conv2d s8 in and s8 out with residual fuse-------------------- class TestS8S8ResCase(parent): def init_data_type(self): init_data_type_with_fusion(self, np.int8, "", True) - #--------------------test conv2d u8 in and s8 out with residual fuse-------------------- + # --------------------test conv2d u8 in and s8 out with residual fuse-------------------- class TestU8S8ResCase(parent): def init_data_type(self): init_data_type_with_fusion(self, np.uint8, "", True) @@ -333,9 +366,9 @@ def create_test_int8_class(parent): TestS8S8Case.__name__ = cls_name_s8s8 TestU8S8Case.__name__ = cls_name_u8s8 TestU8U8Case.__name__ = cls_name_u8u8 - TestS8S8ResCase.__name__ = cls_name_s8s8_re_1 TestU8S8ResCase.__name__ = cls_name_u8s8_re_1 + globals()[cls_name_s8u8] = TestS8U8Case globals()[cls_name_s8s8] = TestS8S8Case globals()[cls_name_u8s8] = TestU8S8Case @@ -344,7 +377,7 @@ def create_test_int8_class(parent): globals()[cls_name_u8s8_re_1] = TestU8S8ResCase if os.name != 'nt': - #--------------------test conv2d s8 in and u8 out with residual fuse-------------------- + # --------------------test conv2d s8 in and u8 out with residual fuse-------------------- class TestS8U8ResCase(parent): def init_data_type(self): init_data_type_with_fusion(self, np.int8, "relu", True)