diff --git a/Makefile b/Makefile index 1374138baf55..ece927144366 100644 --- a/Makefile +++ b/Makefile @@ -12,9 +12,6 @@ ifndef DMLC_CORE DMLC_CORE = dmlc-core endif -ifndef RABIT - RABIT = rabit -endif # use customized config file include $(config) @@ -65,15 +62,14 @@ endif ENGINE=naive_engine.o BIN = tests/test_simple_engine OBJ = narray_function_cpu.o -OBJCXX11 = batch_norm_cpu.o reshape_cpu.o narray.o c_api.o operator.o symbol.o storage.o fully_connected_cpu.o static_graph.o activation_cpu.o graph_executor.o softmax_cpu.o elementwise_sum_cpu.o pooling_cpu.o convolution_cpu.o io.o iter_mnist.o $(ENGINE) -CUOBJ = +OBJCXX11 = narray.o c_api.o operator.o symbol.o storage.o static_graph.o graph_executor.o io.o iter_mnist.o $(ENGINE) +CUOBJ = narray_function_gpu.o SLIB = lib/libmxnet.so ALIB = lib/libmxnet.a LIB_DEP = $(DMLC_CORE)/libdmlc.a - -ifeq ($(USE_CUDA), 1) - CUOBJ += batch_norm_gpu.o reshape_gpu.o narray_function_gpu.o fully_connected_gpu.o activation_gpu.o elementwise_sum_gpu.o pooling_gpu.o softmax_gpu.o convolution_gpu.o -endif +ALL_DEP = $(OBJ) $(OBJCXX11) $(LIB_DEP) +# common headers, change them will results in rebuild of all files +COMMON_HEADERS=include/mxnet/*.h src/common/*.h .PHONY: clean all test lint doc @@ -85,36 +81,36 @@ $(DMLC_CORE)/libdmlc.a: storage.o: src/storage/storage.cc naive_engine.o: src/dag_engine/naive_engine.cc dag_engine.o: src/dag_engine/dag_engine.cc -simple_engine.o: src/dag_engine/simple_engine.cc +simple_engine.o: src/dag_engine/simple_engine.cc narray.o: src/narray/narray.cc narray_function_cpu.o: src/narray/narray_function.cc src/narray/narray_function-inl.h narray_function_gpu.o: src/narray/narray_function.cu src/narray/narray_function-inl.h -symbol.o: src/symbol/symbol.cc -graph_executor.o: src/symbol/graph_executor.cc -static_graph.o : src/symbol/static_graph.cc +symbol.o: src/symbol/symbol.cc src/symbol/*.h +graph_executor.o: src/symbol/graph_executor.cc src/symbol/*.h +static_graph.o : src/symbol/static_graph.cc src/symbol/*.h operator.o: src/operator/operator.cc c_api.o: src/c_api.cc -fully_connected_cpu.o: src/operator/fully_connected.cc -fully_connected_gpu.o: src/operator/fully_connected.cu -activation_cpu.o: src/operator/activation.cc -activation_gpu.o: src/operator/activation.cu -elementwise_sum_cpu.o: src/operator/elementwise_sum.cc -elementwise_sum_gpu.o: src/operator/elementwise_sum.cu -pooling_cpu.o: src/operator/pooling.cc -pooling_gpu.o: src/operator/pooling.cu -softmax_cpu.o: src/operator/softmax.cc -softmax_gpu.o: src/operator/softmax.cu -convolution_cpu.o: src/operator/convolution.cc -convolution_gpu.o: src/operator/convolution.cu -reshape_cpu.o: src/operator/reshape.cc -reshape_gpu.o: src/operator/reshape.cu -batch_norm_cpu.o: src/operator/batch_norm.cc -batch_norm_gpu.o: src/operator/batch_norm.cu io.o: src/io/io.cc -iter_mnist.o: src/io/iter_mnist.cc +iter_mnist.o: src/io/iter_mnist.cc src/io/*.h -lib/libmxnet.a: $(OBJ) $(OBJCXX11) $(CUOBJ) $(LIB_DEP) -lib/libmxnet.so: $(OBJ) $(OBJCXX11) $(CUOBJ) $(LIB_DEP) +# Rules for operators +OPERATOR_HDR=$(wildcard src/operator/*-inl.h) +OPERATOR_OBJ=$(patsubst %-inl.h, %_cpu.o, $(OPERATOR_HDR)) +OPERATOR_CUOBJ=$(patsubst %-inl.h, %_gpu.o, $(OPERATOR_HDR)) + +ALL_DEP += $(OPERATOR_OBJ) +ifeq ($(USE_CUDA), 1) + ALL_DEP += $(OPERATOR_CUOBJ) $(CUOBJ) +endif + +src/operator/%_cpu.o : src/operator/%.cc src/operator/%-inl.h src/operator/mshadow_op.h src/operator/operator_common.h $(COMMON_HEADERS) + $(CXX) -std=c++0x -c $(CFLAGS) -o $@ $(filter %.cpp %.c %.cc, $^) + +src/operator/%_gpu.o : src/operator/%.cu src/operator/%-inl.h src/operator/operator_common.h src/operator/mshadow_op.h $(COMMON_HEADERS) + $(NVCC) -c -o $@ $(NVCCFLAGS) -Xcompiler "$(CFLAGS)" $(filter %.cu, $^) + +lib/libmxnet.a: $(ALL_DEP) +lib/libmxnet.so: $(ALL_DEP) tests/test_storage: tests/test_storage.cc lib/libmxnet.a tests/test_simple_engine: tests/test_simple_engine.cc lib/libmxnet.a @@ -122,10 +118,10 @@ tests/test_simple_engine: tests/test_simple_engine.cc lib/libmxnet.a $(BIN) : $(CXX) $(CFLAGS) -std=c++0x -o $@ $(filter %.cpp %.o %.c %.a %.cc, $^) $(LDFLAGS) -$(OBJ) : +$(OBJ) : $(COMMON_HEADERS) $(CXX) -c $(CFLAGS) -o $@ $(filter %.cpp %.c %.cc, $^) -$(OBJCXX11) : +$(OBJCXX11) : $(COMMON_HEADERS) $(CXX) -std=c++0x -c $(CFLAGS) -o $@ $(filter %.cpp %.c %.cc, $^) $(SLIB) : @@ -134,13 +130,12 @@ $(SLIB) : $(ALIB): $(OBJ) $(OBJCXX11) ar cr $@ $+ -$(CUOBJ) : +$(CUOBJ) :$(COMMON_HEADERS) $(NVCC) -c -o $@ $(NVCCFLAGS) -Xcompiler "$(CFLAGS)" $(filter %.cu, $^) $(CUBIN) : $(NVCC) -o $@ $(NVCCFLAGS) -Xcompiler "$(CFLAGS)" -Xlinker "$(LDFLAGS)" $(filter %.cu %.cpp %.o, $^) - lint: python dmlc-core/scripts/lint.py mxnet ${LINT_LANG} include src scripts test python @@ -148,5 +143,5 @@ doxygen: doxygen doc/Doxyfile clean: - $(RM) $(OBJ) $(OBJCXX11) $(BIN) $(CUBIN) $(CUOBJ) $(SLIB) $(ALIB) *~ */*~ */*/*~ */*/*/*~ + $(RM) $(ALL_DEP) $(SLIB) $(ALIB) *~ */*~ */*/*~ */*/*/*~ cd $(DMLC_CORE); make clean; cd - diff --git a/include/mxnet/narray.h b/include/mxnet/narray.h index 5d445261f4c1..06da4d841944 100755 --- a/include/mxnet/narray.h +++ b/include/mxnet/narray.h @@ -269,8 +269,9 @@ class NArray { friend void BinaryOp(const NArray &lhs, const NArray &rhs, NArray *out); template friend void UnaryOp(const NArray &lhs, const NArray &rhs, NArray *out); - template + template friend void ScalarOp(const NArray &lhs, const real_t &rhs, NArray *out); + friend void SetValueOp(const real_t &rhs, NArray *out); }; /*! @@ -385,6 +386,23 @@ struct NArrayFunctionReg num_mutate_vars(0), num_scalars(0), type_mask(0) {} + /*! + * \brief set the function body to a NArray setvalue function + * this will also auto set the parameters correctly + * \param fsetvalue function body to set + * \return ref to the registered entry, used to set properties + */ + inline NArrayFunctionReg &set_function(void fsetvalue(const real_t &rhs, + NArray *out)) { + body = [fsetvalue] (NArray **used_vars, + real_t *s, NArray **mutate_vars) { + fsetvalue(s[0], mutate_vars[0]); + }; + num_mutate_vars = 1; num_scalars = 1; + // type_mask = kNArrayArgBeforeScalar; + this->add_argument("rhs", "real_t", "Right operand to the function."); + return *this; + } /*! * \brief set the function body to a binary NArray function * this will also auto set the parameters correctly diff --git a/include/mxnet/operator.h b/include/mxnet/operator.h index 5fe7f2d7ee5a..9b900db4fdb9 100755 --- a/include/mxnet/operator.h +++ b/include/mxnet/operator.h @@ -80,7 +80,7 @@ class Operator { * \param req the request types of saving operation, can only be kWriteTo or kWriteInplace. * \param out_data array of output data, pointer is used to indicate that this is holder * the space of TBlob in out_data must be pre-allocated with InferShape - * \param aux_states Auxiliary states of operator. Normally operator doesn't + * \param aux_states Auxiliary states of operator. Normally operator doesn't * need, epecial case like Batch Norm requires. * \sa OpReqType, OpContext */ @@ -411,11 +411,12 @@ struct OperatorPropertyReg * \endcode */ #define MXNET_REGISTER_OP_PROPERTY(name, OperatorPropertyType) \ - static ::mxnet::OperatorProperty* __create__ ## OperatorPropertyType ## __() { \ + static ::mxnet::OperatorProperty* __create__ ## OperatorProperty ## name ## __() { \ return new OperatorPropertyType; \ } \ DMLC_REGISTRY_REGISTER(::mxnet::OperatorPropertyReg, OperatorPropertyReg, name) \ - .set_body(__create__ ## OperatorPropertyType ## __) + .set_body(__create__ ## OperatorProperty ## name ## __) + #endif // DMLC_USE_CXX11 } // namespace mxnet #endif // MXNET_OPERATOR_H_ diff --git a/python/mxnet/base.py b/python/mxnet/base.py index 1e7b66bb1317..df91998e9a45 100644 --- a/python/mxnet/base.py +++ b/python/mxnet/base.py @@ -15,11 +15,13 @@ #---------------------------- if sys.version_info[0] == 3: string_types = str, + numeric_types = (float, int) # this function is needed for python3 # to convert ctypes.char_p .value back to python str py_str = lambda x: x.decode('utf-8') else: string_types = basestring, + numeric_types = (float, int, long) py_str = lambda x: x diff --git a/python/mxnet/executor.py b/python/mxnet/executor.py old mode 100755 new mode 100644 diff --git a/python/mxnet/narray.py b/python/mxnet/narray.py index b0f38e94c3a5..76fdc5f893d0 100755 --- a/python/mxnet/narray.py +++ b/python/mxnet/narray.py @@ -3,8 +3,9 @@ from __future__ import absolute_import import ctypes +import warnings import sys -from .base import _LIB, string_types +from .base import _LIB, string_types, numeric_types from .base import c_array, py_str, c_str from .base import mx_uint, mx_float, NArrayHandle, FunctionHandle from .base import ctypes2numpy_shared, ctypes2buffer @@ -66,7 +67,7 @@ def __del__(self): def __add__(self, other): if isinstance(other, NArray): return NArray._plus(self, other) - elif isinstance(other, float) or isinstance(other, int): + elif isinstance(other, numeric_types): return NArray._plus_scalar(self, float(other)) else: raise TypeError('type %s not supported' % str(type(other))) @@ -74,7 +75,7 @@ def __add__(self, other): def __iadd__(self, other): if isinstance(other, NArray): return NArray._plus(self, other, out=self) - elif isinstance(other, float) or isinstance(other, int): + elif isinstance(other, numeric_types): return NArray._plus_scalar(self, float(other), out=self) else: raise TypeError('type %s not supported' % str(type(other))) @@ -85,7 +86,7 @@ def __radd__(self, other): def __sub__(self, other): if isinstance(other, NArray): return NArray._minus(self, other) - elif isinstance(other, float) or isinstance(other, int): + elif isinstance(other, numeric_types): return NArray._minus_scalar(self, float(other)) else: raise TypeError('type %s not supported' % str(type(other))) @@ -93,23 +94,32 @@ def __sub__(self, other): def __isub__(self, other): if isinstance(other, NArray): return NArray._minus(self, other, out=self) - elif isinstance(other, float) or isinstance(other, int): + elif isinstance(other, numeric_types): return NArray._minus_scalar(self, float(other), out=self) else: raise TypeError('type %s not supported' % str(type(other))) + def __rsub__(self, other): + if isinstance(other, numeric_types): + return NArray._rminus_scalar(self, float(other)) + else: + raise TypeError('type %s not supported' % str(type(other))) + def __mul__(self, other): if isinstance(other, NArray): return NArray._mul(self, other) - elif isinstance(other, float) or isinstance(other, int): + elif isinstance(other, numeric_types): return NArray._mul_scalar(self, float(other)) else: raise TypeError('type %s not supported' % str(type(other))) + def __neg__(self): + return NArray._mul_scalar(self, -1.0, out=self) + def __imul__(self, other): if isinstance(other, NArray): return NArray._mul(self, other, out=self) - elif isinstance(other, float) or isinstance(other, int): + elif isinstance(other, numeric_types): return NArray._mul_scalar(self, float(other), out=self) else: raise TypeError('type %s not supported' % str(type(other))) @@ -120,15 +130,21 @@ def __rmul__(self, other): def __div__(self, other): if isinstance(other, NArray): return NArray._div(self, other) - elif isinstance(other, float) or isinstance(other, int): + elif isinstance(other, numeric_types): return NArray._div_scalar(self, float(other)) else: raise TypeError('type %s not supported' % str(type(other))) + def __rdiv__(self, other): + if isinstance(other, numeric_types): + return NArray._rdiv_scalar(self, float(other)) + else: + raise TypeError('type %s not supported' % str(type(other))) + def __idiv__(self, other): if isinstance(other, NArray): return NArray._div(self, other, out=self) - elif isinstance(other, float) or isinstance(other, int): + elif isinstance(other, numeric_types): return NArray._div_scalar(self, float(other), out=self) else: raise TypeError('type %s not supported' % str(type(other))) @@ -163,9 +179,13 @@ def __setitem__(self, in_slice, value): """Set narray value""" if in_slice.step != None: raise Exception("Set NArray should use empty index array[:] = target_array") - if isinstance(value, NArray) == False: + if isinstance(value, NArray): + if value.handle is not self.handle: + value.copyto(self) + elif isinstance(value, numeric_types): + return NArray._set_value(float(value), out=self) + else: raise TypeError('type %s not supported' % str(type(value))) - value.copyto(self) def __getitem__(self, in_slice): """Get narray""" @@ -238,6 +258,10 @@ def copyto(self, other): The copy target NArray """ if isinstance(other, NArray): + if other.handle is self.handle: + warnings.warn('copy an array to itself, is it intended?', + RuntimeWarning) + return return NArray._copyto(self, out=other) elif isinstance(other, Context): hret = NArray(_new_alloc_handle(self.shape, other, True)) diff --git a/python/mxnet/symbol.py b/python/mxnet/symbol.py old mode 100755 new mode 100644 index 3fb5ae665764..9b93db788173 --- a/python/mxnet/symbol.py +++ b/python/mxnet/symbol.py @@ -15,6 +15,7 @@ class Symbol(object): """Symbol is symbolic graph of the mxnet.""" + # pylint: disable=no-member def __init__(self, handle): """Initialize the function with handle @@ -25,6 +26,39 @@ def __init__(self, handle): """ self.handle = handle + def __add__(self, other): + if isinstance(other, Symbol): + return Symbol._Plus(self, other) + else: + raise TypeError('type %s not supported' % str(type(other))) + + def __radd__(self, other): + return self.__add__(other) + + def __sub__(self, other): + if isinstance(other, Symbol): + return Symbol._Minus(self, other) + else: + raise TypeError('type %s not supported' % str(type(other))) + + def __mul__(self, other): + if isinstance(other, Symbol): + return Symbol._Mul(self, other) + else: + raise TypeError('type %s not supported' % str(type(other))) + + def __rmul__(self, other): + return self.__mul__(other) + + def __div__(self, other): + if isinstance(other, Symbol): + return Symbol._Div(self, other) + else: + raise TypeError('type %s not supported' % str(type(other))) + + def __truediv__(self, other): + return self.__div__(other) + def __del__(self): check_call(_LIB.MXSymbolFree(self.handle)) @@ -238,19 +272,19 @@ def debug_str(self): return py_str(debug_str.value) def bind(self, ctx, args, args_grad, reqs, aux_states=None): - """bind current symbol to get an executor. + """Bind current symbol to get an executor. Parameters ---------- - ctx: Context + ctx : Context context executor to run on - args: Array of NArray + args : Array of NArray input args to the symbol - args_grad: Array of NArray + args_grad : Array of NArray input args' gradient - reqs: Array of enum + reqs : Array of enum graident requirements - aux_states: Array of NArray + aux_states : Array of NArray input auxiliary states to the symbol """ # TODO(bing): consider a more friendly interface @@ -278,11 +312,11 @@ def bind(self, ctx, args, args_grad, reqs, aux_states=None): return Executor(handle) def grad(self, wrt): - """get the autodiff of current symbol. + """Get the autodiff of current symbol. Parameters ---------- - wrt: Array of String + wrt : Array of String keyword arguments of the symbol that the gradients are taken. """ handle = SymbolHandle() @@ -292,6 +326,8 @@ def grad(self, wrt): c_wrt, ctypes.byref(handle))) return Symbol(handle) + # pylint: enable= no-member + def Variable(name): """Create a symbolic variable with specified name. @@ -431,7 +467,10 @@ def _init_symbol_module(): for i in range(size.value): hdl = SymbolHandle(plist[i]) function = _make_atomic_symbol_function(hdl) - setattr(module_obj, function.__name__, function) + if function.__name__.startswith('_'): + setattr(Symbol, function.__name__, staticmethod(function)) + else: + setattr(module_obj, function.__name__, function) # Initialize the atomic symbo in startups _init_symbol_module() diff --git a/src/narray/narray.cc b/src/narray/narray.cc index a9f7ebde678d..eee59ed8ecd1 100755 --- a/src/narray/narray.cc +++ b/src/narray/narray.cc @@ -78,6 +78,34 @@ inline void BinaryOp(const NArray &lhs, } } +inline void SetValueOp(const real_t &rhs, NArray *out) { + CHECK_NE(out->is_none(), true) << "Set value target must not be empty"; + // important: callback must always capture by value + NArray ret = *out; + switch (ret.ctx().dev_mask) { + case cpu::kDevMask: { + auto func = [rhs, ret](RunContext ctx) { + ret.ptr_->CheckAndAlloc(); + TBlob tmp = ret.data(); + narray::Eval(rhs, &tmp, ctx); + }; + DAGEngine::Get()->Push(func, ret.ctx(), {}, {ret.ptr_->var}); + break; + } +#if MXNET_USE_CUDA + case gpu::kDevMask: { + auto func = [rhs, ret](RunContext ctx) { + ret.ptr_->CheckAndAlloc(); + TBlob tmp = ret.data(); + narray::Eval(rhs, &tmp, ctx); + }; + DAGEngine::Get()->Push(func, ret.ctx(), {}, {ret.ptr_->var}); + break; + } +#endif + default: LOG(FATAL) << "GPU is not enabled"; + } +} /*! * \brief run a binary operation * \param lhs left operand @@ -85,7 +113,7 @@ inline void BinaryOp(const NArray &lhs, * \param out the output narray * \param binary_op the real */ -template +template inline void ScalarOp(const NArray &lhs, const real_t &rhs, NArray *out) { @@ -104,7 +132,7 @@ inline void ScalarOp(const NArray &lhs, auto func = [lhs, rhs, ret](RunContext ctx) { ret.ptr_->CheckAndAlloc(); TBlob tmp = ret.data(); - narray::Eval(lhs.data(), rhs, &tmp, ctx); + narray::Eval(lhs.data(), rhs, &tmp, ctx); }; if (lhs.ptr_->var == ret.ptr_->var) { DAGEngine::Get()->Push(func, lhs.ctx(), {}, {ret.ptr_->var}); @@ -118,7 +146,7 @@ inline void ScalarOp(const NArray &lhs, auto func = [lhs, rhs, ret](RunContext ctx) { ret.ptr_->CheckAndAlloc(); TBlob tmp = ret.data(); - narray::Eval(lhs.data(), rhs, &tmp, ctx); + narray::Eval(lhs.data(), rhs, &tmp, ctx); }; if (lhs.ptr_->var == ret.ptr_->var) { DAGEngine::Get()->Push(func, lhs.ctx(), {}, {ret.ptr_->var}); @@ -194,11 +222,11 @@ inline NArray BinaryOpRet(const NArray &lhs, return ret; } -template +template inline NArray ScalarOpRet(const NArray &lhs, const real_t &rhs) { NArray ret; - ScalarOp(lhs, rhs, &ret); + ScalarOp(lhs, rhs, &ret); return ret; } @@ -212,9 +240,10 @@ inline NArray &BinaryOpApply(NArray *dst, template inline NArray &ScalarOpApply(NArray *dst, const real_t &src) { - ScalarOp(*dst, src, dst); + ScalarOp(*dst, src, dst); return *dst; } + // Binary NArray operator+(const NArray &lhs, const NArray &rhs) { return BinaryOpRet(lhs, rhs); @@ -230,18 +259,23 @@ NArray operator/(const NArray &lhs, const NArray &rhs) { } // Scalar NArray operator+(const NArray &lhs, const real_t &rhs) { - return ScalarOpRet(lhs, rhs); + return ScalarOpRet(lhs, rhs); } NArray operator-(const NArray &lhs, const real_t &rhs) { - return ScalarOpRet(lhs, rhs); + return ScalarOpRet(lhs, rhs); } NArray operator*(const NArray &lhs, const real_t &rhs) { - return ScalarOpRet(lhs, rhs); + return ScalarOpRet(lhs, rhs); } NArray operator/(const NArray &lhs, const real_t &rhs) { - return ScalarOpRet(lhs, rhs); + return ScalarOpRet(lhs, rhs); } // Binary +NArray &NArray::operator=(real_t scalar) { + SetValueOp(scalar, this); + return *this; +} + NArray &NArray::operator+=(const NArray &src) { return BinaryOpApply(this, src); } @@ -333,16 +367,28 @@ NArray NArray::Copy(Context ctx) const { // register API function // those with underscore will be registered at NArray +MXNET_REGISTER_NARRAY_FUN(_set_value).set_function(SetValueOp); + MXNET_REGISTER_NARRAY_FUN(_plus).set_function(BinaryOp); MXNET_REGISTER_NARRAY_FUN(_minus).set_function(BinaryOp); MXNET_REGISTER_NARRAY_FUN(_mul).set_function(BinaryOp); MXNET_REGISTER_NARRAY_FUN(_div).set_function(BinaryOp); -/////// -MXNET_REGISTER_NARRAY_FUN(_plus_scalar).set_function(ScalarOp); -MXNET_REGISTER_NARRAY_FUN(_minus_scalar).set_function(ScalarOp); -MXNET_REGISTER_NARRAY_FUN(_mul_scalar).set_function(ScalarOp); -MXNET_REGISTER_NARRAY_FUN(_div_scalar).set_function(ScalarOp); +// register API function +// those with underscore will be registered at NArray +// scalar +MXNET_REGISTER_NARRAY_FUN(_plus_scalar).set_function(ScalarOp); +MXNET_REGISTER_NARRAY_FUN(_minus_scalar).set_function(ScalarOp); +MXNET_REGISTER_NARRAY_FUN(_mul_scalar).set_function(ScalarOp); +MXNET_REGISTER_NARRAY_FUN(_div_scalar).set_function(ScalarOp); + +// register API function +// those with underscore will be registered at NArray +// scalar +// reverse scalar +MXNET_REGISTER_NARRAY_FUN(_rminus_scalar).set_function(ScalarOp); +MXNET_REGISTER_NARRAY_FUN(_rdiv_scalar).set_function(ScalarOp); + // copy function is special // that we need to remove kAcceptEmptyMutateTarget from it MXNET_REGISTER_NARRAY_FUN(_copyto) diff --git a/src/narray/narray_function-inl.h b/src/narray/narray_function-inl.h index 6c79d74e3f5a..155a4e19c1b7 100755 --- a/src/narray/narray_function-inl.h +++ b/src/narray/narray_function-inl.h @@ -17,13 +17,23 @@ #endif #ifndef DECL_SCALAR -#define DECL_SCALAR(XPU, OP, FUN) \ +#define DECL_SCALAR(XPU, OP, FUN, REVERSE) \ template<> \ - void Eval(const TBlob &lhs, const real_t &rhs, TBlob *ret, RunContext ctx) { \ - FUN(lhs, rhs, ret, ctx); \ + void Eval(const TBlob &lhs, const real_t &rhs, TBlob *ret, RunContext ctx) { \ + FUN(lhs, rhs, ret, ctx); \ + } +#endif + +#ifndef DECL_SETVALUE +#define DECL_SETVALUE(XPU) \ + template<> \ + void Eval(const real_t &rhs, TBlob *ret, RunContext ctx) { \ + mshadow::Stream *s = static_cast*>(ctx.stream); \ + ret->FlatTo2D(s) = rhs; \ } #endif + #if defined(__CUDACC__) #define DEVICE gpu #else @@ -43,23 +53,37 @@ inline void EvalBinary_(const TBlob &lhs, const TBlob &rhs, rhs.FlatTo2D(s)); } -template +template inline void EvalScalar_(const TBlob &lhs, const real_t &rhs, TBlob *ret, RunContext ctx) { using namespace mshadow::expr; mshadow::Stream *s = static_cast*>(ctx.stream); - ret->FlatTo2D(s) - = F(lhs.FlatTo2D(s), rhs); + if (reverse) { + ret->FlatTo2D(s) + = F(rhs, lhs.FlatTo2D(s)); + } else { + ret->FlatTo2D(s) + = F(lhs.FlatTo2D(s), rhs); + } } + + // declarations DECL_BINARY(DEVICE, Plus, EvalBinary_) DECL_BINARY(DEVICE, Minus, EvalBinary_) DECL_BINARY(DEVICE, Mul, EvalBinary_) DECL_BINARY(DEVICE, Div, EvalBinary_) -DECL_SCALAR(DEVICE, Plus, EvalScalar_) -DECL_SCALAR(DEVICE, Minus, EvalScalar_) -DECL_SCALAR(DEVICE, Mul, EvalScalar_) -DECL_SCALAR(DEVICE, Div, EvalScalar_) +DECL_SCALAR(DEVICE, Plus, EvalScalar_, true) +DECL_SCALAR(DEVICE, Minus, EvalScalar_, true) +DECL_SCALAR(DEVICE, Mul, EvalScalar_, true) +DECL_SCALAR(DEVICE, Div, EvalScalar_, true) +// for reverse seq +DECL_SCALAR(DEVICE, Plus, EvalScalar_, false) +DECL_SCALAR(DEVICE, Minus, EvalScalar_, false) +DECL_SCALAR(DEVICE, Mul, EvalScalar_, false) +DECL_SCALAR(DEVICE, Div, EvalScalar_, false) +// +DECL_SETVALUE(DEVICE) } // namespace narray } // namespace mxnet diff --git a/src/narray/narray_function.h b/src/narray/narray_function.h old mode 100644 new mode 100755 index 4ea556883ede..dc879c28c1e8 --- a/src/narray/narray_function.h +++ b/src/narray/narray_function.h @@ -36,9 +36,12 @@ struct Div : public BinaryBase { template void Eval(const TBlob &lhs, const TBlob &rhs, TBlob *ret, RunContext ctx); -template +template void Eval(const TBlob &lhs, const real_t &rhs, TBlob *ret, RunContext ctx); +template +void Eval(const real_t &rhs, TBlob *ret, RunContext ctx); + // copy function when only cpu is involved template void Copy(const TBlob &from, TBlob *to, diff --git a/src/operator/elementwise_binary_op-inl.h b/src/operator/elementwise_binary_op-inl.h new file mode 100644 index 000000000000..89d8b115bf6b --- /dev/null +++ b/src/operator/elementwise_binary_op-inl.h @@ -0,0 +1,234 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file elementwise_binary_op-inl.h + * \brief Elementwise binary operation, plus, minus, mul, div +*/ +#ifndef MXNET_OPERATOR_ELEMENTWISE_BINARY_OP_INL_H_ +#define MXNET_OPERATOR_ELEMENTWISE_BINARY_OP_INL_H_ + +#include +#include +#include +#include +#include +#include +#include "./operator_common.h" +#include "./mshadow_op.h" + +namespace mxnet { +namespace op { + +enum ElementwiseBinaryOpInputs {kLhs, kRhs}; +enum ElementwiseBinaryOpOutputs {kOut}; +enum ElementwiseBinaryOpType {kPlus, kMinus, kMul, kDiv}; + +template +inline ElementwiseBinaryOpType GetOpType(); + +template +inline const char* GetOpTypeString(); + +template<> +inline ElementwiseBinaryOpType GetOpType() { + return kPlus; +} +template<> +inline ElementwiseBinaryOpType GetOpType() { + return kMinus; +} +template<> +inline ElementwiseBinaryOpType GetOpType() { + return kMul; +} +template<> +inline ElementwiseBinaryOpType GetOpType() { + return kDiv; +} + +template<> +inline const char* GetOpTypeString() { + return "_Plus"; +} +template<> +inline const char* GetOpTypeString() { + return "_Minus"; +} + +template<> +inline const char* GetOpTypeString() { + return "_Mul"; +} + +template<> +inline const char* GetOpTypeString() { + return "_Div"; +} + +template +class ElementwiseBinaryOp : public Operator { + public: + virtual void Forward(const OpContext &ctx, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data, + const std::vector &aux_args) { + using namespace mshadow; + using namespace mshadow::expr; + CHECK_EQ(in_data.size(), 2); + CHECK_EQ(out_data.size(), 1); + Stream *s = ctx.get_stream(); + Tensor lhs = in_data[kLhs].FlatTo2D(s); + Tensor rhs = in_data[kRhs].FlatTo2D(s); + Tensor out = out_data[kOut].FlatTo2D(s); + Assign(out, req[kOut], F(lhs, rhs)); + } + + virtual void Backward(const OpContext &ctx, + const std::vector &out_grad, + const std::vector &in_data, + const std::vector &out_data, + const std::vector &req, + const std::vector &in_grad, + const std::vector &aux_args) { + using namespace mshadow; + using namespace mshadow::expr; + CHECK_EQ(out_grad.size(), 1); + CHECK(in_data.size() == 2 && in_grad.size() == 2); + CHECK_EQ(req.size(), 2); + + Stream *s = ctx.get_stream(); + Tensor m_out_grad = out_grad[kOut].FlatTo2D(s); + Tensor lhs_grad = in_grad[kLhs].FlatTo2D(s); + Tensor rhs_grad = in_grad[kRhs].FlatTo2D(s); + switch (GetOpType()) { + case kPlus: { + Assign(lhs_grad, req[kLhs], F(m_out_grad)); + Assign(rhs_grad, req[kRhs], F(m_out_grad)); + break; + } + case kMinus: { + Assign(lhs_grad, req[kLhs], F(m_out_grad)); + Assign(rhs_grad, req[kRhs], F(m_out_grad)); + break; + } + case kMul: { + Tensor lhs_data = in_data[kLhs].FlatTo2D(s); + Tensor rhs_data = in_data[kRhs].FlatTo2D(s); + // rhs cannot do inplace + CHECK_NE(req[kRhs], kWriteInplace); + Assign(rhs_grad, req[kRhs], lhs_data * m_out_grad); + Assign(lhs_grad, req[kLhs], rhs_data * m_out_grad); + break; + } + case kDiv: { + Tensor lhs_data = in_data[kLhs].FlatTo2D(s); + Tensor rhs_data = in_data[kRhs].FlatTo2D(s); + // rhs cannot do inplace + CHECK_NE(req[kRhs], kWriteInplace); + Assign(rhs_grad, req[kRhs], + F(m_out_grad * lhs_data) / F(rhs_data)); + Assign(lhs_grad, req[kLhs], m_out_grad / rhs_data); + break; + } + } + } +}; // class ElementwiseBinaryOp + + +template +inline Operator* CreateElementwiseBinaryOp_(ElementwiseBinaryOpType type) { + switch (type) { + case kPlus: return new ElementwiseBinaryOp(); + case kMinus: return new ElementwiseBinaryOp(); + case kMul: return new ElementwiseBinaryOp(); + case kDiv: return new ElementwiseBinaryOp(); + } + LOG(FATAL) << "uknown op type"; + return NULL; +} + +// Decalre Factory function, used for dispatch specialization +template +Operator* CreateElementwiseBinaryOp(ElementwiseBinaryOpType type); + +#if DMLC_USE_CXX11 +template +class ElementwiseBinaryOpProp : public OperatorProperty { + public: + void Init(const std::vector >& kwargs) override { + CHECK_EQ(kwargs.size(), 0) + << TypeString() << " do not take any additional keyword arguments besides lhs and rhs"; + } + + bool InferShape(std::vector *in_shape, + std::vector *out_shape, + std::vector *aux_shape) const override { + using namespace mshadow; + CHECK_EQ(in_shape->size(), 2) << "Input:[lhs, rhs]"; + if (in_shape->at(kLhs).ndim() != 0) { + SHAPE_ASSIGN_CHECK(*in_shape, kRhs, in_shape->at(kLhs)); + } else if (in_shape->at(kRhs).ndim() != 0) { + in_shape->at(kLhs) = in_shape->at(kRhs); + } else { + return false; + } + const TShape &dshape = in_shape->at(kLhs); + out_shape->clear(); + out_shape->push_back(dshape); + return true; + } + + std::vector ListArguments() const override { + return {"lhs", "rhs"}; + } + + OperatorProperty* Copy() const override { + return new ElementwiseBinaryOpProp(); + } + + std::string TypeString() const override { + return GetOpTypeString(); + } + + // decalre dependency and inplace optimization options + std::vector DeclareBackwardDependency( + const std::vector &out_grad, + const std::vector &in_data, + const std::vector &out_data) const override { + switch (GetOpType()) { + case kPlus: + case kMinus: return {out_grad[kOut]}; + case kMul: + case kDiv: return {out_grad[kOut], in_data[kLhs], in_data[kRhs]}; + } + LOG(FATAL) << "not reached"; + return {}; + } + + std::vector > BackwardInplaceOption( + const std::vector &out_grad, + const std::vector &in_data, + const std::vector &out_data, + const std::vector &in_grad) const override { + switch (GetOpType()) { + case kPlus: + case kMinus: return {}; + case kMul: + case kDiv: return {{out_grad[kOut], in_grad[kLhs]}}; + } + LOG(FATAL) << "not reached"; + return {}; + } + + std::vector > ForwardInplaceOption( + const std::vector &in_data, + const std::vector &out_data) const override { + return {{in_data[kLhs], out_data[kOut]}}; + } + + Operator* CreateOperator(Context ctx) const override; +}; +#endif // DMLC_USE_CXX11 +} // namespace op +} // namespace mxnet +#endif // MXNET_OPERATOR_ELEMENTWISE_BINARY_OP_INL_H_ diff --git a/src/operator/elementwise_binary_op.cc b/src/operator/elementwise_binary_op.cc new file mode 100644 index 000000000000..bdf228df4851 --- /dev/null +++ b/src/operator/elementwise_binary_op.cc @@ -0,0 +1,31 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file elementwise_binary_op.cc + * \brief elementwise binary operator +*/ +#include "./elementwise_binary_op-inl.h" + +namespace mxnet { +namespace op { +template<> +Operator* CreateElementwiseBinaryOp(ElementwiseBinaryOpType type) { + return CreateElementwiseBinaryOp_(type); +} + +// DO_BIND_DISPATCH comes from static_operator_common.h +template +Operator* ElementwiseBinaryOpProp::CreateOperator(Context ctx) const { + DO_BIND_DISPATCH(CreateElementwiseBinaryOp, GetOpType()); +} + +MXNET_REGISTER_OP_PROPERTY(_Plus, ElementwiseBinaryOpProp) +.describe("Perform an elementwise plus."); +MXNET_REGISTER_OP_PROPERTY(_Minus, ElementwiseBinaryOpProp) +.describe("Perform an elementwise minus."); +MXNET_REGISTER_OP_PROPERTY(_Mul, ElementwiseBinaryOpProp) +.describe("Perform an elementwise mul."); +MXNET_REGISTER_OP_PROPERTY(_Div, ElementwiseBinaryOpProp) +.describe("Perform an elementwise div."); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/elementwise_binary_op.cu b/src/operator/elementwise_binary_op.cu new file mode 100644 index 000000000000..36616dbda797 --- /dev/null +++ b/src/operator/elementwise_binary_op.cu @@ -0,0 +1,15 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file elementwise_binary_op.cu + * \brief elementwise binary operator +*/ +#include "./elementwise_binary_op-inl.h" + +namespace mxnet { +namespace op { +template<> +Operator* CreateElementwiseBinaryOp(ElementwiseBinaryOpType type) { + return CreateElementwiseBinaryOp_(type); +} +} // namespace op +} // namespace mxnet diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h index 010cf0ce7cc9..1994006b2205 100644 --- a/src/operator/mshadow_op.h +++ b/src/operator/mshadow_op.h @@ -18,12 +18,20 @@ struct identity { return a; } }; + struct identity_grad { MSHADOW_XINLINE static real_t Map(real_t a) { return 1.0f; } }; + +struct negation { + MSHADOW_XINLINE static real_t Map(real_t a) { + return -a; + } +}; + /*! \brief sigmoid unit */ struct sigmoid { MSHADOW_XINLINE static real_t Map(real_t a) { diff --git a/src/symbol/symbol.cc b/src/symbol/symbol.cc index ddc7d96556e6..6daa30ef21d1 100755 --- a/src/symbol/symbol.cc +++ b/src/symbol/symbol.cc @@ -250,8 +250,9 @@ void Symbol::Compose(const std::vector& args, CHECK(!heads_[0].source->is_variable()) << "Variable cannot be composed"; heads_[0].source->name = name; for (size_t i = 0; i < args.size(); ++i) { - CHECK_NE(args[i].NumReturns(), 1) - << "Argument " << i << " is a tuple, scalar is required"; + CHECK_EQ(args[i].NumReturns(), 1) + << "Argument " << i << " is a tuple with " << args[i].NumReturns() + << " elements, scalar is required"; } // positional arguments requires all arguments for now. // TODO(bing) consider partial assignments diff --git a/tests/python/test_bind.py b/tests/python/test_bind.py new file mode 100644 index 000000000000..17ea3353fc23 --- /dev/null +++ b/tests/python/test_bind.py @@ -0,0 +1,64 @@ +import numpy as np +import mxnet as mx + + +def reldiff(a, b): + diff = np.sum(np.abs(a - b)) + norm = np.sum(np.abs(a)) + reldiff = diff / norm + return reldiff + + +def check_bind_with_uniform(uf, gf, dim): + """check function consistency with uniform random numbers""" + shape = tuple(np.random.randint(1, int(1000**(1.0/dim)), size=dim)) + lhs = mx.symbol.Variable('lhs') + rhs = mx.symbol.Variable('rhs') + ret = uf(lhs, rhs) + assert ret.list_arguments() == ['lhs', 'rhs'] + lhs_arr = mx.narray.create(shape) + rhs_arr = mx.narray.create(shape) + lhs_grad = mx.narray.create(shape) + rhs_grad = mx.narray.create(shape) + lhs_arr.numpy[:] = np.random.uniform(-10, 10, shape) + rhs_arr.numpy[:] = np.random.uniform(-10, 10, shape) + + executor = ret.bind(mx.Context('cpu'), + args=[lhs_arr, rhs_arr], + args_grad=[lhs_grad, rhs_grad], + reqs=['write_to'] * 2) + executor.forward() + out2 = executor.heads()[0].numpy + out1 = uf(lhs_arr.numpy, rhs_arr.numpy) + assert reldiff(out1, out2) < 1e-6 + # test gradient + out_grad = mx.narray.create(shape) + out_grad.numpy[:] = np.ones(shape) + lhs_grad2, rhs_grad2 = gf(out_grad.numpy, + lhs_arr.numpy, + rhs_arr.numpy) + executor.backward([out_grad]) + assert reldiff(lhs_grad.numpy, lhs_grad2) < 1e-6 + assert reldiff(rhs_grad.numpy, rhs_grad2) < 1e-6 + + +def test_bind(): + np.random.seed(0) + nrepeat = 10 + maxdim = 4 + for repeat in range(nrepeat): + for dim in range(1, maxdim): + check_bind_with_uniform(lambda x, y: x + y, + lambda g, x, y: (g, g), + dim) + check_bind_with_uniform(lambda x, y: x - y, + lambda g, x, y: (g, -g), + dim) + check_bind_with_uniform(lambda x, y: x * y, + lambda g, x, y: (y * g, x * g), + dim) + check_bind_with_uniform(lambda x, y: x / y, + lambda g, x, y: (g / y, -x * g/ (y**2)), + dim) + + diff --git a/tests/python/test_narray.py b/tests/python/test_narray.py index 6e3466fbe473..b6325112ba99 100644 --- a/tests/python/test_narray.py +++ b/tests/python/test_narray.py @@ -44,7 +44,7 @@ def test_narray_elementwise(): check_with_uniform(lambda x, y: x + y, 2, dim) check_with_uniform(lambda x, y: x - y, 2, dim) check_with_uniform(lambda x, y: x * y, 2, dim) - # check_with_uniform(lambda x, y: x / y, 2, dim) + check_with_uniform(lambda x, y: x / y, 2, dim) def test_narray_copy(): c = mx.narray.create((10,10)) @@ -60,8 +60,12 @@ def test_narray_scalar(): d.numpy[:] = 1.0 d -= c * 2 / 3 * 6.0 c += 0.5 - assert(np.sum(c.numpy) == 100) - assert(np.sum(d.numpy) == -100) + assert(np.sum(c.numpy) - 100 < 1e-5) + assert(np.sum(d.numpy) + 100 < 1e-5) + c[:] = 2 + assert(np.sum(c.numpy) - 200 < 1e-5) + d = -c + 2 + assert(np.sum(c.numpy) < 1e-5) def test_narray_pickle(): np.random.seed(0)