| use std::sync::Arc; |
| |
| extern crate mio; |
| use mio::tcp::{TcpListener, TcpStream, Shutdown}; |
| |
| #[macro_use] |
| extern crate log; |
| |
| use std::fs; |
| use std::io; |
| use std::net; |
| use std::io::{Write, Read, BufReader}; |
| use std::collections::HashMap; |
| |
| #[macro_use] |
| extern crate serde_derive; |
| extern crate docopt; |
| use docopt::Docopt; |
| |
| extern crate env_logger; |
| extern crate vecio; |
| extern crate rustls; |
| |
| use rustls::{RootCertStore, Session, NoClientAuth, AllowAnyAuthenticatedClient, |
| AllowAnyAnonymousOrAuthenticatedClient}; |
| |
| mod util; |
| use util::WriteVAdapter; |
| |
| // Token for our listening socket. |
| const LISTENER: mio::Token = mio::Token(0); |
| |
| // Which mode the server operates in. |
| #[derive(Clone)] |
| enum ServerMode { |
| /// Write back received bytes |
| Echo, |
| |
| /// Do one read, then write a bodged HTTP response and |
| /// cleanly close the connection. |
| Http, |
| |
| /// Forward traffic to/from given port on localhost. |
| Forward(u16), |
| } |
| |
| /// This binds together a TCP listening socket, some outstanding |
| /// connections, and a TLS server configuration. |
| struct TlsServer { |
| server: TcpListener, |
| connections: HashMap<mio::Token, Connection>, |
| next_id: usize, |
| tls_config: Arc<rustls::ServerConfig>, |
| mode: ServerMode, |
| } |
| |
| impl TlsServer { |
| fn new(server: TcpListener, mode: ServerMode, cfg: Arc<rustls::ServerConfig>) -> TlsServer { |
| TlsServer { |
| server, |
| connections: HashMap::new(), |
| next_id: 2, |
| tls_config: cfg, |
| mode, |
| } |
| } |
| |
| fn accept(&mut self, poll: &mut mio::Poll) -> bool { |
| match self.server.accept() { |
| Ok((socket, addr)) => { |
| debug!("Accepting new connection from {:?}", addr); |
| |
| let tls_session = rustls::ServerSession::new(&self.tls_config); |
| let mode = self.mode.clone(); |
| |
| let token = mio::Token(self.next_id); |
| self.next_id += 1; |
| |
| self.connections.insert(token, Connection::new(socket, token, mode, tls_session)); |
| self.connections[&token].register(poll); |
| true |
| } |
| Err(e) => { |
| println!("encountered error while accepting connection; err={:?}", e); |
| false |
| } |
| } |
| } |
| |
| fn conn_event(&mut self, poll: &mut mio::Poll, event: &mio::Event) { |
| let token = event.token(); |
| |
| if self.connections.contains_key(&token) { |
| self.connections |
| .get_mut(&token) |
| .unwrap() |
| .ready(poll, event); |
| |
| if self.connections[&token].is_closed() { |
| self.connections.remove(&token); |
| } |
| } |
| } |
| } |
| |
| /// This is a connection which has been accepted by the server, |
| /// and is currently being served. |
| /// |
| /// It has a TCP-level stream, a TLS-level session, and some |
| /// other state/metadata. |
| struct Connection { |
| socket: TcpStream, |
| token: mio::Token, |
| closing: bool, |
| closed: bool, |
| mode: ServerMode, |
| tls_session: rustls::ServerSession, |
| back: Option<TcpStream>, |
| sent_http_response: bool, |
| } |
| |
| /// Open a plaintext TCP-level connection for forwarded connections. |
| fn open_back(mode: &ServerMode) -> Option<TcpStream> { |
| match *mode { |
| ServerMode::Forward(ref port) => { |
| let addr = net::SocketAddrV4::new(net::Ipv4Addr::new(127, 0, 0, 1), *port); |
| let conn = TcpStream::connect(&net::SocketAddr::V4(addr)).unwrap(); |
| Some(conn) |
| } |
| _ => None, |
| } |
| } |
| |
| /// This used to be conveniently exposed by mio: map EWOULDBLOCK |
| /// errors to something less-errory. |
| fn try_read(r: io::Result<usize>) -> io::Result<Option<usize>> { |
| match r { |
| Ok(len) => Ok(Some(len)), |
| Err(e) => { |
| if e.kind() == io::ErrorKind::WouldBlock { |
| Ok(None) |
| } else { |
| Err(e) |
| } |
| } |
| } |
| } |
| |
| impl Connection { |
| fn new(socket: TcpStream, |
| token: mio::Token, |
| mode: ServerMode, |
| tls_session: rustls::ServerSession) |
| -> Connection { |
| let back = open_back(&mode); |
| Connection { |
| socket, |
| token, |
| closing: false, |
| closed: false, |
| mode, |
| tls_session, |
| back, |
| sent_http_response: false, |
| } |
| } |
| |
| /// We're a connection, and we have something to do. |
| fn ready(&mut self, poll: &mut mio::Poll, ev: &mio::Event) { |
| // If we're readable: read some TLS. Then |
| // see if that yielded new plaintext. Then |
| // see if the backend is readable too. |
| if ev.readiness().is_readable() { |
| self.do_tls_read(); |
| self.try_plain_read(); |
| self.try_back_read(); |
| } |
| |
| if ev.readiness().is_writable() { |
| self.do_tls_write(); |
| } |
| |
| if self.closing && !self.tls_session.wants_write() { |
| let _ = self.socket.shutdown(Shutdown::Both); |
| self.close_back(); |
| self.closed = true; |
| } else { |
| self.reregister(poll); |
| } |
| } |
| |
| /// Close the backend connection for forwarded sessions. |
| fn close_back(&mut self) { |
| if self.back.is_some() { |
| let back = self.back.as_mut().unwrap(); |
| back.shutdown(Shutdown::Both).unwrap(); |
| } |
| self.back = None; |
| } |
| |
| fn do_tls_read(&mut self) { |
| // Read some TLS data. |
| let rc = self.tls_session.read_tls(&mut self.socket); |
| if rc.is_err() { |
| let err = rc.unwrap_err(); |
| |
| if let io::ErrorKind::WouldBlock = err.kind() { |
| return; |
| } |
| |
| error!("read error {:?}", err); |
| self.closing = true; |
| return; |
| } |
| |
| if rc.unwrap() == 0 { |
| debug!("eof"); |
| self.closing = true; |
| return; |
| } |
| |
| // Process newly-received TLS messages. |
| let processed = self.tls_session.process_new_packets(); |
| if processed.is_err() { |
| error!("cannot process packet: {:?}", processed); |
| self.closing = true; |
| return; |
| } |
| } |
| |
| fn try_plain_read(&mut self) { |
| // Read and process all available plaintext. |
| let mut buf = Vec::new(); |
| |
| let rc = self.tls_session.read_to_end(&mut buf); |
| if rc.is_err() { |
| error!("plaintext read failed: {:?}", rc); |
| self.closing = true; |
| return; |
| } |
| |
| if !buf.is_empty() { |
| debug!("plaintext read {:?}", buf.len()); |
| self.incoming_plaintext(&buf); |
| } |
| } |
| |
| fn try_back_read(&mut self) { |
| if self.back.is_none() { |
| return; |
| } |
| |
| // Try a non-blocking read. |
| let mut buf = [0u8; 1024]; |
| let back = self.back.as_mut().unwrap(); |
| let rc = try_read(back.read(&mut buf)); |
| |
| if rc.is_err() { |
| error!("backend read failed: {:?}", rc); |
| self.closing = true; |
| return; |
| } |
| |
| let maybe_len = rc.unwrap(); |
| |
| // If we have a successful but empty read, that's an EOF. |
| // Otherwise, we shove the data into the TLS session. |
| match maybe_len { |
| Some(len) if len == 0 => { |
| debug!("back eof"); |
| self.closing = true; |
| } |
| Some(len) => { |
| self.tls_session.write_all(&buf[..len]).unwrap(); |
| } |
| None => {} |
| }; |
| } |
| |
| /// Process some amount of received plaintext. |
| fn incoming_plaintext(&mut self, buf: &[u8]) { |
| match self.mode { |
| ServerMode::Echo => { |
| self.tls_session.write_all(buf).unwrap(); |
| } |
| ServerMode::Http => { |
| self.send_http_response_once(); |
| } |
| ServerMode::Forward(_) => { |
| self.back.as_mut().unwrap().write_all(buf).unwrap(); |
| } |
| } |
| } |
| |
| fn send_http_response_once(&mut self) { |
| let response = b"HTTP/1.0 200 OK\r\nConnection: close\r\n\r\nHello world from rustls tlsserver\r\n"; |
| if !self.sent_http_response { |
| self.tls_session |
| .write_all(response) |
| .unwrap(); |
| self.sent_http_response = true; |
| self.tls_session.send_close_notify(); |
| } |
| } |
| |
| fn do_tls_write(&mut self) { |
| let rc = self.tls_session.writev_tls(&mut WriteVAdapter::new(&mut self.socket)); |
| if rc.is_err() { |
| error!("write failed {:?}", rc); |
| self.closing = true; |
| return; |
| } |
| } |
| |
| fn register(&self, poll: &mut mio::Poll) { |
| poll.register(&self.socket, |
| self.token, |
| self.event_set(), |
| mio::PollOpt::level() | mio::PollOpt::oneshot()) |
| .unwrap(); |
| |
| if self.back.is_some() { |
| poll.register(self.back.as_ref().unwrap(), |
| self.token, |
| mio::Ready::readable(), |
| mio::PollOpt::level() | mio::PollOpt::oneshot()) |
| .unwrap(); |
| } |
| } |
| |
| fn reregister(&self, poll: &mut mio::Poll) { |
| poll.reregister(&self.socket, |
| self.token, |
| self.event_set(), |
| mio::PollOpt::level() | mio::PollOpt::oneshot()) |
| .unwrap(); |
| |
| if self.back.is_some() { |
| poll.reregister(self.back.as_ref().unwrap(), |
| self.token, |
| mio::Ready::readable(), |
| mio::PollOpt::level() | mio::PollOpt::oneshot()) |
| .unwrap(); |
| } |
| } |
| |
| /// What IO events we're currently waiting for, |
| /// based on wants_read/wants_write. |
| fn event_set(&self) -> mio::Ready { |
| let rd = self.tls_session.wants_read(); |
| let wr = self.tls_session.wants_write(); |
| |
| if rd && wr { |
| mio::Ready::readable() | mio::Ready::writable() |
| } else if wr { |
| mio::Ready::writable() |
| } else { |
| mio::Ready::readable() |
| } |
| } |
| |
| fn is_closed(&self) -> bool { |
| self.closed |
| } |
| } |
| |
| const USAGE: &'static str = |
| " |
| Runs a TLS server on :PORT. The default PORT is 443. |
| |
| `echo' mode means the server echoes received data on each connection. |
| |
| `http' mode means the server blindly sends a HTTP response on each |
| connection. |
| |
| `forward' means the server forwards plaintext to a connection made to |
| localhost:fport. |
| |
| `--certs' names the full certificate chain, `--key' provides the |
| RSA private key. |
| |
| Usage: |
| tlsserver --certs CERTFILE --key KEYFILE [--suite SUITE ...] \ |
| [--proto PROTO ...] [options] echo |
| tlsserver --certs CERTFILE --key KEYFILE [--suite SUITE ...] \ |
| [--proto PROTO ...] [options] http |
| tlsserver --certs CERTFILE --key KEYFILE [--suite SUITE ...] \ |
| [--proto PROTO ...] [options] forward <fport> |
| tlsserver (--version | -v) |
| tlsserver (--help | -h) |
| |
| Options: |
| -p, --port PORT Listen on PORT [default: 443]. |
| --certs CERTFILE Read server certificates from CERTFILE. |
| This should contain PEM-format certificates |
| in the right order (the first certificate should |
| certify KEYFILE, the last should be a root CA). |
| --key KEYFILE Read private key from KEYFILE. This should be a RSA |
| private key or PKCS8-encoded private key, in PEM format. |
| --ocsp OCSPFILE Read DER-encoded OCSP response from OCSPFILE and staple |
| to certificate. Optional. |
| --auth CERTFILE Enable client authentication, and accept certificates |
| signed by those roots provided in CERTFILE. |
| --require-auth Send a fatal alert if the client does not complete client |
| authentication. |
| --resumption Support session resumption. |
| --tickets Support tickets. |
| --protover VERSION Disable default TLS version list, and use |
| VERSION instead. May be used multiple times. |
| --suite SUITE Disable default cipher suite list, and use |
| SUITE instead. May be used multiple times. |
| --proto PROTOCOL Negotiate PROTOCOL using ALPN. |
| May be used multiple times. |
| --verbose Emit log output. |
| --version, -v Show tool version. |
| --help, -h Show this screen. |
| "; |
| |
| #[derive(Debug, Deserialize)] |
| struct Args { |
| cmd_echo: bool, |
| cmd_http: bool, |
| cmd_forward: bool, |
| flag_port: Option<u16>, |
| flag_verbose: bool, |
| flag_protover: Vec<String>, |
| flag_suite: Vec<String>, |
| flag_proto: Vec<String>, |
| flag_certs: Option<String>, |
| flag_key: Option<String>, |
| flag_ocsp: Option<String>, |
| flag_auth: Option<String>, |
| flag_require_auth: bool, |
| flag_resumption: bool, |
| flag_tickets: bool, |
| arg_fport: Option<u16>, |
| } |
| |
| fn find_suite(name: &str) -> Option<&'static rustls::SupportedCipherSuite> { |
| for suite in &rustls::ALL_CIPHERSUITES { |
| let sname = format!("{:?}", suite.suite).to_lowercase(); |
| |
| if sname == name.to_string().to_lowercase() { |
| return Some(suite); |
| } |
| } |
| |
| None |
| } |
| |
| fn lookup_suites(suites: &[String]) -> Vec<&'static rustls::SupportedCipherSuite> { |
| let mut out = Vec::new(); |
| |
| for csname in suites { |
| let scs = find_suite(csname); |
| match scs { |
| Some(s) => out.push(s), |
| None => panic!("cannot look up ciphersuite '{}'", csname), |
| } |
| } |
| |
| out |
| } |
| |
| /// Make a vector of protocol versions named in `versions` |
| fn lookup_versions(versions: &[String]) -> Vec<rustls::ProtocolVersion> { |
| let mut out = Vec::new(); |
| |
| for vname in versions { |
| let version = match vname.as_ref() { |
| "1.2" => rustls::ProtocolVersion::TLSv1_2, |
| "1.3" => rustls::ProtocolVersion::TLSv1_3, |
| _ => panic!("cannot look up version '{}', valid are '1.2' and '1.3'", vname), |
| }; |
| out.push(version); |
| } |
| |
| out |
| } |
| |
| fn load_certs(filename: &str) -> Vec<rustls::Certificate> { |
| let certfile = fs::File::open(filename).expect("cannot open certificate file"); |
| let mut reader = BufReader::new(certfile); |
| rustls::internal::pemfile::certs(&mut reader).unwrap() |
| } |
| |
| fn load_private_key(filename: &str) -> rustls::PrivateKey { |
| let rsa_keys = { |
| let keyfile = fs::File::open(filename) |
| .expect("cannot open private key file"); |
| let mut reader = BufReader::new(keyfile); |
| rustls::internal::pemfile::rsa_private_keys(&mut reader) |
| .expect("file contains invalid rsa private key") |
| }; |
| |
| let pkcs8_keys = { |
| let keyfile = fs::File::open(filename) |
| .expect("cannot open private key file"); |
| let mut reader = BufReader::new(keyfile); |
| rustls::internal::pemfile::pkcs8_private_keys(&mut reader) |
| .expect("file contains invalid pkcs8 private key (encrypted keys not supported)") |
| }; |
| |
| // prefer to load pkcs8 keys |
| if !pkcs8_keys.is_empty() { |
| pkcs8_keys[0].clone() |
| } else { |
| assert!(!rsa_keys.is_empty()); |
| rsa_keys[0].clone() |
| } |
| } |
| |
| fn load_ocsp(filename: &Option<String>) -> Vec<u8> { |
| let mut ret = Vec::new(); |
| |
| if let &Some(ref name) = filename { |
| fs::File::open(name) |
| .expect("cannot open ocsp file") |
| .read_to_end(&mut ret) |
| .unwrap(); |
| } |
| |
| ret |
| } |
| |
| fn make_config(args: &Args) -> Arc<rustls::ServerConfig> { |
| let client_auth = if args.flag_auth.is_some() { |
| let roots = load_certs(args.flag_auth.as_ref().unwrap()); |
| let mut client_auth_roots = RootCertStore::empty(); |
| for root in roots { |
| client_auth_roots.add(&root).unwrap(); |
| } |
| if args.flag_require_auth { |
| AllowAnyAuthenticatedClient::new(client_auth_roots) |
| } else { |
| AllowAnyAnonymousOrAuthenticatedClient::new(client_auth_roots) |
| } |
| } else { |
| NoClientAuth::new() |
| }; |
| |
| let mut config = rustls::ServerConfig::new(client_auth); |
| config.key_log = Arc::new(rustls::KeyLogFile::new()); |
| |
| let certs = load_certs(args.flag_certs.as_ref().expect("--certs option missing")); |
| let privkey = load_private_key(args.flag_key.as_ref().expect("--key option missing")); |
| let ocsp = load_ocsp(&args.flag_ocsp); |
| config.set_single_cert_with_ocsp_and_sct(certs, privkey, ocsp, vec![]) |
| .expect("bad certificates/private key"); |
| |
| if !args.flag_suite.is_empty() { |
| config.ciphersuites = lookup_suites(&args.flag_suite); |
| } |
| |
| if !args.flag_protover.is_empty() { |
| config.versions = lookup_versions(&args.flag_protover); |
| } |
| |
| if args.flag_resumption { |
| config.set_persistence(rustls::ServerSessionMemoryCache::new(256)); |
| } |
| |
| if args.flag_tickets { |
| config.ticketer = rustls::Ticketer::new(); |
| } |
| |
| config.set_protocols(&args.flag_proto); |
| |
| Arc::new(config) |
| } |
| |
| fn main() { |
| let version = env!("CARGO_PKG_NAME").to_string() + ", version: " + env!("CARGO_PKG_VERSION"); |
| |
| let args: Args = Docopt::new(USAGE) |
| .and_then(|d| Ok(d.help(true))) |
| .and_then(|d| Ok(d.version(Some(version)))) |
| .and_then(|d| d.deserialize()) |
| .unwrap_or_else(|e| e.exit()); |
| |
| if args.flag_verbose { |
| env_logger::Builder::new() |
| .parse("trace") |
| .init(); |
| } |
| |
| let mut addr: net::SocketAddr = "0.0.0.0:443".parse().unwrap(); |
| addr.set_port(args.flag_port.unwrap_or(443)); |
| |
| let config = make_config(&args); |
| |
| let listener = TcpListener::bind(&addr).expect("cannot listen on port"); |
| let mut poll = mio::Poll::new() |
| .unwrap(); |
| poll.register(&listener, |
| LISTENER, |
| mio::Ready::readable(), |
| mio::PollOpt::level()) |
| .unwrap(); |
| |
| let mode = if args.cmd_echo { |
| ServerMode::Echo |
| } else if args.cmd_http { |
| ServerMode::Http |
| } else { |
| ServerMode::Forward(args.arg_fport.expect("fport required")) |
| }; |
| |
| let mut tlsserv = TlsServer::new(listener, mode, config); |
| |
| let mut events = mio::Events::with_capacity(256); |
| loop { |
| poll.poll(&mut events, None) |
| .unwrap(); |
| |
| for event in events.iter() { |
| match event.token() { |
| LISTENER => { |
| if !tlsserv.accept(&mut poll) { |
| break; |
| } |
| } |
| _ => tlsserv.conn_event(&mut poll, &event) |
| } |
| } |
| } |
| } |