/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.beam.runners.dataflow.worker;

import static org.apache.beam.runners.dataflow.util.Structs.getBytes;
import static org.apache.beam.runners.dataflow.util.Structs.getString;

import com.google.api.services.dataflow.model.SideInputInfo;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import javax.annotation.Nullable;
import org.apache.beam.runners.core.GlobalCombineFnRunner;
import org.apache.beam.runners.core.GlobalCombineFnRunners;
import org.apache.beam.runners.core.SideInputReader;
import org.apache.beam.runners.dataflow.util.CloudObject;
import org.apache.beam.runners.dataflow.util.PropertyNames;
import org.apache.beam.runners.dataflow.worker.util.WorkerPropertyNames;
import org.apache.beam.runners.dataflow.worker.util.common.worker.ParDoFn;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.IterableCoder;
import org.apache.beam.sdk.coders.KvCoder;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.transforms.Combine.CombineFn;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.DoFnSchemaInformation;
import org.apache.beam.sdk.transforms.windowing.BoundedWindow;
import org.apache.beam.sdk.util.AppliedCombineFn;
import org.apache.beam.sdk.util.DoFnInfo;
import org.apache.beam.sdk.util.SerializableUtils;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions;

/**
 * A {@link ParDoFnFactory} to create instances of user {@link CombineFn} according to
 * specifications from the Dataflow service.
 */
class CombineValuesFnFactory implements ParDoFnFactory {

  @Override
  public ParDoFn create(
      PipelineOptions options,
      CloudObject cloudUserFn,
      @Nullable List<SideInputInfo> sideInputInfos,
      TupleTag<?> mainOutputTag,
      Map<TupleTag<?>, Integer> outputTupleTagsToReceiverIndices,
      DataflowExecutionContext<?> executionContext,
      DataflowOperationContext operationContext)
      throws Exception {

    Preconditions.checkArgument(
        outputTupleTagsToReceiverIndices.size() == 1,
        "expected exactly one output for CombineValuesFn");

    Object deserializedFn =
        SerializableUtils.deserializeFromByteArray(
            getBytes(cloudUserFn, PropertyNames.SERIALIZED_FN), "serialized user fn");
    Preconditions.checkArgument(deserializedFn instanceof AppliedCombineFn);
    AppliedCombineFn<?, ?, ?, ?> combineFn = (AppliedCombineFn<?, ?, ?, ?>) deserializedFn;
    Iterable<PCollectionView<?>> sideInputViews = combineFn.getSideInputViews();
    SideInputReader sideInputReader =
        executionContext.getSideInputReader(sideInputInfos, sideInputViews, operationContext);

    // Get the combine phase, default to ALL. (The implementation
    // doesn't have to split the combiner).
    String phase = getString(cloudUserFn, WorkerPropertyNames.PHASE, CombinePhase.ALL);

    DoFnInfo<?, ?> doFnInfo = getDoFnInfo(combineFn, sideInputReader, phase);
    return new SimpleParDoFn(
        options,
        DoFnInstanceManagers.singleInstance(doFnInfo),
        sideInputReader,
        mainOutputTag,
        outputTupleTagsToReceiverIndices,
        executionContext.getStepContext(operationContext),
        operationContext,
        doFnInfo.getDoFnSchemaInformation(),
        doFnInfo.getSideInputMapping(),
        SimpleDoFnRunnerFactory.INSTANCE);
  }

  private static <K, InputT, AccumT, OutputT> DoFnInfo<?, ?> getDoFnInfo(
      AppliedCombineFn<K, InputT, AccumT, OutputT> combineFn,
      SideInputReader sideInputReader,
      String phase) {
    switch (phase) {
      case CombinePhase.ALL:
        return CombineValuesDoFn.createDoFnInfo(combineFn, sideInputReader);
      case CombinePhase.ADD:
        return AddInputsDoFn.createDoFnInfo(combineFn, sideInputReader);
      case CombinePhase.MERGE:
        return MergeAccumulatorsDoFn.createDoFnInfo(combineFn, sideInputReader);
      case CombinePhase.EXTRACT:
        return ExtractOutputDoFn.createDoFnInfo(combineFn, sideInputReader);
      default:
        throw new IllegalArgumentException("phase must be one of 'all', 'add', 'merge', 'extract'");
    }
  }

