/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

package org.apache.flink.api.common.functions.util;

import org.apache.flink.annotation.Internal;
import org.apache.flink.annotation.PublicEvolving;
import org.apache.flink.annotation.VisibleForTesting;
import org.apache.flink.api.common.ExecutionConfig;
import org.apache.flink.api.common.TaskInfo;
import org.apache.flink.api.common.accumulators.AbstractAccumulatorRegistry;
import org.apache.flink.api.common.accumulators.Accumulator;
import org.apache.flink.api.common.accumulators.AccumulatorHelper;
import org.apache.flink.api.common.accumulators.DoubleCounter;
import org.apache.flink.api.common.accumulators.Histogram;
import org.apache.flink.api.common.accumulators.IntCounter;
import org.apache.flink.api.common.accumulators.LongCounter;
import org.apache.flink.api.common.cache.DistributedCache;
import org.apache.flink.api.common.functions.RuntimeContext;
import org.apache.flink.api.common.state.AggregatingState;
import org.apache.flink.api.common.state.AggregatingStateDescriptor;
import org.apache.flink.api.common.state.FoldingState;
import org.apache.flink.api.common.state.FoldingStateDescriptor;
import org.apache.flink.api.common.state.ListState;
import org.apache.flink.api.common.state.ListStateDescriptor;
import org.apache.flink.api.common.state.MapState;
import org.apache.flink.api.common.state.MapStateDescriptor;
import org.apache.flink.api.common.state.ReducingState;
import org.apache.flink.api.common.state.ReducingStateDescriptor;
import org.apache.flink.api.common.state.SortedMapState;
import org.apache.flink.api.common.state.SortedMapStateDescriptor;
import org.apache.flink.api.common.state.ValueState;
import org.apache.flink.api.common.state.ValueStateDescriptor;
import org.apache.flink.core.fs.Path;
import org.apache.flink.metrics.MetricGroup;

import java.io.Serializable;
import java.util.Map;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.Future;

import static org.apache.flink.util.Preconditions.checkNotNull;

/**
 * A standalone implementation of the {@link RuntimeContext}, created by runtime UDF operators.
 */
@Internal
public abstract class AbstractRuntimeUDFContext implements RuntimeContext {

	private final TaskInfo taskInfo;

	private final ClassLoader userCodeClassLoader;

	private final ExecutionConfig executionConfig;

	private final AbstractAccumulatorRegistry accumulatorRegistry;

	private final DistributedCache distributedCache;

	private final MetricGroup metrics;

	public AbstractRuntimeUDFContext(TaskInfo taskInfo,
										ClassLoader userCodeClassLoader,
										ExecutionConfig executionConfig,
									 	AbstractAccumulatorRegistry accumulatorRegistry,
										Map<String, Future<Path>> cpTasks,
										MetricGroup metrics) {
		this.taskInfo = checkNotNull(taskInfo);
		this.userCodeClassLoader = userCodeClassLoader;
		this.executionConfig = executionConfig;
		this.distributedCache = new DistributedCache(checkNotNull(cpTasks));
		this.accumulatorRegistry = checkNotNull(accumulatorRegistry);
		this.metrics = metrics;
	}

	@Override
	public ExecutionConfig getExecutionConfig() {
		return executionConfig;
	}

	@Override
	public String getTaskName() {
		return taskInfo.getTaskName();
	}

	@Override
	public int getNumberOfParallelSubtasks() {
		return taskInfo.getNumberOfParallelSubtasks();
	}

	@Override
	public int getMaxNumberOfParallelSubtasks() {
		return taskInfo.getMaxNumberOfParallelSubtasks();
	}

	@Override
	public int getIndexOfThisSubtask() {
		return taskInfo.getIndexOfThisSubtask();
	}

	@Override
	public MetricGroup getMetricGroup() {
		return metrics;
	}

	@Override
	public int getAttemptNumber() {
		return taskInfo.getAttemptNumber();
	}

	@Override
	public String getTaskNameWithSubtasks() {
		return taskInfo.getTaskNameWithSubtasks();
	}

	@Override
	public IntCounter getIntCounter(String name) {
		return (IntCounter) getAccumulator(name, IntCounter.class);
	}

	@Override
	public LongCounter getLongCounter(String name) {
		return (LongCounter) getAccumulator(name, LongCounter.class);
	}

	@Override
	public Histogram getHistogram(String name) {
		return (Histogram) getAccumulator(name, Histogram.class);
	}

	@Override
	public DoubleCounter getDoubleCounter(String name) {
		return (DoubleCounter) getAccumulator(name, DoubleCounter.class);
	}

