| # |
| # Licensed to the Apache Software Foundation (ASF) under one or more |
| # contributor license agreements. See the NOTICE file distributed with |
| # this work for additional information regarding copyright ownership. |
| # The ASF licenses this file to You under the Apache License, Version 2.0 |
| # (the "License"); you may not use this file except in compliance with |
| # the License. You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, software |
| # distributed under the License is distributed on an "AS IS" BASIS, |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| # See the License for the specific language governing permissions and |
| # limitations under the License. |
| # ============================================================================= |
| |
| import json |
| import random |
| import datetime |
| import pandas as pd |
| from rest_framework.generics import GenericAPIView |
| |
| from server.wsgi import registry |
| |
| from api.models import (ABTest, Endpoint, GermanDataModel, |
| MLAlgorithm, MLAlgorithmStatus, MLRequest) |
| from api.serializers import (ABTestSerializer, EndpointSerializer, |
| GermanDataModelSerializer, GroupSerializer, |
| MLAlgorithmSerializer, MLAlgorithmStatusSerializer, |
| MLRequestSerializer, UserSerializer) |
| |
| from django.contrib.auth.models import Group, User |
| from django.db import transaction |
| from django.db.models import F |
| from django.shortcuts import get_object_or_404, render |
| |
| from rest_framework import viewsets, permissions, status, mixins, views |
| from rest_framework.decorators import action |
| from rest_framework.exceptions import APIException |
| from rest_framework.response import Response |
| |
| from drf_spectacular.utils import extend_schema, OpenApiParameter, OpenApiExample |
| from drf_spectacular.types import OpenApiTypes |
| |
| # Create your views here. |
| |
| class EndpointViewSet(viewsets.ReadOnlyModelViewSet): |
| serializer_class = EndpointSerializer |
| queryset = Endpoint.objects.all() |
| |
| |
| class MLAlgorithmViewSet(viewsets.ReadOnlyModelViewSet): |
| serializer_class = MLAlgorithmSerializer |
| queryset = MLAlgorithm.objects.all() |
| |
| |
| def deactivate_other_statuses(instance): |
| old_statuses = MLAlgorithmStatus.objects.filter(parent_mlalgorithm = instance.parent_mlalgorithm, |
| created_at__lt=instance.created_at, |
| active=True) |
| for i in range(len(old_statuses)): |
| old_statuses[i].active = False |
| MLAlgorithmStatus.objects.bulk_update(old_statuses, ["active"]) |
| |
| class MLAlgorithmStatusViewSet( |
| mixins.RetrieveModelMixin, mixins.ListModelMixin, viewsets.GenericViewSet, |
| mixins.CreateModelMixin |
| ): |
| serializer_class = MLAlgorithmStatusSerializer |
| queryset = MLAlgorithmStatus.objects.all() |
| def perform_create(self, serializer): |
| try: |
| with transaction.atomic(): |
| instance = serializer.save(active=True) |
| # set active=False for other statuses |
| deactivate_other_statuses(instance) |
| |
| |
| |
| except Exception as e: |
| raise APIException(str(e)) |
| |
| class MLRequestViewSet( |
| mixins.RetrieveModelMixin, mixins.ListModelMixin, viewsets.GenericViewSet, |
| mixins.UpdateModelMixin |
| ): |
| serializer_class = MLRequestSerializer |
| queryset = MLRequest.objects.all() |
| |
| class PredictView(GenericAPIView): |
| queryset = GermanDataModel.objects.all() |
| serializer_class = GermanDataModelSerializer |
| |
| @extend_schema( |
| description='Predict credit risk for a loan', |
| parameters=[ |
| OpenApiParameter( |
| name='classifier', |
| description='Name of algorithm classifier', |
| required=False, |
| type=str, |
| examples=[ |
| OpenApiExample( |
| 'Example 1', |
| summary='Random Forest Classifier', |
| value='random_forest' |
| ), |
| ], |
| ), |
| OpenApiParameter( |
| name='status', |
| description='The status of the algorithm', |
| required=False, |
| type=str, |
| examples=[ |
| OpenApiExample( |
| 'Example 1', |
| summary='Algorithm in production', |
| value='production' |
| ), |
| ], |
| ), |
| OpenApiParameter(name='version', description='Algorithm version', required=True, type=str), |
| ], |
| ) |
| def post(self, request, endpoint_name: str, format=None): |
| algorithm_version = self.request.query_params.get("version") |
| algorithm_status = self.request.query_params.get("status", "production") |
| algorithm_classifier = self.request.query_params.get("classifier", "random_forest") |
| |
| algs = MLAlgorithm.objects.filter(parent_endpoint__name = endpoint_name, |
| status__status = algorithm_status, |
| status__active = True) |
| |
| |
| if algorithm_version is not None: |
| algs = algs.filter(version = algorithm_version) |
| |
| num_algs = len(algs) |
| |
| if num_algs == 0: |
| return Response( |
| {"status": "Error", "message": "ML algorithm is not available"}, |
| status=status.HTTP_400_BAD_REQUEST) |
| |
| if num_algs != 1 and algorithm_status != "ab_testing": |
| return Response( |
| {"status": "Error", "message": "ML algorithm selection is ambiguous. Please specify algorithm version."}, |
| status=status.HTTP_400_BAD_REQUEST) |
| |
| alg_index = 0 |
| if algorithm_status == "ab_testing": |
| alg_index = random.randrange(num_algs) |
| else: |
| algs = algs.filter(parent_endpoint__classifier = algorithm_classifier) |
| |
| algorithm = algs[alg_index] |
| algorithm_object = registry.endpoints[algorithm.id] |
| prediction = algorithm_object.compute_prediction(request.data) |
| |
| label = prediction["label"] if "label" in prediction else "error" |
| ml_request = MLRequest( |
| input_data=json.dumps(request.data), |
| full_response=prediction, |
| response=label, |
| feedback="", |
| parent_mlalgorithm=algs[alg_index], |
| ) |
| ml_request.save() |
| |
| prediction["request_id"] = ml_request.id |
| |
| return Response(prediction) |
| |
| |
| |
| class ABTestViewSet( |
| mixins.RetrieveModelMixin, mixins.ListModelMixin, viewsets.GenericViewSet, |
| mixins.CreateModelMixin, mixins.UpdateModelMixin |
| ): |
| serializer_class = ABTestSerializer |
| queryset = ABTest.objects.all() |
| |
| def perform_create(self, serializer): |
| try: |
| with transaction.atomic(): |
| instance = serializer.save() |
| # update status for first algorithm |
| |
| status_1 = MLAlgorithmStatus(status = "ab_testing", |
| created_by=instance.created_by, |
| parent_mlalgorithm = instance.parent_mlalgorithm_1, |
| active=True) |
| status_1.save() |
| deactivate_other_statuses(status_1) |
| # update status for second algorithm |
| status_2 = MLAlgorithmStatus(status = "ab_testing", |
| created_by=instance.created_by, |
| parent_mlalgorithm = instance.parent_mlalgorithm_2, |
| active=True) |
| status_2.save() |
| deactivate_other_statuses(status_2) |
| |
| except Exception as e: |
| raise APIException(str(e)) |
| |
| class StopABTestView(GenericAPIView): |
| serializer_class = ABTestSerializer |
| queryset = ABTest.objects.all() |
| |
| def post(self, request, ab_test_id: int, format=None): |
| |
| try: |
| ab_test = ABTest.objects.get(pk=ab_test_id) |
| |
| if ab_test.ended_at is not None: |
| return Response({"message": "AB Test already finished."}) |
| |
| date_now = datetime.datetime.now() |
| # alg #1 accuracy |
| all_responses_1 = MLRequest.objects.filter(parent_mlalgorithm=ab_test.parent_mlalgorithm_1, created_at__gt = ab_test.created_at, created_at__lt = date_now).count() |
| correct_responses_1 = MLRequest.objects.filter(parent_mlalgorithm=ab_test.parent_mlalgorithm_1, created_at__gt = ab_test.created_at, created_at__lt = date_now, response=F('feedback')).count() |
| accuracy_1 = correct_responses_1 / float(all_responses_1) |
| print(all_responses_1, correct_responses_1, accuracy_1) |
| |
| # alg #2 accuracy |
| all_responses_2 = MLRequest.objects.filter(parent_mlalgorithm=ab_test.parent_mlalgorithm_2, created_at__gt = ab_test.created_at, created_at__lt = date_now).count() |
| correct_responses_2 = MLRequest.objects.filter(parent_mlalgorithm=ab_test.parent_mlalgorithm_2, created_at__gt = ab_test.created_at, created_at__lt = date_now, response=F('feedback')).count() |
| accuracy_2 = correct_responses_2 / float(all_responses_2) |
| print(all_responses_2, correct_responses_2, accuracy_2) |
| |
| # select algorithm with higher accuracy |
| alg_id_1, alg_id_2 = ab_test.parent_mlalgorithm_1, ab_test.parent_mlalgorithm_2 |
| # swap |
| if accuracy_1 < accuracy_2: |
| alg_id_1, alg_id_2 = alg_id_2, alg_id_1 |
| |
| status_1 = MLAlgorithmStatus(status = "production", |
| created_by=ab_test.created_by, |
| parent_mlalgorithm = alg_id_1, |
| active=True) |
| status_1.save() |
| deactivate_other_statuses(status_1) |
| # update status for second algorithm |
| status_2 = MLAlgorithmStatus(status = "testing", |
| created_by=ab_test.created_by, |
| parent_mlalgorithm = alg_id_2, |
| active=True) |
| status_2.save() |
| deactivate_other_statuses(status_2) |
| |
| |
| summary = "Algorithm #1 accuracy: {}, Algorithm #2 accuracy: {}".format(accuracy_1, accuracy_2) |
| ab_test.ended_at = date_now |
| ab_test.summary = summary |
| ab_test.save() |
| |
| except Exception as e: |
| return Response({"status": "Error", "message": str(e)}, |
| status=status.HTTP_400_BAD_REQUEST |
| ) |
| return Response({"message": "AB Test finished.", "summary": summary}) |