diff --git a/example/quantization/README.md b/example/quantization/README.md new file mode 100644 index 000000000000..5f7e5c679305 --- /dev/null +++ b/example/quantization/README.md @@ -0,0 +1,191 @@ + + + + + + + + + + + + + + + + + +# Model Quantization with Calibration Examples + +This folder contains examples of quantizing a FP32 model with Intel® oneAPI Deep Neural Network Library (oneDNN) to (U)INT8 model. + +

Model Quantization with Intel® oneDNN

+ +Intel® oneDNN supports quantization with subgraph features on Intel® CPU Platform and can bring performance improvements on the [Intel® Xeon® Scalable Platform](https://www.intel.com/content/www/us/en/processors/xeon/scalable/xeon-scalable-platform.html). + +``` +usage: python imagenet_gen_qsym_onednn.py [-h] [--model MODEL] [--epoch EPOCH] + [--no-pretrained] [--batch-size BATCH_SIZE] + [--calib-dataset CALIB_DATASET] + [--image-shape IMAGE_SHAPE] + [--data-nthreads DATA_NTHREADS] + [--num-calib-batches NUM_CALIB_BATCHES] + [--exclude-first-conv] [--shuffle-dataset] + [--calib-mode CALIB_MODE] + [--quantized-dtype {auto,int8,uint8}] + [--quiet] + +Generate a calibrated quantized model from a FP32 model with Intel oneDNN support + +optional arguments: + -h, --help show this help message and exit + --model MODEL model to be quantized. If no-pretrained is set then + model must be provided to `model` directory in the same path + as this python script, default is `resnet50_v1` + --epoch EPOCH number of epochs, default is `0` + --no-pretrained If enabled, will not download pretrained model from + MXNet or Gluon-CV modelzoo, default is `False` + --batch-size BATCH_SIZE + batch size to be used when calibrating model, default is `32` + --calib-dataset CALIB_DATASET + path of the calibration dataset, default is `data/val_256_q90.rec` + --image-shape IMAGE_SHAPE + number of channels, height and width of input image separated by comma, + default is `3,224,224` + --data-nthreads DATA_NTHREADS + number of threads for data loading, default is `0` + --num-calib-batches NUM_CALIB_BATCHES + number of batches for calibration, default is `10` + --exclude-first-conv excluding quantizing the first conv layer since the + input data may have negative value which doesn't + support at moment + --shuffle-dataset shuffle the calibration dataset + --calib-mode CALIB_MODE + calibration mode used for generating calibration table + for the quantized symbol; supports 1. none: no + calibration will be used. The thresholds for + quantization will be calculated on the fly. This will + result in inference speed slowdown and loss of + accuracy in general. 2. naive: simply take min and max + values of layer outputs as thresholds for + quantization. In general, the inference accuracy + worsens with more examples used in calibration. It is + recommended to use `entropy` mode as it produces more + accurate inference results. 3. entropy: calculate KL + divergence of the FP32 output and quantized output for + optimal thresholds. This mode is expected to produce + the best inference accuracy of all three kinds of + quantized models if the calibration dataset is + representative enough of the inference dataset. + default is `entropy` + --quantized-dtype {auto,int8,uint8} + quantization destination data type for input data, + default is `auto` + --quiet suppress most of log +``` + +A new benchmark script `launch_inference_onednn.sh` has been designed to launch performance benchmark for FP32 or INT8 image-classification models with Intel® oneDNN. +``` +usage: bash ./launch_inference_onednn.sh -s symbol_file [-b batch_size] [-iter iteraton] [-ins instance] [-c cores/instance] [-h] + +arguments: + -h, --help show this help message and exit + -s, --symbol_file symbol file for benchmark, required + -b, --batch_size inference batch size + default: 64 + -iter, --iteration inference iteration + default: 500 + -ins, --instance launch multi-instance inference + default: one instance per socket + -c, --core number of cores per instance + default: divide full physical cores + +example: resnet INT8 performance benchmark on c5.24xlarge(duo sockets, 24 physical cores per socket). + + bash ./launch_inference_onednn.sh -s ./model/resnet50_v1-quantized-5batches-naive-symbol.json + +will launch two instances for throughput benchmark and each instance will use 24 physical cores. +``` + +The following models have been tested on Linux systems. Accuracy is collected on Intel XEON Cascade Lake CPU. For CPU with Skylake Lake or eariler architecture, the accuracy may not be the same. +| Model | Source | Dataset | FP32 Accuracy (top-1/top-5)| INT8 Accuracy (top-1/top-5)| +|:---|:---|---|:---:|:---:| +| ResNet18-V1 | [MXNet ModelZoo](https://github.com/apache/incubator-mxnet/tree/master/python/mxnet/gluon/model_zoo) | [Validation Dataset](http://data.mxnet.io/data/val_256_q90.rec) |70.45%/89.55%|70.22%/89.38%| +| ResNet50-V1 | [MXNet ModelZoo](https://github.com/apache/incubator-mxnet/tree/master/python/mxnet/gluon/model_zoo) | [Validation Dataset](http://data.mxnet.io/data/val_256_q90.rec) |76.36%/93.49%|76.04%/93.30%| +| ResNet101-V1 | [MXNet ModelZoo](https://github.com/apache/incubator-mxnet/tree/master/python/mxnet/gluon/model_zoo) | [Validation Dataset](http://data.mxnet.io/data/val_256_q90.rec) |78.23%/93.99%|77.85%/93.69%| +| MobileNet v2 1.0 | [MXNet ModelZoo](https://github.com/apache/incubator-mxnet/tree/master/python/mxnet/gluon/model_zoo) | [Validation Dataset](http://data.mxnet.io/data/val_256_q90.rec) |71.72%/90.28%|71.22%/89.92%| +| VGG16 | [MXNet ModelZoo](https://github.com/apache/incubator-mxnet/tree/master/python/mxnet/gluon/model_zoo) | [Validation Dataset](http://data.mxnet.io/data/val_256_q90.rec) |72.83%/91.11%|72.81%/91.10%| +| VGG19 | [MXNet ModelZoo](https://github.com/apache/incubator-mxnet/tree/master/python/mxnet/gluon/model_zoo) | [Validation Dataset](http://data.mxnet.io/data/val_256_q90.rec) |73.67%/91.63%|73.67%/91.67%| +*Measured on validation ImageNet (ILSVRC2012) with batch-size=64, num-calib-batches=10 and calib-mode=entropy* + +

Pre-trained Model

+ +The following command is to download the pre-trained model from [MXNet ModelZoo](http://data.mxnet.io/models/imagenet/resnet/152-layers/) and transfer it into the symbolic model which would be finally quantized. The [validation dataset](http://data.mxnet.io/data/val_256_q90.rec) is available for testing the pre-trained models: + +``` +python imagenet_gen_qsym_onednn.py --model=resnet50_v1 --num-calib-batches=5 --calib-mode=naive +``` + +The model would be automatically replaced in fusion and quantization format. It is then saved as the quantized symbol and parameter files in the `./model` directory. Set `--model` to one of above listed verified models to quantize them. The following command is to launch inference. + +``` +# Launch FP32 Inference +python imagenet_inference.py --symbol-file=./model/resnet50_v1-symbol.json --param-file=./model/resnet50_v1-0000.params --rgb-mean=0.485,0.456,0.406 --rgb-std=0.229,0.224,0.225 --num-skipped-batches=50 --batch-size=64 --num-inference-batches=500 --dataset=./data/val_256_q90.rec + +# Launch INT8 Inference +python imagenet_inference.py --symbol-file=./model/resnet50_v1-quantized-5batches-naive-symbol.json --param-file=./model/resnet50_v1-quantized-0000.params --rgb-mean=0.485,0.456,0.406 --rgb-std=0.229,0.224,0.225 --num-skipped-batches=50 --batch-size=64 --num-inference-batches=500 --dataset=./data/val_256_q90.rec + +# Launch dummy data Inference +bash ./launch_inference_onednn.sh -s ./model/resnet50_v1-symbol.json +bash ./launch_inference_onednn.sh -s ./model/resnet50_v1-quantized-5batches-naive-symbol.json +``` + +

Custom Model

+ +This script also supports custom symbolic models. Quantization layer configs can easily be added in `imagenet_gen_qsym_onednn.py` like below: + +``` +if logger: + frameinfo = getframeinfo(currentframe()) + logger.info(F'Please set proper RGB configs inside this script below {frameinfo.filename}:{frameinfo.lineno} for model {args.model}!') +# add rgb mean/std of your model. +rgb_mean = '0,0,0' +rgb_std = '0,0,0' +# add layer names that shouldn't be quantized. +if logger: + frameinfo = getframeinfo(currentframe()) + logger.info(F'Please set proper excluded_sym_names inside this script below {frameinfo.filename}:{frameinfo.lineno} for model {args.model} if required!') +excluded_sym_names += [] +if exclude_first_conv: + excluded_sym_names += [] +``` + +Some tips on quantization configs: + +1. First, data, symbol file (custom-symbol.json) and parameter file (custom-0000.params) of FP32 symbolic model should be prepared. +2. Then, following command should be run to verify that FP32 symbolic model runs inference as expected. + +``` +# Launch FP32 Inference +python imagenet_inference.py --symbol-file=./model/custom-symbol.json --param-file=./model/custom-0000.params --rgb-mean=* --rgb-std=* --num-skipped-batches=* --batch-size=* --num-inference-batches=*--dataset=./data/* +``` + +3. Proper `rgb_mean`, `rgb_std` and `excluded_sym_names` should be added in `imagenet_gen_qsym_onednn.py` script. + +4. Run following command for quantization: + +``` +python imagenet_gen_qsym_onednn.py --model=custom --num-calib-batches=5 --calib-mode=naive +``` + +5. After quantization, the quantized symbol and parameter files will be saved in the `model/` directory. + +6. Finally, INT8 inference can be run: + +``` +# Launch INT8 Inference +python imagenet_inference.py --symbol-file=./model/resnet50_v1-quantized-10batches-entropy-symbol.json --param-file=./model/resnet50_v1-quantized-10batches-entropy-0000.params --benchmark + +# Launch dummy data Inference +bash ./launch_inference_onednn.sh -s ./model/*.json +``` \ No newline at end of file diff --git a/example/quantization/imagenet_gen_qsym_onednn.py b/example/quantization/imagenet_gen_qsym_onednn.py new file mode 100644 index 000000000000..bb75a056d1b5 --- /dev/null +++ b/example/quantization/imagenet_gen_qsym_onednn.py @@ -0,0 +1,272 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import argparse +import logging +import os +import re +import sys +from inspect import currentframe, getframeinfo + +import mxnet as mx +from mxnet import gluon +from mxnet.contrib.quantization import quantize_net +from mxnet.gluon.data import DataLoader +from mxnet.gluon.data.vision import transforms +from mxnet.gluon.model_zoo.vision import get_model + +sys.path.append('../..') +from tools.rec2idx import IndexCreator + + +def download_calib_dataset(dataset_url, calib_dataset, logger=None): + if logger is not None: + logger.info('Downloading calibration dataset from %s to %s' % (dataset_url, calib_dataset)) + mx.test_utils.download(dataset_url, calib_dataset) + +def get_from_gluon(model_name, classes=1000, logger=None): + dir_path = os.path.dirname(os.path.realpath(__file__)) + model_path = os.path.join(dir_path, 'model') + if logger is not None: + logger.info('Converting model from Gluon-CV ModelZoo %s... into path %s' % (model_name, model_path)) + net = get_model(name=model_name, classes=classes, pretrained=True) + prefix = os.path.join(model_path, model_name) + return net, prefix + +def regex_find_excluded_symbols(patterns_dict, model_name): + for key, value in patterns_dict.items(): + if re.search(key, model_name) is not None: + return value + return None + +def get_exclude_symbols(model_name, exclude_first_conv): + """Grouped supported models at the time of commit: + - alexnet + - densenet121, densenet161 + - densenet169, densenet201 + - inceptionv3 + - mobilenet0.25, mobilenet0.5, mobilenet0.75, mobilenet1.0, + - mobilenetv2_0.25, mobilenetv2_0.5, mobilenetv2_0.75, mobilenetv2_1.0 + - resnet101_v1, resnet152_v1, resnet18_v1, resnet34_v1, resnet50_v1 + - resnet101_v2, resnet152_v2, resnet18_v2, resnet34_v2, resnet50_v2 + - squeezenet1.0, squeezenet1.1 + - vgg11, vgg11_bn, vgg13, vgg13_bn, vgg16, vgg16_bn, vgg19, vgg19_bn + """ + exclude_symbol_regex = { + 'mobilenet[^v]': ['mobilenet_hybridsequential0_flatten0_flatten0', 'mobilenet_hybridsequential0_globalavgpool2d0_fwd'], + 'mobilenetv2': ['mobilenetv2_hybridsequential1_flatten0_flatten0'], + # resnetv2_hybridsequential0_hybridsequential0_bottleneckv20_batchnorm0_fwd is excluded for the sake of accuracy + 'resnet.*v2': ['resnetv2_hybridsequential0_flatten0_flatten0', 'resnetv2_hybridsequential0_hybridsequential0_bottleneckv20_batchnorm0_fwd'], + 'squeezenet1': ['squeezenet_hybridsequential1_flatten0_flatten0'], + } + excluded_sym_names = regex_find_excluded_symbols(exclude_symbol_regex, model_name) + if excluded_sym_names is None: + excluded_sym_names = [] + if exclude_first_conv: + first_conv_regex = { + 'alexnet': ['alexnet_hybridsequential0_conv2d0_fwd'], + 'densenet': ['densenet_hybridsequential0_conv2d0_fwd'], + 'inceptionv3': ['inception3_hybridsequential0_hybridsequential0_conv2d0_fwd'], + 'mobilenet[^v]': ['mobilenet_hybridsequential0_conv2d0_fwd'], + 'mobilenetv2': ['mobilenetv2_hybridsequential0_conv2d0_fwd'], + 'resnet.*v1': ['resnetv1_hybridsequential0_conv2d0_fwd'], + 'resnet.*v2': ['resnetv2_hybridsequential0_conv2d0_fwd'], + 'squeezenet1': ['squeezenet_hybridsequential0_conv2d0_fwd'], + 'vgg': ['vgg_hybridsequential0_conv2d0_fwd'], + } + excluded_first_conv_sym_names = regex_find_excluded_symbols(first_conv_regex, model_name) + if excluded_first_conv_sym_names is None: + raise ValueError('Currently, model %s is not supported in this script' % model_name) + excluded_sym_names += excluded_first_conv_sym_names + return excluded_sym_names + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Generate a calibrated quantized model from a FP32 model with Intel oneDNN support') + parser.add_argument('--model', type=str, default='resnet50_v1', + help='model to be quantized. If no-pretrained is set then' + 'model must be provided to `model` directory in the same path' + 'as this python script') + parser.add_argument('--epoch', type=int, default=0, + help='number of epochs, default is 0') + parser.add_argument('--no-pretrained', action='store_true', default=False, + help='If enabled, will not download pretrained model from MXNet or Gluon-CV modelzoo.') + parser.add_argument('--batch-size', type=int, default=32) + parser.add_argument('--calib-dataset', type=str, default='data/val_256_q90.rec', + help='path of the calibration dataset') + parser.add_argument('--image-shape', type=str, default='3,224,224', + help='number of channels, height and width of input image separated by comma') + parser.add_argument('--data-nthreads', type=int, default=0, + help='number of threads for data loading') + parser.add_argument('--num-calib-batches', type=int, default=10, + help='number of batches for calibration') + parser.add_argument('--exclude-first-conv', action='store_true', default=False, + help='excluding quantizing the first conv layer since the' + ' input data may have negative value which doesn\'t support at moment' ) + parser.add_argument('--shuffle-dataset', action='store_true', + help='shuffle the calibration dataset') + parser.add_argument('--calib-mode', type=str, default='entropy', + help='calibration mode used for generating calibration table for the quantized symbol; supports' + ' 1. none: no calibration will be used. The thresholds for quantization will be calculated' + ' on the fly. This will result in inference speed slowdown and loss of accuracy' + ' in general.' + ' 2. naive: simply take min and max values of layer outputs as thresholds for' + ' quantization. In general, the inference accuracy worsens with more examples used in' + ' calibration. It is recommended to use `entropy` mode as it produces more accurate' + ' inference results.' + ' 3. entropy: calculate KL divergence of the fp32 output and quantized output for optimal' + ' thresholds. This mode is expected to produce the best inference accuracy of all three' + ' kinds of calibration modes if the calibration dataset is representative enough of the' + ' inference dataset.') + parser.add_argument('--quantized-dtype', type=str, default='auto', + choices=['auto', 'int8', 'uint8'], + help='quantization destination data type for input data') + parser.add_argument('--quiet', action='store_true', default=False, + help='suppress most of log') + args = parser.parse_args() + ctx = mx.cpu(0) + logger = None + + if not args.quiet: + logging.basicConfig() + logger = logging.getLogger('logger') + logger.setLevel(logging.INFO) + + if logger: + logger.info(args) + logger.info('shuffle_dataset=%s' % args.shuffle_dataset) + logger.info('calibration mode set to %s' % args.calib_mode) + + calib_mode = args.calib_mode + + # download calibration dataset + if calib_mode != 'none': + idx_file_name = os.path.splitext(args.calib_dataset)[0] + '.idx' + if not os.path.isfile(idx_file_name): + download_calib_dataset('http://data.mxnet.io/data/val_256_q90.rec', args.calib_dataset) + creator = IndexCreator(args.calib_dataset, idx_file_name) + creator.create_index() + creator.close() + + # get image shape + image_shape = args.image_shape + data_shape = [(1,) + tuple(int(i) for i in image_shape.split(','))] + + # check if directory for output model exists + dir_path = os.path.dirname(os.path.realpath(__file__)) + dir_path = os.path.join(dir_path, 'model') + if not os.path.exists(dir_path): + os.mkdir(dir_path) # without try catch block as we expect to finish + # script if it fail + + # download model + if not args.no_pretrained: + if logger: + logger.info('Get pre-trained model from Gluon-CV modelzoo.') + logger.info('If you want to use custom model, please set --no-pretrained.') + net, prefix = get_from_gluon(model_name=args.model, classes=1000, logger=logger) + rgb_mean = '0.485,0.456,0.406' + rgb_std = '0.229,0.224,0.225' + epoch = 0 + net.hybridize() + net(mx.nd.zeros(data_shape[0])) # dummy forward pass to build graph + net.export(prefix) # save model + net.hybridize(active=False) # disable hybridization - it will be handled in quantization API + else: + prefix = os.path.join(dir_path, args.model) + epoch = args.epoch + net = gluon.SymbolBlock.imports("{}-symbol.json".format(prefix), ['data'], "{}-0000.params".format(prefix)) + + + # get batch size + batch_size = args.batch_size + if logger: + logger.info('batch size = %d for calibration' % batch_size) + + # get number of batches for calibration + num_calib_batches = args.num_calib_batches + if logger: + if calib_mode == 'none': + logger.info('skip calibration step as calib_mode is none') + else: + logger.info('number of batches = %d for calibration' % num_calib_batches) + + # get number of threads for decoding the dataset + data_nthreads = args.data_nthreads + + exclude_first_conv = args.exclude_first_conv + if args.quantized_dtype == "uint8": + if logger: + logger.info('quantized dtype is set to uint8, will exclude first conv.') + exclude_first_conv = True + excluded_sym_names = [] + if not args.no_pretrained: + excluded_sym_names += get_exclude_symbols(args.model, args.exclude_first_conv) + else: + if logger: + frameinfo = getframeinfo(currentframe()) + logger.info(F'Please set proper RGB configs inside this script below {frameinfo.filename}:{frameinfo.lineno} for model {args.model}!') + # add rgb mean/std of your model. + rgb_mean = '0,0,0' + rgb_std = '0,0,0' + # add layer names you donnot want to quantize. + if logger: + frameinfo = getframeinfo(currentframe()) + logger.info(F'Please set proper excluded_sym_names inside this script below {frameinfo.filename}:{frameinfo.lineno} for model {args.model} if required!') + excluded_sym_names += [] + if exclude_first_conv: + excluded_sym_names += [] + + if logger: + logger.info('These layers have been excluded %s' % excluded_sym_names) + logger.info('Input data shape = %s' % str(data_shape)) + logger.info('rgb_mean = %s' % rgb_mean) + logger.info('rgb_std = %s' % rgb_std) + + rgb_mean = [float(i) for i in rgb_mean.split(',')] + mean_args = {'mean_r': rgb_mean[0], 'mean_g': rgb_mean[1], 'mean_b': rgb_mean[2]} + rgb_std = [float(i) for i in rgb_std.split(',')] + std_args = {'std_r': rgb_std[0], 'std_g': rgb_std[1], 'std_b': rgb_std[2]} + if calib_mode == 'none': + if logger: + logger.info('Quantizing FP32 model %s' % args.model) + qsym = quantize_net(net, ctx=ctx, exclude_layers_match=excluded_sym_names, data_shapes=data_shape, + calib_mode=calib_mode, quantized_dtype=args.quantized_dtype, + logger=logger) + suffix = '-quantized' + else: + if logger: + logger.info('Creating DataLoader for reading calibration dataset') + dataset = mx.gluon.data.vision.ImageRecordDataset(args.calib_dataset) + transformer = transforms.Compose([transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize(mean=rgb_mean, std=rgb_std)]) + data_loader = DataLoader(dataset.transform_first(transformer), batch_size, shuffle=args.shuffle_dataset, num_workers=data_nthreads) + qsym = quantize_net(net, ctx=ctx, exclude_layers_match=excluded_sym_names, + calib_mode=calib_mode, calib_data=data_loader, num_calib_batches=num_calib_batches, + quantized_dtype=args.quantized_dtype, logger=logger) + if calib_mode == 'entropy': + suffix = '-quantized-%dbatches-entropy' % num_calib_batches + elif calib_mode == 'naive': + suffix = '-quantized-%dbatches-naive' % num_calib_batches + else: + raise ValueError('unknow calibration mode %s received, only supports `none`, `naive`, and `entropy`' + % calib_mode) + save_path = prefix + suffix + model_path, params_path = qsym.export(save_path, epoch) + if logger is not None: + logger.info(F'Saved quantized model into:\n{model_path}\n{params_path}') diff --git a/example/quantization/imagenet_inference.py b/example/quantization/imagenet_inference.py new file mode 100644 index 000000000000..fae4ab97a886 --- /dev/null +++ b/example/quantization/imagenet_inference.py @@ -0,0 +1,187 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +import argparse +import logging +import time + +import mxnet as mx +import numpy as np +from mxnet import gluon +from mxnet.gluon.data import DataLoader +from mxnet.gluon.data.vision import transforms + + +def download_dataset(dataset_url, dataset_dir, logger=None): + if logger is not None: + logger.info('Downloading dataset for inference from %s to %s' % (dataset_url, dataset_dir)) + mx.test_utils.download(dataset_url, dataset_dir) + + +def score(symblock, data, ctx, max_num_examples, skip_num_batches, logger=None): + metrics = [gluon.metric.create('acc')], + gluon.metric.create('top_k_accuracy', top_k=5)] + + # make sure that fp32 inference works on the same images as calibrated quantized model + logger.info('Skipping the first %d batches' % skip_num_batches) + + tic = time.time() + num = 0 + for i, input_data in enumerate(data): + if i < skip_num_batches: + continue + x = input_data[0].as_in_context(ctx) + label = input_data[1].as_in_context(ctx) + outputs = symblock.forward(x) + for m in metrics: + m.update(label, outputs) + num += batch_size + if max_num_examples is not None and num >= max_num_examples: + break + + speed = num / (time.time() - tic) + + if logger is not None: + logger.info('Finished inference with %d images' % num) + logger.info('Finished with %f images per second', speed) + for m in metrics: + logger.info(m.get()) + +def initialize_block_params(block, initializer): + for _, param in block.collect_params('.*gamma|.*moving_var|.*running_var').items(): + param.initialize(mx.init.Constant(1)) + for _, param in block.collect_params('.*beta|.*moving_mean|.*running_mean|.*bias').items(): + param.initialize(mx.init.Constant(0)) + for _, param in block.collect_params('.*weight').items(): + param.initialize(initializer) + +def benchmark_score(symblock, ctx, batch_size, warmup_batches, num_batches, data_layer_type): + if data_layer_type == "int8": + dshape = mx.io.DataDesc(name='data', shape=( + batch_size,) + data_shape, dtype=np.int8) + elif data_layer_type == 'uint8': + dshape = mx.io.DataDesc(name='data', shape=( + batch_size,) + data_shape, dtype=np.uint8) + else: # float32 + dshape = mx.io.DataDesc(name='data', shape=( + batch_size,) + data_shape, dtype=np.float32) + + # get data + if data_layer_type == "float32": + data = [mx.random.uniform(-1.0, 1.0, shape=shape, ctx=ctx, dtype=data_layer_type) + for _, shape in [dshape]] + else: + data = [mx.nd.full(shape=shape, val=127, ctx=ctx, dtype=data_layer_type) + for _, shape in [dshape]] + + # run + for i in range(warmup_batches+num_batches): + if i == warmup_batches: + tic = time.time() + outputs = symblock.forward(*data) + for output in outputs: + output.wait_to_read() + + # return num images per second + return num_batches * batch_size / (time.time() - tic) + +if __name__ == '__main__': + parser = argparse.ArgumentParser(description='Score a model on a dataset') + parser.add_argument('--ctx', type=str, default='cpu') + parser.add_argument('--benchmark', type=bool, default=False, help='dummy data benchmark') + parser.add_argument('--symbol-file', type=str, required=True, help='symbol file path') + parser.add_argument('--param-file', type=str, required=False, help='param file path') + parser.add_argument('--batch-size', type=int, default=32) + parser.add_argument('--dataset', type=str, required=False, help='dataset path') + parser.add_argument('--rgb-mean', type=str, default='0,0,0') + parser.add_argument('--rgb-std', type=str, default='1,1,1') + parser.add_argument('--image-shape', type=str, default='3,224,224') + parser.add_argument('--data-nthreads', type=int, default=60, help='number of threads for data decoding') + parser.add_argument('--num-skipped-batches', type=int, default=0, help='skip the number of batches for inference') + parser.add_argument('--num-inference-batches', type=int, required=True, help='number of images used for inference') + parser.add_argument('--num-warmup-batches', type=int, default=5, help='number of warmup batches used for benchmark') + parser.add_argument('--shuffle-dataset', action='store_true', default=True, + help='shuffle the score dataset') + parser.add_argument('--data-layer-type', type=str, default='float32', + choices=['float32', 'int8', 'uint8'], + help='data type for data layer (only with --benchmark)') + + args = parser.parse_args() + + logging.basicConfig() + logger = logging.getLogger('logger') + logger.setLevel(logging.INFO) + + if args.ctx == 'cpu': + ctx = mx.cpu(0) + elif args.ctx == 'gpu': + ctx = mx.gpu(0) + logger.warning('Notice that oneDNN optimized and quantized model may not work with GPU context') + else: + raise ValueError('ctx %s is not supported in this script' % args.ctx) + + symbol_file = args.symbol_file + param_file = args.param_file + data_nthreads = args.data_nthreads + + batch_size = args.batch_size + logger.info('batch size = %d for inference' % batch_size) + + rgb_mean = args.rgb_mean + logger.info('rgb_mean = %s' % rgb_mean) + rgb_mean = [float(i) for i in rgb_mean.split(',')] + rgb_std = args.rgb_std + logger.info('rgb_std = %s' % rgb_std) + rgb_std = [float(i) for i in rgb_std.split(',')] + + image_shape = args.image_shape + data_shape = tuple([int(i) for i in image_shape.split(',')]) + logger.info('Input data shape = %s' % str(data_shape)) + + data_layer_type = args.data_layer_type + + if not args.benchmark: + dataset = args.dataset + download_dataset('http://data.mxnet.io/data/val_256_q90.rec', dataset) + logger.info('Dataset for inference: %s' % dataset) + + dataset = mx.gluon.data.vision.ImageRecordDataset(dataset) + transformer = transforms.Compose([transforms.Resize(256), + transforms.CenterCrop(224), + transforms.ToTensor(), + transforms.Normalize(mean=rgb_mean, std=rgb_std)]) + data_loader = DataLoader(dataset.transform_first( + transformer), batch_size, shuffle=args.shuffle_dataset, num_workers=data_nthreads) + + # loading model + symblock = gluon.SymbolBlock.imports(symbol_file, ['data'], param_file) + + num_inference_images = args.num_inference_batches * batch_size + logger.info('Running model %s for inference' % symbol_file) + score(symblock, data_loader, ctx, max_num_examples=num_inference_images, + skip_num_batches=args.num_skipped_batches, logger=logger) + else: + # loading model + symblock = gluon.SymbolBlock.imports(symbol_file, ['data']) + initialize_block_params(symblock, mx.init.One()) + + logger.info(f'Running model {symbol_file} for inference.') + logger.info(f'Warmup batches: {args.num_warmup_batches}') + logger.info(f'Inference batches: {args.num_inference_batches}') + speed = benchmark_score(symblock, ctx, batch_size, + args.num_warmup_batches, args.num_inference_batches, data_layer_type) + logger.info('batch size %2d, image/sec: %f', batch_size, speed) diff --git a/example/quantization/launch_inference_onednn.sh b/example/quantization/launch_inference_onednn.sh new file mode 100755 index 000000000000..a98fb29ea6e3 --- /dev/null +++ b/example/quantization/launch_inference_onednn.sh @@ -0,0 +1,116 @@ +#!/bin/sh + +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +usage() +{ + echo "usage: bash ./launch_inference_onednn.sh [[[-s symbol_file ] [-b batch_size] [-iter iteraton] [-ins instance] [-c cores/instance]] | [-h]]" +} + +while [ $# -gt 0 ]; do + case "$1" in + --symbol | -s) + shift + SYMBOL=$1 + ;; + --batch-size | -b) + shift + BS=$1 + ;; + --iteration | -iter) + shift + ITERATIONS=$1 + ;; + --instance | -ins) + shift + INS=$1 + ;; + --core | -c) + shift + CORES=$1 + ;; + --help | -h) + usage + exit 1 + ;; + *) + usage + exit 1 + esac + shift +done + +NUM_SOCKET=`lscpu | grep 'Socket(s)' | awk '{print $NF}'` +NUM_NUMA_NODE=`lscpu | grep 'NUMA node(s)' | awk '{print $NF}'` +CORES_PER_SOCKET=`lscpu | grep 'Core(s) per socket' | awk '{print $NF}'` +NUM_CORES=$((CORES_PER_SOCKET * NUM_SOCKET)) +CORES_PER_NUMA=$((NUM_CORES / NUM_NUMA_NODE)) +echo "target machine has $NUM_CORES physical core(s) on $NUM_NUMA_NODE numa nodes of $NUM_SOCKET socket(s)." + +if [ -z $SYMBOL ]; then + echo "Error: Need a symbol file as input." +fi +if [ -z $INS ]; then + echo "Default: launch one instance per socket." + INS=$NUM_SOCKET +fi +if [ -z $CORES ]; then + echo "Default: divide full physical cores." + CORES=$((NUM_CORES / $INS)) +fi +if [ -z $BS ]; then + echo "Default: set batch size to 64." + BS=64 +fi +if [ -z $ITERATIONS ]; then + echo "Default: set iterations to 500." + ITERATIONS=500 +fi + +echo " benchmark configs" +echo " cores per instance: $CORES" +echo " total instances: $INS" +echo " batch size: $BS" +echo " iterations: $ITERATIONS" +echo "" + +rm BENCHMARK_*.log || echo "benchmarking..." + +i=0 +while [ "$i" -lt $INS ]; do + a=$((i * CORES)) + b=$((a + CORES - 1)) + memid=$((b/CORES_PER_NUMA % NUM_NUMA_NODE)) + LOG=BENCHMARK_$i.log + echo " Instance $i use $a-$b cores and mem $memid with $LOG" + KMP_AFFINITY=granularity=fine,noduplicates,compact,1,0 \ + OMP_NUM_THREADS=$CORES \ + nohup numactl --physcpubind=$a-$b --membind=$memid python imagenet_inference.py --symbol-file=$SYMBOL --batch-size=$BS --num-inference-batches=$ITERATIONS --benchmark=True > $LOG 2>&1 & + i=$(( i + 1 )) +done +wait + +fps=`grep image/sec BENCHMARK_*.log | awk '{ sum += $(NF) }; END { print sum }'` +if [ -z "$fps" ]; then + echo "FPS not found in benchmark log." + return 1 +fi +latency=$(awk "BEGIN {printf \"%.2f\", 1000*${BS}*${INS}/${fps}}") +echo "overall throughput (image/sec): $fps" +echo "latency per batch per instance (ms): $latency" +echo "benchmark finish:)" diff --git a/python/mxnet/contrib/__init__.py b/python/mxnet/contrib/__init__.py index ab4303f3d845..9b67d5f7e24d 100644 --- a/python/mxnet/contrib/__init__.py +++ b/python/mxnet/contrib/__init__.py @@ -29,4 +29,6 @@ from . import text from . import onnx from . import io +from . import quantization +from . import quantization as quant from . import tensorrt diff --git a/python/mxnet/contrib/quantization.py b/python/mxnet/contrib/quantization.py new file mode 100644 index 000000000000..6ea74fe7f97d --- /dev/null +++ b/python/mxnet/contrib/quantization.py @@ -0,0 +1,938 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Quantization module for generating quantized (INT8) models from FP32 models.""" + +import abc +import ctypes +import logging +import os +import warnings +import numpy as np +import mxnet as mx +from ..base import _LIB, check_call, py_str +from ..base import c_array, c_str, mx_uint, mx_real_t, c_str_array +from ..base import SymbolHandle +from ..symbol import Symbol +from .. import ndarray +from ..io import DataDesc +from ..context import cpu, Context +from ..util import is_np_array + +def _quantize_params(qsym, params, min_max_dict): + """Given a quantized symbol and a dict of params that have not been quantized, + generate quantized params. Currently only supports quantizing the arg_params + with names of `weight` or `bias`, not aux_params. If `qsym` contains symbols + that are excluded from being quantized, their corresponding params will + not be quantized, but saved together with quantized params of the symbols that + have been quantized. + + Parameters + ---------- + qsym : Symbol + Quantized symbol from FP32 symbol. + params : dict of str->NDArray + min_max_dict: dict of min/max pairs of layers' output + """ + inputs_name = qsym.list_arguments() + quantized_params = {} + if is_np_array(): + quantize_fn = mx.npx.contrib_quantize + min_fn = lambda arr: mx.np.array([mx.np.min(arr)]) + max_fn = lambda arr: mx.np.array([mx.np.max(arr)]) + array_cls = mx.np + else: + quantize_fn = mx.nd.contrib.quantize + min_fn = mx.nd.min + max_fn = mx.nd.max + array_cls = mx.nd + + for name in inputs_name: + if name.endswith(('weight_quantize', 'bias_quantize')): + original_name = name[:-len('_quantize')] + param = params[original_name] + # pylint: disable=unbalanced-tuple-unpacking + param_min = min_fn(param) + param_max = max_fn(param) + val, vmin, vmax = quantize_fn(data=param, + min_range=param_min, + max_range=param_max, + out_type='int8') + quantized_params[name] = val + quantized_params[name+'_min'] = vmin + quantized_params[name+'_max'] = vmax + elif name in params: + quantized_params[name] = params[name] + elif name.endswith(('_min')): + output = name[: - len('_min')] + if output in min_max_dict: + quantized_params[name] = array_cls.array([min_max_dict[output][0]]) + elif name.endswith(('_max')): + output = name[: - len('_min')] + if output in min_max_dict: + quantized_params[name] = array_cls.array([min_max_dict[output][1]]) + return quantized_params + +def _quantize_symbol(sym, ctx, excluded_symbols=None, excluded_operators=None, + offline_params=None, quantized_dtype='int8', quantize_mode='smart', + quantize_granularity='tensor-wise'): + """Given a symbol object representing a neural network of data type FP32, + quantize it into a INT8 network. + + Parameters + ---------- + sym : Symbol + FP32 neural network symbol. + ctx : Context + Defines the device that users want to run quantized symbol. + excluded_symbols : list of strings + A list of strings representing the names of the symbols that users want to excluding + from being quantized. + excluded_operators : list of strings + A list of strings representing the names of the operators that users want to excluding + from being quantized. + offline_params : list of strs + Names of the parameters that users want to quantize offline. It's always recommended to + quantize parameters offline so that quantizing parameters during the inference can be + avoided. + quantized_dtype: str + The quantized destination type for input data. + quantize_mode: str + The mode that quantization pass to apply. + quantize_granularity: str + The granularity of quantization, currently supports 'tensor-wise' and 'channel-wise' + quantization. The default value is 'tensor-wise'. + """ + num_excluded_symbols = 0 + if excluded_symbols is not None: + assert isinstance(excluded_symbols, list) + num_excluded_symbols = len(excluded_symbols) + else: + excluded_symbols = [] + + num_excluded_ops = 0 + if excluded_operators is not None: + assert isinstance(excluded_operators, list) + num_excluded_ops = len(excluded_operators) + else: + excluded_operators = [] + + num_offline = 0 + offline = [] + if offline_params is not None: + num_offline = len(offline_params) + for k in offline_params: + offline.append(c_str(k)) + + out = SymbolHandle() + size = mx_uint() + calib_str = ctypes.POINTER(ctypes.c_char_p)() + check_call(_LIB.MXQuantizeSymbol(sym.handle, + ctypes.byref(out), + ctypes.byref(ctypes.c_int(ctx.device_typeid)), + mx_uint(num_excluded_symbols), + c_str_array(excluded_symbols), + mx_uint(num_excluded_ops), + c_str_array(excluded_operators), + mx_uint(num_offline), + c_array(ctypes.c_char_p, offline), + c_str(quantized_dtype), + ctypes.c_bool(True), + c_str(quantize_mode), + c_str(quantize_granularity), + ctypes.byref(size), + ctypes.byref(calib_str))) + calib_layers = [] + calib_layers = [py_str(calib_str[i]) for i in range(size.value)] + return Symbol(out), calib_layers + + +class CalibrationCollector(object): + """Base class for all other collectors used with quantization""" + __metaclass__ = abc.ABCMeta + + def __init__(self): + self.include_layers = None + self.min_max_dict = {} + + @abc.abstractmethod + def collect(self, name, op_name, arr): + """Function which is registered to Block as monitor callback. Names of layers + requiring calibration are stored in `self.include_layers` variable. + Parameters + ---------- + name : str + Node name from which collected data comes from + op_name : str + Operator name from which collected data comes from. Single operator + can have multiple inputs/ouputs nodes - each should have different name + arr : NDArray + NDArray containing data of monitored node + """ + + def post_collect(self): + """ Function called after collecting parameters. Returns dictionary of min and max values + for each calibrated layer. If not overriden, returns content of `self.min_max_dict`. + """ + return self.min_max_dict + + +class _LayerHistogramCollector(CalibrationCollector): + """Saves layer histogram in a dict with layer names as keys and lists of NDArrays as + values. The collected histogram will be used for calculating the optimal thresholds for + quantization using KL divergence. + """ + def __init__(self, quantized_dtype, num_bins=8001, include_layers=None, logger=None): + super(_LayerHistogramCollector, self).__init__() + self.hist_dict = {} + self.num_bins = num_bins + self.include_layers = include_layers + self.logger = logger + self.quantized_dtype = quantized_dtype + + def collect(self, name, op_name, arr): + """Callback function for collecting layer output NDArrays.""" + if name not in self.include_layers: + return + arr = arr.copyto(cpu()).asnumpy() + if self.logger: + self.logger.debug("Collecting layer %s histogram of shape %s" % (name, arr.shape)) + min_range = np.min(arr) + max_range = np.max(arr) + th = max(abs(min_range), abs(max_range)) + if name in self.hist_dict: + self.hist_dict[name] = self.combine_histogram(self.hist_dict[name], arr, min_range, max_range, th) + else: + hist, hist_edges = np.histogram(arr, bins=self.num_bins, range=(-th, th)) + self.hist_dict[name] = (hist, hist_edges, min_range, max_range, th) + + def post_collect(self): + min_max_dict = self.get_optimal_thresholds(self.hist_dict, self.quantized_dtype, logger=self.logger) + return min_max_dict + + @staticmethod + def combine_histogram(old_hist, arr, new_min, new_max, new_th): + """ Collect layer histogram for arr and combine it with old histogram. + """ + (old_hist, old_hist_edges, old_min, old_max, old_th) = old_hist + if new_th <= old_th: + hist, _ = np.histogram(arr, bins=len(old_hist), range=(-old_th, old_th)) + return (old_hist + hist, old_hist_edges, min(old_min, new_min), max(old_max, new_max), old_th) + else: + # Need to generate new histogram with new_th + old_num_bins = len(old_hist) + old_step = 2 * old_th / old_num_bins + half_increased_bins = int((new_th - old_th) // old_step + 1) + new_num_bins = half_increased_bins * 2 + old_num_bins + new_th = half_increased_bins * old_step + old_th + hist, hist_edges = np.histogram(arr, bins=new_num_bins, range=(-new_th, new_th)) + hist[half_increased_bins:new_num_bins - half_increased_bins] += old_hist + return (hist, hist_edges, min(old_min, new_min), max(old_max, new_max), new_th) + + # pylint: disable=line-too-long + @staticmethod + def get_optimal_threshold(hist_data, quantized_dtype, num_quantized_bins=255): + """Given a dataset, find the optimal threshold for quantizing it. + The reference distribution is `q`, and the candidate distribution is `p`. + `q` is a truncated version of the original distribution. + + Ref: http://on-demand.gputechconf.com/gtc/2017/presentation/s7310-8-bit-inference-with-tensorrt.pdf + """ + (hist, hist_edges, min_val, max_val, _) = hist_data + num_bins = len(hist) + assert (num_bins % 2 == 1) + if min_val >= 0 and quantized_dtype in ['auto', 'uint8']: + # We need to move negative bins to positive bins to fit uint8 range. + num_quantized_bins = num_quantized_bins * 2 + 1 + hist = ndarray.array(hist, ctx=cpu()) + hist_edges = ndarray.array(hist_edges, ctx=cpu()) + threshold, divergence = ndarray.contrib.calibrate_entropy(hist=hist, + hist_edges=hist_edges, + num_quantized_bins=num_quantized_bins) + threshold = threshold.asnumpy() + divergence = divergence.asnumpy() + return min_val, max_val, threshold, divergence + # pylint: enable=line-too-long + + @staticmethod + def get_optimal_thresholds(hist_dict, quantized_dtype, num_quantized_bins=255, logger=None): + """Given a ndarray dict, find the optimal threshold for quantizing each value of the key.""" + assert isinstance(hist_dict, dict) + if logger is not None: + logger.info('Calculating optimal thresholds for quantization using KL divergence' + ' with num_quantized_bins=%d' % num_quantized_bins) + th_dict = {} + # copy hist_dict keys since the keys() only returns a view in python3 + layer_names = list(hist_dict.keys()) + for name in layer_names: + assert name in hist_dict + min_val, max_val, th, divergence = \ + _LayerHistogramCollector.get_optimal_threshold(hist_dict[name], quantized_dtype, + num_quantized_bins=num_quantized_bins) + if min_val >= 0 and quantized_dtype in ['auto', 'uint8']: + th_dict[name] = (0, th) + else: + th_dict[name] = (-th, th) + del hist_dict[name] # release the memory + if logger: + logger.debug(f"layer={name}, min_val={min_val}, max_val={max_val}, th={th}, divergence={divergence}") + return th_dict + +class _LayerOutputMinMaxCollector(CalibrationCollector): + """Saves layer output min and max values in a dict with layer names as keys. + The collected min and max values will be directly used as thresholds for quantization. + """ + def __init__(self, quantized_dtype, include_layers=None, logger=None): + super(_LayerOutputMinMaxCollector, self).__init__() + self.min_max_dict = {} + self.quantized_dtype = quantized_dtype + self.include_layers = include_layers + self.logger = logger + + def collect(self, name, op_name, arr): + """Callback function for collecting min and max values from an NDArray.""" + if name not in self.include_layers: + return + arr = arr.copyto(cpu()).asnumpy() + min_range = np.min(arr) + max_range = np.max(arr) + if name in self.min_max_dict: + cur_min_max = self.min_max_dict[name] + self.min_max_dict[name] = (min(cur_min_max[0], min_range), + max(cur_min_max[1], max_range)) + else: + self.min_max_dict[name] = (min_range, max_range) + if self.logger: + self.logger.debug("Collecting layer %s min_range=%f, max_range=%f" + % (name, min_range, max_range)) + +def _calibrate_quantized_sym(qsym, min_max_dict): + """Given a dictionary containing the thresholds for quantizing the layers, + set the thresholds into the quantized symbol as the params of requantize operators. + """ + if min_max_dict is None or len(min_max_dict) == 0: + return qsym + num_layer_outputs = len(min_max_dict) + layer_output_names = [] + min_vals = [] + max_vals = [] + for k, v in min_max_dict.items(): + layer_output_names.append(k) + min_vals.append(v[0]) + max_vals.append(v[1]) + + calibrated_sym = SymbolHandle() + check_call(_LIB.MXSetCalibTableToQuantizedSymbol(qsym.handle, + mx_uint(num_layer_outputs), + c_str_array(layer_output_names), + c_array(ctypes.c_float, min_vals), + c_array(ctypes.c_float, max_vals), + ctypes.byref(calibrated_sym))) + return Symbol(calibrated_sym) + + +def _collect_layer_statistics(sym_block, data, collector, num_inputs, num_calib_batches=None, logger=None): + if not isinstance(data, mx.gluon.data.DataLoader): + raise ValueError('Only supports data as a type of DataLoader, while received type %s' + % str(type(data))) + sym_block.register_op_hook(collector.collect, monitor_all=True) + num_batches = 0 + for batch in data: + if not isinstance(batch, list): + batch = [batch] + batch = [b.as_in_context(mx.cpu()) for b in batch] + sym_block(*batch[:num_inputs]) + num_batches += 1 + if num_calib_batches is not None and num_batches >= num_calib_batches: + break + if logger is not None: + logger.info("Collected statistics from %d batches" % (num_batches)) + return num_batches + + + + +def _generate_list_of_data_desc(data_shapes, data_types): + """"Convert list ot tuples to list of DataDesc.""" + if isinstance(data_shapes, list): + if all(isinstance(x, DataDesc) for x in data_shapes): + return data_shapes + if all(isinstance(x, tuple) for x in data_shapes): + if len(data_shapes) == 1: + data_shapes = [DataDesc(name='data', shape=data_shapes[0], dtype=data_types[0])] + else: + data_shapes = [DataDesc(name='data' + str(i), shape=data_shapes[i], + dtype=data_types[i]) for i in range(len(data_shapes))] + return data_shapes + raise ValueError('data_shapes must be either a list of DataDesc or a list of Tuple') + + +def quantize_model(sym, arg_params, aux_params, data_names=('data',), + ctx=cpu(), excluded_sym_names=None, excluded_op_names=None, calib_mode='entropy', + calib_data=None, num_calib_batches=None, + quantized_dtype='int8', quantize_mode='smart', + quantize_granularity='tensor-wise', logger=None): + """User-level API for generating a quantized model from a FP32 model w/ or w/o calibration. + The backend quantized operators are only enabled for Linux systems. Please do not run + inference using the quantized models on Windows for now. + The quantization implementation adopts the TensorFlow's approach: + https://www.tensorflow.org/performance/quantization. + The calibration implementation borrows the idea of Nvidia's 8-bit Inference with TensorRT: + http://on-demand.gputechconf.com/gtc/2017/presentation/s7310-8-bit-inference-with-tensorrt.pdf + and adapts the method to MXNet. + + .. _`quantize_model_params`: + Parameters + ---------- + sym : str or Symbol + Defines the structure of a neural network for FP32 data types. + arg_params : dict + Dictionary of name to `NDArray`. + aux_params : dict + Dictionary of name to `NDArray`. + data_names : a list of strs + Data names required for creating a Module object to run forward propagation on the + calibration dataset. + ctx : Context + Defines the device that users want to run forward propagation on the calibration + dataset for collecting layer output statistics. Currently, only supports single context. + excluded_sym_names : list of strings + A list of strings representing the names of the symbols that users want to excluding + from being quantized. + excluded_op_names : list of strings + A list of strings representing the names of the operators that users want to excluding + from being quantized. + calib_mode : str + If calib_mode='none', no calibration will be used and the thresholds for + requantization after the corresponding layers will be calculated at runtime by + calling min and max operators. The quantized models generated in this + mode are normally 10-20% slower than those with calibrations during inference. + If calib_mode='naive', the min and max values of the layer outputs from a calibration + dataset will be directly taken as the thresholds for quantization. + If calib_mode='entropy' (default mode), the thresholds for quantization will be + derived such that the KL divergence between the distributions of FP32 layer outputs and + quantized layer outputs is minimized based upon the calibration dataset. + calib_data : DataLoader + A DataLoader initialized by the calibration dataset. + num_calib_batches : int or None + The maximum number of batches that user would like to use for calibration. If not provided, + the whole calibration dataset will be used. + quantized_dtype : str + The quantized destination type for input data. Currently support 'int8', 'uint8' and 'auto'. + 'auto' means automatically select output type according to calibration result. + Default value is 'int8'. + quantize_mode : str + The mode that quantization pass to apply. Support 'full' and 'smart'. + 'full' means quantize all operator if possible. + 'smart' means quantization pass will smartly choice which operator should be quantized. + quantize_granularity: str + The granularity of quantization, currently supports 'tensor-wise' and 'channel-wise' + quantization. The default value is 'tensor-wise'. + logger : Object + A logging object for printing information during the process of quantization. + + Returns + ------- + quantized_model: tuple + A tuple of quantized symbol, quantized arg_params, and aux_params. + """ + warnings.warn('WARNING: This will be deprecated please use quantize_net with Gluon models') + if excluded_sym_names is None: + excluded_sym_names = [] + if not isinstance(excluded_sym_names, list): + raise ValueError('excluded_sym_names must be a list of strings representing' + ' the names of the symbols that will not be quantized,' + ' while received type %s' % str(type(excluded_sym_names))) + + if excluded_op_names is None: + excluded_op_names = [] + if not isinstance(excluded_op_names, list): + raise ValueError('excluded_op_names must be a list of strings representing' + ' the names of the operators that will not be quantized,' + ' while received type %s' % str(type(excluded_op_names))) + + if logger: + os.environ['MXNET_QUANTIZATION_VERBOSE'] = '1' + logger.info('Quantizing symbol') + if quantized_dtype not in ('int8', 'uint8', 'auto'): + raise ValueError('unknown quantized_dtype %s received,' + ' expected `int8`, `uint8` or `auto`' % quantized_dtype) + if quantize_granularity not in ('tensor-wise', 'channel-wise'): + raise ValueError('unkonwn quantize_granularity %s received,' + ' expected `tensor-wise` or `channel-wise`.' % quantize_granularity) + qsym, calib_layers = _quantize_symbol(sym, ctx, excluded_symbols=excluded_sym_names, + excluded_operators=excluded_op_names, + offline_params=list(arg_params.keys()), + quantized_dtype=quantized_dtype, + quantize_mode=quantize_mode, + quantize_granularity=quantize_granularity) + min_max_dict = {} + if calib_mode is not None and calib_mode != 'none': + if not isinstance(ctx, Context): + raise ValueError('currently only supports single ctx, while received %s' % str(ctx)) + if calib_data is None: + raise ValueError('calib_data must be provided when calib_mode=%s' % calib_mode) + if not isinstance(calib_data, mx.gluon.data.DataLoader): + raise ValueError('calib_data must be of DataLoader type when calib_mode=%s,' + ' while received type %s' % (calib_mode, str(type(calib_data)))) + + inputs = [mx.sym.var(dname) for dname in data_names] + param_dict = arg_params + param_dict.update(aux_params) + sym_block = mx.gluon.SymbolBlock(sym, inputs) + sym_block.load_dict(param_dict) + + if calib_mode == 'entropy': + collector = _LayerHistogramCollector(quantized_dtype=quantized_dtype, + include_layers=calib_layers, + logger=logger) + elif calib_mode == 'naive': + collector = _LayerOutputMinMaxCollector(quantized_dtype=quantized_dtype, + include_layers=calib_layers, + logger=logger) + + else: + raise ValueError('unknown calibration mode %s received,' + ' expected `none`, `naive`, or `entropy`' % calib_mode) + + num_batches = _collect_layer_statistics(sym_block, calib_data, collector, + len(inputs), num_calib_batches, logger) + if logger: + logger.info('Collected layer output min/max values from FP32 model using %d batches' + % num_batches) + logger.info('Performing calibration post collecting operations') + + min_max_dict = collector.post_collect() + qsym = _calibrate_quantized_sym(qsym, min_max_dict) + + if logger: + logger.info('Quantizing parameters') + qarg_params = _quantize_params(qsym, arg_params, min_max_dict) + + if is_np_array(): + qsym = qsym.as_np_ndarray() + + return qsym, qarg_params, aux_params + +def quantize_model_mkldnn(sym, arg_params, aux_params, data_names=('data',), + ctx=cpu(), excluded_sym_names=None, excluded_op_names=None, + calib_mode='entropy', calib_data=None, num_calib_batches=None, + quantized_dtype='int8', quantize_mode='smart', + quantize_granularity='tensor-wise', logger=None): + """User-level API for generating a fusion + quantized model from a FP32 model + w/ or w/o calibration with Intel MKL-DNN. + The backend quantized operators are only enabled for Linux systems. Please do not run + inference using the quantized models on Windows for now. + + Parameters + ---------- + all + :ref:`As in quantize_model` + + + Returns + ------- + quantized_model: tuple + A tuple of quantized symbol, quantized arg_params, and aux_params. + """ + if not isinstance(ctx, Context): + raise ValueError('currently only supports single ctx, while received %s' % str(ctx)) + if ctx.device_type != 'cpu': + raise ValueError( + 'quantize_model_mkldnn only support Intel cpu platform with MKL-DNN Backend') + + sym = sym.optimize_for(backend='MKLDNN_QUANTIZE') + + qsym, qarg_params, aux_params = quantize_model(sym=sym, arg_params=arg_params, aux_params=aux_params, + data_names=data_names, ctx=ctx, + excluded_sym_names=excluded_sym_names, + excluded_op_names=excluded_op_names, + calib_mode=calib_mode, calib_data=calib_data, + num_calib_batches=num_calib_batches, + quantized_dtype=quantized_dtype, quantize_mode=quantize_mode, + quantize_granularity=quantize_granularity, logger=logger) + + qsym = qsym.optimize_for(backend='MKLDNN_QUANTIZE') + + return qsym, qarg_params, aux_params + +def quantize_graph(sym, arg_params, aux_params, ctx=cpu(), + excluded_sym_names=None, excluded_op_names=None, + calib_mode='entropy', quantized_dtype='int8', + quantize_mode='full', quantize_granularity='tensor-wise', + LayerOutputCollector=None, logger=None): + """User-level API for generating a quantized model from a FP32 model w/o calibration + and a collector for naive or entropy calibration. + The backend quantized operators are only enabled for Linux systems. Please do not run + inference using the quantized models on Windows for now. + Parameters + ---------- + sym : str or Symbol + Defines the structure of a neural network for FP32 data types. + ctx : Context + Defines the device that users want to run forward propagation on the calibration + dataset for collecting layer output statistics. Currently, only supports single context. + arg_params : dict + Dictionary of name to `NDArray`. + aux_params : dict + Dictionary of name to `NDArray`. + excluded_sym_names : list of strings + A list of strings representing the names of the symbols that users want to excluding + from being quantized. + excluded_op_names : list of strings + A list of strings representing the names of the operators that users want to excluding + calib_mode : str + If calib_mode='none', no calibration will be used and the thresholds for + requantization after the corresponding layers will be calculated at runtime by + calling min and max operators. The quantized models generated in this + mode are normally 10-20% slower than those with calibrations during inference. + If calib_mode='naive', the min and max values of the layer outputs from a calibration + dataset will be directly taken as the thresholds for quantization. + If calib_mode='entropy' (default mode), the thresholds for quantization will be + derived such that the KL divergence between the distributions of FP32 layer outputs and + quantized layer outputs is minimized based upon the calibration dataset. + quantized_dtype : str + The quantized destination type for input data. Currently support 'int8' + , 'uint8' and 'auto'. 'auto' means automatically select output type according to calibration result. + Default value is 'int8'. + quantize_mode : str + The mode that quantization pass to apply. Support 'full' and 'smart'. + 'full' means quantize all operator if possible. + 'smart' means quantization pass will smartly choice which operator should be quantized. + quantize_granularity: str + The granularity of quantization, currently supports 'tensor-wise' and 'channel-wise' + quantization. The default value is 'tensor-wise'. + LayerOutputCollector : subclass of CalibrationCollector + For custom calibration method usage. + Passed object's include_layers attribute will be feed with names of layers which needs calibration + logger : Object + A logging object for printing information during the process of quantization. + Returns + ------- + quantized_model : tuple + A tuple of quantized symbol, quantized arg_params, aux_params and collector. + """ + if excluded_sym_names is None: + excluded_sym_names = [] + if not isinstance(excluded_sym_names, list): + raise ValueError('excluded_sym_names must be a list of strings representing' + ' the names of the symbols that will not be quantized,' + ' while received type %s' % str(type(excluded_sym_names))) + if not isinstance(ctx, Context): + raise ValueError('currently only supports single ctx, while received %s' % str(ctx)) + if logger: + os.environ['MXNET_QUANTIZATION_VERBOSE'] = '1' + logger.info('Quantizing graph') + if quantized_dtype not in ('int8', 'uint8', 'auto'): + raise ValueError('unknown quantized_dtype %s received,' + ' expected `int8`, `uint8` or `auto`' % quantized_dtype) + if quantize_granularity not in ('tensor-wise', 'channel-wise'): + raise ValueError('unkonwn quantize_granularity %s received,' + ' expected `tensor-wise` or `channel-wise`.' % quantize_granularity) + qsym, calib_layers = _quantize_symbol(sym, ctx, excluded_symbols=excluded_sym_names, + excluded_operators=excluded_op_names, + offline_params=list(arg_params.keys()), + quantized_dtype=quantized_dtype, + quantize_mode=quantize_mode, + quantize_granularity=quantize_granularity) + + collector = None + if calib_mode is not None and calib_mode != 'none': + if calib_mode == 'entropy': + collector = _LayerHistogramCollector(quantized_dtype=quantized_dtype, + include_layers=calib_layers, logger=logger) + if logger: + logger.info( + 'Create a layer output collector for entropy calibration.') + elif calib_mode == 'naive': + collector = _LayerOutputMinMaxCollector(quantized_dtype=quantized_dtype, + include_layers=calib_layers, logger=logger) + if logger: + logger.info( + 'Create a layer output minmax collector for naive calibration') + elif calib_mode == 'custom' and LayerOutputCollector is not None: + if not isinstance(LayerOutputCollector, CalibrationCollector): + raise ValueError('LayerOutputCollecotr must be a subclass of a CalibrationCollector class,' + ' but it is %s' % LayerOutputCollector.__class__) + collector = LayerOutputCollector + + # Inject layer names that need calibration to collector + if hasattr(collector, "include_layers"): + if collector.include_layers is not None: + logger.info('Custom collector has set include_layers attribute. ' + 'Calibration layers not passed') + else: + collector.include_layers = calib_layers + if logger: + logger.info( + 'Create a custom layer output minmax collector for calibration') + else: + raise ValueError('unknown calibration mode %s received,' + ' expected `none`, `naive`, `entropy` or `custom`' % calib_mode) + if logger: + logger.info('Collector created, please use set_monitor_callback' + ' to collect calibration information.') + + if logger: + logger.info('Quantizing parameters') + qarg_params = _quantize_params(qsym, arg_params, min_max_dict={}) + + if is_np_array(): + qsym = qsym.as_np_ndarray() + + return qsym, qarg_params, aux_params, collector, calib_layers + +def calib_graph(qsym, arg_params, aux_params, collector, + calib_mode='entropy', logger=logging): + """User-level API for calibrating a quantized model using a filled collector. + The backend quantized operators are only enabled for Linux systems. Please do not run + inference using the quantized models on Windows for now. + Parameters + ---------- + qsym : str or Symbol + Defines the structure of a neural network for INT8 data types. + arg_params : dict + Dictionary of name to `NDArray`. + aux_params : dict + Dictionary of name to `NDArray`. + collector : function + layer collector for naive or entropy calibration. + calib_mode : str + If calib_mode='none', no calibration will be used and the thresholds for + requantization after the corresponding layers will be calculated at runtime by + calling min and max operators. The quantized models generated in this + mode are normally 10-20% slower than those with calibrations during inference. + If calib_mode='naive', the min and max values of the layer outputs from a calibration + dataset will be directly taken as the thresholds for quantization. + If calib_mode='entropy' (default mode), the thresholds for quantization will be + derived such that the KL divergence between the distributions of FP32 layer outputs and + quantized layer outputs is minimized based upon the calibration dataset. + quantized_dtype : str + The quantized destination type for input data. Currently support 'int8' + , 'uint8' and 'auto'. 'auto' means automatically select output type according to calibration result. + Default value is 'int8'. + logger : Object + A logging object for printing information during the process of quantization. + Returns + ------- + quantized_model : tuple + A tuple of calibrated symbol, quantized arg_params, aux_params. + """ + min_max_dict = {} + if calib_mode is not None and calib_mode != 'none': + if calib_mode in ('entropy', 'naive', 'custom'): + min_max_dict = collector.post_collect() + + else: + raise ValueError('unknown calibration mode %s received,' + ' expected `none`, `naive`, `entropy` or `custom`' % calib_mode) + qsym = _calibrate_quantized_sym(qsym, min_max_dict) + else: + raise ValueError('Please set calibration mode to naive, entropy or custom (with custom CalibrationCollector)') + + if logger: + logger.info('Quantizing parameters') + qarg_params = _quantize_params(qsym, arg_params, min_max_dict) + + if is_np_array(): + qsym = qsym.as_np_ndarray() + + return qsym, qarg_params, aux_params + +def quantize_net(network, quantized_dtype='auto', quantize_mode='full', quantize_granularity='tensor-wise', + exclude_layers=None, exclude_layers_match=None, exclude_operators=None, + calib_data=None, data_shapes=None, calib_mode='none', + num_calib_batches=None, ctx=cpu(), LayerOutputCollector=None, logger=None): + """User-level API for Gluon users to generate a quantized SymbolBlock from a FP32 HybridBlock w/ or w/o calibration. + The backend quantized operators are only enabled for Linux systems. Please do not run + inference using the quantized models on Windows for now. + + Parameters + ---------- + network : Gluon HybridBlock + Defines the structure of a neural network for FP32 data types. + quantized_dtype : str + The quantized destination type for input data. Currently support 'int8' + , 'uint8' and 'auto'. 'auto' means automatically select output type according to calibration result. + Default value is 'int8'. + quantize_mode : str + The mode that quantization pass to apply. Support 'full' and 'smart'. + 'full' means quantize all operator if possible. + 'smart' means quantization pass will smartly choice which operator should be quantized. + quantize_granularity: str + The granularity of quantization, currently supports 'tensor-wise' and 'channel-wise' + quantization. The default value is 'tensor-wise'. + exclude_layers : list of strings + A list of strings representing the names of the symbols that users want to excluding + exclude_layers_match : list of strings + A list of strings wildcard matching the names of the symbols that users want to excluding + from being quantized. + exclude_operators : list of strings + A list of strings representing the names of the operators that users want to excluding + calib_data : gluon.DataLoader + A iterable data loading object. + data_shapes : list of DataDesc or list of tuple + A list of data shapes. Required if calib_data is not provided. In case of tuples, + the names of inputs are generated. + calib_mode : str + If calib_mode='none', no calibration will be used and the thresholds for + requantization after the corresponding layers will be calculated at runtime by + calling min and max operators. The quantized models generated in this + mode are normally 10-20% slower than those with calibrations during inference. + If calib_mode='naive', the min and max values of the layer outputs from a calibration + dataset will be directly taken as the thresholds for quantization. + If calib_mode='entropy' (default mode), the thresholds for quantization will be + derived such that the KL divergence between the distributions of FP32 layer outputs and + quantized layer outputs is minimized based upon the calibration dataset. + If calib_mode='custom', the provided LayerOutputCollector will be used to determine + the thresholds for quantization. For more information refer to CalibrationCollector + documentation. + num_calib_batches : int or None + The maximum number of batches that user would like to use for calibration. If not provided, + the whole calibration dataset will be used. + ctx : Context + Defines the device that users want to run forward propagation on the calibration + dataset for collecting layer output statistics. Currently, only supports single context. + LayerOutputCollector : subclass of CalibrationCollector + For `custom` calibration method usage. + Passed object's include_layers attribute will be feed with names of layers which needs calibration + logger : Object + A logging object for printing information during the process of quantization. + + Returns + ------- + network : Gluon SymbolBlock + Defines the structure of a neural network for INT8 data types. + """ + from ..gluon import SymbolBlock + + if ctx != mx.cpu(): + raise ValueError('Quantization currently supports only CPU context') + backend = 'MKLDNN_QUANTIZE' + + network.hybridize(static_alloc=False, static_shape=False) + data_types = None + if data_shapes is None: + if calib_data is None: + raise ValueError('At least one of data_shapes or calib_data has to be provided.') + + if isinstance(calib_data, mx.gluon.data.DataLoader): + x = iter(calib_data) + batch = next(x) + if isinstance(batch, list): + data_shapes = [b.shape for b in batch] + data_types = [b.dtype for b in batch] + else: + data_shapes = [batch.shape] + data_types = [batch.dtype] + else: + raise ValueError('calib_data expects mx.gluon.data.DataLoader') + + if data_types is None: + data_types = [mx_real_t] * len(data_shapes) + data_descs = _generate_list_of_data_desc(data_shapes, data_types) + + num_inputs = len(data_descs) + data_nd = [] + for desc in data_descs: + if is_np_array(): + data_nd.append(mx.np.zeros(shape=desc.shape, dtype=desc.dtype)) + else: + data_nd.append(mx.nd.zeros(shape=desc.shape, dtype=desc.dtype)) + while True: + try: + network(*data_nd) + except TypeError as err: + if logger: + logger.warning(err) + logger.warning("Deduced input data descriptors failed to run forward pass." + " Trying again with one less input.") + del data_nd[-1] + num_inputs -= 1 + data_shapes = [b.shape for b in data_nd] + data_types = [b.dtype for b in data_nd] + data_descs = _generate_list_of_data_desc(data_shapes, data_types) + continue + else: + break + + symnet, params = network.export(None) + symnet = symnet.optimize_for(backend=backend) + + if is_np_array(): + symnet = symnet.as_np_ndarray() + + args, auxs = dict(), dict() + for k, v in params.items(): + ptype, pname = k[:3], k[4:] + if ptype == "arg": + args[pname] = v + else: + auxs[pname] = v + + if exclude_layers is None: + exclude_layers = [] + if exclude_layers_match is None: + exclude_layers_match = [] + if exclude_operators is None: + exclude_operators = [] + for name_match in exclude_layers_match: + for layers in list(symnet.get_internals()): + if layers.name.find(name_match) != -1: + exclude_layers.append(layers.name) + if logger: + logger.info('These layers have been excluded %s' % exclude_layers) + + qsym, qarg_params, aux_params, collector, _ = quantize_graph( + sym=symnet, arg_params=args, aux_params=auxs, ctx=ctx, + excluded_sym_names=exclude_layers, excluded_op_names=exclude_operators, + calib_mode=calib_mode, quantized_dtype=quantized_dtype, quantize_mode=quantize_mode, + quantize_granularity=quantize_granularity, LayerOutputCollector=LayerOutputCollector, + logger=logger) + + if calib_mode is not None and calib_mode != 'none': + if not isinstance(ctx, Context): + raise ValueError( + 'currently only supports single ctx, while received %s' % str(ctx)) + if calib_data is None: + raise ValueError( + 'calib_data must be provided when calib_mode=%s' % calib_mode) + if calib_mode in ['naive', 'entropy', 'custom']: + inputs = [mx.sym.var(desc.name) for desc in data_descs] + calib_net = SymbolBlock(symnet, inputs) + calib_net.load_dict(params, cast_dtype=True, dtype_source='saved') + calib_net.hybridize(static_alloc=False, static_shape=False) + num_batches = _collect_layer_statistics(calib_net, calib_data, collector, num_inputs, + num_calib_batches, logger) + + if logger: + logger.info('Collected layer output values from FP32 model using %d batches' + % num_batches) + + qsym, qarg_params, aux_params = calib_graph( + qsym=qsym, arg_params=args, aux_params=auxs, collector=collector, + calib_mode=calib_mode, logger=logger) + else: + raise ValueError('calib_mode has to be one of: naive, entropy, custom') + elif calib_mode is not None and calib_mode == 'none': + inputs = [mx.sym.var(desc.name) for desc in data_descs] + + net = SymbolBlock(qsym, inputs) + all_params = {('arg:%s' % k): v.as_in_context(cpu()) for k, v in qarg_params.items()} + all_params.update({('aux:%s' % k): v.as_in_context(cpu()) for k, v in aux_params.items()}) + net.load_dict(all_params, cast_dtype=True, dtype_source='saved') + net.optimize_for(data_nd, backend=backend, skip_infer=True) + return net diff --git a/src/common/utils.h b/src/common/utils.h index 5582f711ae1f..aa36dc1eae47 100644 --- a/src/common/utils.h +++ b/src/common/utils.h @@ -855,15 +855,6 @@ void ExecuteMonOutputCallback( size_t nid, const std::function &monitor_callback); -/*! - * \brief This is function can return the output names of a NodeEntry. - */ -static inline std::string GetOutputName(const nnvm::NodeEntry& e) { - nnvm::Symbol sym; - sym.outputs.push_back(e); - return sym.ListOutputNames()[0]; -} - inline mxnet::TShape CanonicalizeAxes(const mxnet::TShape& src) { // convert negative axes to positive values const int ndim = src.ndim(); diff --git a/src/operator/nn/mkldnn/mkldnn_convolution-inl.h b/src/operator/nn/mkldnn/mkldnn_convolution-inl.h index ac2d3169340e..84299eeed2cf 100644 --- a/src/operator/nn/mkldnn/mkldnn_convolution-inl.h +++ b/src/operator/nn/mkldnn/mkldnn_convolution-inl.h @@ -42,6 +42,7 @@ struct MKLDNNConvParam : public dmlc::Parameter { bool with_sum; bool with_postsum_act; bool quantized; + bool dedup_sum; dmlc::optional min_calib_range; // min float value calculated from calibration dataset dmlc::optional max_calib_range; // max float value calculated from calibration dataset @@ -57,6 +58,8 @@ struct MKLDNNConvParam : public dmlc::Parameter { .describe("Add post activation after sum"); DMLC_DECLARE_FIELD(quantized).set_default(false) .describe("enable quantization"); + DMLC_DECLARE_FIELD(dedup_sum).set_default(false). + describe("deduplicated sum input"); DMLC_DECLARE_FIELD(min_calib_range) .set_default(dmlc::optional()) .describe("The minimum scalar value in the form of float32 obtained " diff --git a/src/operator/quantization/calibrate.cc b/src/operator/quantization/calibrate.cc index 3852eab979f8..5d4237da43ae 100644 --- a/src/operator/quantization/calibrate.cc +++ b/src/operator/quantization/calibrate.cc @@ -191,6 +191,7 @@ static inline bool CalibrateType(const nnvm::NodeAttrs& attrs, std::vector* } NNVM_REGISTER_OP(_contrib_calibrate_entropy) +.add_alias("_npx_contrib_calibrate_entropy") .describe(R"code(Provide calibrated min/max for input histogram. .. Note:: diff --git a/src/operator/quantization/quantize.cc b/src/operator/quantization/quantize.cc index 63467506b99b..a77832829b36 100644 --- a/src/operator/quantization/quantize.cc +++ b/src/operator/quantization/quantize.cc @@ -49,6 +49,7 @@ bool QuantizeStorageType(const nnvm::NodeAttrs& attrs, } NNVM_REGISTER_OP(_contrib_quantize) +.add_alias("_npx_contrib_quantize") .describe(R"code(Quantize a input tensor from float to `out_type`, with user-specified `min_range` and `max_range`. diff --git a/src/operator/quantization/quantize_graph_pass.cc b/src/operator/quantization/quantize_graph_pass.cc index ff75801af410..624ad047e061 100644 --- a/src/operator/quantization/quantize_graph_pass.cc +++ b/src/operator/quantization/quantize_graph_pass.cc @@ -27,6 +27,7 @@ #include #include #include +#include #include #include #include @@ -42,6 +43,24 @@ using nnvm::ObjectPtr; using nnvm::NodeEntry; using nnvm::Graph; +static inline std::string GetOutputName(const nnvm::Node* node, int index) { + // Map where Key is Op and Value is function registered by FListOutputNames + static const auto& flist_outputs = nnvm::Op::GetAttr("FListOutputNames"); + std::vector output_names; + + // If operator has registered FListOutputNames function + // get names by calling it with passed node attributes as argument + // else return output index as a string + if (flist_outputs.count(node->op())) { + output_names = flist_outputs[node->op()](node->attrs); + CHECK_GT(output_names.size(), index); + return output_names[index]; + } + + CHECK_GT(node->num_outputs(), index); + return std::to_string(index); +} + static inline size_t GetNumOutputs(ObjectPtr node) { // Get NumOutputs, check if current node has NumVisibleOutputs function, if yes, return // num_visible_outputs @@ -265,7 +284,6 @@ static void MarkQuantizedNodes(const Graph& src, } Graph QuantizeGraph(Graph &&src) { - static const auto& flist_outputs = nnvm::Op::GetAttr("FListOutputNames"); static const auto& need_requantize_map = Op::GetAttr("FNeedRequantize"); static const auto& avoid_quantize_input_map = Op::GetAttr("FAvoidQuantizeInput"); @@ -322,18 +340,17 @@ Graph QuantizeGraph(Graph &&src) { // Or the output name is not ending with 'output', just put the output name here // to better align with calibration phase. No need to change name to weights/bias. std::string suffix = ""; + std::string new_name = e.node->attrs.name; + if (mirror_node->op() != nullptr) { - auto list_output_names_func = flist_outputs.get(e.node->op(), nullptr); - if (list_output_names_func != nullptr) { - std::vector names = list_output_names_func(e.node->attrs); - suffix = "_" + names[e.index]; - } else { - suffix = "_" + std::to_string(e.index); - } + std::string name = GetOutputName(e.node.get(), e.index); + suffix = "_" + name; + } else if (!offline_params.count(new_name)) { + new_name = node->attrs.name + "_" + e.node->attrs.name; } ObjectPtr quantize_node = InsertNode("_contrib_quantize_v2", - e.node->attrs.name + suffix + "_quantize", new_node, mirror_entry); + new_name + suffix + "_quantize", new_node, mirror_entry); quantize_node->attrs.dict["out_type"] = quantized_dtype; quantize_node->op()->attr_parser(&(quantize_node->attrs)); mirror_entry_map[e] = NodeEntry{quantize_node, 0, e.version}; @@ -486,22 +503,42 @@ Graph QuantizeGraph(Graph &&src) { Op::GetAttr("FNeedCalibrateInput"); static const auto& need_calib_output_map = Op::GetAttr("FNeedCalibrateOutput"); + + std::stack calib_variables; std::vector calib_nodes; DFSVisit(ret.outputs, [&](const ObjectPtr& node) { + if (node->op() && !calib_variables.empty()) { + if (reverse_mirror_map.count(node)) { + const std::string& var_name = calib_variables.top(); + const auto& fp32_in_node = reverse_mirror_map[node]; + for (const auto &input_node : fp32_in_node->inputs) { + if (var_name == input_node.node->attrs.name) { + calib_nodes.push_back(fp32_in_node->attrs.name + "_" + var_name); + calib_variables.pop(); + break; + } + } + } + } if (need_calib_input_map.count(node->op())) { const auto calib_idx = need_calib_input_map[node->op()](node->attrs); for (const auto &idx : calib_idx) { if (reverse_mirror_map.count(node)) { - calib_nodes.push_back(common::GetOutputName( - {reverse_mirror_map[node], node->inputs[idx].index, node->inputs[idx].version})); + const auto& fp32_in_node = reverse_mirror_map[node]; + std::string name = GetOutputName(fp32_in_node.get(), node->inputs[idx].index); + calib_nodes.push_back(fp32_in_node->attrs.name + "_" + name); } else { const auto& e = node->inputs[idx]; if (e.node->is_variable()) { - calib_nodes.push_back(e.node->attrs.name); + // monitor callback join operator name and variable name as observable node, + // utilize fact that we're using DFS and put variable name on stack to + // find operator node name for this variable node + calib_variables.emplace(e.node->attrs.name); } else { if (reverse_mirror_map.count(e.node)) { const auto& fp32_in_node = reverse_mirror_map.at(e.node); - calib_nodes.push_back(common::GetOutputName({fp32_in_node, e.index, e.version})); + std::string name = GetOutputName(fp32_in_node.get(), e.index); + calib_nodes.push_back(fp32_in_node->attrs.name + "_" + name); } else { LOG(FATAL) << "Can't find calibration node for " << node->attrs.name; } @@ -512,10 +549,12 @@ Graph QuantizeGraph(Graph &&src) { const auto calib_idx = need_calib_output_map[node->op()](node->attrs); for (const auto& idx : calib_idx) { if (reverse_mirror_map.count(node)) { - calib_nodes.push_back( - common::GetOutputName({reverse_mirror_map[node], static_cast(idx), 0})); + const auto& fp32_in_node = reverse_mirror_map[node]; + std::string name = GetOutputName(fp32_in_node.get(), static_cast(idx)); + calib_nodes.push_back(fp32_in_node->attrs.name + "_" + name); } else { - calib_nodes.push_back(common::GetOutputName({node, static_cast(idx), 0})); + std::string name = GetOutputName(node.get(), static_cast(idx)); + calib_nodes.push_back(node->attrs.name + "_" + name); } } } @@ -527,12 +566,22 @@ Graph QuantizeGraph(Graph &&src) { static inline void SetCalibTableForEntry( const NodeEntry& e, const ObjectPtr& node, const std::unordered_map>& calib_table) { - std::string out_data_name = common::GetOutputName(e); + std::string out_name = GetOutputName(e.node.get(), e.index); + std::string full_node_name = e.node->attrs.name; + + if (!e.node->is_variable()) { + full_node_name += "_" + out_name; + } else { + const std::string suffix = "_quantize"; + full_node_name = node->attrs.name; + full_node_name = std::string(full_node_name.begin(), full_node_name.end() - suffix.size()); + } + const std::string prefix = "quantized_"; if (e.node->attrs.name.rfind(prefix, 0) == 0) { - out_data_name = out_data_name.substr(prefix.size()); + full_node_name = full_node_name.substr(prefix.size()); } - const auto calib_table_iter = calib_table.find(out_data_name); + const auto calib_table_iter = calib_table.find(full_node_name); static int verbose = dmlc::GetEnv("MXNET_QUANTIZATION_VERBOSE", 0); if (calib_table_iter != calib_table.end()) { if (verbose) { diff --git a/src/operator/quantization/quantize_v2.cc b/src/operator/quantization/quantize_v2.cc index 9a30386723be..d8b355964f72 100644 --- a/src/operator/quantization/quantize_v2.cc +++ b/src/operator/quantization/quantize_v2.cc @@ -64,6 +64,7 @@ static OpStatePtr CreateQuantizeV2State(const nnvm::NodeAttrs& attrs, Context ct } NNVM_REGISTER_OP(_contrib_quantize_v2) +.add_alias("_npx_contrib_quantize_v2") .describe(R"code(Quantize a input tensor from float to `out_type`, with user-specified `min_calib_range` and `max_calib_range` or the input range collected at runtime. diff --git a/src/operator/subgraph/mkldnn/mkldnn_conv.cc b/src/operator/subgraph/mkldnn/mkldnn_conv.cc index 0868e0c8da21..4f703c311483 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_conv.cc +++ b/src/operator/subgraph/mkldnn/mkldnn_conv.cc @@ -78,6 +78,9 @@ static void UpdateConvWeightBias(NDArray *weight, NDArray *bias, bool no_bias, } static inline size_t GetInSumIndex(const MKLDNNConvFusionParam ¶m) { + if (param.full_conv_param.mkldnn_param.dedup_sum) { + return 0; + } return 2 + (param.full_conv_param.conv_param.no_bias ? 0 : 1) + (param.full_conv_param.mkldnn_param.with_bn ? 4 : 0); } @@ -127,8 +130,15 @@ void SgMKLDNNConvOperator::Forward(const OpContext &ctx, 2 + (conv_param.no_bias ? 0 : 1) + (mkldnn_param.with_bn ? 4 : 0) + (mkldnn_param.with_sum ? 1 : 0) + (mkldnn_param.quantized ? 2 + (full_conv_param.mkldnn_param.with_sum ? 2 : 0) : 0); + // When dedup is on, in_data is used to calculate sum instead of in_sum + if (mkldnn_param.dedup_sum) { + input_size -= 1; + if (mkldnn_param.quantized) { + input_size -= 2; + } + } CHECK_EQ(inputs.size(), input_size); - size_t idx = 0; + index_t idx = 0; auto in_data = idx++; auto in_weight = idx++; @@ -137,17 +147,22 @@ void SgMKLDNNConvOperator::Forward(const OpContext &ctx, auto in_beta = mkldnn_param.with_bn ? (idx++) : 0; auto in_mean = mkldnn_param.with_bn ? (idx++) : 0; auto in_var = mkldnn_param.with_bn ? (idx++) : 0; - auto in_sum = mkldnn_param.with_sum ? (idx++) : 0; + auto in_sum = mkldnn_param.with_sum ? (mkldnn_param.dedup_sum? in_data : idx++) : -1; float data_min = mkldnn_param.quantized ? inputs[idx++].data().dptr()[0] : 0.0; float data_max = mkldnn_param.quantized ? inputs[idx++].data().dptr()[0] : 0.0; - float sum_min = (mkldnn_param.with_sum && mkldnn_param.quantized) - ? inputs[idx++].data().dptr()[0] - : 0.0; - float sum_max = (mkldnn_param.with_sum && mkldnn_param.quantized) - ? inputs[idx++].data().dptr()[0] - : 0.0; + float sum_min = 0.0f; + float sum_max = 0.0f; + if (mkldnn_param.with_sum && mkldnn_param.quantized) { + if (mkldnn_param.dedup_sum) { + sum_min = data_min; + sum_max = data_max; + } else { + sum_min = inputs[idx++].data().dptr()[0]; + sum_max = inputs[idx++].data().dptr()[0]; + } + } CHECK_EQ(input_size, idx); bool has_bias = mkldnn_param.with_bn || !conv_param.no_bias; NDArray data = inputs[in_data]; @@ -370,7 +385,8 @@ static uint32_t SgMKLDNNConvNumInputs(const NodeAttrs &attrs) { auto const ¶m = nnvm::get(attrs.parsed); auto num_input = DefaultSubgraphOpNumInputs(attrs); if (param.full_conv_param.mkldnn_param.quantized) - return num_input + 2 + (param.full_conv_param.mkldnn_param.with_sum ? 2 : 0); + return num_input + 2 + (param.full_conv_param.mkldnn_param.with_sum + && !param.full_conv_param.mkldnn_param.dedup_sum ? 2 : 0); else return num_input; } @@ -458,13 +474,14 @@ static std::vector SgMKLDNNConvListInputNames(const NodeAttrs &attr input_names.emplace_back("mean"); input_names.emplace_back("var"); } - if (param.full_conv_param.mkldnn_param.with_sum) { + auto &mkldnn_param = param.full_conv_param.mkldnn_param; + if (mkldnn_param.with_sum && !mkldnn_param.dedup_sum) { input_names.emplace_back("sum"); } if (param.full_conv_param.mkldnn_param.quantized) { input_names.emplace_back("data_min"); input_names.emplace_back("data_max"); - if (param.full_conv_param.mkldnn_param.with_sum) { + if (mkldnn_param.with_sum && !mkldnn_param.dedup_sum) { input_names.emplace_back("sum_min"); input_names.emplace_back("sum_max"); } @@ -498,7 +515,7 @@ static void FilterMinMaxIndice(const MKLDNNConvParam &mkldnn_param, std::unordered_set *minmax_indice) { base_out_shapes->push_back(out_shapes->at(0)); size_t last = in_shapes->size() - 1; - if (mkldnn_param.with_sum) { + if (mkldnn_param.with_sum && !mkldnn_param.dedup_sum) { minmax_indice->insert(last); minmax_indice->insert(last - 1); minmax_indice->insert(last - 2); @@ -560,14 +577,15 @@ static bool SgMKLDNNConvInferType(const nnvm::NodeAttrs &attrs, int orig_data = base_in_types[0]; base_in_types[0] = mshadow::kFloat32; int orig_sum = base_in_types[0]; - if (param.full_conv_param.mkldnn_param.with_sum) { + auto &mkldnn_param = param.full_conv_param.mkldnn_param; + if (param.full_conv_param.mkldnn_param.with_sum && !mkldnn_param.dedup_sum) { auto sum_index = GetInSumIndex(param); orig_sum = base_in_types[sum_index]; base_in_types[sum_index] = mshadow::kFloat32; } bool result = DefaultSubgraphOpType(attrs, &base_in_types, &base_out_types); base_in_types[0] = orig_data; - if (param.full_conv_param.mkldnn_param.with_sum) { + if (param.full_conv_param.mkldnn_param.with_sum && !mkldnn_param.dedup_sum) { auto sum_index = GetInSumIndex(param); base_in_types[sum_index] = orig_sum; } @@ -634,7 +652,8 @@ static bool SgMKLDNNConvOpStorageType(const nnvm::NodeAttrs &attrs, std::vector> SgMKLDNNConvInplaceOption( const NodeAttrs &attrs) { auto const ¶m = nnvm::get(attrs.parsed); - if (param.full_conv_param.mkldnn_param.with_sum) { + if (param.full_conv_param.mkldnn_param.with_sum && + !param.full_conv_param.mkldnn_param.dedup_sum) { return std::vector>{{GetInSumIndex(param), 0}}; } else { return std::vector>(); diff --git a/src/operator/subgraph/mkldnn/mkldnn_conv_property.h b/src/operator/subgraph/mkldnn/mkldnn_conv_property.h index cae2fcdc7331..6eaa930f5422 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_conv_property.h +++ b/src/operator/subgraph/mkldnn/mkldnn_conv_property.h @@ -98,22 +98,23 @@ class SgMKLDNNConvSelector : public SubgraphSelector { // Use status_ machine to do selection. The status_ change is // kStart -> kBN -> kSum -> kSuccess + const auto node_name = new_node.op()->name; switch (status_) { case kStart: - if ((!disable_conv_bn_) && new_node.op()->name == "BatchNorm") { + if ((!disable_conv_bn_) && node_name == "BatchNorm") { matched_list_.push_back(&new_node); status_ = kBN; return true; } case kBN: - if ((!disable_conv_sum_) && new_node.op()->name == "elemwise_add") { + if ((!disable_conv_sum_) && (node_name == "elemwise_add" || node_name == "_npi_add")) { matched_list_.push_back(&new_node); status_ = kSum; return true; } case kSum: default: - if ((!disable_conv_act_) && new_node.op()->name == "Activation") { + if ((!disable_conv_act_) && node_name == "Activation") { const ActivationParam ¶m = nnvm::get(new_node.attrs.parsed); if ((quantize_ && SupportQuantizedMKLDNNAct(param)) || @@ -123,7 +124,7 @@ class SgMKLDNNConvSelector : public SubgraphSelector { status_ = kSuccess; return true; } - } else if ((!disable_conv_act_) && new_node.op()->name == "LeakyReLU") { + } else if ((!disable_conv_act_) && node_name == "LeakyReLU") { const LeakyReLUParam ¶m = nnvm::get(new_node.attrs.parsed); if (param.act_type == leakyrelu::kLeakyReLU || @@ -133,7 +134,7 @@ class SgMKLDNNConvSelector : public SubgraphSelector { status_ = kSuccess; return true; } - } else if ((!disable_conv_act_) && new_node.op()->name == "clip") { + } else if ((!disable_conv_act_) && node_name == "clip") { if (!(quantize_ && (status_ == kSum))) { // TODO(zhennan): doesn't support int8 conv+sum+relu6 at moment. To support this, we // need to fuse conv+sum first, and calibrate with it. Then fuse int8 relu6 into fused @@ -215,11 +216,10 @@ class SgMKLDNNConvProperty : public SubgraphProperty { } else if (sub_name == "BatchNorm") { node_name << "bn_"; n->attrs.dict["with_bn"] = "true"; - } else if (sub_name == "elemwise_add") { + } else if (sub_name == "elemwise_add" || sub_name == "_npi_add") { node_name << "add_"; n->attrs.dict["with_sum"] = "true"; _with_sum = true; - } else if (sub_name == "Activation" || sub_name == "LeakyReLU" || sub_name == "clip") { node_name << "act_"; if (!_with_sum) { @@ -259,10 +259,20 @@ class SgMKLDNNConvProperty : public SubgraphProperty { std::vector *orig_input_entries) const override { auto sym = n->attrs.subgraphs[0]; std::unordered_set node_sets; + nnvm::Node* conv_input = nullptr; DFSVisit(sym->outputs, [&](const nnvm::ObjectPtr &node) { if (node->is_variable()) return; node_sets.insert(node.get()); - if (node->op()->name == "elemwise_add") { + if (node->op()->name == "Convolution") { + conv_input = node->inputs[0].node.get(); + } else if (node->op()->name == "elemwise_add" || node->op()->name == "_npi_add") { + if (dedup_subgraph && + (conv_input == node->inputs[1].node.get() || + conv_input == node->inputs[0].node.get())) { + n->attrs.dict["dedup_sum"] = "true"; + n->op()->attr_parser(&(n->attrs)); + return; + } // Make sure n is the left operand of sum, if not, // switch sum operands sequence to ensure that // the extra sum operand stays in the last of inputs. diff --git a/src/operator/subgraph/mkldnn/mkldnn_post_quantize_property.h b/src/operator/subgraph/mkldnn/mkldnn_post_quantize_property.h index 085dd494dcd2..e2bb7ca6413d 100644 --- a/src/operator/subgraph/mkldnn/mkldnn_post_quantize_property.h +++ b/src/operator/subgraph/mkldnn/mkldnn_post_quantize_property.h @@ -50,6 +50,7 @@ class SgMKLDNNPostQuantizeSelector : public SubgraphSelector { SgMKLDNNPostQuantizeSelector() { support_requantize_fusion_op_name.insert("_sg_mkldnn_conv"); support_requantize_fusion_op_name.insert("_contrib_quantized_elemwise_add"); + support_requantize_fusion_op_name.insert("_contrib_quantized_npi_add"); } bool Select(const nnvm::Node &n) override { @@ -62,7 +63,8 @@ class SgMKLDNNPostQuantizeSelector : public SubgraphSelector { matched_list.push_back(&n); return true; } - } else if (n.op()->name == "_contrib_quantized_elemwise_add") { + } else if (n.op()->name == "_contrib_quantized_elemwise_add" || + n.op()->name == "_contrib_quantized_npi_add") { status = kStart; matched_list.clear(); matched_list.push_back(&n); @@ -121,6 +123,7 @@ class SgMKLDNNPostQuantizeProperty : public SubgraphProperty { SgMKLDNNPostQuantizeProperty() { support_requantize_fusion_op_name.insert("_sg_mkldnn_conv"); support_requantize_fusion_op_name.insert("_contrib_quantized_elemwise_add"); + support_requantize_fusion_op_name.insert("_contrib_quantized_npi_add"); } static SubgraphPropertyPtr Create() { static const std::string &name = "MKLDNN post-quantization optimization pass"; diff --git a/tests/python/mkl/test_quantization_mkldnn.py b/tests/python/mkl/test_quantization_mkldnn.py new file mode 100644 index 000000000000..d3251ea604bc --- /dev/null +++ b/tests/python/mkl/test_quantization_mkldnn.py @@ -0,0 +1,31 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import os +import sys +import mxnet as mx + +os.environ['ENABLE_MKLDNN_QUANTIZATION_TEST'] = '1' +os.environ['MXNET_SUBGRAPH_BACKEND'] = 'NONE' +curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) +sys.path.insert(0, os.path.join(curr_path, '../quantization')) +from test_quantization import * + +if __name__ == '__main__': + import pytest + pytest.main() + del os.environ['ENABLE_MKLDNN_QUANTIZATION_TEST'] + del os.environ['MXNET_SUBGRAPH_BACKEND'] diff --git a/tests/python/mkl/test_subgraph.py b/tests/python/mkl/test_subgraph.py index 2f09854d754c..a88ca555a4bc 100644 --- a/tests/python/mkl/test_subgraph.py +++ b/tests/python/mkl/test_subgraph.py @@ -20,9 +20,13 @@ import mxnet as mx import numpy as np import unittest -import ctypes import pytest -from mxnet.test_utils import assert_almost_equal +import ctypes + +curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) +sys.path.append(os.path.join(curr_path, '../unittest/')) +from mxnet.contrib import quant +from mxnet.test_utils import assert_almost_equal, assert_almost_equal_with_err, DummyIter def test_float64_fallback(): sym = mx.sym.FullyConnected( @@ -40,6 +44,921 @@ def test_float64_fallback(): ex.outputs[0].wait_to_read() +OP_NAME='op_name' +QUANTIZED_OP_NAME='quantized_op_name' +SG_PASS_NAME='MKLDNN' +QUANTIZE_SG_PASS_NAME='MKLDNN_QUANTIZE' +config = { + 'conv': { + OP_NAME: 'sg_mkldnn_conv', + QUANTIZED_OP_NAME: 'quantized_sg_mkldnn_conv' + }, + 'fc': { + OP_NAME: 'sg_mkldnn_fully_connected', + QUANTIZED_OP_NAME: 'quantized_sg_mkldnn_fully_connected' + } +} + +DATA_SHAPE=[(64, 4, 10, 10), (4, 3, 24, 24), (1, 16, 32, 32)] +fc_post_ops_list=['relu', 'sigmoid', 'tanh', 'softrelu', + 'square', 'square_root', 'abs', 'exp', 'bounded_relu'] + + +def initialize_block_params(block, initializer): + for name, param in block.collect_params('.*gamma|.*moving_var|.*running_var').items(): + param.initialize(mx.init.Constant(1)) + for name, param in block.collect_params('.*beta|.*moving_mean|.*running_mean|.*bias').items(): + param.initialize(mx.init.Constant(0)) + for name, param in block.collect_params('.*weight').items(): + param.initialize(initializer) + +def collect_block_args_aux(block, sym): + arg_params, aux_params = dict(), dict() + for k, v in block.collect_params().items(): + if k in sym.list_arguments(): + arg_params[k]= v._reduce() + elif k in sym.list_auxiliary_states(): + aux_params[k]= v._reduce() + return arg_params, aux_params + +def check_qsym_calibrated(qsym, out_type, name='conv'): + quantized_op_name = 'quantized_' + name + assert ''.join(qsym.attr_dict().keys()).find(quantized_op_name) != -1 + for k, v in qsym.attr_dict().items(): + if k.find('_quantize') != -1: + assert v['out_type'] == out_type + if k.find(quantized_op_name) != -1: + if quantized_op_name.startswith("quantized_sg_mkldnn_fully_connected") and 'enable_float_output' in v: + continue + assert 'min_calib_range' in v + assert 'max_calib_range' in v + +def check_qsym_scale_align(qsym): + assert ''.join(qsym.attr_dict().keys()).find('quantized_sg_mkldnn_conv') != -1 + init = False + for k, v in qsym.attr_dict().items(): + if k.find('quantized_sg_mkldnn_conv') != -1: + assert 'min_calib_range' in v + assert 'max_calib_range' in v + if not init: + min_calib_range = v['min_calib_range'] + max_calib_range = v['max_calib_range'] + init = True + else: + assert min_calib_range == v['min_calib_range'] + assert max_calib_range == v['max_calib_range'] + + +def check_qsym_dummy_forward(qsym, data, data_shape): + inputs = mx.sym.var('data', dtype='float32') + sym_block = mx.gluon.SymbolBlock(qsym, inputs) + initialize_block_params(sym_block, mx.init.One()) + + outputs = sym_block(data) + for output in outputs: + output.wait_to_read() + return outputs + +def check_qsym_gluon_forward(qsym, qarg_params, qaux_params, data): + param_dict = {('arg:%s' % k): v.as_in_context(mx.current_context()) for k, v in qarg_params.items()} + param_dict.update({('aux:%s' % k): v.as_in_context(mx.current_context()) for k, v in qaux_params.items()}) + + # create SymbolBlock + net = mx.gluon.SymbolBlock(qsym, mx.sym.var('data')) + net.load_dict(param_dict, cast_dtype=True, dtype_source='saved') + net.reset_ctx(ctx = mx.current_context()) + net.hybridize() + outputs = net(data) + for output in outputs: + output.wait_to_read() + return outputs + +def check_quantize(sym, data_shape, out_type, name='conv', + check_calibration=True, check_scale_align=False): + quantize_granularity_list = ['tensor-wise'] + if name == 'fc': + quantize_granularity_list += ['channel-wise'] + + if name in config: + name = config[name][OP_NAME] + sym_sg = sym.optimize_for(QUANTIZE_SG_PASS_NAME, dedup_subgraph=True, skip_infer=True) + + inputs = mx.sym.var('data', dtype='float32') + sym_block = mx.gluon.SymbolBlock(sym, inputs) + initialize_block_params(sym_block, mx.init.Normal(0.5)) + + min_value = -1 if out_type != 'uint8' else 0 + data = mx.random.uniform(min_value, 1.0, shape=data_shape, dtype='float32', ctx=mx.current_context()) + + outputs = sym_block(data) + for output in outputs: + output.wait_to_read() + ref_out = outputs + arg_params, aux_params = collect_block_args_aux(sym_block, sym) + + excluded_sym_names = [] + excluded_op_names = [] + + calib_data = mx.gluon.data.DataLoader(data, batch_size=1) + for quantize_granularity in quantize_granularity_list: + qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(sym=sym_sg, + arg_params=arg_params, + aux_params=aux_params, + ctx=mx.current_context(), + excluded_sym_names=excluded_sym_names, + excluded_op_names=excluded_op_names, + quantized_dtype=out_type, + calib_mode='naive', + calib_data=calib_data, + num_calib_batches=1, + quantize_mode='full', + quantize_granularity=quantize_granularity) + qsym = qsym.optimize_for(QUANTIZE_SG_PASS_NAME, dedup_subgraph=True, skip_infer=True) + if check_calibration: + check_qsym_calibrated(qsym, out_type, name=name) + if check_scale_align: + check_qsym_scale_align(qsym) + + quantized_out = check_qsym_gluon_forward(qsym, qarg_params, qaux_params, data) + for i in range(len(ref_out)): + min_range = mx.nd.min(ref_out[i]).asscalar() + max_range = mx.nd.max(ref_out[i]).asscalar() + atol = 0.1 * max(abs(min_range), abs(max_range)) + assert_almost_equal_with_err(quantized_out[i].asnumpy(), ref_out[i].asnumpy(), rtol=0.1, atol=atol, etol=0.2) + check_qsym_dummy_forward(qsym, data, data_shape) + + +@pytest.mark.parametrize('qdtype', ['uint8', 'int8', 'auto']) +def test_quantize_whole_model_with_forward(qdtype): + batch_size = 4 + data_shape = (batch_size, 4, 10, 10) + data = mx.sym.Variable('data') + conv0 = mx.sym.Convolution(data, kernel=(1, 1), num_filter=16, name='conv0') + sym = mx.sym.Convolution(conv0, kernel=(1, 1), num_filter=16, name='conv1') + + sym_block = mx.gluon.SymbolBlock(outputs=sym, inputs=data) + initialize_block_params(sym_block, mx.init.Normal(0.5)) + + in_data = mx.random.uniform(0.0 if qdtype=='uint8' else -1.0, 1.0, shape=data_shape) + ref_out = sym_block(in_data) + + excluded_layers = [] + + calib_data = mx.nd.random.uniform(0.0 if qdtype=='uint8' else -1.0, 1.0, shape=data_shape) + calib_data = mx.gluon.data.DataLoader(calib_data, batch_size=batch_size) + qsym = mx.contrib.quantization.quantize_net(sym_block, + ctx=mx.current_context(), + exclude_layers=excluded_layers, + quantized_dtype=qdtype, + calib_mode='naive', + calib_data=calib_data, + num_calib_batches=1, + quantize_mode='full') + + outputs = qsym(in_data) + for output in outputs: + output.wait_to_read() + + for i in range(len(ref_out)): + min_range = mx.nd.min(ref_out[i]).asscalar() + max_range = mx.nd.max(ref_out[i]).asscalar() + atol = 0.1 * max(abs(min_range), abs(max_range)) + assert_almost_equal_with_err(outputs[i].asnumpy(), ref_out[i].asnumpy(), rtol=0.1, atol=atol, etol=0.2) + + + +def check_fusion(sym, data_shape, attrs_dict, check_fp32_fusion=True, check_quantization=True, + out_types=['uint8', 'int8', 'auto'], dedup_subgraph=True): + if check_fp32_fusion: + data_min = -1.0 + data_max = 1.0 + if ''.join(sym.get_internals().list_outputs()).find('sqrt') != -1: + check_quantization = False + data_min = 0 + + sym_sg = sym.optimize_for(SG_PASS_NAME, dedup_subgraph=dedup_subgraph, skip_infer=True) + for name, attrs in attrs_dict.items(): + if name in config: + op_name = config[name][OP_NAME] + else: + op_name = name + assert ''.join(sym_sg.get_internals().list_outputs()).find(op_name) != -1 + if len(attrs): + found = False + for k, v in sym_sg.attr_dict().items(): + if k.find(op_name) != -1: + found = True + for attr_name, attr_value in attrs.items(): + assert v[attr_name].lower() == attr_value.lower() + assert found + arg_shapes, _, aux_shapes = sym.infer_shape() + aux_names = sym.list_auxiliary_states() + arg_names = sym.list_arguments() + arg_dict = {name : mx.nd.random.uniform(data_min, data_max, shape=shape, dtype='float32') + for shape, name in zip(arg_shapes, arg_names)} + + aux_dict = {name : mx.nd.random.uniform(shape=shape, dtype='float32') for shape, name in zip(aux_shapes, aux_names)} + + exe = sym._bind(ctx=mx.current_context(), args=arg_dict, aux_states=aux_dict, grad_req='null') + exe.forward() + + exe_sg = sym_sg._bind(ctx=mx.current_context(), args=arg_dict, aux_states=aux_dict, grad_req='null') + exe_sg.forward() + for i in range(len(exe.outputs)): + assert_almost_equal(exe.outputs[i].asnumpy(), exe_sg.outputs[i].asnumpy(), rtol=1e-3, atol=1e-1) + + if check_quantization: + # fp32 to int8 + for out_type in out_types: + check_quantize(sym, data_shape, out_type, name=name) + +def check_neg_fusion(syms, attrs_name=None, excluded_attrs=None, + date_shape=(4,4,10,10), name='conv'): + op_name = config[name][OP_NAME] + + for sym, attrs, excluded_attr in zip(syms, attrs_name, excluded_attrs): + sym_sg = sym.optimize_for(SG_PASS_NAME, dedup_subgraph=True, skip_infer=True) + exe_sg = sym_sg._simple_bind(mx.cpu(), data=date_shape, grad_req='null') + + attrs_dict = sym_sg.attr_dict() + for k, v in attrs_dict.items(): + if k.find(op_name) != -1: + for attr in attrs: + assert v[attr] == 'true' + for exc_attr in excluded_attr: + assert exc_attr not in v.keys() + +def head_symbol(data_shape): + data = mx.symbol.Variable('data', shape=data_shape, dtype='float32') + weight = mx.symbol.Variable('weight', dtype='float32') + return data, weight + + +@pytest.mark.parametrize('data_shape', DATA_SHAPE) +@pytest.mark.parametrize('no_bias', [True, False]) +def test_pos_single_conv(no_bias, data_shape): +# single conv fusion case + attr = {'conv': []} + data, weight = head_symbol(data_shape) + conv = mx.symbol.Convolution(data=data, weight=weight, name='conv', num_filter=64, + kernel=(3, 3), stride=(1, 1), no_bias=no_bias) + check_fusion(conv, data_shape, attr) + +# conv + bn fusion case +def conv_bn(no_bias, data_shape): + attr = {'conv': {'with_bn': 'true'}} + data, weight = head_symbol(data_shape) + conv = mx.symbol.Convolution(data=data, weight=weight, name='conv', num_filter=64, + kernel=(3, 3), stride=(1, 1), no_bias=no_bias) + bn1 = mx.symbol.BatchNorm(data=conv, name="bn1") + return bn1, attr + +# conv + act fusion case +def conv_act(no_bias, data_shape, alg): + attr = {'conv': {'with_act': 'true'}} + data, weight = head_symbol(data_shape) + conv = mx.symbol.Convolution(data=data, weight=weight, name='conv', num_filter=64, + kernel=(3, 3), stride=(1, 1), no_bias=no_bias) + if alg == "relu6": + relu = mx.symbol.clip(data=conv, name='relu6', a_min=0, a_max=6) + elif alg == "leakyrelu": + relu = mx.symbol.LeakyReLU(data=conv, slope=0.25, act_type='leaky') + elif alg == "gelu": + relu = mx.symbol.LeakyReLU(data=conv, act_type='gelu') + else: + relu = mx.symbol.Activation(data=conv, name=alg, act_type=alg) + return relu, attr + +# conv + act + sum fusion case +def conv_act_sum(no_bias, data_shape, alg): + attr = {'conv': {'with_act': 'true', 'with_sum': 'true'}} + data, weight = head_symbol(data_shape) + conv = mx.symbol.Convolution(data=data, weight=weight, name='conv', num_filter=64, + kernel=(3, 3), stride=(1, 1), no_bias=no_bias) + if alg == "relu6": + relu = mx.symbol.clip(data=conv, name='relu6', a_min=0, a_max=6) + elif alg == "leakyrelu": + relu = mx.symbol.LeakyReLU(data=conv, slope=0.25, act_type='leaky') + elif alg == "gelu": + relu = mx.symbol.LeakyReLU(data=conv, act_type='gelu') + else: + relu = mx.symbol.Activation(data=conv, name=alg, act_type=alg) + conv1 = mx.symbol.Convolution(data=data, weight=weight, name='conv1', num_filter=64, + kernel=(3, 3), stride=(1, 1), no_bias=no_bias) + sum = relu + conv1 + return sum, attr + + +# conv + add fusion case +@pytest.mark.parametrize('data_shape', DATA_SHAPE) +@pytest.mark.parametrize('no_bias', [True, False]) +def test_pos_conv_add(no_bias, data_shape): + attr = {'conv': {'with_sum': 'true'}} + data, weight = head_symbol(data_shape) + conv1 = mx.symbol.Convolution(data=data, weight=weight, name='conv1', num_filter=64, + kernel=(3, 3), stride=(1, 1), no_bias=no_bias) + conv2 = mx.symbol.Convolution(data=data, name='conv2', num_filter=64, + kernel=(3, 3), stride=(1, 1)) + pool = mx.sym.Pooling(data=conv2, kernel=(1, 1), pool_type='avg', name='pool') + sum = conv1 + pool + check_fusion(sum, data_shape, attr) + + +# conv + add fusion case 2 +@pytest.mark.parametrize('data_shape', DATA_SHAPE) +@pytest.mark.parametrize('no_bias', [True, False]) +def test_pos_conv_add2(no_bias, data_shape): + attr = {'conv': {'with_sum': 'true'}} + data, weight = head_symbol(data_shape) + conv1 = mx.symbol.Convolution(data=data, weight=weight, name='conv1', num_filter=64, + kernel=(3, 3), stride=(1, 1), no_bias=no_bias) + conv2 = mx.symbol.Convolution(data=data, name='conv2', num_filter=64, + kernel=(3, 3), stride=(1, 1)) + pool = mx.sym.Pooling(data=conv2, kernel=(1, 1), pool_type='avg', name='pool') + sum = pool + conv1 + check_fusion(sum, data_shape, attr) + + +# conv + bn + act fusion case +@pytest.mark.parametrize('data_shape', DATA_SHAPE) +@pytest.mark.parametrize('alg,quantize', [ + ("relu", True), + ("sigmoid", True), + ("tanh", True), + ("softrelu", True), + ("relu6", True), + ("leakyrelu", True), + ("gelu", True) +]) +@pytest.mark.parametrize('no_bias', [True, False]) +def test_pos_conv_bn_act(no_bias, data_shape, alg, quantize): + attr = {'conv': {'with_bn': 'true', 'with_act': 'true'}} + data, weight = head_symbol(data_shape) + conv = mx.symbol.Convolution(data=data, weight=weight, name='conv', num_filter=64, + kernel=(3, 3), stride=(1, 1), no_bias=no_bias) + bn1 = mx.symbol.BatchNorm(data=conv, name="bn1") + if alg == "relu6": + relu = mx.symbol.clip(data=bn1, name='relu6', a_min=0, a_max=6) + elif alg == "leakyrelu": + relu = mx.symbol.LeakyReLU(data=bn1, slope=0.25, act_type='leaky') + elif alg == "gelu": + relu = mx.symbol.LeakyReLU(data=bn1, act_type='gelu') + else: + relu = mx.symbol.Activation(data=bn1, name=alg, act_type=alg) + check_fusion(relu, data_shape, attr, check_quantization=quantize) + + +# conv + bn + add + act fusion case +@pytest.mark.parametrize('data_shape', DATA_SHAPE) +@pytest.mark.parametrize('alg,quantize', [ + ("relu", True), + ("sigmoid", True), + ("tanh", True), + #("softrelu", True), #TODO(bgawrych): failing fusion check - difference in random single element + ("relu6", False), + ("leakyrelu", True), + ("gelu", False) +]) +@pytest.mark.parametrize('no_bias', [True, False]) +def test_pos_conv_bn_sum_act(no_bias, data_shape, alg, quantize): + attr = {'conv': {'with_sum': 'true', 'with_postsum_act': 'true', 'with_bn': 'true'}} + data, weight = head_symbol(data_shape) + conv = mx.symbol.Convolution(data=data, weight=weight, name='conv', num_filter=64, + kernel=(3, 3), stride=(1, 1), no_bias=no_bias) + bn1 = mx.symbol.BatchNorm(data=conv, name="bn1") + conv1 = mx.symbol.Convolution(data=data, weight=weight, name='conv1', num_filter=64, + kernel=(3, 3), stride=(1, 1)) + sum1 = bn1 + conv1 + if alg == "relu6": + relu = mx.symbol.clip(data=sum1, name='relu6', a_min=0, a_max=6) + elif alg == "leakyrelu": + relu = mx.symbol.LeakyReLU(data=sum1, slope=0.25, act_type='leaky') + elif alg == "gelu": + relu = mx.symbol.LeakyReLU(data=sum1, act_type='gelu') + else: + relu = mx.symbol.Activation(data=sum1, name=alg, act_type=alg) + check_fusion(relu, data_shape, attr, check_quantization=quantize) + +def run_sym_block(model, data_nd, dedup_subgraph): + data = mx.symbol.Variable('data', shape=data_nd.shape, dtype='float32') + sym = model.optimize_for(backend='MKLDNN', dedup_subgraph = dedup_subgraph, skip_infer = True) + sym_block = mx.gluon.SymbolBlock(sym, data) + initialize_block_params(sym_block, mx.init.One()) + return sym_block(data_nd) + +def conv_bn_sum(data_shape, reverse_sum_order): + attr = {'sg_mkldnn_conv_bn_add_0' : {'with_bn': 'true'}} + data = mx.symbol.Variable('data', shape=data_shape, dtype='float32') + weight = mx.symbol.Variable('conv_weight', dtype='float32') + conv = mx.symbol.Convolution(data=data, weight=weight, name='conv', num_filter=4, + kernel=(1, 1), stride=(1, 1), no_bias=True) + bn = mx.symbol.BatchNorm(data=conv, name="bn") + sum = bn + data if reverse_sum_order else data + bn + return sum, attr + + +@pytest.mark.parametrize('reverse_sum_order', [True, False]) +@pytest.mark.parametrize('dedup_subgraph', [True, False]) +def test_conv_bn_sum(reverse_sum_order, dedup_subgraph): + data_shape=(64, 4, 10, 10) + net, attrs = conv_bn_sum(data_shape=data_shape, reverse_sum_order=reverse_sum_order) + check_fusion(net, data_shape, attrs, out_types=['int8', 'auto'], dedup_subgraph=dedup_subgraph) + + +@pytest.mark.parametrize('reverse_sum_order', [False, True]) +@pytest.mark.parametrize('model_name', ['conv_bn_sum', 'mobilenetv2_struct']) +def test_dedup(reverse_sum_order, model_name): + shape = (64, 4, 10, 10) + data = mx.symbol.Variable('data', shape=shape, dtype='float32') + data_nd = mx.random.uniform(-1, 1, shape=shape, ctx=mx.cpu()) + if (model_name == 'mobilenetv2_struct'): + model, _ = mobilenetv2_struct(data_shape=shape, reverse_sum_order=reverse_sum_order) + else: + model, _ = conv_bn_sum(data_shape=shape, reverse_sum_order=reverse_sum_order) + out = run_sym_block(model, data_nd, dedup_subgraph = False) + out_dedup = run_sym_block(model, data_nd, dedup_subgraph = True) + assert_almost_equal(out.asnumpy(), out_dedup.asnumpy(), rtol=1e-3, atol=1e-1) + + +# single concat case +@pytest.mark.parametrize('data_shape', DATA_SHAPE) +@pytest.mark.parametrize('input_num,dim', [ + (2, -1), + (2, 1), + (4, 2), + (4, 3) +]) +@pytest.mark.parametrize('out_type', ['int8', 'auto']) +def test_pos_single_concat(data_shape, input_num, dim, out_type): + data = mx.symbol.Variable('data', shape=data_shape, dtype='float32') + inputs = [] + for i in range(input_num): + inputs.append(data) + concat = mx.symbol.Concat(*inputs, name="concat", dim=dim) + check_quantize(concat, data_shape, out_type, name='conv', + check_calibration=False) + +@pytest.mark.parametrize('data_shape', DATA_SHAPE) +@pytest.mark.parametrize('out_type', ['int8', 'auto']) +def test_pos_single_concat_pos_neg(data_shape, out_type): + data, weight = head_symbol(data_shape) + conv = mx.symbol.Convolution(data=data, weight=weight, name='conv', num_filter=4, + kernel=(1, 1), stride=(1, 1), no_bias=True) + relu = mx.symbol.Activation(data=conv, name='relu', act_type='relu') + inputs = [data, relu] + concat = mx.symbol.Concat(*inputs, name="concat", dim=1) + check_quantize(concat, data_shape, out_type, name='', check_calibration=False) + + +# concat scale alignment case +@pytest.mark.parametrize('data_shape', DATA_SHAPE) +@pytest.mark.parametrize('out_type', ['int8', 'auto']) +def test_pos_concat_scale_align(data_shape, out_type): + data, weight = head_symbol(data_shape) + conv1 = mx.symbol.Convolution(data=data, weight=weight, name='conv1', num_filter=64, + kernel=(3, 3), stride=(1, 1), no_bias=True) + conv2 = mx.symbol.Convolution(data=data, weight=weight * 2, name='conv2', num_filter=64, + kernel=(3, 3), stride=(1, 1), no_bias=True) + conv3 = mx.symbol.Convolution(data=data, weight=weight * 3, name='conv3', num_filter=64, + kernel=(3, 3), stride=(1, 1), no_bias=True) + conv4 = mx.symbol.Convolution(data=data, weight=weight * 4, name='conv4', num_filter=64, + kernel=(3, 3), stride=(1, 1), no_bias=True) + concat = mx.symbol.Concat(*[conv1, conv2, conv3, conv4], name="concat", dim=1) + check_quantize(concat, data_shape, out_type, check_calibration=True, + check_scale_align=True) + + +# mobilenetv2 case +def mobilenetv2_struct(data_shape, reverse_sum_order=False): + attr = {'sg_mkldnn_conv_bn_0' : {'with_bn': 'true'}} + data = mx.symbol.Variable('data', shape=data_shape, dtype='float32') + weight1 = mx.symbol.Variable('conv1_weight', dtype='float32') + weight2 = mx.symbol.Variable('conv2_weight', dtype='float32') + conv1 = mx.symbol.Convolution(data=data, weight=weight1, name='conv1', num_filter=64, + kernel=(1, 1), stride=(1, 1), no_bias=True) + bn1 = mx.symbol.BatchNorm(data=conv1, name="bn1") + conv2 = mx.symbol.Convolution(data=bn1, weight=weight2, name='conv2', num_filter=64, + kernel=(1, 1), stride=(1, 1), no_bias=True) + bn2 = mx.symbol.BatchNorm(data=conv2, name="bn2") + sum = bn2 + bn1 if reverse_sum_order else bn1 + bn2 + return sum, attr + +def tail_neg_symbol(sym1, sym2): + fc1 = mx.sym.FullyConnected(data=sym1, num_hidden=10, flatten=True, name='fc1') + fc2 = mx.sym.FullyConnected(data=sym2, num_hidden=10, flatten=True, name='fc2') + concat = mx.sym.Concat(*[fc1, fc2], name="concat") + sym = mx.sym.softmax(data=concat, name='softmax') + return sym + +# conv + bn can't be fusion case +# eg.1 +# conv --------- > bn +# | +# | +# -------------> [custom op] +def neg_conv_bn(data_shape): + syms = [] + attrs = [] + excluded_attrs = [] + data, weight = head_symbol(data_shape) + + # eg.1 ([custom op] = pool) + conv = mx.symbol.Convolution(data=data, weight=weight, name='conv', num_filter=64, kernel=(3, 3), stride=(1, 1)) + bn1 = mx.symbol.BatchNorm(data=conv, name="bn1") + pool = mx.sym.Pooling(data=conv, kernel=(4, 4), pool_type='avg', name='pool') + sym = tail_neg_symbol(bn1, pool) + + syms.append(sym) + attrs.append([]) + excluded_attrs.append([]) + return syms, attrs, excluded_attrs + +# conv + relu can't be fusion case +# eg.1 +# conv -----------> relu +# | +# | +# ---------------> [custom op] +def neg_conv_relu(data_shape): + syms = [] + attrs = [] + excluded_attrs = [] + data, weight = head_symbol(data_shape) + + # eg.1 ([custom op] = pool) + conv = mx.symbol.Convolution(data=data, weight=weight, name='conv', num_filter=64, kernel=(3, 3), stride=(1, 1)) + relu = mx.symbol.Activation(data=conv, name='relu', act_type="relu") + pool = mx.sym.Pooling(data=conv, kernel=(4, 4), pool_type='avg', name='pool') + sym = tail_neg_symbol(relu, pool) + + syms.append(sym) + attrs.append([]) + excluded_attrs.append([]) + return syms, attrs, excluded_attrs + +# conv + add can't be fusion case +# eg.1 +# ---------------> [custom op] +# | +# | +# conv -----------> add +# | +# | +# added ------------> +def neg_conv_add(data_shape): + syms = [] + attrs = [] + excluded_attrs = [] + val = mx.symbol.Variable('addval') + data, weight = head_symbol(data_shape) + + # eg.1 ([custom op] = pool, [added op] = val) + conv = mx.symbol.Convolution(data=data, weight=weight, name='conv', num_filter=64, kernel=(3, 3), stride=(1, 1)) + sum1 = conv + val + pool = mx.sym.Pooling(data=conv, kernel=(4, 4), pool_type='avg', name='pool') + sym = tail_neg_symbol(sum1, pool) + + syms.append(sym) + attrs.append([]) + excluded_attrs.append('with_sum') + return syms, attrs, excluded_attrs + +# conv + bn + relu can't be fusion case +# eg.1 +# --------------> [custom op] +# | +# conv -----------> bn -----------> relu +# +# eg.2 +# --------------> [custom op] +# | +# conv -----------> bn -----------> relu +def neg_conv_bn_relu(data_shape): + syms = [] + attrs = [] + excluded_attrs = [] + data, weight = head_symbol(data_shape) + + # eg.1 ([custom op] = pool11) + conv11 = mx.symbol.Convolution(data=data, weight=weight, name='conv11', num_filter=64, kernel=(3, 3), stride=(1, 1)) + bn11 = mx.symbol.BatchNorm(data=conv11, name="bn11") + relu11 = mx.symbol.Activation(data=bn11, name='relu11', act_type="relu") + pool11 = mx.sym.Pooling(data=conv11, kernel=(4, 4), pool_type='avg', name='pool11') + sym1 = tail_neg_symbol(relu11, pool11) + + syms.append(sym1) + attrs.append([]) + excluded_attrs.append([]) + + # eg.2 ([custom op] = pool) + conv21 = mx.symbol.Convolution(data=data, weight=weight, name='conv21', num_filter=64, kernel=(3, 3), stride=(1, 1)) + bn21 = mx.symbol.BatchNorm(data=conv21, name="bn21") + relu21 = mx.symbol.Activation(data=bn21, name='relu21', act_type="relu") + pool21 = mx.sym.Pooling(data=bn21, kernel=(4, 4), pool_type='avg', name='pool21') + sym2 = tail_neg_symbol(relu21, pool21) + + syms.append(sym2) + attrs.append(['with_bn']) + excluded_attrs.append(['with_act']) + return syms, attrs, excluded_attrs + +# conv + bn + add + relu can't be fusion case +# eg.1 +# --------------> [custom op] +# | +# conv -----------> bn -----------> add -----------> relu +# +# eg.2 +# -------------> [custom op] +# | +# conv -----------> bn -----------> add -----------> relu +# +# eg.3 +# --------------> [custom op] +# | +# conv -----------> bn -----------> add -----------> relu +def neg_conv_bn_add_relu(data_shape): + syms = [] + attrs = [] + excluded_attrs = [] + addVal = mx.symbol.Variable('addval') + data, weight = head_symbol(data_shape) + + # eg.1 + conv11 = mx.symbol.Convolution(data=data, weight=weight, name='conv11', num_filter=64, kernel=(3, 3), stride=(1, 1)) + bn11 = mx.symbol.BatchNorm(data=conv11, name="bn11") + sum11 = bn11 + addVal + relu11 = mx.symbol.Activation(data=sum11, name='relu11', act_type="relu") + pool11 = mx.sym.Pooling(data=conv11, kernel=(4, 4), pool_type='avg', name='pool11') + sym1 = tail_neg_symbol(relu11, pool11) + + syms.append(sym1) + attrs.append([]) + excluded_attrs.append(['with_sum', 'with_postsum_act', 'with_bn']) + + # eg.2 + conv21 = mx.symbol.Convolution(data=data, weight=weight, name='conv21', num_filter=64, kernel=(3, 3), stride=(1, 1)) + bn21 = mx.symbol.BatchNorm(data=conv21, name="bn21") + sum21 = bn21 + addVal + relu21 = mx.symbol.Activation(data=sum21, name='relu21', act_type="relu") + pool21 = mx.sym.Pooling(data=bn21, kernel=(4, 4), pool_type='avg', name='pool21') + sym2 = tail_neg_symbol(relu21, pool21) + + syms.append(sym2) + attrs.append(['with_bn']) + excluded_attrs.append(['with_sum', 'with_postsum_act']) + + # eg.3 + conv31 = mx.symbol.Convolution(data=data, weight=weight, name='conv31', num_filter=64, kernel=(3, 3), stride=(1, 1)) + bn31 = mx.symbol.BatchNorm(data=conv31, name="bn31") + sum31 = bn31 + addVal + relu31 = mx.symbol.Activation(data=sum31, name='relu31', act_type="relu") + pool31 = mx.sym.Pooling(data=sum31, kernel=(4, 4), pool_type='avg', name='pool31') + sym3 = tail_neg_symbol(relu31, pool31) + + syms.append(sym3) + attrs.append(['with_bn', 'with_sum']) + excluded_attrs.append(['with_postsum_act']) + return syms, attrs, excluded_attrs + +def single_fc(no_bias, data_shape, flatten=True): + attr = {'fc': {}} + data, weight = head_symbol(data_shape) + fc = mx.symbol.FullyConnected(name='fc', data=data, weight=weight, num_hidden=64, + no_bias=no_bias, flatten=flatten) + return fc, attr + +# fc + eltwise fusion case +def fc_eltwise(no_bias, data_shape, flatten=True, alg='relu'): + assert alg in fc_post_ops_list + + attr = {'fc': {'with_eltwise': 'true'}} + data, weight = head_symbol(data_shape) + fc = mx.symbol.FullyConnected(name='fc', data=data, weight=weight, num_hidden=64, + no_bias=no_bias, flatten=flatten) + if alg in ['relu', 'sigmoid', 'tanh', 'softrelu']: + sym = mx.symbol.Activation(data=fc, name='act', act_type=alg) + elif alg == 'square': + sym = mx.symbol.square(data=fc, name='square') + elif alg == 'square_root': + sym = mx.symbol.sqrt(data=fc, name='sqrt') + elif alg == 'abs': + sym = mx.symbol.abs(data=fc, name='abs') + elif alg == 'exp': + sym = mx.symbol.exp(data=fc, name='exp') + else: + sym = mx.symbol.clip(data=fc, name='bounded_relu', a_min=0, a_max=1.0) + + return sym, attr + +# fc + relu can't be fusion case +# eg.1 +# fc -----------> relu +# | +# | +# ---------------> [custom op] +def neg_fc_relu(no_bias, data_shape, flatten=True): + syms = [] + attrs = [] + excluded_attrs = [] + data, weight = head_symbol(data_shape) + + # eg.1 ([custom op] = pool) + fc = mx.symbol.FullyConnected(name='fc', data=data, weight=weight, num_hidden=64, + no_bias=no_bias, flatten=flatten) + relu = mx.symbol.Activation(data=fc, name='relu', act_type="relu") + sigmoid = mx.symbol.Activation(data=fc, name='sigmoid', act_type="sigmoid") + sym = tail_neg_symbol(relu, sigmoid) + + syms.append(sym) + attrs.append([]) + excluded_attrs.append([]) + return syms, attrs, excluded_attrs + + +def test_pos_conv_act(): + act_list = {"relu": True, + "sigmoid": True, + "tanh": True, + "softrelu": True, + "relu6": True, + "leakyrelu": True, + "gelu": True} + for data_shape in DATA_SHAPE: + for (alg, quantize) in act_list.items(): + net, attrs = conv_act(False, data_shape, alg) + check_fusion(net, data_shape, attrs, check_quantization=quantize) + net, attrs = conv_act(True, data_shape, alg) + check_fusion(net, data_shape, attrs, check_quantization=quantize) + + +def test_pos_conv_bn(): + for data_shape in DATA_SHAPE: + net, attrs = conv_bn(False, data_shape) + check_fusion(net, data_shape, attrs) + net, attrs = conv_bn(True, data_shape) + check_fusion(net, data_shape, attrs) + + +@pytest.mark.parametrize('data_shape', DATA_SHAPE) +@pytest.mark.parametrize('reverse_sum_order', [True, False]) +@pytest.mark.parametrize('dedup_subgraph', [True, False]) +def test_mobilenetv2_struct(data_shape, reverse_sum_order, dedup_subgraph): + net, attrs = mobilenetv2_struct(data_shape, reverse_sum_order=reverse_sum_order) + check_fusion(net, data_shape, attrs, out_types=['int8', 'auto'], dedup_subgraph=dedup_subgraph) + + +@pytest.mark.parametrize('data_shape', DATA_SHAPE) +def test_neg_conv_bn(data_shape): + syms, attrs, excluded_attrs = neg_conv_bn(data_shape) + check_neg_fusion(syms, attrs, excluded_attrs, data_shape) + + +@pytest.mark.parametrize('data_shape', DATA_SHAPE) +def test_neg_conv_relu(data_shape): + syms, attrs, excluded_attrs = neg_conv_relu(data_shape) + check_neg_fusion(syms, attrs, excluded_attrs, data_shape) + + +@pytest.mark.parametrize('data_shape', DATA_SHAPE) +def test_neg_conv_add(data_shape): + syms, attrs, excluded_attrs = neg_conv_add(data_shape) + check_neg_fusion(syms, attrs, excluded_attrs, data_shape) + + +@pytest.mark.parametrize('data_shape', DATA_SHAPE) +def test_neg_conv_bn_relu(data_shape): + syms, attrs, excluded_attrs = neg_conv_bn_relu(data_shape) + check_neg_fusion(syms, attrs, excluded_attrs, data_shape) + + +@pytest.mark.parametrize('data_shape', DATA_SHAPE) +def test_neg_conv_bn_add_relu(data_shape): + syms, attrs, excluded_attrs = neg_conv_bn_add_relu(data_shape) + check_neg_fusion(syms, attrs, excluded_attrs, data_shape) + + +@pytest.mark.parametrize('data_shape', DATA_SHAPE) +@pytest.mark.parametrize('no_bias', [True, False]) +@pytest.mark.parametrize('flatten', [True, False]) +def test_single_fc(data_shape, no_bias, flatten): + syms, attrs = single_fc(no_bias, data_shape, flatten) + check_fusion(syms, data_shape, attrs, check_quantization=flatten) + + +@pytest.mark.parametrize('data_shape', DATA_SHAPE) +@pytest.mark.parametrize('no_bias', [True, False]) +@pytest.mark.parametrize('flatten', [True, False]) +@pytest.mark.parametrize('alg', fc_post_ops_list) +def test_fc_eltwise(data_shape, no_bias, flatten, alg): + syms, attrs = fc_eltwise(no_bias, data_shape, flatten, alg) + check_fusion(syms, data_shape, attrs, check_quantization=flatten) + + +@pytest.mark.parametrize('data_shape', DATA_SHAPE) +@pytest.mark.parametrize('no_bias', [True, False]) +@pytest.mark.parametrize('flatten', [True, False]) +def test_neg_fc_relu(data_shape, no_bias, flatten): + syms, attrs, excluded_attrs = neg_fc_relu(no_bias, data_shape, flatten) + check_neg_fusion(syms, attrs, excluded_attrs, data_shape, name='fc') + + +@pytest.mark.parametrize('data_min,data_max,weight_min,weight_max', [ + (-1, 1, 0, 0), + (-1, 1, -1e-6, +1e-6), + (0, 0, 1, 1), + (-1e-6, +1e-6, -1, 1), + (-1e-6, +1e-6, -1e-6, +1e-6), + (0, 0, 0, 0) +]) +def test_quantized_conv_bias_overflow(data_min, data_max, weight_min, weight_max): + data_shape = (1, 32, 2, 2) + data = mx.symbol.Variable('data', shape=data_shape, dtype='float32') + weight = mx.symbol.Variable('weight', dtype='float32') + bias = mx.symbol.Variable('bias', dtype='float32') + sym = mx.symbol.Convolution(data=data, weight=weight, bias=bias, name='conv', num_filter=64, + kernel=(1, 1), stride=(1, 1)) + data_nd = mx.random.uniform(data_min, data_max, shape=data_shape, ctx=mx.cpu()) + weight_nd = mx.random.uniform(weight_min, weight_max, shape=[64, 32, 1, 1], ctx=mx.cpu()) + bias_nd = mx.random.uniform(-1, +1, shape=[64], ctx=mx.cpu()) + arg_params = { + 'weight': weight_nd, + 'bias': bias_nd + } + + ex = sym._bind(mx.cpu(), arg_params, args_grad=None) + ex.forward(data = data_nd) + ex.outputs[0].wait_to_read() + sym_sg = sym.optimize_for(QUANTIZE_SG_PASS_NAME, dedup_subgraph=True, skip_infer=True) + + calib_data = mx.gluon.data.DataLoader(data_nd, batch_size=data_shape[0]) + qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(sym=sym_sg, + arg_params=arg_params, + aux_params={}, + ctx=mx.cpu(), + excluded_sym_names=None, + excluded_op_names=None, + quantized_dtype='int8', + calib_mode='naive', + calib_data=calib_data, + num_calib_batches=1, + quantize_mode='full') + qsym = qsym.optimize_for(QUANTIZE_SG_PASS_NAME, dedup_subgraph=True, skip_infer=True) + qarg_params['data'] = data_nd + qex = qsym._bind(mx.cpu(), qarg_params, args_grad=None) + qex.forward() + qex.outputs[0].wait_to_read() + assert_almost_equal_with_err(ex.outputs[0].asnumpy(), qex.outputs[0].asnumpy(), + rtol=1e-2, atol=1e-2, etol=0.01) + + +@pytest.mark.parametrize('data_min,data_max,weight_min,weight_max', [ + (-1, 1, 0, 0), + (-1, 1, -1e-6, +1e-6), + (0, 0, 1, 1), + (-1e-6, +1e-6, -1, 1), + (-1e-6, +1e-6, -1e-6, +1e-6), + (0, 0, 0, 0) +]) +def test_quantized_fc_bias_overflow(data_min, data_max, weight_min, weight_max): + data_shape = (1, 32) + data = mx.symbol.Variable('data', shape=data_shape, dtype='float32') + weight = mx.symbol.Variable('weight', dtype='float32') + bias = mx.symbol.Variable('bias', dtype='float32') + sym = mx.symbol.FullyConnected(data=data, weight=weight, bias=bias, name='fc', num_hidden=64) + data_nd = mx.random.uniform(data_min, data_max, shape=data_shape, ctx=mx.cpu()) + weight_nd = mx.random.uniform(weight_min, weight_max, shape=[64, 32], ctx=mx.cpu()) + bias_nd = mx.random.uniform(-1, +1, shape=[64], ctx=mx.cpu()) + arg_params = { + 'weight': weight_nd, + 'bias': bias_nd + } + + ex = sym._bind(mx.cpu(), arg_params, args_grad=None) + ex.forward(data = data_nd) + ex.outputs[0].wait_to_read() + sym_sg = sym.optimize_for(QUANTIZE_SG_PASS_NAME, dedup_subgraph=True, skip_infer=True) + + calib_data = mx.gluon.data.DataLoader(data_nd, batch_size=1) + qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(sym=sym_sg, + arg_params=arg_params, + aux_params={}, + ctx=mx.cpu(), + excluded_sym_names=None, + excluded_op_names=None, + quantized_dtype='int8', + calib_mode='naive', + calib_data=calib_data, + num_calib_batches=1, + quantize_mode='full') + qarg_params['data'] = data_nd + qsym = qsym.optimize_for(QUANTIZE_SG_PASS_NAME, dedup_subgraph=True, skip_infer=True) + qex = qsym._bind(mx.cpu(), qarg_params, args_grad=None) + qex.forward() + qex.outputs[0].wait_to_read() + assert_almost_equal_with_err(ex.outputs[0].asnumpy(), qex.outputs[0].asnumpy(), + rtol=1e-2, atol=1e-2, etol=0.01) + @pytest.mark.parametrize('axis', [0, 1, 2, 3]) def test_bn_relu_fusion(axis): dummy_data = mx.nd.uniform(-1.0, 1.0, shape=(32, 3, 224, 224)) diff --git a/tests/python/quantization/common.py b/tests/python/quantization/common.py new file mode 120000 index 000000000000..dccb90b10675 --- /dev/null +++ b/tests/python/quantization/common.py @@ -0,0 +1 @@ +../unittest/common.py \ No newline at end of file diff --git a/tests/python/quantization/test_quantization.py b/tests/python/quantization/test_quantization.py new file mode 100644 index 000000000000..948b2d3b8cee --- /dev/null +++ b/tests/python/quantization/test_quantization.py @@ -0,0 +1,1207 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Some of the tests using CUDNN require a special GPU instruction called dp4a. +Ref: http://images.nvidia.com/content/pdf/tesla/184457-Tesla-P4-Datasheet-NV-Final-Letter-Web.pdf +""" +import os +import mxnet as mx +import numpy as np +from mxnet.gluon.model_zoo import vision +from mxnet.test_utils import assert_almost_equal, assert_exception, rand_ndarray, rand_shape_nd, same, DummyIter +from common import xfail_when_nonstandard_decimal_separator +from mxnet.io import NDArrayIter +import unittest +import operator + + +def initialize_block_params(block, initializer): + for name, param in block.collect_params('.*gamma|.*running_var|.*moving_var').items(): + param.initialize(mx.init.Constant(1)) + for name, param in block.collect_params('.*beta|.*bias|.*moving_mean|.*running_mean').items(): + param.initialize(mx.init.Constant(0)) + for name, param in block.collect_params('.*weight').items(): + param.initialize(initializer) + +def collect_block_args_aux(block, sym): + arg_params, aux_params = dict(), dict() + for k, v in block.collect_params().items(): + if k in sym.list_arguments(): + arg_params[k]= v._reduce() + elif k in sym.list_auxiliary_states(): + aux_params[k]= v._reduce() + return arg_params, aux_params + +def is_test_for_gpu(): + return mx.current_context().device_type == 'gpu' + + +def is_test_for_mkldnn(): + return (mx.current_context().device_type == 'cpu' + and os.environ.get('ENABLE_MKLDNN_QUANTIZATION_TEST') == '1') + + +def is_test_for_native_cpu(): + return (mx.current_context().device_type == 'cpu' + and os.environ.get('ENABLE_MKLDNN_QUANTIZATION_TEST') == None) + + +def test_quantize_float32_to_int8(): + shape = rand_shape_nd(4) + data = rand_ndarray(shape, 'default', dtype='float32') + min_range = mx.nd.min(data) + max_range = mx.nd.max(data) + qdata, min_val, max_val = mx.nd.contrib.quantize(data, min_range, max_range, out_type='int8') + data_np = data.asnumpy() + min_range = min_range.asscalar() + max_range = max_range.asscalar() + real_range = np.maximum(np.abs(min_range), np.abs(max_range)) + quantized_range = 127.0 + scale = quantized_range / real_range + assert qdata.dtype == np.int8 + assert min_val.dtype == np.float32 + assert max_val.dtype == np.float32 + assert same(min_val.asscalar(), -real_range) + assert same(max_val.asscalar(), real_range) + qdata_np = (np.sign(data_np) * np.minimum(np.abs(data_np) * scale + 0.5, quantized_range)).astype(np.int8) + assert_almost_equal(qdata.asnumpy(), qdata_np, atol = 1) + + +def test_dequantize_int8_to_float32(): + + def get_test_data(real_range, qdata_np): + qdata = mx.nd.array(qdata_np, dtype=np.int8) + min_range = mx.nd.array([-real_range], dtype=np.float32) + max_range = mx.nd.array([real_range], dtype=np.float32) + return qdata, min_range, max_range + + def baseline_dequantization(qdata, real_range, qdata_np): + quantized_range = 127.0 + scale = real_range / quantized_range + data_np = qdata_np * scale + return data_np + + def test_nd_array_dequantization(qdata, min_range, max_range, expected_result): + data = mx.nd.contrib.dequantize(qdata, min_range, max_range, out_type='float32') + assert data.dtype == np.float32 + assert_almost_equal(data.asnumpy(), expected_result, atol = 1) + + def test_symbolic_api_dequantization(qdata, min_range, max_range, expected_result): + sym_data = mx.sym.Variable('data') + sym_min_range = mx.sym.Variable('min_range') + sym_max_range = mx.sym.Variable('max_range') + dequant = mx.sym.contrib.dequantize(sym_data, sym_min_range, + sym_max_range, out_type='float32') + out = dequant._bind(ctx=mx.current_context(), + args={'data':qdata, 'min_range':min_range, 'max_range':max_range}) + data = out.forward()[0] + assert data.dtype == np.float32 + assert_almost_equal(data.asnumpy(), expected_result, atol = 1) + + real_range = 128 + shape = rand_shape_nd(4) + qdata_np = np.random.uniform(low=-127, high=127, size=shape).astype(dtype=np.int8) + qdata, min_range, max_range = get_test_data(real_range, qdata_np) + expected_result = baseline_dequantization(qdata, real_range, qdata_np) + # test nd array implementation. + test_nd_array_dequantization(qdata, min_range, max_range, expected_result) + # test symbolic api implementaion. + test_symbolic_api_dequantization(qdata, min_range, max_range, expected_result) + + +def test_requantize_int32_to_int8(): + def quantized_int32_to_float(qdata, min_range, max_range): + assert qdata.dtype == 'int32' + quantized_range = np.iinfo('int32').max + real_range = np.maximum(np.abs(min_range), np.abs(max_range)) + scale = float(real_range) / float(quantized_range) + return qdata.astype('float32') * scale + + def float_to_quantized_int8(data, min_range, max_range): + assert data.dtype == 'float32' + real_range = np.maximum(np.abs(min_range), np.abs(max_range)) + quantized_range = np.iinfo('int8').max + scale = float(quantized_range) / float(real_range) + return (np.sign(data) * np.minimum(np.abs(data) * scale + 0.5, quantized_range)).astype('int8') + + def requantize(qdata, min_data, max_data, real_range): + data = quantized_int32_to_float(qdata, min_data, max_data) + output = float_to_quantized_int8(data, -real_range, real_range) + return output, -real_range, real_range + + def requantize_baseline(qdata, min_data, max_data, min_calib_range=None, max_calib_range=None): + if min_calib_range is not None and max_calib_range is not None: + real_range = np.maximum(np.abs(min_calib_range), np.abs(max_calib_range)) + return requantize(qdata, min_data, max_data, real_range) + else: + min_range = quantized_int32_to_float(np.min(qdata), min_data, max_data) + max_range = quantized_int32_to_float(np.max(qdata), min_data, max_data) + return requantize(qdata, min_data, max_data, np.maximum(np.abs(min_range), np.abs(max_range))) + + def check_requantize(shape, min_calib_range=None, max_calib_range=None): + qdata = mx.nd.random.uniform(low=-1000.0, high=1000.0, shape=shape).astype('int32') + min_range = mx.nd.array([-1010.0]) + max_range = mx.nd.array([1020.0]) + if min_calib_range is None or max_calib_range is None: + qdata_int8, min_output, max_output = mx.nd.contrib.requantize(qdata, min_range, max_range) + else: + qdata_int8, min_output, max_output = mx.nd.contrib.requantize(qdata, min_range, max_range, + min_calib_range=min_calib_range, + max_calib_range=max_calib_range) + + qdata_int8_np, min_output_np, max_output_np = requantize_baseline(qdata.asnumpy(), min_range.asscalar(), + max_range.asscalar(), + min_calib_range=min_calib_range, + max_calib_range=max_calib_range) + assert_almost_equal(qdata_int8.asnumpy(), qdata_int8_np, atol = 1) + assert_almost_equal(min_output.asnumpy(), np.array([min_output_np])) + assert_almost_equal(max_output.asnumpy(), np.array([max_output_np])) + + def check_requantize_with_symbol(shape, min_calib_range=None, max_calib_range=None): + qdata = mx.nd.random.uniform(low=-1000.0, high=1000.0, shape=shape).astype('int32') + min_range = mx.nd.array([-1010.0]) + max_range = mx.nd.array([1020.0]) + sym_data = mx.sym.Variable('data') + sym_min_range = mx.sym.Variable('min_range') + sym_max_range = mx.sym.Variable('max_range') + if min_calib_range is None or max_calib_range is None: + requant = mx.sym.contrib.requantize(sym_data, sym_min_range, sym_max_range) + out = requant._bind(ctx=mx.current_context(), + args={'data':qdata, 'min_range':min_range, + 'max_range':max_range}) + qdata_int8, min_output, max_output = out.forward() + else: + requant = mx.sym.contrib.requantize(sym_data, sym_min_range, sym_max_range, + min_calib_range=min_calib_range, + max_calib_range=max_calib_range) + out = requant._bind(ctx=mx.current_context(), args={'data':qdata, 'min_range':min_range, + 'max_range':max_range}) + qdata_int8, min_output, max_output = out.forward() + + qdata_int8_np, min_output_np, max_output_np = requantize_baseline(qdata.asnumpy(), min_range.asscalar(), + max_range.asscalar(), + min_calib_range=min_calib_range, + max_calib_range=max_calib_range) + assert_almost_equal(qdata_int8.asnumpy(), qdata_int8_np, atol = 1) + assert_almost_equal(min_output.asnumpy(), np.array([min_output_np])) + assert_almost_equal(max_output.asnumpy(), np.array([max_output_np])) + + # test with symbol API. + check_requantize_with_symbol((3, 4, 10, 10)) + check_requantize_with_symbol((32, 3, 23, 23)) + check_requantize_with_symbol((3, 4, 10, 10), min_calib_range=-1050.0, max_calib_range=1040.0) + check_requantize_with_symbol((32, 3, 23, 23), min_calib_range=-134.349, max_calib_range=523.43) + # Test with nd array API + check_requantize((3, 4, 10, 10)) + check_requantize((32, 3, 23, 23)) + check_requantize((3, 4, 10, 10), min_calib_range=-1050.0, max_calib_range=1040.0) + check_requantize((32, 3, 23, 23), min_calib_range=-134.349, max_calib_range=523.43) + + +def test_quantized_conv(): + def check_quantized_conv(data_shape, kernel, num_filter, pad, stride, dilate, no_bias, qdtype): + if is_test_for_native_cpu(): + print('skipped testing quantized_conv for native cpu since it is not supported yet') + return + elif is_test_for_mkldnn(): + # (TODO)Xinyu: https://github.com/apache/incubator-mxnet/issues/16830 + print('skipped testing quantized_conv for mkldnn cpu since it is a flaky case') + return + elif qdtype == 'uint8' and is_test_for_gpu(): + print('skipped testing quantized_conv for gpu uint8 since it is not supported yet') + return + elif is_test_for_gpu() and len(data_shape) != 4: + print('skipped testing quantized_conv for gpu 5d layout since it is not supported yet') + return + + # run fp32 conv + data = mx.sym.Variable(name='data', shape=data_shape, dtype='float32') + conv = mx.sym.Convolution(data=data, kernel=kernel, num_filter=num_filter, pad=pad, stride=stride, + dilate=dilate, no_bias=no_bias, cudnn_off=False, name='conv') + arg_shapes, _, _ = conv.infer_shape(data=data_shape) + arg_names = conv.list_arguments() + conv_exe_fp32 = conv._simple_bind(ctx=mx.current_context(), grad_req='null') + if qdtype == 'uint8': + data_low = 0.0 + data_high = 127.0 + else: + data_low = -127.0 + data_high = 127.0 + conv_exe_fp32.arg_dict[arg_names[0]][:] = mx.nd.random.uniform(low=data_low, high=data_high, + shape=data_shape).astype('int32') + conv_exe_fp32.arg_dict[arg_names[1]][:] = mx.nd.random.uniform(low=-127.0, high=127.0, + shape=arg_shapes[1]).astype('int32') + if not no_bias: + conv_exe_fp32.arg_dict[arg_names[2]][:] = mx.nd.random.uniform(low=-127.0, high=127.0, + shape=arg_shapes[2]).astype('int32') + output = conv_exe_fp32.forward()[0] + + # run quantized conv + qdata = mx.sym.Variable(name='qdata', shape=data_shape, dtype=qdtype) + qweight = mx.sym.Variable(name='qweight', dtype='int8') + min_data = mx.sym.Variable(name='min_data') + max_data = mx.sym.Variable(name='max_data') + min_weight = mx.sym.Variable(name='min_weight') + max_weight = mx.sym.Variable(name='max_weight') + quantized_conv = mx.sym.contrib.quantized_conv(data=qdata, weight=qweight, min_data=min_data, + max_data=max_data, min_weight=min_weight, + max_weight=max_weight, kernel=kernel, + num_filter=num_filter, pad=pad, stride=stride, + dilate=dilate, no_bias=no_bias) + qarg_names = quantized_conv.list_arguments() + type_dict = None + if not no_bias: + type_dict = {qarg_names[2]: 'int8'} + conv_exe_int8 = quantized_conv._simple_bind(ctx=mx.current_context(), type_dict=type_dict, grad_req='null') + conv_exe_int8.arg_dict[qarg_names[0]][:] = conv_exe_fp32.arg_dict[arg_names[0]].astype(qdtype) + conv_exe_int8.arg_dict[qarg_names[1]][:] = conv_exe_fp32.arg_dict[arg_names[1]].astype('int8') + quantized_range = 127.0 + if no_bias: + conv_exe_int8.arg_dict[qarg_names[2]][:] = -quantized_range + conv_exe_int8.arg_dict[qarg_names[3]][:] = quantized_range + conv_exe_int8.arg_dict[qarg_names[4]][:] = -quantized_range + conv_exe_int8.arg_dict[qarg_names[5]][:] = quantized_range + else: + conv_exe_int8.arg_dict[qarg_names[2]][:] = conv_exe_fp32.arg_dict[arg_names[2]].astype('int8') + conv_exe_int8.arg_dict[qarg_names[3]][:] = -quantized_range + conv_exe_int8.arg_dict[qarg_names[4]][:] = quantized_range + conv_exe_int8.arg_dict[qarg_names[5]][:] = -quantized_range + conv_exe_int8.arg_dict[qarg_names[6]][:] = quantized_range + conv_exe_int8.arg_dict[qarg_names[7]][:] = -quantized_range + conv_exe_int8.arg_dict[qarg_names[8]][:] = quantized_range + qoutput, min_range, max_range = conv_exe_int8.forward() + + if no_bias: + assert_almost_equal(output.asnumpy(), qoutput.asnumpy(), atol = 1) + else: + # with adding bias, accuracy loss should not be greater than one + diff = mx.nd.abs(output - qoutput.astype(output.dtype)) + cond = mx.nd.lesser(2, diff).sum().asscalar() + assert cond == 0 + + for qdtype in ['int8', 'uint8']: + check_quantized_conv((3, 4, 28, 28), (3, 3), 128, (1, 1), (1, 1), (1, 1), True, qdtype) + check_quantized_conv((3, 4, 28, 28), (3, 3), 128, (1, 1), (1, 1), (1, 1), False, qdtype) + check_quantized_conv((1, 3, 4, 28, 28), (1, 3, 3), 128, (1, 1, 1), (1, 1, 1), (1, 1, 1), False, qdtype) + check_quantized_conv((1, 3, 4, 28, 28), (1, 3, 3), 128, (1, 1, 1), (1, 1, 1), (1, 1, 1), True, qdtype) + check_quantized_conv((1, 3, 4, 28, 28), (1, 3, 3), 128, (1, 1, 1), (1, 1, 1), (2, 2, 2), False, qdtype) + check_quantized_conv((1, 3, 4, 28, 28), (1, 3, 3), 128, (1, 1, 1), (1, 1, 1), (2, 2, 2), True, qdtype) + + +def test_quantized_elemwise_add(): + def check_quantized_elemwise_add(data_shape, qtype): + if is_test_for_native_cpu(): + print('skipped testing quantized_elemwise_add for native cpu since it is not supported yet') + return + elif qtype != 'uint8' and qtype != 'int8': + print('skipped testing quantized_elemwise_add for not supported data type') + return + elif is_test_for_gpu(): + print('skipped testing quantized_elemwise_add for gpu since it is not supported yet') + return + + dataA = mx.sym.Variable(name='dataA', shape=data_shape, dtype='float32') + dataB = mx.sym.Variable(name='dataB', shape=data_shape, dtype='float32') + elemwise_add_fp32 = mx.sym.elemwise_add(dataA, dataB) + arg_names = elemwise_add_fp32.list_arguments() + elemwise_add_fp32_exe = elemwise_add_fp32._simple_bind(ctx=mx.current_context(), grad_req='null') + if qtype == 'uint8': + data_low = 0.0 + data_high = 255.0 + else: + data_low = -127.0 + data_high = 127.0 + + dataA_val = mx.nd.random.uniform(low=data_low, high=data_high, shape=data_shape).astype('int32') + dataB_val = mx.nd.random.uniform(low=data_low, high=data_high, shape=data_shape).astype('int32') + elemwise_add_fp32_exe.arg_dict[arg_names[0]][:] = dataA_val + + elemwise_add_fp32_exe.arg_dict[arg_names[1]][:] = dataB_val + + output = elemwise_add_fp32_exe.forward()[0] + qdataA = mx.sym.Variable(name='qdataA', shape=data_shape, dtype=qtype) + qdataB = mx.sym.Variable(name='qdataB', shape=data_shape, dtype=qtype) + min_dataA = mx.sym.Variable(name='min_dataA', dtype='float32') + max_dataA = mx.sym.Variable(name='max_dataA', dtype='float32') + min_dataB = mx.sym.Variable(name='min_dataB', dtype='float32') + max_dataB = mx.sym.Variable(name='max_dataB', dtype='float32') + quantized_elemwise_add = mx.sym.contrib.quantized_elemwise_add(qdataA, qdataB, min_dataA, max_dataA, min_dataB, max_dataB) + elemwise_add_int8_exe = quantized_elemwise_add._simple_bind(ctx=mx.current_context(), grad_req='null') + qarg_names = quantized_elemwise_add.list_arguments() + elemwise_add_int8_exe.arg_dict[qarg_names[0]][:] = elemwise_add_fp32_exe.arg_dict[arg_names[0]].astype(qtype) + elemwise_add_int8_exe.arg_dict[qarg_names[1]][:] = elemwise_add_fp32_exe.arg_dict[arg_names[1]].astype(qtype) + quantized_range = 127.0 + elemwise_add_int8_exe.arg_dict[qarg_names[2]][:] = data_low + elemwise_add_int8_exe.arg_dict[qarg_names[3]][:] = data_high + elemwise_add_int8_exe.arg_dict[qarg_names[4]][:] = data_low + elemwise_add_int8_exe.arg_dict[qarg_names[5]][:] = data_high + qoutput, min_range, max_range = elemwise_add_int8_exe.forward() + int8_rslt = qoutput.astype(output.dtype)*max_range/0x7fffffff + diff = mx.nd.abs(output - int8_rslt) + cond = mx.nd.lesser(2, diff).sum().asscalar() + assert cond == 0 + + for qtype in ['int8', 'uint8']: + check_quantized_elemwise_add((4, 6), qtype) + check_quantized_elemwise_add((13, 74, 52), qtype) + check_quantized_elemwise_add((3, 4, 56, 56), qtype) + check_quantized_elemwise_add((32, 56, 64, 11), qtype) + + +def test_quantized_elemwise_mul(): + def check_quantized_elemwise_mul(data_shape, qtype): + if is_test_for_native_cpu(): + print('skipped testing quantized_elemwise_mul for native cpu since it is not supported yet') + return + elif qtype != 'int8': + print('skipped testing quantized_elemwise_mul for not supported data type') + return + elif is_test_for_gpu(): + print('skipped testing quantized_elemwise_mul for gpu since it is not supported yet') + return + + dataA = mx.sym.Variable(name='dataA', shape=data_shape, dtype='float32') + dataB = mx.sym.Variable(name='dataB', shape=data_shape, dtype='float32') + elemwise_mul_fp32 = mx.sym.elemwise_mul(dataA, dataB) + arg_names = elemwise_mul_fp32.list_arguments() + elemwise_mul_fp32_exe = elemwise_mul_fp32._simple_bind(ctx=mx.current_context(), grad_req='null') + if qtype == 'uint8': + data_low = 0.0 + data_high = 255.0 + else: + data_low = -127.0 + data_high = 127.0 + + dataA_val = mx.nd.random.uniform(low=data_low, high=data_high, shape=data_shape).astype('int32') + dataB_val = mx.nd.random.uniform(low=data_low, high=data_high, shape=data_shape).astype('int32') + elemwise_mul_fp32_exe.arg_dict[arg_names[0]][:] = dataA_val + + elemwise_mul_fp32_exe.arg_dict[arg_names[1]][:] = dataB_val + + output = elemwise_mul_fp32_exe.forward()[0] + + qdataA = mx.sym.Variable(name='qdataA', shape=data_shape, dtype=qtype) + qdataB = mx.sym.Variable(name='qdataB', shape=data_shape, dtype=qtype) + min_dataA = mx.sym.Variable(name='min_dataA', dtype='float32') + max_dataA = mx.sym.Variable(name='max_dataA', dtype='float32') + min_dataB = mx.sym.Variable(name='min_dataB', dtype='float32') + max_dataB = mx.sym.Variable(name='max_dataB', dtype='float32') + quantized_elemwise_mul = mx.sym.contrib.quantized_elemwise_mul(qdataA, qdataB, min_dataA, max_dataA, min_dataB, max_dataB) + elemwise_mul_int8_exe = quantized_elemwise_mul._simple_bind(ctx=mx.current_context(), grad_req='null') + qarg_names = quantized_elemwise_mul.list_arguments() + elemwise_mul_int8_exe.arg_dict[qarg_names[0]][:] = elemwise_mul_fp32_exe.arg_dict[arg_names[0]].astype(qtype) + elemwise_mul_int8_exe.arg_dict[qarg_names[1]][:] = elemwise_mul_fp32_exe.arg_dict[arg_names[1]].astype(qtype) + quantized_range = 127.0 + elemwise_mul_int8_exe.arg_dict[qarg_names[2]][:] = data_low + elemwise_mul_int8_exe.arg_dict[qarg_names[3]][:] = data_high + elemwise_mul_int8_exe.arg_dict[qarg_names[4]][:] = data_low + elemwise_mul_int8_exe.arg_dict[qarg_names[5]][:] = data_high + qoutput, min_range, max_range = elemwise_mul_int8_exe.forward() + + fp32_rslt = output.asnumpy() + int8_rslt = qoutput.astype(output.dtype) + assert_almost_equal(fp32_rslt, int8_rslt, atol = 1e-4) + + for qtype in ['int8', 'uint8']: + check_quantized_elemwise_mul((4, 6), qtype) + check_quantized_elemwise_mul((13, 74, 52), qtype) + check_quantized_elemwise_mul((3, 4, 56, 56), qtype) + check_quantized_elemwise_mul((32, 56, 64, 11), qtype) + + +def test_quantized_pooling(): + def check_quantized_pooling(data_shape, kernel, pool_type, pad, stride, global_pool, qdtype, convention='valid'): + if is_test_for_native_cpu(): + print('skipped testing quantized_pooling for native cpu since it is not supported yet') + return + elif qdtype == 'uint8' and is_test_for_gpu(): + print('skipped testing quantized_pooling for gpu uint8 since it is not supported yet') + return + elif is_test_for_gpu() and len(data_shape) != 4: + print('skipped testing quantized_pooling for gpu 5d layout since it is not supported yet') + return + + data = mx.sym.Variable(name='data', shape=data_shape, dtype='float32') + pooling_fp32 = mx.sym.Pooling(data=data, kernel=kernel, pad=pad, stride=stride, + pool_type=pool_type, global_pool=global_pool, cudnn_off=False, + pooling_convention=convention) + arg_shapes, _, _ = pooling_fp32.infer_shape(data=data_shape) + arg_names = pooling_fp32.list_arguments() + pooling_fp32_exe = pooling_fp32._simple_bind(ctx=mx.current_context(), grad_req='null') + if qdtype == 'uint8': + data_low = 0.0 + data_high = 127.0 + else: + data_low = -127.0 + data_high = 127.0 + pooling_fp32_exe.arg_dict[arg_names[0]][:] = mx.nd.random.uniform(low=data_low, high=data_high, + shape=data_shape).astype('int32') + output = pooling_fp32_exe.forward()[0] + + qdata = mx.sym.Variable(name='qdata', shape=data_shape, dtype=qdtype) + min_data = mx.sym.Variable(name='min_data') + max_data = mx.sym.Variable(name='max_data') + quantized_pooling = mx.sym.contrib.quantized_pooling(data=qdata, min_data=min_data, + max_data=max_data, kernel=kernel, + pad=pad, stride=stride, pool_type=pool_type, + global_pool=global_pool, + pooling_convention=convention) + pooling_int8_exe = quantized_pooling._simple_bind(ctx=mx.current_context(), grad_req='null') + qarg_names = quantized_pooling.list_arguments() + pooling_int8_exe.arg_dict[qarg_names[0]][:] = pooling_fp32_exe.arg_dict[arg_names[0]].astype(qdtype) + quantized_range = 127.0 + pooling_int8_exe.arg_dict[qarg_names[1]][:] = -quantized_range + pooling_int8_exe.arg_dict[qarg_names[2]][:] = quantized_range + qoutput, min_range, max_range = pooling_int8_exe.forward() + + if pool_type == 'max': + assert_almost_equal(output.asnumpy(), qoutput.asnumpy()) + elif pool_type == 'avg': # for avg pooling, fp32 and int8 may be different due to rounding errors + diff = mx.nd.abs(output - qoutput.astype(output.dtype)) + cond = mx.nd.lesser(2, diff).sum().asscalar() + assert cond == 0 + + for qdtype in ['int8', 'uint8']: + check_quantized_pooling((3, 4, 56, 56), (3, 3), 'max', (0, 0), (2, 2), False, qdtype) + check_quantized_pooling((3, 4, 56, 56), (3, 3), 'max', (0, 0), (2, 2), True, qdtype) + check_quantized_pooling((3, 512, 7, 7), (7, 7), 'avg', (0, 0), (1, 1), False, qdtype) + check_quantized_pooling((3, 512, 7, 7), (7, 7), 'avg', (0, 0), (1, 1), True, qdtype) + check_quantized_pooling((3, 4, 3, 56, 56), (1, 3, 3), 'max', (0, 0, 0), (1, 2, 2), False, qdtype) + check_quantized_pooling((3, 4, 3, 56, 56), (1, 3, 3), 'max', (0, 0, 0), (1, 2, 2), True, qdtype) + check_quantized_pooling((3, 512, 3, 7, 7), (1, 7, 7), 'avg', (0, 0, 0), (1, 2, 2), False, qdtype) + check_quantized_pooling((3, 512, 3, 7, 7), (1, 7, 7), 'avg', (0, 0, 0), (1, 2, 2), True, qdtype) + + check_quantized_pooling((3, 4, 56, 56), (3, 3), 'max', (0, 0), (2, 2), False, qdtype, 'full') + check_quantized_pooling((3, 4, 56, 56), (3, 3), 'max', (0, 0), (2, 2), True, qdtype, 'full') + check_quantized_pooling((3, 512, 7, 7), (7, 7), 'avg', (0, 0), (1, 1), False, qdtype, 'full') + check_quantized_pooling((3, 512, 7, 7), (7, 7), 'avg', (0, 0), (1, 1), True, qdtype, 'full') + check_quantized_pooling((3, 4, 3, 56, 56), (1, 3, 3), 'max', (0, 0, 0), (1, 2, 2), False, qdtype, 'full') + check_quantized_pooling((3, 4, 3, 56, 56), (1, 3, 3), 'max', (0, 0, 0), (1, 2, 2), True, qdtype, 'full') + check_quantized_pooling((3, 512, 3, 7, 7), (1, 7, 7), 'avg', (0, 0, 0), (1, 2, 2), False, qdtype, 'full') + check_quantized_pooling((3, 512, 3, 7, 7), (1, 7, 7), 'avg', (0, 0, 0), (1, 2, 2), True, qdtype, 'full') + + +def test_quantized_fc(): + def check_quantized_fc(data_shape, num_hidden, no_bias, qdtype, flatten=True): + if is_test_for_native_cpu(): + hasMKL = False + for key in os.environ.keys(): + if operator.eq(key, "BUILD_TAG"): + if os.environ['BUILD_TAG'].find("MKL") != -1: + hasMKL = True + break + if hasMKL == False: + print('skipped testing quantized_fc on cpu since s8u8s32 is only supported by MKL BLAS library') + return + elif qdtype == 'uint8' and is_test_for_gpu(): + print('skipped testing quantized_fc for gpu uint8 since it is not supported yet') + return + + def maxabs(a, b): + return mx.nd.maximum(mx.nd.abs(a), mx.nd.abs(b)) + + data = mx.sym.Variable(name='data', shape=data_shape, dtype='float32') + fc_fp32 = mx.sym.FullyConnected(data=data, num_hidden=num_hidden, no_bias=no_bias, flatten=flatten) + arg_shapes, _, _ = fc_fp32.infer_shape(data=data_shape) + arg_names = fc_fp32.list_arguments() + fc_fp32_exe = fc_fp32._simple_bind(ctx=mx.current_context(), grad_req='null') + int8_range = 127.0 + if qdtype == 'uint8': + data_low = 0.0 + data_high = 63.0 + quantized_range = 255.0 + else: + data_low = -63.0 + data_high = 63.0 + quantized_range = 127.0 + + data = mx.nd.random.uniform(low=data_low, high=data_high, + shape=data_shape).astype('int32') + weight = mx.nd.random.uniform(low=data_low, high=data_high, + shape=arg_shapes[1]).astype('int32') + fc_fp32_exe.arg_dict[arg_names[0]][:] = data + fc_fp32_exe.arg_dict[arg_names[1]][:] = weight + + data_min = mx.nd.min(data).astype('float32') + data_max = mx.nd.max(data).astype('float32') + weight_min = mx.nd.min(weight).astype('float32') + weight_max = mx.nd.max(weight).astype('float32') + data_range = maxabs(data_min, data_max) + weight_range = maxabs(weight_min, weight_max) + + if not no_bias: + bias = mx.nd.random.uniform(low=data_low, high=data_high, + shape=arg_shapes[2]).astype('int32') + bias_min = mx.nd.min(bias).astype('float32') + bias_max = mx.nd.max(bias).astype('float32') + bias_range = maxabs(bias_min, bias_max) + + bias_scale = int8_range / bias_range + data_scale = quantized_range / data_range + weight_scale = int8_range / weight_range + bias_int32_rescale = data_scale * weight_scale / bias_scale + new_bias = mx.nd.cast(bias, dtype='float32') * bias_int32_rescale + fc_fp32_exe.arg_dict[arg_names[2]][:] = new_bias.astype('int32') + + output = fc_fp32_exe.forward()[0] + + qdata = mx.sym.Variable(name='qdata', shape=data_shape, dtype=qdtype) + fc_int8 = mx.sym.contrib.quantized_fully_connected(data=qdata, num_hidden=num_hidden, + no_bias=no_bias, flatten=flatten) + qarg_names = fc_int8.list_arguments() + type_dict = {qarg_names[1]: 'int8'} + if not no_bias: + type_dict.update({qarg_names[2]: 'int8'}) + fc_int8_exe = fc_int8._simple_bind(ctx=mx.current_context(), type_dict=type_dict, grad_req='null') + fc_int8_exe.arg_dict[qarg_names[0]][:] = fc_fp32_exe.arg_dict[arg_names[0]].astype(qdtype) + fc_int8_exe.arg_dict[qarg_names[1]][:] = fc_fp32_exe.arg_dict[arg_names[1]].astype('int8') + if no_bias: + fc_int8_exe.arg_dict[qarg_names[2]][:] = -data_range + fc_int8_exe.arg_dict[qarg_names[3]][:] = data_range + fc_int8_exe.arg_dict[qarg_names[4]][:] = -weight_range + fc_int8_exe.arg_dict[qarg_names[5]][:] = weight_range + else: + fc_int8_exe.arg_dict[qarg_names[2]][:] = bias.astype('int8') + fc_int8_exe.arg_dict[qarg_names[3]][:] = -data_range + fc_int8_exe.arg_dict[qarg_names[4]][:] = data_range + fc_int8_exe.arg_dict[qarg_names[5]][:] = -weight_range + fc_int8_exe.arg_dict[qarg_names[6]][:] = weight_range + fc_int8_exe.arg_dict[qarg_names[7]][:] = -bias_range + fc_int8_exe.arg_dict[qarg_names[8]][:] = bias_range + qoutput, min_range, max_range = fc_int8_exe.forward() + + if no_bias: + assert_almost_equal(output.asnumpy(), qoutput.asnumpy()) + else: + # with adding bias, accuracy loss should not be greater than one + diff = mx.nd.abs(output - qoutput.astype(output.dtype)) + cond = mx.nd.lesser(2, diff).sum().asscalar() + assert cond == 0 + + for qdtype in ['int8', 'uint8']: + if is_test_for_mkldnn(): + check_quantized_fc((32, 512, 2), 100, True, qdtype, flatten=False) + check_quantized_fc((32, 512, 2), 100, False, qdtype, flatten=False) + check_quantized_fc((32, 512, 2, 2), 100, True, qdtype, flatten=False) + check_quantized_fc((32, 512, 2, 2), 100, False, qdtype, flatten=False) + check_quantized_fc((32, 512, 2, 2), 100, True, qdtype) + check_quantized_fc((32, 111, 2, 2), 100, True, qdtype) + check_quantized_fc((32, 512, 2, 2), 100, False, qdtype) + check_quantized_fc((32, 111, 2, 2), 100, False, qdtype) + check_quantized_fc((256, 2048, 2, 2), 800, False, qdtype) + check_quantized_fc((256, 111, 2, 2), 800, False, qdtype) + check_quantized_fc((256, 2048, 2, 2), 800, True, qdtype) + check_quantized_fc((256, 111, 2, 2), 800, True, qdtype) + + +def test_quantized_embedding(): + def check_quantized_embedding(data_shape, input_dim, output_dim): + if is_test_for_gpu(): + print('skipped testing test_quantized_embedding for gpu since it is not supported yet') + return + + def maxabs(a, b): + return mx.nd.maximum(mx.nd.abs(a), mx.nd.abs(b)) + + data0 = mx.sym.Variable(name='data', shape=data_shape, dtype='int32') + embedding_fp32 = mx.sym.Embedding(data=data0, input_dim=input_dim, output_dim=output_dim) + arg_shapes, _, _ = embedding_fp32.infer_shape(data=data_shape) + arg_names = embedding_fp32.list_arguments() + embedding_fp32_exe = embedding_fp32._simple_bind(ctx=mx.current_context(), grad_req='null') + int8_range = 127.0 + data = mx.nd.random.uniform(low=0, high=input_dim, + shape=arg_shapes[0]).astype('int32') + weight = mx.nd.random.uniform(low=-int8_range, high=int8_range, + shape=arg_shapes[1]).astype('int32') + embedding_fp32_exe.arg_dict[arg_names[0]][:] = data + embedding_fp32_exe.arg_dict[arg_names[1]][:] = weight + + weight_min = mx.nd.min(weight).astype('float32') + weight_max = mx.nd.max(weight).astype('float32') + weight_range = maxabs(weight_min, weight_max) + + output = embedding_fp32_exe.forward()[0] + + embedding_int8 = mx.sym.contrib.quantized_embedding(data=data0, input_dim=input_dim, output_dim=output_dim) + qarg_names = embedding_int8.list_arguments() + type_dict = {qarg_names[1]: 'int8'} + embedding_int8_exe = embedding_int8._simple_bind(ctx=mx.current_context(), type_dict=type_dict, grad_req='null') + embedding_int8_exe.arg_dict[qarg_names[0]][:] = embedding_fp32_exe.arg_dict[arg_names[0]] + embedding_int8_exe.arg_dict[qarg_names[1]][:] = embedding_fp32_exe.arg_dict[arg_names[1]].astype('int8') + embedding_int8_exe.arg_dict[qarg_names[2]][:] = -weight_range + embedding_int8_exe.arg_dict[qarg_names[3]][:] = weight_range + qoutput, min_range, max_range = embedding_int8_exe.forward() + + assert_almost_equal(output.asnumpy(), qoutput.asnumpy()) + + check_quantized_embedding((1,), 1000, 256) + check_quantized_embedding((1,), 1024, 512) + check_quantized_embedding((32,), 1000, 256) + check_quantized_embedding((32,), 1024, 512) + + +def test_quantized_flatten(): + def check_quantized_flatten(shape, qdtype): + if qdtype == 'uint8': + data_low = 0.0 + data_high = 127.0 + else: + data_low = -127.0 + data_high = 127.0 + qdata = mx.nd.random.uniform(low=data_low, high=data_high, shape=shape).astype(qdtype) + min_data = mx.nd.array([-1023.343], dtype='float32') + max_data = mx.nd.array([2343.324275], dtype='float32') + qoutput, min_output, max_output = mx.nd.contrib.quantized_flatten(qdata, min_data, max_data) + assert qoutput.ndim == 2 + assert qoutput.shape[0] == qdata.shape[0] + assert qoutput.shape[1] == np.prod(qdata.shape[1:]) + assert same(qdata.asnumpy().flatten(), qoutput.asnumpy().flatten()) + assert same(min_data.asnumpy(), min_output.asnumpy()) + assert same(max_data.asnumpy(), max_output.asnumpy()) + + for qdtype in ['int8', 'uint8']: + check_quantized_flatten((10,), qdtype) + check_quantized_flatten((10, 15), qdtype) + check_quantized_flatten((10, 15, 18), qdtype) + check_quantized_flatten((3, 4, 23, 23), qdtype) + + +def test_quantized_act(): + def check_quantized_act(data_shape, qdtype): + if is_test_for_native_cpu(): + print('skipped testing quantized_act for native cpu since it is not supported yet') + return + elif qdtype == 'int8' and is_test_for_mkldnn(): + print('skipped testing quantized_act for mkldnn cpu int8 since it is not supported yet') + return + elif is_test_for_gpu(): + print('skipped testing quantized_act for gpu since it is not supported yet') + return + data = mx.sym.Variable(name='data', shape=data_shape, dtype='float32') + act_fp32 = mx.sym.Activation(data=data, act_type='relu', name='relu') + arg_shapes, _, _ = act_fp32.infer_shape(data=data_shape) + arg_names = act_fp32.list_arguments() + act_fp32_exe = act_fp32._simple_bind(ctx=mx.current_context(), grad_req='null') + if qdtype == 'uint8': + data_low = 0.0 + data_high = 127.0 + else: + data_low = -127.0 + data_high = 127.0 + + act_fp32_exe.arg_dict[arg_names[0]][:] = mx.nd.random.uniform(low=data_low, + high=data_high, shape=data_shape).astype(qdtype) + output = act_fp32_exe.forward()[0] + + qdata = mx.sym.Variable(name='qdata', shape=data_shape, dtype=qdtype) + min_data = mx.sym.Variable(name='min_data') + max_data = mx.sym.Variable(name='max_data') + quantized_act = mx.sym.contrib.quantized_act(data=qdata, min_data=min_data, max_data=max_data, act_type='relu') + act_int8_exe = quantized_act._simple_bind(ctx=mx.current_context(), grad_req='null') + qarg_names = quantized_act.list_arguments() + + act_int8_exe.arg_dict[qarg_names[0]][:] = act_fp32_exe.arg_dict[arg_names[0]].astype(qdtype) + quantized_range_min = mx.nd.min(act_int8_exe.arg_dict[qarg_names[0]][:]) + quantized_range_max = mx.nd.max(act_int8_exe.arg_dict[qarg_names[0]][:]) + act_int8_exe.arg_dict[qarg_names[1]][:] = quantized_range_min.astype(qdtype) + act_int8_exe.arg_dict[qarg_names[2]][:] = quantized_range_max.astype(qdtype) + qoutput, min_range, max_range = act_int8_exe.forward() + + assert_almost_equal(output.asnumpy(), qoutput.asnumpy()) + assert_almost_equal(min_range.asscalar(), quantized_range_min.asscalar()) + assert_almost_equal(max_range.asscalar(), quantized_range_max.asscalar()) + + for qdtype in ['int8', 'uint8']: + check_quantized_act((10,), qdtype) + check_quantized_act((10, 15), qdtype) + check_quantized_act((10, 15, 18), qdtype) + check_quantized_act((3, 4, 23, 23), qdtype) + + +def test_quantized_bn(): + def get_mean_var(data): + mean = mx.ndarray.mean(data, axis=1, exclude=1) + mean_broad = mx.ndarray.expand_dims(mean, axis=0) + mean_broad = mx.ndarray.expand_dims(mean_broad, axis=2) + mean_broad = mx.ndarray.expand_dims(mean_broad, axis=3) + mean_broad = mx.ndarray.broadcast_like(mean_broad, data) + var = mx.ndarray.multiply(data - mean_broad, data - mean_broad) + var = mx.ndarray.mean(var, axis=1, exclude=1) + return mean, var + + def check_quantized_bn(data_shape, qdtype): + if is_test_for_native_cpu(): + print('skipped testing quantize_bn for native cpu since it is not supported yet') + return + elif is_test_for_gpu(): + print('skipped testing quantize_bn for gpu since it is not supported yet') + return + + # qdtype = uint8 + if qdtype == 'uint8': + data_low = 0.0 + data_high = 255.0 + else: + data_low = -127.0 + data_high = 127.0 + + # run fp32 bn + data_sym = mx.sym.Variable(name='data', shape=data_shape, dtype='float32') + bn_fp32 = mx.sym.BatchNorm(data=data_sym, name='bn', use_global_stats=True, fix_gamma=False) + arg_shapes, out_shapes, aux_shapes = bn_fp32.infer_shape(data=data_shape) + arg_names = bn_fp32.list_arguments() + aux_names = bn_fp32.list_auxiliary_states() + data = mx.nd.random.uniform(low=data_low, high=data_high, shape=data_shape) + gamma = mx.nd.random.uniform(low=data_low, high=data_high, shape=arg_shapes[1]) + beta = mx.nd.random.uniform(low=data_low, high=data_high, shape=arg_shapes[2]) + moving_mean, moving_var = get_mean_var(data) + + bn_fp32_exe = bn_fp32._simple_bind(ctx=mx.current_context(), grad_req='null') + bn_fp32_exe.arg_dict[arg_names[0]][:] = data + bn_fp32_exe.arg_dict[arg_names[1]][:] = gamma + bn_fp32_exe.arg_dict[arg_names[2]][:] = beta + bn_fp32_exe.aux_dict[aux_names[0]][:] = moving_mean + bn_fp32_exe.aux_dict[aux_names[1]][:] = moving_var + + output = bn_fp32_exe.forward()[0] + + # generate int8 bn from fp32 bn + arg_params = dict() + for k, v in bn_fp32_exe.arg_dict.items(): + if 'data' in k or 'softmax_label' in k: + continue + arg_params[k] = v + + calib_data = mx.gluon.data.DataLoader(data, batch_size=data_shape[0]) + qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(sym=bn_fp32, + arg_params=arg_params, + aux_params=bn_fp32_exe.aux_dict, + ctx=mx.current_context(), + quantized_dtype=qdtype, + quantize_mode='full', + calib_mode='naive', + calib_data=calib_data, + num_calib_batches=1) + + sym_block = mx.gluon.SymbolBlock(outputs=qsym, inputs=data_sym) + params = qarg_params + params.update(qaux_params) + sym_block.load_dict(params) + output_int8_to_fp32 = sym_block.forward(data) + + assert_almost_equal(output.asnumpy(), output_int8_to_fp32.asnumpy(), rtol=1e-1, atol=8) + + for qdtype in ['int8', 'uint8']: + check_quantized_bn((32, 512, 4, 4), qdtype) + check_quantized_bn((32, 1024, 8, 8), qdtype) + check_quantized_bn((32, 3, 224, 224), qdtype) + + +def test_quantize_params(): + if is_test_for_native_cpu(): + print('skipped testing quantized_params for native cpu since it is not supported yet') + return + + data = mx.sym.Variable('data') + conv = mx.sym.Convolution(data, kernel=(1, 1), num_filter=2048, name='conv') + sym = mx.sym.BatchNorm(data=conv, eps=2e-05, fix_gamma=False, momentum=0.9, use_global_stats=False, name='bn') + offline_params = [name for name in sym.list_arguments() + if not name.startswith('data') and not name.endswith('label')] + params = {} + for name in offline_params: + params[name] = mx.nd.uniform(shape=(2, 2)) + qsym, _ = mx.contrib.quant._quantize_symbol(sym, ctx=mx.current_context(), + offline_params=offline_params, quantize_mode='full') + qparams = mx.contrib.quant._quantize_params(qsym, params, min_max_dict = {}) + param_names = params.keys() + qparam_names = qparams.keys() + for name in qparam_names: + if name.startswith('bn'): + assert name in param_names + elif name.startswith('conv'): + assert name not in param_names + assert name.find('quantize') != -1 + + +def get_fp32_sym(): + data = mx.sym.Variable('data') + conv = mx.sym.Convolution(data, kernel=(1, 1), num_filter=16, name='conv') + bn = mx.sym.BatchNorm(data=conv, eps=2e-05, fix_gamma=False, momentum=0.9, use_global_stats=False, name='bn') + act = mx.sym.Activation(data=bn, act_type='relu', name='relu') + pool = mx.sym.Pooling(act, kernel=(4, 4), pool_type='avg', name='pool') + fc = mx.sym.FullyConnected(pool, num_hidden=10, flatten=True, name='fc') + sym = mx.sym.softmax(fc, name='softmax') + return sym + +def get_fp32_residual(): + data = mx.sym.Variable('data') + conv0 = mx.sym.Convolution(data=data, num_filter=4, kernel=(1,1), pad=(0,0), + no_bias=True, name='conv0') + bn = mx.sym.BatchNorm(data=conv0, fix_gamma=False, eps=2e-5, momentum=0.9, name='bn') + sum0 = mx.sym.elemwise_add(bn, data, name='sum0') + act0 = mx.sym.Activation(data=sum0, act_type='relu', name='relu0') + pool0 = mx.sym.Pooling(act0, kernel=(4, 4), pool_type='avg', name='pool0') + conv1 = mx.sym.Convolution(data=pool0, num_filter=4, kernel=(1,1), pad=(0,0), + no_bias=False, name='conv1') + act1 = mx.sym.Activation(data=conv1, act_type='relu', name='relu1') + pool1 = mx.sym.Pooling(act1, kernel=(4, 4), pool_type='avg', name='pool1') + fc = mx.sym.FullyConnected(pool1, num_hidden=10, flatten=True, name='fc') + sym = mx.sym.SoftmaxOutput(fc, grad_scale=1, ignore_label=-1, multi_output=False, + out_grad=False, preserve_shape=False, use_ignore=False, name='softmax') + return sym + +def get_fp32_sym_with_multiple_outputs(length=1): + data = mx.sym.Variable('data') + inputs = list(mx.sym.split(data, axis=0, num_outputs=length, squeeze_axis=1, name='split')) + + _conv_outs = [] + for i in range(length): + _conv_outs.append(mx.sym.Convolution(data=inputs[i], kernel=(1, 1), num_filter=16, name='conv_{0}'.format(i))) + conv_out = [mx.sym.expand_dims(i, axis=0) for i in _conv_outs] + conv_out = mx.sym.Concat(*conv_out, dim=0, name='concat') + reshape_out = mx.sym.reshape(data=conv_out, shape=((length, -1)), name='reshape') + fc_out = mx.sym.FullyConnected(reshape_out, num_hidden=10, flatten=True, name='fc') + sym = mx.sym.softmax(fc_out, name='softmax') + return sym + +def get_fp32_sym_with_multiple_inputs(): + data1 = mx.symbol.Variable('data1', shape=(64, 4, 10, 10), dtype='float32') + data2 = mx.symbol.Variable('data2', shape=(64, 4, 10, 10), dtype='float32') + weight1 = mx.symbol.Variable('conv1_weight', dtype='float32') + weight2 = mx.symbol.Variable('conv2_weight', dtype='float32') + conv1 = mx.symbol.Convolution(data=data1, weight=weight1, name='conv1', num_filter=64, + kernel=(1, 1), stride=(1, 1), no_bias=True) + bn1 = mx.symbol.BatchNorm(data=conv1, name="bn1") + conv2 = mx.symbol.Convolution(data=data2, weight=weight2, name='conv2', num_filter=64, + kernel=(1, 1), stride=(1, 1), no_bias=True) + bn2 = mx.symbol.BatchNorm(data=conv2, name="bn2") + sum = bn2 + bn1 + return sum + + +@xfail_when_nonstandard_decimal_separator +def test_quantize_model(): + def check_params(params, qparams, qsym=None): + if qsym is None: + assert len(params) == len(qparams) + for k, v in params.items(): + assert k in qparams + assert same(v.asnumpy(), qparams[k].asnumpy()) + else: + qparams_ground_truth = mx.contrib.quant._quantize_params(qsym, params, min_max_dict = {}) + assert len(qparams) == len(qparams_ground_truth) + for k, v in qparams_ground_truth.items(): + assert k in qparams + assert same(v.asnumpy(), qparams[k].asnumpy()) + + def check_qsym_calibrated(qsym): + attrs = qsym.attr_dict() + for k, v in attrs.items(): + if k.find('requantize_') != -1: + assert 'min_calib_range' in v + assert 'max_calib_range' in v + + def check_qsym_qdtype(qsym, qdtype): + attrs = qsym.attr_dict() + for k, v in attrs.items(): + if k.find('_quantize') != -1: + assert 'out_type' in v + assert v['out_type'] == qdtype + + def skip_not_supported(): + if is_test_for_native_cpu(): + print('skipped testing quantize_model for native cpu since it is not supported yet') + return True + elif qdtype == 'int8' and is_test_for_mkldnn(): + print('skipped testing quantize_model for mkldnn cpu int8 since it is not supported yet') + return True + elif qdtype == 'uint8' and is_test_for_gpu(): + print('skipped testing quantize_model for gpu uint8 since it is not supported yet') + return True + return False + + def check_quantize_model(qdtype): + if is_test_for_native_cpu(): + print('skipped testing quantize_model for native cpu since it is not supported yet') + return + elif qdtype == 'int8' and is_test_for_mkldnn(): + print('skipped testing quantize_model for mkldnn cpu int8 since it is not supported yet') + return + elif qdtype == 'uint8' and is_test_for_gpu(): + print('skipped testing quantize_model for gpu uint8 since it is not supported yet') + return + + sym = get_fp32_sym() + batch_size = 4 + data_shape = (batch_size, 4, 10, 10) + + length = batch_size # specify num of outputs from split op + msym = get_fp32_sym_with_multiple_outputs(length) + msym_data_shape = (length, 4, 4, 10, 10) + + for s, dshape in zip((sym, msym), (data_shape, msym_data_shape)): + data = mx.sym.Variable('data') + sym_block = mx.gluon.SymbolBlock(outputs=s, inputs=data) + initialize_block_params(sym_block, mx.init.One()) + data = mx.nd.random.uniform(low=0, high=1, shape=dshape) + sym_block.forward(data) + arg_params, aux_params = collect_block_args_aux(sym_block, s) + + qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(sym=s, + arg_params=arg_params, + aux_params=aux_params, + ctx=mx.current_context(), + quantized_dtype=qdtype, + calib_mode='none', + quantize_mode='full') + check_params(arg_params, qarg_params, qsym) + check_params(aux_params, qaux_params) + + calib_data = mx.nd.random.uniform(shape=dshape) + calib_data = mx.gluon.data.DataLoader(calib_data, batch_size=batch_size) + qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(sym=s, + arg_params=arg_params, + aux_params=aux_params, + ctx=mx.current_context(), + quantized_dtype=qdtype, + calib_mode='naive', + calib_data=calib_data, + num_calib_batches=1, + quantize_mode='full') + check_params(arg_params, qarg_params, qsym) + check_params(aux_params, qaux_params) + check_qsym_calibrated(qsym) + check_qsym_qdtype(qsym, qdtype) + + def check_quantize_model_multiple_inputs(qdtype): + if skip_not_supported(): + return + + sym = get_fp32_sym_with_multiple_inputs() + dshape = (64, 4, 10, 10) + data1 = mx.sym.Variable('data1') + data2 = mx.sym.Variable('data2') + sym_block = mx.gluon.SymbolBlock(outputs=sym, inputs=[data1, data2]) + initialize_block_params(sym_block, mx.init.One()) + data = [mx.nd.random.uniform(low=0, high=1, shape=dshape), + mx.nd.random.uniform(low=0, high=1, shape=dshape)] + sym_block.forward(*data) + arg_params, aux_params = collect_block_args_aux(sym_block, sym) + + qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(sym=sym, + arg_params=arg_params, + aux_params=aux_params, + ctx=mx.current_context(), + quantized_dtype=qdtype, + calib_mode='none', + quantize_mode='full') + check_params(arg_params, qarg_params, qsym) + check_params(aux_params, qaux_params) + + calib_data = [mx.nd.random.uniform(shape=dshape), + mx.nd.random.uniform(shape=dshape)] + calib_data = mx.gluon.data.DataLoader(mx.gluon.data.ArrayDataset(*calib_data), batch_size=4) + qsym, qarg_params, qaux_params = mx.contrib.quant.quantize_model(sym=sym, + arg_params=arg_params, + aux_params=aux_params, + ctx=mx.current_context(), + quantized_dtype=qdtype, + calib_mode='naive', + calib_data=calib_data, + data_names=["data1","data2"], + num_calib_batches=1, + quantize_mode='full') + check_params(arg_params, qarg_params, qsym) + check_params(aux_params, qaux_params) + check_qsym_calibrated(qsym) + check_qsym_qdtype(qsym, qdtype) + + + for qdtype in ['int8', 'uint8']: + check_quantize_model(qdtype) + check_quantize_model_multiple_inputs(qdtype) + + +def test_quantize_gluon_with_forward(): + def check_quantize_net(qdtype): + if is_test_for_native_cpu(): + print('skipped testing test_quantize_model_with_forward for native cpu since it is not supported yet') + return + elif is_test_for_gpu(): + print('skipped testing test_quantize_model_with_forward for gpu uint8 since it is not supported yet') + return + + data_shape = (32, 3, 224, 224) + batch_size = 1 + resnet18_v1 = vision.resnet18_v1(pretrained=True) + resnet18_v1.reset_ctx(mx.current_context()) + excluded_names_match = [] + if mx.current_context() == mx.gpu(): + excluded_names_match += ['activation', 'relu', 'conv0'] + num_calib_batches = 1 + + random_data = mx.random.uniform(shape=data_shape) + calib_data = mx.gluon.data.DataLoader(random_data, batch_size=batch_size) + + quantized_resnet18_v1 = mx.contrib.quant.quantize_net(resnet18_v1, quantized_dtype=qdtype, + exclude_layers=None, + exclude_layers_match=excluded_names_match, + calib_mode='none', + data_shapes=[data_shape], + ctx=mx.current_context()) + quantized_resnet18_v1.hybridize(static_alloc=True, static_shape=True) + quantized_resnet18_v1(random_data) + + for mode in ['naive', 'entropy']: + for quantize_granularity in ['tensor-wise', 'channel-wise']: + qdtype = qdtype if mode is 'naive' else 'auto' + quantized_resnet18_v1 = mx.contrib.quant.quantize_net(resnet18_v1, quantized_dtype=qdtype, + exclude_layers=None, + exclude_layers_match=excluded_names_match, + calib_data=calib_data, + calib_mode=mode, + quantize_granularity=quantize_granularity, + num_calib_batches=num_calib_batches, + ctx=mx.current_context()) + + quantized_resnet18_v1.hybridize(static_alloc=True, static_shape=True) + quantized_resnet18_v1(random_data) + + for qdtype in ['int8', 'uint8']: + check_quantize_net(qdtype) + + +@xfail_when_nonstandard_decimal_separator +def test_quantize_sym_with_calib(): + if is_test_for_native_cpu(): + print('skipped testing quantized_pooling for native cpu since it is not supported yet') + return + + sym = get_fp32_sym() + offline_params = [name for name in sym.list_arguments() + if not name.startswith('data') and not name.endswith('label')] + qsym, _ = mx.contrib.quant._quantize_symbol(sym, ctx=mx.current_context(), + offline_params=offline_params, quantize_mode='full') + requantize_op_names = ['requantize_conv', 'requantize_fc'] + min_max_dict = {'conv_output': (np.random.uniform(low=100.0, high=200.0), np.random.uniform(low=100.0, high=200.0)), + 'fc_output': (np.random.uniform(low=100.0, high=200.0), np.random.uniform(low=100.0, high=200.0))} + op_name_to_th_name = {'requantize_conv': 'conv_output', 'requantize_fc': 'fc_output'} + cqsym = mx.contrib.quant._calibrate_quantized_sym(qsym, min_max_dict) + attr_dict = cqsym.attr_dict() + for name in requantize_op_names: + assert name in attr_dict + lhs = float(attr_dict[name]['min_calib_range']) + rhs = min_max_dict[op_name_to_th_name[name]][0] + assert_almost_equal(np.array([lhs]), np.array([rhs])) + lhs = float(attr_dict[name]['max_calib_range']) + rhs = min_max_dict[op_name_to_th_name[name]][1] + assert_almost_equal(np.array([lhs]), np.array([rhs]), rtol=1e-3, atol=1e-4) + + +def test_quantization_net_with_different_data_inputs_options(): + if is_test_for_native_cpu(): + print('skipped testing test_quantization_net_with_different_data_inputs_options for native cpu since it is not supported yet') + return + elif is_test_for_gpu(): + print('skipped testing test_quantization_net_with_different_data_inputs_options for gpu since it is not supported yet') + return + + sym = get_fp32_sym() + net = mx.gluon.SymbolBlock(sym, mx.sym.var('data')) + initialize_block_params(net, mx.init.Normal(0.2)) + + batch_size = 32 + data_shape = (batch_size, 3, 224, 224) + random_data = mx.random.uniform(shape=data_shape) + + # pass data_shapes as list of tuples + quantized_net = mx.contrib.quant.quantize_net(net, + quantized_dtype='auto', + data_shapes=[data_shape], + ctx=mx.current_context()) + out = quantized_net(random_data) + out.wait_to_read() + + + # pass data_shapes as list of DataDescs + net2 = mx.gluon.SymbolBlock(sym, mx.sym.var('data')) + initialize_block_params(net2, mx.init.Normal(0.2)) + data_desc = mx.io.DataDesc('data', data_shape) + quantized_net2 = mx.contrib.quant.quantize_net(net, + quantized_dtype='auto', + data_shapes=[data_desc], + ctx=mx.current_context()) + out2 = quantized_net2(random_data) + out2.wait_to_read() + + + # pass data as DataLoader + net3 = mx.gluon.SymbolBlock(sym, mx.sym.var('data')) + initialize_block_params(net3, mx.init.Normal(0.2)) + data_loader = mx.gluon.data.DataLoader(random_data, batch_size=batch_size) + quantized_net3 = mx.contrib.quant.quantize_net(net, + quantized_dtype='auto', + calib_data=data_loader, + ctx=mx.current_context()) + out3 = quantized_net3(random_data) + out3.wait_to_read() + + + +def test_optimal_threshold_adversarial_case(): + # The worst case for the optimal_threshold function is when the values are concentrated + # at one edge: [0, 0, ..., 1000]. (histogram) + # We want to make sure that the optimal threshold in this case is the max. + hist = [] + hist_edges = [] + min_val = -2 + max_val = 2 + for i in range(0, 998): + hist.append(0) + for i in range(0, 999): + hist_edges.append((max_val - min_val) / 999 * i + min_val) + hist.append(1000) + hist_edges.append(max_val) + hist_data = (hist, hist_edges, min_val, max_val, max_val) + for dtype in ['uint8', 'int8', 'auto']: + res = mx.contrib.quant._LayerHistogramCollector.get_optimal_threshold(hist_data, dtype, num_quantized_bins=5) + # The threshold should be 2. + print (res) + assert abs(res[2] - 2) < 1e-5 + + +def test_get_optimal_thresholds(): + # Given an ndarray with elements following a uniform distribution, the optimal threshold + # for quantizing the ndarray should be either abs(min(nd)) or abs(max(nd)). + def get_threshold(nd): + min_nd = mx.nd.min(nd) + max_nd = mx.nd.max(nd) + return mx.nd.maximum(mx.nd.abs(min_nd), mx.nd.abs(max_nd)).asnumpy() + + for dtype in ['uint8', 'int8', 'auto']: + nd = mx.nd.uniform(low=-10.532, high=11.3432, shape=(8, 3, 23, 23), dtype=np.float64) + expected_threshold = get_threshold(nd) + arr = nd.asnumpy() + min_range = np.min(arr) + max_range = np.max(arr) + th = max(abs(min_range), abs(max_range)) + hist, hist_edges = np.histogram(arr, bins=8001, range=(-th, th)) + hist_dict = {'layer1' : (hist, hist_edges, min_range, max_range, th)} + min_max_dict = mx.contrib.quant._LayerHistogramCollector.get_optimal_thresholds(hist_dict, dtype) + assert 'layer1' in min_max_dict + assert_almost_equal(np.array([min_max_dict['layer1'][1]]), expected_threshold, rtol=1e-2, atol=1e-4) + diff --git a/tests/python/test_quantization_gpu.py b/tests/python/test_quantization_gpu.py new file mode 100644 index 000000000000..0f14fa1ac961 --- /dev/null +++ b/tests/python/test_quantization_gpu.py @@ -0,0 +1,27 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +import os +import sys +import mxnet as mx + + +curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__))) +sys.path.insert(0, os.path.join(curr_path, '../quantization')) +from mxnet.test_utils import set_default_context +from test_quantization import * + +set_default_context(mx.gpu(0)) diff --git a/tools/rec2idx.py b/tools/rec2idx.py index 111ce9ba0a22..0219d2fc2a2a 100644 --- a/tools/rec2idx.py +++ b/tools/rec2idx.py @@ -31,8 +31,8 @@ class IndexCreator(mx.recordio.MXRecordIO): Example usage: ---------- >>> creator = IndexCreator('data/test.rec','data/test.idx') - >>> record.create_index() - >>> record.close() + >>> creator.create_index() + >>> creator.close() >>> !ls data/ test.rec test.idx