blob: 93390aa088191dc10e7a2835898aa5f1e0730e77 [file] [log] [blame]
#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements. See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
from pyspark import StorageLevel
from pyspark.errors import AnalysisException
from pyspark.sql.types import StructType, StructField, IntegerType
from pyspark.testing.sqlutils import ReusedSQLTestCase
class CatalogTestsMixin:
def test_current_database(self):
spark = self.spark
with self.database("some_db"):
self.assertEqual(spark.catalog.currentDatabase(), "default")
spark.sql("CREATE DATABASE some_db")
spark.catalog.setCurrentDatabase("some_db")
self.assertEqual(spark.catalog.currentDatabase(), "some_db")
self.assertRaisesRegex(
AnalysisException,
"does_not_exist",
lambda: spark.catalog.setCurrentDatabase("does_not_exist"),
)
def test_list_databases(self):
spark = self.spark
with self.database("some_db"):
databases = [db.name for db in spark.catalog.listDatabases()]
self.assertEqual(databases, ["default"])
spark.sql("CREATE DATABASE some_db")
databases = [db.name for db in spark.catalog.listDatabases()]
self.assertEqual(sorted(databases), ["default", "some_db"])
databases = [db.name for db in spark.catalog.listDatabases("def*")]
self.assertEqual(sorted(databases), ["default"])
databases = [db.name for db in spark.catalog.listDatabases("def2*")]
self.assertEqual(sorted(databases), [])
def test_database_exists(self):
# SPARK-36207: testing that database_exists returns correct boolean
spark = self.spark
with self.database("some_db"):
self.assertFalse(spark.catalog.databaseExists("some_db"))
spark.sql("CREATE DATABASE some_db")
self.assertTrue(spark.catalog.databaseExists("some_db"))
self.assertTrue(spark.catalog.databaseExists("spark_catalog.some_db"))
self.assertFalse(spark.catalog.databaseExists("spark_catalog.some_db2"))
def test_get_database(self):
spark = self.spark
with self.database("some_db"):
spark.sql("CREATE DATABASE some_db")
db = spark.catalog.getDatabase("spark_catalog.some_db")
self.assertEqual(db.name, "some_db")
self.assertEqual(db.catalog, "spark_catalog")
def test_list_tables(self):
from pyspark.sql.catalog import Table
spark = self.spark
with self.database("some_db"):
spark.sql("CREATE DATABASE some_db")
with self.table("tab1", "some_db.tab2", "tab3_via_catalog"):
with self.tempView("temp_tab"):
self.assertEqual(spark.catalog.listTables(), [])
self.assertEqual(spark.catalog.listTables("some_db"), [])
spark.createDataFrame([(1, 1)]).createOrReplaceTempView("temp_tab")
spark.sql("CREATE TABLE tab1 (name STRING, age INT) USING parquet")
spark.sql("CREATE TABLE some_db.tab2 (name STRING, age INT) USING parquet")
schema = StructType([StructField("a", IntegerType(), True)])
description = "this a table created via Catalog.createTable()"
spark.catalog.createTable(
"tab3_via_catalog", schema=schema, description=description
)
tables = sorted(spark.catalog.listTables(), key=lambda t: t.name)
tablesDefault = sorted(
spark.catalog.listTables("default"), key=lambda t: t.name
)
tablesSomeDb = sorted(spark.catalog.listTables("some_db"), key=lambda t: t.name)
self.assertEqual(tables, tablesDefault)
self.assertEqual(len(tables), 3)
self.assertEqual(len(tablesSomeDb), 2)
# make table in old fashion
def makeTable(
name,
database,
description,
tableType,
isTemporary,
):
return Table(
name=name,
catalog=None,
namespace=[database] if database is not None else None,
description=description,
tableType=tableType,
isTemporary=isTemporary,
)
# compare tables in old fashion
def compareTables(t1, t2):
return (
t1.name == t2.name
and t1.database == t2.database
and t1.description == t2.description
and t1.tableType == t2.tableType
and t1.isTemporary == t2.isTemporary
)
self.assertTrue(
compareTables(
tables[0],
makeTable(
name="tab1",
database="default",
description=None,
tableType="MANAGED",
isTemporary=False,
),
)
)
self.assertTrue(
compareTables(
tables[1],
makeTable(
name="tab3_via_catalog",
database="default",
description=description,
tableType="MANAGED",
isTemporary=False,
),
)
)
self.assertTrue(
compareTables(
tables[2],
makeTable(
name="temp_tab",
database=None,
description=None,
tableType="TEMPORARY",
isTemporary=True,
),
)
)
self.assertTrue(
compareTables(
tablesSomeDb[0],
makeTable(
name="tab2",
database="some_db",
description=None,
tableType="MANAGED",
isTemporary=False,
),
)
)
self.assertTrue(
compareTables(
tablesSomeDb[1],
makeTable(
name="temp_tab",
database=None,
description=None,
tableType="TEMPORARY",
isTemporary=True,
),
)
)
self.assertRaisesRegex(
AnalysisException,
"does_not_exist",
lambda: spark.catalog.listTables("does_not_exist"),
)
def test_list_functions(self):
spark = self.spark
with self.database("some_db"):
spark.sql("CREATE DATABASE some_db")
functions = dict((f.name, f) for f in spark.catalog.listFunctions())
functionsDefault = dict((f.name, f) for f in spark.catalog.listFunctions("default"))
self.assertTrue(len(functions) > 200)
self.assertTrue("+" in functions)
self.assertTrue("like" in functions)
self.assertTrue("month" in functions)
self.assertTrue("to_date" in functions)
self.assertTrue("to_timestamp" in functions)
self.assertTrue("to_unix_timestamp" in functions)
self.assertTrue("current_database" in functions)
self.assertEqual(functions["+"].name, "+")
self.assertEqual(functions["+"].description, "expr1 + expr2 - Returns `expr1`+`expr2`.")
self.assertEqual(
functions["+"].className, "org.apache.spark.sql.catalyst.expressions.Add"
)
self.assertTrue(functions["+"].isTemporary)
self.assertEqual(functions, functionsDefault)
with self.function("func1", "some_db.func2"):
try:
spark.udf
support_udf = True
except Exception:
support_udf = False
if support_udf:
spark.udf.register("temp_func", lambda x: str(x))
spark.sql("CREATE FUNCTION func1 AS 'org.apache.spark.data.bricks'")
spark.sql("CREATE FUNCTION some_db.func2 AS 'org.apache.spark.data.bricks'")
newFunctions = dict((f.name, f) for f in spark.catalog.listFunctions())
newFunctionsSomeDb = dict(
(f.name, f) for f in spark.catalog.listFunctions("some_db")
)
self.assertTrue(set(functions).issubset(set(newFunctions)))
self.assertTrue(set(functions).issubset(set(newFunctionsSomeDb)))
if support_udf:
self.assertTrue("temp_func" in newFunctions)
self.assertTrue("func1" in newFunctions)
self.assertTrue("func2" not in newFunctions)
if support_udf:
self.assertTrue("temp_func" in newFunctionsSomeDb)
self.assertTrue("func1" not in newFunctionsSomeDb)
self.assertTrue("func2" in newFunctionsSomeDb)
self.assertRaisesRegex(
AnalysisException,
"does_not_exist",
lambda: spark.catalog.listFunctions("does_not_exist"),
)
def test_function_exists(self):
# SPARK-36258: testing that function_exists returns correct boolean
spark = self.spark
with self.function("func1"):
self.assertFalse(spark.catalog.functionExists("func1"))
self.assertFalse(spark.catalog.functionExists("default.func1"))
self.assertFalse(spark.catalog.functionExists("spark_catalog.default.func1"))
self.assertFalse(spark.catalog.functionExists("func1", "default"))
spark.sql("CREATE FUNCTION func1 AS 'org.apache.spark.data.bricks'")
self.assertTrue(spark.catalog.functionExists("func1"))
self.assertTrue(spark.catalog.functionExists("default.func1"))
self.assertTrue(spark.catalog.functionExists("spark_catalog.default.func1"))
self.assertTrue(spark.catalog.functionExists("func1", "default"))
def test_get_function(self):
spark = self.spark
with self.function("func1"):
spark.sql("CREATE FUNCTION func1 AS 'org.apache.spark.data.bricks'")
func1 = spark.catalog.getFunction("spark_catalog.default.func1")
self.assertTrue(func1.name == "func1")
self.assertTrue(func1.namespace == ["default"])
self.assertTrue(func1.catalog == "spark_catalog")
self.assertTrue(func1.className == "org.apache.spark.data.bricks")
self.assertFalse(func1.isTemporary)
def test_list_columns(self):
from pyspark.sql.catalog import Column
spark = self.spark
with self.database("some_db"):
spark.sql("CREATE DATABASE some_db")
with self.table("tab1", "some_db.tab2"):
spark.sql("CREATE TABLE tab1 (name STRING, age INT) USING parquet")
spark.sql(
"CREATE TABLE some_db.tab2 (nickname STRING, tolerance FLOAT) USING parquet"
)
columns = sorted(
spark.catalog.listColumns("spark_catalog.default.tab1"), key=lambda c: c.name
)
columnsDefault = sorted(
spark.catalog.listColumns("tab1", "default"), key=lambda c: c.name
)
self.assertEqual(columns, columnsDefault)
self.assertEqual(len(columns), 2)
self.assertEqual(
columns[0],
Column(
name="age",
description=None,
dataType="int",
nullable=True,
isPartition=False,
isBucket=False,
),
)
self.assertEqual(
columns[1],
Column(
name="name",
description=None,
dataType="string",
nullable=True,
isPartition=False,
isBucket=False,
),
)
columns2 = sorted(
spark.catalog.listColumns("tab2", "some_db"), key=lambda c: c.name
)
self.assertEqual(len(columns2), 2)
self.assertEqual(
columns2[0],
Column(
name="nickname",
description=None,
dataType="string",
nullable=True,
isPartition=False,
isBucket=False,
),
)
self.assertEqual(
columns2[1],
Column(
name="tolerance",
description=None,
dataType="float",
nullable=True,
isPartition=False,
isBucket=False,
),
)
self.assertRaisesRegex(
AnalysisException, "tab2", lambda: spark.catalog.listColumns("tab2")
)
self.assertRaisesRegex(
AnalysisException,
"does_not_exist",
lambda: spark.catalog.listColumns("does_not_exist"),
)
def test_table_cache(self):
spark = self.spark
with self.database("some_db"):
spark.sql("CREATE DATABASE some_db")
with self.table("tab1"):
spark.sql("CREATE TABLE some_db.tab1 (name STRING, age INT) USING parquet")
def if_cached(x):
return spark.catalog.isCached(x)
names = ["some_db.tab1", "spark_catalog.some_db.tab1"]
def assert_cached(c: bool):
if c:
self.assertTrue(all(map(if_cached, names)))
else:
self.assertFalse(any(map(if_cached, names)))
assert_cached(False)
spark.catalog.cacheTable("spark_catalog.some_db.tab1")
assert_cached(True)
spark.catalog.uncacheTable("spark_catalog.some_db.tab1")
assert_cached(False)
spark.catalog.cacheTable("spark_catalog.some_db.tab1", StorageLevel.MEMORY_ONLY)
assert_cached(True)
spark.catalog.clearCache()
assert_cached(False)
def test_table_exists(self):
# SPARK-36176: testing that table_exists returns correct boolean
spark = self.spark
with self.database("some_db"):
spark.sql("CREATE DATABASE some_db")
with self.table("tab1", "some_db.tab2"):
self.assertFalse(spark.catalog.tableExists("tab1"))
self.assertFalse(spark.catalog.tableExists("tab2", "some_db"))
spark.sql("CREATE TABLE tab1 (name STRING, age INT) USING parquet")
self.assertTrue(spark.catalog.tableExists("tab1"))
self.assertTrue(spark.catalog.tableExists("default.tab1"))
self.assertTrue(spark.catalog.tableExists("spark_catalog.default.tab1"))
self.assertTrue(spark.catalog.tableExists("tab1", "default"))
spark.sql("CREATE TABLE some_db.tab2 (name STRING, age INT) USING parquet")
self.assertFalse(spark.catalog.tableExists("tab2"))
self.assertTrue(spark.catalog.tableExists("some_db.tab2"))
self.assertTrue(spark.catalog.tableExists("spark_catalog.some_db.tab2"))
self.assertTrue(spark.catalog.tableExists("tab2", "some_db"))
def test_get_table(self):
spark = self.spark
with self.database("some_db"):
spark.sql("CREATE DATABASE some_db")
with self.table("tab1"):
spark.sql("CREATE TABLE tab1 (name STRING, age INT) USING parquet")
self.assertEqual(spark.catalog.getTable("tab1").database, "default")
self.assertEqual(spark.catalog.getTable("default.tab1").catalog, "spark_catalog")
self.assertEqual(spark.catalog.getTable("spark_catalog.default.tab1").name, "tab1")
def test_refresh_table(self):
import os
import tempfile
spark = self.spark
with tempfile.TemporaryDirectory() as tmp_dir:
with self.table("my_tab"):
spark.sql(
"CREATE TABLE my_tab (col STRING) USING TEXT LOCATION '{}'".format(tmp_dir)
)
spark.sql("INSERT INTO my_tab SELECT 'abc'")
spark.catalog.cacheTable("my_tab")
self.assertEqual(spark.table("my_tab").count(), 1)
os.system("rm -rf {}/*".format(tmp_dir))
self.assertEqual(spark.table("my_tab").count(), 1)
spark.catalog.refreshTable("spark_catalog.default.my_tab")
self.assertEqual(spark.table("my_tab").count(), 0)
class CatalogTests(ReusedSQLTestCase):
pass
if __name__ == "__main__":
import unittest
from pyspark.sql.tests.test_catalog import * # noqa: F401
try:
import xmlrunner
testRunner = xmlrunner.XMLTestRunner(output="target/test-reports", verbosity=2)
except ImportError:
testRunner = None
unittest.main(testRunner=testRunner, verbosity=2)