	@Override
	public <V, A extends Serializable> void addAccumulator(String name, Accumulator<V, A> accumulator) {
		accumulatorRegistry.addAccumulator(name, accumulator);
	}

	@SuppressWarnings("unchecked")
	@Override
	public <V, A extends Serializable> Accumulator<V, A> getAccumulator(String name) {
		return (Accumulator<V, A>) accumulatorRegistry.getAccumulators().get(name);
	}

	@Override
	public Map<String, Accumulator<?, ?>> getAllAccumulators() {
		return accumulatorRegistry.getAccumulators();
	}

	@Override
	public <V, A extends Serializable> void addPreAggregatedAccumulator(String name, Accumulator<V, A> accumulator) {
		accumulatorRegistry.addPreAggregatedAccumulator(name, accumulator);
	}

	@SuppressWarnings("unchecked")
	@Override
	public <V, A extends Serializable> Accumulator<V, A> getPreAggregatedAccumulator(String name) {
		return (Accumulator<V, A>) accumulatorRegistry.getPreAggregatedAccumulators().get(name);
	}

	@Override
	public void commitPreAggregatedAccumulator(String name) {
		accumulatorRegistry.commitPreAggregatedAccumulator(name);
	}

	@Override
	public <V, A extends Serializable> CompletableFuture<Accumulator<V, A>> queryPreAggregatedAccumulator(String name) {
		return accumulatorRegistry.queryPreAggregatedAccumulator(name);
	}

	@Override
	public ClassLoader getUserCodeClassLoader() {
		return this.userCodeClassLoader;
	}
	
	@Override
	public DistributedCache getDistributedCache() {
		return this.distributedCache;
	}
	
	// --------------------------------------------------------------------------------------------
	
	@SuppressWarnings("unchecked")
	private <V, A extends Serializable> Accumulator<V, A> getAccumulator(String name,
			Class<? extends Accumulator<V, A>> accumulatorClass)
	{
		Accumulator<?, ?> accumulator = accumulatorRegistry.getAccumulators().get(name);

		if (accumulator != null) {
			AccumulatorHelper.compareAccumulatorTypes(name, accumulator.getClass(), accumulatorClass);
		} else {
			// Create new accumulator
			try {
				accumulator = accumulatorClass.newInstance();
			}
			catch (Exception e) {
				throw new RuntimeException("Cannot create accumulator " + accumulatorClass.getName());
			}
			accumulatorRegistry.addAccumulator(name, accumulator);
		}
		return (Accumulator<V, A>) accumulator;
	}

	@Override
	@PublicEvolving
	public <T> ValueState<T> getState(ValueStateDescriptor<T> stateProperties) {
		throw new UnsupportedOperationException(
				"This state is only accessible by functions executed on a KeyedStream");
	}

	@Override
	@PublicEvolving
	public <T> ListState<T> getListState(ListStateDescriptor<T> stateProperties) {
		throw new UnsupportedOperationException(
				"This state is only accessible by functions executed on a KeyedStream");
	}

	@Override
	@PublicEvolving
	public <T> ReducingState<T> getReducingState(ReducingStateDescriptor<T> stateProperties) {
		throw new UnsupportedOperationException(
				"This state is only accessible by functions executed on a KeyedStream");
	}

	@Override
	@PublicEvolving
	public <IN, ACC, OUT> AggregatingState<IN, OUT> getAggregatingState(AggregatingStateDescriptor<IN, ACC, OUT> stateProperties) {
		throw new UnsupportedOperationException(
				"This state is only accessible by functions executed on a KeyedStream");
	}

	@Override
	@PublicEvolving
	@Deprecated
	public <T, ACC> FoldingState<T, ACC> getFoldingState(FoldingStateDescriptor<T, ACC> stateProperties) {
		throw new UnsupportedOperationException(
				"This state is only accessible by functions executed on a KeyedStream");
	}
	
	@Override
	@PublicEvolving
	public <UK, UV> MapState<UK, UV> getMapState(MapStateDescriptor<UK, UV> stateProperties) {
		throw new UnsupportedOperationException(
				"This state is only accessible by functions executed on a KeyedStream");
	}

	@Override
	public <UK, UV> SortedMapState<UK, UV> getSortedMapState(SortedMapStateDescriptor<UK, UV> stateProperties) {
		throw new UnsupportedOperationException(
			"This state is only accessible by functions executed on a KeyedStream");
	}

	@Internal
	@VisibleForTesting
	public String getAllocationIDAsString() {
		return taskInfo.getAllocationIDAsString();
	}
}
