| commit | 6248b5db43505fbcfb13cc289d11877d5d2649e8 | [log] [tgz] |
|---|---|---|
| author | Nguyen Duy Loc <77536430+locnd182644@users.noreply.github.com> | Sat Dec 13 14:29:23 2025 +0700 |
| committer | GitHub <noreply@github.com> | Sat Dec 13 02:29:23 2025 -0500 |
| tree | aed9f110e6068a7152261c49ff35d2ff0066bb3c | |
| parent | 85a877085714b4d10d65e2c267dab3937915e8a1 [diff] |
[Relax][Torch] Fixed issues related to sum op when without dim and keep dim (#18583)
## Issue 1: Without Dim
### Summary:
In _sum function (BaseFXGraphImporter), after retrieve_args, args[1] =
[] and still pass into relax.op.sum so the result is incorrect.
### Steps to Reproduce
- Module
```
class SumWithoutDim(nn.Module):
def forward(self, x):
return torch.sum(x)
```
```
class Module:
def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tuple(R.Tensor((2, 3), dtype="float32")):
with R.dataflow():
lv: R.Tensor((2, 3), dtype="float32") = R.sum(x, axis=[], keepdims=False)
gv: R.Tuple(R.Tensor((2, 3), dtype="float32")) = (lv,)
R.output(gv)
return gv
```
- Result:
Input: tensor([[1., 1., 1.], [1., 1., 1.]])
Torch output: tensor(6.)
Torch output shape: torch.Size([])
TVM output: [[1. 1. 1.] [1. 1. 1.]]
TVM output shape: (2, 3)
### Expected
```
class Module:
def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tuple(R.Tensor((), dtype="float32")):
with R.dataflow():
lv: R.Tensor((), dtype="float32") = R.sum(x, axis=None, keepdims=False)
gv: R.Tuple(R.Tensor((), dtype="float32")) = (lv,)
R.output(gv)
return gv
```
- Result: TVM output: 6.0; TVM output shape: ()
## Issue 2: Keep Dim
### Summary:
In _sum function (BaseFXGraphImporter), previously keepdim value get
only from node.kwargs and no pass into relax.op.sum. Now keepdim get
more from args[2] and pass into.
### Steps to Reproduce
- Module
```
class SumKeepDim(nn.Module):
def forward(self, x):
return torch.sum(x, dim=1, keepdim=True)
```
```
class Module:
def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tuple(R.Tensor((2,), dtype="float32")):
with R.dataflow():
lv: R.Tensor((2,), dtype="float32") = R.sum(x, axis=[1], keepdims=False)
gv: R.Tuple(R.Tensor((2,), dtype="float32")) = (lv,)
R.output(gv)
return gv
```
- Result:
Input: tensor([[1., 1., 1.], [1., 1., 1.]])
Torch output: tensor([[3.], [3.]])
Torch output shape: torch.Size([2, 1])
TVM VM output: [3. 3.]
TVM VM output shape: (2,)
### Expected
```
class Module:
def main(x: R.Tensor((2, 3), dtype="float32")) -> R.Tuple(R.Tensor((2, 1), dtype="float32")):
with R.dataflow():
lv: R.Tensor((2, 1), dtype="float32") = R.sum(x, axis=[1], keepdims=True)
gv: R.Tuple(R.Tensor((2, 1), dtype="float32")) = (lv,)
R.output(gv)
return gv
```
- Result: TVM output: [[3.] [3.]] ;TVM output shape: (2, 1)Documentation | Contributors | Community | Release Notes
Apache TVM is an open machine learning compilation framework, following the following principles:
TVM is licensed under the Apache-2.0 license.
Check out the TVM Documentation site for installation instructions, tutorials, examples, and more. The Getting Started with TVM tutorial is a great place to start.
TVM adopts the Apache committer model. We aim to create an open-source project maintained and owned by the community. Check out the Contributor Guide.
TVM started as a research project for deep learning compilation. The first version of the project benefited a lot from the following projects:
Since then, the project has gone through several rounds of redesigns. The current design is also drastically different from the initial design, following the development trend of the ML compiler community.
The most recent version focuses on a cross-level design with TensorIR as the tensor-level representation and Relax as the graph-level representation and Python-first transformations. The project's current design goal is to make the ML compiler accessible by enabling most transformations to be customizable in Python and bringing a cross-level representation that can jointly optimize computational graphs, tensor programs, and libraries. The project is also a foundation infra for building Python-first vertical compilers for domains, such as LLMs.