/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *   http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing,
 * software distributed under the License is distributed on an
 * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
 * KIND, either express or implied.  See the License for the
 * specific language governing permissions and limitations
 * under the License.
 */

package org.apache.sysds.runtime.controlprogram.federated;

import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Map.Entry;
import java.util.TreeMap;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Future;
import java.util.function.BiFunction;

import org.apache.commons.lang3.tuple.ImmutablePair;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.controlprogram.caching.CacheBlock;
import org.apache.sysds.runtime.controlprogram.caching.CacheableData;
import org.apache.sysds.runtime.controlprogram.federated.FederatedRequest.RequestType;
import org.apache.sysds.runtime.instructions.cp.ScalarObject;
import org.apache.sysds.runtime.instructions.cp.VariableCPInstruction;
import org.apache.sysds.runtime.matrix.data.MatrixBlock;
import org.apache.sysds.runtime.util.CommonThreadPool;

public class FederationMap
{
	public enum FType {
		ROW, //row partitioned, groups of rows
		COL, //column partitioned, groups of columns
		OTHER,
	}
	
	private long _ID = -1;
	private final Map<FederatedRange, FederatedData> _fedMap;
	private FType _type;
	
	public FederationMap(Map<FederatedRange, FederatedData> fedMap) {
		this(-1, fedMap);
	}
	
	public FederationMap(long ID, Map<FederatedRange, FederatedData> fedMap) {
		this(ID, fedMap, FType.OTHER);
	}
	
	public FederationMap(long ID, Map<FederatedRange, FederatedData> fedMap, FType type) {
		_ID = ID;
		_fedMap = fedMap;
		_type = type;
	}
	
	public long getID() {
		return _ID;
	}
	
	public FType getType() {
		return _type;
	}
	
	public boolean isInitialized() {
		return _ID >= 0;
	}
	
	public void setType(FType type) {
		_type = type;
	}
	
	public FederatedRange[] getFederatedRanges() {
		return _fedMap.keySet().toArray(new FederatedRange[0]);
	}
	
	public FederatedRequest broadcast(CacheableData<?> data) {
		//prepare single request for all federated data
		long id = FederationUtils.getNextFedDataID();
		CacheBlock cb = data.acquireReadAndRelease();
		return new FederatedRequest(RequestType.PUT_VAR, id, cb);
	}
	
	public FederatedRequest broadcast(ScalarObject scalar) {
		//prepare single request for all federated data
		long id = FederationUtils.getNextFedDataID();
		return new FederatedRequest(RequestType.PUT_VAR, id, scalar);
	}
	
	public FederatedRequest[] broadcastSliced(CacheableData<?> data, boolean transposed) {
		//prepare separate requests for different slices
		long id = FederationUtils.getNextFedDataID();
		CacheBlock cb = data.acquireReadAndRelease();
		List<FederatedRequest> ret = new ArrayList<>();
		for(Entry<FederatedRange, FederatedData> e : _fedMap.entrySet()) {
			int rl = transposed ? 0 : e.getKey().getBeginDimsInt()[0];
			int ru = transposed ? cb.getNumRows()-1 : e.getKey().getEndDimsInt()[0]-1;
			int cl = transposed ? e.getKey().getBeginDimsInt()[0] : 0;
			int cu = transposed ? e.getKey().getEndDimsInt()[0]-1 : cb.getNumColumns()-1;
			CacheBlock tmp = cb.slice(rl, ru, cl, cu, new MatrixBlock());
			ret.add(new FederatedRequest(RequestType.PUT_VAR, id, tmp));
		}
		return ret.toArray(new FederatedRequest[0]);
	}
	
	public boolean isAligned(FederationMap that, boolean transposed) {
		//determines if the two federated data are aligned row/column partitions
		//at the same federated site (which allows for purely federated operation)
		boolean ret = true;
		for(Entry<FederatedRange, FederatedData> e : _fedMap.entrySet()) {
			FederatedRange range = !transposed ? e.getKey() :
				new FederatedRange(e.getKey()).transpose();
			FederatedData dat2 = that._fedMap.get(range);
			ret &= e.getValue().equalAddress(dat2);
		}
		return ret;
	}
	
