blob: 76caccb53666490c8ec8206771a2b4a943e0f446 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
import tvm.testing
from tvm.ir import VDevice
from tvm.relax.transform import UpdateVDevice
from tvm.script.parser import ir as I, relax as R, tir as T
def verify(input, new_vdevice, vdevice_index, expected):
tvm.ir.assert_structural_equal(UpdateVDevice(new_vdevice, vdevice_index)(input), expected)
def test_update():
vdevices = [
VDevice("llvm"),
VDevice("cuda", 0),
VDevice("metal", 0, "global"),
VDevice("cuda -arch=sm_80", 0),
VDevice("metal", 1, "global"),
VDevice("llvm", 1),
]
@I.ir_module
class Input1:
I.module_attrs({"attr": 10})
I.module_global_infos(
{
"vdevice": [
I.vdevice("llvm"),
I.vdevice("cuda", 0),
I.vdevice("metal", 0, "global"),
I.vdevice("cuda -arch=sm_80", 0),
]
}
)
@R.function
def main(
a: R.Tensor((128, 128), "float32", "cuda:1"), # noqa: F722
c: R.Tensor((128, 128), "float32", "vdevice:3"), # noqa: F722
) -> R.Tensor((128, 128), "float32"):
s = R.add(a, c)
return s
@I.ir_module
class Expect1:
I.module_attrs({"attr": 10})
I.module_global_infos(
{
"vdevice": [
I.vdevice("llvm"),
I.vdevice("cuda", 0),
I.vdevice("metal", 0, "global"),
I.vdevice("metal", 1, "global"),
]
}
)
@R.function
def main(
a: R.Tensor((128, 128), dtype="float32", vdevice="metal:1"), # noqa: F722
c: R.Tensor((128, 128), dtype="float32", vdevice="metal:1"), # noqa: F722
) -> R.Tensor((128, 128), dtype="float32", vdevice="metal:1"): # noqa: F722
s: R.Tensor((128, 128), dtype="float32", vdevice="metal:1") = R.add(a, c) # noqa: F722
return s
@I.ir_module
class Input2:
I.module_attrs({"attr": 10})
I.module_global_infos(
{
"vdevice": [
I.vdevice("llvm"),
I.vdevice("cuda", 0),
]
}
)
@R.function
def main(
a: R.Tensor((128, 128), "float32", "cuda:0"), # noqa: F722
c: R.Tensor((128, 128), "float32", "cuda:0"), # noqa: F722
) -> R.Tensor((128, 128), "float32"):
s = R.add(a, c)
return s
@I.ir_module
class Expect2:
I.module_attrs({"attr": 10})
I.module_global_infos(
{
"vdevice": [
I.vdevice("llvm"),
I.vdevice("llvm", 1),
]
}
)
@R.function
def main(
a: R.Tensor((128, 128), "float32", "llvm:1"), # noqa: F722
c: R.Tensor((128, 128), "float32", "llvm:1"), # noqa: F722
) -> R.Tensor((128, 128), "float32", "llvm:1"): # noqa: F722
s: R.Tensor((128, 128), "float32", "llvm:1") = R.add(a, c) # noqa: F722
return s
verify(Input1, vdevices[4], 3, Expect1)
verify(Input2, vdevices[5], 1, Expect2)
if __name__ == "__main__":
tvm.testing.main()