blob: e04de94ac8027a8d31605d77f955f2b1d885f24f [file] [log] [blame]
# coding=utf-8
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import sys
import numpy as np
from os import path
# Add modules to the pythonpath.
sys.path.append(path.dirname(path.dirname(path.dirname(path.dirname(path.abspath(__file__))))))
sys.path.append(path.dirname(path.dirname(path.dirname(path.abspath(__file__)))))
import unittest
from mock import *
import plpy_mock as plpy
m4_changequote(`<!', `!>')
class DBSCANTestCase(unittest.TestCase):
def setUp(self):
self.plpy_mock = Mock(spec='error')
patches = {
'plpy': plpy
}
# we need to use MagicMock() instead of Mock() for the plpy.execute mock
# to be able to iterate on the return value
self.plpy_mock_execute = MagicMock()
plpy.execute = self.plpy_mock_execute
self.module_patcher = patch.dict('sys.modules', patches)
self.module_patcher.start()
self.default_dict = {
'schema_madlib' : "madlib",
'source_table' : "source",
'output_table' : "output",
'id_column' : "id_in",
'expr_point' : "data",
'eps' : 20,
'min_samples' : 4,
'metric' : 'squared_dist_norm2',
'algorithm' : 'brute'
}
import dbscan.dbscan
self.module = dbscan.dbscan
def tearDown(self):
self.module_patcher.stop()
def test_validate_dbscan(self):
self.module.input_tbl_valid = Mock()
self.module.output_tbl_valid = Mock()
self.module.cols_in_tbl_valid = Mock()
self.module.is_var_valid = Mock(return_value = True)
self.module.get_expr_type = Mock(return_value = 'double precision[]')
self.module._validate_dbscan(**self.default_dict)
if __name__ == '__main__':
unittest.main()
# ---------------------------------------------------------------------