blob: 76558fd688c1d75f3a2647229510f9695e3cdb47 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.io.jdbc;
import static org.apache.beam.sdk.io.common.DatabaseTestHelper.assertRowCount;
import static org.apache.beam.sdk.io.common.DatabaseTestHelper.getTestDataToWrite;
import static org.apache.beam.sdk.io.common.IOITHelper.readIOTestPipelineOptions;
import static org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.base.MoreObjects.firstNonNull;
import static org.junit.Assert.assertNotEquals;
import com.google.cloud.Timestamp;
import java.sql.SQLException;
import java.util.HashSet;
import java.util.List;
import java.util.Optional;
import java.util.Set;
import java.util.UUID;
import java.util.function.Function;
import org.apache.beam.sdk.PipelineResult;
import org.apache.beam.sdk.coders.VarIntCoder;
import org.apache.beam.sdk.io.GenerateSequence;
import org.apache.beam.sdk.io.common.DatabaseTestHelper;
import org.apache.beam.sdk.io.common.HashingFn;
import org.apache.beam.sdk.io.common.PostgresIOTestPipelineOptions;
import org.apache.beam.sdk.io.common.TestRow;
import org.apache.beam.sdk.state.StateSpec;
import org.apache.beam.sdk.state.StateSpecs;
import org.apache.beam.sdk.state.ValueState;
import org.apache.beam.sdk.testing.PAssert;
import org.apache.beam.sdk.testing.TestPipeline;
import org.apache.beam.sdk.testutils.NamedTestResult;
import org.apache.beam.sdk.testutils.metrics.IOITMetrics;
import org.apache.beam.sdk.testutils.metrics.MetricsReader;
import org.apache.beam.sdk.testutils.metrics.TimeMonitor;
import org.apache.beam.sdk.testutils.publishing.InfluxDBSettings;
import org.apache.beam.sdk.transforms.Combine;
import org.apache.beam.sdk.transforms.Count;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.FlatMapElements;
import org.apache.beam.sdk.transforms.Impulse;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.PeriodicSequence;
import org.apache.beam.sdk.transforms.PeriodicSequence.SequenceDefinition;
import org.apache.beam.sdk.transforms.Top;
import org.apache.beam.sdk.values.KV;
import org.apache.beam.sdk.values.PBegin;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.TypeDescriptors;
import org.apache.beam.vendor.guava.v32_1_2_jre.com.google.common.collect.Lists;
import org.joda.time.Duration;
import org.joda.time.Instant;
import org.junit.BeforeClass;
import org.junit.Rule;
import org.junit.Test;
import org.junit.runner.RunWith;
import org.junit.runners.JUnit4;
import org.postgresql.ds.PGSimpleDataSource;
/**
* A test of {@link org.apache.beam.sdk.io.jdbc.JdbcIO} on an independent Postgres instance.
*
* <p>This test requires a running instance of Postgres. Pass in connection information using
* PipelineOptions:
*
* <pre>
* ./gradlew integrationTest -p sdks/java/io/jdbc -DintegrationTestPipelineOptions='[
* "--postgresServerName=1.2.3.4",
* "--postgresUsername=postgres",
* "--postgresDatabaseName=myfancydb",
* "--postgresPassword=mypass",
* "--postgresSsl=false",
* "--numberOfRecords=1000" ]'
* --tests org.apache.beam.sdk.io.jdbc.JdbcIOIT
* -DintegrationTestRunner=direct
* </pre>
*
* <p>Please see 'build_rules.gradle' file for instructions regarding running this test using Beam
* performance testing framework.
*/
@RunWith(JUnit4.class)
public class JdbcIOIT {
private static final String NAMESPACE = JdbcIOIT.class.getName();
// the number of rows written to table in the performance test.
private static int numberOfRows;
private static PGSimpleDataSource dataSource;
private static String tableName;
private static InfluxDBSettings settings;
@Rule public TestPipeline pipelineWrite = TestPipeline.create();
@Rule public TestPipeline pipelineRead = TestPipeline.create();
@BeforeClass
public static void setup() {
PostgresIOTestPipelineOptions options;
try {
options = readIOTestPipelineOptions(PostgresIOTestPipelineOptions.class);
} catch (IllegalArgumentException e) {
options = null;
}
org.junit.Assume.assumeNotNull(options);
numberOfRows = options.getNumberOfRecords();
dataSource = DatabaseTestHelper.getPostgresDataSource(options);
tableName = DatabaseTestHelper.getTestTableName("IT");
settings =
InfluxDBSettings.builder()
.withHost(options.getInfluxHost())
.withDatabase(options.getInfluxDatabase())
.withMeasurement(options.getInfluxMeasurement())
.get();
}
/**
* Tests writing then reading data for a postgres database. Also used as a performance test of
* JDBCIO.
*/
@Test
public void testWriteThenRead() throws SQLException {
DatabaseTestHelper.createTable(dataSource, tableName);
try {
PipelineResult writeResult = runWrite();
PipelineResult.State writeState = writeResult.waitUntilFinish();
PipelineResult readResult = runRead(tableName);
PipelineResult.State readState = readResult.waitUntilFinish();
gatherAndPublishMetrics(writeResult, readResult);
// Fail the test if pipeline failed.
assertNotEquals(PipelineResult.State.FAILED, writeState);
assertNotEquals(PipelineResult.State.FAILED, readState);
} finally {
DatabaseTestHelper.deleteTable(dataSource, tableName);
}
}
private void gatherAndPublishMetrics(PipelineResult writeResult, PipelineResult readResult) {
String uuid = UUID.randomUUID().toString();
String timestamp = Timestamp.now().toString();
Set<Function<MetricsReader, NamedTestResult>> metricSuppliers =
getWriteMetricSuppliers(uuid, timestamp);
IOITMetrics writeMetrics =
new IOITMetrics(metricSuppliers, writeResult, NAMESPACE, uuid, timestamp);
writeMetrics.publishToInflux(settings);
IOITMetrics readMetrics =
new IOITMetrics(
getReadMetricSuppliers(uuid, timestamp), readResult, NAMESPACE, uuid, timestamp);
readMetrics.publishToInflux(settings);
}
private Set<Function<MetricsReader, NamedTestResult>> getWriteMetricSuppliers(
String uuid, String timestamp) {
Set<Function<MetricsReader, NamedTestResult>> suppliers = new HashSet<>();
Optional<Long> postgresTableSize =
DatabaseTestHelper.getPostgresTableSize(dataSource, tableName);
suppliers.add(
reader -> {
long writeStart = reader.getStartTimeMetric("write_time");
long writeEnd = reader.getEndTimeMetric("write_time");
return NamedTestResult.create(
uuid, timestamp, "write_time", (writeEnd - writeStart) / 1e3);
});
postgresTableSize.ifPresent(
tableFinalSize ->
suppliers.add(
ignore -> NamedTestResult.create(uuid, timestamp, "total_size", tableFinalSize)));
return suppliers;
}
private Set<Function<MetricsReader, NamedTestResult>> getReadMetricSuppliers(
String uuid, String timestamp) {
Set<Function<MetricsReader, NamedTestResult>> suppliers = new HashSet<>();
suppliers.add(
reader -> {
long readStart = reader.getStartTimeMetric("read_time");
long readEnd = reader.getEndTimeMetric("read_time");
return NamedTestResult.create(uuid, timestamp, "read_time", (readEnd - readStart) / 1e3);
});
return suppliers;
}
/**
* Writes the test dataset to postgres.
*
* <p>This method does not attempt to validate the data - we do so in the read test. This does
* make it harder to tell whether a test failed in the write or read phase, but the tests are much
* easier to maintain (don't need any separate code to write test data for read tests to the
* database.)
*/
private PipelineResult runWrite() {
pipelineWrite
.apply(GenerateSequence.from(0).to(numberOfRows))
.apply(ParDo.of(new TestRow.DeterministicallyConstructTestRowFn()))
.apply(ParDo.of(new TimeMonitor<>(NAMESPACE, "write_time")))
.apply(
JdbcIO.<TestRow>write()
.withDataSourceConfiguration(JdbcIO.DataSourceConfiguration.create(dataSource))
.withStatement(String.format("insert into %s values(?, ?)", tableName))
.withPreparedStatementSetter(new JdbcTestHelper.PrepareStatementFromTestRow()));
return pipelineWrite.run();
}
/**
* Read the test dataset from postgres and validate its contents.
*
* <p>When doing the validation, we wish to ensure that we: 1. Ensure *all* the rows are correct
* 2. Provide enough information in assertions such that it is easy to spot obvious errors (e.g.
* all elements have a similar mistake, or "only 5 elements were generated" and the user wants to
* see what the problem was.
*
* <p>We do not wish to generate and compare all of the expected values, so this method uses
* hashing to ensure that all expected data is present. However, hashing does not provide easy
* debugging information (failures like "every element was empty string" are hard to see), so we
* also: 1. Generate expected values for the first and last 500 rows 2. Use containsInAnyOrder to
* verify that their values are correct. Where first/last 500 rows is determined by the fact that
* we know all rows have a unique id - we can use the natural ordering of that key.
*/
private PipelineResult runRead(String tableName) {
if (tableName == null) {
tableName = JdbcIOIT.tableName;
}
PCollection<TestRow> namesAndIds =
pipelineRead
.apply(
JdbcIO.<TestRow>read()
.withDataSourceConfiguration(JdbcIO.DataSourceConfiguration.create(dataSource))
.withQuery(String.format("select name,id from %s;", tableName))
.withRowMapper(new JdbcTestHelper.CreateTestRowOfNameAndId()))
.apply(ParDo.of(new TimeMonitor<>(NAMESPACE, "read_time")));
PAssert.thatSingleton(namesAndIds.apply("Count All", Count.globally()))
.isEqualTo((long) numberOfRows);
PCollection<String> consolidatedHashcode =
namesAndIds
.apply(ParDo.of(new TestRow.SelectNameFn()))
.apply("Hash row contents", Combine.globally(new HashingFn()).withoutDefaults());
PAssert.that(consolidatedHashcode)
.containsInAnyOrder(TestRow.getExpectedHashForRowCount(numberOfRows));
PCollection<List<TestRow>> frontOfList = namesAndIds.apply(Top.smallest(500));
Iterable<TestRow> expectedFrontOfList = TestRow.getExpectedValues(0, 500);
PAssert.thatSingletonIterable(frontOfList).containsInAnyOrder(expectedFrontOfList);
PCollection<List<TestRow>> backOfList = namesAndIds.apply(Top.largest(500));
Iterable<TestRow> expectedBackOfList =
TestRow.getExpectedValues(numberOfRows - 500, numberOfRows);
PAssert.thatSingletonIterable(backOfList).containsInAnyOrder(expectedBackOfList);
return pipelineRead.run();
}
/** An integration test of auto sharding functionality using test stream. */
@Test
public void testWriteWithAutosharding() throws Exception {
String firstTableName = DatabaseTestHelper.getTestTableName("JDBCIT_AUTOSHARD");
DatabaseTestHelper.createTable(dataSource, firstTableName);
try {
PCollection<TestRow> dataCollection =
pipelineWrite.apply(
// emit 50_000 elements per seconds.
new GenerateRecordsStream(numberOfRows, 50_000, Duration.standardSeconds(1)));
dataCollection
.apply(ParDo.of(new TimeMonitor<>(NAMESPACE, "write_time")))
.apply(
JdbcIO.<TestRow>write()
.withDataSourceConfiguration(JdbcIO.DataSourceConfiguration.create(dataSource))
.withStatement(String.format("insert into %s values(?, ?)", firstTableName))
.withAutoSharding()
.withPreparedStatementSetter(new JdbcTestHelper.PrepareStatementFromTestRow()));
List<String> additionalArgs = Lists.newArrayList("--streaming");
if (pipelineWrite
.getOptions()
.getRunner()
.getCanonicalName()
.startsWith("org.apache.beam.runners.dataflow")) {
// enableStreamingEngine is a gcp option.
additionalArgs.add("--enableStreamingEngine");
}
pipelineWrite.runWithAdditionalOptionArgs(additionalArgs).waitUntilFinish();
assertRowCount(dataSource, firstTableName, numberOfRows);
} finally {
DatabaseTestHelper.deleteTable(dataSource, firstTableName);
}
}
/** Generate a stream of records for testing. */
private static class GenerateRecordsStream extends PTransform<PBegin, PCollection<TestRow>> {
private final long numRecords;
private final long numPerPeriod;
public GenerateRecordsStream(long numRecords, long numPerPeriod, Duration periodLength) {
this.numRecords = numRecords;
this.numPerPeriod = numPerPeriod;
}
@Override
public PCollection<TestRow> expand(PBegin pBegin) {
PCollection<TestRow> pcoll =
pBegin
.apply(Impulse.create())
.apply(ParDo.of(new GenerateSequenceDefinitionFn(numRecords / numPerPeriod)))
.apply(PeriodicSequence.create())
.apply(
"Add dumb key",
ParDo.of(
new DoFn<Instant, KV<Integer, Instant>>() {
@ProcessElement
public void processElement(ProcessContext c) {
c.output(KV.of(0, c.element()));
}
}))
.apply(ParDo.of(new EmitSequenceFn(numRecords, numPerPeriod)))
.apply(ParDo.of(new TestRow.DeterministicallyConstructTestRowFn()));
return pcoll;
}
}
/** Set Periodic Sequence starting time when pipeline executation begins. */
private static class GenerateSequenceDefinitionFn extends DoFn<byte[], SequenceDefinition> {
private final long numPulses;
@ProcessElement
public void processElement(ProcessContext c) {
Instant now = Instant.now();
c.output(
new SequenceDefinition(
now, now.plus(Duration.standardSeconds(numPulses)), Duration.standardSeconds(1)));
}
public GenerateSequenceDefinitionFn(long numPulses) {
this.numPulses = numPulses;
}
}
private static class EmitSequenceFn extends DoFn<KV<Integer, Instant>, Long> {
private final long numRecords;
private final long numPerPeriod;
public EmitSequenceFn(long numRecords, long numPerPeriod) {
this.numRecords = numRecords;
this.numPerPeriod = numPerPeriod;
}
@StateId("count")
@SuppressWarnings("unused")
private final StateSpec<ValueState<Integer>> countSpec = StateSpecs.value(VarIntCoder.of());
@ProcessElement
public void processElement(ProcessContext c, @StateId("count") ValueState<Integer> count) {
int current = firstNonNull(count.read(), 0);
count.write(current + 1);
long startId = current * numPerPeriod;
long endId = Math.min((current + 1) * numPerPeriod, numRecords);
for (long id = startId; id < endId; ++id) {
c.output(id);
}
}
}
/** An integration test of with write results functionality. */
@Test
public void testWriteWithWriteResults() throws Exception {
String firstTableName = DatabaseTestHelper.getTestTableName("JDBCIT_WRITE");
DatabaseTestHelper.createTable(dataSource, firstTableName);
try {
PCollection<KV<Integer, String>> dataCollection =
pipelineWrite
.apply(GenerateSequence.from(0).to(numberOfRows))
.apply(
FlatMapElements.into(
TypeDescriptors.kvs(
TypeDescriptors.integers(), TypeDescriptors.strings()))
.via(num -> getTestDataToWrite(1)));
PCollection<JdbcTestHelper.TestDto> resultSetCollection =
dataCollection.apply(
getJdbcWriteWithReturning(firstTableName)
.withWriteResults(
resultSet -> {
if (resultSet != null && resultSet.next()) {
return new JdbcTestHelper.TestDto(resultSet.getInt(1));
}
return new JdbcTestHelper.TestDto(JdbcTestHelper.TestDto.EMPTY_RESULT);
}));
resultSetCollection.setCoder(JdbcTestHelper.TEST_DTO_CODER);
PAssert.that(resultSetCollection.apply(Count.globally()))
.containsInAnyOrder(Long.valueOf(numberOfRows));
pipelineWrite.run().waitUntilFinish();
assertRowCount(dataSource, firstTableName, numberOfRows);
} finally {
DatabaseTestHelper.deleteTable(dataSource, firstTableName);
}
}
/**
* @return {@link JdbcIO.Write} transform that writes to {@param tableName} Postgres table and
* returns all fields of modified rows.
*/
private static JdbcIO.Write<KV<Integer, String>> getJdbcWriteWithReturning(String tableName) {
return JdbcIO.<KV<Integer, String>>write()
.withDataSourceConfiguration(JdbcIO.DataSourceConfiguration.create(dataSource))
.withStatement(String.format("insert into %s values(?, ?) returning *", tableName))
.withPreparedStatementSetter(
(element, statement) -> {
statement.setInt(1, element.getKey());
statement.setString(2, element.getValue());
});
}
}