diff --git a/include/mxnet/op_attr_types.h b/include/mxnet/op_attr_types.h index c936d3e84afa..0bc2a8f62daf 100644 --- a/include/mxnet/op_attr_types.h +++ b/include/mxnet/op_attr_types.h @@ -343,6 +343,19 @@ using FNeedRequantize = std::function; using FAvoidQuantizeInput = std::function< bool(const NodeAttrs& attrs, const size_t index, const std::string quantize_granularity)>; +/*! + * \brief Register a function to determine if the input of a quantized operator + * needs to be quantized asymmetrically. + */ +using FNeedAsymQuantizeInput = std::function; + +/*! + * \brief Register a function to determine if the output of a quantized operator + * needs to be dequantized. This is usually used for the quantized operators + * which can produce fp32 outputs directly. + */ +using FAvoidDequantizeOutput = std::function; + /*! * \brief Register a function to determine if the input of a quantized operator * needs to be calibrated. This is usually used for the quantized operators diff --git a/python/mxnet/amp/lists/symbol_bf16.py b/python/mxnet/amp/lists/symbol_bf16.py index dd545a778578..ed2416a477cb 100644 --- a/python/mxnet/amp/lists/symbol_bf16.py +++ b/python/mxnet/amp/lists/symbol_bf16.py @@ -119,6 +119,7 @@ '_contrib_index_copy', '_contrib_quadratic', '_contrib_quantize', + '_contrib_quantize_asym', '_contrib_quantize_v2', '_contrib_quantized_concat', '_contrib_quantized_conv', @@ -127,6 +128,7 @@ '_contrib_quantized_pooling', '_contrib_quantized_elemwise_add', '_contrib_quantized_act', + '_contrib_quantized_rnn', '_image_crop', '_linspace', '_contrib_requantize', diff --git a/python/mxnet/amp/lists/symbol_fp16.py b/python/mxnet/amp/lists/symbol_fp16.py index 52a8f459f302..ad1f0ad4b293 100644 --- a/python/mxnet/amp/lists/symbol_fp16.py +++ b/python/mxnet/amp/lists/symbol_fp16.py @@ -99,6 +99,7 @@ '_contrib_index_copy', '_contrib_quadratic', '_contrib_quantize', + '_contrib_quantize_asym', '_contrib_quantize_v2', '_contrib_quantized_concat', '_contrib_quantized_conv', @@ -108,6 +109,7 @@ '_contrib_quantized_elemwise_add', '_contrib_quantized_act', '_contrib_quantized_reshape', + '_contrib_quantized_rnn', '_contrib_quantized_transpose', '_npx_quantized_reshape', '_npx_quantized_transpose', diff --git a/python/mxnet/contrib/quantization.py b/python/mxnet/contrib/quantization.py index 10d2455cb9ae..64942353db29 100644 --- a/python/mxnet/contrib/quantization.py +++ b/python/mxnet/contrib/quantization.py @@ -33,6 +33,20 @@ from ..util import is_np_array, wrap_ctx_to_device_func +def _multilist_iterator(arg, func): + """Iterate over multidiemnsional list and returns new list + with same dimensions, but applied `func` function on list elements. + E.g. _multilist_iterator([1, 2, [3, 4]], lambda x: x**2) = [1, 4, [9, 16]] + """ + ret = [] + if isinstance(arg, list): + for el in arg: + ret.append(_multilist_iterator(el, func)) + else: + return func(arg) + + return ret + def _quantize_params(qsym, params, min_max_dict): """Given a quantized symbol and a dict of params that have not been quantized, generate quantized params. Currently only supports quantizing the arg_params @@ -357,7 +371,7 @@ def _collect_layer_statistics(sym_block, data, collector, num_inputs, num_calib_ for batch in data: if not isinstance(batch, list): batch = [batch] - batch = [b.as_in_context(mx.cpu()) for b in batch] + batch = _multilist_iterator(batch, lambda b: b.as_in_context(mx.cpu())) sym_block(*batch[:num_inputs]) num_batches += 1 if num_calib_batches is not None and num_batches >= num_calib_batches: @@ -368,20 +382,45 @@ def _collect_layer_statistics(sym_block, data, collector, num_inputs, num_calib_ def _generate_list_of_data_desc(data_shapes, data_types): - """"Convert list ot tuples to list of DataDesc.""" - if isinstance(data_shapes, list): - if all(isinstance(x, DataDesc) for x in data_shapes): - return data_shapes - if all(isinstance(x, tuple) for x in data_shapes): - if len(data_shapes) == 1: - data_shapes = [DataDesc(name='data', shape=data_shapes[0], dtype=data_types[0])] + """Convert list of tuples to list of DataDesc.""" + def flatten_list(arg): + ret = [] + for el in arg: + if isinstance(el, list): + ret += flatten_list(el) else: - data_shapes = [DataDesc(name='data' + str(i), shape=data_shapes[i], - dtype=data_types[i]) for i in range(len(data_shapes))] - return data_shapes - raise ValueError('data_shapes must be either a list of DataDesc or a list of Tuple') + ret.append(el) + return ret + + flattened_data_types = flatten_list(data_types) + flattened_data_shapes = flatten_list(data_shapes) + + if all(isinstance(x, DataDesc) for x in flattened_data_shapes): + return data_shapes + + assert len(flattened_data_types) == len(flattened_data_shapes) + + # pass integral type as reference + counter = [0] + def get_data_desc(data_shape, counter=counter, data_types=flattened_data_types): + if isinstance(data_shape, DataDesc): + return data_shape + elif isinstance(data_shape, tuple): + desc = DataDesc(name='data' + str(counter[0]), shape=data_shape, + dtype=data_types[counter[0]]) + counter[0] += 1 + return desc + else: + raise ValueError('data_shapes must be either a list of DataDesc or a list of Tuple') + if len(data_shapes) == 1 and not isinstance(data_shapes[0], list): + data_descs = [DataDesc(name='data', shape=data_shapes[0], dtype=data_types[0])] + else: + data_descs = _multilist_iterator(data_shapes, get_data_desc) + + return data_descs + @wrap_ctx_to_device_func def quantize_model(sym, arg_params, aux_params, data_names=('data',), device=cpu(), excluded_sym_names=None, excluded_op_names=None, calib_mode='entropy', @@ -841,8 +880,8 @@ def quantize_net(network, quantized_dtype='auto', quantize_mode='full', quantize x = iter(calib_data) batch = next(x) if isinstance(batch, list): - data_shapes = [b.shape for b in batch] - data_types = [b.dtype for b in batch] + data_shapes = _multilist_iterator(batch, lambda x: x.shape) + data_types = _multilist_iterator(batch, lambda x: x.dtype) else: data_shapes = [batch.shape] data_types = [batch.dtype] @@ -850,16 +889,15 @@ def quantize_net(network, quantized_dtype='auto', quantize_mode='full', quantize raise ValueError('calib_data expects mx.gluon.data.DataLoader') if data_types is None: - data_types = [mx_real_t] * len(data_shapes) + data_types = _multilist_iterator(data_shapes, lambda x: mx_real_t) + data_descs = _generate_list_of_data_desc(data_shapes, data_types) num_inputs = len(data_descs) data_nd = [] - for desc in data_descs: - if is_np_array(): - data_nd.append(mx.np.zeros(shape=desc.shape, dtype=desc.dtype)) - else: - data_nd.append(mx.nd.zeros(shape=desc.shape, dtype=desc.dtype)) + arr_fn = mx.np if is_np_array() else mx.nd + data_nd = _multilist_iterator(data_descs, lambda d, F=arr_fn: F.zeros(shape=d.shape, dtype=d.dtype)) + while True: try: network(*data_nd) @@ -919,7 +957,7 @@ def quantize_net(network, quantized_dtype='auto', quantize_mode='full', quantize raise ValueError( 'calib_data must be provided when calib_mode=%s' % calib_mode) if calib_mode in ['naive', 'entropy', 'custom']: - inputs = [mx.sym.var(desc.name) for desc in data_descs] + inputs = _multilist_iterator(data_descs, lambda dd: mx.sym.var(dd.name)) calib_net = SymbolBlock(symnet, inputs) for k, v in calib_net.collect_params().items(): v.grad_req = 'null' @@ -939,7 +977,7 @@ def quantize_net(network, quantized_dtype='auto', quantize_mode='full', quantize else: raise ValueError('calib_mode has to be one of: naive, entropy, custom') elif calib_mode is not None and calib_mode == 'none': - inputs = [mx.sym.var(desc.name) for desc in data_descs] + inputs = _multilist_iterator(data_descs, lambda dd: mx.sym.var(dd.name)) net = SymbolBlock(qsym, inputs) for k, v in net.collect_params().items(): diff --git a/python/mxnet/io/io.py b/python/mxnet/io/io.py index 4d78cd999bae..b0a01290c4f1 100644 --- a/python/mxnet/io/io.py +++ b/python/mxnet/io/io.py @@ -643,8 +643,11 @@ def provide_data(self): @property def provide_label(self): """The name and shape of label provided by this iterator.""" + batch_axis = self.layout.find('N') return [ - DataDesc(k, tuple([self.batch_size] + list(v.shape[1:])), v.dtype) + DataDesc(k, tuple(list(v.shape[:batch_axis]) + \ + [self.batch_size] + list(v.shape[batch_axis + 1:])), + v.dtype, layout=self.layout) for k, v in self.label ] diff --git a/src/operator/nn/dnnl/dnnl_rnn-inl.h b/src/operator/nn/dnnl/dnnl_rnn-inl.h index f28753461e58..6165dfaeb4c4 100644 --- a/src/operator/nn/dnnl/dnnl_rnn-inl.h +++ b/src/operator/nn/dnnl/dnnl_rnn-inl.h @@ -32,10 +32,20 @@ #include "operator/rnn-inl.h" #include "dnnl_base-inl.h" +#include "operator/quantization/quantized_rnn-inl.h" namespace mxnet { namespace op { +struct DNNLRnnParam : public dmlc::Parameter { + bool quantized; + + DMLC_DECLARE_PARAMETER(DNNLRnnParam) { + DMLC_DECLARE_FIELD(quantized).set_default(false).describe( + "Whether it's a quantized RNN operator"); + } +}; + struct DNNLRnnLayerParam { using memory = dnnl::memory; using dims = dnnl::memory::dims; @@ -66,6 +76,10 @@ struct DNNLRnnLayerParam { size_t native_single_b_size; // bias size of a single cell from framework size_t single_state_size; // state size of a single cell, hy, cy + bool quantized; // whether this layer is quantized + bool enable_u8_output; // true by default, only be false when it is the last fusion layer of the + // quantized rnn operator + DNNLRnnLayerParam(int num_layer, index_t batch_size, index_t seq_len, @@ -82,7 +96,9 @@ struct DNNLRnnLayerParam { input_size(input_size), state_size(state_size), proj_size(proj_size), - seq_len(seq_len) {} + seq_len(seq_len), + quantized(false), + enable_u8_output(false) {} void SetDims(); }; @@ -90,10 +106,11 @@ struct DNNLRnnLayerParam { typedef std::vector LayerParamVector; struct DNNLRnnFullParam { RNNParam default_param; + DNNLRnnParam dnnl_param; LayerParamVector layer_params; }; -DNNLRnnFullParam DNNLRnnFullParamParser(const RNNParam& rnn_param, +DNNLRnnFullParam DNNLRnnFullParamParser(const nnvm::NodeAttrs& attrs, const index_t seq_len, const index_t batch_size, const index_t input_size); @@ -105,7 +122,7 @@ class DNNLRnnMemMgr { // The memory buffer in NDArray life-cycle NDArray workspace_; // This points to the memory buffer from a NDArray - char* curr_mem; + char* curr_mem = nullptr; // The total bytes of the workspace of a DNNLRnnOp size_t mem_size = 0; // The current available memory bytes @@ -121,7 +138,7 @@ class DNNLRnnMemMgr { * \param size byte number * \param ctx Context of device enviroment */ - void Init(dim_t size, const Context& ctx); + void Init(const dim_t size, const Context& ctx); // Return the bytes number of the buffer const size_t Size() { @@ -135,6 +152,8 @@ class DNNLRnnMemMgr { dnnl::memory* Alloc(const dnnl::memory::desc& md); }; +typedef std::shared_ptr shared_dnnl_attr_t; + /* * Rnn Primitive. */ @@ -144,15 +163,15 @@ class RnnPrimitive { * lstm_forward, lbr_gru_forward, vanilla_rnn_forward */ template - static RnnPrimitive Create(Args&&... args) { + static RnnPrimitive Create(const shared_dnnl_attr_t attr, Args&&... args) { RnnPrimitive rnn_fwd_prim; auto fwd_desc = typename rnn_fwd::desc(std::forward(args)...); rnn_fwd_prim.fwd_pd_.reset( - new typename rnn_fwd::primitive_desc(fwd_desc, CpuEngine::Get()->get_engine()), - [](typename rnn_fwd::primitive_desc* pd) { - delete reinterpret_cast(pd); - }); + new typename rnn_fwd::primitive_desc( + fwd_desc, attr ? *attr : dnnl::primitive_attr(), CpuEngine::Get()->get_engine()), + [](void* pd) { delete reinterpret_cast(pd); }); auto fwd_pd = reinterpret_cast(rnn_fwd_prim.fwd_pd_.get()); + rnn_fwd_prim.attr_ = attr; rnn_fwd_prim.weights_layer_desc_ = fwd_pd->weights_layer_desc(); rnn_fwd_prim.weights_iter_desc_ = fwd_pd->weights_iter_desc(); rnn_fwd_prim.weights_proj_desc_ = fwd_pd->weights_projection_desc(); @@ -164,6 +183,7 @@ class RnnPrimitive { } RnnPrimitive() { + this->attr_ = nullptr; this->fwd_pd_ = nullptr; this->primitive_ = nullptr; this->weights_layer_desc_ = dnnl::memory::desc(); @@ -173,6 +193,7 @@ class RnnPrimitive { } RnnPrimitive(const RnnPrimitive& rnn_fwd_prim) { + this->attr_ = rnn_fwd_prim.attr_; this->fwd_pd_ = rnn_fwd_prim.fwd_pd_; this->primitive_ = rnn_fwd_prim.primitive_; this->weights_layer_desc_ = rnn_fwd_prim.weights_layer_desc_; @@ -183,6 +204,7 @@ class RnnPrimitive { RnnPrimitive& operator=(const RnnPrimitive& rnn_fwd_prim) { if (this != &rnn_fwd_prim) { + this->attr_ = rnn_fwd_prim.attr_; this->fwd_pd_ = rnn_fwd_prim.fwd_pd_; this->primitive_ = rnn_fwd_prim.primitive_; this->weights_layer_desc_ = rnn_fwd_prim.weights_layer_desc_; @@ -217,9 +239,14 @@ class RnnPrimitive { return workspace_desc_; } + const dnnl::primitive_attr& GetPrimAttr() const { + return *attr_; + } + private: std::shared_ptr fwd_pd_; std::shared_ptr primitive_; + shared_dnnl_attr_t attr_; dnnl::memory::desc weights_layer_desc_; dnnl::memory::desc weights_iter_desc_; dnnl::memory::desc weights_proj_desc_; @@ -229,7 +256,8 @@ class RnnPrimitive { RnnPrimitive GetRnnFwdPrim(const DNNLRnnLayerParam& layer_param, const bool is_train, const NDArray& data, - const NDArray& params); + const NDArray& params, + const shared_dnnl_attr_t attr = nullptr); /* * Use this to manage memory and primitive of DNNL RNN forward inference. @@ -240,11 +268,12 @@ class DNNLRnnForward { const DNNLRnnLayerParam& layer_param, const bool is_train, const NDArray& data, - const NDArray& params) + const NDArray& params, + const shared_dnnl_attr_t attr = nullptr) : ctx_(ctx), initialized_(false), param_(layer_param), - fwd_inf_(GetRnnFwdPrim(layer_param, false, data, params)) {} + fwd_inf_(GetRnnFwdPrim(layer_param, false, data, params, attr)) {} void SetNewDataMem(void* x, void* hx, @@ -263,6 +292,10 @@ class DNNLRnnForward { return fwd_inf_.GetPrim(); } + void ResetFwd(const NDArray& data, const NDArray& params, const shared_dnnl_attr_t& attr) { + fwd_inf_ = GetRnnFwdPrim(this->param_, false, data, params, attr); + } + const size_t GetSize() const { const size_t size = fwd_inf_.GetLayerDesc().get_size() + fwd_inf_.GetIterDesc().get_size() + fwd_inf_.GetProjDesc().get_size(); @@ -482,13 +515,13 @@ class DNNLRnnBackward { */ class DNNLRnnOp { public: - explicit DNNLRnnOp(const RNNParam& param, + explicit DNNLRnnOp(const nnvm::NodeAttrs& attrs, const int seq_len, const int batch_size, const int input_size) : initialized_(false), weights_version_(0), - full_param_(DNNLRnnFullParamParser(param, seq_len, batch_size, input_size)) {} + full_param_(DNNLRnnFullParamParser(attrs, seq_len, batch_size, input_size)) {} void Forward(const OpContext& ctx, const std::vector& inputs, diff --git a/src/operator/nn/dnnl/dnnl_rnn.cc b/src/operator/nn/dnnl/dnnl_rnn.cc index 0d65eb99350d..bdda9b5e2259 100644 --- a/src/operator/nn/dnnl/dnnl_rnn.cc +++ b/src/operator/nn/dnnl/dnnl_rnn.cc @@ -33,6 +33,8 @@ namespace mxnet { namespace op { +DMLC_REGISTER_PARAMETER(DNNLRnnParam); + inline int GetRnnGatesNum(int mode) { switch (mode) { case rnn_enum::kLstm: @@ -88,13 +90,28 @@ void DNNLRnnLayerParam::SetDims() { reserve_size = 0; } -DNNLRnnFullParam DNNLRnnFullParamParser(const RNNParam& rnn_param, +DNNLRnnFullParam DNNLRnnFullParamParser(const NodeAttrs& attrs, const index_t seq_len, const index_t batch_size, const index_t input_size) { + const RNNParam& rnn_param = nnvm::get(attrs.parsed); DNNLRnnFullParam full_param; full_param.default_param = rnn_param; - const int state_size = rnn_param.state_size; + try { + full_param.dnnl_param.Init(attrs.dict, dmlc::parameter::kAllowUnknown); + } catch (const dmlc::ParamError& e) { + std::ostringstream os; + os << e.what(); + os << ", in operator " << attrs.op->name << "(" + << "name=\"" << attrs.name << "\""; + for (const auto& k : attrs.dict) { + os << ", " << k.first << "=\"" << k.second << "\""; + } + os << ")"; + throw dmlc::ParamError(os.str()); + } + + const int state_size = rnn_param.state_size; const int proj_size = rnn_param.projection_size.has_value() ? rnn_param.projection_size.value() : -1; const int iter_size = @@ -135,15 +152,20 @@ DNNLRnnFullParam DNNLRnnFullParamParser(const RNNParam& rnn_param, false); } - // Set dims, workspace size, and state_outputs flag + // Set dims, workspace size, state_outputs, quantized and enable_u8_output flag for (auto& layer_param : layer_params) { layer_param.SetDims(); - layer_param.state_outputs = rnn_param.state_outputs; + layer_param.state_outputs = rnn_param.state_outputs; + layer_param.quantized = full_param.dnnl_param.quantized; + layer_param.enable_u8_output = true; } + // Quantized RNN operator produces kFloat32 outputs. + if (full_param.dnnl_param.quantized) + layer_params.back().enable_u8_output = false; return full_param; } -void DNNLRnnMemMgr::Init(dim_t size, const Context& ctx) { +void DNNLRnnMemMgr::Init(const dim_t size, const Context& ctx) { workspace_ = NDArray(TShape({size}), ctx, false, mshadow::kUint8); if (workspace_.data().dptr_ == nullptr) LOG(FATAL) << "oneDNN RNN operator memory allocation error."; @@ -178,39 +200,48 @@ dnnl::memory* DNNLRnnMemMgr::Alloc(const dnnl::memory::desc& md) { RnnPrimitive GetRnnFwdPrim(const DNNLRnnLayerParam& layer_param, const bool is_train, const NDArray& data, - const NDArray& params) { + const NDArray& params, + const shared_dnnl_attr_t attr) { using namespace dnnl; - using tag = dnnl::memory::format_tag; - const int mode = layer_param.mode; - memory::data_type data_type = get_dnnl_type(data.dtype()); - memory::data_type weight_type = get_dnnl_type(params.dtype()); + using tag = dnnl::memory::format_tag; + const int mode = layer_param.mode; + memory::data_type src_layer_dtype = get_dnnl_type(data.dtype()); + memory::data_type iter_dtype = get_dnnl_type(mshadow::kFloat32); + memory::data_type weight_dtype = + get_dnnl_type(layer_param.quantized ? mshadow::kInt8 : params.dtype()); + memory::data_type bias_dtype = get_dnnl_type(mshadow::kFloat32); + memory::data_type dst_layer_dtype = + get_dnnl_type((layer_param.quantized && layer_param.enable_u8_output) ? mshadow::kUint8 : + mshadow::kFloat32); + const prop_kind prop = is_train ? prop_kind::forward_training : prop_kind::forward_inference; const rnn_direction dnnl_rnn_direction = layer_param.bidirectional ? rnn_direction::bidirectional_concat : rnn_direction::unidirectional; - auto src_layer_desc = memory::desc(layer_param.src_dims, data_type, tag::tnc); - auto weight_layer_desc = memory::desc(layer_param.weight_layer_dims, weight_type, tag::any); - auto weight_iter_desc = memory::desc(layer_param.weight_iter_dims, weight_type, tag::any); - auto bias_desc = memory::desc(layer_param.bias_dims, data_type, tag::ldgo); - auto dst_layer_desc = memory::desc(layer_param.dst_dims, data_type, tag::tnc); - auto src_state_desc = memory::desc(layer_param.state_dims, data_type, tag::ldnc); - auto src_cell_desc = memory::desc(layer_param.cell_dims, data_type, tag::ldnc); + auto src_layer_desc = memory::desc(layer_param.src_dims, src_layer_dtype, tag::tnc); + auto weight_layer_desc = memory::desc(layer_param.weight_layer_dims, weight_dtype, tag::any); + auto weight_iter_desc = memory::desc(layer_param.weight_iter_dims, weight_dtype, tag::any); + auto bias_desc = memory::desc(layer_param.bias_dims, bias_dtype, tag::ldgo); + auto dst_layer_desc = memory::desc(layer_param.dst_dims, dst_layer_dtype, tag::tnc); + auto src_state_desc = memory::desc(layer_param.state_dims, iter_dtype, tag::ldnc); + auto src_cell_desc = memory::desc(layer_param.cell_dims, iter_dtype, tag::ldnc); auto weight_peep_desc = memory::desc(); auto weight_proj_desc = layer_param.proj_size > 0 ? - memory::desc(layer_param.weight_proj_dims, weight_type, tag::any) : + memory::desc(layer_param.weight_proj_dims, weight_dtype, tag::any) : memory::desc(); auto dst_state_desc = layer_param.state_outputs ? - memory::desc(layer_param.state_dims, data_type, tag::ldnc) : + memory::desc(layer_param.state_dims, iter_dtype, tag::ldnc) : memory::desc(); auto dst_cell_desc = layer_param.state_outputs ? - memory::desc(layer_param.cell_dims, data_type, tag::ldnc) : + memory::desc(layer_param.cell_dims, iter_dtype, tag::ldnc) : memory::desc(); auto fwd = RnnPrimitive(); switch (mode) { case rnn_enum::kLstm: - fwd = RnnPrimitive::Create(prop, + fwd = RnnPrimitive::Create(attr, + prop, dnnl_rnn_direction, src_layer_desc, src_state_desc, @@ -225,7 +256,8 @@ RnnPrimitive GetRnnFwdPrim(const DNNLRnnLayerParam& layer_param, dst_cell_desc); break; case rnn_enum::kGru: - fwd = RnnPrimitive::Create(prop, + fwd = RnnPrimitive::Create(attr, + prop, dnnl_rnn_direction, src_layer_desc, src_state_desc, @@ -238,6 +270,7 @@ RnnPrimitive GetRnnFwdPrim(const DNNLRnnLayerParam& layer_param, case rnn_enum::kRnnRelu: case rnn_enum::kRnnTanh: fwd = RnnPrimitive::Create( + attr, prop, mode == rnn_enum::kRnnTanh ? algorithm::eltwise_tanh : algorithm::eltwise_relu, dnnl_rnn_direction, @@ -449,11 +482,19 @@ void DNNLRnnForward::SetNewDataMem(void* x, auto& cpu_engine = CpuEngine::Get()->get_engine(); dnnl_args_map_t& args = net_args_; + int src_dtype = dtype; + int dst_dtype = dtype; + if (param_.quantized) { + src_dtype = mshadow::kUint8; + if (param_.enable_u8_output) + dst_dtype = mshadow::kUint8; + } + RNN_HANDLE_FUNC(RNN_HANDLE_FUNC_NAME); // Set various data memory - RNN_FWD_SET(SRC, param_.src_dims, format_tag::tnc, x, dtype); - RNN_FWD_SET(DST, param_.dst_dims, format_tag::tnc, y, dtype); + RNN_FWD_SET(SRC, param_.src_dims, format_tag::tnc, x, src_dtype); + RNN_FWD_SET(DST, param_.dst_dims, format_tag::tnc, y, dst_dtype); RNN_FWD_SET(SRC_ITER, param_.state_dims, format_tag::ldnc, hx, dtype); if (param_.state_outputs) { @@ -495,10 +536,25 @@ inline void DNNLMemoryReorder(const dnnl::memory& src, const dnnl::memory& dst) * with primitive-prefered format. */ void DNNLRnnForward::ReorderWeights() { - DNNLMemoryReorder(*weights_layer_r_, *weights_layer_); - DNNLMemoryReorder(*weights_iter_r_, *weights_iter_); - if (param_.proj_size > 0) - DNNLMemoryReorder(*weights_proj_r_, *weights_proj_); + if (param_.quantized) { + const dnnl::primitive_attr& attr = this->fwd_inf_.GetPrimAttr(); + auto ReorderWithAttr = [&](dnnl::memory& src, dnnl::memory& dst) { + auto reorder_pd = dnnl::reorder::primitive_desc(src, dst, attr); + dnnl_args_map_t net_args; + net_args[DNNL_ARG_SRC] = src; + net_args[DNNL_ARG_DST] = dst; + DNNLStream::Get()->RegisterPrimArgs(dnnl::reorder(reorder_pd), net_args); + }; + ReorderWithAttr(*weights_layer_r_, *weights_layer_); + ReorderWithAttr(*weights_iter_r_, *weights_iter_); + if (param_.proj_size > 0) + ReorderWithAttr(*weights_proj_r_, *weights_proj_); + } else { + DNNLMemoryReorder(*weights_layer_r_, *weights_layer_); + DNNLMemoryReorder(*weights_iter_r_, *weights_iter_); + if (param_.proj_size > 0) + DNNLMemoryReorder(*weights_proj_r_, *weights_proj_); + } } void AdjustGruGateOrder(char* weight, @@ -573,7 +629,7 @@ inline void EmplaceNetArgs(dnnl_args_map_t* net_args, const int arg_name, const */ void DNNLRnnForward::SetWeightsMem(void* w_ptr, void* b_ptr, const bool is_train, const int dtype) { using format_tag = dnnl::memory::format_tag; - auto dnnl_dtype = get_dnnl_type(dtype); + const auto dnnl_dtype = get_dnnl_type(dtype); const size_t dtype_bytes = mshadow::mshadow_sizeof(dtype); const size_t buffer_bytes = @@ -702,7 +758,7 @@ void DNNLRnnForward::SetWeightsMem(void* w_ptr, void* b_ptr, const bool is_train // in forward training path, we use plain memory (ldxxx) as the space for weights and // their gradients. Then, forward training primitives could fetch them from the scope // of forward inference. And from there, we don't need to reorder the plain memory to - // the optimal rnn-packed memory for forward inference. + // the optimal rnn-packed memory for forward inference ReorderWeights(); initialized_ = true; } @@ -764,6 +820,19 @@ void DNNLRnnOp::Init(const OpContext& op_ctx, const std::vector& outputs) { using format_tag = dnnl::memory::format_tag; + // Get the bytes of a real type + const NDArray& weights = inputs[rnn_enum::kParams]; + int dtype = weights.dtype(); + size_t dtype_bytes = mshadow::mshadow_sizeof(dtype); + const RNNParam& default_param = full_param_.default_param; + const size_t weights_size = + weights.data().Size() - GetRnnBiasSize(default_param.num_layers, + default_param.state_size, + default_param.bidirectional + 1, + default_param.mode); + char* weights_ptr = static_cast(weights.data().dptr_); + char* bias_ptr = weights_ptr + weights_size * dtype_bytes; + // In the `autograd.record()` context, RNNOp is required to run into // `forward_training` mode. const bool is_training = (op_ctx.is_train || op_ctx.need_grad); @@ -772,7 +841,7 @@ void DNNLRnnOp::Init(const OpContext& op_ctx, if (fwd_inf_vec_.size() < num_fusion) { for (auto& layer_param : full_param_.layer_params) { fwd_inf_vec_.emplace_back( - ctx, layer_param, false, inputs[rnn_enum::kData], inputs[rnn_enum::kParams]); + ctx, layer_param, false, inputs[rnn_enum::kData], inputs[rnn_enum::kParams], nullptr); } } @@ -783,19 +852,6 @@ void DNNLRnnOp::Init(const OpContext& op_ctx, } } - // Get the bytes of a real type - const NDArray& weights = inputs[rnn_enum::kParams]; - int dtype = weights.dtype(); - size_t dtype_bytes = mshadow::mshadow_sizeof(dtype); - - const RNNParam& default_param = full_param_.default_param; - char* weights_ptr = static_cast(weights.data().dptr_); - char* bias_ptr = - weights_ptr + (weights.data().Size() - GetRnnBiasSize(default_param.num_layers, - default_param.state_size, - default_param.bidirectional + 1, - default_param.mode)) * - dtype_bytes; for (auto& fwd_layer : fwd_inf_vec_) { size_t single_w_bytes = fwd_layer.GetParam().single_w_size * dtype_bytes; size_t single_b_bytes = fwd_layer.GetParam().native_single_b_size * dtype_bytes; @@ -819,7 +875,7 @@ void DNNLRnnOp::Init(const OpContext& op_ctx, CHECK_EQ(num_fusion, fwd_inf_vec_.size()) << "Layer vector's size has a different value than the number of fusion."; if (dst_.size() < num_fusion - 1) { - int data_dtype = outputs[rnn_enum::kOut].dtype(); + const int data_dtype = outputs[rnn_enum::kOut].dtype(); const size_t data_dbytes = mshadow::mshadow_sizeof(data_dtype); mgr_.Init((outputs[rnn_enum::kOut].data().Size() * data_dbytes + kDNNLAlign) * (num_fusion - 1), op_ctx.run_ctx.ctx); @@ -1121,7 +1177,7 @@ void DNNLRnnOp::Forward(const OpContext& ctx, } // Get data type - int data_dtype = inputs[rnn_enum::kData].dtype(); + int data_dtype = outputs[rnn_enum::kOut].dtype(); // Get temporary memory for output, state_out, statecell_out const int num_layers = default_param.num_layers; const int seq_length = default_param.seq_length_; diff --git a/src/operator/quantization/dnnl/dnnl_quantize_asym-inl.h b/src/operator/quantization/dnnl/dnnl_quantize_asym-inl.h new file mode 100644 index 000000000000..9bbbd2d9eb54 --- /dev/null +++ b/src/operator/quantization/dnnl/dnnl_quantize_asym-inl.h @@ -0,0 +1,161 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file dnnl_quantize_asym-inl.h + * \brief implementation of asymmetric quantize operation using DNNL + */ + +#ifndef MXNET_OPERATOR_QUANTIZATION_DNNL_DNNL_QUANTIZE_ASYM_INL_H_ +#define MXNET_OPERATOR_QUANTIZATION_DNNL_DNNL_QUANTIZE_ASYM_INL_H_ +#if MXNET_USE_ONEDNN == 1 + +#include +#include +#include "operator/nn/dnnl/dnnl_base-inl.h" +#include "operator/quantization/quantize_asym-inl.h" + +namespace mxnet { +namespace op { + +class DNNLQuantizeAsymOp { + public: + explicit DNNLQuantizeAsymOp(const nnvm::NodeAttrs& attrs) + : param_(nnvm::get(attrs.parsed)) {} + + void Forward(const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs); + + private: + QuantizeAsymParam param_; + bool initialized_{false}; + float cached_scale_{0.f}; + float cached_shift_{0.f}; + dnnl::memory::desc o_desc_; + dnnl_args_map_t args_; + std::shared_ptr fwd_pd_; +}; + +void DNNLQuantizeAsymOp::Forward(const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using mshadow::red::limits::MaxValue; + using mshadow::red::limits::MinValue; + NDArray in_buffer = inputs[0]; + float scale = 0.f; + float shift = 0.f; + + // Pass through quantized data + if (inputs[0].dtype() == mshadow::kUint8) { + *outputs[1].data().dptr() = 1; + *outputs[2].data().dptr() = 0; + if (req[0] != kWriteInplace) { + const_cast(outputs[0]).CopyFrom(*inputs[0].GetDNNLData()); + DNNLStream::Get()->Submit(); + } + } else { + in_buffer = inputs[0].Reorder2Default(); + const dnnl::memory* i_mem = in_buffer.GetDNNLData(); + float* in_ptr = in_buffer.data().dptr(); + const int nthreads = engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); + if (inputs[0].dtype() == mshadow::kInt8) { + *outputs[1].data().dptr() = 1; + *outputs[2].data().dptr() = 128; +#pragma omp parallel for num_threads(nthreads) + for (index_t i = 0; i < static_cast(in_buffer.shape().Size()); ++i) { + in_ptr[i] += 128.0f; + } + } else if (inputs[0].dtype() == mshadow::kFloat32) { + if (param_.min_calib_range.has_value() && param_.max_calib_range.has_value()) { + scale = + MaxValue() / (param_.max_calib_range.value() - param_.min_calib_range.value()); + shift = MaxValue() - param_.max_calib_range.value() * scale; + } else { + float data_min = mshadow::red::limits::MaxValue(); + float data_max = mshadow::red::limits::MinValue(); + std::vector data_maxs(nthreads, data_max); + std::vector data_mins(nthreads, data_min); +#pragma omp parallel for num_threads(nthreads) + for (index_t i = 0; i < static_cast(in_buffer.shape().Size()); i++) { + int tid = omp_get_thread_num(); + if (in_ptr[i] > data_maxs[tid]) + data_maxs[tid] = in_ptr[i]; + if (in_ptr[i] < data_mins[tid]) + data_mins[tid] = in_ptr[i]; + } + for (index_t i = 0; i < nthreads; i++) { + if (data_maxs[i] > data_max) + data_max = data_maxs[i]; + if (data_mins[i] < data_min) + data_min = data_mins[i]; + } + scale = MaxValue() / (data_max - data_min); + shift = MaxValue() - data_max * scale; + } + + if (initialized_ && (cached_scale_ != scale || cached_shift_ != shift)) + initialized_ = false; + } + + *outputs[1].data().dptr() = scale; + *outputs[2].data().dptr() = shift; + + if (!initialized_) { + cached_scale_ = scale; + cached_shift_ = shift; + dnnl::primitive_attr attr; + attr.set_rnn_data_qparams(scale, shift); + const dnnl::engine& cpu_engine = mxnet::CpuEngine::Get()->get_engine(); + const dnnl::memory::desc& i_desc = i_mem->get_desc(); + o_desc_ = i_desc; + o_desc_.data.data_type = get_dnnl_type_t(outputs[0].dtype()); + dnnl::reorder::primitive_desc reorder_pd(cpu_engine, i_desc, cpu_engine, o_desc_, attr); + fwd_pd_ = std::make_shared(reorder_pd); + initialized_ = true; + } + dnnl_output_t o_mem = CreateDNNLMem(outputs[0], o_desc_, req[0]); + args_[DNNL_ARG_FROM] = *i_mem; + args_[DNNL_ARG_TO] = *o_mem.second; + DNNLStream::Get()->RegisterPrimArgs(*fwd_pd_, args_); + CommitOutput(outputs[0], o_mem); + DNNLStream::Get()->Submit(); + } +} + +void DNNLQuantizeAsymForward(const OpStatePtr& state_ptr, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + if (inputs[0].shape().ndim() == 3 && inputs[0].dtype() == mshadow::kFloat32) { + DNNLQuantizeAsymOp& op = state_ptr.get_state(); + op.Forward(ctx, inputs, req, outputs); + } else { + FallBackCompute(QuantizeAsymForward, state_ptr, ctx, inputs, req, outputs); + } +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_USE_ONEDNN == 1 +#endif // MXNET_OPERATOR_QUANTIZATION_DNNL_DNNL_QUANTIZE_ASYM_INL_H_ diff --git a/src/operator/quantization/dnnl/dnnl_quantized_rnn-inl.h b/src/operator/quantization/dnnl/dnnl_quantized_rnn-inl.h new file mode 100644 index 000000000000..cdd5417e3ea3 --- /dev/null +++ b/src/operator/quantization/dnnl/dnnl_quantized_rnn-inl.h @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file dnnl_quantized_rnn-inl.h + * \brief Common functions for quantized recurrent neural network + * \author Zixuan Wei + */ + +#ifndef MXNET_OPERATOR_QUANTIZATION_DNNL_DNNL_QUANTIZED_RNN_INL_H_ +#define MXNET_OPERATOR_QUANTIZATION_DNNL_DNNL_QUANTIZED_RNN_INL_H_ + +#if MXNET_USE_ONEDNN == 1 + +#include +#include "operator/nn/dnnl/dnnl_rnn-inl.h" +#include "operator/rnn-inl.h" +#include "operator/quantization/quantized_rnn-inl.h" + +namespace mxnet { +namespace op { + +class DNNLQuantizedRnnOp { + public: + explicit DNNLQuantizedRnnOp(const nnvm::NodeAttrs& attrs, + const int seq_len, + const int batch_size, + const int input_size) + : initialized_(false), + weights_ver_(0), + rnn_attr_(new dnnl::primitive_attr), + full_param_(DNNLRnnFullParamParser(attrs, seq_len, batch_size, input_size)) {} + + void Forward(const OpContext& op_ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs); + + private: + bool initialized_; + size_t weights_ver_; + shared_dnnl_attr_t rnn_attr_; + DNNLRnnFullParam full_param_; + DNNLRnnMemMgr mgr_; + std::vector fwd_inf_vec_; // forward inference layers + + // Used to store the intermediate results of multi-layer + std::vector dst_; + // According to + // https://oneapi-src.github.io/oneDNN/dev_guide_int8_computations.html, the + // non-symmetric quantization is assumed by LSTM primitive. Namely, the + // formula is: + // data_f32 = (data_u8 - shift) / scale + float cached_data_shift_{0.0}; + float cached_data_scale_{0.0}; + void Init(const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs); +}; + +} // namespace op +} // namespace mxnet + +#endif // MXNET_USE_ONEDNN == 1 +#endif // MXNET_OPERATOR_QUANTIZATION_DNNL_DNNL_QUANTIZED_RNN_INL_H_ diff --git a/src/operator/quantization/dnnl/dnnl_quantized_rnn.cc b/src/operator/quantization/dnnl/dnnl_quantized_rnn.cc new file mode 100644 index 000000000000..73393d9b4c36 --- /dev/null +++ b/src/operator/quantization/dnnl/dnnl_quantized_rnn.cc @@ -0,0 +1,366 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file dnnl_quantized_rnn.cc + * \brief Common functions for quantized recurrent neural network + * \author Zixuan Wei + */ + +#if MXNET_USE_ONEDNN == 1 + +#include "operator/quantization/quantization_utils.h" +#include "operator/quantization/dnnl/dnnl_quantized_rnn-inl.h" + +namespace mxnet { +namespace op { + +std::vector GetDNNLRnnWeightsQParams(const DNNLRnnFullParam& full_param, float* w_ptr) { + const int nthreads = mxnet::engine::OpenMP::Get()->GetRecommendedOMPThreadCount(); + const int num_gates = 4; + const RNNParam& default_param = full_param.default_param; + const LayerParamVector& layer_params = full_param.layer_params; + + const DNNLRnnLayerParam& layer_param0 = layer_params.at(0); + const size_t w_size0 = layer_param0.single_w_size; + const size_t wx_size0 = num_gates * layer_param0.state_size * layer_param0.input_size; + const size_t wh_size0 = num_gates * layer_param0.state_size * layer_param0.state_size; + + int directions = 1; + float* wx = w_ptr; + float* wh = wx + wx_size0; + float* fake_wx = wx; + float* fake_wh = wh; + + std::vector wx_goi_max; + std::vector wh_goi_max; + if (default_param.bidirectional) { + directions = 2; + wx_goi_max.resize(wx_size0); + wh_goi_max.resize(wh_size0); + fake_wx = wx_goi_max.data(); + fake_wh = wh_goi_max.data(); +#pragma omp parallel for num_threads(nthreads) + for (index_t i = 0; i < static_cast(wx_size0); ++i) { + fake_wx[i] = MaxAbs(wx[i], wx[i + w_size0]); + } +#pragma omp parallel for num_threads(nthreads) + for (index_t i = 0; i < static_cast(wh_size0); ++i) { + fake_wh[i] = MaxAbs(wh[i], wh[i + w_size0]); + } + } + std::vector w_max(num_gates * layer_param0.state_size, 0.0); + const index_t input_size = layer_param0.input_size; // input + const index_t state_size = layer_param0.state_size; // state + const index_t gates_nblks = num_gates * layer_param0.state_size; // gates * state + for (index_t go = 0; go < gates_nblks; ++go) { + float tmp_max = w_max[go]; + for (index_t i = 0; i < input_size; ++i) { + tmp_max = MaxAbs(fake_wx[go * input_size + i], tmp_max); + } + for (index_t i = 0; i < state_size; ++i) { + tmp_max = MaxAbs(fake_wh[go * state_size + i], tmp_max); + } + w_max[go] = tmp_max; + } + wx += layer_param0.single_w_size * directions; + wh += layer_param0.single_w_size * directions; + + std::vector goi_max(wh_size0, 0.0); + for (size_t lyr = 1; lyr < layer_params.size(); ++lyr) { + const DNNLRnnLayerParam& layer_param = layer_params.at(lyr); + const int weight_nblks = layer_param.num_layer * directions; + for (int blk = 0; blk < weight_nblks; ++blk) { +#pragma omp parallel for num_threads(nthreads) + for (index_t i = 0; i < static_cast(wh_size0); ++i) { + goi_max[i] = MaxAbs(wx[i], wh[i]); + } + for (index_t go = 0; go < gates_nblks; ++go) { + float tmp = w_max[go]; +// NOTES: min/max reductions were supported since OpenMP 3.1, which was +// released in Jul 2011 (hence the version number). +#if _OPENMP >= 201107 +#pragma omp parallel for reduction(max : tmp) num_threads(nthreads) +#endif + for (index_t i = 0; i < state_size; ++i) { + tmp = Max(goi_max[go * state_size + i], tmp); + } + w_max[go] = tmp; + } + } + wx += layer_param.single_w_size * directions; + wh = wx + wh_size0; + } +#pragma omp parallel for num_threads(nthreads) + for (index_t i = 0; i < static_cast(w_max.size()); ++i) { + w_max[i] = mshadow::red::limits::MaxValue() / w_max[i]; + } + return w_max; +} + +void DNNLQuantizedRnnOp::Init(const OpContext& op_ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using format_tag = dnnl::memory::format_tag; + + // Get the bytes of a real type + const Context& ctx = op_ctx.run_ctx.ctx; + const NDArray& weights = inputs[rnn_enum::kParams]; + int dtype = weights.dtype(); + int weights_dtype = weights.dtype(); + size_t dtype_bytes = mshadow::mshadow_sizeof(dtype); + const RNNParam& default_param = full_param_.default_param; + const size_t weights_size = + weights.data().Size() - GetRnnBiasSize(default_param.num_layers, + default_param.state_size, + default_param.bidirectional + 1, + default_param.mode); + char* weights_ptr = static_cast(weights.data().dptr_); + char* bias_ptr = weights_ptr + weights_size * dtype_bytes; + + // In the `autograd.record()` context, RNNOp is required to run into + // `forward_training` mode. + + const size_t num_fusion = full_param_.layer_params.size(); + if (fwd_inf_vec_.size() < num_fusion) { + size_t buffer_size = 0; // Element number, instead of bytes, in the buffer + for (auto& layer_param : full_param_.layer_params) { + buffer_size += layer_param.workspace_size + layer_param.reserve_size; + } + buffer_size += outputs[rnn_enum::kOut].data().Size() * (num_fusion - 1); + buffer_size += kDNNLAlign * num_fusion * 5; // Add margin for alignment + + for (auto& layer_param : full_param_.layer_params) { + fwd_inf_vec_.emplace_back( + ctx, layer_param, false, inputs[rnn_enum::kData], inputs[rnn_enum::kParams], rnn_attr_); + buffer_size += fwd_inf_vec_.back().GetSize(); + } + mgr_.Init(buffer_size, ctx); + } + + for (auto& fwd_layer : fwd_inf_vec_) { + size_t single_w_bytes = fwd_layer.GetParam().single_w_size * dtype_bytes; + size_t single_b_bytes = fwd_layer.GetParam().native_single_b_size * dtype_bytes; + size_t directions = fwd_layer.GetParam().bidirectional ? 2 : 1; + size_t layer_weights_bytes = single_w_bytes * directions; + size_t layer_bias_bytes = single_b_bytes * directions; // Native MXNet has double bias + + if (!fwd_layer.IsInitialized()) + fwd_layer.SetWeightsMem(weights_ptr, bias_ptr, false, weights_dtype); + weights_ptr += layer_weights_bytes; + bias_ptr += layer_bias_bytes; + } + + CHECK_EQ(num_fusion, fwd_inf_vec_.size()) + << "Layer vector's size has a different value than the number of fusion."; + if (dst_.size() < num_fusion - 1) { + const int data_dtype = outputs[rnn_enum::kOut].dtype(); + // Here we need `fwd_inf_vec_.size() - 1` spaces for the intermediate + // results of the multiple fused layers. And for the result of the last + // fused layer, `outputs[rnn_enum::kOut]` could provide the space. Hence, + // `forward_inf_vec_.back()` is excluded when allocates the spaces for + // intermediate results. + for (std::vector::const_iterator fwd = fwd_inf_vec_.begin(); + fwd != fwd_inf_vec_.end() - 1; + ++fwd) + dst_.push_back( + mgr_.Alloc({fwd->GetParam().dst_dims, get_dnnl_type(data_dtype), format_tag::tnc})); + } + + initialized_ = true; +} + +template +inline void RegisterDNNLRnn(DNNLRnnX const& rnn) { + DNNLStream::Get()->RegisterPrimArgs(rnn.GetFwd(), rnn.GetArgsMap()); +} + +template <> +inline void RegisterDNNLRnn(DNNLRnnBackward const& rnn) { + DNNLStream::Get()->RegisterPrimArgs(rnn.GetBwd(), rnn.GetArgsMap()); + rnn.SetNativeWeightsGrads(); +} + +void DNNLQuantizedRnnOp::Forward(const OpContext& op_ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + TmpMemMgr::Get()->Init(op_ctx.requested[0]); + + const RNNParam& default_param = full_param_.default_param; + const uint32_t num_base_inputs = GetRnnNumInputs(default_param); + float data_scale = inputs[num_base_inputs + quantized_rnn::kDataScale].data().dptr()[0]; + float data_shift = inputs[num_base_inputs + quantized_rnn::kDataShift].data().dptr()[0]; + + const bool need_reset_weight = (!dmlc::GetEnv("MXNET_RNN_USE_WEIGHT_CACHE", 0) && + weights_ver_ != inputs[rnn_enum::kParams].version()) ? + true : + false; + const NDArray& weights = inputs.at(rnn_enum::kParams); + float* weights_ptr = weights.data().dptr(); + if (!initialized_ || fwd_inf_vec_.empty()) { + weights_ver_ = inputs[rnn_enum::kParams].version(); + cached_data_scale_ = data_scale; + cached_data_shift_ = data_shift; + rnn_attr_->set_rnn_data_qparams(data_scale, data_shift); + if (need_reset_weight || fwd_inf_vec_.empty()) + rnn_attr_->set_rnn_weights_qparams(0 + (1 << 3) + (1 << 4), + GetDNNLRnnWeightsQParams(full_param_, weights_ptr)); + } + + // Initialize weights version + if (!initialized_ && weights_ver_ == 0) { + weights_ver_ = inputs[rnn_enum::kParams].version(); + cached_data_scale_ = data_scale; + cached_data_shift_ = data_shift; + } + + if (!fwd_inf_vec_.empty() && + ((cached_data_scale_ != data_scale || cached_data_shift_ != data_shift))) { + initialized_ = false; + weights_ver_ = inputs[rnn_enum::kParams].version(); + cached_data_scale_ = data_scale; + cached_data_shift_ = data_shift; + } + + // Check if weights NDArray was changed. If so, reset initialized_ + if (fwd_inf_vec_.size() > 0 && weights_ver_ != inputs[rnn_enum::kParams].version()) { + initialized_ = false; + for (auto& fwd : fwd_inf_vec_) + fwd.Reset(); + weights_ver_ = inputs[rnn_enum::kParams].version(); + cached_data_scale_ = data_scale; + cached_data_shift_ = data_shift; + } + + if (!initialized_ || fwd_inf_vec_.empty()) { + Init(op_ctx, inputs, req, outputs); + } + + // Get data type + int data_dtype = outputs[rnn_enum::kOut].dtype(); + // Get temporary memory for output, state_out, statecell_out + const int num_layers = default_param.num_layers; + const int seq_length = default_param.seq_length_; + const int batch_size = default_param.batch_size_; + const int state_size = default_param.state_size; + const int directions = default_param.bidirectional ? 2 : 1; + dnnl::memory::desc dst_desc({seq_length, batch_size, directions * state_size}, + get_dnnl_type(data_dtype), + dnnl::memory::format_tag::tnc); + dnnl::memory::desc state_desc({num_layers, directions, batch_size, state_size}, + get_dnnl_type(data_dtype), + dnnl::memory::format_tag::ldnc); + auto out_mem = CreateDNNLMem(outputs[rnn_enum::kOut], dst_desc, req[rnn_enum::kOut]); + dnnl_output_t stateout_mem; + dnnl_output_t statecellout_mem; + + // Get input & output NDArray + char* src = static_cast(inputs[rnn_enum::kData].data().dptr_); + char* src_state = static_cast(inputs[rnn_enum::kState].data().dptr_); + char* dst = static_cast(out_mem.second->get_data_handle()); + char* dst_state = nullptr; // Output state + char* src_state_cell = nullptr; // Used in LSTM for cell state + char* dst_state_cell = nullptr; // Used in LSTM for cell state + const size_t cell_bytes = (default_param.bidirectional + 1) * default_param.batch_size_ * + default_param.state_size * mshadow::mshadow_sizeof(data_dtype); + + if (default_param.state_outputs && req[rnn_enum::kStateOut] != kNullOp) { + stateout_mem = + CreateDNNLMem(outputs[rnn_enum::kStateOut], state_desc, req[rnn_enum::kStateOut]); + dst_state = static_cast(stateout_mem.second->get_data_handle()); + } + + if (default_param.mode == rnn_enum::kLstm) { + src_state_cell = static_cast(inputs[rnn_enum::kStateCell].data().dptr_); + if (default_param.state_outputs && req[rnn_enum::kStateCellOut] != kNullOp) { + statecellout_mem = + CreateDNNLMem(outputs[rnn_enum::kStateCellOut], state_desc, req[rnn_enum::kStateCellOut]); + dst_state_cell = static_cast(statecellout_mem.second->get_data_handle()); + } + } + + if (fwd_inf_vec_.size() == 1) { + fwd_inf_vec_.front().SetNewDataMem( + src, src_state, src_state_cell, dst, dst_state, dst_state_cell, data_dtype); + } else { + CHECK_EQ(fwd_inf_vec_.size(), dst_.size() + 1) << "Output memory error."; + size_t cell_bytes = (default_param.bidirectional + 1) * default_param.batch_size_ * + default_param.state_size * mshadow::mshadow_sizeof(data_dtype); + + // Set input data memory for the first layer. This stores intermediate + // output results in this->xxx, used as the source input of the next layer. + fwd_inf_vec_.front().SetNewDataMem(src, + src_state, + src_state_cell, + this->dst_.front()->get_data_handle(), + dst_state, + dst_state_cell, + data_dtype); + // 1st_lyr -> dst_handle -> next_lyr -> dst_handle -> next_lyr -> ... + for (size_t lyr = 1; lyr < fwd_inf_vec_.size() - 1; ++lyr) { + src_state += cell_bytes; + if (src_state_cell) + src_state_cell += cell_bytes; + if (dst_state) + dst_state += cell_bytes; + if (dst_state_cell) + dst_state_cell += cell_bytes; + fwd_inf_vec_.at(lyr).SetNewDataMem(this->dst_.at(lyr - 1)->get_data_handle(), + src_state, + src_state_cell, + this->dst_.at(lyr)->get_data_handle(), + dst_state, + dst_state_cell, + data_dtype); + } + // Set output data memory for the last layer. + src_state += cell_bytes; + if (src_state_cell) + src_state_cell += cell_bytes; + if (dst_state) + dst_state += cell_bytes; + if (dst_state_cell) + dst_state_cell += cell_bytes; + fwd_inf_vec_.back().SetNewDataMem(this->dst_.back()->get_data_handle(), + src_state, + src_state_cell, + dst, + dst_state, + dst_state_cell, + data_dtype); + } + + for (auto& inf_lyr : fwd_inf_vec_) + RegisterDNNLRnn(inf_lyr); + + CommitOutput(outputs[rnn_enum::kOut], out_mem); + if (default_param.state_outputs) { + CommitOutput(outputs[rnn_enum::kStateOut], stateout_mem); + if (default_param.mode == rnn_enum::kLstm) + CommitOutput(outputs[rnn_enum::kStateCellOut], statecellout_mem); + } + DNNLStream::Get()->Submit(); +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_USE_ONEDNN == 1 diff --git a/src/operator/quantization/quantize_asym-inl.h b/src/operator/quantization/quantize_asym-inl.h new file mode 100644 index 000000000000..3aa44c4e4fd6 --- /dev/null +++ b/src/operator/quantization/quantize_asym-inl.h @@ -0,0 +1,177 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file quantize_asym-inl.h + * \brief implementation of asymmetric quantize operation + */ +#ifndef MXNET_OPERATOR_QUANTIZATION_QUANTIZE_ASYM_INL_H_ +#define MXNET_OPERATOR_QUANTIZATION_QUANTIZE_ASYM_INL_H_ + +#include +#include +#include +#include +#include + +#include "../mshadow_op.h" +#include "../mxnet_op.h" +#include "../tensor/broadcast_reduce_op.h" +#include "./quantization_utils.h" + +namespace mxnet { +namespace op { + +struct QuantizeAsymParam : public dmlc::Parameter { + dmlc::optional min_calib_range; + dmlc::optional max_calib_range; + + DMLC_DECLARE_PARAMETER(QuantizeAsymParam) { + DMLC_DECLARE_FIELD(min_calib_range) + .set_default(dmlc::optional()) + .describe( + "The minimum scalar value in the form of float32. If " + "present, it will be used to " + "quantize the fp32 data."); + DMLC_DECLARE_FIELD(max_calib_range) + .set_default(dmlc::optional()) + .describe( + "The maximum scalar value in the form of float32. If " + "present, it will be used to " + "quantize the fp32 data."); + } +}; + +// quantize float to uint8_t +struct quantize_asymmetric { + template + MSHADOW_XINLINE static void Map(int i, + DstDType* out, + float* oscale, + float* oshift, + const SrcDType* in, + const float scale, + const float shift) { + out[i] = static_cast(in[i] * scale + shift + 0.5); + *oscale = scale; + *oshift = shift; + } +}; + +template +class QuantizeAsymOp { + public: + explicit QuantizeAsymOp(const nnvm::NodeAttrs& attrs) : attrs_(attrs) {} + + void Forward(const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + using namespace mshadow; + using namespace mxnet_op; + using mshadow::red::limits::MaxValue; + using mshadow::red::limits::MinValue; + + CHECK_EQ(outputs[0].type_flag_, mshadow::kUint8) + << "Asymmetric quantization only supports uint8 outputs."; + mshadow::Stream* s = ctx.get_stream(); + const int input_data_dtype = inputs[0].type_flag_; + if (input_data_dtype == mshadow::kUint8) { + *outputs[1].dptr() = 1; + *outputs[2].dptr() = 0; + UnaryOp::IdentityCompute(attrs_, ctx, {inputs[0]}, req, outputs); + } else if (input_data_dtype == mshadow::kInt8) { + const float scale = 1; + const float shift = 128; + Kernel::Launch(s, + outputs[0].Size(), + outputs[0].dptr(), + outputs[1].dptr(), + outputs[2].dptr(), + inputs[0].dptr(), + scale, + shift); + } else if (input_data_dtype == mshadow::kFloat32) { + const QuantizeAsymParam& param = nnvm::get(attrs_.parsed); + if (param.min_calib_range.has_value() && param.max_calib_range.has_value()) { + const float scale = + MaxValue() / (param.max_calib_range.value() - param.min_calib_range.value()); + const float shift = MaxValue() - param.max_calib_range.value() * scale; + Kernel::Launch(s, + outputs[0].Size(), + outputs[0].dptr(), + outputs[1].dptr(), + outputs[2].dptr(), + inputs[0].dptr(), + scale, + shift); + } else { + mxnet::TShape src_shape, dst_shape; + const size_t float_bytes = sizeof(float); + const size_t temp_reduce_size = ConfigReduce( + s, inputs[0].shape_, mxnet::TShape(1, 1), &src_shape, &dst_shape); + Tensor temp_space = ctx.requested[0].get_space_typed( + Shape1(2 * float_bytes + temp_reduce_size), s); + const int dev_id = ctx.run_ctx.ctx.dev_id; + TBlob in_min_t( + reinterpret_cast(temp_space.dptr_), Shape1(1), xpu::kDevMask, dev_id); + TBlob in_max_t( + reinterpret_cast(temp_space.dptr_) + 1, Shape1(1), xpu::kDevMask, dev_id); + Tensor workspace( + temp_space.dptr_ + 2 * float_bytes, Shape1(temp_reduce_size), s); + broadcast::Reduce( + s, in_min_t.reshape(dst_shape), kWriteTo, workspace, inputs[0].reshape(src_shape)); + broadcast::Reduce( + s, in_max_t.reshape(dst_shape), kWriteTo, workspace, inputs[0].reshape(src_shape)); + const float scale = + MaxValue() / (*in_max_t.dptr() - *in_min_t.dptr()); + const float shift = MaxValue() - *in_max_t.dptr() * scale; + Kernel::Launch(s, + outputs[0].Size(), + outputs[0].dptr(), + outputs[1].dptr(), + outputs[2].dptr(), + inputs[0].dptr(), + scale, + shift); + } + } else { + LOG(FATAL) << "Asymmetric quantizaiton only supports int8, uint8 and " + "float inputs"; + } + } + + private: + nnvm::NodeAttrs attrs_; +}; + +template +void QuantizeAsymForward(const OpStatePtr& state_ptr, + const OpContext& ctx, + const std::vector& inputs, + const std::vector& req, + const std::vector& outputs) { + QuantizeAsymOp& op = state_ptr.get_state>(); + op.Forward(ctx, inputs, req, outputs); +} + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_QUANTIZATION_QUANTIZE_ASYM_INL_H_ diff --git a/src/operator/quantization/quantize_asym.cc b/src/operator/quantization/quantize_asym.cc new file mode 100644 index 000000000000..4cb2669cd1c9 --- /dev/null +++ b/src/operator/quantization/quantize_asym.cc @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file quantize_asym.cc + * \brief implementation of asymmetric quantize operation + */ + +#include + +#include "operator/quantization/quantize_asym-inl.h" +#if MXNET_USE_ONEDNN == 1 +#include "operator/quantization/dnnl/dnnl_quantize_asym-inl.h" +#endif + +namespace mxnet { +namespace op { + +DMLC_REGISTER_PARAMETER(QuantizeAsymParam); + +inline bool QuantizeAsymShape(const nnvm::NodeAttrs& attrs, + mxnet::ShapeVector* in_attrs, + mxnet::ShapeVector* out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 3U); + + mxnet::TShape dshape = in_attrs->at(0); + SHAPE_ASSIGN_CHECK(*out_attrs, 0, dshape); + SHAPE_ASSIGN_CHECK(*out_attrs, 1, TShape(1, 1)); + SHAPE_ASSIGN_CHECK(*out_attrs, 2, TShape(1, 1)); + + if (out_attrs->at(0).ndim() > 0) { + dshape[0] = out_attrs->at(0)[0]; + SHAPE_ASSIGN_CHECK(*in_attrs, 0, dshape); + } + + return !shape_is_none(out_attrs->at(0)); +} + +inline bool QuantizeAsymType(const nnvm::NodeAttrs& attrs, + std::vector* in_attrs, + std::vector* out_attrs) { + CHECK_EQ(in_attrs->size(), 1U); + CHECK_EQ(out_attrs->size(), 3U); + + CHECK_EQ(in_attrs->at(0), mshadow::kFloat32); + + TYPE_ASSIGN_CHECK(*out_attrs, 0, mshadow::kUint8); + TYPE_ASSIGN_CHECK(*out_attrs, 1, mshadow::kFloat32); + TYPE_ASSIGN_CHECK(*out_attrs, 2, mshadow::kFloat32); + + return !type_is_none(out_attrs->at(0)); +} + +bool QuantizeAsymStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector* in_attrs, + std::vector* out_attrs) { + *dispatch_mode = DispatchMode::kFCompute; +#if MXNET_USE_ONEDNN == 1 + if (dev_mask == mshadow::cpu::kDevMask) { + *dispatch_mode = DispatchMode::kFComputeEx; + } +#endif + out_attrs->at(0) = kDefaultStorage; + out_attrs->at(1) = kDefaultStorage; + out_attrs->at(2) = kDefaultStorage; + return true; +} + +OpStatePtr CreateQuantizeAsymState(const nnvm::NodeAttrs& attrs, + const Context& ctx, + const std::vector& in_shapes, + const std::vector& in_types) { + OpStatePtr state; + if (ctx.dev_type == kGPU) { + state = OpStatePtr::Create>(attrs); + } else { +#if MXNET_USE_ONEDNN == 1 + if (in_shapes[0].ndim() == 3 && in_types[0] == mshadow::kFloat32) { + state = OpStatePtr::Create(attrs); + return state; + } +#else + state = OpStatePtr::Create>(attrs); +#endif + } + return state; +} + +NNVM_REGISTER_OP(_contrib_quantize_asym) + .describe(R"code(Quantize a input tensor from float to uint8_t. +Output `scale` and `shift` are scalar floats that specify the quantization +parameters for the input data. The output is calculated using the following equation: + +`out[i] = in[i] * scale + shift + 0.5`, + +where `scale = uint8_range / (max_range - min_range)` and +`shift = numeric_limits::max - max_range * scale`. + +.. Note:: + This operator only supports forward propagation. DO NOT use it in training.)code" ADD_FILELINE) + .set_attr_parser(ParamParser) + .set_num_inputs(1) + .set_num_outputs(3) + .set_attr("FListInputNames", + [](const NodeAttrs& attrs) { + return std::vector{"data"}; + }) + .set_attr("FListOutputNames", + [](const NodeAttrs& attrs) { + return std::vector{"output", "scale", "shift"}; + }) + .set_attr("FInferShape", QuantizeAsymShape) + .set_attr("FInferType", QuantizeAsymType) + .set_attr("FInferStorageType", QuantizeAsymStorageType) + .set_attr("FGradient", MakeZeroGradNodes) + .set_attr("FCreateOpState", CreateQuantizeAsymState) +#if MXNET_USE_ONEDNN == 1 + .set_attr("TIsDNNL", true) + .set_attr("FStatefulComputeEx", DNNLQuantizeAsymForward) +#endif + .set_attr("FStatefulCompute", QuantizeAsymForward) + .set_attr("FNeedCalibrateInput", + [](const NodeAttrs& attrs) { return std::vector{0}; }) + .set_attr("FResourceRequest", + [](const NodeAttrs& attrs) { + const QuantizeAsymParam& param = + nnvm::get(attrs.parsed); + if (param.max_calib_range.has_value() && + param.max_calib_range.has_value()) { + return std::vector(); + } else { + return std::vector( + 1, ResourceRequest::kTempSpace); + } + }) + .add_argument("data", "NDArray-or-Symbol", "A ndarray/symbol of type `float32`") + .add_arguments(QuantizeAsymParam::__FIELDS__()); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/quantization/quantize_graph_pass.cc b/src/operator/quantization/quantize_graph_pass.cc index 3835f1a3a9c9..a4e3086653b0 100644 --- a/src/operator/quantization/quantize_graph_pass.cc +++ b/src/operator/quantization/quantize_graph_pass.cc @@ -288,6 +288,10 @@ Graph QuantizeGraph(Graph&& src) { static const auto& avoid_quantize_input_map = Op::GetAttr("FAvoidQuantizeInput"); static const auto& flist_inputs = nnvm::Op::GetAttr("FListInputNames"); + static const auto& avoid_dequantize_map = + Op::GetAttr("FAvoidDequantizeOutput"); + static const auto& need_asym_quantize_map = + Op::GetAttr("FNeedAsymQuantizeInput"); const auto offline_params = src.GetAttr>("offline_params"); const auto quantized_dtype = src.GetAttr("quantized_dtype"); const auto quantize_granularity = src.GetAttr("quantize_granularity"); @@ -331,7 +335,14 @@ Graph QuantizeGraph(Graph&& src) { if (avoid_quantize_input_map.count(node->op()) && avoid_quantize_input_map[node->op()](node->attrs, i, quantize_granularity)) { new_node->inputs.emplace_back(mirror_entry); - } else if (!quantized_node_map.count(e.node)) { + } else if (!quantized_node_map.count(e.node) || + (avoid_dequantize_map.count(e.node->op()) && + avoid_dequantize_map[e.node->op()](e.node->attrs, e.index))) { + // If the input of current quantized node has non-support of quantization, a quantize op + // is supposed to insert into the position after the input node to quantize the float + // input to int8/uint8 type. Also, a quantized operator with avoid-dequantize attribute + // can produce float outputs directly. A quantize op is necessary to convert them into + // int8/uint8 type as the input of current quantized node. if (mirror_entry_map.count(e)) { new_node->inputs.emplace_back(mirror_entry_map[e]); } else { @@ -354,10 +365,20 @@ Graph QuantizeGraph(Graph&& src) { new_name = node->attrs.name + "_" + e.node->attrs.name; } } - - ObjectPtr quantize_node = InsertNode( - "_contrib_quantize_v2", new_name + suffix + "_quantize", new_node, mirror_entry); - quantize_node->attrs.dict["out_type"] = quantized_dtype; + ObjectPtr quantize_node; + if (need_asym_quantize_map.count(node->op()) && + need_asym_quantize_map[node->op()](node->attrs, i)) { + quantize_node = InsertNode("_contrib_quantize_asym", + new_name + suffix + "_quantize", + new_node, + mirror_entry); + } else { + quantize_node = InsertNode( + "_contrib_quantize_v2", new_name + suffix + "_quantize", new_node, mirror_entry); + // If current node is rnn op, the quantize op is supposed to quantize the result of + // pre-node to uint8, as quantized rnn op requires uint8 input. + quantize_node->attrs.dict["out_type"] = quantized_dtype; + } quantize_node->op()->attr_parser(&(quantize_node->attrs)); mirror_entry_map[e] = NodeEntry{quantize_node, 0, e.version}; } @@ -439,9 +460,13 @@ Graph QuantizeGraph(Graph&& src) { for (const auto& e : node->inputs) { ObjectPtr mirror_node = mirror_map.at(e.node.get()); NodeEntry mirror_entry = NodeEntry{mirror_node, e.index, e.version}; - // if input node is quantized operator, add dequantize node + // If input node is quantized operator, add dequantize node. But if input node is a + // quantized operator with avoid-dequantize attribute, its output may be already in float + // type, which dosen't need a dequantize op. if (quantized_node_map.count(e.node) && - (mirror_node->op() != Op::Get("_contrib_dequantize"))) { + mirror_node->op() != Op::Get("_contrib_dequantize") && + !(avoid_dequantize_map.count(e.node->op()) && + avoid_dequantize_map[e.node->op()](e.node->attrs, e.index))) { // here we calculate the output number (exclude min/max, in order to // calculate min/max index from mirror node) based on assumption that // there is only 1 min and 1 max output from mirror node (which is @@ -473,7 +498,9 @@ Graph QuantizeGraph(Graph&& src) { std::vector outputs; for (const auto& e : src.outputs) { - if (quantized_node_map.count(e.node)) { + if (quantized_node_map.count(e.node) && + !(avoid_dequantize_map.count(e.node->op()) && + avoid_dequantize_map[e.node->op()](e.node->attrs, e.index))) { // Only insert dequantize for those Ops supports quantize and not excluded. ObjectPtr mirror_node = mirror_map.at(e.node.get()); NodeEntry mirror_entry = NodeEntry{mirror_node, e.index, e.version}; diff --git a/src/operator/quantization/quantize_v2.cc b/src/operator/quantization/quantize_v2.cc index e08bd0d5f76d..497ea37cdf28 100644 --- a/src/operator/quantization/quantize_v2.cc +++ b/src/operator/quantization/quantize_v2.cc @@ -18,7 +18,7 @@ */ /*! - * \file quantize.cc + * \file quantize_v2.cc * \brief */ diff --git a/src/operator/quantization/quantized_rnn-inl.h b/src/operator/quantization/quantized_rnn-inl.h new file mode 100644 index 000000000000..6ab53cef867c --- /dev/null +++ b/src/operator/quantization/quantized_rnn-inl.h @@ -0,0 +1,41 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file quantized_rnn-inl.h + * \brief Common functions for quantized recurrent neural network + * \author Zixuan Wei + */ + +#ifndef MXNET_OPERATOR_QUANTIZATION_QUANTIZED_RNN_INL_H_ +#define MXNET_OPERATOR_QUANTIZATION_QUANTIZED_RNN_INL_H_ + +namespace mxnet { +namespace op { + +namespace quantized_rnn { +enum QuantizedRnnInputs { kData, kParams, kState, kStateCell }; +enum QuantizedRnnInputMinMax { kDataScale, kDataShift }; +enum QuantizedRnnOutputs { kOut, kStateOut, kStateCellOut }; +} // namespace quantized_rnn + +} // namespace op +} // namespace mxnet + +#endif // MXNET_OPERATOR_QUANTIZATION_QUANTIZED_RNN_INL_H_ diff --git a/src/operator/quantization/quantized_rnn.cc b/src/operator/quantization/quantized_rnn.cc new file mode 100644 index 000000000000..88c80bca3cc7 --- /dev/null +++ b/src/operator/quantization/quantized_rnn.cc @@ -0,0 +1,363 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +/*! + * \file quantized_rnn.cc + * \brief Common functions for quantized recurrent neural network + * \author Zixuan Wei + */ + +#include +#include +#include +#include + +#include "operator/rnn-inl.h" +#include "operator/quantization/quantization_utils.h" +#include "operator/quantization/quantized_rnn-inl.h" + +#if MXNET_USE_ONEDNN == 1 +#include "operator/quantization/dnnl/dnnl_quantized_rnn-inl.h" +#endif + +namespace mxnet { +namespace op { + +uint32_t QuantizedRnnNumInputs(const NodeAttrs& attrs) { + const RNNParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(param.mode, rnn_enum::kLstm) + << "Quantized recurrent neural network only supports LSTM operator on " + "CPU."; + return 6U; +} + +uint32_t QuantizedRnnNumOutputs(const NodeAttrs& attrs) { + const RNNParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(param.mode, rnn_enum::kLstm) + << "Quantized recurrent neural network only supports LSTM operator on " + "CPU."; + return param.state_outputs ? 3U : 1U; +} + +std::vector QuantizedRnnInputNames(const NodeAttrs& attrs) { + const RNNParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(param.mode, rnn_enum::kLstm) + << "Quantized recurrent neural network only supports LSTM operator on " + "CPU."; + return std::vector{ + "data", "parameters", "state", "state_cell", "min_data", "max_data"}; +} + +std::vector QuantizedRnnOutputNames(const NodeAttrs& attrs) { + const RNNParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(param.mode, rnn_enum::kLstm) + << "Quantized recurrent neural network only supports LSTM operator on " + "CPU."; + if (param.state_outputs) { + return std::vector{"output", "state_output", "statecell_ouput"}; + } else { + return std::vector{"output"}; + } +} + +bool QuantizedRnnShape(const nnvm::NodeAttrs& attrs, + std::vector* in_shape, + std::vector* out_shape) { + const RNNParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(param.mode, rnn_enum::kLstm) << "Quantized RNN operator only supports LSTM mode."; + + const uint32_t num_inputs = QuantizedRnnNumInputs(attrs); + const uint32_t num_outputs = QuantizedRnnNumOutputs(attrs); + CHECK_EQ(in_shape->size(), num_inputs) + << "Arguments' size of quantized RNN operator is mismatched. Expected " << num_inputs + << " argmuments but got " << in_shape->size() << "."; + CHECK_EQ(out_shape->size(), num_outputs); + + const mxnet::TShape dshape = in_shape->at(quantized_rnn::kData); + if (!mxnet::ndim_is_known(dshape)) + return false; + CHECK_EQ(dshape.ndim(), 3U) << "Input data of RNN operator should be 3-rank " + "tensor of dim [steps, batch, input size]"; + const dim_t batch_size = dshape[1]; + const dim_t input_size = dshape[2]; + const dim_t directions = param.bidirectional ? 2 : 1; + const dim_t total_lyrs = directions * param.num_layers; + const dim_t state_size = param.state_size; + SHAPE_ASSIGN_CHECK(*in_shape, quantized_rnn::kState, Shape3(total_lyrs, batch_size, state_size)); + if (param.mode == rnn_enum::kLstm) + SHAPE_ASSIGN_CHECK( + *in_shape, quantized_rnn::kStateCell, Shape3(total_lyrs, batch_size, state_size)); + + const int param_size_fp = GetRnnParamSize( + param.num_layers, input_size, state_size, directions, param.mode, param.projection_size); + SHAPE_ASSIGN_CHECK(*in_shape, quantized_rnn::kParams, Shape1(param_size_fp)); + const uint32_t num_base_inputs = GetRnnNumInputs(param); + for (size_t i = num_base_inputs; i < num_inputs; ++i) + SHAPE_ASSIGN_CHECK(*in_shape, i, Shape1(1)); + + out_shape->clear(); + out_shape->push_back({dshape[0], batch_size, directions * state_size}); // output dim: [T, N, C] + if (param.state_outputs) { + out_shape->push_back({total_lyrs, batch_size, state_size}); // state dim: [L*D, N, C] + if (param.mode == rnn_enum::kLstm) + out_shape->push_back({total_lyrs, batch_size, state_size}); // cell dim: [L*D, N, C] + } + return true; +} + +bool QuantizedRnnType(const nnvm::NodeAttrs& attrs, + std::vector* in_type, + std::vector* out_type) { + const RNNParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(param.mode, rnn_enum::kLstm) << "Quantized RNN operator only supports LSTM mode."; + + const uint32_t num_inputs = QuantizedRnnNumInputs(attrs); + const uint32_t num_outputs = QuantizedRnnNumOutputs(attrs); + CHECK_EQ(in_type->size(), num_inputs); + CHECK_EQ(out_type->size(), num_outputs); + + CHECK_EQ(in_type->at(quantized_rnn::kData), mshadow::kUint8) + << "Quantized RNN operator only supports uint8 input, while " + << in_type->at(quantized_rnn::kData) << " is given."; + TYPE_ASSIGN_CHECK(*in_type, quantized_rnn::kParams, mshadow::kFloat32); + TYPE_ASSIGN_CHECK(*in_type, quantized_rnn::kState, mshadow::kFloat32); + const uint32_t num_base_inputs = GetRnnNumInputs(param); + if (param.mode == rnn_enum::kLstm) + TYPE_ASSIGN_CHECK(*in_type, quantized_rnn::kStateCell, mshadow::kFloat32); + for (size_t i = num_base_inputs; i < num_inputs; ++i) + TYPE_ASSIGN_CHECK(*in_type, i, mshadow::kFloat32); + + TYPE_ASSIGN_CHECK(*out_type, quantized_rnn::kOut, mshadow::kFloat32); + if (param.state_outputs) { + TYPE_ASSIGN_CHECK(*out_type, quantized_rnn::kStateOut, mshadow::kFloat32); + if (param.mode == rnn_enum::kLstm) + TYPE_ASSIGN_CHECK(*out_type, quantized_rnn::kStateCellOut, mshadow::kFloat32); + } + return true; +} + +bool QuantizedRnnStorageType(const nnvm::NodeAttrs& attrs, + const int dev_mask, + DispatchMode* dispatch_mode, + std::vector* in_attrs, + std::vector* out_attrs) { + const uint32_t num_inputs = QuantizedRnnNumInputs(attrs); + const uint32_t num_outputs = QuantizedRnnNumOutputs(attrs); + CHECK_EQ(in_attrs->size(), num_inputs); + CHECK_EQ(out_attrs->size(), num_outputs); + +#if MXNET_USE_ONEDNN == 1 + return DNNLStorageType(attrs, dev_mask, true, dispatch_mode, in_attrs, out_attrs); +#else + *dispatch_mode = DispatchMode::kFCompute; + + for (auto& v : *out_attrs) { + v = kDefaultStorage; + if (common::stype_string(v).compare("unknown") == 0) { + return false; + } + } + + for (auto& v : *in_attrs) { + v = kDefaultStorage; + if (common::stype_string(v).compare("unknown") == 0) { + return false; + } + } + return true; +#endif +} + +void QuantizedRnnParamParser(nnvm::NodeAttrs* attrs) { + RNNParam param; + attrs->dict["quantized"] = "true"; + try { + param.Init(attrs->dict, dmlc::parameter::kAllowUnknown); + } catch (const dmlc::ParamError& e) { + std::ostringstream os; + os << e.what(); + os << ", in operator " << attrs->op->name << "(" + << "name=\"" << attrs->name << "\""; + for (const auto& k : attrs->dict) { + os << ", " << k.first << "=\"" << k.second << "\""; + } + os << ")"; + throw dmlc::ParamError(os.str()); + } + attrs->parsed = std::move(param); +} + +OpStatePtr CreateQuantizedRnnState(const nnvm::NodeAttrs& attrs, + const Context ctx, + const mxnet::ShapeVector& in_shapes, + const std::vector& in_types) { + const RNNParam& param = nnvm::get(attrs.parsed); + CHECK_EQ(param.mode, rnn_enum::kLstm) << "Quantized RNN operator only supports LSTM mode."; + OpStatePtr state = OpStatePtr(); +#if MXNET_USE_ONEDNN == 1 + const int data_type = in_types[quantized_rnn::kData]; + const int weight_type = in_types[quantized_rnn::kParams]; + if (data_type == mshadow::kUint8 && weight_type == mshadow::kFloat32) { + const mxnet::TShape& data_shape = in_shapes[quantized_rnn::kData]; + state = + OpStatePtr::Create(attrs, data_shape[0], data_shape[1], data_shape[2]); + } +#else + LOG(FATAL) << "Quantized RNN operator relies on oneDNN library." + << " Please build MXNet with USE_ONEDNN=ON to leverage this operator."; +#endif + return state; +} + +void QuantizedRnnForwardCPU(const OpStatePtr& state_ptr, + const OpContext& ctx, + const std::vector& in_data, + const std::vector& req, + const std::vector& out_data) { + LOG(FATAL) << "Quantized RNN operator relies on oneDNN library." + << " Please build MXNet with USE_ONEDNN=ON to leverage this operator."; +} + +#if MXNET_USE_ONEDNN == 1 +void QuantizedRnnForwardCPUEx(const OpStatePtr& state_ptr, + const OpContext& ctx, + const std::vector& in_data, + const std::vector& req, + const std::vector& out_data) { + DNNLQuantizedRnnOp& op = state_ptr.get_state(); + op.Forward(ctx, in_data, req, out_data); +} +#endif // MXNET_USE_ONEDNN == 1 + +bool NeedAsymQuantizeRnnInput(const NodeAttrs& attrs, const size_t index_to_check) { + bool need_asym_quantize = false; + switch (index_to_check) { + case rnn_enum::kData: { + need_asym_quantize = true; + break; + } + default: { + need_asym_quantize = false; + } + } + return need_asym_quantize; +} + +bool AvoidRnnQuantizeInput(const NodeAttrs& attrs, + const size_t index_to_check, + const std::string quantize_granularity) { + std::unordered_set avoid_indexes; + avoid_indexes.insert({quantized_rnn::kParams, quantized_rnn::kState, quantized_rnn::kStateCell}); + + return avoid_indexes.count(index_to_check); +} + +bool AvoidRnnDequantizeOutput(const NodeAttrs& attrs, const size_t index_to_check) { + return true; +} + +static std::vector QuantizedRnnResourceEx(const NodeAttrs& attrs, + const int dev_mask, + const DispatchMode dispatch_mode) { + std::vector request; + if (dev_mask == kGPU) { +#if MXNET_USE_CUDNN == 1 + LOG(FATAL) << "Currently, quantized RNN is not supported on the GPU platform."; +#endif + } else { +#if MXNET_USE_ONEDNN == 1 + request.emplace_back(ResourceRequest::kTempSpace); +#endif + } + return request; +} + +NNVM_REGISTER_OP(_contrib_quantized_rnn) + .add_alias("_npx_contrib_quantized_rnn") + .describe(R"code(RNN operator for input data type of uint8. The weight of each +gates is converted to int8, while bias is accumulated in type float32. +The hidden state and cell state are in type float32. For the input data, two more arguments +of type float32 must be provided representing the thresholds of quantizing argument from +data type float32 to uint8. The final outputs contain the recurrent result in float32. +It only supports quantization for Vanilla LSTM network. + +.. Note:: + This operator only supports forward propagation. DO NOT use it in training.)code" ADD_FILELINE) + .set_num_inputs(QuantizedRnnNumInputs) + .set_num_outputs(QuantizedRnnNumOutputs) + .set_attr_parser(QuantizedRnnParamParser) + .set_attr("FListInputNames", QuantizedRnnInputNames) + .set_attr("FListOutputNames", QuantizedRnnOutputNames) + .set_attr("FInferShape", QuantizedRnnShape) + .set_attr("FInferType", QuantizedRnnType) + .set_attr("FInferStorageType", QuantizedRnnStorageType) + .set_attr("FCreateOpState", CreateQuantizedRnnState) + .set_attr("FStatefulCompute", QuantizedRnnForwardCPU) +#if MXNET_USE_ONEDNN == 1 + .set_attr("TIsDNNL", true) + .set_attr("FStatefulComputeEx", QuantizedRnnForwardCPUEx) +#endif + .set_attr("FResourceRequestEx", QuantizedRnnResourceEx) + .add_argument("data", "NDArray-or-Symbol", "Input data.") + .add_argument("parameters", "NDArray-or-Symbol", "weight.") + .add_argument("state", "NDArray-or-Symbol", "initial hidden state of the RNN") + .add_argument("state_cell", + "NDArray-or-Symbol", + "initial cell state for LSTM networks (only for LSTM)") + .add_argument("data_scale", "NDArray-or-Symbol", "quantization scale of data.") + .add_argument("data_shift", "NDArray-or-Symbol", "quantization shift of data.") + .add_arguments(RNNParam::__FIELDS__()); + +NNVM_REGISTER_OP(RNN) + .set_attr("FQuantizable", + [](const NodeAttrs& attrs) { +#if MXNET_USE_ONEDNN == 1 + const RNNParam& param = nnvm::get(attrs.parsed); + if (param.mode != rnn_enum::kLstm) + LOG(INFO) << "Quantized RNN only supports LSTM mode."; + if (param.mode == rnn_enum::kLstm && + !param.projection_size.has_value()) { + return QuantizeType::kMust; + } else { + return QuantizeType::kNone; + } +#else + LOG(INFO) << "Quantized RNN is not supported by this MXNet release. Please enable oneDNN to " + << "use the feature."; + return QuantizeType::kNone; +#endif // MXNET_USE_ONEDNN == 1 + }) + .set_attr("FQuantizedOp", + [](const NodeAttrs& attrs) { + nnvm::ObjectPtr node = nnvm::Node::Create(); + node->attrs.op = Op::Get("_contrib_quantized_rnn"); + node->attrs.name = "quantized_" + attrs.name; + node->attrs.dict = attrs.dict; + node->attrs.dict["quantized"] = "true"; + if (node->op()->attr_parser != nullptr) { + node->op()->attr_parser(&(node->attrs)); + } + return node; + }) + .set_attr("FNeedAsymQuantizeInput", NeedAsymQuantizeRnnInput) + .set_attr("FAvoidQuantizeInput", AvoidRnnQuantizeInput) + .set_attr("FAvoidDequantizeOutput", AvoidRnnDequantizeOutput); + +} // namespace op +} // namespace mxnet diff --git a/src/operator/rnn-inl.h b/src/operator/rnn-inl.h index c34855468c5c..eac274f96a9d 100644 --- a/src/operator/rnn-inl.h +++ b/src/operator/rnn-inl.h @@ -291,9 +291,9 @@ inline size_t GetRNNReserveSpaceSize(int num_layer, return size; } -inline size_t GetNumInputArguments(RNNParam param_) { - size_t num_inputs = (param_.mode == rnn_enum::kLstm) ? 4U : 3U; - if (param_.use_sequence_length) +inline size_t GetRnnNumInputs(RNNParam param) { + size_t num_inputs = (param.mode == rnn_enum::kLstm) ? 4U : 3U; + if (param.use_sequence_length) num_inputs += 1U; return num_inputs; } @@ -748,7 +748,7 @@ class RNNOp { using namespace mshadow::expr; CHECK(param_.p >= 0.0f && param_.p < 1.0f) << "unsupported dropout value, should be 0 <= dropout < 1"; - size_t num_inputs = GetNumInputArguments(param_); + size_t num_inputs = GetRnnNumInputs(param_); // kOut size_t num_outputs = 1; @@ -1125,7 +1125,7 @@ class RNNOp { CHECK(param_.p >= 0.0f && param_.p < 1.0f) << "unsupported dropout value, should be 0 <= dropout < 1"; - size_t num_inputs = GetNumInputArguments(param_); + size_t num_inputs = GetRnnNumInputs(param_); // kOut size_t num_outputs = 1; @@ -1369,7 +1369,7 @@ class RNNOp { const std::vector& out_data) { using namespace mshadow; - size_t num_inputs = GetNumInputArguments(param_); + size_t num_inputs = GetRnnNumInputs(param_); // kOut size_t num_outputs = 1; if (param_.state_outputs) { diff --git a/src/operator/rnn.cc b/src/operator/rnn.cc index e4b84dd0d927..5a03b06674c8 100644 --- a/src/operator/rnn.cc +++ b/src/operator/rnn.cc @@ -34,31 +34,41 @@ namespace mxnet { namespace op { DMLC_REGISTER_PARAMETER(RNNParam); -static inline std::vector ListArguments(const RNNParam& param_) { +static inline std::vector ListRnnInputNames(const RNNParam& param) { // All RNNs start off with same 3 input arguments std::vector arguments{"data", "parameters", "state"}; // LSTMs also have an additional state_cell argument - if (param_.mode == rnn_enum::kLstm) { + if (param.mode == rnn_enum::kLstm) { arguments.emplace_back("state_cell"); } // All RNNs have option of additional sequence_length argument - if (param_.use_sequence_length) { + if (param.use_sequence_length) { arguments.emplace_back("sequence_length"); } return arguments; } +static inline std::vector ListRnnOutputNames(const RNNParam& param) { + std::vector names{"output"}; + if (param.state_outputs) { + names.emplace_back("state_output"); + if (param.mode == rnn_enum::kLstm) + names.emplace_back("statecell_output"); + } + return names; +} + static bool RNNShape(const nnvm::NodeAttrs& attrs, std::vector* in_shape, std::vector* out_shape) { - const RNNParam& param_ = nnvm::get(attrs.parsed); using namespace mshadow; + const RNNParam& param = nnvm::get(attrs.parsed); - // Query param_ object to figure out what the expectd input arguments are - std::vector expected_arguments = ListArguments(param_); + // Query param object to figure out what the expectd input arguments are + std::vector expected_arguments = ListRnnInputNames(param); CHECK_EQ(in_shape->size(), expected_arguments.size()) << "Input shape mismatch. Expected " << expected_arguments.size() @@ -76,29 +86,29 @@ static bool RNNShape(const nnvm::NodeAttrs& attrs, } int batch_size = dshape[1]; int input_size = dshape[2]; - int numDirections = param_.bidirectional ? 2 : 1; - int total_layers = numDirections * param_.num_layers; // double for bidirectional + int numDirections = param.bidirectional ? 2 : 1; + int total_layers = numDirections * param.num_layers; // double for bidirectional int layer_size = - (param_.projection_size.has_value()) ? param_.projection_size.value() : param_.state_size; + (param.projection_size.has_value()) ? param.projection_size.value() : param.state_size; SHAPE_ASSIGN_CHECK(*in_shape, rnn_enum::kState, Shape3(total_layers, batch_size, layer_size)); - if (param_.mode == rnn_enum::kLstm) { + if (param.mode == rnn_enum::kLstm) { SHAPE_ASSIGN_CHECK( - *in_shape, rnn_enum::kStateCell, Shape3(total_layers, batch_size, param_.state_size)); + *in_shape, rnn_enum::kStateCell, Shape3(total_layers, batch_size, param.state_size)); } // calculate parameter vector length - int param_size = GetRnnParamSize(param_.num_layers, + int param_size = GetRnnParamSize(param.num_layers, input_size, - param_.state_size, + param.state_size, numDirections, - param_.mode, - param_.projection_size); + param.mode, + param.projection_size); SHAPE_ASSIGN_CHECK(*in_shape, rnn_enum::kParams, Shape1(param_size)); // Check on sequence_length shape if using - if (param_.use_sequence_length) { + if (param.use_sequence_length) { size_t seq_len_input_idx = rnn_enum::kSequenceLength; - if (param_.mode != rnn_enum::kLstm) + if (param.mode != rnn_enum::kLstm) --seq_len_input_idx; SHAPE_ASSIGN_CHECK(*in_shape, seq_len_input_idx, Shape1(batch_size)); @@ -107,29 +117,29 @@ static bool RNNShape(const nnvm::NodeAttrs& attrs, out_shape->clear(); // output: [sequence len, batch, output size] TShape oshape = dshape; - if (param_.projection_size.has_value()) { - oshape[2] = numDirections * param_.projection_size.value(); + if (param.projection_size.has_value()) { + oshape[2] = numDirections * param.projection_size.value(); } else { - oshape[2] = numDirections * param_.state_size; + oshape[2] = numDirections * param.state_size; } out_shape->push_back(oshape); - if (param_.state_outputs) { + if (param.state_outputs) { // outStateShape: [layer_num, batch, state size] TShape outStateShape = dshape; outStateShape[0] = total_layers; outStateShape[1] = batch_size; - if (param_.projection_size.has_value()) { - outStateShape[2] = param_.projection_size.value(); + if (param.projection_size.has_value()) { + outStateShape[2] = param.projection_size.value(); } else { - outStateShape[2] = param_.state_size; + outStateShape[2] = param.state_size; } out_shape->push_back(outStateShape); // Deal with lstm cell state - if (param_.mode == rnn_enum::kLstm) { + if (param.mode == rnn_enum::kLstm) { TShape cellStateShape = dshape; cellStateShape[0] = total_layers; cellStateShape[1] = batch_size; - cellStateShape[2] = param_.state_size; + cellStateShape[2] = param.state_size; out_shape->push_back(cellStateShape); } } @@ -140,34 +150,34 @@ static bool RNNShape(const nnvm::NodeAttrs& attrs, static bool RNNType(const nnvm::NodeAttrs& attrs, std::vector* in_type, std::vector* out_type) { - const RNNParam& param_ = nnvm::get(attrs.parsed); + const RNNParam& param = nnvm::get(attrs.parsed); - CHECK_EQ(in_type->size(), GetNumInputArguments(param_)); + CHECK_EQ(in_type->size(), GetRnnNumInputs(param)); size_t seq_len_input_idx = rnn_enum::kSequenceLength; - if (param_.mode != rnn_enum::kLstm) + if (param.mode != rnn_enum::kLstm) --seq_len_input_idx; int dtype = (*in_type)[0]; CHECK_NE(dtype, -1) << "First input must have specified type"; - std::vector arguments = ListArguments(param_); + std::vector arguments = ListRnnInputNames(param); for (size_t i = 0; i < in_type->size(); ++i) { if ((*in_type)[i] == -1) { TYPE_ASSIGN_CHECK(*in_type, i, dtype); } else { // If using sequence length argument, it has its own indexing type // All other input arguments must match the main data type - if (!(param_.use_sequence_length && i == seq_len_input_idx)) { + if (!(param.use_sequence_length && i == seq_len_input_idx)) { UNIFORM_TYPE_CHECK((*in_type)[i], dtype, arguments[i]); } } } out_type->clear(); out_type->push_back(dtype); - if (param_.state_outputs) { + if (param.state_outputs) { out_type->push_back(dtype); // Deal with lstm cell state - if (param_.mode == rnn_enum::kLstm) { + if (param.mode == rnn_enum::kLstm) { out_type->push_back(dtype); } } @@ -248,7 +258,7 @@ static OpStatePtr CreateRNNState(const nnvm::NodeAttrs& attrs, #if MXNET_USE_ONEDNN == 1 if (ctx.dev_type == kCPU && SupportDNNLRnn(param, in_types[rnn_enum::kData])) { const mxnet::TShape& data_shape = in_shapes[rnn_enum::kData]; - state = OpStatePtr::Create(param, data_shape[0], data_shape[1], data_shape[2]); + state = OpStatePtr::Create(attrs, data_shape[0], data_shape[1], data_shape[2]); return state; } #endif // MXNET_USE_ONEDNN == 1 @@ -370,7 +380,7 @@ The definition of GRU here is slightly different from paper but compatible with .set_attr_parser(ParamParser) .set_num_inputs([](const NodeAttrs& attrs) { const RNNParam& params = nnvm::get(attrs.parsed); - return GetNumInputArguments(params); + return GetRnnNumInputs(params); }) .set_num_outputs([](const NodeAttrs& attrs) { const RNNParam& params = nnvm::get(attrs.parsed); @@ -386,18 +396,12 @@ The definition of GRU here is slightly different from paper but compatible with .set_attr("FListInputNames", [](const NodeAttrs& attrs) { const RNNParam& params = nnvm::get(attrs.parsed); - return ListArguments(params); + return ListRnnInputNames(params); }) .set_attr("FListOutputNames", [](const NodeAttrs& attrs) { const RNNParam& params = nnvm::get(attrs.parsed); - std::vector names{"output"}; - if (params.state_outputs) { - names.emplace_back("state_output"); - if (params.mode == rnn_enum::kLstm) - names.emplace_back("statecell_output"); - } - return names; + return ListRnnOutputNames(params); }) .set_attr("FInferShape", RNNShape) .set_attr("FInferType", RNNType) @@ -441,7 +445,7 @@ NNVM_REGISTER_OP(_backward_RNN) }) .set_num_outputs([](const NodeAttrs& attrs) { const RNNParam& params = nnvm::get(attrs.parsed); - return GetNumInputArguments(params); + return GetRnnNumInputs(params); }) .set_attr_parser(ParamParser) .set_attr("TIsLayerOpBackward", true) diff --git a/tests/python/quantization/test_quantization.py b/tests/python/quantization/test_quantization.py index dcd4bbd5b546..6b74a49a9d56 100644 --- a/tests/python/quantization/test_quantization.py +++ b/tests/python/quantization/test_quantization.py @@ -1414,3 +1414,106 @@ def get_threshold(nd): assert 'layer1' in min_max_dict assert_almost_equal(onp.array([min_max_dict['layer1'][1]]), expected_threshold, rtol=1e-2, atol=1e-4) + +@use_np +def test_rnn_quantization(): + data_low = -1 + data_high = 1 + def check_rnn_quantization(num_layers, bidirectional, seq_len, batch_size, input_dim, state_size): + data_shape = (seq_len, batch_size, input_dim) + + rnn_fp32 = mx.gluon.rnn.LSTM(hidden_size=state_size, + num_layers = num_layers, + bidirectional=bidirectional) + + data = mx.np.random.uniform(low=data_low, high=data_high, size=data_shape) + states_shape = (num_layers * 2 if bidirectional else num_layers, batch_size, state_size) + states = [mx.np.zeros((states_shape)) for _ in range(batch_size)] + + rnn_fp32.initialize() + rnn_fp32.hybridize() + ref_out = rnn_fp32(data, states) + + class RNNDataLoader(mx.gluon.data.DataLoader): + def __init__(self, data, states): + super().__init__(mx.gluon.data.SimpleDataset([]), 1) + self.data = data + self.states = states + + def __iter__(self): + return self + + def __next__(self): + return [self.data, self.states] + + def __bool__(self): + return bool(self.dataiter.iter_next()) + + calib_data = RNNDataLoader(data, states) + quant_rnn = mx.contrib.quant.quantize_net(rnn_fp32, + quantized_dtype='auto', + quantize_mode='full', + calib_data=calib_data, + calib_mode='naive', + num_calib_batches=1, + device=mx.current_device()) + qout = quant_rnn(data, states) + + qsym, _ = quant_rnn.export(None) + assert qsym.tojson().find("quantized_rnn") != -1 + + ref_out = [ref_out[0], ref_out[1][0], ref_out[1][1]] + for i in range(len(qout)): + mse = onp.mean((ref_out[i].asnumpy() - qout[i].asnumpy())**2) + assert mse < 0.001 + + check_rnn_quantization(1, False, 5, 2, 16, 16) + check_rnn_quantization(1, True, 5, 2, 16, 16) + + + +@use_np +def test_quantized_rnn(): + def check_quantized_rnn(num_layers, bidirectional, seq_len, batch_size, input_dim, state_size): + ndir = 2 if bidirectional else 1 + size = ndir*state_size*4 + first_lyr_param_size = (input_dim + state_size + 2) * size + other_lyr_param_size = (state_size * ndir + state_size + 2) * size + full_param_size = first_lyr_param_size + (num_layers - 1) * other_lyr_param_size + + data = mx.np.random.uniform(-1, 1, (seq_len, batch_size, input_dim)) + state = mx.np.random.uniform(-1, 1, (num_layers*ndir, batch_size, state_size)) + state_cell = mx.np.random.uniform(0, 1, (num_layers*ndir, batch_size, state_size)) + params = mx.np.random.normal(0, 1, (full_param_size,)) + + out = npx.rnn(data=data, + parameters=params, + mode='lstm', + state=state, + state_size=state_size, + state_cell=state_cell, + num_layers=num_layers, + bidirectional=bidirectional) + + data_min = mx.np.min(data) + data_max = mx.np.max(data) + data_scale = mx.np.array(128.0 / (data_max - data_min)).reshape((1,)) + data_shift = mx.np.array(128.0 - data_max * data_scale).reshape((1,)) + + qdata = (data * data_scale + data_shift + 0.5).astype('uint8') + qout = npx.contrib_quantized_rnn(data=qdata, + parameters=params, + mode='lstm', + state=state, + state_size=state_size, + state_cell=state_cell, + num_layers=num_layers, + bidirectional=bidirectional, + data_scale=data_scale, + data_shift=data_shift) + + mse = onp.mean((out.asnumpy() - qout.asnumpy())**2) + assert mse < 0.001 + + check_quantized_rnn(1, False, 5, 2, 16, 16) + check_quantized_rnn(1, True, 5, 2, 16, 16) \ No newline at end of file