blob: 5112bf53844bada20b896d4f4968eb1f08a28e83 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import inspect
import pytest
import tvm.testing
from tvm import relax
from tvm.script import ir as I, relax as R, tir as T
class Base:
def test_compare(self):
transform = relax.transform.AdjustMatmulOrder()
if inspect.isclass(self.Expected) and issubclass(self.Expected, Exception):
with pytest.raises(self.Expected):
transform(self.Before)
else:
after = transform(self.Before)
tvm.ir.assert_structural_equal(self.Expected, after)
class TestLHS(Base):
"""Prefer (x*A)*B instead of x*(A*B)"""
@I.ir_module
class Before:
@R.function
def main(
x: R.Tensor([16]),
A: R.Tensor([16, 2]),
B: R.Tensor([2, 32]),
) -> R.Tensor([32]):
weight: R.Tensor([16, 32]) = R.matmul(A, B)
out: R.Tensor([32]) = R.matmul(x, weight)
return out
@I.ir_module
class Expected:
@R.function
def main(
x: R.Tensor([16]),
A: R.Tensor([16, 2]),
B: R.Tensor([2, 32]),
) -> R.Tensor([32]):
x: R.Tensor([2]) = R.matmul(x, A)
x: R.Tensor([32]) = R.matmul(x, B)
return x
class TestRHS(Base):
"""Prefer A*(B*x) instead of (A*B)*x"""
@I.ir_module
class Before:
@R.function
def main(
x: R.Tensor([16]),
A: R.Tensor([32, 2]),
B: R.Tensor([2, 16]),
) -> R.Tensor([32]):
weight: R.Tensor([32, 16]) = R.matmul(A, B)
out: R.Tensor([32]) = R.matmul(weight, x)
return out
@I.ir_module
class Expected:
@R.function
def main(
x: R.Tensor([16]),
A: R.Tensor([32, 2]),
B: R.Tensor([2, 16]),
) -> R.Tensor([32]):
x: R.Tensor([2]) = R.matmul(B, x)
x: R.Tensor([32]) = R.matmul(A, x)
return x
class TestIdempotentLHS(Base):
"""The transform shouldn't undo itself if re-applied"""
Before = TestLHS.Expected
Expected = TestLHS.Expected
class TestIdempotentRHS(Base):
"""The transform shouldn't undo itself if re-applied"""
Before = TestRHS.Expected
Expected = TestRHS.Expected
class TestPreserveCompileTimeMatmulOnLHS(Base):
"""Prefer x*(A*B) if (A*B) can be pre-computed
If both `A` and `B` are known at compile-time, they can be lifted
out and computed at compile-time. Therefore, optimization should
avoid breaking apart the `(A*B)` expression.
"""
@I.ir_module
class Before:
@R.function
def main(
x: R.Tensor([16]),
A: R.Tensor([32, 2]),
B: R.Tensor([2, 16]),
) -> R.Tensor([32]):
R.func_attr({"num_input": 1})
weight = R.matmul(A, B)
out = R.matmul(weight, x)
return out
Expected = Before
class TestPreserveCompileTimeMatmulOnRHS(Base):
"""Prefer (A*B)*x if (A*B) can be pre-computed
If both `A` and `B` are known at compile-time, they can be lifted
out and computed at compile-time. Therefore, optimization should
avoid breaking apart the `(A*B)` expression.
"""
@I.ir_module
class Before:
@R.function
def main(
x: R.Tensor([16]),
A: R.Tensor([16, 2]),
B: R.Tensor([2, 32]),
) -> R.Tensor([32]):
R.func_attr({"num_input": 1})
weight = R.matmul(A, B)
out = R.matmul(x, weight)
return out
Expected = Before
class TestLHSDynamic(Base):
"""Prefer (x*A)*B instead of x*(A*B)
This case appears when evaluating LoRA-tuned models with a dynamic
rank.
"""
@I.ir_module
class Before:
@R.function
def main(
x: R.Tensor([16]),
A: R.Tensor([16, "lora_r"]),
B: R.Tensor(["lora_r", 32]),
) -> R.Tensor([32]):
weight: R.Tensor([16, 32]) = R.matmul(A, B)
out: R.Tensor([32]) = R.matmul(x, weight)
return out
@I.ir_module
class Expected:
@R.function
def main(
x: R.Tensor([16]),
A: R.Tensor([16, "lora_r"]),
B: R.Tensor(["lora_r", 32]),
) -> R.Tensor([32]):
lora_r = T.int64()
x: R.Tensor([lora_r]) = R.matmul(x, A)
x: R.Tensor([32]) = R.matmul(x, B)
return x
class TestRHSDynamic(Base):
"""Prefer A*(B*x) instead of (A*B)*x"""
@I.ir_module
class Before:
@R.function
def main(
x: R.Tensor([16]),
A: R.Tensor([32, "lora_r"]),
B: R.Tensor(["lora_r", 16]),
) -> R.Tensor([32]):
weight: R.Tensor([32, 16]) = R.matmul(A, B)
out: R.Tensor([32]) = R.matmul(weight, x)
return out
@I.ir_module
class Expected:
@R.function
def main(
x: R.Tensor([16]),
A: R.Tensor([32, "lora_r"]),
B: R.Tensor(["lora_r", 16]),
) -> R.Tensor([32]):
lora_r = T.int64()
x: R.Tensor([lora_r]) = R.matmul(B, x)
x: R.Tensor([32]) = R.matmul(A, x)
return x
class TestIdempotentLHSDynamic(Base):
"""The transform shouldn't undo itself if re-applied"""
Before = TestLHSDynamic.Expected
Expected = TestLHSDynamic.Expected
class TestIdempotentRHSDynamic(Base):
"""The transform shouldn't undo itself if re-applied"""
Before = TestRHSDynamic.Expected
Expected = TestRHSDynamic.Expected
class TestLHSDynamicWithBatch(Base):
"""Prefer (x*A)*B instead of x*(A*B)"""
@I.ir_module
class Before:
@R.function
def main(
x: R.Tensor(["batch_size", 1, 16]),
A: R.Tensor([16, "lora_r"]),
B: R.Tensor(["lora_r", 32]),
) -> R.Tensor(["batch_size", 1, 32]):
batch_size = T.int64()
weight: R.Tensor([16, 32]) = R.matmul(A, B)
out: R.Tensor([batch_size, 1, 32]) = R.matmul(x, weight)
return out
@I.ir_module
class Expected:
@R.function
def main(
x: R.Tensor(["batch_size", 1, 16]),
A: R.Tensor([16, "lora_r"]),
B: R.Tensor(["lora_r", 32]),
) -> R.Tensor(["batch_size", 1, 32]):
lora_r = T.int64()
batch_size = T.int64()
x: R.Tensor([batch_size, 1, lora_r]) = R.matmul(x, A)
x: R.Tensor([batch_size, 1, 32]) = R.matmul(x, B)
return x
class TestRHSDynamicWithBatch(Base):
"""Prefer A*(B*x) instead of (A*B)*x"""
@I.ir_module
class Before:
@R.function
def main(
x: R.Tensor(["batch_size", 16, 1]),
A: R.Tensor([32, "lora_r"]),
B: R.Tensor(["lora_r", 16]),
) -> R.Tensor(["batch_size", 32, 1]):
batch_size = T.int64()
weight: R.Tensor([32, 16]) = R.matmul(A, B)
out: R.Tensor([batch_size, 32, 1]) = R.matmul(weight, x)
return out
@I.ir_module
class Expected:
@R.function
def main(
x: R.Tensor(["batch_size", 16, 1]),
A: R.Tensor([32, "lora_r"]),
B: R.Tensor(["lora_r", 16]),
) -> R.Tensor(["batch_size", 32, 1]):
lora_r = T.int64()
batch_size = T.int64()
x: R.Tensor([batch_size, lora_r, 1]) = R.matmul(B, x)
x: R.Tensor([batch_size, 32, 1]) = R.matmul(A, x)
return x
class TestNoOpForFullyDynamicOnLHS(Base):
"""Keep existing order if no benefit can be proven
Like `TestNoOpForFullyDynamicOnRHS`, except the input has the LHS
computed first.
Here, it is uncertain whether a reordering would improve the
number of operations.
LHS first: (M+P)*N*Q
RHS first: (N+Q)*M*P
After simplifying, the LHS should be performed first if the
following inequality holds:
1/M + 1/P < 1/N + 1/Q
"""
@I.ir_module
class Before:
@R.function
def main(
A: R.Tensor(["M", "N"]),
B: R.Tensor(["N", "P"]),
C: R.Tensor(["P", "Q"]),
):
out = R.matmul(R.matmul(A, B), C)
return out
Expected = Before
class TestNoOpForFullyDynamicOnRHS(Base):
"""Keep existing order if no benefit can be proven
Like `TestNoOpForFullyDynamicOnLHS`, except the input has the RHS
computed first.
"""
@I.ir_module
class Before:
@R.function
def main(
A: R.Tensor(["M", "N"]),
B: R.Tensor(["N", "P"]),
C: R.Tensor(["P", "Q"]),
):
out = R.matmul(A, R.matmul(B, C))
return out
Expected = Before
class TestRHSPermuteDims(Base):
"""Prefer (x*A)*B instead of x*(A*B)
Like `TestRHS`, but the weights on the RHS are transposed.
"""
@I.ir_module
class Before:
@R.function
def main(
x: R.Tensor([16]),
A: R.Tensor([32, 2]),
B: R.Tensor([2, 16]),
) -> R.Tensor([32]):
linear_weight: R.Tensor([32, 16]) = R.matmul(A, B)
matmul_weight: R.Tensor([16, 32]) = R.permute_dims(linear_weight)
out: R.Tensor([32]) = R.matmul(x, matmul_weight)
return out
@I.ir_module
class Expected:
@R.function
def main(
x: R.Tensor([16]),
A: R.Tensor([32, 2]),
B: R.Tensor([2, 16]),
) -> R.Tensor([32]):
B_transpose = R.permute_dims(B)
x: R.Tensor([2]) = R.matmul(x, B_transpose)
A_transpose = R.permute_dims(A)
x: R.Tensor([32]) = R.matmul(x, A_transpose)
return x
class TestRHSPermuteDimsDynamic(Base):
"""Prefer (x*A)*B instead of x*(A*B)
Like `TestRHSPermuteDims`, but the weights on the RHS have a
dynamic shape.
"""
@I.ir_module
class Before:
@R.function
def main(
x: R.Tensor([16]),
A: R.Tensor([32, "lora_r"]),
B: R.Tensor(["lora_r", 16]),
) -> R.Tensor([32]):
linear_weight: R.Tensor([32, 16]) = R.matmul(A, B)
matmul_weight: R.Tensor([16, 32]) = R.permute_dims(linear_weight)
out: R.Tensor([32]) = R.matmul(x, matmul_weight)
return out
@I.ir_module
class Expected:
@R.function
def main(
x: R.Tensor([16]),
A: R.Tensor([32, "lora_r"]),
B: R.Tensor(["lora_r", 16]),
) -> R.Tensor([32]):
lora_r = T.int64()
B_transpose = R.permute_dims(B)
x: R.Tensor([lora_r]) = R.matmul(x, B_transpose)
A_transpose = R.permute_dims(A)
x: R.Tensor([32]) = R.matmul(x, A_transpose)
return x
class TestRHSPermuteDimsWithDynamicBatch(Base):
"""Prefer (x*A)*B instead of x*(A*B)
Like `TestRHSPermuteDims`, but both the weights on the RHS and the
activations on the LHS have a dynamic dimension.
Unlike most of the tests for this transform, the
`tir_vars_upper_bound` attribute is required. In order to make a
change, `AdjustMatmulOrder` must first prove that the modified
execution order reduces the number of computations.
ops_left_to_right = (batch_size + lora_r)*4096*4096
ops_right_to_left = (4096 + 4096)*batch_size*lora_r
Without an upper bound on `lora_r`, we cannot prove which of these
is the preferred execution order. With the upper bound, TVM can
determine the preferred order using the following arithmethic
reasoning.
(batch_size + lora_r)*4096*4096 < (4096 + 4096)*batch_size*lora_r
(batch_size + lora_r)*2048 < batch_size*lora_r
1/batch_size + 1/lora_r < 1/2048
"""
@I.ir_module
class Before:
@R.function
def main(
x: R.Tensor(["batch_size", 4096]),
A: R.Tensor([4096, "lora_r"]),
B: R.Tensor(["lora_r", 4096]),
) -> R.Tensor(["batch_size", 4096]):
R.func_attr({"tir_var_upper_bound": {"lora_r": 2048}})
batch_size = T.int64()
linear_weight: R.Tensor([4096, 4096]) = R.matmul(A, B)
matmul_weight: R.Tensor([4096, 4096]) = R.permute_dims(linear_weight)
out: R.Tensor([batch_size, 4096]) = R.matmul(x, matmul_weight)
return out
@I.ir_module
class Expected:
@R.function
def main(
x: R.Tensor(["batch_size", 4096]),
A: R.Tensor([4096, "lora_r"]),
B: R.Tensor(["lora_r", 4096]),
) -> R.Tensor(["batch_size", 4096]):
R.func_attr({"tir_var_upper_bound": {"lora_r": 2048}})
lora_r = T.int64()
batch_size = T.int64()
B_transpose = R.permute_dims(B)
x: R.Tensor([batch_size, lora_r]) = R.matmul(x, B_transpose)
A_transpose = R.permute_dims(A)
x: R.Tensor([batch_size, 4096]) = R.matmul(x, A_transpose)
return x
class TestRHSPermuteDimsDynamicWithSquareMatrix(Base):
"""Prefer (x*A)*B instead of x*(A*B)
Like `TestRHSPermuteDims`, but the weights on the RHS have a
dynamic shape.
"""
@I.ir_module
class Before:
@R.function
def main(
x: R.Tensor([32]),
A: R.Tensor([32, "lora_r"]),
B: R.Tensor(["lora_r", 32]),
) -> R.Tensor([32]):
linear_weight: R.Tensor([32, 32]) = R.matmul(A, B)
matmul_weight: R.Tensor([32, 32]) = R.permute_dims(linear_weight)
out: R.Tensor([32]) = R.matmul(x, matmul_weight)
return out
@I.ir_module
class Expected:
@R.function
def main(
x: R.Tensor([32]),
A: R.Tensor([32, "lora_r"]),
B: R.Tensor(["lora_r", 32]),
) -> R.Tensor([32]):
lora_r = T.int64()
B_transpose = R.permute_dims(B)
x: R.Tensor([lora_r]) = R.matmul(x, B_transpose)
A_transpose = R.permute_dims(A)
x: R.Tensor([32]) = R.matmul(x, A_transpose)
return x
if __name__ == "__main__":
tvm.testing.main()