[TOPI] Expose `topi::collapse_sum` to Python and support symbolic shape (#14541)

TOPI has an implementation of collapse_sum internally (tvm/topi/reduction.h) but it is not exposed to FFI and can not be called in Python side. This patch exposes it and adds related tests. And this PR lets the implementation of topi::collapse_sum support symbolic shape cases.
diff --git a/include/tvm/topi/reduction.h b/include/tvm/topi/reduction.h
index 169ae01..9090ae9 100644
--- a/include/tvm/topi/reduction.h
+++ b/include/tvm/topi/reduction.h
@@ -333,21 +333,28 @@
 }
 
 inline Tensor collapse_sum(const Tensor& data, Array<PrimExpr> target_shape) {
-  ICHECK_GE(data->shape.size(), target_shape.size());
-  auto ishape = detail::GetConstIntValues(data->shape, "ishape");
-  auto oshape = detail::GetConstIntValues(target_shape, "oshape");
+  const auto& ishape = data->shape;
+  const auto& oshape = target_shape;
+  int isize = data->shape.size();
+  int osize = target_shape.size();
+
+  ICHECK_GE(isize, osize)
+      << "Invalid collapse: input dimensionality smaller than output dimensionality.\ninput shape: "
+      << data->shape << "\nvs\noutput shape: " << target_shape;
 
   std::vector<int> reduce_axes;
   std::vector<int> squeeze_axes;
-  for (int i_ax = ishape.size() - 1, o_ax = oshape.size() - 1; i_ax >= 0; --i_ax) {
-    if (o_ax >= 0 && ishape[i_ax] == oshape[o_ax]) {
+  tvm::PrimExpr one(1);
+
+  for (int i_ax = isize - 1, o_ax = osize - 1; i_ax >= 0; --i_ax) {
+    if (o_ax >= 0 && topi::detail::EqualCheck(ishape[i_ax], oshape[o_ax])) {
       --o_ax;
       continue;
     }
     reduce_axes.push_back(i_ax);
     if (o_ax < 0) {  // squeeze o_ax if was added during expansion
       squeeze_axes.push_back(i_ax);
-    } else if (oshape[o_ax] == 1) {
+    } else if (topi::detail::EqualCheck(one, oshape[o_ax])) {
       --o_ax;
     }
   }
diff --git a/python/tvm/topi/reduction.py b/python/tvm/topi/reduction.py
index 45d07af..5045cb8 100644
--- a/python/tvm/topi/reduction.py
+++ b/python/tvm/topi/reduction.py
@@ -248,3 +248,34 @@
     ret : tvm.te.Tensor
     """
     return cpp.prod(data, axis, keepdims)
+
+
+def collapse_sum(data, target_shape):
+    """Return a summation of data to the given shape.
+
+    collapse_sum is intended as the backward operator of topi broadcast operators in the automatic
+    differentiation process.
+
+    We expect that data is the result of broadcasting some tensor of target_shape in some
+    broadcast operation. Thus target_shape and data.shape must follow broadcast rules.
+
+    During computation, the axes of data.shape and target_shape are checked from right to left.
+    For every axis, if it either:
+    - exist in data but not in target_shape, or
+    - is larger than 1 in data and equals to 1 in target_shape,
+    data will be summed over this axis.
+
+    Parameters
+    ----------
+    data : tvm.te.Tensor
+        The input tensor.
+
+    shape : Tuple[int]
+        The shape to collapse to.
+
+    Returns
+    -------
+    ret : tvm.te.Tensor
+        The result tensor after summation.
+    """
+    return cpp.collapse_sum(data, target_shape)
diff --git a/src/topi/reduction.cc b/src/topi/reduction.cc
index 3d1c6f9..a9d692c 100644
--- a/src/topi/reduction.cc
+++ b/src/topi/reduction.cc
@@ -64,5 +64,9 @@
   *rv = topi::any(args[0], ArrayOrInt(args[1]), args[2]);
 });
 
+TVM_REGISTER_GLOBAL("topi.collapse_sum").set_body([](TVMArgs args, TVMRetValue* rv) {
+  *rv = topi::collapse_sum(args[0], args[1]);
+});
+
 }  // namespace topi
 }  // namespace tvm
diff --git a/tests/python/topi/python/test_topi_reduce.py b/tests/python/topi/python/test_topi_reduce.py
index 3c4c170..8e45ae9 100644
--- a/tests/python/topi/python/test_topi_reduce.py
+++ b/tests/python/topi/python/test_topi_reduce.py
@@ -25,7 +25,7 @@
 import tvm.testing
 import tvm.topi.testing
 
-from tvm import te, topi
+from tvm import te, topi, tir
 
 in_shape, axis, keepdims, reduce_type, dtype = tvm.testing.parameters(
     ((32,), 0, False, "argmax", "float32"),
@@ -191,5 +191,53 @@
     tvm.testing.assert_allclose(out_tvm.numpy(), out_npy, 1e-3, 1e-3)
 
 
+n = tir.Var("n", "int32")
+m = tir.Var("m", "int32")
+true_value_map = {n: 3, m: 5}
+
+data_shape, target_shape = tvm.testing.parameters(
+    ((2, 3), (3,)),
+    ((2, 3, 4), (2, 1, 4)),
+    ((2, 3, 4, 5), (3, 1, 5)),
+    ((2, n, 4, m), (n, 1, m)),
+)
+
+
+def _my_npy_collapse_sum(data, target_shape):
+    reduce_axes = []
+    i = data.ndim - 1
+    j = len(target_shape) - 1
+    while i >= 0:
+        if j < 0:
+            reduce_axes.append(i)
+        elif target_shape[j] == 1 and data.shape[i] > 1:
+            reduce_axes.append(i)
+        i -= 1
+        j -= 1
+    return np.sum(data, tuple(reduce_axes)).reshape(target_shape)
+
+
+def test_collapse_sum(data_shape, target_shape):
+    A = te.placeholder(data_shape, name="A")
+    B = topi.collapse_sum(A, target_shape)
+    s = te.create_schedule([B.op])
+
+    data_shape_const = [int(s) if s not in true_value_map else true_value_map[s] for s in A.shape]
+    target_shape_const = [
+        int(s) if s not in true_value_map else true_value_map[s] for s in target_shape
+    ]
+    a_np = np.random.uniform(size=data_shape_const).astype(A.dtype)
+    b_np = _my_npy_collapse_sum(a_np, target_shape_const)
+    dev = tvm.cpu(0)
+    a = tvm.nd.array(a_np, dev)
+    B_shape_const = [int(s) if s not in true_value_map else true_value_map[s] for s in B.shape]
+    b = tvm.nd.array(np.zeros(B_shape_const, dtype=B.dtype), dev)
+    # Building with the CSE pass disabled
+    with tvm.transform.PassContext(opt_level=3, disabled_pass=["tir.CommonSubexprElimTIR"]):
+        foo = tvm.build(s, [A, B], "llvm", name="collapse_sum")
+    foo(a, b)
+    tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-5)
+
+
 if __name__ == "__main__":
     tvm.testing.main()