port & update pr16744 numpy gcd  (#19547)

* numpy-compatible gcd operator

* use BinaryScalarRTCCompute

* Update _op.py

* Update np_elemwise_broadcast_op_extended.cc

* fix

* Update operator_tune.cc

* fix kernel

* add large tensor test

* add gcd interoperability workload

* Update test_numpy_interoperability.py

* Update np_elemwise_broadcast_op_extended.cc

* Update np_elemwise_broadcast_op_extended.cc

* avoid ci linspce issue

Co-authored-by: Hao Jin <hjjn.amzn@gmail.com>
diff --git a/ci/docker/install/requirements b/ci/docker/install/requirements
index 8741e33..e3d90d9 100644
--- a/ci/docker/install/requirements
+++ b/ci/docker/install/requirements
@@ -19,7 +19,7 @@
 # the whole docker cache for the image
 
 # Required dependencies
-numpy<1.20.0
+numpy>=1.17,<1.20.0
 requests>=2.20.0,<3
 graphviz<0.9.0,>=0.8.1
 contextvars;python_version<"3.7"
diff --git a/python/mxnet/amp/lists/symbol_fp16.py b/python/mxnet/amp/lists/symbol_fp16.py
index 7242a70..d58237b 100644
--- a/python/mxnet/amp/lists/symbol_fp16.py
+++ b/python/mxnet/amp/lists/symbol_fp16.py
@@ -245,6 +245,8 @@
     '_npi_logistic',
     '_npi_lcm',
     '_npi_lcm_scalar',
+    '_npi_gcd',
+    '_npi_gcd_scalar',
     '_npi_linspace',
     '_npi_logical_not',
     '_npi_logical_and_scalar',
diff --git a/python/mxnet/ndarray/numpy/_op.py b/python/mxnet/ndarray/numpy/_op.py
index 68eea62..a1f4bcf 100644
--- a/python/mxnet/ndarray/numpy/_op.py
+++ b/python/mxnet/ndarray/numpy/_op.py
@@ -44,7 +44,7 @@
            'max', 'min', 'amax', 'amin', 'logical_and', 'logical_or', 'logical_xor',
            'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'unravel_index',
            'diag_indices_from', 'hanning', 'hamming', 'blackman', 'flip', 'flipud', 'fliplr',
-           'hypot', 'bitwise_and', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm',
+           'hypot', 'bitwise_and', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'gcd',
            'tril', 'triu', 'tri', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'cross', 'kron',
            'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'roll', 'rot90', 'einsum',
            'true_divide', 'nonzero', 'quantile', 'percentile', 'shares_memory', 'may_share_memory', 'interp',
@@ -2083,6 +2083,46 @@
 
 @set_module('mxnet.ndarray.numpy')
 @wrap_np_binary_func
+def gcd(x1, x2, out=None, **kwargs):
+    """
+    Returns the greatest common divisor of ``|x1|`` and ``|x2|``
+
+    Parameters
+    ----------
+    x1, x2 : ndarrays or scalar values
+        The arrays for computing greatest common divisor. If x1.shape != x2.shape,
+        they must be broadcastable to a common shape (which may be the shape of
+        one or the other).
+
+    out : ndarray or None, optional
+        A location into which the result is stored. If provided, it must have a shape
+        that the inputs broadcast to. If not provided or None, a freshly-allocated array
+        is returned.
+
+    Returns
+    -------
+    y : ndarray or scalar
+        The greatest common divisor of the absolute value of the inputs
+        This is a scalar if both `x1` and `x2` are scalars.
+
+    See Also
+    --------
+    lcm : The lowest common multiple
+
+    Examples
+    --------
+    >>> np.gcd(12, 20)
+    4
+    >>> np.gcd(np.arange(6, dtype=int), 20)
+    array([20,  1,  2,  1,  4,  5], dtype=int64)
+    """
+    if isinstance(x1, numeric_types) and isinstance(x2, numeric_types):
+        return _np.gcd(x1, x2, out=out)
+    return _api_internal.gcd(x1, x2, out)
+
+
+@set_module('mxnet.ndarray.numpy')
+@wrap_np_binary_func
 def lcm(x1, x2, out=None, **kwargs):
     """
     Returns the lowest common multiple of ``|x1|`` and ``|x2|``
diff --git a/python/mxnet/numpy/multiarray.py b/python/mxnet/numpy/multiarray.py
index e260a41..0f55c31 100644
--- a/python/mxnet/numpy/multiarray.py
+++ b/python/mxnet/numpy/multiarray.py
@@ -74,7 +74,7 @@
            'flip', 'flipud', 'fliplr', 'around', 'round', 'round_', 'arctan2', 'hypot',
            'triu_indices_from', 'triu_indices', 'tri',
            'bitwise_and', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad',
-           'unique', 'lcm', 'tril', 'triu', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer',
+           'unique', 'lcm', 'gcd', 'tril', 'triu', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer',
            'cross', 'kron', 'equal', 'not_equal', 'interp',
            'greater', 'less', 'greater_equal', 'less_equal', 'roll', 'rot90', 'einsum', 'true_divide', 'nonzero',
            'quantile', 'percentile', 'shares_memory', 'may_share_memory', 'diff', 'ediff1d', 'resize', 'matmul',
@@ -3622,6 +3622,44 @@
 
 @set_module('mxnet.numpy')
 @wrap_np_binary_func
+def gcd(x1, x2, out=None, **kwargs):
+    """
+    Returns the greatest common divisor of ``|x1|`` and ``|x2|``
+
+    Parameters
+    ----------
+    x1, x2 : ndarrays or scalar values
+        The arrays for computing greatest common divisor. If x1.shape != x2.shape,
+        they must be broadcastable to a common shape (which may be the shape of
+        one or the other).
+
+    out : ndarray or None, optional
+        A location into which the result is stored. If provided, it must have a shape
+        that the inputs broadcast to. If not provided or None, a freshly-allocated array
+        is returned.
+
+    Returns
+    -------
+    y : ndarray or scalar
+        The greatest common divisor of the absolute value of the inputs
+        This is a scalar if both `x1` and `x2` are scalars.
+
+    See Also
+    --------
+    gcd : The lowest common multiple
+
+    Examples
+    --------
+    >>> np.gcd(12, 20)
+    4
+    >>> np.gcd(np.arange(6, dtype=int), 20)
+    array([20,  1,  2,  1,  4,  5], dtype=int64)
+    """
+    return _mx_nd_np.gcd(x1, x2, out=out)
+
+
+@set_module('mxnet.numpy')
+@wrap_np_binary_func
 def lcm(x1, x2, out=None, **kwargs):
     """
     Returns the lowest common multiple of ``|x1|`` and ``|x2|``
diff --git a/python/mxnet/numpy_dispatch_protocol.py b/python/mxnet/numpy_dispatch_protocol.py
index 01f2b72..f047076 100644
--- a/python/mxnet/numpy_dispatch_protocol.py
+++ b/python/mxnet/numpy_dispatch_protocol.py
@@ -249,6 +249,7 @@
     'degrees',
     'hypot',
     'lcm',
+    'gcd',
     # 'ldexp',
     'subtract',
     'multiply',
diff --git a/python/mxnet/symbol/numpy/_symbol.py b/python/mxnet/symbol/numpy/_symbol.py
index c1df972..821b6fa 100644
--- a/python/mxnet/symbol/numpy/_symbol.py
+++ b/python/mxnet/symbol/numpy/_symbol.py
@@ -49,7 +49,7 @@
            'flatnonzero', 'tril_indices', 'amax', 'amin', 'max', 'min', 'logical_and', 'logical_or', 'logical_xor',
            'swapaxes', 'clip', 'argmax', 'argmin', 'std', 'var', 'indices', 'copysign', 'ravel', 'unravel_index',
            'diag_indices_from', 'hanning', 'hamming', 'blackman', 'flip', 'flipud', 'fliplr',
-           'hypot', 'bitwise_and', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'interp',
+           'hypot', 'bitwise_and', 'bitwise_xor', 'bitwise_or', 'rad2deg', 'deg2rad', 'unique', 'lcm', 'gcd', 'interp',
            'tril', 'triu', 'tri', 'identity', 'take', 'ldexp', 'vdot', 'inner', 'outer', 'cross', 'kron',
            'equal', 'not_equal', 'greater', 'less', 'greater_equal', 'less_equal', 'roll', 'rot90', 'einsum',
            'true_divide', 'quantile', 'percentile', 'shares_memory', 'may_share_memory', 'diff', 'ediff1d',
@@ -1680,6 +1680,37 @@
 
 @set_module('mxnet.symbol.numpy')
 @wrap_np_binary_func
+def gcd(x1, x2, out=None, **kwargs):
+    """
+    Returns the greatest common divisor of ``|x1|`` and ``|x2|``
+
+    Parameters
+    ----------
+    x1, x2 : ndarrays or scalar values
+        The arrays for computing greatest common divisor. If x1.shape != x2.shape,
+        they must be broadcastable to a common shape (which may be the shape of
+        one or the other).
+
+    out : ndarray or None, optional
+        A location into which the result is stored. If provided, it must have a shape
+        that the inputs broadcast to. If not provided or None, a freshly-allocated array
+        is returned.
+
+    Returns
+    -------
+    y : ndarray or scalar
+        The greatest common divisor of the absolute value of the inputs
+        This is a scalar if both `x1` and `x2` are scalars.
+
+    See Also
+    --------
+    lcm : The lowest common multiple
+    """
+    return _ufunc_helper(x1, x2, _npi.gcd, _np.gcd, _npi.gcd_scalar, None, out)
+
+
+@set_module('mxnet.symbol.numpy')
+@wrap_np_binary_func
 def matmul(a, b, out=None, **kwargs):
     """
     Matrix product of two arrays.
diff --git a/src/api/operator/numpy/np_elemwise_broadcast_op.cc b/src/api/operator/numpy/np_elemwise_broadcast_op.cc
index f55f82c..a411b06 100644
--- a/src/api/operator/numpy/np_elemwise_broadcast_op.cc
+++ b/src/api/operator/numpy/np_elemwise_broadcast_op.cc
@@ -88,6 +88,14 @@
   UFuncHelper(args, ret, op, op_scalar, nullptr);
 });
 
+MXNET_REGISTER_API("_npi.gcd")
+.set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
+  using namespace runtime;
+  const nnvm::Op* op = Op::Get("_npi_gcd");
+  const nnvm::Op* op_scalar = Op::Get("_npi_gcd_scalar");
+  UFuncHelper(args, ret, op, op_scalar, nullptr);
+});
+
 MXNET_REGISTER_API("_npi.logical_and")
 .set_body([](runtime::MXNetArgs args, runtime::MXNetRetValue* ret) {
   using namespace runtime;
diff --git a/src/common/cuda/rtc/forward_functions-inl.h b/src/common/cuda/rtc/forward_functions-inl.h
index 6568bae..f85916f 100644
--- a/src/common/cuda/rtc/forward_functions-inl.h
+++ b/src/common/cuda/rtc/forward_functions-inl.h
@@ -542,6 +542,49 @@
 }
 
 template <typename DType, typename DType2>
+__device__ inline typename type_util::mixed_type<DType, DType2>::type
+gcd(const DType a, const DType2 b) {
+  if (type_util::is_integral<DType>::value &&
+      type_util::is_integral<DType2>::value) {
+    DType A = a;
+    DType2 B = b;
+    // minus cases.
+    if (a < 0) {
+      A = -a;
+    }
+    if (b < 0) {
+      B = -b;
+    }
+    // handle zero-valued cases.
+    DType c;
+    if (a == 0 && b != 0) {
+      c = B;
+    } else if (b == 0 && a != 0) {
+      c = A;
+    } else if (a == 0 && b == 0) {
+      c = 0;
+    } else {
+      DType tmp;
+      if (A < B) {
+        tmp = A;
+        A = B;
+        B = tmp;
+      }
+      while (A % B != 0) {
+        A = A % B;
+        tmp = A;
+        A = B;
+        B = tmp;
+      }
+      c = B;
+    }
+    return c;
+  } else {
+    return 0;
+  }
+}
+
+template <typename DType, typename DType2>
 __device__ inline typename type_util::mixed_type<DType, DType2>::type bitwise_xor(const DType a,
                                                                        const DType2 b) {
   using mixed_type = typename type_util::mixed_type<DType, DType2>::type;
diff --git a/src/operator/mshadow_op.h b/src/operator/mshadow_op.h
index 82170cd..7c7c18f 100644
--- a/src/operator/mshadow_op.h
+++ b/src/operator/mshadow_op.h
@@ -1704,6 +1704,52 @@
 #pragma GCC diagnostic ignored "-Wint-in-bool-context"
 #pragma GCC diagnostic ignored "-Wbool-compare"
 #endif
+
+/*! \brief used for computing binary greatest common divisor */
+struct gcd : public mxnet_op::tunable {
+  template<typename DType>
+  MSHADOW_XINLINE static typename enable_if<is_integral<DType>::value, DType>::type
+  Map(DType a, DType b) {
+    // minus cases.
+    if (a < 0) {
+      a = -a;
+    }
+    if (b < 0) {
+      b = -b;
+    }
+    // handle zero-valued cases.
+    DType c;
+    if (a == 0 && b != 0) {
+      c = b;
+    } else if (b == 0 && a != 0) {
+      c = a;
+    } else if (a == 0 && b == 0) {
+      c = 0;
+    } else {
+      DType tmp;
+      if (a < b) {
+        tmp = a;
+        a = b;
+        b = tmp;
+      }
+      while (a % b != 0) {
+        a = a % b;
+        tmp = a;
+        a = b;
+        b = tmp;
+      }
+      c = b;
+    }
+    return c;
+  }
+
+  template<typename DType>
+  MSHADOW_XINLINE static typename enable_if<!is_integral<DType>::value, DType>::type
+  Map(DType a, DType b) {
+    return DType(0.0f);
+  }
+};
+
 /*! \brief used for computing binary lowest common multiple */
 struct lcm : public mxnet_op::tunable {
   template<typename DType>
diff --git a/src/operator/numpy/np_elemwise_broadcast_op_extended.cc b/src/operator/numpy/np_elemwise_broadcast_op_extended.cc
index 90a48d4..188b6c8 100644
--- a/src/operator/numpy/np_elemwise_broadcast_op_extended.cc
+++ b/src/operator/numpy/np_elemwise_broadcast_op_extended.cc
@@ -63,6 +63,39 @@
 .set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastBackwardUseIn<cpu, mshadow_op::copysign_grad,
                                                                   mshadow_op::copysign_rgrad>);
 
+NNVM_REGISTER_OP(_npi_gcd)
+.set_num_inputs(2)
+.set_num_outputs(1)
+.set_attr<nnvm::FListInputNames>("FListInputNames",
+[](const NodeAttrs& attrs) {
+     return std::vector<std::string>{"lhs", "rhs"};
+})
+.set_attr<mxnet::FInferShape>("FInferShape", BinaryBroadcastShape)
+.set_attr<nnvm::FInferType>("FInferType", ElemwiseIntType<2, 1>)
+.set_attr<nnvm::FInplaceOption>("FInplaceOption",
+[](const NodeAttrs& attrs){
+     return std::vector<std::pair<int, int> >{{0, 0}, {1, 0}};
+})
+.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
+.set_attr<FCompute>("FCompute<cpu>", BinaryBroadcastIntCompute<cpu, mshadow_op::gcd>)
+.add_argument("lhs", "NDArray-or-Symbol", "First input to the function")
+.add_argument("rhs", "NDArray-or-Symbol", "Second input to the function");
+
+NNVM_REGISTER_OP(_npi_gcd_scalar)
+.set_num_inputs(1)
+.set_num_outputs(1)
+.set_attr_parser(ParamParser<NumpyBinaryScalarParam>)
+.set_attr<mxnet::FInferShape>("FInferShape", ElemwiseShape<1, 1>)
+.set_attr<nnvm::FInferType>("FInferType", ElemwiseIntType<1, 1>)
+.set_attr<nnvm::FInplaceOption>("FInplaceOption",
+  [](const NodeAttrs& attrs){
+    return std::vector<std::pair<int, int> >{{0, 0}};
+  })
+.set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
+.add_argument("data", "NDArray-or-Symbol", "source input")
+.add_arguments(NumpyBinaryScalarParam::__FIELDS__())
+.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::ComputeInt<cpu, mshadow_op::gcd>);
+
 NNVM_REGISTER_OP(_npi_lcm)
 .set_num_inputs(2)
 .set_num_outputs(1)
@@ -94,7 +127,7 @@
 .set_attr<nnvm::FGradient>("FGradient", MakeZeroGradNodes)
 .add_argument("data", "NDArray-or-Symbol", "source input")
 .add_arguments(NumpyBinaryScalarParam::__FIELDS__())
-.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::Compute<cpu, mshadow_op::lcm>);
+.set_attr<FCompute>("FCompute<cpu>", BinaryScalarOp::ComputeInt<cpu, mshadow_op::lcm>);
 
 NNVM_REGISTER_OP(_npi_bitwise_and)
 .set_num_inputs(2)
diff --git a/src/operator/numpy/np_elemwise_broadcast_op_extended.cu b/src/operator/numpy/np_elemwise_broadcast_op_extended.cu
index b1d7e71..ff1cedf 100644
--- a/src/operator/numpy/np_elemwise_broadcast_op_extended.cu
+++ b/src/operator/numpy/np_elemwise_broadcast_op_extended.cu
@@ -31,6 +31,9 @@
 NNVM_REGISTER_OP(_npi_copysign)
 .set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastRTCCompute{"copysign"});
 
+NNVM_REGISTER_OP(_npi_gcd)
+.set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastRTCCompute{"gcd"});
+
 NNVM_REGISTER_OP(_npi_lcm)
 .set_attr<FCompute>("FCompute<gpu>", BinaryBroadcastRTCCompute{"lcm"});
 
@@ -82,6 +85,9 @@
 NNVM_REGISTER_OP(_backward_npi_rarctan2_scalar)
 .set_attr<FCompute>("FCompute<gpu>", BinaryScalarRTCBackward{"rarctan2_grad"});
 
+NNVM_REGISTER_OP(_npi_gcd_scalar)
+.set_attr<FCompute>("FCompute<gpu>", BinaryScalarRTCCompute{"gcd"});
+
 NNVM_REGISTER_OP(_npi_lcm_scalar)
 .set_attr<FCompute>("FCompute<gpu>", BinaryScalarRTCCompute{"lcm"});
 
diff --git a/src/operator/operator_tune.cc b/src/operator/operator_tune.cc
index 9af3364..557338e 100644
--- a/src/operator/operator_tune.cc
+++ b/src/operator/operator_tune.cc
@@ -417,6 +417,7 @@
 IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::bitwise_or);  // NOLINT()
 IMPLEMENT_BINARY_WORKLOAD_FWD(mxnet::op::mshadow_op::smooth_l1_loss);  // NOLINT()
 IMPLEMENT_BINARY_WORKLOAD_BWD(mxnet::op::mshadow_op::smooth_l1_gradient);  // NOLINT()
+IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::gcd);  // NOLINT()
 IMPLEMENT_BINARY_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mshadow_op::lcm);  // NOLINT()
 IMPLEMENT_BLANK_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mxnet_op::set_to_int<0>);  // NOLINT()
 IMPLEMENT_BLANK_WORKLOAD_FWD_WITH_BOOL(mxnet::op::mxnet_op::set_to_int<1>);  // NOLINT()
diff --git a/tests/nightly/test_np_large_array.py b/tests/nightly/test_np_large_array.py
index a423253..fa103fd 100644
--- a/tests/nightly/test_np_large_array.py
+++ b/tests/nightly/test_np_large_array.py
@@ -907,6 +907,22 @@
 
 
 @use_np
+def test_gcd():
+    inp1 = np.ones((2, INT_OVERFLOW), dtype='int32')
+    inp2 = np.ones((2, INT_OVERFLOW), dtype='int32')
+    inp1[-1, -1] = 12
+    inp2[-1, -1] = 20
+    inp1.attach_grad()
+    with mx.autograd.record():
+        out = np.gcd(inp1, inp2)
+        out.backward()
+    assert out.shape == inp1.shape
+    assert out[-1, -1] == 4
+    assert inp1.grad.shape == inp1.shape
+    assert inp1.grad[-1, -1] == 0
+
+
+@use_np
 def test_log_family():
     def batch_check(funcs, exp):
         inp.attach_grad()
diff --git a/tests/python/unittest/test_numpy_interoperability.py b/tests/python/unittest/test_numpy_interoperability.py
index a05bc79..1fa7d52 100644
--- a/tests/python/unittest/test_numpy_interoperability.py
+++ b/tests/python/unittest/test_numpy_interoperability.py
@@ -1488,6 +1488,12 @@
     OpArgMngr.add_workload('lcm', np.array(195225786*2, dtype=np.int32), np.array(195225786*5, dtype=np.int32))
 
 
+def _add_workload_gcd():
+    OpArgMngr.add_workload('gcd', np.array([24, 30], dtype=np.int8), np.array([20, 75], dtype=np.int8))
+    OpArgMngr.add_workload('gcd', np.array([24, 30], dtype=np.uint8), np.array([20, 75], dtype=np.uint8))
+    OpArgMngr.add_workload('gcd', np.array(195225786*2, dtype=np.int32), np.array(195225786*5, dtype=np.int32))
+
+
 def _add_workload_bitwise_or():
     OpArgMngr.add_workload('bitwise_or', np.array([False, False, True, True], dtype=np.bool),
                            np.array([False, True, False, True], dtype=np.bool))
@@ -3071,6 +3077,7 @@
     _add_workload_interp()
     _add_workload_hypot()
     _add_workload_lcm()
+    _add_workload_gcd()
     _add_workload_bitwise_and()
     _add_workload_bitwise_xor()
     _add_workload_bitwise_or()
diff --git a/tests/python/unittest/test_numpy_op.py b/tests/python/unittest/test_numpy_op.py
index 650c420..6bea510 100644
--- a/tests/python/unittest/test_numpy_op.py
+++ b/tests/python/unittest/test_numpy_op.py
@@ -3021,6 +3021,7 @@
                       [[_np.float16, _np.float32, _np.float64], [_np.int32]]),
         'power': (1.0, 3.0, [lambda y, x1, x2: _np.power(x1, x2 - 1.0) * x2],
                              [lambda y, x1, x2: _np.power(x1, x2) * _np.log(x1)]),
+        'gcd': (-100, 100, [None], None, [[_np.int32]]),
         'lcm': (-100, 100, [None], None, [[_np.int32]]),
         'bitwise_and': (-100, 100, [None], None, [[_np.int32]]),
         'bitwise_xor': (-100, 100, [None], None, [[_np.int32]]),