/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.beam.runners.jet;

import static com.hazelcast.jet.impl.util.ExceptionUtil.rethrow;
import static java.util.stream.Collectors.toList;
import static org.apache.beam.vendor.guava.v20_0.com.google.common.base.Preconditions.checkState;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.util.Collection;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.function.Function;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.beam.runners.core.construction.ParDoTranslation;
import org.apache.beam.runners.core.construction.TransformInputs;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.ListCoder;
import org.apache.beam.sdk.runners.AppliedPTransform;
import org.apache.beam.sdk.runners.TransformHierarchy;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.reflect.DoFnSignatures;
import org.apache.beam.sdk.util.CoderUtils;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.PValue;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.sdk.values.WindowingStrategy;
import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.Iterables;

/** Various common methods used by the Jet based runner. */
public class Utils {

  public static String getTupleTagId(PValue value) {
    Map<TupleTag<?>, PValue> expansion = value.expand();
    return Iterables.getOnlyElement(expansion.keySet()).getId();
  }

  static PValue getMainInput(Pipeline pipeline, TransformHierarchy.Node node) {
    Collection<PValue> mainInputs = getMainInputs(pipeline, node);
    return mainInputs == null ? null : Iterables.getOnlyElement(mainInputs);
  }

  static Collection<PValue> getMainInputs(Pipeline pipeline, TransformHierarchy.Node node) {
    if (node.getTransform() == null) {
      return null;
    }
    return TransformInputs.nonAdditionalInputs(node.toAppliedPTransform(pipeline));
  }

  static Map<TupleTag<?>, PValue> getInputs(AppliedPTransform<?, ?, ?> appliedTransform) {
    return appliedTransform.getInputs();
  }

  static Map<TupleTag<?>, PValue> getAdditionalInputs(TransformHierarchy.Node node) {
    return node.getTransform() != null ? node.getTransform().getAdditionalInputs() : null;
  }

  static PCollection getInput(AppliedPTransform<?, ?, ?> appliedTransform) {
    if (appliedTransform.getTransform() == null) {
      return null;
    }
    return (PCollection)
        Iterables.getOnlyElement(TransformInputs.nonAdditionalInputs(appliedTransform));
  }

  static Map<TupleTag<?>, PValue> getOutputs(AppliedPTransform<?, ?, ?> appliedTransform) {
    if (appliedTransform.getTransform() == null) {
      return null;
    }
    return appliedTransform.getOutputs();
  }

  static Map.Entry<TupleTag<?>, PValue> getOutput(AppliedPTransform<?, ?, ?> appliedTransform) {
    return Iterables.getOnlyElement(getOutputs(appliedTransform).entrySet());
  }

  static <T> boolean isBounded(AppliedPTransform<?, ?, ?> appliedTransform) {
    return ((PCollection) getOutput(appliedTransform).getValue())
        .isBounded()
        .equals(PCollection.IsBounded.BOUNDED);
  }

  static Coder getCoder(PCollection pCollection) {
    if (pCollection == null) {
      return null;
    }

    if (pCollection.getWindowingStrategy() == null) {
      return pCollection.getCoder();
    } else {
      return WindowedValue.FullWindowedValueCoder.of(
          pCollection.getCoder(), pCollection.getWindowingStrategy().getWindowFn().windowCoder());
    }
  }

  static <T> Map<T, Coder> getCoders(
      Map<TupleTag<?>, PValue> pCollections,
      Function<Map.Entry<TupleTag<?>, PValue>, T> tupleTagExtractor) {
    return pCollections.entrySet().stream()
        .collect(Collectors.toMap(tupleTagExtractor, e -> getCoder((PCollection) e.getValue())));
  }

  static Map<TupleTag<?>, Coder<?>> getOutputValueCoders(
      AppliedPTransform<?, ?, ?> appliedTransform) {
    return appliedTransform.getOutputs().entrySet().stream()
        .filter(e -> e.getValue() instanceof PCollection)
        .collect(Collectors.toMap(Map.Entry::getKey, e -> ((PCollection) e.getValue()).getCoder()));
  }

