From d4da970ad663d4b4b620e2853d1b7d122e87f6aa Mon Sep 17 00:00:00 2001 From: piotrw Date: Mon, 13 Dec 2021 15:53:44 +0000 Subject: [PATCH 1/5] Unified oneDNN pooling implementation calls --- src/operator/contrib/adaptive_avg_pooling.cc | 4 +- src/operator/nn/dnnl/dnnl_pooling-inl.h | 37 +++++++++------ src/operator/nn/dnnl/dnnl_pooling.cc | 47 ++++++++++--------- src/operator/nn/pooling.cc | 24 +--------- .../dnnl/dnnl_quantized_pooling.cc | 4 +- 5 files changed, 52 insertions(+), 64 deletions(-) diff --git a/src/operator/contrib/adaptive_avg_pooling.cc b/src/operator/contrib/adaptive_avg_pooling.cc index 6af2fa02d66a..54ae6c7dc1f0 100644 --- a/src/operator/contrib/adaptive_avg_pooling.cc +++ b/src/operator/contrib/adaptive_avg_pooling.cc @@ -238,11 +238,9 @@ void AdaptiveAvgPoolComputeExCPU(const nnvm::NodeAttrs& attrs, oneDNN doesn't support adaptive pooling. Fallback is needed when padding is not equal 0; */ - const PoolingParam& param = nnvm::get(attrs.parsed); if (SupportDNNL(inputs[0]) && SupportDNNLAveragePooling(inputs[0], outputs[0])) { - const NDArray* workspace = nullptr; DNNL_OPCHECK_INIT(false, 1, inputs, outputs); - DNNLPoolingCompute(ctx, param, inputs[0], req[0], outputs[0], workspace, true); + DNNLRun(DNNLPoolingCompute, attrs, ctx, inputs, req, outputs); DNNL_OPCHECK_RUN(PoolingCompute, attrs, ctx, inputs, req, outputs); return; } diff --git a/src/operator/nn/dnnl/dnnl_pooling-inl.h b/src/operator/nn/dnnl/dnnl_pooling-inl.h index 15a544e38fd9..a8ec9da7b1dc 100644 --- a/src/operator/nn/dnnl/dnnl_pooling-inl.h +++ b/src/operator/nn/dnnl/dnnl_pooling-inl.h @@ -172,27 +172,34 @@ inline bool DNNLRequireWorkspace(const PoolingParam& param) { } typedef ParamOpSign DNNLPoolingSignature; -void DNNLPoolingCompute(const OpContext& ctx, - const PoolingParam& param, - const NDArray& in_data, - const OpReqType req, - const NDArray& out_data, - const NDArray* workspace, - const bool use_adaptive_pooling); - -void DNNLPoolingGradCompute(const OpContext& ctx, - const PoolingParam& param, - const NDArray& out_grad, - const NDArray& in_data, - const NDArray* workspace, - const OpReqType req, - const NDArray& in_grad); + +void DNNLPoolingGradCompute(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs); DNNLPoolingFwd& GetPoolingFwd(const PoolingParam& param, const bool is_train, const NDArray& data, const NDArray& output, const bool use_adaptive_pooling); + +template +void DNNLPoolingCompute(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& in_data, + const std::vector& req, + const std::vector& out_data) { + const PoolingParam& param = nnvm::get(attrs.parsed); + const NDArray* workspace = nullptr; + if (DNNLRequireWorkspace(param)) { + CHECK_GT(out_data.size(), 1U); + workspace = &out_data[1]; + } + auto& fwd = GetPoolingFwd(param, ctx.is_train, in_data[0], out_data[0], use_adaptive_pooling); + fwd.Execute(in_data[0], req[0], out_data[0], workspace); +} } // namespace op } // namespace mxnet #endif // MXNET_USE_ONEDNN == 1 diff --git a/src/operator/nn/dnnl/dnnl_pooling.cc b/src/operator/nn/dnnl/dnnl_pooling.cc index 418c832703ff..f85a86523a97 100644 --- a/src/operator/nn/dnnl/dnnl_pooling.cc +++ b/src/operator/nn/dnnl/dnnl_pooling.cc @@ -310,17 +310,6 @@ DNNLPoolingFwd& GetPoolingFwd(const PoolingParam& param, return it->second; } -void DNNLPoolingCompute(const OpContext& ctx, - const PoolingParam& param, - const NDArray& in_data, - const OpReqType req, - const NDArray& out_data, - const NDArray* workspace, - const bool use_adaptive_pooling) { - auto& fwd = GetPoolingFwd(param, ctx.is_train, in_data, out_data, use_adaptive_pooling); - fwd.Execute(in_data, req, out_data, workspace); -} - DNNLPoolingBwd::DNNLPoolingBwd(const dnnl::pooling_backward::primitive_desc& pdesc, bool with_ws) : with_workspace(with_ws), pd(pdesc) { bwd = std::make_shared(pd); @@ -384,22 +373,38 @@ DNNLPoolingBwd& GetPoolingBwd(const PoolingParam& param, return it->second; } -void DNNLPoolingGradCompute(const OpContext& ctx, - const PoolingParam& param, - const NDArray& out_grad, - const NDArray& in_data, - const NDArray* workspace, - const OpReqType req, - const NDArray& in_grad) { - if (req == kNullOp) { +void DNNLPoolingGradCompute(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + if (req[0] == kNullOp) { return; } + const PoolingParam& param = nnvm::get(attrs.parsed); + + const NDArray& out_grad = inputs[0]; + const NDArray* workspace = nullptr; + const NDArray* in_data = nullptr; + if (DNNLRequireWorkspace(param)) { + // The first two elements are the gradient of the outputs in forward. + // The third is the input of forward. + // The fourth and the fifth are the outputs of forward. + CHECK_EQ(inputs.size(), 5U); + in_data = &inputs[2]; + workspace = &inputs[4]; + } else { + CHECK_EQ(inputs.size(), 3U); + in_data = &inputs[1]; + } + const NDArray& in_grad = outputs[0]; + TmpMemMgr::Get()->Init(ctx.requested[0]); - auto& bwd = GetPoolingBwd(param, in_data, in_grad, out_grad); + auto& bwd = GetPoolingBwd(param, *in_data, in_grad, out_grad); auto diff_dst_mem = out_grad.GetDNNLDataReorder(bwd.pd.diff_dst_desc()); - auto diff_src_mem = CreateDNNLMem(in_grad, bwd.pd.diff_src_desc(), req); + auto diff_src_mem = CreateDNNLMem(in_grad, bwd.pd.diff_src_desc(), req[0]); dnnl_args_map_t args = { {DNNL_ARG_DIFF_DST, *diff_dst_mem}, {DNNL_ARG_DIFF_SRC, *diff_src_mem.second}, diff --git a/src/operator/nn/pooling.cc b/src/operator/nn/pooling.cc index 7b302ee9db73..7781c6a1ee90 100644 --- a/src/operator/nn/pooling.cc +++ b/src/operator/nn/pooling.cc @@ -294,7 +294,6 @@ void PoolingComputeExCPU(const nnvm::NodeAttrs& attrs, const std::vector& req, const std::vector& outputs) { const PoolingParam& param = nnvm::get(attrs.parsed); - const NDArray* workspace = nullptr; // Pooling does not currently support working with views if (inputs[0].IsView() || outputs[0].IsView()) { @@ -303,12 +302,8 @@ void PoolingComputeExCPU(const nnvm::NodeAttrs& attrs, } if (SupportDNNLPooling(param, inputs[0])) { - if (DNNLRequireWorkspace(param)) { - CHECK_GT(outputs.size(), 1U); - workspace = &outputs[1]; - } DNNL_OPCHECK_INIT(false, 1, inputs, outputs); - DNNLPoolingCompute(ctx, param, inputs[0], req[0], outputs[0], workspace, false); + DNNLRun(DNNLPoolingCompute, attrs, ctx, inputs, req, outputs); DNNL_OPCHECK_RUN(PoolingCompute, attrs, ctx, inputs, req, outputs); return; } @@ -329,23 +324,8 @@ void PoolingGradComputeExCPU(const nnvm::NodeAttrs& attrs, } if (SupportDNNLPooling(param, inputs[0])) { - const NDArray& out_grad = inputs[0]; - const NDArray* workspace = nullptr; - const NDArray* in_data = nullptr; - if (DNNLRequireWorkspace(param)) { - // The first two elements are the gradient of the outputs in forward. - // The third is the input of forward. - // The fourth and the fifth are the outputs of forward. - CHECK_EQ(inputs.size(), 5U); - in_data = &inputs[2]; - workspace = &inputs[4]; - } else { - CHECK_EQ(inputs.size(), 3U); - in_data = &inputs[1]; - } - const NDArray& in_grad = outputs[0]; DNNL_OPCHECK_INIT(true, outputs.size(), inputs, outputs); - DNNLPoolingGradCompute(ctx, param, out_grad, *in_data, workspace, req[0], in_grad); + DNNLRun(DNNLPoolingGradCompute, attrs, ctx, inputs, req, outputs); DNNL_OPCHECK_RUN(PoolingGradCompute, attrs, ctx, inputs, req, outputs); return; } diff --git a/src/operator/quantization/dnnl/dnnl_quantized_pooling.cc b/src/operator/quantization/dnnl/dnnl_quantized_pooling.cc index a6f89ee6b875..dd8b55cf3dc4 100644 --- a/src/operator/quantization/dnnl/dnnl_quantized_pooling.cc +++ b/src/operator/quantization/dnnl/dnnl_quantized_pooling.cc @@ -37,9 +37,7 @@ static void DNNLQuantizedPoolingForward(const nnvm::NodeAttrs& attrs, const std::vector& out_data) { CHECK(in_data[0].dtype() == mshadow::kUint8 || in_data[0].dtype() == mshadow::kInt8) << "dnnl_quantized_pooling op only supports uint8 and int8 as input type"; - const PoolingParam& param = nnvm::get(attrs.parsed); - DNNLPoolingCompute( - ctx, param, in_data[0], req[0], out_data[0], nullptr, /*use_adaptive_pooling*/ false); + DNNLRun(DNNLPoolingCompute, attrs, ctx, in_data, req, out_data); out_data[1].data().dptr()[0] = in_data[1].data().dptr()[0]; out_data[2].data().dptr()[0] = in_data[2].data().dptr()[0]; } From 3ef827a3daaec32b7f214b955493ba0f3e87fb68 Mon Sep 17 00:00:00 2001 From: piotrw Date: Mon, 13 Dec 2021 16:27:19 +0000 Subject: [PATCH 2/5] Added include for vector --- src/operator/nn/dnnl/dnnl_pooling-inl.h | 1 + 1 file changed, 1 insertion(+) diff --git a/src/operator/nn/dnnl/dnnl_pooling-inl.h b/src/operator/nn/dnnl/dnnl_pooling-inl.h index a8ec9da7b1dc..eb233c0937af 100644 --- a/src/operator/nn/dnnl/dnnl_pooling-inl.h +++ b/src/operator/nn/dnnl/dnnl_pooling-inl.h @@ -28,6 +28,7 @@ #include #include +#include #include "../pooling-inl.h" #include "./dnnl_base-inl.h" From 7c626a1c8a8a922da2bbef41bd1e03b6ef9edc2f Mon Sep 17 00:00:00 2001 From: piotrw Date: Tue, 28 Dec 2021 10:16:17 +0000 Subject: [PATCH 3/5] Run CI again --- src/operator/contrib/adaptive_avg_pooling.cc | 1 + 1 file changed, 1 insertion(+) diff --git a/src/operator/contrib/adaptive_avg_pooling.cc b/src/operator/contrib/adaptive_avg_pooling.cc index 54ae6c7dc1f0..03d68ea9f24c 100644 --- a/src/operator/contrib/adaptive_avg_pooling.cc +++ b/src/operator/contrib/adaptive_avg_pooling.cc @@ -214,6 +214,7 @@ bool SupportDNNLAveragePooling(const NDArray& in_data, const NDArray& out_data) return false; } } + const int IH = in_data.shape()[2]; const int IW = in_data.shape()[3]; const int OH = out_data.shape()[2]; From abca9966fcfa464740aa8e774e44140ff082d224 Mon Sep 17 00:00:00 2001 From: piotrw Date: Tue, 28 Dec 2021 10:34:07 +0000 Subject: [PATCH 4/5] CI run --- src/operator/contrib/adaptive_avg_pooling.cc | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/operator/contrib/adaptive_avg_pooling.cc b/src/operator/contrib/adaptive_avg_pooling.cc index 03d68ea9f24c..c6bdfd62d242 100644 --- a/src/operator/contrib/adaptive_avg_pooling.cc +++ b/src/operator/contrib/adaptive_avg_pooling.cc @@ -214,7 +214,7 @@ bool SupportDNNLAveragePooling(const NDArray& in_data, const NDArray& out_data) return false; } } - + const int IH = in_data.shape()[2]; const int IW = in_data.shape()[3]; const int OH = out_data.shape()[2]; From b3520dfe8bd640f8dd1fd80ea8781532c37bd04d Mon Sep 17 00:00:00 2001 From: piotrw Date: Tue, 11 Jan 2022 08:23:01 +0000 Subject: [PATCH 5/5] CI run --- src/operator/contrib/adaptive_avg_pooling.cc | 1 - 1 file changed, 1 deletion(-) diff --git a/src/operator/contrib/adaptive_avg_pooling.cc b/src/operator/contrib/adaptive_avg_pooling.cc index c6bdfd62d242..54ae6c7dc1f0 100644 --- a/src/operator/contrib/adaptive_avg_pooling.cc +++ b/src/operator/contrib/adaptive_avg_pooling.cc @@ -214,7 +214,6 @@ bool SupportDNNLAveragePooling(const NDArray& in_data, const NDArray& out_data) return false; } } - const int IH = in_data.shape()[2]; const int IW = in_data.shape()[3]; const int OH = out_data.shape()[2];