Switch tls sample to use session manager
diff --git a/samplecode/tls/tlsclient/app/src/main.rs b/samplecode/tls/tlsclient/app/src/main.rs
index f7eb96c..7310cfb 100644
--- a/samplecode/tls/tlsclient/app/src/main.rs
+++ b/samplecode/tls/tlsclient/app/src/main.rs
@@ -43,7 +43,6 @@
use std::path;
use std::net::SocketAddr;
use std::str;
-use std::ptr;
use std::io::{self, Read, Write};
const BUFFER_SIZE: usize = 1024;
@@ -52,18 +51,18 @@
static ENCLAVE_TOKEN: &'static str = "enclave.token";
extern {
- fn tls_client_new(eid: sgx_enclave_id_t, retval: *mut *const c_void,
+ fn tls_client_new(eid: sgx_enclave_id_t, retval: *mut usize,
fd: c_int, hostname: *const c_char, cert: *const c_char) -> sgx_status_t;
fn tls_client_read(eid: sgx_enclave_id_t, retval: *mut c_int,
- session: *const c_void, buf: *mut c_void, cnt: c_int) -> sgx_status_t;
+ session_id: usize, buf: *mut c_void, cnt: c_int) -> sgx_status_t;
fn tls_client_write(eid: sgx_enclave_id_t, retval: *mut c_int,
- session: *const c_void, buf: *const c_void, cnt: c_int) -> sgx_status_t;
+ session_id: usize, buf: *const c_void, cnt: c_int) -> sgx_status_t;
fn tls_client_wants_read(eid: sgx_enclave_id_t, retval: *mut c_int,
- session: *const c_void) -> sgx_status_t;
+ session_id: usize) -> sgx_status_t;
fn tls_client_wants_write(eid: sgx_enclave_id_t, retval: *mut c_int,
- session: *const c_void) -> sgx_status_t;
+ session_id: usize) -> sgx_status_t;
fn tls_client_close(eid: sgx_enclave_id_t,
- session: *const c_void) -> sgx_status_t;
+ session_id: usize) -> sgx_status_t;
}
fn init_enclave() -> SgxResult<SgxEnclave> {
@@ -142,7 +141,7 @@
enclave_id: sgx_enclave_id_t,
socket: TcpStream,
closing: bool,
- tlsclient: *const c_void,
+ tlsclient_id: usize,
}
impl TlsClient {
@@ -181,13 +180,13 @@
println!("[+] TlsClient new {} {}", hostname, cert);
- let mut tlsclient: *const c_void = ptr::null();
+ let mut tlsclient_id: usize = 0xFFFF_FFFF_FFFF_FFFF;
let c_host = CString::new(hostname.to_string()).unwrap();
let c_cert = CString::new(cert.to_string()).unwrap();
let retval = unsafe {
tls_client_new(enclave_id,
- &mut tlsclient as *mut *const c_void,
+ &mut tlsclient_id,
sock.as_raw_fd(),
c_host.as_ptr() as *const c_char,
c_cert.as_ptr() as *const c_char)
@@ -198,7 +197,7 @@
return Option::None;
}
- if tlsclient.is_null() {
+ if tlsclient_id == 0xFFFF_FFFF_FFFF_FFFF {
println!("[-] New enclave tlsclient error");
return Option::None;
}
@@ -208,14 +207,14 @@
enclave_id: enclave_id,
socket: sock,
closing: false,
- tlsclient: tlsclient as *const c_void,
+ tlsclient_id: tlsclient_id,
})
}
fn close(&self) {
let retval = unsafe {
- tls_client_close(self.enclave_id, self.tlsclient)
+ tls_client_close(self.enclave_id, self.tlsclient_id)
};
if retval != sgx_status_t::SGX_SUCCESS {
@@ -228,7 +227,7 @@
let result = unsafe {
tls_client_read(self.enclave_id,
&mut retval,
- self.tlsclient,
+ self.tlsclient_id,
buf.as_mut_ptr() as * mut c_void,
buf.len() as c_int)
};
@@ -247,7 +246,7 @@
let result = unsafe {
tls_client_write(self.enclave_id,
&mut retval,
- self.tlsclient,
+ self.tlsclient_id,
buf.as_ptr() as * const c_void,
buf.len() as c_int)
};
@@ -302,7 +301,7 @@
let result = unsafe {
tls_client_wants_read(self.enclave_id,
&mut retval,
- self.tlsclient)
+ self.tlsclient_id)
};
match result {
@@ -323,7 +322,7 @@
let result = unsafe {
tls_client_wants_write(self.enclave_id,
&mut retval,
- self.tlsclient)
+ self.tlsclient_id)
};
match result {
diff --git a/samplecode/tls/tlsclient/enclave/Cargo.toml b/samplecode/tls/tlsclient/enclave/Cargo.toml
index 080548f..c99b6b1 100644
--- a/samplecode/tls/tlsclient/enclave/Cargo.toml
+++ b/samplecode/tls/tlsclient/enclave/Cargo.toml
@@ -18,6 +18,7 @@
[dependencies]
rustls = { git = "https://github.com/mesalock-linux/rustls", branch = "mesalock_sgx" }
webpki = { git = "https://github.com/mesalock-linux/webpki", branch = "mesalock_sgx" }
+lazy_static = { version = "1.4.0", default-features = false, features = ["spin_no_std"] }
# Comment out these following lines to use rust-sgx-sdk from git
[patch.'https://github.com/baidu/rust-sgx-sdk.git']
diff --git a/samplecode/tls/tlsclient/enclave/Enclave.edl b/samplecode/tls/tlsclient/enclave/Enclave.edl
index 777e883..c71ff81 100644
--- a/samplecode/tls/tlsclient/enclave/Enclave.edl
+++ b/samplecode/tls/tlsclient/enclave/Enclave.edl
@@ -41,10 +41,10 @@
/* define ECALLs here. */
public size_t tls_client_new(int fd, [in, string]char* hostname, [in, string] char* cert);
- public int tls_client_read([user_check] void* session, [out, size=cnt] char* buf, int cnt);
- public int tls_client_write([user_check] void* session, [in, size=cnt] char* buf, int cnt);
- public int tls_client_wants_read([user_check] void* session);
- public int tls_client_wants_write([user_check] void* session);
- public void tls_client_close([user_check] void* session);
+ public int tls_client_read(size_t session_id, [out, size=cnt] char* buf, int cnt);
+ public int tls_client_write(size_t session_id, [in, size=cnt] char* buf, int cnt);
+ public int tls_client_wants_read(size_t session_id);
+ public int tls_client_wants_write(size_t session_id);
+ public void tls_client_close(size_t session_id);
};
};
diff --git a/samplecode/tls/tlsclient/enclave/src/lib.rs b/samplecode/tls/tlsclient/enclave/src/lib.rs
index e65dad3..43d4606 100644
--- a/samplecode/tls/tlsclient/enclave/src/lib.rs
+++ b/samplecode/tls/tlsclient/enclave/src/lib.rs
@@ -38,11 +38,13 @@
#[macro_use]
extern crate sgx_tstd as std;
-use sgx_trts::trts::{rsgx_raw_is_outside_enclave, rsgx_lfence};
+#[macro_use]
+extern crate lazy_static;
+
+use sgx_trts::trts::{rsgx_lfence, rsgx_sfence};
use sgx_types::*;
use std::collections;
-use std::mem;
use std::untrusted::fs;
use std::io::BufReader;
@@ -50,14 +52,15 @@
use std::ffi::CStr;
use std::os::raw::c_char;
-use std::ptr;
use std::string::String;
use std::vec::Vec;
use std::boxed::Box;
use std::io::{Read, Write};
use std::slice;
-use std::sync::{Arc, SgxMutex};
+use std::sync::{Arc, SgxMutex, SgxRwLock};
use std::net::TcpStream;
+use std::collections::HashMap;
+use std::sync::atomic::{AtomicUsize, AtomicPtr, Ordering};
extern crate webpki;
extern crate rustls;
@@ -68,6 +71,14 @@
tls_session: rustls::ClientSession,
}
+static GLOBAL_CONTEXT_COUNT: AtomicUsize = AtomicUsize::new(0);
+
+lazy_static! {
+ static ref GLOBAL_CONTEXTS: SgxRwLock<HashMap<usize, AtomicPtr<TlsClient>>> = {
+ SgxRwLock::new(HashMap::new())
+ };
+}
+
impl TlsClient {
fn new(fd: c_int, hostname: &str, cfg: Arc<rustls::ClientConfig>) -> TlsClient {
TlsClient {
@@ -236,123 +247,142 @@
Arc::new(config)
}
+struct Sessions;
+
+impl Sessions {
+ fn new_session(svr_ptr : *mut TlsClient) -> Option<usize> {
+ match GLOBAL_CONTEXTS.write() {
+ Ok(mut gctxts) => {
+ let curr_id = GLOBAL_CONTEXT_COUNT.fetch_add(1, Ordering::SeqCst);
+ gctxts.insert(curr_id, AtomicPtr::new(svr_ptr));
+ Some(curr_id)
+ },
+ Err(x) => {
+ println!("Locking global context SgxRwLock failed! {:?}", x);
+ None
+ },
+ }
+ }
+
+ fn get_session(sess_id: size_t) -> Option<*mut TlsClient> {
+ match GLOBAL_CONTEXTS.read() {
+ Ok(gctxts) => {
+ match gctxts.get(&sess_id) {
+ Some(s) => {
+ Some(s.load(Ordering::SeqCst))
+ },
+ None => {
+ println!("Global contexts cannot find session id = {}", sess_id);
+ None
+ }
+ }
+ },
+ Err(x) => {
+ println!("Locking global context SgxRwLock failed on get_session! {:?}", x);
+ None
+ },
+ }
+ }
+
+ fn remove_session(sess_id: size_t) {
+ if let Ok(mut gctxts) = GLOBAL_CONTEXTS.write() {
+ if let Some(session_ptr) = gctxts.get(&sess_id) {
+ let session_ptr = session_ptr.load(Ordering::SeqCst);
+ let session = unsafe { &mut *session_ptr };
+ let _ = unsafe { Box::<TlsClient>::from_raw(session as *mut _) };
+ let _ = gctxts.remove(&sess_id);
+ }
+ }
+ }
+}
#[no_mangle]
-pub extern "C" fn tls_client_new(fd: c_int, hostname: * const c_char, cert: * const c_char) -> *const c_void {
+pub extern "C" fn tls_client_new(fd: c_int, hostname: * const c_char, cert: * const c_char) -> usize {
let certfile = unsafe { CStr::from_ptr(cert).to_str() };
if certfile.is_err() {
- return ptr::null();
+ return 0xFFFF_FFFF_FFFF_FFFF;
}
let config = make_config(certfile.unwrap());
let name = unsafe { CStr::from_ptr(hostname).to_str() };
let name = match name {
Ok(n) => n,
Err(_) => {
- return ptr::null();
+ return 0xFFFF_FFFF_FFFF_FFFF;
}
};
- Box::into_raw(Box::new(TlsClient::new(fd, name, config))) as *const c_void
+ let p: *mut TlsClient = Box::into_raw(Box::new(TlsClient::new(fd, name, config)));
+ match Sessions::new_session(p) {
+ Some(s) => s,
+ None => 0xFFFF_FFFF_FFFF_FFFF,
+ }
}
#[no_mangle]
-pub extern "C" fn tls_client_read(session: *const c_void, buf: * mut c_char, cnt: c_int) -> c_int {
- if session.is_null() {
- return -1;
- }
-
- if rsgx_raw_is_outside_enclave(session as * const u8, mem::size_of::<TlsClient>()) {
- return -1;
- }
- rsgx_lfence();
-
+pub extern "C" fn tls_client_read(session_id: usize, buf: * mut c_char, cnt: c_int) -> c_int {
if buf.is_null() {
return -1;
}
- let session= unsafe { &mut *(session as *mut TlsClient) };
+ rsgx_sfence();
- let mut plaintext = Vec::new();
- let mut result = session.do_read(&mut plaintext);
+ if let Some(session_ptr) = Sessions::get_session(session_id) {
+ let session= unsafe { &mut *session_ptr };
- if result == -1 {
- return result;
- }
- if cnt < result {
- result = cnt;
- }
+ let mut plaintext = Vec::new();
+ let mut result = session.do_read(&mut plaintext);
- let raw_buf = unsafe { slice::from_raw_parts_mut(buf as * mut u8, result as usize) };
- raw_buf.copy_from_slice(plaintext.as_slice());
- result
-}
-
-#[no_mangle]
-pub extern "C" fn tls_client_write(session: *const c_void, buf: * const c_char, cnt: c_int) -> c_int {
- if session.is_null() {
- return -1;
- }
-
- if rsgx_raw_is_outside_enclave(session as * const u8, mem::size_of::<TlsClient>()) {
- return -1;
- }
- rsgx_lfence();
-
- let session= unsafe { &mut *(session as *mut TlsClient) };
-
- // no buffer, just write_tls.
- if buf.is_null() || cnt == 0 {
- session.do_write();
- 0
- } else {
- let cnt = cnt as usize;
- let plaintext = unsafe { slice::from_raw_parts(buf as * mut u8, cnt) };
- let result = session.write(plaintext);
-
- result
- }
-}
-
-#[no_mangle]
-pub extern "C" fn tls_client_wants_read(session: *const c_void) -> c_int {
- if session.is_null() {
- return -1;
- }
-
- if rsgx_raw_is_outside_enclave(session as * const u8, mem::size_of::<TlsClient>()) {
- return -1;
- }
- rsgx_lfence();
-
- let session= unsafe { &mut *(session as *mut TlsClient) };
- let result = session.tls_session.wants_read() as c_int;
- result
-}
-
-#[no_mangle]
-pub extern "C" fn tls_client_wants_write(session: *const c_void) -> c_int {
- if session.is_null() {
- return -1;
- }
-
- if rsgx_raw_is_outside_enclave(session as * const u8, mem::size_of::<TlsClient>()) {
- return -1;
- }
- rsgx_lfence();
-
- let session= unsafe { &mut *(session as *mut TlsClient) };
- let result = session.tls_session.wants_write() as c_int;
- result
-}
-
-#[no_mangle]
-pub extern "C" fn tls_client_close(session: * const c_void) {
- if !session.is_null() {
-
- if rsgx_raw_is_outside_enclave(session as * const u8, mem::size_of::<TlsClient>()) {
- return;
+ if result == -1 {
+ return result;
}
- rsgx_lfence();
+ if cnt < result {
+ result = cnt;
+ }
- let _ = unsafe { Box::<TlsClient>::from_raw(session as *mut _) };
- }
+ let raw_buf = unsafe { slice::from_raw_parts_mut(buf as * mut u8, result as usize) };
+ raw_buf.copy_from_slice(plaintext.as_slice());
+ result
+ } else { -1 }
+}
+
+#[no_mangle]
+pub extern "C" fn tls_client_write(session_id: usize, buf: * const c_char, cnt: c_int) -> c_int {
+ if let Some(session_ptr) = Sessions::get_session(session_id) {
+ let session = unsafe { &mut *session_ptr };
+
+ // no buffer, just write_tls.
+ if buf.is_null() || cnt == 0 {
+ session.do_write();
+ 0
+ } else {
+ rsgx_lfence();
+ let cnt = cnt as usize;
+ let plaintext = unsafe { slice::from_raw_parts(buf as * mut u8, cnt) };
+ let result = session.write(plaintext);
+
+ result
+ }
+ } else { -1 }
+}
+
+#[no_mangle]
+pub extern "C" fn tls_client_wants_read(session_id: usize) -> c_int {
+ if let Some(session_ptr) = Sessions::get_session(session_id) {
+ let session= unsafe { &mut *session_ptr };
+ let result = session.tls_session.wants_read() as c_int;
+ result
+ } else { -1 }
+}
+
+#[no_mangle]
+pub extern "C" fn tls_client_wants_write(session_id: usize) -> c_int {
+ if let Some(session_ptr) = Sessions::get_session(session_id) {
+ let session= unsafe { &mut *session_ptr };
+ let result = session.tls_session.wants_write() as c_int;
+ result
+ } else { -1 }
+}
+
+#[no_mangle]
+pub extern "C" fn tls_client_close(session_id: usize) {
+ Sessions::remove_session(session_id)
}
diff --git a/samplecode/tls/tlsserver/app/src/main.rs b/samplecode/tls/tlsserver/app/src/main.rs
index 82ce503..1a52cbe 100644
--- a/samplecode/tls/tlsserver/app/src/main.rs
+++ b/samplecode/tls/tlsserver/app/src/main.rs
@@ -44,7 +44,6 @@
use std::path;
use std::net;
use std::str;
-use std::ptr;
use std::io;
use std::io::{Read, Write};
use std::collections::HashMap;
@@ -55,20 +54,20 @@
static ENCLAVE_TOKEN: &'static str = "enclave.token";
extern {
- fn tls_server_new(eid: sgx_enclave_id_t, retval: *mut * const c_void,
+ fn tls_server_new(eid: sgx_enclave_id_t, retval: *mut size_t,
fd: c_int, cert: *const c_char, key: *const c_char) -> sgx_status_t;
fn tls_server_read(eid: sgx_enclave_id_t, retval: *mut c_int,
- session: *const c_void, buf: *mut c_void, cnt: c_int) -> sgx_status_t;
+ session_id: size_t, buf: *mut c_void, cnt: c_int) -> sgx_status_t;
fn tls_server_write(eid: sgx_enclave_id_t, retval: *mut c_int,
- session: *const c_void, buf: *const c_void, cnt: c_int) -> sgx_status_t;
+ session_id: size_t, buf: *const c_void, cnt: c_int) -> sgx_status_t;
fn tls_server_wants_read(eid: sgx_enclave_id_t, retval: *mut c_int,
- session: *const c_void) -> sgx_status_t;
+ session_id: size_t) -> sgx_status_t;
fn tls_server_wants_write(eid: sgx_enclave_id_t, retval: *mut c_int,
- session: *const c_void) -> sgx_status_t;
+ session_id: size_t) -> sgx_status_t;
fn tls_server_close(eid: sgx_enclave_id_t,
- session: *const c_void) -> sgx_status_t;
+ session_id: size_t) -> sgx_status_t;
fn tls_server_send_close(edi: sgx_enclave_id_t,
- session: *const c_void) -> sgx_status_t;
+ session_id: size_t) -> sgx_status_t;
}
fn init_enclave() -> SgxResult<SgxEnclave> {
@@ -190,10 +189,10 @@
println!("Accepting new connection from {:?}", addr);
- let mut tlsserver: *const c_void = ptr::null();
+ let mut tlsserver_id: usize = 0xFFFF_FFFF_FFFF_FFFF;
let retval = unsafe {
tls_server_new(self.enclave_id,
- &mut tlsserver as *mut *const c_void,
+ &mut tlsserver_id,
socket.as_raw_fd(),
self.cert.as_bytes_with_nul().as_ptr() as * const c_char,
self.key.as_bytes_with_nul().as_ptr() as * const c_char)
@@ -204,7 +203,7 @@
return false;
}
- if tlsserver.is_null() {
+ if tlsserver_id == 0xFFFF_FFFF_FFFF_FFFF {
println!("[-] New enclave tlsserver error");
return false;
}
@@ -216,7 +215,7 @@
socket,
token,
mode,
- tlsserver));
+ tlsserver_id));
self.connections[&token].register(poll);
true
}
@@ -254,7 +253,7 @@
token: mio::Token,
closing: bool,
mode: ServerMode,
- tlsserver: *const c_void,
+ tlsserver_id: usize,
back: Option<TcpStream>,
sent_http_response: bool,
}
@@ -291,7 +290,7 @@
socket: TcpStream,
token: mio::Token,
mode: ServerMode,
- tlsserver: * const c_void)
+ tlsserver_id: usize)
-> Connection {
let back = open_back(&mode);
Connection {
@@ -300,7 +299,7 @@
token: token,
closing: false,
mode: mode,
- tlsserver: tlsserver,
+ tlsserver_id: tlsserver_id,
back: back,
sent_http_response: false,
}
@@ -311,7 +310,7 @@
let result = unsafe {
tls_server_read(self.enclave_id,
&mut retval,
- self.tlsserver,
+ self.tlsserver_id,
buf.as_ptr() as * mut c_void,
buf.len() as c_int)
};
@@ -329,7 +328,7 @@
let result = unsafe {
tls_server_write(self.enclave_id,
&mut retval,
- self.tlsserver,
+ self.tlsserver_id,
buf.as_ptr() as * const c_void,
buf.len() as c_int)
};
@@ -347,7 +346,7 @@
let result = unsafe {
tls_server_wants_read(self.enclave_id,
&mut retval,
- self.tlsserver)
+ self.tlsserver_id)
};
match result {
sgx_status_t::SGX_SUCCESS => {},
@@ -368,7 +367,7 @@
let result = unsafe {
tls_server_wants_write(self.enclave_id,
&mut retval,
- self.tlsserver)
+ self.tlsserver_id)
};
match result {
@@ -387,13 +386,13 @@
fn tls_close(&self) {
unsafe {
- tls_server_close(self.enclave_id, self.tlsserver)
+ tls_server_close(self.enclave_id, self.tlsserver_id)
};
}
fn send_close_notify(&self) {
unsafe {
- tls_server_send_close(self.enclave_id, self.tlsserver);
+ tls_server_send_close(self.enclave_id, self.tlsserver_id);
}
}
diff --git a/samplecode/tls/tlsserver/enclave/Cargo.toml b/samplecode/tls/tlsserver/enclave/Cargo.toml
index 99e7a0b..78e0835 100644
--- a/samplecode/tls/tlsserver/enclave/Cargo.toml
+++ b/samplecode/tls/tlsserver/enclave/Cargo.toml
@@ -18,6 +18,7 @@
[dependencies]
rustls = { git = "https://github.com/mesalock-linux/rustls", branch = "mesalock_sgx" }
webpki = { git = "https://github.com/mesalock-linux/webpki", branch = "mesalock_sgx" }
+lazy_static = { version = "1.4.0", default-features = false, features = ["spin_no_std"] }
[patch.'https://github.com/baidu/rust-sgx-sdk.git']
sgx_alloc = { path = "../../../../sgx_alloc" }
diff --git a/samplecode/tls/tlsserver/enclave/Enclave.edl b/samplecode/tls/tlsserver/enclave/Enclave.edl
index 1fd61c9..df94eff 100644
--- a/samplecode/tls/tlsserver/enclave/Enclave.edl
+++ b/samplecode/tls/tlsserver/enclave/Enclave.edl
@@ -39,13 +39,12 @@
trusted {
/* define ECALLs here. */
-
public size_t tls_server_new(int fd, [in, string]char* cert, [in, string] char* key);
- public int tls_server_read([user_check] void* session, [user_check] char* buf, int cnt);
- public int tls_server_write([user_check] void* session, [in, size=cnt] char* buf, int cnt);
- public int tls_server_wants_read([user_check] void* session);
- public int tls_server_wants_write([user_check] void* session);
- public void tls_server_close([user_check] void* session);
- public void tls_server_send_close([user_check] void* session);
+ public int tls_server_read(size_t session_id, [user_check] char* buf, int cnt);
+ public int tls_server_write(size_t session_id, [in, size=cnt] char* buf, int cnt);
+ public int tls_server_wants_read(size_t session_id);
+ public int tls_server_wants_write(size_t session_id);
+ public void tls_server_close(size_t session_id);
+ public void tls_server_send_close(size_t session_id);
};
};
diff --git a/samplecode/tls/tlsserver/enclave/src/lib.rs b/samplecode/tls/tlsserver/enclave/src/lib.rs
index 11ef8ed..c7d9b1e 100644
--- a/samplecode/tls/tlsserver/enclave/src/lib.rs
+++ b/samplecode/tls/tlsserver/enclave/src/lib.rs
@@ -38,9 +38,11 @@
#[macro_use]
extern crate sgx_tstd as std;
+#[macro_use]
+extern crate lazy_static;
+
use sgx_types::*;
-use sgx_trts::trts::{rsgx_raw_is_outside_enclave, rsgx_lfence};
-use std::mem;
+use sgx_trts::trts::{rsgx_raw_is_outside_enclave, rsgx_lfence, rsgx_sfence};
use std::untrusted::fs;
use std::io::BufReader;
@@ -48,13 +50,14 @@
use std::ffi::CStr;
use std::os::raw::c_char;
-use std::ptr;
use std::vec::Vec;
use std::boxed::Box;
use std::io::{Read, Write};
use std::slice;
-use std::sync::Arc;
+use std::sync::{Arc, SgxRwLock};
use std::net::TcpStream;
+use std::collections::HashMap;
+use std::sync::atomic::{AtomicUsize, AtomicPtr, Ordering};
extern crate webpki;
extern crate rustls;
@@ -65,6 +68,14 @@
tls_session: rustls::ServerSession,
}
+static GLOBAL_CONTEXT_COUNT: AtomicUsize = AtomicUsize::new(0);
+
+lazy_static! {
+ static ref GLOBAL_CONTEXTS: SgxRwLock<HashMap<usize, AtomicPtr<TlsServer>>> = {
+ SgxRwLock::new(HashMap::new())
+ };
+}
+
impl TlsServer {
fn new(fd: c_int, cfg: Arc<rustls::ServerConfig>) -> TlsServer {
TlsServer {
@@ -172,142 +183,150 @@
Arc::new(config)
}
+struct Sessions;
+
+impl Sessions {
+ fn new_session(svr_ptr : *mut TlsServer) -> Option<usize> {
+ match GLOBAL_CONTEXTS.write() {
+ Ok(mut gctxts) => {
+ let curr_id = GLOBAL_CONTEXT_COUNT.fetch_add(1, Ordering::SeqCst);
+ gctxts.insert(curr_id, AtomicPtr::new(svr_ptr));
+ Some(curr_id)
+ },
+ Err(x) => {
+ println!("Locking global context SgxRwLock failed! {:?}", x);
+ None
+ },
+ }
+ }
+
+ fn get_session(sess_id: size_t) -> Option<*mut TlsServer> {
+ match GLOBAL_CONTEXTS.read() {
+ Ok(gctxts) => {
+ match gctxts.get(&sess_id) {
+ Some(s) => {
+ Some(s.load(Ordering::SeqCst))
+ },
+ None => {
+ println!("Global contexts cannot find session id = {}", sess_id);
+ None
+ }
+ }
+ },
+ Err(x) => {
+ println!("Locking global context SgxRwLock failed on get_session! {:?}", x);
+ None
+ },
+ }
+ }
+
+ fn remove_session(sess_id: size_t) {
+ if let Ok(mut gctxts) = GLOBAL_CONTEXTS.write() {
+ if let Some(session_ptr) = gctxts.get(&sess_id) {
+ let session_ptr = session_ptr.load(Ordering::SeqCst);
+ let session = unsafe { &mut *session_ptr };
+ let _ = unsafe { Box::<TlsServer>::from_raw(session as *mut _) };
+ let _ = gctxts.remove(&sess_id);
+ }
+ }
+ }
+}
+
#[no_mangle]
-pub extern "C" fn tls_server_new(fd: c_int, cert: * const c_char, key: * const c_char) -> *const c_void {
+pub extern "C" fn tls_server_new(fd: c_int, cert: * const c_char, key: * const c_char) -> usize {
let certfile = unsafe { CStr::from_ptr(cert).to_str() };
if certfile.is_err() {
- return ptr::null();
+ return 0xFFFF_FFFF_FFFF_FFFF;
}
let keyfile = unsafe { CStr::from_ptr(key).to_str() };
if keyfile.is_err() {
- return ptr::null();
+ return 0xFFFF_FFFF_FFFF_FFFF;
}
let config = make_config(certfile.unwrap(), keyfile.unwrap());
- Box::into_raw(Box::new(TlsServer::new(fd, config))) as *const c_void
+ let p: *mut TlsServer = Box::into_raw(Box::new(TlsServer::new(fd, config)));
+ match Sessions::new_session(p) {
+ Some(s) => s,
+ None => 0xFFFF_FFFF_FFFF_FFFF,
+ }
}
#[no_mangle]
-pub extern "C" fn tls_server_read(session: *const c_void, buf: * mut c_char, cnt: c_int) -> c_int {
- if session.is_null() {
- return -1;
- }
-
- if rsgx_raw_is_outside_enclave(session as * const u8, mem::size_of::<TlsServer>()) {
- return -1;
- }
- rsgx_lfence();
-
- let session = unsafe { &mut *(session as *mut TlsServer) };
-
- if buf.is_null() || cnt == 0 {
- // just read_tls
- session.do_read()
- } else {
-
- if !rsgx_raw_is_outside_enclave(buf as * const u8, cnt as usize) {
- return -1;
+pub extern "C" fn tls_server_read(session_id: size_t, buf: * mut c_char, cnt: c_int) -> c_int {
+ if let Some(session_ptr) = Sessions::get_session(session_id) {
+ let session = unsafe { &mut *(session_ptr) };
+ if buf.is_null() || cnt == 0 {
+ // just read_tls
+ session.do_read()
+ } else {
+ if !rsgx_raw_is_outside_enclave(buf as * const u8, cnt as usize) {
+ return -1;
+ }
+ // read plain buffer
+ let mut plaintext = Vec::new();
+ let mut result = session.read(&mut plaintext);
+ if result == -1 {
+ return result;
+ }
+ if cnt < result {
+ result = cnt;
+ }
+ rsgx_sfence();
+ let raw_buf = unsafe { slice::from_raw_parts_mut(buf as * mut u8, result as usize) };
+ raw_buf.copy_from_slice(plaintext.as_slice());
+ result
}
+ } else { -1 }
+}
+
+#[no_mangle]
+pub extern "C" fn tls_server_write(session_id: usize, buf: * const c_char, cnt: c_int) -> c_int {
+ if let Some(session_ptr) = Sessions::get_session(session_id) {
+ let session = unsafe { &mut *(session_ptr) };
+
+ // no buffer, just write_tls.
+ if buf.is_null() || cnt == 0 {
+ session.do_write();
+ return 0;
+ }
+
rsgx_lfence();
+ // cache buffer, waitting for next write_tls
+ let cnt = cnt as usize;
+ let plaintext = unsafe { slice::from_raw_parts(buf as * mut u8, cnt) };
+ let result = session.write(plaintext);
- // read plain buffer
- let mut plaintext = Vec::new();
- let mut result = session.read(&mut plaintext);
- if result == -1 {
- return result;
- }
-
- if cnt < result {
- result = cnt;
- }
- let raw_buf = unsafe { slice::from_raw_parts_mut(buf as * mut u8, result as usize) };
- raw_buf.copy_from_slice(plaintext.as_slice());
result
- }
+ } else { -1 }
}
#[no_mangle]
-pub extern "C" fn tls_server_write(session: *const c_void, buf: * const c_char, cnt: c_int) -> c_int {
- if session.is_null() {
- return -1;
- }
-
- if rsgx_raw_is_outside_enclave(session as * const u8, mem::size_of::<TlsServer>()) {
- return -1;
- }
- rsgx_lfence();
-
- let session = unsafe { &mut *(session as *mut TlsServer) };
-
- // no buffer, just write_tls.
- if buf.is_null() || cnt == 0 {
- session.do_write();
- return 0;
- }
-
- // cache buffer, waitting for next write_tls
- let cnt = cnt as usize;
- let plaintext = unsafe { slice::from_raw_parts(buf as * mut u8, cnt) };
- let result = session.write(plaintext);
-
- result
+pub extern "C" fn tls_server_wants_read(session_id: usize) -> c_int {
+ if let Some(session_ptr) = Sessions::get_session(session_id) {
+ let session = unsafe { &mut *(session_ptr) };
+ let result = session.tls_session.wants_read() as c_int;
+ result
+ } else { -1 }
}
#[no_mangle]
-pub extern "C" fn tls_server_wants_read(session: *const c_void) -> c_int {
- if session.is_null() {
- return -1;
- }
-
- if rsgx_raw_is_outside_enclave(session as * const u8, mem::size_of::<TlsServer>()) {
- return -1;
- }
- rsgx_lfence();
-
- let session = unsafe { &mut *(session as *mut TlsServer) };
- let result = session.tls_session.wants_read() as c_int;
- result
+pub extern "C" fn tls_server_wants_write(session_id: usize) -> c_int {
+ if let Some(session_ptr) = Sessions::get_session(session_id) {
+ let session = unsafe { &mut *(session_ptr) };
+ let result = session.tls_session.wants_write() as c_int;
+ result
+ } else { -1 }
}
#[no_mangle]
-pub extern "C" fn tls_server_wants_write(session: *const c_void) -> c_int {
- if session.is_null() {
- return -1;
- }
-
- if rsgx_raw_is_outside_enclave(session as * const u8, mem::size_of::<TlsServer>()) {
- return -1;
- }
- rsgx_lfence();
-
- let session = unsafe { &mut *(session as *mut TlsServer) };
- let result = session.tls_session.wants_write() as c_int;
- result
+pub extern "C" fn tls_server_close(session_id: usize) {
+ Sessions::remove_session(session_id)
}
#[no_mangle]
-pub extern "C" fn tls_server_close(session: * const c_void) {
- if !session.is_null() {
-
- if rsgx_raw_is_outside_enclave(session as * const u8, mem::size_of::<TlsServer>()) {
- return;
- }
- rsgx_lfence();
-
- let _ = unsafe { Box::<TlsServer>::from_raw(session as *mut _) };
- }
-}
-
-#[no_mangle]
-pub extern "C" fn tls_server_send_close(session: * const c_void) {
- if !session.is_null() {
-
- if rsgx_raw_is_outside_enclave(session as * const u8, mem::size_of::<TlsServer>()) {
- return;
- }
- rsgx_lfence();
-
- let session = unsafe { &mut *(session as *mut TlsServer) };
+pub extern "C" fn tls_server_send_close(session_id: usize) {
+ if let Some(session_ptr) = Sessions::get_session(session_id) {
+ let session = unsafe { &mut *session_ptr };
session.tls_session.send_close_notify();
}
}