blob: 347a74f1361e00d87a559e8c1430b8e9763f6710 [file] [log] [blame]
# -------------------------------------------------------------
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
# -------------------------------------------------------------
from random import randint
import jinja2
import onnx
import systemds.onnx_systemds.onnx_helper as onnx_helper
from systemds.onnx_systemds import util
class GeneratedScriptPart:
def __init__(self, dml_script: str, imports: [str] = None, sub_graphs: [onnx.GraphProto] = None):
if sub_graphs is None:
sub_graphs = []
if imports is None:
imports = []
self.dml_script = dml_script
self.imports = imports
self.sub_graphs = sub_graphs
def gen_simple_function_call(env: jinja2.environment.Environment, graph: onnx.GraphProto,
node: onnx.NodeProto) -> GeneratedScriptPart:
"""
Generates a simple function call by directly providing the node inputs as arguments
and node outputs as outputs to a function call. Additionally adds the required imports.
:param env: Jinja environment to load the template files
:param graph: the onnx-graph for which the script shall be generated
:param node: the onnx-node for which the script shall be generated
:return: The generated script part
"""
operator_template = env.get_template("operators/" + "function_call.dml.jinja")
import_template = env.get_template("module_import.dml.jinja")
if len(list(node.output)) != 1:
raise Exception("Function call needs output")
if len(node.attribute) != 0:
raise Exception("Attributes not supported for this generator")
required_import = {
"Relu": {"path": "/nn/layers/relu.dml", "import_name": "relu_layer", "function_name": "forward"},
"Tanh": {"path": "/nn/layers/tanh.dml", "import_name": "tanh_layer", "function_name": "forward"},
"Sigmoid": {"path": "/nn/layers/sigmoid.dml", "import_name": "sigmoid_layer", "function_name": "forward"},
"Softmax": {"path": "/nn/layers/softmax.dml", "import_name": "softmax_layer", "function_name": "forward"}
}
import_render = ""
function_name = node.op_type
function_namespace = ""
if node.op_type in required_import.keys():
module_import = required_import[node.op_type]
import_render = import_template.render(
path=module_import["path"],
name=module_import["import_name"]
)
function_name = module_import["function_name"]
function_namespace = module_import["import_name"]
node_render = operator_template.render(
function_namespace=function_namespace,
function=function_name,
arguments=list(node.input),
outputs=list(node.output),
doc_string=node.doc_string
)
return GeneratedScriptPart(imports=[import_render], dml_script=node_render)
def gen_2input_1output_operator(env: jinja2.environment.Environment, graph: onnx.GraphProto,
node: onnx.NodeProto) -> GeneratedScriptPart:
"""
Generates simple operator calls like 'z = x + y' which have two inputs (left and right) and one output.
:param env: Jinja environment to load the template files
:param graph: the onnx-graph for which the script shall be generated
:param node: the onnx-node for which the script shall be generated
:return: The generated script part
"""
operator = {
"Add": "+",
"Sub": "-",
"MatMul": "%*%",
"And": "&",
"Or": "|"
}
operator_template = env.get_template("operators/2input_1output_operator.dml.jinja")
if len(node.attribute) != 0:
raise Exception("attributes not supported for operator")
if len(list(node.input)) > 2 or len(list(node.output)) > 1:
raise Exception("Operator needs 2 inputs and 1 output")
node_render = operator_template.render(
input_0=list(node.input)[0],
input_1=list(node.input)[1],
output=list(node.output)[0],
operator=operator[node.op_type],
doc_string=node.doc_string
)
return GeneratedScriptPart(node_render)
def gen_1input_1output_mat_operator(env: jinja2.environment.Environment, graph: onnx.GraphProto,
node: onnx.NodeProto) -> GeneratedScriptPart:
"""
Generates simple operators like 'y = -x' which have one input and one output.
:param env: Jinja environment to load the template files
:param graph: the onnx-graph for which the script shall be generated
:param node: the onnx-node for which the script shall be generated
:return: The generated script part
"""
template_for_operator = {
"Neg": "neg.dml.jinja",
}
operator_template = env.get_template("operators/" + template_for_operator[node.op_type])
if len(node.attribute) != 0:
raise Exception("attributes not supported for operator")
if len(list(node.input)) != 1 or len(list(node.output)) != 1:
raise Exception("Operator needs 1 input and 1 output")
node_render = operator_template.render(
input=list(node.input)[0],
output=list(node.output)[0],
doc_string=node.doc_string
)
return GeneratedScriptPart(dml_script=node_render)
def gen_dropout_call(env: jinja2.environment.Environment, graph: onnx.GraphProto,
node: onnx.NodeProto) -> GeneratedScriptPart:
operator_template = env.get_template("operators/" + "function_call.dml.jinja")
import_template = env.get_template("module_import.dml.jinja")
function_namespace = "dropout_layer"
function_name = "forward"
path = "/nn/layers/dropout.dml"
# * Inputs:
# * - X: Inputs, of shape (any, any).
# * - p: Probability of keeping a neuron output.
# * - seed: [Optional: -1] Random number generator seed to allow for
# * deterministic evaluation. Set to -1 for a random seed.
# * Outputs:
# * - out: Outputs, of same shape as `X`.
# * - mask: Dropout mask used to compute the output.
X = list(node.input)[0]
p = 0.5
seed = -1
if len(list(node.attribute)) > 0:
attributes = list(node.attribute)
if attributes[0].name != "ratio" or len(attributes) > 1:
raise Exception("Error in generating dropout call invalid attributes" + str(attributes))
p = attributes[0].f
import_render = import_template.render(
path=path,
name=function_namespace
)
node_render = operator_template.render(
function_namespace=function_namespace,
function=function_name,
arguments=[X, p, seed],
outputs=list(node.output),
doc_string=node.doc_string
)
return GeneratedScriptPart(imports=[import_render], dml_script=node_render)
def __compute_pad(auto_pad: str, Hf: int, Wf: int, strides: [int], pads: [int], Hin: int, Win: int):
strideh = strides[0]
stridew = strides[1]
if auto_pad == "NOTSET":
padh = pads[0]
padw = pads[1]
if pads[0] != pads[2] or pads[1] != pads[3]:
raise Exception("Only support symmetric pads")
elif auto_pad == "SAME_UPPER" or "SAME_LOWER":
# pad such that output size matches input
padh = (Hin * (strideh - 1) + Hf - strideh) / 2
padw = (Win * (stridew - 1) + Wf - stridew) / 2
elif auto_pad == "VALID":
# no padding
padh = 0
padw = 0
else:
raise Exception("Invalid auto_pad value")
return padh, padw
def gen_maxpool_call(env: jinja2.environment.Environment, graph: onnx.GraphProto,
node: onnx.NodeProto) -> GeneratedScriptPart:
operator_template = env.get_template("operators/" + "function_call.dml.jinja")
import_template = env.get_template("module_import.dml.jinja")
function_namespace = "maxpool_layer"
function_name = "forward"
path = "/nn/layers/max_pool2d.dml"
# * Inputs:
# * - X: Inputs, of shape (N, C*Hin*Win).
# * - C: Number of input channels (dimensionality of input depth).
# * - Hin: Input height.
# * - Win: Input width.
# * - Hf: Filter height.
# * - Wf: Filter width.
# * - strideh: Stride over height.
# * - stridew: Stride over width.
# * - padh: Padding for top and bottom sides.
# * A typical value is 0.
# * - padw: Padding for left and right sides.
# * A typical value is 0.
# *
# * Outputs:
# * - out: Outputs, of shape (N, C*Hout*Wout).
# * - Hout: Output height.
# * - Wout: Output width.
if len(node.input) != 1:
raise Exception("Invalid number of inputs")
if len(node.output) < 1 or len(node.output) > 2:
raise Exception("Invalid number of outputs")
# Inputs
x = onnx_helper.get_value_info(graph, node.input[0])
# dimensions are (N x C x H x W), where N is the batch size, C is the number of channels,
# and H and W are the height and the width
x_shape = onnx_helper.get_valueinfo_dimensions(x)
if len(x_shape) > 4:
raise NotImplementedError("Currently only MaxPool-2D supported")
batch_size = x_shape[0] # TODO: currently not used
C = x_shape[1]
Hin = x_shape[2]
Win = x_shape[3]
# Attributes
auto_pad = "NOTSET"
ceil_mode = 0
dilations = [1, 1]
kernel_shape = None
pads = [0, 0, 0, 0]
storage_order = 0
strides = [1, 1]
for attribute in node.attribute:
if attribute.name == "auto_pad":
auto_pad = attribute.strings[0]
elif attribute.name == "ceil_mode":
ceil_mode = attribute.ints[0]
raise NotImplementedError("Currently no support for ceil_mode")
elif attribute.name == "dilations":
raise NotImplementedError
elif attribute.name == "kernel_shape":
kernel_shape = attribute.ints
elif attribute.name == "pads":
pads = attribute.ints
elif attribute.name == "storage_order":
raise NotImplementedError("Currently no support for storage_order")
elif attribute.name == "strides":
strides = attribute.ints
else:
raise Exception("Invalid Attribute")
if kernel_shape is None:
raise Exception("kernel_shape attribute is required")
Hf = kernel_shape[0]
Wf = kernel_shape[1]
strideh = strides[0]
stridew = strides[1]
padh, padw = __compute_pad(auto_pad, Hf, Wf, strides, pads, Hin, Win)
# Create render
node_render = operator_template.render(
function_namespace=function_namespace,
function=function_name,
arguments=[x.name, C, Hin, Win, Hf, Wf, strideh, stridew, padh, padw],
outputs=list(node.output),
doc_string=node.doc_string
)
import_render = import_template.render(
path=path,
name=function_namespace
)
return GeneratedScriptPart(imports=[import_render], dml_script=node_render)
def gen_conv_call(env: jinja2.environment.Environment, graph: onnx.GraphProto, node: onnx.NodeProto) \
-> GeneratedScriptPart:
operator_template = env.get_template("operators/" + "function_call.dml.jinja")
import_template = env.get_template("module_import.dml.jinja")
function_namespace = "conv_layer"
function_name = "forward"
path = "/nn/layers/conv2d.dml"
# * Inputs:
# * - X: Inputs, of shape (N, C*Hin*Win).
# * - W: Weights, of shape (F, C*Hf*Wf).
# * - b: Biases, of shape (F, 1).
# * - C: Number of input channels (dimensionality of input depth).
# * - Hin: Input height.
# * - Win: Input width.
# * - Hf: Filter height.
# * - Wf: Filter width.
# * - strideh: Stride over height.
# * - stridew: Stride over width.
# * - padh: Padding for top and bottom sides.
# * - padw: Padding for left and right sides.
# *
# * Outputs:
# * - out: Outputs, of shape (N, F*Hout*Wout).
# * - Hout: Output height.
# * - Wout: Output width.
if len(node.input) < 2 or len(node.input) > 3:
raise Exception("Invalid number of inputs")
if len(node.output) > 1:
raise Exception("Invalid number of outputs")
# Inputs
x = onnx_helper.get_value_info(graph, node.input[0])
# size (N x C x H x W), where N is the batch size, C is the number of channels, and H and W are the height and width
x_shape = onnx_helper.get_valueinfo_dimensions(x)
if len(x_shape) > 4:
raise NotImplementedError("Currently only Conv-2D supported")
batch_size = x_shape[0] # TODO: Batch size unused?
C = x_shape[1]
Hin = x_shape[2]
Win = x_shape[3]
w = onnx_helper.get_value_info(graph, node.input[1])
W_shape = onnx_helper.get_valueinfo_dimensions(w)
M = W_shape[0]
C_group = W_shape[1] # TODO Channels/group unused?
Hf = W_shape[2]
Wf = W_shape[3]
bias = None
bias_initializer_render = ""
if len(node.input) == 2:
# Generate 0-bias if no bias given
generated_bias_identifier = "gen_bias"
while onnx_helper.get_value_info(graph, generated_bias_identifier) is not None:
# add random number to create unique identifier if already exists
generated_bias_identifier += str(randint())
bias_init_template = env.get_template("matrix_initialize.dml.jinja")
bias_initializer_render = bias_init_template.render(
identifier_name=generated_bias_identifier,
initializer_values=[0] * M,
rows=M,
cols=1
)
bias = generated_bias_identifier
elif len(node.input) == 3:
bias = node.input[3]
# Attributes
auto_pad = "NOTSET"
dilations = [1, 1]
group = 1
pads = [0, 0, 0, 0]
strides = [1, 1]
for attribute in node.attribute:
if attribute.name == "auto_pad":
auto_pad = attribute.strings[0]
elif attribute.name == "dilations":
raise NotImplementedError
elif attribute.name == "group":
group = attribute.ints[0]
elif attribute.name == "kernel_shape":
kernel_shape = attribute.ints
if kernel_shape[0] != Hf or kernel_shape[1] != Wf:
raise Exception("Invalid kernel shape")
elif attribute.name == "pads":
pads = attribute.ints
elif attribute.name == "strides":
strides = attribute.ints
else:
raise Exception("Invalid Attribute")
strideh = strides[0]
stridew = strides[1]
padh, padw = __compute_pad(auto_pad, Hf, Wf, strides, pads, Hin, Win)
node_render = operator_template.render(
function_namespace=function_namespace,
function=function_name,
arguments=[x.name, w.name, bias, C, Hin, Win, Hf, Wf, strideh, stridew, padh, padw],
outputs=list(node.output),
doc_string=node.doc_string
)
import_render = import_template.render(
path=path,
name=function_namespace
)
return GeneratedScriptPart(imports=[import_render], dml_script=bias_initializer_render + "\n" + node_render)
def gen_if_call(env: jinja2.environment.Environment, graph: onnx.GraphProto, node: onnx.NodeProto) \
-> GeneratedScriptPart:
operator_template = env.get_template("operators/if_operator.dml.jinja")
function_call_template = env.get_template("operators/function_call.dml.jinja")
if len(node.input) != 1:
raise Exception("Wrong number of inputs")
if len(node.attribute) != 2:
raise Exception("Wrong number of attributes")
if node.attribute[0].name != "else_branch" or node.attribute[1].name != "then_branch":
raise Exception("Wrong attributes")
else_graph = node.attribute[0].g
then_graph = node.attribute[1].g
else_call = function_call_template.render(
doc_string="",
function_namespace="",
function=util.generate_function_name(else_graph.name),
arguments=[i.name for i in list(else_graph.input)],
outputs=[o.name for o in list(else_graph.output)],
)
then_call = function_call_template.render(
doc_string="",
function_namespace="",
function=util.generate_function_name(then_graph.name),
arguments=[i.name for i in list(then_graph.input)],
outputs=[o.name for o in list(then_graph.output)],
)
sub_graphs = [else_graph, then_graph]
node_render = operator_template.render(
cond=node.input[0],
then_function_call=then_call,
else_function_call=else_call
)
return GeneratedScriptPart(dml_script=node_render, sub_graphs=sub_graphs)