blob: 3986036073f482792c3168c422480a798cae23c4 [file] [log] [blame]
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
from __future__ import print_function
import argparse
import math
import os
import shutil
import sys
import zipfile
from os import path
import numpy as np
import mxnet as mx
from mxnet import gluon, autograd as ag
from mxnet.gluon import nn
from mxnet.image import CenterCropAug, ResizeAug
from mxnet.io import PrefetchingIter
from mxnet.test_utils import download
this_dir = path.abspath(path.dirname(__file__))
sys.path.append(path.join(this_dir, path.pardir))
from data import ImagePairIter
# CLI
parser = argparse.ArgumentParser(description='Super-resolution using an efficient sub-pixel convolution neural network.')
parser.add_argument('--upscale_factor', type=int, default=3, help="super resolution upscale factor. default is 3.")
parser.add_argument('--batch_size', type=int, default=4, help='training batch size, per device. default is 4.')
parser.add_argument('--test_batch_size', type=int, default=100, help='test batch size')
parser.add_argument('--epochs', type=int, default=30, help='number of training epochs')
parser.add_argument('--lr', type=float, default=0.001, help='learning Rate. default is 0.001.')
parser.add_argument('--use-gpu', action='store_true', help='whether to use GPU.')
parser.add_argument('--seed', type=int, default=123, help='random seed to use. Default=123')
parser.add_argument('--resolve_img', type=str, help='input image to use')
opt = parser.parse_args()
print(opt)
upscale_factor = opt.upscale_factor
batch_size, test_batch_size = opt.batch_size, opt.test_batch_size
color_flag = 0
# Get data from https://github.com/BIDS/BSDS500/
# The BSDS500 Dataset is copyright Berkeley Computer Vision Group
# For more details, see https://www2.eecs.berkeley.edu/Research/Projects/CS/vision/grouping/resources.html#bsds500
datasets_dir = path.expanduser(path.join("~", ".mxnet", "datasets"))
datasets_tmpdir = path.join(datasets_dir, "tmp")
dataset_url = "https://github.com/BIDS/BSDS500/archive/master.zip"
data_dir = path.expanduser(path.join(datasets_dir, "BSDS500"))
tmp_dir = path.join(data_dir, "tmp")
def get_dataset(prefetch=False):
"""Download the BSDS500 dataset and return train and test iters."""
if path.exists(data_dir):
print(
"Directory {} already exists, skipping.\n"
"To force download and extraction, delete the directory and re-run."
"".format(data_dir),
file=sys.stderr,
)
else:
print("Downloading dataset...", file=sys.stderr)
downloaded_file = download(dataset_url, dirname=datasets_tmpdir)
print("done", file=sys.stderr)
print("Extracting files...", end="", file=sys.stderr)
os.makedirs(data_dir)
os.makedirs(tmp_dir)
with zipfile.ZipFile(downloaded_file) as archive:
archive.extractall(tmp_dir)
shutil.rmtree(datasets_tmpdir)
shutil.copytree(
path.join(tmp_dir, "BSDS500-master", "BSDS500", "data", "images"),
path.join(data_dir, "images"),
)
shutil.copytree(
path.join(tmp_dir, "BSDS500-master", "BSDS500", "data", "groundTruth"),
path.join(data_dir, "groundTruth"),
)
shutil.rmtree(tmp_dir)
print("done", file=sys.stderr)
crop_size = 256
crop_size -= crop_size % upscale_factor
input_crop_size = crop_size // upscale_factor
input_transform = [CenterCropAug((crop_size, crop_size)), ResizeAug(input_crop_size)]
target_transform = [CenterCropAug((crop_size, crop_size))]
iters = (
ImagePairIter(
path.join(data_dir, "images", "train"),
(input_crop_size, input_crop_size),
(crop_size, crop_size),
batch_size,
color_flag,
input_transform,
target_transform,
),
ImagePairIter(
path.join(data_dir, "images", "test"),
(input_crop_size, input_crop_size),
(crop_size, crop_size),
test_batch_size,
color_flag,
input_transform,
target_transform,
),
)
return [PrefetchingIter(i) for i in iters] if prefetch else iters
train_data, val_data = get_dataset()
mx.np.random.seed(opt.seed)
device = [mx.gpu(0)] if opt.use_gpu else [mx.cpu()]
class SuperResolutionNet(gluon.HybridBlock):
def __init__(self, upscale_factor):
super(SuperResolutionNet, self).__init__()
self.conv1 = nn.Conv2D(64, (5, 5), strides=(1, 1), padding=(2, 2), activation='relu')
self.conv2 = nn.Conv2D(64, (3, 3), strides=(1, 1), padding=(1, 1), activation='relu')
self.conv3 = nn.Conv2D(32, (3, 3), strides=(1, 1), padding=(1, 1), activation='relu')
self.conv4 = nn.Conv2D(upscale_factor ** 2, (3, 3), strides=(1, 1), padding=(1, 1))
self.pxshuf = nn.PixelShuffle2D(upscale_factor)
def forward(self, x):
x = self.conv1(x)
x = self.conv2(x)
x = self.conv3(x)
x = self.conv4(x)
x = self.pxshuf(x)
return x
net = SuperResolutionNet(upscale_factor)
metric = mx.gluon.metric.MSE()
def test(device):
val_data.reset()
avg_psnr = 0
batches = 0
for batch in val_data:
batches += 1
data = gluon.utils.split_and_load(batch.data[0], device_list=device, batch_axis=0)
label = gluon.utils.split_and_load(batch.label[0], device_list=device, batch_axis=0)
outputs = []
for x in data:
outputs.append(net(x))
metric.update(label, outputs)
avg_psnr += 10 * math.log10(1/metric.get()[1])
metric.reset()
avg_psnr /= batches
print(f'validation avg psnr: {avg_psnr}')
def train(epoch, device):
if isinstance(device, mx.Device):
device = [device]
net.initialize(mx.init.Orthogonal(), device=device)
# re-initialize conv4's weight to be Orthogonal
net.conv4.initialize(mx.init.Orthogonal(scale=1), force_reinit=True, device=device)
trainer = gluon.Trainer(net.collect_params(), 'adam', {'learning_rate': opt.lr})
loss = gluon.loss.L2Loss()
for i in range(epoch):
train_data.reset()
for batch in train_data:
data = gluon.utils.split_and_load(batch.data[0], device_list=device, batch_axis=0)
label = gluon.utils.split_and_load(batch.label[0], device_list=device, batch_axis=0)
outputs = []
with ag.record():
for x, y in zip(data, label):
z = net(x)
L = loss(z, y)
L.backward()
outputs.append(z)
trainer.step(batch.data[0].shape[0])
metric.update(label, outputs)
name, acc = metric.get()
metric.reset()
print(f'training mse at epoch {i}: {name}={acc}')
test(device)
net.save_parameters(path.join(this_dir, 'superres.params'))
def resolve(device):
from PIL import Image
if isinstance(device, list):
device = [device[0]]
img_basename = path.splitext(path.basename(opt.resolve_img))[0]
img_dirname = path.dirname(opt.resolve_img)
net.load_parameters(path.join(this_dir, 'superres.params'), device=device)
img = Image.open(opt.resolve_img).convert('YCbCr')
y, cb, cr = img.split()
data = mx.np.expand_dims(mx.np.expand_dims(mx.np.array(y), axis=0), axis=0)
out_img_y = mx.np.reshape(net(data), shape=(-3, -2)).asnumpy()
out_img_y = out_img_y.clip(0, 255)
out_img_y = Image.fromarray(np.uint8(out_img_y[0]), mode='L')
out_img_cb = cb.resize(out_img_y.size, Image.BICUBIC)
out_img_cr = cr.resize(out_img_y.size, Image.BICUBIC)
out_img = Image.merge('YCbCr', [out_img_y, out_img_cb, out_img_cr]).convert('RGB')
out_img.save(path.join(img_dirname, '{}-resolved.png'.format(img_basename)))
if opt.resolve_img:
resolve(device)
else:
train(opt.epochs, device)