  /**
   * The ALL phase is the unsplit combiner, in case combiner lifting is disabled or the optimizer
   * chose not to lift this combiner.
   */
  private static class CombineValuesDoFn<K, InputT, OutputT>
      extends DoFn<KV<K, Iterable<InputT>>, KV<K, OutputT>> {

    private static <K, InputT, AccumT, OutputT> DoFnInfo<?, ?> createDoFnInfo(
        AppliedCombineFn<K, InputT, AccumT, OutputT> combineFn, SideInputReader sideInputReader) {
      GlobalCombineFnRunner<InputT, AccumT, OutputT> combineFnRunner =
          GlobalCombineFnRunners.create(combineFn.getFn());
      DoFn<KV<K, Iterable<InputT>>, KV<K, OutputT>> doFn =
          new CombineValuesDoFn<>(combineFnRunner, sideInputReader);

      Coder<KV<K, Iterable<InputT>>> inputCoder = null;
      if (combineFn.getKvCoder() != null) {
        inputCoder =
            KvCoder.of(
                combineFn.getKvCoder().getKeyCoder(),
                IterableCoder.of(combineFn.getKvCoder().getValueCoder()));
      }
      return DoFnInfo.forFn(
          doFn,
          combineFn.getWindowingStrategy(),
          combineFn.getSideInputViews(),
          inputCoder,
          Collections.emptyMap(), // Not needed here.
          new TupleTag<>(PropertyNames.OUTPUT),
          DoFnSchemaInformation.create(),
          Collections.emptyMap());
    }

    private final GlobalCombineFnRunner<InputT, ?, OutputT> combineFnRunner;
    private final SideInputReader sideInputReader;

    private CombineValuesDoFn(
        GlobalCombineFnRunner<InputT, ?, OutputT> combineFnRunner,
        SideInputReader sideInputReader) {
      this.combineFnRunner = combineFnRunner;
      this.sideInputReader = sideInputReader;
    }

    @ProcessElement
    public void processElement(ProcessContext c, BoundedWindow window) {
      KV<K, Iterable<InputT>> kv = c.element();
      c.output(
          KV.of(
              kv.getKey(),
              applyCombineFn(combineFnRunner, kv.getValue(), window, c.getPipelineOptions())));
    }

    private <AccumT> OutputT applyCombineFn(
        GlobalCombineFnRunner<InputT, AccumT, OutputT> combineFnRunner,
        Iterable<InputT> inputs,
        BoundedWindow window,
        PipelineOptions options) {
      List<BoundedWindow> windows = Collections.singletonList(window);
      AccumT accum = combineFnRunner.createAccumulator(options, sideInputReader, windows);
      for (InputT input : inputs) {
        accum = combineFnRunner.addInput(accum, input, options, sideInputReader, windows);
      }
      return combineFnRunner.extractOutput(accum, options, sideInputReader, windows);
    }
  }

  /*
   * ADD phase: KV<K, Iterable<InputT>> -> KV<K, AccumT>.
   */
  private static class AddInputsDoFn<K, InputT, AccumT>
      extends DoFn<KV<K, Iterable<InputT>>, KV<K, AccumT>> {

    private final SideInputReader sideInputReader;

    private static <K, InputT, AccumT, OutputT> DoFnInfo<?, ?> createDoFnInfo(
        AppliedCombineFn<K, InputT, AccumT, OutputT> combineFn, SideInputReader sideInputReader) {
      GlobalCombineFnRunner<InputT, AccumT, OutputT> combineFnRunner =
          GlobalCombineFnRunners.create(combineFn.getFn());
      DoFn<KV<K, Iterable<InputT>>, KV<K, AccumT>> doFn =
          new AddInputsDoFn<>(combineFnRunner, sideInputReader);

      Coder<KV<K, Iterable<InputT>>> inputCoder = null;
      if (combineFn.getKvCoder() != null) {
        inputCoder =
            KvCoder.of(
                combineFn.getKvCoder().getKeyCoder(),
                IterableCoder.of(combineFn.getKvCoder().getValueCoder()));
      }
      return DoFnInfo.forFn(
          doFn,
          combineFn.getWindowingStrategy(),
          combineFn.getSideInputViews(),
          inputCoder,
          Collections.emptyMap(), // Not needed here.
          new TupleTag<>(PropertyNames.OUTPUT),
          DoFnSchemaInformation.create(),
          Collections.emptyMap());
    }

    private final GlobalCombineFnRunner<InputT, AccumT, ?> combineFnRunner;

    private AddInputsDoFn(
        GlobalCombineFnRunner<InputT, AccumT, ?> combineFnRunner, SideInputReader sideInputReader) {
      this.combineFnRunner = combineFnRunner;
      this.sideInputReader = sideInputReader;
    }

    @ProcessElement
    public void processElement(ProcessContext c, BoundedWindow window) {
      KV<K, Iterable<InputT>> kv = c.element();
      K key = kv.getKey();

      List<BoundedWindow> windows = Collections.singletonList(window);
      AccumT accum =
          combineFnRunner.createAccumulator(c.getPipelineOptions(), sideInputReader, windows);

      for (InputT input : kv.getValue()) {
        accum =
            combineFnRunner.addInput(
                accum, input, c.getPipelineOptions(), sideInputReader, windows);
      }

      c.output(KV.of(key, accum));
    }
  }

