Implement and test API endpoints for API models
diff --git a/api/tests.py b/api/tests.py
index 9f4bf12..5e13c78 100644
--- a/api/tests.py
+++ b/api/tests.py
@@ -13,8 +13,32 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-#
+# =============================================================================
 
 from django.test import TestCase
+from rest_framework.test import APIClient
 
-# Create your tests here.
+
+test_data = {
+    "Age": 22,
+    "Sex": "female",
+    "Job": 2,
+    "Housing": "own",
+    "Credit amount": 5951,
+    "Duration": 48,
+    "Purpose": "radio/TV"
+}
+
+expected_output = 'bad'
+
+class EndpointTests(TestCase):
+
+    def test_predict_view(self):
+        client = APIClient()
+        
+        classifier_url = "/api/v1/credit_scoring/predict"
+        response = client.post(classifier_url, test_data, format='json')
+        self.assertEqual(response.status_code, 200)
+        self.assertEqual(response.data["label"], expected_output)
+        self.assertTrue("request_id" in response.data)
+        self.assertTrue("status" in response.data)
diff --git a/api/urls.py b/api/urls.py
index 3fb1471..cd4c90c 100644
--- a/api/urls.py
+++ b/api/urls.py
@@ -13,22 +13,28 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-#
+# =============================================================================
 
 """
-Definition of urls for scorecardapp.
+Definition of urls for api resource.
 """
 
+from django.conf.urls import url
 from django.urls import path, include
 from rest_framework import routers
 from drf_spectacular.views import SpectacularAPIView, SpectacularRedocView, SpectacularSwaggerView
-from api import views as apiViews
+
+from api.views import (ABTestViewSet, EndpointViewSet,
+                       MLAlgorithmStatusViewSet, MLAlgorithmViewSet, 
+                       MLRequestViewSet, PredictView, StopABTestView)
 
 
-router = routers.DefaultRouter()
-router.register(r'users', apiViews.UserViewSet)
-router.register(r'groups', apiViews.GroupViewSet)
-router.register(r'scorecard', apiViews.ScorecardViewSet, basename='scorecard')
+router = routers.DefaultRouter(trailing_slash=False)
+router.register(r"endpoints", EndpointViewSet, basename="endpoints")
+router.register(r"mlalgorithms", MLAlgorithmViewSet, basename="mlalgorithms")
+router.register(r"mlalgorithmstatuses", MLAlgorithmStatusViewSet, basename="mlalgorithmstatuses")
+router.register(r"mlrequests", MLRequestViewSet, basename="mlrequests")
+router.register(r"abtests", ABTestViewSet, basename="abtests")
 
 
 urlpatterns = [
@@ -39,8 +45,10 @@
     path('api/docs/redoc', SpectacularRedocView.as_view(url_name='schema'), name='redoc'),
 
     # API Views
-    # Wire up our API using automatic URL routing.
-    # Additionally, we include login URLs for the browsable API.
     path('api/', include(router.urls)),
-    path('api-auth/', include('rest_framework.urls', namespace='rest_framework'))
+    path('api-auth/', include('rest_framework.urls', namespace='rest_framework')),
+
+    url(r"^api/v1/(?P<endpoint_name>.+)/predict$", PredictView.as_view(), name="predict"),
+    url(r"^api/v1/stop_ab_test/(?P<ab_test_id>.+)", StopABTestView.as_view(), name="stop_ab"),
 ]
+ 
\ No newline at end of file
diff --git a/api/views.py b/api/views.py
index fbdb921..6668229 100644
--- a/api/views.py
+++ b/api/views.py
@@ -13,61 +13,251 @@
 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 # See the License for the specific language governing permissions and
 # limitations under the License.
-#
+# =============================================================================
 
-from api.models import GermanDataModel
-from api.serializers import GermanDataModelSerializer, GroupSerializer, UserSerializer
+import json
+import random
+import datetime
+import pandas as pd
+from rest_framework.generics import GenericAPIView
+
+from server.wsgi import registry
+
+from api.models import (ABTest, Endpoint, GermanDataModel,
+                        MLAlgorithm, MLAlgorithmStatus, MLRequest)
+from api.serializers import (ABTestSerializer, EndpointSerializer,
+                             GermanDataModelSerializer, GroupSerializer,
+                             MLAlgorithmSerializer, MLAlgorithmStatusSerializer,
+                             MLRequestSerializer, UserSerializer)
+
 from django.contrib.auth.models import Group, User
-from rest_framework.decorators import action
-from rest_framework.response import Response
-from rest_framework import viewsets, permissions, status
+from django.db import transaction
+from django.db.models import F
 from django.shortcuts import get_object_or_404, render
 
+from rest_framework import viewsets, permissions, status, mixins, views
+from rest_framework.decorators import action
+from rest_framework.exceptions import APIException
+from rest_framework.response import Response
+
+from drf_spectacular.utils import extend_schema, OpenApiParameter, OpenApiExample
+from drf_spectacular.types import OpenApiTypes
+
 # Create your views here.
 
