blob: 25f515a106a631dcbf73d4b78b9396069067cf79 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import mxnet as mx
from common import *
def detect_cycle_from(sym, visited, stack):
visited.add(sym.handle.value)
stack.add(sym.handle.value)
for s in sym.get_children():
if s.handle.value not in visited:
if detect_cycle_from(sym, visited, stack):
return True
elif s.handle.value in stack:
return True
stack.remove(sym.handle.value)
return False
def has_no_cycle(sym):
visited = set()
stack = set()
all_nodes = sym.get_internals()
for s in all_nodes:
if s.handle.value in visited:
if detect_cycle_from(s, visited, stack):
return False
return True
def test_simple_cycle():
inp = mx.sym.Variable('input', shape=[1,10])
A = mx.sym.FullyConnected(data=inp, num_hidden=10, no_bias=False, name='A')
B = mx.sym.FullyConnected(data=A, num_hidden=10, no_bias=False, name='B')
D = mx.sym.sin(data=A, name='D')
C = mx.sym.elemwise_add(lhs=B, rhs=D, name='C')
arg_params = {
'I_weight': mx.nd.zeros([10,10]),
'I_bias': mx.nd.zeros([10]),
'A_weight': mx.nd.zeros([10,10]),
'A_bias': mx.nd.zeros([10]),
'B_weight': mx.nd.zeros([10,10]),
'B_bias': mx.nd.zeros([10]),
}
executor = C.simple_bind(ctx=mx.gpu(0), data=(1,10), softmax_label=(1,),
shared_buffer=arg_params, grad_req='null', force_rebind=True)
optimized_graph = mx.contrib.tensorrt.get_optimized_symbol(executor)
assert has_no_cycle(optimized_graph), "The graph optimized by TRT contains a cycle"
if __name__ == '__main__':
import nose
nose.runmodule()