blob: 147279115cb0cb576e844303f823f18252cc7147 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License") you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.nifi.security.kms
import org.apache.commons.lang3.SystemUtils
import org.bouncycastle.jce.provider.BouncyCastleProvider
import org.bouncycastle.util.encoders.Hex
import org.junit.After
import org.junit.AfterClass
import org.junit.Assume
import org.junit.Before
import org.junit.BeforeClass
import org.junit.ClassRule
import org.junit.Test
import org.junit.rules.TemporaryFolder
import org.junit.runner.RunWith
import org.junit.runners.JUnit4
import org.slf4j.Logger
import org.slf4j.LoggerFactory
import javax.crypto.Cipher
import java.nio.charset.StandardCharsets
import java.nio.file.Files
import java.nio.file.attribute.PosixFilePermission
import java.security.Security
@RunWith(JUnit4.class)
class CryptoUtilsTest {
private static final Logger logger = LoggerFactory.getLogger(CryptoUtilsTest.class)
private static final String KEY_ID = "K1"
private static final String KEY_HEX_128 = "0123456789ABCDEFFEDCBA9876543210"
private static final String KEY_HEX_256 = KEY_HEX_128 * 2
private static final String KEY_HEX = isUnlimitedStrengthCryptoAvailable() ? KEY_HEX_256 : KEY_HEX_128
private static
final Set<PosixFilePermission> ALL_POSIX_ATTRS = PosixFilePermission.values() as Set<PosixFilePermission>
@ClassRule
public static TemporaryFolder tempFolder = new TemporaryFolder()
@BeforeClass
static void setUpOnce() throws Exception {
Security.addProvider(new BouncyCastleProvider())
logger.metaClass.methodMissing = { String name, args ->
logger.info("[${name?.toUpperCase()}] ${(args as List).join(" ")}")
}
}
@Before
void setUp() throws Exception {
tempFolder.create()
}
@After
void tearDown() throws Exception {
tempFolder?.delete()
}
@AfterClass
static void tearDownOnce() throws Exception {
}
private static boolean isUnlimitedStrengthCryptoAvailable() {
Cipher.getMaxAllowedKeyLength("AES") > 128
}
private static boolean isRootUser() {
ProcessBuilder pb = new ProcessBuilder(["id", "-u"])
Process process = pb.start()
InputStream responseStream = process.getInputStream()
BufferedReader responseReader = new BufferedReader(new InputStreamReader(responseStream))
responseReader.text.trim() == "0"
}
@Test
void testShouldConcatenateByteArrays() {
// Arrange
byte[] bytes1 = "These are some bytes".getBytes(StandardCharsets.UTF_8)
byte[] bytes2 = "These are some other bytes".getBytes(StandardCharsets.UTF_8)
final byte[] EXPECTED_CONCATENATED_BYTES = ((bytes1 as List) << (bytes2 as List)).flatten() as byte[]
logger.info("Expected concatenated bytes: ${Hex.toHexString(EXPECTED_CONCATENATED_BYTES)}")
// Act
byte[] concat = CryptoUtils.concatByteArrays(bytes1, bytes2)
logger.info(" Actual concatenated bytes: ${Hex.toHexString(concat)}")
// Assert
assert concat == EXPECTED_CONCATENATED_BYTES
}
@Test
void testShouldValidateStaticKeyProvider() {
// Arrange
String staticProvider = StaticKeyProvider.class.name
String providerLocation = null
// Act
boolean keyProviderIsValid = CryptoUtils.isValidKeyProvider(staticProvider, providerLocation, KEY_ID, [(KEY_ID): KEY_HEX])
logger.info("Key Provider ${staticProvider} with location ${providerLocation} and keyId ${KEY_ID} / ${KEY_HEX} is ${keyProviderIsValid ? "valid" : "invalid"}")
// Assert
assert keyProviderIsValid
}
@Test
void testShouldValidateLegacyStaticKeyProvider() {
// Arrange
String staticProvider = StaticKeyProvider.class.name.replaceFirst("security.kms", "provenance")
String providerLocation = null
// Act
boolean keyProviderIsValid = CryptoUtils.isValidKeyProvider(staticProvider, providerLocation, KEY_ID, [(KEY_ID): KEY_HEX])
logger.info("Key Provider ${staticProvider} with location ${providerLocation} and keyId ${KEY_ID} / ${KEY_HEX} is ${keyProviderIsValid ? "valid" : "invalid"}")
// Assert
assert keyProviderIsValid
}
@Test
void testShouldNotValidateStaticKeyProviderMissingKeyId() {
// Arrange
String staticProvider = StaticKeyProvider.class.name
String providerLocation = null
// Act
boolean keyProviderIsValid = CryptoUtils.isValidKeyProvider(staticProvider, providerLocation, null, [(KEY_ID): KEY_HEX])
logger.info("Key Provider ${staticProvider} with location ${providerLocation} and keyId ${null} / ${KEY_HEX} is ${keyProviderIsValid ? "valid" : "invalid"}")
// Assert
assert !keyProviderIsValid
}
@Test
void testShouldNotValidateStaticKeyProviderMissingKey() {
// Arrange
String staticProvider = StaticKeyProvider.class.name
String providerLocation = null
// Act
boolean keyProviderIsValid = CryptoUtils.isValidKeyProvider(staticProvider, providerLocation, KEY_ID, null)
logger.info("Key Provider ${staticProvider} with location ${providerLocation} and keyId ${KEY_ID} / ${null} is ${keyProviderIsValid ? "valid" : "invalid"}")
// Assert
assert !keyProviderIsValid
}
@Test
void testShouldNotValidateStaticKeyProviderWithInvalidKey() {
// Arrange
String staticProvider = StaticKeyProvider.class.name
String providerLocation = null
// Act
boolean keyProviderIsValid = CryptoUtils.isValidKeyProvider(staticProvider, providerLocation, KEY_ID, [(KEY_ID): KEY_HEX[0..<-2]])
logger.info("Key Provider ${staticProvider} with location ${providerLocation} and keyId ${KEY_ID} / ${KEY_HEX[0..<-2]} is ${keyProviderIsValid ? "valid" : "invalid"}")
// Assert
assert !keyProviderIsValid
}
@Test
void testShouldValidateFileBasedKeyProvider() {
// Arrange
String fileBasedProvider = FileBasedKeyProvider.class.name
File fileBasedProviderFile = tempFolder.newFile("filebased.kp")
String providerLocation = fileBasedProviderFile.path
logger.info("Created temporary file based key provider: ${providerLocation}")
// Act
boolean keyProviderIsValid = CryptoUtils.isValidKeyProvider(fileBasedProvider, providerLocation, KEY_ID, null)
logger.info("Key Provider ${fileBasedProvider} with location ${providerLocation} and keyId ${KEY_ID} / ${null} is ${keyProviderIsValid ? "valid" : "invalid"}")
// Assert
assert keyProviderIsValid
}
@Test
void testShouldValidateLegacyFileBasedKeyProvider() {
// Arrange
String fileBasedProvider = FileBasedKeyProvider.class.name.replaceFirst("security.kms", "provenance")
File fileBasedProviderFile = tempFolder.newFile("filebased.kp")
String providerLocation = fileBasedProviderFile.path
logger.info("Created temporary file based key provider: ${providerLocation}")
// Act
boolean keyProviderIsValid = CryptoUtils.isValidKeyProvider(fileBasedProvider, providerLocation, KEY_ID, null)
logger.info("Key Provider ${fileBasedProvider} with location ${providerLocation} and keyId ${KEY_ID} / ${null} is ${keyProviderIsValid ? "valid" : "invalid"}")
// Assert
assert keyProviderIsValid
}
@Test
void testShouldNotValidateMissingFileBasedKeyProvider() {
// Arrange
String fileBasedProvider = FileBasedKeyProvider.class.name
File fileBasedProviderFile = new File(tempFolder.root, "filebased_missing.kp")
String providerLocation = fileBasedProviderFile.path
logger.info("Created (no actual file) temporary file based key provider: ${providerLocation}")
// Act
String missingLocation = providerLocation
boolean missingKeyProviderIsValid = CryptoUtils.isValidKeyProvider(fileBasedProvider, missingLocation, KEY_ID, null)
logger.info("Key Provider ${fileBasedProvider} with location ${missingLocation} and keyId ${KEY_ID} / ${null} is ${missingKeyProviderIsValid ? "valid" : "invalid"}")
// Assert
assert !missingKeyProviderIsValid
}
@Test
void testShouldNotValidateUnreadableFileBasedKeyProvider() {
// Arrange
Assume.assumeFalse("This test does not run on Windows", SystemUtils.IS_OS_WINDOWS)
Assume.assumeFalse("This test does not run for root users", isRootUser())
String fileBasedProvider = FileBasedKeyProvider.class.name
File fileBasedProviderFile = tempFolder.newFile("filebased.kp")
String providerLocation = fileBasedProviderFile.path
logger.info("Created temporary file based key provider: ${providerLocation}")
// Make it unreadable
markFileUnreadable(fileBasedProviderFile)
// Act
boolean unreadableKeyProviderIsValid = CryptoUtils.isValidKeyProvider(fileBasedProvider, providerLocation, KEY_ID, null)
logger.info("Key Provider ${fileBasedProvider} with location ${providerLocation} and keyId ${KEY_ID} / ${null} is ${unreadableKeyProviderIsValid ? "valid" : "invalid"}")
// Assert
assert !unreadableKeyProviderIsValid
// Make the file deletable so cleanup can occur
markFileReadable(fileBasedProviderFile)
}
private static void markFileReadable(File fileBasedProviderFile) {
if (SystemUtils.IS_OS_WINDOWS) {
fileBasedProviderFile.setReadable(true, false)
} else {
Files.setPosixFilePermissions(fileBasedProviderFile.toPath(), ALL_POSIX_ATTRS)
}
}
private static void markFileUnreadable(File fileBasedProviderFile) {
if (SystemUtils.IS_OS_WINDOWS) {
fileBasedProviderFile.setReadable(false, false)
} else {
Files.setPosixFilePermissions(fileBasedProviderFile.toPath(), [] as Set<PosixFilePermission>)
}
}
@Test
void testShouldNotValidateFileBasedKeyProviderMissingKeyId() {
// Arrange
String fileBasedProvider = FileBasedKeyProvider.class.name
File fileBasedProviderFile = tempFolder.newFile("missing_key_id.kp")
String providerLocation = fileBasedProviderFile.path
logger.info("Created temporary file based key provider: ${providerLocation}")
// Act
boolean keyProviderIsValid = CryptoUtils.isValidKeyProvider(fileBasedProvider, providerLocation, null, null)
logger.info("Key Provider ${fileBasedProvider} with location ${providerLocation} and keyId ${null} / ${null} is ${keyProviderIsValid ? "valid" : "invalid"}")
// Assert
assert !keyProviderIsValid
}
@Test
void testShouldNotValidateUnknownKeyProvider() {
// Arrange
String providerImplementation = "org.apache.nifi.provenance.ImaginaryKeyProvider"
String providerLocation = null
// Act
boolean keyProviderIsValid = CryptoUtils.isValidKeyProvider(providerImplementation, providerLocation, KEY_ID, null)
logger.info("Key Provider ${providerImplementation} with location ${providerLocation} and keyId ${KEY_ID} / ${null} is ${keyProviderIsValid ? "valid" : "invalid"}")
// Assert
assert !keyProviderIsValid
}
@Test
void testShouldValidateKey() {
// Arrange
String validKey = KEY_HEX
String validLowercaseKey = KEY_HEX.toLowerCase()
String tooShortKey = KEY_HEX[0..<-2]
String tooLongKey = KEY_HEX + KEY_HEX // Guaranteed to be 2x the max valid key length
String nonHexKey = KEY_HEX.replaceFirst(/A/, "X")
def validKeys = [validKey, validLowercaseKey]
def invalidKeys = [tooShortKey, tooLongKey, nonHexKey]
// If unlimited strength is available, also validate 128 and 196 bit keys
if (isUnlimitedStrengthCryptoAvailable()) {
validKeys << KEY_HEX_128
validKeys << KEY_HEX_256[0..<48]
} else {
invalidKeys << KEY_HEX_256[0..<48]
invalidKeys << KEY_HEX_256
}
// Act
def validResults = validKeys.collect { String key ->
logger.info("Validating ${key}")
CryptoUtils.keyIsValid(key)
}
def invalidResults = invalidKeys.collect { String key ->
logger.info("Validating ${key}")
CryptoUtils.keyIsValid(key)
}
// Assert
assert validResults.every()
assert invalidResults.every { !it }
}
@Test
void testShouldEvaluateConstantTimeEqualsForStrings() {
// Arrange
String plaintext = "This is a short string."
String firstCharOff = "this is a short string."
String lastCharOff = "This is a short string,"
final int ITERATIONS = 10_000
final int WARM_UP_ITERATIONS = 1_000 * ITERATIONS
def scenarios = ["identical": plaintext, "first off": firstCharOff, "last off": lastCharOff]
def results = [:]
def timings = [:]
boolean isEqual = true
long nanos = 0
long scNanos = 0
// Prepare the JVM
(WARM_UP_ITERATIONS).times { int i ->
def scIterationNanos = time("warm up sc") {
assert plaintext == plaintext
}
scNanos += scIterationNanos
def iterationNanos = time("warm up") {
assert CryptoUtils.constantTimeEquals(plaintext, plaintext)
}
nanos += iterationNanos
}
logger.info("${"warm up sc".padLeft(10)}: ${nanos} ns (avg: ${nanos / (WARM_UP_ITERATIONS)} ns)")
logger.info("${"warm up".padLeft(10)}: ${scNanos} ns (avg: ${scNanos / (WARM_UP_ITERATIONS)} ns)")
// Act
scenarios.each { String scenario, String value ->
isEqual = true
scNanos = 0
nanos = 0
ITERATIONS.times { int i ->
def scIterationNanos = time(scenario + " sc") {
(plaintext == value)
}
scNanos += scIterationNanos
def iterationNanos = time(scenario) {
isEqual = CryptoUtils.constantTimeEquals(plaintext, value)
}
nanos += iterationNanos
}
def scenarioWidth = 16
logger.info("${(scenario + " sc").padLeft(scenarioWidth)}: ${scNanos} ns (avg: ${scNanos / ITERATIONS} ns)")
logger.info("${scenario.padLeft(scenarioWidth)}: ${nanos} ns (avg: ${nanos / ITERATIONS} ns)")
results[scenario] = isEqual
timings[scenario] = nanos
}
// Assert
assert results["identical"]
assert !results["first off"]
assert !results["last off"]
// TODO: Assert timings are within std dev?
}
@Test
void testShouldEvaluateConstantTimeEqualsForBytes() {
// Arrange
String plaintext = "This is a short string."
String firstCharOff = "this is a short string."
String lastCharOff = "This is a short string,"
final int ITERATIONS = 10_000
final int WARM_UP_ITERATIONS = 1_000 * ITERATIONS
def scenarios = ["identical": plaintext, "first off": firstCharOff, "last off": lastCharOff]
def results = [:]
def timings = [:]
boolean isEqual = true
long nanos = 0
long scNanos = 0
// Prepare the JVM
byte[] plaintextBytes = plaintext.getBytes("UTF-8")
(WARM_UP_ITERATIONS).times { int i ->
def scIterationNanos = time("warm up sc") {
assert plaintext == plaintext
}
scNanos += scIterationNanos
def iterationNanos = time("warm up") {
assert CryptoUtils.constantTimeEquals(plaintextBytes, plaintextBytes)
}
nanos += iterationNanos
}
logger.info("${"warm up sc".padLeft(10)}: ${nanos} ns (avg: ${nanos / (WARM_UP_ITERATIONS)} ns)")
logger.info("${"warm up".padLeft(10)}: ${scNanos} ns (avg: ${scNanos / (WARM_UP_ITERATIONS)} ns)")
// Act
scenarios.each { String scenario, String value ->
isEqual = true
scNanos = 0
nanos = 0
byte[] valueBytes = value.getBytes("UTF-8")
ITERATIONS.times { int i ->
def scIterationNanos = time(scenario + " sc") {
(plaintextBytes == valueBytes)
}
scNanos += scIterationNanos
def iterationNanos = time(scenario) {
isEqual = CryptoUtils.constantTimeEquals(plaintextBytes, valueBytes)
}
nanos += iterationNanos
}
def scenarioWidth = 16
logger.info("${(scenario + " sc").padLeft(scenarioWidth)}: ${scNanos} ns (avg: ${scNanos / ITERATIONS} ns)")
logger.info("${scenario.padLeft(scenarioWidth)}: ${nanos} ns (avg: ${nanos / ITERATIONS} ns)")
results[scenario] = isEqual
timings[scenario] = nanos
}
// Assert
assert results["identical"]
assert !results["first off"]
assert !results["last off"]
// TODO: Assert timings are within std dev?
}
@Test
void testShouldEvaluateConstantTimeEqualsForChars() {
// Arrange
String plaintext = "This is a short string."
String firstCharOff = "this is a short string."
String lastCharOff = "This is a short string,"
final int ITERATIONS = 10_000
final int WARM_UP_ITERATIONS = 1_000 * ITERATIONS
def scenarios = ["identical": plaintext, "first off": firstCharOff, "last off": lastCharOff]
def results = [:]
def timings = [:]
boolean isEqual = true
long nanos = 0
long scNanos = 0
// Prepare the JVM
def plaintextChars = plaintext.chars
(WARM_UP_ITERATIONS).times { int i ->
def scIterationNanos = time("warm up sc") {
assert plaintext == plaintext
}
scNanos += scIterationNanos
def iterationNanos = time("warm up") {
assert CryptoUtils.constantTimeEquals(plaintextChars, plaintextChars)
}
nanos += iterationNanos
}
logger.info("${"warm up sc".padLeft(10)}: ${nanos} ns (avg: ${nanos / (WARM_UP_ITERATIONS)} ns)")
logger.info("${"warm up".padLeft(10)}: ${scNanos} ns (avg: ${scNanos / (WARM_UP_ITERATIONS)} ns)")
// Act
scenarios.each { String scenario, String value ->
isEqual = true
scNanos = 0
nanos = 0
def valueChars = value.chars
ITERATIONS.times { int i ->
def scIterationNanos = time(scenario + " sc") {
(plaintextChars == valueChars)
}
scNanos += scIterationNanos
def iterationNanos = time(scenario) {
isEqual = CryptoUtils.constantTimeEquals(plaintextChars, valueChars)
}
nanos += iterationNanos
}
def scenarioWidth = 16
logger.info("${(scenario + " sc").padLeft(scenarioWidth)}: ${scNanos} ns (avg: ${scNanos / ITERATIONS} ns)")
logger.info("${scenario.padLeft(scenarioWidth)}: ${nanos} ns (avg: ${nanos / ITERATIONS} ns)")
results[scenario] = isEqual
timings[scenario] = nanos
}
// Assert
assert results["identical"]
assert !results["first off"]
assert !results["last off"]
// TODO: Assert timings are within std dev?
}
private static long time(String name = "closure", Closure closure) {
long start = System.nanoTime()
closure.run()
long end = System.nanoTime()
end - start
}
}