/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.apache.sysds.utils;

import java.io.IOException;

import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.apache.log4j.Level;
import org.apache.log4j.Logger;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.conf.DMLConfig;
import org.apache.sysds.hops.OptimizerUtils;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.io.IOUtilFunctions;

import java.util.Vector;
import java.io.InputStream;
import java.io.OutputStream;
import java.nio.FloatBuffer;
import java.io.File;

import org.apache.commons.io.FileUtils;
import org.apache.commons.io.IOUtils;
import org.apache.commons.lang.SystemUtils;

/**
 * This class helps in loading native library.
 * By default, it first tries to load Intel MKL, else tries to load OpenBLAS.
 */
public class NativeHelper {
	
	public enum NativeBlasState {
		NOT_ATTEMPTED_LOADING_NATIVE_BLAS,
		SUCCESSFULLY_LOADED_NATIVE_BLAS_AND_IN_USE,
		SUCCESSFULLY_LOADED_NATIVE_BLAS_AND_NOT_IN_USE,
		ATTEMPTED_LOADING_NATIVE_BLAS_UNSUCCESSFULLY
	}
	
	public static NativeBlasState CURRENT_NATIVE_BLAS_STATE = NativeBlasState.NOT_ATTEMPTED_LOADING_NATIVE_BLAS;
	private static String blasType;

	// Useful for deciding whether to use native BLAS in parfor environment.
	private static int maxNumThreads = -1;
	private static boolean setMaxNumThreads = false;

	private static final Log LOG = LogFactory.getLog(NativeHelper.class.getName());

	/**
	 * Called by Statistics to print the loaded BLAS.
	 * 
	 * @return empty string or the BLAS that is loaded
	 */
	public static String getCurrentBLAS() {
		return blasType != null ? blasType : "";
	}
	
	/**
	 * Called by runtime to check if the BLAS is available for exploitation
	 * 
	 * @return true if CURRENT_NATIVE_BLAS_STATE is SUCCESSFULLY_LOADED_NATIVE_BLAS_AND_NOT_IN_USE else false
	 */
	public static boolean isNativeLibraryLoaded() {
		if(!isBLASLoaded()) {
			DMLConfig dmlConfig = ConfigurationManager.getDMLConfig();
			String userSpecifiedBLAS = (dmlConfig == null) ? "auto" : dmlConfig.getTextValue(DMLConfig.NATIVE_BLAS)
					.trim().toLowerCase();
			String customLibPath = (dmlConfig == null) ? "none" : dmlConfig.getTextValue(DMLConfig.NATIVE_BLAS_DIR).trim();
			performLoading(customLibPath, userSpecifiedBLAS);
		}

		if(maxNumThreads == -1)
			maxNumThreads = OptimizerUtils.getConstrainedNumThreads(-1);

		if(CURRENT_NATIVE_BLAS_STATE == NativeBlasState.SUCCESSFULLY_LOADED_NATIVE_BLAS_AND_IN_USE && !setMaxNumThreads
				&& maxNumThreads != -1) {
			/* This method helps us decide whether to use GetPrimitiveArrayCritical or GetDoubleArrayElements in JNI as
			 * each has different tradeoffs. In current implementation, we always use GetPrimitiveArrayCritical as it
			 * has proven to be fastest.
			 * We can revisit this decision later and hence I would not recommend removing this method.
			 * */
			setMaxNumThreads(maxNumThreads);
			setMaxNumThreads = true;
		}
		return CURRENT_NATIVE_BLAS_STATE == NativeBlasState.SUCCESSFULLY_LOADED_NATIVE_BLAS_AND_IN_USE;
	}
	
	/**
	 * Initialize the native library before executing the DML program 
	 * 
	 * @param customLibPath specified by sysds.native.blas.directory
	 * @param userSpecifiedBLAS specified by sysds.native.blas
	 */
	public static void initialize(String customLibPath, String userSpecifiedBLAS) {
		if(isBLASLoaded() && isSupportedBLAS(userSpecifiedBLAS) && !blasType.equalsIgnoreCase(userSpecifiedBLAS)) {
			throw new DMLRuntimeException("Cannot replace previously loaded blas \"" + blasType + "\" with \"" +
					userSpecifiedBLAS + "\".");
		}
		else if(isBLASLoaded() && userSpecifiedBLAS.equalsIgnoreCase("none")) {
			CURRENT_NATIVE_BLAS_STATE = NativeBlasState.SUCCESSFULLY_LOADED_NATIVE_BLAS_AND_NOT_IN_USE;
		}
		else if(isBLASLoaded() && userSpecifiedBLAS.equalsIgnoreCase(blasType)) {
			CURRENT_NATIVE_BLAS_STATE = NativeBlasState.SUCCESSFULLY_LOADED_NATIVE_BLAS_AND_IN_USE;
		}
		else if(!isBLASLoaded() && isSupportedBLAS(userSpecifiedBLAS)) {
			performLoading(customLibPath, userSpecifiedBLAS);
		}
	}
	
