blob: 293279919f1d99cb12d0f6d88aac9c2ae91ac7d5 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# coding: utf-8
import tvm
AllTypes = ["float32", "float64", "float16", "uint8", "uint16",
"uint32", "uint64", "int8", "int16", "int32", "int64"]
RealTypes = ["float32", "float64", "float16"]
def assign_by_req(a, req, otype=None):
b = tvm.te.placeholder(a.shape, name='assign_by_req_b', dtype=a.dtype)
if req == "kAddTo":
c = tvm.te.compute(a.shape, lambda *idx: a[idx].astype(otype) + b[idx]
if otype else a[idx] + b[idx])
else:
c = tvm.te.compute(a.shape, lambda *idx: a[idx].astype(otype) if otype else a[idx])
return b, c
def reduce_axes(X, axes, reducer, atype=None):
def get_index(idx, ridx):
j = 0
k = 0
ret = []
for val in axes:
ret.append(idx[j] if val == 0 else ridx[k])
j += (val == 0)
k += (val != 0)
return tuple(ret)
ishape = X.shape
odim = (len(ishape) + 1 - axes[0]) // 2
oshape = [tvm.te.size_var() for _ in range(odim)]
ridx = [tvm.te.reduce_axis((0, ishape[i])) for (i, val) in enumerate(axes) if val == 1]
ret = tvm.te.compute(oshape, lambda *idx: reducer(X[get_index(idx, ridx)].astype(atype)
if atype else X[get_index(idx, ridx)],
axis=ridx), name='ret')
return ret