blob: 07aa0bd752df59362e3921358d59a9d0ab3ad2d3 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import numpy as np
import tvm
import tvm.script.relax as R
from tvm.contrib import ndk, utils
from tvm.relax.backend.contrib.nnapi import partition_for_nnapi
# pylint: disable=import-outside-toplevel,missing-function-docstring
def reshape_matmul(mod: tvm.IRModule):
from typing import Dict
from tvm.relax import Expr
from tvm.relax.dpl import DFPattern, rewrite_call
from tvm.relax.dpl.pattern import is_op, wildcard
input0 = wildcard()
input1 = wildcard()
pattern = is_op("relax.matmul")(input0, input1)
def _rewriter(expr: Expr, matches: Dict[DFPattern, Expr]):
i0 = matches[input0]
i1 = matches[input1]
if len(i0.struct_info.shape) == 2 and len(i1.struct_info.shape) == 2:
i0_shape = [1] + [*i0.struct_info.shape.values]
i1_shape = [1] + [*i1.struct_info.shape.values]
oshape = matches[pattern].struct_info.shape
return R.reshape(R.matmul(R.reshape(i0, i0_shape), R.reshape(i1, i1_shape)), oshape)
return expr
mod["main"] = rewrite_call(pattern, _rewriter, mod["main"])
return mod
def decompose_clip(mod: tvm.IRModule) -> tvm.IRModule:
from typing import Dict
from tvm.relax import Expr
from tvm.relax.dpl import DFPattern, rewrite_call
from tvm.relax.dpl.pattern import is_op, wildcard
input_pattern = wildcard()
min_pattern = wildcard()
max_pattern = wildcard()
pattern = is_op("relax.clip")(input_pattern, min_pattern, max_pattern)
def _rewriter(
expr: Expr, matches: Dict[DFPattern, Expr]
) -> Expr: # pylint: disable=unused-argument
dtype = matches[input_pattern].struct_info.dtype
return R.minimum(
R.maximum(
matches[input_pattern],
R.const(np.array(matches[min_pattern].value.value).astype(dtype), dtype),
),
R.const(np.array(matches[max_pattern].value.value).astype(dtype), dtype),
)
mod["main"] = rewrite_call(pattern, _rewriter, mod["main"])
return mod
def _build(mod, enable_nnapi):
if isinstance(mod, tvm.relax.expr.Call):
mod = tvm.IRModule.from_expr(mod)
if enable_nnapi:
mod = tvm.relax.transform.FoldConstant()(mod)
mod = reshape_matmul(mod)
mod = decompose_clip(mod)
mod = partition_for_nnapi(mod)
mod = tvm.relax.transform.RunCodegen()(mod)
ex = tvm.compile(mod, target="llvm -mtriple=aarch64-linux-android")
return ex
def _run(remote, tracker, ex, inputs):
tmp = utils.tempdir()
so_name = "test_mod.so"
so_path = tmp / so_name
ex.export_library(str(so_path), fcompile=ndk.create_shared, options=["-shared", "-fPIC", "-lm"])
remote.upload(so_path)
dev = remote.cpu(0)
try:
# Execute the model on the remote.
remote_ex = remote.load_module(so_name)
vm = tvm.relax.VirtualMachine(remote_ex, device=dev)
inputs = [x.copyto(dev) for x in inputs]
vm.set_input("main", *inputs)
vm.invoke_stateful("main")
output = vm.get_outputs("main")
output = output.numpy()
except Exception as e:
# Re-raise all exceptions
raise e
finally:
# Manually close the connection.
# See https://discuss.tvm.apache.org/t/trouble-with-rpc-session/14008/.
#
# TODO: Remove if it does not happen on Python 3.11.
remote._sess.get_function("CloseRPCConnection")()
tracker.close()
pass
return output
def build_and_run(
remote,
tracker,
mod,
inputs,
enable_nnapi=False,
):
ex = _build(mod, enable_nnapi)
return _run(remote, tracker, ex, inputs)