blob: 462e2abb210083b124a84070db9663a62cb5e3b5 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.cloud;
import java.io.IOException;
import java.lang.invoke.MethodHandles;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.function.Predicate;
import com.codahale.metrics.Counter;
import com.codahale.metrics.Meter;
import com.codahale.metrics.Metric;
import com.codahale.metrics.MetricRegistry;
import com.codahale.metrics.Timer;
import org.apache.http.HttpResponse;
import org.apache.http.client.HttpClient;
import org.apache.http.client.methods.HttpGet;
import org.apache.http.message.AbstractHttpMessage;
import org.apache.http.message.BasicHeader;
import org.apache.http.util.EntityUtils;
import org.apache.solr.client.solrj.embedded.JettySolrRunner;
import org.apache.solr.common.util.Base64;
import org.apache.solr.common.util.StrUtils;
import org.apache.solr.common.util.Utils;
import org.apache.solr.util.TimeOut;
import org.jose4j.jws.JsonWebSignature;
import org.jose4j.lang.JoseException;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import static java.nio.charset.StandardCharsets.UTF_8;
/**
* Base test class for cloud tests wanting to track authentication metrics.
* The assertions provided by this base class require a *minimum* count, not exact count from metrics.
* Warning: Make sure that your test case does not break when beasting.
*/
public class SolrCloudAuthTestCase extends SolrCloudTestCase {
private static final Logger log = LoggerFactory.getLogger(MethodHandles.lookup().lookupClass());
private static final List<String> AUTH_METRICS_KEYS = Arrays.asList("errors", "requests", "authenticated",
"passThrough", "failWrongCredentials", "failMissingCredentials", "requestTimes", "totalTime");
private static final List<String> AUTH_METRICS_METER_KEYS = Arrays.asList("errors", "count");
private static final List<String> AUTH_METRICS_TIMER_KEYS = Collections.singletonList("requestTimes");
private static final String METRICS_PREFIX_PKI = "SECURITY./authentication/pki.";
private static final String METRICS_PREFIX = "SECURITY./authentication.";
@SuppressWarnings({"rawtypes"})
public static final Predicate NOT_NULL_PREDICATE = o -> o != null;
private static final List<String> AUDIT_METRICS_KEYS = Arrays.asList("count");
private static final List<String> AUTH_METRICS_TO_COMPARE = Arrays.asList("requests", "authenticated", "passThrough", "failWrongCredentials", "failMissingCredentials", "errors");
private static final List<String> AUDIT_METRICS_TO_COMPARE = Arrays.asList("count");
@BeforeClass
public static void enableMetrics() {
System.setProperty("metricsEnabled", "true");
}
@AfterClass
public static void disableMetrics() {
System.clearProperty("metricsEnabled");
}
/**
* Used to check metric counts for PKI auth
*/
protected void assertPkiAuthMetricsMinimums(int requests, int authenticated, int passThrough, int failWrongCredentials, int failMissingCredentials, int errors) throws InterruptedException {
assertAuthMetricsMinimums(METRICS_PREFIX_PKI, requests, authenticated, passThrough, failWrongCredentials, failMissingCredentials, errors);
}
/**
* Used to check metric counts for the AuthPlugin in use (except PKI)
*
* TODO: many of these params have to be under specified - this should wait a bit to see the desired params and timeout
*/
protected void assertAuthMetricsMinimums(int requests, int authenticated, int passThrough, int failWrongCredentials, int failMissingCredentials, int errors) throws InterruptedException {
assertAuthMetricsMinimums(METRICS_PREFIX, requests, authenticated, passThrough, failWrongCredentials, failMissingCredentials, errors);
}
/**
* Common test method to be able to check security from any authentication plugin
* @param cluster the MiniSolrCloudCluster to fetch metrics from
* @param prefix the metrics key prefix, currently "SECURITY./authentication." for basic auth and "SECURITY./authentication/pki." for PKI
* @param keys what keys to examine
*/
Map<String,Long> countSecurityMetrics(MiniSolrCloudCluster cluster, String prefix, List<String> keys) {
List<Map<String, Metric>> metrics = new ArrayList<>();
cluster.getJettySolrRunners().forEach(r -> {
MetricRegistry registry = r.getCoreContainer().getMetricManager().registry("solr.node");
assertNotNull(registry);
metrics.add(registry.getMetrics());
});
Map<String,Long> counts = new HashMap<>();
keys.forEach(k -> {
counts.put(k, sumCount(prefix, k, metrics));
});
return counts;
}
/**
* Common test method to be able to check auth metrics from any authentication plugin
* @param prefix the metrics key prefix, currently "SECURITY./authentication." for basic auth and "SECURITY./authentication/pki." for PKI
*/
private void assertAuthMetricsMinimums(String prefix, int requests, int authenticated, int passThrough, int failWrongCredentials, int failMissingCredentials, int errors) throws InterruptedException {
Map<String, Long> expectedCounts = new HashMap<>();
expectedCounts.put("requests", (long) requests);
expectedCounts.put("authenticated", (long) authenticated);
expectedCounts.put("passThrough", (long) passThrough);
expectedCounts.put("failWrongCredentials", (long) failWrongCredentials);
expectedCounts.put("failMissingCredentials", (long) failMissingCredentials);
expectedCounts.put("errors", (long) errors);
final Map<String, Long> counts = countSecurityMetrics(cluster, prefix, AUTH_METRICS_KEYS);
final boolean success = isMetricsEqualOrLarger(AUTH_METRICS_TO_COMPARE, expectedCounts, counts);
assertTrue("Expected metric minimums for prefix " + prefix + ": " + expectedCounts +
", but got: " + counts + "(Possible cause is delay in loading modified " +
"security.json; see SOLR-13464 for test work around)",
success);
if (counts.get("requests") > 0) {
assertTrue("requestTimes count not > 1", counts.get("requestTimes") > 1);
assertTrue("totalTime not > 0", counts.get("totalTime") > 0);
}
}
/**
* Common test method to be able to check audit metrics
* @param className the class name to be used for composing prefix, e.g. "SECURITY./auditlogging/SolrLogAuditLoggerPlugin"
*/
protected void assertAuditMetricsMinimums(MiniSolrCloudCluster cluster, String className, int count, int errors) throws InterruptedException {
String prefix = "SECURITY./auditlogging." + className + ".";
Map<String, Long> expectedCounts = new HashMap<>();
expectedCounts.put("count", (long) count);
Map<String, Long> counts = countSecurityMetrics(cluster, prefix, AUDIT_METRICS_KEYS);
boolean success = isMetricsEqualOrLarger(AUDIT_METRICS_TO_COMPARE, expectedCounts, counts);
if (!success) {
log.info("First metrics count assert failed, pausing 2s before re-attempt");
Thread.sleep(2000);
counts = countSecurityMetrics(cluster, prefix, AUDIT_METRICS_KEYS);
success = isMetricsEqualOrLarger(AUDIT_METRICS_TO_COMPARE, expectedCounts, counts);
}
assertTrue("Expected metric minimums for prefix " + prefix + ": " + expectedCounts + ", but got: " + counts, success);
}
private boolean isMetricsEqualOrLarger(List<String> metricsToCompare, Map<String, Long> expectedCounts, Map<String, Long> actualCounts) {
return metricsToCompare.stream()
.allMatch(k -> actualCounts.get(k).intValue() >= expectedCounts.get(k).intValue());
}
// Have to sum the metrics from all three shards/nodes
private long sumCount(String prefix, String key, List<Map<String, Metric>> metrics) {
assertTrue("Metric " + prefix + key + " does not exist", metrics.get(0).containsKey(prefix + key));
if (AUTH_METRICS_METER_KEYS.contains(key))
return metrics.stream().mapToLong(l -> ((Meter)l.get(prefix + key)).getCount()).sum();
else if (AUTH_METRICS_TIMER_KEYS.contains(key))
return (long) ((long) 1000 * metrics.stream().mapToDouble(l -> ((Timer)l.get(prefix + key)).getMeanRate()).average().orElse(0.0d));
else
return metrics.stream().mapToLong(l -> ((Counter)l.get(prefix + key)).getCount()).sum();
}
public static void verifySecurityStatus(HttpClient cl, String url, String objPath,
Object expected, int count) throws Exception {
verifySecurityStatus(cl, url, objPath, expected, count, (String)null);
}
public static void verifySecurityStatus(HttpClient cl, String url, String objPath,
Object expected, int count, String user, String pwd)
throws Exception {
verifySecurityStatus(cl, url, objPath, expected, count, makeBasicAuthHeader(user, pwd));
}
protected void verifySecurityStatus(HttpClient cl, String url, String objPath,
Object expected, int count, JsonWebSignature jws) throws Exception {
verifySecurityStatus(cl, url, objPath, expected, count, getBearerAuthHeader(jws));
}
@SuppressWarnings({"unchecked"})
private static void verifySecurityStatus(HttpClient cl, String url, String objPath,
Object expected, int count, String authHeader) throws IOException, InterruptedException {
boolean success = false;
String s = null;
List<String> hierarchy = StrUtils.splitSmart(objPath, '/');
for (int i = 0; i < count; i++) {
HttpGet get = new HttpGet(url);
if (authHeader != null) setAuthorizationHeader(get, authHeader);
HttpResponse rsp = cl.execute(get);
s = EntityUtils.toString(rsp.getEntity());
@SuppressWarnings({"rawtypes"})
Map m = null;
try {
m = (Map) Utils.fromJSONString(s);
} catch (Exception e) {
fail("Invalid json " + s);
}
Utils.consumeFully(rsp.getEntity());
Object actual = Utils.getObjectByPath(m, true, hierarchy);
if (expected instanceof Predicate) {
@SuppressWarnings({"rawtypes"})
Predicate predicate = (Predicate) expected;
if (predicate.test(actual)) {
success = true;
break;
}
} else if (Objects.equals(actual == null ? null : String.valueOf(actual), expected)) {
success = true;
break;
}
Thread.sleep(50);
}
assertTrue("No match for " + objPath + " = " + expected + ", full response = " + s, success);
}
protected static String makeBasicAuthHeader(String user, String pwd) {
String userPass = user + ":" + pwd;
return "Basic " + Base64.byteArrayToBase64(userPass.getBytes(UTF_8));
}
static String getBearerAuthHeader(JsonWebSignature jws) throws JoseException {
return "Bearer " + jws.getCompactSerialization();
}
public static void setAuthorizationHeader(AbstractHttpMessage httpMsg, String headerString) {
httpMsg.setHeader(new BasicHeader("Authorization", headerString));
log.info("Added Authorization Header {}", headerString);
}
/**
* This helper method can be used by tests to monitor the current state of either
* <code>"authentication"</code> or <code>"authorization"</code> plugins in use each
* node of the current cluster.
* <p>
* This can be useful in a {@link TimeOut#waitFor} loop to monitor a cluster and "wait for"
* A change in security settings to affect all nodes by comparing the objects in the current
* Map with the one in use prior to executing some test command. (providing a work around
* for the security user experienence limitations identified in
* <a href="https://issues.apache.org/jira/browse/SOLR-13464">SOLR-13464</a> )
* </p>
*
* @param url A REST url (or any arbitrary String) ending in
* <code>"authentication"</code> or <code>"authorization"</code> used to specify the type of
* plugins to introspect
* @return A Map from <code>nodeName</code> to auth plugin
*/
public static Map<String,Object> getAuthPluginsInUseForCluster(String url) {
Map<String,Object> plugins = new HashMap<>();
if (url.endsWith("authentication")) {
for (JettySolrRunner r : cluster.getJettySolrRunners()) {
plugins.put(r.getNodeName(), r.getCoreContainer().getAuthenticationPlugin());
}
} else if (url.endsWith("authorization")) {
for (JettySolrRunner r : cluster.getJettySolrRunners()) {
plugins.put(r.getNodeName(), r.getCoreContainer().getAuthorizationPlugin());
}
} else {
fail("Test helper method assumptions broken: " + url);
}
return plugins;
}
}