IMPALA-9126: part 3: move more logic to PhjBuilder

The general flavour of this patch is to move code
that orchestrates the top-level spilling hash join
algorithm to PhjBuilder, and better encapsulate
state in PhjBuilder by reducing the number of
public methods.

Specific changes include:
* Move HashJoinState to PhjBuilder, which is necessary
  for the shared join build since the builder will
  be orchestrating the spilling.
* Reduce public methods of PhjBuilder. The goal is
  for the builder to hand off pointers to hash tables,
  partitions, etc only during transitions of the
  state machine (i.e. synchronization points when
  we have a shared build).
* Highlight methods of PhjBuilder that will be
  synchronization points for the shared join build
  where hash tables are built and destroyed.
* Highlight other future changes to PhjBuilder.
* Move some of the output build partition logic
  from NextSpilledProbeRowBatch() to the builder,
  which required a few other changes - e.g. explicit
  *eos output arguments for some functions.

This does *not* include a change to have the builder pick
the spilled partition to process. That will be a follow-on,
because it requires more refactoring of the relationship
between PhjBuilder::Partition and ProbePartition.

Testing:
* Earlier version of patch passed exhaustive tests.

Change-Id: I0e233468de1eeae86651ab96df207de19e091053
Reviewed-on: http://gerrit.cloudera.org:8080/14787
Reviewed-by: Impala Public Jenkins <impala-public-jenkins@cloudera.com>
Tested-by: Impala Public Jenkins <impala-public-jenkins@cloudera.com>
diff --git a/be/src/exec/partitioned-hash-join-builder.cc b/be/src/exec/partitioned-hash-join-builder.cc
index 8496703..8c1023a 100644
--- a/be/src/exec/partitioned-hash-join-builder.cc
+++ b/be/src/exec/partitioned-hash-join-builder.cc
@@ -63,23 +63,7 @@
     join_op_(join_op),
     buffer_pool_client_(buffer_pool_client),
     spillable_buffer_size_(spillable_buffer_size),