  static List<PCollectionView<?>> getSideInputs(AppliedPTransform<?, ?, ?> appliedTransform) {
    PTransform<?, ?> transform = appliedTransform.getTransform();
    if (transform instanceof ParDo.MultiOutput) {
      ParDo.MultiOutput multiParDo = (ParDo.MultiOutput) transform;
      return multiParDo.getSideInputs();
    } else if (transform instanceof ParDo.SingleOutput) {
      ParDo.SingleOutput singleParDo = (ParDo.SingleOutput) transform;
      return singleParDo.getSideInputs();
    }
    return Collections.emptyList();
  }

  static boolean usesStateOrTimers(AppliedPTransform<?, ?, ?> appliedTransform) {
    try {
      return ParDoTranslation.usesStateOrTimers(appliedTransform);
    } catch (IOException e) {
      throw new RuntimeException(e);
    }
  }

  static DoFn<?, ?> getDoFn(AppliedPTransform<?, ?, ?> appliedTransform) {
    try {
      DoFn<?, ?> doFn = ParDoTranslation.getDoFn(appliedTransform);
      if (DoFnSignatures.signatureForDoFn(doFn).processElement().isSplittable()) {
        throw new IllegalStateException(
            "Not expected to directly translate splittable DoFn, should have been overridden: "
                + doFn); // todo
      }
      return doFn;
    } catch (IOException e) {
      throw new RuntimeException(e);
    }
  }

  static WindowingStrategy<?, ?> getWindowingStrategy(AppliedPTransform<?, ?, ?> appliedTransform) {
    // assume that the windowing strategy is the same for all outputs

    Map<TupleTag<?>, PValue> outputs = getOutputs(appliedTransform);

    if (outputs == null || outputs.isEmpty()) {
      throw new IllegalStateException("No outputs defined.");
    }

    PValue taggedValue = outputs.values().iterator().next();
    checkState(
        taggedValue instanceof PCollection,
        "Within ParDo, got a non-PCollection output %s of type %s",
        taggedValue,
        taggedValue.getClass().getSimpleName());
    PCollection<?> coll = (PCollection<?>) taggedValue;
    return coll.getWindowingStrategy();
  }

  /**
   * Assigns the {@code list} to {@code count} sublists in a round-robin fashion. One call returns
   * the {@code index}-th sublist.
   *
   * <p>For example, for a 7-element list where {@code count == 3}, it would respectively return for
   * indices 0..2:
   *
   * <pre>
   *   0, 3, 6
   *   1, 4
   *   2, 5
   * </pre>
   */
  public static <T> List<T> roundRobinSubList(List<T> list, int index, int count) {
    if (index < 0 || index >= count) {
      throw new IllegalArgumentException("index=" + index + ", count=" + count);
    }
    return IntStream.range(0, list.size())
        .filter(i -> i % count == index)
        .mapToObj(list::get)
        .collect(toList());
  }

  /** Returns a deep clone of an object by serializing and deserializing it (ser-de). */
  @SuppressWarnings("unchecked")
  public static <T> T serde(T object) {
    try {
      ByteArrayOutputStream baos = new ByteArrayOutputStream();
      ObjectOutputStream oos = new ObjectOutputStream(baos);
      oos.writeObject(object);
      oos.close();
      byte[] byteData = baos.toByteArray();
      ByteArrayInputStream bais = new ByteArrayInputStream(byteData);
      return (T) new ObjectInputStream(bais).readObject();
    } catch (Exception e) {
      throw new RuntimeException(e);
    }
  }

  public static <T> byte[] encodeWindowedValue(WindowedValue<T> windowedValue, Coder coder) {
    try {
      return CoderUtils.encodeToByteArray(coder, windowedValue);
    } catch (IOException e) {
      throw rethrow(e);
    }
  }

  public static <T> WindowedValue<T> decodeWindowedValue(byte[] item, Coder coder) {
    try {
      return (WindowedValue<T>) CoderUtils.decodeFromByteArray(coder, item);
    } catch (IOException e) {
      throw rethrow(e);
    }
  }

  public static WindowedValue.FullWindowedValueCoder deriveIterableValueCoder(
      WindowedValue.FullWindowedValueCoder elementCoder) {
    return WindowedValue.FullWindowedValueCoder.of(
        ListCoder.of(elementCoder.getValueCoder()), elementCoder.getWindowCoder());
  }
}
