| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License") you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| package org.apache.nifi.security.kms |
| |
| import org.apache.commons.lang3.SystemUtils |
| import org.bouncycastle.jce.provider.BouncyCastleProvider |
| import org.bouncycastle.util.encoders.Hex |
| import org.junit.After |
| import org.junit.AfterClass |
| import org.junit.Assume |
| import org.junit.Before |
| import org.junit.BeforeClass |
| import org.junit.ClassRule |
| import org.junit.Test |
| import org.junit.rules.TemporaryFolder |
| import org.junit.runner.RunWith |
| import org.junit.runners.JUnit4 |
| import org.slf4j.Logger |
| import org.slf4j.LoggerFactory |
| |
| import javax.crypto.Cipher |
| import java.nio.charset.StandardCharsets |
| import java.nio.file.Files |
| import java.nio.file.attribute.PosixFilePermission |
| import java.security.Security |
| |
| @RunWith(JUnit4.class) |
| class CryptoUtilsTest { |
| private static final Logger logger = LoggerFactory.getLogger(CryptoUtilsTest.class) |
| |
| private static final String KEY_ID = "K1" |
| private static final String KEY_HEX_128 = "0123456789ABCDEFFEDCBA9876543210" |
| private static final String KEY_HEX_256 = KEY_HEX_128 * 2 |
| private static final String KEY_HEX = isUnlimitedStrengthCryptoAvailable() ? KEY_HEX_256 : KEY_HEX_128 |
| |
| private static |
| final Set<PosixFilePermission> ALL_POSIX_ATTRS = PosixFilePermission.values() as Set<PosixFilePermission> |
| |
| @ClassRule |
| public static TemporaryFolder tempFolder = new TemporaryFolder() |
| |
| @BeforeClass |
| static void setUpOnce() throws Exception { |
| Security.addProvider(new BouncyCastleProvider()) |
| |
| logger.metaClass.methodMissing = { String name, args -> |
| logger.info("[${name?.toUpperCase()}] ${(args as List).join(" ")}") |
| } |
| } |
| |
| @Before |
| void setUp() throws Exception { |
| tempFolder.create() |
| } |
| |
| @After |
| void tearDown() throws Exception { |
| tempFolder?.delete() |
| } |
| |
| @AfterClass |
| static void tearDownOnce() throws Exception { |
| |
| } |
| |
| private static boolean isUnlimitedStrengthCryptoAvailable() { |
| Cipher.getMaxAllowedKeyLength("AES") > 128 |
| } |
| |
| private static boolean isRootUser() { |
| ProcessBuilder pb = new ProcessBuilder(["id", "-u"]) |
| Process process = pb.start() |
| InputStream responseStream = process.getInputStream() |
| BufferedReader responseReader = new BufferedReader(new InputStreamReader(responseStream)) |
| responseReader.text.trim() == "0" |
| } |
| |
| @Test |
| void testShouldConcatenateByteArrays() { |
| // Arrange |
| byte[] bytes1 = "These are some bytes".getBytes(StandardCharsets.UTF_8) |
| byte[] bytes2 = "These are some other bytes".getBytes(StandardCharsets.UTF_8) |
| final byte[] EXPECTED_CONCATENATED_BYTES = ((bytes1 as List) << (bytes2 as List)).flatten() as byte[] |
| logger.info("Expected concatenated bytes: ${Hex.toHexString(EXPECTED_CONCATENATED_BYTES)}") |
| |
| // Act |
| byte[] concat = CryptoUtils.concatByteArrays(bytes1, bytes2) |
| logger.info(" Actual concatenated bytes: ${Hex.toHexString(concat)}") |
| |
| // Assert |
| assert concat == EXPECTED_CONCATENATED_BYTES |
| } |
| |
| @Test |
| void testShouldValidateStaticKeyProvider() { |
| // Arrange |
| String staticProvider = StaticKeyProvider.class.name |
| String providerLocation = null |
| |
| // Act |
| boolean keyProviderIsValid = CryptoUtils.isValidKeyProvider(staticProvider, providerLocation, KEY_ID, [(KEY_ID): KEY_HEX]) |
| logger.info("Key Provider ${staticProvider} with location ${providerLocation} and keyId ${KEY_ID} / ${KEY_HEX} is ${keyProviderIsValid ? "valid" : "invalid"}") |
| |
| // Assert |
| assert keyProviderIsValid |
| } |
| |
| @Test |
| void testShouldValidateLegacyStaticKeyProvider() { |
| // Arrange |
| String staticProvider = StaticKeyProvider.class.name.replaceFirst("security.kms", "provenance") |
| String providerLocation = null |
| |
| // Act |
| boolean keyProviderIsValid = CryptoUtils.isValidKeyProvider(staticProvider, providerLocation, KEY_ID, [(KEY_ID): KEY_HEX]) |
| logger.info("Key Provider ${staticProvider} with location ${providerLocation} and keyId ${KEY_ID} / ${KEY_HEX} is ${keyProviderIsValid ? "valid" : "invalid"}") |
| |
| // Assert |
| assert keyProviderIsValid |
| } |
| |
| @Test |
| void testShouldNotValidateStaticKeyProviderMissingKeyId() { |
| // Arrange |
| String staticProvider = StaticKeyProvider.class.name |
| String providerLocation = null |
| |
| // Act |
| boolean keyProviderIsValid = CryptoUtils.isValidKeyProvider(staticProvider, providerLocation, null, [(KEY_ID): KEY_HEX]) |
| logger.info("Key Provider ${staticProvider} with location ${providerLocation} and keyId ${null} / ${KEY_HEX} is ${keyProviderIsValid ? "valid" : "invalid"}") |
| |
| // Assert |
| assert !keyProviderIsValid |
| } |
| |
| @Test |
| void testShouldNotValidateStaticKeyProviderMissingKey() { |
| // Arrange |
| String staticProvider = StaticKeyProvider.class.name |
| String providerLocation = null |
| |
| // Act |
| boolean keyProviderIsValid = CryptoUtils.isValidKeyProvider(staticProvider, providerLocation, KEY_ID, null) |
| logger.info("Key Provider ${staticProvider} with location ${providerLocation} and keyId ${KEY_ID} / ${null} is ${keyProviderIsValid ? "valid" : "invalid"}") |
| |
| // Assert |
| assert !keyProviderIsValid |
| } |
| |
| @Test |
| void testShouldNotValidateStaticKeyProviderWithInvalidKey() { |
| // Arrange |
| String staticProvider = StaticKeyProvider.class.name |
| String providerLocation = null |
| |
| // Act |
| boolean keyProviderIsValid = CryptoUtils.isValidKeyProvider(staticProvider, providerLocation, KEY_ID, [(KEY_ID): KEY_HEX[0..<-2]]) |
| logger.info("Key Provider ${staticProvider} with location ${providerLocation} and keyId ${KEY_ID} / ${KEY_HEX[0..<-2]} is ${keyProviderIsValid ? "valid" : "invalid"}") |
| |
| // Assert |
| assert !keyProviderIsValid |
| } |
| |
| @Test |
| void testShouldValidateFileBasedKeyProvider() { |
| // Arrange |
| String fileBasedProvider = FileBasedKeyProvider.class.name |
| File fileBasedProviderFile = tempFolder.newFile("filebased.kp") |
| String providerLocation = fileBasedProviderFile.path |
| logger.info("Created temporary file based key provider: ${providerLocation}") |
| |
| // Act |
| boolean keyProviderIsValid = CryptoUtils.isValidKeyProvider(fileBasedProvider, providerLocation, KEY_ID, null) |
| logger.info("Key Provider ${fileBasedProvider} with location ${providerLocation} and keyId ${KEY_ID} / ${null} is ${keyProviderIsValid ? "valid" : "invalid"}") |
| |
| // Assert |
| assert keyProviderIsValid |
| } |
| |
| @Test |
| void testShouldValidateLegacyFileBasedKeyProvider() { |
| // Arrange |
| String fileBasedProvider = FileBasedKeyProvider.class.name.replaceFirst("security.kms", "provenance") |
| File fileBasedProviderFile = tempFolder.newFile("filebased.kp") |
| String providerLocation = fileBasedProviderFile.path |
| logger.info("Created temporary file based key provider: ${providerLocation}") |
| |
| // Act |
| boolean keyProviderIsValid = CryptoUtils.isValidKeyProvider(fileBasedProvider, providerLocation, KEY_ID, null) |
| logger.info("Key Provider ${fileBasedProvider} with location ${providerLocation} and keyId ${KEY_ID} / ${null} is ${keyProviderIsValid ? "valid" : "invalid"}") |
| |
| // Assert |
| assert keyProviderIsValid |
| } |
| |
| @Test |
| void testShouldNotValidateMissingFileBasedKeyProvider() { |
| // Arrange |
| String fileBasedProvider = FileBasedKeyProvider.class.name |
| File fileBasedProviderFile = new File(tempFolder.root, "filebased_missing.kp") |
| String providerLocation = fileBasedProviderFile.path |
| logger.info("Created (no actual file) temporary file based key provider: ${providerLocation}") |
| |
| // Act |
| String missingLocation = providerLocation |
| boolean missingKeyProviderIsValid = CryptoUtils.isValidKeyProvider(fileBasedProvider, missingLocation, KEY_ID, null) |
| logger.info("Key Provider ${fileBasedProvider} with location ${missingLocation} and keyId ${KEY_ID} / ${null} is ${missingKeyProviderIsValid ? "valid" : "invalid"}") |
| |
| // Assert |
| assert !missingKeyProviderIsValid |
| } |
| |
| @Test |
| void testShouldNotValidateUnreadableFileBasedKeyProvider() { |
| // Arrange |
| Assume.assumeFalse("This test does not run on Windows", SystemUtils.IS_OS_WINDOWS) |
| Assume.assumeFalse("This test does not run for root users", isRootUser()) |
| |
| String fileBasedProvider = FileBasedKeyProvider.class.name |
| File fileBasedProviderFile = tempFolder.newFile("filebased.kp") |
| String providerLocation = fileBasedProviderFile.path |
| logger.info("Created temporary file based key provider: ${providerLocation}") |
| |
| // Make it unreadable |
| markFileUnreadable(fileBasedProviderFile) |
| |
| // Act |
| boolean unreadableKeyProviderIsValid = CryptoUtils.isValidKeyProvider(fileBasedProvider, providerLocation, KEY_ID, null) |
| logger.info("Key Provider ${fileBasedProvider} with location ${providerLocation} and keyId ${KEY_ID} / ${null} is ${unreadableKeyProviderIsValid ? "valid" : "invalid"}") |
| |
| // Assert |
| assert !unreadableKeyProviderIsValid |
| |
| // Make the file deletable so cleanup can occur |
| markFileReadable(fileBasedProviderFile) |
| } |
| |
| private static void markFileReadable(File fileBasedProviderFile) { |
| if (SystemUtils.IS_OS_WINDOWS) { |
| fileBasedProviderFile.setReadable(true, false) |
| } else { |
| Files.setPosixFilePermissions(fileBasedProviderFile.toPath(), ALL_POSIX_ATTRS) |
| } |
| } |
| |
| private static void markFileUnreadable(File fileBasedProviderFile) { |
| if (SystemUtils.IS_OS_WINDOWS) { |
| fileBasedProviderFile.setReadable(false, false) |
| } else { |
| Files.setPosixFilePermissions(fileBasedProviderFile.toPath(), [] as Set<PosixFilePermission>) |
| } |
| } |
| |
| @Test |
| void testShouldNotValidateFileBasedKeyProviderMissingKeyId() { |
| // Arrange |
| String fileBasedProvider = FileBasedKeyProvider.class.name |
| File fileBasedProviderFile = tempFolder.newFile("missing_key_id.kp") |
| String providerLocation = fileBasedProviderFile.path |
| logger.info("Created temporary file based key provider: ${providerLocation}") |
| |
| // Act |
| boolean keyProviderIsValid = CryptoUtils.isValidKeyProvider(fileBasedProvider, providerLocation, null, null) |
| logger.info("Key Provider ${fileBasedProvider} with location ${providerLocation} and keyId ${null} / ${null} is ${keyProviderIsValid ? "valid" : "invalid"}") |
| |
| // Assert |
| assert !keyProviderIsValid |
| } |
| |
| @Test |
| void testShouldNotValidateUnknownKeyProvider() { |
| // Arrange |
| String providerImplementation = "org.apache.nifi.provenance.ImaginaryKeyProvider" |
| String providerLocation = null |
| |
| // Act |
| boolean keyProviderIsValid = CryptoUtils.isValidKeyProvider(providerImplementation, providerLocation, KEY_ID, null) |
| logger.info("Key Provider ${providerImplementation} with location ${providerLocation} and keyId ${KEY_ID} / ${null} is ${keyProviderIsValid ? "valid" : "invalid"}") |
| |
| // Assert |
| assert !keyProviderIsValid |
| } |
| |
| @Test |
| void testShouldValidateKey() { |
| // Arrange |
| String validKey = KEY_HEX |
| String validLowercaseKey = KEY_HEX.toLowerCase() |
| |
| String tooShortKey = KEY_HEX[0..<-2] |
| String tooLongKey = KEY_HEX + KEY_HEX // Guaranteed to be 2x the max valid key length |
| String nonHexKey = KEY_HEX.replaceFirst(/A/, "X") |
| |
| def validKeys = [validKey, validLowercaseKey] |
| def invalidKeys = [tooShortKey, tooLongKey, nonHexKey] |
| |
| // If unlimited strength is available, also validate 128 and 196 bit keys |
| if (isUnlimitedStrengthCryptoAvailable()) { |
| validKeys << KEY_HEX_128 |
| validKeys << KEY_HEX_256[0..<48] |
| } else { |
| invalidKeys << KEY_HEX_256[0..<48] |
| invalidKeys << KEY_HEX_256 |
| } |
| |
| // Act |
| def validResults = validKeys.collect { String key -> |
| logger.info("Validating ${key}") |
| CryptoUtils.keyIsValid(key) |
| } |
| |
| def invalidResults = invalidKeys.collect { String key -> |
| logger.info("Validating ${key}") |
| CryptoUtils.keyIsValid(key) |
| } |
| |
| // Assert |
| assert validResults.every() |
| assert invalidResults.every { !it } |
| } |
| |
| @Test |
| void testShouldEvaluateConstantTimeEqualsForStrings() { |
| // Arrange |
| String plaintext = "This is a short string." |
| String firstCharOff = "this is a short string." |
| String lastCharOff = "This is a short string," |
| |
| final int ITERATIONS = 10_000 |
| final int WARM_UP_ITERATIONS = 1_000 * ITERATIONS |
| |
| def scenarios = ["identical": plaintext, "first off": firstCharOff, "last off": lastCharOff] |
| def results = [:] |
| def timings = [:] |
| |
| boolean isEqual = true |
| long nanos = 0 |
| long scNanos = 0 |
| |
| // Prepare the JVM |
| (WARM_UP_ITERATIONS).times { int i -> |
| def scIterationNanos = time("warm up sc") { |
| assert plaintext == plaintext |
| } |
| scNanos += scIterationNanos |
| def iterationNanos = time("warm up") { |
| assert CryptoUtils.constantTimeEquals(plaintext, plaintext) |
| } |
| nanos += iterationNanos |
| } |
| logger.info("${"warm up sc".padLeft(10)}: ${nanos} ns (avg: ${nanos / (WARM_UP_ITERATIONS)} ns)") |
| logger.info("${"warm up".padLeft(10)}: ${scNanos} ns (avg: ${scNanos / (WARM_UP_ITERATIONS)} ns)") |
| |
| // Act |
| scenarios.each { String scenario, String value -> |
| isEqual = true |
| scNanos = 0 |
| nanos = 0 |
| ITERATIONS.times { int i -> |
| def scIterationNanos = time(scenario + " sc") { |
| (plaintext == value) |
| } |
| scNanos += scIterationNanos |
| def iterationNanos = time(scenario) { |
| isEqual = CryptoUtils.constantTimeEquals(plaintext, value) |
| } |
| nanos += iterationNanos |
| } |
| def scenarioWidth = 16 |
| logger.info("${(scenario + " sc").padLeft(scenarioWidth)}: ${scNanos} ns (avg: ${scNanos / ITERATIONS} ns)") |
| logger.info("${scenario.padLeft(scenarioWidth)}: ${nanos} ns (avg: ${nanos / ITERATIONS} ns)") |
| results[scenario] = isEqual |
| timings[scenario] = nanos |
| } |
| |
| // Assert |
| assert results["identical"] |
| assert !results["first off"] |
| assert !results["last off"] |
| |
| // TODO: Assert timings are within std dev? |
| } |
| |
| @Test |
| void testShouldEvaluateConstantTimeEqualsForBytes() { |
| // Arrange |
| String plaintext = "This is a short string." |
| String firstCharOff = "this is a short string." |
| String lastCharOff = "This is a short string," |
| |
| final int ITERATIONS = 10_000 |
| final int WARM_UP_ITERATIONS = 1_000 * ITERATIONS |
| |
| def scenarios = ["identical": plaintext, "first off": firstCharOff, "last off": lastCharOff] |
| def results = [:] |
| def timings = [:] |
| |
| boolean isEqual = true |
| long nanos = 0 |
| long scNanos = 0 |
| |
| // Prepare the JVM |
| byte[] plaintextBytes = plaintext.getBytes("UTF-8") |
| (WARM_UP_ITERATIONS).times { int i -> |
| def scIterationNanos = time("warm up sc") { |
| assert plaintext == plaintext |
| } |
| scNanos += scIterationNanos |
| def iterationNanos = time("warm up") { |
| assert CryptoUtils.constantTimeEquals(plaintextBytes, plaintextBytes) |
| } |
| nanos += iterationNanos |
| } |
| logger.info("${"warm up sc".padLeft(10)}: ${nanos} ns (avg: ${nanos / (WARM_UP_ITERATIONS)} ns)") |
| logger.info("${"warm up".padLeft(10)}: ${scNanos} ns (avg: ${scNanos / (WARM_UP_ITERATIONS)} ns)") |
| |
| // Act |
| scenarios.each { String scenario, String value -> |
| isEqual = true |
| scNanos = 0 |
| nanos = 0 |
| byte[] valueBytes = value.getBytes("UTF-8") |
| ITERATIONS.times { int i -> |
| def scIterationNanos = time(scenario + " sc") { |
| (plaintextBytes == valueBytes) |
| } |
| scNanos += scIterationNanos |
| def iterationNanos = time(scenario) { |
| isEqual = CryptoUtils.constantTimeEquals(plaintextBytes, valueBytes) |
| } |
| nanos += iterationNanos |
| } |
| def scenarioWidth = 16 |
| logger.info("${(scenario + " sc").padLeft(scenarioWidth)}: ${scNanos} ns (avg: ${scNanos / ITERATIONS} ns)") |
| logger.info("${scenario.padLeft(scenarioWidth)}: ${nanos} ns (avg: ${nanos / ITERATIONS} ns)") |
| results[scenario] = isEqual |
| timings[scenario] = nanos |
| } |
| |
| // Assert |
| assert results["identical"] |
| assert !results["first off"] |
| assert !results["last off"] |
| |
| // TODO: Assert timings are within std dev? |
| } |
| |
| @Test |
| void testShouldEvaluateConstantTimeEqualsForChars() { |
| // Arrange |
| String plaintext = "This is a short string." |
| String firstCharOff = "this is a short string." |
| String lastCharOff = "This is a short string," |
| |
| final int ITERATIONS = 10_000 |
| final int WARM_UP_ITERATIONS = 1_000 * ITERATIONS |
| |
| def scenarios = ["identical": plaintext, "first off": firstCharOff, "last off": lastCharOff] |
| def results = [:] |
| def timings = [:] |
| |
| boolean isEqual = true |
| long nanos = 0 |
| long scNanos = 0 |
| |
| // Prepare the JVM |
| def plaintextChars = plaintext.chars |
| (WARM_UP_ITERATIONS).times { int i -> |
| def scIterationNanos = time("warm up sc") { |
| assert plaintext == plaintext |
| } |
| scNanos += scIterationNanos |
| def iterationNanos = time("warm up") { |
| assert CryptoUtils.constantTimeEquals(plaintextChars, plaintextChars) |
| } |
| nanos += iterationNanos |
| } |
| logger.info("${"warm up sc".padLeft(10)}: ${nanos} ns (avg: ${nanos / (WARM_UP_ITERATIONS)} ns)") |
| logger.info("${"warm up".padLeft(10)}: ${scNanos} ns (avg: ${scNanos / (WARM_UP_ITERATIONS)} ns)") |
| |
| // Act |
| scenarios.each { String scenario, String value -> |
| isEqual = true |
| scNanos = 0 |
| nanos = 0 |
| def valueChars = value.chars |
| ITERATIONS.times { int i -> |
| def scIterationNanos = time(scenario + " sc") { |
| (plaintextChars == valueChars) |
| } |
| scNanos += scIterationNanos |
| def iterationNanos = time(scenario) { |
| isEqual = CryptoUtils.constantTimeEquals(plaintextChars, valueChars) |
| } |
| nanos += iterationNanos |
| } |
| def scenarioWidth = 16 |
| logger.info("${(scenario + " sc").padLeft(scenarioWidth)}: ${scNanos} ns (avg: ${scNanos / ITERATIONS} ns)") |
| logger.info("${scenario.padLeft(scenarioWidth)}: ${nanos} ns (avg: ${nanos / ITERATIONS} ns)") |
| results[scenario] = isEqual |
| timings[scenario] = nanos |
| } |
| |
| // Assert |
| assert results["identical"] |
| assert !results["first off"] |
| assert !results["last off"] |
| |
| // TODO: Assert timings are within std dev? |
| } |
| |
| private static long time(String name = "closure", Closure closure) { |
| long start = System.nanoTime() |
| closure.run() |
| long end = System.nanoTime() |
| end - start |
| } |
| } |