Merge pull request #12501 from TobKed/BEAM-10662-fix-GCP-variable-check-in-build-python-wheels-workflow
[BEAM-10662] Fix GCP variable check in build python wheels workflow
diff --git a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/ValidateRunnerXlangTest.java b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/ValidateRunnerXlangTest.java
index 3a0f60b..b96c2c3 100644
--- a/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/ValidateRunnerXlangTest.java
+++ b/runners/core-construction-java/src/test/java/org/apache/beam/runners/core/construction/ValidateRunnerXlangTest.java
@@ -54,7 +54,30 @@
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
-/** Test External transforms. */
+/**
+ * Runner Validation Test Suite for Cross-language Transforms.
+ *
+ * <p>As per Beams's Portability Framework design, Cross-language transforms should work out of the
+ * box. In spite of this, there always exists a possibility of rough edges existing. It could be
+ * caused due to unpolished implementation of any part of the execution code path, for example: –>
+ * Transform expansion [SDK] –> Pipeline construction [SDK] –> Cross-language artifact staging
+ * [Runner] –> Language specific serialization/deserialization of PCollection (and other data types)
+ * [Runner/SDK]
+ *
+ * <p>In an effort to improve developer visibility into potential problems, this test suite
+ * validates correct execution of 5 Core Beam transforms when used as cross-language transforms
+ * within the Java SDK from any foreign SDK: –> ParDo
+ * (https://beam.apache.org/documentation/programming-guide/#pardo) –> GroupByKey
+ * (https://beam.apache.org/documentation/programming-guide/#groupbykey) –> CoGroupByKey
+ * (https://beam.apache.org/documentation/programming-guide/#cogroupbykey) –> Combine
+ * (https://beam.apache.org/documentation/programming-guide/#combine) –> Flatten
+ * (https://beam.apache.org/documentation/programming-guide/#flatten) –> Partition
+ * (https://beam.apache.org/documentation/programming-guide/#partition)
+ *
+ * <p>See Runner Validation Test Plan for Cross-language transforms
+ * (https://docs.google.com/document/d/1xQp0ElIV84b8OCVz8CD2hvbiWdR8w4BvWxPTZJZA6NA") for further
+ * details.
+ */
@RunWith(JUnit4.class)
public class ValidateRunnerXlangTest implements Serializable {
@Rule public transient TestPipeline testPipeline = TestPipeline.create();
@@ -110,6 +133,14 @@
}
}
+ /**
+ * Motivation behind singleInputOutputTest.
+ *
+ * <p>Target transform – ParDo (https://beam.apache.org/documentation/programming-guide/#pardo)
+ * Test scenario – Mapping elements from a single input collection to a single output collection
+ * Boundary conditions checked – –> PCollection<?> to external transforms –> PCollection<?> from
+ * external transforms
+ */
@Test
@Category({ValidatesRunner.class, UsesCrossLanguageTransforms.class})
public void singleInputOutputTest() throws IOException {
@@ -120,6 +151,14 @@
PAssert.that(col).containsInAnyOrder("01", "02", "03");
}
+ /**
+ * Motivation behind multiInputOutputWithSideInputTest.
+ *
+ * <p>Target transform – ParDo (https://beam.apache.org/documentation/programming-guide/#pardo)
+ * Test scenario – Mapping elements from multiple input collections (main and side) to multiple
+ * output collections (main and side) Boundary conditions checked – –> PCollectionTuple to
+ * external transforms –> PCollectionTuple from external transforms
+ */
@Test
@Category({ValidatesRunner.class, UsesCrossLanguageTransforms.class})
public void multiInputOutputWithSideInputTest() {
@@ -135,6 +174,15 @@
PAssert.that(pTuple.get("side")).containsInAnyOrder("ss");
}
+ /**
+ * Motivation behind groupByKeyTest.
+ *
+ * <p>Target transform – GroupByKey
+ * (https://beam.apache.org/documentation/programming-guide/#groupbykey) Test scenario – Grouping
+ * a collection of KV<K,V> to a collection of KV<K, Iterable<V>> by key Boundary conditions
+ * checked – –> PCollection<KV<?, ?>> to external transforms –> PCollection<KV<?, Iterable<?>>>
+ * from external transforms
+ */
@Test
@Category({ValidatesRunner.class, UsesCrossLanguageTransforms.class})
public void groupByKeyTest() {
@@ -154,6 +202,15 @@
PAssert.that(col).containsInAnyOrder("0:1,2", "1:3");
}
+ /**
+ * Motivation behind coGroupByKeyTest.
+ *
+ * <p>Target transform – CoGroupByKey
+ * (https://beam.apache.org/documentation/programming-guide/#cogroupbykey) Test scenario –
+ * Grouping multiple input collections with keys to a collection of KV<K, CoGbkResult> by key
+ * Boundary conditions checked – –> KeyedPCollectionTuple<?> to external transforms –>
+ * PCollection<KV<?, Iterable<?>>> from external transforms
+ */
@Test
@Category({ValidatesRunner.class, UsesCrossLanguageTransforms.class})
public void coGroupByKeyTest() {
@@ -177,6 +234,14 @@
PAssert.that(col).containsInAnyOrder("0:1,2,4", "1:3,5,6");
}
+ /**
+ * Motivation behind combineGloballyTest.
+ *
+ * <p>Target transform – Combine
+ * (https://beam.apache.org/documentation/programming-guide/#combine) Test scenario – Combining
+ * elements globally with a predefined simple CombineFn Boundary conditions checked – –>
+ * PCollection<?> to external transforms –> PCollection<?> from external transforms
+ */
@Test
@Category({ValidatesRunner.class, UsesCrossLanguageTransforms.class})
public void combineGloballyTest() {
@@ -187,6 +252,14 @@
PAssert.that(col).containsInAnyOrder(6L);
}
+ /**
+ * Motivation behind combinePerKeyTest.
+ *
+ * <p>Target transform – Combine
+ * (https://beam.apache.org/documentation/programming-guide/#combine) Test scenario – Combining
+ * elements per key with a predefined simple merging function Boundary conditions checked – –>
+ * PCollection<?> to external transforms –> PCollection<?> from external transforms
+ */
@Test
@Category({ValidatesRunner.class, UsesCrossLanguageTransforms.class})
public void combinePerKeyTest() {
@@ -197,6 +270,14 @@
PAssert.that(col).containsInAnyOrder(KV.of("a", 3L), KV.of("b", 3L));
}
+ /**
+ * Motivation behind flattenTest.
+ *
+ * <p>Target transform – Flatten
+ * (https://beam.apache.org/documentation/programming-guide/#flatten) Test scenario – Merging
+ * multiple collections into a single collection Boundary conditions checked – –>
+ * PCollectionList<?> to external transforms –> PCollection<?> from external transforms
+ */
@Test
@Category({ValidatesRunner.class, UsesCrossLanguageTransforms.class})
public void flattenTest() {
@@ -209,6 +290,15 @@
PAssert.that(col).containsInAnyOrder(1L, 2L, 3L, 4L, 5L, 6L);
}
+ /**
+ * Motivation behind partitionTest.
+ *
+ * <p>Target transform – Partition
+ * (https://beam.apache.org/documentation/programming-guide/#partition) Test scenario – Splitting
+ * a single collection into multiple collections with a predefined simple PartitionFn Boundary
+ * conditions checked – –> PCollection<?> to external transforms –> PCollectionList<?> from
+ * external transforms
+ */
@Test
@Category({ValidatesRunner.class, UsesCrossLanguageTransforms.class})
public void partitionTest() {
diff --git a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/bigquery/BeamSqlUnparseContext.java b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/bigquery/BeamSqlUnparseContext.java
index f107dc3..ac65eb5 100644
--- a/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/bigquery/BeamSqlUnparseContext.java
+++ b/sdks/java/extensions/sql/src/main/java/org/apache/beam/sdk/extensions/sql/meta/provider/bigquery/BeamSqlUnparseContext.java
@@ -19,6 +19,9 @@
import static org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.rel2sql.SqlImplementor.POS;
+import java.util.HashMap;
+import java.util.Map;
+import java.util.Objects;
import java.util.function.IntFunction;
import org.apache.beam.repackaged.core.org.apache.commons.lang3.text.translate.CharSequenceTranslator;
import org.apache.beam.repackaged.core.org.apache.commons.lang3.text.translate.EntityArrays;
@@ -27,9 +30,12 @@
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.avatica.util.ByteString;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.avatica.util.TimeUnitRange;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.rel2sql.SqlImplementor;
+import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.type.RelDataType;
+import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rex.RexDynamicParam;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rex.RexLiteral;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rex.RexNode;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rex.RexProgram;
+import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.sql.SqlDynamicParam;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.sql.SqlKind;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.sql.SqlLiteral;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.sql.SqlNode;
@@ -62,10 +68,16 @@
// Unicode (only 4 hex digits)
.with(JavaUnicodeEscaper.outsideOf(32, 0x7f));
+ private Map<String, RelDataType> nullParams = new HashMap<>();
+
public BeamSqlUnparseContext(IntFunction<SqlNode> field) {
super(BeamBigQuerySqlDialect.DEFAULT, field);
}
+ public Map<String, RelDataType> getNullParams() {
+ return nullParams;
+ }
+
@Override
public SqlNode toSql(RexProgram program, RexNode rex) {
if (rex.getKind().equals(SqlKind.LITERAL)) {
@@ -92,6 +104,12 @@
return new ReplaceLiteral(literal, POS, "ISOWEEK");
}
}
+ } else if (rex.getKind().equals(SqlKind.DYNAMIC_PARAM)) {
+ final RexDynamicParam param = (RexDynamicParam) rex;
+ final int index = param.getIndex();
+ final String name = "null_param_" + index;
+ nullParams.put(name, param.getType());
+ return new NamedDynamicParam(index, POS, name);
}
return super.toSql(program, rex);
@@ -110,6 +128,26 @@
public void unparse(SqlWriter writer, int leftPrec, int rightPrec) {
writer.literal("DATETIME '" + timestampString.toString() + "'");
}
+
+ @Override
+ public boolean equals(Object o) {
+ if (this == o) {
+ return true;
+ }
+ if (o == null || getClass() != o.getClass()) {
+ return false;
+ }
+ if (!super.equals(o)) {
+ return false;
+ }
+ SqlDateTimeLiteral that = (SqlDateTimeLiteral) o;
+ return Objects.equals(timestampString, that.timestampString);
+ }
+
+ @Override
+ public int hashCode() {
+ return Objects.hash(super.hashCode(), timestampString);
+ }
}
private static class SqlByteStringLiteral extends SqlLiteral {
@@ -167,4 +205,18 @@
return super.hashCode();
}
}
+
+ private static class NamedDynamicParam extends SqlDynamicParam {
+ private final String newName;
+
+ NamedDynamicParam(int index, SqlParserPos pos, String newName) {
+ super(index, pos);
+ this.newName = newName;
+ }
+
+ @Override
+ public void unparse(SqlWriter writer, int leftPrec, int rightPrec) {
+ writer.literal("@" + newName);
+ }
+ }
}
diff --git a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/BeamZetaSqlCalcRel.java b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/BeamZetaSqlCalcRel.java
index 70153bd..50c5336 100644
--- a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/BeamZetaSqlCalcRel.java
+++ b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/BeamZetaSqlCalcRel.java
@@ -20,7 +20,6 @@
import com.google.zetasql.AnalyzerOptions;
import com.google.zetasql.PreparedExpression;
import com.google.zetasql.Value;
-import java.util.Collections;
import java.util.HashMap;
import java.util.Map;
import java.util.function.IntFunction;
@@ -42,7 +41,7 @@
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.plan.RelTraitSet;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.RelNode;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.core.Calc;
-import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.rel2sql.SqlImplementor;
+import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.type.RelDataType;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rex.RexBuilder;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rex.RexNode;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rex.RexProgram;
@@ -61,7 +60,7 @@
public class BeamZetaSqlCalcRel extends AbstractBeamCalcRel {
private static final SqlDialect DIALECT = BeamBigQuerySqlDialect.DEFAULT;
- private final SqlImplementor.Context context;
+ private final BeamSqlUnparseContext context;
private static String columnName(int i) {
return "_" + i;
@@ -110,6 +109,7 @@
CalcFn calcFn =
new CalcFn(
context.toSql(getProgram(), rex).toSqlString(DIALECT).getSql(),
+ createNullParams(context.getNullParams()),
upstream.getSchema(),
outputSchema,
options.getZetaSqlDefaultTimezone(),
@@ -122,12 +122,23 @@
}
}
+ private static Map<String, Value> createNullParams(Map<String, RelDataType> input) {
+ Map<String, Value> result = new HashMap<>();
+ for (Map.Entry<String, RelDataType> entry : input.entrySet()) {
+ result.put(
+ entry.getKey(),
+ Value.createNullValue(ZetaSqlCalciteTranslationUtils.toZetaType(entry.getValue())));
+ }
+ return result;
+ }
+
/**
* {@code CalcFn} is the executor for a {@link BeamZetaSqlCalcRel} step. The implementation is
* based on the {@code ZetaSQL} expression evaluator.
*/
private static class CalcFn extends DoFn<Row, Row> {
private final String sql;
+ private final Map<String, Value> nullParams;
private final Schema inputSchema;
private final Schema outputSchema;
private final String defaultTimezone;
@@ -136,11 +147,13 @@
CalcFn(
String sql,
+ Map<String, Value> nullParams,
Schema inputSchema,
Schema outputSchema,
String defaultTimezone,
boolean verifyRowValues) {
this.sql = sql;
+ this.nullParams = nullParams;
this.inputSchema = inputSchema;
this.outputSchema = outputSchema;
this.defaultTimezone = defaultTimezone;
@@ -149,10 +162,8 @@
@Setup
public void setup() {
- // TODO[BEAM-9182]: support parameters in expression evaluation
- // Query parameters are not set because they have already been substituted.
AnalyzerOptions options =
- SqlAnalyzer.getAnalyzerOptions(QueryParameters.ofNone(), defaultTimezone);
+ SqlAnalyzer.getAnalyzerOptions(QueryParameters.ofNamed(nullParams), defaultTimezone);
for (int i = 0; i < inputSchema.getFieldCount(); i++) {
options.addExpressionColumn(
columnName(i),
@@ -175,11 +186,7 @@
row.getBaseValue(i, Object.class), inputSchema.getField(i).getType()));
}
- // TODO[BEAM-9182]: support parameters in expression evaluation
- // The map is empty because parameters in the query string have already been substituted.
- Map<String, Value> params = Collections.emptyMap();
-
- Value v = exp.execute(columns, params);
+ Value v = exp.execute(columns, nullParams);
if (!v.isNull()) {
Row outputRow =
ZetaSqlBeamTranslationUtils.zetaSqlStructValueToBeamRow(
diff --git a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSqlCalciteTranslationUtils.java b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSqlCalciteTranslationUtils.java
index 81bc142..df7f4e2 100644
--- a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSqlCalciteTranslationUtils.java
+++ b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSqlCalciteTranslationUtils.java
@@ -136,9 +136,12 @@
// -1 cardinality means unlimited array size.
// TODO: is unlimited array size right for general case?
// TODO: whether isNullable should be ArrayType's nullablity (not its element type's?)
- return rexBuilder
- .getTypeFactory()
- .createArrayType(toRelDataType(rexBuilder, arrayType.getElementType(), isNullable), -1);
+ return nullable(
+ rexBuilder,
+ rexBuilder
+ .getTypeFactory()
+ .createArrayType(toRelDataType(rexBuilder, arrayType.getElementType(), isNullable), -1),
+ isNullable);
}
private static List<String> toNameList(List<StructField> fields) {
diff --git a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/ExpressionConverter.java b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/ExpressionConverter.java
index fd5651f..d0b4112 100644
--- a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/ExpressionConverter.java
+++ b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/ExpressionConverter.java
@@ -193,6 +193,7 @@
private final RelOptCluster cluster;
private final QueryParameters queryParams;
+ private int nullParamCount = 0;
private final Map<String, ResolvedCreateFunctionStmt> userDefinedFunctions;
public ExpressionConverter(
@@ -1217,7 +1218,17 @@
throw new IllegalArgumentException("Found unexpected parameter " + parameter);
}
Preconditions.checkState(parameter.getType().equals(value.getType()));
- return convertValueToRexNode(value.getType(), value);
+ if (value.isNull()) {
+ // In some cases NULL parameter cannot be substituted with NULL literal
+ // Therefore we create a dynamic parameter placeholder here for each NULL parameter
+ return rexBuilder()
+ .makeDynamicParam(
+ ZetaSqlCalciteTranslationUtils.toRelDataType(rexBuilder(), value.getType(), true),
+ nullParamCount++);
+ } else {
+ // Substitute non-NULL parameter with literal
+ return convertValueToRexNode(value.getType(), value);
+ }
}
private RexNode convertResolvedArgumentRef(
diff --git a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/LimitOffsetScanToLimitConverter.java b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/LimitOffsetScanToLimitConverter.java
index 3052d6f..b58c40a 100644
--- a/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/LimitOffsetScanToLimitConverter.java
+++ b/sdks/java/extensions/sql/zetasql/src/main/java/org/apache/beam/sdk/extensions/sql/zetasql/translation/LimitOffsetScanToLimitConverter.java
@@ -26,6 +26,7 @@
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.RelCollations;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.RelNode;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rel.logical.LogicalSort;
+import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rex.RexDynamicParam;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rex.RexLiteral;
import org.apache.beam.vendor.calcite.v1_20_0.org.apache.calcite.rex.RexNode;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableList;
@@ -64,7 +65,11 @@
input.getRowType().getFieldList(),
ImmutableMap.of());
- if (RexLiteral.isNullLiteral(offset) || RexLiteral.isNullLiteral(fetch)) {
+ // offset or fetch being RexDynamicParam means it is NULL (the only param supported currently)
+ if (offset instanceof RexDynamicParam
+ || RexLiteral.isNullLiteral(offset)
+ || fetch instanceof RexDynamicParam
+ || RexLiteral.isNullLiteral(fetch)) {
throw new UnsupportedOperationException("Limit requires non-null count and offset");
}
diff --git a/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSqlDialectSpecTest.java b/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSqlDialectSpecTest.java
index d148012..9a67da6 100644
--- a/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSqlDialectSpecTest.java
+++ b/sdks/java/extensions/sql/zetasql/src/test/java/org/apache/beam/sdk/extensions/sql/zetasql/ZetaSqlDialectSpecTest.java
@@ -252,7 +252,6 @@
}
@Test
- @Ignore("[BEAM-9182] NULL parameters do not work in BeamZetaSqlCalcRel")
public void testEQ1() {
String sql = "SELECT @p0 = @p1 AS ColA";
@@ -294,7 +293,6 @@
}
@Test
- @Ignore("[BEAM-9182] NULL parameters do not work in BeamZetaSqlCalcRel")
public void testEQ3() {
String sql = "SELECT @p0 = @p1 AS ColA";
@@ -523,7 +521,6 @@
}
@Test
- @Ignore("[BEAM-9182] NULL parameters do not work in BeamZetaSqlCalcRel")
public void testNullIfCoercion() {
String sql = "SELECT NULLIF(@p0, @p1) AS ColA";
ImmutableMap<String, Value> params =
@@ -733,9 +730,8 @@
}
@Test
- @Ignore("[BEAM-9182] NULL parameters do not work in BeamZetaSqlCalcRel")
public void testLikeNullPattern() {
- String sql = "SELECT @p0 LIKE @p1 AS ColA";
+ String sql = "SELECT @p0 LIKE @p1 AS ColA";
ImmutableMap<String, Value> params =
ImmutableMap.of(
"p0",
diff --git a/sdks/python/apache_beam/dataframe/expressions.py b/sdks/python/apache_beam/dataframe/expressions.py
index 34e01cca..7e8b782 100644
--- a/sdks/python/apache_beam/dataframe/expressions.py
+++ b/sdks/python/apache_beam/dataframe/expressions.py
@@ -18,7 +18,6 @@
import contextlib
import threading
-
from typing import Any
from typing import Callable
from typing import Iterable
diff --git a/sdks/python/apache_beam/runners/interactive/cache_manager.py b/sdks/python/apache_beam/runners/interactive/cache_manager.py
index 9894535..48f1fc5 100644
--- a/sdks/python/apache_beam/runners/interactive/cache_manager.py
+++ b/sdks/python/apache_beam/runners/interactive/cache_manager.py
@@ -67,13 +67,17 @@
"""Returns the latest version number of the PCollection cache."""
raise NotImplementedError
- def read(self, *labels):
- # type (*str) -> Tuple[str, Generator[Any]]
+ def read(self, *labels, **args):
+ # type (*str, Dict[str, Any]) -> Tuple[str, Generator[Any]]
"""Return the PCollection as a list as well as the version number.
Args:
*labels: List of labels for PCollection instance.
+ **args: Dict of additional arguments. Currently only supports 'limiters'
+ as a list of ElementLimiters, and 'tail' as a boolean. Limiters limits
+ the amount of elements read and duration with respect to processing
+ time.
Returns:
A tuple containing an iterator for the items in the PCollection and the
@@ -97,6 +101,17 @@
"""
raise NotImplementedError
+ def clear(self, *labels):
+ # type (*str) -> Boolean
+
+ """Clears the cache entry of the given labels and returns True on success.
+
+ Args:
+ value: An encodable (with corresponding PCoder) value
+ *labels: List of labels for PCollection instance
+ """
+ raise NotImplementedError
+
def source(self, *labels):
# type (*str) -> ptransform.PTransform
@@ -196,17 +211,35 @@
self._default_pcoder if self._default_pcoder is not None else
self._saved_pcoders[self._path(*labels)])
- def read(self, *labels):
+ def read(self, *labels, **args):
# Return an iterator to an empty list if it doesn't exist.
if not self.exists(*labels):
return iter([]), -1
+ limiters = args.pop('limiters', [])
+
# Otherwise, return a generator to the cached PCollection.
source = self.source(*labels)._source
range_tracker = source.get_range_tracker(None, None)
reader = source.read(range_tracker)
version = self._latest_version(*labels)
- return reader, version
+
+ # The return type is a generator, so in order to implement the limiter for
+ # the FileBasedCacheManager we wrap the original generator with the logic
+ # to limit yielded elements.
+ def limit_reader(r):
+ for e in r:
+ # Update the limiters and break early out of reading from cache if any
+ # are triggered.
+ for l in limiters:
+ l.update(e)
+
+ if any(l.is_triggered() for l in limiters):
+ break
+
+ yield e
+
+ return limit_reader(reader), version
def write(self, values, *labels):
sink = self.sink(labels)._sink
@@ -218,6 +251,12 @@
writer.write(v)
writer.close()
+ def clear(self, *labels):
+ if self.exists(*labels):
+ filesystems.FileSystems.delete(self._match(*labels))
+ return True
+ return False
+
def source(self, *labels):
return self._reader_class(
self._glob_path(*labels), coder=self.load_pcoder(*labels))
diff --git a/sdks/python/apache_beam/runners/interactive/cache_manager_test.py b/sdks/python/apache_beam/runners/interactive/cache_manager_test.py
index 7868e90..e7dc936 100644
--- a/sdks/python/apache_beam/runners/interactive/cache_manager_test.py
+++ b/sdks/python/apache_beam/runners/interactive/cache_manager_test.py
@@ -30,6 +30,7 @@
from apache_beam import coders
from apache_beam.io import filesystems
from apache_beam.runners.interactive import cache_manager as cache
+from apache_beam.runners.interactive.options.capture_limiters import CountLimiter
class FileBasedCacheManagerTest(object):
@@ -91,6 +92,18 @@
self.mock_write_cache(cache_version_one, prefix, cache_label)
self.assertTrue(self.cache_manager.exists(prefix, cache_label))
+ def test_clear(self):
+ """Test that CacheManager can correctly tell if the cache exists or not."""
+ prefix = 'full'
+ cache_label = 'some-cache-label'
+ cache_version_one = ['cache', 'version', 'one']
+
+ self.assertFalse(self.cache_manager.exists(prefix, cache_label))
+ self.mock_write_cache(cache_version_one, prefix, cache_label)
+ self.assertTrue(self.cache_manager.exists(prefix, cache_label))
+ self.assertTrue(self.cache_manager.clear(prefix, cache_label))
+ self.assertFalse(self.cache_manager.exists(prefix, cache_label))
+
def test_read_basic(self):
"""Test the condition where the cache is read once after written once."""
prefix = 'full'
@@ -180,6 +193,21 @@
self.assertTrue(
self.cache_manager.is_latest_version(version, prefix, cache_label))
+ def test_read_with_count_limiter(self):
+ """Test the condition where the cache is read once after written once."""
+ prefix = 'full'
+ cache_label = 'some-cache-label'
+ cache_version_one = ['cache', 'version', 'one']
+
+ self.mock_write_cache(cache_version_one, prefix, cache_label)
+ reader, version = self.cache_manager.read(
+ prefix, cache_label, limiters=[CountLimiter(2)])
+ pcoll_list = list(reader)
+ self.assertListEqual(pcoll_list, ['cache', 'version'])
+ self.assertEqual(version, 0)
+ self.assertTrue(
+ self.cache_manager.is_latest_version(version, prefix, cache_label))
+
class TextFileBasedCacheManagerTest(
FileBasedCacheManagerTest,
diff --git a/sdks/python/apache_beam/runners/interactive/caching/streaming_cache.py b/sdks/python/apache_beam/runners/interactive/caching/streaming_cache.py
index b2204cf..77f976d 100644
--- a/sdks/python/apache_beam/runners/interactive/caching/streaming_cache.py
+++ b/sdks/python/apache_beam/runners/interactive/caching/streaming_cache.py
@@ -153,13 +153,23 @@
cache_dir,
labels,
is_cache_complete=None,
- coder=SafeFastPrimitivesCoder()):
+ coder=None,
+ limiters=None):
+ if not coder:
+ coder = SafeFastPrimitivesCoder()
+
+ if not is_cache_complete:
+ is_cache_complete = lambda _: True
+
+ if not limiters:
+ limiters = []
+
self._cache_dir = cache_dir
self._coder = coder
self._labels = labels
self._path = os.path.join(self._cache_dir, *self._labels)
- self._is_cache_complete = (
- is_cache_complete if is_cache_complete else lambda _: True)
+ self._is_cache_complete = is_cache_complete
+ self._limiters = limiters
from apache_beam.runners.interactive.pipeline_instrument import CacheKey
self._pipeline_id = CacheKey.from_str(labels[-1]).pipeline_id
@@ -193,7 +203,8 @@
# Check if we are at EOF or if we have an incomplete line.
if not line or (line and line[-1] != b'\n'[0]):
- if not tail:
+ # Read at least the first line to get the header.
+ if not tail and pos != 0:
break
# Complete reading only when the cache is complete.
@@ -210,10 +221,16 @@
proto_cls = TestStreamFileHeader if pos == 0 else TestStreamFileRecord
msg = self._try_parse_as(proto_cls, to_decode)
if msg:
- yield msg
+ for l in self._limiters:
+ l.update(msg)
+
+ if any(l.is_triggered() for l in self._limiters):
+ break
else:
break
+ yield msg
+
def _try_parse_as(self, proto_cls, to_decode):
try:
msg = proto_cls()
@@ -285,7 +302,7 @@
return os.path.exists(path)
# TODO(srohde): Modify this to return the correct version.
- def read(self, *labels):
+ def read(self, *labels, **args):
"""Returns a generator to read all records from file.
Does not tail.
@@ -293,8 +310,12 @@
if not self.exists(*labels):
return iter([]), -1
+ limiters = args.pop('limiters', [])
+ tail = args.pop('tail', False)
+
reader = StreamingCacheSource(
- self._cache_dir, labels, self._is_cache_complete).read(tail=False)
+ self._cache_dir, labels, self._is_cache_complete,
+ limiters=limiters).read(tail=tail)
# Return an empty iterator if there is nothing in the file yet. This can
# only happen when tail is False.
@@ -304,7 +325,7 @@
return iter([]), -1
return StreamingCache.Reader([header], [reader]).read(), 1
- def read_multiple(self, labels):
+ def read_multiple(self, labels, limiters=None, tail=True):
"""Returns a generator to read all records from file.
Does tail until the cache is complete. This is because it is used in the
@@ -312,9 +333,9 @@
pipeline runtime which needs to block.
"""
readers = [
- StreamingCacheSource(self._cache_dir, l,
- self._is_cache_complete).read(tail=True)
- for l in labels
+ StreamingCacheSource(
+ self._cache_dir, l, self._is_cache_complete,
+ limiters=limiters).read(tail=tail) for l in labels
]
headers = [next(r) for r in readers]
return StreamingCache.Reader(headers, readers).read()
@@ -334,6 +355,14 @@
val = v
f.write(self._default_pcoder.encode(val) + b'\n')
+ def clear(self, *labels):
+ directory = os.path.join(self._cache_dir, *labels[:-1])
+ filepath = os.path.join(directory, labels[-1])
+ if os.path.exists(filepath):
+ os.remove(filepath)
+ return True
+ return False
+
def source(self, *labels):
"""Returns the StreamingCacheManager source.
diff --git a/sdks/python/apache_beam/runners/interactive/caching/streaming_cache_test.py b/sdks/python/apache_beam/runners/interactive/caching/streaming_cache_test.py
index a56b851..23390cc 100644
--- a/sdks/python/apache_beam/runners/interactive/caching/streaming_cache_test.py
+++ b/sdks/python/apache_beam/runners/interactive/caching/streaming_cache_test.py
@@ -28,6 +28,8 @@
from apache_beam.portability.api.beam_runner_api_pb2 import TestStreamPayload
from apache_beam.runners.interactive.cache_manager import SafeFastPrimitivesCoder
from apache_beam.runners.interactive.caching.streaming_cache import StreamingCache
+from apache_beam.runners.interactive.options.capture_limiters import CountLimiter
+from apache_beam.runners.interactive.options.capture_limiters import ProcessingTimeLimiter
from apache_beam.runners.interactive.pipeline_instrument import CacheKey
from apache_beam.runners.interactive.testing.test_cache_manager import FileRecordsBuilder
from apache_beam.testing.test_pipeline import TestPipeline
@@ -64,6 +66,14 @@
# Assert that an empty reader returns an empty list.
self.assertFalse([e for e in reader])
+ def test_clear(self):
+ cache = StreamingCache(cache_dir=None)
+ self.assertFalse(cache.exists('my_label'))
+ cache.write([TestStreamFileRecord()], 'my_label')
+ self.assertTrue(cache.exists('my_label'))
+ self.assertTrue(cache.clear('my_label'))
+ self.assertFalse(cache.exists('my_label'))
+
def test_single_reader(self):
"""Tests that we expect to see all the correctly emitted TestStreamPayloads.
"""
@@ -403,6 +413,106 @@
self.assertListEqual(actual_events, expected_events)
+ def test_single_reader_with_count_limiter(self):
+ """Tests that we expect to see all the correctly emitted TestStreamPayloads.
+ """
+ CACHED_PCOLLECTION_KEY = repr(CacheKey('arbitrary_key', '', '', ''))
+
+ values = (FileRecordsBuilder(tag=CACHED_PCOLLECTION_KEY)
+ .add_element(element=0, event_time_secs=0)
+ .advance_processing_time(1)
+ .add_element(element=1, event_time_secs=1)
+ .advance_processing_time(1)
+ .add_element(element=2, event_time_secs=2)
+ .build()) # yapf: disable
+
+ cache = StreamingCache(cache_dir=None)
+ cache.write(values, CACHED_PCOLLECTION_KEY)
+
+ reader, _ = cache.read(CACHED_PCOLLECTION_KEY, limiters=[CountLimiter(2)])
+ coder = coders.FastPrimitivesCoder()
+ events = list(reader)
+
+ # Units here are in microseconds.
+ # These are a slice of the original values such that we only get two
+ # elements.
+ expected = [
+ TestStreamPayload.Event(
+ element_event=TestStreamPayload.Event.AddElements(
+ elements=[
+ TestStreamPayload.TimestampedElement(
+ encoded_element=coder.encode(0), timestamp=0)
+ ],
+ tag=CACHED_PCOLLECTION_KEY)),
+ TestStreamPayload.Event(
+ processing_time_event=TestStreamPayload.Event.AdvanceProcessingTime(
+ advance_duration=1 * 10**6)),
+ TestStreamPayload.Event(
+ element_event=TestStreamPayload.Event.AddElements(
+ elements=[
+ TestStreamPayload.TimestampedElement(
+ encoded_element=coder.encode(1), timestamp=1 * 10**6)
+ ],
+ tag=CACHED_PCOLLECTION_KEY)),
+ TestStreamPayload.Event(
+ processing_time_event=TestStreamPayload.Event.AdvanceProcessingTime(
+ advance_duration=1 * 10**6)),
+ ]
+ self.assertSequenceEqual(events, expected)
+
+ def test_single_reader_with_processing_time_limiter(self):
+ """Tests that we expect to see all the correctly emitted TestStreamPayloads.
+ """
+ CACHED_PCOLLECTION_KEY = repr(CacheKey('arbitrary_key', '', '', ''))
+
+ values = (FileRecordsBuilder(tag=CACHED_PCOLLECTION_KEY)
+ .advance_processing_time(1e-6)
+ .add_element(element=0, event_time_secs=0)
+ .advance_processing_time(1)
+ .add_element(element=1, event_time_secs=1)
+ .advance_processing_time(1)
+ .add_element(element=2, event_time_secs=2)
+ .advance_processing_time(1)
+ .add_element(element=3, event_time_secs=2)
+ .advance_processing_time(1)
+ .add_element(element=4, event_time_secs=2)
+ .build()) # yapf: disable
+
+ cache = StreamingCache(cache_dir=None)
+ cache.write(values, CACHED_PCOLLECTION_KEY)
+
+ reader, _ = cache.read(
+ CACHED_PCOLLECTION_KEY, limiters=[ProcessingTimeLimiter(2)])
+ coder = coders.FastPrimitivesCoder()
+ events = list(reader)
+
+ # Units here are in microseconds.
+ # Expects that the elements are a slice of the original values where all
+ # processing time is less than the duration.
+ expected = [
+ TestStreamPayload.Event(
+ processing_time_event=TestStreamPayload.Event.AdvanceProcessingTime(
+ advance_duration=1)),
+ TestStreamPayload.Event(
+ element_event=TestStreamPayload.Event.AddElements(
+ elements=[
+ TestStreamPayload.TimestampedElement(
+ encoded_element=coder.encode(0), timestamp=0)
+ ],
+ tag=CACHED_PCOLLECTION_KEY)),
+ TestStreamPayload.Event(
+ processing_time_event=TestStreamPayload.Event.AdvanceProcessingTime(
+ advance_duration=1 * 10**6)),
+ TestStreamPayload.Event(
+ element_event=TestStreamPayload.Event.AddElements(
+ elements=[
+ TestStreamPayload.TimestampedElement(
+ encoded_element=coder.encode(1), timestamp=1 * 10**6)
+ ],
+ tag=CACHED_PCOLLECTION_KEY)),
+ ]
+ self.assertSequenceEqual(events, expected)
+
if __name__ == '__main__':
unittest.main()
diff --git a/sdks/python/apache_beam/runners/interactive/options/capture_limiters.py b/sdks/python/apache_beam/runners/interactive/options/capture_limiters.py
index 2c84f80..c9888ab 100644
--- a/sdks/python/apache_beam/runners/interactive/options/capture_limiters.py
+++ b/sdks/python/apache_beam/runners/interactive/options/capture_limiters.py
@@ -109,12 +109,14 @@
self._count += 1
def is_triggered(self):
- return self._count >= self._max_count
+ return self._count > self._max_count
class ProcessingTimeLimiter(ElementLimiter):
"""Limits by how long the ProcessingTime passed in the element stream.
+ Reads all elements from the timespan [start, start + duration).
+
This measures the duration from the first element in the stream. Each
subsequent element has a delta "advance_duration" that moves the internal
clock forward. This triggers when the duration from the internal clock and
diff --git a/sdks/python/apache_beam/runners/interactive/options/capture_limiters_test.py b/sdks/python/apache_beam/runners/interactive/options/capture_limiters_test.py
index 850c56e2c..347cb8e 100644
--- a/sdks/python/apache_beam/runners/interactive/options/capture_limiters_test.py
+++ b/sdks/python/apache_beam/runners/interactive/options/capture_limiters_test.py
@@ -28,7 +28,7 @@
def test_count_limiter(self):
limiter = CountLimiter(5)
- for e in range(4):
+ for e in range(5):
limiter.update(e)
self.assertFalse(limiter.is_triggered())
diff --git a/sdks/python/apache_beam/runners/interactive/testing/test_cache_manager.py b/sdks/python/apache_beam/runners/interactive/testing/test_cache_manager.py
index f39f016..098f249 100644
--- a/sdks/python/apache_beam/runners/interactive/testing/test_cache_manager.py
+++ b/sdks/python/apache_beam/runners/interactive/testing/test_cache_manager.py
@@ -45,11 +45,23 @@
def _latest_version(self, *labels):
return True
- def read(self, *labels):
+ def read(self, *labels, **args):
if not self.exists(*labels):
return itertools.chain([]), -1
- ret = itertools.chain(self._cached[self._key(*labels)])
- return ret, None
+
+ limiters = args.pop('limiters', [])
+
+ def limit_reader(r):
+ for e in r:
+ for l in limiters:
+ l.update(e)
+
+ if any(l.is_triggered() for l in limiters):
+ break
+
+ yield e
+
+ return limit_reader(itertools.chain(self._cached[self._key(*labels)])), None
def write(self, value, *labels):
if not self.exists(*labels):
diff --git a/sdks/python/apache_beam/transforms/validate_runner_xlang_test.py b/sdks/python/apache_beam/transforms/validate_runner_xlang_test.py
index 4d6b56b..d6ff006 100644
--- a/sdks/python/apache_beam/transforms/validate_runner_xlang_test.py
+++ b/sdks/python/apache_beam/transforms/validate_runner_xlang_test.py
@@ -14,6 +14,42 @@
# See the License for the specific language governing permissions and
# limitations under the License.
#
+
+"""
+###########################################################
+Runner Validation Test Suite for Cross-language Transforms
+###########################################################
+ As per Beams's Portability Framework design, Cross-language transforms
+ should work out of the box. In spite of this, there always exists a
+ possibility of rough edges existing. It could be caused due to unpolished
+ implementation of any part of the execution code path, for example:
+ - Transform expansion [SDK]
+ - Pipeline construction [SDK]
+ - Cross-language artifact staging [Runner]
+ - Language specific serialization/deserialization of PCollection (and
+ other data types) [Runner/SDK]
+
+ In an effort to improve developer visibility into potential problems,
+ this test suite validates correct execution of 5 Core Beam transforms when
+ used as cross-language transforms within the Python SDK from any foreign SDK:
+ - ParDo
+ (https://beam.apache.org/documentation/programming-guide/#pardo)
+ - GroupByKey
+ (https://beam.apache.org/documentation/programming-guide/#groupbykey)
+ - CoGroupByKey
+ (https://beam.apache.org/documentation/programming-guide/#cogroupbykey)
+ - Combine
+ (https://beam.apache.org/documentation/programming-guide/#combine)
+ - Flatten
+ (https://beam.apache.org/documentation/programming-guide/#flatten)
+ - Partition
+ (https://beam.apache.org/documentation/programming-guide/#partition)
+
+ See Runner Validation Test Plan for Cross-language transforms at
+https://docs.google.com/document/d/1xQp0ElIV84b8OCVz8CD2hvbiWdR8w4BvWxPTZJZA6NA
+ for further details.
+"""
+
from __future__ import absolute_import
import logging
@@ -46,6 +82,15 @@
'localhost:%s' % os.environ.get('EXPANSION_PORT'))
def run_prefix(self, pipeline):
+ """
+ Target transform - ParDo
+ (https://beam.apache.org/documentation/programming-guide/#pardo)
+ Test scenario - Mapping elements from a single input collection to a
+ single output collection
+ Boundary conditions checked -
+ - PCollection<?> to external transforms
+ - PCollection<?> from external transforms
+ """
with pipeline as p:
res = (
p
@@ -57,6 +102,15 @@
assert_that(res, equal_to(['0a', '0b']))
def run_multi_input_output_with_sideinput(self, pipeline):
+ """
+ Target transform - ParDo
+ (https://beam.apache.org/documentation/programming-guide/#pardo)
+ Test scenario - Mapping elements from multiple input collections (main
+ and side) to multiple output collections (main and side)
+ Boundary conditions checked -
+ - PCollectionTuple to external transforms
+ - PCollectionTuple from external transforms
+ """
with pipeline as p:
main1 = p | 'Main1' >> beam.Create(
['a', 'bb'], reshuffle=False).with_output_types(unicode)
@@ -70,6 +124,15 @@
assert_that(res['side'], equal_to(['ss']), label='CheckSide')
def run_group_by_key(self, pipeline):
+ """
+ Target transform - GroupByKey
+ (https://beam.apache.org/documentation/programming-guide/#groupbykey)
+ Test scenario - Grouping a collection of KV<K,V> to a collection of
+ KV<K, Iterable<V>> by key
+ Boundary conditions checked -
+ - PCollection<KV<?, ?>> to external transforms
+ - PCollection<KV<?, Iterable<?>>> from external transforms
+ """
with pipeline as p:
res = (
p
@@ -81,6 +144,15 @@
assert_that(res, equal_to(['0:1,2', '1:3']))
def run_cogroup_by_key(self, pipeline):
+ """
+ Target transform - CoGroupByKey
+ (https://beam.apache.org/documentation/programming-guide/#cogroupbykey)
+ Test scenario - Grouping multiple input collections with keys to a
+ collection of KV<K, CoGbkResult> by key
+ Boundary conditions checked -
+ - KeyedPCollectionTuple<?> to external transforms
+ - PCollection<KV<?, Iterable<?>>> from external transforms
+ """
with pipeline as p:
col1 = p | 'create_col1' >> beam.Create(
[(0, "1"), (0, "2"), (1, "3")], reshuffle=False).with_output_types(
@@ -95,6 +167,15 @@
assert_that(res, equal_to(['0:1,2,4', '1:3,5,6']))
def run_combine_globally(self, pipeline):
+ """
+ Target transform - Combine
+ (https://beam.apache.org/documentation/programming-guide/#combine)
+ Test scenario - Combining elements globally with a predefined simple
+ CombineFn
+ Boundary conditions checked -
+ - PCollection<?> to external transforms
+ - PCollection<?> from external transforms
+ """
with pipeline as p:
res = (
p
@@ -104,6 +185,15 @@
assert_that(res, equal_to([6]))
def run_combine_per_key(self, pipeline):
+ """
+ Target transform - Combine
+ (https://beam.apache.org/documentation/programming-guide/#combine)
+ Test scenario - Combining elements per key with a predefined simple
+ merging function
+ Boundary conditions checked -
+ - PCollection<?> to external transforms
+ - PCollection<?> from external transforms
+ """
with pipeline as p:
res = (
p
@@ -114,6 +204,14 @@
assert_that(res, equal_to([('a', 3), ('b', 3)]))
def run_flatten(self, pipeline):
+ """
+ Target transform - Flatten
+ (https://beam.apache.org/documentation/programming-guide/#flatten)
+ Test scenario - Merging multiple collections into a single collection
+ Boundary conditions checked -
+ - PCollectionList<?> to external transforms
+ - PCollection<?> from external transforms
+ """
with pipeline as p:
col1 = p | 'col1' >> beam.Create([1, 2, 3]).with_output_types(int)
col2 = p | 'col2' >> beam.Create([4, 5, 6]).with_output_types(int)
@@ -123,6 +221,15 @@
assert_that(res, equal_to([1, 2, 3, 4, 5, 6]))
def run_partition(self, pipeline):
+ """
+ Target transform - Partition
+ (https://beam.apache.org/documentation/programming-guide/#partition)
+ Test scenario - Splitting a single collection into multiple collections
+ with a predefined simple PartitionFn
+ Boundary conditions checked -
+ - PCollection<?> to external transforms
+ - PCollectionList<?> from external transforms
+ """
with pipeline as p:
res = (
p