HBASE-23295 HBaseContext should use most recent delegation token (#47)

Signed-off-by: Balazs Meszaros <meszibalu@apache.org>
diff --git a/spark/hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/HBaseContext.scala b/spark/hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/HBaseContext.scala
index e50a3e8..890e67f 100644
--- a/spark/hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/HBaseContext.scala
+++ b/spark/hbase-spark/src/main/scala/org/apache/hadoop/hbase/spark/HBaseContext.scala
@@ -65,13 +65,11 @@
                    val tmpHdfsConfgFile: String = null)
   extends Serializable with Logging {
 
-  @transient var credentials = UserGroupInformation.getCurrentUser().getCredentials()
   @transient var tmpHdfsConfiguration:Configuration = config
   @transient var appliedCredentials = false
   @transient val job = Job.getInstance(config)
   TableMapReduceUtil.initCredentials(job)
   val broadcastedConf = sc.broadcast(new SerializableWritable(config))
-  val credentialsConf = sc.broadcast(new SerializableWritable(job.getCredentials))
 
   LatestHBaseContextCache.latest = this
 
@@ -233,21 +231,12 @@
   }
 
   def applyCreds[T] (){
-    credentials = UserGroupInformation.getCurrentUser().getCredentials()
-
-    if (log.isDebugEnabled) {
-      logDebug("appliedCredentials:" + appliedCredentials + ",credentials:" + credentials)
-    }
-
-    if (!appliedCredentials && credentials != null) {
+    if (!appliedCredentials) {
       appliedCredentials = true
 
       @transient val ugi = UserGroupInformation.getCurrentUser
-      ugi.addCredentials(credentials)
       // specify that this is a proxy user
       ugi.setAuthenticationMethod(AuthenticationMethod.PROXY)
-
-      ugi.addCredentials(credentialsConf.value.value)
     }
   }
 
diff --git a/spark/hbase-spark/src/test/java/org/apache/hadoop/hbase/spark/TestJavaHBaseContext.java b/spark/hbase-spark/src/test/java/org/apache/hadoop/hbase/spark/TestJavaHBaseContext.java
index 4134ee6..865a3a3 100644
--- a/spark/hbase-spark/src/test/java/org/apache/hadoop/hbase/spark/TestJavaHBaseContext.java
+++ b/spark/hbase-spark/src/test/java/org/apache/hadoop/hbase/spark/TestJavaHBaseContext.java
@@ -52,8 +52,10 @@
 import org.apache.spark.api.java.JavaSparkContext;
 import org.apache.spark.api.java.function.Function;
 import org.junit.After;
+import org.junit.AfterClass;
 import org.junit.Assert;
 import org.junit.Before;
+import org.junit.BeforeClass;
 import org.junit.ClassRule;
 import org.junit.Test;
 import org.junit.experimental.categories.Category;
@@ -70,11 +72,10 @@
   public static final HBaseClassTestRule TIMEOUT =
       HBaseClassTestRule.forClass(TestJavaHBaseContext.class);
 
-  private transient JavaSparkContext jsc;
-  HBaseTestingUtility htu;
-  protected static final Logger LOG = LoggerFactory.getLogger(TestJavaHBaseContext.class);
-
-
+  private static transient JavaSparkContext JSC;
+  private static HBaseTestingUtility TEST_UTIL;
+  private static JavaHBaseContext HBASE_CONTEXT;
+  private static final Logger LOG = LoggerFactory.getLogger(TestJavaHBaseContext.class);
 
   byte[] tableName = Bytes.toBytes("t1");
   byte[] columnFamily = Bytes.toBytes("c");
@@ -82,56 +83,57 @@
   String columnFamilyStr = Bytes.toString(columnFamily);
   String columnFamilyStr1 = Bytes.toString(columnFamily1);
 
+  @BeforeClass
+  public static void setUpBeforeClass() throws Exception {
+
+    JSC = new JavaSparkContext("local", "JavaHBaseContextSuite");
+    TEST_UTIL = new HBaseTestingUtility();
+    Configuration conf = TEST_UTIL.getConfiguration();
+
+    HBASE_CONTEXT = new JavaHBaseContext(JSC, conf);
+
+    LOG.info("cleaning up test dir");
+
+    TEST_UTIL.cleanupTestDir();
+
+    LOG.info("starting minicluster");
+
+    TEST_UTIL.startMiniZKCluster();
+    TEST_UTIL.startMiniHBaseCluster(1, 1);
+
+    LOG.info(" - minicluster started");
+  }
+
+  @AfterClass
+  public static void tearDownAfterClass() throws Exception {
+    LOG.info("shuting down minicluster");
+    TEST_UTIL.shutdownMiniHBaseCluster();
+    TEST_UTIL.shutdownMiniZKCluster();
+    LOG.info(" - minicluster shut down");
+    TEST_UTIL.cleanupTestDir();
+
+    JSC.stop();
+    JSC = null;
+  }
 
   @Before
-  public void setUp() {
-    jsc = new JavaSparkContext("local", "JavaHBaseContextSuite");
+  public void setUp() throws Exception {
 
-    File tempDir = Files.createTempDir();
-    tempDir.deleteOnExit();
-
-    htu = new HBaseTestingUtility();
     try {
-      LOG.info("cleaning up test dir");
-
-      htu.cleanupTestDir();
-
-      LOG.info("starting minicluster");
-
-      htu.startMiniZKCluster();
-      htu.startMiniHBaseCluster(1, 1);
-
-      LOG.info(" - minicluster started");
-
-      try {
-        htu.deleteTable(TableName.valueOf(tableName));
-      } catch (Exception e) {
-        LOG.info(" - no table " + Bytes.toString(tableName) + " found");
-      }
-
-      LOG.info(" - creating table " + Bytes.toString(tableName));
-      htu.createTable(TableName.valueOf(tableName),
-          new byte[][]{columnFamily, columnFamily1});
-      LOG.info(" - created table");
-    } catch (Exception e1) {
-      throw new RuntimeException(e1);
+      TEST_UTIL.deleteTable(TableName.valueOf(tableName));
+    } catch (Exception e) {
+      LOG.info(" - no table {} found", Bytes.toString(tableName));
     }
+
+    LOG.info(" - creating table {}", Bytes.toString(tableName));
+    TEST_UTIL.createTable(TableName.valueOf(tableName),
+        new byte[][]{columnFamily, columnFamily1});
+    LOG.info(" - created table");
   }
 
   @After
-  public void tearDown() {
-    try {
-      htu.deleteTable(TableName.valueOf(tableName));
-      LOG.info("shuting down minicluster");
-      htu.shutdownMiniHBaseCluster();
-      htu.shutdownMiniZKCluster();
-      LOG.info(" - minicluster shut down");
-      htu.cleanupTestDir();
-    } catch (Exception e) {
-      throw new RuntimeException(e);
-    }
-    jsc.stop();
-    jsc = null;
+  public void tearDown() throws Exception {
+      TEST_UTIL.deleteTable(TableName.valueOf(tableName));
   }
 
   @Test
@@ -144,11 +146,9 @@
     list.add("4," + columnFamilyStr + ",a,4");
     list.add("5," + columnFamilyStr + ",a,5");
 
-    JavaRDD<String> rdd = jsc.parallelize(list);
+    JavaRDD<String> rdd = JSC.parallelize(list);
 
-    Configuration conf = htu.getConfiguration();
-
-    JavaHBaseContext hbaseContext = new JavaHBaseContext(jsc, conf);
+    Configuration conf = TEST_UTIL.getConfiguration();
 
     Connection conn = ConnectionFactory.createConnection(conf);
     Table table = conn.getTable(TableName.valueOf(tableName));
@@ -163,7 +163,7 @@
       table.close();
     }
 
-    hbaseContext.bulkPut(rdd,
+    HBASE_CONTEXT.bulkPut(rdd,
             TableName.valueOf(tableName),
             new PutFunction());
 
@@ -212,15 +212,13 @@
     list.add(Bytes.toBytes("2"));
     list.add(Bytes.toBytes("3"));
 
-    JavaRDD<byte[]> rdd = jsc.parallelize(list);
+    JavaRDD<byte[]> rdd = JSC.parallelize(list);
 
-    Configuration conf = htu.getConfiguration();
+    Configuration conf = TEST_UTIL.getConfiguration();
 
     populateTableWithMockData(conf, TableName.valueOf(tableName));
 
-    JavaHBaseContext hbaseContext = new JavaHBaseContext(jsc, conf);
-
-    hbaseContext.bulkDelete(rdd, TableName.valueOf(tableName),
+    HBASE_CONTEXT.bulkDelete(rdd, TableName.valueOf(tableName),
             new JavaHBaseBulkDeleteExample.DeleteFunction(), 2);
 
 
@@ -248,17 +246,15 @@
 
   @Test
   public void testDistributedScan() throws IOException {
-    Configuration conf = htu.getConfiguration();
+    Configuration conf = TEST_UTIL.getConfiguration();
 
     populateTableWithMockData(conf, TableName.valueOf(tableName));
 
-    JavaHBaseContext hbaseContext = new JavaHBaseContext(jsc, conf);
-
     Scan scan = new Scan();
     scan.setCaching(100);
 
     JavaRDD<String> javaRdd =
-            hbaseContext.hbaseRDD(TableName.valueOf(tableName), scan)
+            HBASE_CONTEXT.hbaseRDD(TableName.valueOf(tableName), scan)
                     .map(new ScanConvertFunction());
 
     List<String> results = javaRdd.collect();
@@ -283,16 +279,14 @@
     list.add(Bytes.toBytes("4"));
     list.add(Bytes.toBytes("5"));
 
-    JavaRDD<byte[]> rdd = jsc.parallelize(list);
+    JavaRDD<byte[]> rdd = JSC.parallelize(list);
 
-    Configuration conf = htu.getConfiguration();
+    Configuration conf = TEST_UTIL.getConfiguration();
 
     populateTableWithMockData(conf, TableName.valueOf(tableName));
 
-    JavaHBaseContext hbaseContext = new JavaHBaseContext(jsc, conf);
-
     final JavaRDD<String> stringJavaRDD =
-            hbaseContext.bulkGet(TableName.valueOf(tableName), 2, rdd,
+            HBASE_CONTEXT.bulkGet(TableName.valueOf(tableName), 2, rdd,
             new GetFunction(),
             new ResultFunction());
 
@@ -302,7 +296,7 @@
   @Test
   public void testBulkLoad() throws Exception {
 
-    Path output = htu.getDataTestDir("testBulkLoad");
+    Path output = TEST_UTIL.getDataTestDir("testBulkLoad");
     // Add cell as String: "row,falmily,qualifier,value"
     List<String> list= new ArrayList<String>();
     // row1
@@ -315,14 +309,11 @@
     list.add("2," + columnFamilyStr + ",a,3");
     list.add("2," + columnFamilyStr + ",b,3");
 
-    JavaRDD<String> rdd = jsc.parallelize(list);
+    JavaRDD<String> rdd = JSC.parallelize(list);
 
-    Configuration conf = htu.getConfiguration();
-    JavaHBaseContext hbaseContext = new JavaHBaseContext(jsc, conf);
+    Configuration conf = TEST_UTIL.getConfiguration();
 
-
-
-    hbaseContext.bulkLoad(rdd, TableName.valueOf(tableName), new BulkLoadFunction(),
+    HBASE_CONTEXT.bulkLoad(rdd, TableName.valueOf(tableName), new BulkLoadFunction(),
             output.toUri().getPath(), new HashMap<byte[], FamilyHFileWriteOptions>(), false,
             HConstants.DEFAULT_MAX_FILE_SIZE);
 
@@ -369,7 +360,7 @@
 
   @Test
   public void testBulkLoadThinRows() throws Exception {
-    Path output = htu.getDataTestDir("testBulkLoadThinRows");
+    Path output = TEST_UTIL.getDataTestDir("testBulkLoadThinRows");
     // because of the limitation of scala bulkLoadThinRows API
     // we need to provide data as <row, all cells in that row>
     List<List<String>> list= new ArrayList<List<String>>();
@@ -389,12 +380,11 @@
     list2.add("2," + columnFamilyStr + ",b,3");
     list.add(list2);
 
-    JavaRDD<List<String>> rdd = jsc.parallelize(list);
+    JavaRDD<List<String>> rdd = JSC.parallelize(list);
 
-    Configuration conf = htu.getConfiguration();
-    JavaHBaseContext hbaseContext = new JavaHBaseContext(jsc, conf);
+    Configuration conf = TEST_UTIL.getConfiguration();
 
-    hbaseContext.bulkLoadThinRows(rdd, TableName.valueOf(tableName), new BulkLoadThinRowsFunction(),
+    HBASE_CONTEXT.bulkLoadThinRows(rdd, TableName.valueOf(tableName), new BulkLoadThinRowsFunction(),
             output.toString(), new HashMap<byte[], FamilyHFileWriteOptions>(), false,
             HConstants.DEFAULT_MAX_FILE_SIZE);
 
diff --git a/spark/hbase-spark/src/test/scala/org/apache/hadoop/hbase/spark/HBaseContextSuite.scala b/spark/hbase-spark/src/test/scala/org/apache/hadoop/hbase/spark/HBaseContextSuite.scala
index 83e2ac6..1b35b93 100644
--- a/spark/hbase-spark/src/test/scala/org/apache/hadoop/hbase/spark/HBaseContextSuite.scala
+++ b/spark/hbase-spark/src/test/scala/org/apache/hadoop/hbase/spark/HBaseContextSuite.scala
@@ -27,6 +27,7 @@
 BeforeAndAfterEach with BeforeAndAfterAll  with Logging {
 
   @transient var sc: SparkContext = null
+  var hbaseContext: HBaseContext = null
   var TEST_UTIL = new HBaseTestingUtility
 
   val tableName = "t1"
@@ -49,6 +50,9 @@
     val envMap = Map[String,String](("Xmx", "512m"))
 
     sc = new SparkContext("local", "test", null, Nil, envMap)
+
+    val config = TEST_UTIL.getConfiguration
+    hbaseContext = new HBaseContext(sc, config)
   }
 
   override def afterAll() {
@@ -73,7 +77,6 @@
       (Bytes.toBytes("5"),
         Array((Bytes.toBytes(columnFamily), Bytes.toBytes("e"), Bytes.toBytes("bar"))))))
 
-    val hbaseContext = new HBaseContext(sc, config)
     hbaseContext.bulkPut[(Array[Byte], Array[(Array[Byte], Array[Byte], Array[Byte])])](rdd,
       TableName.valueOf(tableName),
       (putRecord) => {
@@ -132,7 +135,6 @@
         Bytes.toBytes("delete1"),
         Bytes.toBytes("delete3")))
 
-      val hbaseContext = new HBaseContext(sc, config)
       hbaseContext.bulkDelete[Array[Byte]](rdd,
         TableName.valueOf(tableName),
         putRecord => new Delete(putRecord),
@@ -174,7 +176,6 @@
       Bytes.toBytes("get2"),
       Bytes.toBytes("get3"),
       Bytes.toBytes("get4")))
-    val hbaseContext = new HBaseContext(sc, config)
 
     val getRdd = hbaseContext.bulkGet[Array[Byte], String](
       TableName.valueOf(tableName),
@@ -221,7 +222,6 @@
       Bytes.toBytes("get2"),
       Bytes.toBytes("get3"),
       Bytes.toBytes("get4")))
-    val hbaseContext = new HBaseContext(sc, config)
 
     intercept[SparkException] {
       try {
@@ -274,7 +274,6 @@
       Bytes.toBytes("get2"),
       Bytes.toBytes("get3"),
       Bytes.toBytes("get4")))
-    val hbaseContext = new HBaseContext(sc, config)
 
     val getRdd = hbaseContext.bulkGet[Array[Byte], String](
       TableName.valueOf(tableName),
@@ -329,8 +328,6 @@
       connection.close()
     }
 
-    val hbaseContext = new HBaseContext(sc, config)
-
     val scan = new Scan()
     val filter = new FirstKeyOnlyFilter()
     scan.setCaching(100)