blob: 46bde96c30cba5cdb80a3d0d7f809561ca2f1e6c [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.runners.spark.structuredstreaming.translation.batch;
import static org.apache.beam.runners.spark.structuredstreaming.Constants.BEAM_SOURCE_OPTION;
import static org.apache.beam.runners.spark.structuredstreaming.Constants.DEFAULT_PARALLELISM;
import static org.apache.beam.runners.spark.structuredstreaming.Constants.PIPELINE_OPTIONS;
import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument;
import java.io.IOException;
import java.io.Serializable;
import java.util.List;
import java.util.Map;
import java.util.Set;
import org.apache.beam.runners.core.construction.SerializablePipelineOptions;
import org.apache.beam.runners.core.serialization.Base64Serializer;
import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.RowHelpers;
import org.apache.beam.runners.spark.structuredstreaming.translation.helpers.SchemaHelpers;
import org.apache.beam.sdk.io.BoundedSource;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.util.WindowedValue;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Strings;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.collect.ImmutableSet;
import org.apache.spark.sql.catalyst.InternalRow;
import org.apache.spark.sql.connector.catalog.SupportsRead;
import org.apache.spark.sql.connector.catalog.Table;
import org.apache.spark.sql.connector.catalog.TableCapability;
import org.apache.spark.sql.connector.catalog.TableProvider;
import org.apache.spark.sql.connector.expressions.Transform;
import org.apache.spark.sql.connector.read.Batch;
import org.apache.spark.sql.connector.read.InputPartition;
import org.apache.spark.sql.connector.read.PartitionReader;
import org.apache.spark.sql.connector.read.PartitionReaderFactory;
import org.apache.spark.sql.connector.read.Scan;
import org.apache.spark.sql.connector.read.ScanBuilder;
import org.apache.spark.sql.types.StructType;
import org.apache.spark.sql.util.CaseInsensitiveStringMap;
/**
* Spark DataSourceV2 API was removed in Spark3. This is a Beam source wrapper using the new spark 3
* source API.
*/
public class DatasetSourceBatch implements TableProvider {
private static final StructType BINARY_SCHEMA = SchemaHelpers.binarySchema();
public DatasetSourceBatch() {}
@Override
public StructType inferSchema(CaseInsensitiveStringMap options) {
return BINARY_SCHEMA;
}
@Override
public boolean supportsExternalMetadata() {
return true;
}
@Override
public Table getTable(
StructType schema, Transform[] partitioning, Map<String, String> properties) {
return new DatasetSourceBatchTable();
}
private static class DatasetSourceBatchTable implements SupportsRead {
@Override
public ScanBuilder newScanBuilder(CaseInsensitiveStringMap options) {
return new ScanBuilder() {
@Override
public Scan build() {
return new Scan() { // scan for Batch reading
@Override
public StructType readSchema() {
return BINARY_SCHEMA;
}
@Override
public Batch toBatch() {
return new BeamBatch<>(options);
}
};
}
};
}
@Override
public String name() {
return "BeamSource";
}
@Override
public StructType schema() {
return BINARY_SCHEMA;
}
@Override
public Set<TableCapability> capabilities() {
final ImmutableSet<TableCapability> capabilities =
ImmutableSet.of(TableCapability.BATCH_READ);
return capabilities;
}
private static class BeamBatch<T> implements Batch, Serializable {
private final int numPartitions;
private final BoundedSource<T> source;
private final SerializablePipelineOptions serializablePipelineOptions;
private BeamBatch(CaseInsensitiveStringMap options) {
if (Strings.isNullOrEmpty(options.get(BEAM_SOURCE_OPTION))) {
throw new RuntimeException("Beam source was not set in DataSource options");
}
this.source =
Base64Serializer.deserializeUnchecked(
options.get(BEAM_SOURCE_OPTION), BoundedSource.class);
if (Strings.isNullOrEmpty(DEFAULT_PARALLELISM)) {
throw new RuntimeException("Spark default parallelism was not set in DataSource options");
}
this.numPartitions = Integer.parseInt(options.get(DEFAULT_PARALLELISM));
checkArgument(numPartitions > 0, "Number of partitions must be greater than zero.");
if (Strings.isNullOrEmpty(options.get(PIPELINE_OPTIONS))) {
throw new RuntimeException("Beam pipelineOptions were not set in DataSource options");
}
this.serializablePipelineOptions =
new SerializablePipelineOptions(options.get(PIPELINE_OPTIONS));
}
@Override
public InputPartition[] planInputPartitions() {
PipelineOptions options = serializablePipelineOptions.get();
long desiredSizeBytes;
try {
desiredSizeBytes = source.getEstimatedSizeBytes(options) / numPartitions;
List<? extends BoundedSource<T>> splits = source.split(desiredSizeBytes, options);
InputPartition[] result = new InputPartition[splits.size()];
int i = 0;
for (BoundedSource<T> split : splits) {
result[i++] = new BeamInputPartition<>(split);
}
return result;
} catch (Exception e) {
throw new RuntimeException(
"Error in splitting BoundedSource " + source.getClass().getCanonicalName(), e);
}
}
@Override
public PartitionReaderFactory createReaderFactory() {
return new PartitionReaderFactory() {
@Override
public PartitionReader<InternalRow> createReader(InputPartition partition) {
return new BeamPartitionReader<T>(
((BeamInputPartition<T>) partition).getSource(), serializablePipelineOptions);
}
};
}
private static class BeamInputPartition<T> implements InputPartition {
private final BoundedSource<T> source;
private BeamInputPartition(BoundedSource<T> source) {
this.source = source;
}
public BoundedSource<T> getSource() {
return source;
}
}
private static class BeamPartitionReader<T> implements PartitionReader<InternalRow> {
private final BoundedSource<T> source;
private final BoundedSource.BoundedReader<T> reader;
private boolean started;
private boolean closed;
BeamPartitionReader(
BoundedSource<T> source, SerializablePipelineOptions serializablePipelineOptions) {
this.started = false;
this.closed = false;
this.source = source;
// reader is not serializable so lazy initialize it
try {
reader =
source.createReader(serializablePipelineOptions.get().as(PipelineOptions.class));
} catch (IOException e) {
throw new RuntimeException("Error creating BoundedReader ", e);
}
}
@Override
public boolean next() throws IOException {
if (!started) {
started = true;
return reader.start();
} else {
return !closed && reader.advance();
}
}
@Override
public InternalRow get() {
WindowedValue<T> windowedValue =
WindowedValue.timestampedValueInGlobalWindow(
reader.getCurrent(), reader.getCurrentTimestamp());
return RowHelpers.storeWindowedValueInRow(windowedValue, source.getOutputCoder());
}
@Override
public void close() throws IOException {
closed = true;
reader.close();
}
}
}
}
}