blob: 32c6d7050082d9e27d7cab7ba3fd668a67679c91 [file] [log] [blame]
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.crypto.key.kms.server;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.crypto.key.KeyProvider;
import org.apache.hadoop.crypto.key.KeyProviderCryptoExtension;
import org.apache.hadoop.fs.CommonConfigurationKeysPublic;
import org.apache.hadoop.util.ExitUtil;
import org.apache.hadoop.util.GenericOptionsParser;
import org.apache.hadoop.util.KMSUtil;
import org.apache.hadoop.util.StringUtils;
import org.apache.hadoop.util.Time;
import org.apache.hadoop.util.Tool;
import org.apache.hadoop.util.ToolRunner;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.IOException;
import java.security.GeneralSecurityException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
/**
* Main class for a series of KMS benchmarks.
*
* Each benchmark measures throughput and average execution time
* of a specific kms operation, e.g. encrypt or decrypt of
* Data Encryption Keys.
*
* The benchmark does not involve any other hadoop components
* except for kms operations. Each operation is executed
* by calling directly the respective kms operation.
*
* For usage, please see <a href="http://hadoop.apache.org/docs/current/
* hadoop-project-dist/hadoop-common/Benchmarking.html#KMSBenchmark">
* the documentation</a>.
* Meanwhile, if you change the usage of this program, please also update the
* documentation accordingly.
*/
public class KMSBenchmark implements Tool {
private static final Logger LOG =
LoggerFactory.getLogger(KMSBenchmark.class);
private static final String GENERAL_OPTIONS_USAGE = "[-threads int] |" +
" [-numops int] | [{-warmup (true|false)}]";
private static Configuration config;
private KeyProviderCryptoExtension kp;
private KeyProviderCryptoExtension.EncryptedKeyVersion eek = null;
private String encryptionKeyName = "systest";
private boolean createEncryptionKey = false;
private boolean warmupKey = false;
private List<String> keys = new ArrayList<String>();
KMSBenchmark(Configuration conf, String[] args)
throws IOException {
config = conf;
kp = createKeyProviderCryptoExtension(config);
try {
eek = kp.generateEncryptedKey(encryptionKeyName);
} catch (GeneralSecurityException e) {
LOG.warn("failed to generate key", e);
}
// create key and/or warm up
for (int i = 2; i < args.length; i++) {
if (args[i].equals("-warmup")) {
warmupKey = Boolean.parseBoolean(args[++i]);
} else if (args[i].equals("-createkey")) {
encryptionKeyName = args[++i];
}
}
try {
if (createEncryptionKey) {
keys = kp.getKeys();
if (!keys.contains(encryptionKeyName)) {
kp.createKey(encryptionKeyName, KeyProvider.options(conf));
} else {
LOG.warn("encryption key already exists: {}",
encryptionKeyName);
}
}
if (warmupKey) {
kp.warmUpEncryptedKeys(encryptionKeyName);
}
} catch (GeneralSecurityException e) {
LOG.warn(" failed to create or warmup encryption key", e);
}
}
/**
* Base class for collecting operation statistics.
*
* Overload this class in order to run statistics for a
* specific kms operation.
*/
abstract class OperationStatsBase {
protected static final String OP_ALL_NAME = "all";
protected static final String OP_ALL_USAGE =
"-op all <other ops options>";
// number of threads
private int numThreads = 0;
// number of operations requested
private int numOpsRequired = 0;
// number of operations executed
private int numOpsExecuted = 0;
// sum of times for each op
private long cumulativeTime = 0;
// time from start to finish
private long elapsedTime = 0;
private List<StatsDaemon> daemons;
/**
* Operation name.
*/
abstract String getOpName();
/**
* Parse command line arguments.
*
* @param args arguments
* @throws IOException
*/
abstract void parseArguments(List<String> args) throws IOException;
/**
* This corresponds to the arg1 argument of
* {@link #executeOp(int, int, String)}, which can have
* different meanings depending on the operation performed.
*
* @param daemonId id of the daemon calling this method
* @return the argument
*/
abstract String getExecutionArgument(int daemonId);
/**
* Execute kms operation.
*
* @param daemonId id of the daemon calling this method.
* @param inputIdx serial index of the operation called by the deamon.
* @param arg1 operation specific argument.
* @return time of the individual kms call.
* @throws IOException
*/
abstract long executeOp(int daemonId, int inputIdx, String arg1)
throws IOException;
/**
* Print the results of the benchmarking.
*/
abstract void printResults();
OperationStatsBase() {
numOpsRequired = 10000;
numThreads = 3;
}
void benchmark() throws IOException {
daemons = new ArrayList<StatsDaemon>();
long start = 0;
try {
numOpsExecuted = 0;
cumulativeTime = 0;
if (numThreads < 1) {
return;
}
// thread index < nrThreads
int tIdx = 0;
int[] opsPerThread = new int[numThreads];
for (int opsScheduled = 0; opsScheduled < numOpsRequired;
opsScheduled += opsPerThread[tIdx++]) {
// execute in a separate thread
opsPerThread[tIdx] =
(numOpsRequired-opsScheduled)/(numThreads-tIdx);
if (opsPerThread[tIdx] == 0) {
opsPerThread[tIdx] = 1;
}
}
// if numThreads > numOpsRequired then the remaining threads
// will do nothing
for (; tIdx < numThreads; tIdx++) {
opsPerThread[tIdx] = 0;
}
for (tIdx=0; tIdx < numThreads; tIdx++) {
daemons.add(new StatsDaemon(tIdx, opsPerThread[tIdx], this));
}
start = Time.now();
LOG.info("Starting "+numOpsRequired+" "+getOpName()+"(s).");
for (StatsDaemon d : daemons) {
d.start();
}
} finally {
while(isInProgress()) {
try {
Thread.sleep(500);
} catch (InterruptedException e) {}
}
elapsedTime = Time.now() - start;
for (StatsDaemon d : daemons) {
incrementStats(d.localNumOpsExecuted, d.localCumulativeTime);
System.out.println(d.toString() + ": ops Exec = " +
d.localNumOpsExecuted);
}
}
}
private boolean isInProgress() {
for (StatsDaemon d : daemons) {
if (d.isInProgress()) {
return true;
}
}
return false;
}
void cleanUp() throws IOException {
}
int getNumOpsExecuted() {
return numOpsExecuted;
}
long getCumulativeTime() {
return cumulativeTime;
}
long getElapsedTime() {
return elapsedTime;
}
long getAverageTime() {
LOG.info("getAverageTime, cumulativeTime = " + cumulativeTime);
LOG.info("getAverageTime, numOpsExecuted = " + numOpsExecuted);
return numOpsExecuted == 0? 0 : cumulativeTime/numOpsExecuted;
}
double getOpsPerSecond() {
return elapsedTime == 0?
0 : 1000*(double)numOpsExecuted / elapsedTime;
}
String getClientName(int idx) {
return getOpName() + "-client-" + idx;
}
void incrementStats(int ops, long time) {
numOpsExecuted += ops;
cumulativeTime += time;
}
int getNumThreads() {
return numThreads;
}
void setNumThreads(int num) {
numThreads = num;
}
int getNumOpsRequired() {
return numOpsRequired;
}
void setNumOpsRequired(int num) {
numOpsRequired = num;
}
/**
* Parse first 2 arguments, corresponding to the "-op" option.
*
* @param args argument list
* @return true if operation is all, which means that options not
* related to this operation should be ignored, or false
* otherwise, meaning that usage should be printed when an
* unrelated option is encountered.
*/
protected boolean verifyOpArgument(List<String> args) {
if (args.size() < 2 || !args.get(0).startsWith("-op")) {
printUsage();
}
// process common options
String type = args.get(1);
if (OP_ALL_NAME.equals(type)) {
type = getOpName();
return true;
}
if (!getOpName().equals(type)) {
printUsage();
}
return false;
}
void printStats() {
LOG.info("--- " + getOpName() + " stats ---");
LOG.info("# operations: " + getNumOpsExecuted());
LOG.info("Elapsed Time: " + getElapsedTime());
LOG.info(" Ops per sec: " + getOpsPerSecond());
LOG.info("Average Time: " + getAverageTime());
}
}
/**
* One of the threads that perform stats operations.
*/
private class StatsDaemon extends Thread {
private final int daemonId;
private int opsPerThread;
private String arg1; // argument passed to executeOp()
private volatile int localNumOpsExecuted = 0;
private volatile long localCumulativeTime = 0;
private final OperationStatsBase statsOp;
StatsDaemon(int daemonId, int nOps, OperationStatsBase op) {
this.daemonId = daemonId;
this.opsPerThread = nOps;
this.statsOp = op;
setName(toString());
}
@Override
public void run() {
localNumOpsExecuted = 0;
localCumulativeTime = 0;
arg1 = statsOp.getExecutionArgument(daemonId);
try {
benchmarkOne();
} catch(IOException ex) {
LOG.error("StatsDaemon " + daemonId + " failed: \n"
+ StringUtils.stringifyException(ex));
}
}
@Override
public String toString() {
return "StatsDaemon-" + daemonId;
}
void benchmarkOne() throws IOException {
for (int idx = 0; idx < opsPerThread; idx++) {
long stat = statsOp.executeOp(daemonId, idx, arg1);
localNumOpsExecuted++;
localCumulativeTime += stat;
}
}
boolean isInProgress() {
return localNumOpsExecuted < opsPerThread;
}
/**
* Schedule to stop this daemon.
*/
void terminate() {
opsPerThread = localNumOpsExecuted;
}
}
/**
* Encrypt key statistics.
*
* Each thread encrypts the key.
*/
class EncryptKeyStats extends OperationStatsBase {
// Operation types
static final String OP_ENCRYPT_KEY = "encrypt";
static final String OP_ENCRYPT_USAGE =
"-op encrypt [-threads T -numops N -warmup F]";
EncryptKeyStats(List<String> args) {
super();
parseArguments(args);
}
@Override
String getOpName() {
return OP_ENCRYPT_KEY;
}
@Override
void parseArguments(List<String> args) {
verifyOpArgument(args);
// parse command line
for (int i = 2; i < args.size(); i++) {
if (args.get(i).equals("-threads")) {
if (i+1 == args.size()) {
printUsage();
}
setNumThreads(Integer.parseInt(args.get(++i)));
} else if (args.get(i).equals("-numops")) {
setNumOpsRequired(Integer.parseInt(args.get(++i)));
}
}
}
/**
* Returns client name.
*/
@Override
String getExecutionArgument(int daemonId) {
return getClientName(daemonId);
}
/**
* Execute key encryption.
*/
@Override
long executeOp(int daemonId, int inputIdx, String clientName)
throws IOException {
long start = Time.now();
try {
eek = kp.generateEncryptedKey(encryptionKeyName);
} catch (GeneralSecurityException e) {
LOG.warn("failed to generate encrypted key", e);
}
long end = Time.now();
return end-start;
}
@Override
void printResults() {
LOG.info("--- " + getOpName() + " inputs ---");
LOG.info("nOps = " + getNumOpsRequired());
LOG.info("nThreads = " + getNumThreads());
printStats();
}
}
/**
* Decrypt key statistics.
*
* Each thread decrypts the key.
*/
class DecryptKeyStats extends OperationStatsBase {
// Operation types
static final String OP_DECRYPT_KEY = "decrypt";
static final String OP_DECRYPT_USAGE =
"-op decrypt [-threads T -numops N -warmup F]";
DecryptKeyStats(List<String> args) {
super();
parseArguments(args);
}
@Override
String getOpName() {
return OP_DECRYPT_KEY;
}
@Override
void parseArguments(List<String> args) {
verifyOpArgument(args);
// parse command line
for (int i = 2; i < args.size(); i++) {
if (args.get(i).equals("-threads")) {
if (i+1 == args.size()) {
printUsage();
}
setNumThreads(Integer.parseInt(args.get(++i)));
} else if (args.get(i).equals("-numops")) {
setNumOpsRequired(Integer.parseInt(args.get(++i)));
}
}
}
/**
* returns client name.
*/
@Override
String getExecutionArgument(int daemonId) {
return getClientName(daemonId);
}
/**
* Execute key decryption.
*/
@Override
long executeOp(int daemonId, int inputIdx, String clientName)
throws IOException {
long start = Time.now();
try {
kp.decryptEncryptedKey(eek);
} catch (GeneralSecurityException e) {
LOG.warn("failed to generate and/or decrypt key", e);
}
long end = Time.now();
return end - start;
}
@Override
void printResults() {
LOG.info("--- " + getOpName() + " inputs ---");
LOG.info("nrOps = " + getNumOpsRequired());
LOG.info("nrThreads = " + getNumThreads());
printStats();
}
}
static void printUsage() {
System.err.println("Usage: KMSBenchmark"
+ "\n\t" + OperationStatsBase.OP_ALL_USAGE
+ " | \n\t" + EncryptKeyStats.OP_ENCRYPT_USAGE
+ " | \n\t" + DecryptKeyStats.OP_DECRYPT_USAGE
+ " | \n\t" + GENERAL_OPTIONS_USAGE
);
System.err.println();
GenericOptionsParser.printGenericCommandUsage(System.err);
ExitUtil.terminate(-1);
}
public static KeyProviderCryptoExtension createKeyProviderCryptoExtension(
final Configuration conf) throws IOException {
KeyProvider keyProvider = KMSUtil.createKeyProvider(conf,
CommonConfigurationKeysPublic.HADOOP_SECURITY_KEY_PROVIDER_PATH);
if (keyProvider == null) {
throw new IOException("Key provider was not configured.");
}
return KeyProviderCryptoExtension.
createKeyProviderCryptoExtension(keyProvider);
}
public static void runBenchmark(Configuration conf, String[] args)
throws Exception {
KMSBenchmark bench = null;
try {
bench = new KMSBenchmark(conf, args);
ToolRunner.run(bench, args);
} finally {
LOG.info("runBenchmark finished.");
}
}
/**
* Main method of the benchmark.
* @param aArgs command line parameters
*/
@Override // Tool
public int run(String[] aArgs) throws Exception {
List<String> args = new ArrayList<String>(Arrays.asList(aArgs));
if (args.size() < 2 || !args.get(0).startsWith("-op")) {
printUsage();
}
String type = args.get(1);
boolean runAll = OperationStatsBase.OP_ALL_NAME.equals(type);
List<OperationStatsBase> ops = new ArrayList<OperationStatsBase>();
OperationStatsBase opStat = null;
try {
if (runAll || EncryptKeyStats.OP_ENCRYPT_KEY.equals(type)) {
opStat = new EncryptKeyStats(args);
ops.add(opStat);
}
if (runAll || DecryptKeyStats.OP_DECRYPT_KEY.equals(type)) {
opStat = new DecryptKeyStats(args);
ops.add(opStat);
}
if (ops.isEmpty()) {
printUsage();
}
// run each benchmark
for (OperationStatsBase op : ops) {
LOG.info("Starting benchmark: " + op.getOpName());
op.benchmark();
op.cleanUp();
}
// print statistics
for (OperationStatsBase op : ops) {
LOG.info("");
op.printResults();
}
} catch(Exception e) {
LOG.error("failed to run benchmarks", e);
throw e;
}
return 0;
}
public static void main(String[] args) throws Exception {
runBenchmark(new Configuration(), args);
}
@Override // Configurable
public void setConf(Configuration conf) {
config = conf;
}
@Override // Configurable
public Configuration getConf() {
return config;
}
}