| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| require(mxnet) |
| |
| context("symbol") |
| |
| test_that("basic symbol operation", { |
| data <- mx.symbol.Variable("data") |
| net1 <- mx.symbol.FullyConnected(data = data, name = "fc1", num_hidden = 10) |
| net1 <- mx.symbol.FullyConnected(data = net1, name = "fc2", num_hidden = 100) |
| |
| expect_equal(arguments(net1), c("data", "fc1_weight", "fc1_bias", "fc2_weight", |
| "fc2_bias")) |
| expect_equal(outputs(net1), "fc2_output") |
| |
| net2 <- mx.symbol.FullyConnected(name = "fc3", num_hidden = 10) |
| net2 <- mx.symbol.Activation(data = net2, act_type = "relu") |
| net2 <- mx.symbol.FullyConnected(data = net2, name = "fc4", num_hidden = 20) |
| |
| composed <- mx.apply(net2, fc3_data = net1, name = "composed") |
| |
| expect_equal(arguments(composed), c("data", "fc1_weight", "fc1_bias", "fc2_weight", |
| "fc2_bias", "fc3_weight", "fc3_bias", "fc4_weight", "fc4_bias")) |
| expect_equal(outputs(composed), "composed_output") |
| |
| multi_out <- mx.symbol.Group(c(composed, net1)) |
| expect_equal(outputs(multi_out), c("composed_output", "fc2_output")) |
| }) |
| |
| test_that("symbol internal", { |
| data <- mx.symbol.Variable("data") |
| oldfc <- mx.symbol.FullyConnected(data = data, name = "fc1", num_hidden = 10) |
| net1 <- mx.symbol.FullyConnected(data = oldfc, name = "fc2", num_hidden = 100) |
| |
| expect_equal(arguments(net1), c("data", "fc1_weight", "fc1_bias", "fc2_weight", |
| "fc2_bias")) |
| |
| internal <- internals(net1) |
| fc1 <- internal[[match("fc1_output", internal$outputs)]] |
| |
| expect_equal(arguments(fc1), arguments(oldfc)) |
| }) |
| |
| test_that("symbol children", { |
| data <- mx.symbol.Variable("data") |
| oldfc <- mx.symbol.FullyConnected(data = data, name = "fc1", num_hidden = 10) |
| net1 <- mx.symbol.FullyConnected(data = oldfc, name = "fc2", num_hidden = 100) |
| |
| expect_equal(outputs(children(net1)), c("fc1_output", "fc2_weight", "fc2_bias")) |
| expect_equal(outputs(children(children(net1))), c("data", "fc1_weight", "fc1_bias")) |
| |
| net2 <- net1$get.children() |
| expect_equal(net2[[match("fc2_weight", net2$outputs)]]$arguments, "fc2_weight") |
| |
| data <- mx.symbol.Variable("data") |
| sliced <- mx.symbol.SliceChannel(data, num_outputs = 3, name = "slice") |
| expect_equal(outputs(children(sliced)), "data") |
| }) |
| |
| test_that("symbol infer type", { |
| num_hidden <- 128 |
| num_dim <- 64 |
| num_sample <- 10 |
| |
| data <- mx.symbol.Variable("data") |
| prev <- mx.symbol.Variable("prevstate") |
| x2h <- mx.symbol.FullyConnected(data = data, name = "x2h", num_hidden = num_hidden) |
| h2h <- mx.symbol.FullyConnected(data = prev, name = "h2h", num_hidden = num_hidden) |
| |
| out <- mx.symbol.Activation(data = mx.symbol.elemwise_add(x2h, h2h), name = "out", |
| act_type = "relu") |
| |
| # shape inference will fail because information is not available for h2h |
| ret <- mx.symbol.infer.shape(out, data = c(num_dim, num_sample)) |
| |
| expect_equal(ret, NULL) |
| }) |
| |
| test_that("symbol save/load", { |
| data <- mx.symbol.Variable("data") |
| fc1 <- mx.symbol.FullyConnected(data, num_hidden = 1) |
| lro <- mx.symbol.LinearRegressionOutput(fc1) |
| mx.symbol.save(lro, "tmp_r_sym.json") |
| data2 <- mx.symbol.load("tmp_r_sym.json") |
| |
| expect_equal(data2$as.json(), lro$as.json()) |
| file.remove("tmp_r_sym.json") |
| }) |
| |
| test_that("symbol attributes access", { |
| str <- "(1, 1, 1, 1)" |
| x <- mx.symbol.Variable("x") |
| x$attributes <- list(`__shape__` = str) |
| |
| expect_equal(x$attributes$`__shape__`, str) |
| |
| y <- mx.symbol.Variable("y") |
| y$attributes$`__shape__` <- str |
| |
| expect_equal(y$attributes$`__shape__`, str) |
| }) |
| |
| test_that("symbol concat", { |
| s1 <- mx.symbol.Variable("data1") |
| s2 <- mx.symbol.Variable("data2") |
| s3 <- mx.symbol.concat(data = c(s1, s2), num.args = 2, name = "concat") |
| expect_equal(outputs(s3), "concat_output") |
| expect_equal(outputs(children(s3)), c("data1", "data2")) |
| expect_equal(arguments(s3), c("data1", "data2")) |
| |
| s4 <- mx.symbol.concat(data = c(s1, s2), num.args = 2, name = "concat") |
| expect_equal(outputs(s3), outputs(s4)) |
| expect_equal(outputs(children(s3)), outputs(children(s4))) |
| expect_equal(arguments(s3), arguments(s4)) |
| }) |