blob: cc885f99830965b6cfd746d3bec489bbb9dd972e [file] [log] [blame]
# coding=utf-8
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
m4_changequote(`<!', `!>')
import sys
import numpy as np
import os
from os import path
# Add dl module to the pythonpath.
sys.path.append(path.dirname(path.dirname(path.dirname(path.dirname(path.abspath(__file__))))))
sys.path.append(path.dirname(path.dirname(path.dirname(path.abspath(__file__)))))
import unittest
from mock import *
import plpy_mock as plpy
class GPUInfoTestCase(unittest.TestCase):
def setUp(self):
self.plpy_mock = Mock(spec='error')
patches = {
'plpy': plpy
}
self.plpy_mock_execute = MagicMock()
plpy.execute = self.plpy_mock_execute
self.module_patcher = patch.dict('sys.modules', patches)
self.module_patcher.start()
import madlib_keras_gpu_info
self.module = madlib_keras_gpu_info
self.module.output_tbl_valid = Mock()
self.table_name = 'does_not_matter'
def tearDown(self):
self.module_patcher.stop()
def test_gpu_configuration_invalid_source(self):
with self.assertRaises(plpy.PLPYException):
self.module.gpu_configuration('schema', self.table_name, 'invalid_source')
def test_gpu_configuration_none_source(self):
self.module.gpu_configuration('schema', self.table_name, None)
self.assertIn('tensorflow', str(self.plpy_mock_execute.call_args_list))
self.assertNotIn('nvidia', str(self.plpy_mock_execute.call_args_list))
def test_gpu_configuration_tensorflow(self):
self.module.gpu_configuration('schema', self.table_name, 'tensorflow')
self.assertIn('tensorflow', str(self.plpy_mock_execute.call_args_list))
self.assertNotIn('nvidia', str(self.plpy_mock_execute.call_args_list))
def test_gpu_configuration_nvidia(self):
self.module.gpu_configuration('schema', self.table_name, 'nvidia')
self.assertIn('nvidia', str(self.plpy_mock_execute.call_args_list))
self.assertNotIn('tensorflow', str(self.plpy_mock_execute.call_args_list))
def test_gpu_configuration_tensorflow_uppercase(self):
self.module.gpu_configuration('schema', self.table_name, 'TENSORFLOW')
self.assertIn('tensorflow', str(self.plpy_mock_execute.call_args_list))
self.assertNotIn('nvidia', str(self.plpy_mock_execute.call_args_list))
def test_gpu_configuration_nvidia_mixedcase(self):
self.module.gpu_configuration('schema', self.table_name, 'NVidIa')
self.assertIn('nvidia', str(self.plpy_mock_execute.call_args_list))
self.assertNotIn('tensorflow', str(self.plpy_mock_execute.call_args_list))
if __name__ == '__main__':
unittest.main()
# ---------------------------------------------------------------------