blob: 1e345070cee0a2fd822fab8c1be8de2c2a8f9539 [file] [log] [blame]
#!/usr/bin/env python
# coding=utf-8
"""TrainingPreparator engine action.
Use this module to add the project main code.
"""
from .._compatibility import six
from .._logging import get_logger
from keras.utils import np_utils
from marvin_python_toolbox.engine_base import EngineBaseDataHandler
__all__ = ['TrainingPreparator']
logger = get_logger('training_preparator')
class TrainingPreparator(EngineBaseDataHandler):
def __init__(self, **kwargs):
super(TrainingPreparator, self).__init__(**kwargs)
def execute(self, params, **kwargs):
X_train = self.marvin_initial_dataset["X_train"]
X_train = X_train.reshape(X_train.shape[0], 1, 28, 28)
X_train = X_train.astype('float32')
X_train /= 255
X_test = self.marvin_initial_dataset["X_test"]
X_test = X_test.reshape(X_test.shape[0], 1, 28, 28)
X_test = X_test.astype('float32')
X_test /= 255
nb_classes = 10
y_train = np_utils.to_categorical(self.marvin_initial_dataset["y_train"], nb_classes)
y_test = np_utils.to_categorical(self.marvin_initial_dataset["y_test"], nb_classes)
self.marvin_dataset = {
"X_train": X_train,
"y_train": y_train,
"X_test": X_test,
"y_test": y_test
}