diff --git a/src/operator/nn/dnnl/dnnl_softmax-inl.h b/src/operator/nn/dnnl/dnnl_softmax-inl.h new file mode 100644 index 000000000000..0978ab0cdfe1 --- /dev/null +++ b/src/operator/nn/dnnl/dnnl_softmax-inl.h @@ -0,0 +1,154 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file dnnl_softmax-inl.h + * Naming convention: + * ________ + * |Softmax| + * data ------->| FWD |---> out + * |_______| + * ________ + * |Softmax|<--- out + * data_grad <---| BWD | + * |_______|<--- out_grad + */ + +#ifndef MXNET_OPERATOR_NN_DNNL_DNNL_SOFTMAX_INL_H_ +#define MXNET_OPERATOR_NN_DNNL_DNNL_SOFTMAX_INL_H_ + +#if MXNET_USE_ONEDNN == 1 +#include + +#include "./dnnl_base-inl.h" +#include "./dnnl_ops-inl.h" + +#include "../softmax-inl.h" + +namespace mxnet { +namespace op { + +using softmax_fwd_t = dnnl::softmax_forward; +using softmax_fwd_pd_t = dnnl::softmax_forward::primitive_desc; + +using softmax_bwd_t = dnnl::softmax_backward; +using softmax_bwd_pd_t = dnnl::softmax_backward::primitive_desc; + +using linear_t = dnnl::eltwise_forward; +using linear_pd_t = dnnl::eltwise_forward::primitive_desc; + +class DNNLSoftmaxFwd { + public: + struct Tensors { + Tensors(const NDArray& data, const NDArray& out); + + const NDArray& data; + const NDArray& out; + }; + + static DNNLSoftmaxFwd& GetCached(const SoftmaxParam& param, + const Tensors& tensors, + const bool is_train); + + static softmax_fwd_pd_t GetSoftmaxFwdPd(const dnnl::memory& input_mem, + const int axis, + const bool is_train); + + static linear_pd_t GetTemperaturePd(const dnnl::memory& input_mem, const float temperature); + + DNNLSoftmaxFwd(const SoftmaxParam& param, const Tensors& tensors, const bool is_train); + void Execute(const Tensors& tensors) const; + + private: + std::shared_ptr softmax_pd; + std::shared_ptr softmax_fwd; + std::shared_ptr temperature_pd; + std::shared_ptr temperature_fwd; +}; + +DNNLSoftmaxFwd::Tensors::Tensors(const NDArray& data, const NDArray& output) + : data(data), out(output) {} + +DNNLSoftmaxFwd::DNNLSoftmaxFwd(const SoftmaxParam& param, + const Tensors& tensors, + const bool is_train) { + const float temperature = param.temperature.has_value() ? param.temperature.value() : 1.0f; + const int axis = CheckAxis(param.axis, tensors.data.shape().ndim()); + const auto input_mem = tensors.data.GetDNNLData(); + + softmax_pd = std::make_shared(GetSoftmaxFwdPd(*input_mem, axis, is_train)); + softmax_fwd = std::make_shared(*softmax_pd); + + if (temperature != 1.0f) { + temperature_pd = std::make_shared(GetTemperaturePd(*input_mem, temperature)); + temperature_fwd = std::make_shared(*temperature_pd); + } +} + +class DNNLSoftmaxBwd { + public: + struct Tensors { + Tensors(const std::vector& inputs, const std::vector& outputs); + const NDArray& out_grad; + const NDArray& out; + const NDArray& data_grad; + }; + static DNNLSoftmaxBwd& GetCached(const SoftmaxParam& param, const Tensors& tensors); + + static softmax_bwd_pd_t GetSoftmaxBwdPd(const dnnl::memory& out_grad_mem, + const dnnl::memory& out_mem, + const int axis, + const softmax_fwd_pd_t& hint_fwd_pd); + + DNNLSoftmaxBwd(const SoftmaxParam& param, const Tensors& tensors); + void Execute(const Tensors& tensors, const std::vector& req) const; + + private: + std::shared_ptr softmax_bwd_pd; + std::shared_ptr softmax_bwd; + std::shared_ptr temperature_pd; + std::shared_ptr temperature_fwd; +}; + +DNNLSoftmaxBwd::Tensors::Tensors(const std::vector& inputs, + const std::vector& outputs) + : out_grad(inputs[0]), out(inputs[1]), data_grad(outputs[0]) {} + +DNNLSoftmaxBwd::DNNLSoftmaxBwd(const SoftmaxParam& param, const Tensors& tensors) { + const float temperature = param.temperature.has_value() ? param.temperature.value() : 1.0f; + const int axis = CheckAxis(param.axis, tensors.out.shape().ndim()); + const auto out_grad_mem = tensors.out_grad.GetDNNLData(); + const auto out_mem = tensors.out.GetDNNLData(); + const auto softmax_fwd_pd = DNNLSoftmaxFwd::GetSoftmaxFwdPd(*out_mem, axis, true); + + softmax_bwd_pd = std::make_shared( + GetSoftmaxBwdPd(*out_grad_mem, *out_mem, axis, softmax_fwd_pd)); + softmax_bwd = std::make_shared(*softmax_bwd_pd); + + if (temperature != 1.0f) { + temperature_pd = + std::make_shared(DNNLSoftmaxFwd::GetTemperaturePd(*out_mem, temperature)); + temperature_fwd = std::make_shared(*temperature_pd); + } +} + +} // namespace op +} // namespace mxnet +#endif +#endif // MXNET_OPERATOR_NN_DNNL_DNNL_SOFTMAX_INL_H_ diff --git a/src/operator/nn/dnnl/dnnl_softmax.cc b/src/operator/nn/dnnl/dnnl_softmax.cc index f5e5f3e3681d..72a25d4c85b9 100644 --- a/src/operator/nn/dnnl/dnnl_softmax.cc +++ b/src/operator/nn/dnnl/dnnl_softmax.cc @@ -23,79 +23,55 @@ * \author Da Zheng */ -#include "../softmax-inl.h" -#include "./dnnl_base-inl.h" -#include "./dnnl_ops-inl.h" - #if MXNET_USE_ONEDNN == 1 +#include "./dnnl_softmax-inl.h" + namespace mxnet { namespace op { -static dnnl::softmax_forward::primitive_desc GetSoftmaxFwdPd(bool is_train, - const int axis, - const dnnl::memory& input_mem) { - dnnl::memory::desc data_md = input_mem.get_desc(); - auto cpu_engine = CpuEngine::Get()->get_engine(); - auto prop = is_train ? dnnl::prop_kind::forward_training : dnnl::prop_kind::forward_scoring; - auto desc = dnnl::softmax_forward::desc(prop, data_md, axis); - return dnnl::softmax_forward::primitive_desc(desc, cpu_engine); -} - -static dnnl::softmax_backward::primitive_desc GetSoftmaxBwdPd( - const dnnl::memory& diff_mem, - const dnnl::memory& data_mem, - const int axis, - const dnnl::softmax_forward::primitive_desc& hint_fwd_pd) { - dnnl::memory::desc diff_md = diff_mem.get_desc(); - dnnl::memory::desc data_md = data_mem.get_desc(); - auto cpu_engine = CpuEngine::Get()->get_engine(); - auto desc = dnnl::softmax_backward::desc(diff_md, data_md, axis); - return dnnl::softmax_backward::primitive_desc(desc, cpu_engine, hint_fwd_pd); -} - bool SupportDNNLSoftmax(const SoftmaxParam& param, const NDArray& data, const NDArray& output) { - // DNNL does not support temperature argument in their softmax function - // now. Need update this once they start to support it. const int ndim = data.shape().ndim(); const int in_dtype = data.dtype(); const int out_dtype = output.dtype(); const int axis = CheckAxis(param.axis, ndim); - // DNNL does not support temperature argument in their softmax function - // now. Need update this once they start to support it. - // Currently, DNNL shows bad performance when softmax is not performed on the last dimension - if (param.temperature.has_value() || in_dtype != mshadow::kFloat32 || in_dtype != out_dtype || - axis != (ndim - 1)) { + + if (param.temperature.has_value() && param.temperature.value() == 0.0) { return false; } - // only supports ndim = 1, 2, 3, 4 for now - return (ndim >= 1 && ndim <= 4); -} + if (in_dtype != mshadow::kFloat32 || in_dtype != out_dtype || axis != (ndim - 1)) { + return false; + } -class DNNLSoftmaxFwd { - public: - dnnl::softmax_forward::primitive_desc pd; + // Supports ndim up to 6 + return (ndim >= 1 && ndim <= 6); +} - DNNLSoftmaxFwd(const bool is_train, const int axis, const dnnl::memory& input) - : pd(GetSoftmaxFwdPd(is_train, axis, input)) { - fwd_ = std::make_shared(pd); - } +void DNNLSoftmaxForward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const NDArray& in_data, + const OpReqType& req, + const NDArray& out_data) { + if (req == kNullOp) + return; + // Same as the FCompute path, softmax only supports kWriteTo and kWriteInplace for now + CHECK_NE(req, kAddTo); - const dnnl::softmax_forward& GetFwd() const { - return *fwd_; + const auto& param = nnvm::get(attrs.parsed); + if (param.temperature.has_value()) { + TmpMemMgr::Get()->Init(ctx.requested[0]); } - private: - std::shared_ptr fwd_; -}; + const bool is_train = ctx.is_train; + const auto tensors = DNNLSoftmaxFwd::Tensors(in_data, out_data); + const auto& fwd = DNNLSoftmaxFwd::GetCached(param, tensors, is_train); + fwd.Execute(tensors); +} typedef ParamOpSign DNNLSoftmaxSignature; - -static DNNLSoftmaxFwd& GetSoftmaxFwd(const SoftmaxParam& param, - const int real_axis, - const bool is_train, - const NDArray& data, - const NDArray& output) { +DNNLSoftmaxFwd& DNNLSoftmaxFwd::GetCached(const SoftmaxParam& param, + const Tensors& tensors, + const bool is_train) { #if DMLC_CXX11_THREAD_LOCAL static thread_local std::unordered_map fwds; #else @@ -103,108 +79,150 @@ static DNNLSoftmaxFwd& GetSoftmaxFwd(const SoftmaxParam& param, #endif DNNLSoftmaxSignature key(param); - key.AddSign(real_axis); + const float temperature = param.temperature.has_value() ? param.temperature.value() : 1.0f; + const int axis = CheckAxis(param.axis, tensors.data.shape().ndim()); + key.AddSign(axis); key.AddSign(is_train); - key.AddSign(data); - key.AddSign(output); - + key.AddSign(temperature); + key.AddSign(tensors.data); + key.AddSign(tensors.out); auto it = fwds.find(key); if (it == fwds.end()) { - DNNLSoftmaxFwd fwd(is_train, real_axis, *(data.GetDNNLData())); + DNNLSoftmaxFwd fwd(param, tensors, is_train); it = AddToCache(&fwds, key, fwd); } return it->second; } -void DNNLSoftmaxForward(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const NDArray& in_data, - const OpReqType& req, - const NDArray& out_data) { - if (req == kNullOp) - return; - // same as the FCompute path, softmax only supports kWriteTo and kWriteInplace for now. - CHECK_NE(req, kAddTo); +softmax_fwd_pd_t DNNLSoftmaxFwd::GetSoftmaxFwdPd(const dnnl::memory& input_mem, + const int axis, + const bool is_train) { + const auto data_md = input_mem.get_desc(); + const auto cpu_engine = CpuEngine::Get()->get_engine(); + const auto prop = is_train ? dnnl::prop_kind::forward_training : dnnl::prop_kind::forward_scoring; + const auto desc = dnnl::softmax_forward::desc(prop, data_md, axis); + return softmax_fwd_pd_t(desc, cpu_engine); +} - const SoftmaxParam& param = nnvm::get(attrs.parsed); - int axis = CheckAxis(param.axis, in_data.shape().ndim()); - auto fwd = GetSoftmaxFwd(param, axis, ctx.is_train, in_data, out_data); +linear_pd_t DNNLSoftmaxFwd::GetTemperaturePd(const dnnl::memory& input_mem, + const float temperature) { + const auto data_md = input_mem.get_desc(); + const auto cpu_engine = CpuEngine::Get()->get_engine(); + const auto desc = dnnl::eltwise_forward::desc(dnnl::prop_kind::forward_scoring, + dnnl::algorithm::eltwise_linear, + data_md, + 1.0f / temperature, + 0.0f); + return linear_pd_t(desc, cpu_engine); +} - auto in_mem = in_data.GetDNNLData(); - auto out_mem = out_data.GetDNNLData(fwd.pd.dst_desc()); +void DNNLSoftmaxFwd::Execute(const Tensors& tensors) const { DNNLStream* stream = DNNLStream::Get(); - stream->RegisterPrimArgs(fwd.GetFwd(), {{DNNL_ARG_SRC, *in_mem}, {DNNL_ARG_DST, *out_mem}}); + + auto original_input_mem = tensors.data.GetDNNLData(); + const auto out_mem = tensors.out.GetDNNLData(softmax_pd->dst_desc()); + + dnnl::memory* softmax_input_mem; + if (temperature_pd) { + // check whether additional buffer is needed, when temperature parameter is being used + if (original_input_mem->get_desc() != out_mem->get_desc()) { + softmax_input_mem = TmpMemMgr::Get()->Alloc(original_input_mem->get_desc()); + } else { + softmax_input_mem = const_cast(out_mem); + } + stream->RegisterPrimArgs( + *temperature_fwd, + {{DNNL_ARG_SRC, *original_input_mem}, {DNNL_ARG_DST, *softmax_input_mem}}); + } else { + softmax_input_mem = const_cast(original_input_mem); + } + + stream->RegisterPrimArgs(*softmax_fwd, + {{DNNL_ARG_SRC, *softmax_input_mem}, {DNNL_ARG_DST, *out_mem}}); stream->Submit(); } -class DNNLSoftmaxBwd { - public: - dnnl::softmax_backward::primitive_desc pd; +softmax_bwd_pd_t DNNLSoftmaxBwd::GetSoftmaxBwdPd(const dnnl::memory& out_grad_mem, + const dnnl::memory& out_mem, + const int axis, + const softmax_fwd_pd_t& hint_fwd_pd) { + dnnl::memory::desc out_grad_md = out_grad_mem.get_desc(); + dnnl::memory::desc out_md = out_mem.get_desc(); + const auto cpu_engine = CpuEngine::Get()->get_engine(); + const auto desc = dnnl::softmax_backward::desc(out_grad_md, out_md, axis); + return softmax_bwd_pd_t(desc, cpu_engine, hint_fwd_pd); +} - DNNLSoftmaxBwd(const dnnl::memory& diff_mem, - const dnnl::memory& data_mem, - const int axis, - const dnnl::softmax_forward::primitive_desc& hint_fwd_pd) - : pd(GetSoftmaxBwdPd(diff_mem, data_mem, axis, hint_fwd_pd)) { - bwd_ = std::make_shared(pd); - } +void DNNLSoftmaxBackward(const nnvm::NodeAttrs& attrs, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + if (req[0] == kNullOp) + return; - const dnnl::softmax_backward& GetBwd() const { - return *bwd_; + const auto& param = nnvm::get(attrs.parsed); + if (param.temperature.has_value()) { + TmpMemMgr::Get()->Init(ctx.requested[0]); } - private: - std::shared_ptr bwd_; -}; + const auto tensors = DNNLSoftmaxBwd::Tensors(inputs, outputs); + const auto& bwd = DNNLSoftmaxBwd::GetCached(param, tensors); + bwd.Execute(tensors, req); +} -static DNNLSoftmaxBwd& GetSoftmaxBwd(const SoftmaxParam& param, - const int real_axis, - const std::vector& data, - const std::vector& output) { +DNNLSoftmaxBwd& DNNLSoftmaxBwd::GetCached(const SoftmaxParam& param, const Tensors& tensors) { #if DMLC_CXX11_THREAD_LOCAL static thread_local std::unordered_map bwds; #else static MX_THREAD_LOCAL std::unordered_map bwds; #endif + const float temperature = param.temperature.has_value() ? param.temperature.value() : 1.0f; + const int axis = CheckAxis(param.axis, tensors.out.shape().ndim()); DNNLSoftmaxSignature key(param); - key.AddSign(real_axis); - key.AddSign(data); - key.AddSign(output); + key.AddSign(axis); + key.AddSign(tensors.out); + key.AddSign(tensors.out_grad); + key.AddSign(temperature); auto it = bwds.find(key); if (it == bwds.end()) { - auto diff_mem = data[0].GetDNNLData(); - auto data_mem = data[1].GetDNNLData(); - auto fwd_pd = GetSoftmaxFwdPd(true, real_axis, *data_mem); - DNNLSoftmaxBwd bwd(*diff_mem, *data_mem, real_axis, fwd_pd); + DNNLSoftmaxBwd bwd(param, tensors); it = AddToCache(&bwds, key, bwd); } return it->second; } -void DNNLSoftmaxBackward(const nnvm::NodeAttrs& attrs, - const OpContext& ctx, - const std::vector& in_data, - const std::vector& req, - const std::vector& out_data) { - if (req[0] == kNullOp) - return; - CHECK_EQ(in_data.size(), 2U); - const SoftmaxParam& param = nnvm::get(attrs.parsed); - int axis = CheckAxis(param.axis, in_data[1].shape().ndim()); - auto diff_mem = in_data[0].GetDNNLData(); - auto data_mem = in_data[1].GetDNNLData(); - auto bwd = GetSoftmaxBwd(param, axis, in_data, out_data); - - auto out_mem = CreateDNNLMem(out_data[0], bwd.pd.diff_src_desc(), req[0]); - DNNLStream* stream = DNNLStream::Get(); - dnnl_args_map_t args = {{DNNL_ARG_DST, *data_mem}, - {DNNL_ARG_DIFF_DST, *diff_mem}, - {DNNL_ARG_DIFF_SRC, *out_mem.second}}; - - stream->RegisterPrimArgs(bwd.GetBwd(), args); - CommitOutput(out_data[0], out_mem); +void DNNLSoftmaxBwd::Execute(const Tensors& tensors, const std::vector& req) const { + DNNLStream* stream = DNNLStream::Get(); + + const auto original_out_grad_mem = tensors.out_grad.GetDNNLData(); + const auto out_mem = tensors.out.GetDNNLData(); + const auto data_grad_mem = + CreateDNNLMem(tensors.data_grad, softmax_bwd_pd->diff_src_desc(), req[0]); + + dnnl::memory* out_grad_mem; + if (temperature_fwd) { + // check whether additional buffer is needed, when temperature parameter is being used + if (original_out_grad_mem->get_desc() != softmax_bwd_pd->diff_src_desc()) { + out_grad_mem = TmpMemMgr::Get()->Alloc(original_out_grad_mem->get_desc()); + } else { + out_grad_mem = const_cast(data_grad_mem.second); + } + stream->RegisterPrimArgs( + *temperature_fwd, {{DNNL_ARG_SRC, *original_out_grad_mem}, {DNNL_ARG_DST, *out_grad_mem}}); + } else { + out_grad_mem = const_cast(original_out_grad_mem); + } + + dnnl_args_map_t args = {{DNNL_ARG_DST, *out_mem}, + {DNNL_ARG_DIFF_DST, *out_grad_mem}, + {DNNL_ARG_DIFF_SRC, *data_grad_mem.second}}; + + stream->RegisterPrimArgs(*softmax_bwd, args); + + CommitOutput(tensors.data_grad, data_grad_mem); stream->Submit(); } diff --git a/src/operator/nn/softmax.cc b/src/operator/nn/softmax.cc index 318446165247..8c88d53de939 100644 --- a/src/operator/nn/softmax.cc +++ b/src/operator/nn/softmax.cc @@ -153,6 +153,10 @@ Example:: .set_attr("TIsDNNL", true) .set_attr("FComputeEx", SoftmaxComputeExCPU) .set_attr("FInferStorageType", SoftmaxStorageType) + .set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) #endif .set_attr("FGradient", SoftmaxFGradient{"_backward_softmax"}) // .set_attr("FGradient", MakeZeroGradNodes) @@ -186,6 +190,10 @@ NNVM_REGISTER_OP(_backward_softmax) .set_attr("TIsDNNL", true) .set_attr("FComputeEx", SoftmaxGradComputeExCPU) .set_attr("FInferStorageType", SoftmaxGradStorageType) + .set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + return std::vector{ResourceRequest::kTempSpace}; + }) #endif .set_attr("FCompute", SoftmaxGradCompute);