blob: 9100cc1875a210f89068f2a39e328b2005647bc0 [file] [log] [blame]
import sys
sys.path.insert(0, "../../mxnet/python")
import mxnet as mx
import numpy as np
import basic
import data_processing
import gen_v3
import gen_v4
# params
vgg_params = mx.nd.load("./vgg19.params")
style_weight = 1.2
content_weight = 10
dshape = (1, 3, 384, 384)
clip_norm = 0.05 * np.prod(dshape)
model_prefix = "v3"
ctx = mx.gpu(0)
# init style
style_np = data_processing.PreprocessStyleImage("../starry_night.jpg", shape=dshape)
style_mod = basic.get_style_module("style", dshape, ctx, vgg_params)
style_mod.forward(mx.io.DataBatch([mx.nd.array(style_np)], [0]), is_train=False)
style_array = [arr.copyto(mx.cpu()) for arr in style_mod.get_outputs()]
del style_mod
# content
content_mod = basic.get_content_module("content", dshape, ctx, vgg_params)
# loss
loss, gscale = basic.get_loss_module("loss", dshape, ctx, vgg_params)
extra_args = {"target_gram_%d" % i : style_array[i] for i in range(len(style_array))}
loss.set_params(extra_args, {}, True, True)
grad_array = []
for i in range(len(style_array)):
grad_array.append(mx.nd.ones((1,), ctx) * (float(style_weight) / gscale[i]))
grad_array.append(mx.nd.ones((1,), ctx) * (float(content_weight)))
# generator
gens = [gen_v4.get_module("g0", dshape, ctx),
gen_v3.get_module("g1", dshape, ctx),
gen_v3.get_module("g2", dshape, ctx),
gen_v4.get_module("g3", dshape, ctx)]
for gen in gens:
gen.init_optimizer(
optimizer='sgd',
optimizer_params={
'learning_rate': 1e-4,
'momentum' : 0.9,
'wd': 5e-3,
'clip_gradient' : 5.0
})
# tv-loss
def get_tv_grad_executor(img, ctx, tv_weight):
"""create TV gradient executor with input binded on img
"""
if tv_weight <= 0.0:
return None
nchannel = img.shape[1]
simg = mx.sym.Variable("img")
skernel = mx.sym.Variable("kernel")
channels = mx.sym.SliceChannel(simg, num_outputs=nchannel)
out = mx.sym.Concat(*[
mx.sym.Convolution(data=channels[i], weight=skernel,
num_filter=1,
kernel=(3, 3), pad=(1,1),
no_bias=True, stride=(1,1))
for i in range(nchannel)])
kernel = mx.nd.array(np.array([[0, -1, 0],
[-1, 4, -1],
[0, -1, 0]])
.reshape((1, 1, 3, 3)),
ctx) / 8.0
out = out * tv_weight
return out.bind(ctx, args={"img": img,
"kernel": kernel})
tv_weight = 1e-2
start_epoch = 0
end_epoch = 3
# data
import os
import random
import logging
data_root = "../data/"
file_list = os.listdir(data_root)
num_image = len(file_list)
logging.info("Dataset size: %d" % num_image)
# train
for i in range(start_epoch, end_epoch):
random.shuffle(file_list)
for idx in range(num_image):
loss_grad_array = []
data_array = []
path = data_root + file_list[idx]
content_np = data_processing.PreprocessContentImage(path, min(dshape[2:]), dshape)
data = mx.nd.array(content_np)
data_array.append(data)
# get content
content_mod.forward(mx.io.DataBatch([data], [0]), is_train=False)
content_array = content_mod.get_outputs()[0].copyto(mx.cpu())
# set target content
loss.set_params({"target_content" : content_array}, {}, True, True)
# gen_forward
for k in range(len(gens)):
gens[k].forward(mx.io.DataBatch([data_array[-1]], [0]), is_train=True)
data_array.append(gens[k].get_outputs()[0].copyto(mx.cpu()))
# loss forward
loss.forward(mx.io.DataBatch([data_array[-1]], [0]), is_train=True)
loss.backward(grad_array)
grad = loss.get_input_grads()[0]
loss_grad_array.append(grad.copyto(mx.cpu()))
grad = mx.nd.zeros(data.shape)
for k in range(len(gens) - 1, -1, -1):
tv_grad_executor = get_tv_grad_executor(gens[k].get_outputs()[0],
ctx, tv_weight)
tv_grad_executor.forward()
grad[:] += loss_grad_array[k] + tv_grad_executor.outputs[0].copyto(mx.cpu())
gnorm = mx.nd.norm(grad).asscalar()
if gnorm > clip_norm:
grad[:] *= clip_norm / gnorm
gens[k].backward([grad])
gens[k].update()
if idx % 20 == 0:
logging.info("Epoch %d: Image %d" % (i, idx))
for k in range(len(gens)):
logging.info("Data Norm :%.5f" %\
(mx.nd.norm(gens[k].get_input_grads()[0]).asscalar() / np.prod(dshape)))
if idx % 1000 == 0:
for k in range(len(gens)):
gens[k].save_params("./model/%d/%s_%04d-%07d.params" % (k, model_prefix, i, idx))