/*
 *
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */
package org.apache.airavata.helix.adaptor;

import com.google.common.collect.Lists;
import net.schmizz.keepalive.KeepAliveProvider;
import net.schmizz.sshj.DefaultConfig;
import net.schmizz.sshj.connection.ConnectionException;
import net.schmizz.sshj.connection.channel.direct.Session;
import net.schmizz.sshj.sftp.*;
import net.schmizz.sshj.userauth.keyprovider.KeyProvider;
import net.schmizz.sshj.userauth.method.AuthKeyboardInteractive;
import net.schmizz.sshj.userauth.method.AuthMethod;
import net.schmizz.sshj.userauth.method.AuthPublickey;
import net.schmizz.sshj.userauth.method.ChallengeResponseProvider;
import net.schmizz.sshj.userauth.password.PasswordFinder;
import net.schmizz.sshj.userauth.password.PasswordUtils;
import net.schmizz.sshj.userauth.password.Resource;
import net.schmizz.sshj.xfer.FilePermission;
import net.schmizz.sshj.xfer.LocalDestFile;
import net.schmizz.sshj.xfer.LocalFileFilter;
import net.schmizz.sshj.xfer.LocalSourceFile;
import org.apache.airavata.agents.api.*;
import org.apache.airavata.helix.adaptor.wrapper.SCPFileTransferWrapper;
import org.apache.airavata.helix.adaptor.wrapper.SFTPClientWrapper;
import org.apache.airavata.helix.adaptor.wrapper.SessionWrapper;
import org.apache.airavata.helix.agent.ssh.StandardOutReader;
import org.apache.airavata.model.appcatalog.computeresource.ComputeResourceDescription;
import org.apache.airavata.model.appcatalog.computeresource.JobSubmissionInterface;
import org.apache.airavata.model.appcatalog.computeresource.JobSubmissionProtocol;
import org.apache.airavata.model.appcatalog.computeresource.SSHJobSubmission;
import org.apache.airavata.model.credential.store.SSHCredential;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

import java.io.*;
import java.util.*;
import java.util.stream.Collectors;

public class SSHJAgentAdaptor implements AgentAdaptor {

    private final static Logger logger = LoggerFactory.getLogger(SSHJAgentAdaptor.class);

    private PoolingSSHJClient sshjClient;

    protected void createPoolingSSHJClient(String user, String host, int port, String publicKey, String privateKey, String passphrase) throws IOException {
        DefaultConfig defaultConfig = new DefaultConfig();
        defaultConfig.setKeepAliveProvider(KeepAliveProvider.KEEP_ALIVE);

        sshjClient = new PoolingSSHJClient(defaultConfig, host, port == 0 ? 22 : port);
        sshjClient.addHostKeyVerifier((h, p, key) -> true);

        sshjClient.setMaxSessionsForConnection(1);

        PasswordFinder passwordFinder = passphrase != null ? PasswordUtils.createOneOff(passphrase.toCharArray()) : null;

        KeyProvider keyProvider = sshjClient.loadKeys(privateKey, publicKey, passwordFinder);

        final List<AuthMethod> am = new LinkedList<>();
        am.add(new AuthPublickey(keyProvider));

        am.add(new AuthKeyboardInteractive(new ChallengeResponseProvider() {
            @Override
            public List<String> getSubmethods() {
                return new ArrayList<>();
            }

            @Override
            public void init(Resource resource, String name, String instruction) {

            }

            @Override
            public char[] getResponse(String prompt, boolean echo) {
                return new char[0];
            }

            @Override
            public boolean shouldRetry() {
                return false;
            }
        }));

        sshjClient.auth(user, am);
    }

    public void init(String user, String host, int port, String publicKey, String privateKey, String passphrase) throws AgentException {
        try {
            createPoolingSSHJClient(user, host, port, publicKey, privateKey, passphrase);
        } catch (IOException e) {
            logger.error("Error while initializing sshj agent for user " + user + " host " + host + " for key starting with " + publicKey.substring(0,10), e);
            throw new AgentException("Error while initializing sshj agent for user " + user + " host " + host + " for key starting with " + publicKey.substring(0,10), e);
        }
    }

