blob: b50c570eb77a189d4d0eed14f4ad93c2473bf90e [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Test relax vm through rpc."""
import tvm
import numpy as np
from tvm import rpc, relax
from tvm.contrib import utils, tvmjs
from tvm.script import relax as R
proxy_host = "127.0.0.1"
proxy_port = 9090
def get_model():
pipeline = relax.get_pipeline()
@tvm.script.ir_module
class Mod:
@R.function
def main(x: R.Tensor([1024], "float32"), y: R.Tensor([1024], "float32")):
lv0 = R.add(x, y)
return lv0
mod = pipeline(Mod)
sch = tvm.s_tir.Schedule(mod)
# manually transform loop
sch.work_on("add")
(i,) = sch.get_loops(block=sch.get_sblock("T_add"))
i0, i1 = sch.split(i, [None, 128])
sch.bind(i0, "blockIdx.x")
sch.bind(i1, "threadIdx.x")
return sch.mod
def test_rpc():
if not tvm.runtime.enabled("rpc"):
return
n = 1024
dtype = "float32"
temp = utils.tempdir()
wasm_path = temp.relpath("relax.wasm")
target = tvm.target.Target(
"webgpu", host={"kind": "llvm", "mtriple": "wasm32-unknown-unknown-wasm"}
)
mod = get_model()
ex = relax.build(mod, target)
ex.export_library(wasm_path, fcompile=tvmjs.create_tvmjs_wasm)
wasm_binary = open(wasm_path, "rb").read()
remote = rpc.connect(
proxy_host,
proxy_port,
key="wasm",
session_constructor_args=["rpc.WasmSession", wasm_binary],
)
def check(remote):
dev = remote.webgpu(0)
# invoke the function
vm = relax.VirtualMachine(remote.system_lib(), device=dev)
adata = np.random.uniform(size=n).astype(dtype)
bdata = np.random.uniform(size=n).astype(dtype)
a = tvm.runtime.tensor(adata, dev)
b = tvm.runtime.tensor(bdata, dev)
vm.set_input("main", a, b)
vm.invoke_stateful("main")
c = vm.get_outputs("main")
np.testing.assert_equal(c.numpy(), a.numpy() + b.numpy())
check(remote)
test_rpc()