SANTUARIO-523 - Applying additional fix. Thanks to Peter De Maeyer


git-svn-id: https://svn.apache.org/repos/asf/santuario/xml-security-java/trunk@1874113 13f79535-47bb-0310-9956-ffa450edef68
diff --git a/src/main/java/org/apache/xml/security/stax/impl/XMLSecurityStreamReader.java b/src/main/java/org/apache/xml/security/stax/impl/XMLSecurityStreamReader.java
index 24d4a7f..e384b4b 100644
--- a/src/main/java/org/apache/xml/security/stax/impl/XMLSecurityStreamReader.java
+++ b/src/main/java/org/apache/xml/security/stax/impl/XMLSecurityStreamReader.java
@@ -85,7 +85,9 @@
                 // We only skip the event itself.
                 StartDocument startDocument = (StartDocument) currentXMLSecEvent;
                 version = startDocument.getVersion();
-                characterEncodingScheme = startDocument.getCharacterEncodingScheme();
+                if (startDocument.encodingSet()) {
+                    characterEncodingScheme = startDocument.getCharacterEncodingScheme();
+                }
                 standalone = startDocument.isStandalone();
                 standaloneSet = startDocument.standaloneSet();
                 if (skipDocumentEvents) {
diff --git a/src/test/java/org/apache/xml/security/test/stax/XMLSecurityStreamReaderTest.java b/src/test/java/org/apache/xml/security/test/stax/XMLSecurityStreamReaderTest.java
index 81dd37e..83b4213 100644
--- a/src/test/java/org/apache/xml/security/test/stax/XMLSecurityStreamReaderTest.java
+++ b/src/test/java/org/apache/xml/security/test/stax/XMLSecurityStreamReaderTest.java
@@ -55,6 +55,7 @@
 import static org.hamcrest.MatcherAssert.assertThat;
 import static org.hamcrest.Matchers.equalTo;
 import static org.hamcrest.Matchers.is;
+import static org.hamcrest.Matchers.nullValue;
 import static org.junit.jupiter.api.Assertions.assertEquals;
 import static org.junit.jupiter.api.Assertions.assertFalse;
 import static org.junit.jupiter.api.Assertions.assertTrue;
@@ -141,6 +142,27 @@
     }
 
     @Test
+    public void testDocumentDeclarationWithoutOptionalAttributes() throws Exception {
+        String xml = "<?xml version='1.1'?>\n"
+                + "<Document/>";
+        ByteArrayInputStream xmlInput = new ByteArrayInputStream(xml.getBytes(StandardCharsets.ISO_8859_1));
+        XMLInputFactory xmlInputFactory = XMLInputFactory.newInstance();
+        XMLStreamReader stdXmlStreamReader = xmlInputFactory.createXMLStreamReader(xmlInput);
+        InboundSecurityContextImpl securityContext = new InboundSecurityContextImpl();
+        InputProcessorChainImpl inputProcessorChain = new InputProcessorChainImpl(securityContext);
+        inputProcessorChain.addProcessor(new EventReaderProcessor(stdXmlStreamReader));
+        XMLSecurityProperties securityProperties = new XMLSecurityProperties();
+        securityProperties.setSkipDocumentEvents(false);
+        XMLSecurityStreamReader xmlSecurityStreamReader = new XMLSecurityStreamReader(inputProcessorChain, securityProperties);
+        advanceToFirstEvent(xmlSecurityStreamReader);
+        assertThat(xmlSecurityStreamReader.getEventType(), is(XMLStreamConstants.START_DOCUMENT));
+        assertThat(xmlSecurityStreamReader.getVersion(), is(equalTo("1.1")));
+        assertThat(xmlSecurityStreamReader.getCharacterEncodingScheme(), is(nullValue()));
+        assertThat(xmlSecurityStreamReader.isStandalone(), is(false));
+        assertThat(xmlSecurityStreamReader.standaloneSet(), is(false));
+    }
+
+    @Test
     public void testDocumentDeclarationWhenSkipDocumentEvents() throws Exception {
         String xml = "<?xml version='1.1' encoding='ISO-8859-1' standalone='yes'?>\n"
                 + "<Document/>";
@@ -284,7 +306,7 @@
                     assertEquals(stdXmlStreamReader.getTextLength(), xmlSecurityStreamReader.getTextLength());
                     break;
                 case XMLStreamConstants.START_DOCUMENT:
-                    assertEquals(StandardCharsets.UTF_8.name(), xmlSecurityStreamReader.getCharacterEncodingScheme());
+                    assertEquals(stdXmlStreamReader.getCharacterEncodingScheme(), xmlSecurityStreamReader.getCharacterEncodingScheme());
                     assertEquals(stdXmlStreamReader.getEncoding(), xmlSecurityStreamReader.getEncoding());
                     assertEquals(stdXmlStreamReader.getVersion(), xmlSecurityStreamReader.getVersion());
                     break;