blob: dd918ab3a2eaeb5f1ae7907f3ce1d3f58bda4910 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: disable=c-extension-no-member
import functools
from typing import Union, Tuple, List
import pytest
import numpy as np
import tvm
import tvm.testing
from tvm import relax
from tvm.script import ir as I
from tvm.script import tir as T
from tvm.script import relax as R
from tvm.relax.frontend.stablehlo import from_stablehlo
def generate_np_inputs(
input_shapes: Union[Tuple, List[Tuple]], dtype: str = "float32"
) -> Union[np.ndarray, List[np.ndarray]]:
"""Generate numpy data as the inputs of model
Parameters
----------
input_shapes: Union[Tuple, List[Tuple]]
shapes for inputs
dtype: str
the data type of inputs
Results
-------
out: List[np.ndarray]
numpy input data
"""
if not isinstance(input_shapes[0], (list, tuple)):
return [np.random.uniform(size=input_shapes).astype(dtype)]
out = []
for input_shape in input_shapes:
out.append(np.random.uniform(size=input_shape).astype(dtype))
return out
def np2jnp(inputs_np: Union[np.ndarray, List[np.ndarray]]):
"""Convert data from numpy to jax.numpy
Parameters
----------
inputs_np: Union[np.ndarray, List[np.ndarray]]
numpy input data
Results
-------
out: Union[jnp.ndarray, List[jnp.ndarray]]
jax numpy data
"""
import jax.numpy as jnp
# Use jnp.asarray to avoid unnecessary memory copies
inputs_jnp = []
if isinstance(inputs_np, (tuple, list)):
for input_np in inputs_np:
inputs_jnp.append(jnp.asarray(input_np))
return inputs_jnp
return jnp.asarray(inputs_np)
def check_correctness(
jax_jit_mod,
input_shapes: Union[Tuple, List[Tuple]],
dtype: str = "float32",
) -> None:
"""Run a jax model and the translated TVM IRModule,
verify the inference accuracy.
Parameters
----------
jax_jit_mod: jaxlib.xla_extension.CompiledFunction
The input jax jitted model
input_shapes: Union[Tuple, List[Tuple]]
shapes for inputs
dtype: str
the data type of inputs
"""
# Generate numpy inputs
inputs_np = generate_np_inputs(input_shapes, dtype)
# Get the jax numpy data
inputs_jnp = np2jnp(inputs_np)
# lower the jitted function to StableHLO
lowered = jax_jit_mod.lower(*inputs_np)
# lowered.as_text(dialect="stablehlo") generates text format
# compiler_ir generates the related jaxlib.mlir.Module
stablehlo_module = lowered.compiler_ir(dialect="stablehlo")
# Convert the StableHLO IR to Relax
ir_mod = from_stablehlo(stablehlo_module)
# Run the jax jitted model with the input jax numpy data
jax_output = jax_jit_mod(*inputs_jnp)
# TODO (yongwww): support multiple targets,
# "llvm" should be good for this check
target = tvm.target.Target("llvm", host="llvm")
# Compile and run
ex = tvm.compile(ir_mod, target)
vm = relax.VirtualMachine(ex, tvm.cpu())
vm.set_input("main", *inputs_np)
vm.invoke_stateful("main")
tvm_output = vm.get_outputs("main")
# Single ouput
if isinstance(tvm_output, tvm.runtime.Tensor):
tvm.testing.assert_allclose(tvm_output.numpy(), jax_output, rtol=1e-5, atol=1e-5)
return
# Multiple ouputs
assert len(tvm_output) == len(jax_output), "numbers of outputs mismatch"
for tvm_out, jax_out in zip(tvm_output, jax_output):
tvm.testing.assert_allclose(tvm_out.numpy(), jax_out, rtol=1e-5, atol=1e-5)
def get_vm_res(
ir_mod: tvm.IRModule, weights: Union[np.ndarray, List[np.ndarray]]
) -> Union[tvm.runtime.Tensor, List[tvm.runtime.Tensor]]:
"""Compile and run an ir_module on Relax VM
Parameters
----------
ir_mod: tvm.IRModule
input ir module
weights: Union[np.ndarray, List[np.ndarray]]
input weights
Results
-------
out: Union[tvm.runtime.Tensor, List[tvm.runtime.Tensor]]
inference result
"""
target = tvm.target.Target("llvm", host="llvm")
# Compile and run
ex = tvm.compile(ir_mod, target)
vm = relax.VirtualMachine(ex, tvm.cpu())
vm.set_input("main", *weights)
vm.invoke_stateful("main")
tvm_output = vm.get_outputs("main")
return tvm_output
@tvm.testing.requires_gpu
def test_add_dynamic():
add_dyn = """
func.func @test(%arg0: tensor<?x?xf32>, %arg1: tensor<?x?xf32>) -> tensor<?x?xf32> {
%1 = stablehlo.add %arg0, %arg1 : (tensor<?x?xf32>, tensor<?x?xf32>) -> tensor<?x?xf32>
func.return %1 : tensor<?x?xf32>
}
"""
mod = from_stablehlo(add_dyn)
@I.ir_module
class Expected:
@R.function
def main(
arg0: R.Tensor(("n_0", "n_1"), dtype="float32"),
arg1: R.Tensor(("n_2", "n_3"), dtype="float32"),
) -> R.Tensor(dtype="float32", ndim=2):
n_0 = T.int64()
n_1 = T.int64()
n_2 = T.int64()
n_3 = T.int64()
with R.dataflow():
lv: R.Tensor(dtype="float32", ndim=2) = R.add(arg0, arg1)
gv: R.Tensor(dtype="float32", ndim=2) = lv
R.output(gv)
return gv
tvm.ir.assert_structural_equal(mod, Expected)
@tvm.testing.requires_gpu
def test_unary():
import jax
def _rsqrt(x):
return jax.lax.rsqrt(x)
def _sqrt(x):
return jax.lax.sqrt(x)
def _sin(x):
return jax.lax.sin(x)
def _sinh(x):
return jax.lax.sinh(x)
def _cos(x):
return jax.lax.cos(x)
def _cosh(x):
return jax.lax.cos(x)
def _exp(x):
return jax.lax.exp(x)
def _round(x):
return jax.lax.round(x)
input_shapes = (2, 3, 4)
for fn in [_rsqrt, _sqrt, _sin, _cos, _cosh, _exp, _round]:
check_correctness(jax.jit(fn), input_shapes)
@tvm.testing.requires_gpu
def test_binary():
import jax
def fn(x, y):
r1 = x + y
r2 = r1 * r1
r3 = r2 / r1
r = r2 - r3
return r
input_shape = (1, 2, 3)
input_shapes = (input_shape, input_shape)
# jit the function
jit_fn = jax.jit(fn)
# verify inference accuracy
check_correctness(jit_fn, input_shapes)
@tvm.testing.requires_gpu
def test_const():
import jax
def fn(x):
return x + 1
check_correctness(jax.jit(fn), (2,))
@tvm.testing.requires_gpu
def test_maximum():
import jax
import jax.numpy as jnp
def fn(x, y):
return jnp.maximum(x, y)
check_correctness(jax.jit(fn), ((2, 3), (2, 3)))
@tvm.testing.requires_gpu
def test_minimum():
import jax
import jax.numpy as jnp
def fn(x, y):
return jnp.minimum(x, y)
check_correctness(jax.jit(fn), ((2, 3), (2, 3)))
@tvm.testing.requires_gpu
@pytest.mark.skip(
reason="jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed."
)
# TODO(mshr-h): may be fixed by upgrading jax to >=0.4.33
def test_reduce():
import jax
import jax.numpy as jnp
def fn(x):
return jnp.mean(x, axis=(1, 2))
check_correctness(jax.jit(fn), (2, 3, 4, 5))
@tvm.testing.requires_gpu
@pytest.mark.skip(
reason="jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed."
)
# TODO(mshr-h): may be fixed by upgrading jax to >=0.4.33
def test_reduce_window():
import jax
from flax import linen as nn
def fn(x):
return nn.max_pool(x, (3, 3), strides=(2, 2), padding="SAME")
check_correctness(jax.jit(fn), (2, 3, 4))
@tvm.testing.requires_gpu
def test_dot_general():
import jax
def fn(x, y):
return jax.lax.dot_general(x, y, (([1], [0]), ([], [])))
input_shapes = ((1, 512), (512, 2))
check_correctness(jax.jit(fn), input_shapes)
@tvm.testing.requires_gpu
@pytest.mark.skip(
reason="jaxlib.xla_extension.XlaRuntimeError: FAILED_PRECONDITION: DNN library initialization failed."
)
# TODO(yongwww): fix flaky error of "invalid device ordinal"
def test_conv():
import jax
from flax import linen as nn
import jax.random as jrandom
conv = nn.Conv(64, (7, 7), (2, 2), padding=[(3, 3), (3, 3)], name="conv_init")
input_shape = (7, 7, 5, 64)
input_np = generate_np_inputs(input_shape)[0]
input_jnp = np2jnp(input_np)
# initialize the conv
weights = conv.init(jrandom.PRNGKey(0), input_jnp)
# get jax inference output
jax_output = conv.apply(weights, input_jnp)
# assemble numpy data using weights generated above
kernel_np = np.asarray(weights["params"]["kernel"])
bias_np = np.asarray(weights["params"]["bias"])
inputs_np = [bias_np, kernel_np, input_np]
# jit and lower to StableHLO
apply = functools.partial(conv.apply)
stablehlo_module = jax.jit(apply).lower(weights, input_jnp).compiler_ir(dialect="stablehlo")
# convert in Relax
ir_mod = from_stablehlo(stablehlo_module)
# compile and run
tvm_output = get_vm_res(ir_mod, inputs_np)
# verify accuracy
tvm.testing.assert_allclose(tvm_output.numpy(), jax_output, rtol=1e-5, atol=1e-5)
if __name__ == "__main__":
tvm.testing.main()