Implement and test API endpoints for API models
diff --git a/api/tests.py b/api/tests.py
index 9f4bf12..5e13c78 100644
--- a/api/tests.py
+++ b/api/tests.py
@@ -13,8 +13,32 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-#
+# =============================================================================
from django.test import TestCase
+from rest_framework.test import APIClient
-# Create your tests here.
+
+test_data = {
+ "Age": 22,
+ "Sex": "female",
+ "Job": 2,
+ "Housing": "own",
+ "Credit amount": 5951,
+ "Duration": 48,
+ "Purpose": "radio/TV"
+}
+
+expected_output = 'bad'
+
+class EndpointTests(TestCase):
+
+ def test_predict_view(self):
+ client = APIClient()
+
+ classifier_url = "/api/v1/credit_scoring/predict"
+ response = client.post(classifier_url, test_data, format='json')
+ self.assertEqual(response.status_code, 200)
+ self.assertEqual(response.data["label"], expected_output)
+ self.assertTrue("request_id" in response.data)
+ self.assertTrue("status" in response.data)
diff --git a/api/urls.py b/api/urls.py
index 3fb1471..cd4c90c 100644
--- a/api/urls.py
+++ b/api/urls.py
@@ -13,22 +13,28 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-#
+# =============================================================================
"""
-Definition of urls for scorecardapp.
+Definition of urls for api resource.
"""
+from django.conf.urls import url
from django.urls import path, include
from rest_framework import routers
from drf_spectacular.views import SpectacularAPIView, SpectacularRedocView, SpectacularSwaggerView
-from api import views as apiViews
+
+from api.views import (ABTestViewSet, EndpointViewSet,
+ MLAlgorithmStatusViewSet, MLAlgorithmViewSet,
+ MLRequestViewSet, PredictView, StopABTestView)
-router = routers.DefaultRouter()
-router.register(r'users', apiViews.UserViewSet)
-router.register(r'groups', apiViews.GroupViewSet)
-router.register(r'scorecard', apiViews.ScorecardViewSet, basename='scorecard')
+router = routers.DefaultRouter(trailing_slash=False)
+router.register(r"endpoints", EndpointViewSet, basename="endpoints")
+router.register(r"mlalgorithms", MLAlgorithmViewSet, basename="mlalgorithms")
+router.register(r"mlalgorithmstatuses", MLAlgorithmStatusViewSet, basename="mlalgorithmstatuses")
+router.register(r"mlrequests", MLRequestViewSet, basename="mlrequests")
+router.register(r"abtests", ABTestViewSet, basename="abtests")
urlpatterns = [
@@ -39,8 +45,10 @@
path('api/docs/redoc', SpectacularRedocView.as_view(url_name='schema'), name='redoc'),
# API Views
- # Wire up our API using automatic URL routing.
- # Additionally, we include login URLs for the browsable API.
path('api/', include(router.urls)),
- path('api-auth/', include('rest_framework.urls', namespace='rest_framework'))
+ path('api-auth/', include('rest_framework.urls', namespace='rest_framework')),
+
+ url(r"^api/v1/(?P<endpoint_name>.+)/predict$", PredictView.as_view(), name="predict"),
+ url(r"^api/v1/stop_ab_test/(?P<ab_test_id>.+)", StopABTestView.as_view(), name="stop_ab"),
]
+
\ No newline at end of file
diff --git a/api/views.py b/api/views.py
index fbdb921..6668229 100644
--- a/api/views.py
+++ b/api/views.py
@@ -13,61 +13,251 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
-#
+# =============================================================================
-from api.models import GermanDataModel
-from api.serializers import GermanDataModelSerializer, GroupSerializer, UserSerializer
+import json
+import random
+import datetime
+import pandas as pd
+from rest_framework.generics import GenericAPIView
+
+from server.wsgi import registry
+
+from api.models import (ABTest, Endpoint, GermanDataModel,
+ MLAlgorithm, MLAlgorithmStatus, MLRequest)
+from api.serializers import (ABTestSerializer, EndpointSerializer,
+ GermanDataModelSerializer, GroupSerializer,
+ MLAlgorithmSerializer, MLAlgorithmStatusSerializer,
+ MLRequestSerializer, UserSerializer)
+
from django.contrib.auth.models import Group, User
-from rest_framework.decorators import action
-from rest_framework.response import Response
-from rest_framework import viewsets, permissions, status
+from django.db import transaction
+from django.db.models import F
from django.shortcuts import get_object_or_404, render
+from rest_framework import viewsets, permissions, status, mixins, views
+from rest_framework.decorators import action
+from rest_framework.exceptions import APIException
+from rest_framework.response import Response
+
+from drf_spectacular.utils import extend_schema, OpenApiParameter, OpenApiExample
+from drf_spectacular.types import OpenApiTypes
+
# Create your views here.
-class UserViewSet(viewsets.ModelViewSet):
- """
- API endpoint that allows users to be viewed or edited.
- """
- queryset = User.objects.all().order_by('-date_joined')
- serializer_class = UserSerializer
- permission_classes = []
+class EndpointViewSet(viewsets.ReadOnlyModelViewSet):
+ serializer_class = EndpointSerializer
+ queryset = Endpoint.objects.all()
-class GroupViewSet(viewsets.ModelViewSet):
- """
- API endpoint that allows groups to be viewed or edited.
- """
- queryset = Group.objects.all()
- serializer_class = GroupSerializer
- permission_classes = [permissions.IsAuthenticated]
+class MLAlgorithmViewSet(viewsets.ReadOnlyModelViewSet):
+ serializer_class = MLAlgorithmSerializer
+ queryset = MLAlgorithm.objects.all()
-class ScorecardViewSet(viewsets.ViewSet):
- """
- List all german data, or create a new GermanDataModel.
- """
- def list(self, request):
- queryset = GermanDataModel.objects.all()
- serializer = GermanDataModelSerializer(queryset, many=True)
- return Response(serializer.data)
- def retrieve(self, request, pk=None):
- queryset = GermanDataModel.objects.all()
- user = get_object_or_404(queryset, pk=pk)
- serializer = GermanDataModelSerializer(user)
- return Response(serializer.data)
+def deactivate_other_statuses(instance):
+ old_statuses = MLAlgorithmStatus.objects.filter(parent_mlalgorithm = instance.parent_mlalgorithm,
+ created_at__lt=instance.created_at,
+ active=True)
+ for i in range(len(old_statuses)):
+ old_statuses[i].active = False
+ MLAlgorithmStatus.objects.bulk_update(old_statuses, ["active"])
+
+class MLAlgorithmStatusViewSet(
+ mixins.RetrieveModelMixin, mixins.ListModelMixin, viewsets.GenericViewSet,
+ mixins.CreateModelMixin
+):
+ serializer_class = MLAlgorithmStatusSerializer
+ queryset = MLAlgorithmStatus.objects.all()
+ def perform_create(self, serializer):
+ try:
+ with transaction.atomic():
+ instance = serializer.save(active=True)
+ # set active=False for other statuses
+ deactivate_other_statuses(instance)
+
+
+
+ except Exception as e:
+ raise APIException(str(e))
+
+class MLRequestViewSet(
+ mixins.RetrieveModelMixin, mixins.ListModelMixin, viewsets.GenericViewSet,
+ mixins.UpdateModelMixin
+):
+ serializer_class = MLRequestSerializer
+ queryset = MLRequest.objects.all()
+
+class PredictView(GenericAPIView):
+ queryset = GermanDataModel.objects.all()
+ serializer_class = GermanDataModelSerializer
+
+ @extend_schema(
+ description='Predict credit risk for a loan',
+ parameters=[
+ OpenApiParameter(
+ name='classifier',
+ description='Name of algorithm classifier',
+ required=False,
+ type=str,
+ examples=[
+ OpenApiExample(
+ 'Example 1',
+ summary='Random Forest Classifier',
+ value='random_forest'
+ ),
+ ],
+ ),
+ OpenApiParameter(
+ name='status',
+ description='The status of the algorithm',
+ required=False,
+ type=str,
+ examples=[
+ OpenApiExample(
+ 'Example 1',
+ summary='Algorithm in production',
+ value='production'
+ ),
+ ],
+ ),
+ OpenApiParameter(name='version', description='Algorithm version', required=True, type=str),
+ ],
+ )
+ def post(self, request, endpoint_name: str, format=None):
+ algorithm_version = self.request.query_params.get("version")
+ algorithm_status = self.request.query_params.get("status", "production")
+ algorithm_classifier = self.request.query_params.get("classifier", "random_forest")
- # def get(self, request, format=None):
- # snippets = GermanDataModel.objects.all()
- # serializer = GermanDataModelSerializer(snippets, many=True)
- # return Response(serializer.data)
+ algs = MLAlgorithm.objects.filter(parent_endpoint__name = endpoint_name,
+ status__status = algorithm_status,
+ status__active = True)
- @action(detail=True, methods=['post'])
- def predict(self, request, format=None):
- data = self.get_object()
- queryset = GermanDataModel.objects.all()
- serializer = GermanDataModelSerializer(data=request.data)
- if serializer.is_valid():
- serializer.save()
- return Response(serializer.data, status=status.HTTP_201_CREATED)
- return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
+
+ if algorithm_version is not None:
+ algs = algs.filter(version = algorithm_version)
+
+ num_algs = len(algs)
+
+ if num_algs == 0:
+ return Response(
+ {"status": "Error", "message": "ML algorithm is not available"},
+ status=status.HTTP_400_BAD_REQUEST)
+
+ if num_algs != 1 and algorithm_status != "ab_testing":
+ return Response(
+ {"status": "Error", "message": "ML algorithm selection is ambiguous. Please specify algorithm version."},
+ status=status.HTTP_400_BAD_REQUEST)
+
+ alg_index = 0
+ if algorithm_status == "ab_testing":
+ alg_index = random.randrange(num_algs)
+ else:
+ algs = algs.filter(parent_endpoint__classifier = algorithm_classifier)
+
+ algorithm = algs[alg_index]
+ algorithm_object = registry.endpoints[algorithm.id]
+ prediction = algorithm_object.compute_prediction(request.data)
+
+ label = prediction["label"] if "label" in prediction else "error"
+ ml_request = MLRequest(
+ input_data=json.dumps(request.data),
+ full_response=prediction,
+ response=label,
+ feedback="",
+ parent_mlalgorithm=algs[alg_index],
+ )
+ ml_request.save()
+
+ prediction["request_id"] = ml_request.id
+
+ return Response(prediction)
+
+
+
+class ABTestViewSet(
+ mixins.RetrieveModelMixin, mixins.ListModelMixin, viewsets.GenericViewSet,
+ mixins.CreateModelMixin, mixins.UpdateModelMixin
+):
+ serializer_class = ABTestSerializer
+ queryset = ABTest.objects.all()
+
+ def perform_create(self, serializer):
+ try:
+ with transaction.atomic():
+ instance = serializer.save()
+ # update status for first algorithm
+
+ status_1 = MLAlgorithmStatus(status = "ab_testing",
+ created_by=instance.created_by,
+ parent_mlalgorithm = instance.parent_mlalgorithm_1,
+ active=True)
+ status_1.save()
+ deactivate_other_statuses(status_1)
+ # update status for second algorithm
+ status_2 = MLAlgorithmStatus(status = "ab_testing",
+ created_by=instance.created_by,
+ parent_mlalgorithm = instance.parent_mlalgorithm_2,
+ active=True)
+ status_2.save()
+ deactivate_other_statuses(status_2)
+
+ except Exception as e:
+ raise APIException(str(e))
+
+class StopABTestView(GenericAPIView):
+ serializer_class = ABTestSerializer
+ queryset = ABTest.objects.all()
+
+ def post(self, request, ab_test_id: int, format=None):
+
+ try:
+ ab_test = ABTest.objects.get(pk=ab_test_id)
+
+ if ab_test.ended_at is not None:
+ return Response({"message": "AB Test already finished."})
+
+ date_now = datetime.datetime.now()
+ # alg #1 accuracy
+ all_responses_1 = MLRequest.objects.filter(parent_mlalgorithm=ab_test.parent_mlalgorithm_1, created_at__gt = ab_test.created_at, created_at__lt = date_now).count()
+ correct_responses_1 = MLRequest.objects.filter(parent_mlalgorithm=ab_test.parent_mlalgorithm_1, created_at__gt = ab_test.created_at, created_at__lt = date_now, response=F('feedback')).count()
+ accuracy_1 = correct_responses_1 / float(all_responses_1)
+ print(all_responses_1, correct_responses_1, accuracy_1)
+
+ # alg #2 accuracy
+ all_responses_2 = MLRequest.objects.filter(parent_mlalgorithm=ab_test.parent_mlalgorithm_2, created_at__gt = ab_test.created_at, created_at__lt = date_now).count()
+ correct_responses_2 = MLRequest.objects.filter(parent_mlalgorithm=ab_test.parent_mlalgorithm_2, created_at__gt = ab_test.created_at, created_at__lt = date_now, response=F('feedback')).count()
+ accuracy_2 = correct_responses_2 / float(all_responses_2)
+ print(all_responses_2, correct_responses_2, accuracy_2)
+
+ # select algorithm with higher accuracy
+ alg_id_1, alg_id_2 = ab_test.parent_mlalgorithm_1, ab_test.parent_mlalgorithm_2
+ # swap
+ if accuracy_1 < accuracy_2:
+ alg_id_1, alg_id_2 = alg_id_2, alg_id_1
+
+ status_1 = MLAlgorithmStatus(status = "production",
+ created_by=ab_test.created_by,
+ parent_mlalgorithm = alg_id_1,
+ active=True)
+ status_1.save()
+ deactivate_other_statuses(status_1)
+ # update status for second algorithm
+ status_2 = MLAlgorithmStatus(status = "testing",
+ created_by=ab_test.created_by,
+ parent_mlalgorithm = alg_id_2,
+ active=True)
+ status_2.save()
+ deactivate_other_statuses(status_2)
+
+
+ summary = "Algorithm #1 accuracy: {}, Algorithm #2 accuracy: {}".format(accuracy_1, accuracy_2)
+ ab_test.ended_at = date_now
+ ab_test.summary = summary
+ ab_test.save()
+
+ except Exception as e:
+ return Response({"status": "Error", "message": str(e)},
+ status=status.HTTP_400_BAD_REQUEST
+ )
+ return Response({"message": "AB Test finished.", "summary": summary})