blob: 3344a274e8275c5ed1c0bdb6c06ee77762df0eba [file] [log] [blame]
import find_mxnet
import mxnet as mx
import os, sys
from collections import namedtuple
ConvExecutor = namedtuple('ConvExecutor', ['executor', 'data', 'data_grad', 'style', 'content', 'arg_dict'])
def get_symbol():
# declare symbol
data = mx.sym.Variable("data")
conv1_1 = mx.symbol.Convolution(name='conv1_1', data=data , num_filter=64, pad=(1,1), kernel=(3,3), stride=(1,1), no_bias=False, workspace=1024)
relu1_1 = mx.symbol.Activation(name='relu1_1', data=conv1_1 , act_type='relu')
conv1_2 = mx.symbol.Convolution(name='conv1_2', data=relu1_1 , num_filter=64, pad=(1,1), kernel=(3,3), stride=(1,1), no_bias=False, workspace=1024)
relu1_2 = mx.symbol.Activation(name='relu1_2', data=conv1_2 , act_type='relu')
pool1 = mx.symbol.Pooling(name='pool1', data=relu1_2 , pad=(0,0), kernel=(2,2), stride=(2,2), pool_type='avg')
conv2_1 = mx.symbol.Convolution(name='conv2_1', data=pool1 , num_filter=128, pad=(1,1), kernel=(3,3), stride=(1,1), no_bias=False, workspace=1024)
relu2_1 = mx.symbol.Activation(name='relu2_1', data=conv2_1 , act_type='relu')
conv2_2 = mx.symbol.Convolution(name='conv2_2', data=relu2_1 , num_filter=128, pad=(1,1), kernel=(3,3), stride=(1,1), no_bias=False, workspace=1024)
relu2_2 = mx.symbol.Activation(name='relu2_2', data=conv2_2 , act_type='relu')
pool2 = mx.symbol.Pooling(name='pool2', data=relu2_2 , pad=(0,0), kernel=(2,2), stride=(2,2), pool_type='avg')
conv3_1 = mx.symbol.Convolution(name='conv3_1', data=pool2 , num_filter=256, pad=(1,1), kernel=(3,3), stride=(1,1), no_bias=False, workspace=1024)
relu3_1 = mx.symbol.Activation(name='relu3_1', data=conv3_1 , act_type='relu')
conv3_2 = mx.symbol.Convolution(name='conv3_2', data=relu3_1 , num_filter=256, pad=(1,1), kernel=(3,3), stride=(1,1), no_bias=False, workspace=1024)
relu3_2 = mx.symbol.Activation(name='relu3_2', data=conv3_2 , act_type='relu')
conv3_3 = mx.symbol.Convolution(name='conv3_3', data=relu3_2 , num_filter=256, pad=(1,1), kernel=(3,3), stride=(1,1), no_bias=False, workspace=1024)
relu3_3 = mx.symbol.Activation(name='relu3_3', data=conv3_3 , act_type='relu')
conv3_4 = mx.symbol.Convolution(name='conv3_4', data=relu3_3 , num_filter=256, pad=(1,1), kernel=(3,3), stride=(1,1), no_bias=False, workspace=1024)
relu3_4 = mx.symbol.Activation(name='relu3_4', data=conv3_4 , act_type='relu')
pool3 = mx.symbol.Pooling(name='pool3', data=relu3_4 , pad=(0,0), kernel=(2,2), stride=(2,2), pool_type='avg')
conv4_1 = mx.symbol.Convolution(name='conv4_1', data=pool3 , num_filter=512, pad=(1,1), kernel=(3,3), stride=(1,1), no_bias=False, workspace=1024)
relu4_1 = mx.symbol.Activation(name='relu4_1', data=conv4_1 , act_type='relu')
conv4_2 = mx.symbol.Convolution(name='conv4_2', data=relu4_1 , num_filter=512, pad=(1,1), kernel=(3,3), stride=(1,1), no_bias=False, workspace=1024)
relu4_2 = mx.symbol.Activation(name='relu4_2', data=conv4_2 , act_type='relu')
conv4_3 = mx.symbol.Convolution(name='conv4_3', data=relu4_2 , num_filter=512, pad=(1,1), kernel=(3,3), stride=(1,1), no_bias=False, workspace=1024)
relu4_3 = mx.symbol.Activation(name='relu4_3', data=conv4_3 , act_type='relu')
conv4_4 = mx.symbol.Convolution(name='conv4_4', data=relu4_3 , num_filter=512, pad=(1,1), kernel=(3,3), stride=(1,1), no_bias=False, workspace=1024)
relu4_4 = mx.symbol.Activation(name='relu4_4', data=conv4_4 , act_type='relu')
pool4 = mx.symbol.Pooling(name='pool4', data=relu4_4 , pad=(0,0), kernel=(2,2), stride=(2,2), pool_type='avg')
conv5_1 = mx.symbol.Convolution(name='conv5_1', data=pool4 , num_filter=512, pad=(1,1), kernel=(3,3), stride=(1,1), no_bias=False, workspace=1024)
relu5_1 = mx.symbol.Activation(name='relu5_1', data=conv5_1 , act_type='relu')
# style and content layers
style = mx.sym.Group([relu1_1, relu2_1, relu3_1, relu4_1, relu5_1])
content = mx.sym.Group([relu4_2])
return style, content
def get_executor(style, content, input_size, ctx):
out = mx.sym.Group([style, content])
# make executor
arg_shapes, output_shapes, aux_shapes = out.infer_shape(data=(1, 3, input_size[0], input_size[1]))
arg_names = out.list_arguments()
arg_dict = dict(zip(arg_names, [mx.nd.zeros(shape, ctx=ctx) for shape in arg_shapes]))
grad_dict = {"data": arg_dict["data"].copyto(ctx)}
# init with pretrained weight
pretrained = mx.nd.load("./model/vgg19.params")
for name in arg_names:
if name == "data":
continue
key = "arg:" + name
if key in pretrained:
pretrained[key].copyto(arg_dict[name])
else:
print("Skip argument %s" % name)
executor = out.bind(ctx=ctx, args=arg_dict, args_grad=grad_dict, grad_req="write")
return ConvExecutor(executor=executor,
data=arg_dict["data"],
data_grad=grad_dict["data"],
style=executor.outputs[:-1],
content=executor.outputs[-1],
arg_dict=arg_dict)
def get_model(input_size, ctx):
style, content = get_symbol()
return get_executor(style, content, input_size, ctx)