blob: c7d9b1e45438caae4424c09de94483eb4ce058ef [file] [log] [blame]
// Copyright (C) 2017-2019 Baidu, Inc. All Rights Reserved.
//
// Redistribution and use in source and binary forms, with or without
// modification, are permitted provided that the following conditions
// are met:
//
// * Redistributions of source code must retain the above copyright
// notice, this list of conditions and the following disclaimer.
// * Redistributions in binary form must reproduce the above copyright
// notice, this list of conditions and the following disclaimer in
// the documentation and/or other materials provided with the
// distribution.
// * Neither the name of Baidu, Inc., nor the names of its
// contributors may be used to endorse or promote products derived
// from this software without specific prior written permission.
//
// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR
// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT
// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL,
// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT
// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY
// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
#![crate_name = "tlsserver"]
#![crate_type = "staticlib"]
#![cfg_attr(not(target_env = "sgx"), no_std)]
#![cfg_attr(target_env = "sgx", feature(rustc_private))]
extern crate sgx_types;
extern crate sgx_trts;
#[cfg(not(target_env = "sgx"))]
#[macro_use]
extern crate sgx_tstd as std;
#[macro_use]
extern crate lazy_static;
use sgx_types::*;
use sgx_trts::trts::{rsgx_raw_is_outside_enclave, rsgx_lfence, rsgx_sfence};
use std::untrusted::fs;
use std::io::BufReader;
use std::ffi::CStr;
use std::os::raw::c_char;
use std::vec::Vec;
use std::boxed::Box;
use std::io::{Read, Write};
use std::slice;
use std::sync::{Arc, SgxRwLock};
use std::net::TcpStream;
use std::collections::HashMap;
use std::sync::atomic::{AtomicUsize, AtomicPtr, Ordering};
extern crate webpki;
extern crate rustls;
use rustls::{Session, NoClientAuth};
pub struct TlsServer {
socket: TcpStream,
tls_session: rustls::ServerSession,
}
static GLOBAL_CONTEXT_COUNT: AtomicUsize = AtomicUsize::new(0);
lazy_static! {
static ref GLOBAL_CONTEXTS: SgxRwLock<HashMap<usize, AtomicPtr<TlsServer>>> = {
SgxRwLock::new(HashMap::new())
};
}
impl TlsServer {
fn new(fd: c_int, cfg: Arc<rustls::ServerConfig>) -> TlsServer {
TlsServer {
socket: TcpStream::new(fd).unwrap(),
tls_session: rustls::ServerSession::new(&cfg)
}
}
fn do_read(&mut self) -> c_int {
// Read TLS data. This fails if the underlying TCP connection
// is broken.
let rc = self.tls_session.read_tls(&mut self.socket);
if rc.is_err() {
println!("TLS read error: {:?}", rc);
return -1;
}
// If we're ready but there's no data: EOF.
if rc.unwrap() == 0 {
println!("EOF");
return -1;
}
// Reading some TLS data might have yielded new TLS
// messages to process. Errors from this indicate
// TLS protocol problems and are fatal.
let processed = self.tls_session.process_new_packets();
if processed.is_err() {
println!("TLS error: {:?}", processed.unwrap_err());
return -1;
}
return 0;
}
fn read(&mut self, plaintext: &mut Vec<u8>) -> c_int {
// Having read some TLS data, and processed any new messages,
// we might have new plaintext as a result.
//
// Read it and then write it to stdout.
let rc = self.tls_session.read_to_end(plaintext);
// If that fails, the peer might have started a clean TLS-level
// session closure.
if rc.is_err() {
let err = rc.unwrap_err();
println!("Plaintext read error: {:?}", err);
return -1;
}
plaintext.len() as c_int
}
// fn is_traffic(&self) -> bool {
// !self.tls_session.is_handshaking()
// }
fn write(&mut self, plaintext: &[u8]) -> c_int{
self.tls_session.write(plaintext).unwrap() as c_int
}
fn do_write(&mut self) {
self.tls_session.write_tls(&mut self.socket).unwrap();
}
}
fn load_certs(filename: &str) -> Vec<rustls::Certificate> {
let certfile = fs::File::open(filename).expect("cannot open certificate file");
let mut reader = BufReader::new(certfile);
rustls::internal::pemfile::certs(&mut reader).unwrap()
}
fn load_private_key(filename: &str) -> rustls::PrivateKey {
let rsa_keys = {
let keyfile = fs::File::open(filename)
.expect("cannot open private key file");
let mut reader = BufReader::new(keyfile);
rustls::internal::pemfile::rsa_private_keys(&mut reader)
.expect("file contains invalid rsa private key")
};
let pkcs8_keys = {
let keyfile = fs::File::open(filename)
.expect("cannot open private key file");
let mut reader = BufReader::new(keyfile);
rustls::internal::pemfile::pkcs8_private_keys(&mut reader)
.expect("file contains invalid pkcs8 private key (encrypted keys not supported)")
};
// prefer to load pkcs8 keys
if !pkcs8_keys.is_empty() {
pkcs8_keys[0].clone()
} else {
assert!(!rsa_keys.is_empty());
rsa_keys[0].clone()
}
}
fn make_config(cert: &str, key: &str) -> Arc<rustls::ServerConfig> {
let mut config = rustls::ServerConfig::new(NoClientAuth::new());
let certs = load_certs(cert);
let privkey = load_private_key(key);
config.set_single_cert_with_ocsp_and_sct(certs, privkey, vec![], vec![]).unwrap();
Arc::new(config)
}
struct Sessions;
impl Sessions {
fn new_session(svr_ptr : *mut TlsServer) -> Option<usize> {
match GLOBAL_CONTEXTS.write() {
Ok(mut gctxts) => {
let curr_id = GLOBAL_CONTEXT_COUNT.fetch_add(1, Ordering::SeqCst);
gctxts.insert(curr_id, AtomicPtr::new(svr_ptr));
Some(curr_id)
},
Err(x) => {
println!("Locking global context SgxRwLock failed! {:?}", x);
None
},
}
}
fn get_session(sess_id: size_t) -> Option<*mut TlsServer> {
match GLOBAL_CONTEXTS.read() {
Ok(gctxts) => {
match gctxts.get(&sess_id) {
Some(s) => {
Some(s.load(Ordering::SeqCst))
},
None => {
println!("Global contexts cannot find session id = {}", sess_id);
None
}
}
},
Err(x) => {
println!("Locking global context SgxRwLock failed on get_session! {:?}", x);
None
},
}
}
fn remove_session(sess_id: size_t) {
if let Ok(mut gctxts) = GLOBAL_CONTEXTS.write() {
if let Some(session_ptr) = gctxts.get(&sess_id) {
let session_ptr = session_ptr.load(Ordering::SeqCst);
let session = unsafe { &mut *session_ptr };
let _ = unsafe { Box::<TlsServer>::from_raw(session as *mut _) };
let _ = gctxts.remove(&sess_id);
}
}
}
}
#[no_mangle]
pub extern "C" fn tls_server_new(fd: c_int, cert: * const c_char, key: * const c_char) -> usize {
let certfile = unsafe { CStr::from_ptr(cert).to_str() };
if certfile.is_err() {
return 0xFFFF_FFFF_FFFF_FFFF;
}
let keyfile = unsafe { CStr::from_ptr(key).to_str() };
if keyfile.is_err() {
return 0xFFFF_FFFF_FFFF_FFFF;
}
let config = make_config(certfile.unwrap(), keyfile.unwrap());
let p: *mut TlsServer = Box::into_raw(Box::new(TlsServer::new(fd, config)));
match Sessions::new_session(p) {
Some(s) => s,
None => 0xFFFF_FFFF_FFFF_FFFF,
}
}
#[no_mangle]
pub extern "C" fn tls_server_read(session_id: size_t, buf: * mut c_char, cnt: c_int) -> c_int {
if let Some(session_ptr) = Sessions::get_session(session_id) {
let session = unsafe { &mut *(session_ptr) };
if buf.is_null() || cnt == 0 {
// just read_tls
session.do_read()
} else {
if !rsgx_raw_is_outside_enclave(buf as * const u8, cnt as usize) {
return -1;
}
// read plain buffer
let mut plaintext = Vec::new();
let mut result = session.read(&mut plaintext);
if result == -1 {
return result;
}
if cnt < result {
result = cnt;
}
rsgx_sfence();
let raw_buf = unsafe { slice::from_raw_parts_mut(buf as * mut u8, result as usize) };
raw_buf.copy_from_slice(plaintext.as_slice());
result
}
} else { -1 }
}
#[no_mangle]
pub extern "C" fn tls_server_write(session_id: usize, buf: * const c_char, cnt: c_int) -> c_int {
if let Some(session_ptr) = Sessions::get_session(session_id) {
let session = unsafe { &mut *(session_ptr) };
// no buffer, just write_tls.
if buf.is_null() || cnt == 0 {
session.do_write();
return 0;
}
rsgx_lfence();
// cache buffer, waitting for next write_tls
let cnt = cnt as usize;
let plaintext = unsafe { slice::from_raw_parts(buf as * mut u8, cnt) };
let result = session.write(plaintext);
result
} else { -1 }
}
#[no_mangle]
pub extern "C" fn tls_server_wants_read(session_id: usize) -> c_int {
if let Some(session_ptr) = Sessions::get_session(session_id) {
let session = unsafe { &mut *(session_ptr) };
let result = session.tls_session.wants_read() as c_int;
result
} else { -1 }
}
#[no_mangle]
pub extern "C" fn tls_server_wants_write(session_id: usize) -> c_int {
if let Some(session_ptr) = Sessions::get_session(session_id) {
let session = unsafe { &mut *(session_ptr) };
let result = session.tls_session.wants_write() as c_int;
result
} else { -1 }
}
#[no_mangle]
pub extern "C" fn tls_server_close(session_id: usize) {
Sessions::remove_session(session_id)
}
#[no_mangle]
pub extern "C" fn tls_server_send_close(session_id: usize) {
if let Some(session_ptr) = Sessions::get_session(session_id) {
let session = unsafe { &mut *session_ptr };
session.tls_session.send_close_notify();
}
}