	/**
	 * Return true if the given BLAS type is supported.
	 * 
	 * @param userSpecifiedBLAS BLAS type specified via sysds.native.blas property
	 * @return true if the userSpecifiedBLAS is auto | mkl | openblas, else false
	 */
	private static boolean isSupportedBLAS(String userSpecifiedBLAS) {
		return userSpecifiedBLAS.equalsIgnoreCase("auto") || 
				userSpecifiedBLAS.equalsIgnoreCase("mkl") || 
				userSpecifiedBLAS.equalsIgnoreCase("openblas");
	}
	
	/**
	 * Note: we only support 64 bit Java on x86 and AMD machine
	 * 
	 * @return true if the hardware architecture is supported
	 */
	private static boolean isSupportedArchitecture() {
		if(SystemUtils.OS_ARCH.equals("x86_64") || SystemUtils.OS_ARCH.equals("amd64")) {
			return true;
		}
		LOG.info("Unsupported architecture for native BLAS:" + SystemUtils.OS_ARCH);
		return false;
	}
	
	/**
	 * Check if native BLAS libraries have been successfully loaded
	 * @return true if CURRENT_NATIVE_BLAS_STATE is SUCCESSFULLY_LOADED_NATIVE_BLAS_AND_IN_USE or
	 * 		   SUCCESSFULLY_LOADED_NATIVE_BLAS_AND_NOT_IN_USE
	 */
	private static boolean isBLASLoaded() {
		return CURRENT_NATIVE_BLAS_STATE == NativeBlasState.SUCCESSFULLY_LOADED_NATIVE_BLAS_AND_IN_USE || 
				CURRENT_NATIVE_BLAS_STATE == NativeBlasState.SUCCESSFULLY_LOADED_NATIVE_BLAS_AND_NOT_IN_USE;
	}
	
	/**
	 * Check if we should attempt to perform loading.
	 * If custom library path is provided, we should attempt to load again if not already loaded.
	 * 
	 * @param customLibPath custom library path 
	 * @return true if we should attempt to load blas again
	 */
	private static boolean shouldReload(String customLibPath) {
		boolean isValidBLASDirectory = customLibPath != null && !customLibPath.equalsIgnoreCase("none");
		return CURRENT_NATIVE_BLAS_STATE == NativeBlasState.NOT_ATTEMPTED_LOADING_NATIVE_BLAS ||
			   (isValidBLASDirectory && !isBLASLoaded());
	}

	// Performing loading in a method instead of a static block will throw a detailed stack trace in case of fatal errors
	private static void performLoading(String customLibPath, String userSpecifiedBLAS) {
		if((customLibPath != null) && customLibPath.equalsIgnoreCase("none"))
				customLibPath = null;

		// attemptedLoading variable ensures that we don't try to load SystemDS and other dependencies
		// again and again especially in the parfor (hence the double-checking with synchronized).
		if(shouldReload(customLibPath) && isSupportedBLAS(userSpecifiedBLAS) && isSupportedArchitecture()) {
			long start = System.nanoTime();
			synchronized(NativeHelper.class) {
				if(shouldReload(customLibPath)) {
					// Set attempted loading unsuccessful in case of exception
					CURRENT_NATIVE_BLAS_STATE = NativeBlasState.ATTEMPTED_LOADING_NATIVE_BLAS_UNSUCCESSFULLY;
					String [] blas = new String[] { userSpecifiedBLAS }; 
					if(userSpecifiedBLAS.equalsIgnoreCase("auto")) {
						blas = new String[] { "mkl", "openblas" };
					}

					if(SystemUtils.IS_OS_WINDOWS) {
						if (checkAndLoadBLAS(customLibPath, blas) &&
								(loadLibraryHelper("systemds_" + blasType + "-Windows-AMD64") ||
								loadBLAS(customLibPath, "systemds_" + blasType + "-Windows-AMD64", null))
						)
						{
							LOG.info("Using native blas: " + blasType + getNativeBLASPath());
							CURRENT_NATIVE_BLAS_STATE = NativeBlasState.SUCCESSFULLY_LOADED_NATIVE_BLAS_AND_IN_USE;
						}
					}
					else {
						if (checkAndLoadBLAS(customLibPath, blas) &&
								loadLibraryHelper("libsystemds_" + blasType + "-Linux-x86_64.so")) {
							LOG.info("Using native blas: " + blasType + getNativeBLASPath());
							CURRENT_NATIVE_BLAS_STATE = NativeBlasState.SUCCESSFULLY_LOADED_NATIVE_BLAS_AND_IN_USE;
						}
					}
				}
			}
			double timeToLoadInMilliseconds = (System.nanoTime()-start)*1e-6;
			if(timeToLoadInMilliseconds > 1000) 
				LOG.warn("Time to load native blas: " + timeToLoadInMilliseconds + " milliseconds.");
		}
		else if(LOG.isDebugEnabled() && !isSupportedBLAS(userSpecifiedBLAS)) {
			LOG.debug("Using internal Java BLAS as native BLAS support the configuration 'sysds.native.blas'=" +
					userSpecifiedBLAS + ".");
		}
	}
	
