[BugFix][Relax][ONNX] Resolve param Vars in Concat to handle mixed Shape/Tensor inputs (#19498)
## Description
When `from_onnx(model, keep_params_in_input=True)` is used, every ONNX
initializer becomes a `relax.Var` instead of a `relax.Constant`. The
`Concat` handler's `is_shape_like()` check only recognizes
`relax.ShapeExpr` and 1D-int64 `relax.Constant`, so a 1D-int64 shape
value loaded as a Var is no longer recognized.
When such a Var is concatenated with a `ShapeExpr` — the standard
pattern for dynamic-batch `Reshape` in PyTorch-exported ONNX models —
the heterogeneous `Tuple(ShapeExpr, Tensor)` is rejected by
`relax.op.concat` with:
```
InternalError: Op(relax.concat) expects the input to be a Tuple of Tensors.
However, the given input is R.Tuple(R.Shape([N]), R.Tensor((1,), dtype="int64"))
```
This effectively breaks `keep_params_in_input=True` for any model with
dynamic-batch `Reshape` (extremely common in PyTorch ONNX exports).
## Fix
Run each `Concat` input through the existing `get_constant` helper
before the `is_shape_like` check. This resolves any `Var` that maps to a
known param back to its baked `Constant`, restoring the all-shape-like
fast path.
## Minimal repro
An 8-node ONNX graph (`Shape` → `Slice` → `Concat([dyn_n, [12]])` →
`Reshape`) fails with `keep_params_in_input=True` before this PR and
passes after. A regression test (`test_concat_with_param_shape_value`)
covers this pattern.
## Testing
```
pytest tests/python/relax/test_frontend_onnx.py -k concat
```
9 passed (1 new + 8 existing).
diff --git a/python/tvm/relax/frontend/onnx/onnx_frontend.py b/python/tvm/relax/frontend/onnx/onnx_frontend.py
index 9d65fe0..268d91b 100644
--- a/python/tvm/relax/frontend/onnx/onnx_frontend.py
+++ b/python/tvm/relax/frontend/onnx/onnx_frontend.py
@@ -1014,6 +1014,7 @@
@classmethod
def _impl_v13(cls, bb, inputs, attr, params):
axis = attr.get("axis", 0)
+ _, param_dict = params
def is_shape_like(x: Any) -> bool:
if isinstance(x, relax.ShapeExpr):
@@ -1023,10 +1024,22 @@
else:
return False
+ # Resolve 1D-int64 param Vars to constants only for the shape-like
+ # fast path; tensor fallback keeps the original Vars so runtime
+ # weights aren't folded under keep_params_in_input=True.
+ def resolve(x):
+ if isinstance(x, relax.Var) and x.name_hint in param_dict:
+ arr = param_dict[x.name_hint][1].numpy()
+ if arr.ndim == 1 and arr.dtype == _np.int64:
+ return relax.const(arr, "int64")
+ return x
+
+ resolved = [resolve(inp) for inp in inputs]
+
# If all inputs are shape expr, perform computation directly.
- if all([is_shape_like(inp) for inp in inputs]):
+ if all([is_shape_like(inp) for inp in resolved]):
const_inputs = []
- for inp in inputs:
+ for inp in resolved:
if isinstance(inp, relax.ShapeExpr):
const_inputs.extend(inp.values)
elif isinstance(inp, relax.Constant):
diff --git a/tests/python/relax/test_frontend_onnx.py b/tests/python/relax/test_frontend_onnx.py
index db68476..5a8d84b 100644
--- a/tests/python/relax/test_frontend_onnx.py
+++ b/tests/python/relax/test_frontend_onnx.py
@@ -29,7 +29,7 @@
import onnxruntime
import pytest
import tvm_ffi
-from onnx import ModelProto, TensorProto, helper
+from onnx import ModelProto, TensorProto, helper, numpy_helper
import tvm
import tvm.testing
@@ -533,6 +533,57 @@
verify_binary("Concat", [1, 32], [1, 32], [2, 32], attrs={"axis": 0})
+def test_concat_with_param_shape_value():
+ """Concat must handle a 1D-int64 initializer mixed with a ShapeExpr when
+ keep_params_in_input=True. Standard pattern in PyTorch-exported ONNX
+ models for dynamic-batch Reshape: Reshape(x, Concat(Shape(x)[:1], [12]))."""
+ inp = helper.make_tensor_value_info("x", TensorProto.FLOAT, ["N", 3, 4])
+ out = helper.make_tensor_value_info("y", TensorProto.FLOAT, ["N", 12])
+ twelve = numpy_helper.from_array(np.array([12], dtype=np.int64), "twelve")
+ starts = numpy_helper.from_array(np.array([0], dtype=np.int64), "starts")
+ ends = numpy_helper.from_array(np.array([1], dtype=np.int64), "ends")
+ nodes = [
+ helper.make_node("Shape", ["x"], ["x_shape"]),
+ helper.make_node("Slice", ["x_shape", "starts", "ends"], ["dyn_n"]),
+ helper.make_node("Concat", ["dyn_n", "twelve"], ["new_shape"], axis=0),
+ helper.make_node("Reshape", ["x", "new_shape"], ["y"]),
+ ]
+ graph = helper.make_graph(
+ nodes, "concat_param_shape", [inp], [out],
+ initializer=[twelve, starts, ends],
+ )
+ model = helper.make_model(
+ graph, opset_imports=[helper.make_opsetid("", 13)]
+ )
+ model.ir_version = 8
+ onnx.checker.check_model(model)
+ # Both modes should succeed; previously True crashed with
+ # "Op(relax.concat) expects the input to be a Tuple of Tensors".
+ from_onnx(model, keep_params_in_input=False)
+ from_onnx(model, keep_params_in_input=True)
+
+
+def test_concat_with_param_tensor_keeps_runtime_param():
+ """Concat(input, weight) under keep_params_in_input=True must keep `weight`
+ as a runtime param, not fold it into a constant."""
+ weight_np = np.arange(8, dtype=np.float32).reshape(2, 4)
+ graph = helper.make_graph(
+ [helper.make_node("Concat", ["x", "w"], ["y"], axis=0)],
+ "concat_param_tensor",
+ [helper.make_tensor_value_info("x", TensorProto.FLOAT, [2, 4])],
+ [helper.make_tensor_value_info("y", TensorProto.FLOAT, [4, 4])],
+ initializer=[numpy_helper.from_array(weight_np, "w")],
+ )
+ model = helper.make_model(graph, opset_imports=[helper.make_opsetid("", 13)])
+ model.ir_version = 8
+ onnx.checker.check_model(model)
+
+ mod, params = relax.frontend.detach_params(from_onnx(model, keep_params_in_input=True))
+ assert "w" in [p.name_hint for p in mod["main"].params]
+ assert len(params["main"]) == 1
+ np.testing.assert_array_equal(params["main"][0].numpy(), weight_np)
+
+
@pytest.mark.parametrize("op_name", ["Add", "Sub", "Mul", "Div", "Pow"])
def test_binary(op_name: str):
verify_binary(op_name, [1, 32], [1, 32], [1, 32])