blob: 66682296f23943271039d4515b9e98954639c591 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# =============================================================================
import json
import random
import datetime
import pandas as pd
from rest_framework.generics import GenericAPIView
from server.wsgi import registry
from api.models import (ABTest, Endpoint, GermanDataModel,
MLAlgorithm, MLAlgorithmStatus, MLRequest)
from api.serializers import (ABTestSerializer, EndpointSerializer,
GermanDataModelSerializer, GroupSerializer,
MLAlgorithmSerializer, MLAlgorithmStatusSerializer,
MLRequestSerializer, UserSerializer)
from django.contrib.auth.models import Group, User
from django.db import transaction
from django.db.models import F
from django.shortcuts import get_object_or_404, render
from rest_framework import viewsets, permissions, status, mixins, views
from rest_framework.decorators import action
from rest_framework.exceptions import APIException
from rest_framework.response import Response
from drf_spectacular.utils import extend_schema, OpenApiParameter, OpenApiExample
from drf_spectacular.types import OpenApiTypes
# Create your views here.
class EndpointViewSet(viewsets.ReadOnlyModelViewSet):
serializer_class = EndpointSerializer
queryset = Endpoint.objects.all()
class MLAlgorithmViewSet(viewsets.ReadOnlyModelViewSet):
serializer_class = MLAlgorithmSerializer
queryset = MLAlgorithm.objects.all()
def deactivate_other_statuses(instance):
old_statuses = MLAlgorithmStatus.objects.filter(parent_mlalgorithm = instance.parent_mlalgorithm,
created_at__lt=instance.created_at,
active=True)
for i in range(len(old_statuses)):
old_statuses[i].active = False
MLAlgorithmStatus.objects.bulk_update(old_statuses, ["active"])
class MLAlgorithmStatusViewSet(
mixins.RetrieveModelMixin, mixins.ListModelMixin, viewsets.GenericViewSet,
mixins.CreateModelMixin
):
serializer_class = MLAlgorithmStatusSerializer
queryset = MLAlgorithmStatus.objects.all()
def perform_create(self, serializer):
try:
with transaction.atomic():
instance = serializer.save(active=True)
# set active=False for other statuses
deactivate_other_statuses(instance)
except Exception as e:
raise APIException(str(e))
class MLRequestViewSet(
mixins.RetrieveModelMixin, mixins.ListModelMixin, viewsets.GenericViewSet,
mixins.UpdateModelMixin
):
serializer_class = MLRequestSerializer
queryset = MLRequest.objects.all()
class PredictView(GenericAPIView):
queryset = GermanDataModel.objects.all()
serializer_class = GermanDataModelSerializer
@extend_schema(
description='Predict credit risk for a loan',
parameters=[
OpenApiParameter(
name='classifier',
description='Name of algorithm classifier',
required=False,
type=str,
examples=[
OpenApiExample(
'Example 1',
summary='Random Forest Classifier',
value='random_forest'
),
],
),
OpenApiParameter(
name='status',
description='The status of the algorithm',
required=False,
type=str,
examples=[
OpenApiExample(
'Example 1',
summary='Algorithm in production',
value='production'
),
],
),
OpenApiParameter(name='version', description='Algorithm version', required=True, type=str),
],
)
def post(self, request, endpoint_name: str, format=None):
algorithm_version = self.request.query_params.get("version")
algorithm_status = self.request.query_params.get("status", "production")
algorithm_classifier = self.request.query_params.get("classifier", "random_forest")
algs = MLAlgorithm.objects.filter(parent_endpoint__name = endpoint_name,
status__status = algorithm_status,
status__active = True)
if algorithm_version is not None:
algs = algs.filter(version = algorithm_version)
num_algs = len(algs)
if num_algs == 0:
return Response(
{"status": "Error", "message": "ML algorithm is not available"},
status=status.HTTP_400_BAD_REQUEST)
if num_algs != 1 and algorithm_status != "ab_testing":
return Response(
{"status": "Error", "message": "ML algorithm selection is ambiguous. Please specify algorithm version."},
status=status.HTTP_400_BAD_REQUEST)
alg_index = 0
if algorithm_status == "ab_testing":
alg_index = random.randrange(num_algs)
else:
algs = algs.filter(parent_endpoint__classifier = algorithm_classifier)
algorithm = algs[alg_index]
algorithm_object = registry.endpoints[algorithm.id]
prediction = algorithm_object.compute_prediction(request.data)
label = prediction["label"] if "label" in prediction else "error"
ml_request = MLRequest(
input_data=json.dumps(request.data),
full_response=prediction,
response=label,
feedback="",
parent_mlalgorithm=algs[alg_index],
)
ml_request.save()
prediction["request_id"] = ml_request.id
return Response(prediction)
class ABTestViewSet(
mixins.RetrieveModelMixin, mixins.ListModelMixin, viewsets.GenericViewSet,
mixins.CreateModelMixin, mixins.UpdateModelMixin
):
serializer_class = ABTestSerializer
queryset = ABTest.objects.all()
def perform_create(self, serializer):
try:
with transaction.atomic():
instance = serializer.save()
# update status for first algorithm
status_1 = MLAlgorithmStatus(status = "ab_testing",
created_by=instance.created_by,
parent_mlalgorithm = instance.parent_mlalgorithm_1,
active=True)
status_1.save()
deactivate_other_statuses(status_1)
# update status for second algorithm
status_2 = MLAlgorithmStatus(status = "ab_testing",
created_by=instance.created_by,
parent_mlalgorithm = instance.parent_mlalgorithm_2,
active=True)
status_2.save()
deactivate_other_statuses(status_2)
except Exception as e:
raise APIException(str(e))
class StopABTestView(GenericAPIView):
serializer_class = ABTestSerializer
queryset = ABTest.objects.all()
def post(self, request, ab_test_id: int, format=None):
try:
ab_test = ABTest.objects.get(pk=ab_test_id)
if ab_test.ended_at is not None:
return Response({"message": "AB Test already finished."})
date_now = datetime.datetime.now()
# alg #1 accuracy
all_responses_1 = MLRequest.objects.filter(parent_mlalgorithm=ab_test.parent_mlalgorithm_1, created_at__gt = ab_test.created_at, created_at__lt = date_now).count()
correct_responses_1 = MLRequest.objects.filter(parent_mlalgorithm=ab_test.parent_mlalgorithm_1, created_at__gt = ab_test.created_at, created_at__lt = date_now, response=F('feedback')).count()
accuracy_1 = correct_responses_1 / float(all_responses_1)
print(all_responses_1, correct_responses_1, accuracy_1)
# alg #2 accuracy
all_responses_2 = MLRequest.objects.filter(parent_mlalgorithm=ab_test.parent_mlalgorithm_2, created_at__gt = ab_test.created_at, created_at__lt = date_now).count()
correct_responses_2 = MLRequest.objects.filter(parent_mlalgorithm=ab_test.parent_mlalgorithm_2, created_at__gt = ab_test.created_at, created_at__lt = date_now, response=F('feedback')).count()
accuracy_2 = correct_responses_2 / float(all_responses_2)
print(all_responses_2, correct_responses_2, accuracy_2)
# select algorithm with higher accuracy
alg_id_1, alg_id_2 = ab_test.parent_mlalgorithm_1, ab_test.parent_mlalgorithm_2
# swap
if accuracy_1 < accuracy_2:
alg_id_1, alg_id_2 = alg_id_2, alg_id_1
status_1 = MLAlgorithmStatus(status = "production",
created_by=ab_test.created_by,
parent_mlalgorithm = alg_id_1,
active=True)
status_1.save()
deactivate_other_statuses(status_1)
# update status for second algorithm
status_2 = MLAlgorithmStatus(status = "testing",
created_by=ab_test.created_by,
parent_mlalgorithm = alg_id_2,
active=True)
status_2.save()
deactivate_other_statuses(status_2)
summary = "Algorithm #1 accuracy: {}, Algorithm #2 accuracy: {}".format(accuracy_1, accuracy_2)
ab_test.ended_at = date_now
ab_test.summary = summary
ab_test.save()
except Exception as e:
return Response({"status": "Error", "message": str(e)},
status=status.HTTP_400_BAD_REQUEST
)
return Response({"message": "AB Test finished.", "summary": summary})