(TWILL-163) Fix FileContextLocationFactory to honor UGI for home location

- Also include removal on usage of the deprecated HDFSLocationFactory

This closes #78 on Github

Signed-off-by: Terence Yim <chtyim@apache.org>
diff --git a/twill-yarn/src/main/java/org/apache/twill/filesystem/FileContextLocationFactory.java b/twill-yarn/src/main/java/org/apache/twill/filesystem/FileContextLocationFactory.java
index d64be71..b8453bc 100644
--- a/twill-yarn/src/main/java/org/apache/twill/filesystem/FileContextLocationFactory.java
+++ b/twill-yarn/src/main/java/org/apache/twill/filesystem/FileContextLocationFactory.java
@@ -50,8 +50,19 @@
    * @param pathBase base path for all non-absolute location created through this {@link LocationFactory}.
    */
   public FileContextLocationFactory(Configuration configuration, String pathBase) {
+    this(configuration, createFileContext(configuration), pathBase);
+  }
+
+  /**
+   * Creates a new instance with the given {@link FileContext} created from the given {@link Configuration}.
+   *
+   * @param configuration the hadoop configuration
+   * @param fc {@link FileContext} instance created from the given configuration
+   * @param pathBase base path for all non-absolute location created through this (@link LocationFactory}.
+   */
+  public FileContextLocationFactory(Configuration configuration, FileContext fc, String pathBase) {
     this.configuration = configuration;
-    this.fc = createFileContext(configuration);
+    this.fc = fc;
     this.pathBase = new Path(pathBase.startsWith("/") ? pathBase : "/" + pathBase);
   }
 
@@ -92,7 +103,9 @@
 
   @Override
   public Location getHomeLocation() {
-    return new FileContextLocation(this, fc, fc.getHomeDirectory());
+    // Fix for TWILL-163. FileContext.getHomeDirectory() uses System.getProperty("user.name") instead of UGI
+    return new FileContextLocation(this, fc,
+                                   new Path(fc.getHomeDirectory().getParent(), fc.getUgi().getShortUserName()));
   }
 
   /**
diff --git a/twill-yarn/src/main/java/org/apache/twill/internal/ServiceMain.java b/twill-yarn/src/main/java/org/apache/twill/internal/ServiceMain.java
index cafd375..a6d9132 100644
--- a/twill-yarn/src/main/java/org/apache/twill/internal/ServiceMain.java
+++ b/twill-yarn/src/main/java/org/apache/twill/internal/ServiceMain.java
@@ -27,10 +27,12 @@
 import com.google.common.util.concurrent.ListenableFuture;
 import com.google.common.util.concurrent.Service;
 import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.CommonConfigurationKeysPublic;
+import org.apache.hadoop.fs.FileContext;
 import org.apache.hadoop.fs.FileSystem;
 import org.apache.hadoop.security.UserGroupInformation;
 import org.apache.twill.api.RunId;
-import org.apache.twill.filesystem.HDFSLocationFactory;
+import org.apache.twill.filesystem.FileContextLocationFactory;
 import org.apache.twill.filesystem.LocalLocationFactory;
 import org.apache.twill.filesystem.Location;
 import org.apache.twill.internal.logging.KafkaAppender;
@@ -50,6 +52,8 @@
 import java.io.File;
 import java.io.StringReader;
 import java.net.URI;
+import java.security.PrivilegedAction;
+import java.security.PrivilegedExceptionAction;
 import java.util.concurrent.ExecutionException;
 import java.util.concurrent.TimeUnit;
 
@@ -127,27 +131,35 @@
   /**
    * Returns the {@link Location} for the application based on the env {@link EnvKeys#TWILL_APP_DIR}.
    */
