blob: 70e021910ab0cf620c58c5c98297425817935fd6 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Definition of generic operator strategy."""
# pylint: disable=invalid-name,unused-argument
import logging
import re
from tvm import _ffi, ir, te, topi
from tvm.target import generic_func, override_native_generic_func
from tvm.topi.utils import get_const_float, get_const_int, get_const_tuple, get_float_tuple
from .. import op as _op
logger = logging.getLogger("strategy")
def naive_schedule(_, outs, target):
"""Return the naive default schedule.
This function acts as a placeholder for op implementations that uses auto-scheduler.
Implemenations using this function should only be used along with auto-scheduler.
"""
if "gpu" in target.keys:
# For GPU, we at least need thread binding to make a valid schedule.
# So the naive schedule cannot be compiled.
logger.debug(
"Cannot compile for GPU targets if no tuned schedule is found. "
"Please see the warning messages above for more information about the failed workloads."
)
return te.create_schedule(outs[-1].op)
def wrap_topi_schedule(topi_schedule):
"""Wrap TOPI schedule which doesn't use attrs"""
def wrapper(attrs, outs, target):
with target:
return topi_schedule(outs)
return wrapper
def wrap_topi_compute(topi_compute):
"""Wrap TOPI compute which doesn't use attrs"""
def wrapper(attrs, inputs, out_type):
return [topi_compute(*inputs)]
return wrapper
def get_conv2d_in_channels(data_shape, data_layout):
"""Get conv2d input channels"""
data_shape = get_const_tuple(data_shape)
if len(data_shape) == 4:
idx = data_layout.find("C")
assert idx >= 0, "Invalid conv2d data layout {}".format(data_layout)
return data_shape[idx]
if re.match(r"NCHW\d*c", data_layout):
# NCHW[8]c
return data_shape[1] * data_shape[4]
raise ValueError("Unknown conv2d data layout {}".format(data_layout))
def get_conv2d_out_channels(kernel_shape, kernel_layout):
"""Get conv2d output channels"""
kernel_shape = get_const_tuple(kernel_shape)
if len(kernel_shape) == 4:
idx = kernel_layout.find("O")
assert idx >= 0, "Invalid conv2d kernel layout {}".format(kernel_layout)
return kernel_shape[idx]
if re.match(r"OIHW\d*i\d*o", kernel_layout):
return kernel_shape[0] * kernel_shape[5]
if re.match(r"OIHW\d*o", kernel_layout):
return kernel_shape[0] * kernel_shape[4]
raise ValueError("Unknown conv2d kernel layout {}".format(kernel_layout))
def is_depthwise_conv2d(data_shape, data_layout, kernel_shape, kernel_layout, groups):
ic = get_conv2d_in_channels(data_shape, data_layout)
oc = get_conv2d_out_channels(kernel_shape, kernel_layout)
return ic == oc == groups
@generic_func
def schedule_injective(attrs, outs, target):
"""Schedule injective ops"""
with target:
return topi.generic.schedule_injective(outs)
@generic_func
def schedule_reduce(attrs, outs, target):
"""Schedule reduction ops"""
with target:
return topi.generic.schedule_reduce(outs)
_op._schedule_injective = schedule_injective
_op._schedule_reduce = schedule_reduce
# concatenate
@generic_func
def schedule_concatenate(attrs, outs, target):
"""Schedule concatenate op"""
with target:
return topi.generic.schedule_injective(outs)
# pool
@generic_func
def schedule_pool(attrs, outs, target):
"""Schedule pooling ops"""
with target:
return topi.generic.schedule_pool(outs, attrs.layout)
# pool_grad
@generic_func
def schedule_pool_grad(attrs, outs, target):
"""Schedule pooling gradient ops"""
with target:
return topi.generic.schedule_pool_grad(outs)
# adaptive pool
@generic_func
def schedule_adaptive_pool(attrs, outs, target):
"""Schedule adaptive pooling ops"""
with target:
return topi.generic.schedule_adaptive_pool(outs)
# softmax
def wrap_compute_softmax(topi_compute):
"""Wrap softmax topi compute"""
def _compute_softmax(attrs, inputs, out_type):
axis = attrs.get_int("axis")
return [topi_compute(inputs[0], axis)]
return _compute_softmax
@override_native_generic_func("softmax_strategy")
def softmax_strategy(attrs, inputs, out_type, target):
"""softmax generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_softmax(topi.nn.softmax),
wrap_topi_schedule(topi.generic.schedule_softmax),
name="softmax.generic",
)
return strategy
@override_native_generic_func("fast_softmax_strategy")
def fast_softmax_strategy(attrs, inputs, out_type, target):
"""fast softmax generic strategy"""
# NOTE: This op does not have an optimized manual schedule,
# so it should only be used together with auto-scheduler.
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_softmax(topi.nn.fast_softmax),
naive_schedule,
name="fast_softmax.generic",
)
return strategy
# log_softmax
@generic_func
def schedule_log_softmax(attrs, outs, target):
"""Schedule log_softmax op"""
with target:
return topi.generic.schedule_softmax(outs)
# lrn
@generic_func
def schedule_lrn(attrs, outs, target):
"""Schedule LRN op"""
with target:
return topi.generic.schedule_lrn(outs)
# bitpack
@generic_func
def schedule_bitpack(attrs, outs, target):
"""Schedule bitpack"""
with target:
return topi.generic.schedule_bitpack(outs)
get_auto_scheduler_rewritten_layout = _ffi.get_global_func(
"relay.attrs.get_auto_scheduler_rewritten_layout"
)
# conv2d
def wrap_compute_conv2d(
topi_compute,
need_data_layout=False,
need_out_layout=False,
has_groups=False,
need_auto_scheduler_layout=False,
):
"""Wrap conv2d topi compute"""
def _compute_conv2d(attrs, inputs, out_type):
padding = get_const_tuple(attrs.padding)
strides = get_const_tuple(attrs.strides)
dilation = get_const_tuple(attrs.dilation)
data_layout = attrs.get_str("data_layout")
out_layout = attrs.get_str("out_layout")
out_dtype = attrs.out_dtype
out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype
args = [inputs[0], inputs[1], strides, padding, dilation]
if has_groups:
args.append(attrs.groups)
if need_data_layout:
args.append(data_layout)
if need_out_layout:
args.append(out_layout)
args.append(out_dtype)
if need_auto_scheduler_layout:
args.append(get_auto_scheduler_rewritten_layout(attrs))
return [topi_compute(*args)]
return _compute_conv2d
@override_native_generic_func("conv2d_strategy")
def conv2d_strategy(attrs, inputs, out_type, target):
"""conv2d generic strategy"""
logger.warning("conv2d is not optimized for this platform.")
strategy = _op.OpStrategy()
data, kernel = inputs
dilation = get_const_tuple(attrs.dilation)
groups = attrs.groups
layout = attrs.data_layout
kernel_layout = attrs.kernel_layout
(dilation_h, dilation_w) = dilation
if dilation_h < 1 or dilation_w < 1:
raise ValueError("dilation should be positive value")
if groups == 1:
if layout == "NCHW":
assert kernel_layout == "OIHW"
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.conv2d_nchw),
wrap_topi_schedule(topi.generic.schedule_conv2d_nchw),
name="conv2d_nchw.generic",
)
elif layout == "NHWC":
assert kernel_layout == "HWIO"
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.conv2d_nhwc),
wrap_topi_schedule(topi.generic.schedule_conv2d_nhwc),
name="conv2d_nhwc.generic",
)
elif layout == "HWCN":
assert kernel_layout == "HWIO"
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.conv2d_hwcn),
wrap_topi_schedule(topi.generic.schedule_conv2d_hwcn),
name="conv2d_hwcn.generic",
)
else:
raise RuntimeError("Unsupported conv2d layout {}".format(layout))
elif is_depthwise_conv2d(data.shape, layout, kernel.shape, kernel_layout, groups):
if layout == "NCHW":
assert kernel_layout == "OIHW"
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.depthwise_conv2d_nchw),
wrap_topi_schedule(topi.generic.schedule_depthwise_conv2d_nchw),
name="depthwise_conv2d_nchw.generic",
)
elif layout == "NHWC":
assert kernel_layout == "HWOI"
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.depthwise_conv2d_nhwc),
wrap_topi_schedule(topi.generic.schedule_depthwise_conv2d_nhwc),
name="depthwise_conv2d_nhwc.generic",
)
else:
raise RuntimeError("Unsupported depthwise_conv2d layout {}".format(layout))
else: # group_conv2d
if layout == "NCHW":
assert kernel_layout == "OIHW"
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.group_conv2d_nchw, has_groups=True),
wrap_topi_schedule(topi.generic.schedule_group_conv2d_nchw),
name="group_conv2d_nchw.generic",
)
elif layout == "NHWC":
assert kernel_layout == "HWIO"
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.group_conv2d_nhwc, has_groups=True),
wrap_topi_schedule(topi.generic.schedule_group_conv2d_nhwc),
name="group_conv2d_nhwc.generic",
)
else:
raise RuntimeError("Unsupported group_conv2d layout {}".format(layout))
return strategy
# conv2d_NCHWc
@override_native_generic_func("conv2d_NCHWc_strategy")
def conv2d_NCHWc_strategy(attrs, inputs, out_type, target):
"""conv2d_NCHWc generic strategy"""
logger.warning("conv2d_NCHWc is not optimized for this platform.")
strategy = _op.OpStrategy()
if inputs[0].dtype == "int8" or inputs[0].dtype == "uint8":
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.conv2d_NCHWc_int8, True, True),
wrap_topi_schedule(topi.generic.schedule_conv2d_NCHWc_int8),
name="conv2d_NCHWc_int8.generic",
)
else:
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.conv2d_NCHWc, True, True),
wrap_topi_schedule(topi.generic.schedule_conv2d_NCHWc),
name="conv2d_NCHWc.generic",
)
return strategy
# depthwise_conv2d_NCHWc
@override_native_generic_func("depthwise_conv2d_NCHWc_strategy")
def depthwise_conv2d_NCHWc_strategy(attrs, inputs, out_type, target):
"""depthwise_conv2d generic strategy"""
logger.warning("depthwise_conv2d_NCHWc is not optimized for this platform.")
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_conv2d(topi.nn.depthwise_conv2d_NCHWc, True, True),
wrap_topi_schedule(topi.generic.schedule_depthwise_conv2d_NCHWc),
name="depthwise_conv2d_NCHWc.generic",
)
return strategy
# conv2d_winograd_without_weight_transform
@override_native_generic_func("conv2d_winograd_without_weight_transform_strategy")
def conv2d_winograd_without_weight_transfrom_strategy(attrs, inputs, out_type, target):
"""conv2d_winograd_without_weight_transfrom generic strategy"""
raise ValueError("No generic implemenation for conv2d_winograd_without_weight_transform")
# conv2d_gemm_without_weight_transform
@override_native_generic_func("conv2d_gemm_without_weight_transform_strategy")
def conv2d_gemm_without_weight_transform_strategy(attrs, inputs, out_type, target):
"""conv2d_gemm_without_weight_transfrom generic strategy"""
raise ValueError("No generic implemenation for conv2d_gemm_without_weight_transform")
# conv2d_winograd_weight_transform
@generic_func
def schedule_conv2d_winograd_weight_transform(attrs, outs, target):
"""Schedule conv2d_winograd_weight_transform"""
with target:
return topi.generic.schedule_conv2d_winograd_weight_transform(outs)
# conv2d_winograd_nnpack_weight_transform
@generic_func
def schedule_conv2d_winograd_nnpack_weight_transform(attrs, outs, target):
"""Schedule conv2d_winograd_nnpack_weight_transform"""
with target:
return topi.generic.schedule_conv2d_winograd_nnpack_weight_transform(outs)
# conv2d_gemm_weight_transform
@generic_func
def schedule_conv2d_gemm_weight_transform(attrs, outs, target):
"""Schedule conv2d_gemm_weight_transform"""
with target:
return topi.generic.schedule_conv2d_gemm_weight_transform(outs)
# deformable_conv2d
def wrap_compute_deformable_conv2d(topi_compute):
"""wrap deformable_conv2d topi compute"""
def _compute_deformable_conv2d(attrs, inputs, out_dtype):
padding = get_const_tuple(attrs.padding)
strides = get_const_tuple(attrs.strides)
dilation = get_const_tuple(attrs.dilation)
deformable_groups = attrs.deformable_groups
groups = attrs.groups
out_dtype = attrs.out_dtype
out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype
out = topi_compute(
inputs[0],
inputs[1],
inputs[2],
strides,
padding,
dilation,
deformable_groups,
groups,
out_dtype,
)
return [out]
return _compute_deformable_conv2d
@override_native_generic_func("deformable_conv2d_strategy")
def deformable_conv2d_strategy(attrs, inputs, out_type, target):
"""deformable_conv2d generic strategy"""
layout = attrs.data_layout
strategy = _op.OpStrategy()
if layout == "NCHW":
strategy.add_implementation(
wrap_compute_deformable_conv2d(topi.nn.deformable_conv2d_nchw),
wrap_topi_schedule(topi.generic.schedule_deformable_conv2d_nchw),
name="deformable_conv2d_nchw.generic",
)
elif layout == "NHWC":
# This implementation should never be picked by autotvm
strategy.add_implementation(
wrap_compute_deformable_conv2d(topi.nn.deformable_conv2d_nhwc),
naive_schedule,
name="deformable_conv2d_nhwc.generic",
)
else:
raise RuntimeError("Layout %s is not supported in deformable conv2d" % layout)
return strategy
# conv2d_transpose
def wrap_compute_conv2d_transpose(topi_compute):
"""wrap conv2d_transpose topi compute"""
def compute_conv2d_transpose(attrs, inputs, out_dtype):
"""Compute definition of conv2d_transpose"""
padding = get_const_tuple(attrs.padding)
strides = get_const_tuple(attrs.strides)
out_dtype = attrs.out_dtype
out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype
output_padding = get_const_tuple(attrs.output_padding)
out = topi_compute(inputs[0], inputs[1], strides, padding, out_dtype, output_padding)
return [out]
return compute_conv2d_transpose
@override_native_generic_func("conv2d_transpose_strategy")
def conv2d_transpose_strategy(attrs, inputs, out_type, target):
"""conv2d_transpose generic strategy"""
logger.warning("conv2d_transpose is not optimized for this platform.")
layout = attrs.data_layout
dilation = get_const_tuple(attrs.dilation)
groups = attrs.groups
assert layout == "NCHW", "only support nchw for now"
assert dilation == (1, 1), "not support dilate now"
assert groups == 1, "only support groups == 1 for now"
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_conv2d_transpose(topi.nn.conv2d_transpose_nchw),
wrap_topi_schedule(topi.generic.schedule_conv2d_transpose_nchw),
name="conv2d_transpose_nchw.generic",
)
return strategy
# conv3d_transpose
def wrap_compute_conv3d_transpose(topi_compute):
"""wrap conv3d_transpose topi compute"""
def compute_conv3d_transpose(attrs, inputs, out_dtype):
"""Compute definition of conv3d_transpose"""
padding = get_const_tuple(attrs.padding)
strides = get_const_tuple(attrs.strides)
output_padding = get_const_tuple(attrs.output_padding)
out_dtype = attrs.out_dtype
out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype
out = topi_compute(inputs[0], inputs[1], strides, padding, out_dtype, output_padding)
return [out]
return compute_conv3d_transpose
@override_native_generic_func("conv3d_transpose_strategy")
def conv3d_transpose_strategy(attrs, inputs, out_type, target):
"""conv3d_transpose generic strategy"""
logger.warning("conv3d_transpose is not optimized for this platform.")
layout = attrs.data_layout
dilation = get_const_tuple(attrs.dilation)
groups = attrs.groups
assert layout == "NCDHW", "only support ncdhw for now"
assert dilation == (1, 1, 1), "not support dilate now"
assert groups == 1, "only support groups == 1 for now"
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_conv3d_transpose(topi.nn.conv3d_transpose_ncdhw),
wrap_topi_schedule(topi.generic.schedule_conv3d_transpose_ncdhw),
name="conv3d_transpose_ncdhw.generic",
)
return strategy
# conv3d
def wrap_compute_conv3d(topi_compute, need_layout=False, need_auto_scheduler_layout=False):
"""wrap conv3d topi compute"""
def _compute_conv3d(attrs, inputs, out_type):
padding = get_const_tuple(attrs.padding)
strides = get_const_tuple(attrs.strides)
dilation = get_const_tuple(attrs.dilation)
groups = attrs.groups
layout = attrs.data_layout
out_dtype = attrs.out_dtype
out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype
(dilation_d, dilation_h, dilation_w) = dilation
if dilation_d < 1 or dilation_h < 1 or dilation_w < 1:
raise ValueError("Dilation should be positive value")
if groups != 1:
raise ValueError("Not support arbitrary group number for conv3d")
args = [inputs[0], inputs[1], strides, padding, dilation]
if need_layout:
args.append(layout)
args.append(out_dtype)
if need_auto_scheduler_layout:
args.append(get_auto_scheduler_rewritten_layout(attrs))
return [topi_compute(*args)]
return _compute_conv3d
@override_native_generic_func("conv3d_strategy")
def conv3d_strategy(attrs, inputs, out_type, target):
"""conv3d generic strategy"""
logger.warning("conv3d is not optimized for this platform.")
strategy = _op.OpStrategy()
layout = attrs.data_layout
if layout == "NCDHW":
strategy.add_implementation(
wrap_compute_conv3d(topi.nn.conv3d_ncdhw),
wrap_topi_schedule(topi.generic.schedule_conv3d_ncdhw),
name="conv3d_ncdhw.generic",
)
elif layout == "NDHWC":
strategy.add_implementation(
wrap_compute_conv3d(topi.nn.conv3d_ndhwc),
wrap_topi_schedule(topi.generic.schedule_conv3d_ndhwc),
name="conv3d_ndhwc.generic",
)
else:
raise ValueError("Not support this layout {} yet".format(layout))
return strategy
# conv3d_winograd_without_weight_transform
@override_native_generic_func("conv3d_winograd_without_weight_transform_strategy")
def conv3d_winograd_without_weight_transfrom_strategy(attrs, inputs, out_type, target):
"""conv3d_winograd_without_weight_transfrom generic strategy"""
raise ValueError("No generic implemenation for conv3d_winograd_without_weight_transform")
# conv3d_winograd_weight_transform
@generic_func
def schedule_conv3d_winograd_weight_transform(attrs, outs, target):
"""Schedule conv3d_winograd_weight_transform"""
with target:
return topi.generic.schedule_conv3d_winograd_weight_transform(outs)
# conv1d
def wrap_compute_conv1d(topi_compute):
"""wrap conv1d topi compute"""
def _compute_conv1d(attrs, inputs, out_type):
"""Compute definition of conv1d"""
strides = get_const_tuple(attrs.strides)
padding = get_const_tuple(attrs.padding)
dilation = get_const_tuple(attrs.dilation)
out_dtype = attrs.out_dtype
out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype
return [topi_compute(inputs[0], inputs[1], strides, padding, dilation, out_dtype)]
return _compute_conv1d
@override_native_generic_func("conv1d_strategy")
def conv1d_strategy(attrs, inputs, out_type, target):
"""conv1d generic strategy"""
logger.warning("conv1d is not optimized for this platform.")
layout = attrs.data_layout
dilation = get_const_tuple(attrs.dilation)
if dilation[0] < 1:
raise ValueError("dilation should be a positive value")
strategy = _op.OpStrategy()
if layout == "NCW":
strategy.add_implementation(
wrap_compute_conv1d(topi.nn.conv1d_ncw),
wrap_topi_schedule(topi.generic.schedule_conv1d_ncw),
name="conv1d_ncw.generic",
)
elif layout == "NWC":
strategy.add_implementation(
wrap_compute_conv1d(topi.nn.conv1d_nwc),
wrap_topi_schedule(topi.generic.schedule_conv1d_nwc),
name="conv1d_nwc.generic",
)
else:
raise ValueError("Unsupported conv1d layout {}".format(layout))
return strategy
# conv1d_transpose
def wrap_compute_conv1d_transpose(topi_compute):
"""wrap conv1d_transpose topi compute"""
def _compute_conv1d_tranpsoe(attrs, inputs, out_type):
padding = get_const_tuple(attrs.padding)
strides = get_const_tuple(attrs.strides)
out_dtype = attrs.out_dtype
out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype
output_padding = get_const_tuple(attrs.output_padding)
out = topi_compute(inputs[0], inputs[1], strides, padding, out_dtype, output_padding)
return [out]
return _compute_conv1d_tranpsoe
@override_native_generic_func("conv1d_transpose_strategy")
def conv1d_transpose_strategy(attrs, inputs, out_type, target):
"""conv1d_transpose generic strategy"""
logger.warning("conv1d_transpose is not optimized for this platform.")
strategy = _op.OpStrategy()
layout = attrs.data_layout
dilation = get_const_tuple(attrs.dilation)
groups = attrs.groups
assert layout == "NCW", "conv1d_transpose ncw only supported"
assert dilation == (1,), "conv1d_transpose dilation is not supported"
assert groups == 1, "conv1d_transpose groups == 1 only supported"
strategy.add_implementation(
wrap_compute_conv1d_transpose(topi.nn.conv1d_transpose_ncw),
wrap_topi_schedule(topi.generic.schedule_conv1d_transpose_ncw),
name="conv1d_transpose_ncw.generic",
)
return strategy
# dilation2d
def wrap_compute_dilation2d(topi_compute, need_data_layout=False):
"""Wrap dilation2d topi compute"""
def _compute_dilation2d(attrs, inputs, out_type):
padding = get_const_tuple(attrs.padding)
strides = get_const_tuple(attrs.strides)
dilations = get_const_tuple(attrs.dilations)
data_layout = attrs.get_str("data_layout")
out_dtype = attrs.out_dtype
out_dtype = inputs[0].dtype if out_dtype in ("same", "") else out_dtype
args = [inputs[0], inputs[1], strides, padding, dilations]
if need_data_layout:
args.append(data_layout)
args.append(out_dtype)
return [topi_compute(*args)]
return _compute_dilation2d
@override_native_generic_func("dilation2d_strategy")
def dilation2d_strategy(attrs, inputs, out_type, target):
"""dilation2d_strategy generic strategy"""
logger.warning("dilation2d_strategy is not optimized for this platform.")
strategy = _op.OpStrategy()
dilations = get_const_tuple(attrs.dilations)
layout = attrs.data_layout
kernel_layout = attrs.kernel_layout
assert layout in ["NCHW", "NHWC"]
(dilation_h, dilation_w) = dilations
if dilation_h < 1 or dilation_w < 1:
raise ValueError("dilation should be positive value")
if layout == "NCHW":
assert kernel_layout == "IHW"
strategy.add_implementation(
wrap_compute_dilation2d(topi.image.dilation2d_nchw),
wrap_topi_schedule(topi.generic.schedule_dilation2d_nchw),
name="dilation2d_nchw.generic",
)
elif layout == "NHWC":
assert kernel_layout == "HWI"
strategy.add_implementation(
wrap_compute_dilation2d(topi.image.dilation2d_nhwc),
wrap_topi_schedule(topi.generic.schedule_dilation2d_nhwc),
name="dilation2d_nhwc.generic",
)
else:
raise RuntimeError("Unsupported dilation2d layout {}".format(layout))
return strategy
# dense
def wrap_compute_dense(topi_compute, need_auto_scheduler_layout=False):
"""wrap dense topi compute"""
def _compute_dense(attrs, inputs, out_type):
"""Compute definition of dense"""
out_dtype = attrs.out_dtype
out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype
args = [inputs[0], inputs[1], None, out_dtype]
if need_auto_scheduler_layout:
args.append(get_auto_scheduler_rewritten_layout(attrs))
return [topi_compute(*args)]
return _compute_dense
@override_native_generic_func("dense_strategy")
def dense_strategy(attrs, inputs, out_type, target):
"""dense generic strategy"""
logger.warning("dense is not optimized for this platform.")
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_dense(topi.nn.dense),
wrap_topi_schedule(topi.generic.schedule_dense),
name="dense.generic",
)
return strategy
@override_native_generic_func("dense_pack_strategy")
def dense_pack_strategy(attrs, inputs, out_type, target):
"""dense_pack generic strategy"""
logger.warning("dense_pack is not optimized for this platform.")
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_dense(topi.nn.dense_pack),
wrap_topi_schedule(topi.generic.schedule_dense),
name="dense_pack.generic",
)
return strategy
# batch_matmul
def wrap_compute_batch_matmul(topi_compute, need_auto_scheduler_layout=False, need_out_dtype=False):
"""wrap batch_matmul topi compute"""
def _compute_batch_matmul(attrs, inputs, out_type):
args = [inputs[0], inputs[1], out_type.shape]
if need_auto_scheduler_layout:
args.append(get_auto_scheduler_rewritten_layout(attrs))
if need_out_dtype:
args.append(out_type.dtype)
return [topi_compute(*args)]
return _compute_batch_matmul
@override_native_generic_func("batch_matmul_strategy")
def batch_matmul_strategy(attrs, inputs, out_type, target):
"""batch_matmul generic strategy"""
logger.warning("batch_matmul is not optimized for this platform.")
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_batch_matmul(topi.nn.batch_matmul),
wrap_topi_schedule(topi.generic.schedule_batch_matmul),
name="batch_matmul.generic",
)
return strategy
# sparse dense
def wrap_compute_sparse_dense(topi_compute):
"""wrap sparse dense topi compute"""
def _compute_sparse_dense(attrs, inputs, out_type):
return [topi_compute(inputs[0], inputs[1], inputs[2], inputs[3], attrs["sparse_lhs"])]
return _compute_sparse_dense
@override_native_generic_func("sparse_dense_strategy")
def sparse_dense_strategy(attrs, inputs, out_type, target):
"""sparse dense generic strategy"""
logger.warning("sparse dense is not optimized for this platform.")
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_sparse_dense(topi.nn.sparse_dense),
wrap_topi_schedule(topi.generic.schedule_sparse_dense),
name="sparse_dense.generic",
)
return strategy
@override_native_generic_func("sparse_dense_padded_strategy")
def sparse_dense_padded_strategy(attrs, inputs, out_type, target):
"""sparse dense padded generic strategy"""
raise NotImplementedError("sparse_dense_padded is only implemented for cuda")
# sparse_add
def wrap_compute_sparse_add(topi_compute):
"""wrap sparse add topi compute"""
def _compute_sparse_add(attrs, inputs, out_type):
return [topi_compute(inputs[0], inputs[1], inputs[2], inputs[3])]
return _compute_sparse_add
@override_native_generic_func("sparse_add_strategy")
def sparse_add_strategy(attrs, inputs, out_type, target):
"""sparse add generic strategy"""
logger.warning("sparse add is not optimized for this platform.")
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_sparse_add(topi.nn.sparse_add),
wrap_topi_schedule(topi.generic.schedule_extern),
name="sparse_add.generic",
)
return strategy
# sparse_transpose
@generic_func
def schedule_sparse_transpose(attrs, outs, target):
"""schedule sparse_transpose"""
with target:
return topi.generic.schedule_sparse_transpose(outs)
# sort
def wrap_compute_sort(topi_compute):
"""Wrap sort topi compute"""
def _compute_sort(attrs, inputs, _):
axis = get_const_int(attrs.axis)
is_ascend = bool(get_const_int(attrs.is_ascend))
return [topi_compute(inputs[0], axis=axis, is_ascend=is_ascend)]
return _compute_sort
@override_native_generic_func("sort_strategy")
def sort_strategy(attrs, inputs, out_type, target):
"""sort generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_sort(topi.sort),
wrap_topi_schedule(topi.generic.schedule_sort),
name="sort.generic",
)
return strategy
# argsort
def wrap_compute_argsort(topi_compute):
"""Wrap argsort topi compute"""
def _compute_argsort(attrs, inputs, _):
axis = get_const_int(attrs.axis)
is_ascend = bool(get_const_int(attrs.is_ascend))
dtype = attrs.dtype
return [topi_compute(inputs[0], axis=axis, is_ascend=is_ascend, dtype=dtype)]
return _compute_argsort
@override_native_generic_func("argsort_strategy")
def argsort_strategy(attrs, inputs, out_type, target):
"""argsort generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_argsort(topi.argsort),
wrap_topi_schedule(topi.generic.schedule_argsort),
name="argsort.generic",
)
return strategy
# topk
def wrap_compute_topk(topi_compute):
"""Wrap topk compute"""
def _compute_topk(attrs, inputs, out_type):
if attrs.k is not None:
k = attrs.k
else:
k = inputs[1]
axis = get_const_int(attrs.axis)
ret_type = attrs.ret_type
is_ascend = bool(get_const_int(attrs.is_ascend))
dtype = attrs.dtype
out = topi_compute(inputs[0], k, axis, ret_type, is_ascend, dtype)
out = out if isinstance(out, list) else [out]
return out
return _compute_topk
@override_native_generic_func("topk_strategy")
def topk_strategy(attrs, inputs, out_type, target):
"""topk generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_topk(topi.topk),
wrap_topi_schedule(topi.generic.schedule_topk),
name="topk.generic",
)
return strategy
# multibox_prior
def wrap_compute_multibox_prior(topi_compute):
"""Wrap multibox_prior compute"""
def _compute_multibox_prior(attrs, inputs, _):
"""Compute definition of multibox_prior"""
sizes = get_float_tuple(attrs.sizes)
ratios = get_float_tuple(attrs.ratios)
steps = get_float_tuple(attrs.steps)
offsets = get_float_tuple(attrs.offsets)
clip = bool(get_const_int(attrs.clip))
return [topi_compute(inputs[0], sizes, ratios, steps, offsets, clip)]
return _compute_multibox_prior
@override_native_generic_func("multibox_prior_strategy")
def multibox_prior_strategy(attrs, inputs, out_type, target):
"""multibox_prior generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_multibox_prior(topi.vision.ssd.multibox_prior),
wrap_topi_schedule(topi.generic.schedule_multibox_prior),
name="multibox_prior.generic",
)
return strategy
# multibox_transform_loc
def wrap_compute_multibox_transform_loc(topi_compute):
"""Wrap multibox_transform_loc compute"""
def _compute_multibox_transform_loc(attrs, inputs, _):
"""Compute definition of multibox_detection"""
clip = bool(get_const_int(attrs.clip))
threshold = get_const_float(attrs.threshold)
variances = get_float_tuple(attrs.variances)
return topi_compute(inputs[0], inputs[1], inputs[2], clip, threshold, variances)
return _compute_multibox_transform_loc
@override_native_generic_func("multibox_transform_loc_strategy")
def multibox_transform_loc_strategy(attrs, inputs, out_type, target):
"""schedule multibox_transform_loc"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_multibox_transform_loc(topi.vision.ssd.multibox_transform_loc),
wrap_topi_schedule(topi.generic.schedule_multibox_transform_loc),
name="multibox_transform_loc.generic",
)
return strategy
# get_valid_counts
def wrap_compute_get_valid_counts(topi_compute):
"""wrap get_valid_counts topi compute"""
def _compute_get_valid_counts(attrs, inputs, out_type):
score_threshold = inputs[1]
id_index = get_const_int(attrs.id_index)
score_index = get_const_int(attrs.score_index)
if attrs.score_threshold is not None:
score_threshold = get_const_float(attrs.score_threshold)
return topi_compute(inputs[0], score_threshold, id_index, score_index)
return _compute_get_valid_counts
@override_native_generic_func("get_valid_counts_strategy")
def get_valid_counts_strategy(attrs, inputs, out_type, target):
"""get_valid_counts generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_get_valid_counts(topi.vision.get_valid_counts),
wrap_topi_schedule(topi.generic.schedule_get_valid_counts),
name="get_valid_counts.generic",
)
return strategy
# non-maximum suppression
def wrap_compute_nms(topi_compute):
"""wrap nms topi compute"""
def _compute_nms(attrs, inputs, out_type):
max_output_size = inputs[3]
iou_threshold = inputs[4]
return_indices = bool(get_const_int(attrs.return_indices))
force_suppress = bool(get_const_int(attrs.force_suppress))
top_k = get_const_int(attrs.top_k)
coord_start = get_const_int(attrs.coord_start)
score_index = get_const_int(attrs.score_index)
id_index = get_const_int(attrs.id_index)
invalid_to_bottom = bool(get_const_int(attrs.invalid_to_bottom))
if return_indices:
return topi_compute(
inputs[0],
inputs[1],
inputs[2],
max_output_size,
iou_threshold,
force_suppress,
top_k,
coord_start,
score_index,
id_index,
return_indices,
invalid_to_bottom,
)
return [
topi_compute(
inputs[0],
inputs[1],
inputs[2],
max_output_size,
iou_threshold,
force_suppress,
top_k,
coord_start,
score_index,
id_index,
return_indices,
invalid_to_bottom,
)
]
return _compute_nms
@override_native_generic_func("non_max_suppression_strategy")
def nms_strategy(attrs, inputs, out_type, target):
"""nms generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_nms(topi.vision.non_max_suppression),
wrap_topi_schedule(topi.generic.schedule_nms),
name="nms.generic",
)
return strategy
def wrap_compute_all_class_nms(topi_compute):
"""wrap all class nms topi compute"""
def _compute_nms(attrs, inputs, out_type):
max_output_size = inputs[2]
iou_threshold = inputs[3]
score_threshold = inputs[4]
return topi_compute(inputs[0], inputs[1], max_output_size, iou_threshold, score_threshold)
return _compute_nms
@override_native_generic_func("all_class_non_max_suppression_strategy")
def all_class_nms_strategy(attrs, inputs, out_type, target):
"""all class nms generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_all_class_nms(topi.vision.all_class_non_max_suppression),
wrap_topi_schedule(topi.generic.schedule_nms),
name="all_class_nms.generic",
)
return strategy
# roi_align
def wrap_compute_roi_align(topi_compute):
"""wrap roi_align topi compute"""
def _compute_roi_align(attrs, inputs, out_type):
pooled_size = get_const_tuple(attrs.pooled_size)
mode = bytes(attrs.mode, "utf-8")
return [
topi_compute(
inputs[0],
inputs[1],
pooled_size=pooled_size,
spatial_scale=attrs.spatial_scale,
sample_ratio=attrs.sample_ratio,
mode=mode,
)
]
return _compute_roi_align
@override_native_generic_func("roi_align_strategy")
def roi_align_strategy(attrs, inputs, out_type, target):
"""roi_align generic strategy"""
strategy = _op.OpStrategy()
layout = attrs.layout
if layout == "NCHW":
strategy.add_implementation(
wrap_compute_roi_align(topi.vision.rcnn.roi_align_nchw),
wrap_topi_schedule(topi.generic.schedule_roi_align),
name="roi_align.generic",
)
else:
assert layout == "NHWC", "layout must be NCHW or NHWC."
strategy.add_implementation(
wrap_compute_roi_align(topi.vision.rcnn.roi_align_nhwc),
wrap_topi_schedule(topi.generic.schedule_roi_align),
name="roi_align.generic",
)
return strategy
# sparse_fill_empty_rows
@override_native_generic_func("sparse_fill_empty_rows_strategy")
def sparse_fill_empty_rows_strategy(attrs, outs, out_type, target):
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_sparse_fill_empty_rows(topi.sparse_fill_empty_rows),
wrap_topi_schedule(topi.generic.schedule_sparse_fill_empty_rows),
name="sparse_fill_empty_rows.generic",
)
return strategy
def wrap_compute_sparse_fill_empty_rows(topi_compute):
"""Wrap sparse_fill_empty_rows compute"""
def _compute_sparse_fill_empty_rows(attrs, inputs, output_type):
return topi_compute(
inputs[0],
inputs[1],
inputs[2],
inputs[3],
output_type.fields[0].shape,
output_type.fields[1].shape,
output_type.fields[2].shape,
)
return _compute_sparse_fill_empty_rows
# sparse_reshape
@override_native_generic_func("sparse_reshape_strategy")
def sparse_reshape_strategy(attrs, outs, out_type, target):
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_sparse_reshape(topi.sparse_reshape),
wrap_topi_schedule(topi.generic.schedule_extern),
name="sparse_reshape.generic",
)
return strategy
def wrap_compute_sparse_reshape(topi_compute):
"""Wrap sparse_reshape compute"""
def _compute_sparse_reshape(attrs, inputs, output_type):
return topi_compute(
inputs[0],
inputs[1],
inputs[2],
output_type.fields[0].shape,
output_type.fields[1].shape,
)
return _compute_sparse_reshape
# roi_pool
@generic_func
def schedule_roi_pool(attrs, outs, target):
"""schedule roi_pool"""
with target:
return topi.generic.schedule_roi_pool(outs)
# proposal
def wrap_compute_proposal(topi_compute):
"""wrap proposal topi compute"""
def _compute_proposal(attrs, inputs, out_type):
scales = get_float_tuple(attrs.scales)
ratios = get_float_tuple(attrs.ratios)
feature_stride = attrs.feature_stride
threshold = attrs.threshold
rpn_pre_nms_top_n = attrs.rpn_pre_nms_top_n
rpn_post_nms_top_n = attrs.rpn_post_nms_top_n
rpn_min_size = attrs.rpn_min_size
iou_loss = bool(get_const_int(attrs.iou_loss))
return [
topi_compute(
inputs[0],
inputs[1],
inputs[2],
scales,
ratios,
feature_stride,
threshold,
rpn_pre_nms_top_n,
rpn_post_nms_top_n,
rpn_min_size,
iou_loss,
)
]
return _compute_proposal
@override_native_generic_func("proposal_strategy")
def proposal_strategy(attrs, inputs, out_type, target):
"""proposal generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_proposal(topi.vision.rcnn.proposal),
wrap_topi_schedule(topi.generic.schedule_proposal),
name="proposal.generic",
)
return strategy
# scatter
@override_native_generic_func("scatter_strategy")
def scatter_strategy(attrs, outs, out_type, target):
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_scatter(topi.scatter),
wrap_topi_schedule(topi.generic.schedule_scatter),
name="scatter.generic",
)
return strategy
def wrap_compute_scatter(topi_compute):
"""Wrap scatter topi compute"""
def _compute_scatter(attrs, inputs, _):
return [topi_compute(inputs[0], inputs[1], inputs[2], attrs.axis)]
return _compute_scatter
@override_native_generic_func("scatter_add_strategy")
def scatter_add_strategy(attrs, outs, out_type, target):
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_scatter(topi.scatter_add),
wrap_topi_schedule(topi.generic.schedule_scatter),
name="scatter_add.generic",
)
return strategy
# scatter_nd
@override_native_generic_func("scatter_nd_strategy")
def scatter_nd_strategy(attrs, inputs, out_type, target):
"""scatter_nd generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_scatter_nd(topi.scatter_nd),
wrap_topi_schedule(topi.generic.schedule_extern),
name="scatter_nd.generic",
)
return strategy
def wrap_compute_scatter_nd(topi_compute):
"""Wrap scatter_nd topi compute"""
def _compute_scatter_nd(attrs, inputs, _):
return [topi_compute(inputs[0], inputs[1], attrs.out_shape)]
return _compute_scatter_nd
# bitserial_conv2d
def wrap_compute_bitserial_conv2d(topi_compute):
"""wrap bitserial_conv2d topi compute"""
def compute_bitserial_conv2d(attrs, inputs, out_dtype):
"""Compute definition for bitserial conv2d."""
padding = get_const_tuple(attrs.padding)
strides = get_const_tuple(attrs.strides)
activation_bits = attrs.activation_bits
weight_bits = attrs.weight_bits
pack_dtype = attrs.pack_dtype
out_dtype = attrs.out_dtype
unipolar = attrs.unipolar
return [
topi_compute(
inputs[0],
inputs[1],
strides,
padding,
activation_bits,
weight_bits,
pack_dtype,
out_dtype,
unipolar,
)
]
return compute_bitserial_conv2d
@override_native_generic_func("bitserial_conv2d_strategy")
def bitserial_conv2d_strategy(attrs, inputs, out_type, target):
"""bitserial_conv2d generic strategy"""
logger.warning("bitserial_conv2d is not optimized for this platform.")
strategy = _op.OpStrategy()
layout = attrs.data_layout
if layout == "NCHW":
strategy.add_implementation(
wrap_compute_bitserial_conv2d(topi.nn.bitserial_conv2d_nchw),
wrap_topi_schedule(topi.generic.schedule_bitserial_conv2d_nchw),
name="bitserial_conv2d_nchw.generic",
)
elif layout == "NHWC":
strategy.add_implementation(
wrap_compute_bitserial_conv2d(topi.nn.bitserial_conv2d_nhwc),
wrap_topi_schedule(topi.generic.schedule_bitserial_conv2d_nhwc),
name="bitserial_conv2d_nhwc.generic",
)
else:
raise ValueError("Data layout {} not supported.".format(layout))
return strategy
# bitserial_dense
def wrap_compute_bitserial_dense(topi_compute):
"""wrap bitserial_dense topi compute"""
def compute_bitserial_dense(attrs, inputs, out_type):
"""Compute definition of bitserial dense"""
data_bits = attrs.data_bits
weight_bits = attrs.weight_bits
pack_dtype = attrs.pack_dtype
out_dtype = attrs.out_dtype
out_dtype = inputs[0].dtype if out_dtype == "" else out_dtype
unipolar = attrs.unipolar
return [
topi_compute(
inputs[0], inputs[1], data_bits, weight_bits, pack_dtype, out_dtype, unipolar
)
]
return compute_bitserial_dense
@override_native_generic_func("bitserial_dense_strategy")
def bitserial_dense_strategy(attrs, inputs, out_type, target):
"""bitserial_dense generic strategy"""
logger.warning("bitserial_dense is not optimized for this platform.")
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_bitserial_dense(topi.nn.bitserial_dense),
wrap_topi_schedule(topi.generic.schedule_bitserial_dense),
name="bitserial_dense.generic",
)
return strategy
# correlation
def wrap_compute_correlation(topi_compute):
"""wrap correlation topi compute"""
def _compute_correlation(attrs, inputs, out_type):
kernel_size = attrs.kernel_size
max_displacement = attrs.max_displacement
stride1 = attrs.stride1
stride2 = attrs.stride2
padding = get_const_tuple(attrs.padding)
is_multiply = attrs.is_multiply
return [
topi_compute(
inputs[0],
inputs[1],
kernel_size,
max_displacement,
stride1,
stride2,
padding,
is_multiply,
)
]
return _compute_correlation
@override_native_generic_func("correlation_strategy")
def correlation_strategy(attrs, inputs, out_type, target):
"""correlation generic strategy"""
logger.warning("correlation is not optimized for this platform.")
layout = attrs.layout
assert layout == "NCHW", "Only support NCHW layout"
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_correlation(topi.nn.correlation_nchw),
wrap_topi_schedule(topi.generic.schedule_correlation_nchw),
name="correlation.generic",
)
return strategy
# argwhere
def wrap_compute_argwhere(topi_compute):
"""wrap argwhere topi compute"""
def _compute_argwhere(attrs, inputs, out_type):
output_shape = []
for s in out_type.shape:
if hasattr(s, "value"):
output_shape.append(s)
else:
output_shape.append(te.var("any_dim", "int32"))
new_output_type = ir.TensorType(output_shape, "int32")
return [topi_compute(new_output_type, inputs[0])]
return _compute_argwhere
@override_native_generic_func("argwhere_strategy")
def argwhere_strategy(attrs, inputs, out_type, target):
"""argwhere generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_argwhere(topi.argwhere),
wrap_topi_schedule(topi.generic.schedule_argwhere),
name="argwhere.generic",
)
return strategy
# threefry_generate
def wrap_compute_threefry_generate(topi_compute):
"""Wrap threefry_generate topi compute"""
def _compute_threefry_generate(attrs, inputs, _):
return topi_compute(inputs[0], attrs.out_shape)
return _compute_threefry_generate
@override_native_generic_func("threefry_generate_strategy")
def threefry_generate_strategy(attrs, inputs, out_type, target):
"""threefry_generate generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_threefry_generate(topi.random.threefry_generate),
wrap_topi_schedule(topi.generic.schedule_extern),
name="threefry_generate.generic",
)
return strategy
# threefry_split
def wrap_compute_threefry_split(topi_compute):
"""Wrap threefry_split topi compute"""
def _compute_threefry_split(attrs, inputs, _):
return topi_compute(inputs[0])
return _compute_threefry_split
@override_native_generic_func("threefry_split_strategy")
def threefry_split_strategy(attrs, inputs, out_type, target):
"""threefry_split generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_threefry_split(topi.random.threefry_split),
wrap_topi_schedule(topi.generic.schedule_extern),
name="threefry_split.generic",
)
return strategy
def wrap_compute_scanop(topi_compute):
"""Wrap scanop style topi compute"""
def _compute_scanop(attrs, inputs, _):
return [topi_compute(inputs[0], attrs.axis, attrs.dtype, attrs.exclusive)]
return _compute_scanop
@override_native_generic_func("cumsum_strategy")
def cumsum_strategy(attrs, inputs, out_type, target):
"""cumsum generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_scanop(topi.cumsum),
wrap_topi_schedule(topi.generic.schedule_extern),
name="cumsum.generic",
)
return strategy
@override_native_generic_func("cumprod_strategy")
def cumprod_strategy(attrs, inputs, out_type, target):
"""cumprod generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_scanop(topi.cumprod),
wrap_topi_schedule(topi.generic.schedule_extern),
name="cumprod.generic",
)
return strategy
def wrap_compute_unique(topi_compute):
"""Wrap unique topi compute"""
def _compute_unique(attrs, inputs, _):
return topi_compute(inputs[0], attrs.sorted, attrs.return_counts)
return _compute_unique
@override_native_generic_func("unique_strategy")
def unique_strategy(attrs, inputs, out_type, target):
"""unique generic strategy"""
strategy = _op.OpStrategy()
strategy.add_implementation(
wrap_compute_unique(topi.unique),
wrap_topi_schedule(topi.generic.schedule_unique),
name="unique.generic",
)
return strategy