  /*
   * MERGE phase: KV<K, Iterable<AccumT>> -> KV<K, AccumT>.
   */
  private static class MergeAccumulatorsDoFn<K, AccumT>
      extends DoFn<KV<K, Iterable<AccumT>>, KV<K, AccumT>> {

    private final SideInputReader sideInputReader;

    private static <K, InputT, AccumT, OutputT> DoFnInfo<?, ?> createDoFnInfo(
        AppliedCombineFn<K, InputT, AccumT, OutputT> combineFn, SideInputReader sideInputReader) {
      GlobalCombineFnRunner<InputT, AccumT, OutputT> combineFnRunner =
          GlobalCombineFnRunners.create(combineFn.getFn());
      DoFn<KV<K, Iterable<AccumT>>, KV<K, AccumT>> doFn =
          new MergeAccumulatorsDoFn<>(combineFnRunner, sideInputReader);

      KvCoder<K, Iterable<AccumT>> inputCoder = null;
      if (combineFn.getKvCoder() != null) {
        inputCoder =
            KvCoder.of(
                combineFn.getKvCoder().getKeyCoder(),
                IterableCoder.of(combineFn.getAccumulatorCoder()));
      }
      return DoFnInfo.forFn(
          doFn,
          combineFn.getWindowingStrategy(),
          combineFn.getSideInputViews(),
          inputCoder,
          Collections.emptyMap(), // Not needed here.
          new TupleTag<>(PropertyNames.OUTPUT),
          DoFnSchemaInformation.create(),
          Collections.emptyMap());
    }

    private final GlobalCombineFnRunner<?, AccumT, ?> combineFnRunner;

    private MergeAccumulatorsDoFn(
        GlobalCombineFnRunner<?, AccumT, ?> combineFnRunner, SideInputReader sideInputReader) {
      this.combineFnRunner = combineFnRunner;
      this.sideInputReader = sideInputReader;
    }

    @ProcessElement
    public void processElement(ProcessContext c, BoundedWindow window) {
      KV<K, Iterable<AccumT>> kv = c.element();
      K key = kv.getKey();
      AccumT accum =
          this.combineFnRunner.mergeAccumulators(
              kv.getValue(),
              c.getPipelineOptions(),
              sideInputReader,
              Collections.singletonList(window));
      c.output(KV.of(key, accum));
    }
  }

  /*
   * EXTRACT phase: KV<K, AccumT> -> KV<K, OutputT>.
   */
  private static class ExtractOutputDoFn<K, AccumT, OutputT>
      extends DoFn<KV<K, AccumT>, KV<K, OutputT>> {
    private static <K, InputT, AccumT, OutputT> DoFnInfo<?, ?> createDoFnInfo(
        AppliedCombineFn<K, InputT, AccumT, OutputT> combineFn, SideInputReader sideInputReader) {
      GlobalCombineFnRunner<InputT, AccumT, OutputT> combineFnRunner =
          GlobalCombineFnRunners.create(combineFn.getFn());
      DoFn<KV<K, AccumT>, KV<K, OutputT>> doFn =
          new ExtractOutputDoFn<>(combineFnRunner, sideInputReader);

      KvCoder<K, AccumT> inputCoder = null;
      if (combineFn.getKvCoder() != null) {
        inputCoder =
            KvCoder.of(combineFn.getKvCoder().getKeyCoder(), combineFn.getAccumulatorCoder());
      }
      return DoFnInfo.forFn(
          doFn,
          combineFn.getWindowingStrategy(),
          combineFn.getSideInputViews(),
          inputCoder,
          Collections.emptyMap(), // Not needed here.
          new TupleTag<>(PropertyNames.OUTPUT),
          DoFnSchemaInformation.create(),
          Collections.emptyMap());
    }

    private final GlobalCombineFnRunner<?, AccumT, OutputT> combineFnRunner;
    private final SideInputReader sideInputReader;

    private ExtractOutputDoFn(
        GlobalCombineFnRunner<?, AccumT, OutputT> combineFnRunner,
        SideInputReader sideInputReader) {
      this.combineFnRunner = combineFnRunner;
      this.sideInputReader = sideInputReader;
    }

    @ProcessElement
    public void processElement(ProcessContext c, BoundedWindow window) {
      KV<K, AccumT> kv = c.element();
      K key = kv.getKey();
      OutputT output =
          this.combineFnRunner.extractOutput(
              kv.getValue(),
              c.getPipelineOptions(),
              sideInputReader,
              Collections.singletonList(window));
      c.output(KV.of(key, output));
    }
  }
}