	private static boolean checkAndLoadBLAS(String customLibPath, String [] listBLAS) {
		if(customLibPath != null && customLibPath.equalsIgnoreCase("none"))
			customLibPath = null;

		boolean isLoaded = false;
		for (String blas : listBLAS) {
			if (blas.equalsIgnoreCase("mkl")) {
					isLoaded = loadBLAS(customLibPath, "mkl_rt", null);
			} else if (blas.equalsIgnoreCase("openblas")) {
				// no need for gomp on windows
				if (SystemUtils.IS_OS_WINDOWS || loadBLAS(customLibPath, "gomp",
						"gomp required for loading OpenBLAS-enabled SystemDS library")) {
					isLoaded = loadBLAS(customLibPath, "openblas", null);
				}
			}
			if (isLoaded) {
				blasType = blas;
				break;
			}
		}
		return isLoaded;
	}
	
	/**
	 * Useful method for debugging.
	 * 
	 * @return empty string (if !LOG.isDebugEnabled()) or the path from where openblas or mkl is loaded.
	 */
	private static String getNativeBLASPath() {
		String blasPathAndHint = "";
		if(LOG.isDebugEnabled()) {
			// Only perform the checking of library paths when DEBUG is enabled to avoid runtime overhead.
			try {
				java.lang.reflect.Field loadedLibraryNamesField = ClassLoader.class.getDeclaredField("loadedLibraryNames");
				loadedLibraryNamesField.setAccessible(true);
				@SuppressWarnings("unchecked")
				Vector<String> libraries = (Vector<String>) loadedLibraryNamesField.get(ClassLoader.getSystemClassLoader());
				LOG.debug("List of native libraries loaded:" + libraries);
				for(String library : libraries) {
					if(library.contains("mkl_rt") || library.contains("libopenblas")) {
						blasPathAndHint = " from the path " + library;
						break;
					}
				}
			} catch (NoSuchFieldException | SecurityException | IllegalArgumentException | IllegalAccessException e) {
				LOG.debug("Error while finding list of native libraries:" + e.getMessage());
			}
		}
		return blasPathAndHint;
	}

	public static int getMaxNumThreads() {
		if(maxNumThreads == -1)
			maxNumThreads = OptimizerUtils.getConstrainedNumThreads(-1);
		return maxNumThreads;
	}

	/**
	 * Attempts to load native BLAS
	 * 
	 * @param customLibPath can be null (if we want to only want to use LD_LIBRARY_PATH), else the 
	 * @param blas can be gomp, openblas or mkl_rt
	 * @param optionalMsg message for debugging
	 * @return true if successfully loaded BLAS
	 */
	private static boolean loadBLAS(String customLibPath, String blas, String optionalMsg) {
		// First attempt to load from custom library path
		if((customLibPath != null) && (!customLibPath.equalsIgnoreCase("none"))) {
			String libPath = customLibPath + File.separator + System.mapLibraryName(blas);
			try {
				System.load(libPath);
				// Print to stdout as this feature is intended for cloud environment
				System.out.println("Loaded the library:" + libPath);
				return true;
			}
			catch (UnsatisfiedLinkError e1) { 
				// Print to stdout as this feature is intended for cloud environment
				System.out.println("Unable to load " + libPath + ":" + e1.getMessage());
			}
		}

		// Then try loading using loadLibrary
		try {
			System.loadLibrary(blas);
			return true;
		}
		catch (UnsatisfiedLinkError e) {
			System.out.println(System.getProperty("java.library.path"));
			if(optionalMsg != null)
				LOG.debug("Unable to load " + blas + "(" + optionalMsg + "):" + e.getMessage());
			else
				LOG.debug("Unable to load " + blas + ":" + e.getMessage());
			return false;
		}
	}