    @Override
    public void init(String computeResource, String gatewayId, String userId, String token) throws AgentException {
        try {
            logger.info("Initializing Compute Resource SSH Adaptor for compute resource : "+ computeResource + ", gateway : " +
                    gatewayId +", user " + userId + ", token : " + token);

            ComputeResourceDescription computeResourceDescription = AgentUtils.getRegistryServiceClient().getComputeResource(computeResource);

            logger.info("Fetching job submission interfaces for compute resource " + computeResource);

            Optional<JobSubmissionInterface> jobSubmissionInterfaceOp = computeResourceDescription.getJobSubmissionInterfaces()
                    .stream().filter(iface -> iface.getJobSubmissionProtocol() == JobSubmissionProtocol.SSH).findFirst();

            JobSubmissionInterface sshInterface = jobSubmissionInterfaceOp
                    .orElseThrow(() -> new AgentException("Could not find a SSH interface for compute resource " + computeResource));

            SSHJobSubmission sshJobSubmission = AgentUtils.getRegistryServiceClient().getSSHJobSubmission(sshInterface.getJobSubmissionInterfaceId());

            logger.info("Fetching credentials for cred store token " + token);

            SSHCredential sshCredential = AgentUtils.getCredentialClient().getSSHCredential(token, gatewayId);

            if (sshCredential == null) {
                throw new AgentException("Null credential for token " + token);
            }
            logger.info("Description for token : " + token + " : " + sshCredential.getDescription());

            String alternateHostName = sshJobSubmission.getAlternativeSSHHostName();
            String selectedHostName = (alternateHostName == null || "".equals(alternateHostName))?
                    computeResourceDescription.getHostName() : alternateHostName;

            int selectedPort = sshJobSubmission.getSshPort() == 0 ? 22 : sshJobSubmission.getSshPort();

            logger.info("Using user {}, Host {}, Port {} to create ssh client for compute resource {}",
                    userId, selectedHostName, selectedPort, computeResource);

            createPoolingSSHJClient(userId, selectedHostName, selectedPort,
                    sshCredential.getPublicKey(), sshCredential.getPrivateKey(), sshCredential.getPassphrase());

        } catch (Exception e) {
            logger.error("Error while initializing ssh agent for compute resource " + computeResource + " to token " + token, e);
            throw new AgentException("Error while initializing ssh agent for compute resource " + computeResource + " to token " + token, e);
        }
    }

    @Override
    public void destroy() {
        try {
            if (sshjClient != null) {
                sshjClient.disconnect();
                sshjClient.close();
            }
        } catch (IOException e) {
            logger.warn("Failed to stop sshj client for host " + sshjClient.getHost() + " and user " +
                    sshjClient.getUsername() + " due to : " + e.getMessage());
            // ignore
        }
    }

    @Override
    public CommandOutput executeCommand(String command, String workingDirectory) throws AgentException {
        SessionWrapper session = null;
        try {
            session = sshjClient.startSessionWrapper();
            Session.Command exec = session.exec((workingDirectory != null ? "cd " + workingDirectory + "; " : "") + command);
            StandardOutReader standardOutReader = new StandardOutReader();

            try {
                standardOutReader.readStdOutFromStream(exec.getInputStream());
                standardOutReader.readStdErrFromStream(exec.getErrorStream());
            } finally {
                exec.close(); // closing the channel before getting the exit status
                standardOutReader.setExitCode(Optional.ofNullable(exec.getExitStatus()).orElseThrow(() -> new Exception("Exit status received as null")));
            }
            return standardOutReader;

        } catch (Exception e) {
            if (e instanceof ConnectionException) {
                Optional.ofNullable(session).ifPresent(ft -> ft.setErrored(true));
            }
            throw new AgentException(e);

        } finally {
            Optional.ofNullable(session).ifPresent(ss -> {
                try {
                    ss.close();
                } catch (IOException e) {
                    //Ignore
                }
            });
        }
    }

    @Override
    public void createDirectory(String path) throws AgentException {
        createDirectory(path, false);
    }

    @Override
    public void createDirectory(String path, boolean recursive) throws AgentException {
        SFTPClientWrapper sftpClient = null;
        try {
            sftpClient = sshjClient.newSFTPClientWrapper();
            if (recursive) {
                sftpClient.mkdirs(path);
            } else {
                sftpClient.mkdir(path);
            }
        } catch (Exception e) {
            if (e instanceof ConnectionException) {
                Optional.ofNullable(sftpClient).ifPresent(ft -> ft.setErrored(true));
            }
            throw new AgentException(e);

        } finally {
            Optional.ofNullable(sftpClient).ifPresent(client -> {
                try {
                    client.close();
                } catch (IOException e) {
                    //Ignore
                }
            });
        }
    }

