blob: 85c2f8a6422cf8677a5f31249525a2ab2971b0be [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
import json
import logging
from statistical_scripts.statistical_scoring import stat_score
from typing import Any, Dict
from drf_spectacular.utils import extend_schema, OpenApiParameter, OpenApiExample, inline_serializer
from rest_framework import viewsets
from rest_framework.decorators import action
from rest_framework.exceptions import APIException, bad_request
from rest_framework.fields import CharField, FloatField, IntegerField
from rest_framework.response import Response
# from rest_framework import permissions
# from rest_framework_api_key.permissions import HasAPIKey
from api.models import Algorithm, Dataset, PredictionRequest
from api.serializers import AlgorithmSerializer, PredictionRequestSerializer, DatasetSerializer
from ml.classifiers import RandomForestClassifier
from server.wsgi import registry
# Create your views here.
log = logging.getLogger(__name__)
class AlgorithmViewSet(viewsets.ModelViewSet):
# permission_classes = []
serializer_class = AlgorithmSerializer
queryset = Algorithm.objects.all()
@extend_schema(
description='Predict credit risk for a loan',
parameters=[
OpenApiParameter(name='classifier',
description='The algorithm/classifier to use',
required=True,
examples=[OpenApiExample('Example 1',
value=RandomForestClassifier().__class__.__name__)]),
OpenApiParameter(name='dataset',
description='The name of the dataset',
examples=[OpenApiExample('Example 1', value='german')]),
OpenApiParameter(name='status',
description='The status of the algorithm',
deprecated=True,
examples=[OpenApiExample('Example 1', value='production')]),
OpenApiParameter(name='version',
description='Algorithm version',
required=True,
default='0.0.1',
examples=[OpenApiExample('Example 1', value='0.0.1')]),
],
operation_id='algorithms_predict',
request=Dict[str, Any],
responses=inline_serializer(name="PredictionResponse",
fields={"probability": FloatField(),
"label": CharField(),
"method": CharField(),
"color": CharField(),
"wilkis_lambda": FloatField(),
"pillais_trace": FloatField(),
"hotelling_tawley": FloatField(),
"roys_reatest_roots": FloatField(),
"request_id": IntegerField()})
)
@action(detail=False, methods=['post'])
def predict(self, request, format=None):
try:
classifier = self.request.query_params.get("classifier")
region = self.request.query_params.get("dataset", "german")
version = self.request.query_params.get("version", "0.0.1")
status = self.request.query_params.get("status", "production")
print(request)
if version is None:
raise bad_request(request=request,
data={"error": "Missing required query parameter: version"})
if classifier is None:
raise bad_request(request=request,
data={"error": "Missing required query parameter: classifier"})
if classifier in ['manova', 'linearRegression', 'polynomialRegression']:
prediction = stat_score(request.data, classifier)
algorithm = None
else:
algorithm: Algorithm = Algorithm.objects.filter(classifier=classifier,
status=status,
version=version,
dataset__name=region)[0]
if algorithm is None:
raise bad_request(request=request,
data={"error": "ML algorithm is not available"})
classifier = registry.classifiers[algorithm.id]
prediction = classifier.compute_prediction(request.data)
if "label" in prediction:
label = prediction["label"]
else:
label = prediction['method']
prediction_request = PredictionRequest(input=json.dumps(request.data),
response=prediction,
prediction=label,
feedback="",
algorithm=algorithm)
prediction_request.save()
prediction["request_id"] = prediction_request.id
return Response(prediction)
except Exception as e:
raise APIException(str(e))
class PredictionRequestViewSet(viewsets.ModelViewSet):
# permission_classes = []
serializer_class = PredictionRequestSerializer
queryset = PredictionRequest.objects.all()
class DatasetViewSet(viewsets.ReadOnlyModelViewSet):
# permission_classes = []
serializer_class = DatasetSerializer
queryset = Dataset.objects.all()