KUDU-2671: Restore custom hash schemas properly

Before this patch, ranges with custom hash schemas
were not being restored properly. The table wide
hash schema was incorrectly applied to these ranges.

Change-Id: I8c28b306f2b630a609231a8fb2a5f5652b028d8e
Reviewed-on: http://gerrit.cloudera.org:8080/18791
Reviewed-by: Alexey Serbin <alexey@apache.org>
Reviewed-by: Abhishek Chennaka <achennaka@cloudera.com>
Tested-by: Kudu Jenkins
diff --git a/java/kudu-backup-common/src/main/scala/org/apache/kudu/backup/TableMetadata.scala b/java/kudu-backup-common/src/main/scala/org/apache/kudu/backup/TableMetadata.scala
index 501b124..3c205be 100644
--- a/java/kudu-backup-common/src/main/scala/org/apache/kudu/backup/TableMetadata.scala
+++ b/java/kudu-backup-common/src/main/scala/org/apache/kudu/backup/TableMetadata.scala
@@ -189,7 +189,7 @@
     }
 
     val bounds = table
-      .getRangePartitions(table.getAsyncClient.getDefaultOperationTimeoutMs)
+      .getRangePartitionsWithTableHashSchema(table.getAsyncClient.getDefaultOperationTimeoutMs)
       .asScala
       .map { p =>
         val lowerValues = getBoundValues(p.getDecodedRangeKeyStart(table), columnNames, tableSchema)
@@ -392,6 +392,22 @@
     }
   }
 
