MINOR: Use time constant algorithms when comparing passwords or keys (#10978)
Author: Randall Hauch <rhauch@gmail.com>
Reviewers: Manikumar Reddy <manikumar@confluent.io>, Rajini Sivaram <rajinisivaram@gmail.com>, Mickael Maison <mickael.maison@gmail.com>, Ismael Juma <ijuma@apache.org>
diff --git a/clients/src/main/java/org/apache/kafka/common/security/plain/internals/PlainServerCallbackHandler.java b/clients/src/main/java/org/apache/kafka/common/security/plain/internals/PlainServerCallbackHandler.java
index 842f986..10f5817 100644
--- a/clients/src/main/java/org/apache/kafka/common/security/plain/internals/PlainServerCallbackHandler.java
+++ b/clients/src/main/java/org/apache/kafka/common/security/plain/internals/PlainServerCallbackHandler.java
@@ -22,9 +22,9 @@
import org.apache.kafka.common.KafkaException;
import org.apache.kafka.common.security.plain.PlainAuthenticateCallback;
import org.apache.kafka.common.security.plain.PlainLoginModule;
+import org.apache.kafka.common.utils.Utils;
import java.io.IOException;
-import java.util.Arrays;
import java.util.List;
import java.util.Map;
@@ -65,7 +65,7 @@
String expectedPassword = JaasContext.configEntryOption(jaasConfigEntries,
JAAS_USER_PREFIX + username,
PlainLoginModule.class.getName());
- return expectedPassword != null && Arrays.equals(password, expectedPassword.toCharArray());
+ return expectedPassword != null && Utils.isEqualConstantTime(password, expectedPassword.toCharArray());
}
}
diff --git a/clients/src/main/java/org/apache/kafka/common/security/scram/internals/ScramSaslClient.java b/clients/src/main/java/org/apache/kafka/common/security/scram/internals/ScramSaslClient.java
index c21a52e..2e6191b 100644
--- a/clients/src/main/java/org/apache/kafka/common/security/scram/internals/ScramSaslClient.java
+++ b/clients/src/main/java/org/apache/kafka/common/security/scram/internals/ScramSaslClient.java
@@ -18,6 +18,7 @@
import java.nio.charset.StandardCharsets;
import java.security.InvalidKeyException;
+import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.Arrays;
import java.util.Collection;
@@ -204,7 +205,7 @@
try {
byte[] serverKey = formatter.serverKey(saltedPassword);
byte[] serverSignature = formatter.serverSignature(serverKey, clientFirstMessage, serverFirstMessage, clientFinalMessage);
- if (!Arrays.equals(signature, serverSignature))
+ if (!MessageDigest.isEqual(signature, serverSignature))
throw new SaslException("Invalid server signature in server final message");
} catch (InvalidKeyException e) {
throw new SaslException("Sasl server signature verification failed", e);
diff --git a/clients/src/main/java/org/apache/kafka/common/security/scram/internals/ScramSaslServer.java b/clients/src/main/java/org/apache/kafka/common/security/scram/internals/ScramSaslServer.java
index f6286a6..3cc8ff0 100644
--- a/clients/src/main/java/org/apache/kafka/common/security/scram/internals/ScramSaslServer.java
+++ b/clients/src/main/java/org/apache/kafka/common/security/scram/internals/ScramSaslServer.java
@@ -17,6 +17,7 @@
package org.apache.kafka.common.security.scram.internals;
import java.security.InvalidKeyException;
+import java.security.MessageDigest;
import java.security.NoSuchAlgorithmException;
import java.util.Arrays;
import java.util.Collection;
@@ -226,7 +227,7 @@
byte[] expectedStoredKey = scramCredential.storedKey();
byte[] clientSignature = formatter.clientSignature(expectedStoredKey, clientFirstMessage, serverFirstMessage, clientFinalMessage);
byte[] computedStoredKey = formatter.storedKey(clientSignature, clientFinalMessage.proof());
- if (!Arrays.equals(computedStoredKey, expectedStoredKey))
+ if (!MessageDigest.isEqual(computedStoredKey, expectedStoredKey))
throw new SaslException("Invalid client credentials");
} catch (InvalidKeyException e) {
throw new SaslException("Sasl client verification failed", e);
diff --git a/clients/src/main/java/org/apache/kafka/common/security/token/delegation/DelegationToken.java b/clients/src/main/java/org/apache/kafka/common/security/token/delegation/DelegationToken.java
index b389a19..a2141b5 100644
--- a/clients/src/main/java/org/apache/kafka/common/security/token/delegation/DelegationToken.java
+++ b/clients/src/main/java/org/apache/kafka/common/security/token/delegation/DelegationToken.java
@@ -18,6 +18,7 @@
import org.apache.kafka.common.annotation.InterfaceStability;
+import java.security.MessageDigest;
import java.util.Arrays;
import java.util.Base64;
import java.util.Objects;
@@ -59,7 +60,7 @@
DelegationToken token = (DelegationToken) o;
- return Objects.equals(tokenInformation, token.tokenInformation) && Arrays.equals(hmac, token.hmac);
+ return Objects.equals(tokenInformation, token.tokenInformation) && MessageDigest.isEqual(hmac, token.hmac);
}
@Override
diff --git a/clients/src/main/java/org/apache/kafka/common/utils/Utils.java b/clients/src/main/java/org/apache/kafka/common/utils/Utils.java
index 5fa32b7..921ce3c 100755
--- a/clients/src/main/java/org/apache/kafka/common/utils/Utils.java
+++ b/clients/src/main/java/org/apache/kafka/common/utils/Utils.java
@@ -299,6 +299,42 @@
}
/**
+ * Compares two character arrays for equality using a constant-time algorithm, which is needed
+ * for comparing passwords. Two arrays are equal if they have the same length and all
+ * characters at corresponding positions are equal.
+ *
+ * All characters in the first array are examined to determine equality.
+ * The calculation time depends only on the length of this first character array; it does not
+ * depend on the length of the second character array or the contents of either array.
+ *
+ * @param first the first array to compare
+ * @param second the second array to compare
+ * @return true if the arrays are equal, or false otherwise
+ */
+ public static boolean isEqualConstantTime(char[] first, char[] second) {
+ if (first == second) {
+ return true;
+ }
+ if (first == null || second == null) {
+ return false;
+ }
+
+ if (second.length == 0) {
+ return first.length == 0;
+ }
+
+ // time-constant comparison that always compares all characters in first array
+ boolean matches = first.length == second.length;
+ for (int i = 0; i < first.length; ++i) {
+ int j = i < second.length ? i : 0;
+ if (first[i] != second[j]) {
+ matches = false;
+ }
+ }
+ return matches;
+ }
+
+ /**
* Sleep for a bit
* @param ms The duration of the sleep
*/
diff --git a/clients/src/test/java/org/apache/kafka/common/utils/UtilsTest.java b/clients/src/test/java/org/apache/kafka/common/utils/UtilsTest.java
index 172a992..5762be4 100755
--- a/clients/src/test/java/org/apache/kafka/common/utils/UtilsTest.java
+++ b/clients/src/test/java/org/apache/kafka/common/utils/UtilsTest.java
@@ -47,6 +47,8 @@
import static org.junit.Assert.assertArrayEquals;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
+import static org.junit.Assert.assertNotNull;
+import static org.junit.Assert.assertNull;
import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.ArgumentMatchers.any;
@@ -487,4 +489,46 @@
} catch (IllegalArgumentException e) {
}
}
+
+ @Test
+ public void testCharacterArrayEquality() {
+ assertCharacterArraysAreNotEqual(null, "abc");
+ assertCharacterArraysAreNotEqual(null, "");
+ assertCharacterArraysAreNotEqual("abc", null);
+ assertCharacterArraysAreNotEqual("", null);
+ assertCharacterArraysAreNotEqual("", "abc");
+ assertCharacterArraysAreNotEqual("abc", "abC");
+ assertCharacterArraysAreNotEqual("abc", "abcd");
+ assertCharacterArraysAreNotEqual("abc", "abcdefg");
+ assertCharacterArraysAreNotEqual("abcdefg", "abc");
+ assertCharacterArraysAreEqual("abc", "abc");
+ assertCharacterArraysAreEqual("a", "a");
+ assertCharacterArraysAreEqual("", "");
+ assertCharacterArraysAreEqual("", "");
+ assertCharacterArraysAreEqual(null, null);
+ }
+
+ private void assertCharacterArraysAreNotEqual(String a, String b) {
+ char[] first = a != null ? a.toCharArray() : null;
+ char[] second = b != null ? b.toCharArray() : null;
+ if (a == null) {
+ assertNotNull(b);
+ } else {
+ assertFalse(a.equals(b));
+ }
+ assertFalse(Utils.isEqualConstantTime(first, second));
+ assertFalse(Utils.isEqualConstantTime(second, first));
+ }
+
+ private void assertCharacterArraysAreEqual(String a, String b) {
+ char[] first = a != null ? a.toCharArray() : null;
+ char[] second = b != null ? b.toCharArray() : null;
+ if (a == null) {
+ assertNull(b);
+ } else {
+ assertTrue(a.equals(b));
+ }
+ assertTrue(Utils.isEqualConstantTime(first, second));
+ assertTrue(Utils.isEqualConstantTime(second, first));
+ }
}