    @Override
    public void uploadFile(String localFile, String remoteFile) throws AgentException {
        SCPFileTransferWrapper fileTransfer = null;
        try {
            fileTransfer = sshjClient.newSCPFileTransferWrapper();
            fileTransfer.upload(localFile, remoteFile);

        } catch (Exception e) {
            if (e instanceof ConnectionException) {
                Optional.ofNullable(fileTransfer).ifPresent(ft -> ft.setErrored(true));
            }
            throw new AgentException(e);

        } finally {
            Optional.ofNullable(fileTransfer).ifPresent(scpFileTransferWrapper -> {
                try {
                    scpFileTransferWrapper.close();
                } catch (IOException e) {
                    //Ignore
                }
            });
        }
    }

    @Override
    public void uploadFile(InputStream localInStream, FileMetadata metadata, String remoteFile) throws AgentException {
        SCPFileTransferWrapper fileTransfer = null;

        try {
            fileTransfer = sshjClient.newSCPFileTransferWrapper();
            fileTransfer.upload(new LocalSourceFile() {
                @Override
                public String getName() {
                    return metadata.getName();
                }

                @Override
                public long getLength() {
                    return metadata.getSize();
                }

                @Override
                public InputStream getInputStream() throws IOException {
                    return localInStream;
                }

                @Override
                public int getPermissions() throws IOException {
                    return 420; //metadata.getPermissions();
                }

                @Override
                public boolean isFile() {
                    return true;
                }

                @Override
                public boolean isDirectory() {
                    return false;
                }

                @Override
                public Iterable<? extends LocalSourceFile> getChildren(LocalFileFilter filter) throws IOException {
                    return null;
                }

                @Override
                public boolean providesAtimeMtime() {
                    return false;
                }

                @Override
                public long getLastAccessTime() throws IOException {
                    return 0;
                }

                @Override
                public long getLastModifiedTime() throws IOException {
                    return 0;
                }
            }, remoteFile);
        } catch (Exception e) {
            if (e instanceof ConnectionException) {
                Optional.ofNullable(fileTransfer).ifPresent(ft -> ft.setErrored(true));
            }
            throw new AgentException(e);

        } finally {
            Optional.ofNullable(fileTransfer).ifPresent(scpFileTransferWrapper -> {
                try {
                    scpFileTransferWrapper.close();
                } catch (IOException e) {
                    //Ignore
                }
            });
        }
    }

    @Override
    public void downloadFile(String remoteFile, String localFile) throws AgentException {
        SCPFileTransferWrapper fileTransfer = null;
        try {
            fileTransfer = sshjClient.newSCPFileTransferWrapper();
            fileTransfer.download(remoteFile, localFile);

        } catch (Exception e) {
            if (e instanceof ConnectionException) {
                Optional.ofNullable(fileTransfer).ifPresent(ft -> ft.setErrored(true));
            }
            throw new AgentException(e);

        } finally {
            Optional.ofNullable(fileTransfer).ifPresent(scpFileTransferWrapper -> {
                try {
                    scpFileTransferWrapper.close();
                } catch (IOException e) {
                    //Ignore
                }
            });
        }
    }

    @Override
    public void downloadFile(String remoteFile, OutputStream localOutStream, FileMetadata metadata) throws AgentException {
        SCPFileTransferWrapper fileTransfer = null;
        try {
            fileTransfer = sshjClient.newSCPFileTransferWrapper();
            fileTransfer.download(remoteFile, new LocalDestFile() {
                @Override
                public OutputStream getOutputStream() throws IOException {
                    return localOutStream;
                }

                @Override
                public LocalDestFile getChild(String name) {
                    return null;
                }

                @Override
                public LocalDestFile getTargetFile(String filename) throws IOException {
                    return this;
                }

                @Override
                public LocalDestFile getTargetDirectory(String dirname) throws IOException {
                    return null;
                }

                @Override
                public void setPermissions(int perms) throws IOException {

                }

                @Override
                public void setLastAccessedTime(long t) throws IOException {

                }

                @Override
                public void setLastModifiedTime(long t) throws IOException {

                }
            });
        } catch (Exception e) {
            if (e instanceof ConnectionException) {
                Optional.ofNullable(fileTransfer).ifPresent(ft -> ft.setErrored(true));
            }
            throw new AgentException(e);

        } finally {
            Optional.ofNullable(fileTransfer).ifPresent(scpFileTransferWrapper -> {
                try {
                    scpFileTransferWrapper.close();
                } catch (IOException e) {
                    //Ignore
                }
            });
        }
    }

