blob: 4ff3dc91af0ea0588e30e7f7ac2cde1620b4b60e [file] [log] [blame]
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
* <p>
* http://www.apache.org/licenses/LICENSE-2.0
* <p>
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.sentry.service.thrift;
import com.codahale.metrics.Counter;
import com.google.common.collect.ImmutableMap;
import com.google.common.util.concurrent.ThreadFactoryBuilder;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.hive.metastore.api.Database;
import org.apache.hadoop.hive.metastore.api.Partition;
import org.apache.hadoop.hive.metastore.api.Table;
import org.apache.sentry.hdfs.PathsUpdate;
import org.apache.sentry.hdfs.SentryMalformedPathException;
import org.apache.sentry.hdfs.ServiceConstants.ServerConfig;
import org.apache.sentry.api.service.thrift.SentryMetrics;
import org.apache.thrift.TException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
import java.util.Deque;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
import java.util.Map;
import java.util.HashMap;
import java.util.concurrent.Callable;
import java.util.concurrent.ConcurrentLinkedDeque;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.TimeUnit;
import static com.codahale.metrics.MetricRegistry.name;
/**
* Manage fetching full snapshot from HMS.
* Snapshot is represented as a map from the hive object name to
* the set of paths for this object.
* The hive object name is either the Hive database name or
* Hive database name joined with Hive table name as {@code dbName.tableName}.
* All table partitions are stored under the table object.
* <p>
* Once {@link FullUpdateInitializer}, the {@link FullUpdateInitializer#getFullHMSSnapshot()}
* method should be called to get the initial update.
* <p>
* It is important to close the {@link FullUpdateInitializer} object to prevent resource
* leaks.
* <p>
* The usual way of using {@link FullUpdateInitializer} is
* <pre>
* {@code
* try (FullUpdateInitializer updateInitializer =
* new FullUpdateInitializer(clientFactory, authzConf)) {
* Map<String, Set<String>> pathsUpdate = updateInitializer.getFullHMSSnapshot();
* return pathsUpdate;
* }
*/
public final class FullUpdateInitializer implements AutoCloseable {
/*
* Implementation note.
*
* The snapshot is obtained using an executor. We follow the map/reduce model.
* Each executor thread (mapper) obtains and returns a partial snapshot which are then
* reduced to a single combined snapshot by getFullHMSSnapshot().
*
* Synchronization between the getFullHMSSnapshot() and executors is done using the
* 'results' queue. The queue holds the futures for each scheduled task.
* It is initially populated by getFullHMSSnapshot and each task may add new future
* results to it. Only getFullHMSSnapshot() removes entries from the results queue.
* This guarantees that once the results queue is empty there are no pending jobs.
*
* Since there are no other data sharing, the implementation is safe without
* any other synchronization. It is not thread-safe for concurrent calls
* to getFullHMSSnapshot().
*
*/
private static final String FULL_UPDATE_INITIALIZER_THREAD_NAME = "hms-fetch-%d";
private final ExecutorService threadPool;
private final int maxPartitionsPerCall;
private final int maxTablesPerCall;
private final Deque<Future<CallResult>> results = new ConcurrentLinkedDeque<>();
private final int maxRetries;
private final int waitDurationMillis;
private static final Logger LOGGER = LoggerFactory.getLogger(FullUpdateInitializer.class);
private static final ObjectMapping emptyObjectMapping =
new ObjectMapping(Collections.<String, Set<String>>emptyMap());
private final HiveConnectionFactory clientFactory;
/** Total number of database objects */
private final Counter databaseCount = SentryMetrics.getInstance()
.getCounter(name(FullUpdateInitializer.class, "total", "db"));
/** Total number of table objects */
private final Counter tableCount = SentryMetrics.getInstance()
.getCounter(name(FullUpdateInitializer.class, "total", "tables"));
/** Total number of partition objects */
private final Counter partitionCount = SentryMetrics.getInstance()
.getCounter(name(FullUpdateInitializer.class, "total", "partitions"));
/**
* Extract path (not starting with "/") from the full URI
* @param uri - resource URI (usually with scheme)
* @return path if uri is valid or null
*/
static String pathFromURI(String uri) {
try {
return PathsUpdate.parsePath(uri);
} catch (SentryMalformedPathException e) {
LOGGER.warn(String.format("Ignoring invalid uri %s: %s",
uri, e.getReason()));
return null;
}
}
/**
* Mapping of object to set of paths.
* Used to represent partial results from executor threads. Multiple
* ObjectMapping objects are combined in a single mapping
* to get the final result.
*/
private static final class ObjectMapping {
private final Map<String, Set<String>> objects;
ObjectMapping(Map<String, Set<String>> objects) {
this.objects = objects;
}
ObjectMapping(String authObject, String path) {
Set<String> values = Collections.singleton(safeIntern(path));
objects = ImmutableMap.of(authObject, values);
}
ObjectMapping(String authObject, Collection<String> paths) {
Set<String> values = new HashSet<>(paths);
objects = ImmutableMap.of(authObject, values);
}
Map<String, Set<String>> getObjects() {
return objects;
}
}
private static final class CallResult {
private final Exception failure;
private final boolean successStatus;
private final ObjectMapping objectMapping;
CallResult(Exception ex) {
failure = ex;
successStatus = false;
objectMapping = emptyObjectMapping;
}
CallResult(ObjectMapping objectMapping) {
failure = null;
successStatus = true;
this.objectMapping = objectMapping;
}
boolean success() {
return successStatus;
}
ObjectMapping getObjectMapping() {
return objectMapping;
}
public Exception getFailure() {
return failure;
}
}
private abstract class BaseTask implements Callable<CallResult> {
/**
* Class represents retry strategy for BaseTask.
*/
private final class RetryStrategy {
private int retryStrategyMaxRetries = 0;
private final int retryStrategyWaitDurationMillis;
private RetryStrategy(int retryStrategyMaxRetries, int retryStrategyWaitDurationMillis) {
this.retryStrategyMaxRetries = retryStrategyMaxRetries;
// Assign default wait duration if negative value is provided.
this.retryStrategyWaitDurationMillis = (retryStrategyWaitDurationMillis > 0) ?
retryStrategyWaitDurationMillis : 1000;
}
@SuppressWarnings({"squid:S1141", "squid:S2142"})
public CallResult exec() {
// Retry logic is happening inside callable/task to avoid
// synchronous waiting on getting the result.
// Retry the failure task until reach the max retry number.
// Wait configurable duration for next retry.
//
// Only thrift exceptions are retried.
// Other exceptions are propagated up the stack.
Exception exception = null;
try {
// We catch all exceptions except Thrift exceptions which are retried
for (int i = 0; i < retryStrategyMaxRetries; i++) {
//noinspection NestedTryStatement
try {
return new CallResult(doTask());
} catch (TException ex) {
LOGGER.debug("Failed to execute task on " + (i + 1) + " attempts." +
" Sleeping for " + retryStrategyWaitDurationMillis + " ms. Exception: " +
ex.toString(), ex);
exception = ex;
try {
Thread.sleep(retryStrategyWaitDurationMillis);
} catch (InterruptedException ignored) {
// Skip the rest retries if get InterruptedException.
// And set the corresponding retries number.
LOGGER.warn("Interrupted during update fetch during iteration " + (i + 1));
break;
}
}
}
} catch (Exception ex) {
exception = ex;
}
LOGGER.error("Failed to execute task", exception);
// We will fail in the end, so we are shutting down the pool to prevent
// new tasks from being scheduled.
threadPool.shutdown();
return new CallResult(exception);
}
}
private final RetryStrategy retryStrategy;
BaseTask() {
retryStrategy = new RetryStrategy(maxRetries, waitDurationMillis);
}
@Override
public CallResult call() throws Exception {
return retryStrategy.exec();
}
abstract ObjectMapping doTask() throws Exception;
}
private class PartitionTask extends BaseTask {
private final String dbName;
private final String tblName;
private final String authName;
private final List<String> partNames;
PartitionTask(String dbName, String tblName, String authName,
List<String> partNames) {
this.dbName = safeIntern(dbName);
this.tblName = safeIntern(tblName);
this.authName = safeIntern(authName);
this.partNames = partNames;
}
@Override
ObjectMapping doTask() throws Exception {
List<Partition> tblParts;
HMSClient c = null;
try (HMSClient client = clientFactory.connect()) {
c = client;
tblParts = client.getClient().getPartitionsByNames(dbName, tblName, partNames);
} catch (Exception e) {
if (c != null) {
c.invalidate();
}
throw e;
}
LOGGER.debug("Fetched partitions for db = {}, table = {}",
dbName, tblName);
Collection<String> partitionNames = new ArrayList<>(tblParts.size());
for (Partition part : tblParts) {
if(part != null && part.getSd() != null) {
String partPath = pathFromURI(part.getSd().getLocation());
if (partPath != null) {
partitionNames.add(partPath.intern());
}
} else {
LOGGER.info("Partition or its storage descriptor is null while fetching partitions for db = {} table = {}", dbName, tblName);
}
}
return new ObjectMapping(authName, partitionNames);
}
}
private class TableTask extends BaseTask {
private final String dbName;
private final List<String> tableNames;
TableTask(Database db, List<String> tableNames) {
dbName = safeIntern(db.getName());
this.tableNames = tableNames;
}
@Override
@SuppressWarnings({"squid:S2629", "squid:S135"})
ObjectMapping doTask() throws Exception {
HMSClient c = null;
try (HMSClient client = clientFactory.connect()) {
c = client;
List<Table> tables = client.getClient().getTableObjectsByName(dbName, tableNames);
LOGGER.debug("Fetching tables for db = {}, tables = {}", dbName, tableNames);
Map<String, Set<String>> objectMapping = new HashMap<>(tables.size());
for (Table tbl : tables) {
// Table names are case insensitive
if (!tbl.getDbName().equalsIgnoreCase(dbName)) {
// Inconsistency in HMS data
LOGGER.warn(String.format("DB name %s for table %s does not match %s",
tbl.getDbName(), tbl.getTableName(), dbName));
continue;
}
String tableName = safeIntern(tbl.getTableName().toLowerCase());
String authzObject = (dbName + "." + tableName).intern();
List<String> tblPartNames =
client.getClient().listPartitionNames(dbName, tableName, (short) -1);
// Count total number of partitions
partitionCount.inc(tblPartNames.size());
for (int i = 0; i < tblPartNames.size(); i += maxPartitionsPerCall) {
List<String> partsToFetch = tblPartNames.subList(i,
Math.min(i + maxPartitionsPerCall, tblPartNames.size()));
Callable<CallResult> partTask = new PartitionTask(dbName,
tableName, authzObject, partsToFetch);
results.add(threadPool.submit(partTask));
}
String tblPath = safeIntern(pathFromURI(tbl.getSd().getLocation()));
if (tblPath == null) {
continue;
}
Set<String> paths = objectMapping.get(authzObject);
if (paths == null) {
paths = new HashSet<>(1);
objectMapping.put(authzObject, paths);
}
paths.add(tblPath);
}
return new ObjectMapping(Collections.unmodifiableMap(objectMapping));
} catch (Exception e) {
if (c != null) {
c.invalidate();
}
throw e;
}
}
}
private class DbTask extends BaseTask {
private final String dbName;
DbTask(String dbName) {
//Database names are case insensitive
this.dbName = safeIntern(dbName.toLowerCase());
databaseCount.inc();
}
@Override
ObjectMapping doTask() throws Exception {
HMSClient c = null;
try (HMSClient client = clientFactory.connect()) {
c = client;
Database db = client.getClient().getDatabase(dbName);
if (!dbName.equalsIgnoreCase(db.getName())) {
LOGGER.warn("Database name {} does not match {}", db.getName(), dbName);
return emptyObjectMapping;
}
List<String> allTblStr = client.getClient().getAllTables(dbName);
// Count total number of tables
tableCount.inc(allTblStr.size());
for (int i = 0; i < allTblStr.size(); i += maxTablesPerCall) {
List<String> tablesToFetch = allTblStr.subList(i,
Math.min(i + maxTablesPerCall, allTblStr.size()));
Callable<CallResult> tableTask = new TableTask(db, tablesToFetch);
results.add(threadPool.submit(tableTask));
}
String dbPath = safeIntern(pathFromURI(db.getLocationUri()));
return (dbPath != null) ? new ObjectMapping(dbName, dbPath) :
emptyObjectMapping;
} catch (Exception e) {
if (c != null) {
c.invalidate();
}
throw e;
}
}
}
FullUpdateInitializer(HiveConnectionFactory clientFactory, Configuration conf) {
this.clientFactory = clientFactory;
maxPartitionsPerCall = conf.getInt(
ServerConfig.SENTRY_HDFS_SYNC_METASTORE_CACHE_MAX_PART_PER_RPC,
ServerConfig.SENTRY_HDFS_SYNC_METASTORE_CACHE_MAX_PART_PER_RPC_DEFAULT);
maxTablesPerCall = conf.getInt(
ServerConfig.SENTRY_HDFS_SYNC_METASTORE_CACHE_MAX_TABLES_PER_RPC,
ServerConfig.SENTRY_HDFS_SYNC_METASTORE_CACHE_MAX_TABLES_PER_RPC_DEFAULT);
maxRetries = conf.getInt(
ServerConfig.SENTRY_HDFS_SYNC_METASTORE_CACHE_RETRY_MAX_NUM,
ServerConfig.SENTRY_HDFS_SYNC_METASTORE_CACHE_RETRY_MAX_NUM_DEFAULT);
waitDurationMillis = conf.getInt(
ServerConfig.SENTRY_HDFS_SYNC_METASTORE_CACHE_RETRY_WAIT_DURAION_IN_MILLIS,
ServerConfig.SENTRY_HDFS_SYNC_METASTORE_CACHE_RETRY_WAIT_DURAION_IN_MILLIS_DEFAULT);
ThreadFactory fullUpdateInitThreadFactory = new ThreadFactoryBuilder()
.setNameFormat(FULL_UPDATE_INITIALIZER_THREAD_NAME)
.setDaemon(false)
.build();
threadPool = Executors.newFixedThreadPool(conf.getInt(
ServerConfig.SENTRY_HDFS_SYNC_METASTORE_CACHE_INIT_THREADS,
ServerConfig.SENTRY_HDFS_SYNC_METASTORE_CACHE_INIT_THREADS_DEFAULT),
fullUpdateInitThreadFactory);
}
/**
* Get Full HMS snapshot.
* @return Full snapshot of HMS objects.
* @throws TException if Thrift error occured
* @throws ExecutionException if there was a scheduling error
* @throws InterruptedException if processing was interrupted
*/
@SuppressWarnings("squid:S00112")
Map<String, Collection<String>> getFullHMSSnapshot() throws Exception {
// Get list of all HMS databases
List<String> allDbStr;
HMSClient c = null;
try (HMSClient client = clientFactory.connect()) {
c = client;
allDbStr = client.getClient().getAllDatabases();
} catch (Exception e) {
if (c != null) {
c.invalidate();
}
throw e;
}
// Schedule async task for each database responsible for fetching per-database
// objects.
for (String dbName : allDbStr) {
results.add(threadPool.submit(new DbTask(dbName)));
}
// Resulting full snapshot
Map<String, Collection<String>> fullSnapshot = new HashMap<>();
// As async tasks complete, merge their results into full snapshot.
while (!results.isEmpty()) {
// This is the only thread that takes elements off the results list - all other threads
// only add to it. Once the list is empty it can't become non-empty
// This means that if we check that results is non-empty we can safely call pop() and
// know that the result of poll() is not null.
Future<CallResult> result = results.pop();
// Wait for the task to complete
CallResult callResult = result.get();
// Fail if we got errors
if (!callResult.success()) {
throw callResult.getFailure();
}
// Merge values into fullUpdate
Map<String, Set<String>> objectMapping =
callResult.getObjectMapping().getObjects();
for (Map.Entry<String, Set<String>> entry: objectMapping.entrySet()) {
String key = entry.getKey();
Set<String> val = entry.getValue();
Set<String> existingSet = (Set<String>)fullSnapshot.get(key);
if (existingSet == null) {
fullSnapshot.put(key, val);
continue;
}
existingSet.addAll(val);
}
}
return fullSnapshot;
}
@Override
public void close() {
threadPool.shutdownNow();
try {
threadPool.awaitTermination(1, TimeUnit.SECONDS);
} catch (InterruptedException ignored) {
LOGGER.warn("Interrupted shutdown");
Thread.currentThread().interrupt();
}
}
/**
* Intern a string but only if it is not null
* @param arg String to be interned, may be null
* @return interned string or null
*/
static String safeIntern(String arg) {
return (arg != null) ? arg.intern() : null;
}
}