# | |
# Licensed to the Apache Software Foundation (ASF) under one | |
# or more contributor license agreements. See the NOTICE file | |
# distributed with this work for additional information | |
# regarding copyright ownership. The ASF licenses this file | |
# to you under the Apache License, Version 2.0 (the | |
# "License"); you may not use this file except in compliance | |
# with the License. You may obtain a copy of the License at | |
# | |
# http://www.apache.org/licenses/LICENSE-2.0 | |
# | |
# Unless required by applicable law or agreed to in writing, | |
# software distributed under the License is distributed on an | |
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY | |
# KIND, either express or implied. See the License for the | |
# specific language governing permissions and limitations | |
# under the License. | |
# | |
import json | |
import logging | |
from stats.statistical_scoring import stat_score | |
from typing import Any, Dict | |
from drf_spectacular.utils import extend_schema, OpenApiParameter, OpenApiExample, inline_serializer | |
from rest_framework import viewsets | |
from rest_framework.decorators import action | |
from rest_framework.exceptions import APIException, bad_request | |
from rest_framework.fields import CharField, FloatField, IntegerField | |
from rest_framework.response import Response | |
from api.models import Algorithm, Dataset, PredictionRequest | |
from api.serializers import AlgorithmSerializer, PredictionRequestSerializer, DatasetSerializer | |
from ml.classifiers import RandomForestClassifier | |
from server.wsgi import registry | |
# Create your views here. | |
log = logging.getLogger(__name__) | |
class AlgorithmViewSet(viewsets.ModelViewSet): | |
# permission_classes = [] | |
serializer_class = AlgorithmSerializer | |
queryset = Algorithm.objects.all() | |
@extend_schema( | |
description='Predict credit risk for a loan', | |
parameters=[ | |
OpenApiParameter( | |
name='classifier', | |
description='The algorithm/classifier to use', | |
required=True, | |
examples=[ | |
OpenApiExample( | |
'Example 1', | |
value=RandomForestClassifier().__class__.__name__) | |
]), | |
OpenApiParameter( | |
name='dataset', | |
description='The name of the dataset', | |
examples=[OpenApiExample('Example 1', value='german')]), | |
OpenApiParameter( | |
name='status', | |
description='The status of the algorithm', | |
deprecated=True, | |
examples=[OpenApiExample('Example 1', value='production')]), | |
OpenApiParameter( | |
name='version', | |
description='Algorithm version', | |
required=True, | |
default='0.0.1', | |
examples=[OpenApiExample('Example 1', value='0.0.1')]), | |
], | |
operation_id='algorithms_predict', | |
request=Dict[str, Any], | |
responses=inline_serializer(name="PredictionResponse", | |
fields={ | |
"probability": FloatField(), | |
"label": CharField(), | |
"method": CharField(), | |
"color": CharField(), | |
"wilkis_lambda": FloatField(), | |
"pillais_trace": FloatField(), | |
"hotelling_tawley": FloatField(), | |
"roys_reatest_roots": FloatField(), | |
"request_id": IntegerField() | |
})) | |
@action(detail=False, methods=['post']) | |
def predict(self, request, format=None): | |
try: | |
classifier = self.request.query_params.get("classifier") | |
region = self.request.query_params.get("dataset", "german") | |
version = self.request.query_params.get("version", "0.0.1") | |
status = self.request.query_params.get("status", "production") | |
print(request) | |
if version is None: | |
raise bad_request( | |
request=request, | |
data={ | |
"error": "Missing required query parameter: version" | |
}) | |
if classifier is None: | |
raise bad_request( | |
request=request, | |
data={ | |
"error": "Missing required query parameter: classifier" | |
}) | |
if classifier in [ | |
'manova', 'linearRegression', 'polynomialRegression' | |
]: | |
prediction = stat_score(request.data, classifier) | |
algorithm = None | |
else: | |
algorithm: Algorithm = Algorithm.objects.filter( | |
classifier=classifier, | |
status=status, | |
version=version, | |
dataset__name=region)[0] | |
if algorithm is None: | |
raise bad_request( | |
request=request, | |
data={"error": "ML algorithm is not available"}) | |
classifier = registry.classifiers[algorithm.id] | |
prediction = classifier.compute_prediction(request.data) | |
if "label" in prediction: | |
label = prediction["label"] | |
else: | |
label = prediction['method'] | |
prediction_request = PredictionRequest(input=json.dumps( | |
request.data), | |
response=prediction, | |
prediction=label, | |
feedback="", | |
algorithm=algorithm) | |
prediction_request.save() | |
prediction["request_id"] = prediction_request.id | |
return Response(prediction) | |
except Exception as e: | |
raise APIException(str(e)) | |
class PredictionRequestViewSet(viewsets.ModelViewSet): | |
# permission_classes = [] | |
serializer_class = PredictionRequestSerializer | |
queryset = PredictionRequest.objects.all() | |
class DatasetViewSet(viewsets.ReadOnlyModelViewSet): | |
# permission_classes = [] | |
serializer_class = DatasetSerializer | |
queryset = Dataset.objects.all() |