blob: 51d203e68287375b24dc4bd6e04e12051ba8af4b [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package ml.dmlc.mxnetexamples.gan
import org.opencv.core.Core
import org.opencv.highgui.Highgui
import org.opencv.imgproc.Imgproc
import ml.dmlc.mxnet.NDArray
import org.opencv.core.Mat
import org.opencv.core.CvType
import java.util.ArrayList
import org.opencv.core.Size
object Viz {
nu.pattern.OpenCV.loadShared()
private def clip(x: Array[Float]): Array[Byte] = {
x.map(_ * 255f).map(x => if (x < 0f) 0 else if (x > 255f) 255 else x.toInt)
.map(_.toByte)
}
private def getImg(rawData: Array[Byte],
channels: Int, height: Int, width: Int, flip: Boolean): Mat = {
val totals = height * width
val img = if (channels > 1) { // rbg image
val (rA, gA, bA) = {
val tmp = rawData.grouped(totals).toArray
(tmp(0), tmp(1), tmp(2))
}
val rr = new Mat(height, width, CvType.CV_8U)
rr.put(0, 0, rA)
val gg = new Mat(height, width, CvType.CV_8U)
gg.put(0, 0, gA)
val bb = new Mat(height, width, CvType.CV_8U)
bb.put(0, 0, bA)
val result = new Mat()
val layers = new ArrayList[Mat]()
layers.add(bb)
layers.add(gg)
layers.add(rr)
Core.merge(layers, result)
result
} else { // gray image
val result = new Mat(height, width, CvType.CV_8U)
result.put(0, 0, rawData)
result
}
if (flip) {
val result = new Mat()
Core.flip(img, result, 0)
result
} else img
}
def imSave(title: String, outputPath: String, x: NDArray, flip: Boolean = false): Unit = {
val shape = x.shape
assert(shape.length == 4)
val (n, c, h, w) = (shape(0), shape(1), shape(2), shape(3))
val totals = h * w
val rawData = clip(x.toArray)
val img = {
val row, col = Math.sqrt(n).toInt
val lineArrs = rawData.grouped(col * c * totals)
val lineMats = new ArrayList[Mat]()
for (line <- lineArrs) {
val imgArr = line.grouped(c * totals)
val colMats = new Mat
val src = new ArrayList[Mat]()
for(arr <- imgArr) src.add(getImg(arr, c, h, w, flip))
Core.hconcat(src, colMats)
lineMats.add(colMats)
}
val result = new Mat()
Core.vconcat(lineMats, result)
result
}
val resizedImg = new Mat
Imgproc.resize(img, resizedImg, new Size(img.width() * 1.5, img.height() * 1.5))
Highgui.imwrite(s"$outputPath/$title.jpg", resizedImg)
}
}