blob: c93118a7db1f5a2045f88619c017a54a1742965f [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
require(mxnet)
context("symbol")
test_that("basic symbol operation", {
data <- mx.symbol.Variable("data")
net1 <- mx.symbol.FullyConnected(data = data, name = "fc1", num_hidden = 10)
net1 <- mx.symbol.FullyConnected(data = net1, name = "fc2", num_hidden = 100)
expect_equal(arguments(net1), c("data", "fc1_weight", "fc1_bias", "fc2_weight",
"fc2_bias"))
expect_equal(outputs(net1), "fc2_output")
net2 <- mx.symbol.FullyConnected(name = "fc3", num_hidden = 10)
net2 <- mx.symbol.Activation(data = net2, act_type = "relu")
net2 <- mx.symbol.FullyConnected(data = net2, name = "fc4", num_hidden = 20)
composed <- mx.apply(net2, fc3_data = net1, name = "composed")
expect_equal(arguments(composed), c("data", "fc1_weight", "fc1_bias", "fc2_weight",
"fc2_bias", "fc3_weight", "fc3_bias", "fc4_weight", "fc4_bias"))
expect_equal(outputs(composed), "composed_output")
multi_out <- mx.symbol.Group(c(composed, net1))
expect_equal(outputs(multi_out), c("composed_output", "fc2_output"))
})
test_that("symbol internal", {
data <- mx.symbol.Variable("data")
oldfc <- mx.symbol.FullyConnected(data = data, name = "fc1", num_hidden = 10)
net1 <- mx.symbol.FullyConnected(data = oldfc, name = "fc2", num_hidden = 100)
expect_equal(arguments(net1), c("data", "fc1_weight", "fc1_bias", "fc2_weight",
"fc2_bias"))
internal <- internals(net1)
fc1 <- internal[[match("fc1_output", internal$outputs)]]
expect_equal(arguments(fc1), arguments(oldfc))
})
test_that("symbol children", {
data <- mx.symbol.Variable("data")
oldfc <- mx.symbol.FullyConnected(data = data, name = "fc1", num_hidden = 10)
net1 <- mx.symbol.FullyConnected(data = oldfc, name = "fc2", num_hidden = 100)
expect_equal(outputs(children(net1)), c("fc1_output", "fc2_weight", "fc2_bias"))
expect_equal(outputs(children(children(net1))), c("data", "fc1_weight", "fc1_bias"))
net2 <- net1$get.children()
expect_equal(net2[[match("fc2_weight", net2$outputs)]]$arguments, "fc2_weight")
data <- mx.symbol.Variable("data")
sliced <- mx.symbol.SliceChannel(data, num_outputs = 3, name = "slice")
expect_equal(outputs(children(sliced)), "data")
})
test_that("symbol infer type", {
num_hidden <- 128
num_dim <- 64
num_sample <- 10
data <- mx.symbol.Variable("data")
prev <- mx.symbol.Variable("prevstate")
x2h <- mx.symbol.FullyConnected(data = data, name = "x2h", num_hidden = num_hidden)
h2h <- mx.symbol.FullyConnected(data = prev, name = "h2h", num_hidden = num_hidden)
out <- mx.symbol.Activation(data = mx.symbol.elemwise_add(x2h, h2h), name = "out",
act_type = "relu")
# shape inference will fail because information is not available for h2h
ret <- mx.symbol.infer.shape(out, data = c(num_dim, num_sample))
expect_equal(ret, NULL)
})
test_that("symbol save/load", {
data <- mx.symbol.Variable("data")
fc1 <- mx.symbol.FullyConnected(data, num_hidden = 1)
lro <- mx.symbol.LinearRegressionOutput(fc1)
mx.symbol.save(lro, "tmp_r_sym.json")
data2 <- mx.symbol.load("tmp_r_sym.json")
expect_equal(data2$as.json(), lro$as.json())
file.remove("tmp_r_sym.json")
})
test_that("symbol attributes access", {
str <- "(1, 1, 1, 1)"
x <- mx.symbol.Variable("x")
x$attributes <- list(`__shape__` = str)
expect_equal(x$attributes$`__shape__`, str)
y <- mx.symbol.Variable("y")
y$attributes$`__shape__` <- str
expect_equal(y$attributes$`__shape__`, str)
})
test_that("symbol concat", {
s1 <- mx.symbol.Variable("data1")
s2 <- mx.symbol.Variable("data2")
s3 <- mx.symbol.concat(data = c(s1, s2), num.args = 2, name = "concat")
expect_equal(outputs(s3), "concat_output")
expect_equal(outputs(children(s3)), c("data1", "data2"))
expect_equal(arguments(s3), c("data1", "data2"))
s4 <- mx.symbol.concat(data = c(s1, s2), num.args = 2, name = "concat")
expect_equal(outputs(s3), outputs(s4))
expect_equal(outputs(children(s3)), outputs(children(s4)))
expect_equal(arguments(s3), arguments(s4))
})