| # Licensed to the Apache Software Foundation (ASF) under one |
| # or more contributor license agreements. See the NOTICE file |
| # distributed with this work for additional information |
| # regarding copyright ownership. The ASF licenses this file |
| # to you under the Apache License, Version 2.0 (the |
| # "License"); you may not use this file except in compliance |
| # with the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| import collections |
| import os |
| import numbers |
| from PIL import Image |
| |
| import numpy as np |
| import mxnet as mx |
| import mxnet.ndarray as F |
| |
| |
| def tensor_load_rgbimage(filename, ctx, size=None, scale=None, keep_asp=False): |
| img = Image.open(filename).convert('RGB') |
| if size is not None: |
| if keep_asp: |
| size2 = int(size * 1.0 / img.size[0] * img.size[1]) |
| img = img.resize((size, size2), Image.ANTIALIAS) |
| else: |
| img = img.resize((size, size), Image.ANTIALIAS) |
| |
| elif scale is not None: |
| img = img.resize((int(img.size[0] / scale), int(img.size[1] / scale)), Image.ANTIALIAS) |
| img = np.array(img).transpose(2, 0, 1).astype(float) |
| img = F.expand_dims(mx.nd.array(img, ctx=ctx), 0) |
| return img |
| |
| |
| def tensor_save_rgbimage(img, filename, cuda=False): |
| img = F.clip(img, 0, 255).asnumpy() |
| img = img.transpose(1, 2, 0).astype('uint8') |
| img = Image.fromarray(img) |
| img.save(filename) |
| |
| |
| def tensor_save_bgrimage(tensor, filename, cuda=False): |
| (b, g, r) = F.split(tensor, num_outputs=3, axis=0) |
| tensor = F.concat(r, g, b, dim=0) |
| tensor_save_rgbimage(tensor, filename, cuda) |
| |
| |
| def subtract_imagenet_mean_batch(batch): |
| """Subtract ImageNet mean pixel-wise from a BGR image.""" |
| batch = F.swapaxes(batch,0, 1) |
| (r, g, b) = F.split(batch, num_outputs=3, axis=0) |
| r = r - 123.680 |
| g = g - 116.779 |
| b = b - 103.939 |
| batch = F.concat(r, g, b, dim=0) |
| batch = F.swapaxes(batch,0, 1) |
| return batch |
| |
| |
| def subtract_imagenet_mean_preprocess_batch(batch): |
| """Subtract ImageNet mean pixel-wise from a BGR image.""" |
| batch = F.swapaxes(batch,0, 1) |
| (r, g, b) = F.split(batch, num_outputs=3, axis=0) |
| r = r - 123.680 |
| g = g - 116.779 |
| b = b - 103.939 |
| batch = F.concat(b, g, r, dim=0) |
| batch = F.swapaxes(batch,0, 1) |
| return batch |
| |
| |
| def add_imagenet_mean_batch(batch): |
| batch = F.swapaxes(batch,0, 1) |
| (b, g, r) = F.split(batch, num_outputs=3, axis=0) |
| r = r + 123.680 |
| g = g + 116.779 |
| b = b + 103.939 |
| batch = F.concat(b, g, r, dim=0) |
| batch = F.swapaxes(batch,0, 1) |
| """ |
| batch = denormalizer(batch) |
| """ |
| return batch |
| |
| |
| def imagenet_clamp_batch(batch, low, high): |
| """ Not necessary in practice """ |
| F.clip(batch[:,0,:,:],low-123.680, high-123.680) |
| F.clip(batch[:,1,:,:],low-116.779, high-116.779) |
| F.clip(batch[:,2,:,:],low-103.939, high-103.939) |
| |
| |
| def preprocess_batch(batch): |
| batch = F.swapaxes(batch, 0, 1) |
| (r, g, b) = F.split(batch, num_outputs=3, axis=0) |
| batch = F.concat(b, g, r, dim=0) |
| batch = F.swapaxes(batch, 0, 1) |
| return batch |
| |
| |
| class ToTensor(object): |
| def __init__(self, ctx): |
| self.ctx = ctx |
| |
| def __call__(self, img): |
| img = mx.nd.array(np.array(img).transpose(2, 0, 1).astype('float32'), ctx=self.ctx) |
| return img |
| |
| |
| class Compose(object): |
| """Composes several transforms together. |
| Args: |
| transforms (list of ``Transform`` objects): list of transforms to compose. |
| Example: |
| >>> transforms.Compose([ |
| >>> transforms.CenterCrop(10), |
| >>> transforms.ToTensor(), |
| >>> ]) |
| """ |
| |
| def __init__(self, transforms): |
| self.transforms = transforms |
| |
| def __call__(self, img): |
| for t in self.transforms: |
| img = t(img) |
| return img |
| |
| |
| class Scale(object): |
| """Rescale the input PIL.Image to the given size. |
| Args: |
| size (sequence or int): Desired output size. If size is a sequence like |
| (w, h), output size will be matched to this. If size is an int, |
| smaller edge of the image will be matched to this number. |
| i.e, if height > width, then image will be rescaled to |
| (size * height / width, size) |
| interpolation (int, optional): Desired interpolation. Default is |
| ``PIL.Image.BILINEAR`` |
| """ |
| |
| def __init__(self, size, interpolation=Image.BILINEAR): |
| assert isinstance(size, int) or (isinstance(size, collections.Iterable) and len(size) == 2) |
| self.size = size |
| self.interpolation = interpolation |
| |
| def __call__(self, img): |
| """ |
| Args: |
| img (PIL.Image): Image to be scaled. |
| Returns: |
| PIL.Image: Rescaled image. |
| """ |
| if isinstance(self.size, int): |
| w, h = img.size |
| if (w <= h and w == self.size) or (h <= w and h == self.size): |
| return img |
| if w < h: |
| ow = self.size |
| oh = int(self.size * h / w) |
| return img.resize((ow, oh), self.interpolation) |
| else: |
| oh = self.size |
| ow = int(self.size * w / h) |
| return img.resize((ow, oh), self.interpolation) |
| else: |
| return img.resize(self.size, self.interpolation) |
| |
| |
| class CenterCrop(object): |
| """Crops the given PIL.Image at the center. |
| Args: |
| size (sequence or int): Desired output size of the crop. If size is an |
| int instead of sequence like (h, w), a square crop (size, size) is |
| made. |
| """ |
| |
| def __init__(self, size): |
| if isinstance(size, numbers.Number): |
| self.size = (int(size), int(size)) |
| else: |
| self.size = size |
| |
| def __call__(self, img): |
| """ |
| Args: |
| img (PIL.Image): Image to be cropped. |
| Returns: |
| PIL.Image: Cropped image. |
| """ |
| w, h = img.size |
| th, tw = self.size |
| x1 = int(round((w - tw) / 2.)) |
| y1 = int(round((h - th) / 2.)) |
| return img.crop((x1, y1, x1 + tw, y1 + th)) |
| |
| |
| class StyleLoader(): |
| def __init__(self, style_folder, style_size, ctx): |
| self.folder = style_folder |
| self.style_size = style_size |
| self.files = os.listdir(style_folder) |
| assert(len(self.files) > 0) |
| self.ctx = ctx |
| |
| def get(self, i): |
| idx = i%len(self.files) |
| filepath = os.path.join(self.folder, self.files[idx]) |
| style = tensor_load_rgbimage(filepath, self.ctx, self.style_size) |
| return style |
| |
| def size(self): |
| return len(self.files) |
| |
| def init_vgg_params(vgg, model_folder, ctx): |
| if not os.path.exists(os.path.join(model_folder, 'mxvgg.params')): |
| os.system('wget https://www.dropbox.com/s/7c92s0guekwrwzf/mxvgg.params?dl=1 -O' + os.path.join(model_folder, 'mxvgg.params')) |
| vgg.collect_params().load(os.path.join(model_folder, 'mxvgg.params'), ctx=ctx) |
| for param in vgg.collect_params().values(): |
| param.grad_req = 'null' |