Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 1 addition & 3 deletions src/operator/contrib/adaptive_avg_pooling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -238,11 +238,9 @@ void AdaptiveAvgPoolComputeExCPU(const nnvm::NodeAttrs& attrs,
oneDNN doesn't support adaptive pooling.
Fallback is needed when padding is not equal 0;
*/
const PoolingParam& param = nnvm::get<PoolingParam>(attrs.parsed);
if (SupportDNNL(inputs[0]) && SupportDNNLAveragePooling(inputs[0], outputs[0])) {
const NDArray* workspace = nullptr;
DNNL_OPCHECK_INIT(false, 1, inputs, outputs);
DNNLPoolingCompute(ctx, param, inputs[0], req[0], outputs[0], workspace, true);
DNNLRun(DNNLPoolingCompute<true>, attrs, ctx, inputs, req, outputs);
DNNL_OPCHECK_RUN(PoolingCompute<cpu>, attrs, ctx, inputs, req, outputs);
return;
}
Expand Down
38 changes: 23 additions & 15 deletions src/operator/nn/dnnl/dnnl_pooling-inl.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

#include <dnnl.hpp>
#include <utility>
#include <vector>

#include "../pooling-inl.h"
#include "./dnnl_base-inl.h"
Expand Down Expand Up @@ -172,27 +173,34 @@ inline bool DNNLRequireWorkspace(const PoolingParam& param) {
}

typedef ParamOpSign<PoolingParam> DNNLPoolingSignature;
void DNNLPoolingCompute(const OpContext& ctx,
const PoolingParam& param,
const NDArray& in_data,
const OpReqType req,
const NDArray& out_data,
const NDArray* workspace,
const bool use_adaptive_pooling);

void DNNLPoolingGradCompute(const OpContext& ctx,
const PoolingParam& param,
const NDArray& out_grad,
const NDArray& in_data,
const NDArray* workspace,
const OpReqType req,
const NDArray& in_grad);

Comment thread
piotrwolinski-intel marked this conversation as resolved.
void DNNLPoolingGradCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs);

DNNLPoolingFwd& GetPoolingFwd(const PoolingParam& param,
const bool is_train,
const NDArray& data,
const NDArray& output,
const bool use_adaptive_pooling);

template <bool use_adaptive_pooling>
void DNNLPoolingCompute(const nnvm::NodeAttrs& attrs,
Comment thread
piotrwolinski-intel marked this conversation as resolved.
const OpContext& ctx,
const std::vector<NDArray>& in_data,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& out_data) {
const PoolingParam& param = nnvm::get<PoolingParam>(attrs.parsed);
const NDArray* workspace = nullptr;
if (DNNLRequireWorkspace(param)) {
CHECK_GT(out_data.size(), 1U);
workspace = &out_data[1];
}
auto& fwd = GetPoolingFwd(param, ctx.is_train, in_data[0], out_data[0], use_adaptive_pooling);
fwd.Execute(in_data[0], req[0], out_data[0], workspace);
}
} // namespace op
} // namespace mxnet
#endif // MXNET_USE_ONEDNN == 1
Expand Down
47 changes: 26 additions & 21 deletions src/operator/nn/dnnl/dnnl_pooling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -310,17 +310,6 @@ DNNLPoolingFwd& GetPoolingFwd(const PoolingParam& param,
return it->second;
}

void DNNLPoolingCompute(const OpContext& ctx,
const PoolingParam& param,
const NDArray& in_data,
const OpReqType req,
const NDArray& out_data,
const NDArray* workspace,
const bool use_adaptive_pooling) {
auto& fwd = GetPoolingFwd(param, ctx.is_train, in_data, out_data, use_adaptive_pooling);
fwd.Execute(in_data, req, out_data, workspace);
}

