| # | |
| # Licensed to the Apache Software Foundation (ASF) under one | |
| # or more contributor license agreements. See the NOTICE file | |
| # distributed with this work for additional information | |
| # regarding copyright ownership. The ASF licenses this file | |
| # to you under the Apache License, Version 2.0 (the | |
| # "License"); you may not use this file except in compliance | |
| # with the License. You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, | |
| # software distributed under the License is distributed on an | |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | |
| # KIND, either express or implied. See the License for the | |
| # specific language governing permissions and limitations | |
| # under the License. | |
| # | |
| from ml.classifiers import GradientBoostClassifier, MLP, RandomForestClassifier, SVC | |
| from ml.registry import MLRegistry | |
| from django.test import TestCase | |
| import inspect | |
| test_data = { | |
| "age": 22, | |
| "sex": "female", | |
| "job": 2, | |
| "housing": "own", | |
| "credit_amount": 5951, | |
| "duration": 48, | |
| "purpose": "radio/TV" | |
| } | |
| expected_output = 'bad' | |
| class MLTests(TestCase): | |
| def test_rf_algorithm(self): | |
| my_alg = RandomForestClassifier() | |
| response = my_alg.compute_prediction(test_data) | |
| # self.assertEqual('OK', response['status']) | |
| self.assertTrue('label' in response) | |
| self.assertEqual(expected_output, response['label']) | |
| # def test_svc_algorithm(self): | |
| # my_alg = SVC() | |
| # response = my_alg.compute_prediction(test_data) | |
| # self.assertEqual('OK', response['status']) | |
| # self.assertTrue('label' in response) | |
| # self.assertEqual(expected_output, response['label']) | |
| def test_mlp_algorithm(self): | |
| my_alg = MLP() | |
| response = my_alg.compute_prediction(test_data) | |
| # self.assertEqual('OK', response['status']) | |
| self.assertTrue('label' in response) | |
| self.assertEqual(expected_output, response['label']) | |
| def test_gb_algorithm(self): | |
| my_alg = GradientBoostClassifier() | |
| response = my_alg.compute_prediction(test_data) | |
| # self.assertEqual('OK', response['status']) | |
| self.assertTrue('label' in response) | |
| self.assertEqual(expected_output, response['label']) | |
| def test_registry(self): | |
| registry = MLRegistry() | |
| self.assertEqual(len(registry.classifiers), 0) | |
| # Random Forest classifier | |
| rf_algo = { | |
| 'classifier': RandomForestClassifier(), | |
| 'description': "Random Forest with simple pre and post-processing", | |
| 'status': "production", | |
| 'version': "0.0.1", | |
| 'dataset': 'German', | |
| 'region': 'Germany', | |
| 'created_by': "xurror" | |
| } | |
| # add to registry | |
| registry.add_algorithms([rf_algo]) | |
| # there should be one endpoint available | |
| self.assertEqual(len(registry.classifiers), 1) |