diff --git a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py index 799957675df7..77179bd46966 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py +++ b/python/mxnet/contrib/onnx/mx2onnx/_op_translations.py @@ -822,7 +822,7 @@ def convert_leakyrelu(node, **kwargs): inputs=input_nodes, outputs=[name], name=name) - elif act_type in ('gelu'): + elif act_type in ('gelu',): sqrt2 = np.float32(1.4142135623730951) create_const_scalar_node(name+"_sqrt2", sqrt2, kwargs) create_const_scalar_node(name+"_one", np.float32(1.0), kwargs) @@ -1225,15 +1225,14 @@ def scalar_op_helper(node, op_name, **kwargs): data_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[new_initializer.dtype] dims = np.shape(new_initializer) - new_a_node = input_nodes[0] + str(kwargs["idx"]) - tensor_node = onnx.helper.make_tensor_value_info(new_a_node, data_type, dims) + tensor_node = onnx.helper.make_tensor_value_info(name, data_type, dims) initializer.append( onnx.helper.make_tensor( - name=new_a_node, + name=name, data_type=data_type, dims=dims, - vals=new_initializer, + vals=new_initializer.flatten(), raw=False, ) ) @@ -2841,6 +2840,8 @@ def convert_zeros(node, **kwargs): dtype = attrs.get('dtype') data_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)] shape = convert_string_to_list(attrs.get('shape')) + # replace 0 with 1 + shape = [x if x else 1 for x in shape] create_tensor(shape, name+'_shape', kwargs['initializer']) tensor_value = make_tensor(name+'_zero', data_type, [1], [0]) nodes = [ @@ -2858,6 +2859,8 @@ def convert_ones(node, **kwargs): dtype = attrs.get('dtype') data_type = onnx.mapping.NP_TYPE_TO_TENSOR_TYPE[np.dtype(dtype)] shape = convert_string_to_list(attrs.get('shape')) + # replace 0 with 1 + shape = [x if x else 1 for x in shape] create_tensor(shape, name+'_shape', kwargs['initializer']) tensor_value = make_tensor(name+'_one', data_type, [1], [1]) nodes = [ @@ -4040,6 +4043,7 @@ def convert_one_hot(node, **kwargs): """Map MXNet's one_hot operator attributes to onnx's OneHot operator """ from onnx.helper import make_node + from onnx import TensorProto name, input_nodes, attrs = get_inputs(node, kwargs) depth = int(attrs.get('depth')) @@ -4050,7 +4054,8 @@ def convert_one_hot(node, **kwargs): create_tensor([off_value, on_value], name+'_values', kwargs['initializer'], dtype=np.dtype(dtype)) create_tensor([depth], name+'_depth', kwargs['initializer']) nodes = [ - make_node('OneHot', [input_nodes[0], name+'_depth', name+'_values'], [name], name=name) + make_node('Cast', [input_nodes[0]], [name+'_cast'], to=int(TensorProto.INT64)), + make_node('OneHot', [name+'_cast', name+'_depth', name+'_values'], [name], name=name) ] return nodes @@ -4106,7 +4111,6 @@ def convert_sequence_reverse(node, **kwargs): return nodes - @mx_op.register("RNN") def convert_RNN(node, **kwargs): """Map MXNet's RNN operator attributes to onnx's operators diff --git a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py index 898a8df2d5c2..89f061d4c161 100644 --- a/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py +++ b/python/mxnet/contrib/onnx/mx2onnx/export_onnx.py @@ -330,6 +330,17 @@ def create_onnx_graph_proto(self, sym, params, in_shape, in_type, verbose=False, else: logging.info("Operator converter function should always return a list") + # sometimes the graph output can also be in the intializer + for i in initializer: + if i.name in graph_outputs: + onnx_processed_outputs.append( + make_tensor_value_info( + name=i.name, + elem_type=graph_outputs[i.name]['dtype'], + shape=graph_outputs[i.name]['shape'] + ) + ) + graph = helper.make_graph( onnx_processed_nodes, "mxnet_converted_model",