blob: 74170d3d3c7130044d74099a4cd8118f1a5294c9 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from itertools import zip_longest, combinations
import json
import os
import warnings
import numpy as np
import tvm
from tvm import relay
from tvm import rpc
from tvm.contrib import graph_executor
from tvm.relay.op.contrib import arm_compute_lib
from tvm.contrib import utils
from tvm.autotvm.measure import request_remote
class Device:
"""
Configuration for Arm Compute Library tests.
Check tests/python/contrib/arm_compute_lib/ for the presence of an test_config.json file.
This file can be used to override the default configuration here which will attempt to run the Arm
Compute Library runtime tests locally if the runtime is available. Changing the configuration
will allow these runtime tests to be offloaded to a remote Arm device via a tracker for example.
Notes
-----
The test configuration will be loaded once when the class is created. If the configuration
changes between tests, any changes will not be picked up.
Parameters
----------
device : RPCSession
Allows tests to connect to and use remote device.
Attributes
----------
connection_type : str
Details the type of RPC connection to use. Options:
local - Use the local device,
tracker - Connect to a tracker to request a remote device,
remote - Connect to a remote device directly.
host : str
Specify IP address or hostname of remote target.
port : int
Specify port number of remote target.
target : str
The compilation target.
device_key : str
The device key of the remote target. Use when connecting to a remote device via a tracker.
cross_compile : str
Specify path to cross compiler to use when connecting a remote device from a non-arm platform.
"""
connection_type = "local"
host = "127.0.0.1"
port = 9090
target = "llvm -mtriple=aarch64-linux-gnu -mattr=+neon"
device_key = ""
cross_compile = ""
def __init__(self):
"""Keep remote device for lifetime of object."""
self.device = self._get_remote()
@classmethod
def _get_remote(cls):
"""Get a remote (or local) device to use for testing."""
if cls.connection_type == "tracker":
device = request_remote(cls.device_key, cls.host, cls.port, timeout=1000)
elif cls.connection_type == "remote":
device = rpc.connect(cls.host, cls.port)
elif cls.connection_type == "local":
device = rpc.LocalSession()
else:
raise ValueError(
"connection_type in test_config.json should be one of: " "local, tracker, remote."
)
return device
@classmethod
def load(cls, file_name):
"""Load test config
Load the test configuration by looking for file_name relative
to the test_arm_compute_lib directory.
"""
location = os.path.realpath(os.path.join(os.getcwd(), os.path.dirname(__file__)))
config_file = os.path.join(location, file_name)
if not os.path.exists(config_file):
warnings.warn(
"Config file doesn't exist, resuming Arm Compute Library tests with default config."
)
return
with open(config_file, mode="r") as config:
test_config = json.load(config)
cls.connection_type = test_config["connection_type"]
cls.host = test_config["host"]
cls.port = test_config["port"]
cls.target = test_config["target"]
cls.device_key = test_config.get("device_key") or ""
cls.cross_compile = test_config.get("cross_compile") or ""
def get_cpu_op_count(mod):
"""Traverse graph counting ops offloaded to TVM."""
class Counter(tvm.relay.ExprVisitor):
def __init__(self):
super().__init__()
self.count = 0
def visit_call(self, call):
if isinstance(call.op, tvm.ir.Op):
self.count += 1
super().visit_call(call)
c = Counter()
c.visit(mod["main"])
return c.count
def skip_runtime_test():
"""Skip test if it requires the runtime and it's not present."""
# ACL codegen not present.
if not tvm.get_global_func("relay.ext.arm_compute_lib", True):
print("Skip because Arm Compute Library codegen is not available.")
return True
# Remote device is in use or ACL runtime not present
# Note: Ensure that the device config has been loaded before this check
if (
not Device.connection_type != "local"
and not arm_compute_lib.is_arm_compute_runtime_enabled()
):
print("Skip because runtime isn't present or a remote device isn't being used.")
return True
def skip_codegen_test():
"""Skip test if it requires the ACL codegen and it's not present."""
if not tvm.get_global_func("relay.ext.arm_compute_lib", True):
print("Skip because Arm Compute Library codegen is not available.")
return True
def build_module(
mod,
target,
params=None,
enable_acl=True,
tvm_ops=0,
acl_partitions=1,
disabled_ops=["concatenate"],
):
"""Build module with option to build for ACL."""
if isinstance(mod, tvm.relay.expr.Call):
mod = tvm.IRModule.from_expr(mod)
with tvm.transform.PassContext(opt_level=3, disabled_pass=["AlterOpLayout"]):
if enable_acl:
mod = arm_compute_lib.partition_for_arm_compute_lib(
mod, params, disabled_ops=disabled_ops
)
tvm_op_count = get_cpu_op_count(mod)
assert tvm_op_count == tvm_ops, "Got {} TVM operators, expected {}".format(
tvm_op_count, tvm_ops
)
partition_count = 0
for global_var in mod.get_global_vars():
if "arm_compute_lib" in global_var.name_hint:
partition_count += 1
assert (
acl_partitions == partition_count
), "Got {} Arm Compute Library partitions, expected {}".format(
partition_count, acl_partitions
)
relay.backend.te_compiler.get().clear()
return relay.build(mod, target=target, params=params)
def build_and_run(
mod,
inputs,
outputs,
params,
device,
enable_acl=True,
no_runs=1,
tvm_ops=0,
acl_partitions=1,
config=None,
disabled_ops=["concatenate"],
):
"""Build and run the relay module."""
if config is None:
config = {}
try:
lib = build_module(
mod, device.target, params, enable_acl, tvm_ops, acl_partitions, disabled_ops
)
except Exception as e:
err_msg = "The module could not be built.\n"
if config:
err_msg += f"The test failed with the following parameters: {config}\n"
err_msg += str(e)
raise Exception(err_msg)
lib = update_lib(lib, device.device, device.cross_compile)
gen_module = graph_executor.GraphModule(lib["default"](device.device.cpu(0)))
gen_module.set_input(**inputs)
out = []
for _ in range(no_runs):
gen_module.run()
out.append([gen_module.get_output(i) for i in range(outputs)])
return out
def update_lib(lib, device, cross_compile):
"""Export the library to the remote/local device."""
lib_name = "mod.so"
temp = utils.tempdir()
lib_path = temp.relpath(lib_name)
if cross_compile:
lib.export_library(lib_path, cc=cross_compile)
else:
lib.export_library(lib_path)
device.upload(lib_path)
lib = device.load_module(lib_name)
return lib
def verify(answers, atol, rtol, verify_saturation=False, config=None):
"""Compare the array of answers. Each entry is a list of outputs."""
if config is None:
config = {}
if len(answers) < 2:
raise RuntimeError(f"No results to compare: expected at least two, found {len(answers)}")
for answer in zip_longest(*answers):
for outs in combinations(answer, 2):
try:
if verify_saturation:
assert (
np.count_nonzero(outs[0].numpy() == 255) < 0.25 * outs[0].numpy().size
), "Output is saturated: {}".format(outs[0])
assert (
np.count_nonzero(outs[0].numpy() == 0) < 0.25 * outs[0].numpy().size
), "Output is saturated: {}".format(outs[0])
tvm.testing.assert_allclose(outs[0].numpy(), outs[1].numpy(), rtol=rtol, atol=atol)
except AssertionError as e:
err_msg = "Results not within the acceptable tolerance.\n"
if config:
err_msg += f"The test failed with the following parameters: {config}\n"
err_msg += str(e)
raise AssertionError(err_msg)
def extract_acl_modules(module):
"""Get the ACL module(s) from llvm module."""
return list(
filter(lambda mod: mod.type_key == "arm_compute_lib", module.get_lib().imported_modules)
)
def verify_codegen(
module,
known_good_codegen,
num_acl_modules=1,
tvm_ops=0,
target="llvm -mtriple=aarch64-linux-gnu -mattr=+neon",
disabled_ops=["concatenate"],
):
"""Check acl codegen against a known good output."""
module = build_module(
module,
target,
tvm_ops=tvm_ops,
acl_partitions=num_acl_modules,
disabled_ops=disabled_ops,
)
acl_modules = extract_acl_modules(module)
assert len(acl_modules) == num_acl_modules, (
f"The number of Arm Compute Library modules produced ({len(acl_modules)}) does not "
f"match the expected value ({num_acl_modules})."
)
for mod in acl_modules:
source = mod.get_source("json")
codegen = json.loads(source)["nodes"]
# remove input and const names as these cannot be predetermined
for node in range(len(codegen)):
if codegen[node]["op"] == "input" or codegen[node]["op"] == "const":
codegen[node]["name"] = ""
codegen_str = json.dumps(codegen, sort_keys=True, indent=2)
known_good_codegen_str = json.dumps(known_good_codegen, sort_keys=True, indent=2)
assert codegen_str == known_good_codegen_str, (
f"The JSON produced by codegen does not match the expected result. \n"
f"Actual={codegen_str} \n"
f"Expected={known_good_codegen_str}"
)