blob: 540d4b797e10e5881ac132e87ae1000ee5f5630a [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import pytest
import tvm
import tvm.testing
from tvm import relax, TVMError
from tvm.ir.base import assert_structural_equal
from tvm.relax.training import SetupTrainer
from tvm.relax.training.optimizer import SGD, MomentumSGD
from tvm.relax.training.loss import MSELoss
from tvm.script import ir as I, relax as R
def test_simple():
# fmt: off
@I.ir_module
class Backbone:
I.module_attrs({"param_num": 1, "state_num": 0})
@R.function
def backbone(x: R.Tensor((2, 2), "float64"), y: R.Tensor((2, 2), "float64")):
with R.dataflow():
x1 = x + y
R.output(x1)
return x1
@I.ir_module
class Expected:
I.module_attrs({"input_num": 1, "param_num": 1, "state_num": 0})
@R.function
def backbone(x: R.Tensor((2, 2), dtype="float64"), y: R.Tensor((2, 2), dtype="float64")) -> R.Tensor((2, 2), dtype="float64"):
with R.dataflow():
x1: R.Tensor((2, 2), dtype="float64") = R.add(x, y)
R.output(x1)
return x1
@R.function
def backbone_loss(x: R.Tensor((2, 2), dtype="float64"), y: R.Tensor((2, 2), dtype="float64"), targets: R.Tensor((2, 2), dtype="float64")) -> R.Tensor((), dtype="float64"):
with R.dataflow():
x1: R.Tensor((2, 2), dtype="float64") = R.add(x, y)
lv: R.Tensor((2, 2), dtype="float64") = R.subtract(x1, targets)
lv1: R.Tensor((2, 2), dtype="float64") = R.multiply(lv, lv)
gv: R.Tensor((), dtype="float64") = R.sum(lv1, axis=None, keepdims=False)
R.output(gv)
return gv
@R.function
def backbone_loss_adjoint(x: R.Tensor((2, 2), dtype="float64"), y: R.Tensor((2, 2), dtype="float64"), targets: R.Tensor((2, 2), dtype="float64")) -> R.Tuple(R.Tensor((), dtype="float64"), R.Tuple(R.Tensor((2, 2), dtype="float64"))):
with R.dataflow():
x1: R.Tensor((2, 2), dtype="float64") = R.add(x, y)
lv: R.Tensor((2, 2), dtype="float64") = R.subtract(x1, targets)
lv1: R.Tensor((2, 2), dtype="float64") = R.multiply(lv, lv)
gv: R.Tensor((), dtype="float64") = R.sum(lv1, axis=None, keepdims=False)
gv_adjoint: R.Tensor((), dtype="float64") = R.ones(R.shape([]), dtype="float64")
lv1_adjoint: R.Tensor((2, 2), dtype="float64") = R.broadcast_to(gv_adjoint, R.shape([2, 2]))
lv_adjoint: R.Tensor((2, 2), dtype="float64") = R.multiply(lv1_adjoint, lv)
lv_1: R.Tensor((2, 2), dtype="float64") = R.multiply(lv1_adjoint, lv)
lv_adjoint1: R.Tensor((2, 2), dtype="float64") = R.add(lv_adjoint, lv_1)
x1_adjoint: R.Tensor((2, 2), dtype="float64") = lv_adjoint1
y_adjoint: R.Tensor((2, 2), dtype="float64") = x1_adjoint
y_adjoint_out: R.Tensor((2, 2), dtype="float64") = y_adjoint
R.output(gv, y_adjoint_out)
return (gv, (y_adjoint_out,))
@R.function
def optimizer(params: R.Tuple(R.Tensor((2, 2), dtype="float64")), gradients: R.Tuple(R.Tensor((2, 2), dtype="float64")), optim_states: R.Tuple(R.Tensor((), dtype="int64"))) -> R.Tuple(R.Tuple(R.Tensor((2, 2), dtype="float64")), R.Tuple(R.Tensor((), dtype="int64"))):
with R.dataflow():
num_steps: R.Tensor((), dtype="int64") = optim_states[0]
num_steps_new: R.Tensor((), dtype="int64") = R.add(num_steps, R.const(1, "int64"))
y: R.Tensor((2, 2), dtype="float64") = params[0]
y_grad: R.Tensor((2, 2), dtype="float64") = gradients[0]
lv: R.Tensor((2, 2), dtype="float64") = R.multiply(R.const(0.10000000000000001, "float64"), y_grad)
y_new: R.Tensor((2, 2), dtype="float64") = R.subtract(y, lv)
params_new: R.Tuple(R.Tensor((2, 2), dtype="float64")) = (y_new,)
optim_states_new: R.Tuple(R.Tensor((), dtype="int64")) = (num_steps_new,)
R.output(params_new, optim_states_new)
return (params_new, optim_states_new)
# fmt: on
sinfo = relax.TensorStructInfo((2, 2), "float64")
setup_trainer = SetupTrainer(MSELoss(reduction="sum"), SGD(0.1), [sinfo, sinfo], legalize=False)
train_mod = setup_trainer(Backbone)
assert_structural_equal(train_mod.without_attr("optim_state"), Expected)
def test_states():
# fmt: off
@I.ir_module
class Backbone:
I.module_attrs({"param_num": 1, "state_num": 1})
@R.function
def backbone(x: R.Tensor((2, 2), "float64"), y: R.Tensor((2, 2), "float64"), z: R.Tensor((2, 2), "float64")):
with R.dataflow():
x1 = x + y
z1 = z + R.const(1, "float64")
R.output(x1, z1)
return x1, z1
@I.ir_module
class Expected:
I.module_attrs({"input_num": 1, "param_num": 1, "state_num": 1})
@R.function
def backbone(x: R.Tensor((2, 2), dtype="float64"), y: R.Tensor((2, 2), dtype="float64"), z: R.Tensor((2, 2), dtype="float64")) -> R.Tuple(R.Tensor((2, 2), dtype="float64"), R.Tensor((2, 2), dtype="float64")):
with R.dataflow():
x1: R.Tensor((2, 2), dtype="float64") = R.add(x, y)
z1: R.Tensor((2, 2), dtype="float64") = R.add(z, R.const(1, "float64"))
R.output(x1, z1)
return (x1, z1)
@R.function
def backbone_loss(x: R.Tensor((2, 2), dtype="float64"), y: R.Tensor((2, 2), dtype="float64"), z: R.Tensor((2, 2), dtype="float64"), targets: R.Tensor((2, 2), dtype="float64")) -> R.Tuple(R.Tensor((), dtype="float64"), R.Tensor((2, 2), dtype="float64")):
with R.dataflow():
x1: R.Tensor((2, 2), dtype="float64") = R.add(x, y)
z1: R.Tensor((2, 2), dtype="float64") = R.add(z, R.const(1, "float64"))
lv: R.Tensor((2, 2), dtype="float64") = R.subtract(x1, targets)
lv1: R.Tensor((2, 2), dtype="float64") = R.multiply(lv, lv)
gv: R.Tensor((), dtype="float64") = R.sum(lv1, axis=None, keepdims=False)
R.output(z1, gv)
return (gv, z1)
@R.function
def backbone_loss_adjoint(x: R.Tensor((2, 2), dtype="float64"), y: R.Tensor((2, 2), dtype="float64"), z: R.Tensor((2, 2), dtype="float64"), targets: R.Tensor((2, 2), dtype="float64")) -> R.Tuple(R.Tuple(R.Tensor((), dtype="float64"), R.Tensor((2, 2), dtype="float64")), R.Tuple(R.Tensor((2, 2), dtype="float64"))):
with R.dataflow():
x1: R.Tensor((2, 2), dtype="float64") = R.add(x, y)
z1: R.Tensor((2, 2), dtype="float64") = R.add(z, R.const(1, "float64"))
lv: R.Tensor((2, 2), dtype="float64") = R.subtract(x1, targets)
lv1: R.Tensor((2, 2), dtype="float64") = R.multiply(lv, lv)
gv: R.Tensor((), dtype="float64") = R.sum(lv1, axis=None, keepdims=False)
gv_adjoint: R.Tensor((), dtype="float64") = R.ones(R.shape([]), dtype="float64")
lv1_adjoint: R.Tensor((2, 2), dtype="float64") = R.broadcast_to(gv_adjoint, R.shape([2, 2]))
lv_adjoint: R.Tensor((2, 2), dtype="float64") = R.multiply(lv1_adjoint, lv)
lv_1: R.Tensor((2, 2), dtype="float64") = R.multiply(lv1_adjoint, lv)
lv_adjoint1: R.Tensor((2, 2), dtype="float64") = R.add(lv_adjoint, lv_1)
x1_adjoint: R.Tensor((2, 2), dtype="float64") = lv_adjoint1
y_adjoint: R.Tensor((2, 2), dtype="float64") = x1_adjoint
y_adjoint_out: R.Tensor((2, 2), dtype="float64") = y_adjoint
R.output(z1, gv, y_adjoint_out)
return ((gv, z1), (y_adjoint_out,))
@R.function
def optimizer(params: R.Tuple(R.Tensor((2, 2), dtype="float64")), gradients: R.Tuple(R.Tensor((2, 2), dtype="float64")), optim_states: R.Tuple(R.Tensor((), dtype="int64"), R.Tensor((2, 2), dtype="float64"))) -> R.Tuple(R.Tuple(R.Tensor((2, 2), dtype="float64")), R.Tuple(R.Tensor((), dtype="int64"), R.Tensor((2, 2), dtype="float64"))):
with R.dataflow():
num_steps: R.Tensor((), dtype="int64") = optim_states[0]
num_steps_new: R.Tensor((), dtype="int64") = R.add(num_steps, R.const(1, "int64"))
y: R.Tensor((2, 2), dtype="float64") = params[0]
y_grad: R.Tensor((2, 2), dtype="float64") = gradients[0]
y_v: R.Tensor((2, 2), dtype="float64") = optim_states[1]
lv: R.Tensor((2, 2), dtype="float64") = R.multiply(R.const(0.10000000000000001, "float64"), y_v)
y_v_new: R.Tensor((2, 2), dtype="float64") = R.add(lv, y_grad)
lv1: R.Tensor((2, 2), dtype="float64") = R.multiply(R.const(0.10000000000000001, "float64"), y_v_new)
y_new: R.Tensor((2, 2), dtype="float64") = R.subtract(y, lv1)
params_new: R.Tuple(R.Tensor((2, 2), dtype="float64")) = (y_new,)
optim_states_new: R.Tuple(R.Tensor((), dtype="int64"), R.Tensor((2, 2), dtype="float64")) = num_steps_new, y_v_new
R.output(params_new, optim_states_new)
return (params_new, optim_states_new)
# fmt: on
sinfo = relax.TensorStructInfo((2, 2), "float64")
setup_trainer = SetupTrainer(
MSELoss(reduction="sum"), MomentumSGD(0.1, 0.1), [sinfo, sinfo], legalize=False
)
train_mod = setup_trainer(Backbone)
assert_structural_equal(train_mod.without_attr("optim_state"), Expected)
def test_invalid_mod():
@I.ir_module
class NoAttr:
@R.function
def backbone(
w0: R.Tensor((10, 5), "float32"),
b0: R.Tensor((5,), "float32"),
x: R.Tensor((1, 10), "float32"),
):
with R.dataflow():
lv0 = R.matmul(x, w0)
gv = R.add(lv0, b0)
out = R.nn.relu(gv)
R.output(gv, out)
return gv, out
pred_sinfo = relax.TensorStructInfo((1, 5), "float32")
setup_trainer = SetupTrainer(
MSELoss(reduction="sum"),
SGD(0.001),
[pred_sinfo, pred_sinfo],
)
with pytest.raises((TVMError, ValueError)):
SetupTrainer(
MSELoss(reduction="sum"),
SGD(0.001),
[pred_sinfo, pred_sinfo],
)(NoAttr)
@I.ir_module
class WrongFuncName:
@R.function
def main(
w0: R.Tensor((10, 5), "float32"),
b0: R.Tensor((5,), "float32"),
x: R.Tensor((1, 10), "float32"),
):
with R.dataflow():
lv0 = R.matmul(x, w0)
lv1 = R.add(lv0, b0)
out = R.nn.relu(lv1)
R.output(out)
return out
with pytest.raises(ValueError):
setup_trainer(WrongFuncName)
if __name__ == "__main__":
tvm.testing.main()