| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| package org.apache.thrift.transport.sasl; |
| |
| import java.nio.channels.SelectionKey; |
| import java.nio.charset.StandardCharsets; |
| |
| import javax.security.sasl.SaslServer; |
| |
| import org.apache.thrift.TByteArrayOutputStream; |
| import org.apache.thrift.TProcessor; |
| import org.apache.thrift.protocol.TProtocol; |
| import org.apache.thrift.protocol.TProtocolFactory; |
| import org.apache.thrift.server.ServerContext; |
| import org.apache.thrift.server.TServerEventHandler; |
| import org.apache.thrift.transport.TMemoryTransport; |
| import org.apache.thrift.transport.TNonblockingTransport; |
| import org.apache.thrift.transport.TTransportException; |
| import org.apache.thrift.transport.sasl.TSaslNegotiationException.ErrorType; |
| import org.slf4j.Logger; |
| import org.slf4j.LoggerFactory; |
| |
| import static org.apache.thrift.transport.sasl.NegotiationStatus.COMPLETE; |
| import static org.apache.thrift.transport.sasl.NegotiationStatus.OK; |
| |
| /** |
| * State machine managing one sasl connection in a nonblocking way. |
| */ |
| public class NonblockingSaslHandler { |
| private static final Logger LOGGER = LoggerFactory.getLogger(NonblockingSaslHandler.class); |
| |
| private static final int INTEREST_NONE = 0; |
| private static final int INTEREST_READ = SelectionKey.OP_READ; |
| private static final int INTEREST_WRITE = SelectionKey.OP_WRITE; |
| |
| // Tracking the current running phase |
| private Phase currentPhase = Phase.INITIIALIIZING; |
| // Tracking the next phase on the next invocation of the state machine. |
| // It should be the same as current phase if current phase is not yet finished. |
| // Otherwise, if it is different from current phase, the statemachine is in a transition state: |
| // current phase is done, and next phase is not yet started. |
| private Phase nextPhase = currentPhase; |
| |
| // Underlying nonblocking transport |
| private SelectionKey selectionKey; |
| private TNonblockingTransport underlyingTransport; |
| |
| // APIs for intercepting event / customizing behaviors: |
| // Factories (decorating the base implementations) & EventHandler (intercepting) |
| private TSaslServerFactory saslServerFactory; |
| private TSaslProcessorFactory processorFactory; |
| private TProtocolFactory inputProtocolFactory; |
| private TProtocolFactory outputProtocolFactory; |
| private TServerEventHandler eventHandler; |
| private ServerContext serverContext; |
| // It turns out the event handler implementation in hive sometimes creates a null ServerContext. |
| // In order to know whether TServerEventHandler#createContext is called we use such a flag. |
| private boolean serverContextCreated = false; |
| |
| // Wrapper around sasl server |
| private ServerSaslPeer saslPeer; |
| |
| // Sasl negotiation io |
| private SaslNegotiationFrameReader saslResponse; |
| private SaslNegotiationFrameWriter saslChallenge; |
| // IO for request from and response to the socket |
| private DataFrameReader requestReader; |
| private DataFrameWriter responseWriter; |
| // If sasl is negotiated for integrity/confidentiality protection |
| private boolean dataProtected; |
| |
| public NonblockingSaslHandler(SelectionKey selectionKey, TNonblockingTransport underlyingTransport, |
| TSaslServerFactory saslServerFactory, TSaslProcessorFactory processorFactory, |
| TProtocolFactory inputProtocolFactory, TProtocolFactory outputProtocolFactory, |
| TServerEventHandler eventHandler) { |
| this.selectionKey = selectionKey; |
| this.underlyingTransport = underlyingTransport; |
| this.saslServerFactory = saslServerFactory; |
| this.processorFactory = processorFactory; |
| this.inputProtocolFactory = inputProtocolFactory; |
| this.outputProtocolFactory = outputProtocolFactory; |
| this.eventHandler = eventHandler; |
| |
| saslResponse = new SaslNegotiationFrameReader(); |
| saslChallenge = new SaslNegotiationFrameWriter(); |
| requestReader = new DataFrameReader(); |
| responseWriter = new DataFrameWriter(); |
| } |
| |
| /** |
| * Get current phase of the state machine. |
| * |
| * @return current phase. |
| */ |
| public Phase getCurrentPhase() { |
| return currentPhase; |
| } |
| |
| /** |
| * Get next phase of the state machine. |
| * It is different from current phase iff current phase is done (and next phase not yet started). |
| * |
| * @return next phase. |
| */ |
| public Phase getNextPhase() { |
| return nextPhase; |
| } |
| |
| /** |
| * |
| * @return underlying nonblocking socket |
| */ |
| public TNonblockingTransport getUnderlyingTransport() { |
| return underlyingTransport; |
| } |
| |
| /** |
| * |
| * @return SaslServer instance |
| */ |
| public SaslServer getSaslServer() { |
| return saslPeer.getSaslServer(); |
| } |
| |
| /** |
| * |
| * @return true if current phase is done. |
| */ |
| public boolean isCurrentPhaseDone() { |
| return currentPhase != nextPhase; |
| } |
| |
| /** |
| * Run state machine. |
| * |
| * @throws IllegalStateException if current state is already done. |
| */ |
| public void runCurrentPhase() { |
| currentPhase.runStateMachine(this); |
| } |
| |
| /** |
| * When current phase is intrested in read selection, calling this will run the current phase and |
| * its following phases if the following ones are interested to read, until there is nothing |
| * available in the underlying transport. |
| * |
| * @throws IllegalStateException if is called in an irrelevant phase. |
| */ |
| public void handleRead() { |
| handleOps(INTEREST_READ); |
| } |
| |
| /** |
| * Similiar to handleRead. But it is for write ops. |
| * |
| * @throws IllegalStateException if it is called in an irrelevant phase. |
| */ |
| public void handleWrite() { |
| handleOps(INTEREST_WRITE); |
| } |
| |
| private void handleOps(int interestOps) { |
| if (currentPhase.selectionInterest != interestOps) { |
| throw new IllegalStateException("Current phase " + currentPhase + " but got interest " + |
| interestOps); |
| } |
| runCurrentPhase(); |
| if (isCurrentPhaseDone() && nextPhase.selectionInterest == interestOps) { |
| stepToNextPhase(); |
| handleOps(interestOps); |
| } |
| } |
| |
| /** |
| * When current phase is finished, it's expected to call this method first before running the |
| * state machine again. |
| * By calling this, "next phase" is marked as started (and not done), thus is ready to run. |
| * |
| * @throws IllegalArgumentException if current phase is not yet done. |
| */ |
| public void stepToNextPhase() { |
| if (!isCurrentPhaseDone()) { |
| throw new IllegalArgumentException("Not yet done with current phase: " + currentPhase); |
| } |
| LOGGER.debug("Switch phase {} to {}", currentPhase, nextPhase); |
| switch (nextPhase) { |
| case INITIIALIIZING: |
| throw new IllegalStateException("INITIALIZING cannot be the next phase of " + currentPhase); |
| default: |
| } |
| // If next phase's interest is not the same as current, nor the same as the selection key, |
| // we need to change interest on the selector. |
| if (!(nextPhase.selectionInterest == currentPhase.selectionInterest || |
| nextPhase.selectionInterest == selectionKey.interestOps())) { |
| changeSelectionInterest(nextPhase.selectionInterest); |
| } |
| currentPhase = nextPhase; |
| } |
| |
| private void changeSelectionInterest(int selectionInterest) { |
| selectionKey.interestOps(selectionInterest); |
| } |
| |
| // sasl negotiaion failure handling |
| private void failSaslNegotiation(TSaslNegotiationException e) { |
| LOGGER.error("Sasl negotiation failed", e); |
| String errorMsg = e.getDetails(); |
| saslChallenge.withHeaderAndPayload(new byte[]{e.getErrorType().code.getValue()}, |
| errorMsg.getBytes(StandardCharsets.UTF_8)); |
| nextPhase = Phase.WRITING_FAILURE_MESSAGE; |
| } |
| |
| private void fail(Exception e) { |
| LOGGER.error("Failed io in " + currentPhase, e); |
| nextPhase = Phase.CLOSING; |
| } |
| |
| private void failIO(TTransportException e) { |
| StringBuilder errorMsg = new StringBuilder("IO failure ") |
| .append(e.getType()) |
| .append(" in ") |
| .append(currentPhase); |
| if (e.getMessage() != null) { |
| errorMsg.append(": ").append(e.getMessage()); |
| } |
| LOGGER.error(errorMsg.toString(), e); |
| nextPhase = Phase.CLOSING; |
| } |
| |
| // Read handlings |
| |
| private void handleInitializing() { |
| try { |
| saslResponse.read(underlyingTransport); |
| if (saslResponse.isComplete()) { |
| SaslNegotiationHeaderReader startHeader = saslResponse.getHeader(); |
| if (startHeader.getStatus() != NegotiationStatus.START) { |
| throw new TInvalidSaslFrameException("Expecting START status but got " + startHeader.getStatus()); |
| } |
| String mechanism = new String(saslResponse.getPayload(), StandardCharsets.UTF_8); |
| saslPeer = saslServerFactory.getSaslPeer(mechanism); |
| saslResponse.clear(); |
| nextPhase = Phase.READING_SASL_RESPONSE; |
| } |
| } catch (TSaslNegotiationException e) { |
| failSaslNegotiation(e); |
| } catch (TTransportException e) { |
| failIO(e); |
| } |
| } |
| |
| private void handleReadingSaslResponse() { |
| try { |
| saslResponse.read(underlyingTransport); |
| if (saslResponse.isComplete()) { |
| nextPhase = Phase.EVALUATING_SASL_RESPONSE; |
| } |
| } catch (TSaslNegotiationException e) { |
| failSaslNegotiation(e); |
| } catch (TTransportException e) { |
| failIO(e); |
| } |
| } |
| |
| private void handleReadingRequest() { |
| try { |
| requestReader.read(underlyingTransport); |
| if (requestReader.isComplete()) { |
| nextPhase = Phase.PROCESSING; |
| } |
| } catch (TTransportException e) { |
| failIO(e); |
| } |
| } |
| |
| // Computation executions |
| |
| private void executeEvaluatingSaslResponse() { |
| if (!(saslResponse.getHeader().getStatus() == OK || saslResponse.getHeader().getStatus() == COMPLETE)) { |
| String error = "Expect status OK or COMPLETE, but got " + saslResponse.getHeader().getStatus(); |
| failSaslNegotiation(new TSaslNegotiationException(ErrorType.PROTOCOL_ERROR, error)); |
| return; |
| } |
| try { |
| byte[] response = saslResponse.getPayload(); |
| saslResponse.clear(); |
| byte[] newChallenge = saslPeer.evaluate(response); |
| if (saslPeer.isAuthenticated()) { |
| dataProtected = saslPeer.isDataProtected(); |
| saslChallenge.withHeaderAndPayload(new byte[]{COMPLETE.getValue()}, newChallenge); |
| nextPhase = Phase.WRITING_SUCCESS_MESSAGE; |
| } else { |
| saslChallenge.withHeaderAndPayload(new byte[]{OK.getValue()}, newChallenge); |
| nextPhase = Phase.WRITING_SASL_CHALLENGE; |
| } |
| } catch (TSaslNegotiationException e) { |
| failSaslNegotiation(e); |
| } |
| } |
| |
| private void executeProcessing() { |
| try { |
| byte[] inputPayload = requestReader.getPayload(); |
| requestReader.clear(); |
| byte[] rawInput = dataProtected ? saslPeer.unwrap(inputPayload) : inputPayload; |
| TMemoryTransport memoryTransport = new TMemoryTransport(rawInput); |
| TProtocol requestProtocol = inputProtocolFactory.getProtocol(memoryTransport); |
| TProtocol responseProtocol = outputProtocolFactory.getProtocol(memoryTransport); |
| |
| if (eventHandler != null) { |
| if (!serverContextCreated) { |
| serverContext = eventHandler.createContext(requestProtocol, responseProtocol); |
| serverContextCreated = true; |
| } |
| eventHandler.processContext(serverContext, memoryTransport, memoryTransport); |
| } |
| |
| TProcessor processor = processorFactory.getProcessor(this); |
| processor.process(requestProtocol, responseProtocol); |
| TByteArrayOutputStream rawOutput = memoryTransport.getOutput(); |
| if (rawOutput.len() == 0) { |
| // This is a oneway request, no response to send back. Waiting for next incoming request. |
| nextPhase = Phase.READING_REQUEST; |
| return; |
| } |
| if (dataProtected) { |
| byte[] outputPayload = saslPeer.wrap(rawOutput.get(), 0, rawOutput.len()); |
| responseWriter.withOnlyPayload(outputPayload); |
| } else { |
| responseWriter.withOnlyPayload(rawOutput.get(), 0 ,rawOutput.len()); |
| } |
| nextPhase = Phase.WRITING_RESPONSE; |
| } catch (TTransportException e) { |
| failIO(e); |
| } catch (Exception e) { |
| fail(e); |
| } |
| } |
| |
| // Write handlings |
| |
| private void handleWritingSaslChallenge() { |
| try { |
| saslChallenge.write(underlyingTransport); |
| if (saslChallenge.isComplete()) { |
| saslChallenge.clear(); |
| nextPhase = Phase.READING_SASL_RESPONSE; |
| } |
| } catch (TTransportException e) { |
| fail(e); |
| } |
| } |
| |
| private void handleWritingSuccessMessage() { |
| try { |
| saslChallenge.write(underlyingTransport); |
| if (saslChallenge.isComplete()) { |
| LOGGER.debug("Authentication is done."); |
| saslChallenge = null; |
| saslResponse = null; |
| nextPhase = Phase.READING_REQUEST; |
| } |
| } catch (TTransportException e) { |
| fail(e); |
| } |
| } |
| |
| private void handleWritingFailureMessage() { |
| try { |
| saslChallenge.write(underlyingTransport); |
| if (saslChallenge.isComplete()) { |
| nextPhase = Phase.CLOSING; |
| } |
| } catch (TTransportException e) { |
| fail(e); |
| } |
| } |
| |
| private void handleWritingResponse() { |
| try { |
| responseWriter.write(underlyingTransport); |
| if (responseWriter.isComplete()) { |
| responseWriter.clear(); |
| nextPhase = Phase.READING_REQUEST; |
| } |
| } catch (TTransportException e) { |
| fail(e); |
| } |
| } |
| |
| /** |
| * Release all the resources managed by this state machine (connection, selection and sasl server). |
| * To avoid being blocked, this should be invoked in the network thread that manages the selector. |
| */ |
| public void close() { |
| underlyingTransport.close(); |
| selectionKey.cancel(); |
| if (saslPeer != null) { |
| saslPeer.dispose(); |
| } |
| if (serverContextCreated) { |
| eventHandler.deleteContext(serverContext, |
| inputProtocolFactory.getProtocol(underlyingTransport), |
| outputProtocolFactory.getProtocol(underlyingTransport)); |
| } |
| nextPhase = Phase.CLOSED; |
| currentPhase = Phase.CLOSED; |
| LOGGER.trace("Connection closed: {}", underlyingTransport); |
| } |
| |
| public enum Phase { |
| INITIIALIIZING(INTEREST_READ) { |
| @Override |
| void unsafeRun(NonblockingSaslHandler statemachine) { |
| statemachine.handleInitializing(); |
| } |
| }, |
| READING_SASL_RESPONSE(INTEREST_READ) { |
| @Override |
| void unsafeRun(NonblockingSaslHandler statemachine) { |
| statemachine.handleReadingSaslResponse(); |
| } |
| }, |
| EVALUATING_SASL_RESPONSE(INTEREST_NONE) { |
| @Override |
| void unsafeRun(NonblockingSaslHandler statemachine) { |
| statemachine.executeEvaluatingSaslResponse(); |
| } |
| }, |
| WRITING_SASL_CHALLENGE(INTEREST_WRITE) { |
| @Override |
| void unsafeRun(NonblockingSaslHandler statemachine) { |
| statemachine.handleWritingSaslChallenge(); |
| } |
| }, |
| WRITING_SUCCESS_MESSAGE(INTEREST_WRITE) { |
| @Override |
| void unsafeRun(NonblockingSaslHandler statemachine) { |
| statemachine.handleWritingSuccessMessage(); |
| } |
| }, |
| WRITING_FAILURE_MESSAGE(INTEREST_WRITE) { |
| @Override |
| void unsafeRun(NonblockingSaslHandler statemachine) { |
| statemachine.handleWritingFailureMessage(); |
| } |
| }, |
| READING_REQUEST(INTEREST_READ) { |
| @Override |
| void unsafeRun(NonblockingSaslHandler statemachine) { |
| statemachine.handleReadingRequest(); |
| } |
| }, |
| PROCESSING(INTEREST_NONE) { |
| @Override |
| void unsafeRun(NonblockingSaslHandler statemachine) { |
| statemachine.executeProcessing(); |
| } |
| }, |
| WRITING_RESPONSE(INTEREST_WRITE) { |
| @Override |
| void unsafeRun(NonblockingSaslHandler statemachine) { |
| statemachine.handleWritingResponse(); |
| } |
| }, |
| CLOSING(INTEREST_NONE) { |
| @Override |
| void unsafeRun(NonblockingSaslHandler statemachine) { |
| statemachine.close(); |
| } |
| }, |
| CLOSED(INTEREST_NONE) { |
| @Override |
| void unsafeRun(NonblockingSaslHandler statemachine) { |
| // Do nothing. |
| } |
| } |
| ; |
| |
| // The interest on the selection key during the phase |
| private int selectionInterest; |
| |
| Phase(int selectionInterest) { |
| this.selectionInterest = selectionInterest; |
| } |
| |
| /** |
| * Provide the execution to run for the state machine in current phase. The execution should |
| * return the next phase after running on the state machine. |
| * |
| * @param statemachine The state machine to run. |
| * @throws IllegalArgumentException if the state machine's current phase is different. |
| * @throws IllegalStateException if the state machine' current phase is already done. |
| */ |
| void runStateMachine(NonblockingSaslHandler statemachine) { |
| if (statemachine.currentPhase != this) { |
| throw new IllegalArgumentException("State machine is " + statemachine.currentPhase + |
| " but is expected to be " + this); |
| } |
| if (statemachine.isCurrentPhaseDone()) { |
| throw new IllegalStateException("State machine should step into " + statemachine.nextPhase); |
| } |
| unsafeRun(statemachine); |
| } |
| |
| // Run the state machine without checkiing its own phase |
| // It should not be called direcly by users. |
| abstract void unsafeRun(NonblockingSaslHandler statemachine); |
| } |
| } |