diff --git a/paddle/fluid/operators/prelu_op.cc b/paddle/fluid/operators/prelu_op.cc index 8a18843a97263689efed737741c71dc19f593897..63d12f790ff5bd2b3b157f5ea57bbe8fc464121e 100644 --- a/paddle/fluid/operators/prelu_op.cc +++ b/paddle/fluid/operators/prelu_op.cc @@ -95,9 +95,17 @@ class PReluOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + auto input_data_type = + framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); + +#ifdef PADDLE_WITH_MKLDNN + if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { + return framework::OpKernelType(input_data_type, ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); + } +#endif + return framework::OpKernelType(input_data_type, ctx.GetPlace()); } }; @@ -126,6 +134,21 @@ There are modes: )DOC"); AddAttr("mode", "The mode for inputs to share weights.") .SetDefault("all"); + AddAttr("use_mkldnn", + "(bool, default false) Only used in mkldnn kernel") + .SetDefault(false) + .AsExtra(); + AddAttr( + "mkldnn_data_type", + "(string, default \"float32\"). Data type of mkldnn kernel") + .SetDefault("float32") + .InEnum({"float32", "bfloat16"}) + .AsExtra(); + AddAttr("is_test", + "(bool, default false) Set to true for inference only, false " + "for training. Some layers may run faster when this is true.") + .SetDefault(false) + .AsExtra(); } }; @@ -153,9 +176,17 @@ class PReluGradOp : public framework::OperatorWithKernel { protected: framework::OpKernelType GetExpectedKernelType( const framework::ExecutionContext &ctx) const override { - return framework::OpKernelType( - OperatorWithKernel::IndicateVarDataType(ctx, "X"), - ctx.device_context()); + auto input_data_type = + framework::OperatorWithKernel::IndicateVarDataType(ctx, "X"); + +#ifdef PADDLE_WITH_MKLDNN + if (this->CanMKLDNNBeUsed(ctx, input_data_type)) { + return framework::OpKernelType(input_data_type, ctx.GetPlace(), + framework::DataLayout::kMKLDNN, + framework::LibraryType::kMKLDNN); + } +#endif + return framework::OpKernelType(input_data_type, ctx.GetPlace()); } };