| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| package org.apache.nifi.security.util.crypto |
| |
| import org.apache.commons.codec.binary.Hex |
| import org.apache.nifi.security.util.EncryptionMethod |
| import org.apache.nifi.security.util.KeyDerivationFunction |
| import org.bouncycastle.jce.provider.BouncyCastleProvider |
| import org.junit.jupiter.api.BeforeAll |
| import org.junit.jupiter.api.Test |
| import org.slf4j.Logger |
| import org.slf4j.LoggerFactory |
| |
| import java.security.Security |
| |
| class CipherUtilityGroovyTest extends GroovyTestCase { |
| private static final Logger logger = LoggerFactory.getLogger(CipherUtilityGroovyTest.class) |
| |
| // TripleDES must precede DES for automatic grouping precedence |
| private static final List<String> CIPHERS = ["AES", "TRIPLEDES", "DES", "RC2", "RC4", "RC5", "TWOFISH"] |
| private static final List<String> SYMMETRIC_ALGORITHMS = EncryptionMethod.values().findAll { it.algorithm.startsWith("PBE") || it.algorithm.startsWith("AES") }*.algorithm |
| private static final Map<String, List<String>> ALGORITHMS_MAPPED_BY_CIPHER = SYMMETRIC_ALGORITHMS.groupBy { String algorithm -> CIPHERS.find { algorithm.contains(it) } } |
| |
| // Manually mapped as of 03/21/21 1.13.0 |
| private static final Map<Integer, List<String>> ALGORITHMS_MAPPED_BY_KEY_LENGTH = [ |
| (40) : ["PBEWITHSHAAND40BITRC2-CBC", |
| "PBEWITHSHAAND40BITRC4"], |
| (64) : ["PBEWITHMD5ANDDES", |
| "PBEWITHSHA1ANDDES"], |
| (112): ["PBEWITHSHAAND2-KEYTRIPLEDES-CBC", |
| "PBEWITHSHAAND3-KEYTRIPLEDES-CBC"], |
| (128): ["PBEWITHMD5AND128BITAES-CBC-OPENSSL", |
| "PBEWITHMD5ANDRC2", |
| "PBEWITHSHA1ANDRC2", |
| "PBEWITHSHA256AND128BITAES-CBC-BC", |
| "PBEWITHSHAAND128BITAES-CBC-BC", |
| "PBEWITHSHAAND128BITRC2-CBC", |
| "PBEWITHSHAAND128BITRC4", |
| "PBEWITHSHAANDTWOFISH-CBC", |
| "AES/CBC/NoPadding", |
| "AES/CBC/PKCS7Padding", |
| "AES/CTR/NoPadding", |
| "AES/GCM/NoPadding"], |
| (192): ["PBEWITHMD5AND192BITAES-CBC-OPENSSL", |
| "PBEWITHSHA256AND192BITAES-CBC-BC", |
| "PBEWITHSHAAND192BITAES-CBC-BC", |
| "AES/CBC/NoPadding", |
| "AES/CBC/PKCS7Padding", |
| "AES/CTR/NoPadding", |
| "AES/GCM/NoPadding"], |
| (256): ["PBEWITHMD5AND256BITAES-CBC-OPENSSL", |
| "PBEWITHSHA256AND256BITAES-CBC-BC", |
| "PBEWITHSHAAND256BITAES-CBC-BC", |
| "AES/CBC/NoPadding", |
| "AES/CBC/PKCS7Padding", |
| "AES/CTR/NoPadding", |
| "AES/GCM/NoPadding"] |
| ] |
| |
| @BeforeAll |
| static void setUpOnce() { |
| Security.addProvider(new BouncyCastleProvider()) |
| |
| // Fix because TRIPLEDES -> DESede |
| def tripleDESAlgorithms = ALGORITHMS_MAPPED_BY_CIPHER.remove("TRIPLEDES") |
| ALGORITHMS_MAPPED_BY_CIPHER.put("DESede", tripleDESAlgorithms) |
| |
| logger.info("Mapped algorithms: ${ALGORITHMS_MAPPED_BY_CIPHER}") |
| } |
| |
| @Test |
| void testShouldParseCipherFromAlgorithm() { |
| // Arrange |
| final def EXPECTED_ALGORITHMS = ALGORITHMS_MAPPED_BY_CIPHER |
| |
| // Act |
| SYMMETRIC_ALGORITHMS.each { String algorithm -> |
| String cipher = CipherUtility.parseCipherFromAlgorithm(algorithm) |
| logger.info("Extracted ${cipher} from ${algorithm}") |
| |
| // Assert |
| assert EXPECTED_ALGORITHMS.get(cipher).contains(algorithm) |
| } |
| } |
| |
| @Test |
| void testShouldParseKeyLengthFromAlgorithm() { |
| // Arrange |
| final def EXPECTED_ALGORITHMS = ALGORITHMS_MAPPED_BY_KEY_LENGTH |
| |
| // Act |
| SYMMETRIC_ALGORITHMS.each { String algorithm -> |
| int keyLength = CipherUtility.parseKeyLengthFromAlgorithm(algorithm) |
| logger.info("Extracted ${keyLength} from ${algorithm}") |
| |
| // Assert |
| assert EXPECTED_ALGORITHMS.get(keyLength).contains(algorithm) |
| } |
| } |
| |
| @Test |
| void testShouldDetermineValidKeyLength() { |
| // Arrange |
| |
| // Act |
| ALGORITHMS_MAPPED_BY_KEY_LENGTH.each { int keyLength, List<String> algorithms -> |
| algorithms.each { String algorithm -> |
| logger.info("Checking ${keyLength} for ${algorithm}") |
| |
| // Assert |
| assert CipherUtility.isValidKeyLength(keyLength, CipherUtility.parseCipherFromAlgorithm(algorithm)) |
| } |
| } |
| } |
| |
| @Test |
| void testShouldDetermineInvalidKeyLength() { |
| // Arrange |
| |
| // Act |
| ALGORITHMS_MAPPED_BY_KEY_LENGTH.each { int keyLength, List<String> algorithms -> |
| algorithms.each { String algorithm -> |
| def invalidKeyLengths = [-1, 0, 1] |
| if (algorithm =~ "RC\\d") { |
| invalidKeyLengths += [39, 2049] |
| } else { |
| invalidKeyLengths += keyLength + 1 |
| } |
| logger.info("Checking ${invalidKeyLengths.join(", ")} for ${algorithm}") |
| |
| // Assert |
| invalidKeyLengths.each { int invalidKeyLength -> |
| assert !CipherUtility.isValidKeyLength(invalidKeyLength, CipherUtility.parseCipherFromAlgorithm(algorithm)) |
| } |
| } |
| } |
| } |
| |
| @Test |
| void testShouldDetermineValidKeyLengthForAlgorithm() { |
| // Arrange |
| |
| // Act |
| ALGORITHMS_MAPPED_BY_KEY_LENGTH.each { int keyLength, List<String> algorithms -> |
| algorithms.each { String algorithm -> |
| logger.info("Checking ${keyLength} for ${algorithm}") |
| |
| // Assert |
| assert CipherUtility.isValidKeyLengthForAlgorithm(keyLength, algorithm) |
| } |
| } |
| } |
| |
| @Test |
| void testShouldDetermineInvalidKeyLengthForAlgorithm() { |
| // Arrange |
| |
| // Act |
| ALGORITHMS_MAPPED_BY_KEY_LENGTH.each { int keyLength, List<String> algorithms -> |
| algorithms.each { String algorithm -> |
| def invalidKeyLengths = [-1, 0, 1] |
| if (algorithm =~ "RC\\d") { |
| invalidKeyLengths += [39, 2049] |
| } else { |
| invalidKeyLengths += keyLength + 1 |
| } |
| logger.info("Checking ${invalidKeyLengths.join(", ")} for ${algorithm}") |
| |
| // Assert |
| invalidKeyLengths.each { int invalidKeyLength -> |
| assert !CipherUtility.isValidKeyLengthForAlgorithm(invalidKeyLength, algorithm) |
| } |
| } |
| } |
| |
| // Extra hard-coded checks |
| String algorithm = "PBEWITHSHA256AND256BITAES-CBC-BC" |
| int invalidKeyLength = 192 |
| logger.info("Checking ${invalidKeyLength} for ${algorithm}") |
| assert !CipherUtility.isValidKeyLengthForAlgorithm(invalidKeyLength, algorithm) |
| } |
| |
| @Test |
| void testShouldGetValidKeyLengthsForAlgorithm() { |
| // Arrange |
| |
| def rcKeyLengths = (40..2048).asList() |
| def CIPHER_KEY_SIZES = [ |
| AES : [128, 192, 256], |
| DES : [56, 64], |
| DESede : [56, 64, 112, 128, 168, 192], |
| RC2 : rcKeyLengths, |
| RC4 : rcKeyLengths, |
| RC5 : rcKeyLengths, |
| TWOFISH: [128, 192, 256] |
| ] |
| |
| def SINGLE_KEY_SIZE_ALGORITHMS = EncryptionMethod.values()*.algorithm.findAll { CipherUtility.parseActualKeyLengthFromAlgorithm(it) != -1 } |
| logger.info("Single key size algorithms: ${SINGLE_KEY_SIZE_ALGORITHMS}") |
| def MULTIPLE_KEY_SIZE_ALGORITHMS = EncryptionMethod.values()*.algorithm - SINGLE_KEY_SIZE_ALGORITHMS |
| MULTIPLE_KEY_SIZE_ALGORITHMS.removeAll { it.contains("PGP") } |
| logger.info("Multiple key size algorithms: ${MULTIPLE_KEY_SIZE_ALGORITHMS}") |
| |
| // Act |
| SINGLE_KEY_SIZE_ALGORITHMS.each { String algorithm -> |
| def EXPECTED_KEY_SIZES = [CipherUtility.parseKeyLengthFromAlgorithm(algorithm)] |
| |
| def validKeySizes = CipherUtility.getValidKeyLengthsForAlgorithm(algorithm) |
| logger.info("Checking ${algorithm} ${validKeySizes} against expected ${EXPECTED_KEY_SIZES}") |
| |
| // Assert |
| assert validKeySizes == EXPECTED_KEY_SIZES |
| } |
| |
| // Act |
| MULTIPLE_KEY_SIZE_ALGORITHMS.each { String algorithm -> |
| String cipher = CipherUtility.parseCipherFromAlgorithm(algorithm) |
| def EXPECTED_KEY_SIZES = CIPHER_KEY_SIZES[cipher] |
| |
| def validKeySizes = CipherUtility.getValidKeyLengthsForAlgorithm(algorithm) |
| logger.info("Checking ${algorithm} ${validKeySizes} against expected ${EXPECTED_KEY_SIZES}") |
| |
| // Assert |
| assert validKeySizes == EXPECTED_KEY_SIZES |
| } |
| } |
| |
| @Test |
| void testShouldFindSequence() { |
| // Arrange |
| byte[] license = """Licensed to the Apache Software Foundation (ASF) under one or more |
| contributor license agreements. See the NOTICE file distributed with |
| this work for additional information regarding copyright ownership. |
| The ASF licenses this file to You under the Apache License, Version 2.0 |
| (the "License"); you may not use this file except in compliance with |
| the License. You may obtain a copy of the License at |
| """.bytes |
| |
| byte[] apache = "Apache".bytes |
| byte[] software = "Software".bytes |
| byte[] asf = "ASF".bytes |
| byte[] kafka = "Kafka".bytes |
| |
| // Act |
| int apacheIndex = CipherUtility.findSequence(license, apache) |
| logger.info("Looking for ${Hex.encodeHexString(apache)}; found at ${apacheIndex}") |
| |
| int softwareIndex = CipherUtility.findSequence(license, software) |
| logger.info("Looking for ${Hex.encodeHexString(software)}; found at ${softwareIndex}") |
| |
| int asfIndex = CipherUtility.findSequence(license, asf) |
| logger.info("Looking for ${Hex.encodeHexString(asf)}; found at ${asfIndex}") |
| |
| int kafkaIndex = CipherUtility.findSequence(license, kafka) |
| logger.info("Looking for ${Hex.encodeHexString(kafka)}; found at ${kafkaIndex}") |
| |
| // Assert |
| assert apacheIndex == 16 |
| assert softwareIndex == 23 |
| assert asfIndex == 44 |
| assert kafkaIndex == -1 |
| } |
| |
| @Test |
| void testShouldExtractRawSalt() { |
| // Arrange |
| byte[] PLAIN_SALT = [0xab] * 16 |
| |
| String ARGON2_SALT = Argon2CipherProvider.formSalt(PLAIN_SALT, 8, 1, 1) |
| String BCRYPT_SALT = BcryptCipherProvider.formatSaltForBcrypt(PLAIN_SALT, 10) |
| String SCRYPT_SALT = ScryptCipherProvider.formatSaltForScrypt(PLAIN_SALT, 10, 1, 1) |
| |
| // Act |
| def results = KeyDerivationFunction.values().findAll { !it.isStrongKDF() }.collectEntries { KeyDerivationFunction weakKdf -> |
| [weakKdf, CipherUtility.extractRawSalt(PLAIN_SALT, weakKdf)] |
| } |
| |
| results.put(KeyDerivationFunction.ARGON2, CipherUtility.extractRawSalt(ARGON2_SALT.bytes, KeyDerivationFunction.ARGON2)) |
| results.put(KeyDerivationFunction.BCRYPT, CipherUtility.extractRawSalt(BCRYPT_SALT.bytes, KeyDerivationFunction.BCRYPT)) |
| results.put(KeyDerivationFunction.SCRYPT, CipherUtility.extractRawSalt(SCRYPT_SALT.bytes, KeyDerivationFunction.SCRYPT)) |
| results.put(KeyDerivationFunction.PBKDF2, CipherUtility.extractRawSalt(PLAIN_SALT, KeyDerivationFunction.PBKDF2)) |
| |
| // Assert |
| assert results.every { k, v -> v == PLAIN_SALT } |
| } |
| } |