blob: f393f35edf661c3784111539e7b110cd2d359527 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.cassandra.spark.bulkwriter;
import java.util.Arrays;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.stream.Collectors;
import com.google.common.base.Preconditions;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.Lists;
import org.apache.commons.lang3.tuple.Pair;
import org.apache.cassandra.spark.common.schema.ColumnType;
import org.apache.cassandra.spark.data.CqlField;
import org.apache.spark.sql.types.StructType;
import org.jetbrains.annotations.NotNull;
import static org.apache.cassandra.spark.bulkwriter.SqlToCqlTypeConverter.CUSTOM;
import static org.apache.cassandra.spark.bulkwriter.SqlToCqlTypeConverter.LIST;
import static org.apache.cassandra.spark.bulkwriter.SqlToCqlTypeConverter.MAP;
import static org.apache.cassandra.spark.bulkwriter.SqlToCqlTypeConverter.SET;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public final class TableSchemaTestCommon
{
private TableSchemaTestCommon()
{
throw new IllegalStateException(getClass() + " is static utility class and shall not be instantiated");
}
public static Pair<StructType, ImmutableMap<String, CqlField.CqlType>> buildMatchedDataframeAndCqlColumns(
String[] fieldNames,
org.apache.spark.sql.types.DataType[] sparkTypes,
CqlField.CqlType[] cqlTypes)
{
StructType dataFrameSchema = new StructType();
ImmutableMap.Builder<String, CqlField.CqlType> cqlColumnsBuilder = ImmutableMap.builder();
for (int field = 0; field < fieldNames.length; field++)
{
dataFrameSchema = dataFrameSchema.add(fieldNames[field], sparkTypes[field]);
cqlColumnsBuilder.put(fieldNames[field], cqlTypes[field]);
}
ImmutableMap<String, CqlField.CqlType> cqlColumns = cqlColumnsBuilder.build();
return Pair.of(dataFrameSchema, cqlColumns);
}
@NotNull
public static CqlField.CqlType mockCqlType(String cqlName)
{
CqlField.CqlType mock = mock(CqlField.CqlType.class);
when(mock.name()).thenReturn(cqlName);
return mock;
}
@NotNull
public static CqlField.CqlCustom mockCqlCustom(String customTypeClassName)
{
CqlField.CqlCustom mock = mock(CqlField.CqlCustom.class);
when(mock.name()).thenReturn(CUSTOM);
when(mock.customTypeClassName()).thenReturn(customTypeClassName);
return mock;
}
@NotNull
public static CqlField.CqlCollection mockSetCqlType(String collectionCqlType)
{
return mockCollectionCqlType(SET, mockCqlType(collectionCqlType));
}
@NotNull
public static CqlField.CqlCollection mockListCqlType(String collectionCqlType)
{
return mockCollectionCqlType(LIST, mockCqlType(collectionCqlType));
}
@NotNull
public static CqlField.CqlCollection mockCollectionCqlType(String cqlName, CqlField.CqlType collectionType)
{
CqlField.CqlCollection mock = mock(CqlField.CqlCollection.class);
when(mock.name()).thenReturn(cqlName);
when(mock.type()).thenReturn(collectionType);
return mock;
}
@NotNull
public static CqlField.CqlType mockMapCqlType(String keyCqlName, String valueCqlName)
{
return mockMapCqlType(mockCqlType(keyCqlName), mockCqlType(valueCqlName));
}
@NotNull
public static CqlField.CqlMap mockMapCqlType(CqlField.CqlType keyType, CqlField.CqlType valueType)
{
CqlField.CqlMap mock = mock(CqlField.CqlMap.class);
when(mock.name()).thenReturn(MAP);
when(mock.keyType()).thenReturn(keyType);
when(mock.valueType()).thenReturn(valueType);
return mock;
}
public static TableSchema buildSchema(String[] fieldNames,
org.apache.spark.sql.types.DataType[] sparkTypes,
CqlField.CqlType[] driverTypes,
String[] partitionKeyColumns,
ColumnType<?>[] partitionKeyColumnTypes,
String[] primaryKeyColumnNames)
{
Pair<StructType, ImmutableMap<String, CqlField.CqlType>> pair = buildMatchedDataframeAndCqlColumns(fieldNames, sparkTypes, driverTypes);
ImmutableMap<String, CqlField.CqlType> cqlColumns = pair.getValue();
StructType dataFrameSchema = pair.getKey();
return
new MockTableSchemaBuilder()
.withCqlColumns(cqlColumns)
.withPartitionKeyColumns(partitionKeyColumns)
.withPrimaryKeyColumnNames(primaryKeyColumnNames)
.withCassandraVersion("3.0.24.8")
.withPartitionKeyColumnTypes(partitionKeyColumnTypes)
.withWriteMode(WriteMode.INSERT)
.withDataFrameSchema(dataFrameSchema)
.build();
}
public static class MockTableSchemaBuilder
{
private ImmutableMap<String, CqlField.CqlType> cqlColumns;
private String[] partitionKeyColumns;
private String[] primaryKeyColumnNames;
private String cassandraVersion;
private ColumnType[] partitionKeyColumnTypes;
private StructType dataFrameSchema;
private WriteMode writeMode = null;
public MockTableSchemaBuilder withCqlColumns(@NotNull Map<String, CqlField.CqlType> cqlColumns)
{
Preconditions.checkNotNull(cqlColumns, "cqlColumns cannot be null");
Preconditions.checkArgument(cqlColumns.size() > 0, "cqlColumns cannot be empty");
this.cqlColumns = ImmutableMap.copyOf(cqlColumns);
return this;
}
public MockTableSchemaBuilder withPartitionKeyColumns(@NotNull String... partitionKeyColumns)
{
Preconditions.checkNotNull(partitionKeyColumns, "partitionKeyColumns cannot be null");
Preconditions.checkArgument(partitionKeyColumns.length > 0, "partitionKeyColumns cannot be empty");
this.partitionKeyColumns = partitionKeyColumns;
return this;
}
public MockTableSchemaBuilder withPrimaryKeyColumnNames(@NotNull String... primaryKeyColumnNames)
{
Preconditions.checkNotNull(primaryKeyColumnNames, "primaryKeyColumnNames cannot be null");
Preconditions.checkArgument(primaryKeyColumnNames.length > 0, "primaryKeyColumnNames cannot be empty");
this.primaryKeyColumnNames = primaryKeyColumnNames;
return this;
}
public MockTableSchemaBuilder withCassandraVersion(@NotNull String cassandraVersion)
{
Preconditions.checkNotNull(cassandraVersion, "cassandraVersion cannot be null");
Preconditions.checkArgument(cassandraVersion.length() > 0, "cassandraVersion cannot be an empty string");
this.cassandraVersion = cassandraVersion;
return this;
}
public MockTableSchemaBuilder withPartitionKeyColumnTypes(@NotNull ColumnType<?>... partitionKeyColumnTypes)
{
Preconditions.checkNotNull(partitionKeyColumnTypes, "partitionKeyColumnTypes cannot be null");
Preconditions.checkArgument(partitionKeyColumnTypes.length > 0, "partitionKeyColumnTypes cannot be empty");
this.partitionKeyColumnTypes = Arrays.copyOf(partitionKeyColumnTypes, partitionKeyColumnTypes.length);
return this;
}
public MockTableSchemaBuilder withWriteMode(@NotNull WriteMode writeMode)
{
Preconditions.checkNotNull(writeMode, "writeMode cannot be null");
this.writeMode = writeMode;
return this;
}
public MockTableSchemaBuilder withDataFrameSchema(StructType dataFrameSchema)
{
Preconditions.checkNotNull(dataFrameSchema, "dataFrameSchema cannot be null");
Preconditions.checkArgument(dataFrameSchema.nonEmpty(), "dataFrameSchema cannot be empty");
this.dataFrameSchema = dataFrameSchema;
return this;
}
public TableSchema build()
{
Objects.requireNonNull(cqlColumns,
"cqlColumns cannot be null. Please provide a list of columns by calling #withCqlColumns");
Objects.requireNonNull(partitionKeyColumns,
"partitionKeyColumns cannot be null. Please provide a list of columns by calling #withPartitionKeyColumns");
Objects.requireNonNull(primaryKeyColumnNames,
"primaryKeyColumnNames cannot be null. Please provide a list of columns by calling #withPrimaryKeyColumnNames");
Objects.requireNonNull(cassandraVersion,
"cassandraVersion cannot be null. Please provide a list of columns by calling #withCassandraVersion");
Objects.requireNonNull(partitionKeyColumnTypes,
"partitionKeyColumnTypes cannot be null. Please provide a list of columns by calling #withPartitionKeyColumnTypes");
Objects.requireNonNull(writeMode,
"writeMode cannot be null. Please provide the write mode by calling #withWriteMode");
Objects.requireNonNull(dataFrameSchema,
"dataFrameSchema cannot be null. Please provide the write mode by calling #withDataFrameSchema");
MockTableInfoProvider tableInfoProvider = new MockTableInfoProvider(cqlColumns,
partitionKeyColumns,
partitionKeyColumnTypes,
primaryKeyColumnNames,
cassandraVersion);
return new TableSchema(dataFrameSchema, tableInfoProvider, writeMode);
}
}
public static class MockTableInfoProvider implements TableInfoProvider
{
private final ImmutableMap<String, CqlField.CqlType> cqlColumns;
private final String[] partitionKeyColumns;
private final ColumnType[] partitionKeyColumnTypes;
private final String[] primaryKeyColumnNames;
Map<String, CqlField.CqlType> columns;
private final String cassandraVersion;
public MockTableInfoProvider(ImmutableMap<String, CqlField.CqlType> cqlColumns,
String[] partitionKeyColumns,
ColumnType[] partitionKeyColumnTypes,
String[] primaryKeyColumnNames,
String cassandraVersion)
{
this.cqlColumns = cqlColumns;
this.partitionKeyColumns = partitionKeyColumns;
this.partitionKeyColumnTypes = partitionKeyColumnTypes;
this.primaryKeyColumnNames = primaryKeyColumnNames;
columns = cqlColumns;
this.cassandraVersion = cassandraVersion.replaceAll("(\\w+-)*cassandra-", "");
}
@Override
public CqlField.CqlType getColumnType(String columnName)
{
return columns.get(columnName);
}
@Override
public List<ColumnType<?>> getPartitionKeyTypes()
{
return Lists.newArrayList(partitionKeyColumnTypes);
}
@Override
public boolean columnExists(String columnName)
{
return columns.containsKey(columnName);
}
@Override
public List<String> getPartitionKeyColumnNames()
{
return Arrays.asList(partitionKeyColumns);
}
@Override
public String getCreateStatement()
{
String keyDef = getKeyDef();
String createStatement = "CREATE TABLE test.test (" + cqlColumns.entrySet()
.stream()
.map(column -> column.getKey() + " " + column.getValue().name())
.collect(Collectors.joining(",\n")) + ", " + keyDef + ") "
+ "WITH COMPRESSION = {'class': '" + getCompression() + "'};";
System.out.println("Create Table:" + createStatement);
return createStatement;
}
private String getCompression()
{
switch (cassandraVersion.charAt(0))
{
case '4':
return "ZstdCompressor";
case '3':
return "LZ4Compressor";
default:
return "LZ4Compressor";
}
}
private String getKeyDef()
{
List<String> partitionColumns = Lists.newArrayList(partitionKeyColumns);
List<String> primaryColumns = Lists.newArrayList(primaryKeyColumnNames);
primaryColumns.removeAll(partitionColumns);
String partitionKey = "(" + String.join(",", partitionKeyColumns) + ")";
String clusteringKey = String.join(",", primaryColumns);
return "PRIMARY KEY (" + partitionKey + clusteringKey + ")";
}
@Override
public List<String> getPrimaryKeyColumnNames()
{
return Arrays.asList(primaryKeyColumnNames);
}
@Override
public String getName()
{
return "test";
}
@Override
public String getKeyspaceName()
{
return "test";
}
@Override
public boolean hasSecondaryIndex()
{
return false;
}
@Override
public List<String> getColumnNames()
{
return cqlColumns.keySet().asList();
}
}
}