    @Override
    public List<String> listDirectory(String path) throws AgentException {
        SFTPClientWrapper sftpClient = null;
        try {
            sftpClient = sshjClient.newSFTPClientWrapper();
            List<RemoteResourceInfo> ls = sftpClient.ls(path);
            return ls.stream().map(RemoteResourceInfo::getName).collect(Collectors.toList());
        } catch (Exception e) {
            if (e instanceof ConnectionException) {
                Optional.ofNullable(sftpClient).ifPresent(ft -> ft.setErrored(true));
            }
            throw new AgentException(e);

        } finally {
            Optional.ofNullable(sftpClient).ifPresent(client -> {
                try {
                    client.close();
                } catch (IOException e) {
                    //Ignore
                }
            });
        }
    }

    @Override
    public Boolean doesFileExist(String filePath) throws AgentException {
        SFTPClientWrapper sftpClient = null;
        try {
            sftpClient = sshjClient.newSFTPClientWrapper();
            return sftpClient.statExistence(filePath) != null;
        } catch (Exception e) {
            if (e instanceof ConnectionException) {
                Optional.ofNullable(sftpClient).ifPresent(ft -> ft.setErrored(true));
            }
            throw new AgentException(e);

        } finally {
            Optional.ofNullable(sftpClient).ifPresent(client -> {
                try {
                    client.close();
                } catch (IOException e) {
                    //Ignore
                }
            });
        }
    }

    @Override
    public List<String> getFileNameFromExtension(String fileName, String parentPath) throws AgentException {

        /*try (SFTPClient sftpClient = sshjClient.newSFTPClientWrapper()) {
            List<RemoteResourceInfo> ls = sftpClient.ls(parentPath, resource -> isMatch(resource.getName(), fileName));
            return ls.stream().map(RemoteResourceInfo::getPath).collect(Collectors.toList());
        } catch (Exception e) {
            throw new AgentException(e);
        }*/
        /*
        if (fileName.endsWith("*")) {
            throw new AgentException("Wildcards that ends with * does not support for security reasons. Specify an extension");
        }
        */

        CommandOutput commandOutput = executeCommand("ls " + fileName, parentPath); // This has a risk of returning folders also
        String[] filesTmp = commandOutput.getStdOut().split("\n");
        List<String> files = new ArrayList<>();
        for (String f: filesTmp) {
            if (!f.isEmpty()) {
                files.add(f);
            }
        }
        return files;
    }

    @Override
    public FileMetadata getFileMetadata(String remoteFile) throws AgentException {
        SFTPClientWrapper sftpClient = null;
        try {
            sftpClient = sshjClient.newSFTPClientWrapper();
            FileAttributes stat = sftpClient.stat(remoteFile);
            FileMetadata metadata = new FileMetadata();
            metadata.setName(new File(remoteFile).getName());
            metadata.setSize(stat.getSize());
            metadata.setPermissions(FilePermission.toMask(stat.getPermissions()));
            return metadata;
        } catch (Exception e) {
            if (e instanceof ConnectionException) {
                Optional.ofNullable(sftpClient).ifPresent(ft -> ft.setErrored(true));
            }
            throw new AgentException(e);

        } finally {
            Optional.ofNullable(sftpClient).ifPresent(scpFileTransferWrapper -> {
                try {
                    scpFileTransferWrapper.close();
                } catch (IOException e) {
                    //Ignore
                }
            });
        }
    }

    private boolean isMatch(String s, String p) {
        int i = 0;
        int j = 0;
        int starIndex = -1;
        int iIndex = -1;

        while (i < s.length()) {
            if (j < p.length() && (p.charAt(j) == '?' || p.charAt(j) == s.charAt(i))) {
                ++i;
                ++j;
            } else if (j < p.length() && p.charAt(j) == '*') {
                starIndex = j;
                iIndex = i;
                j++;
            } else if (starIndex != -1) {
                j = starIndex + 1;
                i = iIndex+1;
                iIndex++;
            } else {
                return false;
            }
        }
        while (j < p.length() && p.charAt(j) == '*') {
            ++j;
        }
        return j == p.length();
    }
}
