blob: c0f7b136e11bbe7015962cade80700712c69e4f6 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import tvm
import ctypes
import os.path as osp
from sys import platform
def get_ext():
"""Return shared library extension"""
return ".dylib" if platform == "darwin" else ".so"
def load_dll(dll):
"""Load shared library
Parameters
------------
dll : str
Path for shared library
Returns
------------
The shared library
"""
try:
return [ctypes.CDLL(dll, ctypes.RTLD_GLOBAL)]
except OSError:
return []
def load_sw():
"""Load all software shared libraries"""
cur_path = osp.dirname(osp.abspath(osp.expanduser(__file__)))
sw_libname = "libsw" + get_ext()
sw_lib = osp.join(cur_path, "..", "build", sw_libname)
load_dll(sw_lib)
def init(hw_backend):
"""Init hardware and software shared library for accelerator
Parameters
------------
hw_backend : str
Hardware backend can be verilog or chisel
"""
cur_path = osp.dirname(osp.abspath(osp.expanduser(__file__)))
hw_libname = "libhw" + get_ext()
if hw_backend in ("verilog", "chisel"):
hw_lib = osp.join(cur_path, "..", "hardware", hw_backend, "build", hw_libname)
load_sw()
m = tvm.runtime.load_module(hw_lib, "vta-tsim")
f = tvm.get_global_func("tvm.vta.tsim.init")
f(m)
def load_module():
"""Return driver function"""
load_sw()
return tvm.get_global_func("tvm.vta.driver")