	public Future<FederatedResponse>[] execute(long tid, FederatedRequest... fr) {
		return execute(tid, false, fr);
	}
	
	public Future<FederatedResponse>[] execute(long tid, boolean wait, FederatedRequest... fr) {
		return execute(tid, wait, null, fr);
	}
	
	public Future<FederatedResponse>[] execute(long tid, FederatedRequest[] frSlices, FederatedRequest... fr) {
		return execute(tid, false, frSlices, fr);
	}
	
	@SuppressWarnings("unchecked")
	public Future<FederatedResponse>[] execute(long tid, boolean wait, FederatedRequest[] frSlices, FederatedRequest... fr) {
		// executes step1[] - step 2 - ... step4 (only first step federated-data-specific)
		setThreadID(tid, frSlices, fr);
		List<Future<FederatedResponse>> ret = new ArrayList<>(); 
		int pos = 0;
		for(Entry<FederatedRange, FederatedData> e : _fedMap.entrySet())
			ret.add(e.getValue().executeFederatedOperation(
				(frSlices!=null) ? addAll(frSlices[pos++], fr) : fr));
		
		// prepare results (future federated responses), with optional wait to ensure the 
		// order of requests without data dependencies (e.g., cleanup RPCs)
		if( wait )
			waitFor(ret);
		return ret.toArray(new Future[0]);
	}
	
	public List<Pair<FederatedRange, Future<FederatedResponse>>> requestFederatedData() {
		if( !isInitialized() )
			throw new DMLRuntimeException("Federated matrix read only supported on initialized FederatedData");
		
		List<Pair<FederatedRange, Future<FederatedResponse>>> readResponses = new ArrayList<>();
		FederatedRequest request = new FederatedRequest(RequestType.GET_VAR, _ID);
		for(Map.Entry<FederatedRange, FederatedData> e : _fedMap.entrySet())
			readResponses.add(new ImmutablePair<>(e.getKey(), 
				e.getValue().executeFederatedOperation(request)));
		return readResponses;
	}
	
	public void cleanup(long tid, long... id) {
		FederatedRequest request = new FederatedRequest(RequestType.EXEC_INST, -1,
			VariableCPInstruction.prepareRemoveInstruction(id).toString());
		request.setTID(tid);
		List<Future<FederatedResponse>> tmp = new ArrayList<>();
		for(FederatedData fd : _fedMap.values())
			tmp.add(fd.executeFederatedOperation(request));
		//wait to avoid interference w/ following requests
		waitFor(tmp);
	}
	
	private static void waitFor(List<Future<FederatedResponse>> responses) {
		try {
			for(Future<FederatedResponse> fr : responses)
				fr.get();
		}
		catch(Exception ex) {
			throw new DMLRuntimeException(ex);
		}
	}
	
	private static FederatedRequest[] addAll(FederatedRequest a, FederatedRequest[] b) {
		FederatedRequest[] ret = new FederatedRequest[b.length + 1];
		ret[0] = a; System.arraycopy(b, 0, ret, 1, b.length);
		return ret;
	}
	
	public FederationMap copyWithNewID() {
		return copyWithNewID(FederationUtils.getNextFedDataID());
	}
	
	public FederationMap copyWithNewID(long id) {
		Map<FederatedRange, FederatedData> map = new TreeMap<>();
		//TODO handling of file path, but no danger as never written
		for( Entry<FederatedRange, FederatedData> e : _fedMap.entrySet() )
			map.put(new FederatedRange(e.getKey()), new FederatedData(e.getValue(), id));
		return new FederationMap(id, map, _type);
	}
	
	public FederationMap copyWithNewID(long id, long clen) {
		Map<FederatedRange, FederatedData> map = new TreeMap<>();
		//TODO handling of file path, but no danger as never written
		for( Entry<FederatedRange, FederatedData> e : _fedMap.entrySet() )
			map.put(new FederatedRange(e.getKey(), clen), new FederatedData(e.getValue(), id));
		return new FederationMap(id, map);
	}

