blob: c6eb47ea1028c89b8e88d9958eb328e18599770c [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.submarine.spark.security
import org.apache.spark.sql.SubmarineSparkUtils._
import org.apache.spark.sql.catalyst.plans.logical.{Filter, SubmarineRowFilter}
import org.apache.spark.sql.hive.test.TestHive
import org.scalatest.{BeforeAndAfterAll, FunSuite}
class RowFilterSQLTest extends FunSuite with BeforeAndAfterAll {
private val spark = TestHive.sparkSession.newSession()
private lazy val sql = spark.sql _
override def beforeAll(): Unit = {
super.beforeAll()
sql(
"""
|CREATE TABLE IF NOT EXISTS default.rangertbl1 AS SELECT * FROM default.src
""".stripMargin)
sql(
"""
|CREATE TABLE IF NOT EXISTS default.rangertbl2 AS SELECT * FROM default.src
""".stripMargin)
sql(
"""
|CREATE TABLE IF NOT EXISTS default.rangertbl3 AS SELECT * FROM default.src
""".stripMargin)
sql(
"""
|CREATE TABLE IF NOT EXISTS default.rangertbl4 AS SELECT * FROM default.src
""".stripMargin)
sql(
"""
|CREATE TABLE IF NOT EXISTS default.rangertbl5 AS SELECT * FROM default.src
""".stripMargin)
sql(
"""
|CREATE TABLE IF NOT EXISTS default.rangertbl6 AS SELECT * FROM default.src
""".stripMargin)
enableRowFilter(spark)
}
override def afterAll(): Unit = {
super.afterAll()
spark.reset()
}
test("simple query") {
val statement = "select * from default.src"
withUser("bob") {
val df = sql(statement)
assert(df.queryExecution.optimizedPlan.find(_.isInstanceOf[SubmarineRowFilter]).nonEmpty)
val row = df.take(1)(0)
assert(row.getInt(0) < 20, "keys above 20 should be filtered automatically")
assert(df.count() === 20, "keys above 20 should be filtered automatically")
}
withUser("alice") {
val df = sql(statement)
assert(df.count() === 500)
}
}
test("projection with ranger filter key") {
val statement = "select key from default.src"
withUser("bob") {
val df = sql(statement)
val row = df.take(1)(0)
assert(row.getInt(0) < 20)
}
withUser("alice") {
val df = sql(statement)
assert(df.count() === 500)
}
}
test("projection without ranger filter key") {
val statement = "select value from default.src"
withUser("bob") {
val df = sql(statement)
val row = df.take(1)(0)
assert(row.getString(0).split("_")(1).toInt < 20)
}
withUser("alice") {
val df = sql(statement)
assert(df.count() === 500)
}
}
test("filter with with ranger filter key") {
val statement = "select key from default.src where key = 0"
val statement2 = "select key from default.src where key >= 20"
withUser("bob") {
val df = sql(statement)
val row = df.take(1)(0)
assert(row.getInt(0) === 0)
val df2 = sql(statement2)
assert(df2.count() === 0, "all keys should be filtered")
}
withUser("alice") {
val df = sql(statement)
assert(df.count() === 3)
val df2 = sql(statement2)
assert(df2.count() === 480)
}
}
test("WITH alias") {
val statement = "select key as k1, value v1 from default.src"
withUser("bob") {
val df = sql(statement)
val row = df.take(1)(0)
assert(row.getInt(0) < 20, "keys above 20 should be filtered automatically")
assert(df.count() === 20, "keys above 20 should be filtered automatically")
}
withUser("alice") {
val df = sql(statement)
assert(df.count() === 500)
}
}
test("aggregate") {
val statement = "select sum(key) as k1, value v1 from default.src group by v1"
withUser("bob") {
val df = sql(statement)
println(df.queryExecution.optimizedPlan)
val row = df.take(1)(0)
assert(row.getString(1).split("_")(1).toInt < 20)
}
withUser("alice") {
val df = sql(statement)
assert(df.count() === 309)
}
}
test("with equal expression") {
val statement = "select * from default.rangertbl1"
withUser("bob") {
val df = sql(statement)
println(df.queryExecution.optimizedPlan)
val row = df.take(1)(0)
assert(row.getInt(0) === 0, "rangertbl1 has an internal expression key=0")
}
}
test("with in set") {
val statement = "select * from default.rangertbl2"
withUser("bob") {
val df = sql(statement)
println(df.queryExecution.optimizedPlan)
val row = df.take(1)(0)
assert(row.getInt(0) === 0, "rangertbl2 has an internal expression key in (0, 1, 2)")
}
}
test("with in subquery") {
val statement = "select * from default.rangertbl3"
withUser("bob") {
val df = sql(statement)
println(df.queryExecution.optimizedPlan)
val rows = df.collect()
assert(rows.forall(_.getInt(0) < 100), "rangertbl3 has an internal expression key in (query)")
}
}
test("with in subquery self joined") {
val statement = "select * from default.rangertbl4"
withUser("bob") {
val df = sql(statement)
println(df.queryExecution.optimizedPlan)
val rows = df.collect()
assert(rows.length === 500)
}
}
test("with udf") {
val statement = "select * from default.rangertbl5"
withUser("bob") {
val df = sql(statement)
println(df.queryExecution.optimizedPlan)
val rows = df.collect()
assert(rows.length === 0)
}
}
test("with multiple expressions") {
val statement = "select * from default.rangertbl6"
withUser("bob") {
val df = sql(statement)
println(df.queryExecution.optimizedPlan)
val rows = df.collect()
assert(rows.forall { r => val x = r.getInt(0); x > 1 && x < 10 || x == 500 })
}
}
test("applying filters with uncorrelated subquery") {
val statement =
s"""
|select
| *
|from default.rangertbl1 outer
|where value in (select value from default.rangertbl2)
|""".stripMargin
withUser("bob") {
val df = sql(statement)
val plan = df.queryExecution.optimizedPlan
println(plan)
assert(plan.collect { case _: Filter => true }.size === 2, "tbl 1 and 2 have 2 filters")
val row = df.take(1)(0)
assert(row.getInt(0) === 0, "tbl 1 and 2 have 2 filters")
}
}
test("applying filters with correlated subquery") {
val statement =
s"""
|select
| *
|from default.rangertbl1 outer
|where key =
| (select max(key) from default.rangertbl2 where value = outer.value)
|""".stripMargin
withUser("bob") {
val df = sql(statement)
val plan = df.queryExecution.optimizedPlan
println(plan)
assert(plan.collectLeaves().size <= plan.collect { case _: SubmarineRowFilter => true}.size)
val row = df.take(1)(0)
assert(row.getString(1) === "val_0", "tbl 1 and 2 have 2 filters")
}
}
test("CTE") {
val statement =
s"""
|with myCTE as
|(select
| *
|from default.rangertbl1)
|select t1.value, t2.value from myCTE t1 join myCTE t2 on t1.key = t2.key
|
|""".stripMargin
withUser("bob") {
val df = sql(statement)
println(df.queryExecution.optimizedPlan)
val row = df.take(1)(0)
assert(row.getString(0) === "val_0", "rangertbl1 has an internal expression key=0")
}
}
}