-  protected static Location createAppLocation(Configuration conf) {
+  protected static Location createAppLocation(final Configuration conf) {
     // Note: It's a little bit hacky based on the uri schema to create the LocationFactory, refactor it later.
-    URI appDir = URI.create(System.getenv(EnvKeys.TWILL_APP_DIR));
+    final URI appDir = URI.create(System.getenv(EnvKeys.TWILL_APP_DIR));
 
     try {
       if ("file".equals(appDir.getScheme())) {
         return new LocalLocationFactory().create(appDir);
       }
 
-      // If not file, assuming it is a FileSystem, hence construct with HDFSLocationFactory which wraps
-      // a FileSystem created from the Configuration
+      // If not file, assuming it is a FileSystem, hence construct with FileContextLocationFactory
+      UserGroupInformation ugi;
       if (UserGroupInformation.isSecurityEnabled()) {
-        return new HDFSLocationFactory(FileSystem.get(appDir, conf)).create(appDir);
+        ugi = UserGroupInformation.getCurrentUser();
+      } else {
+        String fsUser = System.getenv(EnvKeys.TWILL_FS_USER);
+        if (fsUser == null) {
+          throw new IllegalStateException("Missing environment variable " + EnvKeys.TWILL_FS_USER);
+        }
+        ugi = UserGroupInformation.createRemoteUser(fsUser);
       }
-
-      String fsUser = System.getenv(EnvKeys.TWILL_FS_USER);
-      if (fsUser == null) {
-        throw new IllegalStateException("Missing environment variable " + EnvKeys.TWILL_FS_USER);
-      }
-      return new HDFSLocationFactory(FileSystem.get(appDir, conf, fsUser)).create(appDir);
-
+      return ugi.doAs(new PrivilegedExceptionAction<Location>() {
+        @Override
+        public Location run() throws Exception {
+          Configuration hConf = new Configuration(conf);
+          URI defaultURI = new URI(appDir.getScheme(), appDir.getAuthority(), null, null, null);
+          hConf.set(CommonConfigurationKeysPublic.FS_DEFAULT_NAME_KEY, defaultURI.toString());
+          return new FileContextLocationFactory(hConf).create(appDir);
+        }
+      });
     } catch (Exception e) {
       LOG.error("Failed to create application location for {}.", appDir);
       throw Throwables.propagate(e);
diff --git a/twill-yarn/src/main/java/org/apache/twill/internal/yarn/YarnUtils.java b/twill-yarn/src/main/java/org/apache/twill/internal/yarn/YarnUtils.java
index 1c65591..e63deed 100644
--- a/twill-yarn/src/main/java/org/apache/twill/internal/yarn/YarnUtils.java
+++ b/twill-yarn/src/main/java/org/apache/twill/internal/yarn/YarnUtils.java
@@ -17,14 +17,13 @@
  */
 package org.apache.twill.internal.yarn;
 
-import com.google.common.base.Function;
 import com.google.common.base.Preconditions;
 import com.google.common.base.Throwables;
 import com.google.common.collect.ImmutableList;
-import com.google.common.collect.Lists;
-import com.google.common.collect.Maps;
 import org.apache.hadoop.conf.Configuration;
+import org.apache.hadoop.fs.FileContext;
 import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.Path;
 import org.apache.hadoop.io.DataInputByteBuffer;
 import org.apache.hadoop.io.DataOutputBuffer;
 import org.apache.hadoop.security.Credentials;
@@ -39,6 +38,7 @@
 import org.apache.hadoop.yarn.util.ConverterUtils;
 import org.apache.hadoop.yarn.util.Records;
 import org.apache.twill.api.LocalFile;
+import org.apache.twill.filesystem.FileContextLocationFactory;
 import org.apache.twill.filesystem.ForwardingLocationFactory;
 import org.apache.twill.filesystem.HDFSLocationFactory;
 import org.apache.twill.filesystem.LocationFactory;
@@ -50,7 +50,6 @@
 import java.net.InetSocketAddress;
 import java.nio.ByteBuffer;
 import java.util.List;
-import java.util.Map;
 import java.util.concurrent.atomic.AtomicReference;
 
 /**
@@ -68,7 +67,7 @@
   }
 
   private static final Logger LOG = LoggerFactory.getLogger(YarnUtils.class);
-  private static final AtomicReference<HadoopVersions> HADOOP_VERSION = new AtomicReference<HadoopVersions>();
+  private static final AtomicReference<HadoopVersions> HADOOP_VERSION = new AtomicReference<>();
 
   public static YarnLocalResource createLocalResource(LocalFile localFile) {
     Preconditions.checkArgument(localFile.getLastModified() >= 0, "Last modified time should be >= 0.");
@@ -155,19 +154,31 @@
       return ImmutableList.of();
     }
 
-    FileSystem fileSystem = getFileSystem(locationFactory);
+    LocationFactory factory = unwrap(locationFactory);
+    String renewer = getYarnTokenRenewer(config);
+    List<Token<?>> tokens = ImmutableList.of();
 
-    if (fileSystem == null) {
-      LOG.debug("LocationFactory is not HDFS");
-      return ImmutableList.of();
+    if (factory instanceof HDFSLocationFactory) {
+      FileSystem fs = ((HDFSLocationFactory) factory).getFileSystem();
+      Token<?>[] fsTokens = fs.addDelegationTokens(renewer, credentials);
+      if (fsTokens != null) {
+        tokens = ImmutableList.copyOf(fsTokens);
+      }
+    } else if (factory instanceof FileContextLocationFactory) {
+      FileContext fc = ((FileContextLocationFactory) locationFactory).getFileContext();
+      tokens = fc.getDelegationTokens(new Path(locationFactory.create("/").toURI()), renewer);
     }
 
-    String renewer = getYarnTokenRenewer(config);
+    for (Token<?> token : tokens) {
+      credentials.addToken(token.getService(), token);
+    }
 
-    Token<?>[] tokens = fileSystem.addDelegationTokens(renewer, credentials);
-    return tokens == null ? ImmutableList.<Token<?>>of() : ImmutableList.copyOf(tokens);
+    return ImmutableList.copyOf(tokens);
   }
 
+  /**
+   * Encodes the given {@link Credentials} as bytes.
+   */
   public static ByteBuffer encodeCredentials(Credentials credentials) {
     try {
       DataOutputBuffer out = new DataOutputBuffer();
@@ -268,26 +279,15 @@
     return localResource;
   }
 
-  private static <T> Map<String, T> transformResource(Map<String, YarnLocalResource> from) {
-    return Maps.transformValues(from, new Function<YarnLocalResource, T>() {
-      @Override
-      public T apply(YarnLocalResource resource) {
-        return resource.getLocalResource();
-      }
-    });
-  }
-
   /**
-   * Gets the Hadoop FileSystem from LocationFactory.
+   * Unwraps the given {@link LocationFactory} and returns the inner most {@link LocationFactory} which is not
+   * a {@link ForwardingLocationFactory}.
    */
-  private static FileSystem getFileSystem(LocationFactory locationFactory) {
-    if (locationFactory instanceof HDFSLocationFactory) {
-      return ((HDFSLocationFactory) locationFactory).getFileSystem();
+  private static LocationFactory unwrap(LocationFactory locationFactory) {
+    while (locationFactory instanceof ForwardingLocationFactory) {
+      locationFactory = ((ForwardingLocationFactory) locationFactory).getDelegate();
     }
-    if (locationFactory instanceof ForwardingLocationFactory) {
-      return getFileSystem(((ForwardingLocationFactory) locationFactory).getDelegate());
-    }
-    return null;
+    return locationFactory;
   }
 
   private YarnUtils() {
diff --git a/twill-yarn/src/main/java/org/apache/twill/yarn/YarnTwillPreparer.java b/twill-yarn/src/main/java/org/apache/twill/yarn/YarnTwillPreparer.java
index d04cdab..f7cb388 100644
--- a/twill-yarn/src/main/java/org/apache/twill/yarn/YarnTwillPreparer.java
+++ b/twill-yarn/src/main/java/org/apache/twill/yarn/YarnTwillPreparer.java
@@ -83,8 +83,6 @@
 import org.apache.twill.internal.yarn.YarnUtils;
 import org.apache.twill.launcher.FindFreePort;
 import org.apache.twill.launcher.TwillLauncher;
-import org.apache.twill.zookeeper.ZKClient;
-import org.apache.twill.zookeeper.ZKClients;
 import org.slf4j.Logger;
 import org.slf4j.LoggerFactory;
 
@@ -131,7 +129,6 @@
   private final List<String> applicationClassPaths = Lists.newArrayList();
   private final Credentials credentials;
   private final int reservedMemory;
-  private String user;
   private String schedulerQueue;
   private String extraOptions;
   private JvmOptions.DebugOptions debugOptions = JvmOptions.DebugOptions.NO_DEBUG;
@@ -152,7 +149,6 @@
     this.credentials = createCredentials();
     this.reservedMemory = yarnConfig.getInt(Configs.Keys.JAVA_RESERVED_MEMORY_MB,
                                             Configs.Defaults.JAVA_RESERVED_MEMORY_MB);
-    this.user = System.getProperty("user.name");
     this.extraOptions = extraOptions;
     this.logLevel = logLevel;
     this.classAcceptor = new ClassAcceptor();
@@ -166,7 +162,6 @@
 
   @Override
   public TwillPreparer setUser(String user) {
-    this.user = user;
     return this;
   }
 
diff --git a/twill-yarn/src/main/java/org/apache/twill/yarn/YarnTwillRunnerService.java b/twill-yarn/src/main/java/org/apache/twill/yarn/YarnTwillRunnerService.java
index c5853d6..67ee2ac 100644
--- a/twill-yarn/src/main/java/org/apache/twill/yarn/YarnTwillRunnerService.java
+++ b/twill-yarn/src/main/java/org/apache/twill/yarn/YarnTwillRunnerService.java
@@ -42,7 +42,7 @@
 import com.google.gson.JsonElement;
 import com.google.gson.JsonObject;
 import org.apache.hadoop.conf.Configuration;
-import org.apache.hadoop.fs.FileSystem;
+import org.apache.hadoop.fs.FileContext;
 import org.apache.hadoop.hdfs.DFSConfigKeys;
 import org.apache.hadoop.security.Credentials;
 import org.apache.hadoop.security.UserGroupInformation;
@@ -62,7 +62,7 @@
 import org.apache.twill.api.logging.LogHandler;
 import org.apache.twill.common.Cancellable;
 import org.apache.twill.common.Threads;
-import org.apache.twill.filesystem.HDFSLocationFactory;
+import org.apache.twill.filesystem.FileContextLocationFactory;
 import org.apache.twill.filesystem.Location;
 import org.apache.twill.filesystem.LocationFactory;
 import org.apache.twill.internal.Constants;
@@ -139,7 +139,7 @@
   private volatile String jvmOptions = null;
 
   /**
-   * Creates an instance with a {@link HDFSLocationFactory} created base on the given configuration with the
+   * Creates an instance with a {@link FileContextLocationFactory} created base on the given configuration with the
    * user home directory as the location factory namespace.
    *
    * @param config Configuration of the yarn cluster
@@ -612,8 +612,8 @@
 
   private static LocationFactory createDefaultLocationFactory(Configuration configuration) {
     try {
-      FileSystem fs = FileSystem.get(configuration);
-      return new HDFSLocationFactory(fs, fs.getHomeDirectory().toUri().getPath());
+      FileContext fc = FileContext.getFileContext(configuration);
+      return new FileContextLocationFactory(configuration, fc, fc.getHomeDirectory().toUri().getPath());
     } catch (IOException e) {
       throw Throwables.propagate(e);
     }
diff --git a/twill-yarn/src/test/java/org/apache/twill/filesystem/LocalLocationTest.java b/twill-yarn/src/test/java/org/apache/twill/filesystem/LocalLocationTest.java
index ba21beb..6bdba27 100644
--- a/twill-yarn/src/test/java/org/apache/twill/filesystem/LocalLocationTest.java
+++ b/twill-yarn/src/test/java/org/apache/twill/filesystem/LocalLocationTest.java
@@ -17,6 +17,8 @@
  */
 package org.apache.twill.filesystem;
 
+import org.junit.Assert;
+
 import java.io.File;
 
 /**
@@ -30,4 +32,10 @@
     basePath.mkdirs();
     return new LocalLocationFactory(basePath);
   }
+
+  @Override
+  public void testHomeLocation() throws Exception {
+    // For Local location, UGI won't take an effect.
+    Assert.assertEquals(System.getProperty("user.name"), createLocationFactory("/").getHomeLocation().getName());
+  }
 }
diff --git a/twill-yarn/src/test/java/org/apache/twill/filesystem/LocationTestBase.java b/twill-yarn/src/test/java/org/apache/twill/filesystem/LocationTestBase.java
index e01115b..af485dc 100644
--- a/twill-yarn/src/test/java/org/apache/twill/filesystem/LocationTestBase.java
+++ b/twill-yarn/src/test/java/org/apache/twill/filesystem/LocationTestBase.java
@@ -22,6 +22,7 @@
 import com.google.common.cache.CacheLoader;
 import com.google.common.cache.LoadingCache;
 import com.google.common.io.CharStreams;
+import org.apache.hadoop.security.UserGroupInformation;
 import org.junit.Assert;
 import org.junit.ClassRule;
 import org.junit.Test;
@@ -33,7 +34,7 @@
 import java.io.Reader;
 import java.io.Writer;
 import java.net.URI;
-import java.net.URISyntaxException;
+import java.security.PrivilegedExceptionAction;
 import java.util.List;
 
 /**
@@ -96,6 +97,25 @@
   }
 
   @Test
+  public void testHomeLocation() throws Exception {
+    LocationFactory locationFactory = createLocationFactory("/");
+
+    // Without UGI, the home location should be the same as the user
+    Assert.assertEquals(System.getProperty("user.name"), locationFactory.getHomeLocation().getName());
+
+    // With UGI, the home location should be based on the UGI current user
+    UserGroupInformation ugi = UserGroupInformation.createRemoteUser(System.getProperty("user.name") + "1");
+    locationFactory = ugi.doAs(new PrivilegedExceptionAction<LocationFactory>() {
+      @Override
+      public LocationFactory run() throws Exception {
+        return createLocationFactory("/");
+      }
+    });
+
+    Assert.assertEquals(ugi.getUserName(), locationFactory.getHomeLocation().getName());
+  }
+
+  @Test
   public void testDelete() throws IOException {
     LocationFactory factory = locationFactoryCache.getUnchecked("delete");
 
diff --git a/twill-yarn/src/test/java/org/apache/twill/yarn/BaseYarnTest.java b/twill-yarn/src/test/java/org/apache/twill/yarn/BaseYarnTest.java
index a9cf2ed..5d67dfa 100644
--- a/twill-yarn/src/test/java/org/apache/twill/yarn/BaseYarnTest.java
+++ b/twill-yarn/src/test/java/org/apache/twill/yarn/BaseYarnTest.java
@@ -17,6 +17,7 @@
  */
 package org.apache.twill.yarn;
 
+import com.google.common.base.Stopwatch;
 import com.google.common.collect.Iterables;
 import org.apache.hadoop.yarn.api.records.NodeReport;
 import org.apache.twill.api.TwillController;
@@ -28,6 +29,7 @@
 import org.slf4j.LoggerFactory;
 
 import java.util.List;
+import java.util.concurrent.Callable;
 import java.util.concurrent.TimeUnit;
 import java.util.concurrent.atomic.AtomicInteger;
 
@@ -101,6 +103,24 @@
     return trial < limit;
   }
 
+  /**
+   * Waits for a task returns the expected value.
+   *
+   * @param expected the expected value
+   * @param callable the task to execute
+   * @param timeout timeout of the wait
+   * @param delay delay between calls to the task to poll for the latest value
+   * @param unit unit for the timeout and delay
+   * @param <T> type of the expected value
+   * @throws Exception if the task through exception or timeout.
+   */
+  public <T> void waitFor(T expected, Callable<T> callable, long timeout, long delay, TimeUnit unit) throws Exception {
+    Stopwatch stopwatch = new Stopwatch().start();
+    while (callable.call() != expected && stopwatch.elapsedTime(unit) < timeout) {
+      unit.sleep(delay);
+    }
+  }
+
   @SuppressWarnings("unchecked")
   public <T extends TwillRunner> T getTwillRunner() {
     return (T) TWILL_TESTER.getTwillRunner();
diff --git a/twill-yarn/src/test/java/org/apache/twill/yarn/EchoServerTestRun.java b/twill-yarn/src/test/java/org/apache/twill/yarn/EchoServerTestRun.java
index 13c07b1..02075a7 100644
--- a/twill-yarn/src/test/java/org/apache/twill/yarn/EchoServerTestRun.java
+++ b/twill-yarn/src/test/java/org/apache/twill/yarn/EchoServerTestRun.java
@@ -32,6 +32,7 @@
 import org.apache.twill.common.Threads;
 import org.apache.twill.discovery.Discoverable;
 import org.apache.twill.zookeeper.ZKClientService;
+import org.apache.zookeeper.data.Stat;
 import org.junit.Assert;
 import org.junit.Test;
 import org.slf4j.Logger;
@@ -45,6 +46,7 @@
 import java.util.Collection;
 import java.util.List;
 import java.util.Map;
+import java.util.concurrent.Callable;
 import java.util.concurrent.CountDownLatch;
 import java.util.concurrent.TimeUnit;
 import javax.annotation.Nullable;
@@ -158,7 +160,7 @@
 
   @Test
   public void testZKCleanup() throws Exception {
-    ZKClientService zkClient = ZKClientService.Builder.of(getZKConnectionString() + "/twill").build();
+    final ZKClientService zkClient = ZKClientService.Builder.of(getZKConnectionString() + "/twill").build();
     zkClient.startAndWait();
 
     try {
@@ -177,7 +179,12 @@
       controller.terminate().get();
 
       // Verify the ZK node gets cleanup
-      Assert.assertNull(zkClient.exists("/EchoServer").get());
+      waitFor(null, new Callable<Stat>() {
+        @Override
+        public Stat call() throws Exception {
+          return zkClient.exists("/EchoServer").get();
+        }
+      }, 10000, 100, TimeUnit.MILLISECONDS);
 
       // Start two instances of the application and stop one of it
       List<TwillController> controllers = new ArrayList<>();
@@ -207,7 +214,12 @@
       controllers.get(1).terminate().get();
 
       // Verify the ZK node gets cleanup
-      Assert.assertNull(zkClient.exists("/EchoServer").get());
+      waitFor(null, new Callable<Stat>() {
+        @Override
+        public Stat call() throws Exception {
+          return zkClient.exists("/EchoServer").get();
+        }
+      }, 10000, 100, TimeUnit.MILLISECONDS);
 
     } finally {
       zkClient.stopAndWait();