QPID-8302: [Broker-J] Validate whether private key matches certificate on NonJavaKeystore certificate update
diff --git a/broker-core/src/main/java/org/apache/qpid/server/security/NonJavaKeyStoreImpl.java b/broker-core/src/main/java/org/apache/qpid/server/security/NonJavaKeyStoreImpl.java
index 3a92a6c..a9dbb29 100644
--- a/broker-core/src/main/java/org/apache/qpid/server/security/NonJavaKeyStoreImpl.java
+++ b/broker-core/src/main/java/org/apache/qpid/server/security/NonJavaKeyStoreImpl.java
@@ -22,15 +22,19 @@
 
 import java.io.File;
 import java.io.IOException;
+import java.math.BigInteger;
 import java.net.MalformedURLException;
 import java.net.URL;
 import java.nio.ByteBuffer;
 import java.nio.charset.StandardCharsets;
 import java.security.GeneralSecurityException;
 import java.security.PrivateKey;
+import java.security.PublicKey;
 import java.security.SecureRandom;
 import java.security.cert.Certificate;
 import java.security.cert.X509Certificate;
+import java.security.interfaces.RSAPrivateKey;
+import java.security.interfaces.RSAPublicKey;
 import java.util.ArrayList;
 import java.util.Arrays;
 import java.util.Collection;
@@ -38,6 +42,7 @@
 import java.util.Date;
 import java.util.List;
 import java.util.Map;
+import java.util.Objects;
 import java.util.Set;
 
 import javax.naming.InvalidNameException;
