blob: ecb98957ff227bb086ef1e20ea66cad44f5571dd [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Test code for bitserial_dense operator"""
import os
import numpy as np
import tvm
from tvm import te
from tvm import topi
import tvm.testing
import tvm.topi.testing
from tvm.topi.utils import get_const_tuple
from tvm.contrib.pickle_memoize import memoize
_bitserial_dense_implement = {
"generic": (topi.nn.bitserial_dense, topi.generic.schedule_bitserial_dense),
"cpu": (topi.x86.bitserial_dense, topi.x86.schedule_bitserial_dense),
"arm_cpu": (topi.arm_cpu.bitserial_dense, topi.arm_cpu.schedule_bitserial_dense),
}
def generate_quantized_np(shape, bits, out_dtype):
min_val = 0
max_val = 1 << bits
return np.random.randint(min_val, max_val, size=shape).astype(out_dtype)
def verify_bitserial_dense(batch, in_dim, out_dim, activation_bits, weight_bits, unipolar):
out_dtype = "int16"
def get_ref_data(a_shape, b_shape, input_dtype):
a_np = generate_quantized_np(get_const_tuple(a_shape), activation_bits, input_dtype)
b_np = generate_quantized_np(get_const_tuple(b_shape), weight_bits, input_dtype)
if unipolar:
b_ = np.copy(b_np).astype(out_dtype)
for x in np.nditer(b_, op_flags=["readwrite"]):
x[...] = 1 if x == 1 else -1
c_np = np.dot(a_np, b_.T)
else:
c_np = np.dot(a_np, b_np.T)
return a_np, b_np, c_np
for target in ["llvm", "llvm -device=arm_cpu"]:
target = tvm.target.Target(target)
if "arm_cpu" in target.keys and "arm" not in os.uname()[4]:
print("Skipped running code, not an arm device")
continue
input_dtype = "uint8" if "arm_cpu" in target.keys else "uint32"
A = te.placeholder((batch, in_dim), dtype=input_dtype, name="A")
B = te.placeholder((out_dim, in_dim), dtype=input_dtype, name="B")
fcompute, fschedule = tvm.topi.testing.dispatch(target, _bitserial_dense_implement)
C = fcompute(A, B, activation_bits, weight_bits, input_dtype, out_dtype, unipolar)
s = fschedule([C])
a_shape = get_const_tuple(A.shape)
b_shape = get_const_tuple(B.shape)
a_np, b_np, c_np = get_ref_data(a_shape, b_shape, input_dtype)
dev = tvm.cpu(0)
a = tvm.nd.array(a_np, dev)
b = tvm.nd.array(b_np, dev)
c = tvm.nd.array(np.zeros(get_const_tuple(C.shape), dtype=C.dtype), dev)
func = tvm.build(s, [A, B, C], target)
func(a, b, c)
tvm.testing.assert_allclose(c.numpy(), c_np, rtol=1e-5)
def test_bitserial_dense():
verify_bitserial_dense(1, 1024, 1000, 1, 1, True)
verify_bitserial_dense(1, 1024, 1000, 2, 1, True)
verify_bitserial_dense(1, 1024, 1000, 1, 1, False)
verify_bitserial_dense(1, 1024, 1000, 2, 1, False)
if __name__ == "__main__":
test_bitserial_dense()