blob: 7c7f0171d9aea59092299a3c41b40be41ec8b400 [file] [log] [blame]
#!/usr/bin/env python
# coding=utf-8
try:
import mock
except ImportError:
import unittest.mock as mock
import numpy as np
from marvin_image_classification_engine.training import Trainer
@mock.patch('marvin_image_classification_engine.training.trainer.Sequential')
def test_build_model(mocked_model):
ac = Trainer()
ac.build_model(w=150, h=150, trainable=True)
mocked_model.assert_called_once()
@mock.patch('marvin_image_classification_engine.training.trainer.Sequential.fit_generator')
@mock.patch('marvin_image_classification_engine.training.trainer.Sequential.compile')
@mock.patch('marvin_image_classification_engine.training.trainer.cv2.imread')
def test_execute(mocked_imread, mocked_compile, mocked_fit, mocked_params):
mocked_params = {
'LEARNING_RATE': 0.001,
'MOMENTUM': 0.09,
'STEPS': 2,
'EPOCHS': 1,
'VAL_STEPS': 2
}
test_data = {
'train': ['t0'],
'val': ['t1']
}
mocked_imread.return_value = np.array([[[0, 1, 2], [1,2, 3], [2,3, 4]], [[0, 1, 2], [1,2, 3], [2,3, 4]], [[0, 1, 2], [1,2, 3], [2,3, 4]]])
ac = Trainer(dataset=test_data)
ac.execute(params=mocked_params)
mocked_compile.assert_called_once()
mocked_fit.assert_called_once()