@@ -194,11 +199,23 @@
     {
         try
         {
-            SSLUtil.readPrivateKey(getUrlFromString(keyStore.getPrivateKeyUrl()));
-            SSLUtil.readCertificates(getUrlFromString(keyStore.getCertificateUrl()));
+            final PrivateKey privateKey = SSLUtil.readPrivateKey(getUrlFromString(keyStore.getPrivateKeyUrl()));
+            X509Certificate[] certs = SSLUtil.readCertificates(getUrlFromString(keyStore.getCertificateUrl()));
+            final List<X509Certificate> allCerts = new ArrayList<>(Arrays.asList(certs));
             if(keyStore.getIntermediateCertificateUrl() != null)
             {
-                SSLUtil.readCertificates(getUrlFromString(keyStore.getIntermediateCertificateUrl()));
+                allCerts.addAll(Arrays.asList(SSLUtil.readCertificates(getUrlFromString(keyStore.getIntermediateCertificateUrl()))));
+                certs = allCerts.toArray(new X509Certificate[allCerts.size()]);
+            }
+            final PublicKey publicKey = certs[0].getPublicKey();
+            if (privateKey instanceof RSAPrivateKey && publicKey instanceof RSAPublicKey)
+            {
+                final BigInteger privateModulus = ((RSAPrivateKey) privateKey).getModulus();
+                final BigInteger publicModulus = ((RSAPublicKey)publicKey).getModulus();
+                if (!Objects.equals(privateModulus, publicModulus))
+                {
+                    throw new IllegalConfigurationException("Private key does not match certificate");
+                }
             }
         }
         catch (IOException | GeneralSecurityException e )
diff --git a/broker-core/src/test/java/org/apache/qpid/server/security/NonJavaKeyStoreTest.java b/broker-core/src/test/java/org/apache/qpid/server/security/NonJavaKeyStoreTest.java
index 2352591..d4d6390 100644
--- a/broker-core/src/test/java/org/apache/qpid/server/security/NonJavaKeyStoreTest.java
+++ b/broker-core/src/test/java/org/apache/qpid/server/security/NonJavaKeyStoreTest.java
@@ -20,11 +20,14 @@
 package org.apache.qpid.server.security;
 
 
+import static java.nio.charset.StandardCharsets.UTF_8;
 import static org.apache.qpid.test.utils.TestSSLConstants.JAVA_KEYSTORE_TYPE;
 import static org.apache.qpid.test.utils.TestSSLConstants.KEYSTORE_PASSWORD;
+import static org.hamcrest.CoreMatchers.is;
 import static org.junit.Assert.assertEquals;
 import static org.junit.Assert.assertNotNull;
 import static org.junit.Assert.fail;
+import static org.junit.Assume.assumeThat;
 import static org.mockito.ArgumentMatchers.any;
 import static org.mockito.ArgumentMatchers.anyLong;
 import static org.mockito.ArgumentMatchers.argThat;
@@ -40,9 +43,11 @@
 import java.security.Key;
 import java.security.cert.Certificate;
 import java.security.cert.X509Certificate;
+import java.time.Duration;
+import java.time.Instant;
+import java.time.temporal.ChronoUnit;
 import java.util.ArrayList;
 import java.util.Arrays;
-import java.util.Base64;
 import java.util.Collections;
 import java.util.HashMap;
 import java.util.List;
@@ -66,7 +71,10 @@
 import org.apache.qpid.server.model.BrokerTestHelper;
 import org.apache.qpid.server.model.ConfiguredObjectFactory;
 import org.apache.qpid.server.model.KeyStore;
+import org.apache.qpid.server.transport.network.security.ssl.SSLUtil;
+import org.apache.qpid.server.util.DataUrlUtils;
 import org.apache.qpid.test.utils.TestFileUtils;
+import org.apache.qpid.test.utils.TestSSLUtils;
 import org.apache.qpid.test.utils.UnitTestBase;
 
 public class NonJavaKeyStoreTest extends UnitTestBase
@@ -118,17 +126,7 @@
             Key pvt = ks.getKey("java-broker", KEYSTORE_PASSWORD.toCharArray());
             if (pem)
             {
-                kos.write("-----BEGIN PRIVATE KEY-----\n".getBytes());
-                String base64encoded = Base64.getEncoder().encodeToString(pvt.getEncoded());
-                while(base64encoded.length() > 76)
-                {
-                    kos.write(base64encoded.substring(0,76).getBytes());
-                    kos.write("\n".getBytes());
-                    base64encoded = base64encoded.substring(76);
-                }
-
-                kos.write(base64encoded.getBytes());
-                kos.write("\n-----END PRIVATE KEY-----".getBytes());
+                kos.write(TestSSLUtils.privateKeyToPEM(pvt).getBytes(UTF_8));
             }
             else
             {
@@ -141,20 +139,10 @@
 
         try(FileOutputStream cos = new FileOutputStream(certificateFile))
         {
-            Certificate pub = ks.getCertificate("rootca");
+            Certificate pub = ks.getCertificate("java-broker");
             if (pem)
             {
-                cos.write("-----BEGIN CERTIFICATE-----\n".getBytes());
-                String base64encoded = Base64.getEncoder().encodeToString(pub.getEncoded());
-                while(base64encoded.length() > 76)
-                {
-                    cos.write(base64encoded.substring(0,76).getBytes());
-                    cos.write("\n".getBytes());
-                    base64encoded = base64encoded.substring(76);
-                }
-                cos.write(base64encoded.getBytes());
-
-                cos.write("\n-----END CERTIFICATE-----".getBytes());
+                cos.write(TestSSLUtils.certificateToPEM(pub).getBytes(UTF_8));
             }
             else
             {
@@ -293,6 +281,76 @@
         _factory.create(KeyStore.class, attributes, _broker);
     }
 
+    @Test
+    public void testCreationOfKeyStoreWithNonMatchingPrivateKeyAndCertificate()throws Exception
+    {
+        assumeThat(SSLUtil.canGenerateCerts(), is(true));
+
+        final SSLUtil.KeyCertPair keyCertPair = generateSelfSignedCertificate();
+        final SSLUtil.KeyCertPair keyCertPair2 = generateSelfSignedCertificate();
+
+        final Map<String,Object> attributes = new HashMap<>();
+        attributes.put(NonJavaKeyStore.NAME, "myTestTrustStore");
+        attributes.put(NonJavaKeyStore.PRIVATE_KEY_URL,
+                       DataUrlUtils.getDataUrlForBytes(TestSSLUtils.privateKeyToPEM(keyCertPair.getPrivateKey()).getBytes(UTF_8)));
+        attributes.put(NonJavaKeyStore.CERTIFICATE_URL,
+                       DataUrlUtils.getDataUrlForBytes(TestSSLUtils.certificateToPEM(keyCertPair2.getCertificate()).getBytes(UTF_8)));
+        attributes.put(NonJavaKeyStore.TYPE, "NonJavaKeyStore");
+
+        try
+        {
+            _factory.create(KeyStore.class, attributes, _broker);
+            fail("Created key store from invalid certificate");
+        }
+        catch(IllegalConfigurationException e)
+        {
+            // pass
+        }
+    }
+
+    @Test
+    public void testUpdateKeyStoreToNonMatchingCertificate()throws Exception
+    {
+        assumeThat(SSLUtil.canGenerateCerts(), is(true));
+
+        final SSLUtil.KeyCertPair keyCertPair = generateSelfSignedCertificate();
+        final SSLUtil.KeyCertPair keyCertPair2 = generateSelfSignedCertificate();
+
+        final Map<String,Object> attributes = new HashMap<>();
+        attributes.put(NonJavaKeyStore.NAME, getTestName());
+        attributes.put(NonJavaKeyStore.PRIVATE_KEY_URL,
+                       DataUrlUtils.getDataUrlForBytes(TestSSLUtils.privateKeyToPEM(keyCertPair.getPrivateKey()).getBytes(UTF_8)));
+        attributes.put(NonJavaKeyStore.CERTIFICATE_URL,
+                       DataUrlUtils.getDataUrlForBytes(TestSSLUtils.certificateToPEM(keyCertPair.getCertificate()).getBytes(UTF_8)));
+        attributes.put(NonJavaKeyStore.TYPE, "NonJavaKeyStore");
+
+        final KeyStore trustStore = _factory.create(KeyStore.class, attributes, _broker);
+        try
+        {
+            final String certUrl = DataUrlUtils.getDataUrlForBytes(TestSSLUtils.certificateToPEM(keyCertPair2.getCertificate()).getBytes(UTF_8));
+            trustStore.setAttributes(Collections.singletonMap("certificateUrl", certUrl));
+            fail("Created key store from invalid certificate");
+        }
+        catch(IllegalConfigurationException e)
+        {
+            // pass
+        }
+    }
+
+    private SSLUtil.KeyCertPair generateSelfSignedCertificate() throws Exception
+    {
+        return SSLUtil.generateSelfSignedCertificate("RSA",
+                                                     "SHA256WithRSA",
+                                                     2048,
+                                                     Instant.now()
+                                                            .minus(1, ChronoUnit.DAYS)
+                                                            .toEpochMilli(),
+                                                     Duration.of(365, ChronoUnit.DAYS)
+                                                             .getSeconds(),
+                                                     "CN=foo",
+                                                     Collections.emptySet(),
+                                                     Collections.emptySet());
+    }
 
     private static class LogMessageArgumentMatcher implements ArgumentMatcher<LogMessage>
     {
diff --git a/qpid-test-utils/src/main/java/org/apache/qpid/test/utils/TestSSLUtils.java b/qpid-test-utils/src/main/java/org/apache/qpid/test/utils/TestSSLUtils.java
new file mode 100644
index 0000000..fedf4ca
--- /dev/null
+++ b/qpid-test-utils/src/main/java/org/apache/qpid/test/utils/TestSSLUtils.java
@@ -0,0 +1,52 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.qpid.test.utils;
+
+import java.security.Key;
+import java.security.cert.Certificate;
+import java.security.cert.CertificateEncodingException;
+import java.util.Base64;
+
+public class TestSSLUtils
+{
+    public static String certificateToPEM(final Certificate pub) throws CertificateEncodingException
+    {
+        return toPEM(pub.getEncoded(), "-----BEGIN CERTIFICATE-----", "-----END CERTIFICATE-----");
+    }
+
+    public static String privateKeyToPEM(final Key key)
+    {
+        return toPEM(key.getEncoded(), "-----BEGIN PRIVATE KEY-----", "-----END PRIVATE KEY-----");
+    }
+
+    private static String toPEM(final byte[] bytes, final String header, final String footer)
+    {
+        StringBuilder pem = new StringBuilder();
+        pem.append(header).append("\n");
+        String base64encoded = Base64.getEncoder().encodeToString(bytes);
+        while (base64encoded.length() > 76)
+        {
+            pem.append(base64encoded, 0, 76).append("\n");
+            base64encoded = base64encoded.substring(76);
+        }
+        pem.append(base64encoded).append("\n");
+        pem.append(footer).append("\n");
+        return pem.toString();
+    }
+}
diff --git a/systests/qpid-systests-jms_1.1/src/test/java/org/apache/qpid/systests/jms_1_1/extensions/tls/TlsTest.java b/systests/qpid-systests-jms_1.1/src/test/java/org/apache/qpid/systests/jms_1_1/extensions/tls/TlsTest.java
index bb81620..71d5e3c 100644
--- a/systests/qpid-systests-jms_1.1/src/test/java/org/apache/qpid/systests/jms_1_1/extensions/tls/TlsTest.java
+++ b/systests/qpid-systests-jms_1.1/src/test/java/org/apache/qpid/systests/jms_1_1/extensions/tls/TlsTest.java
@@ -20,6 +20,7 @@
  */
 package org.apache.qpid.systests.jms_1_1.extensions.tls;
 
+import static java.nio.charset.StandardCharsets.UTF_8;
 import static org.apache.qpid.test.utils.TestSSLConstants.JAVA_KEYSTORE_TYPE;
 import static org.apache.qpid.test.utils.TestSSLConstants.BROKER_KEYSTORE_PASSWORD;
 import static org.apache.qpid.test.utils.TestSSLConstants.BROKER_TRUSTSTORE_PASSWORD;
@@ -66,6 +67,7 @@
 import org.apache.qpid.systests.ConnectionBuilder;
 import org.apache.qpid.systests.JmsTestBase;
 import org.apache.qpid.test.utils.TestSSLConstants;
+import org.apache.qpid.test.utils.TestSSLUtils;
 import org.apache.qpid.tests.utils.BrokerAdmin;
 
 public class TlsTest extends JmsTestBase
@@ -704,18 +706,7 @@
         try (FileOutputStream kos = new FileOutputStream(privateKeyFile))
         {
             Key pvt = ks.getKey(TestSSLConstants.CERT_ALIAS_APP1, KEYSTORE_PASSWORD.toCharArray());
-            kos.write("-----BEGIN PRIVATE KEY-----\n".getBytes());
-            String base64encoded = Base64.getEncoder().encodeToString(pvt.getEncoded());
-            while (base64encoded.length() > 76)
-            {
-                kos.write(base64encoded.substring(0, 76).getBytes());
-                kos.write("\n".getBytes());
-                base64encoded = base64encoded.substring(76);
-            }
-
-            kos.write(base64encoded.getBytes());
-            kos.write("\n-----END PRIVATE KEY-----".getBytes());
-            kos.flush();
+            kos.write(TestSSLUtils.privateKeyToPEM(pvt).getBytes(UTF_8));
         }
 
         File certificateFile = Files.createTempFile(getTestName(), ".certificate.der").toFile();
@@ -724,17 +715,7 @@
             Certificate[] chain = ks.getCertificateChain(TestSSLConstants.CERT_ALIAS_APP1);
             for (Certificate pub : chain)
             {
-                cos.write("-----BEGIN CERTIFICATE-----\n".getBytes());
-                String base64encoded = Base64.getEncoder().encodeToString(pub.getEncoded());
-                while (base64encoded.length() > 76)
-                {
-                    cos.write(base64encoded.substring(0, 76).getBytes());
-                    cos.write("\n".getBytes());
-                    base64encoded = base64encoded.substring(76);
-                }
-                cos.write(base64encoded.getBytes());
-
-                cos.write("\n-----END CERTIFICATE-----\n".getBytes());
+                cos.write(TestSSLUtils.certificateToPEM(pub).getBytes(UTF_8));
             }
             cos.flush();
         }