WSS-659 SecurityContextToken validator fixing QName (#2)

* WSS-659 SecurityContextToken validator fixing QName

* WSS-659 Fixed handlers identifier access, Added tests

* WSS-659 Removed redundant method, sharpen test
diff --git a/ws-security-stax/src/main/java/org/apache/wss4j/stax/impl/processor/input/SecurityContextTokenInputHandler.java b/ws-security-stax/src/main/java/org/apache/wss4j/stax/impl/processor/input/SecurityContextTokenInputHandler.java
index 1998bc0..b7bd2d0 100644
--- a/ws-security-stax/src/main/java/org/apache/wss4j/stax/impl/processor/input/SecurityContextTokenInputHandler.java
+++ b/ws-security-stax/src/main/java/org/apache/wss4j/stax/impl/processor/input/SecurityContextTokenInputHandler.java
@@ -47,18 +47,17 @@
     public void handle(InputProcessorChain inputProcessorChain, final XMLSecurityProperties securityProperties,
                        Deque<XMLSecEvent> eventQueue, Integer index) throws XMLSecurityException {
 
-        @SuppressWarnings("unchecked")
         JAXBElement<AbstractSecurityContextTokenType> securityContextTokenTypeJAXBElement =
-                (JAXBElement<AbstractSecurityContextTokenType>) parseStructure(eventQueue, index, securityProperties);
+                parseStructure(eventQueue, index, securityProperties);
         final AbstractSecurityContextTokenType securityContextTokenType = securityContextTokenTypeJAXBElement.getValue();
         if (securityContextTokenType.getId() == null) {
             securityContextTokenType.setId(IDGenerator.generateID(null));
         }
 
-        final QName elementName = new QName(securityContextTokenTypeJAXBElement.getName().getNamespaceURI(),
+        final QName identifierElementName = new QName(securityContextTokenTypeJAXBElement.getName().getNamespaceURI(),
                 WSSConstants.TAG_WSC0502_IDENTIFIER.getLocalPart());
-        final String identifier = (String) XMLSecurityUtils.getQNameType(securityContextTokenType.getAny(),
-                elementName);
+        final String identifier = XMLSecurityUtils.getQNameType(securityContextTokenType.getAny(),
+                identifierElementName);
 
         final WSInboundSecurityContext wsInboundSecurityContext = 
             (WSInboundSecurityContext) inputProcessorChain.getSecurityContext();
@@ -69,6 +68,7 @@
         final TokenContext tokenContext = 
             new TokenContext(wssSecurityProperties, wsInboundSecurityContext, xmlSecEvents, elementPath);
 
+        final QName elementName = securityContextTokenTypeJAXBElement.getName();
         SecurityContextTokenValidator securityContextTokenValidator = wssSecurityProperties.getValidator(elementName);
         if (securityContextTokenValidator == null) {
             securityContextTokenValidator = new SecurityContextTokenValidatorImpl();
@@ -108,9 +108,17 @@
         wsInboundSecurityContext.registerSecurityTokenProvider(identifier, securityTokenProviderDirectReference);
 
         //fire a tokenSecurityEvent
+        SecurityContextTokenSecurityEvent securityEvent = createTokenSecurityEvent(securityContextTokenType, securityTokenProvider);
+        wsInboundSecurityContext.registerSecurityEvent(securityEvent);
+    }
+
+    private SecurityContextTokenSecurityEvent createTokenSecurityEvent(AbstractSecurityContextTokenType securityContextTokenType,
+                                                                       SecurityTokenProvider<InboundSecurityToken> securityTokenProvider)
+            throws XMLSecurityException {
         SecurityContextTokenSecurityEvent securityContextTokenSecurityEvent = new SecurityContextTokenSecurityEvent();
         securityContextTokenSecurityEvent.setSecurityToken(securityTokenProvider.getSecurityToken());
         securityContextTokenSecurityEvent.setCorrelationID(securityContextTokenType.getId());
-        wsInboundSecurityContext.registerSecurityEvent(securityContextTokenSecurityEvent);
+        return securityContextTokenSecurityEvent;
     }
+
 }
diff --git a/ws-security-stax/src/test/java/org/apache/wss4j/stax/test/SecurityContextTokenTest.java b/ws-security-stax/src/test/java/org/apache/wss4j/stax/test/SecurityContextTokenTest.java
index c429422..5751f75 100644
--- a/ws-security-stax/src/test/java/org/apache/wss4j/stax/test/SecurityContextTokenTest.java
+++ b/ws-security-stax/src/test/java/org/apache/wss4j/stax/test/SecurityContextTokenTest.java
@@ -34,10 +34,12 @@
 import javax.xml.transform.dom.DOMSource;
 import javax.xml.transform.stream.StreamResult;
 
+import org.apache.wss4j.binding.wssc.AbstractSecurityContextTokenType;
 import org.apache.wss4j.common.bsp.BSPRule;
 import org.apache.wss4j.common.crypto.Crypto;
 import org.apache.wss4j.common.crypto.CryptoFactory;
 import org.apache.wss4j.common.derivedKey.ConversationConstants;
+import org.apache.wss4j.common.ext.WSSecurityException;
 import org.apache.wss4j.dom.WSConstants;
 import org.apache.wss4j.dom.engine.WSSConfig;
 import org.apache.wss4j.dom.handler.WSHandlerConstants;
@@ -60,8 +62,12 @@
 import org.apache.wss4j.stax.test.utils.SecretKeyCallbackHandler;
 import org.apache.wss4j.stax.test.utils.StAX2DOM;
 import org.apache.wss4j.stax.test.utils.XmlReaderToWriter;
+import org.apache.wss4j.stax.validate.SecurityContextTokenValidator;
+import org.apache.wss4j.stax.validate.SecurityContextTokenValidatorImpl;
+import org.apache.wss4j.stax.validate.TokenContext;
 import org.apache.xml.security.stax.securityEvent.SecurityEvent;
 import org.apache.xml.security.stax.securityEvent.SignatureValueSecurityEvent;
+import org.apache.xml.security.stax.securityToken.InboundSecurityToken;
 import org.junit.Assert;
 import org.junit.BeforeClass;
 import org.junit.Test;
@@ -752,9 +758,10 @@
             };
             final TestSecurityEventListener securityEventListener = new TestSecurityEventListener(expectedSecurityEvents);
 
-            XMLStreamReader xmlStreamReader = wsSecIn.processInMessage(xmlInputFactory.createXMLStreamReader(new ByteArrayInputStream(baos.toByteArray())), null, securityEventListener);
+            XMLStreamReader xmlStreamReader = xmlInputFactory.createXMLStreamReader(new ByteArrayInputStream(baos.toByteArray()));
+            XMLStreamReader secureXmlStreamReader = wsSecIn.processInMessage(xmlStreamReader, null, securityEventListener);
 
-            Document document = StAX2DOM.readDoc(documentBuilderFactory.newDocumentBuilder(), xmlStreamReader);
+            Document document = StAX2DOM.readDoc(documentBuilderFactory.newDocumentBuilder(), secureXmlStreamReader);
 
             NodeList nodeList = document.getElementsByTagNameNS(WSSConstants.TAG_xenc_EncryptedData.getNamespaceURI(), WSSConstants.TAG_xenc_EncryptedData.getLocalPart());
             Assert.assertEquals(nodeList.getLength(), 0);
@@ -1077,9 +1084,10 @@
             };
             final TestSecurityEventListener securityEventListener = new TestSecurityEventListener(expectedSecurityEvents);
 
