|  | # Licensed to the Apache Software Foundation (ASF) under one | 
|  | # or more contributor license agreements.  See the NOTICE file | 
|  | # distributed with this work for additional information | 
|  | # regarding copyright ownership.  The ASF licenses this file | 
|  | # to you under the Apache License, Version 2.0 (the | 
|  | # "License"); you may not use this file except in compliance | 
|  | # with the License.  You may obtain a copy of the License at | 
|  | # | 
|  | #   http://www.apache.org/licenses/LICENSE-2.0 | 
|  | # | 
|  | # Unless required by applicable law or agreed to in writing, | 
|  | # software distributed under the License is distributed on an | 
|  | # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | 
|  | # KIND, either express or implied.  See the License for the | 
|  | # specific language governing permissions and limitations | 
|  | # under the License. | 
|  |  | 
|  | from typing import Any | 
|  |  | 
|  | import tvm | 
|  | from tvm.relax.frontend import nn | 
|  |  | 
|  |  | 
|  | def test_mutator_naming_basic(): | 
|  | class Module0(nn.Module): | 
|  | def __init__(self) -> None: | 
|  | super().__init__() | 
|  | self.param0 = nn.Parameter((32, 128), "float64") | 
|  |  | 
|  | class Module1(nn.Module): | 
|  | def __init__(self) -> None: | 
|  | super().__init__() | 
|  | self.mod0 = Module0() | 
|  | self.param1 = nn.Parameter((32, 128), "float32") | 
|  |  | 
|  | class Module2(nn.Module): | 
|  | def __init__(self) -> None: | 
|  | super().__init__() | 
|  | self.mod1 = Module1() | 
|  | self.param2 = nn.Parameter((32, 128), "float16") | 
|  |  | 
|  | class Module3(nn.Module): | 
|  | def __init__(self) -> None: | 
|  | super().__init__() | 
|  | self.mod2 = Module2() | 
|  | self.param3 = nn.Parameter((32, 128), "float8") | 
|  |  | 
|  | class Mutator(nn.Mutator): | 
|  | def visit_param(self, name: str, node: nn.Parameter) -> Any: | 
|  | if node.dtype == "float8": | 
|  | assert name == "mod3.param3" | 
|  | return node | 
|  | elif node.dtype == "float16": | 
|  | assert name == "mod3.mod2.param2" | 
|  | return node | 
|  | elif node.dtype == "float32": | 
|  | assert name == "mod3.mod2.mod1.param1" | 
|  | return node | 
|  | elif node.dtype == "float64": | 
|  | assert name == "mod3.mod2.mod1.mod0.param0" | 
|  | return node | 
|  |  | 
|  | mod3 = Module3() | 
|  | mutator = Mutator() | 
|  | mutator.visit("mod3", mod3) | 
|  |  | 
|  |  | 
|  | def test_mutator_naming_modulelist(): | 
|  | class Module(nn.Module): | 
|  | def __init__(self, dtype) -> None: | 
|  | super().__init__() | 
|  | self.param = nn.Parameter((32, 128), dtype) | 
|  |  | 
|  | class Mutator(nn.Mutator): | 
|  | def visit_param(self, name: str, node: nn.Parameter) -> Any: | 
|  | if node.dtype == "float64": | 
|  | assert name == "mod_list.0.0.param" | 
|  | return node | 
|  | elif node.dtype == "float32": | 
|  | assert name == "mod_list.0.1.param" | 
|  | return node | 
|  | elif node.dtype == "float16": | 
|  | assert name == "mod_list.1.0.param" | 
|  | return node | 
|  | elif node.dtype == "float8": | 
|  | assert name == "mod_list.1.1.param" | 
|  | return node | 
|  |  | 
|  | mod_list = nn.ModuleList( | 
|  | [ | 
|  | nn.ModuleList([Module("float64"), Module("float32")]), | 
|  | nn.ModuleList([Module("float16"), Module("float8")]), | 
|  | ] | 
|  | ) | 
|  | mutator = Mutator() | 
|  | mutator.visit("mod_list", mod_list) | 
|  |  | 
|  |  | 
|  | def test_mutator_module(): | 
|  | class SubModule1(nn.Module): | 
|  | def __init__(self) -> None: | 
|  | super().__init__() | 
|  |  | 
|  | class SubModule2(nn.Module): | 
|  | def __init__(self) -> None: | 
|  | super().__init__() | 
|  |  | 
|  | class Module(nn.Module): | 
|  | def __init__(self) -> None: | 
|  | super().__init__() | 
|  | self.mod = SubModule1() | 
|  |  | 
|  | class Mutator(nn.Mutator): | 
|  | def visit_module(self, name: str, node: nn.Module) -> Any: | 
|  | if isinstance(node, SubModule1): | 
|  | return SubModule2() | 
|  | else: | 
|  | return node | 
|  |  | 
|  | mutator = Mutator() | 
|  | module = Module() | 
|  | assert isinstance(module.mod, SubModule1) | 
|  | module = mutator.visit("", module) | 
|  | assert isinstance(module.mod, SubModule2) | 
|  |  | 
|  |  | 
|  | def test_mutator_modulelist(): | 
|  | class Module1(nn.Module): | 
|  | def __init__(self) -> None: | 
|  | super().__init__() | 
|  |  | 
|  | class Module2(nn.Module): | 
|  | def __init__(self) -> None: | 
|  | super().__init__() | 
|  |  | 
|  | class Module3(nn.Module): | 
|  | def __init__(self) -> None: | 
|  | super().__init__() | 
|  |  | 
|  | class Mutator(nn.Mutator): | 
|  | def visit_module(self, name: str, node: nn.Module) -> Any: | 
|  | if isinstance(node, Module3): | 
|  | return Module1() | 
|  | else: | 
|  | return node | 
|  |  | 
|  | mutator = Mutator() | 
|  | module_list = nn.ModuleList([Module1(), Module2(), Module3()]) | 
|  | assert isinstance(module_list[0], Module1) | 
|  | assert isinstance(module_list[1], Module2) | 
|  | assert isinstance(module_list[2], Module3) | 
|  | module_list = mutator.visit("", module_list) | 
|  | print(module_list[2]) | 
|  | assert isinstance(module_list[0], Module1) | 
|  | assert isinstance(module_list[1], Module2) | 
|  | assert isinstance(module_list[2], Module1) | 
|  |  | 
|  |  | 
|  | def test_mutator_effect(): | 
|  | class Effect1(nn.Effect): | 
|  | pass | 
|  |  | 
|  | class Effect2(nn.Effect): | 
|  | pass | 
|  |  | 
|  | class Module(nn.Module): | 
|  | def __init__(self) -> None: | 
|  | super().__init__() | 
|  | self.effect = Effect1() | 
|  |  | 
|  | class Mutator(nn.Mutator): | 
|  | def visit_effect(self, name: str, node: nn.Effect) -> Any: | 
|  | if isinstance(node, Effect1): | 
|  | return Effect2() | 
|  |  | 
|  | mutator = Mutator() | 
|  | module = Module() | 
|  | assert isinstance(module.effect, Effect1) | 
|  | module = mutator.visit("", module) | 
|  | assert isinstance(module.effect, Effect2) | 
|  |  | 
|  |  | 
|  | def test_mutator_param(): | 
|  | class Module(nn.Module): | 
|  | def __init__(self) -> None: | 
|  | super().__init__() | 
|  | self.weight = nn.Parameter((128, 64), "float16") | 
|  |  | 
|  | class Mutator(nn.Mutator): | 
|  | def visit_param(self, name: str, node: nn.Parameter) -> Any: | 
|  | if node.dtype == "float16": | 
|  | return nn.Parameter(node.shape, "float32") | 
|  |  | 
|  | mutator = Mutator() | 
|  | module = Module() | 
|  | assert module.weight.dtype == "float16" | 
|  | module = mutator.visit("", module) | 
|  | assert module.weight.dtype == "float32" | 
|  |  | 
|  |  | 
|  | def test_mutator_recursively(): | 
|  | class SubModule(nn.Module): | 
|  | def __init__(self) -> None: | 
|  | super().__init__() | 
|  | self.weight = nn.Parameter((128, 64), "float16") | 
|  |  | 
|  | class Module(nn.Module): | 
|  | def __init__(self) -> None: | 
|  | super().__init__() | 
|  | self.mod = SubModule() | 
|  |  | 
|  | class Mutator(nn.Mutator): | 
|  | def visit_param(self, name: str, node: nn.Parameter) -> Any: | 
|  | if node.dtype == "float16": | 
|  | return nn.Parameter(node.shape, "float32") | 
|  |  | 
|  | mutator = Mutator() | 
|  | module = Module() | 
|  | assert module.mod.weight.dtype == "float16" | 
|  | module = mutator.visit("", module) | 
|  | assert module.mod.weight.dtype == "float32" | 
|  |  | 
|  |  | 
|  | if __name__ == "__main__": | 
|  | tvm.testing.main() |