| // Licensed to the Apache Software Foundation (ASF) under one |
| // or more contributor license agreements. See the NOTICE file |
| // distributed with this work for additional information |
| // regarding copyright ownership. The ASF licenses this file |
| // to you under the Apache License, Version 2.0 (the |
| // "License"); you may not use this file except in compliance |
| // with the License. You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, |
| // software distributed under the License is distributed on an |
| // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| // KIND, either express or implied. See the License for the |
| // specific language governing permissions and limitations |
| // under the License. |
| package streamer; |
| |
| import static streamer.debug.MockServer.Packet.PacketType.CLIENT; |
| import static streamer.debug.MockServer.Packet.PacketType.SERVER; |
| |
| import java.io.IOException; |
| import java.io.InputStream; |
| import java.io.OutputStream; |
| import java.net.InetSocketAddress; |
| import java.net.Socket; |
| import java.util.HashMap; |
| |
| import javax.net.SocketFactory; |
| import javax.net.ssl.SSLContext; |
| import javax.net.ssl.SSLSocket; |
| import javax.net.ssl.SSLSocketFactory; |
| import javax.net.ssl.TrustManager; |
| |
| import org.apache.log4j.Logger; |
| |
| import org.apache.cloudstack.utils.security.SSLUtils; |
| import org.apache.cloudstack.utils.security.SecureSSLSocketFactory; |
| |
| import streamer.debug.MockServer; |
| import streamer.debug.MockServer.Packet; |
| import streamer.ssl.SSLState; |
| import streamer.ssl.TrustAllX509TrustManager; |
| |
| public class SocketWrapperImpl extends PipelineImpl implements SocketWrapper { |
| private static final Logger s_logger = Logger.getLogger(SocketWrapperImpl.class); |
| |
| protected InputStreamSource source; |
| protected OutputStreamSink sink; |
| protected Socket socket; |
| protected InetSocketAddress address; |
| |
| protected SSLSocket sslSocket; |
| |
| protected SSLState sslState; |
| |
| public SocketWrapperImpl(String id, SSLState sslState) { |
| super(id); |
| this.sslState = sslState; |
| } |
| |
| @Override |
| protected HashMap<String, Element> initElementMap(String id) { |
| HashMap<String, Element> map = new HashMap<String, Element>(); |
| |
| source = new InputStreamSource(id + "." + OUT, this); |
| sink = new OutputStreamSink(id + "." + IN, this); |
| |
| // Pass requests to read data to socket input stream |
| map.put(OUT, source); |
| |
| // All incoming data, which is sent to this socket wrapper, will be sent |
| // to socket remote |
| map.put(IN, sink); |
| |
| return map; |
| } |
| |
| /** |
| * Connect this socket wrapper to remote server and start main loop on |
| * IputStreamSource stdout link, to watch for incoming data, and |
| * OutputStreamSink stdin link, to pull for outgoing data. |
| * |
| * @param address |
| * @throws IOException |
| */ |
| @Override |
| public void connect(InetSocketAddress address) throws IOException { |
| this.address = address; |
| |
| // Connect socket to server |
| socket = SocketFactory.getDefault().createSocket(); |
| try { |
| socket.connect(address); |
| |
| InputStream is = socket.getInputStream(); |
| source.setInputStream(is); |
| |
| OutputStream os = socket.getOutputStream(); |
| sink.setOutputStream(os); |
| |
| // Start polling for data to send to remote sever |
| runMainLoop(IN, STDIN, true, true); |
| |
| // Push incoming data from server to handlers |
| runMainLoop(OUT, STDOUT, false, false); |
| |
| } finally { |
| socket.close(); |
| } |
| } |
| |
| @Override |
| public void handleEvent(Event event, Direction direction) { |
| switch (event) { |
| case SOCKET_UPGRADE_TO_SSL: |
| upgradeToSsl(); |
| break; |
| default: |
| super.handleEvent(event, direction); |
| break; |
| } |
| } |
| |
| @Override |
| public void upgradeToSsl() { |
| |
| if (sslSocket != null) |
| // Already upgraded |
| return; |
| |
| if (verbose) |
| System.out.println("[" + this + "] INFO: Upgrading socket to SSL."); |
| |
| try { |
| // Use most secure implementation of SSL available now. |
| // JVM will try to negotiate TLS1.2, then will fallback to TLS1.0, if |
| // TLS1.2 is not supported. |
| SSLContext sslContext = SSLUtils.getSSLContext(); |
| |
| // Trust all certificates (FIXME: insecure) |
| sslContext.init(null, new TrustManager[] {new TrustAllX509TrustManager(sslState)}, null); |
| |
| SSLSocketFactory sslSocketFactory = new SecureSSLSocketFactory(sslContext); |
| sslSocket = (SSLSocket)sslSocketFactory.createSocket(socket, address.getHostName(), address.getPort(), true); |
| sslSocket.setEnabledProtocols(SSLUtils.getSupportedProtocols(sslSocket.getEnabledProtocols())); |
| |
| sslSocket.startHandshake(); |
| |
| InputStream sis = sslSocket.getInputStream(); |
| source.setInputStream(sis); |
| |
| OutputStream sos = sslSocket.getOutputStream(); |
| sink.setOutputStream(sos); |
| |
| } catch (Exception e) { |
| throw new RuntimeException("Cannot upgrade socket to SSL: " + e.getMessage(), e); |
| } |
| |
| } |
| |
| @Override |
| public void validate() { |
| for (Element element : elements.values()) |
| element.validate(); |
| |
| if (get(IN).getPads(Direction.IN).size() == 0) |
| throw new RuntimeException("[ " + this + "] Input of socket is not connected."); |
| |
| if (get(OUT).getPads(Direction.OUT).size() == 0) |
| throw new RuntimeException("[ " + this + "] Output of socket is not connected."); |
| |
| } |
| |
| @Override |
| public void shutdown() { |
| try { |
| handleEvent(Event.STREAM_CLOSE, Direction.IN); |
| } catch (Exception e) { |
| s_logger.info("[ignored]" |
| + "error sending input close event: " + e.getLocalizedMessage()); |
| } |
| try { |
| handleEvent(Event.STREAM_CLOSE, Direction.OUT); |
| } catch (Exception e) { |
| s_logger.info("[ignored]" |
| + "error sending output close event: " + e.getLocalizedMessage()); |
| } |
| try { |
| if (sslSocket != null) |
| sslSocket.close(); |
| } catch (Exception e) { |
| s_logger.info("[ignored]" |
| + "error closing ssl socket: " + e.getLocalizedMessage()); |
| } |
| try { |
| socket.close(); |
| } catch (Exception e) { |
| s_logger.info("[ignored]" |
| + "error closing socket: " + e.getLocalizedMessage()); |
| } |
| } |
| |
| @Override |
| public String toString() { |
| return "SocketWrapper(" + id + ")"; |
| } |
| |
| /** |
| * Example. |
| */ |
| public static void main(String args[]) { |
| |
| try { |
| System.setProperty("streamer.Link.debug", "true"); |
| System.setProperty("streamer.Element.debug", "true"); |
| System.setProperty("rdpclient.MockServer.debug", "true"); |
| |
| Pipeline pipeline = new PipelineImpl("echo client"); |
| |
| SocketWrapperImpl socketWrapper = new SocketWrapperImpl("socket", null); |
| |
| pipeline.add(socketWrapper); |
| pipeline.add(new BaseElement("echo")); |
| pipeline.add(new Queue("queue")); // To decouple input and output |
| |
| pipeline.link("socket", "echo", "queue", "socket"); |
| |
| final byte[] mockData = new byte[] {0x01, 0x02, 0x03}; |
| MockServer server = new MockServer(new Packet[] {new Packet("Server hello") { |
| { |
| type = SERVER; |
| data = mockData; |
| } |
| }, new Packet("Client hello") { |
| { |
| type = CLIENT; |
| data = mockData; |
| } |
| }, new Packet("Server hello") { |
| { |
| type = SERVER; |
| data = mockData; |
| } |
| }, new Packet("Client hello") { |
| { |
| type = CLIENT; |
| data = mockData; |
| } |
| }}); |
| server.start(); |
| InetSocketAddress address = server.getAddress(); |
| |
| /*DEBUG*/System.out.println("Address: " + address); |
| socketWrapper.connect(address); |
| |
| } catch (Exception e) { |
| e.printStackTrace(System.err); |
| } |
| |
| } |
| } |