Skip to content
This repository was archived by the owner on Nov 17, 2023. It is now read-only.
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions example/imagenet/alexnet.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,7 @@
conv2 = mx.symbol.Convolution(
data=lrn1, kernel=(5, 5), pad=(2, 2), num_filter=256)
relu2 = mx.symbol.Activation(data=conv2, act_type="relu")
pool2 = mx.symbol.Pooling(data=relu2, kernel=(3, 3), stride=(2, 2))
pool2 = mx.symbol.Pooling(data=relu2, kernel=(3, 3), stride=(2, 2), pool_type="max")
lrn2 = mx.symbol.LRN(data=pool2, alpha=0.0001, beta=0.75, knorm=1, nsize=5)
# stage 3
conv3 = mx.symbol.Convolution(
Expand All @@ -28,7 +28,7 @@
conv5 = mx.symbol.Convolution(
data=relu4, kernel=(3, 3), pad=(1, 1), num_filter=256)
relu5 = mx.symbol.Activation(data=conv5, act_type="relu")
pool3 = mx.symbol.Pooling(data=relu5, kernel=(3, 3), stride=(2, 2))
pool3 = mx.symbol.Pooling(data=relu5, kernel=(3, 3), stride=(2, 2), pool_type="max")
# stage 4
flatten = mx.symbol.Flatten(data=pool3)
fc1 = mx.symbol.FullyConnected(data=flatten, num_hidden=4096)
Expand All @@ -48,7 +48,7 @@
train, val = ilsvrc12_iterator(batch_size=batch_size, input_shape=(3,224,224))

## train
num_gpus = 2
num_gpus = 4
gpus = [mx.gpu(i) for i in range(num_gpus)]
model = mx.model.FeedForward(
ctx = gpus,
Expand Down
8 changes: 4 additions & 4 deletions example/imagenet/data.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,17 +7,17 @@
def ilsvrc12_iterator(batch_size, input_shape):
"""return train and val iterators for imagenet"""
train_dataiter = mx.io.ImageRecordIter(
path_imgrec = "data/ilsvrc12/train.rec",
mean_img = "data/ilsvrc12/mean.bin",
path_imgrec = "data/train.rec",
mean_img = "data/mean.bin",
rand_crop = True,
rand_mirror = True,
prefetch_buffer = 4,
preprocess_threads = 4,
data_shape = input_shape,
batch_size = batch_size)
val_dataiter = mx.io.ImageRecordIter(
path_imgrec = "data/ilsvrc12/val.rec",
mean_img = "data/ilsvrc12/mean.bin",
path_imgrec = "data/val.rec",
mean_img = "data/mean.bin",
rand_crop = False,
rand_mirror = False,
prefetch_buffer = 4,
Expand Down
59 changes: 45 additions & 14 deletions python/mxnet/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -122,6 +122,7 @@ def _train_multi_device(symbol, ctx, input_shape,
begin_round, end_round, optimizer,
train_data, eval_data=None, eval_metric=None,
iter_end_callback=None, epoch_end_callback=None,
update_on_kvstore=False,
logger=None):
"""Internal training function on multiple devices.

Expand Down Expand Up @@ -172,12 +173,18 @@ def _train_multi_device(symbol, ctx, input_shape,
epoch_end_callback: callable(iteration)
A callback that is invoked at end of each batch

update_on_kvstore: boolean, optional
Whether to perform parameter update on kvstore instead of training device.

logger : logging logger
When not specified, default logger will be used.

Notes
-----
This function will inplace update the NDArrays in arg_parans and aux_states.
- This function will inplace update the NDArrays in arg_parans and aux_states.
- Turning update_on_kvstore on and off can affect speed of multi-gpu training.
- update_on_kvstore=True works well for inception type nets that contains many small weights.
- update_on_kvstore=False works better for Alexnet style net with bulk weights.
"""
if logger is None:
logger = logging
Expand All @@ -203,9 +210,11 @@ def _train_multi_device(symbol, ctx, input_shape,

for texec in train_execs:
texec.copy_params_from(arg_params, aux_params)

# ky value store
kv = kvstore.create() if num_device != 1 else None
if kv is None:
update_on_kvstore = False

opt_state_blocks = []
# If there are multiple devices, initialize the weights.
for index, pair in enumerate(zip(arg_blocks, grad_blocks)):
Expand All @@ -214,11 +223,20 @@ def _train_multi_device(symbol, ctx, input_shape,
if kv:
kv.init(index, arg_list[0])
# attach state direct to weight
opt_list = [optimizer.create_state(index, w) for w in arg_list]
opt_state_blocks.append(opt_list)
if update_on_kvstore:
opt_state_blocks.append(nd.zeros(arg_list[0].shape, cpu()))
else:
opt_list = [optimizer.create_state(index, w) for w in arg_list]
opt_state_blocks.append(opt_list)
else:
opt_state_blocks.append(None)

def kv_updater(index, grad, weight):
"""Internal updater on KVstore, used when update_on_kvstore=True."""
optimizer.update(index, weight, grad, opt_state_blocks[index])
if update_on_kvstore:
kv.set_updater(kv_updater)

# Input and output data structure
data_index, label_index = _check_arguments(symbol)
merged_shape = list(train_execs[0].outputs[0].shape)
Expand Down Expand Up @@ -255,12 +273,17 @@ def _train_multi_device(symbol, ctx, input_shape,
if kv:
# push gradient, priority is negative index
kv.push(index, grad_list, priority=-index)
# pull back the sum, to the same locations.
kv.pull(index, grad_list, priority=-index)
opt_list = opt_state_blocks[index]
# optimizea
for w, g, state in zip(arg_list, grad_list, opt_list):
optimizer.update(index, w, g, state)
if update_on_kvstore:
# pull back the weights
kv.pull(index, arg_list, priority=-index)
else:
# pull back the sum gradients, to the same locations.
kv.pull(index, grad_list, priority=-index)
if not update_on_kvstore:
opt_list = opt_state_blocks[index]
# optimizea
for w, g, state in zip(arg_list, grad_list, opt_list):
optimizer.update(index, w, g, state)
nbatch += 1
# epoch callback (for print purpose)
if epoch_end_callback != None:
Expand Down Expand Up @@ -562,7 +585,8 @@ def predict(self, X):
return np.concatenate(outputs)

def fit(self, X, y=None, eval_data=None, eval_metric='acc',
iter_end_callback=None, epoch_end_callback=None, logger=None):
iter_end_callback=None, epoch_end_callback=None,
update_on_kvstore=False, logger=None):
"""Fit the model.

Parameters
Expand Down Expand Up @@ -592,6 +616,9 @@ def fit(self, X, y=None, eval_data=None, eval_metric='acc',
A callback that is invoked at end of each batch
For print purpose

update_on_kvstore: boolean, optional
Whether to perform parameter update on kvstore instead of training device.

logger : logging logger, optional
When not specified, default logger will be used.
"""
Expand Down Expand Up @@ -622,7 +649,7 @@ def fit(self, X, y=None, eval_data=None, eval_metric='acc',
eval_metric=eval_metric,
iter_end_callback=iter_end_callback,
epoch_end_callback=epoch_end_callback,
logger=logger)
update_on_kvstore=update_on_kvstore, logger=logger)

def save(self, prefix, iteration=None):
"""Checkpoint the model checkpoint into file.
Expand Down Expand Up @@ -684,7 +711,7 @@ def load(prefix, iteration, ctx=None):
def create(symbol, X, y=None, ctx=None,
num_round=None, optimizer='sgd', initializer=Xavier(),
eval_data=None, eval_metric='acc', iter_end_callback=None,
logger=None, **kwargs):
update_on_kvstore=False, logger=None, **kwargs):
"""Functional style to create a model.

This function will be more consistent with functional
Expand Down Expand Up @@ -726,10 +753,14 @@ def create(symbol, X, y=None, ctx=None,
A callback that is invoked at end of each iteration.
This can be used to checkpoint model each iteration.

update_on_kvstore: boolean, optional
Whether to perform parameter update on kvstore instead of training device.

logger : logging logger, optional
"""
model = FeedForward(symbol, ctx=ctx, num_round=num_round,
optimizer=optimizer, initializer=initializer, **kwargs)
model.fit(X, y, eval_data=eval_data, eval_metric=eval_metric,
iter_end_callback=iter_end_callback, logger=logger)
iter_end_callback=iter_end_callback,
update_on_kvstore=update_on_kvstore, logger=logger)
return model
3 changes: 2 additions & 1 deletion tests/python/train/test_mlp.py
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ def test_mlp():
ctx=[mx.cpu(i) for i in range(2)],
num_round=num_round,
learning_rate=0.01, wd=0.0004,
momentum=0.9)
momentum=0.9,
update_on_kvstore=True)

logging.info('Finish traning...')
prob = model.predict(val_dataiter)
Expand Down