blob: 33ee8c89fd219c13ffdfcdd3ba86673773b4eda9 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
#
import json
import logging
from stats.statistical_scoring import stat_score
from typing import Any, Dict
from drf_spectacular.utils import extend_schema, OpenApiParameter, OpenApiExample, inline_serializer
from rest_framework import viewsets
from rest_framework.decorators import action
from rest_framework.exceptions import APIException, bad_request
from rest_framework.fields import CharField, FloatField, IntegerField
from rest_framework.response import Response
from api.models import Algorithm, Dataset, PredictionRequest
from api.serializers import AlgorithmSerializer, PredictionRequestSerializer, DatasetSerializer
from ml.classifiers import RandomForestClassifier
from server.wsgi import registry
# Create your views here.
log = logging.getLogger(__name__)
class AlgorithmViewSet(viewsets.ModelViewSet):
# permission_classes = []
serializer_class = AlgorithmSerializer
queryset = Algorithm.objects.all()
@extend_schema(
description='Predict credit risk for a loan',
parameters=[
OpenApiParameter(
name='classifier',
description='The algorithm/classifier to use',
required=True,
examples=[
OpenApiExample(
'Example 1',
value=RandomForestClassifier().__class__.__name__)
]),
OpenApiParameter(
name='dataset',
description='The name of the dataset',
examples=[OpenApiExample('Example 1', value='german')]),
OpenApiParameter(
name='status',
description='The status of the algorithm',
deprecated=True,
examples=[OpenApiExample('Example 1', value='production')]),
OpenApiParameter(
name='version',
description='Algorithm version',
required=True,
default='0.0.1',
examples=[OpenApiExample('Example 1', value='0.0.1')]),
],
operation_id='algorithms_predict',
request=Dict[str, Any],
responses=inline_serializer(name="PredictionResponse",
fields={
"probability": FloatField(),
"label": CharField(),
"method": CharField(),
"color": CharField(),
"wilkis_lambda": FloatField(),
"pillais_trace": FloatField(),
"hotelling_tawley": FloatField(),
"roys_reatest_roots": FloatField(),
"request_id": IntegerField()
}))
@action(detail=False, methods=['post'])
def predict(self, request, format=None):
try:
classifier = self.request.query_params.get("classifier")
region = self.request.query_params.get("dataset", "german")
version = self.request.query_params.get("version", "0.0.1")
status = self.request.query_params.get("status", "production")
print(request)
if version is None:
raise bad_request(
request=request,
data={
"error": "Missing required query parameter: version"
})
if classifier is None:
raise bad_request(
request=request,
data={
"error": "Missing required query parameter: classifier"
})
if classifier in [
'manova', 'linearRegression', 'polynomialRegression'
]:
prediction = stat_score(request.data, classifier)
algorithm = None
else:
algorithm: Algorithm = Algorithm.objects.filter(
classifier=classifier,
status=status,
version=version,
dataset__name=region)[0]
if algorithm is None:
raise bad_request(
request=request,
data={"error": "ML algorithm is not available"})
classifier = registry.classifiers[algorithm.id]
prediction = classifier.compute_prediction(request.data)
if "label" in prediction:
label = prediction["label"]
else:
label = prediction['method']
prediction_request = PredictionRequest(input=json.dumps(
request.data),
response=prediction,
prediction=label,
feedback="",
algorithm=algorithm)
prediction_request.save()
prediction["request_id"] = prediction_request.id
return Response(prediction)
except Exception as e:
raise APIException(str(e))
class PredictionRequestViewSet(viewsets.ModelViewSet):
# permission_classes = []
serializer_class = PredictionRequestSerializer
queryset = PredictionRequest.objects.all()
class DatasetViewSet(viewsets.ReadOnlyModelViewSet):
# permission_classes = []
serializer_class = DatasetSerializer
queryset = Dataset.objects.all()