blob: 51460c8cc022795632ddb98f580117d5761fab22 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import json
import mxnet as mx
from mxnet import amp
from mxnet.gluon import nn
from mxnet.test_utils import assert_almost_equal
from subgraph_common import SG_PASS_NAME, QUANTIZE_SG_PASS_NAME
from test_matmul_subgraph import MultiHeadAttention
import os
import sys
from pathlib import Path
curr_path = Path(__file__).resolve().parent
sys.path.insert(0, str(curr_path.parent.parent))
from amp.common import check_amp_net_stats
AMP_SG_PASS_NAME = 'ONEDNN_AMP'
AMP_DTYPE = 'bfloat16'
def same_graph_structure(symnet_observed, symnet_expected, expected):
nodes_obs = json.loads(symnet_observed.tojson(remove_amp_cast=False))['nodes']
nodes_exp = json.loads(symnet_expected.tojson(remove_amp_cast=False))['nodes']
nodes_obs = [(node['op'], node['inputs']) for node in nodes_obs]
nodes_exp = [(node['op'], node['inputs']) for node in nodes_exp]
assert (len(nodes_obs) == len(nodes_exp)) == expected
for node_obs, node_exp in zip(nodes_obs, nodes_exp):
if node_obs != node_exp:
assert expected == False, '\n'.join([f'{n1} vs {n2}' for n1, n2 in zip(nodes_obs, nodes_exp)])
break
def check_amp_fuse(net, data_example, expected_sym=None, quantized_nodes=[], rtol=0.05):
net.hybridize()
out_ref = net(*data_example)
net.optimize_for(data_example, backend=SG_PASS_NAME) # amp pass works only on oneDNN nodes
lp_net = amp.convert_hybrid_block(net, data_example, target_dtype=AMP_DTYPE,
excluded_sym_names=quantized_nodes, cast_params_offline=True,
device=mx.current_context())
lp_net.optimize_for(data_example, backend=AMP_SG_PASS_NAME)
out_lp_net = lp_net(*data_example)
# check outputs
out_ref = [out_ref] if not isinstance(out_ref, list) else out_ref
out_lp_net = [out_lp_net] if not isinstance(out_ref, list) else out_lp_net
for ref_out, lp_out in zip(out_ref, out_lp_net):
assert_almost_equal(ref_out, lp_out, rtol=rtol, atol=1.0)
# check graph
if expected_sym is not None:
lp_symnet = lp_net.export(None, remove_amp_cast=False)[0]
same_graph_structure(lp_symnet, expected_sym, True)
@mx.util.use_np
def test_amp_fc():
class TestNet(nn.HybridBlock):
def __init__(self):
super(TestNet, self).__init__()
self.fc1 = nn.Dense(16)
self.fc2 = nn.Dense(16)
def forward(self, x):
x = self.fc1(x)
x = self.fc2(x)
return x
net = TestNet()
net.initialize()
exp_data = mx.symbol.Variable('data')
exp_weight = [mx.symbol.Variable('weight{}'.format(i)) for i in range(2)]
exp_bias = [mx.symbol.Variable('bias{}'.format(i)) for i in range(2)]
exp_sym = mx.symbol.amp_cast(exp_data, dtype=AMP_DTYPE)
for weight, bias in zip(exp_weight, exp_bias):
exp_sym = mx.symbol.FullyConnected(exp_sym, weight, bias, num_hidden=1)
exp_sym = exp_sym.get_backend_symbol(SG_PASS_NAME)
data_example = mx.np.random.uniform(-1, 1, (1, 8))
check_amp_fuse(net, [data_example], exp_sym)
check_amp_fuse(net, [data_example], exp_sym, ['sg_onednn_fully_connected_1'])
@mx.util.use_np
def test_amp_conv():
class TestNet(nn.HybridBlock):
def __init__(self):
super(TestNet, self).__init__()
self.conv1 = nn.Conv2D(16, (3, 3))
self.conv2 = nn.Conv2D(16, (3, 3))
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
return x
net = TestNet()
net.initialize()
data_example = mx.np.random.uniform(-1, 1, (1, 3, 8, 8))
exp_data = mx.symbol.Variable('data')
exp_weight = [mx.symbol.Variable('weight{}'.format(i)) for i in range(2)]
exp_bias = [mx.symbol.Variable('bias{}'.format(i)) for i in range(2)]
exp_sym = mx.symbol.amp_cast(exp_data, dtype=AMP_DTYPE)
for weight, bias in zip(exp_weight, exp_bias):
exp_sym = mx.symbol.Convolution(exp_sym, weight, bias, kernel=(3, 3), num_filter=1)
exp_sym = exp_sym.get_backend_symbol(SG_PASS_NAME)
check_amp_fuse(net, [data_example], exp_sym)
exp_sym = mx.symbol.amp_cast(exp_data, dtype=AMP_DTYPE)
for weight, bias in zip(exp_weight, exp_bias):
exp_sym = mx.symbol.Convolution(exp_sym, weight, bias, kernel=(3, 3), num_filter=1)
exp_sym = exp_sym.get_backend_symbol(SG_PASS_NAME)
check_amp_fuse(net, [data_example], exp_sym, ['sg_onednn_conv_1'])
@mx.util.use_np
def test_amp_transformers():
batch_size = 16
seq_length = 32
units = 8
num_heads = 8
in_data = mx.np.random.uniform(size=(batch_size, seq_length, units), dtype='float32')
mask = mx.np.random.randint(0, 2, (batch_size, seq_length, seq_length), dtype='int32')
net = MultiHeadAttention(units, num_heads)
net.initialize()
check_amp_fuse(net, [in_data, mask], None)
check_amp_fuse(net, [in_data, mask], None, ['sg_onednn_fully_connected_0'])
@mx.util.use_np
def test_amp_concat():
class TestNet(nn.HybridBlock):
def __init__(self):
super(TestNet, self).__init__()
self.fc1 = nn.Dense(16)
self.fc2 = nn.Dense(16)
def forward(self, x):
x1 = self.fc1(x)
x2 = self.fc2(x)
x = mx.np.concat((x1, x2), axis=1)
return x
net = TestNet()
net.initialize()
data_example = mx.np.random.uniform(-1, 1, (1, 16))
exp_data = mx.symbol.Variable('data')
exp_amp_data = mx.symbol.amp_cast(exp_data, dtype=AMP_DTYPE)
exp_weight = [mx.symbol.Variable(f"fc{i}_weight") for i in range(2)]
exp_bias = [mx.symbol.Variable(f"fc{i}_bias") for i in range(2)]
exp_fc = [mx.symbol.FullyConnected(exp_amp_data, exp_weight[i], exp_bias[i], num_hidden=1)
for i in range(2)]
exp_sym = mx.symbol.Concat(*exp_fc)
exp_sym = mx.symbol.amp_cast(exp_sym, dtype='float32')
exp_sym = exp_sym.get_backend_symbol(SG_PASS_NAME)
check_amp_fuse(net, [data_example], exp_sym)
exp_fc[0] = mx.symbol.FullyConnected(exp_amp_data, exp_weight[0], exp_bias[0], num_hidden=1)
exp_fc[1] = mx.symbol.FullyConnected(exp_data, exp_weight[1], exp_bias[1], num_hidden=1)
exp_sym = mx.symbol.Concat(*exp_fc)
exp_sym = exp_sym.get_backend_symbol(SG_PASS_NAME)
check_amp_fuse(net, [data_example], exp_sym, ['sg_onednn_fully_connected_1'])
@mx.util.use_np
def test_amp_fuse_with_branch():
class TestNet(nn.HybridBlock):
def __init__(self, **kwargs):
super(TestNet, self).__init__(**kwargs)
self.fc1 = nn.Dense(16)
self.fc2 = nn.Dense(16)
def forward(self, x, *args):
out = self.fc1(x)
out1 = self.fc2(out)
out1 = nn.Activation('relu')(out1)
with nn.HybridBlock.OptConstraint.disable_amp():
out2 = mx.npx.softmax(out)
return out1, out2
net = TestNet()
net.initialize()
data_example = mx.np.ones((10,))
# |---> lp16_op_2
# lp16_op_1 ---|
# |---> f32_amp_cast ---> f32_op
#
# `lp16_op_1` cannot fuse the `f32_amp_cast` node, since `lp16_op_2` already uses its lp16 output
exp_data = mx.sym.Variable('data')
exp_weight = [mx.symbol.Variable('weight{}'.format(i)) for i in range(2)]
exp_bias = [mx.symbol.Variable('bias{}'.format(i)) for i in range(2)]
amp_data = mx.sym.amp_cast(exp_data, dtype=AMP_DTYPE)
lp16_op_1 = mx.sym.FullyConnected(amp_data, exp_weight[0], exp_bias[0], num_hidden=16)
lp16_op_2 = mx.sym.FullyConnected(lp16_op_1, exp_weight[1], exp_bias[1], num_hidden=16)
f32_amp_cast = mx.sym.amp_cast(lp16_op_1, dtype='float32')
f32_op = mx.sym.softmax(f32_amp_cast)
exp_sym = mx.sym.Group([lp16_op_2, f32_op])
exp_sym = exp_sym.get_backend_symbol(SG_PASS_NAME)
check_amp_fuse(net, [data_example], exp_sym)
def test_amp_excluding_after_graph_pass():
class TestNet(nn.HybridBlock):
def __init__(self):
super(TestNet, self).__init__()
self.fc1 = nn.Dense(16)
self.fc2 = nn.Dense(16)
def forward(self, x):
x = self.fc1(x)
with nn.HybridBlock.OptConstraint.disable_amp():
x = self.fc2(x)
return x
data_example = mx.np.random.uniform(-1, 1, (1, 8))
net = TestNet()
net.initialize()
net_before = amp.convert_hybrid_block(net, data_example, AMP_DTYPE, cast_params_offline=True)
check_amp_net_stats(AMP_DTYPE, net_before, data_example, lp16_tensors_num=2,
lp16_casts_num=1, other_casts_num=1)
net.optimize_for(data_example, backend=SG_PASS_NAME) # introduces new nodes
net_after = amp.convert_hybrid_block(net, data_example, AMP_DTYPE, cast_params_offline=True)
check_amp_net_stats(AMP_DTYPE, net_after, data_example, lp16_tensors_num=2,
lp16_casts_num=1, other_casts_num=1)