| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| package org.apache.submarine.spark.security |
| |
| import org.apache.spark.sql.SubmarineSparkUtils |
| import org.apache.spark.sql.catalyst.expressions.codegen.{CodeFormatter, CodeGenerator} |
| import org.apache.spark.sql.catalyst.plans.logical.{SubmarineDataMasking, SubmarineRowFilter} |
| import org.apache.spark.sql.catalyst.util.resourceToString |
| import org.apache.spark.sql.execution.{SparkPlan, WholeStageCodegenExec} |
| import org.apache.spark.sql.hive.test.TestHive |
| import org.apache.spark.sql.internal.SQLConf |
| import org.scalatest.{BeforeAndAfterAll, FunSuite} |
| |
| class TPCDSTest extends FunSuite with BeforeAndAfterAll { |
| import org.apache.spark.sql.SubmarineSparkUtils._ |
| private val spark = TestHive.sparkSession |
| private lazy val sql = spark.sql _ |
| |
| override def beforeAll(): Unit = { |
| super.beforeAll() |
| spark.conf.set(SQLConf.CROSS_JOINS_ENABLED.key, "true") |
| |
| sql( |
| """ |
| |CREATE TABLE `catalog_page` ( |
| |`cp_catalog_page_sk` INT, `cp_catalog_page_id` STRING, `cp_start_date_sk` INT, |
| |`cp_end_date_sk` INT, `cp_department` STRING, `cp_catalog_number` INT, |
| |`cp_catalog_page_number` INT, `cp_description` STRING, `cp_type` STRING) |
| |USING parquet |
| """.stripMargin) |
| |
| sql( |
| """ |
| |CREATE TABLE `catalog_returns` ( |
| |`cr_returned_date_sk` INT, `cr_returned_time_sk` INT, `cr_item_sk` INT, |
| |`cr_refunded_customer_sk` INT, `cr_refunded_cdemo_sk` INT, `cr_refunded_hdemo_sk` INT, |
| |`cr_refunded_addr_sk` INT, `cr_returning_customer_sk` INT, `cr_returning_cdemo_sk` INT, |
| |`cr_returning_hdemo_sk` INT, `cr_returning_addr_sk` INT, `cr_call_center_sk` INT, |
| |`cr_catalog_page_sk` INT, `cr_ship_mode_sk` INT, `cr_warehouse_sk` INT, `cr_reason_sk` INT, |
| |`cr_order_number` INT, `cr_return_quantity` INT, `cr_return_amount` DECIMAL(7,2), |
| |`cr_return_tax` DECIMAL(7,2), `cr_return_amt_inc_tax` DECIMAL(7,2), `cr_fee` DECIMAL(7,2), |
| |`cr_return_ship_cost` DECIMAL(7,2), `cr_refunded_cash` DECIMAL(7,2), |
| |`cr_reversed_charge` DECIMAL(7,2), `cr_store_credit` DECIMAL(7,2), |
| |`cr_net_loss` DECIMAL(7,2)) |
| |USING parquet |
| """.stripMargin) |
| |
| sql( |
| """ |
| |CREATE TABLE `customer` ( |
| |`c_customer_sk` INT, `c_customer_id` STRING, `c_current_cdemo_sk` INT, |
| |`c_current_hdemo_sk` INT, `c_current_addr_sk` INT, `c_first_shipto_date_sk` INT, |
| |`c_first_sales_date_sk` INT, `c_salutation` STRING, `c_first_name` STRING, |
| |`c_last_name` STRING, `c_preferred_cust_flag` STRING, `c_birth_day` INT, |
| |`c_birth_month` INT, `c_birth_year` INT, `c_birth_country` STRING, `c_login` STRING, |
| |`c_email_address` STRING, `c_last_review_date` STRING) |
| |USING parquet |
| """.stripMargin) |
| |
| sql( |
| """ |
| |CREATE TABLE `customer_address` ( |
| |`ca_address_sk` INT, `ca_address_id` STRING, `ca_street_number` STRING, |
| |`ca_street_name` STRING, `ca_street_type` STRING, `ca_suite_number` STRING, |
| |`ca_city` STRING, `ca_county` STRING, `ca_state` STRING, `ca_zip` STRING, |
| |`ca_country` STRING, `ca_gmt_offset` DECIMAL(5,2), `ca_location_type` STRING) |
| |USING parquet |
| """.stripMargin) |
| |
| sql( |
| """ |
| |CREATE TABLE `customer_demographics` ( |
| |`cd_demo_sk` INT, `cd_gender` STRING, `cd_marital_status` STRING, |
| |`cd_education_status` STRING, `cd_purchase_estimate` INT, `cd_credit_rating` STRING, |
| |`cd_dep_count` INT, `cd_dep_employed_count` INT, `cd_dep_college_count` INT) |
| |USING parquet |
| """.stripMargin) |
| |
| sql( |
| """ |
| |CREATE TABLE `date_dim` ( |
| |`d_date_sk` INT, `d_date_id` STRING, `d_date` STRING, |
| |`d_month_seq` INT, `d_week_seq` INT, `d_quarter_seq` INT, `d_year` INT, `d_dow` INT, |
| |`d_moy` INT, `d_dom` INT, `d_qoy` INT, `d_fy_year` INT, `d_fy_quarter_seq` INT, |
| |`d_fy_week_seq` INT, `d_day_name` STRING, `d_quarter_name` STRING, `d_holiday` STRING, |
| |`d_weekend` STRING, `d_following_holiday` STRING, `d_first_dom` INT, `d_last_dom` INT, |
| |`d_same_day_ly` INT, `d_same_day_lq` INT, `d_current_day` STRING, `d_current_week` STRING, |
| |`d_current_month` STRING, `d_current_quarter` STRING, `d_current_year` STRING) |
| |USING parquet |
| """.stripMargin) |
| |
| sql( |
| """ |
| |CREATE TABLE `household_demographics` ( |
| |`hd_demo_sk` INT, `hd_income_band_sk` INT, `hd_buy_potential` STRING, `hd_dep_count` INT, |
| |`hd_vehicle_count` INT) |
| |USING parquet |
| """.stripMargin) |
| |
| sql( |
| """ |
| |CREATE TABLE `inventory` (`inv_date_sk` INT, `inv_item_sk` INT, `inv_warehouse_sk` INT, |
| |`inv_quantity_on_hand` INT) |
| |USING parquet |
| """.stripMargin) |
| |
| sql( |
| """ |
| |CREATE TABLE `item` (`i_item_sk` INT, `i_item_id` STRING, `i_rec_start_date` STRING, |
| |`i_rec_end_date` STRING, `i_item_desc` STRING, `i_current_price` DECIMAL(7,2), |
| |`i_wholesale_cost` DECIMAL(7,2), `i_brand_id` INT, `i_brand` STRING, `i_class_id` INT, |
| |`i_class` STRING, `i_category_id` INT, `i_category` STRING, `i_manufact_id` INT, |
| |`i_manufact` STRING, `i_size` STRING, `i_formulation` STRING, `i_color` STRING, |
| |`i_units` STRING, `i_container` STRING, `i_manager_id` INT, `i_product_name` STRING) |
| |USING parquet |
| """.stripMargin) |
| |
| sql( |
| """ |
| |CREATE TABLE `promotion` ( |
| |`p_promo_sk` INT, `p_promo_id` STRING, `p_start_date_sk` INT, `p_end_date_sk` INT, |
| |`p_item_sk` INT, `p_cost` DECIMAL(15,2), `p_response_target` INT, `p_promo_name` STRING, |
| |`p_channel_dmail` STRING, `p_channel_email` STRING, `p_channel_catalog` STRING, |
| |`p_channel_tv` STRING, `p_channel_radio` STRING, `p_channel_press` STRING, |
| |`p_channel_event` STRING, `p_channel_demo` STRING, `p_channel_details` STRING, |
| |`p_purpose` STRING, `p_discount_active` STRING) |
| |USING parquet |
| """.stripMargin) |
| |
| sql( |
| """ |
| |CREATE TABLE `store` ( |
| |`s_store_sk` INT, `s_store_id` STRING, `s_rec_start_date` STRING, |
| |`s_rec_end_date` STRING, `s_closed_date_sk` INT, `s_store_name` STRING, |
| |`s_number_employees` INT, `s_floor_space` INT, `s_hours` STRING, `s_manager` STRING, |
| |`s_market_id` INT, `s_geography_class` STRING, `s_market_desc` STRING, |
| |`s_market_manager` STRING, `s_division_id` INT, `s_division_name` STRING, |
| |`s_company_id` INT, `s_company_name` STRING, `s_street_number` STRING, |
| |`s_street_name` STRING, `s_street_type` STRING, `s_suite_number` STRING, `s_city` STRING, |
| |`s_county` STRING, `s_state` STRING, `s_zip` STRING, `s_country` STRING, |
| |`s_gmt_offset` DECIMAL(5,2), `s_tax_precentage` DECIMAL(5,2)) |
| |USING parquet |
| """.stripMargin) |
| |
| sql( |
| """ |
| |CREATE TABLE `store_returns` ( |
| |`sr_returned_date_sk` BIGINT, `sr_return_time_sk` BIGINT, `sr_item_sk` BIGINT, |
| |`sr_customer_sk` BIGINT, `sr_cdemo_sk` BIGINT, `sr_hdemo_sk` BIGINT, `sr_addr_sk` BIGINT, |
| |`sr_store_sk` BIGINT, `sr_reason_sk` BIGINT, `sr_ticket_number` BIGINT, |
| |`sr_return_quantity` BIGINT, `sr_return_amt` DECIMAL(7,2), `sr_return_tax` DECIMAL(7,2), |
| |`sr_return_amt_inc_tax` DECIMAL(7,2), `sr_fee` DECIMAL(7,2), |
| |`sr_return_ship_cost` DECIMAL(7,2), `sr_refunded_cash` DECIMAL(7,2), |
| |`sr_reversed_charge` DECIMAL(7,2), `sr_store_credit` DECIMAL(7,2), |
| |`sr_net_loss` DECIMAL(7,2)) |
| |USING parquet |
| """.stripMargin) |
| |
| sql( |
| """ |
| |CREATE TABLE `catalog_sales` ( |
| |`cs_sold_date_sk` INT, `cs_sold_time_sk` INT, `cs_ship_date_sk` INT, |
| |`cs_bill_customer_sk` INT, `cs_bill_cdemo_sk` INT, `cs_bill_hdemo_sk` INT, |
| |`cs_bill_addr_sk` INT, `cs_ship_customer_sk` INT, `cs_ship_cdemo_sk` INT, |
| |`cs_ship_hdemo_sk` INT, `cs_ship_addr_sk` INT, `cs_call_center_sk` INT, |
| |`cs_catalog_page_sk` INT, `cs_ship_mode_sk` INT, `cs_warehouse_sk` INT, |
| |`cs_item_sk` INT, `cs_promo_sk` INT, `cs_order_number` INT, `cs_quantity` INT, |
| |`cs_wholesale_cost` DECIMAL(7,2), `cs_list_price` DECIMAL(7,2), |
| |`cs_sales_price` DECIMAL(7,2), `cs_ext_discount_amt` DECIMAL(7,2), |
| |`cs_ext_sales_price` DECIMAL(7,2), `cs_ext_wholesale_cost` DECIMAL(7,2), |
| |`cs_ext_list_price` DECIMAL(7,2), `cs_ext_tax` DECIMAL(7,2), `cs_coupon_amt` DECIMAL(7,2), |
| |`cs_ext_ship_cost` DECIMAL(7,2), `cs_net_paid` DECIMAL(7,2), |
| |`cs_net_paid_inc_tax` DECIMAL(7,2), `cs_net_paid_inc_ship` DECIMAL(7,2), |
| |`cs_net_paid_inc_ship_tax` DECIMAL(7,2), `cs_net_profit` DECIMAL(7,2)) |
| |USING parquet |
| """.stripMargin) |
| |
| sql( |
| """ |
| |CREATE TABLE `web_sales` ( |
| |`ws_sold_date_sk` INT, `ws_sold_time_sk` INT, `ws_ship_date_sk` INT, `ws_item_sk` INT, |
| |`ws_bill_customer_sk` INT, `ws_bill_cdemo_sk` INT, `ws_bill_hdemo_sk` INT, |
| |`ws_bill_addr_sk` INT, `ws_ship_customer_sk` INT, `ws_ship_cdemo_sk` INT, |
| |`ws_ship_hdemo_sk` INT, `ws_ship_addr_sk` INT, `ws_web_page_sk` INT, `ws_web_site_sk` INT, |
| |`ws_ship_mode_sk` INT, `ws_warehouse_sk` INT, `ws_promo_sk` INT, `ws_order_number` INT, |
| |`ws_quantity` INT, `ws_wholesale_cost` DECIMAL(7,2), `ws_list_price` DECIMAL(7,2), |
| |`ws_sales_price` DECIMAL(7,2), `ws_ext_discount_amt` DECIMAL(7,2), |
| |`ws_ext_sales_price` DECIMAL(7,2), `ws_ext_wholesale_cost` DECIMAL(7,2), |
| |`ws_ext_list_price` DECIMAL(7,2), `ws_ext_tax` DECIMAL(7,2), |
| |`ws_coupon_amt` DECIMAL(7,2), `ws_ext_ship_cost` DECIMAL(7,2), `ws_net_paid` DECIMAL(7,2), |
| |`ws_net_paid_inc_tax` DECIMAL(7,2), `ws_net_paid_inc_ship` DECIMAL(7,2), |
| |`ws_net_paid_inc_ship_tax` DECIMAL(7,2), `ws_net_profit` DECIMAL(7,2)) |
| |USING parquet |
| """.stripMargin) |
| |
| sql( |
| """ |
| |CREATE TABLE `store_sales` ( |
| |`ss_sold_date_sk` INT, `ss_sold_time_sk` INT, `ss_item_sk` INT, `ss_customer_sk` INT, |
| |`ss_cdemo_sk` INT, `ss_hdemo_sk` INT, `ss_addr_sk` INT, `ss_store_sk` INT, |
| |`ss_promo_sk` INT, `ss_ticket_number` INT, `ss_quantity` INT, |
| |`ss_wholesale_cost` DECIMAL(7,2), `ss_list_price` DECIMAL(7,2), |
| |`ss_sales_price` DECIMAL(7,2), `ss_ext_discount_amt` DECIMAL(7,2), |
| |`ss_ext_sales_price` DECIMAL(7,2), `ss_ext_wholesale_cost` DECIMAL(7,2), |
| |`ss_ext_list_price` DECIMAL(7,2), `ss_ext_tax` DECIMAL(7,2), |
| |`ss_coupon_amt` DECIMAL(7,2), `ss_net_paid` DECIMAL(7,2), |
| |`ss_net_paid_inc_tax` DECIMAL(7,2), `ss_net_profit` DECIMAL(7,2)) |
| |USING parquet |
| """.stripMargin) |
| |
| sql( |
| """ |
| |CREATE TABLE `web_returns` ( |
| |`wr_returned_date_sk` BIGINT, `wr_returned_time_sk` BIGINT, `wr_item_sk` BIGINT, |
| |`wr_refunded_customer_sk` BIGINT, `wr_refunded_cdemo_sk` BIGINT, |
| |`wr_refunded_hdemo_sk` BIGINT, `wr_refunded_addr_sk` BIGINT, |
| |`wr_returning_customer_sk` BIGINT, `wr_returning_cdemo_sk` BIGINT, |
| |`wr_returning_hdemo_sk` BIGINT, `wr_returning_addr_sk` BIGINT, `wr_web_page_sk` BIGINT, |
| |`wr_reason_sk` BIGINT, `wr_order_number` BIGINT, `wr_return_quantity` BIGINT, |
| |`wr_return_amt` DECIMAL(7,2), `wr_return_tax` DECIMAL(7,2), |
| |`wr_return_amt_inc_tax` DECIMAL(7,2), `wr_fee` DECIMAL(7,2), |
| |`wr_return_ship_cost` DECIMAL(7,2), `wr_refunded_cash` DECIMAL(7,2), |
| |`wr_reversed_charge` DECIMAL(7,2), `wr_account_credit` DECIMAL(7,2), |
| |`wr_net_loss` DECIMAL(7,2)) |
| |USING parquet |
| """.stripMargin) |
| |
| sql( |
| """ |
| |CREATE TABLE `web_site` ( |
| |`web_site_sk` INT, `web_site_id` STRING, `web_rec_start_date` DATE, |
| |`web_rec_end_date` DATE, `web_name` STRING, `web_open_date_sk` INT, |
| |`web_close_date_sk` INT, `web_class` STRING, `web_manager` STRING, `web_mkt_id` INT, |
| |`web_mkt_class` STRING, `web_mkt_desc` STRING, `web_market_manager` STRING, |
| |`web_company_id` INT, `web_company_name` STRING, `web_street_number` STRING, |
| |`web_street_name` STRING, `web_street_type` STRING, `web_suite_number` STRING, |
| |`web_city` STRING, `web_county` STRING, `web_state` STRING, `web_zip` STRING, |
| |`web_country` STRING, `web_gmt_offset` STRING, `web_tax_percentage` DECIMAL(5,2)) |
| |USING parquet |
| """.stripMargin) |
| |
| sql( |
| """ |
| |CREATE TABLE `reason` ( |
| |`r_reason_sk` INT, `r_reason_id` STRING, `r_reason_desc` STRING) |
| |USING parquet |
| """.stripMargin) |
| |
| sql( |
| """ |
| |CREATE TABLE `call_center` ( |
| |`cc_call_center_sk` INT, `cc_call_center_id` STRING, `cc_rec_start_date` DATE, |
| |`cc_rec_end_date` DATE, `cc_closed_date_sk` INT, `cc_open_date_sk` INT, `cc_name` STRING, |
| |`cc_class` STRING, `cc_employees` INT, `cc_sq_ft` INT, `cc_hours` STRING, |
| |`cc_manager` STRING, `cc_mkt_id` INT, `cc_mkt_class` STRING, `cc_mkt_desc` STRING, |
| |`cc_market_manager` STRING, `cc_division` INT, `cc_division_name` STRING, `cc_company` INT, |
| |`cc_company_name` STRING, `cc_street_number` STRING, `cc_street_name` STRING, |
| |`cc_street_type` STRING, `cc_suite_number` STRING, `cc_city` STRING, `cc_county` STRING, |
| |`cc_state` STRING, `cc_zip` STRING, `cc_country` STRING, `cc_gmt_offset` DECIMAL(5,2), |
| |`cc_tax_percentage` DECIMAL(5,2)) |
| |USING parquet |
| """.stripMargin) |
| |
| sql( |
| """ |
| |CREATE TABLE `warehouse` ( |
| |`w_warehouse_sk` INT, `w_warehouse_id` STRING, `w_warehouse_name` STRING, |
| |`w_warehouse_sq_ft` INT, `w_street_number` STRING, `w_street_name` STRING, |
| |`w_street_type` STRING, `w_suite_number` STRING, `w_city` STRING, `w_county` STRING, |
| |`w_state` STRING, `w_zip` STRING, `w_country` STRING, `w_gmt_offset` DECIMAL(5,2)) |
| |USING parquet |
| """.stripMargin) |
| |
| sql( |
| """ |
| |CREATE TABLE `ship_mode` ( |
| |`sm_ship_mode_sk` INT, `sm_ship_mode_id` STRING, `sm_type` STRING, `sm_code` STRING, |
| |`sm_carrier` STRING, `sm_contract` STRING) |
| |USING parquet |
| """.stripMargin) |
| |
| sql( |
| """ |
| |CREATE TABLE `income_band` ( |
| |`ib_income_band_sk` INT, `ib_lower_bound` INT, `ib_upper_bound` INT) |
| |USING parquet |
| """.stripMargin) |
| |
| sql( |
| """ |
| |CREATE TABLE `time_dim` ( |
| |`t_time_sk` INT, `t_time_id` STRING, `t_time` INT, `t_hour` INT, `t_minute` INT, |
| |`t_second` INT, `t_am_pm` STRING, `t_shift` STRING, `t_sub_shift` STRING, |
| |`t_meal_time` STRING) |
| |USING parquet |
| """.stripMargin) |
| |
| sql( |
| """ |
| |CREATE TABLE `web_page` (`wp_web_page_sk` INT, `wp_web_page_id` STRING, |
| |`wp_rec_start_date` DATE, `wp_rec_end_date` DATE, `wp_creation_date_sk` INT, |
| |`wp_access_date_sk` INT, `wp_autogen_flag` STRING, `wp_customer_sk` INT, |
| |`wp_url` STRING, `wp_type` STRING, `wp_char_count` INT, `wp_link_count` INT, |
| |`wp_image_count` INT, `wp_max_ad_count` INT) |
| |USING parquet |
| """.stripMargin) |
| |
| SubmarineSparkUtils.enableAll(spark) |
| } |
| |
| private val tpcdsQueries = Seq( |
| "q1", "q2", "q3", "q4", "q5", "q6", "q7", "q8", "q9", "q10", "q11", |
| "q12", "q13", "q14a", "q14b", "q15", "q16", "q17", "q18", "q19", "q20", |
| "q21", "q22", "q23a", "q23b", "q24a", "q24b", "q25", "q26", "q27", "q28", "q29", "q30", |
| "q31", "q32", "q33", "q34", "q35", "q36", "q37", "q38", "q39a", "q39b", "q40", |
| "q41", "q42", "q43", "q44", "q45", "q46", "q47", "q48", "q49", "q50", |
| "q51", "q52", "q53", "q54", "q55", "q56", "q57", "q58", "q59", "q60", |
| "q61", "q62", "q63", "q64", "q65", "q66", "q67", "q68", "q69", "q70", |
| "q71", "q72", "q73", "q74", "q75", "q76", "q77", "q78", "q79", "q80", |
| "q81", "q82", "q83", "q84", "q85", "q86", "q87", "q88", "q89", "q90", |
| "q91", "q92", "q93", "q94", "q95", "q96", "q97", "q98", "q99") |
| |
| tpcdsQueries.foreach { name => |
| val queryString = resourceToString(s"tpcds/$name.sql", |
| classLoader = Thread.currentThread().getContextClassLoader) |
| test(name) { |
| withUser("bob") { |
| val queryExecution = sql(queryString).queryExecution |
| val optimized = queryExecution.optimizedPlan |
| // println(optimized) |
| assert(optimized.find(_.isInstanceOf[SubmarineRowFilter]).nonEmpty) |
| assert(optimized.find(_.isInstanceOf[SubmarineDataMasking]).nonEmpty) |
| val plan = queryExecution.executedPlan |
| checkGeneratedCode(plan) |
| } |
| } |
| } |
| |
| /** |
| * Check whether the Modified queries can be properly compiled |
| */ |
| def checkGeneratedCode(plan: SparkPlan): Unit = { |
| val codegenSubtrees = new collection.mutable.HashSet[WholeStageCodegenExec]() |
| plan foreach { |
| case s: WholeStageCodegenExec => |
| codegenSubtrees += s |
| case s => s |
| } |
| codegenSubtrees.toSeq.foreach { subtree => |
| val code = subtree.doCodeGen()._2 |
| try { |
| // Just check the generated code can be properly compiled |
| CodeGenerator.compile(code) |
| } catch { |
| case e: Exception => |
| val msg = |
| s""" |
| |failed to compile: |
| |Subtree: |
| |$subtree |
| |Generated code: |
| |${CodeFormatter.format(code)} |
| """.stripMargin |
| throw new Exception(msg, e) |
| } |
| } |
| } |
| } |