-class UserViewSet(viewsets.ModelViewSet):
-    """
-    API endpoint that allows users to be viewed or edited.
-    """
-    queryset = User.objects.all().order_by('-date_joined')
-    serializer_class = UserSerializer
-    permission_classes = []
+class EndpointViewSet(viewsets.ReadOnlyModelViewSet):
+    serializer_class = EndpointSerializer
+    queryset = Endpoint.objects.all()
 
 
-class GroupViewSet(viewsets.ModelViewSet):
-    """
-    API endpoint that allows groups to be viewed or edited.
-    """
-    queryset = Group.objects.all()
-    serializer_class = GroupSerializer
-    permission_classes = [permissions.IsAuthenticated]
+class MLAlgorithmViewSet(viewsets.ReadOnlyModelViewSet):
+    serializer_class = MLAlgorithmSerializer
+    queryset = MLAlgorithm.objects.all()
 
-class ScorecardViewSet(viewsets.ViewSet):
-    """
-    List all german data, or create a new GermanDataModel.
-    """
-    def list(self, request):
-        queryset = GermanDataModel.objects.all()
-        serializer = GermanDataModelSerializer(queryset, many=True)
-        return Response(serializer.data)
 
-    def retrieve(self, request, pk=None):
-        queryset = GermanDataModel.objects.all()
-        user = get_object_or_404(queryset, pk=pk)
-        serializer = GermanDataModelSerializer(user)
-        return Response(serializer.data)
+def deactivate_other_statuses(instance):
+    old_statuses = MLAlgorithmStatus.objects.filter(parent_mlalgorithm = instance.parent_mlalgorithm,
+                                                        created_at__lt=instance.created_at,
+                                                        active=True)
+    for i in range(len(old_statuses)):
+        old_statuses[i].active = False
+    MLAlgorithmStatus.objects.bulk_update(old_statuses, ["active"])
+
+class MLAlgorithmStatusViewSet(
+    mixins.RetrieveModelMixin, mixins.ListModelMixin, viewsets.GenericViewSet,
+    mixins.CreateModelMixin
+):
+    serializer_class = MLAlgorithmStatusSerializer
+    queryset = MLAlgorithmStatus.objects.all()
+    def perform_create(self, serializer):
+        try:
+            with transaction.atomic():
+                instance = serializer.save(active=True)
+                # set active=False for other statuses
+                deactivate_other_statuses(instance)
+
+
+
+        except Exception as e:
+            raise APIException(str(e))
+
+class MLRequestViewSet(
+    mixins.RetrieveModelMixin, mixins.ListModelMixin, viewsets.GenericViewSet,
+    mixins.UpdateModelMixin
+):
+    serializer_class = MLRequestSerializer
+    queryset = MLRequest.objects.all()
+
+class PredictView(GenericAPIView):
+    queryset = GermanDataModel.objects.all()
+    serializer_class = GermanDataModelSerializer
+    
+    @extend_schema(
+        description='Predict credit risk for a loan',        
+        parameters=[
+            OpenApiParameter(
+                name='classifier',
+                description='Name of algorithm classifier',
+                required=False,
+                type=str,
+                examples=[
+                    OpenApiExample(
+                        'Example 1',
+                        summary='Random Forest Classifier',
+                        value='random_forest'
+                    ),
+                ],
+            ),
+            OpenApiParameter(
+                name='status',
+                description='The status of the algorithm',
+                required=False,
+                type=str,
+                examples=[
+                    OpenApiExample(
+                        'Example 1',
+                        summary='Algorithm in production',
+                        value='production'
+                    ),
+                ],
+            ),
+            OpenApiParameter(name='version', description='Algorithm version', required=True, type=str),
+        ],
+    )
+    def post(self, request, endpoint_name: str, format=None):
+        algorithm_version = self.request.query_params.get("version")
+        algorithm_status = self.request.query_params.get("status", "production")
+        algorithm_classifier = self.request.query_params.get("classifier", "random_forest")
         
-    # def get(self, request, format=None):
-    #     snippets = GermanDataModel.objects.all()
-    #     serializer = GermanDataModelSerializer(snippets, many=True)
-    #     return Response(serializer.data)
+        algs = MLAlgorithm.objects.filter(parent_endpoint__name = endpoint_name,                                          
+                                          status__status = algorithm_status,
+                                          status__active = True)
 
