blob: 039f91aa059ed39a0b5f948cff3b6db3edbe094a [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Test code for space to batch"""
import numpy as np
import tvm
from tvm import te
from tvm import topi
import tvm.testing
import tvm.topi.testing
def verify_space_to_batch_nd(input_shape, block_shape, pad_before, pad_after, pad_value=0):
out_shape = []
out_shape.append(int((input_shape[0] * np.prod(block_shape))))
for i in range(1, len(block_shape) + 1):
pad = pad_before[i - 1] + pad_after[i - 1]
out_shape.append(int((input_shape[i] + pad) // block_shape[i - 1]))
for i in range(len(block_shape) + 1, len(input_shape)):
out_shape.append(input_shape[i])
A = te.placeholder(input_shape, name="A", dtype="float32")
dtype = A.dtype
a_np = np.random.uniform(size=input_shape).astype(dtype)
B = topi.nn.space_to_batch_nd(A, block_shape, pad_before, pad_after, pad_value)
b_np = tvm.topi.testing.space_to_batch_nd_python(
a_np, block_shape, pad_before, pad_after, pad_value
)
def check_target(target, dev):
print("Running on target: %s" % target)
with tvm.target.create(target):
s = tvm.topi.testing.get_injective_schedule(target)(B)
a = tvm.nd.array(a_np, dev)
b = tvm.nd.array(np.zeros(out_shape, dtype=dtype), dev)
f = tvm.build(s, [A, B], target)
f(a, b)
tvm.testing.assert_allclose(b.numpy(), b_np, rtol=1e-3, atol=1e-3)
for target, dev in tvm.testing.enabled_targets():
check_target(target, dev)
@tvm.testing.uses_gpu
def test_space_to_batch():
# Without paddings
verify_space_to_batch_nd([3, 3, 2, 1], [3], [0], [0])
# With paddings
verify_space_to_batch_nd([3, 3, 2, 1], [3], [1], [2])
# Multiple spatial dims
verify_space_to_batch_nd([3, 3, 4, 5, 2], [3, 4, 2], [1, 0, 3], [2, 0, 0])
# No remaining dims
verify_space_to_batch_nd([3, 3, 4, 5, 2], [3, 4, 2, 2], [1, 4, 0, 0], [2, 0, 1, 0])
if __name__ == "__main__":
test_space_to_batch()