blob: 41271b1ba31208218b5f60ad40a0e92867b13eb3 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import time
import ctypes
import tvm
from tvm import te
from tvm.contrib.utils import tempdir
from tvm.runtime.module import BenchmarkResult
def test_min_repeat_ms():
tmp = tempdir()
filename = tmp.relpath("log")
@tvm.register_global_func
def my_debug(filename):
"""one call lasts for 100 ms and writes one character to a file"""
time.sleep(0.1)
with open(filename, "a") as fout:
fout.write("c")
X = te.compute((), lambda: tvm.tir.call_packed("my_debug", filename))
func = tvm.tir.build(te.create_prim_func([X]))
x = tvm.runtime.empty((), dtype="int32")
ftimer = func.time_evaluator(func.entry_name, tvm.cpu(), number=1, repeat=1)
ftimer(x)
with open(filename, "r") as fin:
ct = len(fin.readline())
assert ct == 2
ftimer = func.time_evaluator(func.entry_name, tvm.cpu(), number=1, repeat=1, min_repeat_ms=1000)
ftimer(x)
# make sure we get more than 10 calls
with open(filename, "r") as fin:
ct = len(fin.readline())
assert ct > 10 + 2
def test_benchmark_result():
r = BenchmarkResult([1, 2, 2, 5])
assert r.mean == 2.5
assert r.median == 2.0
assert r.min == 1
assert r.max == 5
assert r.std == 1.5
if __name__ == "__main__":
test_min_repeat_ms()
test_benchmark_result()