blob: 037925c7d033fc8812f54a0e0c9baf19af7e1444 [file]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from typing import Any
import numpy as np
import pytest
import tvm
import tvm.testing
from tvm.relax.frontend import nn
class ParamContainerModule(nn.Module):
def __init__(self):
self.list_params = nn.ParameterList(
[
nn.Parameter((4,), "float32"),
nn.Parameter((4,), "float32"),
]
)
self.dict_params = nn.ParameterDict(
{
"foo": nn.Parameter((4,), "float32"),
"bar": nn.Parameter((4,), "float32"),
}
)
def test_parameter_list_basic_behavior():
p0 = nn.Parameter((4,), "float32")
p1 = nn.Parameter((4,), "float32")
params = nn.ParameterList([p0])
params.append(p1)
assert len(params) == 2
assert params[0] is p0
assert list(params) == [p0, p1]
p2 = nn.Parameter((4,), "float32")
params[1] = p2
assert params[1] is p2
p3 = nn.Parameter((4,), "float32")
params.extend([p3])
assert list(params) == [p0, p2, p3]
def test_parameter_dict_basic_behavior():
p0 = nn.Parameter((4,), "float32")
p1 = nn.Parameter((4,), "float32")
params = nn.ParameterDict({"foo": p0})
params["bar"] = p1
assert len(params) == 2
assert params["foo"] is p0
assert "bar" in params
assert list(params) == ["foo", "bar"]
assert list(params.keys()) == ["foo", "bar"]
assert list(params.values()) == [p0, p1]
assert list(params.items()) == [("foo", p0), ("bar", p1)]
assert params.get("foo") is p0
p2 = nn.Parameter((4,), "float32")
params.update({"baz": p2})
assert list(params.keys()) == ["foo", "bar", "baz"]
assert params.pop("baz") is p2
params.clear()
assert len(params) == 0
def test_type_validation():
with pytest.raises(TypeError):
nn.ParameterList([object()])
with pytest.raises(TypeError):
nn.ParameterDict({"bad": object()})
with pytest.raises(TypeError):
nn.ParameterDict({1: nn.Parameter((4,), "float32")})
with pytest.raises(TypeError):
nn.ParameterList()[0] = object()
def test_named_parameters_parameters_and_state_dict():
m = ParamContainerModule()
expected = [
"list_params.0",
"list_params.1",
"dict_params.foo",
"dict_params.bar",
]
assert list(m.state_dict().keys()) == expected
assert [name for name, _ in m.named_parameters()] == expected
assert len(list(m.parameters())) == 4
def test_nested_traversal_through_module_dict():
class Inner(nn.Module):
def __init__(self):
self.params = nn.ParameterList([nn.Parameter((4,), "float32")])
class Outer(nn.Module):
def __init__(self):
self.blocks = nn.ModuleDict({"inner": Inner()})
m = Outer()
assert list(m.state_dict().keys()) == ["blocks.inner.params.0"]
def test_nested_traversal_through_module_list():
class Inner(nn.Module):
def __init__(self):
self.params = nn.ParameterList([nn.Parameter((4,), "float32")])
class Outer(nn.Module):
def __init__(self):
self.blocks = nn.ModuleList([Inner()])
m = Outer()
assert list(m.state_dict().keys()) == ["blocks.0.params.0"]
def test_to_dtype():
m = ParamContainerModule()
m.to(dtype="float16")
assert m.list_params[0].dtype == "float16"
assert m.list_params[1].dtype == "float16"
assert m.dict_params["foo"].dtype == "float16"
assert m.dict_params["bar"].dtype == "float16"
def test_load_state_dict():
m = ParamContainerModule()
p0 = nn.Parameter((4,), "float32")
p0.data = np.full((4,), 1.0, dtype="float32")
p1 = nn.Parameter((4,), "float32")
p1.data = np.full((4,), 2.0, dtype="float32")
p2 = nn.Parameter((4,), "float32")
p2.data = np.full((4,), 3.0, dtype="float32")
p3 = nn.Parameter((4,), "float32")
p3.data = np.full((4,), 4.0, dtype="float32")
state_dict = {
"list_params.0": p0,
"list_params.1": p1,
"dict_params.foo": p2,
"dict_params.bar": p3,
}
missing_keys, unexpected_keys = m.load_state_dict(state_dict)
assert missing_keys == []
assert unexpected_keys == []
tvm.testing.assert_allclose(m.list_params[0].data.numpy(), np.full((4,), 1.0, "float32"))
tvm.testing.assert_allclose(m.list_params[1].data.numpy(), np.full((4,), 2.0, "float32"))
tvm.testing.assert_allclose(m.dict_params["foo"].data.numpy(), np.full((4,), 3.0, "float32"))
tvm.testing.assert_allclose(m.dict_params["bar"].data.numpy(), np.full((4,), 4.0, "float32"))
def test_export_tvm_parameter_names():
class M(nn.Module):
def __init__(self):
self.biases = nn.ParameterList(
[
nn.Parameter((4,), "float32"),
nn.Parameter((4,), "float32"),
]
)
self.scales = nn.ParameterDict({"main": nn.Parameter((4,), "float32")})
def forward(self, x):
return x + self.biases[0] + self.biases[1] + self.scales["main"]
_, params = M().export_tvm(
spec={"forward": {"x": nn.spec.Tensor((4,), "float32")}},
debug=False,
)
assert [name for name, _ in params] == ["biases.0", "biases.1", "scales.main"]
def test_mutator_parameter_container_names():
seen = []
class Recorder(nn.Mutator):
def visit_param(self, name: str, node: nn.Parameter) -> Any:
seen.append(name)
return node
m = ParamContainerModule()
Recorder().visit_module("", m)
assert seen == [
"list_params.0",
"list_params.1",
"dict_params.foo",
"dict_params.bar",
]
if __name__ == "__main__":
tvm.testing.main()