diff --git a/src/operator/contrib/adaptive_avg_pooling.cc b/src/operator/contrib/adaptive_avg_pooling.cc index 6af2fa02d66a..54ae6c7dc1f0 100644 --- a/src/operator/contrib/adaptive_avg_pooling.cc +++ b/src/operator/contrib/adaptive_avg_pooling.cc @@ -238,11 +238,9 @@ void AdaptiveAvgPoolComputeExCPU(const nnvm::NodeAttrs& attrs, oneDNN doesn't support adaptive pooling. Fallback is needed when padding is not equal 0; */ - const PoolingParam& param = nnvm::get(attrs.parsed); if (SupportDNNL(inputs[0]) && SupportDNNLAveragePooling(inputs[0], outputs[0])) { - const NDArray* workspace = nullptr; DNNL_OPCHECK_INIT(false, 1, inputs, outputs); - DNNLPoolingCompute(ctx, param, inputs[0], req[0], outputs[0], workspace, true); + DNNLRun(DNNLPoolingCompute, attrs, ctx, inputs, req, outputs); DNNL_OPCHECK_RUN(PoolingCompute, attrs, ctx, inputs, req, outputs); return; } diff --git a/src/operator/nn/dnnl/dnnl_pooling-inl.h b/src/operator/nn/dnnl/dnnl_pooling-inl.h index 15a544e38fd9..eb233c0937af 100644 --- a/src/operator/nn/dnnl/dnnl_pooling-inl.h +++ b/src/operator/nn/dnnl/dnnl_pooling-inl.h @@ -28,6 +28,7 @@ #include #include +#include #include "../pooling-inl.h" #include "./dnnl_base-inl.h" @@ -172,27 +173,34 @@ inline bool DNNLRequireWorkspace(const PoolingParam& param) { } typedef ParamOpSign DNNLPoolingSignature; -void DNNLPoolingCompute(const OpContext& ctx, - const PoolingParam& param, - const NDArray& in_data, - const OpReqType req, - const NDArray& out_data, - const NDArray* workspace, - const bool use_adaptive_pooling); - -void DNNLPoolingGradCompute(const OpContext& ctx, - const PoolingParam& param, - const NDArray& out_grad, - const NDArray& in_data, - const NDArray* workspace, - const OpReqType req, - const NDArray& in_grad); + +void DNNLPoolingGradCompute(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs); DNNLPoolingFwd& GetPoolingFwd(const PoolingParam& param, const bool is_train, const NDArray& data, const NDArray& output, const bool use_adaptive_pooling); + +template +void DNNLPoolingCompute(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& in_data, + const std::vector& req, + const std::vector& out_data) { + const PoolingParam& param = nnvm::get(attrs.parsed); + const NDArray* workspace = nullptr; + if (DNNLRequireWorkspace(param)) { + CHECK_GT(out_data.size(), 1U); + workspace = &out_data[1]; + } + auto& fwd = GetPoolingFwd(param, ctx.is_train, in_data[0], out_data[0], use_adaptive_pooling); + fwd.Execute(in_data[0], req[0], out_data[0], workspace); +} } // namespace op } // namespace mxnet #endif // MXNET_USE_ONEDNN == 1 diff --git a/src/operator/nn/dnnl/dnnl_pooling.cc b/src/operator/nn/dnnl/dnnl_pooling.cc index 418c832703ff..f85a86523a97 100644 --- a/src/operator/nn/dnnl/dnnl_pooling.cc +++ b/src/operator/nn/dnnl/dnnl_pooling.cc @@ -310,17 +310,6 @@ DNNLPoolingFwd& GetPoolingFwd(const PoolingParam& param, return it->second; } -void DNNLPoolingCompute(const OpContext& ctx, - const PoolingParam& param, - const NDArray& in_data, - const OpReqType req, - const NDArray& out_data, - const NDArray* workspace, - const bool use_adaptive_pooling) { - auto& fwd = GetPoolingFwd(param, ctx.is_train, in_data, out_data, use_adaptive_pooling); - fwd.Execute(in_data, req, out_data, workspace); -} - DNNLPoolingBwd::DNNLPoolingBwd(const dnnl::pooling_backward::primitive_desc& pdesc, bool with_ws) : with_workspace(with_ws), pd(pdesc) { bwd = std::make_shared(pd); @@ -384,22 +373,38 @@ DNNLPoolingBwd& GetPoolingBwd(const PoolingParam& param, return it->second; } -void DNNLPoolingGradCompute(const OpContext& ctx, - const PoolingParam& param, - const NDArray& out_grad, - const NDArray& in_data, - const NDArray* workspace, - const OpReqType req, - const NDArray& in_grad) { - if (req == kNullOp) { +void DNNLPoolingGradCompute(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + if (req[0] == kNullOp) { return; } + const PoolingParam& param = nnvm::get(attrs.parsed); + + const NDArray& out_grad = inputs[0]; + const NDArray* workspace = nullptr; + const NDArray* in_data = nullptr; + if (DNNLRequireWorkspace(param)) { + // The first two elements are the gradient of the outputs in forward. + // The third is the input of forward. + // The fourth and the fifth are the outputs of forward. + CHECK_EQ(inputs.size(), 5U); + in_data = &inputs[2]; + workspace = &inputs[4]; + } else { + CHECK_EQ(inputs.size(), 3U); + in_data = &inputs[1]; + } + const NDArray& in_grad = outputs[0]; + TmpMemMgr::Get()->Init(ctx.requested[0]); - auto& bwd = GetPoolingBwd(param, in_data, in_grad, out_grad); + auto& bwd = GetPoolingBwd(param, *in_data, in_grad, out_grad); auto diff_dst_mem = out_grad.GetDNNLDataReorder(bwd.pd.diff_dst_desc()); - auto diff_src_mem = CreateDNNLMem(in_grad, bwd.pd.diff_src_desc(), req); + auto diff_src_mem = CreateDNNLMem(in_grad, bwd.pd.diff_src_desc(), req[0]); dnnl_args_map_t args = { {DNNL_ARG_DIFF_DST, *diff_dst_mem}, {DNNL_ARG_DIFF_SRC, *diff_src_mem.second}, diff --git a/src/operator/nn/pooling.cc b/src/operator/nn/pooling.cc index 7b302ee9db73..7781c6a1ee90 100644 --- a/src/operator/nn/pooling.cc +++ b/src/operator/nn/pooling.cc @@ -294,7 +294,6 @@ void PoolingComputeExCPU(const nnvm::NodeAttrs& attrs, const std::vector& req, const std::vector& outputs) { const PoolingParam& param = nnvm::get(attrs.parsed); - const NDArray* workspace = nullptr; // Pooling does not currently support working with views if (inputs[0].IsView() || outputs[0].IsView()) { @@ -303,12 +302,8 @@ void PoolingComputeExCPU(const nnvm::NodeAttrs& attrs, } if (SupportDNNLPooling(param, inputs[0])) { - if (DNNLRequireWorkspace(param)) { - CHECK_GT(outputs.size(), 1U); - workspace = &outputs[1]; - } DNNL_OPCHECK_INIT(false, 1, inputs, outputs); - DNNLPoolingCompute(ctx, param, inputs[0], req[0], outputs[0], workspace, false); + DNNLRun(DNNLPoolingCompute, attrs, ctx, inputs, req, outputs); DNNL_OPCHECK_RUN(PoolingCompute, attrs, ctx, inputs, req, outputs); return; } @@ -329,23 +324,8 @@ void PoolingGradComputeExCPU(const nnvm::NodeAttrs& attrs, } if (SupportDNNLPooling(param, inputs[0])) { - const NDArray& out_grad = inputs[0]; - const NDArray* workspace = nullptr; - const NDArray* in_data = nullptr; - if (DNNLRequireWorkspace(param)) { - // The first two elements are the gradient of the outputs in forward. - // The third is the input of forward. - // The fourth and the fifth are the outputs of forward. - CHECK_EQ(inputs.size(), 5U); - in_data = &inputs[2]; - workspace = &inputs[4]; - } else { - CHECK_EQ(inputs.size(), 3U); - in_data = &inputs[1]; - } - const NDArray& in_grad = outputs[0]; DNNL_OPCHECK_INIT(true, outputs.size(), inputs, outputs); - DNNLPoolingGradCompute(ctx, param, out_grad, *in_data, workspace, req[0], in_grad); + DNNLRun(DNNLPoolingGradCompute, attrs, ctx, inputs, req, outputs); DNNL_OPCHECK_RUN(PoolingGradCompute, attrs, ctx, inputs, req, outputs); return; } diff --git a/src/operator/quantization/dnnl/dnnl_quantized_pooling.cc b/src/operator/quantization/dnnl/dnnl_quantized_pooling.cc index a6f89ee6b875..dd8b55cf3dc4 100644 --- a/src/operator/quantization/dnnl/dnnl_quantized_pooling.cc +++ b/src/operator/quantization/dnnl/dnnl_quantized_pooling.cc @@ -37,9 +37,7 @@ static void DNNLQuantizedPoolingForward(const nnvm::NodeAttrs& attrs, const std::vector& out_data) { CHECK(in_data[0].dtype() == mshadow::kUint8 || in_data[0].dtype() == mshadow::kInt8) << "dnnl_quantized_pooling op only supports uint8 and int8 as input type"; - const PoolingParam& param = nnvm::get(attrs.parsed); - DNNLPoolingCompute( - ctx, param, in_data[0], req[0], out_data[0], nullptr, /*use_adaptive_pooling*/ false); + DNNLRun(DNNLPoolingCompute, attrs, ctx, in_data, req, out_data); out_data[1].data().dptr()[0] = in_data[1].data().dptr()[0]; out_data[2].data().dptr()[0] = in_data[2].data().dptr()[0]; }