-            XMLStreamReader xmlStreamReader = wsSecIn.processInMessage(xmlInputFactory.createXMLStreamReader(new ByteArrayInputStream(baos.toByteArray())), null, securityEventListener);
+            XMLStreamReader xmlStreamReader = xmlInputFactory.createXMLStreamReader(new ByteArrayInputStream(baos.toByteArray()));
+            XMLStreamReader secureXmlStreamReader = wsSecIn.processInMessage(xmlStreamReader, null, securityEventListener);
 
-            StAX2DOM.readDoc(documentBuilderFactory.newDocumentBuilder(), xmlStreamReader);
+            StAX2DOM.readDoc(documentBuilderFactory.newDocumentBuilder(), secureXmlStreamReader);
 
             securityEventListener.compare();
 
@@ -1114,4 +1122,70 @@
             );
         }
     }
+
+    @Test
+    public void testSCTCustomValidator() throws Exception {
+        byte[] tempSecret = WSSecurityUtil.generateNonce(16);
+        ByteArrayOutputStream baos = new ByteArrayOutputStream();
+        {
+            Document doc = SOAPUtil.toSOAPPart(SOAPUtil.SAMPLE_SOAP_MSG);
+            WSSecHeader secHeader = new WSSecHeader(doc);
+            secHeader.insertSecurityHeader();
+
+            WSSecSecurityContextToken sctBuilder = new WSSecSecurityContextToken(secHeader, null);
+            sctBuilder.setWscVersion(version);
+            Crypto crypto = CryptoFactory.getInstance("transmitter-crypto.properties");
+            sctBuilder.prepare(crypto);
+
+            // Store the secret
+            SecretKeyCallbackHandler callbackHandler = new SecretKeyCallbackHandler();
+            callbackHandler.addSecretKey(sctBuilder.getIdentifier(), tempSecret);
+
+            String tokenId = sctBuilder.getSctId();
+
+            WSSecSignature builder = new WSSecSignature(secHeader);
+            builder.setSecretKey(tempSecret);
+            builder.setKeyIdentifierType(WSConstants.CUSTOM_SYMM_SIGNING);
+            builder.setCustomTokenValueType(WSConstants.WSC_SCT);
+            builder.setCustomTokenId(tokenId);
+            builder.setSignatureAlgorithm(SignatureMethod.HMAC_SHA1);
+            builder.build(crypto);
+
+            sctBuilder.prependSCTElementToHeader();
+
+            javax.xml.transform.Transformer transformer = TRANSFORMER_FACTORY.newTransformer();
+            transformer.transform(new DOMSource(doc), new StreamResult(baos));
+        }
+
+        {
+            WSSSecurityProperties securityProperties = new WSSSecurityProperties();
+            securityProperties.loadSignatureVerificationKeystore(this.getClass().getClassLoader().getResource("receiver.jks"), "default".toCharArray());
+            CallbackHandlerImpl callbackHandler = new CallbackHandlerImpl(tempSecret);
+            securityProperties.setCallbackHandler(callbackHandler);
+
+            final boolean[] validatorCalled = {false};
+            SecurityContextTokenValidator validator = new SecurityContextTokenValidatorImpl() {
+                @Override
+                public InboundSecurityToken validate(AbstractSecurityContextTokenType securityContextTokenType, String identifier, TokenContext tokenContext) throws WSSecurityException {
+                    validatorCalled[0] = true;
+                    return super.validate(securityContextTokenType, identifier, tokenContext);
+                }
+            };
+
+            if (version == ConversationConstants.VERSION_05_02) {
+                securityProperties.addValidator(WSSConstants.TAG_WSC0502_SCT, validator);
+            } else {
+                securityProperties.addValidator(WSSConstants.TAG_WSC0512_SCT, validator);
+            }
+
+            InboundWSSec wsSecIn = WSSec.getInboundWSSec(securityProperties);
+
+            XMLStreamReader xmlStreamReader = xmlInputFactory.createXMLStreamReader(new ByteArrayInputStream(baos.toByteArray()));
+            XMLStreamReader secureXmlStreamReader = wsSecIn.processInMessage(xmlStreamReader);
+
+            StAX2DOM.readDoc(documentBuilderFactory.newDocumentBuilder(), secureXmlStreamReader);
+
+            Assert.assertTrue("Validator should be called when configured", validatorCalled[0]);
+        }
+    }
 }