/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.beam.runners.apex.translation.operators;

import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkState;

import com.datatorrent.api.Context.OperatorContext;
import com.datatorrent.api.DefaultInputPort;
import com.datatorrent.api.DefaultOutputPort;
import com.datatorrent.api.annotation.InputPortFieldAnnotation;
import com.datatorrent.api.annotation.OutputPortFieldAnnotation;
import com.datatorrent.common.util.BaseOperator;
import com.esotericsoftware.kryo.serializers.FieldSerializer.Bind;
import com.esotericsoftware.kryo.serializers.JavaSerializer;
import com.google.common.collect.Iterables;
import com.google.common.collect.Maps;
import java.util.ArrayList;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.concurrent.Executors;
import org.apache.beam.runners.apex.ApexPipelineOptions;
import org.apache.beam.runners.apex.ApexRunner;
import org.apache.beam.runners.apex.translation.utils.ApexStateInternals.ApexStateBackend;
import org.apache.beam.runners.apex.translation.utils.ApexStreamTuple;
import org.apache.beam.runners.apex.translation.utils.NoOpStepContext;
import org.apache.beam.runners.apex.translation.utils.StateInternalsProxy;
import org.apache.beam.runners.apex.translation.utils.ValueAndCoderKryoSerializable;
import org.apache.beam.runners.core.DoFnRunner;
import org.apache.beam.runners.core.DoFnRunners;
import org.apache.beam.runners.core.DoFnRunners.OutputManager;
import org.apache.beam.runners.core.KeyedWorkItem;
import org.apache.beam.runners.core.KeyedWorkItemCoder;
import org.apache.beam.runners.core.NullSideInputReader;
import org.apache.beam.runners.core.OutputAndTimeBoundedSplittableProcessElementInvoker;
import org.apache.beam.runners.core.OutputWindowedValue;
import org.apache.beam.runners.core.PushbackSideInputDoFnRunner;
import org.apache.beam.runners.core.SideInputHandler;
import org.apache.beam.runners.core.SideInputReader;
import org.apache.beam.runners.core.SimplePushbackSideInputDoFnRunner;
import org.apache.beam.runners.core.SplittableParDoViaKeyedWorkItems.ProcessFn;
import org.apache.beam.runners.core.StateInternals;
import org.apache.beam.runners.core.StateInternalsFactory;
import org.apache.beam.runners.core.StateNamespace;
import org.apache.beam.runners.core.StateNamespaces.WindowNamespace;
import org.apache.beam.runners.core.StatefulDoFnRunner;
import org.apache.beam.runners.core.TimerInternals;
import org.apache.beam.runners.core.TimerInternals.TimerData;
import org.apache.beam.runners.core.TimerInternalsFactory;
import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.coders.ListCoder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.coders.VoidCoder;
import org.apache.beam.sdk.state.TimeDomain;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.reflect.DoFnInvoker;
import org.apache.beam.sdk.transforms.reflect.DoFnInvokers;
import org.apache.beam.sdk.transforms.reflect.DoFnSignature;
import org.apache.beam.sdk.transforms.reflect.DoFnSignatures;
import org.apache.beam.sdk.transforms.splittabledofn.RestrictionTracker;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.transforms.windowing.PaneInfo;
import org.apache.beam.sdk.util.UserCodeException;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.util.WindowedValue.FullWindowedValueCoder;
import org.apache.beam.sdk.util.WindowedValue.WindowedValueCoder;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.sdk.values.WindowingStrategy;
import org.joda.time.Duration;
import org.joda.time.Instant;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;

/**
 * Apex operator for Beam {@link DoFn}.
 */
