blob: 34cb79f67ac7d002e98da677f33094a643ff8a92 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Add hexagon thread pool test"""
import numpy as np
import tvm
import tvm.contrib.hexagon
import tvm.script
import tvm.testing
from tvm.contrib.hexagon.session import Session
from tvm.script import tir as T
from .infrastructure import get_hexagon_target
@tvm.script.ir_module
class ElemwiseSumIRModule:
"""IRModule definition for elementwise sum"""
# pylint: disable=no-self-argument,invalid-name,missing-function-docstring
@T.prim_func
def elemwise_sum_serial(a: T.handle, b: T.handle, c: T.handle, n: T.int32):
T.func_attr({"global_symbol": "elemwise_sum_serial", "tir.noalias": True})
A = T.match_buffer(a, (n,), dtype="float32")
B = T.match_buffer(b, (n,), dtype="float32")
C = T.match_buffer(c, (n,), dtype="float32")
for i in T.serial(n):
with T.sblock("C"):
vi = T.axis.spatial(n, i)
C[vi] = A[vi] + B[vi]
@T.prim_func
def elemwise_sum_parallel(a: T.handle, b: T.handle, c: T.handle, n: T.int32):
T.func_attr({"global_symbol": "elemwise_sum_parallel", "tir.noalias": True})
A = T.match_buffer(a, (n,), dtype="float32")
B = T.match_buffer(b, (n,), dtype="float32")
C = T.match_buffer(c, (n,), dtype="float32")
for i in T.parallel(n):
with T.sblock("C"):
vi = T.axis.spatial(n, i)
C[vi] = A[vi] + B[vi]
# pylint: enable=no-self-argument,invalid-name,missing-function-docstring
def generate_add_test_data(hexagon_session: Session, n=128 * 1024):
a = tvm.runtime.tensor(np.random.uniform(size=n).astype("float32"), hexagon_session.device)
b = tvm.runtime.tensor(np.random.uniform(size=n).astype("float32"), hexagon_session.device)
c = tvm.runtime.tensor(np.zeros(n, dtype="float32"), hexagon_session.device)
return (a, b, c, n)
def benchmark_func(mod, name, args, hexagon_session):
(a, b, c, n) = args
evaluator = mod.time_evaluator(name, hexagon_session.device, number=100)
return evaluator(a, b, c, n).mean
@tvm.testing.requires_hexagon
def test_speedup(hexagon_session: Session, capsys):
"""Test speedup"""
func = tvm.compile(
ElemwiseSumIRModule,
target=get_hexagon_target("v68"),
)
mod = hexagon_session.load_module(func)
args = generate_add_test_data(hexagon_session)
parallel_mean = benchmark_func(mod, "elemwise_sum_parallel", args, hexagon_session)
serial_mean = benchmark_func(mod, "elemwise_sum_serial", args, hexagon_session)
with capsys.disabled():
print("... speedup of {:.2f}".format(serial_mean / parallel_mean), end=" ")
@tvm.testing.requires_hexagon
def test_elemwise_sum_parallel(hexagon_session: Session):
"""Test parallel elementwise sum"""
func = tvm.compile(
ElemwiseSumIRModule,
target=get_hexagon_target("v68"),
)
mod = hexagon_session.load_module(func)
(a, b, c, n) = generate_add_test_data(hexagon_session)
mod["elemwise_sum_parallel"](a, b, c, n)
tvm.testing.assert_allclose(c.numpy(), a.numpy() + b.numpy())
if __name__ == "__main__":
tvm.testing.main()