blob: e9668da92bf9db8e551fee1c03540aeecec89787 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.iceberg.spark;
import java.util.List;
import java.util.Map;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
import org.apache.hadoop.hive.conf.HiveConf;
import org.apache.iceberg.catalog.Namespace;
import org.apache.iceberg.exceptions.AlreadyExistsException;
import org.apache.iceberg.hive.HiveCatalog;
import org.apache.iceberg.hive.TestHiveMetastore;
import org.apache.iceberg.relocated.com.google.common.collect.ImmutableList;
import org.apache.iceberg.relocated.com.google.common.collect.Iterables;
import org.apache.iceberg.relocated.com.google.common.collect.Maps;
import org.apache.spark.sql.Row;
import org.apache.spark.sql.SparkSession;
import org.apache.spark.sql.internal.SQLConf;
import org.junit.AfterClass;
import org.junit.Assert;
import org.junit.BeforeClass;
import static org.apache.hadoop.hive.conf.HiveConf.ConfVars.METASTOREURIS;
public abstract class SparkTestBase {
protected static final Object ANY = new Object();
protected static TestHiveMetastore metastore = null;
protected static HiveConf hiveConf = null;
protected static SparkSession spark = null;
protected static HiveCatalog catalog = null;
@BeforeClass
public static void startMetastoreAndSpark() {
SparkTestBase.metastore = new TestHiveMetastore();
metastore.start();
SparkTestBase.hiveConf = metastore.hiveConf();
SparkTestBase.spark = SparkSession.builder()
.master("local[2]")
.config(SQLConf.PARTITION_OVERWRITE_MODE().key(), "dynamic")
.config("spark.hadoop." + METASTOREURIS.varname, hiveConf.get(METASTOREURIS.varname))
.enableHiveSupport()
.getOrCreate();
SparkTestBase.catalog = new HiveCatalog(spark.sessionState().newHadoopConf());
try {
catalog.createNamespace(Namespace.of("default"));
} catch (AlreadyExistsException ignored) {
// the default namespace already exists. ignore the create error
}
}
@AfterClass
public static void stopMetastoreAndSpark() {
if (catalog != null) {
catalog.close();
}
SparkTestBase.catalog = null;
metastore.stop();
SparkTestBase.metastore = null;
spark.stop();
SparkTestBase.spark = null;
}
protected long waitUntilAfter(long timestampMillis) {
long current = System.currentTimeMillis();
while (current <= timestampMillis) {
current = System.currentTimeMillis();
}
return current;
}
protected List<Object[]> sql(String query, Object... args) {
List<Row> rows = spark.sql(String.format(query, args)).collectAsList();
if (rows.size() < 1) {
return ImmutableList.of();
}
return rows.stream()
.map(row -> IntStream.range(0, row.size())
.mapToObj(pos -> row.isNullAt(pos) ? null : row.get(pos))
.toArray(Object[]::new)
).collect(Collectors.toList());
}
protected Object scalarSql(String query, Object... args) {
List<Object[]> rows = sql(query, args);
Assert.assertEquals("Scalar SQL should return one row", 1, rows.size());
Object[] row = Iterables.getOnlyElement(rows);
Assert.assertEquals("Scalar SQL should return one value", 1, row.length);
return row[0];
}
protected Object[] row(Object... values) {
return values;
}
protected void assertEquals(String context, List<Object[]> expectedRows, List<Object[]> actualRows) {
Assert.assertEquals(context + ": number of results should match", expectedRows.size(), actualRows.size());
for (int row = 0; row < expectedRows.size(); row += 1) {
Object[] expected = expectedRows.get(row);
Object[] actual = actualRows.get(row);
Assert.assertEquals("Number of columns should match", expected.length, actual.length);
for (int col = 0; col < actualRows.get(row).length; col += 1) {
if (expected[col] != ANY) {
Assert.assertEquals(context + ": row " + row + " col " + col + " contents should match",
expected[col], actual[col]);
}
}
}
}
protected static String dbPath(String dbName) {
return metastore.getDatabasePath(dbName);
}
protected void withSQLConf(Map<String, String> conf, Action action) {
SQLConf sqlConf = SQLConf.get();
Map<String, String> currentConfValues = Maps.newHashMap();
conf.keySet().forEach(confKey -> {
if (sqlConf.contains(confKey)) {
String currentConfValue = sqlConf.getConfString(confKey);
currentConfValues.put(confKey, currentConfValue);
}
});
conf.forEach((confKey, confValue) -> {
if (SQLConf.staticConfKeys().contains(confKey)) {
throw new RuntimeException("Cannot modify the value of a static config: " + confKey);
}
sqlConf.setConfString(confKey, confValue);
});
try {
action.invoke();
} finally {
conf.forEach((confKey, confValue) -> {
if (currentConfValues.containsKey(confKey)) {
sqlConf.setConfString(confKey, currentConfValues.get(confKey));
} else {
sqlConf.unsetConf(confKey);
}
});
}
}
@FunctionalInterface
protected interface Action {
void invoke();
}
}