﻿// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements.  See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership.  The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License.  You may obtain a copy of the License at
//
//   http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied.  See the License for the
// specific language governing permissions and limitations
// under the License.

using Google.Protobuf;
using Grpc.Core;
using Org.Apache.REEF.Bridge.Core.Common.Driver;
using Org.Apache.REEF.Bridge.Core.Proto;
using Org.Apache.REEF.Common.Exceptions;
using Org.Apache.REEF.Driver.Evaluator;
using Org.Apache.REEF.Tang.Annotations;
using Org.Apache.REEF.Tang.Formats;
using Org.Apache.REEF.Tang.Implementations.Tang;
using Org.Apache.REEF.Tang.Interface;
using Org.Apache.REEF.Utilities;
using Org.Apache.REEF.Utilities.Logging;
using System;
using System.Collections.Generic;
using System.IO;
using System.Linq;
using System.Runtime.Serialization;

namespace Org.Apache.REEF.Bridge.Core.Grpc.Driver
{
    internal class DriverServiceClient : IDriverServiceClient
    {
        private static readonly Logger Logger = Logger.GetLogger(typeof(DriverServiceClient));

        private readonly IConfigurationSerializer _configurationSerializer;

        private readonly DriverService.DriverServiceClient _driverServiceStub;

        [Inject]
        private DriverServiceClient(
            IConfigurationSerializer configurationSerializer,
            [Parameter(Value = typeof(DriverServicePort))] int driverServicePort)
        {
            _configurationSerializer = configurationSerializer;
            Logger.Log(Level.Info, "Binding to driver service at port {0}", driverServicePort);
            var driverServiceChannel = new Channel("127.0.0.1", driverServicePort, ChannelCredentials.Insecure);
            _driverServiceStub = new DriverService.DriverServiceClient(driverServiceChannel);
            Logger.Log(Level.Info, "Channel state {0}", driverServiceChannel.State);
        }

        public void RegisterDriverClientService(Exception exception)
        {
            Logger.Log(Level.Info, "Register driver client error", exception);
            var registration = new DriverClientRegistration
            {
                Exception = GrpcUtils.SerializeException(exception)
            };
            _driverServiceStub.RegisterDriverClient(registration);
        }

        public void RegisterDriverClientService(string host, int port)
        {
            Logger.Log(Level.Info, "Register driver client at host {0} port {1}", host, port);
            var registration = new DriverClientRegistration
            {
                Host = host,
                Port = port
            };
            _driverServiceStub.RegisterDriverClient(registration);
        }

        public void OnShutdown()
        {
            Logger.Log(Level.Info, "Driver clean shutdown");
            _driverServiceStub.Shutdown(new ShutdownRequest());
        }

        public void OnShutdown(Exception ex)
        {
            Logger.Log(Level.Error, "Driver shutdown with error", ex);
            byte[] errorBytes;
            try
            {
                errorBytes = ByteUtilities.SerializeToBinaryFormat(ex);
            }
            catch (SerializationException se)
            {
                Logger.Log(Level.Warning, "Unable to serialize exception", ex);
                errorBytes = ByteUtilities.SerializeToBinaryFormat(
                    NonSerializableJobException.UnableToSerialize(ex, se));
            }

            _driverServiceStub.Shutdown(new ShutdownRequest()
            {
                Exception = new ExceptionInfo()
                {
                    NoError = false,
                    Message = ex.Message,
                    Name = ex.Source,
                    Data = ByteString.CopyFrom(errorBytes)
                }
            });
        }

        public void OnSetAlarm(string alarmId, long timeoutMs)
        {
            _driverServiceStub.SetAlarm(new AlarmRequest()
            {
                AlarmId = alarmId,
                TimeoutMs = (int)timeoutMs
            });
        }

        public void OnEvaluatorRequest(IEvaluatorRequest evaluatorRequest)
        {
            var request = new ResourceRequest()
            {
                ResourceCount = evaluatorRequest.Number,
                Cores = evaluatorRequest.VirtualCore,
                MemorySize = evaluatorRequest.MemoryMegaBytes,
                RelaxLocality = evaluatorRequest.RelaxLocality,
                RuntimeName = evaluatorRequest.RuntimeName,
                NodeLabel = evaluatorRequest.NodeLabelExpression
            };
            if (!string.IsNullOrEmpty(evaluatorRequest.Rack))
            {
                request.RackNameList.Add(evaluatorRequest.Rack);
            }
            request.NodeNameList.Add(evaluatorRequest.NodeNames);
            _driverServiceStub.RequestResources(request);
        }

