/*
 * Licensed to the Apache Software Foundation (ASF) under one
 * or more contributor license agreements.  See the NOTICE file
 * distributed with this work for additional information
 * regarding copyright ownership.  The ASF licenses this file
 * to you under the Apache License, Version 2.0 (the
 * "License"); you may not use this file except in compliance
 * with the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
package org.apache.beam.runners.core.construction;

import static org.hamcrest.Matchers.equalTo;
import static org.junit.Assert.assertThat;

import java.util.Map;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.coders.VarIntCoder;
import org.apache.beam.sdk.runners.PTransformOverrideFactory.ReplacementOutput;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollection.IsBounded;
import org.apache.beam.sdk.values.PCollectionTuple;
import org.apache.beam.sdk.values.PValue;
import org.apache.beam.sdk.values.TaggedPValue;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.sdk.values.WindowingStrategy;
import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.ImmutableMap;
import org.apache.beam.vendor.guava.v20_0.com.google.common.collect.Iterables;
import org.hamcrest.Matchers;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;

/** Tests for {@link ReplacementOutputs}. */
@RunWith(JUnit4.class)
public class ReplacementOutputsTest {
  @Rule public ExpectedException thrown = ExpectedException.none();
  private TestPipeline p = TestPipeline.create();

  private PCollection<Integer> ints =
      PCollection.createPrimitiveOutputInternal(
          p, WindowingStrategy.globalDefault(), IsBounded.BOUNDED, VarIntCoder.of());
  private PCollection<Integer> moreInts =
      PCollection.createPrimitiveOutputInternal(
          p, WindowingStrategy.globalDefault(), IsBounded.BOUNDED, VarIntCoder.of());
  private PCollection<String> strs =
      PCollection.createPrimitiveOutputInternal(
          p, WindowingStrategy.globalDefault(), IsBounded.BOUNDED, StringUtf8Coder.of());

  private PCollection<Integer> replacementInts =
      PCollection.createPrimitiveOutputInternal(
          p, WindowingStrategy.globalDefault(), IsBounded.BOUNDED, VarIntCoder.of());
  private PCollection<Integer> moreReplacementInts =
      PCollection.createPrimitiveOutputInternal(
          p, WindowingStrategy.globalDefault(), IsBounded.BOUNDED, VarIntCoder.of());
  private PCollection<String> replacementStrs =
      PCollection.createPrimitiveOutputInternal(
          p, WindowingStrategy.globalDefault(), IsBounded.BOUNDED, StringUtf8Coder.of());

  @Test
  public void singletonSucceeds() {
    Map<PValue, ReplacementOutput> replacements =
        ReplacementOutputs.singleton(ints.expand(), replacementInts);

    assertThat(replacements, Matchers.hasKey(replacementInts));

    ReplacementOutput replacement = replacements.get(replacementInts);
    Map.Entry<TupleTag<?>, PValue> taggedInts = Iterables.getOnlyElement(ints.expand().entrySet());
    assertThat(replacement.getOriginal().getTag(), equalTo(taggedInts.getKey()));
    assertThat(replacement.getOriginal().getValue(), equalTo(taggedInts.getValue()));
    assertThat(replacement.getReplacement().getValue(), equalTo(replacementInts));
  }

  @Test
  public void singletonMultipleOriginalsThrows() {
    thrown.expect(IllegalArgumentException.class);
    ReplacementOutputs.singleton(
        ImmutableMap.<TupleTag<?>, PValue>builder()
            .putAll(ints.expand())
            .putAll(moreInts.expand())
            .build(),
        replacementInts);
  }

  private TupleTag<Integer> intsTag = new TupleTag<>();
  private TupleTag<Integer> moreIntsTag = new TupleTag<>();
  private TupleTag<String> strsTag = new TupleTag<>();

  @Test
  public void taggedSucceeds() {
    PCollectionTuple original =
        PCollectionTuple.of(intsTag, ints).and(strsTag, strs).and(moreIntsTag, moreInts);

    Map<PValue, ReplacementOutput> replacements =
        ReplacementOutputs.tagged(
            original.expand(),
            PCollectionTuple.of(strsTag, replacementStrs)
                .and(moreIntsTag, moreReplacementInts)
                .and(intsTag, replacementInts));
    assertThat(
        replacements.keySet(),
        Matchers.containsInAnyOrder(replacementStrs, replacementInts, moreReplacementInts));
    ReplacementOutput intsReplacement = replacements.get(replacementInts);
    ReplacementOutput strsReplacement = replacements.get(replacementStrs);
    ReplacementOutput moreIntsReplacement = replacements.get(moreReplacementInts);

    assertThat(
        intsReplacement,
        equalTo(
            ReplacementOutput.of(
                TaggedPValue.of(intsTag, ints), TaggedPValue.of(intsTag, replacementInts))));
    assertThat(
        strsReplacement,
        equalTo(
            ReplacementOutput.of(
                TaggedPValue.of(strsTag, strs), TaggedPValue.of(strsTag, replacementStrs))));
    assertThat(
        moreIntsReplacement,
        equalTo(
            ReplacementOutput.of(
                TaggedPValue.of(moreIntsTag, moreInts),
                TaggedPValue.of(moreIntsTag, moreReplacementInts))));
  }

  @Test
  public void taggedMissingReplacementThrows() {
    PCollectionTuple original =
        PCollectionTuple.of(intsTag, ints).and(strsTag, strs).and(moreIntsTag, moreInts);

    thrown.expect(IllegalArgumentException.class);
    thrown.expectMessage("Missing replacement");
    thrown.expectMessage(intsTag.toString());
    thrown.expectMessage(ints.toString());
    ReplacementOutputs.tagged(
        original.expand(),
        PCollectionTuple.of(strsTag, replacementStrs).and(moreIntsTag, moreReplacementInts));
  }

  @Test
  public void taggedExtraReplacementThrows() {
    PCollectionTuple original = PCollectionTuple.of(intsTag, ints).and(strsTag, strs);

    thrown.expect(IllegalArgumentException.class);
    thrown.expectMessage("Missing original output");
    thrown.expectMessage(moreIntsTag.toString());
    thrown.expectMessage(moreReplacementInts.toString());
    ReplacementOutputs.tagged(
        original.expand(),
        PCollectionTuple.of(strsTag, replacementStrs)
            .and(moreIntsTag, moreReplacementInts)
            .and(intsTag, replacementInts));
  }
}
