/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.beam.runners.dataflow.worker.graph;

import static org.hamcrest.Matchers.containsInAnyOrder;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertThat;

import com.google.api.services.dataflow.model.InstructionOutput;
import com.google.api.services.dataflow.model.ParDoInstruction;
import com.google.api.services.dataflow.model.ParallelInstruction;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import javax.annotation.Nullable;
import org.apache.beam.model.pipeline.v1.RunnerApi;
import org.apache.beam.model.pipeline.v1.RunnerApi.ParDoPayload;
import org.apache.beam.model.pipeline.v1.RunnerApi.SdkFunctionSpec;
import org.apache.beam.runners.core.construction.PTransformTranslation;
import org.apache.beam.runners.core.construction.ParDoTranslation;
import org.apache.beam.runners.core.construction.PipelineTranslation;
import org.apache.beam.runners.core.construction.SdkComponents;
import org.apache.beam.runners.dataflow.util.CloudObject;
import org.apache.beam.runners.dataflow.util.PropertyNames;
import org.apache.beam.runners.dataflow.worker.NameContextsForTests;
import org.apache.beam.runners.dataflow.worker.graph.Edges.DefaultEdge;
import org.apache.beam.runners.dataflow.worker.graph.Edges.Edge;
import org.apache.beam.runners.dataflow.worker.graph.Nodes.ExecutionLocation;
import org.apache.beam.runners.dataflow.worker.graph.Nodes.FetchAndFilterStreamingSideInputsNode;
import org.apache.beam.runners.dataflow.worker.graph.Nodes.InstructionOutputNode;
import org.apache.beam.runners.dataflow.worker.graph.Nodes.Node;
import org.apache.beam.runners.dataflow.worker.graph.Nodes.ParallelInstructionNode;
import org.apache.beam.sdk.Pipeline;
import org.apache.beam.sdk.options.PipelineOptionsFactory;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.View;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.vendor.grpc.v1p21p0.com.google.protobuf.InvalidProtocolBufferException;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Equivalence;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Equivalence.Wrapper;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.Iterables;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.graph.ImmutableNetwork;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.graph.MutableNetwork;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.graph.Network;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.graph.NetworkBuilder;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

/** Tests for {@link InsertFetchAndFilterStreamingSideInputNodes}. */
@RunWith(JUnit4.class)
public class InsertFetchAndFilterStreamingSideInputNodesTest {

  @Test
  public void testWithoutPipeline() throws Exception {
    Node unknown = createParDoNode("parDoId");

    MutableNetwork<Node, Edge> network = createEmptyNetwork();
    network.addNode(unknown);

    Network<Node, Edge> inputNetwork = ImmutableNetwork.copyOf(network);
    network = InsertFetchAndFilterStreamingSideInputNodes.with(null).forNetwork(network);

    assertThatNetworksAreIdentical(inputNetwork, network);
  }

