release version 3.0.0
diff --git a/.github/workflows/build-ce7-releases.yml b/.github/workflows/build-ce7-releases.yml index cf8fda8..fb51a3e 100644 --- a/.github/workflows/build-ce7-releases.yml +++ b/.github/workflows/build-ce7-releases.yml
@@ -12,7 +12,7 @@ strategy: matrix: sparkver: [spark303, spark333] - blazever: [2.0.9.1] + blazever: [3.0.0] steps: - uses: actions/checkout@v4
diff --git a/.github/workflows/tpcds.yml b/.github/workflows/tpcds.yml index 98722c4..2b280d9 100644 --- a/.github/workflows/tpcds.yml +++ b/.github/workflows/tpcds.yml
@@ -34,19 +34,18 @@ with: {version: "21.7"} - uses: actions-rust-lang/setup-rust-toolchain@v1 - with: {rustflags: --allow warnings -C target-cpu=native} + with: + toolchain: nightly + rustflags: --allow warnings -C target-feature=+aes + components: + cargo + rustfmt - name: Rustfmt Check uses: actions-rust-lang/rustfmt@v1 - ## - name: Rust Clippy Check - ## uses: actions-rs/clippy-check@v1 - ## with: - ## token: ${{ secrets.GITHUB_TOKEN }} - ## args: --all-features - - name: Cargo test - run: cargo test --workspace --all-features + run: cargo +nightly test --workspace --all-features - name: Build Spark303 run: mvn package -Ppre -Pspark303
diff --git a/Cargo.lock b/Cargo.lock index ef23afd..e270c47 100644 --- a/Cargo.lock +++ b/Cargo.lock
@@ -97,7 +97,7 @@ [[package]] name = "arrow" version = "50.0.0" -source = "git+https://github.com/blaze-init/arrow-rs.git?rev=2c39d9a251f7e3f8f15312bdd0c38759e465e8bc#2c39d9a251f7e3f8f15312bdd0c38759e465e8bc" +source = "git+https://github.com/blaze-init/arrow-rs.git?rev=7471d70f7ae6edd5d4da82b7d966a8ede720e499#7471d70f7ae6edd5d4da82b7d966a8ede720e499" dependencies = [ "arrow-arith", "arrow-array", @@ -117,7 +117,7 @@ [[package]] name = "arrow-arith" version = "50.0.0" -source = "git+https://github.com/blaze-init/arrow-rs.git?rev=2c39d9a251f7e3f8f15312bdd0c38759e465e8bc#2c39d9a251f7e3f8f15312bdd0c38759e465e8bc" +source = "git+https://github.com/blaze-init/arrow-rs.git?rev=7471d70f7ae6edd5d4da82b7d966a8ede720e499#7471d70f7ae6edd5d4da82b7d966a8ede720e499" dependencies = [ "arrow-array", "arrow-buffer", @@ -131,7 +131,7 @@ [[package]] name = "arrow-array" version = "50.0.0" -source = "git+https://github.com/blaze-init/arrow-rs.git?rev=2c39d9a251f7e3f8f15312bdd0c38759e465e8bc#2c39d9a251f7e3f8f15312bdd0c38759e465e8bc" +source = "git+https://github.com/blaze-init/arrow-rs.git?rev=7471d70f7ae6edd5d4da82b7d966a8ede720e499#7471d70f7ae6edd5d4da82b7d966a8ede720e499" dependencies = [ "ahash", "arrow-buffer", @@ -147,7 +147,7 @@ [[package]] name = "arrow-buffer" version = "50.0.0" -source = "git+https://github.com/blaze-init/arrow-rs.git?rev=2c39d9a251f7e3f8f15312bdd0c38759e465e8bc#2c39d9a251f7e3f8f15312bdd0c38759e465e8bc" +source = "git+https://github.com/blaze-init/arrow-rs.git?rev=7471d70f7ae6edd5d4da82b7d966a8ede720e499#7471d70f7ae6edd5d4da82b7d966a8ede720e499" dependencies = [ "bytes", "half", @@ -157,7 +157,7 @@ [[package]] name = "arrow-cast" version = "50.0.0" -source = "git+https://github.com/blaze-init/arrow-rs.git?rev=2c39d9a251f7e3f8f15312bdd0c38759e465e8bc#2c39d9a251f7e3f8f15312bdd0c38759e465e8bc" +source = "git+https://github.com/blaze-init/arrow-rs.git?rev=7471d70f7ae6edd5d4da82b7d966a8ede720e499#7471d70f7ae6edd5d4da82b7d966a8ede720e499" dependencies = [ "arrow-array", "arrow-buffer", @@ -175,7 +175,7 @@ [[package]] name = "arrow-csv" version = "50.0.0" -source = "git+https://github.com/blaze-init/arrow-rs.git?rev=2c39d9a251f7e3f8f15312bdd0c38759e465e8bc#2c39d9a251f7e3f8f15312bdd0c38759e465e8bc" +source = "git+https://github.com/blaze-init/arrow-rs.git?rev=7471d70f7ae6edd5d4da82b7d966a8ede720e499#7471d70f7ae6edd5d4da82b7d966a8ede720e499" dependencies = [ "arrow-array", "arrow-buffer", @@ -193,7 +193,7 @@ [[package]] name = "arrow-data" version = "50.0.0" -source = "git+https://github.com/blaze-init/arrow-rs.git?rev=2c39d9a251f7e3f8f15312bdd0c38759e465e8bc#2c39d9a251f7e3f8f15312bdd0c38759e465e8bc" +source = "git+https://github.com/blaze-init/arrow-rs.git?rev=7471d70f7ae6edd5d4da82b7d966a8ede720e499#7471d70f7ae6edd5d4da82b7d966a8ede720e499" dependencies = [ "arrow-buffer", "arrow-schema", @@ -204,7 +204,7 @@ [[package]] name = "arrow-ipc" version = "50.0.0" -source = "git+https://github.com/blaze-init/arrow-rs.git?rev=2c39d9a251f7e3f8f15312bdd0c38759e465e8bc#2c39d9a251f7e3f8f15312bdd0c38759e465e8bc" +source = "git+https://github.com/blaze-init/arrow-rs.git?rev=7471d70f7ae6edd5d4da82b7d966a8ede720e499#7471d70f7ae6edd5d4da82b7d966a8ede720e499" dependencies = [ "arrow-array", "arrow-buffer", @@ -218,7 +218,7 @@ [[package]] name = "arrow-json" version = "50.0.0" -source = "git+https://github.com/blaze-init/arrow-rs.git?rev=2c39d9a251f7e3f8f15312bdd0c38759e465e8bc#2c39d9a251f7e3f8f15312bdd0c38759e465e8bc" +source = "git+https://github.com/blaze-init/arrow-rs.git?rev=7471d70f7ae6edd5d4da82b7d966a8ede720e499#7471d70f7ae6edd5d4da82b7d966a8ede720e499" dependencies = [ "arrow-array", "arrow-buffer", @@ -237,7 +237,7 @@ [[package]] name = "arrow-ord" version = "50.0.0" -source = "git+https://github.com/blaze-init/arrow-rs.git?rev=2c39d9a251f7e3f8f15312bdd0c38759e465e8bc#2c39d9a251f7e3f8f15312bdd0c38759e465e8bc" +source = "git+https://github.com/blaze-init/arrow-rs.git?rev=7471d70f7ae6edd5d4da82b7d966a8ede720e499#7471d70f7ae6edd5d4da82b7d966a8ede720e499" dependencies = [ "arrow-array", "arrow-buffer", @@ -251,7 +251,7 @@ [[package]] name = "arrow-row" version = "50.0.0" -source = "git+https://github.com/blaze-init/arrow-rs.git?rev=2c39d9a251f7e3f8f15312bdd0c38759e465e8bc#2c39d9a251f7e3f8f15312bdd0c38759e465e8bc" +source = "git+https://github.com/blaze-init/arrow-rs.git?rev=7471d70f7ae6edd5d4da82b7d966a8ede720e499#7471d70f7ae6edd5d4da82b7d966a8ede720e499" dependencies = [ "ahash", "arrow-array", @@ -265,7 +265,7 @@ [[package]] name = "arrow-schema" version = "50.0.0" -source = "git+https://github.com/blaze-init/arrow-rs.git?rev=2c39d9a251f7e3f8f15312bdd0c38759e465e8bc#2c39d9a251f7e3f8f15312bdd0c38759e465e8bc" +source = "git+https://github.com/blaze-init/arrow-rs.git?rev=7471d70f7ae6edd5d4da82b7d966a8ede720e499#7471d70f7ae6edd5d4da82b7d966a8ede720e499" dependencies = [ "bitflags 2.5.0", "serde", @@ -274,7 +274,7 @@ [[package]] name = "arrow-select" version = "50.0.0" -source = "git+https://github.com/blaze-init/arrow-rs.git?rev=2c39d9a251f7e3f8f15312bdd0c38759e465e8bc#2c39d9a251f7e3f8f15312bdd0c38759e465e8bc" +source = "git+https://github.com/blaze-init/arrow-rs.git?rev=7471d70f7ae6edd5d4da82b7d966a8ede720e499#7471d70f7ae6edd5d4da82b7d966a8ede720e499" dependencies = [ "ahash", "arrow-array", @@ -287,7 +287,7 @@ [[package]] name = "arrow-string" version = "50.0.0" -source = "git+https://github.com/blaze-init/arrow-rs.git?rev=2c39d9a251f7e3f8f15312bdd0c38759e465e8bc#2c39d9a251f7e3f8f15312bdd0c38759e465e8bc" +source = "git+https://github.com/blaze-init/arrow-rs.git?rev=7471d70f7ae6edd5d4da82b7d966a8ede720e499#7471d70f7ae6edd5d4da82b7d966a8ede720e499" dependencies = [ "arrow-array", "arrow-buffer", @@ -751,7 +751,7 @@ [[package]] name = "datafusion" version = "36.0.0" -source = "git+https://github.com/blaze-init/arrow-datafusion.git?rev=71433f743b2c399ea1728531b0e56fd7c6ef5282#71433f743b2c399ea1728531b0e56fd7c6ef5282" +source = "git+https://github.com/blaze-init/arrow-datafusion.git?rev=17b1ad3c7432391b94dd54e48a60db6d5712a7ef#17b1ad3c7432391b94dd54e48a60db6d5712a7ef" dependencies = [ "ahash", "arrow", @@ -800,7 +800,7 @@ [[package]] name = "datafusion-common" version = "36.0.0" -source = "git+https://github.com/blaze-init/arrow-datafusion.git?rev=71433f743b2c399ea1728531b0e56fd7c6ef5282#71433f743b2c399ea1728531b0e56fd7c6ef5282" +source = "git+https://github.com/blaze-init/arrow-datafusion.git?rev=17b1ad3c7432391b94dd54e48a60db6d5712a7ef#17b1ad3c7432391b94dd54e48a60db6d5712a7ef" dependencies = [ "ahash", "arrow", @@ -819,7 +819,7 @@ [[package]] name = "datafusion-execution" version = "36.0.0" -source = "git+https://github.com/blaze-init/arrow-datafusion.git?rev=71433f743b2c399ea1728531b0e56fd7c6ef5282#71433f743b2c399ea1728531b0e56fd7c6ef5282" +source = "git+https://github.com/blaze-init/arrow-datafusion.git?rev=17b1ad3c7432391b94dd54e48a60db6d5712a7ef#17b1ad3c7432391b94dd54e48a60db6d5712a7ef" dependencies = [ "arrow", "chrono", @@ -839,7 +839,7 @@ [[package]] name = "datafusion-expr" version = "36.0.0" -source = "git+https://github.com/blaze-init/arrow-datafusion.git?rev=71433f743b2c399ea1728531b0e56fd7c6ef5282#71433f743b2c399ea1728531b0e56fd7c6ef5282" +source = "git+https://github.com/blaze-init/arrow-datafusion.git?rev=17b1ad3c7432391b94dd54e48a60db6d5712a7ef#17b1ad3c7432391b94dd54e48a60db6d5712a7ef" dependencies = [ "ahash", "arrow", @@ -923,6 +923,7 @@ "arrow", "async-trait", "base64 0.22.1", + "bitvec", "blaze-jni-bridge", "byteorder", "bytes", @@ -957,7 +958,7 @@ [[package]] name = "datafusion-functions" version = "36.0.0" -source = "git+https://github.com/blaze-init/arrow-datafusion.git?rev=71433f743b2c399ea1728531b0e56fd7c6ef5282#71433f743b2c399ea1728531b0e56fd7c6ef5282" +source = "git+https://github.com/blaze-init/arrow-datafusion.git?rev=17b1ad3c7432391b94dd54e48a60db6d5712a7ef#17b1ad3c7432391b94dd54e48a60db6d5712a7ef" dependencies = [ "arrow", "base64 0.21.7", @@ -971,7 +972,7 @@ [[package]] name = "datafusion-functions-array" version = "36.0.0" -source = "git+https://github.com/blaze-init/arrow-datafusion.git?rev=71433f743b2c399ea1728531b0e56fd7c6ef5282#71433f743b2c399ea1728531b0e56fd7c6ef5282" +source = "git+https://github.com/blaze-init/arrow-datafusion.git?rev=17b1ad3c7432391b94dd54e48a60db6d5712a7ef#17b1ad3c7432391b94dd54e48a60db6d5712a7ef" dependencies = [ "arrow", "datafusion-common", @@ -984,7 +985,7 @@ [[package]] name = "datafusion-optimizer" version = "36.0.0" -source = "git+https://github.com/blaze-init/arrow-datafusion.git?rev=71433f743b2c399ea1728531b0e56fd7c6ef5282#71433f743b2c399ea1728531b0e56fd7c6ef5282" +source = "git+https://github.com/blaze-init/arrow-datafusion.git?rev=17b1ad3c7432391b94dd54e48a60db6d5712a7ef#17b1ad3c7432391b94dd54e48a60db6d5712a7ef" dependencies = [ "arrow", "async-trait", @@ -1001,7 +1002,7 @@ [[package]] name = "datafusion-physical-expr" version = "36.0.0" -source = "git+https://github.com/blaze-init/arrow-datafusion.git?rev=71433f743b2c399ea1728531b0e56fd7c6ef5282#71433f743b2c399ea1728531b0e56fd7c6ef5282" +source = "git+https://github.com/blaze-init/arrow-datafusion.git?rev=17b1ad3c7432391b94dd54e48a60db6d5712a7ef#17b1ad3c7432391b94dd54e48a60db6d5712a7ef" dependencies = [ "ahash", "arrow", @@ -1036,7 +1037,7 @@ [[package]] name = "datafusion-physical-plan" version = "36.0.0" -source = "git+https://github.com/blaze-init/arrow-datafusion.git?rev=71433f743b2c399ea1728531b0e56fd7c6ef5282#71433f743b2c399ea1728531b0e56fd7c6ef5282" +source = "git+https://github.com/blaze-init/arrow-datafusion.git?rev=17b1ad3c7432391b94dd54e48a60db6d5712a7ef#17b1ad3c7432391b94dd54e48a60db6d5712a7ef" dependencies = [ "ahash", "arrow", @@ -1066,7 +1067,7 @@ [[package]] name = "datafusion-sql" version = "36.0.0" -source = "git+https://github.com/blaze-init/arrow-datafusion.git?rev=71433f743b2c399ea1728531b0e56fd7c6ef5282#71433f743b2c399ea1728531b0e56fd7c6ef5282" +source = "git+https://github.com/blaze-init/arrow-datafusion.git?rev=17b1ad3c7432391b94dd54e48a60db6d5712a7ef#17b1ad3c7432391b94dd54e48a60db6d5712a7ef" dependencies = [ "arrow", "arrow-schema", @@ -1858,7 +1859,7 @@ [[package]] name = "parquet" version = "50.0.0" -source = "git+https://github.com/blaze-init/arrow-rs.git?rev=2c39d9a251f7e3f8f15312bdd0c38759e465e8bc#2c39d9a251f7e3f8f15312bdd0c38759e465e8bc" +source = "git+https://github.com/blaze-init/arrow-rs.git?rev=7471d70f7ae6edd5d4da82b7d966a8ede720e499#7471d70f7ae6edd5d4da82b7d966a8ede720e499" dependencies = [ "ahash", "arrow-array",
diff --git a/Cargo.toml b/Cargo.toml index 5052eab..ad86c08 100644 --- a/Cargo.toml +++ b/Cargo.toml
@@ -64,26 +64,26 @@ [patch.crates-io] # datafusion: branch=v36-blaze -datafusion = { git = "https://github.com/blaze-init/arrow-datafusion.git", rev = "71433f743b2c399ea1728531b0e56fd7c6ef5282"} -datafusion-common = { git = "https://github.com/blaze-init/arrow-datafusion.git", rev = "71433f743b2c399ea1728531b0e56fd7c6ef5282"} -datafusion-expr = { git = "https://github.com/blaze-init/arrow-datafusion.git", rev = "71433f743b2c399ea1728531b0e56fd7c6ef5282"} -datafusion-execution = { git = "https://github.com/blaze-init/arrow-datafusion.git", rev = "71433f743b2c399ea1728531b0e56fd7c6ef5282"} -datafusion-optimizer = { git = "https://github.com/blaze-init/arrow-datafusion.git", rev = "71433f743b2c399ea1728531b0e56fd7c6ef5282"} -datafusion-physical-expr = { git = "https://github.com/blaze-init/arrow-datafusion.git", rev = "71433f743b2c399ea1728531b0e56fd7c6ef5282"} +datafusion = { git = "https://github.com/blaze-init/arrow-datafusion.git", rev = "17b1ad3c7432391b94dd54e48a60db6d5712a7ef"} +datafusion-common = { git = "https://github.com/blaze-init/arrow-datafusion.git", rev = "17b1ad3c7432391b94dd54e48a60db6d5712a7ef"} +datafusion-expr = { git = "https://github.com/blaze-init/arrow-datafusion.git", rev = "17b1ad3c7432391b94dd54e48a60db6d5712a7ef"} +datafusion-execution = { git = "https://github.com/blaze-init/arrow-datafusion.git", rev = "17b1ad3c7432391b94dd54e48a60db6d5712a7ef"} +datafusion-optimizer = { git = "https://github.com/blaze-init/arrow-datafusion.git", rev = "17b1ad3c7432391b94dd54e48a60db6d5712a7ef"} +datafusion-physical-expr = { git = "https://github.com/blaze-init/arrow-datafusion.git", rev = "17b1ad3c7432391b94dd54e48a60db6d5712a7ef"} # arrow: branch=v50-blaze -arrow = { git = "https://github.com/blaze-init/arrow-rs.git", rev = "2c39d9a251f7e3f8f15312bdd0c38759e465e8bc"} -arrow-arith = { git = "https://github.com/blaze-init/arrow-rs.git", rev = "2c39d9a251f7e3f8f15312bdd0c38759e465e8bc"} -arrow-array = { git = "https://github.com/blaze-init/arrow-rs.git", rev = "2c39d9a251f7e3f8f15312bdd0c38759e465e8bc"} -arrow-buffer = { git = "https://github.com/blaze-init/arrow-rs.git", rev = "2c39d9a251f7e3f8f15312bdd0c38759e465e8bc"} -arrow-cast = { git = "https://github.com/blaze-init/arrow-rs.git", rev = "2c39d9a251f7e3f8f15312bdd0c38759e465e8bc"} -arrow-data = { git = "https://github.com/blaze-init/arrow-rs.git", rev = "2c39d9a251f7e3f8f15312bdd0c38759e465e8bc"} -arrow-ord = { git = "https://github.com/blaze-init/arrow-rs.git", rev = "2c39d9a251f7e3f8f15312bdd0c38759e465e8bc"} -arrow-row = { git = "https://github.com/blaze-init/arrow-rs.git", rev = "2c39d9a251f7e3f8f15312bdd0c38759e465e8bc"} -arrow-schema = { git = "https://github.com/blaze-init/arrow-rs.git", rev = "2c39d9a251f7e3f8f15312bdd0c38759e465e8bc"} -arrow-select = { git = "https://github.com/blaze-init/arrow-rs.git", rev = "2c39d9a251f7e3f8f15312bdd0c38759e465e8bc"} -arrow-string = { git = "https://github.com/blaze-init/arrow-rs.git", rev = "2c39d9a251f7e3f8f15312bdd0c38759e465e8bc"} -parquet = { git = "https://github.com/blaze-init/arrow-rs.git", rev = "2c39d9a251f7e3f8f15312bdd0c38759e465e8bc"} +arrow = { git = "https://github.com/blaze-init/arrow-rs.git", rev = "7471d70f7ae6edd5d4da82b7d966a8ede720e499"} +arrow-arith = { git = "https://github.com/blaze-init/arrow-rs.git", rev = "7471d70f7ae6edd5d4da82b7d966a8ede720e499"} +arrow-array = { git = "https://github.com/blaze-init/arrow-rs.git", rev = "7471d70f7ae6edd5d4da82b7d966a8ede720e499"} +arrow-buffer = { git = "https://github.com/blaze-init/arrow-rs.git", rev = "7471d70f7ae6edd5d4da82b7d966a8ede720e499"} +arrow-cast = { git = "https://github.com/blaze-init/arrow-rs.git", rev = "7471d70f7ae6edd5d4da82b7d966a8ede720e499"} +arrow-data = { git = "https://github.com/blaze-init/arrow-rs.git", rev = "7471d70f7ae6edd5d4da82b7d966a8ede720e499"} +arrow-ord = { git = "https://github.com/blaze-init/arrow-rs.git", rev = "7471d70f7ae6edd5d4da82b7d966a8ede720e499"} +arrow-row = { git = "https://github.com/blaze-init/arrow-rs.git", rev = "7471d70f7ae6edd5d4da82b7d966a8ede720e499"} +arrow-schema = { git = "https://github.com/blaze-init/arrow-rs.git", rev = "7471d70f7ae6edd5d4da82b7d966a8ede720e499"} +arrow-select = { git = "https://github.com/blaze-init/arrow-rs.git", rev = "7471d70f7ae6edd5d4da82b7d966a8ede720e499"} +arrow-string = { git = "https://github.com/blaze-init/arrow-rs.git", rev = "7471d70f7ae6edd5d4da82b7d966a8ede720e499"} +parquet = { git = "https://github.com/blaze-init/arrow-rs.git", rev = "7471d70f7ae6edd5d4da82b7d966a8ede720e499"} # serde_json: branch=v1.0.96-blaze serde_json = { git = "https://github.com/blaze-init/json", branch = "v1.0.96-blaze" }
diff --git a/native-engine/blaze-jni-bridge/src/conf.rs b/native-engine/blaze-jni-bridge/src/conf.rs index dd476ed..9eccc0e 100644 --- a/native-engine/blaze-jni-bridge/src/conf.rs +++ b/native-engine/blaze-jni-bridge/src/conf.rs
@@ -41,6 +41,8 @@ define_conf!(BooleanConf, PARTIAL_AGG_SKIPPING_ENABLE); define_conf!(DoubleConf, PARTIAL_AGG_SKIPPING_RATIO); define_conf!(IntConf, PARTIAL_AGG_SKIPPING_MIN_ROWS); +define_conf!(BooleanConf, PARQUET_ENABLE_PAGE_FILTERING); +define_conf!(BooleanConf, PARQUET_ENABLE_BLOOM_FILTER); pub trait BooleanConf { fn key(&self) -> &'static str;
diff --git a/native-engine/blaze-serde/proto/blaze.proto b/native-engine/blaze-serde/proto/blaze.proto index 9818a13..b424f1f 100644 --- a/native-engine/blaze-serde/proto/blaze.proto +++ b/native-engine/blaze-serde/proto/blaze.proto
@@ -35,19 +35,19 @@ FilterExecNode filter = 8; UnionExecNode union = 9; SortMergeJoinExecNode sort_merge_join = 10; - BroadcastJoinExecNode broadcast_join = 11; - RenameColumnsExecNode rename_columns = 12; - EmptyPartitionsExecNode empty_partitions = 13; - AggExecNode agg = 14; - LimitExecNode limit = 15; - FFIReaderExecNode ffi_reader = 16; - CoalesceBatchesExecNode coalesce_batches = 17; - ExpandExecNode expand = 18; - RssShuffleWriterExecNode rss_shuffle_writer= 19; - WindowExecNode window = 20; - GenerateExecNode generate = 21; - ParquetSinkExecNode parquet_sink = 22; - BroadcastNestedLoopJoinExecNode broadcast_nested_loop_join = 23; + BroadcastJoinBuildHashMapExecNode broadcast_join_build_hash_map = 11; + BroadcastJoinExecNode broadcast_join = 12; + RenameColumnsExecNode rename_columns = 13; + EmptyPartitionsExecNode empty_partitions = 14; + AggExecNode agg = 15; + LimitExecNode limit = 16; + FFIReaderExecNode ffi_reader = 17; + CoalesceBatchesExecNode coalesce_batches = 18; + ExpandExecNode expand = 19; + RssShuffleWriterExecNode rss_shuffle_writer= 20; + WindowExecNode window = 21; + GenerateExecNode generate = 22; + ParquetSinkExecNode parquet_sink = 23; } } @@ -398,27 +398,28 @@ } message SortMergeJoinExecNode { - PhysicalPlanNode left = 1; - PhysicalPlanNode right = 2; - repeated JoinOn on = 3; - repeated SortOptions sort_options = 4; - JoinType join_type = 5; - JoinFilter join_filter = 6; + Schema schema = 1; + PhysicalPlanNode left = 2; + PhysicalPlanNode right = 3; + repeated JoinOn on = 4; + repeated SortOptions sort_options = 5; + JoinType join_type = 6; + JoinFilter join_filter = 7; +} + +message BroadcastJoinBuildHashMapExecNode { + PhysicalPlanNode input = 1; + repeated PhysicalExprNode keys =2; } message BroadcastJoinExecNode { - PhysicalPlanNode left = 1; - PhysicalPlanNode right = 2; - repeated JoinOn on = 3; - JoinType join_type = 4; - JoinFilter join_filter = 5; -} - -message BroadcastNestedLoopJoinExecNode { - PhysicalPlanNode left = 1; - PhysicalPlanNode right = 2; - JoinType join_type = 3; - JoinFilter join_filter = 4; + Schema schema = 1; + PhysicalPlanNode left = 2; + PhysicalPlanNode right = 3; + repeated JoinOn on = 4; + JoinType join_type = 5; + JoinSide broadcast_side = 6; + string cached_build_hash_map_id = 7; } message RenameColumnsExecNode { @@ -438,6 +439,7 @@ FULL = 3; SEMI = 4; ANTI = 5; + EXISTENCE = 6; } message SortOptions { @@ -456,8 +458,8 @@ } message JoinOn { - PhysicalColumn left = 1; - PhysicalColumn right = 2; + PhysicalExprNode left = 1; + PhysicalExprNode right = 2; } message ProjectionExecNode {
diff --git a/native-engine/blaze-serde/src/from_proto.rs b/native-engine/blaze-serde/src/from_proto.rs index 1f4e824..cc89de0 100644 --- a/native-engine/blaze-serde/src/from_proto.rs +++ b/native-engine/blaze-serde/src/from_proto.rs
@@ -45,7 +45,6 @@ BinaryExpr, CaseExpr, CastExpr, Column, IsNotNullExpr, IsNullExpr, Literal, NegativeExpr, NotExpr, PhysicalSortExpr, }, - joins::utils::{ColumnIndex, JoinFilter}, union::UnionExec, ColumnStatistics, ExecutionPlan, Partitioning, PhysicalExpr, Statistics, }, @@ -61,8 +60,8 @@ use datafusion_ext_plans::{ agg::{create_agg, AggExecMode, AggExpr, AggFunction, AggMode, GroupingExpr}, agg_exec::AggExec, + broadcast_join_build_hash_map_exec::BroadcastJoinBuildHashMapExec, broadcast_join_exec::BroadcastJoinExec, - broadcast_nested_loop_join_exec::BroadcastNestedLoopJoinExec, debug_exec::DebugExec, empty_partitions_exec::EmptyPartitionsExec, expand_exec::ExpandExec, @@ -89,7 +88,7 @@ use crate::{ convert_box_required, convert_required, error::PlanSerDeError, - from_proto_binary_op, into_required, proto_error, protobuf, + from_proto_binary_op, proto_error, protobuf, protobuf::{ physical_expr_node::ExprType, physical_plan_node::PhysicalPlanType, GenerateFunction, }, @@ -182,19 +181,20 @@ ))) } PhysicalPlanType::SortMergeJoin(sort_merge_join) => { + let schema = Arc::new(convert_required!(sort_merge_join.schema)?); let left: Arc<dyn ExecutionPlan> = convert_box_required!(sort_merge_join.left)?; let right: Arc<dyn ExecutionPlan> = convert_box_required!(sort_merge_join.right)?; let on: Vec<(Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>)> = sort_merge_join .on .iter() .map(|col| { - let left_col: Column = into_required!(col.left)?; - let left_col_binded: Arc<dyn PhysicalExpr> = - Arc::new(Column::new_with_schema(left_col.name(), &left.schema())?); - let right_col: Column = into_required!(col.right)?; - let right_col_binded: Arc<dyn PhysicalExpr> = - Arc::new(Column::new_with_schema(right_col.name(), &right.schema())?); - Ok((left_col_binded, right_col_binded)) + let left_key = + try_parse_physical_expr(&col.left.as_ref().unwrap(), &left.schema())?; + let left_key_binded = bind(left_key, &left.schema())?; + let right_key = + try_parse_physical_expr(&col.right.as_ref().unwrap(), &right.schema())?; + let right_key_binded = bind(right_key, &right.schema())?; + Ok((left_key_binded, right_key_binded)) }) .collect::<Result<_, Self::Error>>()?; @@ -210,38 +210,14 @@ let join_type = protobuf::JoinType::try_from(sort_merge_join.join_type) .expect("invalid JoinType"); - let join_filter = sort_merge_join - .join_filter - .as_ref() - .map(|f| { - let schema = Arc::new(convert_required!(f.schema)?); - let expression = try_parse_physical_expr_required(&f.expression, &schema)?; - let column_indices = f - .column_indices - .iter() - .map(|i| { - let side = - protobuf::JoinSide::try_from(i.side).expect("invalid JoinSide"); - Ok(ColumnIndex { - index: i.index as usize, - side: side.into(), - }) - }) - .collect::<Result<Vec<_>, PlanSerDeError>>()?; - - Ok(JoinFilter::new( - bind(expression, &schema)?, - column_indices, - schema.as_ref().clone(), - )) - }) - .map_or(Ok(None), |v: Result<_, PlanSerDeError>| v.map(Some))?; Ok(Arc::new(SortMergeJoinExec::try_new( + schema, left, right, on, - join_type.into(), - join_filter, + join_type + .try_into() + .map_err(|_| proto_error("invalid JoinType"))?, sort_options, )?)) } @@ -306,7 +282,7 @@ self )) })?; - if let protobuf::physical_expr_node::ExprType::Sort(sort_expr) = expr { + if let ExprType::Sort(sort_expr) = expr { let expr = sort_expr .expr .as_ref() @@ -342,97 +318,58 @@ sort.fetch_limit.as_ref().map(|limit| limit.limit as usize), ))) } + PhysicalPlanType::BroadcastJoinBuildHashMap(bhm) => { + let input: Arc<dyn ExecutionPlan> = convert_box_required!(bhm.input)?; + let keys = bhm + .keys + .iter() + .map(|expr| { + Ok(bind( + try_parse_physical_expr(expr, &input.schema())?, + &input.schema(), + )?) + }) + .collect::<Result<Vec<Arc<dyn PhysicalExpr>>, Self::Error>>()?; + Ok(Arc::new(BroadcastJoinBuildHashMapExec::new(input, keys))) + } PhysicalPlanType::BroadcastJoin(broadcast_join) => { + let schema = Arc::new(convert_required!(broadcast_join.schema)?); let left: Arc<dyn ExecutionPlan> = convert_box_required!(broadcast_join.left)?; let right: Arc<dyn ExecutionPlan> = convert_box_required!(broadcast_join.right)?; let on: Vec<(Arc<dyn PhysicalExpr>, Arc<dyn PhysicalExpr>)> = broadcast_join .on .iter() .map(|col| { - let left_col: Column = into_required!(col.left)?; - let left_col_binded: Arc<dyn PhysicalExpr> = - Arc::new(Column::new_with_schema(left_col.name(), &left.schema())?); - let right_col: Column = into_required!(col.right)?; - let right_col_binded: Arc<dyn PhysicalExpr> = - Arc::new(Column::new_with_schema(right_col.name(), &right.schema())?); - Ok((left_col_binded, right_col_binded)) + let left_key = + try_parse_physical_expr(&col.left.as_ref().unwrap(), &left.schema())?; + let left_key_binded = bind(left_key, &left.schema())?; + let right_key = + try_parse_physical_expr(&col.right.as_ref().unwrap(), &right.schema())?; + let right_key_binded = bind(right_key, &right.schema())?; + Ok((left_key_binded, right_key_binded)) }) .collect::<Result<_, Self::Error>>()?; let join_type = protobuf::JoinType::try_from(broadcast_join.join_type) .expect("invalid JoinType"); - let join_filter = broadcast_join - .join_filter - .as_ref() - .map(|f| { - let schema = Arc::new(convert_required!(f.schema)?); - let expression = try_parse_physical_expr_required(&f.expression, &schema)?; - let column_indices = f - .column_indices - .iter() - .map(|i| { - let side = - protobuf::JoinSide::try_from(i.side).expect("invalid JoinSide"); - Ok(ColumnIndex { - index: i.index as usize, - side: side.into(), - }) - }) - .collect::<Result<Vec<_>, PlanSerDeError>>()?; - Ok(JoinFilter::new( - bind(expression, &schema)?, - column_indices, - schema.as_ref().clone(), - )) - }) - .map_or(Ok(None), |v: Result<_, PlanSerDeError>| v.map(Some))?; + let broadcast_side = protobuf::JoinSide::try_from(broadcast_join.broadcast_side) + .expect("invalid BroadcastSide"); + + let cached_build_hash_map_id = broadcast_join.cached_build_hash_map_id.clone(); Ok(Arc::new(BroadcastJoinExec::try_new( + schema, left, right, on, - join_type.into(), - join_filter, - )?)) - } - PhysicalPlanType::BroadcastNestedLoopJoin(bnlj) => { - let left: Arc<dyn ExecutionPlan> = convert_box_required!(bnlj.left)?; - let right: Arc<dyn ExecutionPlan> = convert_box_required!(bnlj.right)?; - let join_type = - protobuf::JoinType::try_from(bnlj.join_type).expect("invalid JoinType"); - let join_filter = bnlj - .join_filter - .as_ref() - .map(|f| { - let schema = Arc::new(convert_required!(f.schema)?); - let expression = try_parse_physical_expr_required(&f.expression, &schema)?; - let column_indices = f - .column_indices - .iter() - .map(|i| { - let side = - protobuf::JoinSide::try_from(i.side).expect("invalid JoinSide"); - Ok(ColumnIndex { - index: i.index as usize, - side: side.into(), - }) - }) - .collect::<Result<Vec<_>, PlanSerDeError>>()?; - - Ok(JoinFilter::new( - bind(expression, &schema)?, - column_indices, - schema.as_ref().clone(), - )) - }) - .map_or(Ok(None), |v: Result<_, PlanSerDeError>| v.map(Some))?; - - Ok(Arc::new(BroadcastNestedLoopJoinExec::try_new( - left, - right, - join_type.into(), - join_filter, + join_type + .try_into() + .map_err(|_| proto_error("invalid JoinType"))?, + broadcast_side + .try_into() + .map_err(|_| proto_error("invalid BroadcastSide"))?, + Some(cached_build_hash_map_id), )?)) } PhysicalPlanType::Union(union) => {
diff --git a/native-engine/blaze-serde/src/lib.rs b/native-engine/blaze-serde/src/lib.rs index 30bd4c2..56cd4a6 100644 --- a/native-engine/blaze-serde/src/lib.rs +++ b/native-engine/blaze-serde/src/lib.rs
@@ -15,10 +15,8 @@ use std::sync::Arc; use arrow::datatypes::{DataType, Field, Fields, IntervalUnit, Schema, TimeUnit}; -use datafusion::{ - common::JoinSide, logical_expr::Operator, prelude::JoinType, scalar::ScalarValue, -}; -use datafusion_ext_plans::agg::AggFunction; +use datafusion::{common::JoinSide, logical_expr::Operator, scalar::ScalarValue}; +use datafusion_ext_plans::{agg::AggFunction, joins::join_utils::JoinType}; use crate::error::PlanSerDeError; @@ -111,6 +109,7 @@ protobuf::JoinType::Full => JoinType::Full, protobuf::JoinType::Semi => JoinType::LeftSemi, protobuf::JoinType::Anti => JoinType::LeftAnti, + protobuf::JoinType::Existence => JoinType::Existence, } } }
diff --git a/native-engine/datafusion-ext-commons/src/lib.rs b/native-engine/datafusion-ext-commons/src/lib.rs index 72f6223..ece6438 100644 --- a/native-engine/datafusion-ext-commons/src/lib.rs +++ b/native-engine/datafusion-ext-commons/src/lib.rs
@@ -13,7 +13,6 @@ // limitations under the License. #![feature(new_uninit)] -#![feature(io_error_other)] #![feature(slice_swap_unchecked)] #![feature(vec_into_raw_parts)] @@ -85,9 +84,9 @@ batch_size } -// for better cache usage +// bigger for better radix sort performance pub const fn staging_mem_size_for_partial_sort() -> usize { - 4194304 * 8 / 10 + 8388608 } // use bigger batch memory size writing shuffling data
diff --git a/native-engine/datafusion-ext-commons/src/spark_hash.rs b/native-engine/datafusion-ext-commons/src/spark_hash.rs index 6a76bb9..85dac30 100644 --- a/native-engine/datafusion-ext-commons/src/spark_hash.rs +++ b/native-engine/datafusion-ext-commons/src/spark_hash.rs
@@ -77,10 +77,8 @@ // avoid boundary checking in performance critical codes. // all operations are garenteed to be safe unsafe { - let mut h1 = hash_bytes_by_int( - std::slice::from_raw_parts(data.get_unchecked(0), len_aligned), - seed, - ); + let mut h1 = + hash_bytes_by_int(std::slice::from_raw_parts(data.as_ptr(), len_aligned), seed); for i in len_aligned..len { let half_word = *data.get_unchecked(i) as i8 as i32;
diff --git a/native-engine/datafusion-ext-functions/src/spark_get_json_object.rs b/native-engine/datafusion-ext-functions/src/spark_get_json_object.rs index 966b2f6..ede47a4 100644 --- a/native-engine/datafusion-ext-functions/src/spark_get_json_object.rs +++ b/native-engine/datafusion-ext-functions/src/spark_get_json_object.rs
@@ -194,8 +194,8 @@ #[derive(Debug)] enum HiveGetJsonObjectError { - InvalidJsonPath(String), - InvalidInput(String), + InvalidJsonPath, + InvalidInput, } struct HiveGetJsonObjectEvaluator { @@ -212,15 +212,11 @@ evaluator.matchers.push(matcher); } if evaluator.matchers.first() != Some(&HiveGetJsonObjectMatcher::Root) { - return Err(HiveGetJsonObjectError::InvalidJsonPath( - "json path missing root".to_string(), - )); + return Err(HiveGetJsonObjectError::InvalidJsonPath); } evaluator.matchers.remove(0); // remove root matcher if evaluator.matchers.contains(&HiveGetJsonObjectMatcher::Root) { - return Err(HiveGetJsonObjectError::InvalidJsonPath( - "json path has more than one root".to_string(), - )); + return Err(HiveGetJsonObjectError::InvalidJsonPath); } Ok(evaluator) } @@ -240,9 +236,7 @@ return Ok(v); } } - Err(HiveGetJsonObjectError::InvalidInput( - "invalid json string".to_string(), - )) + Err(HiveGetJsonObjectError::InvalidInput) } fn evaluate_with_value_serde_json( @@ -296,7 +290,7 @@ serde_json::Value::Bool(b) => Ok(Some(b.to_string())), serde_json::Value::Array(_) | serde_json::Value::Object(_) => serde_json::to_string(value) .map(Some) - .map_err(|_| HiveGetJsonObjectError::InvalidInput("array to json error".to_string())), + .map_err(|_| HiveGetJsonObjectError::InvalidInput), } } @@ -310,7 +304,7 @@ sonic_rs::JsonType::Boolean => Ok(value.as_bool().map(|v| v.to_string())), _ => sonic_rs::to_string(value) .map(Some) - .map_err(|_| HiveGetJsonObjectError::InvalidInput("array to json error".to_string())), + .map_err(|_| HiveGetJsonObjectError::InvalidInput), } } @@ -352,9 +346,7 @@ } } if child_name.is_empty() { - return Err(HiveGetJsonObjectError::InvalidJsonPath( - "empty child name".to_string(), - )); + return Err(HiveGetJsonObjectError::InvalidJsonPath); } Ok(Some(Self::Child(child_name))) } @@ -372,24 +364,18 @@ chars.next(); } None => { - return Err(HiveGetJsonObjectError::InvalidJsonPath( - "unterminated subscript".to_string(), - )); + return Err(HiveGetJsonObjectError::InvalidJsonPath); } } } if index_str.is_empty() || index_str == "*" { return Ok(Some(Self::SubscriptAll)); } - let index = str::parse::<usize>(&index_str).map_err(|_| { - HiveGetJsonObjectError::InvalidJsonPath("invalid subscript index".to_string()) - })?; + let index = str::parse::<usize>(&index_str) + .map_err(|_| HiveGetJsonObjectError::InvalidJsonPath)?; Ok(Some(Self::Subscript(index))) } - Some(c) => Err(HiveGetJsonObjectError::InvalidJsonPath(format!( - "unexpected char in json path: {}", - c - ))), + Some(_) => Err(HiveGetJsonObjectError::InvalidJsonPath), } }
diff --git a/native-engine/datafusion-ext-functions/src/spark_null_if.rs b/native-engine/datafusion-ext-functions/src/spark_null_if.rs index af753d5..4845a93 100644 --- a/native-engine/datafusion-ext-functions/src/spark_null_if.rs +++ b/native-engine/datafusion-ext-functions/src/spark_null_if.rs
@@ -16,10 +16,7 @@ use arrow::{ array::*, - compute::{ - kernels::{cmp::eq, nullif::nullif}, - *, - }, + compute::kernels::{cmp::eq, nullif::nullif}, datatypes::*, }; use datafusion::{ @@ -87,7 +84,8 @@ ($dt:ident) => {{ type T = paste::paste! {arrow::datatypes::[<$dt Type>]}; let array = as_primitive_array::<T>(array); - let eq_zeros = eq_scalar(array, T::default_value())?; + let _0 = PrimitiveArray::<T>::new_scalar(Default::default()); + let eq_zeros = eq(array, &_0)?; Arc::new(nullif(array, &eq_zeros)?) as ArrayRef }}; }
diff --git a/native-engine/datafusion-ext-functions/src/spark_strings.rs b/native-engine/datafusion-ext-functions/src/spark_strings.rs index 6eb5d5e..6deaa7c 100644 --- a/native-engine/datafusion-ext-functions/src/spark_strings.rs +++ b/native-engine/datafusion-ext-functions/src/spark_strings.rs
@@ -223,7 +223,9 @@ None => return Ok(Arg::Ignore), } } - if let ScalarValue::List(l) = scalar && l.data_type() == &DataType::Utf8 { + if let ScalarValue::List(l) = scalar + && l.data_type() == &DataType::Utf8 + { if l.is_null(0) { return Ok(Arg::Ignore); }
diff --git a/native-engine/datafusion-ext-plans/Cargo.toml b/native-engine/datafusion-ext-plans/Cargo.toml index a233274..82fad59 100644 --- a/native-engine/datafusion-ext-plans/Cargo.toml +++ b/native-engine/datafusion-ext-plans/Cargo.toml
@@ -11,6 +11,7 @@ arrow = { workspace = true } async-trait = "0.1.80" base64 = "0.22.1" +bitvec = "1.0.1" byteorder = "1.5.0" bytes = "1.6.0" blaze-jni-bridge = { workspace = true }
diff --git a/native-engine/datafusion-ext-plans/src/broadcast_join_build_hash_map_exec.rs b/native-engine/datafusion-ext-plans/src/broadcast_join_build_hash_map_exec.rs new file mode 100644 index 0000000..3f1ca6d --- /dev/null +++ b/native-engine/datafusion-ext-plans/src/broadcast_join_build_hash_map_exec.rs
@@ -0,0 +1,150 @@ +// Copyright 2022 The Blaze Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{ + any::Any, + fmt::{Debug, Formatter}, + sync::Arc, +}; + +use arrow::{compute::concat_batches, datatypes::SchemaRef}; +use datafusion::{ + common::Result, + execution::{SendableRecordBatchStream, TaskContext}, + physical_expr::{Partitioning, PhysicalExpr, PhysicalSortExpr}, + physical_plan::{ + metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}, + stream::RecordBatchStreamAdapter, + DisplayAs, DisplayFormatType, ExecutionPlan, + }, +}; +use futures::{stream::once, TryStreamExt}; + +use crate::{ + common::output::{NextBatchWithTimer, TaskOutputter}, + joins::join_hash_map::{join_hash_map_schema, JoinHashMap}, +}; + +pub struct BroadcastJoinBuildHashMapExec { + input: Arc<dyn ExecutionPlan>, + keys: Vec<Arc<dyn PhysicalExpr>>, + metrics: ExecutionPlanMetricsSet, +} + +impl BroadcastJoinBuildHashMapExec { + pub fn new(input: Arc<dyn ExecutionPlan>, keys: Vec<Arc<dyn PhysicalExpr>>) -> Self { + Self { + input, + keys, + metrics: ExecutionPlanMetricsSet::new(), + } + } +} + +impl Debug for BroadcastJoinBuildHashMapExec { + fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { + write!(f, "BroadcastJoinBuildHashMap [{:?}]", self.keys) + } +} + +impl DisplayAs for BroadcastJoinBuildHashMapExec { + fn fmt_as(&self, _t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { + write!(f, "BroadcastJoinBuildHashMapExec [{:?}]", self.keys) + } +} + +impl ExecutionPlan for BroadcastJoinBuildHashMapExec { + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> SchemaRef { + join_hash_map_schema(&self.input.schema()) + } + + fn output_partitioning(&self) -> Partitioning { + Partitioning::UnknownPartitioning(self.input.output_partitioning().partition_count()) + } + + fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { + None + } + + fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> { + vec![self.input.clone()] + } + + fn with_new_children( + self: Arc<Self>, + children: Vec<Arc<dyn ExecutionPlan>>, + ) -> Result<Arc<dyn ExecutionPlan>> { + Ok(Arc::new(Self::new(children[0].clone(), self.keys.clone()))) + } + + fn execute( + &self, + partition: usize, + context: Arc<TaskContext>, + ) -> Result<SendableRecordBatchStream> { + let baseline_metrics = BaselineMetrics::new(&self.metrics, partition); + let input = self.input.execute(partition, context.clone())?; + Ok(Box::pin(RecordBatchStreamAdapter::new( + self.schema(), + once(execute_build_hash_map( + context, + input, + self.keys.clone(), + baseline_metrics, + )) + .try_flatten(), + ))) + } + + fn metrics(&self) -> Option<MetricsSet> { + Some(self.metrics.clone_inner()) + } +} + +async fn execute_build_hash_map( + context: Arc<TaskContext>, + mut input: SendableRecordBatchStream, + keys: Vec<Arc<dyn PhysicalExpr>>, + metrics: BaselineMetrics, +) -> Result<SendableRecordBatchStream> { + let elapsed_compute = metrics.elapsed_compute().clone(); + let mut timer = elapsed_compute.timer(); + + let mut data_batches = vec![]; + let data_schema = input.schema(); + + // collect all input batches + while let Some(batch) = input.next_batch(Some(&mut timer)).await? { + data_batches.push(batch); + } + let data_batch = concat_batches(&data_schema, data_batches.iter())?; + + // build hash map + let hash_map_schema = join_hash_map_schema(&data_schema); + let hash_map = JoinHashMap::try_from_data_batch(data_batch, &keys)?; + drop(timer); + + // output hash map batches as stream + context.output_with_sender("BuildHashMap", hash_map_schema, move |sender| async move { + let mut timer = elapsed_compute.timer(); + sender + .send(Ok(hash_map.into_hash_map_batch()?), Some(&mut timer)) + .await; + Ok(()) + }) +}
diff --git a/native-engine/datafusion-ext-plans/src/broadcast_join_exec.rs b/native-engine/datafusion-ext-plans/src/broadcast_join_exec.rs index 201173c..de160af 100644 --- a/native-engine/datafusion-ext-plans/src/broadcast_join_exec.rs +++ b/native-engine/datafusion-ext-plans/src/broadcast_join_exec.rs
@@ -15,90 +15,203 @@ use std::{ any::Any, fmt::{Debug, Formatter}, - sync::Arc, - task::Poll, - time::Duration, + future::Future, + pin::Pin, + sync::{Arc, Weak}, + time::{Duration, Instant}, }; -use arrow::{datatypes::SchemaRef, record_batch::RecordBatch}; -use blaze_jni_bridge::{ - conf, - conf::{BooleanConf, IntConf}, +use arrow::{ + array::RecordBatch, + compute::SortOptions, + datatypes::{DataType, SchemaRef}, }; +use async_trait::async_trait; use datafusion::{ - common::{Result, Statistics}, + common::{JoinSide, Result, Statistics}, execution::context::TaskContext, - logical_expr::JoinType, - physical_expr::PhysicalSortExpr, + physical_expr::{PhysicalExprRef, PhysicalSortExpr}, physical_plan::{ - expressions::Column, - joins::{ - utils::{build_join_schema, check_join_is_valid, JoinFilter, JoinOn}, - HashJoinExec, PartitionMode, - }, - memory::MemoryStream, - metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}, + joins::utils::JoinOn, + metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, Time}, stream::RecordBatchStreamAdapter, DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, }, }; -use datafusion_ext_commons::{df_execution_err, downcast_any}; -use futures::{stream::once, StreamExt, TryStreamExt}; +use datafusion_ext_commons::{ + batch_size, df_execution_err, streams::coalesce_stream::CoalesceInput, +}; +use futures::{StreamExt, TryStreamExt}; +use hashbrown::HashMap; +use once_cell::sync::OnceCell; use parking_lot::Mutex; -use crate::{sort_exec::SortExec, sort_merge_join_exec::SortMergeJoinExec}; +use crate::{ + common::{ + batch_statisitcs::{stat_input, InputBatchStatistics}, + column_pruning::ExecuteWithColumnPruning, + output::{TaskOutputter, WrappedRecordBatchSender}, + }, + joins::{ + bhj::{ + full_join::{ + LProbedFullOuterJoiner, LProbedInnerJoiner, LProbedLeftJoiner, LProbedRightJoiner, + RProbedFullOuterJoiner, RProbedInnerJoiner, RProbedLeftJoiner, RProbedRightJoiner, + }, + semi_join::{ + LProbedExistenceJoiner, LProbedLeftAntiJoiner, LProbedLeftSemiJoiner, + LProbedRightAntiJoiner, LProbedRightSemiJoiner, RProbedExistenceJoiner, + RProbedLeftAntiJoiner, RProbedLeftSemiJoiner, RProbedRightAntiJoiner, + RProbedRightSemiJoiner, + }, + }, + join_hash_map::{join_data_schema, JoinHashMap}, + join_utils::{JoinType, JoinType::*}, + JoinParams, JoinProjection, + }, +}; #[derive(Debug)] pub struct BroadcastJoinExec { - /// Left sorted joining execution plan left: Arc<dyn ExecutionPlan>, - /// Right sorting joining execution plan right: Arc<dyn ExecutionPlan>, - /// Set of common columns used to join on on: JoinOn, - /// How the join is performed join_type: JoinType, - /// Optional filter before outputting - join_filter: Option<JoinFilter>, - /// The schema once the join is applied + broadcast_side: JoinSide, schema: SchemaRef, - /// Execution metrics + cached_build_hash_map_id: Option<String>, metrics: ExecutionPlanMetricsSet, } impl BroadcastJoinExec { pub fn try_new( + schema: SchemaRef, left: Arc<dyn ExecutionPlan>, right: Arc<dyn ExecutionPlan>, on: JoinOn, join_type: JoinType, - join_filter: Option<JoinFilter>, + broadcast_side: JoinSide, + cached_build_hash_map_id: Option<String>, ) -> Result<Self> { - if matches!( - join_type, - JoinType::LeftSemi | JoinType::LeftAnti | JoinType::RightSemi | JoinType::RightAnti, - ) { - if join_filter.is_some() { - df_execution_err!("Semi/Anti join with filter is not supported yet")?; - } - } - - let left_schema = left.schema(); - let right_schema = right.schema(); - - check_join_is_valid(&left_schema, &right_schema, &on)?; - let schema = Arc::new(build_join_schema(&left_schema, &right_schema, &join_type).0); - Ok(Self { left, right, on, join_type, - join_filter, + broadcast_side, schema, + cached_build_hash_map_id, metrics: ExecutionPlanMetricsSet::new(), }) } + + fn create_join_params(&self, projection: &[usize]) -> Result<JoinParams> { + let left_schema = self.left.schema(); + let right_schema = self.right.schema(); + let (left_keys, right_keys): (Vec<PhysicalExprRef>, Vec<PhysicalExprRef>) = + self.on.iter().cloned().unzip(); + let key_data_types: Vec<DataType> = self + .on + .iter() + .map(|(left_key, right_key)| { + Ok({ + let left_dt = left_key.data_type(&left_schema)?; + let right_dt = right_key.data_type(&right_schema)?; + if left_dt != right_dt { + df_execution_err!( + "join key data type differs {left_dt:?} <-> {right_dt:?}" + )?; + } + left_dt + }) + }) + .collect::<Result<_>>()?; + + let projection = JoinProjection::try_new( + self.join_type, + &self.schema, + &match self.broadcast_side { + JoinSide::Left => join_data_schema(&left_schema), + JoinSide::Right => left_schema.clone(), + }, + &match self.broadcast_side { + JoinSide::Left => right_schema.clone(), + JoinSide::Right => join_data_schema(&right_schema), + }, + projection, + )?; + + Ok(JoinParams { + join_type: self.join_type, + left_schema, + right_schema, + output_schema: self.schema(), + left_keys, + right_keys, + batch_size: batch_size(), + sort_options: vec![SortOptions::default(); self.on.len()], + projection, + key_data_types, + }) + } + + fn execute_with_projection( + &self, + partition: usize, + context: Arc<TaskContext>, + projection: Vec<usize>, + ) -> Result<SendableRecordBatchStream> { + let metrics = Arc::new(BaselineMetrics::new(&self.metrics, partition)); + let join_params = self.create_join_params(&projection)?; + let left = self.left.execute(partition, context.clone())?; + let right = self.right.execute(partition, context.clone())?; + let broadcast_side = self.broadcast_side; + let cached_build_hash_map_id = self.cached_build_hash_map_id.clone(); + + // stat probed side + let input_batch_stat = + InputBatchStatistics::from_metrics_set_and_blaze_conf(&self.metrics, partition)?; + let (left, right) = match broadcast_side { + JoinSide::Left => (left, stat_input(input_batch_stat, right)?), + JoinSide::Right => (stat_input(input_batch_stat, left)?, right), + }; + + let metrics_cloned = metrics.clone(); + let context_cloned = context.clone(); + let output_stream = Box::pin(RecordBatchStreamAdapter::new( + join_params.projection.schema.clone(), + futures::stream::once(async move { + context_cloned.output_with_sender( + "BroadcastJoin", + join_params.projection.schema.clone(), + move |sender| { + execute_join( + left, + right, + join_params, + broadcast_side, + cached_build_hash_map_id, + metrics_cloned, + sender, + ) + }, + ) + }) + .try_flatten(), + )); + Ok(context.coalesce_with_default_batch_size(output_stream, &metrics)?) + } +} + +impl ExecuteWithColumnPruning for BroadcastJoinExec { + fn execute_projected( + &self, + partition: usize, + context: Arc<TaskContext>, + projection: &[usize], + ) -> Result<SendableRecordBatchStream> { + self.execute_with_projection(partition, context, projection.to_vec()) + } } impl ExecutionPlan for BroadcastJoinExec { @@ -111,7 +224,10 @@ } fn output_partitioning(&self) -> Partitioning { - self.right.output_partitioning() + match self.broadcast_side { + JoinSide::Left => self.right.output_partitioning(), + JoinSide::Right => self.left.output_partitioning(), + } } fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { @@ -127,11 +243,13 @@ children: Vec<Arc<dyn ExecutionPlan>>, ) -> Result<Arc<dyn ExecutionPlan>> { Ok(Arc::new(Self::try_new( + self.schema.clone(), children[0].clone(), children[1].clone(), self.on.iter().cloned().collect(), self.join_type, - self.join_filter.clone(), + self.broadcast_side, + None, )?)) } @@ -140,21 +258,8 @@ partition: usize, context: Arc<TaskContext>, ) -> Result<SendableRecordBatchStream> { - let stream = execute_broadcast_join( - self.left.clone(), - self.right.clone(), - partition, - context, - self.on.clone(), - self.join_type, - self.join_filter.clone(), - BaselineMetrics::new(&self.metrics, partition), - ); - - Ok(Box::pin(RecordBatchStreamAdapter::new( - self.schema(), - once(stream).try_flatten(), - ))) + let projection = (0..self.schema.fields().len()).collect(); + self.execute_with_projection(partition, context, projection) } fn metrics(&self) -> Option<MetricsSet> { @@ -172,221 +277,188 @@ } } -async fn execute_broadcast_join( - left: Arc<dyn ExecutionPlan>, - right: Arc<dyn ExecutionPlan>, - partition: usize, - context: Arc<TaskContext>, - on: JoinOn, - join_type: JoinType, - join_filter: Option<JoinFilter>, - metrics: BaselineMetrics, -) -> Result<SendableRecordBatchStream> { - let enabled_fallback_to_smj = conf::BHJ_FALLBACKS_TO_SMJ_ENABLE.value()?; - let bhj_num_rows_limit = conf::BHJ_FALLBACKS_TO_SMJ_ROWS_THRESHOLD.value()? as usize; - let bhj_mem_size_limit = conf::BHJ_FALLBACKS_TO_SMJ_MEM_THRESHOLD.value()? as usize; +async fn execute_join( + left: SendableRecordBatchStream, + right: SendableRecordBatchStream, + join_params: JoinParams, + broadcast_side: JoinSide, + cached_build_hash_map_id: Option<String>, + metrics: Arc<BaselineMetrics>, + sender: Arc<WrappedRecordBatchSender>, +) -> Result<()> { + let start_time = Instant::now(); + let mut excluded_time_ns = 0; + let poll_time = Time::new(); - // if broadcasted size is small enough, use hash join - // otherwise use sort-merge join - #[derive(Debug)] - enum JoinMode { - Hash, - SortMerge, - } - let mut join_mode = JoinMode::Hash; - - let left_schema = left.schema(); - let mut left = left; - - if enabled_fallback_to_smj { - let mut left_stream = left.execute(0, context.clone())?.fuse(); - let mut left_cached: Vec<RecordBatch> = vec![]; - let mut left_num_rows = 0; - let mut left_mem_size = 0; - - // read and cache batches from broadcasted side until reached limits - while let Some(batch) = left_stream.next().await.transpose()? { - left_num_rows += batch.num_rows(); - left_mem_size += batch.get_array_memory_size(); - left_cached.push(batch); - if left_num_rows > bhj_num_rows_limit || left_mem_size > bhj_mem_size_limit { - join_mode = JoinMode::SortMerge; - break; - } + let (mut probed, _keys, mut joiner): (_, _, Pin<Box<dyn Joiner + Send>>) = match broadcast_side + { + JoinSide::Left => { + let right_schema = right.schema(); + let mut right_peeked = Box::pin(right.peekable()); + let (_, lmap_result) = futures::join!( + // fetch two sides asynchronously + async { + let timer = poll_time.timer(); + right_peeked.as_mut().peek().await; + drop(timer); + }, + collect_join_hash_map( + cached_build_hash_map_id, + left, + &join_params.left_keys, + poll_time.clone(), + ), + ); + let lmap = lmap_result?; + ( + Box::pin(RecordBatchStreamAdapter::new(right_schema, right_peeked)), + join_params.right_keys.clone(), + match join_params.join_type { + Inner => Box::pin(RProbedInnerJoiner::new(join_params, lmap, sender)), + Left => Box::pin(RProbedLeftJoiner::new(join_params, lmap, sender)), + Right => Box::pin(RProbedRightJoiner::new(join_params, lmap, sender)), + Full => Box::pin(RProbedFullOuterJoiner::new(join_params, lmap, sender)), + LeftSemi => Box::pin(RProbedLeftSemiJoiner::new(join_params, lmap, sender)), + LeftAnti => Box::pin(RProbedLeftAntiJoiner::new(join_params, lmap, sender)), + RightSemi => Box::pin(RProbedRightSemiJoiner::new(join_params, lmap, sender)), + RightAnti => Box::pin(RProbedRightAntiJoiner::new(join_params, lmap, sender)), + Existence => Box::pin(RProbedExistenceJoiner::new(join_params, lmap, sender)), + }, + ) } - - // convert left cached and rest batches into execution plan - let left_cached_stream: SendableRecordBatchStream = Box::pin(MemoryStream::try_new( - left_cached, - left_schema.clone(), - None, - )?); - let left_rest_stream: SendableRecordBatchStream = Box::pin(RecordBatchStreamAdapter::new( - left_schema.clone(), - left_stream, - )); - let left_stream: SendableRecordBatchStream = Box::pin(RecordBatchStreamAdapter::new( - left_schema.clone(), - left_cached_stream.chain(left_rest_stream), - )); - left = Arc::new(RecordBatchStreamsWrapperExec { - schema: left_schema.clone(), - stream: Mutex::new(Some(left_stream)), - output_partitioning: right.output_partitioning(), - }); - } - - match join_mode { - JoinMode::Hash => { - let join = Arc::new(HashJoinExec::try_new( - left.clone(), - right.clone(), - on, - join_filter, - &join_type, - PartitionMode::CollectLeft, - false, - )?); - log::info!("BroadcastJoin is using hash join mode: {:?}", &join); - - let join_schema = join.schema(); - let completed = join - .execute(partition, context)? - .chain(futures::stream::poll_fn(move |_| { - // update metrics - let join_metrics = join.metrics().unwrap(); - metrics.record_output(join_metrics.output_rows().unwrap_or(0)); - metrics.elapsed_compute().add_duration(Duration::from_nanos( - [ - join_metrics - .sum_by_name("build_time") - .map(|v| v.as_usize() as u64), - join_metrics - .sum_by_name("join_time") - .map(|v| v.as_usize() as u64), - ] - .into_iter() - .flatten() - .sum(), - )); - Poll::Ready(None) - })); - Ok(Box::pin(RecordBatchStreamAdapter::new( - join_schema, - completed, - ))) + JoinSide::Right => { + let left_schema = left.schema(); + let mut left_peeked = Box::pin(left.peekable()); + let (_, rmap_result) = futures::join!( + // fetch two sides asynchronizely + async { + let timer = poll_time.timer(); + left_peeked.as_mut().peek().await; + drop(timer); + }, + collect_join_hash_map( + cached_build_hash_map_id, + right, + &join_params.right_keys, + poll_time.clone(), + ), + ); + let rmap = rmap_result?; + ( + Box::pin(RecordBatchStreamAdapter::new(left_schema, left_peeked)), + join_params.left_keys.clone(), + match join_params.join_type { + Inner => Box::pin(LProbedInnerJoiner::new(join_params, rmap, sender)), + Left => Box::pin(LProbedLeftJoiner::new(join_params, rmap, sender)), + Right => Box::pin(LProbedRightJoiner::new(join_params, rmap, sender)), + Full => Box::pin(LProbedFullOuterJoiner::new(join_params, rmap, sender)), + LeftSemi => Box::pin(LProbedLeftSemiJoiner::new(join_params, rmap, sender)), + LeftAnti => Box::pin(LProbedLeftAntiJoiner::new(join_params, rmap, sender)), + RightSemi => Box::pin(LProbedRightSemiJoiner::new(join_params, rmap, sender)), + RightAnti => Box::pin(LProbedRightAntiJoiner::new(join_params, rmap, sender)), + Existence => Box::pin(LProbedExistenceJoiner::new(join_params, rmap, sender)), + }, + ) } - JoinMode::SortMerge => { - let sort_exprs: Vec<PhysicalSortExpr> = on - .iter() - .map(|(_col_left, col_right)| PhysicalSortExpr { - expr: Arc::new(Column::new( - "", - downcast_any!(col_right, Column) - .expect("requires column") - .index(), - )), - options: Default::default(), - }) - .collect(); + }; - let right_sorted = Arc::new(SortExec::new(right, sort_exprs.clone(), None)); - let join = Arc::new(SortMergeJoinExec::try_new( - left.clone(), - right_sorted.clone(), - on, - join_type, - join_filter, - sort_exprs.into_iter().map(|se| se.options).collect(), - )?); - log::info!("BroadcastJoin is using sort-merge join mode: {:?}", &join); + while let Some(batch) = { + let timer = poll_time.timer(); + let batch = probed.next().await.transpose()?; + drop(timer); + batch + } { + joiner.as_mut().join(batch).await?; + } + joiner.as_mut().finish().await?; + metrics.record_output(joiner.num_output_rows()); - let join_schema = join.schema(); - let completed = join - .execute(partition, context)? - .chain(futures::stream::poll_fn(move |_| { - // update metrics - let right_sorted_metrics = right_sorted.metrics().unwrap(); - let join_metrics = join.metrics().unwrap(); - metrics.record_output(join_metrics.output_rows().unwrap_or(0)); - metrics.elapsed_compute().add_duration(Duration::from_nanos( - [ - right_sorted_metrics.elapsed_compute(), - join_metrics.elapsed_compute(), - ] - .into_iter() - .flatten() - .sum::<usize>() as u64, - )); - Poll::Ready(None) - })); - Ok(Box::pin(RecordBatchStreamAdapter::new( - join_schema, - completed, - ))) + excluded_time_ns += poll_time.value(); + excluded_time_ns += joiner.total_send_output_time(); + + // discount poll input and send output batch time + let mut join_time_ns = (Instant::now() - start_time).as_nanos() as u64; + join_time_ns -= excluded_time_ns as u64; + metrics + .elapsed_compute() + .add_duration(Duration::from_nanos(join_time_ns)); + Ok(()) +} + +async fn collect_join_hash_map( + cached_build_hash_map_id: Option<String>, + input: SendableRecordBatchStream, + key_exprs: &[PhysicalExprRef], + poll_time: Time, +) -> Result<Arc<JoinHashMap>> { + Ok(match cached_build_hash_map_id { + Some(cached_id) => { + get_cached_join_hash_map(&cached_id, || async { + collect_join_hash_map_without_caching(input, key_exprs, poll_time).await + }) + .await? } + None => { + let map = collect_join_hash_map_without_caching(input, key_exprs, poll_time).await?; + Arc::new(map) + } + }) +} + +async fn collect_join_hash_map_without_caching( + mut input: SendableRecordBatchStream, + key_exprs: &[PhysicalExprRef], + poll_time: Time, +) -> Result<JoinHashMap> { + let mut hash_map_batches = vec![]; + while let Some(batch) = { + let timer = poll_time.timer(); + let batch = input.next().await.transpose()?; + drop(timer); + batch + } { + hash_map_batches.push(batch); + } + match hash_map_batches.len() { + 0 => Ok(JoinHashMap::try_new_empty(input.schema(), key_exprs)?), + 1 => Ok(JoinHashMap::try_from_hash_map_batch( + hash_map_batches[0].clone(), + key_exprs, + )?), + n => df_execution_err!("expect zero or one hash map batch, got {n}"), } } -pub struct RecordBatchStreamsWrapperExec { - pub schema: SchemaRef, - pub stream: Mutex<Option<SendableRecordBatchStream>>, - pub output_partitioning: Partitioning, +#[async_trait] +pub trait Joiner { + async fn join(self: Pin<&mut Self>, probed_batch: RecordBatch) -> Result<()>; + async fn finish(self: Pin<&mut Self>) -> Result<()>; + + fn total_send_output_time(&self) -> usize; + fn num_output_rows(&self) -> usize; } -impl Debug for RecordBatchStreamsWrapperExec { - fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result { - write!(f, "RecordBatchStreamsWrapper") - } -} +async fn get_cached_join_hash_map<Fut: Future<Output = Result<JoinHashMap>> + Send>( + cached_id: &str, + init: impl FnOnce() -> Fut, +) -> Result<Arc<JoinHashMap>> { + type Slot = Arc<tokio::sync::Mutex<Weak<JoinHashMap>>>; + static CACHED_JOIN_HASH_MAP: OnceCell<Arc<Mutex<HashMap<String, Slot>>>> = OnceCell::new(); -impl DisplayAs for RecordBatchStreamsWrapperExec { - fn fmt_as(&self, _t: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { - write!(f, "RecordBatchStreamsWrapper") - } -} + // TODO: remove expired keys from cached join hash map + let cached_join_hash_map = CACHED_JOIN_HASH_MAP.get_or_init(|| Arc::default()); + let slot = cached_join_hash_map + .lock() + .entry(cached_id.to_string()) + .or_default() + .clone(); -impl ExecutionPlan for RecordBatchStreamsWrapperExec { - fn as_any(&self) -> &dyn Any { - self - } - - fn schema(&self) -> SchemaRef { - self.schema.clone() - } - - fn output_partitioning(&self) -> Partitioning { - self.output_partitioning.clone() - } - - fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { - None - } - - fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> { - vec![] - } - - fn with_new_children( - self: Arc<Self>, - _: Vec<Arc<dyn ExecutionPlan>>, - ) -> Result<Arc<dyn ExecutionPlan>> { - unimplemented!() - } - - fn execute( - &self, - _partition: usize, - _context: Arc<TaskContext>, - ) -> Result<SendableRecordBatchStream> { - let stream = std::mem::take(&mut *self.stream.lock()); - Ok(Box::pin(RecordBatchStreamAdapter::new( - self.schema.clone(), - Box::pin(futures::stream::iter(stream).flatten()), - ))) - } - - fn statistics(&self) -> Result<Statistics> { - unimplemented!() + let mut slot = slot.lock().await; + if let Some(cached) = slot.upgrade() { + Ok(cached) + } else { + let new = Arc::new(init().await?); + *slot = Arc::downgrade(&new); + Ok(new) } }
diff --git a/native-engine/datafusion-ext-plans/src/broadcast_nested_loop_join_exec.rs b/native-engine/datafusion-ext-plans/src/broadcast_nested_loop_join_exec.rs deleted file mode 100644 index b52e77f..0000000 --- a/native-engine/datafusion-ext-plans/src/broadcast_nested_loop_join_exec.rs +++ /dev/null
@@ -1,252 +0,0 @@ -// Copyright 2022 The Blaze Authors -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -use std::{any::Any, fmt::Formatter, sync::Arc}; - -use arrow::datatypes::SchemaRef; -use datafusion::{ - common::{JoinType, Result, Statistics}, - execution::{SendableRecordBatchStream, TaskContext}, - physical_expr::{Partitioning, PhysicalSortExpr}, - physical_plan::{ - joins::{ - utils::{build_join_schema, check_join_is_valid, JoinFilter}, - NestedLoopJoinExec, - }, - memory::MemoryExec, - metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}, - stream::RecordBatchStreamAdapter, - DisplayAs, DisplayFormatType, ExecutionPlan, - }, -}; -use datafusion_ext_commons::batch_size; -use futures::{stream::once, StreamExt, TryStreamExt}; -use parking_lot::Mutex; - -use crate::broadcast_join_exec::RecordBatchStreamsWrapperExec; - -#[derive(Debug)] -pub struct BroadcastNestedLoopJoinExec { - left: Arc<dyn ExecutionPlan>, - right: Arc<dyn ExecutionPlan>, - join_type: JoinType, - filter: Option<JoinFilter>, - schema: SchemaRef, - metrics: ExecutionPlanMetricsSet, -} - -impl BroadcastNestedLoopJoinExec { - pub fn try_new( - left: Arc<dyn ExecutionPlan>, - right: Arc<dyn ExecutionPlan>, - join_type: JoinType, - filter: Option<JoinFilter>, - ) -> Result<Self> { - let left_schema = left.schema(); - let right_schema = right.schema(); - check_join_is_valid(&left_schema, &right_schema, &[])?; - let (schema, _column_indices) = build_join_schema(&left_schema, &right_schema, &join_type); - - Ok(Self { - left, - right, - filter, - join_type, - schema: Arc::new(schema), - metrics: ExecutionPlanMetricsSet::new(), - }) - } -} - -impl DisplayAs for BroadcastNestedLoopJoinExec { - fn fmt_as(&self, _: DisplayFormatType, f: &mut Formatter) -> std::fmt::Result { - write!(f, "BroadcastNestedLoopJoin") - } -} - -impl ExecutionPlan for BroadcastNestedLoopJoinExec { - fn as_any(&self) -> &dyn Any { - self - } - - fn schema(&self) -> SchemaRef { - self.schema.clone() - } - - fn output_partitioning(&self) -> Partitioning { - if left_is_build_side(self.join_type) { - self.right.output_partitioning() - } else { - self.left.output_partitioning() - } - } - - fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { - None - } - - fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> { - vec![self.left.clone(), self.right.clone()] - } - - fn with_new_children( - self: Arc<Self>, - children: Vec<Arc<dyn ExecutionPlan>>, - ) -> Result<Arc<dyn ExecutionPlan>> { - Ok(Arc::new(Self::try_new( - children[0].clone(), - children[1].clone(), - self.join_type, - self.filter.clone(), - )?)) - } - - fn execute( - &self, - partition: usize, - context: Arc<TaskContext>, - ) -> Result<SendableRecordBatchStream> { - let joined = Box::pin(RecordBatchStreamAdapter::new( - self.schema(), - once(execute_join( - partition, - context, - self.left.clone(), - self.right.clone(), - self.join_type, - self.filter.clone(), - self.metrics.clone(), - )) - .try_flatten(), - )); - Ok(joined) - } - - fn metrics(&self) -> Option<MetricsSet> { - Some(self.metrics.clone_inner()) - } - - fn statistics(&self) -> Result<Statistics> { - todo!() - } -} - -async fn execute_join( - partition: usize, - context: Arc<TaskContext>, - left: Arc<dyn ExecutionPlan>, - right: Arc<dyn ExecutionPlan>, - join_type: JoinType, - filter: Option<JoinFilter>, - metrics: ExecutionPlanMetricsSet, -) -> Result<SendableRecordBatchStream> { - // inner side - let mut inner_stream = if left_is_build_side(join_type) { - left.execute(partition, context.clone())? - } else { - right.execute(partition, context.clone())? - }; - let inner_schema = inner_stream.schema(); - let mut inner_batches = vec![]; - while let Some(batch) = inner_stream.next().await.transpose()? { - inner_batches.push(batch); - } - - let inner_batch_max_num_rows = inner_batches - .iter() - .map(|batch| batch.num_rows()) - .max() - .unwrap_or(0); - let inner_batch_max_mem_size = inner_batches - .iter() - .map(|batch| batch.get_array_memory_size()) - .max() - .unwrap_or(0); - - let target_output_num_rows = batch_size(); - let target_output_mem_size = 1 << 26; // 64MB - let inner_exec: Arc<dyn ExecutionPlan> = - Arc::new(MemoryExec::try_new(&[inner_batches], inner_schema, None)?); - - // outer side - let (outer_schema, outer_partitioning, outer_stream) = if left_is_build_side(join_type) { - ( - right.schema(), - right.output_partitioning(), - right.execute(partition, context.clone())?, - ) - } else { - ( - left.schema(), - left.output_partitioning(), - left.execute(partition, context.clone())?, - ) - }; - let chunked_outer_stream = Box::pin(RecordBatchStreamAdapter::new( - outer_schema.clone(), - outer_stream.flat_map(move |batch_result| match batch_result { - Ok(batch) => { - let batch_num_rows = batch.num_rows(); - let batch_mem_size = batch.get_array_memory_size(); - let output_num_rows = batch_num_rows * inner_batch_max_num_rows; - let output_mem_size = batch_num_rows * inner_batch_max_mem_size - + batch_mem_size * inner_batch_max_num_rows; - let chunk_count = std::cmp::min( - (output_num_rows / target_output_num_rows).max(1), - (output_mem_size / target_output_mem_size).max(1), - ); - let chunk_len = (batch_num_rows / chunk_count).max(1); - - let mut chunks = vec![]; - for beg in (0..batch.num_rows()).step_by(chunk_len) { - chunks.push(Ok(batch.slice(beg, chunk_len.min(batch.num_rows() - beg)))); - } - futures::stream::iter(chunks) - } - Err(err) => futures::stream::iter(vec![Err(err)]), - }), - )); - let outer_exec: Arc<dyn ExecutionPlan> = Arc::new(RecordBatchStreamsWrapperExec { - schema: outer_schema, - stream: Mutex::new(Some(chunked_outer_stream)), - output_partitioning: outer_partitioning, - }); - - // join with datafusion's builtin NestedLoopJoinExec - let nlj = if left_is_build_side(join_type) { - NestedLoopJoinExec::try_new(inner_exec, outer_exec, filter, &join_type)? - } else { - NestedLoopJoinExec::try_new(outer_exec, inner_exec, filter, &join_type)? - }; - let joined = nlj.execute(partition, context)?; - - let baseline_metrics = BaselineMetrics::new(&metrics, partition); - let output_stream = Box::pin(RecordBatchStreamAdapter::new( - joined.schema(), - joined.map(move |batch_result| { - if let Ok(batch) = &batch_result { - baseline_metrics.record_output(batch.num_rows()); - } - batch_result - }), - )); - Ok(output_stream) -} - -fn left_is_build_side(join_type: JoinType) -> bool { - matches!( - join_type, - JoinType::Right | JoinType::RightSemi | JoinType::RightAnti | JoinType::Full - ) -}
diff --git a/native-engine/datafusion-ext-plans/src/common/batch_selection.rs b/native-engine/datafusion-ext-plans/src/common/batch_selection.rs index 6aa8395..a5e789c 100644 --- a/native-engine/datafusion-ext-plans/src/common/batch_selection.rs +++ b/native-engine/datafusion-ext-plans/src/common/batch_selection.rs
@@ -41,16 +41,33 @@ take_batch_internal(batch, indices) } +pub fn take_cols<T: num::PrimInt>( + cols: &[ArrayRef], + indices: impl IntoIterator<Item = T>, +) -> Result<Vec<ArrayRef>> { + let indices: UInt32Array = + PrimitiveArray::from_iter(indices.into_iter().map(|idx| idx.to_u32().unwrap())); + take_cols_internal(cols, &indices) +} + +pub fn take_cols_opt<T: num::PrimInt>( + cols: &[ArrayRef], + indices: impl IntoIterator<Item = Option<T>>, +) -> Result<Vec<ArrayRef>> { + let indices: UInt32Array = PrimitiveArray::from_iter( + indices + .into_iter() + .map(|opt| opt.map(|idx| idx.to_u32().unwrap())), + ); + take_cols_internal(cols, &indices) +} + fn take_batch_internal(batch: RecordBatch, indices: UInt32Array) -> Result<RecordBatch> { let taken_num_batch_rows = indices.len(); let schema = batch.schema(); - let cols = batch.columns().to_vec(); - drop(batch); // we would like to release batch as soon as possible + let cols = batch.columns(); - let cols = cols - .into_iter() - .map(|c| Ok(arrow::compute::take(&c, &indices, None)?)) - .collect::<Result<_>>()?; + let cols = take_cols_internal(cols, &indices)?; drop(indices); let taken = RecordBatch::try_new_with_options( @@ -61,6 +78,14 @@ Ok(taken) } +fn take_cols_internal(cols: &[ArrayRef], indices: &UInt32Array) -> Result<Vec<ArrayRef>> { + let cols = cols + .into_iter() + .map(|c| Ok(arrow::compute::take(&c, indices, None)?)) + .collect::<Result<_>>()?; + Ok(cols) +} + pub fn interleave_batches( schema: SchemaRef, batches: &[RecordBatch],
diff --git a/native-engine/datafusion-ext-plans/src/common/output.rs b/native-engine/datafusion-ext-plans/src/common/output.rs index b1a0a28..d888026 100644 --- a/native-engine/datafusion-ext-plans/src/common/output.rs +++ b/native-engine/datafusion-ext-plans/src/common/output.rs
@@ -20,6 +20,7 @@ }; use arrow::{datatypes::SchemaRef, record_batch::RecordBatch}; +use async_trait::async_trait; use blaze_jni_bridge::is_task_running; use datafusion::{ common::Result, @@ -221,3 +222,34 @@ WrappedRecordBatchSender::cancel_task(self); } } + +#[async_trait] +pub trait NextBatchWithTimer { + async fn next_batch( + &mut self, + stop_timer: Option<&mut ScopedTimerGuard<'_>>, + ) -> Result<Option<RecordBatch>>; +} + +#[async_trait] +impl NextBatchWithTimer for SendableRecordBatchStream { + async fn next_batch( + &mut self, + stop_timer: Option<&mut ScopedTimerGuard<'_>>, + ) -> Result<Option<RecordBatch>> { + struct StopScopedTimerGuard<'a, 'z>(&'a mut ScopedTimerGuard<'z>); + impl<'a, 'z> StopScopedTimerGuard<'a, 'z> { + fn new(timer: &'a mut ScopedTimerGuard<'z>) -> Self { + timer.stop(); + Self(timer) + } + } + impl Drop for StopScopedTimerGuard<'_, '_> { + fn drop(&mut self) { + self.0.restart(); + } + } + let _stop_timer = stop_timer.map(|timer| StopScopedTimerGuard::new(timer)); + self.next().await.transpose() + } +}
diff --git a/native-engine/datafusion-ext-plans/src/ipc_reader_exec.rs b/native-engine/datafusion-ext-plans/src/ipc_reader_exec.rs index fb8651c..9688e5a 100644 --- a/native-engine/datafusion-ext-plans/src/ipc_reader_exec.rs +++ b/native-engine/datafusion-ext-plans/src/ipc_reader_exec.rs
@@ -192,8 +192,8 @@ })); while let Some(batch) = { - let reader_cloned = reader.clone(); - tokio::task::spawn_blocking(move || reader_cloned.clone().lock().read_batch()) + let reader = reader.clone(); + tokio::task::spawn_blocking(move || reader.lock().read_batch()) .await .or_else(|err| df_execution_err!("{err}"))?? } {
diff --git a/native-engine/datafusion-ext-plans/src/joins/bhj/full_join.rs b/native-engine/datafusion-ext-plans/src/joins/bhj/full_join.rs new file mode 100644 index 0000000..ca51b56 --- /dev/null +++ b/native-engine/datafusion-ext-plans/src/joins/bhj/full_join.rs
@@ -0,0 +1,324 @@ +// Copyright 2022 The Blaze Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{ + pin::Pin, + sync::{ + atomic::{AtomicUsize, Ordering::Relaxed}, + Arc, + }, +}; + +use arrow::array::{new_null_array, ArrayRef, RecordBatch}; +use async_trait::async_trait; +use bitvec::{bitvec, prelude::BitVec}; +use datafusion::{common::Result, physical_plan::metrics::Time}; + +use crate::{ + broadcast_join_exec::Joiner, + common::{batch_selection::take_cols, output::WrappedRecordBatchSender}, + joins::{ + bhj::{ + filter_joined_indices, + full_join::ProbeSide::{L, R}, + ProbeSide, + }, + join_hash_map::{join_create_hashes, JoinHashMap}, + JoinParams, + }, +}; + +#[derive(std::marker::ConstParamTy, Clone, Copy, PartialEq, Eq)] +pub struct JoinerParams { + probe_side: ProbeSide, + probe_side_outer: bool, + build_side_outer: bool, +} + +impl JoinerParams { + const fn new(probe_side: ProbeSide, probe_side_outer: bool, build_side_outer: bool) -> Self { + Self { + probe_side, + probe_side_outer, + build_side_outer, + } + } +} + +const LEFT_PROBED_INNER: JoinerParams = JoinerParams::new(L, false, false); +const LEFT_PROBED_LEFT: JoinerParams = JoinerParams::new(L, true, false); +const LEFT_PROBED_RIGHT: JoinerParams = JoinerParams::new(L, false, true); +const LEFT_PROBED_OUTER: JoinerParams = JoinerParams::new(L, true, true); + +const RIGHT_PROBED_INNER: JoinerParams = JoinerParams::new(R, false, false); +const RIGHT_PROBED_LEFT: JoinerParams = JoinerParams::new(R, false, true); +const RIGHT_PROBED_RIGHT: JoinerParams = JoinerParams::new(R, true, false); +const RIGHT_PROBED_OUTER: JoinerParams = JoinerParams::new(R, true, true); + +pub type LProbedInnerJoiner = FullJoiner<LEFT_PROBED_INNER>; +pub type LProbedLeftJoiner = FullJoiner<LEFT_PROBED_LEFT>; +pub type LProbedRightJoiner = FullJoiner<LEFT_PROBED_RIGHT>; +pub type LProbedFullOuterJoiner = FullJoiner<LEFT_PROBED_OUTER>; +pub type RProbedInnerJoiner = FullJoiner<RIGHT_PROBED_INNER>; +pub type RProbedLeftJoiner = FullJoiner<RIGHT_PROBED_LEFT>; +pub type RProbedRightJoiner = FullJoiner<RIGHT_PROBED_RIGHT>; +pub type RProbedFullOuterJoiner = FullJoiner<RIGHT_PROBED_OUTER>; + +pub struct FullJoiner<const P: JoinerParams> { + join_params: JoinParams, + output_sender: Arc<WrappedRecordBatchSender>, + map: Arc<JoinHashMap>, + map_joined: BitVec, + send_output_time: Time, + output_rows: AtomicUsize, +} + +impl<const P: JoinerParams> FullJoiner<P> { + pub fn new( + join_params: JoinParams, + map: Arc<JoinHashMap>, + output_sender: Arc<WrappedRecordBatchSender>, + ) -> Self { + let map_joined = bitvec![0; map.data_batch().num_rows()]; + Self { + join_params, + output_sender, + map, + map_joined, + send_output_time: Time::default(), + output_rows: AtomicUsize::new(0), + } + } + + fn create_probed_key_columns(&self, probed_batch: &RecordBatch) -> Result<Vec<ArrayRef>> { + let probed_key_exprs = match P.probe_side { + L => &self.join_params.left_keys, + R => &self.join_params.right_keys, + }; + let probed_key_columns: Vec<ArrayRef> = probed_key_exprs + .iter() + .map(|expr| { + Ok(expr + .evaluate(probed_batch)? + .into_array(probed_batch.num_rows())?) + }) + .collect::<Result<_>>()?; + Ok(probed_key_columns) + } + + async fn flush(&self, probe_cols: Vec<ArrayRef>, build_cols: Vec<ArrayRef>) -> Result<()> { + let output_batch = RecordBatch::try_new( + self.join_params.output_schema.clone(), + match P.probe_side { + L => [probe_cols, build_cols].concat(), + R => [build_cols, probe_cols].concat(), + }, + )?; + self.output_rows.fetch_add(output_batch.num_rows(), Relaxed); + + let timer = self.send_output_time.timer(); + self.output_sender.send(Ok(output_batch), None).await; + drop(timer); + Ok(()) + } + + async fn flush_hash_joined( + mut self: Pin<&mut Self>, + probed_batch: &RecordBatch, + probed_key_columns: &[ArrayRef], + probed_joined: &mut BitVec, + mut hash_joined_probe_indices: Vec<u32>, + mut hash_joined_build_indices: Vec<u32>, + ) -> Result<()> { + filter_joined_indices( + probed_key_columns, + self.map.key_columns(), + &mut hash_joined_probe_indices, + &mut hash_joined_build_indices, + )?; + let probe_indices = hash_joined_probe_indices; + let build_indices = hash_joined_build_indices; + + let pprojected = match P.probe_side { + L => self + .join_params + .projection + .project_left(probed_batch.columns()), + R => self + .join_params + .projection + .project_right(probed_batch.columns()), + }; + let mprojected = match P.probe_side { + L => self + .join_params + .projection + .project_right(self.map.data_batch().columns()), + R => self + .join_params + .projection + .project_left(self.map.data_batch().columns()), + }; + for &idx in &probe_indices { + probed_joined.set(idx as usize, true); + } + let pcols = if probe_indices.len() == probed_batch.num_rows() && probed_joined.all() { + // fast path for the case where every probed records have 1-to-1 joined + pprojected + } else { + take_cols(&pprojected, probe_indices)? + }; + + for &idx in &build_indices { + self.map_joined.set(idx as usize, true); + } + let bcols = take_cols(&mprojected, build_indices)?; + + self.flush(pcols, bcols).await?; + Ok(()) + } +} + +#[async_trait] +impl<const P: JoinerParams> Joiner for FullJoiner<P> { + async fn join(mut self: Pin<&mut Self>, probed_batch: RecordBatch) -> Result<()> { + let mut hash_joined_probe_indices: Vec<u32> = vec![]; + let mut hash_joined_build_indices: Vec<u32> = vec![]; + let mut probed_joined = bitvec![0; probed_batch.num_rows()]; + let batch_size = self.join_params.batch_size.max(probed_batch.num_rows()); + + let probed_key_columns = self.create_probed_key_columns(&probed_batch)?; + let probed_hashes = join_create_hashes(probed_batch.num_rows(), &probed_key_columns)?; + + // join by hash code + for (row_idx, &hash) in probed_hashes.iter().enumerate() { + let mut maybe_joined = false; + if let Some(entries) = self.map.entry_indices(hash) { + for map_idx in entries { + hash_joined_probe_indices.push(row_idx as u32); + hash_joined_build_indices.push(map_idx); + } + maybe_joined = true; + } + + if maybe_joined && hash_joined_probe_indices.len() > batch_size { + self.as_mut() + .flush_hash_joined( + &probed_batch, + &probed_key_columns, + &mut probed_joined, + std::mem::take(&mut hash_joined_probe_indices), + std::mem::take(&mut hash_joined_build_indices), + ) + .await?; + } + } + if !hash_joined_probe_indices.is_empty() { + self.as_mut() + .flush_hash_joined( + &probed_batch, + &probed_key_columns, + &mut probed_joined, + hash_joined_probe_indices, + hash_joined_build_indices, + ) + .await?; + } + + // output unjoined rows of probed side + if P.probe_side_outer { + let probed_unjoined_indices = probed_joined + .iter() + .enumerate() + .filter(|(_, joined)| !**joined) + .map(|(idx, _)| idx as u32) + .collect::<Vec<_>>(); + + let pprojected = match P.probe_side { + L => self + .join_params + .projection + .project_left(probed_batch.columns()), + R => self + .join_params + .projection + .project_right(probed_batch.columns()), + }; + let mprojected = match P.probe_side { + L => self + .join_params + .projection + .project_right(self.map.data_batch().columns()), + R => self + .join_params + .projection + .project_left(self.map.data_batch().columns()), + }; + + let bcols = mprojected + .iter() + .map(|col| new_null_array(col.data_type(), probed_unjoined_indices.len())) + .collect::<Vec<_>>(); + + let pcols = take_cols(&pprojected, probed_unjoined_indices)?; + self.as_mut().flush(pcols, bcols).await?; + } + Ok(()) + } + + async fn finish(mut self: Pin<&mut Self>) -> Result<()> { + // output unjoined rows of probed side + let map_joined = std::mem::take(&mut self.map_joined); + if P.build_side_outer { + let map_unjoined_indices = map_joined + .into_iter() + .enumerate() + .filter(|(_, joined)| !joined) + .map(|(idx, _)| idx as u32) + .collect::<Vec<_>>(); + + let pschema = match P.probe_side { + L => &self.join_params.left_schema, + R => &self.join_params.right_schema, + }; + let mprojected = match P.probe_side { + L => self + .join_params + .projection + .project_right(self.map.data_batch().columns()), + R => self + .join_params + .projection + .project_left(self.map.data_batch().columns()), + }; + + let pcols = pschema + .fields() + .iter() + .map(|field| new_null_array(field.data_type(), map_unjoined_indices.len())) + .collect::<Vec<_>>(); + let bcols = take_cols(&mprojected, map_unjoined_indices)?; + self.as_mut().flush(pcols, bcols).await?; + } + Ok(()) + } + + fn total_send_output_time(&self) -> usize { + self.send_output_time.value() + } + + fn num_output_rows(&self) -> usize { + self.output_rows.load(Relaxed) + } +}
diff --git a/native-engine/datafusion-ext-plans/src/joins/bhj/mod.rs b/native-engine/datafusion-ext-plans/src/joins/bhj/mod.rs new file mode 100644 index 0000000..57d934c --- /dev/null +++ b/native-engine/datafusion-ext-plans/src/joins/bhj/mod.rs
@@ -0,0 +1,146 @@ +// Copyright 2022 The Blaze Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use arrow::{ + array::*, + datatypes::{DataType, IntervalUnit, TimeUnit}, +}; +use datafusion::common::Result; +use datafusion_ext_commons::{df_execution_err, downcast_any}; + +pub mod full_join; +pub mod semi_join; + +#[derive(std::marker::ConstParamTy, Clone, Copy, PartialEq, Eq)] +pub enum ProbeSide { + L, + R, +} + +fn filter_joined_indices( + key_columns1: &[ArrayRef], + key_columns2: &[ArrayRef], + indices1: &mut Vec<u32>, + indices2: &mut Vec<u32>, +) -> Result<()> { + fn filter_one( + key_column1: &ArrayRef, + key_column2: &ArrayRef, + indices1: &mut Vec<u32>, + indices2: &mut Vec<u32>, + ) -> Result<()> { + macro_rules! filter_atomic { + ($cast_type:ty) => {{ + let col1 = downcast_any!(key_column1, $cast_type)?; + let col2 = downcast_any!(key_column2, $cast_type)?; + let mut valid_count = 0; + for i in 0..indices1.len() { + let idx1 = indices1[i] as usize; + let idx2 = indices2[i] as usize; + if col1.is_valid(idx1) && col2.is_valid(idx2) && { + let v1 = col1.value(idx1); + let v2 = col2.value(idx2); + v1 == v2 + } { + indices1[valid_count] = indices1[i]; + indices2[valid_count] = indices2[i]; + valid_count += 1; + } + } + indices1.truncate(valid_count); + indices2.truncate(valid_count); + }}; + } + + let dt1 = key_column1.data_type(); + let dt2 = key_column2.data_type(); + if dt1 != dt2 { + return df_execution_err!("join key data type not matched: {dt1:?} <-> {dt2:?}"); + } + match dt1 { + DataType::Null => { + indices1.clear(); + indices2.clear(); + } + DataType::Boolean => filter_atomic!(BooleanArray), + DataType::Int8 => filter_atomic!(Int8Array), + DataType::Int16 => filter_atomic!(Int16Array), + DataType::Int32 => filter_atomic!(Int32Array), + DataType::Int64 => filter_atomic!(Int64Array), + DataType::UInt8 => filter_atomic!(UInt8Array), + DataType::UInt16 => filter_atomic!(UInt16Array), + DataType::UInt32 => filter_atomic!(UInt32Array), + DataType::UInt64 => filter_atomic!(UInt64Array), + DataType::Float16 => filter_atomic!(Float16Array), + DataType::Float32 => filter_atomic!(Float32Array), + DataType::Float64 => filter_atomic!(Float64Array), + DataType::Timestamp(unit, _) => match unit { + TimeUnit::Second => filter_atomic!(TimestampSecondArray), + TimeUnit::Millisecond => filter_atomic!(TimestampMillisecondArray), + TimeUnit::Microsecond => filter_atomic!(TimestampMicrosecondArray), + TimeUnit::Nanosecond => filter_atomic!(TimestampNanosecondArray), + }, + DataType::Date32 => filter_atomic!(Date32Array), + DataType::Date64 => filter_atomic!(Date64Array), + DataType::Time32(unit) => match unit { + TimeUnit::Second => filter_atomic!(Time32SecondArray), + TimeUnit::Millisecond => filter_atomic!(Time32MillisecondArray), + TimeUnit::Microsecond => filter_atomic!(Time32MillisecondArray), + TimeUnit::Nanosecond => filter_atomic!(Time32MillisecondArray), + }, + DataType::Time64(unit) => match unit { + TimeUnit::Microsecond => filter_atomic!(Time64MicrosecondArray), + TimeUnit::Nanosecond => filter_atomic!(Time64NanosecondArray), + _ => return df_execution_err!("unsupported time64 unit: {unit:?}"), + }, + DataType::Duration(unit) => match unit { + TimeUnit::Second => filter_atomic!(DurationSecondArray), + TimeUnit::Millisecond => filter_atomic!(DurationMillisecondArray), + TimeUnit::Microsecond => filter_atomic!(DurationMicrosecondArray), + TimeUnit::Nanosecond => filter_atomic!(DurationNanosecondArray), + }, + DataType::Interval(unit) => match unit { + IntervalUnit::YearMonth => filter_atomic!(IntervalYearMonthArray), + IntervalUnit::DayTime => filter_atomic!(IntervalDayTimeArray), + IntervalUnit::MonthDayNano => filter_atomic!(IntervalMonthDayNanoArray), + }, + DataType::Binary => filter_atomic!(BinaryArray), + DataType::FixedSizeBinary(_) => filter_atomic!(FixedSizeBinaryArray), + DataType::LargeBinary => filter_atomic!(LargeBinaryArray), + DataType::Utf8 => filter_atomic!(StringArray), + DataType::LargeUtf8 => filter_atomic!(LargeStringArray), + DataType::List(_) => filter_atomic!(ListArray), + DataType::FixedSizeList(..) => filter_atomic!(FixedSizeListArray), + DataType::LargeList(_) => filter_atomic!(LargeListArray), + DataType::Struct(_) => filter_joined_indices( + key_column1.as_struct().columns(), + key_column2.as_struct().columns(), + indices1, + indices2, + )?, + DataType::Decimal128(..) => filter_atomic!(Decimal128Array), + DataType::Decimal256(..) => filter_atomic!(Decimal256Array), + DataType::Map(..) => filter_atomic!(MapArray), + dt => { + return df_execution_err!("unsupported data type: {dt:?}"); + } + } + Ok(()) + } + + for (key_column1, key_column2) in key_columns1.iter().zip(key_columns2) { + filter_one(key_column1, key_column2, indices1, indices2)?; + } + Ok(()) +}
diff --git a/native-engine/datafusion-ext-plans/src/joins/bhj/semi_join.rs b/native-engine/datafusion-ext-plans/src/joins/bhj/semi_join.rs new file mode 100644 index 0000000..8c168f0 --- /dev/null +++ b/native-engine/datafusion-ext-plans/src/joins/bhj/semi_join.rs
@@ -0,0 +1,283 @@ +// Copyright 2022 The Blaze Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{ + pin::Pin, + sync::{ + atomic::{AtomicUsize, Ordering::Relaxed}, + Arc, + }, +}; + +use arrow::array::{ArrayRef, BooleanArray, RecordBatch}; +use async_trait::async_trait; +use bitvec::{bitvec, prelude::BitVec}; +use datafusion::{common::Result, physical_plan::metrics::Time}; + +use crate::{ + broadcast_join_exec::Joiner, + common::{batch_selection::take_cols, output::WrappedRecordBatchSender}, + joins::{ + bhj::{ + filter_joined_indices, + semi_join::{ + ProbeSide::{L, R}, + SemiMode::{Anti, Existence, Semi}, + }, + ProbeSide, + }, + join_hash_map::{join_create_hashes, JoinHashMap}, + JoinParams, + }, +}; + +#[derive(std::marker::ConstParamTy, Clone, Copy, PartialEq, Eq)] +pub enum SemiMode { + Semi, + Anti, + Existence, +} + +#[derive(std::marker::ConstParamTy, Clone, Copy, PartialEq, Eq)] +pub struct JoinerParams { + probe_side: ProbeSide, + probe_is_join_side: bool, + mode: SemiMode, +} + +impl JoinerParams { + const fn new(probe_side: ProbeSide, probe_is_join_side: bool, mode: SemiMode) -> Self { + Self { + probe_side, + probe_is_join_side, + mode, + } + } +} + +const LEFT_PROBED_LEFT_SEMI: JoinerParams = JoinerParams::new(L, true, Semi); +const LEFT_PROBED_LEFT_ANTI: JoinerParams = JoinerParams::new(L, true, Anti); +const LEFT_PROBED_RIGHT_SEMI: JoinerParams = JoinerParams::new(L, false, Semi); +const LEFT_PROBED_RIGHT_ANTI: JoinerParams = JoinerParams::new(L, false, Anti); +const LEFT_PROBED_EXISTENCE: JoinerParams = JoinerParams::new(L, true, Existence); +const RIGHT_PROBED_LEFT_SEMI: JoinerParams = JoinerParams::new(R, false, Semi); +const RIGHT_PROBED_LEFT_ANTI: JoinerParams = JoinerParams::new(R, false, Anti); +const RIGHT_PROBED_RIGHT_SEMI: JoinerParams = JoinerParams::new(R, true, Semi); +const RIGHT_PROBED_RIGHT_ANTI: JoinerParams = JoinerParams::new(R, true, Anti); +const RIGHT_PROBED_EXISTENCE: JoinerParams = JoinerParams::new(R, false, Existence); + +pub type LProbedLeftSemiJoiner = SemiJoiner<LEFT_PROBED_LEFT_SEMI>; +pub type LProbedLeftAntiJoiner = SemiJoiner<LEFT_PROBED_LEFT_ANTI>; +pub type LProbedRightSemiJoiner = SemiJoiner<LEFT_PROBED_RIGHT_SEMI>; +pub type LProbedRightAntiJoiner = SemiJoiner<LEFT_PROBED_RIGHT_ANTI>; +pub type LProbedExistenceJoiner = SemiJoiner<LEFT_PROBED_EXISTENCE>; +pub type RProbedLeftSemiJoiner = SemiJoiner<RIGHT_PROBED_LEFT_SEMI>; +pub type RProbedLeftAntiJoiner = SemiJoiner<RIGHT_PROBED_LEFT_ANTI>; +pub type RProbedRightSemiJoiner = SemiJoiner<RIGHT_PROBED_RIGHT_SEMI>; +pub type RProbedRightAntiJoiner = SemiJoiner<RIGHT_PROBED_RIGHT_ANTI>; +pub type RProbedExistenceJoiner = SemiJoiner<RIGHT_PROBED_EXISTENCE>; + +pub struct SemiJoiner<const P: JoinerParams> { + join_params: JoinParams, + output_sender: Arc<WrappedRecordBatchSender>, + map_joined: BitVec, + map: Arc<JoinHashMap>, + send_output_time: Time, + output_rows: AtomicUsize, +} + +impl<const P: JoinerParams> SemiJoiner<P> { + pub fn new( + join_params: JoinParams, + map: Arc<JoinHashMap>, + output_sender: Arc<WrappedRecordBatchSender>, + ) -> Self { + let map_joined = bitvec![0; map.data_batch().num_rows()]; + Self { + join_params, + output_sender, + map, + map_joined, + send_output_time: Time::new(), + output_rows: AtomicUsize::new(0), + } + } + + fn create_probed_key_columns(&self, probed_batch: &RecordBatch) -> Result<Vec<ArrayRef>> { + let probed_key_exprs = match P.probe_side { + L => &self.join_params.left_keys, + R => &self.join_params.right_keys, + }; + let probed_key_columns: Vec<ArrayRef> = probed_key_exprs + .iter() + .map(|expr| { + Ok(expr + .evaluate(probed_batch)? + .into_array(probed_batch.num_rows())?) + }) + .collect::<Result<_>>()?; + Ok(probed_key_columns) + } + + async fn flush(&self, cols: Vec<ArrayRef>) -> Result<()> { + let output_batch = RecordBatch::try_new(self.join_params.output_schema.clone(), cols)?; + self.output_rows.fetch_add(output_batch.num_rows(), Relaxed); + + let timer = self.send_output_time.timer(); + self.output_sender.send(Ok(output_batch), None).await; + drop(timer); + Ok(()) + } + + fn flush_hash_joined( + mut self: Pin<&mut Self>, + probed_key_columns: &[ArrayRef], + probed_joined: &mut BitVec, + mut hash_joined_probe_indices: Vec<u32>, + mut hash_joined_build_indices: Vec<u32>, + ) -> Result<()> { + filter_joined_indices( + probed_key_columns, + self.map.key_columns(), + &mut hash_joined_probe_indices, + &mut hash_joined_build_indices, + )?; + let probe_indices = hash_joined_probe_indices; + let build_indices = hash_joined_build_indices; + + for &idx in &probe_indices { + probed_joined.set(idx as usize, true); + } + for &idx in &build_indices { + self.map_joined.set(idx as usize, true); + } + Ok(()) + } +} + +#[async_trait] +impl<const P: JoinerParams> Joiner for SemiJoiner<P> { + async fn join(mut self: Pin<&mut Self>, probed_batch: RecordBatch) -> Result<()> { + let mut hash_joined_probe_indices: Vec<u32> = vec![]; + let mut hash_joined_build_indices: Vec<u32> = vec![]; + let mut probed_joined = bitvec![0; probed_batch.num_rows()]; + + let probed_key_columns = self.create_probed_key_columns(&probed_batch)?; + let probed_hashes = join_create_hashes(probed_batch.num_rows(), &probed_key_columns)?; + + // join by hash code + for (row_idx, &hash) in probed_hashes.iter().enumerate() { + let mut maybe_joined = false; + if let Some(entries) = self.map.entry_indices(hash) { + for map_idx in entries { + hash_joined_probe_indices.push(row_idx as u32); + hash_joined_build_indices.push(map_idx); + } + maybe_joined = true; + } + + if maybe_joined && hash_joined_probe_indices.len() >= self.join_params.batch_size { + self.as_mut().flush_hash_joined( + &probed_key_columns, + &mut probed_joined, + std::mem::take(&mut hash_joined_probe_indices), + std::mem::take(&mut hash_joined_build_indices), + )?; + } + } + if !hash_joined_probe_indices.is_empty() { + self.as_mut().flush_hash_joined( + &probed_key_columns, + &mut probed_joined, + hash_joined_probe_indices, + hash_joined_build_indices, + )?; + } + + if P.probe_is_join_side { + let pprojected = match P.probe_side { + L => self + .join_params + .projection + .project_left(probed_batch.columns()), + R => self + .join_params + .projection + .project_right(probed_batch.columns()), + }; + let pcols = match P.mode { + Semi | Anti => { + let probed_indices = probed_joined + .into_iter() + .enumerate() + .filter(|(_, joined)| (P.mode == Semi) ^ !joined) + .map(|(idx, _)| idx as u32) + .collect::<Vec<_>>(); + take_cols(&pprojected, probed_indices)? + } + Existence => { + let exists_col = Arc::new(BooleanArray::from( + probed_joined.into_iter().collect::<Vec<_>>(), + )); + [pprojected, vec![exists_col]].concat() + } + }; + self.as_mut().flush(pcols).await?; + } + Ok(()) + } + + async fn finish(mut self: Pin<&mut Self>) -> Result<()> { + if !P.probe_is_join_side { + let mprojected = match P.probe_side { + L => self + .join_params + .projection + .project_right(self.map.data_batch().columns()), + R => self + .join_params + .projection + .project_left(self.map.data_batch().columns()), + }; + let map_joined = std::mem::take(&mut self.map_joined); + let pcols = match P.mode { + Semi | Anti => { + let map_indices = map_joined + .into_iter() + .enumerate() + .filter(|(_, joined)| (P.mode == Semi) ^ !joined) + .map(|(idx, _)| idx as u32) + .collect::<Vec<_>>(); + take_cols(&mprojected, map_indices)? + } + Existence => { + let exists_col = Arc::new(BooleanArray::from( + map_joined.into_iter().collect::<Vec<_>>(), + )); + [mprojected, vec![exists_col]].concat() + } + }; + self.as_mut().flush(pcols).await?; + } + Ok(()) + } + + fn total_send_output_time(&self) -> usize { + self.send_output_time.value() + } + + fn num_output_rows(&self) -> usize { + self.output_rows.load(Relaxed) + } +}
diff --git a/native-engine/datafusion-ext-plans/src/joins/join_hash_map.rs b/native-engine/datafusion-ext-plans/src/joins/join_hash_map.rs new file mode 100644 index 0000000..8bd1a57 --- /dev/null +++ b/native-engine/datafusion-ext-plans/src/joins/join_hash_map.rs
@@ -0,0 +1,340 @@ +// Copyright 2022 The Blaze Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{ + io::{Cursor, Read, Write}, + slice::{from_raw_parts, from_raw_parts_mut}, + sync::Arc, +}; + +use arrow::{ + array::{ArrayRef, AsArray, BinaryBuilder, RecordBatch}, + datatypes::{DataType, Field, FieldRef, Schema, SchemaRef}, +}; +use byteorder::{NativeEndian, ReadBytesExt, WriteBytesExt}; +use datafusion::{common::Result, physical_expr::PhysicalExprRef}; +use datafusion_ext_commons::spark_hash::create_hashes; +use hashbrown::HashMap; +use itertools::Itertools; +use once_cell::sync::OnceCell; + +use crate::common::batch_selection::take_batch; + +pub struct Table { + entry_offsets: Vec<u32>, + entry_lens: Vec<u32>, + item_indices: Vec<u32>, + item_hashes: Vec<u32>, +} + +impl Table { + pub fn new_empty() -> Self { + let num_entries = Self::num_entries_of_rows(0); + Self { + entry_offsets: vec![0; num_entries], + entry_lens: vec![0; num_entries], + item_indices: vec![], + item_hashes: vec![], + } + } + + pub fn try_from_key_columns( + num_rows: usize, + data_batch: RecordBatch, + key_columns: &[ArrayRef], + ) -> Result<(Self, RecordBatch)> { + // returns the new data batch sorted by hashes + + assert!( + num_rows < 1073741824, + "join hash table: number of rows exceeded 2^30: {num_rows}" + ); + + let num_entries = Self::num_entries_of_rows(num_rows) as u32; + let item_hashes = join_create_hashes(num_rows, &key_columns)?; + + // sort record batch by hashes for better compression and data locality + let (indices, item_hashes): (Vec<usize>, Vec<u32>) = item_hashes + .into_iter() + .enumerate() + .sorted_unstable_by_key(|(_idx, hash)| *hash) + .unzip(); + let data_batch = take_batch(data_batch, indices)?; + + let mut entries_to_row_indices: HashMap<u32, Vec<u32>> = HashMap::new(); + for (row_idx, hash) in item_hashes.iter().enumerate() { + let entry = hash % num_entries; + entries_to_row_indices + .entry(entry) + .or_default() + .push(row_idx as u32); + } + + let mut entry_offsets = Vec::with_capacity(num_entries as usize); + let mut entry_lens = Vec::with_capacity(num_entries as usize); + let mut item_indices = Vec::with_capacity(num_rows); + for entry in 0..num_entries { + match entries_to_row_indices.get(&entry) { + Some(row_indices) => { + entry_offsets.push(item_indices.len() as u32); + entry_lens.push(row_indices.len() as u32); + item_indices.extend_from_slice(row_indices); + } + None => { + entry_offsets.push(item_indices.len() as u32); + entry_lens.push(0); + } + } + } + let new = Self { + entry_offsets, + entry_lens, + item_indices, + item_hashes, + }; + Ok((new, data_batch)) + } + + pub fn try_from_raw_bytes(raw_bytes: &[u8]) -> Result<Self> { + let mut cursor = Cursor::new(raw_bytes); + let num_rows = cursor.read_u32::<NativeEndian>()? as usize; + let num_entries = Self::num_entries_of_rows(num_rows); + + let mut new = Self { + entry_offsets: vec![0; num_entries], + entry_lens: vec![0; num_entries], + item_indices: vec![0; num_rows], + item_hashes: vec![0; num_rows], + }; + + unsafe { + // safety: read integer arrays as raw bytes + cursor.read_exact(from_raw_parts_mut( + new.entry_offsets.as_mut_ptr() as *mut u8, + num_entries * 4, + ))?; + cursor.read_exact(from_raw_parts_mut( + new.entry_lens.as_mut_ptr() as *mut u8, + num_entries * 4, + ))?; + cursor.read_exact(from_raw_parts_mut( + new.item_indices.as_mut_ptr() as *mut u8, + num_rows * 4, + ))?; + cursor.read_exact(from_raw_parts_mut( + new.item_hashes.as_mut_ptr() as *mut u8, + num_rows * 4, + ))?; + } + Ok(new) + } + + pub fn try_into_raw_bytes(self) -> Result<Vec<u8>> { + let num_entries = self.entry_offsets.len(); + let num_rows = self.item_indices.len(); + let mut raw_bytes = Vec::with_capacity(num_entries * 8 + num_rows * 4 + 4); + + raw_bytes.write_u32::<NativeEndian>(num_rows as u32)?; + unsafe { + // safety: write integer arrays as raw bytes + raw_bytes.write_all(from_raw_parts( + self.entry_offsets.as_ptr() as *const u8, + num_entries * 4, + ))?; + raw_bytes.write_all(from_raw_parts( + self.entry_lens.as_ptr() as *const u8, + num_entries * 4, + ))?; + raw_bytes.write_all(from_raw_parts( + self.item_indices.as_ptr() as *const u8, + num_rows * 4, + ))?; + raw_bytes.write_all(from_raw_parts( + self.item_hashes.as_ptr() as *const u8, + num_rows * 4, + ))?; + } + Ok(raw_bytes) + } + + pub fn entry<'a>(&'a self, hash: u32) -> Option<impl Iterator<Item = u32> + 'a> { + let entry = hash % (self.entry_offsets.len() as u32); + let len = self.entry_lens[entry as usize] as usize; + if len > 0 { + let offset = self.entry_offsets[entry as usize] as usize; + Some( + self.item_indices[offset..][..len] + .iter() + .cloned() + .filter(move |&idx| self.item_hashes[idx as usize] == hash), + ) + } else { + None + } + } + + fn num_entries_of_rows(num_rows: usize) -> usize { + num_rows * 3 + 1 + } +} + +pub struct JoinHashMap { + data_batch: RecordBatch, + key_columns: Vec<ArrayRef>, + table: Table, +} + +impl JoinHashMap { + pub fn try_from_data_batch( + data_batch: RecordBatch, + key_exprs: &[PhysicalExprRef], + ) -> Result<JoinHashMap> { + let key_columns: Vec<ArrayRef> = key_exprs + .iter() + .map(|expr| { + Ok(expr + .evaluate(&data_batch)? + .into_array(data_batch.num_rows())?) + }) + .collect::<Result<_>>()?; + + let (table, data_batch) = + Table::try_from_key_columns(data_batch.num_rows(), data_batch, &key_columns)?; + Ok(JoinHashMap { + data_batch, + key_columns, + table, + }) + } + + pub fn try_from_hash_map_batch( + hash_map_batch: RecordBatch, + key_exprs: &[PhysicalExprRef], + ) -> Result<Self> { + let mut data_batch = hash_map_batch.clone(); + let table = Table::try_from_raw_bytes( + data_batch + .remove_column(data_batch.num_columns() - 1) + .as_binary::<i32>() + .value(0), + )?; + let key_columns: Vec<ArrayRef> = key_exprs + .iter() + .map(|expr| { + Ok(expr + .evaluate(&data_batch)? + .into_array(data_batch.num_rows())?) + }) + .collect::<Result<_>>()?; + Ok(Self { + data_batch, + key_columns, + table, + }) + } + + pub fn try_new_empty( + hash_map_schema: SchemaRef, + key_exprs: &[PhysicalExprRef], + ) -> Result<Self> { + let table = Table::new_empty(); + let data_batch = RecordBatch::new_empty(hash_map_schema); + let key_columns: Vec<ArrayRef> = key_exprs + .iter() + .map(|expr| { + Ok(expr + .evaluate(&data_batch)? + .into_array(data_batch.num_rows())?) + }) + .collect::<Result<_>>()?; + Ok(Self { + data_batch, + key_columns, + table, + }) + } + + pub fn data_schema(&self) -> SchemaRef { + self.data_batch().schema() + } + + pub fn data_batch(&self) -> &RecordBatch { + &self.data_batch + } + + pub fn key_columns(&self) -> &[ArrayRef] { + &self.key_columns + } + + pub fn entry_indices<'a>(&'a self, hash: u32) -> Option<impl Iterator<Item = u32> + 'a> { + self.table.entry(hash) + } + + pub fn into_hash_map_batch(self) -> Result<RecordBatch> { + let schema = join_hash_map_schema(&self.data_batch.schema()); + if self.data_batch.num_rows() == 0 { + return Ok(RecordBatch::new_empty(schema)); + } + let mut table_col_builder = BinaryBuilder::new(); + table_col_builder.append_value(&self.table.try_into_raw_bytes()?); + for _ in 1..self.data_batch.num_rows() { + table_col_builder.append_null(); + } + let table_col: ArrayRef = Arc::new(table_col_builder.finish()); + Ok(RecordBatch::try_new( + schema, + vec![self.data_batch.columns().to_vec(), vec![table_col]].concat(), + )?) + } +} + +#[inline] +pub fn join_data_schema(hash_map_schema: &SchemaRef) -> SchemaRef { + Arc::new(Schema::new( + hash_map_schema + .fields() + .iter() + .take(hash_map_schema.fields().len() - 1) // exclude hash map column + .cloned() + .collect::<Vec<_>>(), + )) +} + +#[inline] +pub fn join_hash_map_schema(data_schema: &SchemaRef) -> SchemaRef { + Arc::new(Schema::new( + data_schema + .fields() + .iter() + .map(|field| Arc::new(field.as_ref().clone().with_nullable(true))) + .chain(std::iter::once(join_table_field())) + .collect::<Vec<_>>(), + )) +} + +#[inline] +pub fn join_create_hashes(num_rows: usize, key_columns: &[ArrayRef]) -> Result<Vec<u32>> { + const JOIN_HASH_RANDOM_SEED: u32 = 0x90ec4058; + let mut hashes = vec![JOIN_HASH_RANDOM_SEED; num_rows]; + create_hashes(key_columns, &mut hashes)?; + Ok(hashes) +} + +#[inline] +fn join_table_field() -> FieldRef { + static BHJ_KEY_FIELD: OnceCell<FieldRef> = OnceCell::new(); + BHJ_KEY_FIELD + .get_or_init(|| Arc::new(Field::new("~TABLE", DataType::Binary, true))) + .clone() +}
diff --git a/native-engine/datafusion-ext-plans/src/joins/join_utils.rs b/native-engine/datafusion-ext-plans/src/joins/join_utils.rs new file mode 100644 index 0000000..076cfa1 --- /dev/null +++ b/native-engine/datafusion-ext-plans/src/joins/join_utils.rs
@@ -0,0 +1,64 @@ +// Copyright 2022 The Blaze Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use datafusion::common::{DataFusionError, Result}; +use datafusion_ext_commons::df_execution_err; + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum JoinType { + Inner, + Left, + Right, + Full, + LeftAnti, + RightAnti, + LeftSemi, + RightSemi, + Existence, +} + +impl TryFrom<JoinType> for datafusion::prelude::JoinType { + type Error = DataFusionError; + + fn try_from(value: JoinType) -> Result<Self> { + match value { + JoinType::Inner => Ok(datafusion::prelude::JoinType::Inner), + JoinType::Left => Ok(datafusion::prelude::JoinType::Left), + JoinType::Right => Ok(datafusion::prelude::JoinType::Right), + JoinType::Full => Ok(datafusion::prelude::JoinType::Full), + JoinType::LeftAnti => Ok(datafusion::prelude::JoinType::LeftAnti), + JoinType::RightAnti => Ok(datafusion::prelude::JoinType::RightAnti), + JoinType::LeftSemi => Ok(datafusion::prelude::JoinType::LeftSemi), + JoinType::RightSemi => Ok(datafusion::prelude::JoinType::RightSemi), + other => df_execution_err!("unsupported join type: {other:?}"), + } + } +} + +impl TryFrom<datafusion::prelude::JoinType> for JoinType { + type Error = DataFusionError; + + fn try_from(value: datafusion::prelude::JoinType) -> Result<Self> { + match value { + datafusion::prelude::JoinType::Inner => Ok(JoinType::Inner), + datafusion::prelude::JoinType::Left => Ok(JoinType::Left), + datafusion::prelude::JoinType::Right => Ok(JoinType::Right), + datafusion::prelude::JoinType::Full => Ok(JoinType::Full), + datafusion::prelude::JoinType::LeftAnti => Ok(JoinType::LeftAnti), + datafusion::prelude::JoinType::RightAnti => Ok(JoinType::RightAnti), + datafusion::prelude::JoinType::LeftSemi => Ok(JoinType::LeftSemi), + datafusion::prelude::JoinType::RightSemi => Ok(JoinType::RightSemi), + } + } +}
diff --git a/native-engine/datafusion-ext-plans/src/joins/mod.rs b/native-engine/datafusion-ext-plans/src/joins/mod.rs new file mode 100644 index 0000000..3505a9a --- /dev/null +++ b/native-engine/datafusion-ext-plans/src/joins/mod.rs
@@ -0,0 +1,113 @@ +// Copyright 2022 The Blaze Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; + +use arrow::{ + array::ArrayRef, + compute::SortOptions, + datatypes::{DataType, SchemaRef}, +}; +use datafusion::{common::Result, physical_expr::PhysicalExprRef}; + +use crate::joins::{join_utils::JoinType, stream_cursor::StreamCursor}; + +pub mod join_hash_map; +pub mod join_utils; +pub mod stream_cursor; + +// join implementations +pub mod bhj; +pub mod smj; +mod test; + +#[derive(Debug, Clone)] +pub struct JoinParams { + pub join_type: JoinType, + pub left_schema: SchemaRef, + pub right_schema: SchemaRef, + pub output_schema: SchemaRef, + pub left_keys: Vec<PhysicalExprRef>, + pub right_keys: Vec<PhysicalExprRef>, + pub key_data_types: Vec<DataType>, + pub sort_options: Vec<SortOptions>, + pub projection: JoinProjection, + pub batch_size: usize, +} + +#[derive(Debug, Clone)] +pub struct JoinProjection { + pub schema: SchemaRef, + pub left_schema: SchemaRef, + pub right_schema: SchemaRef, + pub left: Vec<usize>, + pub right: Vec<usize>, +} + +impl JoinProjection { + pub fn try_new( + join_type: JoinType, + schema: &SchemaRef, + left_schema: &SchemaRef, + right_schema: &SchemaRef, + projection: &[usize], + ) -> Result<Self> { + let projected_schema = Arc::new(schema.project(projection)?); + let mut left = vec![]; + let mut right = vec![]; + + match join_type { + JoinType::Inner | JoinType::Left | JoinType::Right | JoinType::Full => { + for &i in projection { + if i < left_schema.fields().len() { + left.push(i); + } else if i - left_schema.fields().len() < right_schema.fields().len() { + right.push(i - left_schema.fields().len()); + } + } + } + JoinType::LeftAnti | JoinType::LeftSemi => { + left = projection.to_vec(); + } + JoinType::RightAnti | JoinType::RightSemi => { + right = projection.to_vec(); + } + JoinType::Existence => { + for &i in projection { + if i < left_schema.fields().len() { + left.push(i); + } + } + } + } + Ok(Self { + schema: projected_schema, + left_schema: Arc::new(left_schema.project(&left)?), + right_schema: Arc::new(right_schema.project(&right)?), + left, + right, + }) + } + + pub fn project_left(&self, cols: &[ArrayRef]) -> Vec<ArrayRef> { + self.left.iter().map(|&i| cols[i].clone()).collect() + } + + pub fn project_right(&self, cols: &[ArrayRef]) -> Vec<ArrayRef> { + self.right.iter().map(|&i| cols[i].clone()).collect() + } +} + +pub type Idx = (usize, usize); +pub type StreamCursors = (StreamCursor, StreamCursor);
diff --git a/native-engine/datafusion-ext-plans/src/joins/smj/existence_join.rs b/native-engine/datafusion-ext-plans/src/joins/smj/existence_join.rs new file mode 100644 index 0000000..5749eb0 --- /dev/null +++ b/native-engine/datafusion-ext-plans/src/joins/smj/existence_join.rs
@@ -0,0 +1,175 @@ +// Copyright 2022 The Blaze Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{cmp::Ordering, pin::Pin, sync::Arc}; + +use arrow::array::{ArrayRef, RecordBatch, RecordBatchOptions}; +use async_trait::async_trait; +use datafusion::{common::Result, physical_plan::metrics::Time}; +use datafusion_ext_commons::suggested_output_batch_mem_size; + +use crate::{ + common::{batch_selection::interleave_batches, output::WrappedRecordBatchSender}, + compare_cursor, cur_forward, + joins::{Idx, JoinParams, StreamCursors}, + sort_merge_join_exec::Joiner, +}; + +pub struct ExistenceJoiner { + join_params: JoinParams, + output_sender: Arc<WrappedRecordBatchSender>, + indices: Vec<Idx>, + exists: Vec<bool>, + send_output_time: Time, + output_rows: usize, +} + +impl ExistenceJoiner { + pub fn new(join_params: JoinParams, output_sender: Arc<WrappedRecordBatchSender>) -> Self { + Self { + join_params, + output_sender, + indices: vec![], + exists: vec![], + send_output_time: Time::new(), + output_rows: 0, + } + } + + fn should_flush(&self, curs: &StreamCursors) -> bool { + if self.indices.len() >= self.join_params.batch_size { + return true; + } + + if curs.0.num_buffered_batches() + curs.1.num_buffered_batches() >= 6 + && curs.0.mem_size() + curs.1.mem_size() > suggested_output_batch_mem_size() + { + if let Some(first_idx) = self.indices.first() { + if first_idx.0 < curs.0.cur_idx.0 { + return true; + } + } + } + false + } + + async fn flush(mut self: Pin<&mut Self>, curs: &mut StreamCursors) -> Result<()> { + let indices = std::mem::take(&mut self.indices); + let num_rows = indices.len(); + let cols = interleave_batches( + curs.0.projected_batch_schema.clone(), + &curs.0.projected_batches, + &indices, + )?; + + let exists = std::mem::take(&mut self.exists); + let exists_col: ArrayRef = Arc::new(arrow::array::BooleanArray::from(exists)); + + let output_batch = RecordBatch::try_new_with_options( + self.join_params.output_schema.clone(), + [cols.columns().to_vec(), vec![exists_col]].concat(), + &RecordBatchOptions::new().with_row_count(Some(num_rows)), + )?; + + if output_batch.num_rows() > 0 { + self.output_rows += output_batch.num_rows(); + + let timer = self.send_output_time.timer(); + self.output_sender.send(Ok(output_batch), None).await; + drop(timer); + } + Ok(()) + } +} + +#[async_trait] +impl Joiner for ExistenceJoiner { + async fn join(mut self: Pin<&mut Self>, curs: &mut StreamCursors) -> Result<()> { + while !curs.0.finished && !curs.1.finished { + let mut lidx = curs.0.cur_idx; + let mut ridx = curs.1.cur_idx; + + match compare_cursor!(curs) { + Ordering::Less => { + self.indices.push(curs.0.cur_idx); + self.exists.push(false); + cur_forward!(curs.0); + if self.should_flush(curs) { + self.as_mut().flush(curs).await?; + } + curs.0 + .set_min_reserved_idx(*self.indices.first().unwrap_or(&curs.0.cur_idx)); + } + Ordering::Greater => { + cur_forward!(curs.1); + curs.1 + .set_min_reserved_idx(*self.indices.first().unwrap_or(&curs.1.cur_idx)); + } + Ordering::Equal => { + loop { + self.indices.push(lidx); + self.exists.push(true); + cur_forward!(curs.0); + if self.should_flush(curs) { + self.as_mut().flush(curs).await?; + } + curs.0 + .set_min_reserved_idx(*self.indices.first().unwrap_or(&lidx)); + + if !curs.0.finished && curs.0.key(curs.0.cur_idx) == curs.0.key(lidx) { + lidx = curs.0.cur_idx; + continue; + } + break; + } + + // skip all right equal rows + loop { + cur_forward!(curs.1); + curs.1.set_min_reserved_idx(ridx); + + if !curs.1.finished && curs.1.key(curs.1.cur_idx) == curs.1.key(ridx) { + ridx = curs.1.cur_idx; + continue; + } + break; + } + } + } + } + + while !curs.0.finished { + self.indices.push(curs.0.cur_idx); + self.exists.push(false); + cur_forward!(curs.0); + if self.should_flush(curs) { + self.as_mut().flush(curs).await?; + } + curs.0 + .set_min_reserved_idx(*self.indices.first().unwrap_or(&curs.0.cur_idx)); + } + if !self.indices.is_empty() { + self.flush(curs).await?; + } + Ok(()) + } + + fn total_send_output_time(&self) -> usize { + self.send_output_time.value() + } + + fn num_output_rows(&self) -> usize { + self.output_rows + } +}
diff --git a/native-engine/datafusion-ext-plans/src/joins/smj/full_join.rs b/native-engine/datafusion-ext-plans/src/joins/smj/full_join.rs new file mode 100644 index 0000000..55967f4 --- /dev/null +++ b/native-engine/datafusion-ext-plans/src/joins/smj/full_join.rs
@@ -0,0 +1,248 @@ +// Copyright 2022 The Blaze Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{cmp::Ordering, pin::Pin, sync::Arc}; + +use arrow::array::{RecordBatch, RecordBatchOptions}; +use async_trait::async_trait; +use datafusion::{common::Result, physical_plan::metrics::Time}; +use datafusion_ext_commons::suggested_output_batch_mem_size; +use smallvec::{smallvec, SmallVec}; + +use crate::{ + common::{batch_selection::interleave_batches, output::WrappedRecordBatchSender}, + compare_cursor, cur_forward, + joins::{Idx, JoinParams, StreamCursors}, + sort_merge_join_exec::Joiner, +}; + +pub struct FullJoiner<const L_OUTER: bool, const R_OUTER: bool> { + join_params: JoinParams, + output_sender: Arc<WrappedRecordBatchSender>, + lindices: Vec<Idx>, + rindices: Vec<Idx>, + send_output_time: Time, + output_rows: usize, +} + +pub type InnerJoiner = FullJoiner<false, false>; +pub type LeftOuterJoiner = FullJoiner<true, false>; +pub type RightOuterJoiner = FullJoiner<false, true>; +pub type FullOuterJoiner = FullJoiner<true, true>; + +impl<const L_OUTER: bool, const R_OUTER: bool> FullJoiner<L_OUTER, R_OUTER> { + pub fn new(join_params: JoinParams, output_sender: Arc<WrappedRecordBatchSender>) -> Self { + Self { + join_params, + output_sender, + lindices: vec![], + rindices: vec![], + send_output_time: Time::new(), + output_rows: 0, + } + } + + fn should_flush(&self, curs: &StreamCursors) -> bool { + if self.lindices.len() >= self.join_params.batch_size { + return true; + } + + if curs.0.num_buffered_batches() + curs.1.num_buffered_batches() >= 6 + && curs.0.mem_size() + curs.1.mem_size() > suggested_output_batch_mem_size() + { + if let Some(first_lidx) = self.lindices.first() { + if first_lidx.0 < curs.0.cur_idx.0 { + return true; + } + } + if let Some(first_ridx) = self.rindices.first() { + if first_ridx.0 < curs.1.cur_idx.0 { + return true; + } + } + } + false + } + + async fn flush(mut self: Pin<&mut Self>, curs: &mut StreamCursors) -> Result<()> { + let lindices = std::mem::take(&mut self.lindices); + let rindices = std::mem::take(&mut self.rindices); + let num_rows = lindices.len(); + assert_eq!(lindices.len(), rindices.len()); + + let lcols = interleave_batches( + curs.0.projected_batch_schema.clone(), + &curs.0.projected_batches, + &lindices, + )?; + let rcols = interleave_batches( + curs.1.projected_batch_schema.clone(), + &curs.1.projected_batches, + &rindices, + )?; + let output_batch = RecordBatch::try_new_with_options( + self.join_params.projection.schema.clone(), + [lcols.columns(), rcols.columns()].concat(), + &RecordBatchOptions::new().with_row_count(Some(num_rows)), + )?; + + if output_batch.num_rows() > 0 { + self.output_rows += output_batch.num_rows(); + + let timer = self.send_output_time.timer(); + self.output_sender.send(Ok(output_batch), None).await; + drop(timer); + } + Ok(()) + } +} + +#[async_trait] +impl<const L_OUTER: bool, const R_OUTER: bool> Joiner for FullJoiner<L_OUTER, R_OUTER> { + async fn join(mut self: Pin<&mut Self>, curs: &mut StreamCursors) -> Result<()> { + while !curs.0.finished && !curs.1.finished { + let mut lidx = curs.0.cur_idx; + let mut ridx = curs.1.cur_idx; + match compare_cursor!(curs) { + Ordering::Less => { + if L_OUTER { + self.lindices.push(lidx); + self.rindices.push(Idx::default()); + } + cur_forward!(curs.0); + if self.should_flush(curs) { + self.as_mut().flush(curs).await?; + } + curs.0 + .set_min_reserved_idx(*self.lindices.first().unwrap_or(&lidx)); + } + Ordering::Greater => { + if R_OUTER { + self.lindices.push(Idx::default()); + self.rindices.push(ridx); + } + cur_forward!(curs.1); + if self.should_flush(curs) { + self.as_mut().flush(curs).await?; + } + curs.1 + .set_min_reserved_idx(*self.rindices.first().unwrap_or(&ridx)); + } + Ordering::Equal => { + cur_forward!(curs.0); + cur_forward!(curs.1); + self.lindices.push(lidx); + self.rindices.push(ridx); + + let mut equal_lindices: SmallVec<[Idx; 16]> = smallvec![lidx]; + let mut equal_rindices: SmallVec<[Idx; 16]> = smallvec![ridx]; + let mut last_lidx = lidx; + let mut last_ridx = ridx; + lidx = curs.0.cur_idx; + ridx = curs.1.cur_idx; + let mut l_equal = !curs.0.finished && curs.0.key(lidx) == curs.0.key(last_lidx); + let mut r_equal = !curs.1.finished && curs.1.key(ridx) == curs.1.key(last_ridx); + + while l_equal || r_equal { + if l_equal { + for &ridx in &equal_rindices { + self.lindices.push(lidx); + self.rindices.push(ridx); + } + if r_equal { + equal_lindices.push(lidx); + } + cur_forward!(curs.0); + last_lidx = lidx; + lidx = curs.0.cur_idx; + } else { + curs.1 + .set_min_reserved_idx(*self.rindices.first().unwrap_or(&last_ridx)); + } + + if r_equal { + for &lidx in &equal_lindices { + self.lindices.push(lidx); + self.rindices.push(ridx); + } + if l_equal { + equal_rindices.push(ridx); + } + cur_forward!(curs.1); + last_ridx = ridx; + ridx = curs.1.cur_idx; + } else { + curs.0 + .set_min_reserved_idx(*self.lindices.first().unwrap_or(&last_lidx)); + } + + if self.should_flush(curs) { + self.as_mut().flush(curs).await?; + } + l_equal = l_equal + && !curs.0.finished + && curs.0.key(lidx) == curs.0.key(last_lidx); + r_equal = r_equal + && !curs.1.finished + && curs.1.key(ridx) == curs.1.key(last_ridx); + } + + if self.should_flush(curs) { + self.as_mut().flush(curs).await?; + } + curs.0 + .set_min_reserved_idx(*self.lindices.first().unwrap_or(&curs.0.cur_idx)); + curs.1 + .set_min_reserved_idx(*self.rindices.first().unwrap_or(&curs.1.cur_idx)); + } + } + } + + // at least one side is finished, consume the other side if it is an outer side + while L_OUTER && !curs.0.finished { + let lidx = curs.0.cur_idx; + self.lindices.push(lidx); + self.rindices.push(Idx::default()); + cur_forward!(curs.0); + if self.should_flush(curs) { + self.as_mut().flush(curs).await?; + } + curs.0 + .set_min_reserved_idx(*self.lindices.first().unwrap_or(&lidx)); + } + while R_OUTER && !curs.1.finished { + let ridx = curs.1.cur_idx; + self.lindices.push(Idx::default()); + self.rindices.push(ridx); + cur_forward!(curs.1); + if self.should_flush(curs) { + self.as_mut().flush(curs).await?; + } + curs.1 + .set_min_reserved_idx(*self.rindices.first().unwrap_or(&ridx)); + } + if !self.lindices.is_empty() { + self.flush(curs).await?; + } + Ok(()) + } + + fn total_send_output_time(&self) -> usize { + self.send_output_time.value() + } + + fn num_output_rows(&self) -> usize { + self.output_rows + } +}
diff --git a/native-engine/datafusion-ext-plans/src/joins/smj/mod.rs b/native-engine/datafusion-ext-plans/src/joins/smj/mod.rs new file mode 100644 index 0000000..8bcdadf --- /dev/null +++ b/native-engine/datafusion-ext-plans/src/joins/smj/mod.rs
@@ -0,0 +1,17 @@ +// Copyright 2022 The Blaze Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +pub mod existence_join; +pub mod full_join; +pub mod semi_join;
diff --git a/native-engine/datafusion-ext-plans/src/joins/smj/semi_join.rs b/native-engine/datafusion-ext-plans/src/joins/smj/semi_join.rs new file mode 100644 index 0000000..fd5f935 --- /dev/null +++ b/native-engine/datafusion-ext-plans/src/joins/smj/semi_join.rs
@@ -0,0 +1,252 @@ +// Copyright 2022 The Blaze Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::{cmp::Ordering, pin::Pin, sync::Arc}; + +use arrow::array::{RecordBatch, RecordBatchOptions}; +use async_trait::async_trait; +use datafusion::{common::Result, physical_plan::metrics::Time}; +use datafusion_ext_commons::suggested_output_batch_mem_size; + +use crate::{ + common::{batch_selection::interleave_batches, output::WrappedRecordBatchSender}, + compare_cursor, cur_forward, + joins::{ + smj::semi_join::SemiJoinSide::{L, R}, + Idx, JoinParams, StreamCursors, + }, + sort_merge_join_exec::Joiner, +}; + +#[derive(std::marker::ConstParamTy, Clone, Copy, PartialEq, Eq)] +pub enum SemiJoinSide { + L, + R, +} + +#[derive(std::marker::ConstParamTy, Clone, Copy, PartialEq, Eq)] +pub struct JoinerParams { + join_side: SemiJoinSide, + semi: bool, +} + +impl JoinerParams { + const fn new(join_side: SemiJoinSide, semi: bool) -> Self { + Self { join_side, semi } + } +} +pub struct SemiJoiner<const P: JoinerParams> { + join_params: JoinParams, + output_sender: Arc<WrappedRecordBatchSender>, + indices: Vec<Idx>, + send_output_time: Time, + output_rows: usize, +} + +const LEFT_SEMI: JoinerParams = JoinerParams::new(L, true); +const LEFT_ANTI: JoinerParams = JoinerParams::new(L, false); +const RIGHT_SEMI: JoinerParams = JoinerParams::new(R, true); +const RIGHT_ANTI: JoinerParams = JoinerParams::new(R, false); + +pub type LeftSemiJoiner = SemiJoiner<LEFT_SEMI>; +pub type LeftAntiJoiner = SemiJoiner<LEFT_ANTI>; +pub type RightSemiJoiner = SemiJoiner<RIGHT_SEMI>; +pub type RightAntiJoiner = SemiJoiner<RIGHT_ANTI>; + +impl<const P: JoinerParams> SemiJoiner<P> { + pub fn new(join_params: JoinParams, output_sender: Arc<WrappedRecordBatchSender>) -> Self { + Self { + join_params, + output_sender, + indices: vec![], + send_output_time: Time::new(), + output_rows: 0, + } + } + + fn should_flush(&self, curs: &StreamCursors) -> bool { + if self.indices.len() >= self.join_params.batch_size { + return true; + } + + if curs.0.num_buffered_batches() + curs.1.num_buffered_batches() >= 6 + && curs.0.mem_size() + curs.1.mem_size() > suggested_output_batch_mem_size() + { + if let Some(first_idx) = self.indices.first() { + let cur_idx = match P.join_side { + L => curs.0.cur_idx, + R => curs.1.cur_idx, + }; + if first_idx.0 < cur_idx.0 { + return true; + } + } + } + false + } + + async fn flush(mut self: Pin<&mut Self>, curs: &mut StreamCursors) -> Result<()> { + let indices = std::mem::take(&mut self.indices); + let num_rows = indices.len(); + + let cols = match P.join_side { + L => interleave_batches( + curs.0.projected_batch_schema.clone(), + &curs.0.projected_batches, + &indices, + )?, + R => interleave_batches( + curs.1.projected_batch_schema.clone(), + &curs.1.projected_batches, + &indices, + )?, + }; + let output_batch = RecordBatch::try_new_with_options( + self.join_params.projection.schema.clone(), + cols.columns().to_vec(), + &RecordBatchOptions::new().with_row_count(Some(num_rows)), + )?; + + if output_batch.num_rows() > 0 { + self.output_rows += output_batch.num_rows(); + + let timer = self.send_output_time.timer(); + self.output_sender.send(Ok(output_batch), None).await; + drop(timer); + } + Ok(()) + } +} + +#[async_trait] +impl<const P: JoinerParams> Joiner for SemiJoiner<P> { + async fn join(mut self: Pin<&mut Self>, curs: &mut StreamCursors) -> Result<()> { + while !curs.0.finished && !curs.1.finished { + let mut lidx = curs.0.cur_idx; + let mut ridx = curs.1.cur_idx; + + match compare_cursor!(curs) { + Ordering::Less => { + if P.join_side == L && !P.semi { + self.indices.push(lidx); + } + cur_forward!(curs.0); + if self.should_flush(curs) { + self.as_mut().flush(curs).await?; + } + curs.0.set_min_reserved_idx(match P.join_side { + L => *self.indices.first().unwrap_or(&lidx), + R => lidx, + }); + } + Ordering::Greater => { + if P.join_side == R && !P.semi { + self.indices.push(ridx); + } + cur_forward!(curs.1); + if self.should_flush(curs) { + self.as_mut().flush(curs).await?; + } + curs.1.set_min_reserved_idx(match P.join_side { + L => ridx, + R => *self.indices.first().unwrap_or(&ridx), + }); + } + Ordering::Equal => { + // output/skip left equal rows + loop { + if P.join_side == L && P.semi { + self.indices.push(lidx); + if self.should_flush(curs) { + self.as_mut().flush(curs).await?; + } + } + cur_forward!(curs.0); + curs.0.set_min_reserved_idx(match P.join_side { + L => *self.indices.first().unwrap_or(&lidx), + R => lidx, + }); + + if !curs.0.finished && curs.0.key(curs.0.cur_idx) == curs.0.key(lidx) { + lidx = curs.0.cur_idx; + continue; + } + break; + } + + // output/skip right equal rows + loop { + if P.join_side == R && P.semi { + self.indices.push(ridx); + if self.should_flush(curs) { + self.as_mut().flush(curs).await?; + } + } + cur_forward!(curs.1); + curs.1.set_min_reserved_idx(match P.join_side { + L => ridx, + R => *self.indices.first().unwrap_or(&ridx), + }); + + if !curs.1.finished && curs.1.key(curs.1.cur_idx) == curs.1.key(ridx) { + ridx = curs.1.cur_idx; + continue; + } + break; + } + } + } + } + + // at least one side is finished, consume the other side if it is an anti side + if !P.semi { + while P.join_side == L && !P.semi && !curs.0.finished { + let lidx = curs.0.cur_idx; + self.indices.push(lidx); + cur_forward!(curs.0); + if self.should_flush(curs) { + self.as_mut().flush(curs).await?; + } + curs.0.set_min_reserved_idx(match P.join_side { + L => *self.indices.first().unwrap_or(&lidx), + R => lidx, + }); + } + while P.join_side == R && !P.semi && !curs.1.finished { + let ridx = curs.1.cur_idx; + self.indices.push(ridx); + cur_forward!(curs.1); + if self.should_flush(curs) { + self.as_mut().flush(curs).await?; + } + curs.1.set_min_reserved_idx(match P.join_side { + L => ridx, + R => *self.indices.first().unwrap_or(&ridx), + }); + } + } + if !self.indices.is_empty() { + self.flush(curs).await?; + } + Ok(()) + } + + fn total_send_output_time(&self) -> usize { + self.send_output_time.value() + } + + fn num_output_rows(&self) -> usize { + self.output_rows + } +}
diff --git a/native-engine/datafusion-ext-plans/src/joins/stream_cursor.rs b/native-engine/datafusion-ext-plans/src/joins/stream_cursor.rs new file mode 100644 index 0000000..c105bb8 --- /dev/null +++ b/native-engine/datafusion-ext-plans/src/joins/stream_cursor.rs
@@ -0,0 +1,235 @@ +// Copyright 2022 The Blaze Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +use std::sync::Arc; + +use arrow::{ + array::{RecordBatch, RecordBatchOptions}, + buffer::NullBuffer, + datatypes::{Schema, SchemaRef}, + row::{Row, RowConverter, Rows, SortField}, +}; +use datafusion::{ + common::{JoinSide, Result}, + execution::SendableRecordBatchStream, + physical_expr::PhysicalExprRef, + physical_plan::metrics::Time, +}; +use datafusion_ext_commons::array_size::ArraySize; +use futures::{Future, StreamExt}; +use parking_lot::Mutex; + +use crate::{ + common::batch_selection::take_batch_opt, + joins::{Idx, JoinParams}, +}; + +pub struct StreamCursor { + stream: SendableRecordBatchStream, + key_converter: Arc<Mutex<RowConverter>>, + key_exprs: Vec<PhysicalExprRef>, + poll_time: Time, + + // IMPORTANT: + // batches/rows/null_buffers always contains a `null batch` in the front + projection: Vec<usize>, + pub projected_batch_schema: SchemaRef, + pub projected_batches: Vec<RecordBatch>, + pub cur_idx: Idx, + min_reserved_idx: Idx, + keys: Vec<Arc<Rows>>, + key_has_nulls: Vec<Option<NullBuffer>>, + num_null_batches: usize, + mem_size: usize, + pub finished: bool, +} + +impl StreamCursor { + pub fn try_new( + stream: SendableRecordBatchStream, + join_params: &JoinParams, + join_side: JoinSide, + projection: &[usize], + ) -> Result<Self> { + let key_converter = Arc::new(Mutex::new(RowConverter::new( + join_params + .key_data_types + .iter() + .cloned() + .zip(&join_params.sort_options) + .map(|(dt, options)| SortField::new_with_options(dt, *options)) + .collect(), + )?)); + let key_exprs = match join_side { + JoinSide::Left => join_params.left_keys.clone(), + JoinSide::Right => join_params.right_keys.clone(), + }; + + let empty_batch = RecordBatch::new_empty(Arc::new(Schema::new( + stream + .schema() + .fields() + .iter() + .map(|f| f.as_ref().clone().with_nullable(true)) + .collect::<Vec<_>>(), + ))); + let empty_keys = Arc::new( + key_converter.lock().convert_columns( + &key_exprs + .iter() + .map(|key| Ok(key.evaluate(&empty_batch)?.into_array(0)?)) + .collect::<Result<Vec<_>>>()?, + )?, + ); + let null_batch = take_batch_opt(empty_batch, [Option::<usize>::None])?; + let projected_null_batch = null_batch.project(projection)?; + let null_nb = NullBuffer::new_null(1); + + Ok(Self { + stream, + key_exprs, + key_converter, + poll_time: Time::new(), + projection: projection.to_vec(), + projected_batch_schema: projected_null_batch.schema(), + projected_batches: vec![projected_null_batch], + cur_idx: (0, 0), + min_reserved_idx: (0, 0), + keys: vec![empty_keys], + key_has_nulls: vec![Some(null_nb)], + num_null_batches: 1, + mem_size: 0, + finished: false, + }) + } + + pub fn next(&mut self) -> Option<impl Future<Output = Result<()>> + '_> { + self.cur_idx.1 += 1; + if self.cur_idx.1 >= self.projected_batches[self.cur_idx.0].num_rows() { + self.cur_idx.0 += 1; + self.cur_idx.1 = 0; + } + + let should_load_next_batch = self.cur_idx.0 >= self.projected_batches.len(); + if should_load_next_batch { + Some(async move { + while let Some(batch) = { + let timer = self.poll_time.timer(); + let batch = self.stream.next().await.transpose()?; + drop(timer); + batch + } { + if batch.num_rows() == 0 { + continue; + } + let key_columns = self + .key_exprs + .iter() + .map(|key| Ok(key.evaluate(&batch)?.into_array(batch.num_rows())?)) + .collect::<Result<Vec<_>>>()?; + let key_has_nulls = key_columns + .iter() + .map(|c| c.nulls().cloned()) + .reduce(|lhs, rhs| NullBuffer::union(lhs.as_ref(), rhs.as_ref())) + .unwrap_or(None); + let keys = Arc::new(self.key_converter.lock().convert_columns(&key_columns)?); + + self.mem_size += batch.get_array_mem_size(); + self.mem_size += key_has_nulls + .as_ref() + .map(|nb| nb.buffer().len()) + .unwrap_or_default(); + self.mem_size += keys.size(); + + self.projected_batches + .push(RecordBatch::try_new_with_options( + self.projected_batches[0].schema(), + self.projection + .iter() + .map(|&i| batch.column(i).clone()) + .collect(), + &RecordBatchOptions::new().with_row_count(Some(batch.num_rows())), + )?); + self.key_has_nulls.push(key_has_nulls); + self.keys.push(keys); + + // fill out-dated batches with null batches + if self.num_null_batches < self.min_reserved_idx.0 { + for i in self.num_null_batches..self.min_reserved_idx.0 { + self.mem_size -= self.projected_batches[i].get_array_mem_size(); + self.mem_size -= self.key_has_nulls[i] + .as_ref() + .map(|nb| nb.buffer().len()) + .unwrap_or_default(); + self.mem_size -= self.keys[i].size(); + + self.projected_batches[i] = self.projected_batches[0].clone(); + self.keys[i] = self.keys[0].clone(); + self.key_has_nulls[i] = self.key_has_nulls[0].clone(); + self.num_null_batches += 1; + } + } + return Ok(()); + } + self.finished = true; + return Ok(()); + }) + } else { + None + } + } + + #[inline] + pub fn is_null_key(&self, idx: Idx) -> bool { + self.key_has_nulls[idx.0] + .as_ref() + .map(|nb| nb.is_null(idx.1)) + .unwrap_or(false) + } + + #[inline] + pub fn key<'a>(&'a self, idx: Idx) -> Row<'a> { + let keys = &self.keys[idx.0]; + keys.row(idx.1) + } + + #[inline] + pub fn num_buffered_batches(&self) -> usize { + self.projected_batches.len() - self.num_null_batches + } + + #[inline] + pub fn mem_size(&self) -> usize { + self.mem_size + } + + #[inline] + pub fn set_min_reserved_idx(&mut self, idx: Idx) { + self.min_reserved_idx = idx; + } + + #[inline] + pub fn total_poll_time(&self) -> usize { + self.poll_time.value() + } +} + +#[macro_export] +macro_rules! cur_forward { + ($cur:expr) => {{ + if let Some(fut) = $cur.next() { + fut.await?; + } + }}; +}
diff --git a/native-engine/datafusion-ext-plans/src/joins/test.rs b/native-engine/datafusion-ext-plans/src/joins/test.rs new file mode 100644 index 0000000..e0826e7 --- /dev/null +++ b/native-engine/datafusion-ext-plans/src/joins/test.rs
@@ -0,0 +1,947 @@ +// Copyright 2022 The Blaze Authors +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +#[cfg(test)] +mod tests { + use std::sync::Arc; + + use arrow::{ + self, + array::*, + compute::SortOptions, + datatypes::{DataType, Field, Schema, SchemaRef}, + record_batch::RecordBatch, + }; + use datafusion::{ + assert_batches_sorted_eq, + common::JoinSide, + error::Result, + physical_expr::expressions::Column, + physical_plan::{common, joins::utils::*, memory::MemoryExec, ExecutionPlan}, + prelude::SessionContext, + }; + use TestType::*; + + use crate::{ + broadcast_join_build_hash_map_exec::BroadcastJoinBuildHashMapExec, + broadcast_join_exec::BroadcastJoinExec, + joins::join_utils::{JoinType, JoinType::*}, + sort_merge_join_exec::SortMergeJoinExec, + }; + + #[derive(Clone, Copy)] + enum TestType { + SMJ, + BHJLeftProbed, + BHJRightProbed, + } + + fn columns(schema: &Schema) -> Vec<String> { + schema.fields().iter().map(|f| f.name().clone()).collect() + } + + fn build_table_i32( + a: (&str, &Vec<i32>), + b: (&str, &Vec<i32>), + c: (&str, &Vec<i32>), + ) -> RecordBatch { + let schema = Schema::new(vec![ + Field::new(a.0, DataType::Int32, false), + Field::new(b.0, DataType::Int32, false), + Field::new(c.0, DataType::Int32, false), + ]); + + RecordBatch::try_new( + Arc::new(schema), + vec![ + Arc::new(Int32Array::from(a.1.clone())), + Arc::new(Int32Array::from(b.1.clone())), + Arc::new(Int32Array::from(c.1.clone())), + ], + ) + .unwrap() + } + + fn build_table( + a: (&str, &Vec<i32>), + b: (&str, &Vec<i32>), + c: (&str, &Vec<i32>), + ) -> Arc<dyn ExecutionPlan> { + let batch = build_table_i32(a, b, c); + let schema = batch.schema(); + Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap()) + } + + fn build_table_from_batches(batches: Vec<RecordBatch>) -> Arc<dyn ExecutionPlan> { + let schema = batches.first().unwrap().schema(); + Arc::new(MemoryExec::try_new(&[batches], schema, None).unwrap()) + } + + fn build_date_table( + a: (&str, &Vec<i32>), + b: (&str, &Vec<i32>), + c: (&str, &Vec<i32>), + ) -> Arc<dyn ExecutionPlan> { + let schema = Schema::new(vec![ + Field::new(a.0, DataType::Date32, false), + Field::new(b.0, DataType::Date32, false), + Field::new(c.0, DataType::Date32, false), + ]); + + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![ + Arc::new(Date32Array::from(a.1.clone())), + Arc::new(Date32Array::from(b.1.clone())), + Arc::new(Date32Array::from(c.1.clone())), + ], + ) + .unwrap(); + + let schema = batch.schema(); + Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap()) + } + + fn build_date64_table( + a: (&str, &Vec<i64>), + b: (&str, &Vec<i64>), + c: (&str, &Vec<i64>), + ) -> Arc<dyn ExecutionPlan> { + let schema = Schema::new(vec![ + Field::new(a.0, DataType::Date64, false), + Field::new(b.0, DataType::Date64, false), + Field::new(c.0, DataType::Date64, false), + ]); + + let batch = RecordBatch::try_new( + Arc::new(schema), + vec![ + Arc::new(Date64Array::from(a.1.clone())), + Arc::new(Date64Array::from(b.1.clone())), + Arc::new(Date64Array::from(c.1.clone())), + ], + ) + .unwrap(); + + let schema = batch.schema(); + Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap()) + } + + /// returns a table with 3 columns of i32 in memory + pub fn build_table_i32_nullable( + a: (&str, &Vec<Option<i32>>), + b: (&str, &Vec<Option<i32>>), + c: (&str, &Vec<Option<i32>>), + ) -> Arc<dyn ExecutionPlan> { + let schema = Arc::new(Schema::new(vec![ + Field::new(a.0, DataType::Int32, true), + Field::new(b.0, DataType::Int32, true), + Field::new(c.0, DataType::Int32, true), + ])); + let batch = RecordBatch::try_new( + schema.clone(), + vec![ + Arc::new(Int32Array::from(a.1.clone())), + Arc::new(Int32Array::from(b.1.clone())), + Arc::new(Int32Array::from(c.1.clone())), + ], + ) + .unwrap(); + Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap()) + } + + fn build_join_schema_for_test( + left: &Schema, + right: &Schema, + join_type: JoinType, + ) -> Result<SchemaRef> { + if join_type == Existence { + let exists_field = Arc::new(Field::new("exists#0", DataType::Boolean, false)); + return Ok(Arc::new(Schema::new( + [left.fields().to_vec(), vec![exists_field]].concat(), + ))); + } + Ok(Arc::new( + build_join_schema(left, right, &join_type.try_into()?).0, + )) + } + + async fn join_collect( + test_type: TestType, + left: Arc<dyn ExecutionPlan>, + right: Arc<dyn ExecutionPlan>, + on: JoinOn, + join_type: JoinType, + ) -> Result<(Vec<String>, Vec<RecordBatch>)> { + let session_ctx = SessionContext::new(); + let task_ctx = session_ctx.task_ctx(); + let schema = build_join_schema_for_test(&left.schema(), &right.schema(), join_type)?; + + let join: Arc<dyn ExecutionPlan> = match test_type { + SMJ => { + let sort_options = vec![SortOptions::default(); on.len()]; + Arc::new(SortMergeJoinExec::try_new( + schema, + left, + right, + on, + join_type, + sort_options, + )?) + } + BHJLeftProbed => { + let right = Arc::new(BroadcastJoinBuildHashMapExec::new( + right, + on.iter().map(|(_, right_key)| right_key.clone()).collect(), + )); + Arc::new(BroadcastJoinExec::try_new( + schema, + left, + right, + on, + join_type, + JoinSide::Right, + None, + )?) + } + BHJRightProbed => { + let left = Arc::new(BroadcastJoinBuildHashMapExec::new( + left, + on.iter().map(|(left_key, _)| left_key.clone()).collect(), + )); + Arc::new(BroadcastJoinExec::try_new( + schema, + left, + right, + on, + join_type, + JoinSide::Left, + None, + )?) + } + }; + let columns = columns(&join.schema()); + let stream = join.execute(0, task_ctx)?; + let batches = common::collect(stream).await?; + Ok((columns, batches)) + } + + #[tokio::test] + async fn join_inner_one() -> Result<()> { + for test_type in [SMJ, BHJLeftProbed, BHJRightProbed] { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 5]), // this has a repetition + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), + ("c2", &vec![70, 80, 90]), + ); + + let on: JoinOn = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?), + Arc::new(Column::new_with_schema("b1", &right.schema())?), + )]; + + let (_, batches) = join_collect(test_type, left, right, on, Inner).await?; + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| 3 | 5 | 9 | 20 | 5 | 80 |", + "+----+----+----+----+----+----+", + ]; + // The output order is important as SMJ preserves sortedness + assert_batches_sorted_eq!(expected, &batches); + } + Ok(()) + } + + #[tokio::test] + async fn join_inner_two() -> Result<()> { + for test_type in [SMJ, BHJLeftProbed, BHJRightProbed] { + let left = build_table( + ("a1", &vec![1, 2, 2]), + ("b2", &vec![1, 2, 2]), + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a1", &vec![1, 2, 3]), + ("b2", &vec![1, 2, 2]), + ("c2", &vec![70, 80, 90]), + ); + let on: JoinOn = vec![ + ( + Arc::new(Column::new_with_schema("a1", &left.schema())?), + Arc::new(Column::new_with_schema("a1", &right.schema())?), + ), + ( + Arc::new(Column::new_with_schema("b2", &left.schema())?), + Arc::new(Column::new_with_schema("b2", &right.schema())?), + ), + ]; + + let (_columns, batches) = join_collect(test_type, left, right, on, Inner).await?; + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b2 | c1 | a1 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 1 | 7 | 1 | 1 | 70 |", + "| 2 | 2 | 8 | 2 | 2 | 80 |", + "| 2 | 2 | 9 | 2 | 2 | 80 |", + "+----+----+----+----+----+----+", + ]; + // The output order is important as SMJ preserves sortedness + assert_batches_sorted_eq!(expected, &batches); + } + Ok(()) + } + + #[tokio::test] + async fn join_inner_two_two() -> Result<()> { + for test_type in [SMJ, BHJLeftProbed, BHJRightProbed] { + let left = build_table( + ("a1", &vec![1, 1, 2]), + ("b2", &vec![1, 1, 2]), + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a1", &vec![1, 1, 3]), + ("b2", &vec![1, 1, 2]), + ("c2", &vec![70, 80, 90]), + ); + let on: JoinOn = vec![ + ( + Arc::new(Column::new_with_schema("a1", &left.schema())?), + Arc::new(Column::new_with_schema("a1", &right.schema())?), + ), + ( + Arc::new(Column::new_with_schema("b2", &left.schema())?), + Arc::new(Column::new_with_schema("b2", &right.schema())?), + ), + ]; + + let (_columns, batches) = join_collect(test_type, left, right, on, Inner).await?; + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b2 | c1 | a1 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 1 | 7 | 1 | 1 | 70 |", + "| 1 | 1 | 7 | 1 | 1 | 80 |", + "| 1 | 1 | 8 | 1 | 1 | 70 |", + "| 1 | 1 | 8 | 1 | 1 | 80 |", + "+----+----+----+----+----+----+", + ]; + // The output order is important as SMJ preserves sortedness + assert_batches_sorted_eq!(expected, &batches); + } + Ok(()) + } + + #[tokio::test] + async fn join_inner_with_nulls() -> Result<()> { + for test_type in [SMJ, BHJLeftProbed, BHJRightProbed] { + let left = build_table_i32_nullable( + ("a1", &vec![Some(1), Some(1), Some(2), Some(2)]), + ("b2", &vec![None, Some(1), Some(2), Some(2)]), // null in key field + ("c1", &vec![Some(1), None, Some(8), Some(9)]), // null in non-key field + ); + let right = build_table_i32_nullable( + ("a1", &vec![Some(1), Some(1), Some(2), Some(3)]), + ("b2", &vec![None, Some(1), Some(2), Some(2)]), + ("c2", &vec![Some(10), Some(70), Some(80), Some(90)]), + ); + let on: JoinOn = vec![ + ( + Arc::new(Column::new_with_schema("a1", &left.schema())?), + Arc::new(Column::new_with_schema("a1", &right.schema())?), + ), + ( + Arc::new(Column::new_with_schema("b2", &left.schema())?), + Arc::new(Column::new_with_schema("b2", &right.schema())?), + ), + ]; + + let (_, batches) = join_collect(test_type, left, right, on, Inner).await?; + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b2 | c1 | a1 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 1 | | 1 | 1 | 70 |", + "| 2 | 2 | 8 | 2 | 2 | 80 |", + "| 2 | 2 | 9 | 2 | 2 | 80 |", + "+----+----+----+----+----+----+", + ]; + // The output order is important as SMJ preserves sortedness + assert_batches_sorted_eq!(expected, &batches); + } + Ok(()) + } + + #[tokio::test] + async fn join_left_one() -> Result<()> { + for test_type in [SMJ, BHJLeftProbed, BHJRightProbed] { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 7]), // 7 does not exist on the right + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), + ("c2", &vec![70, 80, 90]), + ); + let on: JoinOn = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?), + Arc::new(Column::new_with_schema("b1", &right.schema())?), + )]; + + let (_, batches) = join_collect(test_type, left, right, on, Left).await?; + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| 3 | 7 | 9 | | | |", + "+----+----+----+----+----+----+", + ]; + // The output order is important as SMJ preserves sortedness + assert_batches_sorted_eq!(expected, &batches); + } + Ok(()) + } + + #[tokio::test] + async fn join_right_one() -> Result<()> { + for test_type in [SMJ, BHJLeftProbed, BHJRightProbed] { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 7]), + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), // 6 does not exist on the left + ("c2", &vec![70, 80, 90]), + ); + let on: JoinOn = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?), + Arc::new(Column::new_with_schema("b1", &right.schema())?), + )]; + + let (_, batches) = join_collect(test_type, left, right, on, Right).await?; + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+----+----+----+----+----+----+", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| | | | 30 | 6 | 90 |", + "+----+----+----+----+----+----+", + ]; + // The output order is important as SMJ preserves sortedness + assert_batches_sorted_eq!(expected, &batches); + } + Ok(()) + } + + #[tokio::test] + async fn join_full_one() -> Result<()> { + for test_type in [SMJ, BHJLeftProbed, BHJRightProbed] { + let left = build_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![4, 5, 7]), // 7 does not exist on the right + ("c1", &vec![7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b2", &vec![4, 5, 6]), + ("c2", &vec![70, 80, 90]), + ); + let on: JoinOn = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?), + Arc::new(Column::new_with_schema("b2", &right.schema())?), + )]; + + let (_, batches) = join_collect(test_type, left, right, on, Full).await?; + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| | | | 30 | 6 | 90 |", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| 3 | 7 | 9 | | | |", + "+----+----+----+----+----+----+", + ]; + assert_batches_sorted_eq!(expected, &batches); + } + Ok(()) + } + + #[tokio::test] + async fn join_anti() -> Result<()> { + for test_type in [SMJ, BHJLeftProbed, BHJRightProbed] { + let left = build_table( + ("a1", &vec![1, 2, 2, 3, 5]), + ("b1", &vec![4, 5, 5, 7, 7]), // 7 does not exist on the right + ("c1", &vec![7, 8, 8, 9, 11]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), + ("c2", &vec![70, 80, 90]), + ); + let on: JoinOn = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?), + Arc::new(Column::new_with_schema("b1", &right.schema())?), + )]; + + let (_, batches) = join_collect(test_type, left, right, on, LeftAnti).await?; + let expected = vec![ + "+----+----+----+", + "| a1 | b1 | c1 |", + "+----+----+----+", + "| 3 | 7 | 9 |", + "| 5 | 7 | 11 |", + "+----+----+----+", + ]; + // The output order is important as SMJ preserves sortedness + assert_batches_sorted_eq!(expected, &batches); + } + Ok(()) + } + + #[tokio::test] + async fn join_semi() -> Result<()> { + for test_type in [SMJ, BHJLeftProbed, BHJRightProbed] { + let left = build_table( + ("a1", &vec![1, 2, 2, 3]), + ("b1", &vec![4, 5, 5, 7]), // 7 does not exist on the right + ("c1", &vec![7, 8, 8, 9]), + ); + let right = build_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![4, 5, 6]), // 5 is double on the right + ("c2", &vec![70, 80, 90]), + ); + let on: JoinOn = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?), + Arc::new(Column::new_with_schema("b1", &right.schema())?), + )]; + + let (_, batches) = join_collect(test_type, left, right, on, LeftSemi).await?; + let expected = vec![ + "+----+----+----+", + "| a1 | b1 | c1 |", + "+----+----+----+", + "| 1 | 4 | 7 |", + "| 2 | 5 | 8 |", + "| 2 | 5 | 8 |", + "+----+----+----+", + ]; + // The output order is important as SMJ preserves sortedness + assert_batches_sorted_eq!(expected, &batches); + } + Ok(()) + } + + #[tokio::test] + async fn join_with_duplicated_column_names() -> Result<()> { + for test_type in [SMJ, BHJLeftProbed, BHJRightProbed] { + let left = build_table( + ("a", &vec![1, 2, 3]), + ("b", &vec![4, 5, 7]), + ("c", &vec![7, 8, 9]), + ); + let right = build_table( + ("a", &vec![10, 20, 30]), + ("b", &vec![1, 2, 7]), + ("c", &vec![70, 80, 90]), + ); + let on: JoinOn = vec![( + // join on a=b so there are duplicate column names on unjoined columns + Arc::new(Column::new_with_schema("a", &left.schema())?), + Arc::new(Column::new_with_schema("b", &right.schema())?), + )]; + + let (_, batches) = join_collect(test_type, left, right, on, Inner).await?; + let expected = vec![ + "+---+---+---+----+---+----+", + "| a | b | c | a | b | c |", + "+---+---+---+----+---+----+", + "| 1 | 4 | 7 | 10 | 1 | 70 |", + "| 2 | 5 | 8 | 20 | 2 | 80 |", + "+---+---+---+----+---+----+", + ]; + // The output order is important as SMJ preserves sortedness + assert_batches_sorted_eq!(expected, &batches); + } + Ok(()) + } + + #[tokio::test] + async fn join_date32() -> Result<()> { + for test_type in [SMJ, BHJLeftProbed, BHJRightProbed] { + let left = build_date_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![19107, 19108, 19108]), // this has a repetition + ("c1", &vec![7, 8, 9]), + ); + let right = build_date_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![19107, 19108, 19109]), + ("c2", &vec![70, 80, 90]), + ); + + let on: JoinOn = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?), + Arc::new(Column::new_with_schema("b1", &right.schema())?), + )]; + + let (_, batches) = join_collect(test_type, left, right, on, Inner).await?; + + let expected = vec![ + "+------------+------------+------------+------------+------------+------------+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+------------+------------+------------+------------+------------+------------+", + "| 1970-01-02 | 2022-04-25 | 1970-01-08 | 1970-01-11 | 2022-04-25 | 1970-03-12 |", + "| 1970-01-03 | 2022-04-26 | 1970-01-09 | 1970-01-21 | 2022-04-26 | 1970-03-22 |", + "| 1970-01-04 | 2022-04-26 | 1970-01-10 | 1970-01-21 | 2022-04-26 | 1970-03-22 |", + "+------------+------------+------------+------------+------------+------------+", + ]; + // The output order is important as SMJ preserves sortedness + assert_batches_sorted_eq!(expected, &batches); + } + Ok(()) + } + + #[tokio::test] + async fn join_date64() -> Result<()> { + for test_type in [SMJ, BHJLeftProbed, BHJRightProbed] { + let left = build_date64_table( + ("a1", &vec![1, 2, 3]), + ("b1", &vec![1650703441000, 1650903441000, 1650903441000]), /* this has a + * repetition */ + ("c1", &vec![7, 8, 9]), + ); + let right = build_date64_table( + ("a2", &vec![10, 20, 30]), + ("b1", &vec![1650703441000, 1650503441000, 1650903441000]), + ("c2", &vec![70, 80, 90]), + ); + + let on: JoinOn = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?), + Arc::new(Column::new_with_schema("b1", &right.schema())?), + )]; + + let (_, batches) = join_collect(test_type, left, right, on, Inner).await?; + let expected = vec![ + "+-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+", + "| a1 | b1 | c1 | a2 | b1 | c2 |", + "+-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+", + "| 1970-01-01T00:00:00.001 | 2022-04-23T08:44:01 | 1970-01-01T00:00:00.007 | 1970-01-01T00:00:00.010 | 2022-04-23T08:44:01 | 1970-01-01T00:00:00.070 |", + "| 1970-01-01T00:00:00.002 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.008 | 1970-01-01T00:00:00.030 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.090 |", + "| 1970-01-01T00:00:00.003 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.009 | 1970-01-01T00:00:00.030 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.090 |", + "+-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+", + ]; + + // The output order is important as SMJ preserves sortedness + assert_batches_sorted_eq!(expected, &batches); + } + Ok(()) + } + + #[tokio::test] + async fn join_left_sort_order() -> Result<()> { + for test_type in [SMJ, BHJLeftProbed, BHJRightProbed] { + let left = build_table( + ("a1", &vec![0, 1, 2, 3, 4, 5]), + ("b1", &vec![3, 4, 5, 6, 6, 7]), + ("c1", &vec![4, 5, 6, 7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![0, 10, 20, 30, 40]), + ("b2", &vec![2, 4, 6, 6, 8]), + ("c2", &vec![50, 60, 70, 80, 90]), + ); + let on: JoinOn = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?), + Arc::new(Column::new_with_schema("b2", &right.schema())?), + )]; + + let (_, batches) = join_collect(test_type, left, right, on, Left).await?; + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| 0 | 3 | 4 | | | |", + "| 1 | 4 | 5 | 10 | 4 | 60 |", + "| 2 | 5 | 6 | | | |", + "| 3 | 6 | 7 | 20 | 6 | 70 |", + "| 3 | 6 | 7 | 30 | 6 | 80 |", + "| 4 | 6 | 8 | 20 | 6 | 70 |", + "| 4 | 6 | 8 | 30 | 6 | 80 |", + "| 5 | 7 | 9 | | | |", + "+----+----+----+----+----+----+", + ]; + assert_batches_sorted_eq!(expected, &batches); + } + Ok(()) + } + + #[tokio::test] + async fn join_right_sort_order() -> Result<()> { + for test_type in [SMJ, BHJLeftProbed, BHJRightProbed] { + let left = build_table( + ("a1", &vec![0, 1, 2, 3]), + ("b1", &vec![3, 4, 5, 7]), + ("c1", &vec![6, 7, 8, 9]), + ); + let right = build_table( + ("a2", &vec![0, 10, 20, 30]), + ("b2", &vec![2, 4, 5, 6]), + ("c2", &vec![60, 70, 80, 90]), + ); + let on: JoinOn = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?), + Arc::new(Column::new_with_schema("b2", &right.schema())?), + )]; + + let (_, batches) = join_collect(test_type, left, right, on, Right).await?; + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| | | | 0 | 2 | 60 |", + "| 1 | 4 | 7 | 10 | 4 | 70 |", + "| 2 | 5 | 8 | 20 | 5 | 80 |", + "| | | | 30 | 6 | 90 |", + "+----+----+----+----+----+----+", + ]; + assert_batches_sorted_eq!(expected, &batches); + } + Ok(()) + } + + #[tokio::test] + async fn join_left_multiple_batches() -> Result<()> { + for test_type in [SMJ, BHJLeftProbed, BHJRightProbed] { + let left_batch_1 = build_table_i32( + ("a1", &vec![0, 1, 2]), + ("b1", &vec![3, 4, 5]), + ("c1", &vec![4, 5, 6]), + ); + let left_batch_2 = build_table_i32( + ("a1", &vec![3, 4, 5, 6]), + ("b1", &vec![6, 6, 7, 9]), + ("c1", &vec![7, 8, 9, 9]), + ); + let right_batch_1 = build_table_i32( + ("a2", &vec![0, 10, 20]), + ("b2", &vec![2, 4, 6]), + ("c2", &vec![50, 60, 70]), + ); + let right_batch_2 = build_table_i32( + ("a2", &vec![30, 40]), + ("b2", &vec![6, 8]), + ("c2", &vec![80, 90]), + ); + let left = build_table_from_batches(vec![left_batch_1, left_batch_2]); + let right = build_table_from_batches(vec![right_batch_1, right_batch_2]); + let on: JoinOn = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?), + Arc::new(Column::new_with_schema("b2", &right.schema())?), + )]; + + let (_, batches) = join_collect(test_type, left, right, on, Left).await?; + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| 0 | 3 | 4 | | | |", + "| 1 | 4 | 5 | 10 | 4 | 60 |", + "| 2 | 5 | 6 | | | |", + "| 3 | 6 | 7 | 20 | 6 | 70 |", + "| 3 | 6 | 7 | 30 | 6 | 80 |", + "| 4 | 6 | 8 | 20 | 6 | 70 |", + "| 4 | 6 | 8 | 30 | 6 | 80 |", + "| 5 | 7 | 9 | | | |", + "| 6 | 9 | 9 | | | |", + "+----+----+----+----+----+----+", + ]; + assert_batches_sorted_eq!(expected, &batches); + } + Ok(()) + } + + #[tokio::test] + async fn join_right_multiple_batches() -> Result<()> { + for test_type in [SMJ, BHJLeftProbed, BHJRightProbed] { + let right_batch_1 = build_table_i32( + ("a2", &vec![0, 1, 2]), + ("b2", &vec![3, 4, 5]), + ("c2", &vec![4, 5, 6]), + ); + let right_batch_2 = build_table_i32( + ("a2", &vec![3, 4, 5, 6]), + ("b2", &vec![6, 6, 7, 9]), + ("c2", &vec![7, 8, 9, 9]), + ); + let left_batch_1 = build_table_i32( + ("a1", &vec![0, 10, 20]), + ("b1", &vec![2, 4, 6]), + ("c1", &vec![50, 60, 70]), + ); + let left_batch_2 = build_table_i32( + ("a1", &vec![30, 40]), + ("b1", &vec![6, 8]), + ("c1", &vec![80, 90]), + ); + let left = build_table_from_batches(vec![left_batch_1, left_batch_2]); + let right = build_table_from_batches(vec![right_batch_1, right_batch_2]); + let on: JoinOn = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?), + Arc::new(Column::new_with_schema("b2", &right.schema())?), + )]; + + let (_, batches) = join_collect(test_type, left, right, on, Right).await?; + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| | | | 0 | 3 | 4 |", + "| 10 | 4 | 60 | 1 | 4 | 5 |", + "| | | | 2 | 5 | 6 |", + "| 20 | 6 | 70 | 3 | 6 | 7 |", + "| 30 | 6 | 80 | 3 | 6 | 7 |", + "| 20 | 6 | 70 | 4 | 6 | 8 |", + "| 30 | 6 | 80 | 4 | 6 | 8 |", + "| | | | 5 | 7 | 9 |", + "| | | | 6 | 9 | 9 |", + "+----+----+----+----+----+----+", + ]; + assert_batches_sorted_eq!(expected, &batches); + } + Ok(()) + } + + #[tokio::test] + async fn join_full_multiple_batches() -> Result<()> { + for test_type in [SMJ, BHJLeftProbed, BHJRightProbed] { + let left_batch_1 = build_table_i32( + ("a1", &vec![0, 1, 2]), + ("b1", &vec![3, 4, 5]), + ("c1", &vec![4, 5, 6]), + ); + let left_batch_2 = build_table_i32( + ("a1", &vec![3, 4, 5, 6]), + ("b1", &vec![6, 6, 7, 9]), + ("c1", &vec![7, 8, 9, 9]), + ); + let right_batch_1 = build_table_i32( + ("a2", &vec![0, 10, 20]), + ("b2", &vec![2, 4, 6]), + ("c2", &vec![50, 60, 70]), + ); + let right_batch_2 = build_table_i32( + ("a2", &vec![30, 40]), + ("b2", &vec![6, 8]), + ("c2", &vec![80, 90]), + ); + let left = build_table_from_batches(vec![left_batch_1, left_batch_2]); + let right = build_table_from_batches(vec![right_batch_1, right_batch_2]); + let on: JoinOn = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?), + Arc::new(Column::new_with_schema("b2", &right.schema())?), + )]; + + let (_, batches) = join_collect(test_type, left, right, on, Full).await?; + let expected = vec![ + "+----+----+----+----+----+----+", + "| a1 | b1 | c1 | a2 | b2 | c2 |", + "+----+----+----+----+----+----+", + "| | | | 0 | 2 | 50 |", + "| | | | 40 | 8 | 90 |", + "| 0 | 3 | 4 | | | |", + "| 1 | 4 | 5 | 10 | 4 | 60 |", + "| 2 | 5 | 6 | | | |", + "| 3 | 6 | 7 | 20 | 6 | 70 |", + "| 3 | 6 | 7 | 30 | 6 | 80 |", + "| 4 | 6 | 8 | 20 | 6 | 70 |", + "| 4 | 6 | 8 | 30 | 6 | 80 |", + "| 5 | 7 | 9 | | | |", + "| 6 | 9 | 9 | | | |", + "+----+----+----+----+----+----+", + ]; + assert_batches_sorted_eq!(expected, &batches); + } + Ok(()) + } + + #[tokio::test] + async fn join_existence_multiple_batches() -> Result<()> { + for test_type in [SMJ, BHJLeftProbed, BHJRightProbed] { + let left_batch_1 = build_table_i32( + ("a1", &vec![0, 1, 2]), + ("b1", &vec![3, 4, 5]), + ("c1", &vec![4, 5, 6]), + ); + let left_batch_2 = build_table_i32( + ("a1", &vec![3, 4, 5, 6]), + ("b1", &vec![6, 6, 7, 9]), + ("c1", &vec![7, 8, 9, 9]), + ); + let right_batch_1 = build_table_i32( + ("a2", &vec![0, 10, 20]), + ("b2", &vec![2, 4, 6]), + ("c2", &vec![50, 60, 70]), + ); + let right_batch_2 = build_table_i32( + ("a2", &vec![30, 40]), + ("b2", &vec![6, 8]), + ("c2", &vec![80, 90]), + ); + let left = build_table_from_batches(vec![left_batch_1, left_batch_2]); + let right = build_table_from_batches(vec![right_batch_1, right_batch_2]); + let on: JoinOn = vec![( + Arc::new(Column::new_with_schema("b1", &left.schema())?), + Arc::new(Column::new_with_schema("b2", &right.schema())?), + )]; + + let (_, batches) = join_collect(test_type, left, right, on, Existence).await?; + let expected = vec![ + "+----+----+----+----------+", + "| a1 | b1 | c1 | exists#0 |", + "+----+----+----+----------+", + "| 0 | 3 | 4 | false |", + "| 1 | 4 | 5 | true |", + "| 2 | 5 | 6 | false |", + "| 3 | 6 | 7 | true |", + "| 4 | 6 | 8 | true |", + "| 5 | 7 | 9 | false |", + "| 6 | 9 | 9 | false |", + "+----+----+----+----------+", + ]; + assert_batches_sorted_eq!(expected, &batches); + } + Ok(()) + } +}
diff --git a/native-engine/datafusion-ext-plans/src/lib.rs b/native-engine/datafusion-ext-plans/src/lib.rs index a0797fb..a48fb56 100644 --- a/native-engine/datafusion-ext-plans/src/lib.rs +++ b/native-engine/datafusion-ext-plans/src/lib.rs
@@ -13,32 +13,38 @@ // limitations under the License. #![feature(get_mut_unchecked)] -#![feature(io_error_other)] +#![feature(adt_const_params)] -pub mod agg; +// execution plan implementations pub mod agg_exec; +pub mod broadcast_join_build_hash_map_exec; pub mod broadcast_join_exec; -pub mod broadcast_nested_loop_join_exec; -pub mod common; pub mod debug_exec; pub mod empty_partitions_exec; pub mod expand_exec; pub mod ffi_reader_exec; pub mod filter_exec; -pub mod generate; pub mod generate_exec; pub mod ipc_reader_exec; pub mod ipc_writer_exec; pub mod limit_exec; -pub mod memmgr; pub mod parquet_exec; pub mod parquet_sink_exec; pub mod project_exec; pub mod rename_columns_exec; pub mod rss_shuffle_writer_exec; -mod shuffle; pub mod shuffle_writer_exec; pub mod sort_exec; pub mod sort_merge_join_exec; -pub mod window; pub mod window_exec; + +// memory management +pub mod memmgr; + +// helper modules +pub mod agg; +pub mod common; +pub mod generate; +pub mod joins; +mod shuffle; +pub mod window;
diff --git a/native-engine/datafusion-ext-plans/src/parquet_exec.rs b/native-engine/datafusion-ext-plans/src/parquet_exec.rs index 8fd5f57..f7c2062 100644 --- a/native-engine/datafusion-ext-plans/src/parquet_exec.rs +++ b/native-engine/datafusion-ext-plans/src/parquet_exec.rs
@@ -20,7 +20,7 @@ use std::{any::Any, fmt, fmt::Formatter, ops::Range, sync::Arc}; use arrow::{ - array::ArrayRef, + array::{Array, ArrayRef, AsArray, ListArray}, datatypes::{DataType, SchemaRef}, }; use base64::{prelude::BASE64_URL_SAFE_NO_PAD, Engine}; @@ -56,7 +56,6 @@ use datafusion_ext_commons::{ batch_size, df_execution_err, hadoop_fs::{FsDataInputStream, FsProvider}, - streams::coalesce_stream::CoalesceInput, }; use fmt::Debug; use futures::{future::BoxFuture, stream::once, FutureExt, StreamExt, TryStreamExt}; @@ -71,7 +70,61 @@ col: &ArrayRef, data_type: &DataType, ) -> Result<ArrayRef, DataFusionError> { - datafusion_ext_commons::cast::cast_scan_input_array(col.as_ref(), data_type) + macro_rules! handle_decimal { + ($s:ident, $t:ident, $tnative:ty, $prec:expr, $scale:expr) => {{ + use arrow::{array::*, datatypes::*}; + type DecimalBuilder = paste::paste! {[<$t Builder>]}; + type IntType = paste::paste! {[<$s Type>]}; + + let col = col.as_primitive::<IntType>(); + let mut decimal_builder = DecimalBuilder::new(); + for i in 0..col.len() { + if col.is_valid(i) { + decimal_builder.append_value(col.value(i) as $tnative); + } else { + decimal_builder.append_null(); + } + } + Ok(Arc::new( + decimal_builder + .finish() + .with_precision_and_scale($prec, $scale)?, + )) + }}; + } + match data_type { + DataType::Decimal128(prec, scale) => match col.data_type() { + DataType::Int8 => handle_decimal!(Int8, Decimal128, i128, *prec, *scale), + DataType::Int16 => handle_decimal!(Int16, Decimal128, i128, *prec, *scale), + DataType::Int32 => handle_decimal!(Int32, Decimal128, i128, *prec, *scale), + DataType::Int64 => handle_decimal!(Int64, Decimal128, i128, *prec, *scale), + DataType::Decimal128(p, s) if p == prec && s == scale => Ok(col.clone()), + _ => df_execution_err!( + "schema_adapter_cast_column unsupported type: {:?} => {:?}", + col.data_type(), + data_type, + ), + }, + DataType::List(to_field) => match col.data_type() { + DataType::List(_from_field) => { + let col = col.as_list::<i32>(); + let from_inner = col.values(); + let to_inner = schema_adapter_cast_column(from_inner, to_field.data_type())?; + Ok(Arc::new(ListArray::try_new( + to_field.clone(), + col.offsets().clone(), + to_inner, + col.nulls().cloned(), + )?)) + } + _ => df_execution_err!( + "schema_adapter_cast_column unsupported type: {:?} => {:?}", + col.data_type(), + data_type, + ), + }, + _ => datafusion_ext_commons::cast::cast_scan_input_array(col.as_ref(), data_type), + } } /// Execution plan for scanning one or more Parquet partitions @@ -231,6 +284,9 @@ None => (0..self.base_config.file_schema.fields().len()).collect(), }; + let page_filtering_enabled = conf::PARQUET_ENABLE_PAGE_FILTERING.value()?; + let bloom_filter_enabled = conf::PARQUET_ENABLE_BLOOM_FILTER.value()?; + let opener = ParquetOpener { partition_index, projection: Arc::from(projection), @@ -243,10 +299,10 @@ metadata_size_hint: None, metrics: self.metrics.clone(), parquet_file_reader_factory: Arc::new(FsReaderFactory::new(fs_provider)), - pushdown_filters: false, - reorder_filters: false, - enable_page_index: false, - enable_bloom_filter: false, + pushdown_filters: page_filtering_enabled, + reorder_filters: page_filtering_enabled, + enable_page_index: page_filtering_enabled, + enable_bloom_filter: bloom_filter_enabled, }; let baseline_metrics_cloned = baseline_metrics.clone(); @@ -274,7 +330,7 @@ }) .try_flatten(), )); - context.coalesce_with_default_batch_size(timed_stream, &baseline_metrics) + Ok(timed_stream) } fn metrics(&self) -> Option<MetricsSet> {
diff --git a/native-engine/datafusion-ext-plans/src/rename_columns_exec.rs b/native-engine/datafusion-ext-plans/src/rename_columns_exec.rs index f2dff1d..69b46cf 100644 --- a/native-engine/datafusion-ext-plans/src/rename_columns_exec.rs +++ b/native-engine/datafusion-ext-plans/src/rename_columns_exec.rs
@@ -35,7 +35,6 @@ SendableRecordBatchStream, Statistics, }, }; -use datafusion_ext_commons::df_execution_err; use futures::{Stream, StreamExt}; use crate::agg::AGG_BUF_COLUMN_NAME; @@ -56,7 +55,12 @@ let input_schema = input.schema(); let mut new_names = vec![]; - for (i, field) in input_schema.fields().iter().enumerate() { + for (i, field) in input_schema + .fields() + .iter() + .take(renamed_column_names.len()) + .enumerate() + { if field.name() != AGG_BUF_COLUMN_NAME { new_names.push(renamed_column_names[i].clone()); } else { @@ -64,11 +68,9 @@ break; } } - if new_names.len() != input_schema.fields().len() { - df_execution_err!( - "renamed_column_names length not matched with input schema, \ - renames: {renamed_column_names:?}, input schema: {input_schema}", - )?; + + while new_names.len() < input_schema.fields().len() { + new_names.push(input_schema.field(new_names.len()).name().clone()); } let renamed_column_names = new_names; let renamed_schema = Arc::new(Schema::new(
diff --git a/native-engine/datafusion-ext-plans/src/sort_exec.rs b/native-engine/datafusion-ext-plans/src/sort_exec.rs index 33a2818..5479940 100644 --- a/native-engine/datafusion-ext-plans/src/sort_exec.rs +++ b/native-engine/datafusion-ext-plans/src/sort_exec.rs
@@ -49,7 +49,6 @@ downcast_any, ds::loser_tree::{ComparableForLoserTree, LoserTree}, io::{read_len, read_one_batch, write_len, write_one_batch}, - staging_mem_size_for_partial_sort, streams::coalesce_stream::CoalesceInput, }; use futures::{lock::Mutex, stream::once, StreamExt, TryStreamExt}; @@ -70,6 +69,7 @@ MemConsumer, MemConsumerInfo, MemManager, }, }; +use crate::common::batch_selection::take_batch; // reserve memory for each spill // estimated size: bufread=64KB + lz4dec.src=64KB + lz4dec.dest=64KB @@ -242,11 +242,9 @@ #[derive(Default)] struct BufferedData { - staging_batches: Vec<RecordBatch>, sorted_key_stores: Vec<Box<[u8]>>, sorted_key_stores_mem_used: usize, sorted_batches: Vec<RecordBatch>, - staging_mem_used: usize, sorted_batches_mem_used: usize, num_rows: usize, } @@ -271,34 +269,15 @@ } fn mem_used(&self) -> usize { - self.staging_mem_used + self.sorted_batches_mem_used + self.sorted_key_stores_mem_used + self.sorted_batches_mem_used + self.sorted_key_stores_mem_used } fn add_batch(&mut self, batch: RecordBatch, sorter: &ExternalSorter) -> Result<()> { self.num_rows += batch.num_rows(); - self.staging_mem_used += batch.get_array_mem_size(); - self.staging_batches.push(batch); - if self.staging_mem_used >= staging_mem_size_for_partial_sort() { - self.flush_staging_batches(sorter)?; - } - Ok(()) - } - - fn flush_staging_batches(&mut self, sorter: &ExternalSorter) -> Result<()> { - let staging_batches = std::mem::take(&mut self.staging_batches); - self.staging_mem_used = 0; - - let schema = sorter.prune_sort_keys_from_batch.pruned_schema.clone(); - let (key_rows, batches): (Vec<Rows>, Vec<RecordBatch>) = staging_batches - .into_iter() - .map(|batch| sorter.prune_sort_keys_from_batch.prune(batch)) - .collect::<Result<Vec<_>>>()? - .into_iter() - .unzip(); + let (key_rows, batch) = sorter.prune_sort_keys_from_batch.prune(batch)?; // sort the batch and append to sorter - let mut sorted_key_store = - Vec::with_capacity(key_rows.iter().map(|rows| rows.size()).sum::<usize>()); + let mut sorted_key_store = Vec::with_capacity(key_rows.size()); let mut key_writer = SortedKeysWriter::default(); let mut num_rows = 0; let sorted_batch; @@ -307,32 +286,28 @@ let cur_sorted_indices = key_rows .iter() .enumerate() - .flat_map(|(batch_idx, rows)| { - rows.iter() - .map(|key| unsafe { - // safety: keys have the same lifetime with key_rows - std::mem::transmute::<_, &'static [u8]>(key.as_ref()) - }) - .enumerate() - .map(move |(row_idx, key)| (key, batch_idx as u32, row_idx as u32)) + .map(|(row_idx, key)| { + let key = unsafe { + // safety: keys have the same lifetime with key_rows + std::mem::transmute::<_, &'static [u8]>(key.as_ref()) + }; + (key, row_idx as u32) }) .sorted_unstable_by_key(|&(key, ..)| key) .take(sorter.limit) - .map(|(key, batch_idx, row_idx)| { + .map(|(key, row_idx)| { num_rows += 1; key_writer.write_key(key, &mut sorted_key_store).unwrap(); - (batch_idx as usize, row_idx as usize) + row_idx as usize }) .collect::<Vec<_>>(); - sorted_batch = interleave_batches(schema, &batches, &cur_sorted_indices)?; + sorted_batch = take_batch(batch, cur_sorted_indices)?; } else { key_rows .iter() - .flat_map(|rows| { - rows.iter().map(|key| unsafe { - // safety: keys have the same lifetime with key_rows - std::mem::transmute::<_, &'static [u8]>(key.as_ref()) - }) + .map(|key| unsafe { + // safety: keys have the same lifetime with key_rows + std::mem::transmute::<_, &'static [u8]>(key.as_ref()) }) .sorted_unstable() .take(sorter.limit) @@ -351,13 +326,10 @@ } fn into_sorted_batches<'a, KC: KeyCollector>( - mut self, + self, batch_size: usize, sorter: &ExternalSorter, ) -> Result<impl Iterator<Item = (KC, RecordBatch)>> { - if !self.staging_batches.is_empty() { - self.flush_staging_batches(sorter)?; - } struct Cursor { idx: usize, row_idx: usize,
diff --git a/native-engine/datafusion-ext-plans/src/sort_merge_join_exec.rs b/native-engine/datafusion-ext-plans/src/sort_merge_join_exec.rs index 8459d47..d4e5cda 100644 --- a/native-engine/datafusion-ext-plans/src/sort_merge_join_exec.rs +++ b/native-engine/datafusion-ext-plans/src/sort_merge_join_exec.rs
@@ -12,135 +12,151 @@ // See the License for the specific language governing permissions and // limitations under the License. -use std::{any::Any, cmp::Ordering, fmt::Formatter, sync::Arc}; - -use arrow::{ - array::*, - buffer::NullBuffer, - compute::{prep_null_mask_filter, SortOptions}, - datatypes::{DataType, Schema, SchemaRef}, - record_batch::{RecordBatch, RecordBatchOptions}, - row::{Row, RowConverter, Rows, SortField}, +use std::{ + any::Any, + fmt::Formatter, + pin::Pin, + sync::Arc, + time::{Duration, Instant}, }; + +use arrow::{compute::SortOptions, datatypes::SchemaRef}; +use async_trait::async_trait; use datafusion::{ - common::JoinSide, + common::{DataFusionError, JoinSide}, error::Result, execution::context::TaskContext, - logical_expr::{JoinType, JoinType::*}, - physical_expr::{expressions::Column, PhysicalSortExpr}, + physical_expr::{PhysicalExprRef, PhysicalSortExpr}, physical_plan::{ - joins::utils::{build_join_schema, check_join_is_valid, ColumnIndex, JoinFilter, JoinOn}, - metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet, ScopedTimerGuard}, + joins::utils::JoinOn, + metrics::{BaselineMetrics, ExecutionPlanMetricsSet, MetricsSet}, stream::RecordBatchStreamAdapter, DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream, Statistics, }, }; use datafusion_ext_commons::{ - array_size::ArraySize, batch_size, df_execution_err, downcast_any, - streams::coalesce_stream::CoalesceInput, suggested_output_batch_mem_size, + batch_size, df_execution_err, streams::coalesce_stream::CoalesceInput, }; -use futures::{StreamExt, TryStreamExt}; -use parking_lot::Mutex as SyncMutex; +use futures::TryStreamExt; -use crate::common::{ - batch_selection::{interleave_batches, take_batch_opt}, - column_pruning::ExecuteWithColumnPruning, - output::{TaskOutputter, WrappedRecordBatchSender}, +use crate::{ + common::{ + column_pruning::ExecuteWithColumnPruning, + output::{TaskOutputter, WrappedRecordBatchSender}, + }, + cur_forward, + joins::{ + join_utils::{JoinType, JoinType::*}, + smj::{ + existence_join::ExistenceJoiner, + full_join::{FullOuterJoiner, InnerJoiner, LeftOuterJoiner, RightOuterJoiner}, + semi_join::{LeftAntiJoiner, LeftSemiJoiner, RightAntiJoiner, RightSemiJoiner}, + }, + stream_cursor::StreamCursor, + JoinParams, JoinProjection, StreamCursors, + }, }; #[derive(Debug)] pub struct SortMergeJoinExec { - /// Left sorted joining execution plan left: Arc<dyn ExecutionPlan>, - /// Right sorting joining execution plan right: Arc<dyn ExecutionPlan>, - /// Set of common columns used to join on on: JoinOn, - /// How the join is performed join_type: JoinType, - /// Optional filter before outputting - join_filter: Option<JoinFilter>, - /// The schema once the join is applied - schema: SchemaRef, - /// Execution metrics - metrics: ExecutionPlanMetricsSet, - /// Sort options of join columns used in sorting left and right execution - /// plans sort_options: Vec<SortOptions>, + schema: SchemaRef, + metrics: ExecutionPlanMetricsSet, } impl SortMergeJoinExec { pub fn try_new( + schema: SchemaRef, left: Arc<dyn ExecutionPlan>, right: Arc<dyn ExecutionPlan>, on: JoinOn, join_type: JoinType, - join_filter: Option<JoinFilter>, sort_options: Vec<SortOptions>, ) -> Result<Self> { - let left_schema = left.schema(); - let right_schema = right.schema(); - - if matches!(join_type, LeftSemi | LeftAnti | RightSemi | RightAnti,) { - if join_filter.is_some() { - df_execution_err!("Semi/Anti join with filter is not supported yet")?; - } - } - - check_join_is_valid(&left_schema, &right_schema, &on)?; - if sort_options.len() != on.len() { - df_execution_err!( - "Expected number of sort options: {}, actual: {}", - on.len(), - sort_options.len(), - )?; - } - - let schema = Arc::new(build_join_schema(&left_schema, &right_schema, &join_type).0); Ok(Self { + schema, left, right, on, join_type, - join_filter, - schema, - metrics: ExecutionPlanMetricsSet::new(), sort_options, + metrics: ExecutionPlanMetricsSet::new(), }) } - fn create_join_params(&self, batch_size: usize) -> JoinParams { - let on_left: Vec<usize> = self + fn create_join_params(&self, projection: &[usize]) -> Result<JoinParams> { + let left_schema = self.left.schema(); + let right_schema = self.right.schema(); + let (left_keys, right_keys): (Vec<PhysicalExprRef>, Vec<PhysicalExprRef>) = + self.on.iter().cloned().unzip(); + let key_data_types = self .on .iter() - .map(|on| downcast_any!(on.0, Column).unwrap().index()) - .collect(); - let on_right: Vec<usize> = self - .on - .iter() - .map(|on| downcast_any!(on.1, Column).unwrap().index()) - .collect(); - let on_data_types = on_left - .iter() - .map(|&i| self.left.schema().field(i).data_type().clone()) - .collect::<Vec<_>>(); - let sub_batch_size = batch_size / batch_size.ilog10() as usize; + .map(|(left_key, right_key)| { + Ok({ + let left_dt = left_key.data_type(&left_schema)?; + let right_dt = right_key.data_type(&right_schema)?; + if left_dt != right_dt { + df_execution_err!( + "join key data type differs {left_dt:?} <-> {right_dt:?}" + )?; + } + left_dt + }) + }) + .collect::<Result<_>>()?; - // use smaller batch size and coalesce batches at the end, to avoid buffer - // overflowing - JoinParams { + let projection = JoinProjection::try_new( + self.join_type, + &self.schema, + &left_schema, + &right_schema, + projection, + )?; + Ok(JoinParams { join_type: self.join_type, + left_schema, + right_schema, output_schema: self.schema(), - on_left, - on_right, - on_data_types, - join_filter: self.join_filter.clone(), + left_keys, + right_keys, + key_data_types, sort_options: self.sort_options.clone(), - batch_size: sub_batch_size, - left_output_projection: (0..self.left.schema().fields().len()).collect(), - right_output_projection: (0..self.right.schema().fields().len()).collect(), - } + projection, + batch_size: batch_size(), + }) + } + + fn execute_with_projection( + &self, + partition: usize, + context: Arc<TaskContext>, + projection: Vec<usize>, + ) -> Result<SendableRecordBatchStream> { + let metrics = Arc::new(BaselineMetrics::new(&self.metrics, partition)); + let join_params = self.create_join_params(&projection)?; + let left = self.left.execute(partition, context.clone())?; + let right = self.right.execute(partition, context.clone())?; + + let metrics_cloned = metrics.clone(); + let context_cloned = context.clone(); + let output_stream = Box::pin(RecordBatchStreamAdapter::new( + join_params.projection.schema.clone(), + futures::stream::once(async move { + context_cloned.output_with_sender( + "SortMergeJoin", + join_params.projection.schema.clone(), + move |sender| execute_join(left, right, join_params, metrics_cloned, sender), + ) + }) + .try_flatten(), + )); + Ok(context.coalesce_with_default_batch_size(output_stream, &metrics)?) } } @@ -154,6 +170,17 @@ } } +impl ExecuteWithColumnPruning for SortMergeJoinExec { + fn execute_projected( + &self, + partition: usize, + context: Arc<TaskContext>, + projection: &[usize], + ) -> Result<SendableRecordBatchStream> { + self.execute_with_projection(partition, context, projection.to_vec()) + } +} + impl ExecutionPlan for SortMergeJoinExec { fn as_any(&self) -> &dyn Any { self @@ -169,7 +196,7 @@ fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> { match self.join_type { - Left | LeftSemi | LeftAnti => self.left.output_ordering(), + Left | LeftSemi | LeftAnti | Existence => self.left.output_ordering(), Right | RightSemi | RightAnti => self.right.output_ordering(), Inner => self.left.output_ordering(), Full => None, @@ -185,11 +212,11 @@ children: Vec<Arc<dyn ExecutionPlan>>, ) -> Result<Arc<dyn ExecutionPlan>> { Ok(Arc::new(SortMergeJoinExec::try_new( + self.schema(), children[0].clone(), children[1].clone(), self.on.clone(), self.join_type, - self.join_filter.clone(), self.sort_options.clone(), )?)) } @@ -199,12 +226,8 @@ partition: usize, context: Arc<TaskContext>, ) -> Result<SendableRecordBatchStream> { - let metrics = Arc::new(BaselineMetrics::new(&self.metrics, partition)); - let batch_size = batch_size(); - let join_params = self.create_join_params(batch_size); - let left = self.left.execute(partition, context.clone())?; - let right = self.right.execute(partition, context.clone())?; - execute_with_join_params(context, join_params, left, right, metrics) + let projection = (0..self.schema.fields().len()).collect(); + self.execute_with_projection(partition, context, projection) } fn metrics(&self) -> Option<MetricsSet> { @@ -216,1549 +239,76 @@ } } -impl ExecuteWithColumnPruning for SortMergeJoinExec { - fn execute_projected( - &self, - partition: usize, - context: Arc<TaskContext>, - projection: &[usize], - ) -> Result<SendableRecordBatchStream> { - let metrics = Arc::new(BaselineMetrics::new(&self.metrics, partition)); - let batch_size = batch_size(); - - let (join_params, left_projection, right_projection) = - self.create_join_params(batch_size).project(projection)?; - let left = self - .left - .execute_projected(partition, context.clone(), &left_projection)?; - let right = self - .right - .execute_projected(partition, context.clone(), &right_projection)?; - execute_with_join_params(context, join_params, left, right, metrics) - } -} - -#[derive(Clone)] -struct JoinParams { - join_type: JoinType, - output_schema: SchemaRef, - on_left: Vec<usize>, - on_right: Vec<usize>, - on_data_types: Vec<DataType>, - sort_options: Vec<SortOptions>, - join_filter: Option<JoinFilter>, - left_output_projection: Vec<usize>, - right_output_projection: Vec<usize>, - batch_size: usize, -} - -impl JoinParams { - fn project(&self, projection: &[usize]) -> Result<(Self, Vec<usize>, Vec<usize>)> { - let num_left_fields = self.left_output_projection.len(); - let mut left_projection = vec![]; - let mut right_projection = vec![]; - - for &i in projection { - match self.join_type { - Inner | Left | Right | Full => { - if i < num_left_fields { - left_projection.push(i); - } else { - right_projection.push(i - num_left_fields); - } - } - LeftSemi | LeftAnti => { - left_projection.push(i); - } - RightSemi | RightAnti => { - right_projection.push(i); - } - } - } - let num_left_output_columns = left_projection.len(); - let num_right_output_columns = right_projection.len(); - - let mut on_left_projected = vec![]; - let mut on_right_projected = vec![]; - for &l in &self.on_left { - on_left_projected.push(left_projection.iter().position(|&i| i == l).unwrap_or_else( - || { - left_projection.push(l); - left_projection.len() - 1 - }, - )); - } - for &r in &self.on_right { - on_right_projected.push( - right_projection - .iter() - .position(|&i| i == r) - .unwrap_or_else(|| { - right_projection.push(r); - right_projection.len() - 1 - }), - ); - } - - let mut join_filter_projected = None; - if let Some(join_filter) = &self.join_filter { - join_filter_projected = Some(JoinFilter::new( - join_filter.expression().clone(), - join_filter - .column_indices() - .iter() - .map(|ci| { - let projected_index = match ci.side { - JoinSide::Left => left_projection - .iter() - .position(|&i| i == ci.index) - .unwrap_or_else(|| { - left_projection.push(ci.index); - left_projection.len() - 1 - }), - JoinSide::Right => right_projection - .iter() - .position(|&i| i == ci.index) - .unwrap_or_else(|| { - right_projection.push(ci.index); - right_projection.len() - 1 - }), - }; - ColumnIndex { - index: projected_index, - side: ci.side, - } - }) - .collect(), - join_filter.schema().clone(), - )); - } - - let projected = Self { - join_type: self.join_type, - output_schema: Arc::new(self.output_schema.project(projection)?), - on_left: on_left_projected, - on_right: on_right_projected, - on_data_types: self.on_data_types.clone(), - sort_options: self.sort_options.clone(), - join_filter: join_filter_projected, - batch_size: self.batch_size, - left_output_projection: (0..num_left_output_columns).collect(), - right_output_projection: (0..num_right_output_columns).collect(), - }; - Ok((projected, left_projection, right_projection)) - } -} - -fn execute_with_join_params( - context: Arc<TaskContext>, - join_params: JoinParams, - left: SendableRecordBatchStream, - right: SendableRecordBatchStream, - metrics: Arc<BaselineMetrics>, -) -> Result<SendableRecordBatchStream> { - let metrics_cloned = metrics.clone(); - let context_cloned = context.clone(); - let output_schema = join_params.output_schema.clone(); - let output_stream = Box::pin(RecordBatchStreamAdapter::new( - join_params.output_schema.clone(), - futures::stream::once(async move { - context_cloned.output_with_sender("SortMergeJoin", output_schema, move |sender| { - execute_join(left, right, join_params, metrics_cloned, sender) - }) - }) - .try_flatten(), - )); - Ok(context.coalesce_with_default_batch_size(output_stream, &metrics)?) -} - -async fn execute_join( +pub async fn execute_join( lstream: SendableRecordBatchStream, rstream: SendableRecordBatchStream, join_params: JoinParams, metrics: Arc<BaselineMetrics>, sender: Arc<WrappedRecordBatchSender>, ) -> Result<()> { - let elapsed_time = metrics.elapsed_compute().clone(); - let mut timer = elapsed_time.timer(); + let start_time = Instant::now(); - let on_row_converter = Arc::new(SyncMutex::new(RowConverter::new( - join_params - .on_data_types - .iter() - .zip(&join_params.sort_options) - .map(|(data_type, sort_option)| { - SortField::new_with_options(data_type.clone(), *sort_option) - }) - .collect(), - )?)); + let mut curs = ( + StreamCursor::try_new( + lstream, + &join_params, + JoinSide::Left, + &join_params.projection.left, + )?, + StreamCursor::try_new( + rstream, + &join_params, + JoinSide::Right, + &join_params.projection.right, + )?, + ); - let mut lcur = StreamCursor::try_new( - lstream, - on_row_converter.clone(), - join_params.on_left.clone(), - join_params.left_output_projection.clone(), + // start first batches of both side asynchronously + tokio::try_join!( + async { Ok::<_, DataFusionError>(cur_forward!(curs.0)) }, + async { Ok::<_, DataFusionError>(cur_forward!(curs.1)) }, )?; - let mut rcur = StreamCursor::try_new( - rstream, - on_row_converter.clone(), - join_params.on_right.clone(), - join_params.right_output_projection.clone(), - )?; - - macro_rules! forward { - ($cur:expr) => {{ - if $cur.next() == NextAction::LoadNextBatch { - $cur.next_batch(&mut timer).await?; - } - }}; - } - - // load first record - forward!(lcur); - forward!(rcur); let join_type = join_params.join_type; - let mut joiner = Joiner::new(); - let mut leqs = vec![]; - let mut reqs = vec![]; + let mut joiner: Pin<Box<dyn Joiner + Send>> = match join_type { + Inner => Box::pin(InnerJoiner::new(join_params, sender)), + Left => Box::pin(LeftOuterJoiner::new(join_params, sender)), + Right => Box::pin(RightOuterJoiner::new(join_params, sender)), + Full => Box::pin(FullOuterJoiner::new(join_params, sender)), + LeftSemi => Box::pin(LeftSemiJoiner::new(join_params, sender)), + RightSemi => Box::pin(RightSemiJoiner::new(join_params, sender)), + LeftAnti => Box::pin(LeftAntiJoiner::new(join_params, sender)), + RightAnti => Box::pin(RightAntiJoiner::new(join_params, sender)), + Existence => Box::pin(ExistenceJoiner::new(join_params, sender)), + }; + joiner.as_mut().join(&mut curs).await?; + metrics.record_output(joiner.num_output_rows()); - macro_rules! joiner_accept_pair { - ($lidx:expr, $ridx:expr) => {{ - let lidx = $lidx; - let ridx = $ridx; - let r = joiner.accept_pair(&join_params, &mut lcur, &mut rcur, lidx, ridx)?; - if let Some(batch) = r { - metrics.record_output(batch.num_rows()); - sender.send(Ok(batch), Some(&mut timer)).await; - } - }}; - } - - // process records until one side is exhausted - while !lcur.finished && !rcur.finished { - let r = compare_cursor(&lcur, lcur.cur_idx, &rcur, rcur.cur_idx); - match r { - Ordering::Less => { - if matches!(join_type, Left | LeftAnti | Full) { - joiner_accept_pair!(Some(lcur.cur_idx), None); - } - forward!(lcur); - lcur.clear_outdated(joiner.l_min_reserved_bidx); - } - Ordering::Greater => { - if matches!(join_type, Right | RightAnti | Full) { - joiner_accept_pair!(None, Some(rcur.cur_idx)); - } - forward!(rcur); - rcur.clear_outdated(joiner.r_min_reserved_bidx); - } - Ordering::Equal => { - let lidx0 = lcur.cur_idx; - let ridx0 = rcur.cur_idx; - leqs.push(lidx0); - reqs.push(ridx0); - forward!(lcur); - forward!(rcur); - - let mut leq = true; - let mut req = true; - while leq && req { - if leq && !lcur.finished && lcur.row(lcur.cur_idx) == lcur.row(lidx0) { - leqs.push(lcur.cur_idx); - forward!(lcur); - } else { - leq = false; - } - if req && !rcur.finished && rcur.row(rcur.cur_idx) == rcur.row(ridx0) { - reqs.push(rcur.cur_idx); - forward!(rcur); - } else { - req = false; - } - } - - match join_type { - Inner | Left | Right | Full => { - for &l in &leqs { - for &r in &reqs { - joiner_accept_pair!(Some(l), Some(r)); - } - } - } - LeftSemi => { - for &l in &leqs { - joiner_accept_pair!(Some(l), None); - } - } - RightSemi => { - for &r in &reqs { - joiner_accept_pair!(None, Some(r)); - } - } - LeftAnti | RightAnti => {} - } - - if leq { - while !lcur.finished && lcur.row(lcur.cur_idx) == rcur.row(ridx0) { - match join_type { - Inner | Left | Right | Full => { - for &r in &reqs { - joiner_accept_pair!(Some(lcur.cur_idx), Some(r)); - } - } - LeftSemi => { - joiner_accept_pair!(Some(lcur.cur_idx), None); - } - RightSemi | LeftAnti | RightAnti => {} - } - forward!(lcur); - lcur.clear_outdated(joiner.l_min_reserved_bidx); - } - } - if req { - while !rcur.finished && rcur.row(rcur.cur_idx) == lcur.row(lidx0) { - match join_type { - Inner | Left | Right | Full => { - for &l in &leqs { - joiner_accept_pair!(Some(l), Some(rcur.cur_idx)); - } - } - RightSemi => { - joiner_accept_pair!(None, Some(rcur.cur_idx)); - } - LeftSemi | LeftAnti | RightAnti => {} - } - forward!(rcur); - rcur.clear_outdated(joiner.r_min_reserved_bidx); - } - } - leqs.clear(); - reqs.clear(); - lcur.clear_outdated(joiner.l_min_reserved_bidx); - rcur.clear_outdated(joiner.r_min_reserved_bidx); - } - } - - // flush joiner if cursors buffered too many batches - if !joiner.is_empty() && (lcur.num_buffered_batches() + rcur.num_buffered_batches() > 5) - || (lcur.mem_size() + rcur.mem_size() > suggested_output_batch_mem_size() - && lcur.num_buffered_batches() > 1 - && rcur.num_buffered_batches() > 1) - { - if let Some(batch) = joiner.flush_pairs(&join_params, &mut lcur, &mut rcur)? { - metrics.record_output(batch.num_rows()); - sender.send(Ok(batch), Some(&mut timer)).await; - } - } - } - - // process rest records in inexhausted side - if matches!(join_type, Left | LeftAnti | Full) { - while !lcur.finished { - joiner_accept_pair!(Some(lcur.cur_idx), None); - forward!(lcur); - lcur.clear_outdated(joiner.l_min_reserved_bidx); - } - } - if matches!(join_type, Right | RightAnti | Full) { - while !rcur.finished { - joiner_accept_pair!(None, Some(rcur.cur_idx)); - forward!(rcur); - rcur.clear_outdated(joiner.r_min_reserved_bidx); - } - } - - // flush joiner - if !joiner.is_empty() { - if let Some(batch) = joiner.flush_pairs(&join_params, &mut lcur, &mut rcur)? { - metrics.record_output(batch.num_rows()); - sender.send(Ok(batch), Some(&mut timer)).await; - } - } + // discount poll input and send output batch time + let mut join_time_ns = (Instant::now() - start_time).as_nanos() as u64; + join_time_ns -= joiner.total_send_output_time() as u64; + join_time_ns -= curs.0.total_poll_time() as u64; + join_time_ns -= curs.1.total_poll_time() as u64; + metrics + .elapsed_compute() + .add_duration(Duration::from_nanos(join_time_ns)); Ok(()) } -struct StreamCursor { - stream: SendableRecordBatchStream, - on_row_converter: Arc<SyncMutex<RowConverter>>, - on_columns: Vec<usize>, - - // IMPORTANT: - // batches/rows/null_buffers always contains a `null batch` in the front - batches: Vec<RecordBatch>, - projected_batches: Vec<RecordBatch>, - projection: Vec<usize>, - on_rows: Vec<Arc<Rows>>, - on_row_null_buffers: Vec<Option<NullBuffer>>, - cur_idx: (usize, usize), - num_null_batches: usize, - mem_size: usize, - finished: bool, +#[macro_export] +macro_rules! compare_cursor { + ($curs:expr) => {{ + match ($curs.0.cur_idx, $curs.1.cur_idx) { + (lidx, _) if $curs.0.is_null_key(lidx) => Ordering::Less, + (_, ridx) if $curs.1.is_null_key(ridx) => Ordering::Greater, + (lidx, ridx) => $curs.0.key(lidx).cmp(&$curs.1.key(ridx)), + } + }}; } -#[derive(Clone, Copy, PartialEq, Eq)] -enum NextAction { - None, - LoadNextBatch, -} - -impl StreamCursor { - fn try_new( - stream: SendableRecordBatchStream, - on_row_converter: Arc<SyncMutex<RowConverter>>, - on_columns: Vec<usize>, - projection: Vec<usize>, - ) -> Result<Self> { - let empty_batch = RecordBatch::new_empty(Arc::new(Schema::new( - stream - .schema() - .fields() - .iter() - .map(|f| f.as_ref().clone().with_nullable(true)) - .collect::<Vec<_>>(), - ))); - let null_batch = take_batch_opt(empty_batch, [Option::<usize>::None])?; - let null_on_rows = Arc::new( - on_row_converter - .lock() - .convert_columns(null_batch.project(&on_columns)?.columns())?, - ); - let null_nb = NullBuffer::new_null(1); - - Ok(Self { - stream, - on_row_converter, - on_columns, - projected_batches: vec![null_batch.project(&projection)?], - batches: vec![null_batch], - projection, - on_rows: vec![null_on_rows], - on_row_null_buffers: vec![Some(null_nb)], - cur_idx: (0, 0), - num_null_batches: 1, - mem_size: 0, - finished: false, - }) - } - - fn next(&mut self) -> NextAction { - let mut next_action = NextAction::None; - let mut cur_idx = self.cur_idx; - - if cur_idx.1 + 1 < self.batches[cur_idx.0].num_rows() { - cur_idx.1 += 1; - } else { - cur_idx.0 += 1; - cur_idx.1 = 0; - next_action = NextAction::LoadNextBatch; - } - self.cur_idx = cur_idx; - next_action - } - - async fn next_batch(&mut self, stop_timer: &mut ScopedTimerGuard<'_>) -> Result<bool> { - stop_timer.stop(); - if let Some(batch) = self.stream.next().await.transpose()? { - stop_timer.restart(); - let on_columns = batch.project(&self.on_columns)?.columns().to_vec(); - let on_row_null_buffer = on_columns - .iter() - .map(|c| c.nulls().cloned()) - .reduce(|lhs, rhs| NullBuffer::union(lhs.as_ref(), rhs.as_ref())) - .unwrap_or(None); - let on_rows = Arc::new(self.on_row_converter.lock().convert_columns(&on_columns)?); - - self.mem_size += batch.get_array_mem_size(); - self.mem_size += on_row_null_buffer - .as_ref() - .map(|nb| nb.buffer().len()) - .unwrap_or_default(); - self.mem_size += on_rows.size(); - - self.projected_batches - .push(batch.project(&self.projection)?); - self.batches.push(batch); - self.on_row_null_buffers.push(on_row_null_buffer); - self.on_rows.push(on_rows); - return Ok(true); - } else { - stop_timer.restart(); - } - self.finished = true; - Ok(false) - } - - #[inline] - fn row<'a>(&'a self, idx: (usize, usize)) -> Row<'a> { - let bidx = idx.0; - let ridx = idx.1; - self.on_rows[bidx].row(ridx) - } - - #[inline] - fn num_buffered_batches(&self) -> usize { - self.batches.len() - self.num_null_batches - } - - #[inline] - fn mem_size(&self) -> usize { - self.mem_size - } - - #[inline] - fn clear_outdated(&mut self, min_reserved_bidx: usize) { - // fill out-dated batches with null batches - for i in self.num_null_batches..min_reserved_bidx.min(self.cur_idx.0) { - self.mem_size -= self.batches[i].get_array_mem_size(); - self.mem_size -= self.on_row_null_buffers[i] - .as_ref() - .map(|nb| nb.buffer().len()) - .unwrap_or_default(); - self.mem_size -= self.on_rows[i].size(); - - self.projected_batches[i] = self.projected_batches[0].clone(); - self.batches[i] = self.batches[0].clone(); - self.on_rows[i] = self.on_rows[0].clone(); - self.on_row_null_buffers[i] = self.on_row_null_buffers[0].clone(); - self.num_null_batches += 1; - } - } -} - -#[derive(Default)] -struct Joiner { - ljoins: Vec<(usize, usize)>, - rjoins: Vec<(usize, usize)>, - l_min_reserved_bidx: usize, - r_min_reserved_bidx: usize, -} - -impl Joiner { - fn new() -> Self { - Self { - ljoins: vec![], - rjoins: vec![], - l_min_reserved_bidx: usize::MAX, - r_min_reserved_bidx: usize::MAX, - } - } - - fn accept_pair( - &mut self, - join_params: &JoinParams, - lcur: &mut StreamCursor, - rcur: &mut StreamCursor, - l: Option<(usize, usize)>, - r: Option<(usize, usize)>, - ) -> Result<Option<RecordBatch>> { - if let Some((bidx, ridx)) = l { - self.ljoins.push((bidx, ridx)); - self.l_min_reserved_bidx = self.l_min_reserved_bidx.min(bidx); - } else { - self.ljoins.push((0, 0)); - } - - if let Some((bidx, ridx)) = r { - self.rjoins.push((bidx, ridx)); - self.r_min_reserved_bidx = self.r_min_reserved_bidx.min(bidx); - } else { - self.rjoins.push((0, 0)); - } - - let batch_size = join_params.batch_size; - if self.ljoins.len() >= batch_size || self.rjoins.len() >= batch_size { - return self.flush_pairs(join_params, lcur, rcur); - } - Ok(None) - } - - fn is_empty(&self) -> bool { - self.ljoins.is_empty() && self.rjoins.is_empty() - } - - fn flush_pairs( - &mut self, - join_params: &JoinParams, - lcur: &mut StreamCursor, - rcur: &mut StreamCursor, - ) -> Result<Option<RecordBatch>> { - self.l_min_reserved_bidx = usize::MAX; - self.r_min_reserved_bidx = usize::MAX; - - if let Some(join_filter) = &join_params.join_filter { - let num_intermediate_rows = std::cmp::max(self.ljoins.len(), self.rjoins.len()); - - // get intermediate batch - let intermediate_columns = join_filter - .column_indices() - .iter() - .map(|ci| { - let (cur, joins) = match ci.side { - JoinSide::Left => (&lcur, &self.ljoins), - JoinSide::Right => (&rcur, &self.rjoins), - }; - let arrays = cur - .batches - .iter() - .map(|b| b.column(ci.index).as_ref()) - .collect::<Vec<_>>(); - Ok(arrow::compute::interleave(&arrays, joins)?) - }) - .collect::<Result<Vec<_>>>()?; - - let intermediate_batch = RecordBatch::try_new_with_options( - Arc::new(join_filter.schema().clone()), - intermediate_columns, - &RecordBatchOptions::new().with_row_count(Some(num_intermediate_rows)), - )?; - - // evalute filter - let filtered_array = join_filter - .expression() - .evaluate(&intermediate_batch)? - .into_array(intermediate_batch.num_rows())?; - let filtered = as_boolean_array(&filtered_array); - let filtered = if filtered.null_count() > 0 { - prep_null_mask_filter(filtered) - } else { - filtered.clone() - }; - - // apply filter - let mut retained = 0; - for (i, selected) in filtered.values().iter().enumerate() { - if selected { - self.ljoins[retained] = self.ljoins[i]; - self.rjoins[retained] = self.rjoins[i]; - retained += 1; - } - } - self.ljoins.truncate(retained); - self.rjoins.truncate(retained); - if retained == 0 { - return Ok(None); - } - } - - let lcols = || -> Result<Vec<ArrayRef>> { - Ok(if !lcur.projection.is_empty() { - interleave_batches( - lcur.projected_batches[0].schema(), - &lcur.projected_batches, - &self.ljoins, - )? - .columns() - .to_vec() - } else { - vec![] - }) - }; - let rcols = || -> Result<Vec<ArrayRef>> { - Ok(if !rcur.projection.is_empty() { - interleave_batches( - rcur.projected_batches[0].schema(), - &rcur.projected_batches, - &self.rjoins, - )? - .columns() - .to_vec() - } else { - vec![] - }) - }; - - let output_columns = match join_params.join_type { - LeftSemi | LeftAnti => lcols()?, - RightSemi | RightAnti => rcols()?, - _ => [lcols()?, rcols()?].concat(), - }; - let num_output_records = std::cmp::max(self.ljoins.len(), self.rjoins.len()); - self.ljoins.clear(); - self.rjoins.clear(); - let batch = RecordBatch::try_new_with_options( - join_params.output_schema.clone(), - output_columns, - &RecordBatchOptions::new().with_row_count(Some(num_output_records)), - )?; - Ok(Some(batch)) - } -} - -fn compare_cursor( - lcur: &StreamCursor, - lidx: (usize, usize), - rcur: &StreamCursor, - ridx: (usize, usize), -) -> Ordering { - match (&lcur.on_rows.get(lidx.0), &rcur.on_rows.get(ridx.0)) { - (None, _) => Ordering::Greater, - (_, None) => Ordering::Less, - (Some(lrows), Some(rrows)) => { - let lkey = &lrows.row(lidx.1); - let rkey = &rrows.row(ridx.1); - match lkey.cmp(rkey) { - Ordering::Greater => Ordering::Greater, - Ordering::Less => Ordering::Less, - _ => { - if let Some(nb) = &lcur.on_row_null_buffers[lidx.0] { - if nb.is_null(lidx.1) { - return Ordering::Less; - } - } - Ordering::Equal - } - } - } - } -} - -#[cfg(test)] -mod tests { - use std::sync::Arc; - - use arrow::{ - self, - array::*, - compute::SortOptions, - datatypes::{DataType, Field, Schema}, - record_batch::RecordBatch, - }; - use datafusion::{ - assert_batches_sorted_eq, - error::Result, - logical_expr::{JoinType, JoinType::*}, - physical_expr::expressions::Column, - physical_plan::{common, joins::utils::*, memory::MemoryExec, ExecutionPlan}, - prelude::SessionContext, - }; - - use crate::sort_merge_join_exec::SortMergeJoinExec; - - fn columns(schema: &Schema) -> Vec<String> { - schema.fields().iter().map(|f| f.name().clone()).collect() - } - - fn build_table_i32( - a: (&str, &Vec<i32>), - b: (&str, &Vec<i32>), - c: (&str, &Vec<i32>), - ) -> RecordBatch { - let schema = Schema::new(vec![ - Field::new(a.0, DataType::Int32, false), - Field::new(b.0, DataType::Int32, false), - Field::new(c.0, DataType::Int32, false), - ]); - - RecordBatch::try_new( - Arc::new(schema), - vec![ - Arc::new(Int32Array::from(a.1.clone())), - Arc::new(Int32Array::from(b.1.clone())), - Arc::new(Int32Array::from(c.1.clone())), - ], - ) - .unwrap() - } - - fn build_table( - a: (&str, &Vec<i32>), - b: (&str, &Vec<i32>), - c: (&str, &Vec<i32>), - ) -> Arc<dyn ExecutionPlan> { - let batch = build_table_i32(a, b, c); - let schema = batch.schema(); - Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap()) - } - - fn build_table_from_batches(batches: Vec<RecordBatch>) -> Arc<dyn ExecutionPlan> { - let schema = batches.first().unwrap().schema(); - Arc::new(MemoryExec::try_new(&[batches], schema, None).unwrap()) - } - - fn build_date_table( - a: (&str, &Vec<i32>), - b: (&str, &Vec<i32>), - c: (&str, &Vec<i32>), - ) -> Arc<dyn ExecutionPlan> { - let schema = Schema::new(vec![ - Field::new(a.0, DataType::Date32, false), - Field::new(b.0, DataType::Date32, false), - Field::new(c.0, DataType::Date32, false), - ]); - - let batch = RecordBatch::try_new( - Arc::new(schema), - vec![ - Arc::new(Date32Array::from(a.1.clone())), - Arc::new(Date32Array::from(b.1.clone())), - Arc::new(Date32Array::from(c.1.clone())), - ], - ) - .unwrap(); - - let schema = batch.schema(); - Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap()) - } - - fn build_date64_table( - a: (&str, &Vec<i64>), - b: (&str, &Vec<i64>), - c: (&str, &Vec<i64>), - ) -> Arc<dyn ExecutionPlan> { - let schema = Schema::new(vec![ - Field::new(a.0, DataType::Date64, false), - Field::new(b.0, DataType::Date64, false), - Field::new(c.0, DataType::Date64, false), - ]); - - let batch = RecordBatch::try_new( - Arc::new(schema), - vec![ - Arc::new(Date64Array::from(a.1.clone())), - Arc::new(Date64Array::from(b.1.clone())), - Arc::new(Date64Array::from(c.1.clone())), - ], - ) - .unwrap(); - - let schema = batch.schema(); - Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap()) - } - - /// returns a table with 3 columns of i32 in memory - pub fn build_table_i32_nullable( - a: (&str, &Vec<Option<i32>>), - b: (&str, &Vec<Option<i32>>), - c: (&str, &Vec<Option<i32>>), - ) -> Arc<dyn ExecutionPlan> { - let schema = Arc::new(Schema::new(vec![ - Field::new(a.0, DataType::Int32, true), - Field::new(b.0, DataType::Int32, true), - Field::new(c.0, DataType::Int32, true), - ])); - let batch = RecordBatch::try_new( - schema.clone(), - vec![ - Arc::new(Int32Array::from(a.1.clone())), - Arc::new(Int32Array::from(b.1.clone())), - Arc::new(Int32Array::from(c.1.clone())), - ], - ) - .unwrap(); - Arc::new(MemoryExec::try_new(&[vec![batch]], schema, None).unwrap()) - } - - fn join_with_options( - left: Arc<dyn ExecutionPlan>, - right: Arc<dyn ExecutionPlan>, - on: JoinOn, - join_type: JoinType, - sort_options: Vec<SortOptions>, - ) -> Result<SortMergeJoinExec> { - SortMergeJoinExec::try_new(left, right, on, join_type, None, sort_options) - } - - async fn join_collect( - left: Arc<dyn ExecutionPlan>, - right: Arc<dyn ExecutionPlan>, - on: JoinOn, - join_type: JoinType, - ) -> Result<(Vec<String>, Vec<RecordBatch>)> { - let sort_options = vec![SortOptions::default(); on.len()]; - join_collect_with_options(left, right, on, join_type, sort_options).await - } - - async fn join_collect_with_options( - left: Arc<dyn ExecutionPlan>, - right: Arc<dyn ExecutionPlan>, - on: JoinOn, - join_type: JoinType, - sort_options: Vec<SortOptions>, - ) -> Result<(Vec<String>, Vec<RecordBatch>)> { - let session_ctx = SessionContext::new(); - let task_ctx = session_ctx.task_ctx(); - let join = join_with_options(left, right, on, join_type, sort_options)?; - let columns = columns(&join.schema()); - - let stream = join.execute(0, task_ctx)?; - let batches = common::collect(stream).await?; - Ok((columns, batches)) - } - - #[tokio::test] - async fn join_inner_one() -> Result<()> { - let left = build_table( - ("a1", &vec![1, 2, 3]), - ("b1", &vec![4, 5, 5]), // this has a repetition - ("c1", &vec![7, 8, 9]), - ); - let right = build_table( - ("a2", &vec![10, 20, 30]), - ("b1", &vec![4, 5, 6]), - ("c2", &vec![70, 80, 90]), - ); - - let on: JoinOn = vec![( - Arc::new(Column::new_with_schema("b1", &left.schema())?), - Arc::new(Column::new_with_schema("b1", &right.schema())?), - )]; - - let (_, batches) = join_collect(left, right, on, Inner).await?; - - let expected = vec![ - "+----+----+----+----+----+----+", - "| a1 | b1 | c1 | a2 | b1 | c2 |", - "+----+----+----+----+----+----+", - "| 1 | 4 | 7 | 10 | 4 | 70 |", - "| 2 | 5 | 8 | 20 | 5 | 80 |", - "| 3 | 5 | 9 | 20 | 5 | 80 |", - "+----+----+----+----+----+----+", - ]; - // The output order is important as SMJ preserves sortedness - assert_batches_sorted_eq!(expected, &batches); - Ok(()) - } - - #[tokio::test] - async fn join_inner_two() -> Result<()> { - let left = build_table( - ("a1", &vec![1, 2, 2]), - ("b2", &vec![1, 2, 2]), - ("c1", &vec![7, 8, 9]), - ); - let right = build_table( - ("a1", &vec![1, 2, 3]), - ("b2", &vec![1, 2, 2]), - ("c2", &vec![70, 80, 90]), - ); - let on: JoinOn = vec![ - ( - Arc::new(Column::new_with_schema("a1", &left.schema())?), - Arc::new(Column::new_with_schema("a1", &right.schema())?), - ), - ( - Arc::new(Column::new_with_schema("b2", &left.schema())?), - Arc::new(Column::new_with_schema("b2", &right.schema())?), - ), - ]; - - let (_columns, batches) = join_collect(left, right, on, Inner).await?; - let expected = vec![ - "+----+----+----+----+----+----+", - "| a1 | b2 | c1 | a1 | b2 | c2 |", - "+----+----+----+----+----+----+", - "| 1 | 1 | 7 | 1 | 1 | 70 |", - "| 2 | 2 | 8 | 2 | 2 | 80 |", - "| 2 | 2 | 9 | 2 | 2 | 80 |", - "+----+----+----+----+----+----+", - ]; - // The output order is important as SMJ preserves sortedness - assert_batches_sorted_eq!(expected, &batches); - Ok(()) - } - - #[tokio::test] - async fn join_inner_two_two() -> Result<()> { - let left = build_table( - ("a1", &vec![1, 1, 2]), - ("b2", &vec![1, 1, 2]), - ("c1", &vec![7, 8, 9]), - ); - let right = build_table( - ("a1", &vec![1, 1, 3]), - ("b2", &vec![1, 1, 2]), - ("c2", &vec![70, 80, 90]), - ); - let on: JoinOn = vec![ - ( - Arc::new(Column::new_with_schema("a1", &left.schema())?), - Arc::new(Column::new_with_schema("a1", &right.schema())?), - ), - ( - Arc::new(Column::new_with_schema("b2", &left.schema())?), - Arc::new(Column::new_with_schema("b2", &right.schema())?), - ), - ]; - - let (_columns, batches) = join_collect(left, right, on, Inner).await?; - let expected = vec![ - "+----+----+----+----+----+----+", - "| a1 | b2 | c1 | a1 | b2 | c2 |", - "+----+----+----+----+----+----+", - "| 1 | 1 | 7 | 1 | 1 | 70 |", - "| 1 | 1 | 7 | 1 | 1 | 80 |", - "| 1 | 1 | 8 | 1 | 1 | 70 |", - "| 1 | 1 | 8 | 1 | 1 | 80 |", - "+----+----+----+----+----+----+", - ]; - // The output order is important as SMJ preserves sortedness - assert_batches_sorted_eq!(expected, &batches); - Ok(()) - } - - #[tokio::test] - async fn join_inner_with_nulls() -> Result<()> { - let left = build_table_i32_nullable( - ("a1", &vec![Some(1), Some(1), Some(2), Some(2)]), - ("b2", &vec![None, Some(1), Some(2), Some(2)]), // null in key field - ("c1", &vec![Some(1), None, Some(8), Some(9)]), // null in non-key field - ); - let right = build_table_i32_nullable( - ("a1", &vec![Some(1), Some(1), Some(2), Some(3)]), - ("b2", &vec![None, Some(1), Some(2), Some(2)]), - ("c2", &vec![Some(10), Some(70), Some(80), Some(90)]), - ); - let on: JoinOn = vec![ - ( - Arc::new(Column::new_with_schema("a1", &left.schema())?), - Arc::new(Column::new_with_schema("a1", &right.schema())?), - ), - ( - Arc::new(Column::new_with_schema("b2", &left.schema())?), - Arc::new(Column::new_with_schema("b2", &right.schema())?), - ), - ]; - - let (_, batches) = join_collect(left, right, on, Inner).await?; - let expected = vec![ - "+----+----+----+----+----+----+", - "| a1 | b2 | c1 | a1 | b2 | c2 |", - "+----+----+----+----+----+----+", - "| 1 | 1 | | 1 | 1 | 70 |", - "| 2 | 2 | 8 | 2 | 2 | 80 |", - "| 2 | 2 | 9 | 2 | 2 | 80 |", - "+----+----+----+----+----+----+", - ]; - // The output order is important as SMJ preserves sortedness - assert_batches_sorted_eq!(expected, &batches); - Ok(()) - } - - #[tokio::test] - async fn join_inner_with_nulls_with_options() -> Result<()> { - let left = build_table_i32_nullable( - ("a1", &vec![Some(2), Some(2), Some(1), Some(1)]), - ("b2", &vec![Some(2), Some(2), Some(1), None]), // null in key field - ("c1", &vec![Some(9), Some(8), None, Some(1)]), // null in non-key field - ); - let right = build_table_i32_nullable( - ("a1", &vec![Some(3), Some(2), Some(1), Some(1)]), - ("b2", &vec![Some(2), Some(2), Some(1), None]), - ("c2", &vec![Some(90), Some(80), Some(70), Some(10)]), - ); - let on: JoinOn = vec![ - ( - Arc::new(Column::new_with_schema("a1", &left.schema())?), - Arc::new(Column::new_with_schema("a1", &right.schema())?), - ), - ( - Arc::new(Column::new_with_schema("b2", &left.schema())?), - Arc::new(Column::new_with_schema("b2", &right.schema())?), - ), - ]; - let (_, batches) = join_collect_with_options( - left, - right, - on, - Inner, - vec![ - SortOptions { - descending: true, - nulls_first: false - }; - 2 - ], - // null_equals_null=false - ) - .await?; - let expected = vec![ - "+----+----+----+----+----+----+", - "| a1 | b2 | c1 | a1 | b2 | c2 |", - "+----+----+----+----+----+----+", - "| 2 | 2 | 9 | 2 | 2 | 80 |", - "| 2 | 2 | 8 | 2 | 2 | 80 |", - "| 1 | 1 | | 1 | 1 | 70 |", - "+----+----+----+----+----+----+", - ]; - // The output order is important as SMJ preserves sortedness - assert_batches_sorted_eq!(expected, &batches); - Ok(()) - } - - #[tokio::test] - async fn join_left_one() -> Result<()> { - let left = build_table( - ("a1", &vec![1, 2, 3]), - ("b1", &vec![4, 5, 7]), // 7 does not exist on the right - ("c1", &vec![7, 8, 9]), - ); - let right = build_table( - ("a2", &vec![10, 20, 30]), - ("b1", &vec![4, 5, 6]), - ("c2", &vec![70, 80, 90]), - ); - let on: JoinOn = vec![( - Arc::new(Column::new_with_schema("b1", &left.schema())?), - Arc::new(Column::new_with_schema("b1", &right.schema())?), - )]; - - let (_, batches) = join_collect(left, right, on, Left).await?; - let expected = vec![ - "+----+----+----+----+----+----+", - "| a1 | b1 | c1 | a2 | b1 | c2 |", - "+----+----+----+----+----+----+", - "| 1 | 4 | 7 | 10 | 4 | 70 |", - "| 2 | 5 | 8 | 20 | 5 | 80 |", - "| 3 | 7 | 9 | | | |", - "+----+----+----+----+----+----+", - ]; - // The output order is important as SMJ preserves sortedness - assert_batches_sorted_eq!(expected, &batches); - Ok(()) - } - - #[tokio::test] - async fn join_right_one() -> Result<()> { - let left = build_table( - ("a1", &vec![1, 2, 3]), - ("b1", &vec![4, 5, 7]), - ("c1", &vec![7, 8, 9]), - ); - let right = build_table( - ("a2", &vec![10, 20, 30]), - ("b1", &vec![4, 5, 6]), // 6 does not exist on the left - ("c2", &vec![70, 80, 90]), - ); - let on: JoinOn = vec![( - Arc::new(Column::new_with_schema("b1", &left.schema())?), - Arc::new(Column::new_with_schema("b1", &right.schema())?), - )]; - - let (_, batches) = join_collect(left, right, on, Right).await?; - let expected = vec![ - "+----+----+----+----+----+----+", - "| a1 | b1 | c1 | a2 | b1 | c2 |", - "+----+----+----+----+----+----+", - "| 1 | 4 | 7 | 10 | 4 | 70 |", - "| 2 | 5 | 8 | 20 | 5 | 80 |", - "| | | | 30 | 6 | 90 |", - "+----+----+----+----+----+----+", - ]; - // The output order is important as SMJ preserves sortedness - assert_batches_sorted_eq!(expected, &batches); - Ok(()) - } - - #[tokio::test] - async fn join_full_one() -> Result<()> { - let left = build_table( - ("a1", &vec![1, 2, 3]), - ("b1", &vec![4, 5, 7]), // 7 does not exist on the right - ("c1", &vec![7, 8, 9]), - ); - let right = build_table( - ("a2", &vec![10, 20, 30]), - ("b2", &vec![4, 5, 6]), - ("c2", &vec![70, 80, 90]), - ); - let on: JoinOn = vec![( - Arc::new(Column::new_with_schema("b1", &left.schema())?), - Arc::new(Column::new_with_schema("b2", &right.schema())?), - )]; - - let (_, batches) = join_collect(left, right, on, Full).await?; - let expected = vec![ - "+----+----+----+----+----+----+", - "| a1 | b1 | c1 | a2 | b2 | c2 |", - "+----+----+----+----+----+----+", - "| | | | 30 | 6 | 90 |", - "| 1 | 4 | 7 | 10 | 4 | 70 |", - "| 2 | 5 | 8 | 20 | 5 | 80 |", - "| 3 | 7 | 9 | | | |", - "+----+----+----+----+----+----+", - ]; - assert_batches_sorted_eq!(expected, &batches); - Ok(()) - } - - #[tokio::test] - async fn join_anti() -> Result<()> { - let left = build_table( - ("a1", &vec![1, 2, 2, 3, 5]), - ("b1", &vec![4, 5, 5, 7, 7]), // 7 does not exist on the right - ("c1", &vec![7, 8, 8, 9, 11]), - ); - let right = build_table( - ("a2", &vec![10, 20, 30]), - ("b1", &vec![4, 5, 6]), - ("c2", &vec![70, 80, 90]), - ); - let on: JoinOn = vec![( - Arc::new(Column::new_with_schema("b1", &left.schema())?), - Arc::new(Column::new_with_schema("b1", &right.schema())?), - )]; - - let (_, batches) = join_collect(left, right, on, LeftAnti).await?; - let expected = vec![ - "+----+----+----+", - "| a1 | b1 | c1 |", - "+----+----+----+", - "| 3 | 7 | 9 |", - "| 5 | 7 | 11 |", - "+----+----+----+", - ]; - // The output order is important as SMJ preserves sortedness - assert_batches_sorted_eq!(expected, &batches); - Ok(()) - } - - #[tokio::test] - async fn join_semi() -> Result<()> { - let left = build_table( - ("a1", &vec![1, 2, 2, 3]), - ("b1", &vec![4, 5, 5, 7]), // 7 does not exist on the right - ("c1", &vec![7, 8, 8, 9]), - ); - let right = build_table( - ("a2", &vec![10, 20, 30]), - ("b1", &vec![4, 5, 6]), // 5 is double on the right - ("c2", &vec![70, 80, 90]), - ); - let on: JoinOn = vec![( - Arc::new(Column::new_with_schema("b1", &left.schema())?), - Arc::new(Column::new_with_schema("b1", &right.schema())?), - )]; - - let (_, batches) = join_collect(left, right, on, LeftSemi).await?; - let expected = vec![ - "+----+----+----+", - "| a1 | b1 | c1 |", - "+----+----+----+", - "| 1 | 4 | 7 |", - "| 2 | 5 | 8 |", - "| 2 | 5 | 8 |", - "+----+----+----+", - ]; - // The output order is important as SMJ preserves sortedness - assert_batches_sorted_eq!(expected, &batches); - Ok(()) - } - - #[tokio::test] - async fn join_with_duplicated_column_names() -> Result<()> { - let left = build_table( - ("a", &vec![1, 2, 3]), - ("b", &vec![4, 5, 7]), - ("c", &vec![7, 8, 9]), - ); - let right = build_table( - ("a", &vec![10, 20, 30]), - ("b", &vec![1, 2, 7]), - ("c", &vec![70, 80, 90]), - ); - let on: JoinOn = vec![( - // join on a=b so there are duplicate column names on unjoined columns - Arc::new(Column::new_with_schema("a", &left.schema())?), - Arc::new(Column::new_with_schema("b", &right.schema())?), - )]; - - let (_, batches) = join_collect(left, right, on, Inner).await?; - let expected = vec![ - "+---+---+---+----+---+----+", - "| a | b | c | a | b | c |", - "+---+---+---+----+---+----+", - "| 1 | 4 | 7 | 10 | 1 | 70 |", - "| 2 | 5 | 8 | 20 | 2 | 80 |", - "+---+---+---+----+---+----+", - ]; - // The output order is important as SMJ preserves sortedness - assert_batches_sorted_eq!(expected, &batches); - Ok(()) - } - - #[tokio::test] - async fn join_date32() -> Result<()> { - let left = build_date_table( - ("a1", &vec![1, 2, 3]), - ("b1", &vec![19107, 19108, 19108]), // this has a repetition - ("c1", &vec![7, 8, 9]), - ); - let right = build_date_table( - ("a2", &vec![10, 20, 30]), - ("b1", &vec![19107, 19108, 19109]), - ("c2", &vec![70, 80, 90]), - ); - - let on: JoinOn = vec![( - Arc::new(Column::new_with_schema("b1", &left.schema())?), - Arc::new(Column::new_with_schema("b1", &right.schema())?), - )]; - - let (_, batches) = join_collect(left, right, on, Inner).await?; - - let expected = vec![ - "+------------+------------+------------+------------+------------+------------+", - "| a1 | b1 | c1 | a2 | b1 | c2 |", - "+------------+------------+------------+------------+------------+------------+", - "| 1970-01-02 | 2022-04-25 | 1970-01-08 | 1970-01-11 | 2022-04-25 | 1970-03-12 |", - "| 1970-01-03 | 2022-04-26 | 1970-01-09 | 1970-01-21 | 2022-04-26 | 1970-03-22 |", - "| 1970-01-04 | 2022-04-26 | 1970-01-10 | 1970-01-21 | 2022-04-26 | 1970-03-22 |", - "+------------+------------+------------+------------+------------+------------+", - ]; - // The output order is important as SMJ preserves sortedness - assert_batches_sorted_eq!(expected, &batches); - Ok(()) - } - - #[tokio::test] - async fn join_date64() -> Result<()> { - let left = build_date64_table( - ("a1", &vec![1, 2, 3]), - ("b1", &vec![1650703441000, 1650903441000, 1650903441000]), // this has a repetition - ("c1", &vec![7, 8, 9]), - ); - let right = build_date64_table( - ("a2", &vec![10, 20, 30]), - ("b1", &vec![1650703441000, 1650503441000, 1650903441000]), - ("c2", &vec![70, 80, 90]), - ); - - let on: JoinOn = vec![( - Arc::new(Column::new_with_schema("b1", &left.schema())?), - Arc::new(Column::new_with_schema("b1", &right.schema())?), - )]; - - let (_, batches) = join_collect(left, right, on, Inner).await?; - let expected = vec![ - "+-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+", - "| a1 | b1 | c1 | a2 | b1 | c2 |", - "+-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+", - "| 1970-01-01T00:00:00.001 | 2022-04-23T08:44:01 | 1970-01-01T00:00:00.007 | 1970-01-01T00:00:00.010 | 2022-04-23T08:44:01 | 1970-01-01T00:00:00.070 |", - "| 1970-01-01T00:00:00.002 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.008 | 1970-01-01T00:00:00.030 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.090 |", - "| 1970-01-01T00:00:00.003 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.009 | 1970-01-01T00:00:00.030 | 2022-04-25T16:17:21 | 1970-01-01T00:00:00.090 |", - "+-------------------------+---------------------+-------------------------+-------------------------+---------------------+-------------------------+", - ]; - - // The output order is important as SMJ preserves sortedness - assert_batches_sorted_eq!(expected, &batches); - Ok(()) - } - - #[tokio::test] - async fn join_left_sort_order() -> Result<()> { - let left = build_table( - ("a1", &vec![0, 1, 2, 3, 4, 5]), - ("b1", &vec![3, 4, 5, 6, 6, 7]), - ("c1", &vec![4, 5, 6, 7, 8, 9]), - ); - let right = build_table( - ("a2", &vec![0, 10, 20, 30, 40]), - ("b2", &vec![2, 4, 6, 6, 8]), - ("c2", &vec![50, 60, 70, 80, 90]), - ); - let on: JoinOn = vec![( - Arc::new(Column::new_with_schema("b1", &left.schema())?), - Arc::new(Column::new_with_schema("b2", &right.schema())?), - )]; - - let (_, batches) = join_collect(left, right, on, Left).await?; - let expected = vec![ - "+----+----+----+----+----+----+", - "| a1 | b1 | c1 | a2 | b2 | c2 |", - "+----+----+----+----+----+----+", - "| 0 | 3 | 4 | | | |", - "| 1 | 4 | 5 | 10 | 4 | 60 |", - "| 2 | 5 | 6 | | | |", - "| 3 | 6 | 7 | 20 | 6 | 70 |", - "| 3 | 6 | 7 | 30 | 6 | 80 |", - "| 4 | 6 | 8 | 20 | 6 | 70 |", - "| 4 | 6 | 8 | 30 | 6 | 80 |", - "| 5 | 7 | 9 | | | |", - "+----+----+----+----+----+----+", - ]; - assert_batches_sorted_eq!(expected, &batches); - Ok(()) - } - - #[tokio::test] - async fn join_right_sort_order() -> Result<()> { - let left = build_table( - ("a1", &vec![0, 1, 2, 3]), - ("b1", &vec![3, 4, 5, 7]), - ("c1", &vec![6, 7, 8, 9]), - ); - let right = build_table( - ("a2", &vec![0, 10, 20, 30]), - ("b2", &vec![2, 4, 5, 6]), - ("c2", &vec![60, 70, 80, 90]), - ); - let on: JoinOn = vec![( - Arc::new(Column::new_with_schema("b1", &left.schema())?), - Arc::new(Column::new_with_schema("b2", &right.schema())?), - )]; - - let (_, batches) = join_collect(left, right, on, Right).await?; - let expected = vec![ - "+----+----+----+----+----+----+", - "| a1 | b1 | c1 | a2 | b2 | c2 |", - "+----+----+----+----+----+----+", - "| | | | 0 | 2 | 60 |", - "| 1 | 4 | 7 | 10 | 4 | 70 |", - "| 2 | 5 | 8 | 20 | 5 | 80 |", - "| | | | 30 | 6 | 90 |", - "+----+----+----+----+----+----+", - ]; - assert_batches_sorted_eq!(expected, &batches); - Ok(()) - } - - #[tokio::test] - async fn join_left_multiple_batches() -> Result<()> { - let left_batch_1 = build_table_i32( - ("a1", &vec![0, 1, 2]), - ("b1", &vec![3, 4, 5]), - ("c1", &vec![4, 5, 6]), - ); - let left_batch_2 = build_table_i32( - ("a1", &vec![3, 4, 5, 6]), - ("b1", &vec![6, 6, 7, 9]), - ("c1", &vec![7, 8, 9, 9]), - ); - let right_batch_1 = build_table_i32( - ("a2", &vec![0, 10, 20]), - ("b2", &vec![2, 4, 6]), - ("c2", &vec![50, 60, 70]), - ); - let right_batch_2 = build_table_i32( - ("a2", &vec![30, 40]), - ("b2", &vec![6, 8]), - ("c2", &vec![80, 90]), - ); - let left = build_table_from_batches(vec![left_batch_1, left_batch_2]); - let right = build_table_from_batches(vec![right_batch_1, right_batch_2]); - let on: JoinOn = vec![( - Arc::new(Column::new_with_schema("b1", &left.schema())?), - Arc::new(Column::new_with_schema("b2", &right.schema())?), - )]; - - let (_, batches) = join_collect(left, right, on, Left).await?; - let expected = vec![ - "+----+----+----+----+----+----+", - "| a1 | b1 | c1 | a2 | b2 | c2 |", - "+----+----+----+----+----+----+", - "| 0 | 3 | 4 | | | |", - "| 1 | 4 | 5 | 10 | 4 | 60 |", - "| 2 | 5 | 6 | | | |", - "| 3 | 6 | 7 | 20 | 6 | 70 |", - "| 3 | 6 | 7 | 30 | 6 | 80 |", - "| 4 | 6 | 8 | 20 | 6 | 70 |", - "| 4 | 6 | 8 | 30 | 6 | 80 |", - "| 5 | 7 | 9 | | | |", - "| 6 | 9 | 9 | | | |", - "+----+----+----+----+----+----+", - ]; - assert_batches_sorted_eq!(expected, &batches); - Ok(()) - } - - #[tokio::test] - async fn join_right_multiple_batches() -> Result<()> { - let right_batch_1 = build_table_i32( - ("a2", &vec![0, 1, 2]), - ("b2", &vec![3, 4, 5]), - ("c2", &vec![4, 5, 6]), - ); - let right_batch_2 = build_table_i32( - ("a2", &vec![3, 4, 5, 6]), - ("b2", &vec![6, 6, 7, 9]), - ("c2", &vec![7, 8, 9, 9]), - ); - let left_batch_1 = build_table_i32( - ("a1", &vec![0, 10, 20]), - ("b1", &vec![2, 4, 6]), - ("c1", &vec![50, 60, 70]), - ); - let left_batch_2 = build_table_i32( - ("a1", &vec![30, 40]), - ("b1", &vec![6, 8]), - ("c1", &vec![80, 90]), - ); - let left = build_table_from_batches(vec![left_batch_1, left_batch_2]); - let right = build_table_from_batches(vec![right_batch_1, right_batch_2]); - let on: JoinOn = vec![( - Arc::new(Column::new_with_schema("b1", &left.schema())?), - Arc::new(Column::new_with_schema("b2", &right.schema())?), - )]; - - let (_, batches) = join_collect(left, right, on, Right).await?; - let expected = vec![ - "+----+----+----+----+----+----+", - "| a1 | b1 | c1 | a2 | b2 | c2 |", - "+----+----+----+----+----+----+", - "| | | | 0 | 3 | 4 |", - "| 10 | 4 | 60 | 1 | 4 | 5 |", - "| | | | 2 | 5 | 6 |", - "| 20 | 6 | 70 | 3 | 6 | 7 |", - "| 30 | 6 | 80 | 3 | 6 | 7 |", - "| 20 | 6 | 70 | 4 | 6 | 8 |", - "| 30 | 6 | 80 | 4 | 6 | 8 |", - "| | | | 5 | 7 | 9 |", - "| | | | 6 | 9 | 9 |", - "+----+----+----+----+----+----+", - ]; - assert_batches_sorted_eq!(expected, &batches); - Ok(()) - } - - #[tokio::test] - async fn join_full_multiple_batches() -> Result<()> { - let left_batch_1 = build_table_i32( - ("a1", &vec![0, 1, 2]), - ("b1", &vec![3, 4, 5]), - ("c1", &vec![4, 5, 6]), - ); - let left_batch_2 = build_table_i32( - ("a1", &vec![3, 4, 5, 6]), - ("b1", &vec![6, 6, 7, 9]), - ("c1", &vec![7, 8, 9, 9]), - ); - let right_batch_1 = build_table_i32( - ("a2", &vec![0, 10, 20]), - ("b2", &vec![2, 4, 6]), - ("c2", &vec![50, 60, 70]), - ); - let right_batch_2 = build_table_i32( - ("a2", &vec![30, 40]), - ("b2", &vec![6, 8]), - ("c2", &vec![80, 90]), - ); - let left = build_table_from_batches(vec![left_batch_1, left_batch_2]); - let right = build_table_from_batches(vec![right_batch_1, right_batch_2]); - let on: JoinOn = vec![( - Arc::new(Column::new_with_schema("b1", &left.schema())?), - Arc::new(Column::new_with_schema("b2", &right.schema())?), - )]; - - let (_, batches) = join_collect(left, right, on, Full).await?; - let expected = vec![ - "+----+----+----+----+----+----+", - "| a1 | b1 | c1 | a2 | b2 | c2 |", - "+----+----+----+----+----+----+", - "| | | | 0 | 2 | 50 |", - "| | | | 40 | 8 | 90 |", - "| 0 | 3 | 4 | | | |", - "| 1 | 4 | 5 | 10 | 4 | 60 |", - "| 2 | 5 | 6 | | | |", - "| 3 | 6 | 7 | 20 | 6 | 70 |", - "| 3 | 6 | 7 | 30 | 6 | 80 |", - "| 4 | 6 | 8 | 20 | 6 | 70 |", - "| 4 | 6 | 8 | 30 | 6 | 80 |", - "| 5 | 7 | 9 | | | |", - "| 6 | 9 | 9 | | | |", - "+----+----+----+----+----+----+", - ]; - assert_batches_sorted_eq!(expected, &batches); - Ok(()) - } +#[async_trait] +pub trait Joiner { + async fn join(self: Pin<&mut Self>, curs: &mut StreamCursors) -> Result<()>; + fn total_send_output_time(&self) -> usize; + fn num_output_rows(&self) -> usize; }
diff --git a/pom.xml b/pom.xml index 8598cfd..f5d8fa2 100644 --- a/pom.xml +++ b/pom.xml
@@ -13,7 +13,7 @@ </modules> <properties> - <revision>2.0.9.1-SNAPSHOT</revision> + <revision>3.0.0-SNAPSHOT</revision> <project.build.sourceEncoding>UTF-8</project.build.sourceEncoding> <arrowVersion>15.0.2</arrowVersion> <protobufVersion>3.21.9</protobufVersion> @@ -107,6 +107,13 @@ </compilerPlugin> </compilerPlugins> </configuration> + <dependencies> + <dependency> + <groupId>com.google.code.findbugs</groupId> + <artifactId>jsr305</artifactId> + <version>2.0.2</version> + </dependency> + </dependencies> <executions> <execution> <id>scala-compile-first</id>
diff --git a/rust-toolchain.toml b/rust-toolchain.toml index 42d866a..1e25db4 100755 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml
@@ -16,5 +16,5 @@ # under the License. [toolchain] -channel = "nightly-2023-08-01" -components = ["cargo", "rustfmt", "clippy"] +channel = "nightly-2024-06-27" +components = ["rust-src", "cargo", "rustfmt", "clippy"]
diff --git a/spark-extension-shims-spark303/src/main/scala/org/apache/spark/sql/blaze/ShimsImpl.scala b/spark-extension-shims-spark303/src/main/scala/org/apache/spark/sql/blaze/ShimsImpl.scala index 823a2bc..8dd9ea8 100644 --- a/spark-extension-shims-spark303/src/main/scala/org/apache/spark/sql/blaze/ShimsImpl.scala +++ b/spark-extension-shims-spark303/src/main/scala/org/apache/spark/sql/blaze/ShimsImpl.scala
@@ -79,8 +79,6 @@ import org.apache.spark.sql.execution.blaze.plan.NativeAggExec import org.apache.spark.sql.execution.blaze.plan.NativeBroadcastJoinBase import org.apache.spark.sql.execution.blaze.plan.NativeBroadcastJoinExec -import org.apache.spark.sql.execution.blaze.plan.NativeBroadcastNestedLoopJoinBase -import org.apache.spark.sql.execution.blaze.plan.NativeBroadcastNestedLoopJoinExec import org.apache.spark.sql.execution.blaze.plan.NativeExpandBase import org.apache.spark.sql.execution.blaze.plan.NativeExpandExec import org.apache.spark.sql.execution.blaze.plan.NativeFilterBase @@ -114,6 +112,7 @@ import org.apache.spark.sql.types.DataType import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.catalog.CatalogTable +import org.apache.spark.sql.execution.blaze.plan.BroadcastSide import org.apache.spark.sql.execution.blaze.plan.NativeParquetSinkBase import org.apache.spark.sql.execution.blaze.plan.NativeParquetSinkExec import org.blaze.{protobuf => pb} @@ -153,7 +152,7 @@ leftKeys: Seq[Expression], rightKeys: Seq[Expression], joinType: JoinType, - condition: Option[Expression]): NativeBroadcastJoinBase = + buildSide: BroadcastSide): NativeBroadcastJoinBase = NativeBroadcastJoinExec( left, right, @@ -161,14 +160,7 @@ leftKeys, rightKeys, joinType, - condition) - - override def createNativeBroadcastNestedLoopJoinExec( - left: SparkPlan, - right: SparkPlan, - joinType: JoinType, - condition: Option[Expression]): NativeBroadcastNestedLoopJoinBase = - NativeBroadcastNestedLoopJoinExec(left, right, joinType, condition) + buildSide) override def createNativeSortMergeJoinExec( left: SparkPlan,
diff --git a/spark-extension-shims-spark303/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeBroadcastJoinExec.scala b/spark-extension-shims-spark303/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeBroadcastJoinExec.scala index 75d1b7c..3101587 100644 --- a/spark-extension-shims-spark303/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeBroadcastJoinExec.scala +++ b/spark-extension-shims-spark303/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeBroadcastJoinExec.scala
@@ -19,8 +19,9 @@ import org.apache.spark.sql.catalyst.plans.JoinType import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.joins import org.apache.spark.sql.execution.joins.BuildLeft +import org.apache.spark.sql.execution.joins.BuildRight +import org.apache.spark.sql.execution.joins.BuildSide import org.apache.spark.sql.execution.joins.HashJoin case class NativeBroadcastJoinExec( @@ -30,7 +31,7 @@ override val leftKeys: Seq[Expression], override val rightKeys: Seq[Expression], override val joinType: JoinType, - override val condition: Option[Expression]) + broadcastSide: BroadcastSide) extends NativeBroadcastJoinBase( left, right, @@ -38,10 +39,15 @@ leftKeys, rightKeys, joinType, - condition) + broadcastSide) with HashJoin { - override val buildSide: joins.BuildSide = BuildLeft + override val condition: Option[Expression] = None + + override val buildSide: BuildSide = broadcastSide match { + case BroadcastLeft => BuildLeft + case BroadcastRight => BuildRight + } override def withNewChildren(newChildren: Seq[SparkPlan]): SparkPlan = copy(left = newChildren(0), right = newChildren(1))
diff --git a/spark-extension-shims-spark303/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeBroadcastNestedLoopJoinExec.scala b/spark-extension-shims-spark303/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeBroadcastNestedLoopJoinExec.scala deleted file mode 100644 index 7b215ce..0000000 --- a/spark-extension-shims-spark303/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeBroadcastNestedLoopJoinExec.scala +++ /dev/null
@@ -1,31 +0,0 @@ -/* - * Copyright 2022 The Blaze Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.execution.blaze.plan - -import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.catalyst.plans.JoinType -import org.apache.spark.sql.execution.SparkPlan - -case class NativeBroadcastNestedLoopJoinExec( - override val left: SparkPlan, - override val right: SparkPlan, - joinType: JoinType, - condition: Option[Expression]) - extends NativeBroadcastNestedLoopJoinBase(left, right, joinType, condition) { - - override def withNewChildren(newChildren: Seq[SparkPlan]): SparkPlan = - copy(left = newChildren(0), right = newChildren(1)) -}
diff --git a/spark-extension-shims-spark333/src/main/scala/org/apache/spark/sql/blaze/ShimsImpl.scala b/spark-extension-shims-spark333/src/main/scala/org/apache/spark/sql/blaze/ShimsImpl.scala index 1d867a5..a394cf9 100644 --- a/spark-extension-shims-spark333/src/main/scala/org/apache/spark/sql/blaze/ShimsImpl.scala +++ b/spark-extension-shims-spark333/src/main/scala/org/apache/spark/sql/blaze/ShimsImpl.scala
@@ -105,7 +105,6 @@ import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.CoalescedMapperPartitionSpec import org.apache.spark.sql.execution.joins.blaze.plan.NativeBroadcastJoinExec -import org.apache.spark.sql.execution.joins.blaze.plan.NativeBroadcastNestedLoopJoinExec import org.apache.spark.sql.execution.joins.blaze.plan.NativeSortMergeJoinExec import org.apache.spark.sql.hive.execution.InsertIntoHiveTable import org.apache.spark.sql.types.DataType @@ -150,7 +149,7 @@ leftKeys: Seq[Expression], rightKeys: Seq[Expression], joinType: JoinType, - condition: Option[Expression]): NativeBroadcastJoinBase = + broadcastSide: BroadcastSide): NativeBroadcastJoinBase = NativeBroadcastJoinExec( left, right, @@ -158,14 +157,7 @@ leftKeys, rightKeys, joinType, - condition) - - override def createNativeBroadcastNestedLoopJoinExec( - left: SparkPlan, - right: SparkPlan, - joinType: JoinType, - condition: Option[Expression]): NativeBroadcastNestedLoopJoinBase = - NativeBroadcastNestedLoopJoinExec(left, right, joinType, condition) + broadcastSide) override def createNativeSortMergeJoinExec( left: SparkPlan,
diff --git a/spark-extension-shims-spark333/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/BlazeBlockStoreShuffleReader.scala b/spark-extension-shims-spark333/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/BlazeBlockStoreShuffleReader.scala index 292f233..fdd3a24 100644 --- a/spark-extension-shims-spark333/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/BlazeBlockStoreShuffleReader.scala +++ b/spark-extension-shims-spark333/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/BlazeBlockStoreShuffleReader.scala
@@ -20,7 +20,6 @@ import org.apache.spark.MapOutputTracker import org.apache.spark.SparkEnv import org.apache.spark.TaskContext - import org.apache.spark.internal.Logging import org.apache.spark.internal.config import org.apache.spark.io.CompressionCodec @@ -28,30 +27,21 @@ import org.apache.spark.shuffle.ShuffleReadMetricsReporter import org.apache.spark.storage.BlockId import org.apache.spark.storage.BlockManager +import org.apache.spark.storage.BlockManagerId import org.apache.spark.storage.ShuffleBlockFetcherIterator class BlazeBlockStoreShuffleReader[K, C]( handle: BaseShuffleHandle[K, _, C], - startPartition: Int, - endPartition: Int, + blocksByAddress: Iterator[(BlockManagerId, Seq[(BlockId, Long, Int)])], context: TaskContext, readMetrics: ShuffleReadMetricsReporter, blockManager: BlockManager = SparkEnv.get.blockManager, mapOutputTracker: MapOutputTracker = SparkEnv.get.mapOutputTracker, - startMapId: Option[Int] = None, - endMapId: Option[Int] = None, shouldBatchFetch: Boolean = false) extends BlazeBlockStoreShuffleReaderBase[K, C](handle, context) with Logging { override def readBlocks(): Iterator[(BlockId, InputStream)] = { - val blocksByAddress = mapOutputTracker.getMapSizesByExecutorId( - handle.shuffleId, - startMapId.getOrElse(0), - endMapId.getOrElse(Int.MaxValue), - startPartition, - endPartition) - new ShuffleBlockFetcherIterator( context, blockManager.blockStoreClient,
diff --git a/spark-extension-shims-spark333/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/BlazeShuffleManager.scala b/spark-extension-shims-spark333/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/BlazeShuffleManager.scala index a7390ee..83decb3 100644 --- a/spark-extension-shims-spark333/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/BlazeShuffleManager.scala +++ b/spark-extension-shims-spark333/src/main/scala/org/apache/spark/sql/execution/blaze/shuffle/BlazeShuffleManager.scala
@@ -22,6 +22,7 @@ import org.apache.spark.internal.Logging import org.apache.spark.shuffle._ import org.apache.spark.shuffle.sort.SortShuffleManager +import org.apache.spark.shuffle.sort.SortShuffleManager.canUseBatchFetch import org.apache.spark.sql.execution.blaze.shuffle.BlazeShuffleDependency.isArrowShuffle class BlazeShuffleManager(conf: SparkConf) extends ShuffleManager with Logging { @@ -54,16 +55,27 @@ metrics: ShuffleReadMetricsReporter): ShuffleReader[K, C] = { if (isArrowShuffle(handle)) { + val baseShuffleHandle = handle.asInstanceOf[BaseShuffleHandle[K, _, C]] + val (blocksByAddress, canEnableBatchFetch) = + if (baseShuffleHandle.dependency.isShuffleMergeFinalizedMarked) { + val res = SparkEnv.get.mapOutputTracker.getPushBasedShuffleMapSizesByExecutorId( + handle.shuffleId, startMapIndex, endMapIndex, startPartition, endPartition) + (res.iter, res.enableBatchFetch) + } else { + val address = SparkEnv.get.mapOutputTracker.getMapSizesByExecutorId( + handle.shuffleId, startMapIndex, endMapIndex, startPartition, endPartition) + (address, true) + } + new BlazeBlockStoreShuffleReader( handle.asInstanceOf[BaseShuffleHandle[K, _, C]], - startPartition, - endPartition, + blocksByAddress, context, metrics, SparkEnv.get.blockManager, SparkEnv.get.mapOutputTracker, - startMapId = Some(startMapIndex), - endMapId = Some(endMapIndex)) + shouldBatchFetch = + canEnableBatchFetch && canUseBatchFetch(startPartition, endPartition, context)) } else { sortShuffleManager.getReader( handle,
diff --git a/spark-extension-shims-spark333/src/main/scala/org/apache/spark/sql/execution/joins/blaze/plan/NativeBroadcastJoinExec.scala b/spark-extension-shims-spark333/src/main/scala/org/apache/spark/sql/execution/joins/blaze/plan/NativeBroadcastJoinExec.scala index 3fc6649..de3f5f8 100644 --- a/spark-extension-shims-spark333/src/main/scala/org/apache/spark/sql/execution/joins/blaze/plan/NativeBroadcastJoinExec.scala +++ b/spark-extension-shims-spark333/src/main/scala/org/apache/spark/sql/execution/joins/blaze/plan/NativeBroadcastJoinExec.scala
@@ -21,12 +21,16 @@ import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.codegen.CodegenContext import org.apache.spark.sql.catalyst.optimizer.BuildLeft +import org.apache.spark.sql.catalyst.optimizer.BuildRight import org.apache.spark.sql.catalyst.optimizer.BuildSide import org.apache.spark.sql.catalyst.plans.physical.BroadcastDistribution import org.apache.spark.sql.catalyst.plans.physical.Distribution import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.catalyst.plans.physical.UnspecifiedDistribution import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.blaze.plan.BroadcastLeft +import org.apache.spark.sql.execution.blaze.plan.BroadcastRight +import org.apache.spark.sql.execution.blaze.plan.BroadcastSide import org.apache.spark.sql.execution.blaze.plan.NativeBroadcastJoinBase import org.apache.spark.sql.execution.joins.HashedRelationBroadcastMode import org.apache.spark.sql.execution.joins.HashedRelationInfo @@ -39,7 +43,7 @@ override val leftKeys: Seq[Expression], override val rightKeys: Seq[Expression], override val joinType: JoinType, - override val condition: Option[Expression]) + broadcastSide: BroadcastSide) extends NativeBroadcastJoinBase( left, right, @@ -47,9 +51,11 @@ leftKeys, rightKeys, joinType, - condition) + broadcastSide) with HashJoin { + override def condition: Option[Expression] = None + override def requiredChildDistribution: Seq[Distribution] = { val mode = HashedRelationBroadcastMode(buildBoundKeys, isNullAware = false) BroadcastDistribution(mode) :: UnspecifiedDistribution :: Nil @@ -65,7 +71,10 @@ throw new NotImplementedError("NativeBroadcastJoin dose not support codegen") } - override def buildSide: BuildSide = BuildLeft + override def buildSide: BuildSide = broadcastSide match { + case BroadcastLeft => BuildLeft + case BroadcastRight => BuildRight + } override protected def withNewChildrenInternal( newLeft: SparkPlan,
diff --git a/spark-extension-shims-spark333/src/main/scala/org/apache/spark/sql/execution/joins/blaze/plan/NativeBroadcastNestedLoopJoinExec.scala b/spark-extension-shims-spark333/src/main/scala/org/apache/spark/sql/execution/joins/blaze/plan/NativeBroadcastNestedLoopJoinExec.scala deleted file mode 100644 index a129e91..0000000 --- a/spark-extension-shims-spark333/src/main/scala/org/apache/spark/sql/execution/joins/blaze/plan/NativeBroadcastNestedLoopJoinExec.scala +++ /dev/null
@@ -1,34 +0,0 @@ -/* - * Copyright 2022 The Blaze Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.execution.joins.blaze.plan - -import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.catalyst.plans.JoinType -import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.blaze.plan.NativeBroadcastNestedLoopJoinBase - -case class NativeBroadcastNestedLoopJoinExec( - override val left: SparkPlan, - override val right: SparkPlan, - joinType: JoinType, - condition: Option[Expression]) - extends NativeBroadcastNestedLoopJoinBase(left, right, joinType, condition) { - - override protected def withNewChildrenInternal( - newLeft: SparkPlan, - newRight: SparkPlan): SparkPlan = - copy(left = newLeft, right = newRight) -}
diff --git a/spark-extension/src/main/java/org/apache/spark/sql/blaze/BlazeConf.java b/spark-extension/src/main/java/org/apache/spark/sql/blaze/BlazeConf.java index 31c3b9a..f7b1f97 100644 --- a/spark-extension/src/main/java/org/apache/spark/sql/blaze/BlazeConf.java +++ b/spark-extension/src/main/java/org/apache/spark/sql/blaze/BlazeConf.java
@@ -27,15 +27,13 @@ /// actual off-heap memory usage is expected to be spark.executor.memoryOverhead * fraction. MEMORY_FRACTION("spark.blaze.memoryFraction", 0.6), - /// translates inequality smj to native. improves performance in most cases, however some - /// issues are found in special cases, like tpcds q72. - SMJ_INEQUALITY_JOIN_ENABLE("spark.blaze.enable.smjInequalityJoin", false), - /// fallbacks to SortMergeJoin when executing BroadcastHashJoin with big broadcasted table. - BHJ_FALLBACKS_TO_SMJ_ENABLE("spark.blaze.enable.bhjFallbacksToSmj", true), + /// not available in blaze 3.0+ + BHJ_FALLBACKS_TO_SMJ_ENABLE("spark.blaze.enable.bhjFallbacksToSmj", false), /// fallbacks to SortMergeJoin when BroadcastHashJoin has a broadcasted table with rows more /// than this threshold. requires spark.blaze.enable.bhjFallbacksToSmj = true. + /// not available in blaze 3.0+ BHJ_FALLBACKS_TO_SMJ_ROWS_THRESHOLD("spark.blaze.bhjFallbacksToSmj.rows", 1000000), /// fallbacks to SortMergeJoin when BroadcastHashJoin has a broadcasted table with memory usage @@ -44,7 +42,7 @@ /// enable converting upper/lower functions to native, special cases may provide different /// outputs from spark due to different unicode versions. - CASE_CONVERT_FUNCTIONS_ENABLE("spark.blaze.enable.caseconvert.functions", false), + CASE_CONVERT_FUNCTIONS_ENABLE("spark.blaze.enable.caseconvert.functions", true), /// number of threads evaluating UDFs /// improves performance for special case that UDF concurrency matters @@ -64,6 +62,12 @@ /// mininum number of rows to trigger partial aggregate skipping PARTIAL_AGG_SKIPPING_MIN_ROWS("spark.blaze.partialAggSkipping.minRows", BATCH_SIZE.intConf() * 2), + + // parquet enable page filtering + PARQUET_ENABLE_PAGE_FILTERING("spark.blaze.parquet.enable.pageFiltering", false), + + // parqeut enable bloom filter + PARQUET_ENABLE_BLOOM_FILTER("spark.blaze.parquet.enable.bloomFilter", false), ; private String key;
diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/BlazeCallNativeWrapper.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/BlazeCallNativeWrapper.scala index 09bf85e..c24888d 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/BlazeCallNativeWrapper.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/BlazeCallNativeWrapper.scala
@@ -28,7 +28,6 @@ import org.apache.arrow.c.ArrowSchema import org.apache.arrow.c.CDataDictionaryProvider import org.apache.arrow.c.Data -import org.apache.arrow.memory.BufferAllocator import org.apache.arrow.vector.VectorSchemaRoot import org.apache.arrow.vector.types.pojo.Schema import org.apache.spark.Partition
diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/BlazeConvertStrategy.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/BlazeConvertStrategy.scala index 9f7eb61..a7ab81d 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/BlazeConvertStrategy.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/BlazeConvertStrategy.scala
@@ -46,7 +46,8 @@ val convertibleTag: TreeNodeTag[Boolean] = TreeNodeTag("blaze.convertible") val convertStrategyTag: TreeNodeTag[ConvertStrategy] = TreeNodeTag("blaze.convert.strategy") - val childOrderingRequiredTag: TreeNodeTag[Boolean] = TreeNodeTag("blaze.child.ordering.required") + val childOrderingRequiredTag: TreeNodeTag[Boolean] = TreeNodeTag( + "blaze.child.ordering.required") def apply(exec: SparkPlan): Unit = { exec.foreach(_.setTagValue(convertibleTag, true))
diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/BlazeConverters.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/BlazeConverters.scala index 99d0172..7e76b43 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/BlazeConverters.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/BlazeConverters.scala
@@ -15,11 +15,8 @@ */ package org.apache.spark.sql.blaze -import java.util.UUID - import scala.annotation.tailrec import scala.collection.mutable -import scala.collection.mutable.ArrayBuffer import org.apache.hadoop.hive.ql.io.parquet.MapredParquetOutputFormat import org.apache.spark.SparkEnv @@ -57,7 +54,6 @@ import org.apache.spark.sql.execution.aggregate.ObjectHashAggregateExec import org.apache.spark.sql.execution.blaze.plan._ import org.apache.spark.sql.execution.blaze.plan.NativeAggBase -import org.apache.spark.sql.execution.blaze.plan.NativeProjectBase import org.apache.spark.sql.execution.blaze.plan.NativeUnionBase import org.apache.spark.sql.execution.blaze.plan.Util import org.apache.spark.sql.execution.command.DataWritingCommandExec @@ -128,7 +124,9 @@ var newExec = exec.withNewChildren(newChildren) exec.getTagValue(convertibleTag).foreach(newExec.setTagValue(convertibleTag, _)) exec.getTagValue(convertStrategyTag).foreach(newExec.setTagValue(convertStrategyTag, _)) - exec.getTagValue(childOrderingRequiredTag).foreach(newExec.setTagValue(childOrderingRequiredTag, _)) + exec + .getTagValue(childOrderingRequiredTag) + .foreach(newExec.setTagValue(childOrderingRequiredTag, _)) if (!isNeverConvert(newExec)) { newExec = convertSparkPlan(newExec) } @@ -333,45 +331,14 @@ val (leftKeys, rightKeys, joinType, condition, left, right) = (exec.leftKeys, exec.rightKeys, exec.joinType, exec.condition, exec.left, exec.right) logDebug(s"Converting SortMergeJoinExec: ${Shims.get.simpleStringWithNodeId(exec)}") - var nativeLeft = convertToNative(left) - var nativeRight = convertToNative(right) - var modifiedLeftKeys = leftKeys - var modifiedRightKeys = rightKeys - var needPostProject = false - if (leftKeys.exists(!_.isInstanceOf[AttributeReference])) { - val (keys, exec) = buildJoinColumnsProject(nativeLeft, leftKeys) - modifiedLeftKeys = keys - nativeLeft = exec - needPostProject = true - } - if (rightKeys.exists(!_.isInstanceOf[AttributeReference])) { - val (keys, exec) = buildJoinColumnsProject(nativeRight, rightKeys) - modifiedRightKeys = keys - nativeRight = exec - needPostProject = true - } - - val smjOrig = SortMergeJoinExec( - modifiedLeftKeys, - modifiedRightKeys, + Shims.get.createNativeSortMergeJoinExec( + addRenameColumnsExec(convertToNative(left)), + addRenameColumnsExec(convertToNative(right)), + leftKeys, + rightKeys, joinType, - condition, - addRenameColumnsExec(nativeLeft), - addRenameColumnsExec(nativeRight)) - val smj = Shims.get.createNativeSortMergeJoinExec( - smjOrig.left, - smjOrig.right, - smjOrig.leftKeys, - smjOrig.rightKeys, - smjOrig.joinType, - smjOrig.condition) - - if (needPostProject) { - buildPostJoinProject(smj, exec.output) - } else { - smj - } + condition) } def convertBroadcastHashJoinExec(exec: BroadcastHashJoinExec): SparkPlan = { @@ -385,84 +352,33 @@ exec.left, exec.right) logDebug(s"Converting BroadcastHashJoinExec: ${Shims.get.simpleStringWithNodeId(exec)}") - logDebug(s" leftKeys: ${exec.leftKeys}") - logDebug(s" rightKeys: ${exec.rightKeys}") - logDebug(s" joinType: ${exec.joinType}") - logDebug(s" buildSide: ${exec.buildSide}") - logDebug(s" condition: ${exec.condition}") - var (hashed, hashedKeys, nativeProbed, probedKeys) = buildSide match { + logDebug(s" leftKeys: $leftKeys") + logDebug(s" rightKeys: $rightKeys") + logDebug(s" joinType: $joinType") + logDebug(s" buildSide: $buildSide") + logDebug(s" condition: $condition") + assert(condition.isEmpty, "join condition is not supported") + + // verify build side is native + buildSide match { case BuildRight => assert(NativeHelper.isNative(right), "broadcast join build side is not native") - val convertedLeft = convertToNative(left) - (right, rightKeys, convertedLeft, leftKeys) - case BuildLeft => assert(NativeHelper.isNative(left), "broadcast join build side is not native") - val convertedRight = convertToNative(right) - (left, leftKeys, convertedRight, rightKeys) - - case _ => - // scalastyle:off throwerror - throw new NotImplementedError( - "Ignore BroadcastHashJoin with unsupported children structure") } - var modifiedHashedKeys = hashedKeys - var modifiedProbedKeys = probedKeys - var needPostProject = false + Shims.get.createNativeBroadcastJoinExec( + addRenameColumnsExec(convertToNative(left)), + addRenameColumnsExec(convertToNative(right)), + exec.outputPartitioning, + leftKeys, + rightKeys, + joinType, + buildSide match { + case BuildLeft => BroadcastLeft + case BuildRight => BroadcastRight + }) - if (hashedKeys.exists(!_.isInstanceOf[AttributeReference])) { - val (keys, exec) = buildJoinColumnsProject(hashed, hashedKeys) - modifiedHashedKeys = keys - hashed = exec - needPostProject = true - } - if (probedKeys.exists(!_.isInstanceOf[AttributeReference])) { - val (keys, exec) = buildJoinColumnsProject(nativeProbed, probedKeys) - modifiedProbedKeys = keys - nativeProbed = exec - needPostProject = true - } - - val modifiedJoinType = buildSide match { - case BuildLeft => joinType - case BuildRight => - needPostProject = true - val modifiedJoinType = joinType match { // reverse join type - case Inner => Inner - case FullOuter => FullOuter - case LeftOuter => RightOuter - case RightOuter => LeftOuter - case _ => - throw new NotImplementedError( - "BHJ Semi/Anti join with BuildRight is not yet supported") - } - modifiedJoinType - } - - val bhjOrig = BroadcastHashJoinExec( - modifiedHashedKeys, - modifiedProbedKeys, - modifiedJoinType, - BuildLeft, - condition, - addRenameColumnsExec(hashed), - addRenameColumnsExec(nativeProbed)) - - val bhj = Shims.get.createNativeBroadcastJoinExec( - bhjOrig.left, - bhjOrig.right, - bhjOrig.outputPartitioning, - bhjOrig.leftKeys, - bhjOrig.rightKeys, - bhjOrig.joinType, - bhjOrig.condition) - - if (needPostProject) { - buildPostJoinProject(bhj, exec.output) - } else { - bhj - } } catch { case e @ (_: NotImplementedError | _: Exception) => val underlyingBroadcast = exec.buildSide match { @@ -483,60 +399,29 @@ logDebug(s" joinType: ${exec.joinType}") logDebug(s" buildSide: ${exec.buildSide}") logDebug(s" condition: ${exec.condition}") - val (broadcasted, nativeProbed) = buildSide match { + assert(condition.isEmpty, "join condition is not supported") + + // verify build side is native + buildSide match { case BuildRight => assert(NativeHelper.isNative(right), "broadcast join build side is not native") - val convertedLeft = convertToNative(left) - (right, convertedLeft) - case BuildLeft => assert(NativeHelper.isNative(left), "broadcast join build side is not native") - val convertedRight = convertToNative(right) - (left, convertedRight) - - case _ => - // scalastyle:off throwerror - throw new NotImplementedError( - "Ignore BroadcastNestedLoopJoin with unsupported children structure") } - // the in-memory inner table is not the same in different join types - // reference: https://docs.rs/datafusion/latest/datafusion/physical_plan/joins/struct.NestedLoopJoinExec.html - var needPostProject = false - val (modifiedLeft, modifiedRight, modifiedJoinType) = (buildSide, joinType) match { - case (BuildLeft, RightOuter | FullOuter) => - (broadcasted, nativeProbed, joinType) // RightOuter, FullOuter => BuildLeft - case (BuildRight, Inner | LeftOuter | LeftSemi | LeftAnti) => - ( - nativeProbed, - broadcasted, - joinType - ) // Inner, LeftOuter, LeftSemi, LeftAnti => BuildRight - case _ => - needPostProject = true - val modifiedJoinType = joinType match { - case Inner => - (nativeProbed, broadcasted, Inner) // Inner + BuildLeft => BuildRight - case FullOuter => - (broadcasted, nativeProbed, FullOuter) // FullOuter + BuildRight => BuildLeft - case _ => - throw new NotImplementedError( - s"BNLJ $joinType with $buildSide is not yet supported") - } - modifiedJoinType - } + // reuse NativeBroadcastJoin with empty equility keys + Shims.get.createNativeBroadcastJoinExec( + addRenameColumnsExec(convertToNative(left)), + addRenameColumnsExec(convertToNative(right)), + exec.outputPartitioning, + Nil, + Nil, + joinType, + buildSide match { + case BuildLeft => BroadcastLeft + case BuildRight => BroadcastRight + }) - val bnlj = Shims.get.createNativeBroadcastNestedLoopJoinExec( - addRenameColumnsExec(modifiedLeft), - addRenameColumnsExec(modifiedRight), - modifiedJoinType, - condition) - - if (needPostProject) { - buildPostJoinProject(bnlj, exec.output) - } else { - bnlj - } } catch { case e @ (_: NotImplementedError | _: Exception) => val underlyingBroadcast = exec.buildSide match { @@ -851,44 +736,6 @@ exec } - private def buildJoinColumnsProject( - child: SparkPlan, - joinKeys: Seq[Expression]): (Seq[AttributeReference], NativeProjectBase) = { - val extraProjectList = ArrayBuffer[NamedExpression]() - val transformedKeys = ArrayBuffer[AttributeReference]() - - joinKeys.foreach { - case attr: AttributeReference => transformedKeys.append(attr) - case expr => - val aliasExpr = - Alias(expr, s"JOIN_KEY:${expr.toString()} (${UUID.randomUUID().toString})")() - extraProjectList.append(aliasExpr) - - val attr = AttributeReference( - aliasExpr.name, - aliasExpr.dataType, - aliasExpr.nullable, - aliasExpr.metadata)(aliasExpr.exprId, aliasExpr.qualifier) - transformedKeys.append(attr) - } - ( - transformedKeys, - Shims.get - .createNativeProjectExec(child.output ++ extraProjectList, addRenameColumnsExec(child))) - } - - private def buildPostJoinProject( - child: SparkPlan, - output: Seq[Attribute]): NativeProjectBase = { - val projectList = output - .filter(!_.name.startsWith("JOIN_KEY:")) - .map(attr => - AttributeReference(attr.name, attr.dataType, attr.nullable, attr.metadata)( - attr.exprId, - attr.qualifier)) - Shims.get.createNativeProjectExec(projectList, child) - } - private def getPartialAggProjection( aggregateExprs: Seq[AggregateExpression], groupingExprs: Seq[NamedExpression])
diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/NativeConverters.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/NativeConverters.scala index 1cbfcc8..355444a 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/NativeConverters.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/NativeConverters.scala
@@ -52,12 +52,11 @@ import org.apache.spark.sql.catalyst.util.ArrayData import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.LeafExpression +import org.apache.spark.sql.catalyst.plans.ExistenceJoin import org.apache.spark.sql.execution.blaze.plan.Util import org.apache.spark.sql.execution.ScalarSubquery -import org.apache.spark.sql.execution.aggregate.ScalaUDAF import org.apache.spark.sql.hive.blaze.HiveUDFUtil import org.apache.spark.sql.hive.blaze.HiveUDFUtil.getFunctionClassName -import org.apache.spark.sql.hive.blaze.HiveUDFUtil.isHiveSimpleUDF import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.ArrayType import org.apache.spark.sql.types.AtomicType @@ -1110,6 +1109,7 @@ case FullOuter => pb.JoinType.FULL case LeftSemi => pb.JoinType.SEMI case LeftAnti => pb.JoinType.ANTI + case _: ExistenceJoin => pb.JoinType.EXISTENCE case _ => throw new NotImplementedError(s"unsupported join type: ${joinType}") } }
diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/Shims.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/Shims.scala index a8aaad2..12deb83 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/Shims.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/Shims.scala
@@ -79,13 +79,7 @@ leftKeys: Seq[Expression], rightKeys: Seq[Expression], joinType: JoinType, - condition: Option[Expression]): NativeBroadcastJoinBase - - def createNativeBroadcastNestedLoopJoinExec( - left: SparkPlan, - right: SparkPlan, - joinType: JoinType, - condition: Option[Expression]): NativeBroadcastNestedLoopJoinBase + broadcastSide: BroadcastSide): NativeBroadcastJoinBase def createNativeSortMergeJoinExec( left: SparkPlan,
diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDFWrapperContext.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDFWrapperContext.scala index ff27e53..e949813 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDFWrapperContext.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/SparkUDFWrapperContext.scala
@@ -16,7 +16,6 @@ package org.apache.spark.sql.blaze import java.nio.ByteBuffer - import org.apache.arrow.c.ArrowArray import org.apache.arrow.c.Data import org.apache.arrow.vector.VectorSchemaRoot
diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/util/Using.scala b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/util/Using.scala index b78eb08..b103969 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/blaze/util/Using.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/blaze/util/Using.scala
@@ -19,15 +19,14 @@ import scala.util.Try /** - * A utility for performing automatic resource management. It can be used to perform an - * operation using resources, after which it releases the resources in reverse order - * of their creation. + * A utility for performing automatic resource management. It can be used to perform an operation + * using resources, after which it releases the resources in reverse order of their creation. * * ==Usage== * - * There are multiple ways to automatically manage resources with `Using`. If you only need - * to manage a single resource, the [[Using.apply `apply`]] method is easiest; it wraps the - * resource opening, operation, and resource releasing in a `Try`. + * There are multiple ways to automatically manage resources with `Using`. If you only need to + * manage a single resource, the [[Using.apply `apply`]] method is easiest; it wraps the resource + * opening, operation, and resource releasing in a `Try`. * * Example: * {{{ @@ -37,9 +36,9 @@ * } * }}} * - * If you need to manage multiple resources, [[Using.Manager$.apply `Using.Manager`]] should - * be used. It allows the managing of arbitrarily many resources, whose creation, use, and - * release are all wrapped in a `Try`. + * If you need to manage multiple resources, [[Using.Manager$.apply `Using.Manager`]] should be + * used. It allows the managing of arbitrarily many resources, whose creation, use, and release + * are all wrapped in a `Try`. * * Example: * {{{ @@ -70,43 +69,44 @@ * * ==Suppression Behavior== * - * If two exceptions are thrown (e.g., by an operation and closing a resource), - * one of them is re-thrown, and the other is - * [[java.lang.Throwable.addSuppressed(Throwable) added to it as a suppressed exception]]. - * If the two exceptions are of different 'severities' (see below), the one of a higher - * severity is re-thrown, and the one of a lower severity is added to it as a suppressed - * exception. If the two exceptions are of the same severity, the one thrown first is - * re-thrown, and the one thrown second is added to it as a suppressed exception. - * If an exception is a [[scala.util.control.ControlThrowable `ControlThrowable`]], or - * if it does not support suppression (see - * [[java.lang.Throwable `Throwable`'s constructor with an `enableSuppression` parameter]]), - * an exception that would have been suppressed is instead discarded. + * If two exceptions are thrown (e.g., by an operation and closing a resource), one of them is + * re-thrown, and the other is + * [[java.lang.Throwable.addSuppressed(Throwable) added to it as a suppressed exception]]. If the + * two exceptions are of different 'severities' (see below), the one of a higher severity is + * re-thrown, and the one of a lower severity is added to it as a suppressed exception. If the two + * exceptions are of the same severity, the one thrown first is re-thrown, and the one thrown + * second is added to it as a suppressed exception. If an exception is a + * [[scala.util.control.ControlThrowable `ControlThrowable`]], or if it does not support + * suppression (see + * [[java.lang.Throwable `Throwable`'s constructor with an `enableSuppression` parameter]]), an + * exception that would have been suppressed is instead discarded. * * Exceptions are ranked from highest to lowest severity as follows: * - `java.lang.VirtualMachineError` * - `java.lang.LinkageError` * - `java.lang.InterruptedException` and `java.lang.ThreadDeath` - * - [[scala.util.control.NonFatal fatal exceptions]], excluding `scala.util.control.ControlThrowable` + * - [[scala.util.control.NonFatal fatal exceptions]], excluding + * `scala.util.control.ControlThrowable` * - `scala.util.control.ControlThrowable` * - all other exceptions * - * When more than two exceptions are thrown, the first two are combined and - * re-thrown as described above, and each successive exception thrown is combined - * as it is thrown. + * When more than two exceptions are thrown, the first two are combined and re-thrown as described + * above, and each successive exception thrown is combined as it is thrown. * - * @define suppressionBehavior See the main doc for [[Using `Using`]] for full details of - * suppression behavior. + * @define suppressionBehavior + * See the main doc for [[Using `Using`]] for full details of suppression behavior. */ object Using { /** - * Performs an operation using a resource, and then releases the resource, - * even if the operation throws an exception. + * Performs an operation using a resource, and then releases the resource, even if the operation + * throws an exception. * * $suppressionBehavior * - * @return a [[Try]] containing an exception if one or more were thrown, - * or the result of the operation if no exceptions were thrown + * @return + * a [[Try]] containing an exception if one or more were thrown, or the result of the + * operation if no exceptions were thrown */ def apply[R: Releasable, A](resource: => R)(f: R => A): Try[A] = Try { Using.resource(resource)(f) @@ -115,20 +115,20 @@ /** * A resource manager. * - * Resources can be registered with the manager by calling [[acquire `acquire`]]; - * such resources will be released in reverse order of their acquisition - * when the manager is closed, regardless of any exceptions thrown - * during use. + * Resources can be registered with the manager by calling [[acquire `acquire`]]; such resources + * will be released in reverse order of their acquisition when the manager is closed, regardless + * of any exceptions thrown during use. * * $suppressionBehavior * - * @note It is recommended for API designers to require an implicit `Manager` - * for the creation of custom resources, and to call `acquire` during those - * resources' construction. Doing so guarantees that the resource ''must'' be - * automatically managed, and makes it impossible to forget to do so. + * @note + * It is recommended for API designers to require an implicit `Manager` for the creation of + * custom resources, and to call `acquire` during those resources' construction. Doing so + * guarantees that the resource ''must'' be automatically managed, and makes it impossible to + * forget to do so. * - * Example: - * {{{ + * Example: + * {{{ * class SafeFileReader(file: File)(implicit manager: Using.Manager) * extends BufferedReader(new FileReader(file)) { * @@ -136,7 +136,7 @@ * * manager.acquire(this) * } - * }}} + * }}} */ final class Manager private { import Manager._ @@ -145,9 +145,8 @@ private[this] var resources: List[Resource[_]] = Nil /** - * Registers the specified resource with this manager, so that - * the resource is released when the manager is closed, and then - * returns the (unmodified) resource. + * Registers the specified resource with this manager, so that the resource is released when + * the manager is closed, and then returns the (unmodified) resource. */ def apply[R: Releasable](resource: R): R = { acquire(resource) @@ -155,8 +154,8 @@ } /** - * Registers the specified resource with this manager, so that - * the resource is released when the manager is closed. + * Registers the specified resource with this manager, so that the resource is released when + * the manager is closed. */ def acquire[R: Releasable](resource: R): Unit = { if (resource == null) throw new NullPointerException("null resource") @@ -194,8 +193,8 @@ object Manager { /** - * Performs an operation using a `Manager`, then closes the `Manager`, - * releasing its resources (in reverse order of acquisition). + * Performs an operation using a `Manager`, then closes the `Manager`, releasing its resources + * (in reverse order of acquisition). * * Example: * {{{ @@ -204,9 +203,8 @@ * } * }}} * - * If using resources which require an implicit `Manager` as a parameter, - * this method should be invoked with an `implicit` modifier before the function - * parameter: + * If using resources which require an implicit `Manager` as a parameter, this method should + * be invoked with an `implicit` modifier before the function parameter: * * Example: * {{{ @@ -217,10 +215,13 @@ * * See the main doc for [[Using `Using`]] for full details of suppression behavior. * - * @param op the operation to perform using the manager - * @tparam A the return type of the operation - * @return a [[Try]] containing an exception if one or more were thrown, - * or the result of the operation if no exceptions were thrown + * @param op + * the operation to perform using the manager + * @tparam A + * the return type of the operation + * @return + * a [[Try]] containing an exception if one or more were thrown, or the result of the + * operation if no exceptions were thrown */ def apply[A](op: Manager => A): Try[A] = Try { (new Manager).manage(op) } @@ -247,18 +248,21 @@ } /** - * Performs an operation using a resource, and then releases the resource, - * even if the operation throws an exception. This method behaves similarly - * to Java's try-with-resources. + * Performs an operation using a resource, and then releases the resource, even if the operation + * throws an exception. This method behaves similarly to Java's try-with-resources. * * $suppressionBehavior * - * @param resource the resource - * @param body the operation to perform with the resource - * @tparam R the type of the resource - * @tparam A the return type of the operation - * @return the result of the operation, if neither the operation nor - * releasing the resource throws + * @param resource + * the resource + * @param body + * the operation to perform with the resource + * @tparam R + * the type of the resource + * @tparam A + * the return type of the operation + * @return + * the result of the operation, if neither the operation nor releasing the resource throws */ def resource[R, A](resource: R)(body: R => A)(implicit releasable: Releasable[R]): A = { if (resource == null) throw new NullPointerException("null resource") @@ -281,20 +285,26 @@ } /** - * Performs an operation using two resources, and then releases the resources - * in reverse order, even if the operation throws an exception. This method - * behaves similarly to Java's try-with-resources. + * Performs an operation using two resources, and then releases the resources in reverse order, + * even if the operation throws an exception. This method behaves similarly to Java's + * try-with-resources. * * $suppressionBehavior * - * @param resource1 the first resource - * @param resource2 the second resource - * @param body the operation to perform using the resources - * @tparam R1 the type of the first resource - * @tparam R2 the type of the second resource - * @tparam A the return type of the operation - * @return the result of the operation, if neither the operation nor - * releasing the resources throws + * @param resource1 + * the first resource + * @param resource2 + * the second resource + * @param body + * the operation to perform using the resources + * @tparam R1 + * the type of the first resource + * @tparam R2 + * the type of the second resource + * @tparam A + * the return type of the operation + * @return + * the result of the operation, if neither the operation nor releasing the resources throws */ def resources[R1: Releasable, R2: Releasable, A](resource1: R1, resource2: => R2)( body: (R1, R2) => A): A = @@ -305,22 +315,30 @@ } /** - * Performs an operation using three resources, and then releases the resources - * in reverse order, even if the operation throws an exception. This method - * behaves similarly to Java's try-with-resources. + * Performs an operation using three resources, and then releases the resources in reverse + * order, even if the operation throws an exception. This method behaves similarly to Java's + * try-with-resources. * * $suppressionBehavior * - * @param resource1 the first resource - * @param resource2 the second resource - * @param resource3 the third resource - * @param body the operation to perform using the resources - * @tparam R1 the type of the first resource - * @tparam R2 the type of the second resource - * @tparam R3 the type of the third resource - * @tparam A the return type of the operation - * @return the result of the operation, if neither the operation nor - * releasing the resources throws + * @param resource1 + * the first resource + * @param resource2 + * the second resource + * @param resource3 + * the third resource + * @param body + * the operation to perform using the resources + * @tparam R1 + * the type of the first resource + * @tparam R2 + * the type of the second resource + * @tparam R3 + * the type of the third resource + * @tparam A + * the return type of the operation + * @return + * the result of the operation, if neither the operation nor releasing the resources throws */ def resources[R1: Releasable, R2: Releasable, R3: Releasable, A]( resource1: R1, @@ -335,24 +353,34 @@ } /** - * Performs an operation using four resources, and then releases the resources - * in reverse order, even if the operation throws an exception. This method - * behaves similarly to Java's try-with-resources. + * Performs an operation using four resources, and then releases the resources in reverse order, + * even if the operation throws an exception. This method behaves similarly to Java's + * try-with-resources. * * $suppressionBehavior * - * @param resource1 the first resource - * @param resource2 the second resource - * @param resource3 the third resource - * @param resource4 the fourth resource - * @param body the operation to perform using the resources - * @tparam R1 the type of the first resource - * @tparam R2 the type of the second resource - * @tparam R3 the type of the third resource - * @tparam R4 the type of the fourth resource - * @tparam A the return type of the operation - * @return the result of the operation, if neither the operation nor - * releasing the resources throws + * @param resource1 + * the first resource + * @param resource2 + * the second resource + * @param resource3 + * the third resource + * @param resource4 + * the fourth resource + * @param body + * the operation to perform using the resources + * @tparam R1 + * the type of the first resource + * @tparam R2 + * the type of the second resource + * @tparam R3 + * the type of the third resource + * @tparam R4 + * the type of the fourth resource + * @tparam A + * the return type of the operation + * @return + * the result of the operation, if neither the operation nor releasing the resources throws */ def resources[R1: Releasable, R2: Releasable, R3: Releasable, R4: Releasable, A]( resource1: R1, @@ -372,17 +400,18 @@ /** * A typeclass describing how to release a particular type of resource. * - * A resource is anything which needs to be released, closed, or otherwise cleaned up - * in some way after it is finished being used, and for which waiting for the object's - * garbage collection to be cleaned up would be unacceptable. For example, an instance of - * [[java.io.OutputStream]] would be considered a resource, because it is important to close - * the stream after it is finished being used. + * A resource is anything which needs to be released, closed, or otherwise cleaned up in some + * way after it is finished being used, and for which waiting for the object's garbage + * collection to be cleaned up would be unacceptable. For example, an instance of + * [[java.io.OutputStream]] would be considered a resource, because it is important to close the + * stream after it is finished being used. * - * An instance of `Releasable` is needed in order to automatically manage a resource - * with [[Using `Using`]]. An implicit instance is provided for all types extending + * An instance of `Releasable` is needed in order to automatically manage a resource with + * [[Using `Using`]]. An implicit instance is provided for all types extending * [[java.lang.AutoCloseable]]. * - * @tparam R the type of the resource + * @tparam R + * the type of the resource */ trait Releasable[-R] {
diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/arrowio/util/ArrowUtils.scala b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/arrowio/util/ArrowUtils.scala index 6e18f47..f6ddfc6 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/arrowio/util/ArrowUtils.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/arrowio/util/ArrowUtils.scala
@@ -15,7 +15,8 @@ */ package org.apache.spark.sql.execution.blaze.arrowio.util -import scala.collection.JavaConverters._ +import scala.collection.JavaConverters.asScalaBufferConverter +import scala.collection.JavaConverters.seqAsJavaListConverter import org.apache.arrow.memory.BufferAllocator import org.apache.arrow.memory.RootAllocator @@ -31,7 +32,6 @@ import org.apache.spark.util.ShutdownHookManager object ArrowUtils { - val rootAllocator = new RootAllocator(Long.MaxValue) ShutdownHookManager.addShutdownHook(() => rootAllocator.close()) @@ -128,7 +128,7 @@ ArrayType(elementType, containsNull = elementField.isNullable) case ArrowType.Struct.INSTANCE => - val fields = field.getChildren().asScala.map { child => + val fields = field.getChildren.asScala.map { child => val dt = fromArrowField(child) StructField(child.getName, dt, child.isNullable) }
diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/ConvertToNativeBase.scala b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/ConvertToNativeBase.scala index 0522623..852e533 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/ConvertToNativeBase.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/ConvertToNativeBase.scala
@@ -34,7 +34,6 @@ import org.apache.spark.sql.execution.blaze.arrowio.ArrowFFIExportIterator import org.apache.spark.sql.execution.metric.SQLMetrics import org.apache.spark.OneToOneDependency -import org.apache.spark.sql.blaze.BlazeConf import org.blaze.protobuf.FFIReaderExecNode import org.blaze.protobuf.PhysicalPlanNode import org.blaze.protobuf.Schema
diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeBroadcastExchangeBase.scala b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeBroadcastExchangeBase.scala index 5525fb4..6fcbd47 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeBroadcastExchangeBase.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeBroadcastExchangeBase.scala
@@ -24,22 +24,19 @@ import java.util.concurrent.TimeoutException import java.util.concurrent.TimeUnit -import scala.collection.mutable.ArrayBuffer import scala.collection.JavaConverters._ import scala.collection.immutable.SortedMap import scala.concurrent.Promise +import org.apache.commons.lang3.reflect.MethodUtils import org.apache.spark.OneToOneDependency import org.apache.spark.Partition import org.apache.spark.SparkException import org.apache.spark.TaskContext import org.apache.spark.broadcast -import org.blaze.{protobuf => pb} import org.apache.spark.broadcast.Broadcast import org.apache.spark.rdd.RDD import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.blaze.BlazeCallNativeWrapper -import org.apache.spark.sql.blaze.BlazeConf import org.apache.spark.sql.blaze.JniBridge import org.apache.spark.sql.blaze.MetricNode import org.apache.spark.sql.blaze.NativeConverters @@ -49,7 +46,10 @@ import org.apache.spark.sql.blaze.Shims import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.Attribute +import org.apache.spark.sql.catalyst.expressions.AttributeReference +import org.apache.spark.sql.catalyst.expressions.BoundReference import org.apache.spark.sql.catalyst.expressions.Expression +import org.apache.spark.sql.catalyst.expressions.InterpretedUnsafeProjection import org.apache.spark.sql.catalyst.plans.physical.BroadcastMode import org.apache.spark.sql.catalyst.plans.physical.BroadcastPartitioning import org.apache.spark.sql.catalyst.plans.physical.IdentityBroadcastMode @@ -63,6 +63,8 @@ import org.apache.spark.sql.execution.joins.HashedRelationBroadcastMode import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.metric.SQLMetrics +import org.apache.spark.sql.types.BinaryType +import org.blaze.{protobuf => pb} abstract class NativeBroadcastExchangeBase(mode: BroadcastMode, override val child: SparkPlan) extends BroadcastExchangeLike @@ -71,10 +73,15 @@ override def output: Seq[Attribute] = child.output override def outputPartitioning: Partitioning = BroadcastPartitioning(mode) + def broadcastMode: BroadcastMode = this.mode + + protected val hashMapOutput: Seq[Attribute] = output + .map(_.withNullability(true)) :+ AttributeReference("~TABLE", BinaryType, nullable = true)() + protected val nativeSchema: pb.Schema = Util.getNativeSchema(output) + protected val nativeHashMapSchema: pb.Schema = Util.getNativeSchema(hashMapOutput) def getRunId: UUID - override lazy val metrics: Map[String, SQLMetric] = SortedMap[String, SQLMetric]() ++ Map( NativeHelper .getDefaultNativeMetrics(sparkContext) @@ -93,9 +100,6 @@ override def doPrepare(): Unit = { // Materialize the future. relationFuture - relationFuture - relationFuture - relationFuture } override def doExecuteBroadcast[T](): Broadcast[T] = { @@ -103,17 +107,31 @@ override def index: Int = 0 } val broadcastReadNativePlan = doExecuteNative().nativePlan(singlePartition, null) - val rows = NativeHelper.executeNativePlan( + val rowsIter = NativeHelper.executeNativePlan( broadcastReadNativePlan, MetricNode(Map(), Nil, None), singlePartition, None) - val v = mode.transform(rows.toArray) + val pruneKeyField = new InterpretedUnsafeProjection( + output.zipWithIndex + .map(v => BoundReference(v._2, v._1.dataType, v._1.nullable)) + .toArray) + val dataRows = rowsIter + .map(pruneKeyField) + .map(_.copy()) + .toArray + + val broadcast = relationFuture.get // bloadcast must be resolved + val v = mode.transform(dataRows) val dummyBroadcasted = new Broadcast[Any](-1) { override protected def getValue(): Any = v - override protected def doUnpersist(blocking: Boolean): Unit = {} - override protected def doDestroy(blocking: Boolean): Unit = {} + override protected def doUnpersist(blocking: Boolean): Unit = { + MethodUtils.invokeMethod(broadcast, true, "doUnpersist", Array(blocking)) + } + override protected def doDestroy(blocking: Boolean): Unit = { + MethodUtils.invokeMethod(broadcast, true, "doDestroy", Array(blocking)) + } } dummyBroadcasted.asInstanceOf[Broadcast[T]] } @@ -154,13 +172,14 @@ Channels.newChannel(new ByteArrayInputStream(bytes)) }) } + JniBridge.resourcesMap.put(resourceId, () => provideIpcIterator()) pb.PhysicalPlanNode .newBuilder() .setIpcReader( pb.IpcReaderExecNode .newBuilder() - .setSchema(nativeSchema) + .setSchema(nativeHashMapSchema) .setNumPartitions(1) .setIpcProviderResourceId(resourceId) .build()) @@ -267,39 +286,21 @@ keys: Seq[Expression], nativeSchema: pb.Schema): Array[Array[Byte]] = { - if (!BlazeConf.BHJ_FALLBACKS_TO_SMJ_ENABLE.booleanConf() || keys.isEmpty) { - return collectedData // no need to sort data in driver side - } - val readerIpcProviderResourceId = s"BuildBroadcastDataReader:${UUID.randomUUID()}" val readerExec = pb.IpcReaderExecNode .newBuilder() .setSchema(nativeSchema) .setIpcProviderResourceId(readerIpcProviderResourceId) - val sortExec = pb.SortExecNode + val buildHashMapExec = pb.BroadcastJoinBuildHashMapExecNode .newBuilder() .setInput(pb.PhysicalPlanNode.newBuilder().setIpcReader(readerExec)) - .addAllExpr( - keys - .map(key => { - pb.PhysicalExprNode - .newBuilder() - .setSort( - pb.PhysicalSortExprNode - .newBuilder() - .setExpr(NativeConverters.convertExpr(key)) - .setAsc(true) - .setNullsFirst(true) - .build()) - .build() - }) - .asJava) + .addAllKeys(keys.map(key => NativeConverters.convertExpr(key)).asJava) val writerIpcProviderResourceId = s"BuildBroadcastDataWriter:${UUID.randomUUID()}" val writerExec = pb.IpcWriterExecNode .newBuilder() - .setInput(pb.PhysicalPlanNode.newBuilder().setSort(sortExec)) + .setInput(pb.PhysicalPlanNode.newBuilder().setBroadcastJoinBuildHashMap(buildHashMapExec)) .setIpcConsumerResourceId(writerIpcProviderResourceId) // build native sorter
diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeBroadcastJoinBase.scala b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeBroadcastJoinBase.scala index ec13b8f..dc27d2a 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeBroadcastJoinBase.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeBroadcastJoinBase.scala
@@ -20,21 +20,24 @@ import org.apache.spark.OneToOneDependency import org.apache.spark.Partition -import org.apache.spark.sql.blaze.BlazeConf import org.apache.spark.sql.blaze.MetricNode import org.apache.spark.sql.blaze.NativeConverters import org.apache.spark.sql.blaze.NativeHelper import org.apache.spark.sql.blaze.NativeRDD import org.apache.spark.sql.blaze.NativeSupports +import org.apache.spark.sql.blaze.Shims +import org.apache.spark.sql.catalyst.expressions.Cast import org.apache.spark.sql.catalyst.expressions.Expression import org.apache.spark.sql.catalyst.plans.JoinType -import org.apache.spark.sql.catalyst.plans.LeftAnti -import org.apache.spark.sql.catalyst.plans.LeftSemi import org.apache.spark.sql.catalyst.plans.physical.Partitioning import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.execution.BinaryExecNode +import org.apache.spark.sql.execution.exchange.BroadcastExchangeExec +import org.apache.spark.sql.execution.joins.HashedRelationBroadcastMode +import org.apache.spark.sql.types.LongType import org.blaze.{protobuf => pb} +import org.blaze.protobuf.JoinOn abstract class NativeBroadcastJoinBase( override val left: SparkPlan, @@ -43,82 +46,114 @@ leftKeys: Seq[Expression], rightKeys: Seq[Expression], joinType: JoinType, - condition: Option[Expression]) + broadcastSide: BroadcastSide) extends BinaryExecNode with NativeSupports { - assert( - (joinType != LeftSemi && joinType != LeftAnti) || condition.isEmpty, - "Semi/Anti join with filter is not supported yet") - - assert( - !BlazeConf.BHJ_FALLBACKS_TO_SMJ_ENABLE.booleanConf() || BlazeConf.SMJ_INEQUALITY_JOIN_ENABLE - .booleanConf() || condition.isEmpty, - "Join filter is not supported when BhjFallbacksToSmj and SmjInequalityJoin both enabled") - override lazy val metrics: Map[String, SQLMetric] = SortedMap[String, SQLMetric]() ++ Map( NativeHelper .getDefaultNativeMetrics(sparkContext) .toSeq: _*) + private val isLongHashRelation = { + val baseBroadcast = broadcastSide match { + case BroadcastLeft => Shims.get.getUnderlyingBroadcast(left) + case BroadcastRight => Shims.get.getUnderlyingBroadcast(right) + } + val mode = baseBroadcast match { + case b: BroadcastExchangeExec => b.mode + case b: NativeBroadcastExchangeBase => b.broadcastMode + } + mode match { + case mode: HashedRelationBroadcastMode + if mode.key.length == 1 && mode.key.head.dataType == LongType => + true + case _ => false + } + } + + private def nativeSchema = Util.getNativeSchema(output) + private def nativeJoinOn = leftKeys.zip(rightKeys).map { case (leftKey, rightKey) => - val leftColumn = NativeConverters.convertExpr(leftKey).getColumn match { - case column if column.getName.isEmpty => - throw new NotImplementedError(s"BHJ leftKey is not column: ${leftKey}") - case column => column + val leftKeyExpr = leftKey match { + case k if !isLongHashRelation || k.dataType == LongType => k + case k => Cast(k, LongType) } - val rightColumn = NativeConverters.convertExpr(rightKey).getColumn match { - case column if column.getName.isEmpty => - throw new NotImplementedError(s"BHJ rightKey is not column: ${rightKey}") - case column => column + val rightKeyExpr = rightKey match { + case k if !isLongHashRelation || k.dataType == LongType => k + case k => Cast(k, LongType) } - pb.JoinOn + JoinOn .newBuilder() - .setLeft(leftColumn) - .setRight(rightColumn) + .setLeft(NativeConverters.convertExpr(leftKeyExpr)) + .setRight(NativeConverters.convertExpr(rightKeyExpr)) .build() } private def nativeJoinType = NativeConverters.convertJoinType(joinType) - private def nativeJoinFilter = - condition.map(NativeConverters.convertJoinFilter(_, left.output, right.output)) + private def nativeBroadcastSide = broadcastSide match { + case BroadcastLeft => pb.JoinSide.LEFT_SIDE + case BroadcastRight => pb.JoinSide.RIGHT_SIDE + } // check whether native converting is supported + nativeSchema nativeJoinType - nativeJoinFilter + nativeJoinOn + nativeBroadcastSide override def doExecuteNative(): NativeRDD = { val leftRDD = NativeHelper.executeNative(left) val rightRDD = NativeHelper.executeNative(right) val nativeMetrics = MetricNode(metrics, leftRDD.metrics :: rightRDD.metrics :: Nil) + val nativeSchema = this.nativeSchema val nativeJoinType = this.nativeJoinType val nativeJoinOn = this.nativeJoinOn - val nativeJoinFilter = this.nativeJoinFilter - val partitions = rightRDD.partitions + + val (probedRDD, builtRDD) = broadcastSide match { + case BroadcastLeft => (rightRDD, leftRDD) + case BroadcastRight => (leftRDD, rightRDD) + } new NativeRDD( sparkContext, nativeMetrics, - partitions, - rddDependencies = new OneToOneDependency(rightRDD) :: Nil, - rightRDD.isShuffleReadFull, + probedRDD.partitions, + rddDependencies = new OneToOneDependency(probedRDD) :: Nil, + probedRDD.isShuffleReadFull, (partition, context) => { val partition0 = new Partition() { override def index: Int = 0 } - val leftChild = leftRDD.nativePlan(partition0, context) - val rightChild = rightRDD.nativePlan(rightRDD.partitions(partition.index), context) + val (leftChild, rightChild) = broadcastSide match { + case BroadcastLeft => + ( + leftRDD.nativePlan(partition0, context), + rightRDD.nativePlan(rightRDD.partitions(partition.index), context)) + case BroadcastRight => + ( + leftRDD.nativePlan(leftRDD.partitions(partition.index), context), + rightRDD.nativePlan(partition0, context)) + } + val cachedBuildHashMapId = s"bhm_stage${context.stageId}_rdd${builtRDD.id}" + val broadcastJoinExec = pb.BroadcastJoinExecNode .newBuilder() + .setSchema(nativeSchema) .setLeft(leftChild) .setRight(rightChild) .setJoinType(nativeJoinType) + .setBroadcastSide(nativeBroadcastSide) + .setCachedBuildHashMapId(cachedBuildHashMapId) .addAllOn(nativeJoinOn.asJava) - nativeJoinFilter.foreach(joinFilter => broadcastJoinExec.setJoinFilter(joinFilter)) pb.PhysicalPlanNode.newBuilder().setBroadcastJoin(broadcastJoinExec).build() }, friendlyName = "NativeRDD.BroadcastJoin") } } + +class BroadcastSide {} +case object BroadcastLeft extends BroadcastSide {} +case object BroadcastRight extends BroadcastSide {}
diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeBroadcastNestedLoopJoinBase.scala b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeBroadcastNestedLoopJoinBase.scala deleted file mode 100644 index bfcf747..0000000 --- a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeBroadcastNestedLoopJoinBase.scala +++ /dev/null
@@ -1,144 +0,0 @@ -/* - * Copyright 2022 The Blaze Authors - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package org.apache.spark.sql.execution.blaze.plan - -import scala.collection.immutable.SortedMap - -import org.apache.spark.OneToOneDependency -import org.apache.spark.Partition -import org.apache.spark.sql.blaze.MetricNode -import org.apache.spark.sql.blaze.NativeConverters -import org.apache.spark.sql.blaze.NativeHelper -import org.apache.spark.sql.blaze.NativeRDD -import org.apache.spark.sql.blaze.NativeSupports -import org.apache.spark.sql.catalyst.expressions.Attribute -import org.apache.spark.sql.catalyst.expressions.Expression -import org.apache.spark.sql.catalyst.plans.ExistenceJoin -import org.apache.spark.sql.catalyst.plans.FullOuter -import org.apache.spark.sql.catalyst.plans.Inner -import org.apache.spark.sql.catalyst.plans.InnerLike -import org.apache.spark.sql.catalyst.plans.JoinType -import org.apache.spark.sql.catalyst.plans.LeftAnti -import org.apache.spark.sql.catalyst.plans.LeftExistence -import org.apache.spark.sql.catalyst.plans.LeftOuter -import org.apache.spark.sql.catalyst.plans.LeftSemi -import org.apache.spark.sql.catalyst.plans.RightOuter -import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.metric.SQLMetric -import org.apache.spark.sql.execution.BinaryExecNode -import org.blaze.{protobuf => pb} - -abstract class NativeBroadcastNestedLoopJoinBase( - override val left: SparkPlan, - override val right: SparkPlan, - joinType: JoinType, - condition: Option[Expression]) - extends BinaryExecNode - with NativeSupports { - - override lazy val metrics: Map[String, SQLMetric] = SortedMap[String, SQLMetric]() ++ Map( - NativeHelper - .getDefaultNativeMetrics(sparkContext) - .filterKeys( - Set( - "stage_id", - "output_rows", - "elapsed_compute", - "input_batch_count", - "input_batch_mem_size", - "input_row_count")) - .toSeq: _*) - - override def output: Seq[Attribute] = { - joinType match { - case _: InnerLike => - left.output ++ right.output - case LeftOuter => - left.output ++ right.output.map(_.withNullability(true)) - case RightOuter => - left.output.map(_.withNullability(true)) ++ right.output - case FullOuter => - left.output.map(_.withNullability(true)) ++ right.output.map(_.withNullability(true)) - case j: ExistenceJoin => - left.output :+ j.exists - case LeftExistence(_) => - left.output - case x => - throw new IllegalArgumentException( - s"BroadcastNestedLoopJoin should not take $x as the JoinType") - } - } - - private def nativeJoinType = NativeConverters.convertJoinType(joinType) - private def nativeJoinFilter = - condition.map(NativeConverters.convertJoinFilter(_, left.output, right.output)) - - // check whether native converting is supported - nativeJoinType - nativeJoinFilter - - private val probedSide = joinType match { - case Inner | LeftOuter | LeftSemi | LeftAnti => "left" - case RightOuter | FullOuter => "right" - case other => s"NativeBroadcastNestedLoopJoin does not support join type $other" - } - - override def doExecuteNative(): NativeRDD = { - val leftRDD = NativeHelper.executeNative(left) - val rightRDD = NativeHelper.executeNative(right) - val nativeMetrics = MetricNode(metrics, leftRDD.metrics :: rightRDD.metrics :: Nil) - val nativeJoinType = this.nativeJoinType - val nativeJoinFilter = this.nativeJoinFilter - val partitions = probedSide match { - case "left" => leftRDD.partitions - case "right" => rightRDD.partitions - } - - new NativeRDD( - sparkContext, - nativeMetrics, - partitions, - rddDependencies = probedSide match { - case "left" => new OneToOneDependency(leftRDD) :: Nil - case "right" => new OneToOneDependency(rightRDD) :: Nil - }, - rightRDD.isShuffleReadFull, - (partition, context) => { - val partition0 = new Partition() { - override def index: Int = 0 - } - val (leftChild, rightChild) = probedSide match { - case "left" => - ( - leftRDD.nativePlan(leftRDD.partitions(partition.index), context), - rightRDD.nativePlan(partition0, context)) - case "right" => - ( - leftRDD.nativePlan(partition0, context), - rightRDD.nativePlan(rightRDD.partitions(partition.index), context)) - } - val bnlj = pb.BroadcastNestedLoopJoinExecNode - .newBuilder() - .setLeft(leftChild) - .setRight(rightChild) - .setJoinType(nativeJoinType) - - nativeJoinFilter.foreach(joinFilter => bnlj.setJoinFilter(joinFilter)) - pb.PhysicalPlanNode.newBuilder().setBroadcastNestedLoopJoin(bnlj).build() - }, - friendlyName = "NativeRDD.BroadcastNestedLoopJoin") - } -}
diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeGenerateBase.scala b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeGenerateBase.scala index 2349cc9..dc0e371 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeGenerateBase.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeGenerateBase.scala
@@ -22,7 +22,6 @@ import org.apache.spark.OneToOneDependency import org.apache.spark.sql.blaze.MetricNode import org.apache.spark.sql.blaze.NativeConverters -import org.apache.spark.sql.blaze.NativeConverters.convertExprWithFallback import org.apache.spark.sql.blaze.NativeHelper import org.apache.spark.sql.blaze.NativeRDD import org.apache.spark.sql.blaze.NativeSupports
diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeParquetScanBase.scala b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeParquetScanBase.scala index 1bbfdc3..276fd53 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeParquetScanBase.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeParquetScanBase.scala
@@ -167,7 +167,7 @@ partitions.asInstanceOf[Array[Partition]], Nil, rddShuffleReadFull = true, - (partition, context) => { + (partition, _context) => { val resourceId = s"NativeParquetScanExec:${UUID.randomUUID().toString}" val sharedConf = broadcastedHadoopConf.value.value JniBridge.resourcesMap.put(
diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeParquetSinkBase.scala b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeParquetSinkBase.scala index 8e81d43..cd2e5a3 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeParquetSinkBase.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeParquetSinkBase.scala
@@ -151,6 +151,6 @@ "ParquetSink") } - protected def newHadoopConf(tableDesc: TableDesc): Configuration = + protected def newHadoopConf(_tableDesc: TableDesc): Configuration = sparkSession.sessionState.newHadoopConf() }
diff --git a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeSortMergeJoinBase.scala b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeSortMergeJoinBase.scala index 52efbcd..831211b 100644 --- a/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeSortMergeJoinBase.scala +++ b/spark-extension/src/main/scala/org/apache/spark/sql/execution/blaze/plan/NativeSortMergeJoinBase.scala
@@ -22,7 +22,6 @@ import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.OneToOneDependency -import org.apache.spark.sql.blaze.BlazeConf import org.apache.spark.sql.blaze.MetricNode import org.apache.spark.sql.blaze.NativeConverters import org.apache.spark.sql.blaze.NativeHelper @@ -52,13 +51,7 @@ extends BinaryExecNode with NativeSupports { - assert( - (joinType != LeftSemi && joinType != LeftAnti) || condition.isEmpty, - "Semi/Anti join with filter is not supported yet") - - assert( - BlazeConf.SMJ_INEQUALITY_JOIN_ENABLE.booleanConf() || condition.isEmpty, - "inequality sort-merge join is not enabled") + assert(condition.isEmpty, "inequality join is not supported") override lazy val metrics: Map[String, SQLMetric] = SortedMap[String, SQLMetric]() ++ Map( NativeHelper @@ -81,21 +74,15 @@ keys.map(SortOrder(_, Ascending)) } + private def nativeSchema = Util.getNativeSchema(output) + private def nativeJoinOn = leftKeys.zip(rightKeys).map { case (leftKey, rightKey) => - val leftColumn = NativeConverters.convertExpr(leftKey).getColumn match { - case column if column.getName.isEmpty => - throw new NotImplementedError(s"SMJ leftKey is not column: ${leftKey}") - case column => column - } - val rightColumn = NativeConverters.convertExpr(rightKey).getColumn match { - case column if column.getName.isEmpty => - throw new NotImplementedError(s"SMJ rightKey is not column: ${rightKey}") - case column => column - } + val leftKeyExpr = NativeConverters.convertExpr(leftKey) + val rightKeyExpr = NativeConverters.convertExpr(rightKey) JoinOn .newBuilder() - .setLeft(leftColumn) - .setRight(rightColumn) + .setLeft(leftKeyExpr) + .setRight(rightKeyExpr) .build() } @@ -109,14 +96,11 @@ private def nativeJoinType = NativeConverters.convertJoinType(joinType) - private def nativeJoinFilter = - condition.map(NativeConverters.convertJoinFilter(_, left.output, right.output)) - // check whether native converting is supported + nativeSchema nativeSortOptions nativeJoinOn nativeJoinType - nativeJoinFilter override def doExecuteNative(): NativeRDD = { val leftRDD = NativeHelper.executeNative(left) @@ -125,7 +109,6 @@ val nativeSortOptions = this.nativeSortOptions val nativeJoinOn = this.nativeJoinOn val nativeJoinType = this.nativeJoinType - val nativeJoinFilter = this.nativeJoinFilter val partitions = if (joinType != RightOuter) { leftRDD.partitions @@ -161,13 +144,12 @@ val sortMergeJoinExec = SortMergeJoinExecNode .newBuilder() + .setSchema(nativeSchema) .setLeft(leftChild) .setRight(rightChild) .setJoinType(nativeJoinType) .addAllOn(nativeJoinOn.asJava) .addAllSortOptions(nativeSortOptions.asJava) - - nativeJoinFilter.foreach(joinFilter => sortMergeJoinExec.setJoinFilter(joinFilter)) PhysicalPlanNode.newBuilder().setSortMergeJoin(sortMergeJoinExec).build() }, friendlyName = "NativeRDD.SortMergeJoin")