blob: 0315a8f11b39ceadb8cd265376481405831684e4 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Example extension package of TVM."""
from __future__ import absolute_import
import os
import ctypes
# Import TVM first to get library symbols
import tvm
from tvm import te
def load_lib():
"""Load library, the functions will be registered into TVM"""
curr_path = os.path.dirname(os.path.abspath(os.path.expanduser(__file__)))
# load in as global so the global extern symbol is visible to other dll.
lib = ctypes.CDLL(os.path.join(curr_path, "../../lib/libtvm_ext.so"), ctypes.RTLD_GLOBAL)
return lib
_LIB = load_lib()
# Expose two functions into python
bind_add = tvm.get_global_func("tvm_ext.bind_add")
sym_add = tvm.get_global_func("tvm_ext.sym_add")
ivec_create = tvm.get_global_func("tvm_ext.ivec_create")
ivec_get = tvm.get_global_func("tvm_ext.ivec_get")
@tvm.register_object("tvm_ext.IntVector")
class IntVec(tvm.Object):
"""Example for using extension class in c++ """
@property
def _tvm_handle(self):
return self.handle.value
def __getitem__(self, idx):
return ivec_get(self, idx)
nd_create = tvm.get_global_func("tvm_ext.nd_create")
nd_add_two = tvm.get_global_func("tvm_ext.nd_add_two")
nd_get_additional_info = tvm.get_global_func("tvm_ext.nd_get_additional_info")
@tvm.register_object("tvm_ext.NDSubClass")
class NDSubClass(tvm.nd.NDArrayBase):
"""Example for subclassing TVM's NDArray infrastructure.
By inheriting TVM's NDArray, external libraries could
leverage TVM's FFI without any modification.
"""
@staticmethod
def create(additional_info):
return nd_create(additional_info)
@property
def additional_info(self):
return nd_get_additional_info(self)
def __add__(self, other):
return nd_add_two(self, other)