blob: 66b65ba03f0447fe07836b98eae4b2f6009d146c [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.hbase.security.provider;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import java.io.ByteArrayInputStream;
import java.io.DataInput;
import java.io.DataInputStream;
import java.io.DataOutput;
import java.io.File;
import java.io.IOException;
import java.net.InetAddress;
import java.security.PrivilegedExceptionAction;
import java.util.Arrays;
import java.util.Collection;
import java.util.List;
import java.util.Map;
import java.util.Optional;
import java.util.concurrent.ConcurrentHashMap;
import java.util.stream.Collectors;
import javax.security.auth.callback.Callback;
import javax.security.auth.callback.CallbackHandler;
import javax.security.auth.callback.NameCallback;
import javax.security.auth.callback.PasswordCallback;
import javax.security.auth.callback.UnsupportedCallbackException;
import javax.security.sasl.AuthorizeCallback;
import javax.security.sasl.RealmCallback;
import javax.security.sasl.RealmChoiceCallback;
import javax.security.sasl.Sasl;
import javax.security.sasl.SaslClient;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.fs.Path;
import org.apache.hadoop.hbase.Cell;
import org.apache.hadoop.hbase.CellUtil;
import org.apache.hadoop.hbase.HBaseTestingUtil;
import org.apache.hadoop.hbase.HConstants;
import org.apache.hadoop.hbase.LocalHBaseCluster;
import org.apache.hadoop.hbase.TableName;
import org.apache.hadoop.hbase.TableNameTestRule;
import org.apache.hadoop.hbase.client.Admin;
import org.apache.hadoop.hbase.client.ColumnFamilyDescriptorBuilder;
import org.apache.hadoop.hbase.client.Connection;
import org.apache.hadoop.hbase.client.ConnectionFactory;
import org.apache.hadoop.hbase.client.Get;
import org.apache.hadoop.hbase.client.Put;
import org.apache.hadoop.hbase.client.Result;
import org.apache.hadoop.hbase.client.RetriesExhaustedException;
import org.apache.hadoop.hbase.client.Table;
import org.apache.hadoop.hbase.client.TableDescriptorBuilder;
import org.apache.hadoop.hbase.coprocessor.CoprocessorHost;
import org.apache.hadoop.hbase.exceptions.MasterRegistryFetchException;
import org.apache.hadoop.hbase.ipc.BlockingRpcClient;
import org.apache.hadoop.hbase.ipc.NettyRpcClient;
import org.apache.hadoop.hbase.ipc.RpcClientFactory;
import org.apache.hadoop.hbase.ipc.RpcServerFactory;
import org.apache.hadoop.hbase.security.AccessDeniedException;
import org.apache.hadoop.hbase.security.HBaseKerberosUtils;
import org.apache.hadoop.hbase.security.SaslUtil;
import org.apache.hadoop.hbase.security.User;
import org.apache.hadoop.hbase.security.token.SecureTestCluster;
import org.apache.hadoop.hbase.security.token.TokenProvider;
import org.apache.hadoop.hbase.util.Bytes;
import org.apache.hadoop.hbase.util.CommonFSUtils;
import org.apache.hadoop.hbase.util.Pair;
import org.apache.hadoop.io.Text;
import org.apache.hadoop.io.WritableUtils;
import org.apache.hadoop.minikdc.MiniKdc;
import org.apache.hadoop.security.UserGroupInformation;
import org.apache.hadoop.security.UserGroupInformation.AuthenticationMethod;
import org.apache.hadoop.security.token.SecretManager;
import org.apache.hadoop.security.token.SecretManager.InvalidToken;
import org.apache.hadoop.security.token.Token;
import org.apache.hadoop.security.token.TokenIdentifier;
import org.junit.After;
import org.junit.AfterClass;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.runners.Parameterized.Parameter;
import org.junit.runners.Parameterized.Parameters;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import org.apache.hbase.thirdparty.com.google.common.base.Throwables;
import org.apache.hadoop.hbase.shaded.protobuf.generated.RPCProtos.UserInformation;
/**
* Tests the pluggable authentication framework with SASL using a contrived authentication system.
* This tests holds a "user database" in memory as a hashmap. Clients provide their password in the
* client Hadoop configuration. The servers validate this password via the "user database".
*/
public abstract class CustomSaslAuthenticationProviderTestBase {
private static final Logger LOG =
LoggerFactory.getLogger(CustomSaslAuthenticationProviderTestBase.class);
private static final Map<String, String> USER_DATABASE = createUserDatabase();
private static final String USER1_PASSWORD = "foobarbaz";
private static final String USER2_PASSWORD = "bazbarfoo";
@Parameters
public static Collection<Object[]> parameters() {
return Arrays.asList(new Object[] { BlockingRpcClient.class.getName() },
new Object[] { NettyRpcClient.class.getName() });
}
@Parameter
public String rpcClientImpl;
private static Map<String, String> createUserDatabase() {
Map<String, String> db = new ConcurrentHashMap<>();
db.put("user1", USER1_PASSWORD);
db.put("user2", USER2_PASSWORD);
return db;
}
public static String getPassword(String user) {
String password = USER_DATABASE.get(user);
if (password == null) {
throw new IllegalStateException("Cannot request password for a user that doesn't exist");
}
return password;
}
/**
* A custom token identifier for our custom auth'n method. Unique from the TokenIdentifier used
* for delegation tokens.
*/
public static class PasswordAuthTokenIdentifier extends TokenIdentifier {
public static final Text PASSWORD_AUTH_TOKEN = new Text("HBASE_PASSWORD_TEST_TOKEN");
private String username;
public PasswordAuthTokenIdentifier() {
}
public PasswordAuthTokenIdentifier(String username) {
this.username = username;
}
@Override
public void readFields(DataInput in) throws IOException {
this.username = WritableUtils.readString(in);
}
@Override
public void write(DataOutput out) throws IOException {
WritableUtils.writeString(out, username);
}
@Override
public Text getKind() {
return PASSWORD_AUTH_TOKEN;
}
@Override
public UserGroupInformation getUser() {
if (username == null || "".equals(username)) {
return null;
}
return UserGroupInformation.createRemoteUser(username);
}
}
public static Token<? extends TokenIdentifier> createPasswordToken(String username,
String password, String clusterId) {
PasswordAuthTokenIdentifier id = new PasswordAuthTokenIdentifier(username);
Token<? extends TokenIdentifier> token =
new Token<>(id.getBytes(), Bytes.toBytes(password), id.getKind(), new Text(clusterId));
return token;
}
/**
* Client provider that finds custom Token in the user's UGI and authenticates with the server via
* DIGEST-MD5 using that password.
*/
public static class InMemoryClientProvider extends AbstractSaslClientAuthenticationProvider {
public static final String MECHANISM = "DIGEST-MD5";
public static final SaslAuthMethod SASL_AUTH_METHOD =
new SaslAuthMethod("IN_MEMORY", (byte) 42, MECHANISM, AuthenticationMethod.TOKEN);
@Override
public SaslClient createClient(Configuration conf, InetAddress serverAddr,
String serverPrincipal, Token<? extends TokenIdentifier> token, boolean fallbackAllowed,
Map<String, String> saslProps) throws IOException {
return Sasl.createSaslClient(new String[] { MECHANISM }, null, null,
SaslUtil.SASL_DEFAULT_REALM, saslProps, new InMemoryClientProviderCallbackHandler(token));
}
public Optional<Token<? extends TokenIdentifier>> findToken(User user) {
List<Token<? extends TokenIdentifier>> tokens = user.getTokens().stream()
.filter((token) -> token.getKind().equals(PasswordAuthTokenIdentifier.PASSWORD_AUTH_TOKEN))
.collect(Collectors.toList());
if (tokens.isEmpty()) {
return Optional.empty();
}
if (tokens.size() > 1) {
throw new IllegalStateException("Cannot handle more than one PasswordAuthToken");
}
return Optional.of(tokens.get(0));
}
@Override
public SaslAuthMethod getSaslAuthMethod() {
return SASL_AUTH_METHOD;
}
/**
* Sasl CallbackHandler which extracts information from our custom token and places it into the
* Sasl objects.
*/
public class InMemoryClientProviderCallbackHandler implements CallbackHandler {
private final Token<? extends TokenIdentifier> token;
public InMemoryClientProviderCallbackHandler(Token<? extends TokenIdentifier> token) {
this.token = token;
}
@Override
public void handle(Callback[] callbacks) throws UnsupportedCallbackException {
NameCallback nc = null;
PasswordCallback pc = null;
RealmCallback rc = null;
for (Callback callback : callbacks) {
if (callback instanceof RealmChoiceCallback) {
continue;
} else if (callback instanceof NameCallback) {
nc = (NameCallback) callback;
} else if (callback instanceof PasswordCallback) {
pc = (PasswordCallback) callback;
} else if (callback instanceof RealmCallback) {
rc = (RealmCallback) callback;
} else {
throw new UnsupportedCallbackException(callback, "Unrecognized SASL client callback");
}
}
if (nc != null) {
nc.setName(SaslUtil.encodeIdentifier(token.getIdentifier()));
}
if (pc != null) {
pc.setPassword(SaslUtil.encodePassword(token.getPassword()));
}
if (rc != null) {
rc.setText(rc.getDefaultText());
}
}
}
@Override
public UserInformation getUserInfo(User user) {
return null;
}
}
/**
* Server provider which validates credentials from an in-memory database.
*/
public static class InMemoryServerProvider extends InMemoryClientProvider
implements SaslServerAuthenticationProvider {
@Override
public AttemptingUserProvidingSaslServer
createServer(SecretManager<TokenIdentifier> secretManager, Map<String, String> saslProps)
throws IOException {
return new AttemptingUserProvidingSaslServer(
Sasl.createSaslServer(getSaslAuthMethod().getSaslMechanism(), null,
SaslUtil.SASL_DEFAULT_REALM, saslProps, new InMemoryServerProviderCallbackHandler()),
() -> null);
}
/**
* Pulls the correct password for the user who started the SASL handshake so that SASL can
* validate that the user provided the right password.
*/
private class InMemoryServerProviderCallbackHandler implements CallbackHandler {
@Override
public void handle(Callback[] callbacks) throws InvalidToken, UnsupportedCallbackException {
NameCallback nc = null;
PasswordCallback pc = null;
AuthorizeCallback ac = null;
for (Callback callback : callbacks) {
if (callback instanceof AuthorizeCallback) {
ac = (AuthorizeCallback) callback;
} else if (callback instanceof NameCallback) {
nc = (NameCallback) callback;
} else if (callback instanceof PasswordCallback) {
pc = (PasswordCallback) callback;
} else if (callback instanceof RealmCallback) {
continue; // realm is ignored
} else {
throw new UnsupportedCallbackException(callback, "Unrecognized SASL Callback");
}
}
if (nc != null && pc != null) {
byte[] encodedId = SaslUtil.decodeIdentifier(nc.getDefaultName());
PasswordAuthTokenIdentifier id = new PasswordAuthTokenIdentifier();
try {
id.readFields(new DataInputStream(new ByteArrayInputStream(encodedId)));
} catch (IOException e) {
throw (InvalidToken) new InvalidToken("Can't de-serialize tokenIdentifier")
.initCause(e);
}
char[] actualPassword =
SaslUtil.encodePassword(Bytes.toBytes(getPassword(id.getUser().getUserName())));
pc.setPassword(actualPassword);
}
if (ac != null) {
String authid = ac.getAuthenticationID();
String authzid = ac.getAuthorizationID();
if (authid.equals(authzid)) {
ac.setAuthorized(true);
} else {
ac.setAuthorized(false);
}
if (ac.isAuthorized()) {
ac.setAuthorizedID(authzid);
}
}
}
}
@Override
public boolean supportsProtocolAuthentication() {
return false;
}
@Override
public UserGroupInformation getAuthorizedUgi(String authzId,
SecretManager<TokenIdentifier> secretManager) throws IOException {
UserGroupInformation authorizedUgi;
byte[] encodedId = SaslUtil.decodeIdentifier(authzId);
PasswordAuthTokenIdentifier tokenId = new PasswordAuthTokenIdentifier();
try {
tokenId.readFields(new DataInputStream(new ByteArrayInputStream(encodedId)));
} catch (IOException e) {
throw new IOException("Can't de-serialize PasswordAuthTokenIdentifier", e);
}
authorizedUgi = tokenId.getUser();
if (authorizedUgi == null) {
throw new AccessDeniedException("Can't retrieve username from tokenIdentifier.");
}
authorizedUgi.addTokenIdentifier(tokenId);
authorizedUgi.setAuthenticationMethod(getSaslAuthMethod().getAuthMethod());
return authorizedUgi;
}
}
/**
* Custom provider which can select our custom provider, amongst other tokens which may be
* available.
*/
public static class InMemoryProviderSelector extends BuiltInProviderSelector {
private InMemoryClientProvider inMemoryProvider;
@Override
public void configure(Configuration conf,
Collection<SaslClientAuthenticationProvider> providers) {
super.configure(conf, providers);
Optional<SaslClientAuthenticationProvider> o =
providers.stream().filter((p) -> p instanceof InMemoryClientProvider).findAny();
inMemoryProvider = (InMemoryClientProvider) o.orElseThrow(() -> new RuntimeException(
"InMemoryClientProvider not found in available providers: " + providers));
}
@Override
public Pair<SaslClientAuthenticationProvider, Token<? extends TokenIdentifier>>
selectProvider(String clusterId, User user) {
Pair<SaslClientAuthenticationProvider, Token<? extends TokenIdentifier>> superPair =
super.selectProvider(clusterId, user);
Optional<Token<? extends TokenIdentifier>> optional = inMemoryProvider.findToken(user);
if (optional.isPresent()) {
LOG.info("Using InMemoryClientProvider");
return new Pair<>(inMemoryProvider, optional.get());
}
LOG.info("InMemoryClientProvider not usable, falling back to {}", superPair);
return superPair;
}
}
private static void createBaseCluster(HBaseTestingUtil util, File keytabFile, MiniKdc kdc)
throws Exception {
String servicePrincipal = "hbase/localhost";
String spnegoPrincipal = "HTTP/localhost";
kdc.createPrincipal(keytabFile, servicePrincipal);
util.startMiniZKCluster();
HBaseKerberosUtils.setSecuredConfiguration(util.getConfiguration(),
servicePrincipal + "@" + kdc.getRealm(), spnegoPrincipal + "@" + kdc.getRealm());
HBaseKerberosUtils.setSSLConfiguration(util, SecureTestCluster.class);
util.getConfiguration().setStrings(CoprocessorHost.REGION_COPROCESSOR_CONF_KEY,
TokenProvider.class.getName());
util.startMiniDFSCluster(1);
Path rootdir = util.getDataTestDirOnTestFS("TestCustomSaslAuthenticationProvider");
CommonFSUtils.setRootDir(util.getConfiguration(), rootdir);
}
private static final HBaseTestingUtil UTIL = new HBaseTestingUtil();
private static final Configuration CONF = UTIL.getConfiguration();
private static LocalHBaseCluster CLUSTER;
private static File KEYTAB_FILE;
protected static void startCluster(String rpcServerImpl) throws Exception {
KEYTAB_FILE = new File(UTIL.getDataTestDir("keytab").toUri().getPath());
final MiniKdc kdc = UTIL.setupMiniKdc(KEYTAB_FILE);
// Adds our test impls instead of creating service loader entries which
// might inadvertently get them loaded on a real cluster.
CONF.setStrings(SaslClientAuthenticationProviders.EXTRA_PROVIDERS_KEY,
InMemoryClientProvider.class.getName());
CONF.setStrings(SaslServerAuthenticationProviders.EXTRA_PROVIDERS_KEY,
InMemoryServerProvider.class.getName());
CONF.set(SaslClientAuthenticationProviders.SELECTOR_KEY,
InMemoryProviderSelector.class.getName());
createBaseCluster(UTIL, KEYTAB_FILE, kdc);
CONF.set(RpcServerFactory.CUSTOM_RPC_SERVER_IMPL_CONF_KEY, rpcServerImpl);
CLUSTER = new LocalHBaseCluster(CONF, 1);
CLUSTER.startup();
}
@AfterClass
public static void shutdownCluster() throws Exception {
if (CLUSTER != null) {
CLUSTER.shutdown();
CLUSTER = null;
}
UTIL.shutdownMiniDFSCluster();
UTIL.shutdownMiniZKCluster();
UTIL.cleanupTestDir();
}
@Before
public void setUp() throws Exception {
createTable();
}
@After
public void tearDown() throws IOException {
UTIL.deleteTable(name.getTableName());
}
@Rule
public TableNameTestRule name = new TableNameTestRule();
private TableName tableName;
private String clusterId;
private void createTable() throws Exception {
tableName = name.getTableName();
// Create a table and write a record as the service user (hbase)
UserGroupInformation serviceUgi = UserGroupInformation
.loginUserFromKeytabAndReturnUGI("hbase/localhost", KEYTAB_FILE.getAbsolutePath());
clusterId = serviceUgi.doAs(new PrivilegedExceptionAction<String>() {
@Override
public String run() throws Exception {
try (Connection conn = ConnectionFactory.createConnection(CONF);
Admin admin = conn.getAdmin();) {
admin.createTable(TableDescriptorBuilder.newBuilder(tableName)
.setColumnFamily(ColumnFamilyDescriptorBuilder.of("f1")).build());
UTIL.waitTableAvailable(tableName);
try (Table t = conn.getTable(tableName)) {
Put p = new Put(Bytes.toBytes("r1"));
p.addColumn(Bytes.toBytes("f1"), Bytes.toBytes("q1"), Bytes.toBytes("1"));
t.put(p);
}
return admin.getClusterMetrics().getClusterId();
}
}
});
assertNotNull(clusterId);
}
private Configuration getClientConf() {
Configuration conf = new Configuration(CONF);
conf.set(RpcClientFactory.CUSTOM_RPC_CLIENT_IMPL_CONF_KEY, rpcClientImpl);
return conf;
}
@Test
public void testPositiveAuthentication() throws Exception {
// Validate that we can read that record back out as the user with our custom auth'n
UserGroupInformation user1 = UserGroupInformation.createUserForTesting("user1", new String[0]);
user1.addToken(createPasswordToken("user1", USER1_PASSWORD, clusterId));
user1.doAs(new PrivilegedExceptionAction<Void>() {
@Override
public Void run() throws Exception {
try (Connection conn = ConnectionFactory.createConnection(getClientConf());
Table t = conn.getTable(tableName)) {
Result r = t.get(new Get(Bytes.toBytes("r1")));
assertNotNull(r);
assertFalse("Should have read a non-empty Result", r.isEmpty());
final Cell cell = r.getColumnLatestCell(Bytes.toBytes("f1"), Bytes.toBytes("q1"));
assertTrue("Unexpected value", CellUtil.matchingValue(cell, Bytes.toBytes("1")));
return null;
}
}
});
}
@Test
public void testNegativeAuthentication() throws Exception {
// Validate that we can read that record back out as the user with our custom auth'n
UserGroupInformation user1 = UserGroupInformation.createUserForTesting("user1", new String[0]);
user1.addToken(createPasswordToken("user1", "definitely not the password", clusterId));
user1.doAs(new PrivilegedExceptionAction<Void>() {
@Override
public Void run() throws Exception {
Configuration clientConf = getClientConf();
clientConf.setInt(HConstants.HBASE_CLIENT_RETRIES_NUMBER, 1);
// Depending on the registry in use, the following code can throw exceptions at different
// places. Master registry fails at the createConnection() step because the RPC to the
// master fails with sasl auth. With ZK registry, connection creation succeeds (since there
// is no RPC to HBase services involved) but the subsequent get() fails. The root cause
// should still be a SaslException in both the cases.
try (Connection conn = ConnectionFactory.createConnection(clientConf);
Table t = conn.getTable(tableName)) {
t.get(new Get(Bytes.toBytes("r1")));
fail("Should not successfully authenticate with HBase");
} catch (MasterRegistryFetchException mfe) {
Throwable cause = mfe.getCause();
assertTrue(cause.getMessage(), cause.getMessage().contains("SaslException"));
} catch (RetriesExhaustedException re) {
assertTrue(re.getMessage(), re.getMessage().contains("SaslException"));
} catch (Exception e) {
// Any other exception is unexpected.
fail("Unexpected exception caught, was expecting a authentication error: "
+ Throwables.getStackTraceAsString(e));
}
return null;
}
});
}
}