Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion example/cifar10/cifar10.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ def RandomInit(narray):
data_shape = (batch_size, 3, 28, 28)

in_data = mx.narray.empty(data_shape, mx.gpu())
executor = loss.simple_bind(mx.gpu(), {"data": in_data})
executor = loss.simple_bind(mx.gpu(), data = in_data)
out_narray = executor.heads()[0]
pred = mx.narray.zeros(out_narray.shape, mx.cpu())

Expand Down
19 changes: 11 additions & 8 deletions python/mxnet/symbol.py
Original file line number Diff line number Diff line change
Expand Up @@ -332,14 +332,19 @@ def _get_narray_handle(arg_key, args, arg_names, allow_missing):
raise TypeError('Only Accept list of NArrays or dict of str->NArray')
return c_array(NArrayHandle, arg_handles)

def simple_bind(self, ctx, args, grad_req='write'):
def simple_bind(self, ctx, grad_req='write', **kwargs):
"""Simply bind current symbol to get an executor
Parameters
----------
ctx : Context
The device context the generated executor to run on.

args : list of NArray or dict of str->NArray
grad_req: string
{'write', 'add', 'null'}, or list of str or dict of str->str, optional
Specifies how we should update the gradient to the args_grad.
- 'write' means everytime gradient is write to specified args_grad NArray.
- 'add' means everytime gradient is add to the specified NArray.
- 'null' means no action is taken, the gradient may not be calculated.
kwargs : dict of str->NArray
Input arguments to the symbol.
- type is dict of str->NArray, then it maps the name of arguments
to the corresponding NArray,
Expand All @@ -349,9 +354,7 @@ def simple_bind(self, ctx, args, grad_req='write'):
executor : mxnet.Executor
The generated Executor
"""
if not isinstance(args, dict):
raise TypeError("args must be dict of str->NArray")
input_shapes = dict((name, arr.shape) for name, arr in args.items())
input_shapes = dict((name, arr.shape) for name, arr in kwargs.items())
# pylint: disable=unused-variable
arg_shapes, out_shapes, aux_shapes = self.infer_shape(**input_shapes)
# pylint: enable=unused-variable
Expand All @@ -360,8 +363,8 @@ def simple_bind(self, ctx, args, grad_req='write'):
# alloc space
arg_narrays = []
for name, shape in zip(self.list_arguments(), arg_shapes):
if name in args:
arg_narrays.append(args[name])
if name in kwargs:
arg_narrays.append(kwargs[name])
else:
arg_narrays.append(zeros(shape, ctx))
# TODO(bing): specail treat input data grad
Expand Down