blob: 969beb7d28f7c4dfc87a0c2e47d0f1a8b78f1bd3 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
"""Test code for loss operators."""
import numpy as np
import pytest
import tvm
from tvm import te
from tvm import topi
import tvm.topi.testing
import tvm.testing
prediction_shape, reduction, ignore_index, dtype = tvm.testing.parameters(
((10, 5), "mean", -100, "float32"),
((10, 5, 2, 2), "mean", -100, "float32"),
((10, 5), "sum", -100, "float32"),
((10, 5), "none", -100, "float32"),
((10, 5), "mean", 3, "float32"),
((10, 5), "mean", -100, "float64"),
((5,), "mean", -100, "float32"),
((5,), "mean", 3, "float32"),
((5,), "none", -100, "float32"),
)
def test_nll_loss(target, dev, prediction_shape, reduction, ignore_index, dtype):
if len(prediction_shape) == 1:
C = prediction_shape[0]
target_shape = []
else:
C = prediction_shape[1]
target_shape = prediction_shape[:1] + prediction_shape[2:]
predictions = te.placeholder(shape=prediction_shape, name="predictions", dtype=dtype)
targets = te.placeholder(shape=target_shape, name="targets", dtype="int32")
weights = te.placeholder(shape=(C,), name="weights", dtype=dtype)
nll_loss_result = topi.nn.nll_loss(predictions, targets, weights, reduction, ignore_index)
with tvm.target.Target(target):
fschedule = tvm.topi.testing.get_reduce_schedule(target)
s = fschedule([nll_loss_result])
fn = tvm.build(s, [predictions, targets, weights, nll_loss_result], target, name="nll_loss")
predictions_npy = np.random.uniform(size=prediction_shape).astype(dtype)
targets_npy = np.random.randint(0, C, target_shape).astype("int32")
weights_npy = np.random.uniform(size=(C,)).astype(dtype)
out_npy = tvm.topi.testing.nll_loss(
predictions_npy, targets_npy, weights_npy, reduction, ignore_index
)
predictions_nd = tvm.nd.array(predictions_npy, dev)
targets_nd = tvm.nd.array(targets_npy, dev)
weights_nd = tvm.nd.array(weights_npy, dev)
out_nd = tvm.nd.array(np.empty(out_npy.shape).astype(nll_loss_result.dtype), dev)
fn(predictions_nd, targets_nd, weights_nd, out_nd)
out_topi = out_nd.numpy()
tvm.testing.assert_allclose(out_topi, out_npy, rtol=1e-4, atol=1e-5)
if __name__ == "__main__":
tvm.testing.main()