From cfb571bdc6f7730df5b3ed4f631b0420b1ab2b48 Mon Sep 17 00:00:00 2001 From: barry-jin Date: Wed, 14 Apr 2021 21:18:25 -0700 Subject: [PATCH] fix broadcast_lie --- .../operator/numpy_extension/npx_broadcast_like_op.cc | 8 ++++++-- tests/python/unittest/test_numpy_op.py | 11 +++++++++++ 2 files changed, 17 insertions(+), 2 deletions(-) diff --git a/src/api/operator/numpy_extension/npx_broadcast_like_op.cc b/src/api/operator/numpy_extension/npx_broadcast_like_op.cc index 62ea2a031a73..bd882665208e 100644 --- a/src/api/operator/numpy_extension/npx_broadcast_like_op.cc +++ b/src/api/operator/numpy_extension/npx_broadcast_like_op.cc @@ -44,14 +44,18 @@ MXNET_REGISTER_API("_npx.broadcast_like") // lhs_axes if (args[2].type_code() == kNull) { param.lhs_axes = dmlc::optional(); + } else if (args[2].type_code() == kDLInt) { + param.lhs_axes = TShape(1, args[2].operator int64_t()); } else { param.lhs_axes = mxnet::TShape(args[2].operator ObjectRef()); } // rhs_axes - if (args[2].type_code() == kNull) { + if (args[3].type_code() == kNull) { param.rhs_axes = dmlc::optional(); + } else if (args[3].type_code() == kDLInt) { + param.rhs_axes = TShape(1, args[3].operator int64_t()); } else { - param.rhs_axes = mxnet::TShape(args[2].operator ObjectRef()); + param.rhs_axes = mxnet::TShape(args[3].operator ObjectRef()); } attrs.op = op; diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py index 2fe0ec019b6f..1e253f20923d 100644 --- a/tests/python/unittest/test_numpy_op.py +++ b/tests/python/unittest/test_numpy_op.py @@ -10338,3 +10338,14 @@ def test_modulated_deformable_convolution(num_batch, num_channel_data, num_defor rtol, atol = 1.0, 1e-2 else: rtol, atol = 0.05, 1e-3 + + +@use_np +def test_broadcast_like_different_types(): + x = mx.np.zeros((2, 1)) + y = mx.np.ones((2, 2)) + + y = mx.np.array(y).astype('int32') + z = mx.npx.broadcast_like(x, y, 1, 1) + assert_almost_equal(z.asnumpy(), np.array([[0,0],[0,0]])) + assert x.dtype == z.dtype