blob: 397b44438a48555ed57a8239b89511c1856a56a4 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.core.construction;
import com.google.auto.service.AutoService;
import java.io.IOException;
import java.io.Serializable;
import java.nio.charset.StandardCharsets;
import java.util.Map;
import org.apache.beam.runners.core.construction.expansion.ExpansionService;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.testing.UsesCrossLanguageTransforms;
import org.apache.beam.sdk.testing.ValidatesRunner;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.Filter;
import org.apache.beam.sdk.transforms.MapElements;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionTuple;
import org.apache.beam.sdk.values.TupleTag;
import org.apache.beam.sdk.values.TupleTagList;
import org.apache.beam.sdk.values.TypeDescriptors;
import org.apache.beam.vendor.grpc.v1p21p0.io.grpc.ConnectivityState;
import org.apache.beam.vendor.grpc.v1p21p0.io.grpc.ManagedChannel;
import org.apache.beam.vendor.grpc.v1p21p0.io.grpc.ManagedChannelBuilder;
import org.apache.beam.vendor.grpc.v1p21p0.io.grpc.Server;
import org.apache.beam.vendor.grpc.v1p21p0.io.grpc.ServerBuilder;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableMap;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Rule;
import org.junit.Test;
import org.junit.experimental.categories.Category;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
/** Test External transforms. */
@RunWith(JUnit4.class)
public class ExternalTest implements Serializable {
@Rule public transient TestPipeline testPipeline = TestPipeline.create();
private static final String TEST_URN_SIMPLE = "simple";
private static final String TEST_URN_LE = "le";
private static final String TEST_URN_MULTI = "multi";
private static Integer expansionPort;
private static String localExpansionAddr;
private static Server localExpansionServer;
@BeforeClass
public static void setUp() throws IOException {
expansionPort = Integer.valueOf(System.getProperty("expansionPort"));
int localExpansionPort = expansionPort + 100;
localExpansionAddr = String.format("localhost:%s", localExpansionPort);
localExpansionServer =
ServerBuilder.forPort(localExpansionPort).addService(new ExpansionService()).build();
localExpansionServer.start();
}
@AfterClass
public static void tearDown() {
localExpansionServer.shutdownNow();
}
@Test
@Category({ValidatesRunner.class, UsesCrossLanguageTransforms.class})
public void expandSingleTest() {
PCollection<String> col =
testPipeline
.apply(Create.of("1", "2", "3"))
.apply(External.of(TEST_URN_SIMPLE, new byte[] {}, localExpansionAddr));
PAssert.that(col).containsInAnyOrder("Simple(1)", "Simple(2)", "Simple(3)");
testPipeline.run();
}
@Test
@Category({ValidatesRunner.class, UsesCrossLanguageTransforms.class})
public void expandMultipleTest() {
PCollection<String> pcol =
testPipeline
.apply(Create.of(1, 2, 3, 4, 5, 6))
.apply(
"filter <=3",
External.of(TEST_URN_LE, "3".getBytes(StandardCharsets.UTF_8), localExpansionAddr))
.apply(MapElements.into(TypeDescriptors.strings()).via(Object::toString))
.apply("put simple", External.of(TEST_URN_SIMPLE, new byte[] {}, localExpansionAddr));
PAssert.that(pcol).containsInAnyOrder("Simple(1)", "Simple(2)", "Simple(3)");
testPipeline.run();
}
@Test
@Category({ValidatesRunner.class, UsesCrossLanguageTransforms.class})
public void expandMultiOutputTest() {
PCollectionTuple pTuple =
testPipeline
.apply(Create.of(1, 2, 3, 4, 5, 6))
.apply(
External.of(TEST_URN_MULTI, new byte[] {}, localExpansionAddr).withMultiOutputs());
PAssert.that(pTuple.get(new TupleTag<Integer>("even") {})).containsInAnyOrder(2, 4, 6);
PAssert.that(pTuple.get(new TupleTag<Integer>("odd") {})).containsInAnyOrder(1, 3, 5);
testPipeline.run();
}
@Test
@Category({ValidatesRunner.class, UsesCrossLanguageTransforms.class})
public void expandPythonTest() {
String target = String.format("localhost:%s", expansionPort);
try {
ManagedChannel channel = ManagedChannelBuilder.forTarget(target).build();
ConnectivityState state = channel.getState(true);
for (int retry = 0; retry < 30 && state != ConnectivityState.READY; retry++) {
Thread.sleep(500);
state = channel.getState(true);
}
channel.shutdownNow();
PCollection<String> pCol =
testPipeline
.apply(Create.of("1", "2", "2", "3", "3", "3"))
.apply(
External.<KV<String, Integer>>of(
"beam:transforms:xlang:count", new byte[] {}, target))
.apply(
"toString",
MapElements.into(TypeDescriptors.strings())
.via(x -> String.format("%s->%s", x.getKey(), x.getValue())));
PAssert.that(pCol).containsInAnyOrder("1->1", "2->2", "3->3");
testPipeline.run();
} catch (InterruptedException e) {
throw new RuntimeException("interrupted.");
}
}
/** Test TransformProvider. */
@AutoService(ExpansionService.ExpansionServiceRegistrar.class)
public static class TestTransforms
implements ExpansionService.ExpansionServiceRegistrar, Serializable {
private final TupleTag<Integer> even = new TupleTag<Integer>("even") {};
private final TupleTag<Integer> odd = new TupleTag<Integer>("odd") {};
@Override
public Map<String, ExpansionService.TransformProvider> knownTransforms() {
return ImmutableMap.of(
TEST_URN_SIMPLE,
spec ->
MapElements.into(TypeDescriptors.strings())
.via((String x) -> String.format("Simple(%s)", x)),
TEST_URN_LE,
spec -> Filter.lessThanEq(Integer.parseInt(spec.getPayload().toStringUtf8())),
TEST_URN_MULTI,
spec ->
ParDo.of(
new DoFn<Integer, Integer>() {
@ProcessElement
public void processElement(ProcessContext c) {
if (c.element() % 2 == 0) {
c.output(c.element());
} else {
c.output(odd, c.element());
}
}
})
.withOutputTags(even, TupleTagList.of(odd)));
}
}
}