blob: 845f137d04bf2004b1a135cd2ce589bcf710f962 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.beam.sdk.io.snowflake;
import static org.apache.beam.sdk.io.TextIO.readFiles;
import static org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Preconditions.checkArgument;
import com.google.auto.value.AutoValue;
import com.opencsv.CSVParser;
import com.opencsv.CSVParserBuilder;
import java.io.IOException;
import java.io.Serializable;
import java.security.PrivateKey;
import java.sql.Connection;
import java.sql.SQLException;
import java.text.SimpleDateFormat;
import java.util.ArrayList;
import java.util.Date;
import java.util.List;
import java.util.UUID;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
import javax.sql.DataSource;
import net.snowflake.client.jdbc.SnowflakeBasicDataSource;
import org.apache.beam.sdk.annotations.Experimental;
import org.apache.beam.sdk.coders.Coder;
import org.apache.beam.sdk.coders.ListCoder;
import org.apache.beam.sdk.coders.StringUtf8Coder;
import org.apache.beam.sdk.io.Compression;
import org.apache.beam.sdk.io.FileIO;
import org.apache.beam.sdk.io.FileSystems;
import org.apache.beam.sdk.io.TextIO;
import org.apache.beam.sdk.io.WriteFilesResult;
import org.apache.beam.sdk.io.fs.MoveOptions;
import org.apache.beam.sdk.io.fs.ResourceId;
import org.apache.beam.sdk.io.snowflake.credentials.KeyPairSnowflakeCredentials;
import org.apache.beam.sdk.io.snowflake.credentials.OAuthTokenSnowflakeCredentials;
import org.apache.beam.sdk.io.snowflake.credentials.SnowflakeCredentials;
import org.apache.beam.sdk.io.snowflake.credentials.UsernamePasswordSnowflakeCredentials;
import org.apache.beam.sdk.io.snowflake.data.SnowflakeTableSchema;
import org.apache.beam.sdk.io.snowflake.enums.CreateDisposition;
import org.apache.beam.sdk.io.snowflake.enums.WriteDisposition;
import org.apache.beam.sdk.io.snowflake.services.SnowflakeService;
import org.apache.beam.sdk.io.snowflake.services.SnowflakeServiceConfig;
import org.apache.beam.sdk.io.snowflake.services.SnowflakeServiceImpl;
import org.apache.beam.sdk.transforms.Combine;
import org.apache.beam.sdk.transforms.Create;
import org.apache.beam.sdk.transforms.DoFn;
import org.apache.beam.sdk.transforms.MapElements;
import org.apache.beam.sdk.transforms.PTransform;
import org.apache.beam.sdk.transforms.ParDo;
import org.apache.beam.sdk.transforms.Reify;
import org.apache.beam.sdk.transforms.Reshuffle;
import org.apache.beam.sdk.transforms.SerializableFunction;
import org.apache.beam.sdk.transforms.SimpleFunction;
import org.apache.beam.sdk.transforms.Values;
import org.apache.beam.sdk.transforms.View;
import org.apache.beam.sdk.transforms.Wait;
import org.apache.beam.sdk.transforms.display.DisplayData;
import org.apache.beam.sdk.transforms.display.HasDisplayData;
import org.apache.beam.sdk.values.PBegin;
import org.apache.beam.sdk.values.PCollection;
import org.apache.beam.sdk.values.PCollectionView;
import org.apache.beam.sdk.values.PDone;
import org.apache.beam.vendor.guava.v26_0_jre.com.google.common.base.Joiner;
import org.checkerframework.checker.nullness.qual.Nullable;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
/**
* IO to read and write data on Snowflake.
*
* <p>SnowflakeIO uses <a href="https://docs.snowflake.net/manuals/user-guide/jdbc.html">Snowflake
* JDBC</a> driver under the hood, but data isn't read/written using JDBC directly. Instead,
* SnowflakeIO uses dedicated <b>COPY</b> operations to read/write data from/to a cloud bucket. By
* now only Google Cloud Storage is supported.
*
* <p>To configure SnowflakeIO to read/write from your Snowflake instance, you have to provide a
* {@link DataSourceConfiguration} using {@link
* DataSourceConfiguration#create(SnowflakeCredentials)}, where {@link SnowflakeCredentials might be
* created using {@link org.apache.beam.sdk.io.snowflake.credentials.SnowflakeCredentialsFactory}}.
* Additionally one of {@link DataSourceConfiguration#withServerName(String)} or {@link
* DataSourceConfiguration#withUrl(String)} must be used to tell SnowflakeIO which instance to use.
* <br>
* There are also other options available to configure connection to Snowflake:
*
* <ul>
* <li>{@link DataSourceConfiguration#withWarehouse(String)} to specify which Warehouse to use
* <li>{@link DataSourceConfiguration#withDatabase(String)} to specify which Database to connect
* to
* <li>{@link DataSourceConfiguration#withSchema(String)} to specify which schema to use
* <li>{@link DataSourceConfiguration#withRole(String)} to specify which role to use
* <li>{@link DataSourceConfiguration#withLoginTimeout(Integer)} to specify the timeout for the
* login
* <li>{@link DataSourceConfiguration#withPortNumber(Integer)} to specify custom port of Snowflake
* instance
* </ul>
*
* <p>For example:
*
* <pre>{@code
* SnowflakeIO.DataSourceConfiguration dataSourceConfiguration =
* SnowflakeIO.DataSourceConfiguration.create(SnowflakeCredentialsFactory.of(options))
* .withServerName(options.getServerName())
* .withWarehouse(options.getWarehouse())
* .withDatabase(options.getDatabase())
* .withSchema(options.getSchema());
* }</pre>
*
* <h3>Reading from Snowflake</h3>
*
* <p>SnowflakeIO.Read returns a bounded collection of {@code T} as a {@code PCollection<T>}. T is
* the type returned by the provided {@link CsvMapper}.
*
* <p>For example
*
* <pre>{@code
* PCollection<GenericRecord> items = pipeline.apply(
* SnowflakeIO.<GenericRecord>read()
* .withDataSourceConfiguration(dataSourceConfiguration)
* .fromQuery(QUERY)
* .withStagingBucketName(...)
* .withStorageIntegrationName(...)
* .withCsvMapper(...)
* .withCoder(...));
* }</pre>
*
* <p><b>Important</b> When reading data from Snowflake, temporary CSV files are created on the
* specified stagingBucketName in directory named `sf_copy_csv_[RANDOM CHARS]_[TIMESTAMP]`. This
* directory and all the files are cleaned up automatically by default, but in case of failed
* pipeline they may remain and will have to be cleaned up manually.
*
* <h3>Writing to Snowflake</h3>
*
* <p>SnowflakeIO.Write supports writing records into a database. It writes a {@link PCollection<T>}
* to the database by converting each T into a {@link Object[]} via a user-provided {@link
* UserDataMapper}.
*
* <p>For example
*
* <pre>{@code
* items.apply(
* SnowflakeIO.<KV<Integer, String>>write()
* .withDataSourceConfiguration(dataSourceConfiguration)
* .withStagingBucketName(...)
* .withStorageIntegrationName(...)
* .withUserDataMapper(maper)
* .to(table);
* }</pre>
*
* <p><b>Important</b> When writing data to Snowflake, firstly data will be saved as CSV files on
* specified stagingBucketName in directory named 'data' and then into Snowflake.
*/
@Experimental
public class SnowflakeIO {
private static final Logger LOG = LoggerFactory.getLogger(SnowflakeIO.class);
private static final String CSV_QUOTE_CHAR = "'";
private static final String WRITE_TMP_PATH = "data";
/**
* Read data from Snowflake.
*
* @param snowflakeService user-defined {@link SnowflakeService}
* @param <T> Type of the data to be read.
*/
public static <T> Read<T> read(SnowflakeService snowflakeService) {
return new AutoValue_SnowflakeIO_Read.Builder<T>()
.setSnowflakeService(snowflakeService)
.build();
}
/**
* Read data from Snowflake.
*
* @param <T> Type of the data to be read.
*/
public static <T> Read<T> read() {
return read(new SnowflakeServiceImpl());
}
/**
* Interface for user-defined function mapping parts of CSV line into T. Used for
* SnowflakeIO.Read.
*
* @param <T> Type of data to be read.
*/
@FunctionalInterface
public interface CsvMapper<T> extends Serializable {
T mapRow(String[] parts) throws Exception;
}
/**
* Interface for user-defined function mapping T into array of Objects. Used for
* SnowflakeIO.Write.
*
* @param <T> Type of data to be written.
*/
@FunctionalInterface
public interface UserDataMapper<T> extends Serializable {
Object[] mapRow(T element);
}
/**
* Write data to Snowflake via COPY statement.
*
* @param <T> Type of data to be written.
*/
public static <T> Write<T> write() {
return new AutoValue_SnowflakeIO_Write.Builder<T>()
.setFileNameTemplate("output")
.setCreateDisposition(CreateDisposition.CREATE_IF_NEEDED)
.setWriteDisposition(WriteDisposition.APPEND)
.build();
}
/** Implementation of {@link #read()}. */
@AutoValue
public abstract static class Read<T> extends PTransform<PBegin, PCollection<T>> {
abstract @Nullable SerializableFunction<Void, DataSource> getDataSourceProviderFn();
abstract @Nullable String getQuery();
abstract @Nullable String getTable();
abstract @Nullable String getStorageIntegrationName();
abstract @Nullable String getStagingBucketName();
abstract @Nullable CsvMapper<T> getCsvMapper();
abstract @Nullable Coder<T> getCoder();
abstract @Nullable SnowflakeService getSnowflakeService();
abstract Builder<T> toBuilder();
@AutoValue.Builder
abstract static class Builder<T> {
abstract Builder<T> setDataSourceProviderFn(
SerializableFunction<Void, DataSource> dataSourceProviderFn);
abstract Builder<T> setQuery(String query);
abstract Builder<T> setTable(String table);
abstract Builder<T> setStorageIntegrationName(String storageIntegrationName);
abstract Builder<T> setStagingBucketName(String stagingBucketName);
abstract Builder<T> setCsvMapper(CsvMapper<T> csvMapper);
abstract Builder<T> setCoder(Coder<T> coder);
abstract Builder<T> setSnowflakeService(SnowflakeService snowflakeService);
abstract Read<T> build();
}
/**
* Setting information about Snowflake server.
*
* @param config - An instance of {@link DataSourceConfiguration}.
*/
public Read<T> withDataSourceConfiguration(final DataSourceConfiguration config) {
return withDataSourceProviderFn(new DataSourceProviderFromDataSourceConfiguration(config));
}
/**
* Setting function that will provide {@link DataSourceConfiguration} in runtime.
*
* @param dataSourceProviderFn a {@link SerializableFunction}.
*/
public Read<T> withDataSourceProviderFn(
SerializableFunction<Void, DataSource> dataSourceProviderFn) {
return toBuilder().setDataSourceProviderFn(dataSourceProviderFn).build();
}
/**
* A query to be executed in Snowflake.
*
* @param query - String with query.
*/
public Read<T> fromQuery(String query) {
return toBuilder().setQuery(query).build();
}
/**
* A table name to be read in Snowflake.
*
* @param table - String with the name of the table.
*/
public Read<T> fromTable(String table) {
return toBuilder().setTable(table).build();
}
/**
* Name of the cloud bucket (GCS by now) to use as tmp location of CSVs during COPY statement.
*
* @param stagingBucketName - String with the name of the bucket.
*/
public Read<T> withStagingBucketName(String stagingBucketName) {
return toBuilder().setStagingBucketName(stagingBucketName).build();
}
/**
* Name of the Storage Integration in Snowflake to be used. See
* https://docs.snowflake.com/en/sql-reference/sql/create-storage-integration.html for
* reference.
*
* @param integrationName - String with the name of the Storage Integration.
*/
public Read<T> withStorageIntegrationName(String integrationName) {
return toBuilder().setStorageIntegrationName(integrationName).build();
}
/**
* User-defined function mapping CSV lines into user data.
*
* @param csvMapper - an instance of {@link CsvMapper}.
*/
public Read<T> withCsvMapper(CsvMapper<T> csvMapper) {
return toBuilder().setCsvMapper(csvMapper).build();
}
/**
* A Coder to be used by the output PCollection generated by the source.
*
* @param coder - an instance of {@link Coder}.
*/
public Read<T> withCoder(Coder<T> coder) {
return toBuilder().setCoder(coder).build();
}
@Override
public PCollection<T> expand(PBegin input) {
checkArguments();
String tmpDirName = makeTmpDirName();
String stagingBucketDir = String.format("%s/%s/", getStagingBucketName(), tmpDirName);
PCollection<Void> emptyCollection = input.apply(Create.of((Void) null));
PCollection<T> output =
emptyCollection
.apply(
ParDo.of(
new CopyIntoStageFn(
getDataSourceProviderFn(),
getQuery(),
getTable(),
getStorageIntegrationName(),
stagingBucketDir,
getSnowflakeService())))
.apply(Reshuffle.viaRandomKey())
.apply(FileIO.matchAll())
.apply(FileIO.readMatches())
.apply(readFiles())
.apply(ParDo.of(new MapCsvToStringArrayFn()))
.apply(ParDo.of(new MapStringArrayToUserDataFn<>(getCsvMapper())));
output.setCoder(getCoder());
emptyCollection
.apply(Wait.on(output))
.apply(ParDo.of(new CleanTmpFilesFromGcsFn(stagingBucketDir)));
return output;
}
private void checkArguments() {
// Either table or query is required. If query is present, it's being used, table is used
// otherwise
checkArgument(getStorageIntegrationName() != null, "withStorageIntegrationName is required");
checkArgument(getStagingBucketName() != null, "withStagingBucketName is required");
checkArgument(
getQuery() != null || getTable() != null, "fromTable() or fromQuery() is required");
checkArgument(
!(getQuery() != null && getTable() != null),
"fromTable() and fromQuery() are not allowed together");
checkArgument(getCsvMapper() != null, "withCsvMapper() is required");
checkArgument(getCoder() != null, "withCoder() is required");
checkArgument(
(getDataSourceProviderFn() != null),
"withDataSourceConfiguration() or withDataSourceProviderFn() is required");
}
private String makeTmpDirName() {
return String.format(
"sf_copy_csv_%s_%s",
new SimpleDateFormat("yyyyMMdd_HHmmss").format(new Date()),
UUID.randomUUID().toString().subSequence(0, 8) // first 8 chars of UUID should be enough
);
}
private static class CopyIntoStageFn extends DoFn<Object, String> {
private final SerializableFunction<Void, DataSource> dataSourceProviderFn;
private final String query;
private final String table;
private final String storageIntegrationName;
private final String stagingBucketDir;
private final SnowflakeService snowflakeService;
private CopyIntoStageFn(
SerializableFunction<Void, DataSource> dataSourceProviderFn,
String query,
String table,
String storageIntegrationName,
String stagingBucketDir,
SnowflakeService snowflakeService) {
this.dataSourceProviderFn = dataSourceProviderFn;
this.query = query;
this.table = table;
this.storageIntegrationName = storageIntegrationName;
this.stagingBucketDir =
String.format(
"%s/run_%s/", stagingBucketDir, UUID.randomUUID().toString().subSequence(0, 8));
this.snowflakeService = snowflakeService;
}
@ProcessElement
public void processElement(ProcessContext context) throws Exception {
SnowflakeServiceConfig config =
new SnowflakeServiceConfig(
dataSourceProviderFn, table, query, storageIntegrationName, stagingBucketDir);
String output = snowflakeService.read(config);
context.output(output);
}
}
public static class MapCsvToStringArrayFn extends DoFn<String, String[]> {
@ProcessElement
public void processElement(ProcessContext c) throws IOException {
String csvLine = c.element();
CSVParser parser = new CSVParserBuilder().withQuoteChar(CSV_QUOTE_CHAR.charAt(0)).build();
String[] parts = parser.parseLine(csvLine);
c.output(parts);
}
}
private static class MapStringArrayToUserDataFn<T> extends DoFn<String[], T> {
private final CsvMapper<T> csvMapper;
public MapStringArrayToUserDataFn(CsvMapper<T> csvMapper) {
this.csvMapper = csvMapper;
}
@ProcessElement
public void processElement(ProcessContext context) throws Exception {
context.output(csvMapper.mapRow(context.element()));
}
}
public static class CleanTmpFilesFromGcsFn extends DoFn<Object, Object> {
private final String stagingBucketDir;
public CleanTmpFilesFromGcsFn(String stagingBucketDir) {
this.stagingBucketDir = stagingBucketDir;
}
@ProcessElement
public void processElement(ProcessContext c) throws IOException {
String combinedPath = stagingBucketDir + "/**";
List<ResourceId> paths =
FileSystems.match(combinedPath).metadata().stream()
.map(metadata -> metadata.resourceId())
.collect(Collectors.toList());
FileSystems.delete(paths, MoveOptions.StandardMoveOptions.IGNORE_MISSING_FILES);
}
}
@Override
public void populateDisplayData(DisplayData.Builder builder) {
super.populateDisplayData(builder);
if (getQuery() != null) {
builder.add(DisplayData.item("query", getQuery()));
}
if (getTable() != null) {
builder.add(DisplayData.item("table", getTable()));
}
builder.add(DisplayData.item("storageIntegrationName", getStagingBucketName()));
builder.add(DisplayData.item("stagingBucketName", getStagingBucketName()));
builder.add(DisplayData.item("csvMapper", getCsvMapper().getClass().getName()));
builder.add(DisplayData.item("coder", getCoder().getClass().getName()));
if (getDataSourceProviderFn() instanceof HasDisplayData) {
((HasDisplayData) getDataSourceProviderFn()).populateDisplayData(builder);
}
}
}
/** Implementation of {@link #write()}. */
@AutoValue
public abstract static class Write<T> extends PTransform<PCollection<T>, PDone> {
abstract @Nullable SerializableFunction<Void, DataSource> getDataSourceProviderFn();
abstract @Nullable String getTable();
abstract @Nullable String getStorageIntegrationName();
abstract @Nullable String getStagingBucketName();
abstract @Nullable String getQuery();
abstract @Nullable String getFileNameTemplate();
abstract @Nullable WriteDisposition getWriteDisposition();
abstract @Nullable CreateDisposition getCreateDisposition();
abstract @Nullable SnowflakeTableSchema getTableSchema();
abstract @Nullable UserDataMapper getUserDataMapper();
abstract @Nullable SnowflakeService getSnowflakeService();
abstract Builder<T> toBuilder();
@AutoValue.Builder
abstract static class Builder<T> {
abstract Builder<T> setDataSourceProviderFn(
SerializableFunction<Void, DataSource> dataSourceProviderFn);
abstract Builder<T> setTable(String table);
abstract Builder<T> setStorageIntegrationName(String storageIntegrationName);
abstract Builder<T> setStagingBucketName(String stagingBucketName);
abstract Builder<T> setQuery(String query);
abstract Builder<T> setFileNameTemplate(String fileNameTemplate);
abstract Builder<T> setUserDataMapper(UserDataMapper userDataMapper);
abstract Builder<T> setWriteDisposition(WriteDisposition writeDisposition);
abstract Builder<T> setCreateDisposition(CreateDisposition createDisposition);
abstract Builder<T> setTableSchema(SnowflakeTableSchema tableSchema);
abstract Builder<T> setSnowflakeService(SnowflakeService snowflakeService);
abstract Write<T> build();
}
/**
* Setting information about Snowflake server.
*
* @param config - An instance of {@link DataSourceConfiguration}.
*/
public Write<T> withDataSourceConfiguration(final DataSourceConfiguration config) {
return withDataSourceProviderFn(new DataSourceProviderFromDataSourceConfiguration(config));
}
/**
* Setting function that will provide {@link DataSourceConfiguration} in runtime.
*
* @param dataSourceProviderFn a {@link SerializableFunction}.
*/
public Write<T> withDataSourceProviderFn(
SerializableFunction<Void, DataSource> dataSourceProviderFn) {
return toBuilder().setDataSourceProviderFn(dataSourceProviderFn).build();
}
/**
* A table name to be written in Snowflake.
*
* @param table - String with the name of the table.
*/
public Write<T> to(String table) {
return toBuilder().setTable(table).build();
}
/**
* Name of the cloud bucket (GCS by now) to use as tmp location of CSVs during COPY statement.
*
* @param stagingBucketName - String with the name of the bucket.
*/
public Write<T> withStagingBucketName(String stagingBucketName) {
return toBuilder().setStagingBucketName(stagingBucketName).build();
}
/**
* Name of the Storage Integration in Snowflake to be used. See
* https://docs.snowflake.com/en/sql-reference/sql/create-storage-integration.html for
* reference.
*
* @param integrationName - String with the name of the Storage Integration.
*/
public Write<T> withStorageIntegrationName(String integrationName) {
return toBuilder().setStorageIntegrationName(integrationName).build();
}
/**
* A query to be executed in Snowflake.
*
* @param query - String with query.
*/
public Write<T> withQueryTransformation(String query) {
return toBuilder().setQuery(query).build();
}
/**
* A template name for files saved to GCP.
*
* @param fileNameTemplate - String with template name for files.
*/
public Write<T> withFileNameTemplate(String fileNameTemplate) {
return toBuilder().setFileNameTemplate(fileNameTemplate).build();
}
/**
* User-defined function mapping user data into CSV lines.
*
* @param userDataMapper - an instance of {@link UserDataMapper}.
*/
public Write<T> withUserDataMapper(UserDataMapper userDataMapper) {
return toBuilder().setUserDataMapper(userDataMapper).build();
}
/**
* A disposition to be used during writing to table phase.
*
* @param writeDisposition - an instance of {@link WriteDisposition}.
*/
public Write<T> withWriteDisposition(WriteDisposition writeDisposition) {
return toBuilder().setWriteDisposition(writeDisposition).build();
}
/**
* A disposition to be used during table preparation.
*
* @param createDisposition - an instance of {@link CreateDisposition}.
*/
public Write<T> withCreateDisposition(CreateDisposition createDisposition) {
return toBuilder().setCreateDisposition(createDisposition).build();
}
/**
* Table schema to be used during creating table.
*
* @param tableSchema - an instance of {@link SnowflakeTableSchema}.
*/
public Write<T> withTableSchema(SnowflakeTableSchema tableSchema) {
return toBuilder().setTableSchema(tableSchema).build();
}
/**
* A snowflake service which is supposed to be used. Note: Currently we have {@link
* SnowflakeServiceImpl} with corresponding {@link FakeSnowflakeServiceImpl} used for testing.
*
* @param snowflakeService - an instance of {@link SnowflakeService}.
*/
public Write<T> withSnowflakeService(SnowflakeService snowflakeService) {
return toBuilder().setSnowflakeService(snowflakeService).build();
}
@Override
public PDone expand(PCollection<T> input) {
checkArguments();
String stagingBucketDir = String.format("%s/%s/", getStagingBucketName(), WRITE_TMP_PATH);
PCollection<String> out = write(input, stagingBucketDir);
out.setCoder(StringUtf8Coder.of());
return PDone.in(out.getPipeline());
}
private void checkArguments() {
checkArgument(getStagingBucketName() != null, "withStagingBucketName is required");
checkArgument(getUserDataMapper() != null, "withUserDataMapper() is required");
checkArgument(
(getDataSourceProviderFn() != null),
"withDataSourceConfiguration() or withDataSourceProviderFn() is required");
checkArgument(getTable() != null, "to() is required");
}
private PCollection<String> write(PCollection<T> input, String stagingBucketDir) {
SnowflakeService snowflakeService =
getSnowflakeService() != null ? getSnowflakeService() : new SnowflakeServiceImpl();
PCollection<String> files = writeFiles(input, stagingBucketDir);
// Combining PCollection of files as a side input into one list of files
ListCoder<String> coder = ListCoder.of(StringUtf8Coder.of());
files =
(PCollection)
files
.getPipeline()
.apply(
Reify.viewInGlobalWindow(
(PCollectionView) files.apply(View.asList()), coder));
return (PCollection)
files.apply("Copy files to table", copyToTable(snowflakeService, stagingBucketDir));
}
private PCollection<String> writeFiles(PCollection<T> input, String stagingBucketDir) {
PCollection<String> mappedUserData =
input
.apply(
MapElements.via(
new SimpleFunction<T, Object[]>() {
@Override
public Object[] apply(T element) {
return getUserDataMapper().mapRow(element);
}
}))
.apply("Map Objects array to CSV lines", ParDo.of(new MapObjectsArrayToCsvFn()))
.setCoder(StringUtf8Coder.of());
WriteFilesResult filesResult =
mappedUserData.apply(
"Write files to specified location",
FileIO.<String>write()
.via(TextIO.sink())
.to(stagingBucketDir)
.withPrefix(getFileNameTemplate())
.withSuffix(".csv")
.withCompression(Compression.GZIP));
return (PCollection)
filesResult
.getPerDestinationOutputFilenames()
.apply("Parse KV filenames to Strings", Values.<String>create());
}
private ParDo.SingleOutput<Object, Object> copyToTable(
SnowflakeService snowflakeService, String stagingBucketDir) {
return ParDo.of(
new CopyToTableFn<>(
getDataSourceProviderFn(),
getTable(),
getQuery(),
stagingBucketDir,
getStorageIntegrationName(),
getCreateDisposition(),
getWriteDisposition(),
getTableSchema(),
snowflakeService));
}
}
public static class Concatenate extends Combine.CombineFn<String, List<String>, List<String>> {
@Override
public List<String> createAccumulator() {
return new ArrayList<>();
}
@Override
public List<String> addInput(List<String> mutableAccumulator, String input) {
mutableAccumulator.add(String.format("'%s'", input));
return mutableAccumulator;
}
@Override
public List<String> mergeAccumulators(Iterable<List<String>> accumulators) {
List<String> result = createAccumulator();
for (List<String> accumulator : accumulators) {
result.addAll(accumulator);
}
return result;
}
@Override
public List<String> extractOutput(List<String> accumulator) {
return accumulator;
}
}
/**
* Custom DoFn that maps {@link Object[]} into CSV line to be saved to Snowflake.
*
* <p>Adds Snowflake-specific quotations around strings.
*/
private static class MapObjectsArrayToCsvFn extends DoFn<Object[], String> {
@ProcessElement
public void processElement(ProcessContext context) {
List<Object> csvItems = new ArrayList<>();
for (Object o : context.element()) {
if (o instanceof String) {
String field = (String) o;
field = field.replace("'", "''");
field = quoteField(field);
csvItems.add(field);
} else {
csvItems.add(o);
}
}
context.output(Joiner.on(",").useForNull("").join(csvItems));
}
private String quoteField(String field) {
return quoteField(field, CSV_QUOTE_CHAR);
}
private String quoteField(String field, String quotation) {
return String.format("%s%s%s", quotation, field, quotation);
}
}
private static class CopyToTableFn<ParameterT, OutputT> extends DoFn<ParameterT, OutputT> {
private final SerializableFunction<Void, DataSource> dataSourceProviderFn;
private final String table;
private final String query;
private final SnowflakeTableSchema tableSchema;
private final String stagingBucketDir;
private final String storageIntegrationName;
private final WriteDisposition writeDisposition;
private final CreateDisposition createDisposition;
private final SnowflakeService snowflakeService;
CopyToTableFn(
SerializableFunction<Void, DataSource> dataSourceProviderFn,
String table,
String query,
String stagingBucketDir,
String storageIntegrationName,
CreateDisposition createDisposition,
WriteDisposition writeDisposition,
SnowflakeTableSchema tableSchema,
SnowflakeService snowflakeService) {
this.dataSourceProviderFn = dataSourceProviderFn;
this.table = table;
this.query = query;
this.stagingBucketDir = stagingBucketDir;
this.storageIntegrationName = storageIntegrationName;
this.writeDisposition = writeDisposition;
this.createDisposition = createDisposition;
this.tableSchema = tableSchema;
this.snowflakeService = snowflakeService;
}
@ProcessElement
public void processElement(ProcessContext context) throws Exception {
SnowflakeServiceConfig config =
new SnowflakeServiceConfig(
dataSourceProviderFn,
(List<String>) context.element(),
table,
query,
tableSchema,
createDisposition,
writeDisposition,
storageIntegrationName,
stagingBucketDir);
snowflakeService.write(config);
}
}
/**
* A POJO describing a {@link DataSource}, providing all properties allowing to create a {@link
* DataSource}.
*/
@AutoValue
public abstract static class DataSourceConfiguration implements Serializable {
public abstract @Nullable String getUrl();
public abstract @Nullable String getUsername();
public abstract @Nullable String getPassword();
public abstract @Nullable PrivateKey getPrivateKey();
public abstract @Nullable String getOauthToken();
public abstract @Nullable String getDatabase();
public abstract @Nullable String getWarehouse();
public abstract @Nullable String getSchema();
public abstract @Nullable String getServerName();
public abstract @Nullable Integer getPortNumber();
public abstract @Nullable String getRole();
public abstract @Nullable Integer getLoginTimeout();
public abstract @Nullable Boolean getSsl();
public abstract @Nullable Boolean getValidate();
public abstract @Nullable DataSource getDataSource();
abstract Builder builder();
@AutoValue.Builder
abstract static class Builder {
abstract Builder setUrl(String url);
abstract Builder setUsername(String username);
abstract Builder setPassword(String password);
abstract Builder setPrivateKey(PrivateKey privateKey);
abstract Builder setOauthToken(String oauthToken);
abstract Builder setDatabase(String database);
abstract Builder setWarehouse(String warehouse);
abstract Builder setSchema(String schema);
abstract Builder setServerName(String serverName);
abstract Builder setPortNumber(Integer portNumber);
abstract Builder setRole(String role);
abstract Builder setLoginTimeout(Integer loginTimeout);
abstract Builder setSsl(Boolean ssl);
abstract Builder setValidate(Boolean validate);
abstract Builder setDataSource(DataSource dataSource);
abstract DataSourceConfiguration build();
}
/**
* Creates {@link DataSourceConfiguration} from existing instance of {@link DataSource}.
*
* @param dataSource - an instance of {@link DataSource}.
*/
public static DataSourceConfiguration create(DataSource dataSource) {
checkArgument(dataSource instanceof Serializable, "dataSource must be Serializable");
return new AutoValue_SnowflakeIO_DataSourceConfiguration.Builder()
.setValidate(true)
.setDataSource(dataSource)
.build();
}
/**
* Creates {@link DataSourceConfiguration} from instance of {@link SnowflakeCredentials}.
*
* @param credentials - an instance of {@link SnowflakeCredentials}.
*/
public static DataSourceConfiguration create(SnowflakeCredentials credentials) {
if (credentials instanceof UsernamePasswordSnowflakeCredentials) {
return new AutoValue_SnowflakeIO_DataSourceConfiguration.Builder()
.setValidate(true)
.setUsername(((UsernamePasswordSnowflakeCredentials) credentials).getUsername())
.setPassword(((UsernamePasswordSnowflakeCredentials) credentials).getPassword())
.build();
} else if (credentials instanceof OAuthTokenSnowflakeCredentials) {
return new AutoValue_SnowflakeIO_DataSourceConfiguration.Builder()
.setValidate(true)
.setOauthToken(((OAuthTokenSnowflakeCredentials) credentials).getToken())
.build();
} else if (credentials instanceof KeyPairSnowflakeCredentials) {
return new AutoValue_SnowflakeIO_DataSourceConfiguration.Builder()
.setValidate(true)
.setUsername(((KeyPairSnowflakeCredentials) credentials).getUsername())
.setPrivateKey(((KeyPairSnowflakeCredentials) credentials).getPrivateKey())
.build();
}
throw new IllegalArgumentException(
"Can't create DataSourceConfiguration from given credentials");
}
/**
* Sets URL of Snowflake server in following format:
* jdbc:snowflake://<account_name>.snowflakecomputing.com
*
* <p>Either withUrl or withServerName is required.
*
* @param url - String with URL of the Snowflake server.
*/
public DataSourceConfiguration withUrl(String url) {
checkArgument(
url.startsWith("jdbc:snowflake://"),
"url must have format: jdbc:snowflake://<account_name>.snowflakecomputing.com");
checkArgument(
url.endsWith("snowflakecomputing.com"),
"url must have format: jdbc:snowflake://<account_name>.snowflakecomputing.com");
return builder().setUrl(url).build();
}
/**
* Sets database to use.
*
* @param database - String with database name.
*/
public DataSourceConfiguration withDatabase(String database) {
return builder().setDatabase(database).build();
}
/**
* Sets Snowflake Warehouse to use.
*
* @param warehouse - String with warehouse name.
*/
public DataSourceConfiguration withWarehouse(String warehouse) {
return builder().setWarehouse(warehouse).build();
}
/**
* Sets schema to use when connecting to Snowflake.
*
* @param schema - String with schema name.
*/
public DataSourceConfiguration withSchema(String schema) {
return builder().setSchema(schema).build();
}
/**
* Sets the name of the Snowflake server. Following format is required:
* <account_name>.snowflakecomputing.com
*
* <p>Either withServerName or withUrl is required.
*
* @param serverName - String with server name.
*/
public DataSourceConfiguration withServerName(String serverName) {
checkArgument(
serverName.endsWith("snowflakecomputing.com"),
"serverName must be in format <account_name>.snowflakecomputing.com");
return builder().setServerName(serverName).build();
}
/**
* Sets port number to use to connect to Snowflake.
*
* @param portNumber - Integer with port number.
*/
public DataSourceConfiguration withPortNumber(Integer portNumber) {
return builder().setPortNumber(portNumber).build();
}
/**
* Sets user's role to be used when running queries on Snowflake.
*
* @param role - String with role name.
*/
public DataSourceConfiguration withRole(String role) {
return builder().setRole(role).build();
}
/**
* Sets loginTimeout that will be used in {@link SnowflakeBasicDataSource:setLoginTimeout}.
*
* @param loginTimeout - Integer with timeout value.
*/
public DataSourceConfiguration withLoginTimeout(Integer loginTimeout) {
return builder().setLoginTimeout(loginTimeout).build();
}
/**
* Disables validation of connection parameters prior to pipeline submission.
*
* @return
*/
public DataSourceConfiguration withoutValidation() {
return builder().setValidate(false).build();
}
void populateDisplayData(DisplayData.Builder builder) {
if (getDataSource() != null) {
builder.addIfNotNull(DisplayData.item("dataSource", getDataSource().getClass().getName()));
} else {
builder.addIfNotNull(DisplayData.item("jdbcUrl", getUrl()));
builder.addIfNotNull(DisplayData.item("username", getUsername()));
}
}
/** Builds {@link SnowflakeBasicDataSource} based on the current configuration. */
public DataSource buildDatasource() {
if (getDataSource() == null) {
SnowflakeBasicDataSource basicDataSource = new SnowflakeBasicDataSource();
basicDataSource.setUrl(buildUrl());
if (getUsername() != null) {
basicDataSource.setUser(getUsername());
}
if (getPassword() != null) {
basicDataSource.setPassword(getPassword());
}
if (getPrivateKey() != null) {
basicDataSource.setPrivateKey(getPrivateKey());
}
if (getDatabase() != null) {
basicDataSource.setDatabaseName(getDatabase());
}
if (getWarehouse() != null) {
basicDataSource.setWarehouse(getWarehouse());
}
if (getSchema() != null) {
basicDataSource.setSchema(getSchema());
}
if (getRole() != null) {
basicDataSource.setRole(getRole());
}
if (getLoginTimeout() != null) {
try {
basicDataSource.setLoginTimeout(getLoginTimeout());
} catch (SQLException e) {
throw new RuntimeException("Failed to setLoginTimeout");
}
}
if (getOauthToken() != null) {
basicDataSource.setOauthToken(getOauthToken());
}
return basicDataSource;
}
return getDataSource();
}
private String buildUrl() {
StringBuilder url = new StringBuilder();
if (getUrl() != null) {
url.append(getUrl());
} else {
url.append("jdbc:snowflake://");
url.append(getServerName());
}
if (getPortNumber() != null) {
url.append(":").append(getPortNumber());
}
url.append("?application=beam");
return url.toString();
}
}
public static class DataSourceProviderFromDataSourceConfiguration
implements SerializableFunction<Void, DataSource>, HasDisplayData {
private static final ConcurrentHashMap<DataSourceConfiguration, DataSource> instances =
new ConcurrentHashMap<>();
private final DataSourceConfiguration config;
private DataSourceProviderFromDataSourceConfiguration(DataSourceConfiguration config) {
if (config.getValidate()) {
try {
Connection connection = config.buildDatasource().getConnection();
connection.close();
} catch (SQLException e) {
throw new IllegalArgumentException(
"Invalid DataSourceConfiguration. Underlying cause: " + e);
}
}
this.config = config;
}
public static SerializableFunction<Void, DataSource> of(DataSourceConfiguration config) {
return new DataSourceProviderFromDataSourceConfiguration(config);
}
@Override
public DataSource apply(Void input) {
return instances.computeIfAbsent(config, (config) -> config.buildDatasource());
}
@Override
public void populateDisplayData(DisplayData.Builder builder) {
config.populateDisplayData(builder);
}
}
}