	public FederationMap rbind(long offset, FederationMap that) {
		for( Entry<FederatedRange, FederatedData> e : that._fedMap.entrySet() ) {
			_fedMap.put(
				new FederatedRange(e.getKey()).shift(offset, 0),
				new FederatedData(e.getValue(), _ID));
		}
		return this;
	}
	
	public FederationMap transpose() {
		Map<FederatedRange, FederatedData> tmp = new TreeMap<>(_fedMap);
		_fedMap.clear();
		for( Entry<FederatedRange, FederatedData> e : tmp.entrySet() ) {
			_fedMap.put(
				new FederatedRange(e.getKey()).transpose(),
				new FederatedData(e.getValue(), _ID));
		}
		//derive output type
		switch(_type) {
			case ROW: _type = FType.COL; break;
			case COL: _type = FType.ROW; break;
			default: _type = FType.OTHER;
		}
		return this;
	}


	/**
	 * Execute a function for each <code>FederatedRange</code> + <code>FederatedData</code> pair. The function should
	 * not change any data of the pair and instead use <code>mapParallel</code> if that is a necessity. Note that this
	 * operation is parallel and necessary synchronisation has to be performed.
	 * 
	 * @param forEachFunction function to execute for each pair
	 */
	public void forEachParallel(BiFunction<FederatedRange, FederatedData, Void> forEachFunction) {
		ExecutorService pool = CommonThreadPool.get(_fedMap.size());

		ArrayList<MappingTask> mappingTasks = new ArrayList<>();
		for(Map.Entry<FederatedRange, FederatedData> fedMap : _fedMap.entrySet())
			mappingTasks.add(new MappingTask(fedMap.getKey(), fedMap.getValue(), forEachFunction, _ID));
		CommonThreadPool.invokeAndShutdown(pool, mappingTasks);
	}

	/**
	 * Execute a function for each <code>FederatedRange</code> + <code>FederatedData</code> pair mapping the pairs to
	 * their new form by directly changing both <code>FederatedRange</code> and <code>FederatedData</code>. The varIDs
	 * don't have to be changed by the <code>mappingFunction</code> as that will be done by this method. Note that this
	 * operation is parallel and necessary synchronisation has to be performed.
	 *
	 * @param newVarID        the new varID to be used by the new FederationMap
	 * @param mappingFunction the function directly changing ranges and data
	 * @return the new <code>FederationMap</code>
	 */
	public FederationMap mapParallel(long newVarID, BiFunction<FederatedRange, FederatedData, Void> mappingFunction) {
		ExecutorService pool = CommonThreadPool.get(_fedMap.size());

		FederationMap fedMapCopy = copyWithNewID(_ID);
		ArrayList<MappingTask> mappingTasks = new ArrayList<>();
		for(Map.Entry<FederatedRange, FederatedData> fedMap : fedMapCopy._fedMap.entrySet())
			mappingTasks.add(new MappingTask(fedMap.getKey(), fedMap.getValue(), mappingFunction, newVarID));
		CommonThreadPool.invokeAndShutdown(pool, mappingTasks);
		fedMapCopy._ID = newVarID;
		return fedMapCopy;
	}
	
	private static void setThreadID(long tid, FederatedRequest[]... frsets) {
		for( FederatedRequest[] frset : frsets )
			if( frset != null )
				Arrays.stream(frset).forEach(fr -> fr.setTID(tid));
	}

	private static class MappingTask implements Callable<Void> {
		private final FederatedRange _range;
		private final FederatedData _data;
		private final BiFunction<FederatedRange, FederatedData, Void> _mappingFunction;
		private final long _varID;

		public MappingTask(FederatedRange range, FederatedData data,
			BiFunction<FederatedRange, FederatedData, Void> mappingFunction, long varID) {
			_range = range;
			_data = data;
			_mappingFunction = mappingFunction;
			_varID = varID;
		}

		@Override
		public Void call() throws Exception {
			_mappingFunction.apply(_range, _data);
			_data.setVarID(_varID);
			return null;
		}
	}
}