        public void OnEvaluatorClose(string evalautorId)
        {
            _driverServiceStub.AllocatedEvaluatorOp(new AllocatedEvaluatorRequest()
            {
                EvaluatorId = evalautorId,
                CloseEvaluator = true
            });
        }

        public void OnEvaluatorSubmit(
            string evaluatorId,
            IConfiguration contextConfiguration,
            Optional<IConfiguration> serviceConfiguration,
            Optional<IConfiguration> taskConfiguration,
            List<FileInfo> addFileList, List<FileInfo> addLibraryList)
        {
            Logger.Log(Level.Info, "Submitting allocated evaluator {0}", evaluatorId);

            var evaluatorConf =
                _configurationSerializer.ToString(TangFactory.GetTang().NewConfigurationBuilder().Build());
            var contextConf = _configurationSerializer.ToString(contextConfiguration);
            var serviceConf = !serviceConfiguration.IsPresent()
                ? string.Empty
                : _configurationSerializer.ToString(serviceConfiguration.Value);
            var taskConf = !taskConfiguration.IsPresent()
                ? string.Empty
                : _configurationSerializer.ToString(taskConfiguration.Value);
            var request = new AllocatedEvaluatorRequest()
            {
                EvaluatorId = evaluatorId,
                EvaluatorConfiguration = evaluatorConf,
                ServiceConfiguration = serviceConf,
                ContextConfiguration = contextConf,
                TaskConfiguration = taskConf,
                SetProcess = new AllocatedEvaluatorRequest.Types.EvaluatorProcessRequest()
                {
                    ProcessType = AllocatedEvaluatorRequest.Types.EvaluatorProcessRequest.Types.Type.Dotnet
                }
            };
            request.AddFiles.Add(addFileList.Select(f => f.ToString()));
            request.AddLibraries.Add(addLibraryList.Select(f => f.ToString()));
            _driverServiceStub.AllocatedEvaluatorOp(request);
        }

        public void OnContextClose(string contextId)
        {
            Logger.Log(Level.Info, "Close context {0}", contextId);
            _driverServiceStub.ActiveContextOp(new ActiveContextRequest()
            {
                ContextId = contextId,
                CloseContext = true
            });
        }

        public void OnContextSubmitContext(string contextId, IConfiguration contextConfiguration)
        {
            _driverServiceStub.ActiveContextOp(new ActiveContextRequest()
            {
                ContextId = contextId,
                NewContextRequest = _configurationSerializer.ToString(contextConfiguration)
            });
        }

        public void OnContextSubmitTask(string contextId, IConfiguration taskConfiguration)
        {
            _driverServiceStub.ActiveContextOp(new ActiveContextRequest()
            {
                ContextId = contextId,
                NewTaskRequest = _configurationSerializer.ToString(taskConfiguration)
            });
        }

        public void OnContextMessage(string contextId, byte[] message)
        {
            _driverServiceStub.ActiveContextOp(new ActiveContextRequest()
            {
                ContextId = contextId,
                Message = ByteString.CopyFrom(message)
            });
        }

        public void OnTaskSuspend(string taskId, Optional<byte[]> message)
        {
            var op = new RunningTaskRequest()
            {
                TaskId = taskId,
                Operation = RunningTaskRequest.Types.Operation.Suspend
            };
            if (message.IsPresent())
            {
                op.Message = ByteString.CopyFrom(message.Value);
            }

            _driverServiceStub.RunningTaskOp(op);
        }

        public void OnTaskClose(string taskId, Optional<byte[]> message)
        {
            var op = new RunningTaskRequest()
            {
                TaskId = taskId,
                Operation = RunningTaskRequest.Types.Operation.Close
            };
            if (message.IsPresent())
            {
                op.Message = ByteString.CopyFrom(message.Value);
            }

            _driverServiceStub.RunningTaskOp(op);
        }

        public void OnTaskMessage(string taskId, byte[] message)
        {
            _driverServiceStub.RunningTaskOp(new RunningTaskRequest()
            {
                TaskId = taskId,
                Message = ByteString.CopyFrom(message),
                Operation = RunningTaskRequest.Types.Operation.SendMessage
            });
        }
    }
}