blob: e65d90386e215d1462f560c489f4c57ace15fc87 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.extensions.sql.meta.provider.test;
import static org.apache.beam.vendor.calcite.v1_20_0.com.google.common.base.Preconditions.checkArgument;
import com.google.auto.service.AutoService;
import java.io.Serializable;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ConcurrentHashMap;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.concurrent.atomic.AtomicLong;
import java.util.stream.Collectors;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.extensions.sql.impl.BeamTableStatistics;
import org.apache.beam.sdk.extensions.sql.meta.BaseBeamTable;
import org.apache.beam.sdk.extensions.sql.meta.BeamSqlTable;
import org.apache.beam.sdk.extensions.sql.meta.BeamSqlTableFilter;
import org.apache.beam.sdk.extensions.sql.meta.DefaultTableFilter;
import org.apache.beam.sdk.extensions.sql.meta.Table;
import org.apache.beam.sdk.extensions.sql.meta.provider.InMemoryMetaTableProvider;
import org.apache.beam.sdk.extensions.sql.meta.provider.TableProvider;
import org.apache.beam.sdk.options.PipelineOptions;
import org.apache.beam.sdk.schemas.Schema;
import org.apache.beam.sdk.schemas.SchemaCoder;
import org.apache.beam.sdk.schemas.transforms.Select;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.values.PBegin;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PDone;
import org.apache.beam.sdk.values.POutput;
import org.apache.beam.sdk.values.Row;
/**
* Test in-memory table provider for use in tests.
*
* <p>Keeps global state and tracks class instances. Works only in DirectRunner.
*/
@AutoService(TableProvider.class)
public class TestTableProvider extends InMemoryMetaTableProvider {
static final Map<Long, Map<String, TableWithRows>> GLOBAL_TABLES = new ConcurrentHashMap<>();
private static final AtomicLong INSTANCES = new AtomicLong(0);
private final long instanceId = INSTANCES.getAndIncrement();
public TestTableProvider() {
GLOBAL_TABLES.put(instanceId, new ConcurrentHashMap<>());
}
@Override
public String getTableType() {
return "test";
}
public Map<String, TableWithRows> tables() {
return GLOBAL_TABLES.get(instanceId);
}
@Override
public void createTable(Table table) {
tables().put(table.getName(), new TableWithRows(instanceId, table));
}
@Override
public void dropTable(String tableName) {
tables().remove(tableName);
}
@Override
public Map<String, Table> getTables() {
return tables().entrySet().stream()
.collect(Collectors.toMap(Map.Entry::getKey, e -> e.getValue().table));
}
@Override
public synchronized BeamSqlTable buildBeamSqlTable(Table table) {
return new InMemoryTable(tables().get(table.getName()));
}
public void addRows(String tableName, Row... rows) {
checkArgument(tables().containsKey(tableName), "Table not found: " + tableName);
tables().get(tableName).rows.addAll(Arrays.asList(rows));
}
public List<Row> tableRows(String tableName) {
return tables().get(tableName).rows;
}
/** TableWitRows. */
public static class TableWithRows implements Serializable {
private Table table;
private List<Row> rows;
private long tableProviderInstanceId;
public TableWithRows(long tableProviderInstanceId, Table table) {
this.tableProviderInstanceId = tableProviderInstanceId;
this.table = table;
this.rows = new CopyOnWriteArrayList<>();
}
public List<Row> getRows() {
return rows;
}
}
private static class InMemoryTable extends BaseBeamTable {
private TableWithRows tableWithRows;
@Override
public PCollection.IsBounded isBounded() {
return PCollection.IsBounded.BOUNDED;
}
public InMemoryTable(TableWithRows tableWithRows) {
this.tableWithRows = tableWithRows;
}
public Coder<Row> rowCoder() {
return SchemaCoder.of(tableWithRows.table.getSchema());
}
@Override
public BeamTableStatistics getTableStatistics(PipelineOptions options) {
return BeamTableStatistics.createBoundedTableStatistics(
(double) tableWithRows.getRows().size());
}
@Override
public PCollection<Row> buildIOReader(PBegin begin) {
TableWithRows tableWithRows =
GLOBAL_TABLES
.get(this.tableWithRows.tableProviderInstanceId)
.get(this.tableWithRows.table.getName());
return begin.apply(Create.of(tableWithRows.rows).withCoder(rowCoder()));
}
@Override
public PCollection<Row> buildIOReader(
PBegin begin, BeamSqlTableFilter filters, List<String> fieldNames) {
PCollection<Row> withAllFields = buildIOReader(begin);
if (fieldNames.isEmpty() && filters instanceof DefaultTableFilter) {
return withAllFields;
}
PCollection<Row> result = withAllFields;
if (!(filters instanceof DefaultTableFilter)) {
throw new RuntimeException("Unimplemented at the moment.");
}
if (!fieldNames.isEmpty()) {
result = result.apply(Select.fieldNames(fieldNames.toArray(new String[0])));
}
return result;
}
@Override
public POutput buildIOWriter(PCollection<Row> input) {
input.apply(ParDo.of(new CollectorFn(tableWithRows)));
return PDone.in(input.getPipeline());
}
@Override
public boolean supportsProjects() {
return true;
}
@Override
public Schema getSchema() {
return tableWithRows.table.getSchema();
}
}
private static final class CollectorFn extends DoFn<Row, Row> {
private TableWithRows tableWithRows;
CollectorFn(TableWithRows tableWithRows) {
this.tableWithRows = tableWithRows;
}
@ProcessElement
public void procesElement(ProcessContext context) {
long instanceId = tableWithRows.tableProviderInstanceId;
String tableName = tableWithRows.table.getName();
GLOBAL_TABLES.get(instanceId).get(tableName).rows.add(context.element());
context.output(context.element());
}
}
}