| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * https://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| package org.apache.avro.ipc; |
| |
| import java.io.IOException; |
| import java.io.EOFException; |
| import java.net.SocketAddress; |
| import java.nio.channels.SocketChannel; |
| import java.nio.charset.StandardCharsets; |
| import java.nio.Buffer; |
| import java.nio.ByteBuffer; |
| import java.util.ArrayList; |
| import java.util.List; |
| |
| import javax.security.sasl.Sasl; |
| import javax.security.sasl.SaslException; |
| import javax.security.sasl.SaslClient; |
| import javax.security.sasl.SaslServer; |
| |
| import org.apache.avro.Protocol; |
| import org.apache.avro.util.ByteBufferOutputStream; |
| |
| import org.slf4j.Logger; |
| import org.slf4j.LoggerFactory; |
| |
| /** |
| * A {@link Transceiver} that uses {@link javax.security.sasl} for |
| * authentication and encryption. |
| */ |
| public class SaslSocketTransceiver extends Transceiver { |
| private static final Logger LOG = LoggerFactory.getLogger(SaslSocketTransceiver.class); |
| |
| private static final ByteBuffer EMPTY = ByteBuffer.allocate(0); |
| |
| private static enum Status { |
| START, CONTINUE, FAIL, COMPLETE |
| } |
| |
| private SaslParticipant sasl; |
| private SocketChannel channel; |
| private boolean dataIsWrapped; |
| private boolean saslResponsePiggybacked; |
| |
| private Protocol remote; |
| |
| private ByteBuffer readHeader = ByteBuffer.allocate(4); |
| private ByteBuffer writeHeader = ByteBuffer.allocate(4); |
| private ByteBuffer zeroHeader = ByteBuffer.allocate(4).putInt(0); |
| |
| /** |
| * Create using SASL's anonymous |
| * (<a href="https://www.ietf.org/rfc/rfc2245.txt">RFC 2245) mechanism. |
| */ |
| public SaslSocketTransceiver(SocketAddress address) throws IOException { |
| this(address, new AnonymousClient()); |
| } |
| |
| /** Create using the specified {@link SaslClient}. */ |
| public SaslSocketTransceiver(SocketAddress address, SaslClient saslClient) throws IOException { |
| this.sasl = new SaslParticipant(saslClient); |
| this.channel = SocketChannel.open(address); |
| this.channel.socket().setTcpNoDelay(true); |
| LOG.debug("open to {}", getRemoteName()); |
| open(true); |
| } |
| |
| /** Create using the specified {@link SaslServer}. */ |
| public SaslSocketTransceiver(SocketChannel channel, SaslServer saslServer) throws IOException { |
| this.sasl = new SaslParticipant(saslServer); |
| this.channel = channel; |
| LOG.debug("open from {}", getRemoteName()); |
| open(false); |
| } |
| |
| @Override |
| public boolean isConnected() { |
| return remote != null; |
| } |
| |
| @Override |
| public void setRemote(Protocol remote) { |
| this.remote = remote; |
| } |
| |
| @Override |
| public Protocol getRemote() { |
| return remote; |
| } |
| |
| @Override |
| public String getRemoteName() { |
| return channel.socket().getRemoteSocketAddress().toString(); |
| } |
| |
| @Override |
| public synchronized List<ByteBuffer> transceive(List<ByteBuffer> request) throws IOException { |
| if (saslResponsePiggybacked) { // still need to read response |
| saslResponsePiggybacked = false; |
| Status status = readStatus(); |
| ByteBuffer frame = readFrame(); |
| switch (status) { |
| case COMPLETE: |
| break; |
| case FAIL: |
| throw new SaslException("Fail: " + toString(frame)); |
| default: |
| throw new IOException("Unexpected SASL status: " + status); |
| } |
| } |
| return super.transceive(request); |
| } |
| |
| private void open(boolean isClient) throws IOException { |
| LOG.debug("beginning SASL negotiation"); |
| |
| if (isClient) { |
| ByteBuffer response = EMPTY; |
| if (sasl.client.hasInitialResponse()) |
| response = ByteBuffer.wrap(sasl.evaluate(response.array())); |
| write(Status.START, sasl.getMechanismName(), response); |
| if (sasl.isComplete()) |
| saslResponsePiggybacked = true; |
| } |
| |
| while (!sasl.isComplete()) { |
| Status status = readStatus(); |
| ByteBuffer frame = readFrame(); |
| switch (status) { |
| case START: |
| String mechanism = toString(frame); |
| frame = readFrame(); |
| if (!mechanism.equalsIgnoreCase(sasl.getMechanismName())) { |
| write(Status.FAIL, "Wrong mechanism: " + mechanism); |
| throw new SaslException("Wrong mechanism: " + mechanism); |
| } |
| case CONTINUE: |
| byte[] response; |
| try { |
| response = sasl.evaluate(frame.array()); |
| status = sasl.isComplete() ? Status.COMPLETE : Status.CONTINUE; |
| } catch (SaslException e) { |
| response = e.toString().getBytes(StandardCharsets.UTF_8); |
| status = Status.FAIL; |
| } |
| write(status, response != null ? ByteBuffer.wrap(response) : EMPTY); |
| break; |
| case COMPLETE: |
| sasl.evaluate(frame.array()); |
| if (!sasl.isComplete()) |
| throw new SaslException("Expected completion!"); |
| break; |
| case FAIL: |
| throw new SaslException("Fail: " + toString(frame)); |
| default: |
| throw new IOException("Unexpected SASL status: " + status); |
| } |
| } |
| LOG.debug("SASL opened"); |
| |
| String qop = (String) sasl.getNegotiatedProperty(Sasl.QOP); |
| LOG.debug("QOP = {}", qop); |
| dataIsWrapped = (qop != null && !qop.equalsIgnoreCase("auth")); |
| } |
| |
| private String toString(ByteBuffer buffer) { |
| return new String(buffer.array(), StandardCharsets.UTF_8); |
| } |
| |
| @Override |
| public synchronized List<ByteBuffer> readBuffers() throws IOException { |
| List<ByteBuffer> buffers = new ArrayList<>(); |
| while (true) { |
| ByteBuffer buffer = readFrameAndUnwrap(); |
| if (((Buffer) buffer).remaining() == 0) |
| return buffers; |
| buffers.add(buffer); |
| } |
| } |
| |
| private Status readStatus() throws IOException { |
| ByteBuffer buffer = ByteBuffer.allocate(1); |
| read(buffer); |
| int status = buffer.get(); |
| if (status > Status.values().length) |
| throw new IOException("Unexpected SASL status byte: " + status); |
| return Status.values()[status]; |
| } |
| |
| private ByteBuffer readFrameAndUnwrap() throws IOException { |
| ByteBuffer frame = readFrame(); |
| if (!dataIsWrapped) |
| return frame; |
| ByteBuffer unwrapped = ByteBuffer.wrap(sasl.unwrap(frame.array())); |
| LOG.debug("unwrapped data of length: {}", unwrapped.remaining()); |
| return unwrapped; |
| } |
| |
| private ByteBuffer readFrame() throws IOException { |
| read(readHeader); |
| ByteBuffer buffer = ByteBuffer.allocate(readHeader.getInt()); |
| LOG.debug("about to read: {} bytes", buffer.capacity()); |
| read(buffer); |
| return buffer; |
| } |
| |
| private void read(ByteBuffer buffer) throws IOException { |
| ((Buffer) buffer).clear(); |
| while (buffer.hasRemaining()) |
| if (channel.read(buffer) == -1) |
| throw new EOFException(); |
| ((Buffer) buffer).flip(); |
| } |
| |
| @Override |
| public synchronized void writeBuffers(List<ByteBuffer> buffers) throws IOException { |
| if (buffers == null) |
| return; // no data to write |
| List<ByteBuffer> writes = new ArrayList<>(buffers.size() * 2 + 1); |
| int currentLength = 0; |
| ByteBuffer currentHeader = writeHeader; |
| for (ByteBuffer buffer : buffers) { // gather writes |
| if (buffer.remaining() == 0) |
| continue; // ignore empties |
| if (dataIsWrapped) { |
| LOG.debug("wrapping data of length: {}", buffer.remaining()); |
| buffer = ByteBuffer.wrap(sasl.wrap(buffer.array(), buffer.position(), buffer.remaining())); |
| } |
| int length = buffer.remaining(); |
| if (!dataIsWrapped // can append buffers on wire |
| && (currentLength + length) <= ByteBufferOutputStream.BUFFER_SIZE) { |
| if (currentLength == 0) |
| writes.add(currentHeader); |
| currentLength += length; |
| ((Buffer) currentHeader).clear(); |
| currentHeader.putInt(currentLength); |
| LOG.debug("adding {} to write, total now {}", length, currentLength); |
| } else { |
| currentLength = length; |
| currentHeader = ByteBuffer.allocate(4).putInt(length); |
| writes.add(currentHeader); |
| LOG.debug("planning write of {}", length); |
| } |
| ((Buffer) currentHeader).flip(); |
| writes.add(buffer); |
| } |
| ((Buffer) zeroHeader).flip(); // zero-terminate |
| writes.add(zeroHeader); |
| |
| writeFully(writes.toArray(new ByteBuffer[0])); |
| } |
| |
| private void write(Status status, String prefix, ByteBuffer response) throws IOException { |
| LOG.debug("write status: {} {}", status, prefix); |
| write(status, prefix); |
| write(response); |
| } |
| |
| private void write(Status status, String response) throws IOException { |
| write(status, ByteBuffer.wrap(response.getBytes(StandardCharsets.UTF_8))); |
| } |
| |
| private void write(Status status, ByteBuffer response) throws IOException { |
| LOG.debug("write status: {}", status); |
| ByteBuffer statusBuffer = ByteBuffer.allocate(1); |
| ((Buffer) statusBuffer).clear(); |
| ((Buffer) statusBuffer.put((byte) (status.ordinal()))).flip(); |
| writeFully(statusBuffer); |
| write(response); |
| } |
| |
| private void write(ByteBuffer response) throws IOException { |
| LOG.debug("writing: {}", response.remaining()); |
| ((Buffer) writeHeader).clear(); |
| ((Buffer) writeHeader.putInt(response.remaining())).flip(); |
| writeFully(writeHeader, response); |
| } |
| |
| private void writeFully(ByteBuffer... buffers) throws IOException { |
| int length = buffers.length; |
| int start = 0; |
| do { |
| channel.write(buffers, start, length - start); |
| while (buffers[start].remaining() == 0) { |
| start++; |
| if (start == length) |
| return; |
| } |
| } while (true); |
| } |
| |
| @Override |
| public void close() throws IOException { |
| if (channel.isOpen()) { |
| LOG.info("closing to " + getRemoteName()); |
| channel.close(); |
| } |
| sasl.dispose(); |
| } |
| |
| /** |
| * Used to abstract over the <code>SaslServer</code> and <code>SaslClient</code> |
| * classes, which share a lot of their interface, but unfortunately don't share |
| * a common superclass. |
| */ |
| private static class SaslParticipant { |
| // One of these will always be null. |
| public SaslServer server; |
| public SaslClient client; |
| |
| public SaslParticipant(SaslServer server) { |
| this.server = server; |
| } |
| |
| public SaslParticipant(SaslClient client) { |
| this.client = client; |
| } |
| |
| public String getMechanismName() { |
| if (client != null) |
| return client.getMechanismName(); |
| else |
| return server.getMechanismName(); |
| } |
| |
| public boolean isComplete() { |
| if (client != null) |
| return client.isComplete(); |
| else |
| return server.isComplete(); |
| } |
| |
| public void dispose() throws SaslException { |
| if (client != null) |
| client.dispose(); |
| else |
| server.dispose(); |
| } |
| |
| public byte[] unwrap(byte[] buf) throws SaslException { |
| if (client != null) |
| return client.unwrap(buf, 0, buf.length); |
| else |
| return server.unwrap(buf, 0, buf.length); |
| } |
| |
| public byte[] wrap(byte[] buf, int start, int len) throws SaslException { |
| if (client != null) |
| return client.wrap(buf, start, len); |
| else |
| return server.wrap(buf, start, len); |
| } |
| |
| public Object getNegotiatedProperty(String propName) { |
| if (client != null) |
| return client.getNegotiatedProperty(propName); |
| else |
| return server.getNegotiatedProperty(propName); |
| } |
| |
| public byte[] evaluate(byte[] buf) throws SaslException { |
| if (client != null) |
| return client.evaluateChallenge(buf); |
| else |
| return server.evaluateResponse(buf); |
| } |
| |
| } |
| |
| private static class AnonymousClient implements SaslClient { |
| @Override |
| public String getMechanismName() { |
| return "ANONYMOUS"; |
| } |
| |
| @Override |
| public boolean hasInitialResponse() { |
| return true; |
| } |
| |
| @Override |
| public byte[] evaluateChallenge(byte[] challenge) throws SaslException { |
| return System.getProperty("user.name").getBytes(StandardCharsets.UTF_8); |
| } |
| |
| @Override |
| public boolean isComplete() { |
| return true; |
| } |
| |
| @Override |
| public byte[] unwrap(byte[] incoming, int offset, int len) { |
| throw new UnsupportedOperationException(); |
| } |
| |
| @Override |
| public byte[] wrap(byte[] outgoing, int offset, int len) { |
| throw new UnsupportedOperationException(); |
| } |
| |
| @Override |
| public Object getNegotiatedProperty(String propName) { |
| return null; |
| } |
| |
| @Override |
| public void dispose() { |
| } |
| } |
| } |