| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| package org.apache.cassandra.spark.bulkwriter; |
| |
| import java.io.Serializable; |
| import java.util.Arrays; |
| import java.util.Collections; |
| import java.util.LinkedHashSet; |
| import java.util.List; |
| import java.util.Set; |
| import java.util.stream.Collectors; |
| import java.util.stream.Stream; |
| |
| import com.google.common.base.Preconditions; |
| import org.slf4j.Logger; |
| import org.slf4j.LoggerFactory; |
| |
| import org.apache.cassandra.spark.common.schema.ColumnType; |
| import org.apache.spark.sql.types.StructType; |
| |
| public class TableSchema implements Serializable |
| { |
| private static final Logger LOGGER = LoggerFactory.getLogger(TableSchema.class); |
| |
| final String createStatement; |
| final String modificationStatement; |
| final List<String> partitionKeyColumns; |
| final List<ColumnType<?>> partitionKeyColumnTypes; |
| final List<SqlToCqlTypeConverter.Converter<?>> converters; |
| private final List<Integer> keyFieldPositions; |
| private final WriteMode writeMode; |
| |
| public TableSchema(StructType dfSchema, TableInfoProvider tableInfo, WriteMode writeMode) |
| { |
| this.writeMode = writeMode; |
| |
| validateDataFrameCompatibility(dfSchema, tableInfo); |
| validateNoSecondaryIndexes(tableInfo); |
| |
| this.createStatement = getCreateStatement(tableInfo); |
| this.modificationStatement = getModificationStatement(dfSchema, tableInfo); |
| this.partitionKeyColumns = getPartitionKeyColumnNames(tableInfo); |
| this.partitionKeyColumnTypes = getPartitionKeyColumnTypes(tableInfo); |
| this.converters = getConverters(dfSchema, tableInfo); |
| LOGGER.info("Converters: {}", converters); |
| this.keyFieldPositions = getKeyFieldPositions(dfSchema, tableInfo.getColumnNames(), getRequiredKeyColumns(tableInfo)); |
| } |
| |
| private List<String> getRequiredKeyColumns(TableInfoProvider tableInfo) |
| { |
| switch (writeMode) |
| { |
| case INSERT: |
| // Inserts require all primary key columns |
| return tableInfo.getPrimaryKeyColumnNames(); |
| case DELETE_PARTITION: |
| // To delete a partition, we only need the partition key columns, not all primary key columns |
| return tableInfo.getPartitionKeyColumnNames(); |
| default: |
| throw new UnsupportedOperationException("Unknown WriteMode provided"); |
| } |
| } |
| |
| public Object[] normalize(Object[] row) |
| { |
| for (int index = 0; index < row.length; index++) |
| { |
| row[index] = converters.get(index).convert(row[index]); |
| } |
| return row; |
| } |
| |
| public Object[] getKeyColumns(Object[] allColumns) |
| { |
| Object[] result = new Object[keyFieldPositions.size()]; |
| for (int keyFieldPosition = 0; keyFieldPosition < keyFieldPositions.size(); keyFieldPosition++) |
| { |
| Object colVal = allColumns[keyFieldPositions.get(keyFieldPosition)]; |
| Preconditions.checkNotNull(colVal, "Found a null primary or composite key column in source data. All key columns must be non-null."); |
| result[keyFieldPosition] = colVal; |
| } |
| return result; |
| } |
| |
| private static List<SqlToCqlTypeConverter.Converter<?>> getConverters(StructType dfSchema, |
| TableInfoProvider tableInfo) |
| { |
| return Arrays.stream(dfSchema.fieldNames()) |
| .map(tableInfo::getColumnType) |
| .map(SqlToCqlTypeConverter::getConverter) |
| .collect(Collectors.toList()); |
| } |
| |
| private static List<ColumnType<?>> getPartitionKeyColumnTypes(TableInfoProvider tableInfo) |
| { |
| return tableInfo.getPartitionKeyTypes(); |
| } |
| |
| private static List<String> getPartitionKeyColumnNames(TableInfoProvider tableInfo) |
| { |
| return tableInfo.getPartitionKeyColumnNames(); |
| } |
| |
| private static String getCreateStatement(TableInfoProvider tableInfo) |
| { |
| String createStatement = tableInfo.getCreateStatement(); |
| LOGGER.info("CQL create statement for the table {}", createStatement); |
| return createStatement; |
| } |
| |
| private String getModificationStatement(StructType dfSchema, TableInfoProvider tableInfo) |
| { |
| switch (writeMode) |
| { |
| case INSERT: |
| return getInsertStatement(dfSchema, tableInfo); |
| case DELETE_PARTITION: |
| return getDeleteStatement(dfSchema, tableInfo); |
| default: |
| throw new UnsupportedOperationException("Unknown WriteMode provided"); |
| } |
| } |
| |
| private static String getInsertStatement(StructType dfSchema, TableInfoProvider tableInfo) |
| { |
| String insertStatement = String.format("INSERT INTO %s.%s (%s) VALUES (%s);", |
| tableInfo.getKeyspaceName(), |
| tableInfo.getName(), |
| String.join(",", dfSchema.fieldNames()), |
| Arrays.stream(dfSchema.fieldNames()) |
| .map(field -> "?") |
| .collect(Collectors.joining(","))); |
| |
| LOGGER.info("CQL insert statement for the RDD {}", insertStatement); |
| return insertStatement; |
| } |
| |
| private String getDeleteStatement(StructType dfSchema, TableInfoProvider tableInfo) |
| { |
| Stream<String> fieldEqualityStatements = Arrays.stream(dfSchema.fieldNames()).map(key -> key + "=?"); |
| String deleteStatement = String.format("DELETE FROM %s.%s where %s;", |
| tableInfo.getKeyspaceName(), |
| tableInfo.getName(), |
| fieldEqualityStatements.collect(Collectors.joining(" AND "))); |
| |
| LOGGER.info("CQL delete statement for the RDD {}", deleteStatement); |
| return deleteStatement; |
| } |
| |
| private void validateDataFrameCompatibility(StructType dfSchema, TableInfoProvider tableInfo) |
| { |
| Set<String> dfFields = new LinkedHashSet<>(); |
| Collections.addAll(dfFields, dfSchema.fieldNames()); |
| |
| validatePrimaryKeyColumnsProvided(tableInfo, dfFields); |
| |
| switch (writeMode) |
| { |
| case INSERT: |
| validateDataframeFieldsInTable(tableInfo, dfFields); |
| return; |
| case DELETE_PARTITION: |
| validateOnlyPartitionKeyColumnsInDataframe(tableInfo, dfFields); |
| return; |
| default: |
| LOGGER.warn("Unrecognized write mode {}", writeMode); |
| } |
| } |
| |
| private void validateOnlyPartitionKeyColumnsInDataframe(TableInfoProvider tableInfo, Set<String> dfFields) |
| { |
| Set<String> requiredKeyColumns = new LinkedHashSet<>(getRequiredKeyColumns(tableInfo)); |
| Preconditions.checkArgument(requiredKeyColumns.equals(dfFields), |
| String.format("Only partition key columns (%s) are supported in the input Dataframe" |
| + " when WRITE_MODE=DELETE_PARTITION but (%s) columns were provided", |
| String.join(",", requiredKeyColumns), String.join(",", dfFields))); |
| } |
| |
| private void validatePrimaryKeyColumnsProvided(TableInfoProvider tableInfo, Set<String> dfFields) |
| { |
| // Make sure all primary key columns are provided |
| List<String> requiredKeyColumns = getRequiredKeyColumns(tableInfo); |
| Preconditions.checkArgument(dfFields.containsAll(requiredKeyColumns), |
| "Missing some required key components in DataFrame => " + requiredKeyColumns |
| .stream() |
| .filter(column -> !dfFields.contains(column)) |
| .collect(Collectors.joining(","))); |
| } |
| |
| private static void validateDataframeFieldsInTable(TableInfoProvider tableInfo, Set<String> dfFields) |
| { |
| // Make sure all fields in DF schema are part of table |
| String unknownFields = dfFields.stream() |
| .filter(columnName -> !tableInfo.columnExists(columnName)) |
| .collect(Collectors.joining(",")); |
| Preconditions.checkArgument(unknownFields.isEmpty(), "Unknown fields in data frame => " + unknownFields); |
| } |
| |
| private static void validateNoSecondaryIndexes(TableInfoProvider tableInfo) |
| { |
| if (tableInfo.hasSecondaryIndex()) |
| { |
| throw new RuntimeException("Bulkwriter doesn't support secondary indexes"); |
| } |
| } |
| |
| private static List<Integer> getKeyFieldPositions(StructType dfSchema, |
| List<String> columnNames, |
| List<String> keyFieldNames) |
| { |
| List<String> dfFieldNames = Arrays.asList(dfSchema.fieldNames()); |
| return columnNames.stream() |
| .filter(keyFieldNames::contains) |
| .map(dfFieldNames::indexOf) |
| .collect(Collectors.toList()); |
| } |
| } |