diff --git a/Makefile b/Makefile index ece927144366..d758c443241e 100644 --- a/Makefile +++ b/Makefile @@ -137,7 +137,7 @@ $(CUBIN) : $(NVCC) -o $@ $(NVCCFLAGS) -Xcompiler "$(CFLAGS)" -Xlinker "$(LDFLAGS)" $(filter %.cu %.cpp %.o, $^) lint: - python dmlc-core/scripts/lint.py mxnet ${LINT_LANG} include src scripts test python + python dmlc-core/scripts/lint.py mxnet ${LINT_LANG} include src scripts python doxygen: doxygen doc/Doxyfile diff --git a/doc/python/python_guide.md b/doc/python/python_guide.md index 43d28b0f7811..29ec50519c3a 100644 --- a/doc/python/python_guide.md +++ b/doc/python/python_guide.md @@ -64,11 +64,11 @@ cpu_array.copyto(gpu_array) # create a new copy of NArray on GPU 0 gpu_array2 = cpu_array.copyto(mx.Context('gpu', 0)) -# do some operations on GPU -gpu_array = gpu_array + gpu_array2 +# do some operations on GPU, the result will be on same device. +gpu_array3 = gpu_array2 + 1.0 # copy back to CPU -gpu_array.copyto(cpu_array) +gpu_array3.copyto(cpu_array) # print the result print(cpu_array.numpy) @@ -76,7 +76,20 @@ print(cpu_array.numpy) In common workflow, it is encouraged to copy the data into a GPU NArray, do as much as computation as you can, and copy it back to CPU. +Besides the NArrays that are explicitly created, the computation will +generate result NArray that are sit on the same device. +It is important to note that mxnet do not support arthematic inputs +from two different devices. You need to insert a copyto explicitly +to do the computation, like showed in the following example. +```python +cpu_array = mx.narray.ones((10, 10)) +gpu_array = mx.narray.create((10, 10), mx.Context('gpu', 0)) +gpu_array2 = gpu_array + cpu_array.copyto(gpu_array.context) +``` + +We made this choice because the copy between devices creates additional overhead. +The current API makes the copy cost transparent to the user. ### Automatically Parallelizing Computation So far you have learnt the basics of NArray, hope you like the flavor so far. @@ -242,17 +255,50 @@ You can also specify it explicitly, like the following code. ['data', 'myweight', 'fc1_bias'] ``` +Besides the coarse grained neuralnet operators such as FullyConnected, Convolution. +MXNet also provides fine graned operations such as elementwise add, multiplications. +The following example first performs an elementwise add between two symbols, then feed +them to the FullyConnected operator. +``` +>>> import mxnet.symbol as sym +>>> lhs = sym.Variable('data1') +>>> rhs = sym.Variable('data2') +>>> net = sym.FullyConnected(data=lhs + rhs, + name='fc1', num_hidden=128) +>>> net.list_arguments() +['data1', 'data2', 'fc1_weight', 'fc1_bias'] +``` + ### More Complicated Composition In the previous example, Symbols are constructed in a forward compositional way. +Besides doing things in a forward compistion way. You can also treat composed symbols as functions, +and apply them to existing symbols. -TODO +```python +>>> import mxnet.symbol as sym +>>> data = sym.Variable('data') +>>> net = sym.FullyConnected(data=data, + name='fc1', num_hidden=128) +>>> net.list_arguments() +['data', 'fc1_weight', 'fc1_bias'] +>>> data2 = sym.Variable('data2') +>>> in_net = sym.FullyConnected(data=data, + name='in', num_hidden=128) +>>> composed_net = net(data=in_net, name='compose') +>>> composed_net.list_arguments() +['data2', 'in_weight', 'in_bias', 'compose_fc1_weight', 'compose_fc1_bias'] +``` +In the above example, net is used a function to apply to an existing symbol ```in_net```, the resulting +composed_net will replace the original ```data``` by the the in_net instead. This is useful when you +want to change the input of some neural-net to be other structure. ### Shape Inference -Now we have defined the computation graph. In the next section, we are going to bind them to execution devices and -really run these computations. But before doing so, we need to figure out the shapes of the arguments, specifically, -the shape of all the weights, bias and outputs. +Now we have defined the computation graph. A common problem in the computation graph, +is to figure out shapes of each parameters. +Usually, we want to know the shape of all the weights, bias and outputs. -You can use ```Symbol.infer_shape``` to do that. Basically, shape inference allows you to shapes of arguments that you know, +You can use ```Symbol.infer_shape``` to do that. THe shape inference function +allows you to pass in shapes of arguments that you know, and it will try to infer the shapes of all arguments and outputs. ```python >>> import mxnet.symbol as sym @@ -272,9 +318,11 @@ The ```infer_shape``` will detect if there is inconsitency in the shapes, and raise an Error if some of them are inconsistent. ### Bind the Symbols -Symbols are configuration objects that represents a computation graph(a configuration of neuralnet) -that do not +Symbols are configuration objects that represents a computation graph (a configuration of neuralnet). +So far we have introduced how to build up the computation graph (i.e. a configuration). +The remaining question is, how we can do computation using the defined graph. +TODO. ### How Efficient is Symbolic API In short, they design to be very efficienct in both memory and runtime. diff --git a/include/mxnet/c_api.h b/include/mxnet/c_api.h old mode 100755 new mode 100644 index 7a8d4033becc..4e22abfcb6cd --- a/include/mxnet/c_api.h +++ b/include/mxnet/c_api.h @@ -262,6 +262,11 @@ MXNET_DLL int MXSymbolListAtomicSymbolCreators(mx_uint *out_size, * \param arg_names Name of the arguments. * \param arg_type_infos Type informations about the arguments. * \param arg_descriptions Description information about the arguments. + * \param key_var_num_args The keyword argument for specifying variable number of arguments. + * When this parameter has non-zero length, the function allows variable number + * of positional arguments, and will need the caller to pass it in in + * MXSymbolCreateAtomicSymbol, + * With key = key_var_num_args, and value = number of positional arguments. * \return 0 when success, -1 when failure happens */ MXNET_DLL int MXSymbolGetAtomicSymbolInfo(AtomicSymbolCreator creator, @@ -270,15 +275,8 @@ MXNET_DLL int MXSymbolGetAtomicSymbolInfo(AtomicSymbolCreator creator, mx_uint *num_args, const char ***arg_names, const char ***arg_type_infos, - const char ***arg_descriptions); -/*! - * \brief Get the docstring of AtomicSymbol. - * \param creator the AtomicSymbolCreator - * \param out the returned name of the creator - * \return 0 when success, -1 when failure happens - */ -MXNET_DLL int MXSymbolGetAtomicSymbolDoc(AtomicSymbolCreator creator, - const char **out); + const char ***arg_descriptions, + const char **key_var_num_args); /*! * \brief Create an AtomicSymbol. * \param creator the AtomicSymbolCreator diff --git a/include/mxnet/narray.h b/include/mxnet/narray.h old mode 100755 new mode 100644 diff --git a/include/mxnet/operator.h b/include/mxnet/operator.h old mode 100755 new mode 100644 index 9b900db4fdb9..64c1515a6f3b --- a/include/mxnet/operator.h +++ b/include/mxnet/operator.h @@ -395,6 +395,24 @@ typedef OperatorProperty *(*OperatorPropertyFactory)(); struct OperatorPropertyReg : public dmlc::FunctionRegEntryBase { + /*! + * \brief Set key_var_num_args + * When this is set, the API caller is required to pass in a + * argument with key=key_num_args.c_str(), and value=num_args. + * num_args is number of positional argument when calling the function. + * + * This is used to pass in length of positional arguments + * for operators that can take variable length of input. + * Most operators do not need to set this property. + * + * \param key the key name to be set + */ + inline OperatorPropertyReg& set_key_var_num_args(const std::string &key) { // NOLINT(*) + this->key_var_num_args = key; + return *this; + } + /*! \brief The key num_args name. */ + std::string key_var_num_args; }; //-------------------------------------------------------------- diff --git a/include/mxnet/symbolic.h b/include/mxnet/symbolic.h old mode 100755 new mode 100644 diff --git a/python/mxnet/narray.py b/python/mxnet/narray.py old mode 100755 new mode 100644 diff --git a/python/mxnet/symbol.py b/python/mxnet/symbol.py index 9b93db788173..f882933538b2 100644 --- a/python/mxnet/symbol.py +++ b/python/mxnet/symbol.py @@ -10,8 +10,10 @@ from .base import NArrayHandle, ExecutorHandle, SymbolHandle from .base import check_call from .context import Context +from .narray import NArray from .executor import Executor + class Symbol(object): """Symbol is symbolic graph of the mxnet.""" @@ -115,7 +117,7 @@ def _compose(self, *args, **kwargs): for arg in args: if not isinstance(arg, Symbol): raise TypeError('Compose expect `Symbol` as arguments') - for _, val in kwargs.items(): + for val in kwargs.values(): if not isinstance(val, Symbol): raise TypeError('Compose expect `Symbol` as arguments') @@ -148,7 +150,7 @@ def list_returns(self): Returns ------- - args: list of string + returns : list of string List of all the returns. """ size = ctypes.c_uint() @@ -162,8 +164,15 @@ def list_auxiliary_states(self): Returns ------- - args: list of string - List of all the auxiliary + aux_states : list of string + List the names of the auxiliary states. + + Notes + ----- + Auxiliary states are special states of symbols that do not corresponds to an argument, + and do not have gradient. But still be useful for the specific operations. + A common example of auxiliary state is the moving_mean and moving_variance in BatchNorm. + Most operators do not have Auxiliary states. """ size = ctypes.c_uint() sarr = ctypes.POINTER(ctypes.c_char_p)() @@ -175,7 +184,7 @@ def infer_shape(self, *args, **kwargs): """Infer the shape of outputs and arguments of given known shapes of arguments. User can either pass in the known shapes in positional way or keyword argument way. - Pair of Nones is returned if there is not enough information passed in. + Tuple of Nones is returned if there is not enough information passed in. An error will be raised if there is inconsistency found in the known shapes passed in. Parameters @@ -255,7 +264,7 @@ def infer_shape(self, *args, **kwargs): tuple(aux_shape_data[i][:aux_shape_ndim[i]]) for i in range(aux_shape_size.value)] return (arg_shapes, out_shapes, aux_shapes) else: - return (None, None) + return (None, None, None) # pylint: enable=too-many-locals def debug_str(self): @@ -271,33 +280,145 @@ def debug_str(self): self.handle, ctypes.byref(debug_str))) return py_str(debug_str.value) - def bind(self, ctx, args, args_grad, reqs, aux_states=None): + @staticmethod + def _get_narray_handle(arg_key, args, arg_names, allow_missing): + """Helper function to get narray handles from various inputs. + + Parameters + ---------- + arg_key : str + The name of argument, used for error message. + + args : list of NArray or dict of str->NArray + Input arguments to the symbols. + If type is list of NArray, the position is in the same order of arg_names. + If type is dict of str->NArray, then it maps the name of arguments + to the corresponding NArray, + + args_names : list of string + List of argument names. + + allow_missing : boolean + Whether missing argument is allowed. + When allowed, the missing handle will be set to None(null) + + Returns + ------- + handles : list of NArrayHandle + The positional list of NArrayHandles generated from input. + """ + # setup args + arg_handles = [] + if isinstance(args, list): + if len(args) != len(arg_names): + raise ValueError('Length of %s do not match number of arguments' % arg_key) + for narr in args: + if not isinstance(narr, NArray): + raise TypeError('Only Accept list of NArrays or dict of str->NArray') + arg_handles.append(narr.handle) + elif isinstance(args, dict): + for name in arg_names: + if name in arg_names: + narr = args[name] + if not isinstance(narr, NArray): + raise TypeError('Only Accept list of NArrays or dict of str->NArray') + arg_handles.append(narr.handle) + else: + if allow_missing: + arg_handles.append(None) + else: + raise ValueError('Must specify all the arguments in %s' % arg_key) + else: + raise TypeError('Only Accept list of NArrays or dict of str->NArray') + return c_array(NArrayHandle, arg_handles) + + def bind(self, ctx, args, args_grad=None, grad_req='write', aux_states=None): """Bind current symbol to get an executor. Parameters ---------- ctx : Context - context executor to run on - args : Array of NArray - input args to the symbol - args_grad : Array of NArray - input args' gradient - reqs : Array of enum - graident requirements - aux_states : Array of NArray - input auxiliary states to the symbol + The device context the generated executor to run on. + + args : list of NArray or dict of str->NArray + Input arguments to the symbol. + - If type is list of NArray, the position is in the same order of list_arguments. + - If type is dict of str->NArray, then it maps the name of arguments + to the corresponding NArray, + - In either case, all the arguments must be provided. + + args_grad : list of NArray or dict of str->NArray, optional + When specified, args_grad provide NArrays to hold + the result of gradient value in backward. + - If type is list of NArray, the position is in the same order of list_arguments. + - If type is dict of str->NArray, then it maps the name of arguments + to the corresponding NArray. + - When the type is dict of str->NArray, users only need to provide the dict + for needed argument gradient. + Only the specified argument gradient will be calculated. + + grad_req : {'write', 'add', 'null'}, or list of str or dict of str->str, optional + Specifies how we should update the gradient to the args_grad. + - 'write' means everytime gradient is write to specified args_grad NArray. + - 'add' means everytime gradient is add to the specified NArray. + - 'null' means no action is taken, the gradient may not be calculated. + + aux_states : list of NArray, or dict of str->NArray, optional + Input auxiliary states to the symbol, only need to specify when + list_auxiliary_states is not empty. + - If type is list of NArray, the position is in the same order of list_auxiliary_states + - If type is dict of str->NArray, then it maps the name of auxiliary_states + to the corresponding NArray, + - In either case, all the auxiliary_states need to be provided. + + Returns + ------- + executor : mxnet.Executor + The generated Executor + + Notes + ----- + Auxiliary states are special states of symbols that do not corresponds to an argument, + and do not have gradient. But still be useful for the specific operations. + A common example of auxiliary state is the moving_mean and moving_variance in BatchNorm. + Most operators do not have auxiliary states and this parameter can be safely ignored. + + User can give up gradient by using a dict in args_grad and only specify + gradient they interested in. """ - # TODO(bing): consider a more friendly interface - # For example, pass in args_grad by dict - enum = {"null" : 0, "write_to" : 1, "in_place":2, "add_to" : 3} if not isinstance(ctx, Context): raise TypeError("Context type error") - if aux_states == None: + + args_handle = self._get_narray_handle('args', args, self.list_arguments(), False) + # setup args gradient + if args_grad is None: + args_grad_handle = c_array(NArrayHandle, [None] * len(args)) + else: + args_grad_handle = self._get_narray_handle('args_grad', args_grad, + self.list_arguments(), True) + + if aux_states is None: aux_states = [] - args_handle = c_array(NArrayHandle, [item.handle for item in args]) - args_grad_handle = c_array(NArrayHandle, [item.handle for item in args_grad]) - reqs_array = c_array(mx_uint, [mx_uint(enum[item]) for item in reqs]) - aux_args_handle = c_array(NArrayHandle, [item.handle for item in aux_states]) + aux_args_handle = self._get_narray_handle('aux_states', aux_states, + self.list_auxiliary_states(), False) + + # setup requirements + req_map = {'null' : 0, 'write' : 1, 'add' : 3} + if isinstance(grad_req, string_types): + if grad_req not in req_map: + raise ValueError('grad_req must be in %s' % str(req_map)) + reqs_array = c_array(mx_uint, [mx_uint(req_map[grad_req])] * len(self.list_arguments())) + elif isinstance(grad_req, list): + reqs_array = c_array(mx_uint, [mx_uint(req_map[item]) for item in grad_req]) + elif isinstance(grad_req, dict): + req_array = [] + for name in self.list_arguments(): + if name in grad_req: + req_array.append(mx_uint(req_map[grad_req[name]])) + else: + req_array.append(mx_uint(0)) + req_array = c_array(mx_uint, req_array) + handle = ExecutorHandle() check_call(_LIB.MXExecutorBind(self.handle, mx_uint(ctx.device_mask), @@ -314,10 +435,17 @@ def bind(self, ctx, args, args_grad, reqs, aux_states=None): def grad(self, wrt): """Get the autodiff of current symbol. + This function can only be used if current symbol is a loss function. + Parameters ---------- wrt : Array of String keyword arguments of the symbol that the gradients are taken. + + Returns + ------- + grad : Symbol + A gradient Symbol with returns to be the corresponding gradients. """ handle = SymbolHandle() c_wrt = c_array(ctypes.c_char_p, [c_str(key) for key in wrt]) @@ -377,25 +505,36 @@ def _make_atomic_symbol_function(handle): """Create an atomic symbol function by handle and funciton name.""" name = ctypes.c_char_p() desc = ctypes.c_char_p() + key_var_num_args = ctypes.c_char_p() num_args = mx_uint() arg_names = ctypes.POINTER(ctypes.c_char_p)() arg_types = ctypes.POINTER(ctypes.c_char_p)() arg_descs = ctypes.POINTER(ctypes.c_char_p)() + check_call(_LIB.MXSymbolGetAtomicSymbolInfo( handle, ctypes.byref(name), ctypes.byref(desc), ctypes.byref(num_args), ctypes.byref(arg_names), ctypes.byref(arg_types), - ctypes.byref(arg_descs))) + ctypes.byref(arg_descs), + ctypes.byref(key_var_num_args))) + key_var_num_args = py_str(key_var_num_args.value) func_name = py_str(name.value) param_str = [] for i in range(num_args.value): - ret = '%s : %s' % (py_str(arg_names[i]), py_str(arg_types[i])) + key = py_str(arg_names[i]) + if key == key_var_num_args: + continue + ret = '%s : %s' % (key, py_str(arg_types[i])) if len(arg_descs[i]) != 0: ret += '\n ' + py_str(arg_descs[i]) param_str.append(ret) + desc = py_str(desc.value) + if key_var_num_args: + desc = '\nThis function support variable length of positional input.' + doc_str = ('%s\n\n' + 'Parameters\n' + '----------\n' + @@ -406,7 +545,7 @@ def _make_atomic_symbol_function(handle): '-------\n' + 'symbol: Symbol\n'+ ' The result symbol.') - doc_str = doc_str % (py_str(desc.value), '\n'.join(param_str)) + doc_str = doc_str % (desc, '\n'.join(param_str)) def creator(*args, **kwargs): """Activation Operator of Neural Net. @@ -427,6 +566,10 @@ def creator(*args, **kwargs): symbol_kwargs = {} name = kwargs.pop('name', None) + if key_var_num_args and key_var_num_args not in kwargs: + param_keys.append(c_str(key_var_num_args)) + param_vals.append(c_str(str(len(args)))) + for k, v in kwargs.items(): if isinstance(v, Symbol): symbol_kwargs[k] = v @@ -446,6 +589,10 @@ def creator(*args, **kwargs): raise TypeError( '%s can only accept input' 'Symbols either as positional or keyword arguments, not both' % func_name) + if key_var_num_args and len(symbol_kwargs) != 0: + raise ValueError('This function support variable length of Symbol arguments.\n' + + 'Please pass all the input Symbols via positional arguments' + + ' instead of keyword arguments.') s = Symbol(sym_handle) s._compose(*args, name=name, **symbol_kwargs) diff --git a/src/c_api.cc b/src/c_api.cc old mode 100755 new mode 100644 index 69db585b29a7..48d83cffc688 --- a/src/c_api.cc +++ b/src/c_api.cc @@ -449,8 +449,10 @@ int MXSymbolGetAtomicSymbolInfo(AtomicSymbolCreator creator, mx_uint *num_args, const char ***arg_names, const char ***arg_type_infos, - const char ***arg_descriptions) { + const char ***arg_descriptions, + const char **key_var_num_args) { OperatorPropertyReg *e = static_cast(creator); + *key_var_num_args = e->key_var_num_args.c_str(); return MXAPIGetFunctionRegInfo(e, name, description, num_args, arg_names, arg_type_infos, arg_descriptions); } @@ -729,8 +731,15 @@ int MXExecutorBind(SymbolHandle symbol_handle, std::vector aux_states_vec; for (mx_uint i = 0; i < len; ++i) { in_args_vec.push_back(*(in_args_ptr[i])); - arg_grad_vec.push_back(*(arg_grad_ptr[i])); - grad_req_vec.push_back(static_cast(grad_req_type[i])); + if (arg_grad_ptr[i] == nullptr) { + arg_grad_vec.push_back(NArray()); + grad_req_vec.push_back(kNullOp); + LOG(INFO) << "nop"; + } else { + LOG(INFO) << "grad=" << grad_req_type[i]; + arg_grad_vec.push_back(*(arg_grad_ptr[i])); + grad_req_vec.push_back(static_cast(grad_req_type[i])); + } } for (mx_uint i = 0; i < aux_states_len; ++i) { aux_states_vec.push_back(*(aux_states_ptr[i])); diff --git a/src/narray/narray.cc b/src/narray/narray.cc old mode 100755 new mode 100644 diff --git a/src/narray/narray_function-inl.h b/src/narray/narray_function-inl.h old mode 100755 new mode 100644 diff --git a/src/narray/narray_function.h b/src/narray/narray_function.h old mode 100755 new mode 100644 diff --git a/src/operator/activation-inl.h b/src/operator/activation-inl.h old mode 100755 new mode 100644 diff --git a/src/operator/batch_norm-inl.h b/src/operator/batch_norm-inl.h old mode 100755 new mode 100644 diff --git a/src/operator/concat-inl.h b/src/operator/concat-inl.h new file mode 100644 index 000000000000..2fa308a10a8a --- /dev/null +++ b/src/operator/concat-inl.h @@ -0,0 +1,230 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file concat-inl.h + * \brief + * \author Bing Xu +*/ +#ifndef MXNET_OPERATOR_CONCAT_INL_H_ +#define MXNET_OPERATOR_CONCAT_INL_H_ +#include +#include +#include +#include +#include +#include +#include +#include +#include "./operator_common.h" + +namespace mxnet { +namespace op { + +enum ConcatOpInputs {kData0, kData1, kData2, kData3, kData4}; +enum ConcatOpOutputs {kOut}; + +struct ConcatParam : public dmlc::Parameter { + int num_args; + DMLC_DECLARE_PARAMETER(ConcatParam) { + DMLC_DECLARE_FIELD(num_args).set_range(1, 6) + .describe("Number of inputs to be concated."); + } +}; // struct ConcatParam + +template +class ConcatOp : public Operator { + public: + explicit ConcatOp(ConcatParam param) + : size_(param.num_args) {} + + virtual void Forward(const OpContext &ctx, + const std::vector &in_data, + const std::vector &req, + const std::vector &out_data, + const std::vector &aux_args) { + using namespace mshadow; + using namespace mshadow::expr; + CHECK_EQ(static_cast(in_data.size()), size_); + CHECK_EQ(out_data.size(), 1); + CHECK_EQ(req[kOut], kWriteTo); + Stream *s = ctx.get_stream(); + std::vector > data(size_); + Tensor out; + if (in_data[kData0].ndim() == 2) { + uint32_t dim = 0; + for (int i = 0; i < size_; ++i) { + uint32_t ds[] = {in_data[i].shape_[0], in_data[i].shape_[1], 1, 1}; + TShape dshape(ds, ds + 4); + data[i] = in_data[i].get_with_shape(dshape, s); + dim += in_data[i].shape_[1]; + } + uint32_t ds_out[] = {in_data[kData0].shape_[0], dim, 1, 1}; + TShape dshape_out(ds_out, ds_out + 4); + out = out_data[kOut].get_with_shape(dshape_out, s); + } else { + for (int i = 0; i < size_; ++i) { + data[i] = in_data[i].get(s); + } + out = out_data[kOut].get(s); + } + switch (size_) { + case 2: + Assign(out, req[kOut], + concat<1>(data[kData0], data[kData1])); + break; + case 3: + Assign(out, req[kOut], + concat<1>(data[kData0], + concat<1>(data[kData1], data[kData2]))); + break; + case 4: + Assign(out, req[kOut], + concat<1>(data[kData0], + concat<1>(data[kData1], + concat<1>(data[kData2], data[kData3])))); + break; + case 5: + Assign(out, req[kOut], + concat<1>(data[kData0], + concat<1>(data[kData1], + concat<1>(data[kData2], + concat<1>(data[kData3], data[kData4]))))); + break; + default: + LOG(FATAL) << "Incorrect concat size_: " << size_; + } + } + + virtual void Backward(const OpContext &ctx, + const std::vector &out_grad, + const std::vector &in_data, + const std::vector &out_data, + const std::vector &req, + const std::vector &in_grad, + const std::vector &aux_states) { + using namespace mshadow; + using namespace mshadow::expr; + CHECK_EQ(out_grad.size(), 1); + CHECK_EQ(in_grad.size(), static_cast(size_)); + Stream *s = ctx.get_stream(); + std::vector > grad_in(size_); + Tensor grad; + if (out_grad[kOut].ndim() == 2) { + uint32_t dim = 0; + for (int i = 0; i < size_; ++i) { + uint32_t ds[] = {in_grad[i].shape_[0], in_grad[i].shape_[1], 1, 1}; + TShape dshape(ds, ds + 4); + grad_in[i] = in_grad[i].get_with_shape(dshape, s); + dim += in_grad[i].shape_[1]; + CHECK_EQ(req[i], kWriteTo); + } + uint32_t ds_out[] = {in_grad[kData0].shape_[0], dim, 1, 1}; + TShape dshape_out(ds_out, ds_out + 4); + grad = out_grad[kOut].get_with_shape(dshape_out, s); + } else { + for (int i = 0; i < size_; ++i) { + grad_in[i] = in_grad[i].get(s); + CHECK_EQ(req[i], kWriteTo); + } + grad = out_grad[kOut].get(s); + } + switch (size_) { + case 2: { + concat<1>(grad_in[kData0], grad_in[kData1]) = grad; + break; + } + case 3: { + concat<1>(grad_in[kData0], + concat<1>(grad_in[kData1], grad_in[kData2])) = grad; + break; + } + case 4: { + concat<1>(grad_in[kData0], + concat<1>(grad_in[kData1], + concat<1>(grad_in[kData2], grad_in[kData3]))) = grad; + break; + } + case 5: { + concat<1>(grad_in[kData0], + concat<1>(grad_in[kData1], + concat<1>(grad_in[kData2], + concat<1>(grad_in[kData3], grad_in[kData4])))) = grad; + break; + } + default: + LOG(FATAL) << "Incorrect concat size_: " << size_; + } + } + + private: + int size_; +}; // class ConcatOp + +template +Operator *CreateOp(ConcatParam param); + +#if DMLC_USE_CXX11 +class ConcatProp : public OperatorProperty { + public: + void Init(const std::vector >& kwargs) override { + param_.Init(kwargs); + } + + std::vector ListArguments() const override { + std::vector ret; + for (int i = 0; i < param_.num_args; ++i) { + ret.push_back(std::string("arg") + static_cast('0' + i)); + } + return ret; + } + + bool InferShape(std::vector *in_shape, + std::vector *out_shape, + std::vector *aux_shape) const override { + using namespace mshadow; + CHECK_EQ(in_shape->size(), static_cast(param_.num_args)); + TShape dshape = in_shape->at(kData0); + if (dshape.ndim() == 0) return false; + CHECK_GT(dshape.ndim(), 1); + for (int i = 1; i < param_.num_args; ++i) { + const TShape &tmp = in_shape->at(i); + if (tmp.ndim() == 0) return false; + for (uint32_t j = 0; j < dshape.ndim(); ++j) { + if (j == 1) { + dshape[1] += tmp[1]; + } else { + CHECK_EQ(dshape[j], tmp[j]); + } + } + } + out_shape->clear(); + out_shape->push_back(dshape); + return true; + } + + OperatorProperty* Copy() const override { + auto ptr = new ConcatProp(); + ptr->param_ = param_; + return ptr; + } + + std::string TypeString() const override { + return "Concat"; + } + + std::vector DeclareBackwardDependency( + const std::vector &out_grad, + const std::vector &in_data, + const std::vector &out_data) const override { + return out_grad; + } + + Operator* CreateOperator(Context ctx) const; + + private: + ConcatParam param_; +}; // class ConcatProp +#endif // DMLC_USE_CXX11 +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_CONCAT_INL_H_ diff --git a/src/operator/concat.cc b/src/operator/concat.cc new file mode 100644 index 000000000000..75867ea7e05f --- /dev/null +++ b/src/operator/concat.cc @@ -0,0 +1,30 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file concat.cc + * \brief + * \author Bing Xu +*/ + +#include "./concat-inl.h" + +namespace mxnet { +namespace op { +template<> +Operator* CreateOp(ConcatParam param) { + return new ConcatOp(param); +} + +Operator* ConcatProp::CreateOperator(Context ctx) const { + DO_BIND_DISPATCH(CreateOp, param_); +} + +DMLC_REGISTER_PARAMETER(ConcatParam); + +MXNET_REGISTER_OP_PROPERTY(Concat, ConcatProp) +.describe("Perform an feature concat on channel dim (dim 1) over all the inputs.") +.add_arguments(ConcatParam::__FIELDS__()) +.set_key_var_num_args("num_args"); + +} // namespace op +} // namespace mxnet + diff --git a/src/operator/concat.cu b/src/operator/concat.cu new file mode 100644 index 000000000000..4e24b45cc676 --- /dev/null +++ b/src/operator/concat.cu @@ -0,0 +1,19 @@ +/*! + * Copyright (c) 2015 by Contributors + * \file concat.cu + * \brief + * \author Bing Xu +*/ + +#include "./concat-inl.h" + +namespace mxnet { +namespace op { +template<> +Operator* CreateOp(ConcatParam param) { + return new ConcatOp(param); +} + +} // namespace op +} // namespace mxnet + diff --git a/src/operator/convolution-inl.h b/src/operator/convolution-inl.h old mode 100755 new mode 100644 diff --git a/src/operator/elementwise_binary_op-inl.h b/src/operator/elementwise_binary_op-inl.h index 89d8b115bf6b..b3ae8adc3de1 100644 --- a/src/operator/elementwise_binary_op-inl.h +++ b/src/operator/elementwise_binary_op-inl.h @@ -18,30 +18,30 @@ namespace mxnet { namespace op { -enum ElementwiseBinaryOpInputs {kLhs, kRhs}; -enum ElementwiseBinaryOpOutputs {kOut}; -enum ElementwiseBinaryOpType {kPlus, kMinus, kMul, kDiv}; +enum ElementWiseBinaryOpInputs {kLhs, kRhs}; +enum ElementWiseBinaryOpOutputs {kOut}; +enum ElementWiseBinaryOpType {kPlus, kMinus, kMul, kDiv}; template -inline ElementwiseBinaryOpType GetOpType(); +inline ElementWiseBinaryOpType GetOpType(); template inline const char* GetOpTypeString(); template<> -inline ElementwiseBinaryOpType GetOpType() { +inline ElementWiseBinaryOpType GetOpType() { return kPlus; } template<> -inline ElementwiseBinaryOpType GetOpType() { +inline ElementWiseBinaryOpType GetOpType() { return kMinus; } template<> -inline ElementwiseBinaryOpType GetOpType() { +inline ElementWiseBinaryOpType GetOpType() { return kMul; } template<> -inline ElementwiseBinaryOpType GetOpType() { +inline ElementWiseBinaryOpType GetOpType() { return kDiv; } @@ -65,7 +65,7 @@ inline const char* GetOpTypeString() { } template -class ElementwiseBinaryOp : public Operator { +class ElementWiseBinaryOp : public Operator { public: virtual void Forward(const OpContext &ctx, const std::vector &in_data, @@ -132,16 +132,16 @@ class ElementwiseBinaryOp : public Operator { } } } -}; // class ElementwiseBinaryOp +}; // class ElementWiseBinaryOp template -inline Operator* CreateElementwiseBinaryOp_(ElementwiseBinaryOpType type) { +inline Operator* CreateElementWiseBinaryOp_(ElementWiseBinaryOpType type) { switch (type) { - case kPlus: return new ElementwiseBinaryOp(); - case kMinus: return new ElementwiseBinaryOp(); - case kMul: return new ElementwiseBinaryOp(); - case kDiv: return new ElementwiseBinaryOp(); + case kPlus: return new ElementWiseBinaryOp(); + case kMinus: return new ElementWiseBinaryOp(); + case kMul: return new ElementWiseBinaryOp(); + case kDiv: return new ElementWiseBinaryOp(); } LOG(FATAL) << "uknown op type"; return NULL; @@ -149,11 +149,11 @@ inline Operator* CreateElementwiseBinaryOp_(ElementwiseBinaryOpType type) { // Decalre Factory function, used for dispatch specialization template -Operator* CreateElementwiseBinaryOp(ElementwiseBinaryOpType type); +Operator* CreateElementWiseBinaryOp(ElementWiseBinaryOpType type); #if DMLC_USE_CXX11 template -class ElementwiseBinaryOpProp : public OperatorProperty { +class ElementWiseBinaryOpProp : public OperatorProperty { public: void Init(const std::vector >& kwargs) override { CHECK_EQ(kwargs.size(), 0) @@ -183,7 +183,7 @@ class ElementwiseBinaryOpProp : public OperatorProperty { } OperatorProperty* Copy() const override { - return new ElementwiseBinaryOpProp(); + return new ElementWiseBinaryOpProp(); } std::string TypeString() const override { diff --git a/src/operator/elementwise_binary_op.cc b/src/operator/elementwise_binary_op.cc index bdf228df4851..0485707ffc18 100644 --- a/src/operator/elementwise_binary_op.cc +++ b/src/operator/elementwise_binary_op.cc @@ -8,23 +8,23 @@ namespace mxnet { namespace op { template<> -Operator* CreateElementwiseBinaryOp(ElementwiseBinaryOpType type) { - return CreateElementwiseBinaryOp_(type); +Operator* CreateElementWiseBinaryOp(ElementWiseBinaryOpType type) { + return CreateElementWiseBinaryOp_(type); } // DO_BIND_DISPATCH comes from static_operator_common.h template -Operator* ElementwiseBinaryOpProp::CreateOperator(Context ctx) const { - DO_BIND_DISPATCH(CreateElementwiseBinaryOp, GetOpType()); +Operator* ElementWiseBinaryOpProp::CreateOperator(Context ctx) const { + DO_BIND_DISPATCH(CreateElementWiseBinaryOp, GetOpType()); } -MXNET_REGISTER_OP_PROPERTY(_Plus, ElementwiseBinaryOpProp) +MXNET_REGISTER_OP_PROPERTY(_Plus, ElementWiseBinaryOpProp) .describe("Perform an elementwise plus."); -MXNET_REGISTER_OP_PROPERTY(_Minus, ElementwiseBinaryOpProp) +MXNET_REGISTER_OP_PROPERTY(_Minus, ElementWiseBinaryOpProp) .describe("Perform an elementwise minus."); -MXNET_REGISTER_OP_PROPERTY(_Mul, ElementwiseBinaryOpProp) +MXNET_REGISTER_OP_PROPERTY(_Mul, ElementWiseBinaryOpProp) .describe("Perform an elementwise mul."); -MXNET_REGISTER_OP_PROPERTY(_Div, ElementwiseBinaryOpProp) +MXNET_REGISTER_OP_PROPERTY(_Div, ElementWiseBinaryOpProp) .describe("Perform an elementwise div."); } // namespace op diff --git a/src/operator/elementwise_binary_op.cu b/src/operator/elementwise_binary_op.cu index 36616dbda797..ba8991707f12 100644 --- a/src/operator/elementwise_binary_op.cu +++ b/src/operator/elementwise_binary_op.cu @@ -8,8 +8,8 @@ namespace mxnet { namespace op { template<> -Operator* CreateElementwiseBinaryOp(ElementwiseBinaryOpType type) { - return CreateElementwiseBinaryOp_(type); +Operator* CreateElementWiseBinaryOp(ElementWiseBinaryOpType type) { + return CreateElementWiseBinaryOp_(type); } } // namespace op } // namespace mxnet diff --git a/src/operator/elementwise_sum-inl.h b/src/operator/elementwise_sum-inl.h old mode 100755 new mode 100644 index 390df9bd36e1..c2a890b2e976 --- a/src/operator/elementwise_sum-inl.h +++ b/src/operator/elementwise_sum-inl.h @@ -16,6 +16,7 @@ #include #include #include "./operator_common.h" +#include "./mshadow_op.h" namespace mxnet { namespace op { @@ -24,9 +25,9 @@ enum ElementWiseSumOpInputs {kData0, kData1, kData2, kData3}; enum ElementWiseSumOpOutputs {kOut}; struct ElementWiseSumParam : public dmlc::Parameter { - int size; + int num_args; DMLC_DECLARE_PARAMETER(ElementWiseSumParam) { - DMLC_DECLARE_FIELD(size).set_range(1, 100) + DMLC_DECLARE_FIELD(num_args).set_range(1, 100) .describe("Number of inputs to be sumed."); } }; @@ -35,7 +36,7 @@ template class ElementWiseSumOp : public Operator { public: explicit ElementWiseSumOp(ElementWiseSumParam param) - : size_(param.size) {} + : size_(param.num_args) {} virtual void Forward(const OpContext &ctx, const std::vector &in_data, @@ -74,10 +75,11 @@ class ElementWiseSumOp : public Operator { } default: { Tensor in_0 = in_data[kData0].FlatTo2D(s); - Assign(out, req[kOut], in_0); - for (int i = 0; i < size_; ++i) { + Assign(out, req[kOut], F(in_0)); + for (int i = 1; i < size_; ++i) { out += in_data[i].FlatTo2D(s); } + break; } } } @@ -91,14 +93,13 @@ class ElementWiseSumOp : public Operator { const std::vector &aux_args) { using namespace mshadow; using namespace mshadow::expr; - CHECK_EQ(out_grad.size(), static_cast(size_)); + CHECK_EQ(in_grad.size(), static_cast(size_)); Stream *s = ctx.get_stream(); Tensor ograd = out_grad[kOut].FlatTo2D(s); - for (int i = 0; i < size_; ++i) { if (req[i] == kNullOp || req[i] == kWriteInplace) continue; Tensor igrad = in_grad[i].FlatTo2D(s); - Assign(igrad, req[i], ograd); + Assign(igrad, req[i], F(ograd)); } } @@ -113,26 +114,39 @@ Operator* CreateOp(ElementWiseSumParam param); class ElementWiseSumProp : public OperatorProperty { public: void Init(const std::vector >& kwargs) override { - // TODO(bing) change directly to vector of pairs begin end - std::map kmap(kwargs.begin(), kwargs.end()); - param_.Init(kmap); + param_.Init(kwargs); } bool InferShape(std::vector *in_shape, - std::vector *out_shape, - std::vector *aux_shape) const override { + std::vector *out_shape, + std::vector *aux_shape) const override { using namespace mshadow; - CHECK_EQ(in_shape->size(), static_cast(param_.size)); - const TShape &dshape = in_shape->at(0); - if (dshape.ndim() == 0) return false; - for (int i = 1; i < param_.size; ++i) { - SHAPE_ASSIGN_CHECK(*in_shape, i, dshape); + CHECK_EQ(in_shape->size(), static_cast(param_.num_args)); + int sidx = -1; + for (int i = 0; i < param_.num_args; ++i) { + if (in_shape->at(i).ndim() != 0) { + sidx = i; break; + } + } + if (sidx == -1) return false; + for (int i = 0; i < param_.num_args; ++i) { + if (i != sidx) { + SHAPE_ASSIGN_CHECK(*in_shape, i, in_shape->at(sidx)); + } } out_shape->clear(); - out_shape->push_back(dshape); + out_shape->push_back(in_shape->at(sidx)); return true; } + std::vector ListArguments() const override { + std::vector ret; + for (int i = 0; i < param_.num_args; ++i) { + ret.push_back(std::string("arg") + static_cast('0' + i)); + } + return ret; + } + OperatorProperty* Copy() const override { auto ptr = new ElementWiseSumProp(); ptr->param_ = param_; diff --git a/src/operator/elementwise_sum.cc b/src/operator/elementwise_sum.cc index e8c3968c94ab..d8546148f76c 100644 --- a/src/operator/elementwise_sum.cc +++ b/src/operator/elementwise_sum.cc @@ -20,7 +20,8 @@ DMLC_REGISTER_PARAMETER(ElementWiseSumParam); MXNET_REGISTER_OP_PROPERTY(ElementWiseSum, ElementWiseSumProp) .describe("Perform an elementwise sum over all the inputs.") -.add_arguments(ElementWiseSumParam::__FIELDS__()); +.add_arguments(ElementWiseSumParam::__FIELDS__()) +.set_key_var_num_args("num_args"); } // namespace op } // namespace mxnet diff --git a/src/operator/fully_connected-inl.h b/src/operator/fully_connected-inl.h old mode 100755 new mode 100644 diff --git a/src/operator/pooling-inl.h b/src/operator/pooling-inl.h old mode 100755 new mode 100644 diff --git a/src/operator/reshape-inl.h b/src/operator/reshape-inl.h old mode 100755 new mode 100644 diff --git a/src/operator/reshape.cc b/src/operator/reshape.cc old mode 100755 new mode 100644 diff --git a/src/operator/softmax-inl.h b/src/operator/softmax-inl.h old mode 100755 new mode 100644 diff --git a/src/storage/gpu_device_storage.h b/src/storage/gpu_device_storage.h index 22956d192893..237c48389d13 100644 --- a/src/storage/gpu_device_storage.h +++ b/src/storage/gpu_device_storage.h @@ -34,7 +34,7 @@ class GPUDeviceStorage { }; // class GPUDeviceStorage inline void* GPUDeviceStorage::Alloc(size_t size) { - void* ret; + void* ret = nullptr; #if MXNET_USE_CUDA CUDA_CALL(cudaMalloc(&ret, size)); #else // MXNET_USE_CUDA diff --git a/src/symbol/graph_executor.cc b/src/symbol/graph_executor.cc old mode 100755 new mode 100644 diff --git a/src/symbol/graph_executor.h b/src/symbol/graph_executor.h old mode 100755 new mode 100644 diff --git a/src/symbol/static_graph.cc b/src/symbol/static_graph.cc old mode 100755 new mode 100644 index 05c8785de0c8..c24ee1a085d5 --- a/src/symbol/static_graph.cc +++ b/src/symbol/static_graph.cc @@ -172,7 +172,7 @@ StaticGraph::Node StaticGraph::CreateSumNode( Node agg_node; agg_node.op.reset(OperatorProperty::Create("ElementWiseSum")); os_size << grad_source.size(); - agg_node.op->Init({{"size", os_size.str()}}); + agg_node.op->Init({{"num_args", os_size.str()}}); agg_node.inputs = grad_source; return agg_node; } diff --git a/src/symbol/symbol.cc b/src/symbol/symbol.cc old mode 100755 new mode 100644 diff --git a/tests/python/test_bind.py b/tests/python/test_bind.py index 17ea3353fc23..934b66688934 100644 --- a/tests/python/test_bind.py +++ b/tests/python/test_bind.py @@ -25,12 +25,29 @@ def check_bind_with_uniform(uf, gf, dim): executor = ret.bind(mx.Context('cpu'), args=[lhs_arr, rhs_arr], - args_grad=[lhs_grad, rhs_grad], - reqs=['write_to'] * 2) + args_grad=[lhs_grad, rhs_grad]) + + exec3 = ret.bind(mx.Context('cpu'), + args=[lhs_arr, rhs_arr]) + + + exec4 = ret.bind(mx.Context('cpu'), + args={'rhs': rhs_arr, 'lhs': lhs_arr}) + + exec4 = ret.bind(mx.Context('cpu'), + args={'rhs': rhs_arr, 'lhs': lhs_arr}, + args_grad={'lhs': lhs_grad, 'rhs': rhs_grad}) + executor.forward() + exec3.forward() + exec4.forward() out2 = executor.heads()[0].numpy out1 = uf(lhs_arr.numpy, rhs_arr.numpy) + out3 = exec3.heads()[0].numpy + out4 = exec4.heads()[0].numpy assert reldiff(out1, out2) < 1e-6 + assert reldiff(out1, out3) < 1e-6 + assert reldiff(out1, out4) < 1e-6 # test gradient out_grad = mx.narray.create(shape) out_grad.numpy[:] = np.ones(shape) diff --git a/tests/python/test_conv.py b/tests/python/test_conv.py index f9c9806e0d30..8fd0a8ccac6b 100644 --- a/tests/python/test_conv.py +++ b/tests/python/test_conv.py @@ -49,10 +49,9 @@ def CalAcc(out, label): if "beta" in name: narray.numpy[:] = 0.0 -req = ['write_to' for i in range(len(arg_narrays))] # bind executer # TODO(bing): think of a better bind interface -executor = softmax.bind(mx.Context('cpu'), arg_narrays, grad_narrays, req, aux_narrays) +executor = softmax.bind(mx.Context('cpu'), arg_narrays, grad_narrays, 'write', aux_narrays) # update out_narray = executor.heads()[0] diff --git a/tests/python/test_mlp.py b/tests/python/test_mlp.py index 92ec1ac38862..c917b0a0085f 100644 --- a/tests/python/test_mlp.py +++ b/tests/python/test_mlp.py @@ -34,10 +34,9 @@ def CalAcc(out, label): if "bias" in name: narray.numpy[:] = 0.0 -req = ['write_to' for i in range(len(arg_narrays))] # bind executer # TODO(bing): think of a better bind interface -executor = softmax.bind(mx.Context('cpu'), arg_narrays, grad_narrays, req) +executor = softmax.bind(mx.Context('cpu'), arg_narrays, grad_narrays) # update out_narray = executor.heads()[0] diff --git a/tests/python/test_operator.py b/tests/python/test_operator.py new file mode 100644 index 000000000000..988610d34811 --- /dev/null +++ b/tests/python/test_operator.py @@ -0,0 +1,105 @@ +# pylint: skip-file + +import numpy as np +import mxnet as mx + +def reldiff(a, b): + diff = np.sum(np.abs(a - b)) + norm = np.sum(np.abs(a)) + if diff == 0: + return 0 + reldiff = diff / norm + return reldiff + + +def same(a, b): + return np.sum(a != b) == 0 + + +def check_elementwise_sum_with_shape(shape, n): + # forward + inputs = [mx.symbol.Variable('arg%d' % i) for i in range(n)] + out = mx.symbol.ElementWiseSum(*inputs, name='esum') + arr = [mx.narray.create(shape) for i in range(n)] + arr_grad = [mx.narray.create(shape) for i in range(n)] + for i in range(n): + arr[i].numpy[:] = np.random.uniform(-10, 10, shape) + exec1 = out.bind(mx.Context('cpu'), + args=arr, + args_grad=arr_grad) + out1 = exec1.heads()[0].numpy + exec1.forward() + out1 = exec1.heads()[0].numpy + out = sum(a.numpy for a in arr) + assert reldiff(out, out1) < 1e-6 + out_grad = mx.narray.create(shape) + out_grad.numpy[:] = np.random.uniform(-10, 10, shape) + # backward + exec1.backward([out_grad]) + for a in arr_grad: + assert same(a.numpy, out_grad.numpy) + + +def test_elementwise_sum(): + np.random.seed(0) + nrepeat = 2 + maxdim = 4 + for repeat in range(nrepeat): + for dim in range(1, maxdim): + shape = tuple(np.random.randint(1, int(1000**(1.0/dim)), size=dim)) + check_elementwise_sum_with_shape(shape, np.random.randint(1, 8)) + +def check_concat_with_shape(shapes): + n = len(shapes) + # forward + target_dim = 0 + for shape in shapes: + target_dim += shape[1] + + inputs = [mx.symbol.Variable('arg%d' % i) for i in range(n)] + out = mx.symbol.Concat(*inputs, name='conc') + arr = [mx.narray.create(shape) for shape in shapes] + for i in range(n): + arr[i][:] = shapes[i][1] + arr_np = [np.copy(narray.numpy) for narray in arr] + arr_grad = [mx.narray.create(shape) for shape in shapes] + args = out.list_arguments() + arg_shapes, out_shapes, aux_shapes = out.infer_shape(**dict(zip(args, shapes))) + out_grad = mx.narray.create(out_shapes[0]) + exec1 = out.bind(mx.Context('cpu'), + args=arr, + args_grad=arr_grad) + exec1.forward() + out1 = exec1.heads()[0] + ret = np.concatenate([narray.numpy for narray in arr], axis=1) + assert same(out1.numpy, ret) + # backward + out1.copyto(out_grad) + out_grad[:] += 1 + exec1.backward([out_grad]) + for grad, np_grad in zip(arr_grad, arr_np): + assert same(grad.numpy, np_grad + 1) + +def test_concat(): + n = 2 + batch = 2 + ch = [2, 3, 4, 5, 6] + h = 3 + w = 4 + # test 2D + for dim in range(2, 6): + shapes = [] + for i in range(dim): + shapes.append((batch, ch[i])) + check_concat_with_shape(shapes) + # test 4D + for dim in range(2, 6): + shapes = [] + for i in range(dim): + shapes.append((batch, ch[i], h, w)) + check_concat_with_shape(shapes) + + +if __name__ == '__main__': + test_elementwise_sum() + test_concat()