/**
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.hadoop.yarn.ipc;

import java.io.Closeable;
import java.io.DataInput;
import java.io.DataOutput;
import java.io.IOException;
import java.lang.reflect.InvocationHandler;
import java.lang.reflect.Method;
import java.lang.reflect.Proxy;
import java.net.InetSocketAddress;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;

import javax.net.SocketFactory;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.io.Writable;
import org.apache.hadoop.ipc.Client;
import org.apache.hadoop.ipc.ProtocolProxy;
import org.apache.hadoop.ipc.RPC;
import org.apache.hadoop.ipc.RpcEngine;
import org.apache.hadoop.ipc.ClientCache;
import org.apache.hadoop.security.UserGroupInformation;
import org.apache.hadoop.security.token.SecretManager;
import org.apache.hadoop.security.token.TokenIdentifier;
import org.apache.hadoop.yarn.exceptions.impl.pb.YarnRemoteExceptionPBImpl;
import org.apache.hadoop.yarn.ipc.RpcProtos.ProtoSpecificRpcRequest;
import org.apache.hadoop.yarn.ipc.RpcProtos.ProtoSpecificRpcResponse;

import com.google.protobuf.BlockingService;
import com.google.protobuf.Descriptors.MethodDescriptor;
import com.google.protobuf.Message;
import com.google.protobuf.ServiceException;
import org.apache.hadoop.classification.InterfaceAudience;
import org.apache.hadoop.classification.InterfaceStability;



@InterfaceStability.Evolving
public class ProtoOverHadoopRpcEngine implements RpcEngine {
  private static final Log LOG = LogFactory.getLog(RPC.class);
  
  private static final ClientCache CLIENTS=new ClientCache();
  
  @Override
  @SuppressWarnings("unchecked")
  public <T> ProtocolProxy<T> getProxy(Class<T> protocol, long clientVersion,
      InetSocketAddress addr, UserGroupInformation ticket, Configuration conf,
      SocketFactory factory, int rpcTimeout) throws IOException {
    return new ProtocolProxy<T>(protocol, (T) Proxy.newProxyInstance(protocol
        .getClassLoader(), new Class[] { protocol }, new Invoker(protocol,
        addr, ticket, conf, factory, rpcTimeout)), false);
  }

  @Override
  public void stopProxy(Object proxy) {
    try {
      ((Invoker) Proxy.getInvocationHandler(proxy)).close();
    } catch (IOException e) {
      LOG.warn("Error while stopping " + proxy, e);
    }
  }

  private static class Invoker implements InvocationHandler, Closeable {
    private Map<String, Message> returnTypes = new ConcurrentHashMap<String, Message>();
    private boolean isClosed = false;
    private Client.ConnectionId remoteId;
    private Client client;

    public Invoker(Class<?> protocol, InetSocketAddress addr,
        UserGroupInformation ticket, Configuration conf, SocketFactory factory,
        int rpcTimeout) throws IOException {
      this.remoteId = Client.ConnectionId.getConnectionId(addr, protocol,
          ticket, rpcTimeout, conf);
      this.client = CLIENTS.getClient(conf, factory,
          ProtoSpecificResponseWritable.class);
    }

    private ProtoSpecificRpcRequest constructRpcRequest(Method method,
        Object[] params) throws ServiceException {
      ProtoSpecificRpcRequest rpcRequest;
      ProtoSpecificRpcRequest.Builder builder;

      builder = ProtoSpecificRpcRequest.newBuilder();
      builder.setMethodName(method.getName());

      if (params.length != 2) { // RpcController + Message
        throw new ServiceException("Too many parameters for request. Method: ["
            + method.getName() + "]" + ", Expected: 2, Actual: "
            + params.length);
      }
      if (params[1] == null) {
        throw new ServiceException("null param while calling Method: ["
            + method.getName() + "]");
      }

      Message param = (Message) params[1];
      builder.setRequestProto(param.toByteString());

      rpcRequest = builder.build();
      return rpcRequest;
    }

    @Override
    public Object invoke(Object proxy, Method method, Object[] args)
        throws Throwable {
      long startTime = 0;
      if (LOG.isDebugEnabled()) {
        startTime = System.currentTimeMillis();
      }

      ProtoSpecificRpcRequest rpcRequest = constructRpcRequest(method, args);
      ProtoSpecificResponseWritable val = null;
      try {
        val = (ProtoSpecificResponseWritable) client.call(
            new ProtoSpecificRequestWritable(rpcRequest), remoteId);
      } catch (Exception e) {
        throw new ServiceException(e);
      }
      
      ProtoSpecificRpcResponse response = val.message;
   
      if (LOG.isDebugEnabled()) {
        long callTime = System.currentTimeMillis() - startTime;
        LOG.debug("Call: " + method.getName() + " " + callTime);
      }
 
      if (response.hasIsError() && response.getIsError() == true) {
        YarnRemoteExceptionPBImpl exception = new YarnRemoteExceptionPBImpl(response.getException());
        exception.fillInStackTrace();
        ServiceException se = new ServiceException(exception);
        throw se;
      }
      
      Message prototype = null;
      try {
        prototype = getReturnProtoType(method);
      } catch (Exception e) {
        throw new ServiceException(e);
      }
      Message actualReturnMessage = prototype.newBuilderForType()
          .mergeFrom(response.getResponseProto()).build();
      return actualReturnMessage;
    }

    @Override
    public void close() throws IOException {
      if (!isClosed) {
        isClosed = true;
        CLIENTS.stopClient(client);
      }
    }

    private Message getReturnProtoType(Method method) throws Exception {
      if (returnTypes.containsKey(method.getName())) {
        return returnTypes.get(method.getName());
      } else {
        Class<?> returnType = method.getReturnType();

        Method newInstMethod = returnType.getMethod("getDefaultInstance");
        newInstMethod.setAccessible(true);
        Message prototype = (Message) newInstMethod.invoke(null,
            (Object[]) null);
        returnTypes.put(method.getName(), prototype);
        return prototype;
      }
    }
  }
  
  /**
   * Writable Wrapper for Protocol Buffer Requests
   */
  private static class ProtoSpecificRequestWritable implements Writable {
    ProtoSpecificRpcRequest message;

    @SuppressWarnings("unused")
    public ProtoSpecificRequestWritable() {
    }
    
    ProtoSpecificRequestWritable(ProtoSpecificRpcRequest message) {
      this.message = message;
    }

    @Override
    public void write(DataOutput out) throws IOException {
      out.writeInt(message.toByteArray().length);
      out.write(message.toByteArray());
    }

    @Override
    public void readFields(DataInput in) throws IOException {
      int length = in.readInt();
      byte[] bytes = new byte[length];
      in.readFully(bytes);
      message = ProtoSpecificRpcRequest.parseFrom(bytes);
    }
  }
  
  /**
   * Writable Wrapper for Protocol Buffer Responses
   */
  public static class ProtoSpecificResponseWritable implements Writable {
    ProtoSpecificRpcResponse message;

    public ProtoSpecificResponseWritable() {
    }
    
    public ProtoSpecificResponseWritable(ProtoSpecificRpcResponse message) {
      this.message = message;
    }

    @Override
    public void write(DataOutput out) throws IOException {
      out.writeInt(message.toByteArray().length);
      out.write(message.toByteArray());
    }

    @Override
    public void readFields(DataInput in) throws IOException {
      int length = in.readInt();
      byte[] bytes = new byte[length];
      in.readFully(bytes);
      message = ProtoSpecificRpcResponse.parseFrom(bytes);
    }
  }
  
  @Override
  public Object[] call(Method method, Object[][] params,
      InetSocketAddress[] addrs, UserGroupInformation ticket, Configuration conf)
      throws IOException, InterruptedException {
    throw new UnsupportedOperationException();
  }

  // for unit testing only
  @InterfaceAudience.Private
  @InterfaceStability.Unstable
  static Client getClient(Configuration conf) {
    return CLIENTS.getClient(conf, SocketFactory.getDefault(),
        ProtoSpecificResponseWritable.class);
  }

  public static class Server extends RPC.Server {

    private BlockingService service;
    private boolean verbose;
//
//    /**
//     * Construct an RPC server.
//     * 
//     * @param instance
//     *          the instance whose methods will be called
//     * @param conf
//     *          the configuration to use
//     * @param bindAddress
//     *          the address to bind on to listen for connection
//     * @param port
//     *          the port to listen for connections on
//     */
//    public Server(Object instance, Configuration conf, String bindAddress,
//        int port) throws IOException {
//      this(instance, conf, bindAddress, port, 1, false, null);
//    }

    private static String classNameBase(String className) {
      String[] names = className.split("\\.", -1);
      if (names == null || names.length == 0) {
        return className;
      }
      return names[names.length - 1];
    }

    /**
     * Construct an RPC server.
     * 
     * @param instance
     *          the instance whose methods will be called
     * @param conf
     *          the configuration to use
     * @param bindAddress
     *          the address to bind on to listen for connection
     * @param port
     *          the port to listen for connections on
     * @param numHandlers
     *          the number of method handler threads to run
     * @param verbose
     *          whether each call should be logged
     */
    public Server(Object instance, Configuration conf, String bindAddress,
        int port, int numHandlers, int numReaders, 
        int queueSizePerHandler, boolean verbose,
        SecretManager<? extends TokenIdentifier> secretManager)
        throws IOException {
      super(bindAddress, port, ProtoSpecificRequestWritable.class, numHandlers,
          numReaders, queueSizePerHandler, conf, classNameBase(instance.getClass().getName()), secretManager);
      this.service = (BlockingService) instance;
      this.verbose = verbose;
    }

    @Override
    public Writable call(Class<?> protocol, Writable writableRequest,
        long receiveTime) throws IOException {
      ProtoSpecificRequestWritable request = (ProtoSpecificRequestWritable) writableRequest;
      ProtoSpecificRpcRequest rpcRequest = request.message;
      String methodName = rpcRequest.getMethodName();
      if (verbose) {
        log("Call: protocol=" + protocol.getCanonicalName() + ", method="
            + methodName);
      }
      MethodDescriptor methodDescriptor = service.getDescriptorForType()
          .findMethodByName(methodName);
      if (methodDescriptor == null) {
        String msg = "Unknown method " + methodName + " called on "
            + protocol + " protocol.";
        LOG.warn(msg);
        return handleException(new IOException(msg));
      }
      Message prototype = service.getRequestPrototype(methodDescriptor);
      Message param = prototype.newBuilderForType()
          .mergeFrom(rpcRequest.getRequestProto()).build();
      Message result;
      try {
        result = service.callBlockingMethod(methodDescriptor, null, param);
      } catch (ServiceException e) {
        e.printStackTrace();
        return handleException(e);
      } catch (Exception e) {
        return handleException(e);
      }

      ProtoSpecificRpcResponse response = constructProtoSpecificRpcSuccessResponse(result);
      return new ProtoSpecificResponseWritable(response);
    }

    private ProtoSpecificResponseWritable handleException(Throwable e) {
      ProtoSpecificRpcResponse.Builder builder = ProtoSpecificRpcResponse
          .newBuilder();
      builder.setIsError(true);
      if (e.getCause() instanceof YarnRemoteExceptionPBImpl) {
        builder.setException(((YarnRemoteExceptionPBImpl) e.getCause())
            .getProto());
      } else {
        builder.setException(new YarnRemoteExceptionPBImpl(e).getProto());
      }
      ProtoSpecificRpcResponse response = builder.build();
      return new ProtoSpecificResponseWritable(response);
    }

    private ProtoSpecificRpcResponse constructProtoSpecificRpcSuccessResponse(
        Message message) {
      ProtoSpecificRpcResponse res = ProtoSpecificRpcResponse.newBuilder()
          .setResponseProto(message.toByteString()).build();
      return res;
    }
  }

  private static void log(String value) {
    if (value != null && value.length() > 55)
      value = value.substring(0, 55) + "...";
    LOG.info(value);
  }

  @Override
  public RPC.Server getServer(Class<?> protocol, Object instance,
      String bindAddress, int port, int numHandlers,int numReaders, 
      int queueSizePerHandler, boolean verbose,
      Configuration conf, SecretManager<? extends TokenIdentifier> secretManager)
      throws IOException {
    return new Server(instance, conf, bindAddress, port, numHandlers, numReaders, queueSizePerHandler,
        verbose, secretManager);
  }
}