-    @action(detail=True, methods=['post'])
-    def predict(self, request, format=None):
-        data = self.get_object()
-        queryset = GermanDataModel.objects.all()
-        serializer = GermanDataModelSerializer(data=request.data)
-        if serializer.is_valid():
-            serializer.save()
-            return Response(serializer.data, status=status.HTTP_201_CREATED)
-        return Response(serializer.errors, status=status.HTTP_400_BAD_REQUEST)
+        
+        if algorithm_version is not None:
+            algs = algs.filter(version = algorithm_version)
+
+        num_algs = len(algs)
+        
+        if num_algs == 0:
+            return Response(
+                {"status": "Error", "message": "ML algorithm is not available"},
+                status=status.HTTP_400_BAD_REQUEST)
+            
+        if num_algs != 1 and algorithm_status != "ab_testing":
+            return Response(
+                {"status": "Error", "message": "ML algorithm selection is ambiguous. Please specify algorithm version."},
+                status=status.HTTP_400_BAD_REQUEST)
+            
+        alg_index = 0
+        if algorithm_status == "ab_testing":
+            alg_index = random.randrange(num_algs)
+        else:
+            algs = algs.filter(parent_endpoint__classifier = algorithm_classifier)
+
+        algorithm = algs[alg_index]
+        algorithm_object = registry.endpoints[algorithm.id]
+        prediction = algorithm_object.compute_prediction(request.data)
+
+        label = prediction["label"] if "label" in prediction else "error"
+        ml_request = MLRequest(
+            input_data=json.dumps(request.data),
+            full_response=prediction,
+            response=label,
+            feedback="",
+            parent_mlalgorithm=algs[alg_index],
+        )
+        ml_request.save()
+
+        prediction["request_id"] = ml_request.id
+
+        return Response(prediction)
+
+
+
+class ABTestViewSet(
+    mixins.RetrieveModelMixin, mixins.ListModelMixin, viewsets.GenericViewSet,
+    mixins.CreateModelMixin, mixins.UpdateModelMixin
+):
+    serializer_class = ABTestSerializer
+    queryset = ABTest.objects.all()
+
+    def perform_create(self, serializer):
+        try:
+            with transaction.atomic():
+                instance = serializer.save()
+                # update status for first algorithm
+
+                status_1 = MLAlgorithmStatus(status = "ab_testing",
+                                created_by=instance.created_by,
+                                parent_mlalgorithm = instance.parent_mlalgorithm_1,
+                                active=True)
+                status_1.save()
+                deactivate_other_statuses(status_1)
+                # update status for second algorithm
+                status_2 = MLAlgorithmStatus(status = "ab_testing",
+                                created_by=instance.created_by,
+                                parent_mlalgorithm = instance.parent_mlalgorithm_2,
+                                active=True)
+                status_2.save()
+                deactivate_other_statuses(status_2)
+
+        except Exception as e:
+            raise APIException(str(e))
+
+class StopABTestView(GenericAPIView):
+    serializer_class = ABTestSerializer
+    queryset = ABTest.objects.all()
+
+    def post(self, request, ab_test_id: int, format=None):
+
+        try:
+            ab_test = ABTest.objects.get(pk=ab_test_id)
+
+            if ab_test.ended_at is not None:
+                return Response({"message": "AB Test already finished."})
+
+            date_now = datetime.datetime.now()
+            # alg #1 accuracy
+            all_responses_1 = MLRequest.objects.filter(parent_mlalgorithm=ab_test.parent_mlalgorithm_1, created_at__gt = ab_test.created_at, created_at__lt = date_now).count()
+            correct_responses_1 = MLRequest.objects.filter(parent_mlalgorithm=ab_test.parent_mlalgorithm_1, created_at__gt = ab_test.created_at, created_at__lt = date_now, response=F('feedback')).count()
+            accuracy_1 = correct_responses_1 / float(all_responses_1)
+            print(all_responses_1, correct_responses_1, accuracy_1)
+
+            # alg #2 accuracy
+            all_responses_2 = MLRequest.objects.filter(parent_mlalgorithm=ab_test.parent_mlalgorithm_2, created_at__gt = ab_test.created_at, created_at__lt = date_now).count()
+            correct_responses_2 = MLRequest.objects.filter(parent_mlalgorithm=ab_test.parent_mlalgorithm_2, created_at__gt = ab_test.created_at, created_at__lt = date_now, response=F('feedback')).count()
+            accuracy_2 = correct_responses_2 / float(all_responses_2)
+            print(all_responses_2, correct_responses_2, accuracy_2)
+
+            # select algorithm with higher accuracy
+            alg_id_1, alg_id_2 = ab_test.parent_mlalgorithm_1, ab_test.parent_mlalgorithm_2
+            # swap
+            if accuracy_1 < accuracy_2:
+                alg_id_1, alg_id_2 = alg_id_2, alg_id_1
+
+            status_1 = MLAlgorithmStatus(status = "production",
+                            created_by=ab_test.created_by,
+                            parent_mlalgorithm = alg_id_1,
+                            active=True)
+            status_1.save()
+            deactivate_other_statuses(status_1)
+            # update status for second algorithm
+            status_2 = MLAlgorithmStatus(status = "testing",
+                            created_by=ab_test.created_by,
+                            parent_mlalgorithm = alg_id_2,
+                            active=True)
+            status_2.save()
+            deactivate_other_statuses(status_2)
+
+
+            summary = "Algorithm #1 accuracy: {}, Algorithm #2 accuracy: {}".format(accuracy_1, accuracy_2)
+            ab_test.ended_at = date_now
+            ab_test.summary = summary
+            ab_test.save()
+
+        except Exception as e:
+            return Response({"status": "Error", "message": str(e)},
+                            status=status.HTTP_400_BAD_REQUEST
+            )
+        return Response({"message": "AB Test finished.", "summary": summary})