blob: e4b1be1e21ec01d8dca302e3ea3525eaa5621310 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
# pylint: skip-file
import mxnet as mx
import os
import sys
from common import with_environment
from mxnet.test_utils import environment
num_hidden = 4096
@with_environment('MXNET_MEMORY_OPT', '1')
def test_rnn_cell():
# x →→→ + →→→ tanh ⇒⇒⇒
# ↑
# y →→→→
#
# ⇒⇒⇒ : Backward Dependency
# In this example, there is no benefit in mirroring the elementwise-add
# operator and the tanh operator.
x = mx.sym.Variable("x")
x = mx.sym.FullyConnected(x, num_hidden=num_hidden)
y = mx.sym.Variable("y")
y = mx.sym.FullyConnected(y, num_hidden=num_hidden)
tmp = mx.sym._internal._plus(x, y)
z = mx.sym.Activation(tmp, act_type='tanh')
exec = z._simple_bind(mx.cpu(), 'write', x=(num_hidden,), y=(num_hidden,))
@with_environment('MXNET_MEMORY_OPT', '1')
def test_mlp_attn():
# x →→→ + →→→ tanh ⇒⇒⇒
# ↑ + →→→ tanh ⇒⇒⇒
# y_1 →→ ↑ + →→→ tanh ⇒⇒⇒
# y_2 →→→→ ↑ ⋱
# y_3 →→→→→→ + →→→ tanh ⇒⇒⇒
# ↑
# y_n →→→→→→→→→→
x = mx.sym.Variable("x")
tmp, z = [], []
num_steps = 5
in_arg_shapes = {'x': (num_steps, num_hidden,)}
for i in range(num_steps):
y = mx.sym.Variable(f"y_t{i}")
tmp.append(mx.sym.broadcast_add(x, y, name=f"broadcast_add{i}"))
z.append(mx.sym.Activation(tmp[-1], act_type='tanh',
name=f"activation{i}"))
in_arg_shapes[f"y_t{i}"] = (1, num_hidden,)
z = mx.sym.Group(z)
exec = z._simple_bind(mx.cpu(), 'write', **in_arg_shapes)
@with_environment('MXNET_MEMORY_OPT', '1')
def test_fc():
# x →→→ tanh ⇒⇒⇒ tanh ⇒⇒⇒ FC
# →→→ tanh_ →→→
# ↓
# FC'
x = mx.sym.Variable("x")
y = mx.sym.Activation(x, act_type='tanh', name='y')
z = mx.sym.Activation(y, act_type='tanh', name='z')
z = mx.sym.FullyConnected(z, num_hidden=num_hidden)
exec = z._simple_bind(mx.cpu(), 'write', x=(num_hidden,))