-    max_row_buffer_size_(max_row_buffer_size),
-    non_empty_build_(false),
-    partitions_created_(NULL),
-    largest_partition_percent_(NULL),
-    max_partition_level_(NULL),
-    num_build_rows_partitioned_(NULL),
-    num_spilled_partitions_(NULL),
-    num_repartitions_(NULL),
-    partition_build_rows_timer_(NULL),
-    build_hash_table_timer_(NULL),
-    repartition_timer_(NULL),
-    null_aware_partition_(NULL),
-    probe_stream_reservation_(),
-    process_build_batch_fn_(NULL),
-    process_build_batch_fn_level0_(NULL),
-    insert_batch_fn_(NULL),
-    insert_batch_fn_level0_(NULL) {}
+    max_row_buffer_size_(max_row_buffer_size) {}
 
 Status PhjBuilder::InitExprsAndFilters(RuntimeState* state,
     const vector<TEqJoinCondition>& eq_join_conjuncts,
@@ -144,6 +128,8 @@
   num_repartitions_ = ADD_COUNTER(profile(), "NumRepartitions", TUnit::UNIT);
   partition_build_rows_timer_ = ADD_TIMER(profile(), "BuildRowsPartitionTime");
   build_hash_table_timer_ = ADD_TIMER(profile(), "HashTablesBuildTime");
+  num_hash_table_builds_skipped_ =
+      ADD_COUNTER(profile(), "NumHashTableBuildsSkipped", TUnit::UNIT);
   repartition_timer_ = ADD_TIMER(profile(), "RepartitionTime");
   state->CheckAndAddCodegenDisabledMessage(profile());
   return Status::OK();
@@ -177,12 +163,12 @@
 Status PhjBuilder::Send(RuntimeState* state, RowBatch* batch) {
   SCOPED_TIMER(partition_build_rows_timer_);
   bool build_filters = ht_ctx_->level() == 0 && filter_ctxs_.size() > 0;
-  if (process_build_batch_fn_ == NULL) {
-      RETURN_IF_ERROR(ProcessBuildBatch(batch, ht_ctx_.get(), build_filters,
-          join_op_ == TJoinOp::NULL_AWARE_LEFT_ANTI_JOIN));
+  if (process_build_batch_fn_ == nullptr) {
+    RETURN_IF_ERROR(ProcessBuildBatch(batch, ht_ctx_.get(), build_filters,
+        join_op_ == TJoinOp::NULL_AWARE_LEFT_ANTI_JOIN));
 
   } else {
-    DCHECK(process_build_batch_fn_level0_ != NULL);
+    DCHECK(process_build_batch_fn_level0_ != nullptr);
     if (ht_ctx_->level() == 0) {
       RETURN_IF_ERROR(
           process_build_batch_fn_level0_(this, batch, ht_ctx_.get(), build_filters,
@@ -244,6 +230,12 @@
   }
 
   RETURN_IF_ERROR(BuildHashTablesAndReserveProbeBuffers());
+  if (state_ == HashJoinState::PARTITIONING_BUILD) {
+    UpdateState(HashJoinState::PARTITIONING_PROBE);
+  } else {
+    DCHECK_ENUM_EQ(state_, HashJoinState::REPARTITIONING_BUILD);
+    UpdateState(HashJoinState::REPARTITIONING_PROBE);
+  }
   return Status::OK();
 }
 
@@ -265,11 +257,52 @@
 
 void PhjBuilder::Reset(RowBatch* row_batch) {
   DCHECK_EQ(0, probe_stream_reservation_.GetReservation());
+  state_ = HashJoinState::PARTITIONING_BUILD;
   expr_results_pool_->Clear();
   non_empty_build_ = false;
   CloseAndDeletePartitions(row_batch);
 }
 
+void PhjBuilder::UpdateState(HashJoinState next_state) {
+  // Validate the state transition.
+  switch (state_) {
+    case HashJoinState::PARTITIONING_BUILD:
+      DCHECK_ENUM_EQ(next_state, HashJoinState::PARTITIONING_PROBE);
+      break;
+    case HashJoinState::PARTITIONING_PROBE:
+    case HashJoinState::REPARTITIONING_PROBE:
+    case HashJoinState::PROBING_SPILLED_PARTITION:
+      DCHECK(next_state == HashJoinState::REPARTITIONING_BUILD
+          || next_state == HashJoinState::PROBING_SPILLED_PARTITION);
+      break;
+    case HashJoinState::REPARTITIONING_BUILD:
+      DCHECK_ENUM_EQ(next_state, HashJoinState::REPARTITIONING_PROBE);
+      break;
+    default:
+      DCHECK(false) << "Invalid state " << static_cast<int>(state_);
+  }
+  state_ = next_state;
+  VLOG(2) << "Transitioned State:" << endl << DebugString();
+}
+
+string PhjBuilder::PrintState() const {
+  switch (state_) {
+    case HashJoinState::PARTITIONING_BUILD:
+      return "PartitioningBuild";
+    case HashJoinState::PARTITIONING_PROBE:
+      return "PartitioningProbe";
+    case HashJoinState::PROBING_SPILLED_PARTITION:
+      return "ProbingSpilledPartition";
+    case HashJoinState::REPARTITIONING_BUILD:
+      return "RepartitioningBuild";
+    case HashJoinState::REPARTITIONING_PROBE:
+      return "RepartitioningProbe";
+    default:
+      DCHECK(false);
+  }
+  return "";
+}
+
 Status PhjBuilder::CreateAndPreparePartition(int level, Partition** partition) {
   all_partitions_.emplace_back(new Partition(runtime_state_, this, level));
   *partition = all_partitions_.back().get();
@@ -322,7 +355,7 @@
     for (Partition* candidate : hash_partitions_) {
       if (!candidate->CanSpill()) continue;
       int64_t mem = candidate->build_rows()->BytesPinned(false);
-      if (candidate->hash_tbl() != NULL) {
+      if (candidate->hash_tbl() != nullptr) {
         // The hash table should not have matches, since we have not probed it yet.
         // Losing match info would lead to incorrect results (IMPALA-1488).
         DCHECK(!candidate->hash_tbl()->HasMatches());
@@ -370,7 +403,7 @@
     Partition* partition = hash_partitions_[i];
     if (partition->build_rows()->num_rows() == 0) {
       // This partition is empty, no need to do anything else.
-      partition->Close(NULL);
+      partition->Close(nullptr);
     } else if (partition->is_spilled()) {
       // We don't need any build-side data for spilled partitions in memory.
       RETURN_IF_ERROR(
@@ -412,7 +445,8 @@
 
   // We need a write buffer for probe rows for each spilled partition, and a read buffer
   // if the input is a spilled partition (i.e. that we are repartitioning the input).
-  int num_probe_streams = GetNumSpilledHashPartitions() + (input_is_spilled ? 1 : 0);
+  int num_probe_streams =
+      GetNumSpilledPartitions(hash_partitions_) + (input_is_spilled ? 1 : 0);
   int64_t per_stream_reservation = spillable_buffer_size_;
   int64_t addtl_reservation = num_probe_streams * per_stream_reservation
       - probe_stream_reservation_.GetReservation();
@@ -432,27 +466,38 @@
   return Status::OK();
 }
 
-void PhjBuilder::TransferProbeStreamReservation(BufferPool::ClientHandle* dst) {
-  int num_streams = GetNumSpilledHashPartitions();
-  int64_t saved_reservation = probe_stream_reservation_.GetReservation();
-  DCHECK_GE(saved_reservation, spillable_buffer_size_ * num_streams);
-
-  // TODO: in future we may need to support different clients for the probe.
-  DCHECK_EQ(dst, buffer_pool_client_);
-  dst->RestoreReservation(&probe_stream_reservation_, saved_reservation);
+PhjBuilder::HashPartitions PhjBuilder::BeginInitialProbe(
+    BufferPool::ClientHandle* probe_client) {
+  DCHECK_ENUM_EQ(state_, HashJoinState::PARTITIONING_PROBE);
+  DCHECK_EQ(PARTITION_FANOUT, hash_partitions_.size());
+  TransferProbeStreamReservation(probe_client);
+  return HashPartitions(ht_ctx_->level(), &hash_partitions_, non_empty_build_);
 }
 
-int PhjBuilder::GetNumSpilledHashPartitions() const {
+void PhjBuilder::TransferProbeStreamReservation(BufferPool::ClientHandle* probe_client) {
+  // An extra buffer is needed for reading spilled input stream, unless we're doing the
+  // initial partitioning of rows from the left child.
+  int num_buffers = GetNumSpilledPartitions(hash_partitions_)
+      + (state_ == HashJoinState::PARTITIONING_PROBE ? 0 : 1);
+  int64_t saved_reservation = probe_stream_reservation_.GetReservation();
+  DCHECK_GE(saved_reservation, spillable_buffer_size_ * num_buffers);
+
+  // TODO: in future we may need to support different clients for the probe.
+  DCHECK_EQ(probe_client, buffer_pool_client_);
+  probe_client->RestoreReservation(&probe_stream_reservation_, saved_reservation);
+}
+
+int PhjBuilder::GetNumSpilledPartitions(const vector<Partition*>& partitions) {
   int num_spilled = 0;
-  for (int i = 0; i < hash_partitions_.size(); ++i) {
-    Partition* partition = hash_partitions_[i];
+  for (int i = 0; i < partitions.size(); ++i) {
+    Partition* partition = partitions[i];
     DCHECK(partition != nullptr);
     if (!partition->IsClosed() && partition->is_spilled()) ++num_spilled;
   }
   return num_spilled;
 }
 
-void PhjBuilder::DoneProbing(const bool retain_partition[PARTITION_FANOUT],
+void PhjBuilder::DoneProbingHashPartitions(const bool retain_partition[PARTITION_FANOUT],
     list<Partition*>* output_partitions, RowBatch* batch) {
   DCHECK(output_partitions->empty());
   for (int i = 0; i < PARTITION_FANOUT; ++i) {
@@ -463,7 +508,10 @@
       DCHECK_EQ(partition->build_rows()->BytesPinned(false), 0)
           << "Build was fully unpinned in BuildHashTablesAndPrepareProbeStreams()";
       // Release resources associated with completed partitions.
-      if (!retain_partition[i]) partition->Close(nullptr);
+      if (!retain_partition[i]) {
+        COUNTER_ADD(num_hash_table_builds_skipped_, 1);
+        partition->Close(nullptr);
+      }
     } else if (NeedToProcessUnmatchedBuildRows(join_op_)) {
       output_partitions->push_back(partition);
     } else {
@@ -474,18 +522,31 @@
   hash_partitions_.clear();
 }
 
+void PhjBuilder::DoneProbingSinglePartition(
+    Partition* partition, std::list<Partition*>* output_partitions, RowBatch* batch) {
+  if (NeedToProcessUnmatchedBuildRows(join_op_)) {
+    // If the build partition was in memory, we are done probing this partition.
+    // In case of right-outer, right-anti and full-outer joins, we move this partition
+    // to the list of partitions that we need to output their unmatched build rows.
+    output_partitions->push_back(partition);
+  } else {
+    // In any other case, just close the input build partition.
+    partition->Close(IsLeftSemiJoin(join_op_) ? nullptr : batch);
+  }
+}
+
 void PhjBuilder::CloseAndDeletePartitions(RowBatch* row_batch) {
   // Close all the partitions and clean up all references to them.
   for (unique_ptr<Partition>& partition : all_partitions_) partition->Close(row_batch);
   all_partitions_.clear();
   hash_partitions_.clear();
-  null_aware_partition_ = NULL;
+  null_aware_partition_ = nullptr;
 }
 
 void PhjBuilder::AllocateRuntimeFilters() {
   DCHECK(join_op_ != TJoinOp::NULL_AWARE_LEFT_ANTI_JOIN || filter_ctxs_.size() == 0)
       << "Runtime filters not supported with NULL_AWARE_LEFT_ANTI_JOIN";
-  DCHECK(ht_ctx_ != NULL);
+  DCHECK(ht_ctx_ != nullptr);
   for (int i = 0; i < filter_ctxs_.size(); ++i) {
     if (filter_ctxs_[i].filter->is_bloom_filter()) {
       filter_ctxs_[i].local_bloom_filter =
@@ -548,6 +609,81 @@
   }
 }
 
+Status PhjBuilder::BeginSpilledProbe(bool empty_probe, Partition* partition,
+    BufferPool::ClientHandle* probe_client, bool* repartitioned, int* level,
+    HashPartitions* new_partitions) {
+  DCHECK(partition->is_spilled());
+  DCHECK_EQ(0, hash_partitions_.size());
+
+  if (empty_probe) {
+    // If there are no probe rows, there's no need to build the hash table, and
+    // only partitions with NeedToProcessUnmatcheBuildRows() will have been added
+    // to 'spilled_partitions_' in DoneProbingHashPartitions().
+    DCHECK(NeedToProcessUnmatchedBuildRows(join_op_));
+    bool got_read_buffer = false;
+    RETURN_IF_ERROR(partition->build_rows()->PrepareForRead(true, &got_read_buffer));
+    if (!got_read_buffer) {
+      return mem_tracker()->MemLimitExceeded(
+          runtime_state_, Substitute(PREPARE_FOR_READ_FAILED_ERROR_MSG, join_node_id_));
+    }
+    COUNTER_ADD(num_hash_table_builds_skipped_, 1);
+    UpdateState(HashJoinState::PROBING_SPILLED_PARTITION);
+    *repartitioned = false;
+    *level = partition->level();
+    return Status::OK();
+  }
+
+  // Set aside memory required for reading the probe stream, so that we don't use
+  // it for the hash table.
+  buffer_pool_client_->SaveReservation(
+      &probe_stream_reservation_, spillable_buffer_size_);
+
+  // Try to build a hash table for the spilled build partition.
+  bool built;
+  RETURN_IF_ERROR(partition->BuildHashTable(&built));
+  if (built) {
+    TransferProbeStreamReservation(probe_client);
+    UpdateState(HashJoinState::PROBING_SPILLED_PARTITION);
+    *repartitioned = false;
+    *level = partition->level();
+    return Status::OK();
+  }
+  // This build partition still does not fit in memory, repartition.
+  UpdateState(HashJoinState::REPARTITIONING_BUILD);
+
+  int next_partition_level = partition->level() + 1;
+  if (UNLIKELY(next_partition_level >= MAX_PARTITION_DEPTH)) {
+    return Status(TErrorCode::PARTITIONED_HASH_JOIN_MAX_PARTITION_DEPTH, join_node_id_,
+        MAX_PARTITION_DEPTH);
+  }
+
+  // Spill to free memory from hash tables and pinned streams for use in new partitions.
+  RETURN_IF_ERROR(partition->Spill(BufferedTupleStream::UNPIN_ALL));
+  // Temporarily free up the probe reservation to use when repartitioning. Repartitioning
+  // will reserve as much memory as needed for the probe streams.
+  buffer_pool_client_->RestoreReservation(
+      &probe_stream_reservation_, spillable_buffer_size_);
+
+  DCHECK_EQ(partition->build_rows()->BytesPinned(false), 0) << DebugString();
+  int64_t num_input_rows = partition->build_rows()->num_rows();
+  RETURN_IF_ERROR(RepartitionBuildInput(partition));
+
+  // Check if there was any reduction in the size of partitions after repartitioning.
+  int64_t largest_partition_rows = LargestPartitionRows();
+  DCHECK_GE(num_input_rows, largest_partition_rows) << "Cannot have a partition with "
+                                                       "more rows than the input";
+  if (UNLIKELY(num_input_rows == largest_partition_rows)) {
+    return Status(TErrorCode::PARTITIONED_HASH_JOIN_REPARTITION_FAILS, join_node_id_,
+        next_partition_level, num_input_rows, DebugString(),
+        buffer_pool_client_->DebugString());
+  }
+  TransferProbeStreamReservation(probe_client);
+  *repartitioned = true;
+  *level = ht_ctx_->level();
+  *new_partitions = HashPartitions(ht_ctx_->level(), &hash_partitions_, non_empty_build_);
+  return Status::OK();
+}
+
 Status PhjBuilder::RepartitionBuildInput(Partition* input_partition) {
   int new_level = input_partition->level() + 1;
   DCHECK_GE(new_level, 1);
@@ -588,7 +724,7 @@
   int64_t max_rows = 0;
   for (int i = 0; i < hash_partitions_.size(); ++i) {
     Partition* partition = hash_partitions_[i];
-    DCHECK(partition != NULL);
+    DCHECK(partition != nullptr);
     if (partition->IsClosed()) continue;
     int64_t rows = partition->build_rows()->num_rows();
     if (rows > max_rows) max_rows = rows;
@@ -621,14 +757,14 @@
 void PhjBuilder::Partition::Close(RowBatch* batch) {
   if (IsClosed()) return;
 
-  if (hash_tbl_ != NULL) {
+  if (hash_tbl_ != nullptr) {
     hash_tbl_->StatsCountersAdd(parent_->ht_stats_profile_.get());
     hash_tbl_->Close();
   }
 
   // Transfer ownership of 'build_rows_' memory to 'batch' if 'batch' is not NULL.
   // Flush out the resources to free up memory for subsequent partitions.
-  if (build_rows_ != NULL) {
+  if (build_rows_ != nullptr) {
     build_rows_->Close(batch, RowBatch::FlushMode::FLUSH_RESOURCES);
     build_rows_.reset();
   }
@@ -638,7 +774,7 @@
   DCHECK(!IsClosed());
   RETURN_IF_ERROR(parent_->runtime_state_->StartSpilling(parent_->mem_tracker()));
   // Close the hash table and unpin the stream backing it to free memory.
-  if (hash_tbl() != NULL) {
+  if (hash_tbl() != nullptr) {
     hash_tbl_->Close();
     hash_tbl_.reset();
   }
@@ -655,7 +791,7 @@
 
 Status PhjBuilder::Partition::BuildHashTable(bool* built) {
   SCOPED_TIMER(parent_->build_hash_table_timer_);
-  DCHECK(build_rows_ != NULL);
+  DCHECK(build_rows_ != nullptr);
   *built = false;
 
   // Before building the hash table, we need to pin the rows in memory.
@@ -698,14 +834,14 @@
     DCHECK_EQ(batch.num_rows(), flat_rows.size());
     DCHECK_LE(batch.num_rows(), hash_tbl_->EmptyBuckets());
     TPrefetchMode::type prefetch_mode = state->query_options().prefetch_mode;
-    if (parent_->insert_batch_fn_ != NULL) {
+    if (parent_->insert_batch_fn_ != nullptr) {
       InsertBatchFn insert_batch_fn;
       if (level() == 0) {
         insert_batch_fn = parent_->insert_batch_fn_level0_;
       } else {
         insert_batch_fn = parent_->insert_batch_fn_;
       }
-      DCHECK(insert_batch_fn != NULL);
+      DCHECK(insert_batch_fn != nullptr);
       if (UNLIKELY(
               !insert_batch_fn(this, prefetch_mode, ctx, &batch, flat_rows, &status))) {
         goto not_built;
@@ -722,7 +858,7 @@
 
   // The hash table fits in memory and is built.
   DCHECK(*built);
-  DCHECK(hash_tbl_ != NULL);
+  DCHECK(hash_tbl_ != nullptr);
   is_spilled_ = false;
   COUNTER_ADD(parent_->ht_stats_profile_->num_hash_buckets_,
       hash_tbl_->num_buckets());
@@ -730,7 +866,7 @@
 
 not_built:
   *built = false;
-  if (hash_tbl_ != NULL) {
+  if (hash_tbl_ != nullptr) {
     hash_tbl_->Close();
     hash_tbl_.reset();
   }
@@ -752,7 +888,7 @@
      << "    Build Rows: " << build_rows_->num_rows()
      << " (Bytes pinned: " << build_rows_->BytesPinned(false) << ")"
      << endl;
-  if (hash_tbl_ != NULL) {
+  if (hash_tbl_ != nullptr) {
     ss << "    Hash Table Rows: " << hash_tbl_->size();
   }
   return ss.str();
@@ -794,7 +930,8 @@
 
 string PhjBuilder::DebugString() const {
   stringstream ss;
-  ss << "Hash partitions: " << hash_partitions_.size() << ":" << endl;
+  ss << " PhjBuilder state=" << PrintState()
+     << " Hash partitions: " << hash_partitions_.size() << ":" << endl;
   for (int i = 0; i < hash_partitions_.size(); ++i) {
     ss << " Hash partition " << i << " " << hash_partitions_[i]->DebugString() << endl;
   }
@@ -809,7 +946,7 @@
     llvm::Function* insert_filters_fn) {
   llvm::Function* process_build_batch_fn =
       codegen->GetFunction(IRFunction::PHJ_PROCESS_BUILD_BATCH, true);
-  DCHECK(process_build_batch_fn != NULL);
+  DCHECK(process_build_batch_fn != nullptr);
 
   // Replace call sites
   int replaced =
@@ -864,14 +1001,14 @@
 
   // Finalize ProcessBuildBatch functions
   process_build_batch_fn = codegen->FinalizeFunction(process_build_batch_fn);
-  if (process_build_batch_fn == NULL) {
+  if (process_build_batch_fn == nullptr) {
     return Status(
         "Codegen'd PhjBuilder::ProcessBuildBatch() function "
         "failed verification, see log");
   }
   process_build_batch_fn_level0 =
       codegen->FinalizeFunction(process_build_batch_fn_level0);
-  if (process_build_batch_fn == NULL) {
+  if (process_build_batch_fn == nullptr) {
     return Status(
         "Codegen'd level-zero PhjBuilder::ProcessBuildBatch() "
         "function failed verification, see log");
@@ -928,13 +1065,13 @@
   DCHECK_REPLACE_COUNT(replaced, 1);
 
   insert_batch_fn = codegen->FinalizeFunction(insert_batch_fn);
-  if (insert_batch_fn == NULL) {
+  if (insert_batch_fn == nullptr) {
     return Status(
         "PartitionedHashJoinNode::CodegenInsertBatch(): codegen'd "
         "InsertBatch() function failed verification, see log");
   }
   insert_batch_fn_level0 = codegen->FinalizeFunction(insert_batch_fn_level0);
-  if (insert_batch_fn_level0 == NULL) {
+  if (insert_batch_fn_level0 == nullptr) {
     return Status(
         "PartitionedHashJoinNode::CodegenInsertBatch(): codegen'd zero-level "
         "InsertBatch() function failed verification, see log");
diff --git a/be/src/exec/partitioned-hash-join-builder.h b/be/src/exec/partitioned-hash-join-builder.h
index 6397756..6e614e8 100644
--- a/be/src/exec/partitioned-hash-join-builder.h
+++ b/be/src/exec/partitioned-hash-join-builder.h
@@ -40,6 +40,31 @@
 class ScalarExpr;
 class ScalarExprEvaluator;
 
+/// See partitioned-hash-join-node.h for explanation of the top-level algorithm and how
+/// these states fit in it.
+enum class HashJoinState {
+  /// Partitioning the build (right) child's input into the builder's hash partitions.
+  PARTITIONING_BUILD,
+
+  /// Processing the probe (left) child's input, probing hash tables and
+  /// spilling probe rows into 'probe_hash_partitions_' if necessary.
+  PARTITIONING_PROBE,
+
+  /// Processing the spilled probe rows of a single spilled partition
+  /// ('input_partition_') that fits in memory.
+  PROBING_SPILLED_PARTITION,
+
+  /// Repartitioning the build rows of a single spilled partition ('input_partition_')
+  /// into the builder's hash partitions.
+  /// Corresponds to PARTITIONING_BUILD but reading from a spilled partition.
+  REPARTITIONING_BUILD,
+
+  /// Probing the repartitioned hash partitions of a single spilled partition
+  /// ('input_partition_') with the probe rows of that partition.
+  /// Corresponds to PARTITIONING_PROBE but reading from a spilled partition.
+  REPARTITIONING_PROBE,
+};
+
 /// The build side for the PartitionedHashJoinNode. Build-side rows are hash-partitioned
 /// into PARTITION_FANOUT partitions, with partitions spilled if the full build side
 /// does not fit in memory. Spilled partitions can be repartitioned with a different
@@ -115,20 +140,88 @@
   /// Reset the builder to the same state as it was in after calling Open().
   void Reset(RowBatch* row_batch);
 
-  /// Transfer reservation for probe streams to 'dst'. Memory for one stream was reserved
-  /// per spilled partition in FlushFinal().
-  void TransferProbeStreamReservation(BufferPool::ClientHandle* dst);
+  /// Represents a set of hash partitions to be handed off to the probe side.
+  struct HashPartitions {
+    HashPartitions() { Reset(); }
+    HashPartitions(
+        int level, const std::vector<Partition*>* hash_partitions, bool non_empty_build)
+      : level(level),
+        hash_partitions(hash_partitions),
+        non_empty_build(non_empty_build) {}
 
-  /// Called after probing of the partitions is done. Appends in-memory partitions that
-  /// may contain build rows to output to 'output_partitions' for build modes like
-  /// right outer join that output unmatched rows. Close other in-memory partitions,
-  /// attaching any tuple data to 'batch' if 'batch' is non-NULL. Closes spilled
-  /// partitions if 'retain_spilled_partition' is false for that partition index.
-  /// Invalid to call hash_partition() after this is called.
-  void DoneProbing(const bool retain_spilled_partition[PARTITION_FANOUT],
+    void Reset() {
+      level = -1;
+      hash_partitions = nullptr;
+      non_empty_build = false;
+    }
+
+    // The partitioning level of this set of partitions. -1 indicates that this is
+    // invalid.
+    int level;
+
+    // The current set of hash partitions. Always contains PARTITION_FANOUT partitions.
+    // The partitions may be in-memory, spilled, or closed. Valid until
+    // DoneProbingHashPartitions() is called.
+    const std::vector<Partition*>* hash_partitions;
+
+    // True iff the build side had at least one row in a partition.
+    bool non_empty_build;
+  };
+
+  /// Get hash partitions and reservation for the initial partitionining of the probe
+  /// side. Only valid to call once when in state PARTITIONING_PROBE.
+  /// When this function returns successfully, 'probe_client' will have enough
+  /// reservation for a write buffer for each spilled partition.
+  /// Return the current set of hash partitions.
+  /// TODO: IMPALA-9156: this will be a synchronization point for shared join build.
+  HashPartitions BeginInitialProbe(BufferPool::ClientHandle* probe_client);
+
+  /// Prepare to process the probe side of 'partition', either by building a hash
+  /// table over 'partition', or if does not fit in memory, by repartitioning into
+  /// PARTITION_FANOUT new partitions.
+  ///
+  /// When this function returns successfully, 'probe_client' will have enough
+  /// reservation for a read buffer for the input probe stream and, if repartitioning,
+  /// a write buffer for each spilled partition.
+  ///
+  /// If repartitioning, creates new hash partitions and repartitions 'partition' into
+  /// PARTITION_FANOUT new partitions with level input_partition->level() + 1. The
+  /// previous hash partitions must have been cleared with DoneProbingHashPartitions().
+  /// The new hash partitions are returned in 'new_partitions'.
+  /// TODO: IMPALA-9156: this will be a synchronization point for shared join build.
+  Status BeginSpilledProbe(bool empty_probe, Partition* partition,
+      BufferPool::ClientHandle* probe_client, bool* repartitioned, int* level,
+      HashPartitions* new_partitions);
+
+  /// Called after probing of the hash partitions returned by BeginInitialProbe() or
+  /// BeginSpilledProbe() (when *repartitioning as true) is complete,
+  /// i.e. all of the corresponding probe rows have been processed by
+  /// PartitionedHashJoinNode. Appends in-memory partitions that may contain build
+  /// rows to output to 'output_partitions' for build modes like right outer join
+  /// that output unmatched rows. Close other in-memory partitions, attaching any
+  /// tuple data to 'batch' if 'batch' is non-NULL. Closes spilled partitions if
+  /// 'retain_spilled_partition' is false for that partition index.
+  /// TODO: IMPALA-9156: this will be a synchronization point for shared join build.
+  void DoneProbingHashPartitions(const bool retain_spilled_partition[PARTITION_FANOUT],
       std::list<Partition*>* output_partitions, RowBatch* batch);
 
+  /// Called after probing of a single spilled partition returned by
+  /// BeginSpilledProbe() when *repartitioning is false.
+  ///
+  /// If the join op requires outputting unmatched build rows and the partition
+  /// may have build rows to return, it is appended to 'output_partitions'. Partitions
+  /// returned via 'output_partitions' are ready for the caller to read from - either
+  /// they are in-memory with a hash table built or have build_rows() prepared for
+  /// reading.
+  ///
+  /// If no build rows need to be returned, closes the build partition and attaches any
+  /// tuple data to 'batch' if 'batch' is non-NULL.
+  /// TODO: IMPALA-9156: this will be a synchronization point for shared join build.
+  void DoneProbingSinglePartition(
+      Partition* partition, std::list<Partition*>* output_partitions, RowBatch* batch);
+
   /// Close the null aware partition (if there is one) and set it to NULL.
+  /// TODO: IMPALA-9176: improve the encapsulation of the null-aware partition.
   void CloseNullAwarePartition() {
     if (null_aware_partition_ != nullptr) {
       // We don't need to pass in a batch because the anti-join only returns tuple data
@@ -139,33 +232,24 @@
     }
   }
 
-  /// Creates new hash partitions and repartitions 'input_partition' into PARTITION_FANOUT
-  /// new partitions with level input_partition->level() + 1. The previous hash partitions
-  /// must have been cleared with ClearHashPartitions(). This function reserves enough
-  /// memory for a read buffer for the input probe stream and a write buffer for each
-  /// spilled partition after repartitioning.
-  Status RepartitionBuildInput(Partition* input_partition) WARN_UNUSED_RESULT;
-
-  /// Returns the largest build row count out of the current hash partitions.
-  int64_t LargestPartitionRows() const;
-
   /// True if the hash table may contain rows with one or more NULL join keys. This
   /// depends on the join type and the equijoin conjuncts.
+  /// Valid to call after InitExprsAndFilters(). Thread-safe.
   bool HashTableStoresNulls() const;
 
   void AddHashTableStatsToProfile(RuntimeProfile* profile);
 
-  /// Accessor functions, mainly required to expose state to PartitionedHashJoinNode.
-  inline bool non_empty_build() const { return non_empty_build_; }
+  /// TODO: IMPALA-9156: document thread safety for accessing this from
+  /// multiple PartitionedHashJoinNodes.
+  HashJoinState state() const { return state_; }
+
+  /// Valid to call after InitExprsAndFilters(). Thread-safe.
   inline const std::vector<bool>& is_not_distinct_from() const {
     return is_not_distinct_from_;
   }
-  inline int num_hash_partitions() const { return hash_partitions_.size(); }
-  inline Partition* hash_partition(int partition_idx) const {
-    DCHECK_GE(partition_idx, 0);
-    DCHECK_LT(partition_idx, hash_partitions_.size());
-    return hash_partitions_[partition_idx];
-  }
+
+  /// Accessor to allow PartitionedHashJoinNode to access null_aware_partition_.
+  /// TODO: IMPALA-9176: improve the encapsulation of the null-aware partition.
   inline Partition* null_aware_partition() const { return null_aware_partition_; }
 
   std::string DebugString() const;
@@ -217,7 +301,7 @@
 
     std::string DebugString();
 
-    bool ALWAYS_INLINE IsClosed() const { return build_rows_ == NULL; }
+    bool ALWAYS_INLINE IsClosed() const { return build_rows_ == nullptr; }
     BufferedTupleStream* ALWAYS_INLINE build_rows() { return build_rows_.get(); }
     HashTable* ALWAYS_INLINE hash_tbl() const { return hash_tbl_.get(); }
     bool ALWAYS_INLINE is_spilled() const { return is_spilled_; }
@@ -279,8 +363,14 @@
   static const char* LLVM_CLASS_NAME;
 
  private:
+  /// Updates 'state_' to 'next_state', logging the transition.
+  void UpdateState(HashJoinState next_state);
+
+  /// Returns the current 'state_' as a string.
+  std::string PrintState() const;
+
   /// Create and initialize a set of hash partitions for partitioning level 'level'.
-  /// The previous hash partitions must have been cleared with ClearHashPartitions().
+  /// The previous hash partitions must have been cleared with DoneProbing().
   /// After calling this, batches are added to the new partitions by calling Send().
   Status CreateHashPartitions(int level) WARN_UNUSED_RESULT;
 
@@ -341,8 +431,23 @@
   /// is encountered or if it runs out of partitions to spill.
   Status ReserveProbeBuffers(bool input_is_spilled) WARN_UNUSED_RESULT;
 
-  /// Returns the number of partitions in 'hash_partitions_' that are spilled.
-  int GetNumSpilledHashPartitions() const;
+  /// Returns the number of partitions in 'partitions' that are spilled.
+  static int GetNumSpilledPartitions(const std::vector<Partition*>& partitions);
+
+  /// Transfer reservation for probe streams to 'probe_client'. Memory for one stream was
+  /// reserved per spilled partition in FlushFinal(), plus the input stream if the input
+  /// partition was spilled.
+  void TransferProbeStreamReservation(BufferPool::ClientHandle* probe_client);
+
+  /// Creates new hash partitions and repartitions 'input_partition' into PARTITION_FANOUT
+  /// new partitions with level input_partition->level() + 1. The previous hash partitions
+  /// must have been cleared with ClearHashPartitions(). This function reserves enough
+  /// memory for a read buffer for the input probe stream and a write buffer for each
+  /// spilled partition after repartitioning.
+  Status RepartitionBuildInput(Partition* input_partition) WARN_UNUSED_RESULT;
+
+  /// Returns the largest build row count out of the current hash partitions.
+  int64_t LargestPartitionRows() const;
 
   /// Calls Close() on every Partition, deletes them, and cleans up any pointers that
   /// may reference them. If 'row_batch' if not NULL, transfers the ownership of all
@@ -412,17 +517,16 @@
   /// Allocator for hash table memory.
   boost::scoped_ptr<Suballocator> ht_allocator_;
 
-  /// If true, the build side has at least one row.
-  bool non_empty_build_;
-
   /// Expressions over input rows for hash table build.
   std::vector<ScalarExpr*> build_exprs_;
 
   /// is_not_distinct_from_[i] is true if and only if the ith equi-join predicate is IS
   /// NOT DISTINCT FROM, rather than equality.
+  /// Set in InitExprsAndFilters() and constant thereafter.
   std::vector<bool> is_not_distinct_from_;
 
   /// Expressions for evaluating input rows for insertion into runtime filters.
+  /// Only includes exprs for filters produced by this builder.
   std::vector<ScalarExpr*> filter_exprs_;
 
   /// List of filters to build. One-to-one correspondence with exprs in 'filter_exprs_'.
@@ -436,37 +540,49 @@
   std::unique_ptr<HashTableStatsProfile> ht_stats_profile_;
 
   /// Total number of partitions created.
-  RuntimeProfile::Counter* partitions_created_;
+  RuntimeProfile::Counter* partitions_created_ = nullptr;
 
   /// The largest fraction (of build side) after repartitioning. This is expected to be
   /// 1 / PARTITION_FANOUT. A value much larger indicates skew.
-  RuntimeProfile::HighWaterMarkCounter* largest_partition_percent_;
+  RuntimeProfile::HighWaterMarkCounter* largest_partition_percent_ = nullptr;
 
   /// Level of max partition (i.e. number of repartitioning steps).
-  RuntimeProfile::HighWaterMarkCounter* max_partition_level_;
+  RuntimeProfile::HighWaterMarkCounter* max_partition_level_ = nullptr;
 
   /// Number of build rows that have been partitioned.
-  RuntimeProfile::Counter* num_build_rows_partitioned_;
+  RuntimeProfile::Counter* num_build_rows_partitioned_ = nullptr;
 
   /// Number of partitions that have been spilled.
-  RuntimeProfile::Counter* num_spilled_partitions_;
+  RuntimeProfile::Counter* num_spilled_partitions_ = nullptr;
 
   /// Number of partitions that have been repartitioned.
-  RuntimeProfile::Counter* num_repartitions_;
+  RuntimeProfile::Counter* num_repartitions_ = nullptr;
 
   /// Time spent partitioning build rows.
-  RuntimeProfile::Counter* partition_build_rows_timer_;
+  RuntimeProfile::Counter* partition_build_rows_timer_ = nullptr;
 
   /// Time spent building hash tables.
-  RuntimeProfile::Counter* build_hash_table_timer_;
+  RuntimeProfile::Counter* build_hash_table_timer_ = nullptr;
+
+  /// Number of partitions which had zero probe rows and we therefore didn't build the
+  /// hash table.
+  RuntimeProfile::Counter* num_hash_table_builds_skipped_ = nullptr;
 
   /// Time spent repartitioning and building hash tables of any resulting partitions
   /// that were not spilled.
-  RuntimeProfile::Counter* repartition_timer_;
+  RuntimeProfile::Counter* repartition_timer_ = nullptr;
 
   /////////////////////////////////////////
   /// BEGIN: Members that must be Reset()
 
+  /// State of the partitioned hash join algorithm. See HashJoinState for more
+  /// information.
+  HashJoinState state_ = HashJoinState::PARTITIONING_BUILD;
+
+  /// If true, the build side has at least one row.
+  /// Set in FlushFinal() and not modified until Reset().
+  bool non_empty_build_ = false;
+
   /// Vector that owns all of the Partition objects.
   std::vector<std::unique_ptr<Partition>> all_partitions_;
 
@@ -483,7 +599,7 @@
   /// NULL if the join is not null aware or we are done processing this partition.
   /// Always NULL once we are done processing the level 0 partitions.
   /// This partition starts off in memory but can be spilled.
-  Partition* null_aware_partition_;
+  Partition* null_aware_partition_ = nullptr;
 
   /// Populated during the hash table building phase if any partitions spilled.
   /// Reservation for one probe stream write buffer per spilled partition is
@@ -511,14 +627,14 @@
   typedef Status (*ProcessBuildBatchFn)(
       PhjBuilder*, RowBatch*, HashTableCtx*, bool build_filters, bool is_null_aware);
   /// Jitted ProcessBuildBatch function pointers.  NULL if codegen is disabled.
-  ProcessBuildBatchFn process_build_batch_fn_;
-  ProcessBuildBatchFn process_build_batch_fn_level0_;
+  ProcessBuildBatchFn process_build_batch_fn_ = nullptr;
+  ProcessBuildBatchFn process_build_batch_fn_level0_ = nullptr;
 
   typedef bool (*InsertBatchFn)(Partition*, TPrefetchMode::type, HashTableCtx*, RowBatch*,
       const std::vector<BufferedTupleStream::FlatRowPtr>&, Status*);
   /// Jitted Partition::InsertBatch() function pointers. NULL if codegen is disabled.
-  InsertBatchFn insert_batch_fn_;
-  InsertBatchFn insert_batch_fn_level0_;
+  InsertBatchFn insert_batch_fn_ = nullptr;
+  InsertBatchFn insert_batch_fn_level0_ = nullptr;
 };
 }
 
diff --git a/be/src/exec/partitioned-hash-join-node-ir.cc b/be/src/exec/partitioned-hash-join-node-ir.cc
index f316928..5038189 100644
--- a/be/src/exec/partitioned-hash-join-node-ir.cc
+++ b/be/src/exec/partitioned-hash-join-node-ir.cc
@@ -272,7 +272,8 @@
 
     // Fetch the hash and expr values' nullness for this row.
     if (expr_vals_cache->IsRowNull()) {
-      if (JoinOp == TJoinOp::NULL_AWARE_LEFT_ANTI_JOIN && builder_->non_empty_build()) {
+      if (JoinOp == TJoinOp::NULL_AWARE_LEFT_ANTI_JOIN
+          && build_hash_partitions_.non_empty_build) {
         const int num_other_join_conjuncts = other_join_conjunct_evals_.size();
         // For NAAJ, we need to treat NULLs on the probe carefully. The logic is:
         // 1. No build rows -> Return this row. The check for 'non_empty_build_'
@@ -302,7 +303,8 @@
         hash_tbl_iterator_ = hash_tbl->FindProbeRow(ht_ctx);
       } else {
         // The build partition is either empty or spilled.
-        PhjBuilder::Partition* build_partition = builder_->hash_partition(partition_idx);
+        PhjBuilder::Partition* build_partition =
+            (*build_hash_partitions_.hash_partitions)[partition_idx];
         ProbePartition* probe_partition = probe_hash_partitions_[partition_idx].get();
         DCHECK((build_partition->IsClosed() && probe_partition == NULL)
             || (build_partition->is_spilled() && probe_partition != NULL));
@@ -365,9 +367,9 @@
 template <int const JoinOp>
 int PartitionedHashJoinNode::ProcessProbeBatch(TPrefetchMode::type prefetch_mode,
     RowBatch* out_batch, HashTableCtx* __restrict__ ht_ctx, Status* __restrict__ status) {
-  DCHECK(state_ == HashJoinState::PARTITIONING_PROBE
-      || state_ == HashJoinState::PROBING_SPILLED_PARTITION
-      || state_ == HashJoinState::REPARTITIONING_PROBE);
+  DCHECK(builder_->state() == HashJoinState::PARTITIONING_PROBE
+      || builder_->state() == HashJoinState::PROBING_SPILLED_PARTITION
+      || builder_->state() == HashJoinState::REPARTITIONING_PROBE);
   ScalarExprEvaluator* const* other_join_conjunct_evals =
       other_join_conjunct_evals_.data();
   const int num_other_join_conjuncts = other_join_conjunct_evals_.size();
diff --git a/be/src/exec/partitioned-hash-join-node.cc b/be/src/exec/partitioned-hash-join-node.cc
index 293171d..cd66e8c 100644
--- a/be/src/exec/partitioned-hash-join-node.cc
+++ b/be/src/exec/partitioned-hash-join-node.cc
@@ -126,8 +126,6 @@
 
   num_probe_rows_partitioned_ =
       ADD_COUNTER(runtime_profile(), "ProbeRowsPartitioned", TUnit::UNIT);
-  num_hash_table_builds_skipped_ =
-      ADD_COUNTER(runtime_profile(), "NumHashTableBuildsSkipped", TUnit::UNIT);
   state->CheckAndAddCodegenDisabledMessage(runtime_profile());
   return Status::OK();
 }
@@ -168,9 +166,10 @@
   probe_expr_results_pool_->Clear();
 
   RETURN_IF_ERROR(BlockingJoinNode::ProcessBuildInputAndOpenProbe(state, builder_.get()));
-  RETURN_IF_ERROR(PrepareForProbe());
 
-  UpdateState(HashJoinState::PARTITIONING_PROBE);
+  build_hash_partitions_ = builder_->BeginInitialProbe(buffer_pool_client());
+  RETURN_IF_ERROR(PrepareForPartitionedProbe());
+
   RETURN_IF_ERROR(BlockingJoinNode::GetFirstProbeRow(state));
   ResetForProbe();
   probe_state_ = ProbeState::PROBING_IN_BATCH;
@@ -198,7 +197,6 @@
     null_probe_output_idx_ = -1;
     matched_null_probe_.clear();
   }
-  state_ = HashJoinState::PARTITIONING_BUILD;
   ht_ctx_->set_level(0);
   CloseAndDeletePartitions(row_batch);
   builder_->Reset(IsLeftSemiJoin(join_op_) ? nullptr : row_batch);
@@ -293,19 +291,19 @@
 }
 
 Status PartitionedHashJoinNode::NextProbeRowBatch(
-    RuntimeState* state, RowBatch* out_batch) {
+    RuntimeState* state, RowBatch* out_batch, bool* eos) {
   DCHECK_ENUM_EQ(probe_state_, ProbeState::PROBING_END_BATCH);
   DCHECK(probe_batch_pos_ == probe_batch_->num_rows() || probe_batch_pos_ == -1);
-  if (state_ == HashJoinState::PARTITIONING_PROBE) {
+  if (builder_->state() == HashJoinState::PARTITIONING_PROBE) {
     DCHECK(input_partition_ == nullptr);
-    RETURN_IF_ERROR(NextProbeRowBatchFromChild(state, out_batch));
+    RETURN_IF_ERROR(NextProbeRowBatchFromChild(state, out_batch, eos));
   } else {
-    DCHECK(state_ == HashJoinState::REPARTITIONING_PROBE
-        || state_ == HashJoinState::PROBING_SPILLED_PARTITION)
-        << PrintState();
+    DCHECK(builder_->state() == HashJoinState::REPARTITIONING_PROBE
+        || builder_->state() == HashJoinState::PROBING_SPILLED_PARTITION)
+        << builder_->DebugString();
     DCHECK(probe_side_eos_);
     DCHECK(input_partition_ != nullptr);
-    RETURN_IF_ERROR(NextSpilledProbeRowBatch(state, out_batch));
+    RETURN_IF_ERROR(NextSpilledProbeRowBatch(state, out_batch, eos));
   }
   // Free expr result allocations of the probe side expressions only after
   // ExprValuesCache has been reset.
@@ -315,10 +313,11 @@
 }
 
 Status PartitionedHashJoinNode::NextProbeRowBatchFromChild(
-    RuntimeState* state, RowBatch* out_batch) {
-  DCHECK_ENUM_EQ(state_, HashJoinState::PARTITIONING_PROBE);
+    RuntimeState* state, RowBatch* out_batch, bool* eos) {
+  DCHECK_ENUM_EQ(builder_->state(), HashJoinState::PARTITIONING_PROBE);
   DCHECK_ENUM_EQ(probe_state_, ProbeState::PROBING_END_BATCH);
   DCHECK(probe_batch_pos_ == probe_batch_->num_rows() || probe_batch_pos_ == -1);
+  *eos = false;
   do {
     // Loop until we find a non-empty row batch.
     probe_batch_->TransferResourceOwnership(out_batch);
@@ -330,6 +329,7 @@
     if (probe_side_eos_) {
       current_probe_row_ = NULL;
       probe_batch_pos_ = -1;
+      *eos = true;
       return Status::OK();
     }
     RETURN_IF_ERROR(child(0)->GetNext(state, probe_batch_.get(), &probe_side_eos_));
@@ -341,11 +341,12 @@
 }
 
 Status PartitionedHashJoinNode::NextSpilledProbeRowBatch(
-    RuntimeState* state, RowBatch* out_batch) {
+    RuntimeState* state, RowBatch* out_batch, bool* eos) {
   DCHECK(input_partition_ != NULL);
-  DCHECK(state_ == HashJoinState::PROBING_SPILLED_PARTITION
-      || state_ == HashJoinState::REPARTITIONING_PROBE);
+  DCHECK(builder_->state() == HashJoinState::PROBING_SPILLED_PARTITION
+      || builder_->state() == HashJoinState::REPARTITIONING_PROBE);
   DCHECK_ENUM_EQ(probe_state_, ProbeState::PROBING_END_BATCH);
+  *eos = false;
   probe_batch_->TransferResourceOwnership(out_batch);
   if (out_batch->AtCapacity()) {
     // The out_batch has resources associated with it that will be recycled on the
@@ -356,48 +357,22 @@
   BufferedTupleStream* probe_rows = input_partition_->probe_rows();
   if (LIKELY(probe_rows->rows_returned() < probe_rows->num_rows())) {
     // Continue from the current probe stream.
-    bool eos = false;
-    RETURN_IF_ERROR(probe_rows->GetNext(probe_batch_.get(), &eos));
+    RETURN_IF_ERROR(probe_rows->GetNext(probe_batch_.get(), eos));
     DCHECK_GT(probe_batch_->num_rows(), 0);
     ResetForProbe();
   } else {
     // Finished processing spilled probe rows from this partition.
-    if (state_ == HashJoinState::PROBING_SPILLED_PARTITION
-        && NeedToProcessUnmatchedBuildRows(join_op_)) {
-      // If the build partition was in memory, we are done probing this partition.
-      // In case of right-outer, right-anti and full-outer joins, we move this partition
-      // to the list of partitions that we need to output their unmatched build rows.
-      DCHECK(output_build_partitions_.empty());
-      DCHECK(output_unmatched_batch_iter_.get() == NULL);
-      if (input_partition_->build_partition()->hash_tbl() != NULL) {
-        hash_tbl_iterator_ =
-            input_partition_->build_partition()->hash_tbl()->FirstUnmatched(
-                ht_ctx_.get());
-      } else {
-        output_unmatched_batch_.reset(new RowBatch(
-            child(1)->row_desc(), runtime_state_->batch_size(), builder_->mem_tracker()));
-        output_unmatched_batch_iter_.reset(
-            new RowBatch::Iterator(output_unmatched_batch_.get(), 0));
-      }
-      output_build_partitions_.push_back(input_partition_->build_partition());
-    } else {
-      // In any other case, just close the input build partition.
-      input_partition_->build_partition()->Close(out_batch);
-    }
-    input_partition_->Close(out_batch);
-    input_partition_.reset();
-    current_probe_row_ = NULL;
+    current_probe_row_ = nullptr;
     probe_batch_pos_ = -1;
+    *eos = true;
   }
   return Status::OK();
 }
 
-// TODO: refactor this method to better separate the logic operating on the builder
-// vs probe data structures.
-Status PartitionedHashJoinNode::PrepareSpilledPartitionForProbe() {
-  VLOG(2) << "PrepareSpilledPartitionForProbe\n" << NodeDebugString();
-  DCHECK(input_partition_ == NULL);
-  DCHECK_EQ(builder_->num_hash_partitions(), 0);
+Status PartitionedHashJoinNode::BeginSpilledProbe() {
+  VLOG(2) << "BeginSpilledProbe\n" << NodeDebugString();
+  DCHECK(input_partition_ == nullptr);
+  DCHECK(build_hash_partitions_.hash_partitions == nullptr);
   DCHECK(probe_hash_partitions_.empty());
   DCHECK(!spilled_partitions_.empty());
 
@@ -406,77 +381,22 @@
   spilled_partitions_.pop_front();
   PhjBuilder::Partition* build_partition = input_partition_->build_partition();
   DCHECK(build_partition->is_spilled());
-  if (input_partition_->probe_rows()->num_rows() == 0) {
-    // If there are no probe rows, there's no need to build the hash table, and
-    // only partitions with NeedToProcessUnmatcheBuildRows() will have been added
-    // to 'spilled_partitions_' in DoneProbing().
-    DCHECK(NeedToProcessUnmatchedBuildRows(join_op_));
-    bool got_read_buffer = false;
-    RETURN_IF_ERROR(input_partition_->build_partition()->build_rows()->PrepareForRead(
-        true, &got_read_buffer));
-    if (!got_read_buffer) {
-      return mem_tracker()->MemLimitExceeded(
-          runtime_state_, Substitute(PREPARE_FOR_READ_FAILED_ERROR_MSG, id_));
-    }
+  DCHECK_EQ(input_partition_->probe_rows()->BytesPinned(false), 0) << NodeDebugString();
 
-    UpdateState(HashJoinState::PROBING_SPILLED_PARTITION);
-    COUNTER_ADD(num_hash_table_builds_skipped_, 1);
+  bool empty_probe = input_partition_->probe_rows()->num_rows() == 0;
+  bool repartitioned;
+  int level;
+  RETURN_IF_ERROR(builder_->BeginSpilledProbe(empty_probe, build_partition,
+      buffer_pool_client(), &repartitioned, &level, &build_hash_partitions_));
+
+  ht_ctx_->set_level(level);
+  if (empty_probe) {
     return Status::OK();
-  }
-
-  // Make sure we have a buffer to read the probe rows before we build the hash table.
-  // TODO: we should set aside the reservation without allocating the buffer, then
-  // move the repartitioning logic into the builder.
-  RETURN_IF_ERROR(input_partition_->PrepareForRead());
-  ht_ctx_->set_level(build_partition->level());
-
-  // Try to build a hash table for the spilled build partition.
-  bool built;
-  RETURN_IF_ERROR(build_partition->BuildHashTable(&built));
-
-  if (!built) {
-    // This build partition still does not fit in memory, repartition.
-    UpdateState(HashJoinState::REPARTITIONING_BUILD);
-
-    int next_partition_level = build_partition->level() + 1;
-    if (UNLIKELY(next_partition_level >= MAX_PARTITION_DEPTH)) {
-      return Status(TErrorCode::PARTITIONED_HASH_JOIN_MAX_PARTITION_DEPTH, id(),
-          MAX_PARTITION_DEPTH);
-    }
-    ht_ctx_->set_level(next_partition_level);
-
-    // Spill to free memory from hash tables and pinned streams for use in new partitions.
-    RETURN_IF_ERROR(build_partition->Spill(BufferedTupleStream::UNPIN_ALL));
-    // Temporarily free up the probe buffer to use when repartitioning.
-    RETURN_IF_ERROR(
-        input_partition_->probe_rows()->UnpinStream(BufferedTupleStream::UNPIN_ALL));
-    DCHECK_EQ(build_partition->build_rows()->BytesPinned(false), 0) << NodeDebugString();
-    DCHECK_EQ(input_partition_->probe_rows()->BytesPinned(false), 0) << NodeDebugString();
-    int64_t num_input_rows = build_partition->build_rows()->num_rows();
-    RETURN_IF_ERROR(builder_->RepartitionBuildInput(build_partition));
-
-    // Check if there was any reduction in the size of partitions after repartitioning.
-    int64_t largest_partition_rows = builder_->LargestPartitionRows();
-    DCHECK_GE(num_input_rows, largest_partition_rows) << "Cannot have a partition with "
-                                                         "more rows than the input";
-    if (UNLIKELY(num_input_rows == largest_partition_rows)) {
-      return Status(TErrorCode::PARTITIONED_HASH_JOIN_REPARTITION_FAILS, id_,
-          next_partition_level, num_input_rows, NodeDebugString(),
-          buffer_pool_client()->DebugString());
-    }
-    RETURN_IF_ERROR(PrepareForProbe());
-    UpdateState(HashJoinState::REPARTITIONING_PROBE);
+  } else if (repartitioned) {
+    RETURN_IF_ERROR(PrepareForPartitionedProbe());
   } else {
-    DCHECK(!input_partition_->build_partition()->is_spilled());
-    DCHECK(input_partition_->build_partition()->hash_tbl() != NULL);
-    // In this case, we did not have to partition the build again, we just built
-    // a hash table. This means the probe does not have to be partitioned either.
-    for (int i = 0; i < PARTITION_FANOUT; ++i) {
-      hash_tbls_[i] = input_partition_->build_partition()->hash_tbl();
-    }
-    UpdateState(HashJoinState::PROBING_SPILLED_PARTITION);
+    RETURN_IF_ERROR(PrepareForUnpartitionedProbe());
   }
-
   COUNTER_ADD(num_probe_rows_partitioned_, input_partition_->probe_rows()->num_rows());
   return Status::OK();
 }
@@ -569,8 +489,8 @@
   // See the definition of ProbeState for description of the state machine and states.
   do {
     DCHECK(status.ok());
-    DCHECK(state_ != HashJoinState::PARTITIONING_BUILD)
-        << "Should not be in GetNext() " << static_cast<int>(state_);
+    DCHECK(builder_->state() != HashJoinState::PARTITIONING_BUILD)
+        << "Should not be in GetNext() " << static_cast<int>(builder_->state());
     RETURN_IF_CANCELLED(state);
     RETURN_IF_ERROR(QueryMaintenance(state));
     switch (probe_state_) {
@@ -586,12 +506,12 @@
       }
       case ProbeState::PROBING_END_BATCH: {
         // Try to get the next row batch from the current probe input.
-        RETURN_IF_ERROR(NextProbeRowBatch(state, out_batch));
-
+        bool probe_eos;
+        RETURN_IF_ERROR(NextProbeRowBatch(state, out_batch, &probe_eos));
         if (probe_batch_pos_ == 0) {
           // Got a batch, need to process it.
           probe_state_ = ProbeState::PROBING_IN_BATCH;
-        } else if (probe_side_eos_ && input_partition_ == nullptr) {
+        } else if (probe_eos) {
           DCHECK_EQ(probe_batch_pos_, -1);
           // Finished processing all the probe rows for the current hash partitions.
           // There may be some partitions that need to outpt their unmatched build rows.
@@ -608,7 +528,7 @@
       }
       case ProbeState::OUTPUTTING_UNMATCHED: {
         DCHECK(!output_build_partitions_.empty());
-        DCHECK_EQ(builder_->num_hash_partitions(), 0);
+        DCHECK(build_hash_partitions_.hash_partitions == nullptr);
         DCHECK(probe_hash_partitions_.empty());
         DCHECK(NeedToProcessUnmatchedBuildRows(join_op_));
         // Output the remaining batch of build rows from the current partition.
@@ -620,7 +540,7 @@
       case ProbeState::PROBE_COMPLETE: {
         if (!spilled_partitions_.empty()) {
           // Move to the next spilled partition.
-          RETURN_IF_ERROR(PrepareSpilledPartitionForProbe());
+          RETURN_IF_ERROR(BeginSpilledProbe());
           probe_state_ = ProbeState::PROBING_END_BATCH;
         } else if (join_op_ == TJoinOp::NULL_AWARE_LEFT_ANTI_JOIN
             && builder_->null_aware_partition() != nullptr) {
@@ -993,19 +913,21 @@
   return Status::OK();
 }
 
-Status PartitionedHashJoinNode::PrepareForProbe() {
-  DCHECK_EQ(builder_->num_hash_partitions(), PARTITION_FANOUT);
+Status PartitionedHashJoinNode::PrepareForPartitionedProbe() {
+  DCHECK(builder_->state() == HashJoinState::PARTITIONING_PROBE
+      || builder_->state() == HashJoinState::REPARTITIONING_PROBE)
+      << builder_->DebugString();
+  DCHECK_EQ(PARTITION_FANOUT, build_hash_partitions_.hash_partitions->size());
   DCHECK(probe_hash_partitions_.empty());
-
-  // Initialize the probe partitions, providing them with probe streams.
+  // Initialize the probe partitions, providing them with probe streams. The reservation
+  // for the probe streams was obtained from 'builder_' when BeginInitialProbe()
+  // or BeginSpilledProbe() was called.
   vector<unique_ptr<BufferedTupleStream>> probe_streams;
-  builder_->TransferProbeStreamReservation(buffer_pool_client());
   if (input_partition_ != nullptr) {
+    DCHECK_ENUM_EQ(builder_->state(), HashJoinState::REPARTITIONING_PROBE);
     // This is a spilled partition - we need to read the probe rows. Memory was reserved
     // in RepartitionBuildInput() for the input stream's read buffer.
-    bool got_buffer;
-    RETURN_IF_ERROR(input_partition_->probe_rows()->PrepareForRead(true, &got_buffer));
-    DCHECK(got_buffer) << "Memory should have been reserved by builder";
+    RETURN_IF_ERROR(input_partition_->PrepareForRead());
   }
 
   bool have_spilled_hash_partitions;
@@ -1024,12 +946,12 @@
 
   // Initialize the hash_tbl_ caching array.
   for (int i = 0; i < PARTITION_FANOUT; ++i) {
-    hash_tbls_[i] = builder_->hash_partition(i)->hash_tbl();
+    hash_tbls_[i] = (*build_hash_partitions_.hash_partitions)[i]->hash_tbl();
   }
 
   // Validate the state of the partitions.
   for (int i = 0; i < PARTITION_FANOUT; ++i) {
-    PhjBuilder::Partition* build_partition = builder_->hash_partition(i);
+    PhjBuilder::Partition* build_partition = (*build_hash_partitions_.hash_partitions)[i];
     ProbePartition* probe_partition = probe_hash_partitions_[i].get();
     if (build_partition->IsClosed()) {
       DCHECK(hash_tbls_[i] == NULL);
@@ -1047,22 +969,43 @@
 
 Status PartitionedHashJoinNode::CreateProbeHashPartitions(
     bool* have_spilled_hash_partitions) {
+  DCHECK_EQ(PARTITION_FANOUT, build_hash_partitions_.hash_partitions->size());
   *have_spilled_hash_partitions = false;
   probe_hash_partitions_.resize(PARTITION_FANOUT);
   for (int i = 0; i < PARTITION_FANOUT; ++i) {
-    PhjBuilder::Partition* build_partition = builder_->hash_partition(i);
+    PhjBuilder::Partition* build_partition = (*build_hash_partitions_.hash_partitions)[i];
     if (build_partition->IsClosed() || !build_partition->is_spilled()) continue;
     *have_spilled_hash_partitions = true;
     DCHECK(probe_hash_partitions_[i] == nullptr);
     // Put partition into vector so it will be cleaned up in CloseAndDeletePartitions()
     // if Init() fails.
     probe_hash_partitions_[i] =
-        make_unique<ProbePartition>(runtime_state_, this, builder_->hash_partition(i));
+        make_unique<ProbePartition>(runtime_state_, this, build_partition);
     RETURN_IF_ERROR(probe_hash_partitions_[i]->PrepareForWrite(this, false));
   }
   return Status::OK();
 }
 
+Status PartitionedHashJoinNode::PrepareForUnpartitionedProbe() {
+  DCHECK_ENUM_EQ(builder_->state(), HashJoinState::PROBING_SPILLED_PARTITION);
+  DCHECK(build_hash_partitions_.hash_partitions == nullptr);
+  DCHECK(probe_hash_partitions_.empty());
+  DCHECK(input_partition_ != nullptr);
+  DCHECK(!input_partition_->build_partition()->is_spilled());
+  DCHECK(input_partition_->build_partition()->hash_tbl() != nullptr);
+
+  // This is a spilled partition - we need to read the probe rows. Memory was reserved
+  // in builder_->BeginSpilledProbe() for the input stream's read buffer.
+  RETURN_IF_ERROR(input_partition_->PrepareForRead());
+
+  // In this case, we did not have to partition the build again, we just built
+  // a hash table. This means the probe does not have to be partitioned either.
+  for (int i = 0; i < PARTITION_FANOUT; ++i) {
+    hash_tbls_[i] = input_partition_->build_partition()->hash_tbl();
+  }
+  return Status::OK();
+}
+
 bool PartitionedHashJoinNode::AppendProbeRowSlow(
     BufferedTupleStream* stream, TupleRow* row, Status* status) {
   if (!status->ok()) return false; // Check if AddRow() set status.
@@ -1129,64 +1072,78 @@
 Status PartitionedHashJoinNode::DoneProbing(RuntimeState* state, RowBatch* batch) {
   DCHECK_ENUM_EQ(probe_state_, ProbeState::PROBING_END_BATCH);
   DCHECK_EQ(probe_batch_pos_, -1);
+  DCHECK(output_build_partitions_.empty());
   // At this point all the rows have been read from the probe side for all partitions in
   // hash_partitions_.
   VLOG(2) << "Probe Side Consumed\n" << NodeDebugString();
-  if (builder_->num_hash_partitions() == 0) {
-    // No hash partitions, so no cleanup required. This can only happen when we are
-    // processing a single spilled partition.
-    DCHECK_ENUM_EQ(state_, HashJoinState::PROBING_SPILLED_PARTITION);
-    return Status::OK();
-  }
-
-  // Walk the partitions that had hash tables built for the probe phase and close them.
-  // In the case of right outer and full outer joins, instead of closing those partitions,
-  // add them to the list of partitions that need to output any unmatched build rows.
-  // This partition will be closed by the function that actually outputs unmatched build
-  // rows.
-  DCHECK_EQ(builder_->num_hash_partitions(), PARTITION_FANOUT);
-  DCHECK_EQ(probe_hash_partitions_.size(), PARTITION_FANOUT);
-  // The build partitions we need to retain for further processing.
-  bool retain_spilled_partition[PARTITION_FANOUT] = {false};
-  for (int i = 0; i < PARTITION_FANOUT; ++i) {
-    ProbePartition* probe_partition = probe_hash_partitions_[i].get();
-    if (probe_partition == nullptr) {
-      // Partition was not spilled.
-      if (join_op_ == TJoinOp::NULL_AWARE_LEFT_ANTI_JOIN) {
-        // For NAAJ, we need to try to match the NULL probe rows with this build partition
-        // before we are done with it.
-        PhjBuilder::Partition* build_partition = builder_->hash_partition(i);
-        if (!build_partition->IsClosed()) {
-          RETURN_IF_ERROR(EvaluateNullProbe(state, build_partition->build_rows()));
+  if (builder_->state() == HashJoinState::PROBING_SPILLED_PARTITION) {
+    // Need to clean up single in-memory build partition instead of hash partitions.
+    DCHECK(build_hash_partitions_.hash_partitions == nullptr);
+    DCHECK(input_partition_ != nullptr);
+    builder_->DoneProbingSinglePartition(input_partition_->build_partition(),
+        &output_build_partitions_, IsLeftSemiJoin(join_op_) ? nullptr : batch);
+  } else {
+    // Walk the partitions that had hash tables built for the probe phase and close them.
+    // In the case of right outer and full outer joins, instead of closing those
+    // partitions, add them to the list of partitions that need to output any unmatched
+    // build rows. This partition will be closed by the function that actually outputs
+    // unmatched build rows.
+    DCHECK_EQ(build_hash_partitions_.hash_partitions->size(), PARTITION_FANOUT);
+    DCHECK_EQ(probe_hash_partitions_.size(), PARTITION_FANOUT);
+    // The build partitions we need to retain for further processing.
+    bool retain_spilled_partition[PARTITION_FANOUT] = {false};
+    for (int i = 0; i < PARTITION_FANOUT; ++i) {
+      ProbePartition* probe_partition = probe_hash_partitions_[i].get();
+      if (probe_partition == nullptr) {
+        // Partition was not spilled.
+        if (join_op_ == TJoinOp::NULL_AWARE_LEFT_ANTI_JOIN) {
+          // For NAAJ, we need to try to match the NULL probe rows with this build
+          // partition before we are done with it.
+          PhjBuilder::Partition* build_partition =
+              (*build_hash_partitions_.hash_partitions)[i];
+          if (!build_partition->IsClosed()) {
+            RETURN_IF_ERROR(EvaluateNullProbe(state, build_partition->build_rows()));
+          }
         }
+      } else if (probe_partition->probe_rows()->num_rows() != 0
+          || NeedToProcessUnmatchedBuildRows(join_op_)) {
+        retain_spilled_partition[i] = true;
+        // Unpin the probe stream to free up more memory. We need to free all memory so we
+        // can recurse the algorithm and create new hash partitions from spilled
+        // partitions.
+        RETURN_IF_ERROR(
+            probe_partition->probe_rows()->UnpinStream(BufferedTupleStream::UNPIN_ALL));
+        // Push newly created partitions at the front. This means a depth first walk
+        // (more finely partitioned partitions are processed first). This allows us
+        // to delete blocks earlier and bottom out the recursion earlier.
+        spilled_partitions_.push_front(std::move(probe_hash_partitions_[i]));
+      } else {
+        // There's no more processing to do for this partition, and since there were no
+        // probe rows we didn't return any rows that reference memory from these
+        // partitions, so just free the resources.
+        probe_partition->Close(nullptr);
       }
-    } else if (probe_partition->probe_rows()->num_rows() != 0 ||
-        NeedToProcessUnmatchedBuildRows(join_op_)) {
-      retain_spilled_partition[i] = true;
-      // Unpin the probe stream to free up more memory. We need to free all memory so we
-      // can recurse the algorithm and create new hash partitions from spilled partitions.
-      RETURN_IF_ERROR(
-          probe_partition->probe_rows()->UnpinStream(BufferedTupleStream::UNPIN_ALL));
-      // Push newly created partitions at the front. This means a depth first walk
-      // (more finely partitioned partitions are processed first). This allows us
-      // to delete blocks earlier and bottom out the recursion earlier.
-      spilled_partitions_.push_front(std::move(probe_hash_partitions_[i]));
-    } else {
-      // There's no more processing to do for this partition, and since there were no
-      // probe rows we didn't return any rows that reference memory from these
-      // partitions, so just free the resources.
-      probe_partition->Close(nullptr);
-      COUNTER_ADD(num_hash_table_builds_skipped_, 1);
     }
+    probe_hash_partitions_.clear();
+    build_hash_partitions_.Reset();
+    builder_->DoneProbingHashPartitions(retain_spilled_partition,
+        &output_build_partitions_, IsLeftSemiJoin(join_op_) ? nullptr : batch);
   }
-  probe_hash_partitions_.clear();
-
-  builder_->DoneProbing(retain_spilled_partition, &output_build_partitions_,
-      IsLeftSemiJoin(join_op_) ? nullptr : batch);
+  if (input_partition_ != nullptr) {
+    input_partition_->Close(batch);
+    input_partition_.reset();
+  }
   if (!output_build_partitions_.empty()) {
     DCHECK(output_unmatched_batch_iter_.get() == nullptr);
-    hash_tbl_iterator_ =
-        output_build_partitions_.front()->hash_tbl()->FirstUnmatched(ht_ctx_.get());
+    PhjBuilder::Partition* output_partition = output_build_partitions_.front();
+    if (output_partition->hash_tbl() != nullptr) {
+      hash_tbl_iterator_ = output_partition->hash_tbl()->FirstUnmatched(ht_ctx_.get());
+    } else {
+      output_unmatched_batch_.reset(new RowBatch(
+          child(1)->row_desc(), runtime_state_->batch_size(), builder_->mem_tracker()));
+      output_unmatched_batch_iter_.reset(
+          new RowBatch::Iterator(output_unmatched_batch_.get(), 0));
+    }
   }
   return Status::OK();
 }
@@ -1200,49 +1157,9 @@
   *out << ")";
 }
 
-void PartitionedHashJoinNode::UpdateState(HashJoinState next_state) {
-  // Validate the state transition.
-  switch (state_) {
-    case HashJoinState::PARTITIONING_BUILD:
-      DCHECK_ENUM_EQ(next_state, HashJoinState::PARTITIONING_PROBE);
-      break;
-    case HashJoinState::PARTITIONING_PROBE:
-    case HashJoinState::REPARTITIONING_PROBE:
-    case HashJoinState::PROBING_SPILLED_PARTITION:
-      DCHECK(next_state == HashJoinState::REPARTITIONING_BUILD
-          || next_state == HashJoinState::PROBING_SPILLED_PARTITION);
-      break;
-    case HashJoinState::REPARTITIONING_BUILD:
-      DCHECK_ENUM_EQ(next_state, HashJoinState::REPARTITIONING_PROBE);
-      break;
-    default:
-      DCHECK(false) << "Invalid state " << static_cast<int>(state_);
-  }
-  state_ = next_state;
-  VLOG(2) << "Transitioned State:" << endl << NodeDebugString();
-}
-
-string PartitionedHashJoinNode::PrintState() const {
-  switch (state_) {
-    case HashJoinState::PARTITIONING_BUILD:
-      return "PartitioningBuild";
-    case HashJoinState::PARTITIONING_PROBE:
-      return "PartitioningProbe";
-    case HashJoinState::PROBING_SPILLED_PARTITION:
-      return "ProbingSpilledPartition";
-    case HashJoinState::REPARTITIONING_BUILD:
-      return "RepartitioningBuild";
-    case HashJoinState::REPARTITIONING_PROBE:
-      return "RepartitioningProbe";
-    default: DCHECK(false);
-  }
-  return "";
-}
-
 string PartitionedHashJoinNode::NodeDebugString() const {
   stringstream ss;
   ss << "PartitionedHashJoinNode (id=" << id() << " op=" << join_op_
-     << " state=" << PrintState()
      << " #spilled_partitions=" << spilled_partitions_.size() << ")" << endl;
 
   if (builder_ != NULL) {
diff --git a/be/src/exec/partitioned-hash-join-node.h b/be/src/exec/partitioned-hash-join-node.h
index a134d54..ba8f141 100644
--- a/be/src/exec/partitioned-hash-join-node.h
+++ b/be/src/exec/partitioned-hash-join-node.h
@@ -98,6 +98,13 @@
 ///      This phase has sub-states (see ProbeState) that are used in GetNext() to drive
 ///      progress.
 ///
+///
+/// TODO: when IMPALA-9156 is implemented, HashJoinState of the builder will drive the
+/// hash join algorithm across all the PartitionedHashJoinNode implementations sharing
+/// the builder. Each PartitionedHashJoinNode implementation will independently execute
+/// its ProbeState state machine, synchronizing via the builder for transitions of the
+/// HashJoinState state machine.
+///
 /// Null aware anti-join (NAAJ) extends the above algorithm by accumulating rows with
 /// NULLs into several different streams, which are processed in a separate step to
 /// produce additional output rows. The NAAJ algorithm is documented in more detail in
@@ -127,30 +134,7 @@
  private:
   class ProbePartition;
 
-  enum class HashJoinState {
-    /// Partitioning the build (right) child's input into the builder's hash partitions.
-    PARTITIONING_BUILD,
-
-    /// Processing the probe (left) child's input, probing hash tables and
-    /// spilling probe rows into 'probe_hash_partitions_' if necessary.
-    PARTITIONING_PROBE,
-
-    /// Processing the spilled probe rows of a single spilled partition
-    /// ('input_partition_') that fits in memory.
-    PROBING_SPILLED_PARTITION,
-
-    /// Repartitioning the build rows of a single spilled partition ('input_partition_')
-    /// into the builder's hash partitions.
-    /// Corresponds to PARTITIONING_BUILD but reading from a spilled partition.
-    REPARTITIONING_BUILD,
-
-    /// Probing the repartitioned hash partitions of a single spilled partition
-    /// ('input_partition_') with the probe rows of that partition.
-    /// Corresponds to PARTITIONING_PROBE but reading from a spilled partition.
-    REPARTITIONING_PROBE,
-  };
-
-  // This enum represents a sub-state of the PARTITIONING_PROBE,
+  // This enum drives a different state machine within the PARTITIONING_PROBE,
   // PROBING_SPILLED_PARTITION and REPARTITIONING_PROBE states.
   // This drives the state machine in GetNext() that processes probe batches and generates
   // output rows. This state machine executes within a HashJoinState state, starting with
@@ -225,11 +209,21 @@
   /// Initialize 'probe_hash_partitions_' and 'hash_tbls_' before probing. One probe
   /// partition is created per spilled build partition, and 'hash_tbls_' is initialized
   /// with pointers to the hash tables of in-memory partitions and NULL pointers for
-  /// spilled or closed partitions.
+  /// spilled or closed partitions. The builder's hash partitions must be initialized
+  /// initialized and present in 'build_hash_partitions_', i.e. the state must be
+  /// PARTITIONING_PROBE or REPARTITIONING_PROBE.
+  ///
+  /// If we are probing a spilled partition (i.e. the state is REPARTITIONING_PROBE), this
+  /// also prepares 'input_partition_' for reading.
+  ///
   /// Called after the builder has partitioned the build rows and built hash tables,
   /// either in the initial build step, or after repartitioning a spilled partition.
   /// After this function returns, all partitions are ready to process probe rows.
-  Status PrepareForProbe() WARN_UNUSED_RESULT;
+  Status PrepareForPartitionedProbe();
+
+  /// Initialize 'hash_tbls_' and 'input_partition_' so that we can read probe rows
+  /// from 'input_partition_' and probe 'hash_tbls_'.
+  Status PrepareForUnpartitionedProbe();
 
   // Initialize 'probe_hash_partitions_'. Each spilled build partition gets a
   // corresponding probe partition. Closed or in-memory build partitions do
@@ -442,39 +436,45 @@
   ///    unmatched rows.
   ///  - If the build partition did not have a hash table, meaning both build and probe
   ///    rows were spilled, move the partition to 'spilled_partitions_'.
+  /// Also cleans up 'input_partition_' (if processing a spilled partition).
   Status DoneProbing(RuntimeState* state, RowBatch* batch) WARN_UNUSED_RESULT;
 
   /// Get the next row batch from the probe (left) side (child(0)), if we are still
   /// doing the first pass over the input (i.e. state_ is PARTITIONING_PROBE) or
   /// from the spilled 'input_partition_' if state_ is REPARTITIONING_PROBE.
   //. If we are done consuming the input, sets 'probe_batch_pos_' to -1, otherwise,
-  /// sets it to 0.  'probe_state_' must be PROBING_END_BATCH.
+  /// sets it to 0.  'probe_state_' must be PROBING_END_BATCH. *eos is true iff
+  /// 'out_batch' contains the last rows from the child or spilled partition.
   Status NextProbeRowBatch(
-      RuntimeState* state, RowBatch* out_batch) WARN_UNUSED_RESULT;
+      RuntimeState* state, RowBatch* out_batch, bool* eos) WARN_UNUSED_RESULT;
 
   /// Get the next row batch from the probe (left) side (child(0)). If we are done
   /// consuming the input, sets 'probe_batch_pos_' to -1, otherwise, sets it to 0.
-  /// 'probe_state_' must be PROBING_END_BATCH.
-  Status NextProbeRowBatchFromChild(
-      RuntimeState* state, RowBatch* out_batch) WARN_UNUSED_RESULT;
+  /// 'probe_state_' must be PROBING_END_BATCH. *eos is true iff 'out_batch'
+  /// contains the last rows from the child.
+  Status NextProbeRowBatchFromChild(RuntimeState* state, RowBatch* out_batch, bool* eos);
 
   /// Get the next probe row batch from 'input_partition_'. If we are done consuming the
   /// input, sets 'probe_batch_pos_' to -1, otherwise, sets it to 0.
-  /// 'probe_state_' must be PROBING_END_BATCH.
-  Status NextSpilledProbeRowBatch(
-      RuntimeState* state, RowBatch* out_batch) WARN_UNUSED_RESULT;
+  /// 'probe_state_' must be PROBING_END_BATCH.. *eos is true iff 'out_batch'
+  /// contains the last rows from 'input_partition_'.
+  Status NextSpilledProbeRowBatch(RuntimeState* state, RowBatch* out_batch, bool* eos);
 
   /// Called when 'probe_state_' is PROBE_COMPLETE to start processing the next spilled
   /// partition. This function sets 'input_partition_' to the chosen partition, then
-  /// processes the entire build side of 'input_partition_'. When this function returns
-  /// function returns, we are ready to consume probe rows in 'input_partition_'.
-  /// If the build side's hash table fits in memory and there are probe rows, we will
-  /// construct input_partition_'s hash table. If it does not fit, meaning we need to
-  /// repartition, this function will repartition the build rows into
-  /// 'builder->hash_partitions_' and prepare for repartitioning the partition's probe
+  /// delegates to 'builder_' to bring all or part of the spilled build side into
+  /// memory' and sets up this node to probe the partition.
+  ///
+  /// If the build side's hash table fits in memory and there are probe rows, then there
+  /// will be a single in-memory partition. If it does not fit, meaning we need to
+  /// repartition, this function will repartition the build rows into PARTITION_FANOUT
+  /// hash partitions and prepare for repartitioning the partition's probe
   /// rows. If there are no probe rows, we just prepare the build side to be read by
   /// OutputUnmatchedBuild().
-  Status PrepareSpilledPartitionForProbe() WARN_UNUSED_RESULT;
+  ///
+  /// When this function returns function returns, we are ready to start reading probe
+  /// rows from 'input_partition_'.
+  Status BeginSpilledProbe() WARN_UNUSED_RESULT;
 
   /// Construct an error status for the null-aware anti-join when it could not fit 'rows'
   /// from the build side in memory.
@@ -498,12 +498,6 @@
   Status CodegenProcessProbeBatch(
       LlvmCodeGen* codegen, TPrefetchMode::type prefetch_mode) WARN_UNUSED_RESULT;
 
-  /// Returns the current state of the partition as a string.
-  std::string PrintState() const;
-
-  /// Updates 'state_' to 'next_state', logging the transition.
-  void UpdateState(HashJoinState next_state);
-
   std::string NodeDebugString() const;
 
   RuntimeState* runtime_state_;
@@ -536,32 +530,28 @@
   /// Time spent evaluating other_join_conjuncts for NAAJ.
   RuntimeProfile::Counter* null_aware_eval_timer_ = nullptr;
 
-  /// Number of partitions which had zero probe rows and we therefore didn't build the
-  /// hash table.
-  RuntimeProfile::Counter* num_hash_table_builds_skipped_ = nullptr;
-
   /////////////////////////////////////////
   /// BEGIN: Members that must be Reset()
 
-  /// State of the partitioned hash join algorithm. See HashJoinState for more
-  /// information.
-  HashJoinState state_ = HashJoinState::PARTITIONING_BUILD;
-
   /// State of the probing algorithm. Used to drive the state machine in GetNext().
   ProbeState probe_state_ = ProbeState::PROBE_COMPLETE;
 
   /// The build-side of the join. Initialized in Init().
   boost::scoped_ptr<PhjBuilder> builder_;
 
-  /// Cache of the per partition hash table to speed up ProcessProbeBatch.
+  /// Last set of hash partitions obtained from builder_. Only valid when the
+  /// builder's state is PARTITIONING_PROBE or REPARTITIONING_PROBE.
+  PhjBuilder::HashPartitions build_hash_partitions_;
+
+  /// Cache of the per partition hash table to speed up ProcessProbeBatch().
   /// In the case where we need to partition the probe:
-  ///  hash_tbls_[i] = builder_->hash_partitions_[i]->hash_tbl();
+  ///  hash_tbls_[i] = (*build_hash_partitions_.hash_partitions)[i]->hash_tbl();
   /// In the case where we don't need to partition the probe:
   ///  hash_tbls_[i] = input_partition_->hash_tbl();
   HashTable* hash_tbls_[PARTITION_FANOUT];
 
   /// Probe partitions, with indices corresponding to the build partitions in
-  /// builder_->hash_partitions(). This is non-empty only in the PARTITIONING_PROBE or
+  /// build_hash_partitions_. This is non-empty only in the PARTITIONING_PROBE or
   /// REPARTITIONING_PROBE states, in which case it has NULL entries for in-memory
   /// build partitions and non-NULL entries for spilled build partitions (so that we
   /// have somewhere to spill the probe rows for the spilled partition).