blob: ab1a1d0cda28695a00ccf5c26a263688c60f0499 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Support a Relay partitioning target using Tensor Expressions."""
from typing import Callable, List, Dict
import tvm
import tvm.ir
from tvm import relay
from tvm import te
_LowerFunc = Callable[[relay.Call, List[te.Tensor]], te.Tensor]
_LOWER_MAP: Dict[str, _LowerFunc] = {}
def lower_composite(comp_name: str) -> Callable[[_LowerFunc], _LowerFunc]:
"""Register a lowering function for a given composite function name."""
def _register(f: _LowerFunc) -> _LowerFunc:
_LOWER_MAP[comp_name] = f
return f
return _register
def relay_to_runtime(target: tvm.target.Target) -> Callable[[relay.Function], tvm.runtime.Module]:
"""Create a Relay to runtime module lowering function using Tensor Expressions for lowering."""
def _relay_to_runtime(partition: relay.Function) -> tvm.runtime.Module:
"""Compile Relay functions to a runtime module using Tensor Expressions."""
assert isinstance(partition, relay.Function)
assert isinstance(partition.body, relay.Call)
assert isinstance(partition.body.op, relay.Function)
global_name = str(partition.attrs.global_symbol)
comp_func = partition.body.op
comp_name = comp_func.attrs["Composite"]
assert comp_name in _LOWER_MAP
assert isinstance(comp_func.body, relay.Call)
op = comp_func.body
inputs = []
for i, param in enumerate(comp_func.params):
inputs.append(
te.placeholder(
param.checked_type.shape,
name=f"input_{i}",
dtype=param.checked_type.dtype,
)
)
output = _LOWER_MAP[comp_name](op, inputs)
prim_func = te.create_prim_func(inputs + [output])
return tvm.build(prim_func, target=target, name=global_name)
return _relay_to_runtime