From cf13244303a480b7accf8ad69d89386533ede208 Mon Sep 17 00:00:00 2001 From: tqchen Date: Sun, 27 Sep 2015 20:10:53 -0400 Subject: [PATCH] [MODEL] Support update on kvstore --- example/imagenet/alexnet.py | 6 ++-- example/imagenet/data.py | 8 ++--- python/mxnet/model.py | 59 ++++++++++++++++++++++++++-------- tests/python/train/test_mlp.py | 3 +- 4 files changed, 54 insertions(+), 22 deletions(-) diff --git a/example/imagenet/alexnet.py b/example/imagenet/alexnet.py index e4e1663406c4..9a74631a2174 100644 --- a/example/imagenet/alexnet.py +++ b/example/imagenet/alexnet.py @@ -16,7 +16,7 @@ conv2 = mx.symbol.Convolution( data=lrn1, kernel=(5, 5), pad=(2, 2), num_filter=256) relu2 = mx.symbol.Activation(data=conv2, act_type="relu") -pool2 = mx.symbol.Pooling(data=relu2, kernel=(3, 3), stride=(2, 2)) +pool2 = mx.symbol.Pooling(data=relu2, kernel=(3, 3), stride=(2, 2), pool_type="max") lrn2 = mx.symbol.LRN(data=pool2, alpha=0.0001, beta=0.75, knorm=1, nsize=5) # stage 3 conv3 = mx.symbol.Convolution( @@ -28,7 +28,7 @@ conv5 = mx.symbol.Convolution( data=relu4, kernel=(3, 3), pad=(1, 1), num_filter=256) relu5 = mx.symbol.Activation(data=conv5, act_type="relu") -pool3 = mx.symbol.Pooling(data=relu5, kernel=(3, 3), stride=(2, 2)) +pool3 = mx.symbol.Pooling(data=relu5, kernel=(3, 3), stride=(2, 2), pool_type="max") # stage 4 flatten = mx.symbol.Flatten(data=pool3) fc1 = mx.symbol.FullyConnected(data=flatten, num_hidden=4096) @@ -48,7 +48,7 @@ train, val = ilsvrc12_iterator(batch_size=batch_size, input_shape=(3,224,224)) ## train -num_gpus = 2 +num_gpus = 4 gpus = [mx.gpu(i) for i in range(num_gpus)] model = mx.model.FeedForward( ctx = gpus, diff --git a/example/imagenet/data.py b/example/imagenet/data.py index cfca1db5e084..2f53902b3c96 100644 --- a/example/imagenet/data.py +++ b/example/imagenet/data.py @@ -7,8 +7,8 @@ def ilsvrc12_iterator(batch_size, input_shape): """return train and val iterators for imagenet""" train_dataiter = mx.io.ImageRecordIter( - path_imgrec = "data/ilsvrc12/train.rec", - mean_img = "data/ilsvrc12/mean.bin", + path_imgrec = "data/train.rec", + mean_img = "data/mean.bin", rand_crop = True, rand_mirror = True, prefetch_buffer = 4, @@ -16,8 +16,8 @@ def ilsvrc12_iterator(batch_size, input_shape): data_shape = input_shape, batch_size = batch_size) val_dataiter = mx.io.ImageRecordIter( - path_imgrec = "data/ilsvrc12/val.rec", - mean_img = "data/ilsvrc12/mean.bin", + path_imgrec = "data/val.rec", + mean_img = "data/mean.bin", rand_crop = False, rand_mirror = False, prefetch_buffer = 4, diff --git a/python/mxnet/model.py b/python/mxnet/model.py index c6f1665f1524..b0b4f46ccb65 100644 --- a/python/mxnet/model.py +++ b/python/mxnet/model.py @@ -122,6 +122,7 @@ def _train_multi_device(symbol, ctx, input_shape, begin_round, end_round, optimizer, train_data, eval_data=None, eval_metric=None, iter_end_callback=None, epoch_end_callback=None, + update_on_kvstore=False, logger=None): """Internal training function on multiple devices. @@ -172,12 +173,18 @@ def _train_multi_device(symbol, ctx, input_shape, epoch_end_callback: callable(iteration) A callback that is invoked at end of each batch + update_on_kvstore: boolean, optional + Whether to perform parameter update on kvstore instead of training device. + logger : logging logger When not specified, default logger will be used. Notes ----- - This function will inplace update the NDArrays in arg_parans and aux_states. + - This function will inplace update the NDArrays in arg_parans and aux_states. + - Turning update_on_kvstore on and off can affect speed of multi-gpu training. + - update_on_kvstore=True works well for inception type nets that contains many small weights. + - update_on_kvstore=False works better for Alexnet style net with bulk weights. """ if logger is None: logger = logging @@ -203,9 +210,11 @@ def _train_multi_device(symbol, ctx, input_shape, for texec in train_execs: texec.copy_params_from(arg_params, aux_params) - # ky value store kv = kvstore.create() if num_device != 1 else None + if kv is None: + update_on_kvstore = False + opt_state_blocks = [] # If there are multiple devices, initialize the weights. for index, pair in enumerate(zip(arg_blocks, grad_blocks)): @@ -214,11 +223,20 @@ def _train_multi_device(symbol, ctx, input_shape, if kv: kv.init(index, arg_list[0]) # attach state direct to weight - opt_list = [optimizer.create_state(index, w) for w in arg_list] - opt_state_blocks.append(opt_list) + if update_on_kvstore: + opt_state_blocks.append(nd.zeros(arg_list[0].shape, cpu())) + else: + opt_list = [optimizer.create_state(index, w) for w in arg_list] + opt_state_blocks.append(opt_list) else: opt_state_blocks.append(None) + def kv_updater(index, grad, weight): + """Internal updater on KVstore, used when update_on_kvstore=True.""" + optimizer.update(index, weight, grad, opt_state_blocks[index]) + if update_on_kvstore: + kv.set_updater(kv_updater) + # Input and output data structure data_index, label_index = _check_arguments(symbol) merged_shape = list(train_execs[0].outputs[0].shape) @@ -255,12 +273,17 @@ def _train_multi_device(symbol, ctx, input_shape, if kv: # push gradient, priority is negative index kv.push(index, grad_list, priority=-index) - # pull back the sum, to the same locations. - kv.pull(index, grad_list, priority=-index) - opt_list = opt_state_blocks[index] - # optimizea - for w, g, state in zip(arg_list, grad_list, opt_list): - optimizer.update(index, w, g, state) + if update_on_kvstore: + # pull back the weights + kv.pull(index, arg_list, priority=-index) + else: + # pull back the sum gradients, to the same locations. + kv.pull(index, grad_list, priority=-index) + if not update_on_kvstore: + opt_list = opt_state_blocks[index] + # optimizea + for w, g, state in zip(arg_list, grad_list, opt_list): + optimizer.update(index, w, g, state) nbatch += 1 # epoch callback (for print purpose) if epoch_end_callback != None: @@ -562,7 +585,8 @@ def predict(self, X): return np.concatenate(outputs) def fit(self, X, y=None, eval_data=None, eval_metric='acc', - iter_end_callback=None, epoch_end_callback=None, logger=None): + iter_end_callback=None, epoch_end_callback=None, + update_on_kvstore=False, logger=None): """Fit the model. Parameters @@ -592,6 +616,9 @@ def fit(self, X, y=None, eval_data=None, eval_metric='acc', A callback that is invoked at end of each batch For print purpose + update_on_kvstore: boolean, optional + Whether to perform parameter update on kvstore instead of training device. + logger : logging logger, optional When not specified, default logger will be used. """ @@ -622,7 +649,7 @@ def fit(self, X, y=None, eval_data=None, eval_metric='acc', eval_metric=eval_metric, iter_end_callback=iter_end_callback, epoch_end_callback=epoch_end_callback, - logger=logger) + update_on_kvstore=update_on_kvstore, logger=logger) def save(self, prefix, iteration=None): """Checkpoint the model checkpoint into file. @@ -684,7 +711,7 @@ def load(prefix, iteration, ctx=None): def create(symbol, X, y=None, ctx=None, num_round=None, optimizer='sgd', initializer=Xavier(), eval_data=None, eval_metric='acc', iter_end_callback=None, - logger=None, **kwargs): + update_on_kvstore=False, logger=None, **kwargs): """Functional style to create a model. This function will be more consistent with functional @@ -726,10 +753,14 @@ def create(symbol, X, y=None, ctx=None, A callback that is invoked at end of each iteration. This can be used to checkpoint model each iteration. + update_on_kvstore: boolean, optional + Whether to perform parameter update on kvstore instead of training device. + logger : logging logger, optional """ model = FeedForward(symbol, ctx=ctx, num_round=num_round, optimizer=optimizer, initializer=initializer, **kwargs) model.fit(X, y, eval_data=eval_data, eval_metric=eval_metric, - iter_end_callback=iter_end_callback, logger=logger) + iter_end_callback=iter_end_callback, + update_on_kvstore=update_on_kvstore, logger=logger) return model diff --git a/tests/python/train/test_mlp.py b/tests/python/train/test_mlp.py index 14e8f6c700a8..fdf89142fc11 100644 --- a/tests/python/train/test_mlp.py +++ b/tests/python/train/test_mlp.py @@ -53,7 +53,8 @@ def test_mlp(): ctx=[mx.cpu(i) for i in range(2)], num_round=num_round, learning_rate=0.01, wd=0.0004, - momentum=0.9) + momentum=0.9, + update_on_kvstore=True) logging.info('Finish traning...') prob = model.predict(val_dataiter)