  @Test
  public void testSdkParDoWithSideInput() throws Exception {
    Pipeline p = Pipeline.create();
    PCollection<String> pc = p.apply(Create.of("a", "b", "c"));
    PCollectionView<Iterable<String>> pcView = pc.apply(View.asIterable());
    pc.apply(ParDo.of(new TestDoFn(pcView)).withSideInputs(pcView));
    RunnerApi.Pipeline pipeline = PipelineTranslation.toProto(p);

    Node predecessor = createParDoNode("predecessor");
    InstructionOutputNode mainInput =
        InstructionOutputNode.create(new InstructionOutput(), "fakeId");
    Node sideInputParDo = createParDoNode(findParDoWithSideInput(pipeline));

    MutableNetwork<Node, Edge> network = createEmptyNetwork();
    network.addNode(predecessor);
    network.addNode(mainInput);
    network.addNode(sideInputParDo);
    network.addEdge(predecessor, mainInput, DefaultEdge.create());
    network.addEdge(mainInput, sideInputParDo, DefaultEdge.create());

    Network<Node, Edge> inputNetwork = ImmutableNetwork.copyOf(network);
    network = InsertFetchAndFilterStreamingSideInputNodes.with(pipeline).forNetwork(network);

    Node mainInputClone = InstructionOutputNode.create(mainInput.getInstructionOutput(), "fakeId");
    Node fetchAndFilter =
        FetchAndFilterStreamingSideInputsNode.create(
            pcView.getWindowingStrategyInternal(),
            ImmutableMap.of(
                pcView,
                ParDoTranslation.translateWindowMappingFn(
                    pcView.getWindowMappingFn(),
                    SdkComponents.create(PipelineOptionsFactory.create()))),
            NameContextsForTests.nameContextForTest());

    MutableNetwork<Node, Edge> expectedNetwork = createEmptyNetwork();
    expectedNetwork.addNode(predecessor);
    expectedNetwork.addNode(mainInputClone);
    expectedNetwork.addNode(fetchAndFilter);
    expectedNetwork.addNode(mainInput);
    expectedNetwork.addNode(sideInputParDo);
    expectedNetwork.addEdge(predecessor, mainInputClone, DefaultEdge.create());
    expectedNetwork.addEdge(mainInputClone, fetchAndFilter, DefaultEdge.create());
    expectedNetwork.addEdge(fetchAndFilter, mainInput, DefaultEdge.create());
    expectedNetwork.addEdge(mainInput, sideInputParDo, DefaultEdge.create());

    assertThatNetworksAreIdentical(expectedNetwork, network);
  }

  @Test
  public void testSdkParDoWithoutSideInput() throws Exception {
    Pipeline p = Pipeline.create();
    PCollection<String> pc = p.apply(Create.of("a", "b", "c"));
    pc.apply(ParDo.of(new TestDoFn(null)));
    RunnerApi.Pipeline pipeline = PipelineTranslation.toProto(p);

    Node predecessor = createParDoNode("predecessor");
    Node mainInput = InstructionOutputNode.create(new InstructionOutput(), "fakeId");
    Node sideInputParDo = createParDoNode("noSideInput");

    MutableNetwork<Node, Edge> network = createEmptyNetwork();
    network.addNode(predecessor);
    network.addNode(mainInput);
    network.addNode(sideInputParDo);
    network.addEdge(predecessor, mainInput, DefaultEdge.create());
    network.addEdge(mainInput, sideInputParDo, DefaultEdge.create());

    Network<Node, Edge> inputNetwork = ImmutableNetwork.copyOf(network);
    network = InsertFetchAndFilterStreamingSideInputNodes.with(pipeline).forNetwork(network);

    assertThatNetworksAreIdentical(inputNetwork, network);
  }

  private String findParDoWithSideInput(RunnerApi.Pipeline pipeline) {
    for (Map.Entry<String, RunnerApi.PTransform> entry :
        pipeline.getComponents().getTransformsMap().entrySet()) {
      if (!PTransformTranslation.PAR_DO_TRANSFORM_URN.equals(entry.getValue().getSpec().getUrn())) {
        continue;
      }
      try {
        ParDoPayload payload = ParDoPayload.parseFrom(entry.getValue().getSpec().getPayload());
        if (!payload.getSideInputsMap().isEmpty()) {
          return entry.getKey();
        }
      } catch (InvalidProtocolBufferException e) {
        throw new IllegalStateException(String.format("Failed to parse PTransform %s", entry));
      }
    }
    throw new IllegalStateException("No side input ptransform found");
  }

  private static class TestDoFn extends DoFn<String, Iterable<String>> {
    @Nullable private final PCollectionView<Iterable<String>> pCollectionView;

    private TestDoFn(@Nullable PCollectionView<Iterable<String>> pCollectionView) {
      this.pCollectionView = pCollectionView;
    }

    @ProcessElement
    public void processElement(ProcessContext context) {}
  }

