diff --git a/include/mxnet/operator.h b/include/mxnet/operator.h index 72c5f6c28823..dc6176fe8b51 100644 --- a/include/mxnet/operator.h +++ b/include/mxnet/operator.h @@ -400,7 +400,7 @@ class OperatorProperty { }; /*! \brief typedef the factory function of operator property */ -typedef OperatorProperty *(*OperatorPropertyFactory)(); +typedef std::function OperatorPropertyFactory; /*! * \brief Registry entry for OperatorProperty factory functions. */ @@ -454,12 +454,8 @@ struct OperatorPropertyReg * \endcode */ #define MXNET_REGISTER_OP_PROPERTY(name, OperatorPropertyType) \ - static ::mxnet::OperatorProperty* __create__ ## OperatorProperty ## name ## __() { \ - OperatorProperty* ret = new OperatorPropertyType(); \ - return ret; \ - } \ DMLC_REGISTRY_REGISTER(::mxnet::OperatorPropertyReg, OperatorPropertyReg, name) \ - .set_body(__create__ ## OperatorProperty ## name ## __) \ + .set_body([]() { return new OperatorPropertyType(); }) \ .check_name() #endif // DMLC_USE_CXX11 diff --git a/src/common/tblob_op_registry.cc b/src/common/tblob_op_registry.cc index e205f29cc42c..ae1f54da3c3a 100644 --- a/src/common/tblob_op_registry.cc +++ b/src/common/tblob_op_registry.cc @@ -11,11 +11,14 @@ namespace mxnet { namespace common { - +class TBlobUnaryOpProp; class TBlobOpRegEntryImpl : public TBlobOpRegEntry { public: - TSelf& set_function(int dev_mask, UnaryFunction funary) override { + // functions + TSelf& set_function(int dev_mask, + UnaryFunction funary, + bool inplace_in_out) override { std::lock_guard lock(mutex_); ++reg_counter_; if (funary_.size() <= static_cast(dev_mask)) { @@ -26,54 +29,46 @@ class TBlobOpRegEntryImpl : public TBlobOpRegEntry { << " already registerd for device " << dev_mask; } funary_[dev_mask] = funary; - // return if it is already registered. - if (reg_counter_ != 1) return *this; + inplace_in0_out_forward_ = inplace_in_out; + if (reg_counter_ == 1) this->DoRegisterUnary(); + return *this; + } - // The body to be registered - auto body = [this] (NDArray **used_vars, - real_t *s, - NDArray **mutate_vars) { - NDArray src = *used_vars[0]; - NDArray *out = mutate_vars[0]; - - if (out->is_none()) { - *out = NDArray(src.shape(), src.ctx(), true); - } else { - CHECK(out->ctx() == src.ctx()) << "target context mismatch"; - CHECK(out->shape() == src.shape()) << "target shape mismatch"; - } - // important: callback must always capture by value - NDArray ret = *out; - // get the const variables - std::vector const_vars; - if (src.var() != ret.var()) const_vars.push_back(src.var()); - // check if the function exist - int dev_mask = src.ctx().dev_mask(); - if (static_cast(dev_mask) >= funary_.size() || - funary_[dev_mask] == nullptr) { - if (dev_mask == gpu::kDevMask) LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; - LOG(FATAL) << "Function " << this->name << "not registered for device " << dev_mask; - } - // invoke the function - UnaryFunction fun = funary_[dev_mask]; - Engine::Get()->PushSync([src, ret, fun, dev_mask](RunContext ctx) { - ret.CheckAndAlloc(); - TBlob tmp = ret.data(); - (*fun)(src.data(), &tmp, ctx); -#if MXNET_USE_CUDA - if (dev_mask == gpu::kDevMask) { - ctx.get_stream()->Wait(); - } -#endif - }, src.ctx(), const_vars, {ret.var()}); - }; - // register the function. - NDArrayReg() - .set_body(body) - .set_num_use_vars(1) - .set_num_mutate_vars(1) - .set_type_mask(kNDArrayArgBeforeScalar | kAcceptEmptyMutateTarget) - .add_argument("src", "NDArray", "Source input to the function"); + TSelf& set_gradient(int dev_mask, + UnaryGradType1 fgrad, + bool inplace_out_in_grad) override { + std::lock_guard lock(mutex_); + if (funary_grad_t1_.size() <= static_cast(dev_mask)) { + funary_grad_t1_.resize(dev_mask + 1, nullptr); + } + if (funary_grad_t1_[dev_mask] != nullptr) { + LOG(FATAL) << "Device gradient function " << this->name + << " already registerd for device " << dev_mask; + } + funary_grad_t1_[dev_mask] = fgrad; + inplace_out_in0_grad_ = inplace_out_in_grad; + return *this; + } + + TSelf& set_gradient(int dev_mask, + UnaryGradType2 fgrad, + bool inplace_out_in_grad) override { + std::lock_guard lock(mutex_); + if (funary_grad_t2_.size() <= static_cast(dev_mask)) { + funary_grad_t2_.resize(dev_mask + 1, nullptr); + } + if (funary_grad_t2_[dev_mask] != nullptr) { + LOG(FATAL) << "Device gradient function " << this->name + << " already registerd for device " << dev_mask; + } + funary_grad_t2_[dev_mask] = fgrad; + inplace_out_in0_grad_ = inplace_out_in_grad; + return *this; + } + + TSelf& set_shape_infer(UnaryShapeInfer fshapeinfer) override { + std::lock_guard lock(mutex_); + unary_infer_ = fshapeinfer; return *this; } @@ -81,22 +76,32 @@ class TBlobOpRegEntryImpl : public TBlobOpRegEntry { std::lock_guard lock(mutex_); if (reg_counter_ != 1) return *this; NDArrayReg().describe(description); + OpReg().describe(description); return *this; } - GenericTBlobOp *GetOp() const override { - return nullptr; - } - private: + // make friend with unary op + friend class TBlobUnaryOpProp; // internal mutex std::mutex mutex_; - // unary functions on each device mask - std::vector funary_; // registration counter int reg_counter_{0}; + // unary shape inferencer + UnaryShapeInfer unary_infer_{nullptr}; + // unary functions on each device mask + std::vector funary_; + // type 1 gradient function + std::vector funary_grad_t1_; + // type 2 gradient function + std::vector funary_grad_t2_; + // whether do inplace optimization of in 0 and output + bool inplace_in0_out_forward_{true}; + // whether do inplace optimization of out_grad and in_grad0 + bool inplace_out_in0_grad_{false}; // NDArray registry NDArrayFunctionReg *ndarray_reg_{nullptr}; + OperatorPropertyReg *op_reg_{nullptr}; // internal function to register NDArray function. inline NDArrayFunctionReg &NDArrayReg() { if (ndarray_reg_ == nullptr) { @@ -106,8 +111,209 @@ class TBlobOpRegEntryImpl : public TBlobOpRegEntry { } return *ndarray_reg_; } + // internal function to register NDArray function. + inline OperatorPropertyReg &OpReg() { + if (op_reg_ == nullptr) { + OperatorPropertyReg ® = + ::dmlc::Registry::Get()->__REGISTER__(this->name); + op_reg_ = ® + } + return *op_reg_; + } + // start registering all stuffs + void DoRegisterUnary(); +}; + +// Unary operator to invoke generic TBlob function. +struct TBlobUnaryOperator : public Operator { + TBlobOpRegEntry::UnaryFunction forward; + TBlobOpRegEntry::UnaryGradType1 backward1{nullptr}; + TBlobOpRegEntry::UnaryGradType2 backward2{nullptr}; + + void Forward(const OpContext &ctx, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data, + const std::vector &aux_args) override { + CHECK_EQ(in_data.size(), 1); + CHECK_EQ(out_data.size(), 1); + TBlob out = out_data[0]; + (*forward)(in_data[0], &out, req[0], ctx.run_ctx); + } + + void Backward(const OpContext &ctx, + const std::vector &out_grad, + const std::vector &in_data, + const std::vector &out_data, + const std::vector &req, + const std::vector &in_grad, + const std::vector &aux_args) override { + CHECK_EQ(out_grad.size(), 1); + CHECK(in_data.size() == 1 && in_grad.size() == 1); + CHECK_EQ(req.size(), 1); + arg::OutGrad ograd; ograd.data = out_grad[0]; + TBlob igrad = in_grad[0]; + if (backward1 != nullptr) { + arg::OutValue out_value; out_value.data = out_data[0]; + (*backward1)(ograd, out_value, &igrad, req[0], ctx.run_ctx); + } else if (backward2 != nullptr) { + arg::Input0 in0; in0.data = in_data[0]; + (*backward2)(ograd, in0, &igrad, req[0], ctx.run_ctx); + } else { + LOG(FATAL) << "Backward is not supported"; + } + } +}; // class UnaryOperator + +class TBlobUnaryOpProp : public OperatorProperty { + public: + std::string name; + TBlobOpRegEntryImpl* source; + + void Init(const std::vector >& kwargs) override { + } + + std::map GetParams() const override { + return std::map(); + } + + bool InferShape(std::vector *in_shape, + std::vector *out_shape, + std::vector *aux_shape) const override { + using namespace mshadow; + CHECK_EQ(in_shape->size(), 1) << "Input:[data]"; + const TShape &dshape = in_shape->at(0); + if (dshape.ndim() == 0) return false; + out_shape->clear(); + if (source->unary_infer_ == nullptr) { + out_shape->push_back(dshape); + } else { + out_shape->push_back((*(source->unary_infer_))(dshape)); + } + return true; + } + + OperatorProperty* Copy() const override { + auto ptr = new TBlobUnaryOpProp(); + ptr->source = source; + ptr->name = name; + return ptr; + } + + std::string TypeString() const override { + return name; + } + + // decalre dependency and inplace optimization options + std::vector DeclareBackwardDependency( + const std::vector &out_grad, + const std::vector &in_data, + const std::vector &out_data) const override { + if (source->funary_grad_t1_.size() != 0) { + return {out_grad[0], out_data[0]}; + } else if (source->funary_grad_t2_.size() != 0) { + return {out_grad[0], in_data[0]}; + } else { + LOG(FATAL) << "Backward of " << name << " is not decalred"; + return {}; + } + } + + std::vector > BackwardInplaceOption( + const std::vector &out_grad, + const std::vector &in_data, + const std::vector &out_data, + const std::vector &in_grad) const override { + if (source->inplace_out_in0_grad_) { + return {{out_grad[0], in_grad[0]}}; + } else { + return {}; + } + } + + std::vector > ForwardInplaceOption( + const std::vector &in_data, + const std::vector &out_data) const override { + if (source->inplace_in0_out_forward_) { + return {{in_data[0], out_data[0]}}; + } else { + return {}; + } + } + + Operator* CreateOperator(Context ctx) const { + size_t dev_mask = ctx.dev_mask(); + TBlobUnaryOperator *op = new TBlobUnaryOperator(); + CHECK(dev_mask < source->funary_.size() && source->funary_[dev_mask] != nullptr); + op->forward = source->funary_[dev_mask]; + if (dev_mask < source->funary_grad_t1_.size()) { + op->backward1 = source->funary_grad_t1_[dev_mask]; + } + if (dev_mask < source->funary_grad_t2_.size()) { + op->backward2 = source->funary_grad_t2_[dev_mask]; + } + return op; + } }; +void TBlobOpRegEntryImpl::DoRegisterUnary() { + CHECK_EQ(reg_counter_, 1); + // The body to be registered + auto body = [this] (NDArray **used_vars, + real_t *s, + NDArray **mutate_vars) { + NDArray src = *used_vars[0]; + NDArray *out = mutate_vars[0]; + + if (out->is_none()) { + *out = NDArray(src.shape(), src.ctx(), true); + } else { + CHECK(out->ctx() == src.ctx()) << "target context mismatch"; + CHECK(out->shape() == src.shape()) << "target shape mismatch"; + } + // important: callback must always capture by value + NDArray ret = *out; + // get the const variables + std::vector const_vars; + if (src.var() != ret.var()) const_vars.push_back(src.var()); + // check if the function exist + int dev_mask = src.ctx().dev_mask(); + if (static_cast(dev_mask) >= funary_.size() || + funary_[dev_mask] == nullptr) { + if (dev_mask == gpu::kDevMask) LOG(FATAL) << MXNET_GPU_NOT_ENABLED_ERROR; + LOG(FATAL) << "Function " << this->name << "not registered for device " << dev_mask; + } + // invoke the function + UnaryFunction fun = funary_[dev_mask]; + Engine::Get()->PushSync([src, ret, fun, dev_mask](RunContext ctx) { + ret.CheckAndAlloc(); + TBlob tmp = ret.data(); + (*fun)(src.data(), &tmp, kWriteTo, ctx); +#if MXNET_USE_CUDA + if (dev_mask == gpu::kDevMask) { + ctx.get_stream()->Wait(); + } +#endif + }, src.ctx(), const_vars, {ret.var()}); + }; + // register the function. + NDArrayReg() + .set_body(body) + .set_num_use_vars(1) + .set_num_mutate_vars(1) + .set_type_mask(kNDArrayArgBeforeScalar | kAcceptEmptyMutateTarget) + .add_argument("src", "NDArray", "Source input to the function"); + // register the operator + auto op_factory = [this]() { + TBlobUnaryOpProp *prop = new TBlobUnaryOpProp(); + prop->name = this->name; + prop->source = this; + return prop; + }; + OpReg() + .set_body(op_factory) + .add_argument("src", "Symbol", "Source symbolic input to the function"); +} TBlobOpRegEntry& TBlobOpRegistry::__REGISTER_OR_FIND__(const std::string &name) { if (fmap_.count(name) != 0) return *fmap_.at(name); @@ -127,6 +333,5 @@ TBlobOpRegistry::~TBlobOpRegistry() { delete kv.second; } } - } // namespace common } // namespace mxnet diff --git a/src/common/tblob_op_registry.h b/src/common/tblob_op_registry.h index 910543efacb3..495144aa931e 100644 --- a/src/common/tblob_op_registry.h +++ b/src/common/tblob_op_registry.h @@ -11,44 +11,90 @@ #include #include +#include #include #include #include +#if DMLC_USE_CXX11 +#include +#endif + namespace mxnet { namespace common { +/*! \brief namespace of arguments */ +namespace arg { +/*! \brief super class of all gradient function argument */ +struct GradFunctionArgument { + /*! \brief The real data */ + TBlob data; +}; +/*! \brief First input to the function */ +struct Input0 : GradFunctionArgument {}; +/*! \brief Second input to the function */ +struct Input1 : GradFunctionArgument {}; -/*! \brief pre-declare generic TBlob function*/ -struct GenericTBlobOp; +/*! \brief Ouput value of the function to the function */ +struct OutValue : GradFunctionArgument {}; +/*! \brief Gradient of output value */ +struct OutGrad : GradFunctionArgument {}; +} // namespace arg /*! \brief registry for function entry */ class TBlobOpRegEntry { public: - /*! \brief unary tblob function */ typedef void (*UnaryFunction)(const TBlob &src, - TBlob *ret, + TBlob* ret, + OpReqType req, RunContext ctx); + typedef TShape (*UnaryShapeInfer)(const TShape &src); + typedef void (*UnaryGradType1)(const arg::OutGrad& out_grad, + const arg::OutValue& out_value, + TBlob* in_grad, + OpReqType req, + RunContext ctx); + typedef void (*UnaryGradType2)(const arg::OutGrad& out_grad, + const arg::Input0& in_data0, + TBlob* in_grad, + OpReqType req, + RunContext ctx); /*! \brief declare self type */ typedef TBlobOpRegEntry TSelf; /*! \brief name of the entry */ std::string name; + /*! + * \brief set shape inference function, by default use same shape. + * \param dev_mask The device mask of the function can act on. + * \param funary The unary function that peforms the operation. + */ + virtual TSelf& set_shape_infer(UnaryShapeInfer fshapeinfer) = 0; /*! * \brief set function of the function to be funary * \param dev_mask The device mask of the function can act on. * \param funary The unary function that peforms the operation. + * \param inplace_in_out Whether do inplace optimization on in and out. + */ + virtual TSelf& set_function(int dev_mask, + UnaryFunction funary, + bool inplace_in_out) = 0; + /*! + * \brief set gradient of the function of this function. + * \param dev_mask The device mask of the function can act on. + * \param fgrad The gradient function to be set. + * \param inplace_out_in_grad whether out_grad and in_grad can share memory. */ - virtual TSelf& set_function(int dev_mask, UnaryFunction funary) = 0; + virtual TSelf& set_gradient(int dev_mask, + UnaryGradType1 fgrad, + bool inplace_out_in_grad) = 0; + virtual TSelf& set_gradient(int dev_mask, + UnaryGradType2 fgrad, + bool inplace_out_in_grad) = 0; /*! * \brief Describe the function. * \param description The description of the function. * \return reference to self. */ virtual TSelf& describe(const std::string &description) = 0; - /*! - * \brief get the internal function representation - * \return the internal function representation. - */ - virtual GenericTBlobOp *GetOp() const = 0; /*! \brief destructor */ virtual ~TBlobOpRegEntry() {} }; @@ -80,22 +126,10 @@ class TBlobOpRegistry { std::map fmap_; }; -#if DMLC_USE_CXX11 -struct GenericTBlobOp { - /*! \brief function type of the function */ - typedef std::function &in, - TBlob *out, - RunContext ctx)> OpType; - /*! \brief the real operator */ - OpType op; -}; -#endif - #define MXNET_REGISTER_TBLOB_FUN(Name, DEV) \ static ::mxnet::common::TBlobOpRegEntry & \ __make_ ## TBlobOpRegEntry ## _ ## Name ## __ ## DEV ##__ = \ ::mxnet::common::TBlobOpRegistry::Get()->__REGISTER_OR_FIND__(#Name) - } // namespace common } // namespace mxnet #endif // MXNET_COMMON_TBLOB_OP_REGISTRY_H_ diff --git a/src/ndarray/unary_function-inl.h b/src/ndarray/unary_function-inl.h index 7832ce1798cd..45e3e42f2495 100644 --- a/src/ndarray/unary_function-inl.h +++ b/src/ndarray/unary_function-inl.h @@ -8,39 +8,86 @@ #include "../common/tblob_op_registry.h" #include "../operator/mshadow_op.h" - +#include "../operator/operator_common.h" #if defined(__CUDACC__) -#define DEVICE gpu +#define XPU gpu #else -#define DEVICE cpu +#define XPU cpu #endif namespace mxnet { namespace ndarray { +using namespace common; // NOLINT(*) + template -void EvalUnary_(const TBlob &src, - TBlob *ret, RunContext ctx) { +void UnaryForward_(const TBlob &src, + TBlob *ret, + OpReqType req, + RunContext ctx) { + using namespace mxnet::op; using namespace mshadow::expr; mshadow::Stream *s = ctx.get_stream(); - ret->FlatTo2D(s) - = F(src.FlatTo2D(s)); + mshadow::Tensor out = ret->FlatTo2D(s); + Assign(out, req, F(src.FlatTo2D(s))); } -// helper macro to register mshadow element-wise unary opts -// usually you only need to use this to register common operations -#define REGISTER_MSHADOW_UNARY(Name, Op) \ - MXNET_REGISTER_TBLOB_FUN(Name, DEVICE) \ - .set_function(DEVICE::kDevMask, EvalUnary_) +// backward function that takes input value of the op +template +void UnaryBackwardUseIn_(const arg::OutGrad& out_grad, + const arg::Input0& in_data0, + TBlob *in_grad, + OpReqType req, + RunContext ctx) { + using namespace mxnet::op; + using namespace mshadow::expr; + mshadow::Stream *s = ctx.get_stream(); + mshadow::Tensor igrad = in_grad->FlatTo2D(s); + Assign(igrad, req, + F(in_data0.data.FlatTo2D(s)) * + out_grad.data.FlatTo2D()); +} +// backward function that takes output value of the op +template +void UnaryBackwardUseOut_(const arg::OutGrad& out_grad, + const arg::OutValue& out_value, + TBlob *in_grad, + OpReqType req, + RunContext ctx) { + using namespace mxnet::op; + using namespace mshadow::expr; + mshadow::Stream *s = ctx.get_stream(); + mshadow::Tensor igrad = in_grad->FlatTo2D(s); + Assign(igrad, req, + F(out_value.data.FlatTo2D(s)) * + out_grad.data.FlatTo2D()); +} -// register all unary operations here -REGISTER_MSHADOW_UNARY(square, op::mshadow_op::square) +// Register all unary operations here +// Square +struct square_grad { + MSHADOW_XINLINE static real_t Map(real_t a) { + return 2.0f * a; + } +}; +// The true means inplace can be enabled. +MXNET_REGISTER_TBLOB_FUN(square, XPU) +.set_function(XPU::kDevMask, UnaryForward_, true) +.set_gradient(XPU::kDevMask, UnaryBackwardUseIn_, true) .describe("Take square of the src"); -REGISTER_MSHADOW_UNARY(sqrt, op::mshadow_op::square_root) -.describe("Take square root of the src"); +// Square root +struct square_root_grad { + MSHADOW_XINLINE static real_t Map(real_t a) { + return 0.5f / a; + } +}; +MXNET_REGISTER_TBLOB_FUN(sqrt, XPU) +.set_function(XPU::kDevMask, UnaryForward_, true) +.set_gradient(XPU::kDevMask, UnaryBackwardUseOut_, true) +.describe("Take square root of the src"); } // namespace ndarray } // namespace mxnet #endif // MXNET_NDARRAY_UNARY_FUNCTION_INL_H_ diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h index 9238ee049c0b..c8ca495d3349 100644 --- a/src/operator/mshadow_op.h +++ b/src/operator/mshadow_op.h @@ -80,7 +80,6 @@ struct tanh_grad { } }; - struct square { MSHADOW_XINLINE static real_t Map(real_t a) { return a * a; @@ -107,6 +106,7 @@ struct square_root { return sqrt(a); } }; + } // namespace mshadow_op } // namespace op } // namespace mxnet