blob: 20959e2a0258a07eff783f44254fffc432e20648 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.solr.encryption;
import com.carrotsearch.randomizedtesting.generators.RandomStrings;
import org.apache.solr.client.solrj.SolrQuery;
import org.apache.solr.client.solrj.SolrRequest;
import org.apache.solr.client.solrj.SolrServerException;
import org.apache.solr.client.solrj.impl.CloudSolrClient;
import org.apache.solr.client.solrj.request.CollectionAdminRequest;
import org.apache.solr.client.solrj.request.GenericSolrRequest;
import org.apache.solr.client.solrj.response.QueryResponse;
import org.apache.solr.cloud.MiniSolrCloudCluster;
import org.apache.solr.cloud.SolrCloudTestCase;
import org.apache.solr.common.SolrInputDocument;
import org.apache.solr.common.params.ModifiableSolrParams;
import org.apache.solr.common.util.NamedList;
import org.junit.After;
import org.junit.AfterClass;
import org.junit.BeforeClass;
import org.junit.Test;
import org.junit.runner.RunWith;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashSet;
import java.util.List;
import java.util.Random;
import java.util.Set;
import java.util.UUID;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.atomic.AtomicLong;
import java.util.function.Supplier;
import static org.apache.solr.encryption.EncryptionRequestHandler.*;
import static org.apache.solr.encryption.TestingKeyManager.*;
/**
* Tests the encryption handler under heavy concurrent load test.
* <p>
* Sends concurrent indexing and querying requests with high throughput while
* triggering re-encryption with the handler to verify concurrent segment merging
* is handled correctly without stopping indexing nor querying, and all encrypted
* files are decrypted correctly when refreshing the index searcher after each
* commit.
*/
@RunWith(com.carrotsearch.randomizedtesting.RandomizedRunner.class)
public class EncryptionHeavyLoadTest extends SolrCloudTestCase {
// Change the test duration manually to run longer, e.g. 3 minutes.
private static final long TEST_DURATION_MS = TimeUnit.SECONDS.toMillis(10);
private static final int RANDOM_DELAY_BETWEEN_INDEXING_BATCHES_MS = 50;
private static final int RANDOM_NUM_DOCS_PER_BATCH = 200;
private static final float PROBABILITY_OF_COMMIT_PER_BATCH = 0.33f;
private static final int DICTIONARY_SIZE = 5000;
private static final int RANDOM_DELAY_BETWEEN_QUERIES_MS = 10;
private static final int NUM_INDEXING_THREADS = 3;
private static final int NUM_QUERYING_THREADS = 2;
private static final int RANDOM_DELAY_BETWEEN_REENCRYPTION_MS = 2000;
private static final String[] KEY_IDS = {KEY_ID_1, KEY_ID_2, KEY_ID_3, NO_KEY_ID};
private static final float PROBABILITY_OF_WAITING_ENCRYPTION_COMPLETION = 0.5f;
private static final String COLLECTION_PREFIX = EncryptionHeavyLoadTest.class.getSimpleName() + "-collection-";
private static final String SYSTEM_OUTPUT_MARKER = "*** ";
private volatile CloudSolrClient solrClient;
private volatile boolean stopTest;
private volatile Dictionary dictionary;
private List<Thread> threads;
private int nextKeyIndex;
private String keyId;
private volatile Exception exception;
private long startTimeMs;
private long endTimeMs;
private long lastDisplayTimeMs;
@BeforeClass
public static void beforeClass() throws Exception {
TestUtil.setInstallDirProperty();
cluster = new MiniSolrCloudCluster.Builder(1, createTempDir())
.addConfig("config", TestUtil.getConfigPath("collection1"))
.configure();
}
@AfterClass
public static void afterClass() throws Exception {
cluster.shutdown();
}
@Override
public void setUp() throws Exception {
super.setUp();
String collectionName = COLLECTION_PREFIX + UUID.randomUUID();
solrClient = cluster.getSolrClient();
solrClient.setDefaultCollection(collectionName);
CollectionAdminRequest.createCollection(collectionName, 1, 1).process(solrClient);
cluster.waitForActiveCollection(collectionName, 1, 1);
dictionary = new Dictionary.Builder().build(DICTIONARY_SIZE, random());
threads = new ArrayList<>();
}
@Override
@After
public void tearDown() throws Exception {
try {
stopTest = true;
for (Thread thread : threads) {
try {
thread.join(5000);
print(thread.getName() + " stopped");
} catch (InterruptedException e) {
System.err.println("Interrupted while closing " + thread.getName());
}
}
startTimeMs = lastDisplayTimeMs = System.currentTimeMillis();
endTimeMs = startTimeMs + TimeUnit.SECONDS.toMillis(20);
print("waiting for the final encryption completion");
assertTrue("Timeout waiting for the final encryption completion", encrypt(keyId, true));
print("final encryption complete");
} finally {
super.tearDown();
}
}
@Test
public void testReencryptionUnderHeavyConcurrentLoad() throws Exception {
print("Starting test");
startTimeMs = lastDisplayTimeMs = System.currentTimeMillis();
endTimeMs = startTimeMs + TEST_DURATION_MS;
Random random = random();
if (random.nextBoolean()) {
print("preparing empty index for encryption");
encrypt(nextKeyId(), waitForCompletion(random));
}
startThreads(NUM_INDEXING_THREADS, "Indexing", Indexer::new);
startThreads(NUM_QUERYING_THREADS, "Querying", Querier::new);
while (!isTimeElapsed()) {
Thread.sleep(random.nextInt(RANDOM_DELAY_BETWEEN_REENCRYPTION_MS));
encrypt(nextKeyId(), waitForCompletion(random));
}
if (System.currentTimeMillis() - lastDisplayTimeMs >= 1000) {
print("elapsed time = " + ((System.currentTimeMillis() - startTimeMs) / 1000) + " s");
}
print("Stopping test");
if (exception != null) {
throw exception;
}
}
private void startThreads(int numThreads, String namePrefix, Supplier<Runnable> runnableSupplier) {
for (int i = 0; i < numThreads; i++) {
String name = namePrefix + "-" + i;
print("Start " + name);
Thread thread = new Thread(runnableSupplier.get(), name);
thread.setDaemon(true);
threads.add(thread);
thread.start();
}
}
private boolean isTimeElapsed() {
long timeMs = System.currentTimeMillis();
if (timeMs - lastDisplayTimeMs >= 10000) {
print("elapsed time = " + ((timeMs - startTimeMs) / 1000) + " s");
lastDisplayTimeMs = timeMs;
}
return timeMs >= endTimeMs;
}
private String nextKeyId() {
keyId = KEY_IDS[nextKeyIndex++];
if (nextKeyIndex == KEY_IDS.length) {
nextKeyIndex = 0;
}
return keyId;
}
private boolean encrypt(String keyId, boolean waitForCompletion) throws Exception {
NamedList<Object> response = sendEncryptionRequest(keyId);
if (response.get(ENCRYPTION_STATE).equals(STATE_PENDING)) {
if (!waitForCompletion) {
return false;
}
print("waiting for encryption completion for keyId=" + keyId);
while (response.get(ENCRYPTION_STATE).equals(STATE_PENDING)) {
if (isTimeElapsed()) {
return false;
}
Thread.sleep(500);
response = sendEncryptionRequest(keyId);
}
print("encryption complete for keyId=" + keyId);
}
return true;
}
private boolean waitForCompletion(Random random) {
return random.nextFloat() <= PROBABILITY_OF_WAITING_ENCRYPTION_COMPLETION;
}
private NamedList<Object> sendEncryptionRequest(String keyId) throws SolrServerException, IOException {
ModifiableSolrParams params = new ModifiableSolrParams();
params.set(PARAM_KEY_ID, keyId);
NamedList<Object> response = solrClient.request(new GenericSolrRequest(SolrRequest.METHOD.GET, "/admin/encrypt", params));
print("encrypt keyId=" + keyId + " => response status=" + response.get(STATUS) + " state=" + response.get(ENCRYPTION_STATE));
return response;
}
private static void print(String message) {
System.out.println(SYSTEM_OUTPUT_MARKER + message);
}
private static void threadPrint(String message) {
print(Thread.currentThread().getName() + ": " + message);
}
private static class Dictionary {
final List<String> terms;
Dictionary(List<String> terms) {
this.terms = terms;
}
String getTerm(Random random) {
return terms.get(random.nextInt(terms.size()));
}
static class Builder {
Dictionary build(int size, Random random) {
Set<String> terms = new HashSet<>();
for (int i = 0; i < size;) {
String term = RandomStrings.randomUnicodeOfCodepointLengthBetween(random, 4, 12);
if (terms.add(term)) {
i++;
}
}
return new Dictionary(new ArrayList<>(terms));
}
}
}
private class Indexer implements Runnable {
final long seed;
final AtomicLong docNum = new AtomicLong();
Indexer() {
seed = random().nextLong();
}
@Override
public void run() {
long numBatches = 0;
long totalDocs = 0;
long numCommits = 0;
try {
Random random = new Random(seed);
while (!stopTest) {
Thread.sleep(random.nextInt(RANDOM_DELAY_BETWEEN_INDEXING_BATCHES_MS));
Collection<SolrInputDocument> docs = new ArrayList<>();
for (int i = random.nextInt(RANDOM_NUM_DOCS_PER_BATCH) + 1; i > 0; i--) {
docs.add(createDoc(random));
}
totalDocs += docs.size();
solrClient.add(docs);
if (random.nextFloat() <= PROBABILITY_OF_COMMIT_PER_BATCH) {
numCommits++;
solrClient.commit();
}
if (++numBatches % 10 == 0) {
threadPrint("sent " + numBatches + " indexing batches, totalDocs=" + totalDocs + ", numCommits=" + numCommits);
}
}
} catch (InterruptedException e) {
threadPrint("Indexing interrupted");
e.printStackTrace(System.err);
} catch (Exception e) {
exception = e;
threadPrint("Indexing stopped by exception");
e.printStackTrace(System.err);
} finally {
threadPrint("Stop indexing");
threadPrint("sent " + numBatches + " indexing batches, totalDocs=" + totalDocs + ", numCommits=" + numCommits);
stopTest = true;
}
}
SolrInputDocument createDoc(Random random) {
SolrInputDocument doc = new SolrInputDocument();
doc.addField("id", Long.toString(docNum.getAndIncrement()));
doc.addField("text", dictionary.getTerm(random));
return doc;
}
}
private class Querier implements Runnable {
final long seed;
Querier() {
seed = random().nextLong();
}
@Override
public void run() {
long totalResults = 0;
long numQueries = 0;
long numConsecutiveNoResults = 0;
try {
Random random = new Random(seed);
while (!stopTest) {
Thread.sleep(random.nextInt(RANDOM_DELAY_BETWEEN_QUERIES_MS));
QueryResponse response = null;
do {
try {
response = solrClient.query(new SolrQuery(dictionary.getTerm(random)));
} catch (Exception e) {
// Some queries might not be parseable due to the random terms. Just retry with another term.
}
} while (response == null);
int numResults = response.getResults().size();
totalResults += numResults;
numQueries++;
if (numResults == 0) {
numConsecutiveNoResults++;
} else {
numConsecutiveNoResults = 0;
}
if (numQueries % 500 == 0) {
threadPrint("sent " + numQueries + " queries, totalResults=" + totalResults + ", numConsecutiveNoResults=" + numConsecutiveNoResults);
}
}
} catch (InterruptedException e) {
threadPrint("Querying interrupted");
e.printStackTrace(System.err);
} catch (Exception e) {
exception = e;
threadPrint("Querying stopped by exception");
e.printStackTrace(System.err);
} finally {
threadPrint("Stop querying");
threadPrint("sent " + numQueries + " queries, totalResults=" + totalResults + ", numConsecutiveNoResults=" + numConsecutiveNoResults);
stopTest = true;
}
}
}
}