Move to dataframe writes (#132)
diff --git a/core/src/test/java/io/onetable/TestSparkDeltaTable.java b/core/src/test/java/io/onetable/TestSparkDeltaTable.java
index 9515da1..529fce0 100644
--- a/core/src/test/java/io/onetable/TestSparkDeltaTable.java
+++ b/core/src/test/java/io/onetable/TestSparkDeltaTable.java
@@ -18,7 +18,6 @@
package io.onetable;
-import static io.onetable.delta.TestDeltaHelper.DATE_TIME_FORMATTER;
import static io.onetable.delta.TestDeltaHelper.createTestDataHelper;
import java.io.Closeable;
@@ -26,8 +25,6 @@
import java.io.UncheckedIOException;
import java.nio.file.Files;
import java.nio.file.Path;
-import java.sql.Timestamp;
-import java.time.LocalDateTime;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
@@ -96,15 +93,15 @@
@Override
public List<Row> insertRows(int numRows) {
List<Row> rows = testDeltaHelper.generateRows(numRows);
- String insertStatement = testDeltaHelper.generateSqlForDataInsert(tableName, rows);
- sparkSession.sql(insertStatement);
+ Dataset<Row> df = sparkSession.createDataFrame(rows, testDeltaHelper.getTableStructSchema());
+ df.write().format("delta").mode("append").save(basePath);
return rows;
}
public List<Row> insertRowsForPartition(int numRows, Integer partitionValue) {
List<Row> rows = testDeltaHelper.generateRowsForSpecificPartition(numRows, partitionValue);
- String insertStatement = testDeltaHelper.generateSqlForDataInsert(tableName, rows);
- sparkSession.sql(insertStatement);
+ Dataset<Row> df = sparkSession.createDataFrame(rows, testDeltaHelper.getTableStructSchema());
+ df.write().format("delta").mode("append").save(basePath);
return rows;
}
@@ -214,18 +211,7 @@
public Map<Integer, List<Row>> getRowsByPartition(List<Row> rows) {
return rows.stream()
- .collect(
- Collectors.groupingBy(
- row -> {
- try {
- LocalDateTime parsedDateTime =
- LocalDateTime.parse(row.getString(4), DATE_TIME_FORMATTER);
- Timestamp timestamp = Timestamp.valueOf(parsedDateTime);
- return timestamp.toLocalDateTime().getYear();
- } catch (Exception e) {
- throw new RuntimeException(e);
- }
- }));
+ .collect(Collectors.groupingBy(row -> row.getTimestamp(4).toLocalDateTime().getYear()));
}
@Override
diff --git a/core/src/test/java/io/onetable/delta/TestDeltaHelper.java b/core/src/test/java/io/onetable/delta/TestDeltaHelper.java
index 4a00dde..688e728 100644
--- a/core/src/test/java/io/onetable/delta/TestDeltaHelper.java
+++ b/core/src/test/java/io/onetable/delta/TestDeltaHelper.java
@@ -25,7 +25,6 @@
import java.sql.Timestamp;
import java.time.LocalDateTime;
import java.time.YearMonth;
-import java.time.format.DateTimeFormatter;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
@@ -41,7 +40,6 @@
import org.apache.spark.sql.Row;
import org.apache.spark.sql.RowFactory;
import org.apache.spark.sql.SparkSession;
-import org.apache.spark.sql.types.DataType;
import org.apache.spark.sql.types.Metadata;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
@@ -70,24 +68,15 @@
private static final Random RANDOM = new Random();
private static final String[] GENDERS = {"Male", "Female"};
- public static final DateTimeFormatter DATE_TIME_FORMATTER =
- DateTimeFormatter.ofPattern("yyyy-MM-dd HH:mm:ss");
-
StructType tableStructSchema;
- String selectTemplateForInserts;
- String insertIntoTemplateSqlStr;
boolean tableIsPartitioned;
boolean includeAdditionalColumns;
public static TestDeltaHelper createTestDataHelper(
boolean isPartitioned, boolean includeAdditionalColumns) {
StructType tableSchema = generateDynamicSchema(isPartitioned, includeAdditionalColumns);
- String selectTemplateForInserts = generateSqlSelectForInsert(tableSchema);
- String insertIntoTemplateSqlStr = generateSqlInsertTemplate(tableSchema);
return TestDeltaHelper.builder()
.tableStructSchema(tableSchema)
- .selectTemplateForInserts(selectTemplateForInserts)
- .insertIntoTemplateSqlStr(insertIntoTemplateSqlStr)
.tableIsPartitioned(isPartitioned)
.includeAdditionalColumns(includeAdditionalColumns)
.build();
@@ -171,7 +160,7 @@
}
}
- private String generateRandomTimeGivenYear(int yearValue) {
+ private Timestamp generateRandomTimeGivenYear(int yearValue) {
int month = RANDOM.nextInt(12) + 1;
int daysInMonth = YearMonth.of(yearValue, month).lengthOfMonth();
int day = RANDOM.nextInt(daysInMonth) + 1;
@@ -179,7 +168,7 @@
LocalDateTime localDateTime =
LocalDateTime.of(
yearValue, month, day, RANDOM.nextInt(24), RANDOM.nextInt(60), RANDOM.nextInt(60));
- return DATE_TIME_FORMATTER.format(localDateTime);
+ return Timestamp.valueOf(localDateTime);
}
public static String generateRandomString() {
@@ -191,24 +180,6 @@
return name.toString();
}
- public String generateSqlForDataInsert(String tableName, List<Row> rows) {
- List<String> selectsForInsert =
- rows.stream().map(this::generateSelectForRow).collect(Collectors.toList());
- return String.format(
- insertIntoTemplateSqlStr, tableName, String.join(" UNION ALL ", selectsForInsert));
- }
-
- public String generateSelectForRow(Row row) {
- List<Object> values = new ArrayList<>();
- for (int i = 0; i < row.size(); i++) {
- values.add(row.get(i));
- }
- if (tableIsPartitioned) {
- values.add(values.get(values.size() - 1));
- }
- return String.format(selectTemplateForInserts, values.toArray());
- }
-
public List<Row> transformForUpsertsOrDeletes(List<Row> rows, boolean isUpsert) {
// Generate random values for few columns for upserts.
// For deletes, retain the same values as the original row.
@@ -225,14 +196,12 @@
}
}
if (tableIsPartitioned) {
- LocalDateTime parsedDateTime =
- LocalDateTime.parse(row.getString(row.size() - 2), DATE_TIME_FORMATTER);
- newRowData[row.size() - 2] = Timestamp.valueOf(parsedDateTime);
- newRowData[row.size() - 1] = parsedDateTime.getYear();
+ Timestamp timestampValue = row.getTimestamp(row.size() - 2);
+ newRowData[row.size() - 2] = timestampValue;
+ newRowData[row.size() - 1] = timestampValue.toLocalDateTime().getYear();
} else {
- LocalDateTime parsedDateTime =
- LocalDateTime.parse(row.getString(row.size() - 1), DATE_TIME_FORMATTER);
- newRowData[row.size() - 1] = Timestamp.valueOf(parsedDateTime);
+ Timestamp timestampValue = row.getTimestamp(row.size() - 1);
+ newRowData[row.size() - 1] = timestampValue;
}
return RowFactory.create(newRowData);
})
@@ -250,38 +219,4 @@
.mapToObj(i -> generateRandomRowForGivenYear(partitionValue))
.collect(Collectors.toList());
}
-
- private static String generateSqlInsertTemplate(StructType schema) {
- String fieldList =
- Arrays.stream(schema.fields()).map(StructField::name).collect(Collectors.joining(", "));
- return String.format("INSERT INTO `%%s` (%s) %%s", fieldList);
- }
-
- private static String generateSqlSelectForInsert(StructType schema) {
- StructField[] fields = schema.fields();
- StringBuilder sqlBuilder = new StringBuilder("SELECT ");
- for (int i = 0; i < fields.length; i++) {
- StructField field = fields[i];
- sqlBuilder.append(getFormattedField(field));
- if (i < fields.length - 1) {
- sqlBuilder.append(", ");
- }
- }
- return sqlBuilder.toString();
- }
-
- private static String getFormattedField(StructField field) {
- String fieldName = field.name();
- DataType fieldType = field.dataType();
- if (fieldName.equals("yearOfBirth")) {
- return "year(timestamp('%s')) AS yearOfBirth";
- } else if (fieldType == IntegerType) {
- return String.format("%%d AS %s", fieldName);
- } else if (fieldType == StringType) {
- return String.format("'%%s' AS %s", fieldName);
- } else if (fieldType == TimestampType) {
- return String.format("timestamp('%%s') AS %s", fieldName);
- }
- throw new IllegalArgumentException("Unsupported field type: " + fieldType);
- }
}