+  def getRangeBoundsPartialRowsWithHashSchemas(
+      metadata: TableMetadataPB): Seq[RangeWithHashSchema] = {
+    val schema = getKuduSchema(metadata)
+    metadata.getPartitions.getRangeAndHashPartitionsList.asScala.map { rhp =>
+      val hashSchemas = rhp.getHashPartitionsList.asScala.map { hp =>
+        val colIds = hp.getColumnNamesList.asScala.map { name =>
+          new Integer(schema.getColumnIndex(schema.getColumnId(name)))
+        }
+        new HashBucketSchema(colIds.asJava, hp.getNumBuckets, hp.getSeed)
+      }
+      val lower = getPartialRow(rhp.getBounds.getLowerBoundsList.asScala, schema)
+      val upper = getPartialRow(rhp.getBounds.getUpperBoundsList.asScala, schema)
+      new RangeWithHashSchema(lower, upper, hashSchemas.asJava)
+    }
+  }
+
   def getPartitionSchema(metadata: TableMetadataPB): PartitionSchema = {
     val colNameToId = metadata.getColumnIdsMap.asScala
     val schema = getKuduSchema(metadata)
diff --git a/java/kudu-backup/src/main/scala/org/apache/kudu/backup/KuduRestore.scala b/java/kudu-backup/src/main/scala/org/apache/kudu/backup/KuduRestore.scala
index ced54fb..87c2d84 100644
--- a/java/kudu-backup/src/main/scala/org/apache/kudu/backup/KuduRestore.scala
+++ b/java/kudu-backup/src/main/scala/org/apache/kudu/backup/KuduRestore.scala
@@ -21,6 +21,8 @@
 import org.apache.kudu.client.AlterTableOptions
 import org.apache.kudu.client.KuduPartitioner
 import org.apache.kudu.client.Partition
+import org.apache.kudu.client.RangePartitionBound
+import org.apache.kudu.client.RangePartitionWithCustomHashSchema
 import org.apache.kudu.client.SessionConfiguration.FlushMode
 import org.apache.kudu.spark.kudu.KuduContext
 import org.apache.kudu.spark.kudu.RowConverter
@@ -263,20 +265,79 @@
     // Create the table with the first range partition (or none if there are none).
     val schema = TableMetadata.getKuduSchema(metadata)
     val options = TableMetadata.getCreateTableOptionsWithoutRangePartitions(metadata, restoreOwner)
-    val bounds = TableMetadata.getRangeBoundPartialRows(metadata)
-    bounds.headOption.foreach(bound => {
-      val (lower, upper) = bound
-      options.addRangePartition(lower, upper)
-    })
-    context.createTable(restoreName, schema, options)
+    // Returns the range bounds of the ranges that contain the table wide hash schema.
+    val boundsWithoutHashSchema = TableMetadata.getRangeBoundPartialRows(metadata)
+    // Returns the range bounds and hash schema of the ranges that contain a custom hash schema.
+    val boundsWithCustomHashSchema =
+      TableMetadata.getRangeBoundsPartialRowsWithHashSchemas(metadata)
+    if (boundsWithoutHashSchema.nonEmpty) {
+      // Adds the first range partition with table wide hash schema through create.
+      boundsWithoutHashSchema.headOption.foreach(bound => {
+        val (lower, upper) = bound
+        options.addRangePartition(lower, upper)
+      })
+      context.createTable(restoreName, schema, options)
 
-    // Add the rest of the range partitions through alters.
-    bounds.tail.foreach(bound => {
-      val (lower, upper) = bound
-      val options = new AlterTableOptions()
-      options.addRangePartition(lower, upper)
-      context.syncClient.alterTable(restoreName, options)
-    })
+      // Add the rest of the range partitions with table wide hash schema through alters.
+      boundsWithoutHashSchema.tail.foreach(bound => {
+        val (lower, upper) = bound
+        val options = new AlterTableOptions()
+        options.addRangePartition(lower, upper)
+        context.syncClient.alterTable(restoreName, options)
+      })
+
+      // Adds range partitions with custom hash schema through alters.
+      boundsWithCustomHashSchema.foreach(bound => {
+        val rangePartition = new RangePartitionWithCustomHashSchema(
+          bound.lowerBound,
+          bound.upperBound,
+          RangePartitionBound.INCLUSIVE_BOUND,
+          RangePartitionBound.EXCLUSIVE_BOUND)
+        bound.hashSchemas.asScala.foreach { hp =>
+          val columnNames = hp.getColumnIds.asScala.map { id =>
+            schema.getColumnByIndex(id).getName
+          }
+          rangePartition.addHashPartitions(columnNames.asJava, hp.getNumBuckets, hp.getSeed)
+        }
+        val options = new AlterTableOptions()
+        options.addRangePartition(rangePartition)
+        context.syncClient.alterTable(restoreName, options)
+      })
+    } else if (boundsWithCustomHashSchema.nonEmpty) {
+      // Adds first range partition with custom hash schema through create.
+      boundsWithCustomHashSchema.headOption.foreach(bound => {
+        val rangePartition = new RangePartitionWithCustomHashSchema(
+          bound.lowerBound,
+          bound.upperBound,
+          RangePartitionBound.INCLUSIVE_BOUND,
+          RangePartitionBound.EXCLUSIVE_BOUND)
+        bound.hashSchemas.asScala.foreach { hp =>
+          val columnNames = hp.getColumnIds.asScala.map { id =>
+            schema.getColumnByIndex(id).getName
+          }
+          rangePartition.addHashPartitions(columnNames.asJava, hp.getNumBuckets, hp.getSeed)
+        }
+        options.addRangePartition(rangePartition)
+      })
+      context.createTable(restoreName, schema, options)
+      // Adds rest of range partitions with custom hash schema through alters.
+      boundsWithCustomHashSchema.tail.foreach(bound => {
+        val rangePartition = new RangePartitionWithCustomHashSchema(
+          bound.lowerBound,
+          bound.upperBound,
+          RangePartitionBound.INCLUSIVE_BOUND,
+          RangePartitionBound.EXCLUSIVE_BOUND)
+        bound.hashSchemas.asScala.foreach { hp =>
+          val columnNames = hp.getColumnIds.asScala.map { id =>
+            schema.getColumnByIndex(id).getName
+          }
+          rangePartition.addHashPartitions(columnNames.asJava, hp.getNumBuckets, hp.getSeed)
+        }
+        val options = new AlterTableOptions()
+        options.addRangePartition(rangePartition)
+        context.syncClient.alterTable(restoreName, options)
+      })
+    }
   }
 
   /**
diff --git a/java/kudu-backup/src/test/scala/org/apache/kudu/backup/TestKuduBackup.scala b/java/kudu-backup/src/test/scala/org/apache/kudu/backup/TestKuduBackup.scala
index 40d00c7..f1421b5 100644
--- a/java/kudu-backup/src/test/scala/org/apache/kudu/backup/TestKuduBackup.scala
+++ b/java/kudu-backup/src/test/scala/org/apache/kudu/backup/TestKuduBackup.scala
@@ -590,35 +590,116 @@
   }
 
   @Test
-  def testTableWithCustomHashSchemas(): Unit = {
+  def testTableWithOnlyCustomHashSchemas(): Unit = {
     // Create the initial table and load it with data.
-    val tableName = "testTableWithCustomHashSchemas"
-    var table = kuduClient.createTable(tableName, schema, tableOptionsWithCustomHashSchema)
+    val tableName = "testTableWithOnlyCustomHashSchemas"
+    val table = kuduClient.createTable(tableName, schema, tableOptionsWithCustomHashSchema)
     insertRows(table, 100)
 
     // Run and validate initial backup.
     backupAndValidateTable(tableName, 100, false)
 
-    // Rename the table and insert more rows
-    val newTableName = "impala::default.testTableWithCustomHashSchemas"
-    kuduClient.alterTable(tableName, new AlterTableOptions().renameTable(newTableName))
-    table = kuduClient.openTable(newTableName)
+    // Insert rows then run and validate an incremental backup.
     insertRows(table, 100, 100)
+    backupAndValidateTable(tableName, 100, true)
 
-    // Run and validate an incremental backup.
-    backupAndValidateTable(newTableName, 100, true)
+    // Restore the table and check the row count.
+    restoreAndValidateTable(tableName, 200)
 
-    // Create a new table with the old name.
-    val tableWithOldName =
-      kuduClient.createTable(tableName, schema, tableOptionsWithCustomHashSchema)
-    insertRows(tableWithOldName, 50)
+    // Check the range bounds and the hash schema of each range of the restored table.
+    val restoredTable = kuduClient.openTable(s"$tableName-restore")
+    assertEquals(
+        "[0 <= VALUES < 100 HASH(key) PARTITIONS 2, " +
+        "100 <= VALUES < 200 HASH(key) PARTITIONS 3]",
+      restoredTable.getFormattedRangePartitionsWithHashSchema(10000).toString
+    )
+  }
 
-    // Backup the table with the old name.
-    backupAndValidateTable(tableName, 50, false)
+  @Test
+  def testTableWithTableAndCustomHashSchemas(): Unit = {
+    // Create the initial table and load it with data.
+    val tableName = "testTableWithTableAndCustomHashSchemas"
+    val table = kuduClient.createTable(tableName, schema, tableOptionsWithTableAndCustomHashSchema)
+    insertRows(table, 100)
 
-    // Restore the tables and check the row counts.
-    restoreAndValidateTable(newTableName, 200)
-    restoreAndValidateTable(tableName, 50)
+    // Run and validate initial backup.
+    backupAndValidateTable(tableName, 100, false)
+
+    // Insert rows then run and validate an incremental backup.
+    insertRows(table, 200, 100)
+    backupAndValidateTable(tableName, 200, true)
+
+    // Restore the table and check the row count.
+    restoreAndValidateTable(tableName, 300)
+
+    // Check the range bounds and the hash schema of each range of the restored table.
+    val restoredTable = kuduClient.openTable(s"$tableName-restore")
+    assertEquals(
+        "[0 <= VALUES < 100 HASH(key) PARTITIONS 2, " +
+        "100 <= VALUES < 200 HASH(key) PARTITIONS 3, " +
+        "200 <= VALUES < 300 HASH(key) PARTITIONS 4]",
+      restoredTable.getFormattedRangePartitionsWithHashSchema(10000).toString
+    )
+  }
+
+  @Test
+  def testTableAlterWithTableAndCustomHashSchemas(): Unit = {
+    // Create the initial table and load it with data.
+    val tableName = "testTableAlterWithTableAndCustomHashSchemas"
+    var table = kuduClient.createTable(tableName, schema, tableOptionsWithTableAndCustomHashSchema)
+    insertRows(table, 100)
+
+    // Run and validate initial backup.
+    backupAndValidateTable(tableName, 100, false)
+
+    // Insert rows then run and validate an incremental backup.
+    insertRows(table, 200, 100)
+    backupAndValidateTable(tableName, 200, true)
+
+    // Drops range partition with table wide hash schema and re-adds same range partition with
+    // custom hash schema, also adds another range partition with custom hash schema through alter.
+    val twoHundred = createPartitionRow(200)
+    val threeHundred = createPartitionRow(300)
+    val fourHundred = createPartitionRow(400)
+    val newPartition = new RangePartitionWithCustomHashSchema(
+      twoHundred,
+      threeHundred,
+      RangePartitionBound.INCLUSIVE_BOUND,
+      RangePartitionBound.EXCLUSIVE_BOUND)
+    newPartition.addHashPartitions(List("key").asJava, 5, 0)
+    val newPartition1 = new RangePartitionWithCustomHashSchema(
+      threeHundred,
+      fourHundred,
+      RangePartitionBound.INCLUSIVE_BOUND,
+      RangePartitionBound.EXCLUSIVE_BOUND)
+    newPartition1.addHashPartitions(List("key").asJava, 6, 0)
+    kuduClient.alterTable(
+      tableName,
+      new AlterTableOptions()
+        .dropRangePartition(twoHundred, threeHundred)
+        .addRangePartition(newPartition)
+        .addRangePartition(newPartition1))
+
+    // TODO: Avoid this table refresh by updating partition schema after alter table calls.
+    // See https://issues.apache.org/jira/browse/KUDU-3388 for more details.
+    table = kuduClient.openTable(tableName)
+
+    // Insert rows then run and validate an incremental backup.
+    insertRows(table, 100, 300)
+    backupAndValidateTable(tableName, 100, true)
+
+    // Restore the table and validate.
+    assertTrue(runRestore(createRestoreOptions(Seq(tableName))))
+
+    // Check the range bounds and the hash schema of each range of the restored table.
+    val restoredTable = kuduClient.openTable(s"$tableName-restore")
+    assertEquals(
+        "[0 <= VALUES < 100 HASH(key) PARTITIONS 2, " +
+        "100 <= VALUES < 200 HASH(key) PARTITIONS 3, " +
+        "200 <= VALUES < 300 HASH(key) PARTITIONS 5, " +
+        "300 <= VALUES < 400 HASH(key) PARTITIONS 6]",
+      restoredTable.getFormattedRangePartitionsWithHashSchema(10000).toString
+    )
   }
 
   @Test
diff --git a/java/kudu-client/src/main/java/org/apache/kudu/client/KuduTable.java b/java/kudu-client/src/main/java/org/apache/kudu/client/KuduTable.java
index c5a7c26..5dd06f7 100644
--- a/java/kudu-client/src/main/java/org/apache/kudu/client/KuduTable.java
+++ b/java/kudu-client/src/main/java/org/apache/kudu/client/KuduTable.java
@@ -340,6 +340,32 @@
   public List<Partition> getRangePartitions(long timeout) throws Exception {
     // TODO: This could be moved into the RangeSchemaPB returned from server
     // to avoid an extra call to get the range partitions.
+    return getRangePartitionsHelper(timeout, false);
+  }
+
+  /**
+   * Only retrieves this table's range partitions that contain the table wide hash schema. The
+   * range partitions will be returned in sorted order by value, and will contain no duplicates.
+   *
+   * @param timeout the timeout of the operation
+   * @return a list of the formatted range partitions
+   */
+  @InterfaceAudience.Private
+  @InterfaceStability.Unstable
+  public List<Partition> getRangePartitionsWithTableHashSchema(long timeout) throws Exception {
+    return getRangePartitionsHelper(timeout, true);
+  }
+
+  /**
+   * Helper method that retrieves the table's range partitions. If onlyTableHashSchema is evaluated
+   * to true, then only range partitions that have the table wide hash schema will be returned. The
+   * range partitions will be returned in sorted order by value and will contain no duplicates.
+   * @param timeout the timeout of the operation
+   * @param onlyTableHashSchema whether to filter out the partitions with custom hash schema
+   * @return a list of the formatted range partitions
+   */
+  private List<Partition> getRangePartitionsHelper(long timeout,
+                                                   boolean onlyTableHashSchema) throws Exception {
     List<Partition> rangePartitions = new ArrayList<>();
     List<KuduScanToken> scanTokens = new KuduScanToken.KuduScanTokenBuilder(client, this)
         .setTimeout(timeout)
@@ -351,6 +377,12 @@
       if (!Iterators.all(partition.getHashBuckets().iterator(), Predicates.equalTo(0))) {
         continue;
       }
+      // If onlyTableHashSchema is true, filter out any partitions
+      // that are part of a range that contains a custom hash schema.
+      if (onlyTableHashSchema && partitionSchema.getHashSchemaForRange(partition.rangeKeyStart) !=
+          partitionSchema.getHashBucketSchemas()) {
+        continue;
+      }
       rangePartitions.add(partition);
     }
     return rangePartitions;
