blob: 49c66a9f8aed61ff80422d30587906746da59e5c [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.mxnet
import scala.util.parsing.json._
import java.io.File
import java.io.PrintWriter
import scala.collection.mutable.ArrayBuffer
object Visualization {
/**
* A simplify implementation of the python-Graphviz library functionality
* based on: https://github.com/xflr6/graphviz/tree/master/graphviz
*/
class Dot(name: String) {
// http://www.graphviz.org/cgi-bin/man?dot
private val ENGINES = Set(
"dot", "neato", "twopi", "circo", "fdp", "sfdp", "patchwork", "osage"
)
// http://www.graphviz.org/doc/info/output.html
private val FORMATS = Set(
"bmp",
"canon", "dot", "gv", "xdot", "xdot1.2", "xdot1.4",
"cgimage",
"cmap",
"eps",
"exr",
"fig",
"gd", "gd2",
"gif",
"gtk",
"ico",
"imap", "cmapx",
"imap_np", "cmapx_np",
"ismap",
"jp2",
"jpg", "jpeg", "jpe",
"pct", "pict",
"pdf",
"pic",
"plain", "plain-ext",
"png",
"pov",
"ps",
"ps2",
"psd",
"sgi",
"svg", "svgz",
"tga",
"tif", "tiff",
"tk",
"vml", "vmlz",
"vrml",
"wbmp",
"webp",
"xlib",
"x11"
)
private val _head = "digraph %s{".format(name)
private val _node = "\t%s %s"
private val _edge = "\t\t%s -> %s %s"
private val _tail = "}"
private val _body = ArrayBuffer[String]()
private def attribute(label: String = null, attrs: Map[String, String]): String = {
if (label != null) {
s"[label=$label ${("" /: attrs){ (acc, elem) => s"$acc ${elem._1}=${elem._2}"}}]"
}
else {
s"[${("" /: attrs){ (acc, elem) => s"$acc ${elem._1}=${elem._2}"}}]"
}
}
/**
* Create a node.
* @param name Unique identifier for the node inside the source.
* @param label Caption to be displayed (defaults to the node name).
* @param attrs Any additional node attributes (must be strings).
*/
def node(name: String, label: String = null, attrs: Map[String, String]): Unit = {
_body += _node.format(name, attribute(label, attrs))
}
/**
* Create an edge between two nodes.
* @param tailName Start node identifier.
* @param headName End node identifier.
* @param label Caption to be displayed near the edge.
* @param attrs Any additional edge attributes (must be strings).
*/
def edge(tailName: String, headName: String,
label: String = null, attrs: Map[String, String]): Unit = {
_body += _edge.format(tailName, headName, attribute(label, attrs))
}
private def save(filename: String, directory: String): String = {
val path = s"$directory${File.separator}$filename"
val writer = new PrintWriter(path)
try {
// scalastyle:off println
writer.println(s"${this._head}")
this._body.toArray.foreach { line => writer.println(s"$line") }
writer.println(s"${this._tail}")
writer.flush()
// scalastyle:off println
} finally {
writer.close()
}
path
}
private def command(engine: String, format: String, filepath: String): String = {
require(ENGINES.contains(engine) == true, s"unknown engine: $engine")
require(FORMATS.contains(format) == true, s"unknown format: $format")
s"$engine -T${format} -O $filepath"
}
/**
* Render file with Graphviz engine into format.
* @param engine The layout commmand used for rendering ('dot', 'neato', ...).
* @param format The output format used for rendering ('pdf', 'png', ...).
* @param fileName Name of the DOT source file to render.
* @param path Path to save the Dot source file.
*/
def render(engine: String = "dot", format: String = "pdf",
fileName: String, path: String): Unit = {
val filePath = this.save(fileName, path)
val args = command(engine, format, filePath)
import sys.process._
try {
args !
} catch { case _ : Throwable =>
val errorMsg = s"""failed to execute "$args", """ +
""""make sure the Graphviz executables are on your systems' path"""
throw new RuntimeException(errorMsg)
}
}
}
/**
* convert shape string to list, internal use only
* @param str shape string
* @return list of string to represent shape
*/
def str2Tuple(str: String): List[String] = {
val re = """\d+""".r
re.findAllIn(str).toList
}
/**
* convert symbol to Dot object for visualization
* @param symbol symbol to be visualized
* @param title title of the dot graph
* @param shape Map of shapes, str -> shape, given input shapes
* @param nodeAttrs Map of node's attributes
* for example:
* nodeAttrs = Map("shape" -> "oval", "fixedsize" -> "false")
* means to plot the network in "oval"
* @param hideWeights
* if true (default) then inputs with names like `*_weight`
* or `*_bias` will be hidden
* @return Dot object of symbol
*/
def plotNetwork(symbol: Symbol,
title: String = "plot", shape: Map[String, Shape] = null,
nodeAttrs: Map[String, String] = Map[String, String](),
hideWeights: Boolean = true): Dot = {
val (drawShape, shapeDict) = {
if (shape == null) (false, null)
else {
val internals = symbol.getInternals()
val (_, outShapes, _) = internals.inferShape(shape)
require(outShapes != null, "Input shape is incomplete")
val shapeDict = internals.listOutputs().zip(outShapes).toMap
(true, shapeDict)
}
}
val conf = JSON.parseFull(symbol.toJson) match {
case None => null
case Some(map) => map.asInstanceOf[Map[String, Any]]
}
require(conf != null)
require(conf.contains("nodes"))
val nodes = conf("nodes").asInstanceOf[List[Any]]
// default attributes of node
val nodeAttr = scala.collection.mutable.Map("shape" -> "box", "fixedsize" -> "true",
"width" -> "1.3", "height" -> "0.8034", "style" -> "filled")
// merge the dict provided by user and the default one
nodeAttrs.foreach { case (k, v) => nodeAttr(k) = v }
val dot = new Dot(name = title)
// color map
val cm = List(""""#8dd3c7"""", """"#fb8072"""", """"#ffffb3"""",
""""#bebada"""", """"#80b1d3"""", """"#fdb462"""",
""""#b3de69"""", """"#fccde5"""")
// Internal helper to figure out if node should be hidden with hide_weights
def looksLikeWeight(name: String): Boolean = {
if (name.endsWith("_weight") || name.endsWith("_bias")
|| name.endsWith("_beta") || name.endsWith("_gamma")
|| name.endsWith("_moving_var") || name.endsWith("_moving_mean")) { true } else { false }
}
// make nodes
val hiddenNodes = scala.collection.mutable.Set[String]()
nodes.foreach { node =>
val params = node.asInstanceOf[Map[String, Any]]
val op = params("op").asInstanceOf[String]
val name = params("name").asInstanceOf[String]
val attrs = {
if (params.contains("attr")) params("attr").asInstanceOf[Map[String, String]]
else Map[String, String]()
}
// input data
val attr = nodeAttr.clone()
var label = name
var continue = false
op match {
case "null" => {
if (looksLikeWeight(name)) {
if (hideWeights) hiddenNodes.add(name)
continue = true
}
attr("shape") = "oval" // inputs get their own shape
label = name
attr("fillcolor") = cm(0)
}
case "Convolution" => {
val kernel = str2Tuple(attrs("kernel"))
val stride = if (attrs.contains("stride")) str2Tuple(attrs("stride")) else List(1)
label =
""""Convolution\n%s/%s, %s"""".format(
kernel.mkString("x"), stride.mkString("x"), attrs("num_filter"))
attr("fillcolor") = cm(1)
}
case "FullyConnected" => {
label = s""""FullyConnected\n${attrs("num_hidden")}""""
attr("fillcolor") = cm(1)
}
case "BatchNorm" => attr("fillcolor") = cm(3)
case "Activation" | "LeakyReLU" => {
label = s""""${op}\n${attrs("act_type")}""""
attr("fillcolor") = cm(2)
}
case "Pooling" => {
val kernel = str2Tuple(attrs("kernel"))
val stride = if (attrs.contains("stride")) str2Tuple(attrs("stride")) else List(1)
label =
s""""Pooling\n%s, %s/%s"""".format(
attrs("pool_type"), kernel.mkString("x"), stride.mkString("x"))
attr("fillcolor") = cm(4)
}
case "Concat" | "Flatten" | "Reshape" => attr("fillcolor") = cm(5)
case "Softmax" => attr("fillcolor") = cm(6)
case _ => {
attr("fillcolor") = cm(7)
if (op == "Custom") label = attrs("op_type")
}
}
if (!continue) dot.node(name = name , label, attr.toMap)
}
val outIdx = scala.collection.mutable.Map[String, Int]()
// add edges
nodes.foreach { node =>
val params = node.asInstanceOf[Map[String, Any]]
val op = params("op").asInstanceOf[String]
val name = params("name").asInstanceOf[String]
if (op != "null") {
val inputs = params("inputs").asInstanceOf[List[List[Double]]]
for (item <- inputs) {
val inputNode = nodes(item(0).toInt).asInstanceOf[Map[String, Any]]
val inputName = inputNode("name").asInstanceOf[String]
if (!hiddenNodes.contains(inputName)) {
val attrs = scala.collection.mutable.Map("dir" -> "back", "arrowtail" -> "open")
// add shapes
if (drawShape) {
val key = {
if (inputNode("op").asInstanceOf[String] != "null") {
var key = s"${inputName}_output"
if (inputNode.contains("attr")) {
val params = inputNode("attr").asInstanceOf[Map[String, String]]
if (params.contains("num_outputs")) {
if (!outIdx.contains(name)) outIdx(name) = params("num_outputs").toInt - 1
key += outIdx(name)
outIdx(name) = outIdx(name) - 1
}
}
key
} else inputName
}
val shape = shapeDict(key).toArray.drop(1)
val label = s""""${shape.mkString("x")}""""
attrs("label") = label
}
dot.edge(tailName = name, headName = inputName, attrs = attrs.toMap)
}
}
}
}
dot
}
}