DNNLPoolingBwd::DNNLPoolingBwd(const dnnl::pooling_backward::primitive_desc& pdesc, bool with_ws)
: with_workspace(with_ws), pd(pdesc) {
bwd = std::make_shared<dnnl::pooling_backward>(pd);
Expand Down Expand Up @@ -384,22 +373,38 @@ DNNLPoolingBwd& GetPoolingBwd(const PoolingParam& param,
return it->second;
}

void DNNLPoolingGradCompute(const OpContext& ctx,
const PoolingParam& param,
const NDArray& out_grad,
const NDArray& in_data,
const NDArray* workspace,
const OpReqType req,
const NDArray& in_grad) {
if (req == kNullOp) {
void DNNLPoolingGradCompute(const nnvm::NodeAttrs& attrs,
const OpContext& ctx,
const std::vector<NDArray>& inputs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
if (req[0] == kNullOp) {
return;
}

const PoolingParam& param = nnvm::get<PoolingParam>(attrs.parsed);

const NDArray& out_grad = inputs[0];
const NDArray* workspace = nullptr;
const NDArray* in_data = nullptr;
if (DNNLRequireWorkspace(param)) {
// The first two elements are the gradient of the outputs in forward.
Comment thread
szha marked this conversation as resolved.
// The third is the input of forward.
// The fourth and the fifth are the outputs of forward.
CHECK_EQ(inputs.size(), 5U);
in_data = &inputs[2];
workspace = &inputs[4];
} else {
CHECK_EQ(inputs.size(), 3U);
in_data = &inputs[1];
}
const NDArray& in_grad = outputs[0];

TmpMemMgr::Get()->Init(ctx.requested[0]);

auto& bwd = GetPoolingBwd(param, in_data, in_grad, out_grad);
auto& bwd = GetPoolingBwd(param, *in_data, in_grad, out_grad);
auto diff_dst_mem = out_grad.GetDNNLDataReorder(bwd.pd.diff_dst_desc());
auto diff_src_mem = CreateDNNLMem(in_grad, bwd.pd.diff_src_desc(), req);
auto diff_src_mem = CreateDNNLMem(in_grad, bwd.pd.diff_src_desc(), req[0]);
dnnl_args_map_t args = {
{DNNL_ARG_DIFF_DST, *diff_dst_mem},
{DNNL_ARG_DIFF_SRC, *diff_src_mem.second},
Expand Down
24 changes: 2 additions & 22 deletions src/operator/nn/pooling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -294,7 +294,6 @@ void PoolingComputeExCPU(const nnvm::NodeAttrs& attrs,
const std::vector<OpReqType>& req,
const std::vector<NDArray>& outputs) {
const PoolingParam& param = nnvm::get<PoolingParam>(attrs.parsed);
const NDArray* workspace = nullptr;

// Pooling does not currently support working with views
if (inputs[0].IsView() || outputs[0].IsView()) {
Expand All @@ -303,12 +302,8 @@ void PoolingComputeExCPU(const nnvm::NodeAttrs& attrs,
}

if (SupportDNNLPooling(param, inputs[0])) {
if (DNNLRequireWorkspace(param)) {
CHECK_GT(outputs.size(), 1U);
workspace = &outputs[1];
}
DNNL_OPCHECK_INIT(false, 1, inputs, outputs);
DNNLPoolingCompute(ctx, param, inputs[0], req[0], outputs[0], workspace, false);
DNNLRun(DNNLPoolingCompute<false>, attrs, ctx, inputs, req, outputs);
DNNL_OPCHECK_RUN(PoolingCompute<cpu>, attrs, ctx, inputs, req, outputs);
return;
}
Expand All @@ -329,23 +324,8 @@ void PoolingGradComputeExCPU(const nnvm::NodeAttrs& attrs,
}

if (SupportDNNLPooling(param, inputs[0])) {
const NDArray& out_grad = inputs[0];
const NDArray* workspace = nullptr;
const NDArray* in_data = nullptr;
if (DNNLRequireWorkspace(param)) {
// The first two elements are the gradient of the outputs in forward.
// The third is the input of forward.
// The fourth and the fifth are the outputs of forward.
CHECK_EQ(inputs.size(), 5U);
in_data = &inputs[2];
workspace = &inputs[4];
} else {
CHECK_EQ(inputs.size(), 3U);
in_data = &inputs[1];
}
const NDArray& in_grad = outputs[0];
DNNL_OPCHECK_INIT(true, outputs.size(), inputs, outputs);
DNNLPoolingGradCompute(ctx, param, out_grad, *in_data, workspace, req[0], in_grad);
DNNLRun(DNNLPoolingGradCompute, attrs, ctx, inputs, req, outputs);
DNNL_OPCHECK_RUN(PoolingGradCompute<cpu>, attrs, ctx, inputs, req, outputs);
return;
}
Expand Down
4 changes: 1 addition & 3 deletions src/operator/quantization/dnnl/dnnl_quantized_pooling.cc
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,7 @@ static void DNNLQuantizedPoolingForward(const nnvm::NodeAttrs& attrs,
const std::vector<NDArray>& out_data) {
CHECK(in_data[0].dtype() == mshadow::kUint8 || in_data[0].dtype() == mshadow::kInt8)
<< "dnnl_quantized_pooling op only supports uint8 and int8 as input type";
const PoolingParam& param = nnvm::get<PoolingParam>(attrs.parsed);
DNNLPoolingCompute(
ctx, param, in_data[0], req[0], out_data[0], nullptr, /*use_adaptive_pooling*/ false);
DNNLRun(DNNLPoolingCompute<false>, attrs, ctx, in_data, req, out_data);
out_data[1].data().dptr<float>()[0] = in_data[1].data().dptr<float>()[0];
out_data[2].data().dptr<float>()[0] = in_data[2].data().dptr<float>()[0];
}
Expand Down