From 21f5eaf46e2e21a58a5927963d46ed3679444778 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Tue, 20 Nov 2018 15:10:34 -0800 Subject: [PATCH 01/11] Aggregate SGD --- python/mxnet/model.py | 8 +- python/mxnet/optimizer/optimizer.py | 224 +++++++++++++++------ src/operator/optimizer_op-inl.h | 295 ++++++++++++++++++++++++++++ src/operator/optimizer_op.cc | 209 +++++++++++++++++++- src/operator/optimizer_op.cu | 9 + 5 files changed, 685 insertions(+), 60 deletions(-) diff --git a/python/mxnet/model.py b/python/mxnet/model.py index 2666f8bbcd4f..821555abc167 100644 --- a/python/mxnet/model.py +++ b/python/mxnet/model.py @@ -87,7 +87,7 @@ def _create_kvstore(kvstore, num_device, arg_params): arg_params : dict of str to `NDArray`. Model parameter, dict of name to `NDArray` of net's weights. """ - update_on_kvstore = True + update_on_kvstore = bool(int(os.getenv('MXNET_UPDATE_ON_KVSTORE', 1))) if kvstore is None: kv = None elif isinstance(kvstore, kvs.KVStore): @@ -157,6 +157,7 @@ def _update_params_on_kvstore(param_arrays, grad_arrays, kvstore, param_names): def _update_params(param_arrays, grad_arrays, updater, num_device, kvstore=None, param_names=None): """Perform update of param_arrays from grad_arrays not on kvstore.""" + updates = [[] for _ in range(num_device)] for i, pair in enumerate(zip(param_arrays, grad_arrays)): arg_list, grad_list = pair if grad_list[0] is None: @@ -173,7 +174,10 @@ def _update_params(param_arrays, grad_arrays, updater, num_device, # state for the same index but on diff devs, TODO(mli) # use a better solution later w, g = p - updater(index*num_device+k, g, w) + updates[k].append((index*num_device+k, g, w)) + for dev_updates in updates: + i, w, g = zip(*dev_updates) + updater(i, w, g) def _multiple_callbacks(callbacks, *args, **kwargs): diff --git a/python/mxnet/optimizer/optimizer.py b/python/mxnet/optimizer/optimizer.py index d632a8c7c640..e0643ac39d94 100644 --- a/python/mxnet/optimizer/optimizer.py +++ b/python/mxnet/optimizer/optimizer.py @@ -22,12 +22,15 @@ import math import pickle import warnings +import os import numpy from ..base import py_str from ..ndarray import (NDArray, zeros, clip, sqrt, cast, maximum, abs as NDabs, array, multiply) from ..ndarray import (sgd_update, sgd_mom_update, adam_update, rmsprop_update, rmspropalex_update, mp_sgd_update, mp_sgd_mom_update, square, ftrl_update, ftml_update, - signsgd_update, signum_update) + signsgd_update, signum_update, + multi_sgd_update, multi_sgd_mom_update, multi_mp_sgd_update, + multi_mp_sgd_mom_update) from ..ndarray import sparse from ..random import normal @@ -37,6 +40,8 @@ 'Test', 'Updater', 'ccSGD', 'create', 'get_updater', 'register' ] +def _flatten_list(nested_list): + return [item for sublist in nested_list for item in sublist] class Optimizer(object): """The base class inherited by all optimizers. @@ -101,6 +106,7 @@ def __init__(self, rescale_grad=1., param_idx2name=None, wd=0., self._index_update_count = {} self.clip_gradient = clip_gradient self.multi_precision = multi_precision + self.aggregate_num = 0 if param_idx2name is None: param_idx2name = {} @@ -376,13 +382,44 @@ def _update_count(self, index): Parameters ---------- - index : int + index : int or list of int The index to be updated. """ - if index not in self._index_update_count: - self._index_update_count[index] = self.begin_num_update - self._index_update_count[index] += 1 - self.num_update = max(self._index_update_count[index], self.num_update) + if not isinstance(index, (list, tuple)): + index = [index] + for idx in index: + if idx not in self._index_update_count: + self._index_update_count[idx] = self.begin_num_update + self._index_update_count[idx] += 1 + self.num_update = max(self._index_update_count[idx], self.num_update) + + def _get_lrs(self, indices): + """Gets the learning rates given the indices of the weights. + + Parameters + ---------- + indices : list of int + Indices corresponding to weights. + + Returns + ------- + lrs : list of float + Learning rates for those indices. + """ + if self.lr_scheduler is not None: + lr = self.lr_scheduler(self.num_update) + else: + lr = self.lr + + lrs = [lr for _ in indices] + for i, index in enumerate(indices): + if index in self.param_dict: + lrs[i] *= self.param_dict[index].lr_mult + elif index in self.lr_mult: + lrs[i] *= self.lr_mult[index] + elif index in self.idx2name: + lrs[i] *= self.lr_mult.get(self.idx2name[index], 1.0) + return lrs def _get_lr(self, index): """Gets the learning rate given the index of the weight. @@ -397,18 +434,31 @@ def _get_lr(self, index): lr : float Learning rate for this index. """ - if self.lr_scheduler is not None: - lr = self.lr_scheduler(self.num_update) - else: - lr = self.lr + return self._get_lrs([index])[0] + + def _get_wds(self, indices): + """Gets weight decays for indices. + Returns 0 for non-weights if the name of weights are provided for `__init__`. - if index in self.param_dict: - lr *= self.param_dict[index].lr_mult - elif index in self.lr_mult: - lr *= self.lr_mult[index] - elif index in self.idx2name: - lr *= self.lr_mult.get(self.idx2name[index], 1.0) - return lr + Parameters + ---------- + indices : list of int + Indices of weights. + + Returns + ------- + wds : list of float + Weight decays for those indices. + """ + wds = [self.wd for _ in indices] + for i, index in enumerate(indices): + if index in self.param_dict: + wds[i] *= self.param_dict[index].wd_mult + elif index in self.wd_mult: + wds[i] *= self.wd_mult[index] + elif index in self.idx2name: + wds[i] *= self.wd_mult.get(self.idx2name[index], 1.0) + return wds def _get_wd(self, index): """Gets weight decay for index. @@ -417,21 +467,14 @@ def _get_wd(self, index): Parameters ---------- index : int - The index for weight. + The index of weight. Returns ------- wd : float Weight decay for this index. """ - wd = self.wd - if index in self.param_dict: - wd *= self.param_dict[index].wd_mult - elif index in self.wd_mult: - wd *= self.wd_mult[index] - elif index in self.idx2name: - wd *= self.wd_mult.get(self.idx2name[index], 1.0) - return wd + return self._get_wds([index])[0] def __getstate__(self): ret = self.__dict__.copy() @@ -498,6 +541,7 @@ def __init__(self, momentum=0.0, lazy_update=True, **kwargs): super(SGD, self).__init__(**kwargs) self.momentum = momentum self.lazy_update = lazy_update + self.aggregate_num = int(os.getenv('MXNET_OPTIMIZER_AGGREGATION_SIZE', 4)) def create_state_multi_precision(self, index, weight): weight_master_copy = None @@ -518,12 +562,22 @@ def create_state(self, index, weight): momentum = zeros(weight.shape, weight.context, dtype=weight.dtype, stype=stype) return momentum - def _update_impl(self, index, weight, grad, state, multi_precision=False): - assert(isinstance(weight, NDArray)) - assert(isinstance(grad, NDArray)) - self._update_count(index) - lr = self._get_lr(index) - wd = self._get_wd(index) + def _update_impl(self, indices, weights, grads, states, multi_precision=False): + aggregate = True + if not isinstance(indices, (tuple, list)): + indices = [indices] + weights = [weights] + grads = [grads] + states = [states] + for weight, grad in zip(weights, grads): + assert(isinstance(weight, NDArray)) + assert(isinstance(grad, NDArray)) + aggregate = (aggregate and + weight.stype == 'default' and + grad.stype == 'default') + self._update_count(indices) + lrs = self._get_lrs(indices) + wds = self._get_wds(indices) kwargs = {'rescale_grad': self.rescale_grad} if self.momentum > 0: @@ -531,26 +585,49 @@ def _update_impl(self, index, weight, grad, state, multi_precision=False): if self.clip_gradient: kwargs['clip_gradient'] = self.clip_gradient - if not multi_precision: - if state is not None: - sgd_mom_update(weight, grad, state, out=weight, - lazy_update=self.lazy_update, lr=lr, wd=wd, **kwargs) + if aggregate: + if not multi_precision: + if self.momentum > 0: + multi_sgd_mom_update(*_flatten_list(zip(weights, grads, states)), out=weights, + num_weights=len(weights), lrs=lrs, wds=wds, **kwargs) + else: + multi_sgd_update(*_flatten_list(zip(weights, grads)), out=weights, + num_weights=len(weights), lrs=lrs, wds=wds, **kwargs) else: - sgd_update(weight, grad, out=weight, lazy_update=self.lazy_update, - lr=lr, wd=wd, **kwargs) + if self.momentum > 0: + multi_mp_sgd_mom_update(*_flatten_list(zip(weights, grads, *zip(*states))), + out=weights, num_weights=len(weights), + lrs=lrs, wds=wds, **kwargs) + else: + multi_mp_sgd_update(*_flatten_list(zip(weights, grads, + list(zip(*states))[1])), + out=weights, num_weights=len(weights), + lrs=lrs, wds=wds, **kwargs) else: - if state[0] is not None: - mp_sgd_mom_update(weight, grad, state[0], state[1], out=weight, - lr=lr, wd=wd, **kwargs) - else: - mp_sgd_update(weight, grad, state[1], out=weight, - lr=lr, wd=wd, **kwargs) + for weight, grad, state, lr, wd in zip(weights, grads, states, lrs, wds): + if not multi_precision: + if state is not None: + sgd_mom_update(weight, grad, state, out=weight, + lazy_update=self.lazy_update, lr=lr, wd=wd, **kwargs) + else: + sgd_update(weight, grad, out=weight, lazy_update=self.lazy_update, + lr=lr, wd=wd, **kwargs) + else: + if state[0] is not None: + mp_sgd_mom_update(weight, grad, state[0], state[1], out=weight, + lr=lr, wd=wd, **kwargs) + else: + mp_sgd_update(weight, grad, state[1], out=weight, + lr=lr, wd=wd, **kwargs) def update(self, index, weight, grad, state): self._update_impl(index, weight, grad, state, multi_precision=False) def update_multi_precision(self, index, weight, grad, state): - use_multi_precision = self.multi_precision and weight.dtype == numpy.float16 + if not isinstance(index, (tuple, list)): + use_multi_precision = self.multi_precision and weight.dtype == numpy.float16 + else: + use_multi_precision = self.multi_precision and weight[0].dtype == numpy.float16 self._update_impl(index, weight, grad, state, multi_precision=use_multi_precision) @@ -1513,20 +1590,55 @@ def __init__(self, optimizer): self.optimizer = optimizer self.states = {} self.states_synced = {} + self.aggregate_updates = optimizer.aggregate_num > 0 def __call__(self, index, grad, weight): """Updates weight given gradient and index.""" - # convert ctypes.char_p.value back to python str if needed - if isinstance(index, bytes): - index = py_str(index) - if index not in self.states: - self.states[index] = self.optimizer.create_state_multi_precision(index, weight) - self.states_synced[index] = True - elif not self.states_synced[index]: - self.states[index] = \ - self.sync_state_context(self.states[index], weight.context) - self.states_synced[index] = True - self.optimizer.update_multi_precision(index, weight, grad, self.states[index]) + if not isinstance(index, (list, tuple)): + indices = [index] + grads = [grad] + weights = [weight] + else: + indices = index + grads = grad + weights = weight + for i, idx in enumerate(indices): + # convert ctypes.char_p.value back to python str if needed + if isinstance(idx, bytes): + indices[i] = py_str(idx) + idx = indices[i] + if idx not in self.states: + self.states[idx] = self.optimizer.create_state_multi_precision(idx, weights[i]) + self.states_synced[idx] = True + elif not self.states_synced[idx]: + self.states[idx] = \ + self.sync_state_context(self.states[idx], weights[i].context) + self.states_synced[idx] = True + if self.aggregate_updates: + # segregate values based on type + type_map = {} + for i, w, g in zip(indices, weights, grads): + if w.dtype in type_map: + type_map[w.dtype].append((i, w, g)) + else: + type_map[w.dtype] = [(i, w, g)] + for idx in type_map: + current_index = 0 + indices, weights, grads = zip(*type_map[idx]) + while current_index < len(indices): + states = [] + step = min(self.optimizer.aggregate_num, len(indices) - current_index) + for j in range(step): + states.append(self.states[indices[current_index + j]]) + self.optimizer.update_multi_precision( + indices[current_index:current_index + self.optimizer.aggregate_num], + weights[current_index:current_index + self.optimizer.aggregate_num], + grads[current_index:current_index + self.optimizer.aggregate_num], + states) + current_index += self.optimizer.aggregate_num + else: + for i, w, g in zip(indices, weights, grads): + self.optimizer.update_multi_precision(i, w, g, self.states[i]) def sync_state_context(self, state, context): """sync state context.""" diff --git a/src/operator/optimizer_op-inl.h b/src/operator/optimizer_op-inl.h index 9251b8614806..223a1aa6c37d 100644 --- a/src/operator/optimizer_op-inl.h +++ b/src/operator/optimizer_op-inl.h @@ -82,6 +82,301 @@ struct SGDParam : public dmlc::Parameter { } }; +struct MultiSGDParam : public dmlc::Parameter { + nnvm::Tuple lrs; + nnvm::Tuple wds; + float rescale_grad; + float clip_gradient; + int num_weights; + DMLC_DECLARE_PARAMETER(MultiSGDParam) { + DMLC_DECLARE_FIELD(lrs) + .describe("Learning rates."); + DMLC_DECLARE_FIELD(wds) + .describe("Weight decay augments the objective function with a " + "regularization term that penalizes large weights. " + "The penalty scales with the square of the magnitude of each weight."); + DMLC_DECLARE_FIELD(rescale_grad) + .set_default(1.0f) + .describe("Rescale gradient to grad = rescale_grad*grad."); + DMLC_DECLARE_FIELD(clip_gradient) + .set_default(-1.0f) + .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] " + "If clip_gradient <= 0, gradient clipping is turned off. " + "grad = max(min(grad, clip_gradient), -clip_gradient)."); + DMLC_DECLARE_FIELD(num_weights) + .set_default(1) + .describe("Number of updated weights."); + } +}; + +struct MultiSGDMomParam : public dmlc::Parameter { + nnvm::Tuple lrs; + nnvm::Tuple wds; + float momentum; + float rescale_grad; + float clip_gradient; + int num_weights; + DMLC_DECLARE_PARAMETER(MultiSGDMomParam) { + DMLC_DECLARE_FIELD(lrs) + .describe("Learning rates."); + DMLC_DECLARE_FIELD(wds) + .describe("Weight decay augments the objective function with a " + "regularization term that penalizes large weights. " + "The penalty scales with the square of the magnitude of each weight."); + DMLC_DECLARE_FIELD(momentum) + .set_default(0.0f) + .describe("The decay rate of momentum estimates at each epoch."); + DMLC_DECLARE_FIELD(rescale_grad) + .set_default(1.0f) + .describe("Rescale gradient to grad = rescale_grad*grad."); + DMLC_DECLARE_FIELD(clip_gradient) + .set_default(-1.0f) + .describe("Clip gradient to the range of [-clip_gradient, clip_gradient] " + "If clip_gradient <= 0, gradient clipping is turned off. " + "grad = max(min(grad, clip_gradient), -clip_gradient)."); + DMLC_DECLARE_FIELD(num_weights) + .set_default(1) + .describe("Number of updated weights."); + } +}; + +template +inline bool MultiSGDShape(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + const ParamType& param = dmlc::get(attrs.parsed); + CHECK_EQ(in_attrs->size(), input_stride * param.num_weights); + CHECK_EQ(out_attrs->size(), param.num_weights); + + bool all_inferred = true; + auto& input_shapes = *in_attrs; + auto& output_shapes = *out_attrs; + // Learning rates + CHECK_EQ(param.lrs.ndim(), param.num_weights) + << "Number of learning rates is inconsistent with num_weights " + << "parameter passed. Expected number of learning rates: " + << param.num_weights << ", and got " << param.lrs.ndim(); + // Weight decays + CHECK_EQ(param.wds.ndim(), param.num_weights) + << "Number of weight decays is inconsistent with num_weights " + << "parameter passed. Expected number of weight decays: " + << param.num_weights << ", and got " << param.wds.ndim(); + // Weights and gradients + for (int i = 0; i < param.num_weights; ++i) { + std::vector input_vec; + std::vector output_vec({output_shapes[i]}); + for (int j = 0; j < input_stride; ++j) { + input_vec.push_back(input_shapes[i * input_stride + j]); + } + all_inferred = all_inferred && ElemwiseShape(attrs, &input_vec, &output_vec); + } + return all_inferred; +} + +template +inline bool MP_MultiSGD_InferType(const nnvm::NodeAttrs& attrs, + std::vector *in_attrs, + std::vector *out_attrs) { + const ParamType& param = dmlc::get(attrs.parsed); + CHECK_EQ(in_attrs->size(), input_stride * param.num_weights); + CHECK_EQ(out_attrs->size(), param.num_weights); + + bool all_inferred = true; + auto& input_types = *in_attrs; + auto& output_types = *out_attrs; + // Weights and gradients + for (int i = 0; i < param.num_weights; ++i) { + std::vector input_vec; + std::vector output_vec({output_types[i]}); + for (int j = 0; j < input_stride - num_fp32_inputs; ++j) { + input_vec.push_back(input_types[i * input_stride + j]); + } + all_inferred = all_inferred && + ElemwiseType(attrs, &input_vec, &output_vec); + } + // master copies of weights + for (int i = 0; i < param.num_weights; ++i) { + for (int j = 0; j < num_fp32_inputs; ++j) { + TYPE_ASSIGN_CHECK(input_types, input_stride * i + input_stride - 1 - j, mshadow::kFloat32); + } + } + return all_inferred; +} + +template +struct MultiSGDKernelParam { + static const int N = 60; + int count; + size_t max_size; + size_t sizes[N]; + DType * weights[N]; + DType * grads[N]; + MPDType * mom[N]; + MPDType * weights32[N]; + DType * out_data[N]; + MPDType lrs[N]; + MPDType wds[N]; + MPDType clip_gradient; + MPDType rescale_grad; + MPDType momentum; +}; + +template +struct MultiSGDKernel { + template + MSHADOW_XINLINE static void Map(int i, const MultiSGDKernelParam& param, + const OpReqType req) { + for (int index = 0; index < param.count; ++index) { + if ((size_t)i < param.sizes[index]) { + MPDType w = has_mixed_precision ? param.weights32[index][i] : + MPDType(param.weights[index][i]); + MPDType mom = has_momentum ? param.mom[index][i] : MPDType(0); + if (param.clip_gradient >= 0.0f) { + mom = param.momentum*mom + - param.lrs[index]*param.wds[index]*w + - param.lrs[index] + *mshadow_op::clip::Map(param.rescale_grad * + static_cast(param.grads[index][i]), + param.clip_gradient); + } else { + mom = param.momentum*mom + - param.lrs[index]*param.wds[index]*w + - param.lrs[index]*param.rescale_grad*static_cast(param.grads[index][i]); + } + if (has_momentum) { + param.mom[index][i] = mom; + } + w = w + mom; + if (has_mixed_precision) { + param.weights32[index][i] = w; + } + KERNEL_ASSIGN(param.out_data[index][i], req, w); + } + } + } +}; + +template +MultiSGDKernelParam FillMultiSGDKernelParam(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &outputs) { + using namespace mxnet_op; + const ParamType& p = nnvm::get(attrs.parsed); + Stream* s = ctx.get_stream(); + MultiSGDKernelParam param; + param.clip_gradient = p.clip_gradient; + param.rescale_grad = p.rescale_grad; + param.momentum = 0; + param.count = p.num_weights; + param.max_size = 0; + for (int i = 0; i < param.count; ++i) { + param.sizes[i] = inputs[i * input_stride].shape_.Size(); + if (param.max_size < param.sizes[i]) { + param.max_size = param.sizes[i]; + } + param.weights[i] = inputs[i * input_stride].FlatTo2D(s).dptr_; + param.grads[i] = inputs[i * input_stride + 1].FlatTo2D(s).dptr_; + // if mixed precision, then the last input in a set + // is 32-bit master copy of the weights + if (!std::is_same::value) { + param.weights32[i] = inputs[i * input_stride + input_stride - 1] + .FlatTo2D(s).dptr_; + } + param.out_data[i] = outputs[i].FlatTo2D(s).dptr_; + param.lrs[i] = p.lrs[i]; + param.wds[i] = p.wds[i]; + } + + return param; +} + + +template +MultiSGDKernelParam FillMultiSGDMomKernelParam(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &outputs) { + using namespace mxnet_op; + const MultiSGDMomParam& p = nnvm::get(attrs.parsed); + Stream* s = ctx.get_stream(); + MultiSGDKernelParam param = + FillMultiSGDKernelParam(attrs, ctx, inputs, outputs); + param.momentum = p.momentum; + for (int i = 0; i < param.count; ++i) { + param.mom[i] = inputs[i * input_stride + 2].FlatTo2D(s).dptr_; + } + + return param; +} + +template +class type_identity { + public: + using type = T; +}; + +template +class single_precision { + public: + using type = float; +}; + +template class MPTypeChooser, int input_stride> +inline void MultiSGDUpdate(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + using namespace mxnet_op; + Stream* s = ctx.get_stream(); + MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { + using MPDType = typename MPTypeChooser::type; + MultiSGDKernelParam param = + FillMultiSGDKernelParam(attrs, ctx, inputs, outputs); + Kernel::value>, + xpu>::Launch(s, param.max_size, param, req[0]); + }); +} + +template class MPTypeChooser, int input_stride> +inline void MultiSGDMomUpdate(const nnvm::NodeAttrs& attrs, + const OpContext &ctx, + const std::vector &inputs, + const std::vector &req, + const std::vector &outputs) { + using namespace mxnet_op; + Stream* s = ctx.get_stream(); + MSHADOW_REAL_TYPE_SWITCH(outputs[0].type_flag_, DType, { + using MPDType = typename MPTypeChooser::type; + MultiSGDKernelParam param = + FillMultiSGDMomKernelParam(attrs, ctx, inputs, outputs); + Kernel::value>, + xpu>::Launch(s, param.max_size, param, req[0]); + }); +} struct SGDKernel { template diff --git a/src/operator/optimizer_op.cc b/src/operator/optimizer_op.cc index 6c44f99c1443..1f4c840cfb50 100644 --- a/src/operator/optimizer_op.cc +++ b/src/operator/optimizer_op.cc @@ -31,6 +31,8 @@ namespace op { DMLC_REGISTER_PARAMETER(SGDParam); DMLC_REGISTER_PARAMETER(SGDMomParam); +DMLC_REGISTER_PARAMETER(MultiSGDParam); +DMLC_REGISTER_PARAMETER(MultiSGDMomParam); DMLC_REGISTER_PARAMETER(FTMLParam); DMLC_REGISTER_PARAMETER(AdamParam); DMLC_REGISTER_PARAMETER(RMSPropParam); @@ -52,7 +54,7 @@ It updates the weights using:: weight = weight - learning_rate * sign(gradient) -.. note:: +.. note:: - sparse ndarray not supported for this optimizer yet. )code" ADD_FILELINE) .set_num_inputs(2) @@ -81,7 +83,7 @@ It updates the weights using:: Where the parameter ``momentum`` is the decay rate of momentum estimates at each epoch. -.. note:: +.. note:: - sparse ndarray not supported for this optimizer yet. )code" ADD_FILELINE) .set_num_inputs(3) @@ -313,6 +315,209 @@ inline bool SGDStorageType(const nnvm::NodeAttrs& attrs, return dispatched; } +NNVM_REGISTER_OP(multi_sgd_update) +.describe(R"code(Update function for Stochastic Gradient Descent (SDG) optimizer. + +It updates the weights using:: + + weight = weight - learning_rate * (gradient + wd * weight) + +)code" ADD_FILELINE) +.set_num_inputs([](const nnvm::NodeAttrs& attrs) { + const MultiSGDParam& param = dmlc::get(attrs.parsed); + return static_cast(param.num_weights * 2); + }) +.set_num_outputs([](const nnvm::NodeAttrs& attrs) { + const MultiSGDParam& param = dmlc::get(attrs.parsed); + return static_cast(param.num_weights); + }) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", MultiSGDShape) +.set_attr("FInferType", ElemwiseType<-1, -1>) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + uint32_t num_args = dmlc::get(attrs.parsed).num_weights; + std::vector ret; + for (uint32_t i = 0; i < num_args; ++i) { + ret.push_back(std::string("weight_") + std::to_string(i)); + ret.push_back(std::string("grad_") + std::to_string(i)); + } + return ret; + }) +.set_attr("FCompute", MultiSGDUpdate) +.add_argument("data", "NDArray-or-Symbol[]", "Weights") +.add_arguments(MultiSGDParam::__FIELDS__()); + +NNVM_REGISTER_OP(multi_sgd_mom_update) +.describe(R"code(Momentum update function for Stochastic Gradient Descent (SGD) optimizer. + +Momentum update has better convergence rates on neural networks. Mathematically it looks +like below: + +.. math:: + + v_1 = \alpha * \nabla J(W_0)\\ + v_t = \gamma v_{t-1} - \alpha * \nabla J(W_{t-1})\\ + W_t = W_{t-1} + v_t + +It updates the weights using:: + + v = momentum * v - learning_rate * gradient + weight += v + +Where the parameter ``momentum`` is the decay rate of momentum estimates at each epoch. + +However, if grad's storage type is ``row_sparse``, ``lazy_update`` is True and weight's storage +type is the same as momentum's storage type, +only the row slices whose indices appear in grad.indices are updated (for both weight and momentum):: + + for row in gradient.indices: + v[row] = momentum[row] * v[row] - learning_rate * gradient[row] + weight[row] += v[row] + +)code" ADD_FILELINE) +.set_num_inputs([](const nnvm::NodeAttrs& attrs) { + const MultiSGDMomParam& param = dmlc::get(attrs.parsed); + return static_cast(param.num_weights * 3); + }) +.set_num_outputs([](const nnvm::NodeAttrs& attrs) { + const MultiSGDMomParam& param = dmlc::get(attrs.parsed); + return static_cast(param.num_weights); + }) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", MultiSGDShape) +.set_attr("FInferType", ElemwiseType<-1, -1>) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + uint32_t num_args = dmlc::get(attrs.parsed).num_weights; + std::vector ret; + for (uint32_t i = 0; i < num_args; ++i) { + ret.push_back(std::string("weight_") + std::to_string(i)); + ret.push_back(std::string("grad_") + std::to_string(i)); + ret.push_back(std::string("mom_") + std::to_string(i)); + } + return ret; + }) +.set_attr("FMutateInputs", + [](const nnvm::NodeAttrs& attrs) { + std::vector ret; + const MultiSGDMomParam& param = dmlc::get(attrs.parsed); + for (int i = 0; i < param.num_weights; ++i) { + ret.push_back(i * 3 + 2); + } + return ret; + }) +.set_attr("FCompute", MultiSGDMomUpdate) +.add_argument("data", "NDArray-or-Symbol[]", "Weights, gradients and momentum") +.add_arguments(MultiSGDMomParam::__FIELDS__()); + +NNVM_REGISTER_OP(multi_mp_sgd_update) +.describe(R"code(Update function for multi-precision Stochastic Gradient Descent (SDG) optimizer. + +It updates the weights using:: + + weight = weight - learning_rate * (gradient + wd * weight) + +)code" ADD_FILELINE) +.set_num_inputs([](const nnvm::NodeAttrs& attrs) { + const MultiSGDParam& param = dmlc::get(attrs.parsed); + return static_cast(param.num_weights * 3); + }) +.set_num_outputs([](const nnvm::NodeAttrs& attrs) { + const MultiSGDParam& param = dmlc::get(attrs.parsed); + return static_cast(param.num_weights); + }) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", MultiSGDShape) +.set_attr("FInferType", MP_MultiSGD_InferType) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + uint32_t num_args = dmlc::get(attrs.parsed).num_weights; + std::vector ret; + for (uint32_t i = 0; i < num_args; ++i) { + ret.push_back(std::string("weight_") + std::to_string(i)); + ret.push_back(std::string("grad_") + std::to_string(i)); + ret.push_back(std::string("weight32_") + std::to_string(i)); + } + return ret; + }) +.set_attr("FMutateInputs", + [](const nnvm::NodeAttrs& attrs) { + std::vector ret; + const MultiSGDParam& param = dmlc::get(attrs.parsed); + for (int i = 0; i < param.num_weights; ++i) { + ret.push_back(i * 3 + 2); + } + return ret; + }) +.set_attr("FCompute", MultiSGDUpdate) +.add_argument("data", "NDArray-or-Symbol[]", "Weights") +.add_arguments(MultiSGDParam::__FIELDS__()); + +NNVM_REGISTER_OP(multi_mp_sgd_mom_update) +.describe(R"code(Momentum update function for multi-precision Stochastic Gradient Descent (SGD) optimizer. + +Momentum update has better convergence rates on neural networks. Mathematically it looks +like below: + +.. math:: + + v_1 = \alpha * \nabla J(W_0)\\ + v_t = \gamma v_{t-1} - \alpha * \nabla J(W_{t-1})\\ + W_t = W_{t-1} + v_t + +It updates the weights using:: + + v = momentum * v - learning_rate * gradient + weight += v + +Where the parameter ``momentum`` is the decay rate of momentum estimates at each epoch. + +However, if grad's storage type is ``row_sparse``, ``lazy_update`` is True and weight's storage +type is the same as momentum's storage type, +only the row slices whose indices appear in grad.indices are updated (for both weight and momentum):: + + for row in gradient.indices: + v[row] = momentum[row] * v[row] - learning_rate * gradient[row] + weight[row] += v[row] + +)code" ADD_FILELINE) +.set_num_inputs([](const nnvm::NodeAttrs& attrs) { + const MultiSGDMomParam& param = dmlc::get(attrs.parsed); + return static_cast(param.num_weights * 4); + }) +.set_num_outputs([](const nnvm::NodeAttrs& attrs) { + const MultiSGDMomParam& param = dmlc::get(attrs.parsed); + return static_cast(param.num_weights); + }) +.set_attr_parser(ParamParser) +.set_attr("FInferShape", MultiSGDShape) +.set_attr("FInferType", MP_MultiSGD_InferType) +.set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + uint32_t num_args = dmlc::get(attrs.parsed).num_weights; + std::vector ret; + for (uint32_t i = 0; i < num_args; ++i) { + ret.push_back(std::string("weight_") + std::to_string(i)); + ret.push_back(std::string("grad_") + std::to_string(i)); + ret.push_back(std::string("mom_") + std::to_string(i)); + ret.push_back(std::string("weight32_") + std::to_string(i)); + } + return ret; + }) +.set_attr("FMutateInputs", + [](const nnvm::NodeAttrs& attrs) { + std::vector ret; + const MultiSGDMomParam& param = dmlc::get(attrs.parsed); + for (int i = 0; i < param.num_weights; ++i) { + ret.push_back(i * 4 + 2); + ret.push_back(i * 4 + 3); + } + return ret; + }) +.set_attr("FCompute", MultiSGDMomUpdate) +.add_argument("data", "NDArray-or-Symbol[]", "Weights") +.add_arguments(MultiSGDMomParam::__FIELDS__()); NNVM_REGISTER_OP(sgd_update) MXNET_ADD_SPARSE_OP_ALIAS(sgd_update) diff --git a/src/operator/optimizer_op.cu b/src/operator/optimizer_op.cu index 0fd2ca83fda4..c42cf1831c43 100644 --- a/src/operator/optimizer_op.cu +++ b/src/operator/optimizer_op.cu @@ -242,6 +242,15 @@ NNVM_REGISTER_OP(mp_sgd_update) NNVM_REGISTER_OP(mp_sgd_mom_update) .set_attr("FCompute", MP_SGDMomUpdate); +NNVM_REGISTER_OP(multi_sgd_update) +.set_attr("FCompute", MultiSGDUpdate); +NNVM_REGISTER_OP(multi_sgd_mom_update) +.set_attr("FCompute", MultiSGDMomUpdate); +NNVM_REGISTER_OP(multi_mp_sgd_update) +.set_attr("FCompute", MultiSGDUpdate); +NNVM_REGISTER_OP(multi_mp_sgd_mom_update) +.set_attr("FCompute", MultiSGDMomUpdate); + NNVM_REGISTER_OP(ftml_update) .set_attr("FCompute", FTMLUpdate); From 3280482eade81b834e7699362300f7b5c4340348 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Wed, 21 Nov 2018 11:54:13 -0800 Subject: [PATCH 02/11] Make OpWrapperGenerator understand Tuple --- cpp-package/scripts/OpWrapperGenerator.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/cpp-package/scripts/OpWrapperGenerator.py b/cpp-package/scripts/OpWrapperGenerator.py index 1b5f8b56b924..1406a506f133 100644 --- a/cpp-package/scripts/OpWrapperGenerator.py +++ b/cpp-package/scripts/OpWrapperGenerator.py @@ -97,7 +97,8 @@ class Arg: 'double':'double',\ 'double or None':'dmlc::optional',\ 'Shape or None':'dmlc::optional',\ - 'string':'const std::string&'} + 'string':'const std::string&',\ + 'tuple of ':'nnvm::Tuple'} name = '' type = '' description = '' From 74f7b9bb313cc086b9eb6a1b7eabd50138ba79d6 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Mon, 26 Nov 2018 08:45:36 -0800 Subject: [PATCH 03/11] Trigger From 81d3b793bac533ac65bdfb6b4502c10dec45da57 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Mon, 26 Nov 2018 15:38:55 -0800 Subject: [PATCH 04/11] Add NNVM Tuple to cpp-package op.h --- cpp-package/scripts/OpWrapperGenerator.py | 1 + 1 file changed, 1 insertion(+) diff --git a/cpp-package/scripts/OpWrapperGenerator.py b/cpp-package/scripts/OpWrapperGenerator.py index 1406a506f133..7563dffe92df 100644 --- a/cpp-package/scripts/OpWrapperGenerator.py +++ b/cpp-package/scripts/OpWrapperGenerator.py @@ -406,6 +406,7 @@ def ParseAllOps(): "#include \"mxnet-cpp/op_util.h\"\n" "#include \"mxnet-cpp/operator.h\"\n" "#include \"dmlc/optional.h\"\n" + "#include \"nnvm/tuple.h\"\n" "\n" "namespace mxnet {\n" "namespace cpp {\n" From c25903ddb47b12053ac3ca91cdb75e27d9648245 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Fri, 30 Nov 2018 20:47:01 -0800 Subject: [PATCH 05/11] Trigger From a29a191d3eb88967648d4c235b090ec6ad9f9770 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Mon, 3 Dec 2018 14:53:37 -0800 Subject: [PATCH 06/11] Fix pylint aggregate SGD --- python/mxnet/model.py | 4 ++-- python/mxnet/optimizer/optimizer.py | 2 +- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/python/mxnet/model.py b/python/mxnet/model.py index 821555abc167..f70c7a592990 100644 --- a/python/mxnet/model.py +++ b/python/mxnet/model.py @@ -87,14 +87,14 @@ def _create_kvstore(kvstore, num_device, arg_params): arg_params : dict of str to `NDArray`. Model parameter, dict of name to `NDArray` of net's weights. """ - update_on_kvstore = bool(int(os.getenv('MXNET_UPDATE_ON_KVSTORE', 1))) + update_on_kvstore = bool(int(os.getenv('MXNET_UPDATE_ON_KVSTORE', "1"))) if kvstore is None: kv = None elif isinstance(kvstore, kvs.KVStore): kv = kvstore elif isinstance(kvstore, str): # create kvstore using the string type - if num_device is 1 and 'dist' not in kvstore: + if num_device == 1 and 'dist' not in kvstore: # no need to use kv for single device and single machine kv = None else: diff --git a/python/mxnet/optimizer/optimizer.py b/python/mxnet/optimizer/optimizer.py index e0643ac39d94..310bfa1d9873 100644 --- a/python/mxnet/optimizer/optimizer.py +++ b/python/mxnet/optimizer/optimizer.py @@ -541,7 +541,7 @@ def __init__(self, momentum=0.0, lazy_update=True, **kwargs): super(SGD, self).__init__(**kwargs) self.momentum = momentum self.lazy_update = lazy_update - self.aggregate_num = int(os.getenv('MXNET_OPTIMIZER_AGGREGATION_SIZE', 4)) + self.aggregate_num = int(os.getenv('MXNET_OPTIMIZER_AGGREGATION_SIZE', "4")) def create_state_multi_precision(self, index, weight): weight_master_copy = None From e03cd057f5cf6b4940ff2c6cf86479e30acce0d5 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Mon, 3 Dec 2018 15:41:32 -0800 Subject: [PATCH 07/11] Update info about new ENV vars and modifying 2 tests that require update_on_kvstore to be true --- docs/faq/env_var.md | 9 +++++++++ tests/python/unittest/test_gluon_trainer.py | 9 +++++++-- tests/python/unittest/test_module.py | 3 +++ 3 files changed, 19 insertions(+), 2 deletions(-) diff --git a/docs/faq/env_var.md b/docs/faq/env_var.md index e373377ee8de..46215f7eec41 100644 --- a/docs/faq/env_var.md +++ b/docs/faq/env_var.md @@ -139,6 +139,10 @@ $env:MXNET_STORAGE_FALLBACK_LOG_VERBOSE=0 - If true, MXNet tries to use GPU peer-to-peer communication, if available on your device, when kvstore's type is `device`. +* MXNET_UPDATE_ON_KVSTORE + - Values: 0(false) or 1(true) ```(default=1)``` + - If true, weight updates are performed during the communication step, if possible. + ## Memonger * MXNET_BACKWARD_DO_MIRROR @@ -202,6 +206,11 @@ When USE_PROFILER is enabled in Makefile or CMake, the following environments ca If no such algorithm exists given other constraints, MXNet will error out. This variable affects the choice of CUDNN convolution algorithms. Please see [CUDNN developer guide](https://docs.nvidia.com/deeplearning/sdk/cudnn-developer-guide/index.html) for more details. +* MXNET_OPTIMIZER_AGGREGATION_SIZE + - Values: Int ```(default=4)``` + - Maximum value is 60. + - This variable controls how many weights will be updated in a single call to optimizer (for optimizers that support aggregation, currently limited to SGD). + Settings for Minimum Memory Usage --------------------------------- - Make sure ```min(MXNET_EXEC_NUM_TEMP, MXNET_GPU_WORKER_NTHREADS) = 1``` diff --git a/tests/python/unittest/test_gluon_trainer.py b/tests/python/unittest/test_gluon_trainer.py index 72c01acb2652..ba36f1e44563 100644 --- a/tests/python/unittest/test_gluon_trainer.py +++ b/tests/python/unittest/test_gluon_trainer.py @@ -98,6 +98,9 @@ def dict_equ(a, b): @with_seed() def test_trainer_save_load(): + previous_update_on_kvstore = os.getenv('MXNET_UPDATE_ON_KVSTORE', "1") + os.putenv('MXNET_UPDATE_ON_KVSTORE', '1') + x = gluon.Parameter('x', shape=(10,), lr_mult=1.0) x.initialize(ctx=[mx.cpu(0), mx.cpu(1)], init='zeros') trainer = gluon.Trainer([x], 'sgd', {'learning_rate': 0.1}) @@ -112,6 +115,7 @@ def test_trainer_save_load(): x.lr_mult = 2.0 # check if parameter dict is correctly associated with optimizer after load_state assert trainer._kvstore._updater.optimizer._get_lr(0) == 0.2 + os.putenv('MXNET_UPDATE_ON_KVSTORE', previous_update_on_kvstore) @with_seed() def test_trainer_sparse_save_load(): @@ -230,7 +234,8 @@ def check_trainer_sparse_kv(kv, stype, grad_stype, update_on_kv): assert (updated_w == -0.2).asnumpy().all() kvs = ['local', 'device'] + global_update_on_kvstore = bool(int(os.getenv('MXNET_UPDATE_ON_KVSTORE', "1"))) for kv in kvs: - check_trainer_sparse_kv(kv, 'default', 'default', True) + check_trainer_sparse_kv(kv, 'default', 'default', global_update_on_kvstore) check_trainer_sparse_kv(kv, 'default', 'row_sparse', False) - check_trainer_sparse_kv(kv, 'row_sparse', 'row_sparse', True) + check_trainer_sparse_kv(kv, 'row_sparse', 'row_sparse', global_update_on_kvstore) diff --git a/tests/python/unittest/test_module.py b/tests/python/unittest/test_module.py index 7347723a39c6..d9d7175f540e 100644 --- a/tests/python/unittest/test_module.py +++ b/tests/python/unittest/test_module.py @@ -174,6 +174,8 @@ def test_module_layout(): @with_seed() def test_save_load(): + previous_update_on_kvstore = os.getenv('MXNET_UPDATE_ON_KVSTORE', "1") + os.putenv('MXNET_UPDATE_ON_KVSTORE', '1') def dict_equ(a, b): assert set(a) == set(b) for k in a: @@ -211,6 +213,7 @@ def dict_equ(a, b): assert mod._symbol.tojson() == mod2._symbol.tojson() dict_equ(mod.get_params()[0], mod2.get_params()[0]) dict_equ(mod._kvstore._updater.states, mod2._updater.states) + os.putenv('MXNET_UPDATE_ON_KVSTORE', previous_update_on_kvstore) @with_seed() From 08314a6bb2083489f1003c7dd54aacd72d7c9b42 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Mon, 3 Dec 2018 16:44:50 -0800 Subject: [PATCH 08/11] Fix --- tests/python/unittest/test_gluon_trainer.py | 1 + 1 file changed, 1 insertion(+) diff --git a/tests/python/unittest/test_gluon_trainer.py b/tests/python/unittest/test_gluon_trainer.py index ba36f1e44563..3dd293032f2d 100644 --- a/tests/python/unittest/test_gluon_trainer.py +++ b/tests/python/unittest/test_gluon_trainer.py @@ -17,6 +17,7 @@ import mxnet as mx import unittest +import os import numpy as np from mxnet import gluon from mxnet.gluon import nn From 47bfcf4f0bc576889e3e174b03cfe1104f1b793a Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Thu, 20 Dec 2018 16:24:28 -0800 Subject: [PATCH 09/11] Aggregate SGD support for Gluon trainer --- python/mxnet/gluon/trainer.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/python/mxnet/gluon/trainer.py b/python/mxnet/gluon/trainer.py index c4d49e82c908..f5b74964847f 100644 --- a/python/mxnet/gluon/trainer.py +++ b/python/mxnet/gluon/trainer.py @@ -356,6 +356,8 @@ def update(self, batch_size, ignore_stale_grad=False): self._update(ignore_stale_grad) def _update(self, ignore_stale_grad=False): + updates = [[] for _ in self._updaters] + for i, param in enumerate(self._params): if param.grad_req == 'null': continue @@ -379,11 +381,17 @@ def _update(self, ignore_stale_grad=False): self._kvstore.pull(i, param.list_data(), priority=-i) continue - for upd, arr, grad in zip(self._updaters, param.list_data(), param.list_grad()): + for upd, arr, grad in zip(updates, param.list_data(), param.list_grad()): if not ignore_stale_grad or arr._fresh_grad: - upd(i, grad, arr) + upd.append((i, grad, arr)) arr._fresh_grad = False + if not (self._kvstore and self._update_on_kvstore): + for updater, upd in zip(self._updaters, updates): + if upd: + i, w, g = zip(*upd) + updater(i, w, g) + def save_states(self, fname): """Saves trainer states (e.g. optimizer, momentum) to a file. From d2752da4a5946ae639c8f0b88545125bf1feb928 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Fri, 4 Jan 2019 13:27:59 -0800 Subject: [PATCH 10/11] Added text to doc about aggregate update in SGD optimizer --- python/mxnet/optimizer/optimizer.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/python/mxnet/optimizer/optimizer.py b/python/mxnet/optimizer/optimizer.py index 42f798b52e4c..9d387d4fd9c6 100644 --- a/python/mxnet/optimizer/optimizer.py +++ b/python/mxnet/optimizer/optimizer.py @@ -514,6 +514,13 @@ class SGD(Optimizer): provides slightly different semantics than the original update, and may lead to different empirical results. + In the case when ``update_on_kvstore`` is set to False (either globally via + MXNET_UPDATE_ON_KVSTORE=0 environment variable or as a parameter in + :class:`~mxnet.gluon.Trainer`) SGD optimizer can perform aggregated update + of parameters, which may lead to improved performance. The aggregation size + is controlled by MXNET_OPTIMIZER_AGGREGATION_SIZE environment variable and + defaults to 4. + Otherwise, **standard updates** are applied by:: rescaled_grad = lr * (rescale_grad * clip(grad, clip_gradient) + wd * weight) From cc37a34fdeb12f97c39b1b5cc28e42e629c6c885 Mon Sep 17 00:00:00 2001 From: Przemek Tredak Date: Wed, 16 Jan 2019 10:07:48 -0800 Subject: [PATCH 11/11] Docs changes from review --- python/mxnet/gluon/trainer.py | 3 ++- src/operator/optimizer_op.cc | 16 ---------------- 2 files changed, 2 insertions(+), 17 deletions(-) diff --git a/python/mxnet/gluon/trainer.py b/python/mxnet/gluon/trainer.py index b6d424977c0c..8060f38ac2aa 100644 --- a/python/mxnet/gluon/trainer.py +++ b/python/mxnet/gluon/trainer.py @@ -60,7 +60,8 @@ class Trainer(object): See mxnet.KVStore.set_gradient_compression method for more details on gradient compression. update_on_kvstore : bool, default None Whether to perform parameter updates on kvstore. If None, then trainer will choose the more - suitable option depending on the type of kvstore. + suitable option depending on the type of kvstore. If the `update_on_kvstore` argument is + provided, environment variable `MXNET_UPDATE_ON_KVSTORE` will be ignored. Properties ---------- diff --git a/src/operator/optimizer_op.cc b/src/operator/optimizer_op.cc index 1f4c840cfb50..81d5cf5434c0 100644 --- a/src/operator/optimizer_op.cc +++ b/src/operator/optimizer_op.cc @@ -367,14 +367,6 @@ It updates the weights using:: Where the parameter ``momentum`` is the decay rate of momentum estimates at each epoch. -However, if grad's storage type is ``row_sparse``, ``lazy_update`` is True and weight's storage -type is the same as momentum's storage type, -only the row slices whose indices appear in grad.indices are updated (for both weight and momentum):: - - for row in gradient.indices: - v[row] = momentum[row] * v[row] - learning_rate * gradient[row] - weight[row] += v[row] - )code" ADD_FILELINE) .set_num_inputs([](const nnvm::NodeAttrs& attrs) { const MultiSGDMomParam& param = dmlc::get(attrs.parsed); @@ -473,14 +465,6 @@ It updates the weights using:: Where the parameter ``momentum`` is the decay rate of momentum estimates at each epoch. -However, if grad's storage type is ``row_sparse``, ``lazy_update`` is True and weight's storage -type is the same as momentum's storage type, -only the row slices whose indices appear in grad.indices are updated (for both weight and momentum):: - - for row in gradient.indices: - v[row] = momentum[row] * v[row] - learning_rate * gradient[row] - weight[row] += v[row] - )code" ADD_FILELINE) .set_num_inputs([](const nnvm::NodeAttrs& attrs) { const MultiSGDMomParam& param = dmlc::get(attrs.parsed);