public class ApexParDoOperator<InputT, OutputT> extends BaseOperator implements OutputManager,
    ApexTimerInternals.TimerProcessor<Object> {
  private static final Logger LOG = LoggerFactory.getLogger(ApexParDoOperator.class);
  private boolean traceTuples = true;

  @Bind(JavaSerializer.class)
  private final SerializablePipelineOptions pipelineOptions;
  @Bind(JavaSerializer.class)
  private final DoFn<InputT, OutputT> doFn;
  @Bind(JavaSerializer.class)
  private final TupleTag<OutputT> mainOutputTag;
  @Bind(JavaSerializer.class)
  private final List<TupleTag<?>> additionalOutputTags;
  @Bind(JavaSerializer.class)
  private final WindowingStrategy<?, ?> windowingStrategy;
  @Bind(JavaSerializer.class)
  private final List<PCollectionView<?>> sideInputs;
  @Bind(JavaSerializer.class)
  private final Coder<WindowedValue<InputT>> inputCoder;

  private StateInternalsProxy<?> currentKeyStateInternals;
  private final ApexTimerInternals<Object> currentKeyTimerInternals;

  private final StateInternals sideInputStateInternals;
  private final ValueAndCoderKryoSerializable<List<WindowedValue<InputT>>> pushedBack;
  private LongMin pushedBackWatermark = new LongMin();
  private long currentInputWatermark = Long.MIN_VALUE;
  private long currentOutputWatermark = currentInputWatermark;

  private transient PushbackSideInputDoFnRunner<InputT, OutputT> pushbackDoFnRunner;
  private transient SideInputHandler sideInputHandler;
  private transient Map<TupleTag<?>, DefaultOutputPort<ApexStreamTuple<?>>>
      additionalOutputPortMapping = Maps.newHashMapWithExpectedSize(5);
  private transient DoFnInvoker<InputT, OutputT> doFnInvoker;

  public ApexParDoOperator(
      ApexPipelineOptions pipelineOptions,
      DoFn<InputT, OutputT> doFn,
      TupleTag<OutputT> mainOutputTag,
      List<TupleTag<?>> additionalOutputTags,
      WindowingStrategy<?, ?> windowingStrategy,
      List<PCollectionView<?>> sideInputs,
      Coder<InputT> linputCoder,
      ApexStateBackend stateBackend
      ) {
    this.pipelineOptions = new SerializablePipelineOptions(pipelineOptions);
    this.doFn = doFn;
    this.mainOutputTag = mainOutputTag;
    this.additionalOutputTags = additionalOutputTags;
    this.windowingStrategy = windowingStrategy;
    this.sideInputs = sideInputs;
    this.sideInputStateInternals = new StateInternalsProxy<>(
        stateBackend.newStateInternalsFactory(VoidCoder.of()));

    if (additionalOutputTags.size() > additionalOutputPorts.length) {
      String msg = String.format("Too many additional outputs (currently only supporting %s).",
          additionalOutputPorts.length);
      throw new UnsupportedOperationException(msg);
    }

    WindowedValueCoder<InputT> wvCoder =
        FullWindowedValueCoder.of(
            linputCoder, this.windowingStrategy.getWindowFn().windowCoder());
    Coder<List<WindowedValue<InputT>>> listCoder = ListCoder.of(wvCoder);
    this.pushedBack = new ValueAndCoderKryoSerializable<>(new ArrayList<WindowedValue<InputT>>(),
        listCoder);
    this.inputCoder = wvCoder;

    TimerInternals.TimerDataCoder timerCoder =
        TimerInternals.TimerDataCoder.of(windowingStrategy.getWindowFn().windowCoder());
    this.currentKeyTimerInternals = new ApexTimerInternals<>(timerCoder);

    if (doFn instanceof ProcessFn) {
      // we know that it is keyed on String
      Coder<?> keyCoder = StringUtf8Coder.of();
      this.currentKeyStateInternals = new StateInternalsProxy<>(
          stateBackend.newStateInternalsFactory(keyCoder));
    } else {
      DoFnSignature signature = DoFnSignatures.getSignature(doFn.getClass());
      if (signature.usesState()) {
        checkArgument(linputCoder instanceof KvCoder, "keyed input required for stateful DoFn");
        @SuppressWarnings("rawtypes")
        Coder<?> keyCoder = ((KvCoder) linputCoder).getKeyCoder();
        this.currentKeyStateInternals = new StateInternalsProxy<>(
            stateBackend.newStateInternalsFactory(keyCoder));
      }
    }
  }

  @SuppressWarnings("unused") // for Kryo
  private ApexParDoOperator() {
    this.pipelineOptions = null;
    this.doFn = null;
    this.mainOutputTag = null;
    this.additionalOutputTags = null;
    this.windowingStrategy = null;
    this.sideInputs = null;
    this.pushedBack = null;
    this.sideInputStateInternals = null;
    this.inputCoder = null;
    this.currentKeyTimerInternals = null;
  }

  public final transient DefaultInputPort<ApexStreamTuple<WindowedValue<InputT>>> input =
      new DefaultInputPort<ApexStreamTuple<WindowedValue<InputT>>>() {
    @Override
    public void process(ApexStreamTuple<WindowedValue<InputT>> t) {
      if (t instanceof ApexStreamTuple.WatermarkTuple) {
        processWatermark((ApexStreamTuple.WatermarkTuple<?>) t);
      } else {
        if (traceTuples) {
          LOG.debug("\ninput {}\n", t.getValue());
        }
        Iterable<WindowedValue<InputT>> justPushedBack = processElementInReadyWindows(t.getValue());
        for (WindowedValue<InputT> pushedBackValue : justPushedBack) {
          pushedBackWatermark.add(pushedBackValue.getTimestamp().getMillis());
          pushedBack.get().add(pushedBackValue);
        }
      }
    }
  };

  @InputPortFieldAnnotation(optional = true)
  public final transient DefaultInputPort<ApexStreamTuple<WindowedValue<Iterable<?>>>> sideInput1 =
      new DefaultInputPort<ApexStreamTuple<WindowedValue<Iterable<?>>>>() {
    @Override
    public void process(ApexStreamTuple<WindowedValue<Iterable<?>>> t) {
      if (t instanceof ApexStreamTuple.WatermarkTuple) {
        // ignore side input watermarks
        return;
      }

      int sideInputIndex = 0;
      if (t instanceof ApexStreamTuple.DataTuple) {
        sideInputIndex = ((ApexStreamTuple.DataTuple<?>) t).getUnionTag();
      }

      if (traceTuples) {
        LOG.debug("\nsideInput {} {}\n", sideInputIndex, t.getValue());
      }

      PCollectionView<?> sideInput = sideInputs.get(sideInputIndex);
      sideInputHandler.addSideInputValue(sideInput, t.getValue());

      List<WindowedValue<InputT>> newPushedBack = new ArrayList<>();
      for (WindowedValue<InputT> elem : pushedBack.get()) {
        Iterable<WindowedValue<InputT>> justPushedBack = processElementInReadyWindows(elem);
        Iterables.addAll(newPushedBack, justPushedBack);
      }

      pushedBack.get().clear();
      pushedBackWatermark.clear();
      for (WindowedValue<InputT> pushedBackValue : newPushedBack) {
        pushedBackWatermark.add(pushedBackValue.getTimestamp().getMillis());
        pushedBack.get().add(pushedBackValue);
      }

      // potentially emit watermark
      processWatermark(ApexStreamTuple.WatermarkTuple.of(currentInputWatermark));
    }
  };

  @OutputPortFieldAnnotation(optional = true)
  public final transient DefaultOutputPort<ApexStreamTuple<?>> output = new DefaultOutputPort<>();

  @OutputPortFieldAnnotation(optional = true)
  public final transient DefaultOutputPort<ApexStreamTuple<?>> additionalOutput1 =
      new DefaultOutputPort<>();
  @OutputPortFieldAnnotation(optional = true)
  public final transient DefaultOutputPort<ApexStreamTuple<?>> additionalOutput2 =
      new DefaultOutputPort<>();
  @OutputPortFieldAnnotation(optional = true)
  public final transient DefaultOutputPort<ApexStreamTuple<?>> additionalOutput3 =
      new DefaultOutputPort<>();
  @OutputPortFieldAnnotation(optional = true)
  public final transient DefaultOutputPort<ApexStreamTuple<?>> additionalOutput4 =
      new DefaultOutputPort<>();
  @OutputPortFieldAnnotation(optional = true)
  public final transient DefaultOutputPort<ApexStreamTuple<?>> additionalOutput5 =
      new DefaultOutputPort<>();

  public final transient DefaultOutputPort<?>[] additionalOutputPorts = {
    additionalOutput1, additionalOutput2, additionalOutput3, additionalOutput4, additionalOutput5
  };

  @Override
  public <T> void output(TupleTag<T> tag, WindowedValue<T> tuple) {
    DefaultOutputPort<ApexStreamTuple<?>> additionalOutputPort =
        additionalOutputPortMapping.get(tag);
    if (additionalOutputPort != null) {
      additionalOutputPort.emit(ApexStreamTuple.DataTuple.of(tuple));
    } else {
      output.emit(ApexStreamTuple.DataTuple.of(tuple));
    }
    if (traceTuples) {
      LOG.debug("\nemitting {}\n", tuple);
    }
  }

  private Iterable<WindowedValue<InputT>> processElementInReadyWindows(WindowedValue<InputT> elem) {
    try {
      pushbackDoFnRunner.startBundle();
      if (currentKeyStateInternals != null) {
        InputT value = elem.getValue();
        final Object key;
        final Coder<Object> keyCoder;
        @SuppressWarnings({ "rawtypes", "unchecked" })
        WindowedValueCoder<InputT> wvCoder = (WindowedValueCoder) inputCoder;
        if (value instanceof KeyedWorkItem) {
          key = ((KeyedWorkItem) value).key();
          @SuppressWarnings({ "rawtypes", "unchecked" })
          KeyedWorkItemCoder<Object, ?> kwiCoder = (KeyedWorkItemCoder) wvCoder.getValueCoder();
          keyCoder = kwiCoder.getKeyCoder();
        } else {
          key = ((KV) value).getKey();
          @SuppressWarnings({ "rawtypes", "unchecked" })
          KvCoder<Object, ?> kwiCoder = (KvCoder) wvCoder.getValueCoder();
          keyCoder = kwiCoder.getKeyCoder();
        }
        ((StateInternalsProxy) currentKeyStateInternals).setKey(key);
        currentKeyTimerInternals.setContext(key, keyCoder,
            new Instant(this.currentInputWatermark),
            new Instant(this.currentOutputWatermark)
            );
      }
      Iterable<WindowedValue<InputT>> pushedBack = pushbackDoFnRunner
          .processElementInReadyWindows(elem);
      pushbackDoFnRunner.finishBundle();
      return pushedBack;
    } catch (UserCodeException ue) {
      if (ue.getCause() instanceof AssertionError) {
        ApexRunner.ASSERTION_ERROR.set((AssertionError) ue.getCause());
      }
      throw ue;
    }
  }

  @Override
  public void fireTimer(Object key, Collection<TimerData> timerDataSet) {
    pushbackDoFnRunner.startBundle();
    @SuppressWarnings("unchecked")
    Coder<Object> keyCoder = (Coder) currentKeyStateInternals.getKeyCoder();
    ((StateInternalsProxy) currentKeyStateInternals).setKey(key);
    currentKeyTimerInternals.setContext(key, keyCoder, new Instant(this.currentInputWatermark),
        new Instant(this.currentOutputWatermark));
    for (TimerData timerData : timerDataSet) {
      StateNamespace namespace = timerData.getNamespace();
      checkArgument(namespace instanceof WindowNamespace);
      BoundedWindow window = ((WindowNamespace<?>) namespace).getWindow();
      pushbackDoFnRunner.onTimer(timerData.getTimerId(), window,
          timerData.getTimestamp(), timerData.getDomain());
    }
    pushbackDoFnRunner.finishBundle();
  }

  private void processWatermark(ApexStreamTuple.WatermarkTuple<?> mark) {
    this.currentInputWatermark = mark.getTimestamp();
    long minEventTimeTimer = currentKeyTimerInternals.fireReadyTimers(
        this.currentInputWatermark,
        this, TimeDomain.EVENT_TIME);

    checkState(minEventTimeTimer >= currentInputWatermark,
        "Event time timer processing generates new timer(s) behind watermark.");
    //LOG.info("Processing time timer {} registered behind watermark {}", minProcessingTimeTimer,
    //    currentInputWatermark);

    // TODO: is this the right way to trigger processing time timers?
    // drain all timers below current watermark, including those that result from firing
    long minProcessingTimeTimer = Long.MIN_VALUE;
    while (minProcessingTimeTimer < currentInputWatermark) {
      minProcessingTimeTimer = currentKeyTimerInternals.fireReadyTimers(
        this.currentInputWatermark,
        this, TimeDomain.PROCESSING_TIME);
      if (minProcessingTimeTimer < currentInputWatermark) {
        LOG.info("Processing time timer {} registered behind watermark {}", minProcessingTimeTimer,
            currentInputWatermark);
      }
    }
    if (sideInputs.isEmpty()) {
      outputWatermark(mark);
      return;
    }

    long potentialOutputWatermark =
        Math.min(pushedBackWatermark.get(), currentInputWatermark);
    if (potentialOutputWatermark > currentOutputWatermark) {
      currentOutputWatermark = potentialOutputWatermark;
      outputWatermark(ApexStreamTuple.WatermarkTuple.of(currentOutputWatermark));
    }
  }

  private void outputWatermark(ApexStreamTuple.WatermarkTuple<?> mark) {
    if (traceTuples) {
      LOG.debug("\nemitting {}\n", mark);
    }
    output.emit(mark);
    if (!additionalOutputPortMapping.isEmpty()) {
      for (DefaultOutputPort<ApexStreamTuple<?>> additionalOutput :
          additionalOutputPortMapping.values()) {
        additionalOutput.emit(mark);
      }
    }
  }

  @Override
  public void setup(OperatorContext context) {
    this.traceTuples =
        ApexStreamTuple.Logging.isDebugEnabled(
            pipelineOptions.get().as(ApexPipelineOptions.class), this);
    SideInputReader sideInputReader = NullSideInputReader.of(sideInputs);
    if (!sideInputs.isEmpty()) {
      sideInputHandler = new SideInputHandler(sideInputs, sideInputStateInternals);
      sideInputReader = sideInputHandler;
    }

    for (int i = 0; i < additionalOutputTags.size(); i++) {
      @SuppressWarnings("unchecked")
      DefaultOutputPort<ApexStreamTuple<?>> port = (DefaultOutputPort<ApexStreamTuple<?>>)
          additionalOutputPorts[i];
      additionalOutputPortMapping.put(additionalOutputTags.get(i), port);
    }

    NoOpStepContext stepContext = new NoOpStepContext() {

      @Override
      public StateInternals stateInternals() {
        return currentKeyStateInternals;
      }

      @Override
      public TimerInternals timerInternals() {
        return currentKeyTimerInternals;
      }

    };
    DoFnRunner<InputT, OutputT> doFnRunner = DoFnRunners.simpleRunner(
        pipelineOptions.get(),
        doFn,
        sideInputReader,
        this,
        mainOutputTag,
        additionalOutputTags,
        stepContext,
        windowingStrategy
        );

    doFnInvoker = DoFnInvokers.invokerFor(doFn);
    doFnInvoker.invokeSetup();

    if (this.currentKeyStateInternals != null) {

      StatefulDoFnRunner.CleanupTimer cleanupTimer =
          new StatefulDoFnRunner.TimeInternalsCleanupTimer(
              stepContext.timerInternals(), windowingStrategy);

      @SuppressWarnings({"rawtypes"})
      Coder windowCoder = windowingStrategy.getWindowFn().windowCoder();

      @SuppressWarnings({"unchecked"})
      StatefulDoFnRunner.StateCleaner<?> stateCleaner =
          new StatefulDoFnRunner.StateInternalsStateCleaner<>(
              doFn, stepContext.stateInternals(), windowCoder);

      doFnRunner = DoFnRunners.defaultStatefulDoFnRunner(
          doFn,
          doFnRunner,
          windowingStrategy,
          cleanupTimer,
          stateCleaner);
    }

    pushbackDoFnRunner =
        SimplePushbackSideInputDoFnRunner.create(doFnRunner, sideInputs, sideInputHandler);

    if (doFn instanceof ProcessFn) {

      @SuppressWarnings("unchecked")
      StateInternalsFactory<String> stateInternalsFactory =
          (StateInternalsFactory<String>) this.currentKeyStateInternals.getFactory();

      @SuppressWarnings({ "rawtypes", "unchecked" })
      ProcessFn<InputT, OutputT, Object, RestrictionTracker<Object>>
        splittableDoFn = (ProcessFn) doFn;
      splittableDoFn.setStateInternalsFactory(stateInternalsFactory);
      TimerInternalsFactory<String> timerInternalsFactory = new TimerInternalsFactory<String>() {
         @Override
         public TimerInternals timerInternalsForKey(String key) {
           return currentKeyTimerInternals;
          }
        };
      splittableDoFn.setTimerInternalsFactory(timerInternalsFactory);
      splittableDoFn.setProcessElementInvoker(
          new OutputAndTimeBoundedSplittableProcessElementInvoker<>(
              doFn,
              pipelineOptions.get(),
              new OutputWindowedValue<OutputT>() {
                @Override
                public void outputWindowedValue(
                    OutputT output,
                    Instant timestamp,
                    Collection<? extends BoundedWindow> windows,
                    PaneInfo pane) {
                  output(
                      mainOutputTag,
                      WindowedValue.of(output, timestamp, windows, pane));
                }

                @Override
                public <AdditionalOutputT> void outputWindowedValue(TupleTag<AdditionalOutputT> tag,
                    AdditionalOutputT output, Instant timestamp,
                    Collection<? extends BoundedWindow> windows, PaneInfo pane) {
                  output(tag, WindowedValue.of(output, timestamp, windows, pane));
                }
              },
              sideInputReader,
              Executors.newSingleThreadScheduledExecutor(Executors.defaultThreadFactory()),
              10000,
              Duration.standardSeconds(10)));
    }

  }

  @Override
  public void teardown() {
    doFnInvoker.invokeTeardown();
    super.teardown();
  }

  @Override
  public void beginWindow(long windowId) {
  }

  @Override
  public void endWindow() {
    currentKeyTimerInternals.fireReadyTimers(
        currentKeyTimerInternals.currentProcessingTime().getMillis(),
        this, TimeDomain.PROCESSING_TIME);
  }

  private static class LongMin {
    long state = Long.MAX_VALUE;

    public void add(long l) {
      state = Math.min(state, l);
    }

    public long get() {
      return state;
    }

    public void clear() {
      state = Long.MAX_VALUE;
    }

  }

}