	private static boolean loadLibraryHelper(String path)  {
		OutputStream out = null;
		try(InputStream in = NativeHelper.class.getResourceAsStream("/lib/"+path)) {
			// This logic is added because Java does not allow to load library from a resource file.
			if(in != null) {
				File temp = File.createTempFile(path, "");
				temp.deleteOnExit();
				out = FileUtils.openOutputStream(temp);
				IOUtils.copy(in, out);
				System.load(temp.getAbsolutePath());
				return true;
			}
			else
				LOG.warn("No lib available in the jar:" + path);
		}
		catch(IOException e) {
			LOG.warn("Unable to load library " + path + " from resource:" + e.getMessage());
		}
		finally {
			IOUtilFunctions.closeSilently(out);
		}
		return false;
	}

	// TODO: Add pmm, wsloss, mmchain, etc.
	
	//double-precision matrix multiply dense-dense
	public static native boolean dmmdd(double [] m1, double [] m2, double [] ret, int m1rlen, int m1clen, int m2clen,
									   int numThreads);
	//single-precision matrix multiply dense-dense
	public static native boolean smmdd(FloatBuffer m1, FloatBuffer m2, FloatBuffer ret, int m1rlen, int m1clen, int m2clen,
									   int numThreads);
	//transpose-self matrix multiply
	public static native boolean tsmm(double[] m1, double[] ret, int m1rlen, int m1clen, boolean leftTrans, int numThreads);

	// ----------------------------------------------------------------------------------------------------------------
	// LibMatrixDNN operations:
	// N = number of images, C = number of channels, H = image height, W = image width
	// K = number of filters, R = filter height, S = filter width
	// TODO: case not handled: sparse filters (which will only be executed in Java). Since filters are relatively smaller,
	//  this is a low priority.

	// Returns -1 if failures or returns number of nonzeros
	// Called by DnnCPInstruction if both input and filter are dense
	public static native int conv2dDense(double [] input, double [] filter, double [] ret, int N, int C, int H, int W, 
			int K, int R, int S, int stride_h, int stride_w, int pad_h, int pad_w, int P, int Q, int numThreads);

	public static native int dconv2dBiasAddDense(double [] input, double [] bias, double [] filter, double [] ret, int N,
		int C, int H, int W, int K, int R, int S, int stride_h, int stride_w, int pad_h, int pad_w, int P, int Q,
												 int numThreads);

	public static native int sconv2dBiasAddDense(FloatBuffer input, FloatBuffer bias, FloatBuffer filter, FloatBuffer ret,
		int N, int C, int H, int W, int K, int R, int S, int stride_h, int stride_w, int pad_h, int pad_w, int P, int Q,
												 int numThreads);

	// Called by DnnCPInstruction if both input and filter are dense
	public static native int conv2dBackwardFilterDense(double [] input, double [] dout, double [] ret, int N, int C,
													   int H, int W, int K, int R, int S, int stride_h, int stride_w,
													   int pad_h, int pad_w, int P, int Q, int numThreads);

	// If both filter and dout are dense, then called by DnnCPInstruction
	// Else, called by LibMatrixDNN's thread if filter is dense. dout[n] is converted to dense if sparse.
	public static native int conv2dBackwardDataDense(double [] filter, double [] dout, double [] ret, int N, int C,
													 int H, int W, int K, int R, int S, int stride_h, int stride_w,
													 int pad_h, int pad_w, int P, int Q, int numThreads);

	// Currently only supported with numThreads = 1 and sparse input
	// Called by LibMatrixDNN's thread if input is sparse. dout[n] is converted to dense if sparse.
	public static native boolean conv2dBackwardFilterSparseDense(int apos, int alen, int[] aix, double[] avals,
																 double [] rotatedDoutPtr, double [] ret, int N, int C,
																 int H, int W, int K, int R, int S, int stride_h,
																 int stride_w, int pad_h, int pad_w, int P, int Q,
																 int numThreads);

	// Called by LibMatrixDNN's thread if input is sparse and filter is dense
	public static native boolean conv2dSparse(int apos, int alen, int[] aix, double[] avals, double [] filter,
											  double [] ret, int N, int C, int H, int W, int K, int R, int S,
											  int stride_h, int stride_w, int pad_h, int pad_w, int P, int Q,
											  int numThreads);
	// ----------------------------------------------------------------------------------------------------------------

	// This method helps us decide whether to use GetPrimitiveArrayCritical or GetDoubleArrayElements in JNI as each has
	// different tradeoffs. In current implementation, we always use GetPrimitiveArrayCritical as it has proven to be
	// fastest. We can revisit this decision later and hence I would not recommend removing this method.
	private static native void setMaxNumThreads(int numThreads);
}
