blob: 6253874f723864ef76943cd5f7b099fe548e15f5 [file] [log] [blame]
#!/usr/bin/env python
# coding=utf-8
"""TrainingPreparator engine action.
Use this module to add the project main code.
"""
import os
import cv2
from marvin_python_toolbox.common.data import MarvinData
from marvin_python_toolbox.engine_base import EngineBaseDataHandler
from .._compatibility import six
from .._logging import get_logger
__all__ = ['TrainingPreparator']
logger = get_logger('training_preparator')
class TrainingPreparator(EngineBaseDataHandler):
def __init__(self, **kwargs):
super(TrainingPreparator, self).__init__(**kwargs)
self.image_path = os.path.join(MarvinData.data_path, "Images")
if not os.path.exists(self.image_path):
os.makedirs(self.image_path)
os.makedirs(os.path.join(self.image_path, '0'))
os.makedirs(os.path.join(self.image_path, '1'))
def convert_images(self, image_path, fnames, w=150, h=150):
data = []
logger.info("Resizing images.")
for nn, (fname, label) in enumerate(fnames):
if nn % 100 == 0:
logger.info("{}/{}".format(nn, len(fnames)))
label = 0 if int(label) == -1 else 1
imname = os.path.join(self.image_path, str(label), fname + '.jpg')
if not os.path.exists(imname):
image = cv2.imread(os.path.join(MarvinData.data_path, image_path, fname + '.jpg'))
image = cv2.resize(image, (w, h))
cv2.imwrite(imname, image)
data.append((imname, label))
return data
def execute(self, params, **kwargs):
train, val = self.marvin_initial_dataset
training_data = self.convert_images(params['IMAGES'],
train,
w=params['W'],
h=params['H'])
validation_data = self.convert_images(params['IMAGES'],
val,
w=params['W'],
h=params['H'])
self.marvin_dataset = {'train': training_data, 'val': validation_data}