  private static final class NodeEquivalence extends Equivalence<Node> {
    static final NodeEquivalence INSTANCE = new NodeEquivalence();

    @Override
    protected boolean doEquivalent(Node a, Node b) {
      if (a instanceof FetchAndFilterStreamingSideInputsNode
          && b instanceof FetchAndFilterStreamingSideInputsNode) {
        FetchAndFilterStreamingSideInputsNode nodeA = (FetchAndFilterStreamingSideInputsNode) a;
        FetchAndFilterStreamingSideInputsNode nodeB = (FetchAndFilterStreamingSideInputsNode) b;
        Map.Entry<PCollectionView<?>, SdkFunctionSpec> nodeAEntry =
            Iterables.getOnlyElement(nodeA.getPCollectionViewsToWindowMappingFns().entrySet());
        Map.Entry<PCollectionView<?>, SdkFunctionSpec> nodeBEntry =
            Iterables.getOnlyElement(nodeB.getPCollectionViewsToWindowMappingFns().entrySet());
        return Objects.equals(
                nodeAEntry.getKey().getTagInternal(), nodeBEntry.getKey().getTagInternal())
            && Objects.equals(nodeAEntry.getValue(), nodeBEntry.getValue());
      } else if (a instanceof InstructionOutputNode && b instanceof InstructionOutputNode) {
        return Objects.equals(
            ((InstructionOutputNode) a).getInstructionOutput(),
            ((InstructionOutputNode) b).getInstructionOutput());
      } else {
        return a.equals(b); // Make sure that other nodes haven't been modified
      }
    }

    @Override
    protected int doHash(Node n) {
      return n.hashCode();
    }
  }

  /**
   * Asserts that the structure and nodes of two graphs are identical except for the deduced
   * ExecutionLocations, and that all paths through the graph still exist.
   */
  private void assertThatNetworksAreIdentical(
      Network<Node, Edge> oldNetwork, Network<Node, Edge> newNetwork) {
    // Assert that both networks still have same number of nodes and edges.
    assertEquals(oldNetwork.nodes().size(), newNetwork.nodes().size());
    assertEquals(oldNetwork.edges().size(), newNetwork.edges().size());

    // Assert that all paths still exist with identical nodes in each path.
    List<List<Wrapper<Node>>> oldPaths = allPathsWithWrappedNodes(oldNetwork);
    List<List<Equivalence.Wrapper<Node>>> newPaths = allPathsWithWrappedNodes(newNetwork);
    assertThat(oldPaths, containsInAnyOrder(newPaths.toArray()));
  }

  private List<List<Equivalence.Wrapper<Node>>> allPathsWithWrappedNodes(
      Network<Node, Edge> network) {
    List<List<Node>> paths = Networks.allPathsFromRootsToLeaves(network);
    List<List<Equivalence.Wrapper<Node>>> wrappedPaths = new ArrayList<>();
    for (List<Node> path : paths) {
      List<Equivalence.Wrapper<Node>> wrappedPath = new ArrayList<>();
      for (Node node : path) {
        wrappedPath.add(NodeEquivalence.INSTANCE.wrap(node));
      }
      wrappedPaths.add(wrappedPath);
    }

    return wrappedPaths;
  }

  private static MutableNetwork<Node, Edge> createEmptyNetwork() {
    return NetworkBuilder.directed()
        .allowsSelfLoops(false)
        .allowsParallelEdges(true)
        .<Node, Edge>build();
  }

  private static ParallelInstructionNode createParDoNode(String parDoId) {
    CloudObject userFn = CloudObject.forClassName("DoFn");
    userFn.put(PropertyNames.SERIALIZED_FN, parDoId);
    return ParallelInstructionNode.create(
        new ParallelInstruction().setParDo(new ParDoInstruction().setUserFn(userFn)),
        ExecutionLocation.SDK_HARNESS);
  }
}
