From d747857d1a055e839b015176eee225598ba9de37 Mon Sep 17 00:00:00 2001 From: Bing Xu Date: Sat, 12 Sep 2015 23:38:05 -0600 Subject: [PATCH] minor switch to kwargs --- example/cifar10/cifar10.py | 2 +- python/mxnet/symbol.py | 19 +++++++++++-------- 2 files changed, 12 insertions(+), 9 deletions(-) diff --git a/example/cifar10/cifar10.py b/example/cifar10/cifar10.py index 892f6fff5d1d..9b387b8d297a 100644 --- a/example/cifar10/cifar10.py +++ b/example/cifar10/cifar10.py @@ -162,7 +162,7 @@ def RandomInit(narray): data_shape = (batch_size, 3, 28, 28) in_data = mx.narray.empty(data_shape, mx.gpu()) -executor = loss.simple_bind(mx.gpu(), {"data": in_data}) +executor = loss.simple_bind(mx.gpu(), data = in_data) out_narray = executor.heads()[0] pred = mx.narray.zeros(out_narray.shape, mx.cpu()) diff --git a/python/mxnet/symbol.py b/python/mxnet/symbol.py index 90df0b663615..4cf15d7c60f5 100644 --- a/python/mxnet/symbol.py +++ b/python/mxnet/symbol.py @@ -332,14 +332,19 @@ def _get_narray_handle(arg_key, args, arg_names, allow_missing): raise TypeError('Only Accept list of NArrays or dict of str->NArray') return c_array(NArrayHandle, arg_handles) - def simple_bind(self, ctx, args, grad_req='write'): + def simple_bind(self, ctx, grad_req='write', **kwargs): """Simply bind current symbol to get an executor Parameters ---------- ctx : Context The device context the generated executor to run on. - - args : list of NArray or dict of str->NArray + grad_req: string + {'write', 'add', 'null'}, or list of str or dict of str->str, optional + Specifies how we should update the gradient to the args_grad. + - 'write' means everytime gradient is write to specified args_grad NArray. + - 'add' means everytime gradient is add to the specified NArray. + - 'null' means no action is taken, the gradient may not be calculated. + kwargs : dict of str->NArray Input arguments to the symbol. - type is dict of str->NArray, then it maps the name of arguments to the corresponding NArray, @@ -349,9 +354,7 @@ def simple_bind(self, ctx, args, grad_req='write'): executor : mxnet.Executor The generated Executor """ - if not isinstance(args, dict): - raise TypeError("args must be dict of str->NArray") - input_shapes = dict((name, arr.shape) for name, arr in args.items()) + input_shapes = dict((name, arr.shape) for name, arr in kwargs.items()) # pylint: disable=unused-variable arg_shapes, out_shapes, aux_shapes = self.infer_shape(**input_shapes) # pylint: enable=unused-variable @@ -360,8 +363,8 @@ def simple_bind(self, ctx, args, grad_req='write'): # alloc space arg_narrays = [] for name, shape in zip(self.list_arguments(), arg_shapes): - if name in args: - arg_narrays.append(args[name]) + if name in kwargs: + arg_narrays.append(kwargs[name]) else: arg_narrays.append(zeros(shape, ctx)) # TODO(bing): specail treat input data grad