blob: 7c94702a217373344f8954ff2be468139e02d7c1 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.cassandra.db.guardrails;
import java.io.Serializable;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.util.Arrays;
import java.util.Collections;
import java.util.List;
import java.util.Map;
import java.util.TreeSet;
import java.util.concurrent.CopyOnWriteArrayList;
import java.util.function.BiConsumer;
import java.util.function.Consumer;
import java.util.function.Function;
import java.util.function.Supplier;
import java.util.stream.Collectors;
import javax.annotation.Nullable;
import com.google.common.collect.ImmutableSet;
import org.junit.After;
import org.junit.Assert;
import org.junit.Before;
import org.junit.BeforeClass;
import org.apache.cassandra.auth.AuthenticatedUser;
import org.apache.cassandra.auth.CassandraRoleManager;
import org.apache.cassandra.config.DatabaseDescriptor;
import org.apache.cassandra.cql3.CQLStatement;
import org.apache.cassandra.cql3.CQLTester;
import org.apache.cassandra.cql3.QueryOptions;
import org.apache.cassandra.cql3.QueryProcessor;
import org.apache.cassandra.db.ConsistencyLevel;
import org.apache.cassandra.db.guardrails.GuardrailEvent.GuardrailEventType;
import org.apache.cassandra.db.view.View;
import org.apache.cassandra.diag.DiagnosticEventService;
import org.apache.cassandra.index.sasi.SASIIndex;
import org.apache.cassandra.service.ClientState;
import org.apache.cassandra.service.ClientWarn;
import org.apache.cassandra.service.QueryState;
import org.apache.cassandra.transport.ProtocolVersion;
import org.apache.cassandra.transport.messages.ResultMessage;
import org.apache.cassandra.utils.Clock;
import org.assertj.core.api.Assertions;
import static java.lang.String.format;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
public abstract class GuardrailTester extends CQLTester
{
// Name used when testing CREATE TABLE that should be aborted (we need to provide it as assertFails, which
// is used to assert the failure, does not know that it is a CREATE TABLE and would thus reuse the name of the
// previously created table, which is not what we want).
protected static final String FAIL_TABLE = "abort_table_creation_test";
private static final String USERNAME = "guardrail_user";
private static final String PASSWORD = "guardrail_password";
protected static ClientState systemClientState, userClientState, superClientState;
/** The tested guardrail, if we are testing a specific one. */
@Nullable
protected final Guardrail guardrail;
/** A listener for emitted diagnostic events. */
protected final Listener listener;
public GuardrailTester()
{
this(null);
}
public GuardrailTester(@Nullable Guardrail guardrail)
{
this.guardrail = guardrail;
this.listener = new Listener();
}
@BeforeClass
public static void setUpClass()
{
CQLTester.setUpClass();
requireAuthentication();
requireNetwork();
DatabaseDescriptor.setDiagnosticEventsEnabled(true);
systemClientState = ClientState.forInternalCalls();
userClientState = ClientState.forExternalCalls(InetSocketAddress.createUnresolved("127.0.0.1", 123));
superClientState = ClientState.forExternalCalls(InetSocketAddress.createUnresolved("127.0.0.1", 321));
superClientState.login(new AuthenticatedUser(CassandraRoleManager.DEFAULT_SUPERUSER_NAME));
}
/**
* Creates an ordinary user that is not excluded from guardrails, that is, a user that is not super not internal.
*/
@Before
public void beforeGuardrailTest() throws Throwable
{
useSuperUser();
executeNet(format("CREATE USER IF NOT EXISTS %s WITH PASSWORD '%s'", USERNAME, PASSWORD));
executeNet(format("GRANT ALL ON KEYSPACE %s TO %s", KEYSPACE, USERNAME));
useUser(USERNAME, PASSWORD);
String useKeyspaceQuery = "USE " + keyspace();
execute(userClientState, useKeyspaceQuery);
execute(systemClientState, useKeyspaceQuery);
execute(superClientState, useKeyspaceQuery);
DiagnosticEventService.instance().subscribe(GuardrailEvent.class, listener);
}
@After
public void afterGuardrailTest() throws Throwable
{
DiagnosticEventService.instance().unsubscribe(listener);
}
static Guardrails guardrails()
{
return Guardrails.instance;
}
protected <T> void assertValidProperty(BiConsumer<Guardrails, T> setter, T value)
{
setter.accept(guardrails(), value);
}
protected <T> void assertValidProperty(BiConsumer<Guardrails, T> setter, Function<Guardrails, T> getter, T value)
{
setter.accept(guardrails(), value);
assertEquals(value, getter.apply(guardrails()));
}
protected <T> void assertInvalidProperty(BiConsumer<Guardrails, T> setter,
T value,
String message,
Object... messageArgs)
{
Assertions.assertThatThrownBy(() -> setter.accept(guardrails(), value))
.isInstanceOf(IllegalArgumentException.class)
.hasMessage(format(message, messageArgs));
}
@SafeVarargs
protected final void testExcludedUsers(Supplier<String>... queries) throws Throwable
{
execute("USE " + keyspace());
assertSuperuserIsExcluded(queries);
assertInternalQueriesAreExcluded(queries);
}
@SafeVarargs
private final void assertInternalQueriesAreExcluded(Supplier<String>... queries) throws Throwable
{
for (Supplier<String> query : queries)
{
assertValid(() -> execute(systemClientState, query.get()));
}
}
@SafeVarargs
private final void assertSuperuserIsExcluded(Supplier<String>... queries) throws Throwable
{
for (Supplier<String> query : queries)
{
assertValid(() -> execute(superClientState, query.get()));
}
}
protected void assertValid(CheckedFunction function) throws Throwable
{
ClientWarn.instance.captureWarnings();
try
{
function.apply();
assertEmptyWarnings();
listener.assertNotWarned();
listener.assertNotFailed();
}
catch (GuardrailViolatedException e)
{
fail("Expected not to fail, but failed with error message: " + e.getMessage());
}
finally
{
ClientWarn.instance.resetWarnings();
listener.clear();
}
}
protected void assertValid(String query) throws Throwable
{
assertValid(() -> execute(userClientState, query));
}
protected void assertWarns(String query, String message) throws Throwable
{
assertWarns(query, message, message);
}
protected void assertWarns(String query, String message, String redactedMessage) throws Throwable
{
assertWarns(query, Collections.singletonList(message), Collections.singletonList(redactedMessage));
}
protected void assertWarns(String query, List<String> messages) throws Throwable
{
assertWarns(() -> execute(userClientState, query), messages, messages);
}
protected void assertWarns(String query, List<String> messages, List<String> redactedMessages) throws Throwable
{
assertWarns(() -> execute(userClientState, query), messages, redactedMessages);
}
protected void assertWarns(CheckedFunction function, String message) throws Throwable
{
assertWarns(function, message, message);
}
protected void assertWarns(CheckedFunction function, String message, String redactedMessage) throws Throwable
{
assertWarns(function, Collections.singletonList(message), Collections.singletonList(redactedMessage));
}
protected void assertWarns(CheckedFunction function, List<String> messages) throws Throwable
{
assertWarns(function, messages, messages);
}
protected void assertWarns(CheckedFunction function, List<String> messages, List<String> redactedMessages) throws Throwable
{
// We use client warnings to check we properly warn as this is the most convenient. Technically,
// this doesn't validate we also log the warning, but that's probably fine ...
ClientWarn.instance.captureWarnings();
try
{
function.apply();
assertWarnings(messages);
listener.assertWarned(redactedMessages);
listener.assertNotFailed();
}
finally
{
ClientWarn.instance.resetWarnings();
listener.clear();
}
}
protected void assertFails(String query, String message) throws Throwable
{
assertFails(query, message, message);
}
protected void assertFails(String query, String message, String redactedMessage) throws Throwable
{
assertFails(query, Collections.singletonList(message), Collections.singletonList(redactedMessage));
}
protected void assertFails(String query, List<String> messages) throws Throwable
{
assertFails(query, messages, messages);
}
protected void assertFails(String query, List<String> messages, List<String> redactedMessages) throws Throwable
{
assertFails(() -> execute(userClientState, query), messages, redactedMessages);
}
protected void assertFails(CheckedFunction function, String message) throws Throwable
{
assertFails(function, message, message);
}
protected void assertFails(CheckedFunction function, String message, String redactedMessage) throws Throwable
{
assertFails(function, true, message, redactedMessage);
}
protected void assertFails(CheckedFunction function, boolean thrown, String message) throws Throwable
{
assertFails(function, thrown, message, message);
}
protected void assertFails(CheckedFunction function, boolean thrown, String message, String redactedMessage) throws Throwable
{
assertFails(function, thrown, Collections.singletonList(message), Collections.singletonList(redactedMessage));
}
protected void assertFails(CheckedFunction function, List<String> messages) throws Throwable
{
assertFails(function, messages, messages);
}
protected void assertFails(CheckedFunction function, List<String> messages, List<String> redactedMessages) throws Throwable
{
assertFails(function, true, messages, redactedMessages);
}
protected void assertFails(CheckedFunction function, boolean thrown, List<String> messages, List<String> redactedMessages) throws Throwable
{
ClientWarn.instance.captureWarnings();
try
{
function.apply();
if (thrown)
fail("Expected to fail, but it did not");
}
catch (GuardrailViolatedException e)
{
assertTrue("Expect no exception thrown", thrown);
// the last message is the one raising the guardrail failure, the previous messages are warnings
String failMessage = messages.get(messages.size() - 1);
if (guardrail != null)
{
String prefix = guardrail.decorateMessage("");
assertTrue(format("Full error message '%s' doesn't start with the prefix '%s'", e.getMessage(), prefix),
e.getMessage().startsWith(prefix));
}
assertTrue(format("Full error message '%s' does not contain expected message '%s'", e.getMessage(), failMessage),
e.getMessage().contains(failMessage));
assertWarnings(messages);
if (messages.size() > 1)
listener.assertWarned(redactedMessages.subList(0, messages.size() - 1));
else
listener.assertNotWarned();
listener.assertFailed(redactedMessages.get(messages.size() - 1));
}
finally
{
ClientWarn.instance.resetWarnings();
listener.clear();
}
}
protected void assertFails(String query, String... messages) throws Throwable
{
assertFails(() -> execute(userClientState, query), Arrays.asList(messages));
}
protected void assertThrows(CheckedFunction function, Class<? extends Throwable> exception, String message)
{
try
{
function.apply();
fail("Expected to fail, but it did not");
}
catch (Throwable e)
{
if (!exception.isAssignableFrom(e.getClass()))
Assert.fail(format("Expected to fail with %s but got %s", exception.getName(), e.getClass().getName()));
assertTrue(format("Error message '%s' does not contain expected message '%s'", e.getMessage(), message),
e.getMessage().contains(message));
}
}
private void assertWarnings(List<String> messages)
{
List<String> warnings = getWarnings();
assertFalse("Expected to warn, but no warning was received", warnings == null || warnings.isEmpty());
assertEquals(format("Expected %d warnings but got %d: %s", messages.size(), warnings.size(), warnings),
messages.size(),
warnings.size());
for (int i = 0; i < messages.size(); i++)
{
String message = messages.get(i);
String warning = warnings.get(i);
if (guardrail != null)
{
String prefix = guardrail.decorateMessage("");
assertTrue(format("Warning log message '%s' doesn't start with the prefix '%s'", warning, prefix),
warning.startsWith(prefix));
}
assertTrue(format("Warning log message '%s' does not contain expected message '%s'", warning, message),
warning.contains(message));
}
}
private void assertEmptyWarnings()
{
List<String> warnings = getWarnings();
if (warnings == null) // will always be the case in practice currently, but being defensive if this change
warnings = Collections.emptyList();
assertTrue(format("Expect no warning messages but got %s", warnings), warnings.isEmpty());
}
protected List<String> getWarnings()
{
List<String> warnings = ClientWarn.instance.getWarnings();
return warnings == null
? Collections.emptyList()
: warnings.stream()
.filter(w -> !w.equals(View.USAGE_WARNING) && !w.equals(SASIIndex.USAGE_WARNING))
.collect(Collectors.toList());
}
protected void assertConfigValid(Consumer<Guardrails> consumer)
{
consumer.accept(guardrails());
}
protected void assertConfigFails(Consumer<Guardrails> consumer, String message)
{
try
{
consumer.accept(guardrails());
fail("Expected failure");
}
catch (IllegalArgumentException e)
{
String actualMessage = e.getMessage();
assertTrue(String.format("Failure message '%s' does not contain expected message '%s'", actualMessage, message),
actualMessage.contains(message));
}
}
protected ResultMessage execute(ClientState state, String query)
{
return execute(state, query, Collections.emptyList());
}
protected ResultMessage execute(ClientState state, String query, List<ByteBuffer> values)
{
QueryOptions options = QueryOptions.forInternalCalls(values);
return execute(state, query, options);
}
protected ResultMessage execute(ClientState state, String query, ConsistencyLevel cl)
{
return execute(state, query, cl, null);
}
protected ResultMessage execute(ClientState state, String query, ConsistencyLevel cl, ConsistencyLevel serialCl)
{
QueryOptions options = QueryOptions.create(cl,
Collections.emptyList(),
false,
10,
null,
serialCl,
ProtocolVersion.CURRENT,
KEYSPACE);
return execute(state, query, options);
}
protected ResultMessage execute(ClientState state, String query, QueryOptions options)
{
QueryState queryState = new QueryState(state);
String formattedQuery = formatQuery(query);
CQLStatement statement = QueryProcessor.parseStatement(formattedQuery, queryState.getClientState());
statement.validate(state);
return statement.execute(queryState, options, Clock.Global.nanoTime());
}
protected static String sortCSV(String csv)
{
return String.join(",", (new TreeSet<>(ImmutableSet.copyOf((csv.split(","))))));
}
/**
* A listener for guardrails diagnostic events.
*/
public class Listener implements Consumer<GuardrailEvent>
{
private final List<String> warnings = new CopyOnWriteArrayList<>();
private final List<String> failures = new CopyOnWriteArrayList<>();
@Override
public void accept(GuardrailEvent event)
{
assertNotNull(event);
Map<String, Serializable> map = event.toMap();
if (guardrail != null)
assertEquals(guardrail.name, map.get("name"));
GuardrailEventType type = (GuardrailEventType) event.getType();
String message = map.toString();
switch (type)
{
case WARNED:
warnings.add(message);
break;
case FAILED:
failures.add(message);
break;
default:
fail("Unexpected diagnostic event:" + type);
}
}
public void clear()
{
warnings.clear();
failures.clear();
}
public void assertNotWarned()
{
assertTrue(format("Expect no warning diagnostic events but got %s", warnings), warnings.isEmpty());
}
public void assertWarned(String message)
{
assertWarned(Collections.singletonList(message));
}
public void assertWarned(List<String> messages)
{
assertFalse("Expected to emit warning diagnostic event, but no warning was emitted", warnings.isEmpty());
assertEquals(format("Expected %d warning diagnostic events but got %d: %s)", messages.size(), warnings.size(), warnings),
messages.size(), warnings.size());
for (int i = 0; i < messages.size(); i++)
{
String message = messages.get(i);
String warning = warnings.get(i);
assertTrue(format("Warning diagnostic event '%s' does not contain expected message '%s'", warning, message),
warning.contains(message));
}
}
public void assertNotFailed()
{
assertTrue(format("Expect no failure diagnostic events but got %s", failures), failures.isEmpty());
}
public void assertFailed(String... messages)
{
assertFalse("Expected to emit failure diagnostic event, but no failure was emitted", failures.isEmpty());
assertEquals(format("Expected %d failure diagnostic events but got %d: %s)", messages.length, failures.size(), failures),
messages.length, failures.size());
for (int i = 0; i < messages.length; i++)
{
String message = messages[i];
String failure = failures.get(i);
assertTrue(format("Failure diagnostic event '%s' does not contain expected message '%s'", failure, message),
failure.contains(message));
}
}
}
}