diff --git a/java/kudu-client/src/main/java/org/apache/kudu/client/RangePartition.java b/java/kudu-client/src/main/java/org/apache/kudu/client/RangePartition.java
index 903f46a..ff633f6 100644
--- a/java/kudu-client/src/main/java/org/apache/kudu/client/RangePartition.java
+++ b/java/kudu-client/src/main/java/org/apache/kudu/client/RangePartition.java
@@ -26,9 +26,9 @@
  *
  * See also RangePartitionWithCustomHashSchema.
  */
-@InterfaceAudience.Private
+@InterfaceAudience.LimitedPrivate({"kudu-backup", "Test"})
 @InterfaceStability.Evolving
-class RangePartition {
+public class RangePartition {
   final PartialRow lowerBound;
   final PartialRow upperBound;
   final RangePartitionBound lowerBoundType;
diff --git a/java/kudu-client/src/test/java/org/apache/kudu/client/TestKuduTable.java b/java/kudu-client/src/test/java/org/apache/kudu/client/TestKuduTable.java
index 68c5b18..ef3656b 100644
--- a/java/kudu-client/src/test/java/org/apache/kudu/client/TestKuduTable.java
+++ b/java/kudu-client/src/test/java/org/apache/kudu/client/TestKuduTable.java
@@ -1641,6 +1641,89 @@
   }
 
   @Test(timeout = 100000)
+  public void testGetRangePartitionsWithTableHashSchema() throws Exception {
+    // The test table is created with the following ranges:
+    //   (-inf, -100) [-100, 0) [0, 100), [100, +inf)
+
+    CreateTableOptions builder = getBasicCreateTableOptions();
+    // Add table-wide schema with one dimensions and two buckets.
+    builder.addHashPartitions(ImmutableList.of("key"), 2, 0);
+
+    // Add range partition with custom hash schema: (-inf, -100)
+    {
+      PartialRow lower = basicSchema.newPartialRow();
+      PartialRow upper = basicSchema.newPartialRow();
+      upper.addInt(0, -100);
+
+      RangePartitionWithCustomHashSchema rangePartition =
+          new RangePartitionWithCustomHashSchema(
+              lower,
+              upper,
+              RangePartitionBound.INCLUSIVE_BOUND,
+              RangePartitionBound.EXCLUSIVE_BOUND);
+      rangePartition.addHashPartitions(ImmutableList.of("key"), 2, 1);
+
+      builder.addRangePartition(rangePartition);
+    }
+
+    // Add range partition with table-wide hash schema: [-100, 0)
+    {
+      PartialRow lower = basicSchema.newPartialRow();
+      lower.addInt(0, -100);
+      PartialRow upper = basicSchema.newPartialRow();
+      upper.addInt(0, 0);
+
+      builder.addRangePartition(lower, upper);
+    }
+
+    // Add range partition with custom hash schema: [0, 100)
+    {
+      PartialRow lower = basicSchema.newPartialRow();
+      lower.addInt(0, 0);
+      PartialRow upper = basicSchema.newPartialRow();
+      upper.addInt(0, 100);
+
+      RangePartitionWithCustomHashSchema rangePartition =
+          new RangePartitionWithCustomHashSchema(
+              lower,
+              upper,
+              RangePartitionBound.INCLUSIVE_BOUND,
+              RangePartitionBound.EXCLUSIVE_BOUND);
+      rangePartition.addHashPartitions(ImmutableList.of("key"), 5, 0);
+
+      builder.addRangePartition(rangePartition);
+    }
+
+    // Add range partition with table-wide hash schema: [100, +inf)
+    {
+      PartialRow lower = basicSchema.newPartialRow();
+      lower.addInt(0, 100);
+      PartialRow upper = basicSchema.newPartialRow();
+
+      builder.addRangePartition(lower, upper);
+    }
+
+    final KuduTable table = client.createTable(tableName, basicSchema, builder);
+    List<Partition> rangePartitions =
+        table.getRangePartitionsWithTableHashSchema(client.getDefaultOperationTimeoutMs());
+    assertEquals(rangePartitions.size(), 2);
+
+    Partition lowerPartition = rangePartitions.get(0);
+    assertTrue(lowerPartition.getRangeKeyStart().length > 0);
+    assertTrue(lowerPartition.getRangeKeyEnd().length > 0);
+    PartialRow decodedLower = lowerPartition.getDecodedRangeKeyStart(table);
+    assertEquals(-100, decodedLower.getInt("key"));
+    PartialRow decodedUpper = lowerPartition.getDecodedRangeKeyEnd(table);
+    assertEquals(0, decodedUpper.getInt("key"));
+
+    Partition upperPartition = rangePartitions.get(1);
+    assertTrue(upperPartition.getRangeKeyStart().length > 0);
+    assertEquals(0, upperPartition.getRangeKeyEnd().length);
+    PartialRow decodedLowerKey = upperPartition.getDecodedRangeKeyStart(table);
+    assertEquals(100, decodedLowerKey.getInt("key"));
+  }
+
+  @Test(timeout = 100000)
   public void testAlterNoWait() throws Exception {
     client.createTable(tableName, basicSchema, getBasicCreateTableOptions());
 
diff --git a/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/KuduTestSuite.scala b/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/KuduTestSuite.scala
index fdecc22..e30dbf0 100644
--- a/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/KuduTestSuite.scala
+++ b/java/kudu-spark/src/test/scala/org/apache/kudu/spark/kudu/KuduTestSuite.scala
@@ -125,7 +125,7 @@
     val bottom = schema.newPartialRow()
     bottom.addInt("key", 0)
     val middle = schema.newPartialRow()
-    middle.addInt("key", 50)
+    middle.addInt("key", 100)
     val top = schema.newPartialRow()
     top.addInt("key", 200)
 
@@ -152,6 +152,40 @@
       .setNumReplicas(1)
   }
 
+  val tableOptionsWithTableAndCustomHashSchema: CreateTableOptions = {
+    val lowest = schema.newPartialRow()
+    lowest.addInt("key", 0)
+    val low = schema.newPartialRow()
+    low.addInt("key", 100)
+    val high = schema.newPartialRow()
+    high.addInt("key", 200)
+    val highest = schema.newPartialRow()
+    highest.addInt("key", 300)
+
+    val columns = List("key").asJava
+    val partitionFirst = new RangePartitionWithCustomHashSchema(
+      lowest,
+      low,
+      RangePartitionBound.INCLUSIVE_BOUND,
+      RangePartitionBound.EXCLUSIVE_BOUND)
+    partitionFirst.addHashPartitions(columns, 2, 0)
+    val partitionSecond = new RangePartitionWithCustomHashSchema(
+      low,
+      high,
+      RangePartitionBound.INCLUSIVE_BOUND,
+      RangePartitionBound.EXCLUSIVE_BOUND)
+    partitionSecond.addHashPartitions(columns, 3, 0)
+
+    new CreateTableOptions()
+      .setRangePartitionColumns(columns)
+      .addRangePartition(partitionFirst)
+      .addRangePartition(partitionSecond)
+      .addRangePartition(high, highest)
+      .addHashPartitions(columns, 4, 0)
+      .setOwner(owner)
+      .setNumReplicas(1)
+  }
+
   val appID: String = new Date().toString + math
     .floor(math.random * 10E4)
     .toLong