blob: 55e5c3aaaa8faa0dc71396a9cff3e5e580f7f362 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.mxnetexamples.gan
import org.kohsuke.args4j.{CmdLineParser, Option}
import org.slf4j.LoggerFactory
import scala.collection.JavaConverters._
import Viz._
import ml.dmlc.mxnet.Context
import ml.dmlc.mxnet.Shape
import ml.dmlc.mxnet.IO
import ml.dmlc.mxnet.NDArray
import ml.dmlc.mxnet.CustomMetric
import ml.dmlc.mxnet.Xavier
import ml.dmlc.mxnet.optimizer.Adam
import ml.dmlc.mxnet.DataBatch
import ml.dmlc.mxnet.Symbol
import ml.dmlc.mxnet.Shape
/**
* @author Depeng Liang
*/
object GanMnist {
private val logger = LoggerFactory.getLogger(classOf[GanMnist])
// a deconv layer that enlarges the feature map
def deconv2D(data: Symbol, iShape: Shape, oShape: Shape,
kShape: (Int, Int), name: String, stride: (Int, Int) = (2, 2)): Symbol = {
val targetShape = (oShape(oShape.length - 2), oShape(oShape.length - 1))
val net = Symbol.Deconvolution(name)()(Map(
"data" -> data,
"kernel" -> s"$kShape",
"stride" -> s"$stride",
"target_shape" -> s"$targetShape",
"num_filter" -> oShape(0),
"no_bias" -> true))
net
}
def deconv2DBnRelu(data: Symbol, prefix: String, iShape: Shape,
oShape: Shape, kShape: (Int, Int), eps: Float = 1e-5f + 1e-12f): Symbol = {
var net = deconv2D(data, iShape, oShape, kShape, name = s"${prefix}_deconv")
net = Symbol.BatchNorm(s"${prefix}_bn")()(Map("data" -> net, "fix_gamma" -> true, "eps" -> eps))
net = Symbol.Activation(s"${prefix}_act")()(Map("data" -> net, "act_type" -> "relu"))
net
}
def deconv2DAct(data: Symbol, prefix: String, actType: String,
iShape: Shape, oShape: Shape, kShape: (Int, Int)): Symbol = {
var net = deconv2D(data, iShape, oShape, kShape, name = s"${prefix}_deconv")
net = Symbol.Activation(s"${prefix}_act")()(Map("data" -> net, "act_type" -> actType))
net
}
def makeDcganSym(oShape: Shape, ngf: Int = 128, finalAct: String = "sigmoid",
eps: Float = 1e-5f + 1e-12f): (Symbol, Symbol) = {
val code = Symbol.Variable("rand")
var net = Symbol.FullyConnected("g1")()(Map("data" -> code,
"num_hidden" -> 4 * 4 * ngf * 4, "no_bias" -> true))
net = Symbol.Activation("gact1")()(Map("data" -> net, "act_type" -> "relu"))
// 4 x 4
net = Symbol.Reshape()()(Map("data" -> net, "shape" -> s"(-1, ${ngf * 4}, 4, 4)"))
// 8 x 8
net = deconv2DBnRelu(net, prefix = "g2",
iShape = Shape(ngf * 4, 4, 4), oShape = Shape(ngf * 2, 8, 8), kShape = (3, 3))
// 14x14
net = deconv2DBnRelu(net, prefix = "g3",
iShape = Shape(ngf * 2, 8, 8), oShape = Shape(ngf, 14, 14), kShape = (4, 4))
// 28x28
val gout = deconv2DAct(net, prefix = "g4", actType = finalAct, iShape = Shape(ngf, 14, 14),
oShape = Shape(oShape.toArray.takeRight(3)), kShape = (4, 4))
val data = Symbol.Variable("data")
// 28 x 28
val conv1 = Symbol.Convolution("conv1")()(Map("data" -> data,
"kernel" -> "(5,5)", "num_filter" -> 20))
val tanh1 = Symbol.Activation()()(Map("data" -> conv1, "act_type" -> "tanh"))
val pool1 = Symbol.Pooling()()(Map("data" -> tanh1,
"pool_type" -> "max", "kernel" -> "(2,2)", "stride" -> "(2,2)"))
// second conv
val conv2 = Symbol.Convolution("conv2")()(Map("data" -> pool1,
"kernel" -> "(5,5)", "num_filter" -> 50))
val tanh2 = Symbol.Activation()()(Map("data" -> conv2, "act_type" -> "tanh"))
val pool2 = Symbol.Pooling()()(Map("data" -> tanh2, "pool_type" -> "max",
"kernel" -> "(2,2)", "stride" -> "(2,2)"))
var d5 = Symbol.Flatten()()(Map("data" -> pool2))
d5 = Symbol.FullyConnected("fc1")()(Map("data" -> d5, "num_hidden" -> 500))
d5 = Symbol.Activation()()(Map("data" -> d5, "act_type" -> "tanh"))
d5 = Symbol.FullyConnected("fc_dloss")()(Map("data" -> d5, "num_hidden" -> 1))
val dloss = Symbol.LogisticRegressionOutput("dloss")()(Map("data" -> d5))
(gout, dloss)
}
// Evaluation
def ferr(label: NDArray, pred: NDArray): Float = {
val predArr = pred.toArray.map(p => if (p > 0.5) 1f else 0f)
val labelArr = label.toArray
labelArr.zip(predArr).map { case (l, p) => Math.abs(l - p) }.sum / label.shape(0)
}
def main(args: Array[String]): Unit = {
val anst = new GanMnist
val parser: CmdLineParser = new CmdLineParser(anst)
try {
parser.parseArgument(args.toList.asJava)
val dataPath = if (anst.mnistDataPath == null) System.getenv("MXNET_DATA_DIR")
else anst.mnistDataPath
assert(dataPath != null)
val lr = 0.0005f
val beta1 = 0.5f
val batchSize = 100
val randShape = Shape(batchSize, 100)
val numEpoch = 100
val dataShape = Shape(batchSize, 1, 28, 28)
val context = if (anst.gpu == -1) Context.cpu() else Context.gpu(anst.gpu)
val (symGen, symDec) =
makeDcganSym(oShape = dataShape, ngf = 32, finalAct = "sigmoid")
val gMod = new GANModule(
symGen,
symDec,
context = context,
dataShape = dataShape,
codeShape = randShape)
gMod.initGParams(new Xavier(factorType = "in", magnitude = 2.34f))
gMod.initDParams(new Xavier(factorType = "in", magnitude = 2.34f))
gMod.initOptimizer(new Adam(learningRate = lr, wd = 0f, beta1 = beta1))
val params = Map(
"image" -> s"${dataPath}/train-images-idx3-ubyte",
"label" -> s"${dataPath}/train-labels-idx1-ubyte",
"input_shape" -> s"(1, 28, 28)",
"batch_size" -> s"$batchSize",
"shuffle" -> "True"
)
val mnistIter = IO.MNISTIter(params)
val metricAcc = new CustomMetric(ferr, "ferr")
var t = 0
var dataBatch: DataBatch = null
for (epoch <- 0 until numEpoch) {
mnistIter.reset()
metricAcc.reset()
t = 0
while (mnistIter.hasNext) {
dataBatch = mnistIter.next()
gMod.update(dataBatch)
gMod.dLabel.set(0f)
metricAcc.update(Array(gMod.dLabel), gMod.outputsFake)
gMod.dLabel.set(1f)
metricAcc.update(Array(gMod.dLabel), gMod.outputsReal)
if (t % 50 == 0) {
val (name, value) = metricAcc.get
logger.info(s"epoch: $epoch, iter $t, metric=$value")
Viz.imSave("gout", anst.outputPath, gMod.tempOutG(0), flip = true)
val diff = gMod.tempDiffD
val arr = diff.toArray
val mean = arr.sum / arr.length
val std = {
val tmpA = arr.map(a => (a - mean) * (a - mean))
Math.sqrt(tmpA.sum / tmpA.length).toFloat
}
diff.set((diff - mean) / std + 0.5f)
Viz.imSave("diff", anst.outputPath, diff, flip = true)
Viz.imSave("data", anst.outputPath, dataBatch.data(0), flip = true)
}
t += 1
}
}
} catch {
case ex: Exception => {
logger.error(ex.getMessage, ex)
parser.printUsage(System.err)
sys.exit(1)
}
}
}
}
class GanMnist {
@Option(name = "--mnist-data-path", usage = "the mnist data path")
private val mnistDataPath: String = null
@Option(name = "--output-path", usage = "the path to save the generated result")
private val outputPath: String = null
@Option(name = "--gpu", usage = "which gpu card to use, default is -1, means using cpu")
private val gpu: Int = -1
}