blob: 4ba0abdf240adaf296a5acd9dc4ab3523d0edb5a [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
use std::collections::HashMap;
use std::path::{Path, PathBuf};
use std::sync::mpsc;
use std::sync::{Arc, Mutex};
use std::thread;
use crate::task_file_manager::TaskFileManager;
use anyhow::Result;
use teaclave_proto::teaclave_common::{ExecutorCommand, ExecutorStatus};
use teaclave_proto::teaclave_scheduler_service::*;
use teaclave_rpc::transport::{channel::Endpoint, Channel};
use teaclave_types::*;
use teaclave_worker::Worker;
use uuid::Uuid;
static WORKER_BASE_DIR: &str = "/tmp/teaclave_agent/";
#[derive(Clone)]
pub(crate) struct TeaclaveExecutionService {
#[allow(dead_code)]
worker: Arc<Worker>,
scheduler_client: TeaclaveSchedulerClient<Channel>,
fusion_base: PathBuf,
id: Uuid,
status: ExecutorStatus,
}
impl TeaclaveExecutionService {
pub(crate) async fn new(
scheduler_service_endpoint: Endpoint,
fusion_base: impl AsRef<Path>,
) -> Result<Self> {
let channel = scheduler_service_endpoint.connect().await?;
let scheduler_client = TeaclaveSchedulerClient::new_with_builtin_config(channel);
Ok(TeaclaveExecutionService {
worker: Arc::new(Worker::default()),
scheduler_client,
fusion_base: fusion_base.as_ref().to_owned(),
id: Uuid::new_v4(),
status: ExecutorStatus::Idle,
})
}
pub(crate) async fn start(&mut self) -> Result<()> {
let (tx, rx) = mpsc::channel();
let mut current_task: Arc<Option<StagedTask>> = Arc::new(None);
let mut task_handle: Option<thread::JoinHandle<()>> = None;
loop {
std::thread::sleep(std::time::Duration::from_secs(3));
match self.heartbeat().await {
Ok(ExecutorCommand::Stop) => {
log::info!("Executor {} is stopped", self.id);
return Err(anyhow::anyhow!("EnclaveForceTermination"));
}
Ok(ExecutorCommand::NewTask) if self.status == ExecutorStatus::Idle => {
match self.pull_task().await {
Ok(task) => {
self.status = ExecutorStatus::Executing;
self.update_task_status(&task.task_id, TaskStatus::Running)
.await?;
let tx_task = tx.clone();
let fusion_base = self.fusion_base.clone();
current_task = Arc::new(Some(task));
let task_copy = current_task.clone();
let handle = thread::spawn(move || {
let result =
invoke_task(task_copy.as_ref().as_ref().unwrap(), &fusion_base);
tx_task.send(result).unwrap();
});
task_handle = Some(handle);
}
Err(e) => {
log::error!("Executor {} failed to pull task: {}", self.id, e);
}
};
}
Err(e) => {
log::error!("Executor {} failed to heartbeat: {}", self.id, e);
return Err(e);
}
_ => {}
}
match rx.try_recv() {
Ok(result) => {
let task_unwrapped = current_task.as_ref().as_ref().unwrap();
match result {
Ok(_) => log::debug!(
"InvokeTask: {:?}, {:?}, success",
task_unwrapped.task_id,
task_unwrapped.function_id
),
Err(_) => log::debug!(
"InvokeTask: {:?}, {:?}, failure",
task_unwrapped.task_id,
task_unwrapped.function_id
),
}
log::debug!("InvokeTask result: {:?}", result);
let task_copy = current_task.clone();
match self
.update_task_result(&task_copy.as_ref().as_ref().unwrap().task_id, result)
.await
{
Ok(_) => (),
Err(e) => {
log::error!("UpdateResult Error: {:?}", e);
continue;
}
}
current_task = Arc::new(None);
task_handle.unwrap().join().unwrap();
task_handle = None;
self.status = ExecutorStatus::Idle;
}
Err(mpsc::TryRecvError::Disconnected) => {
log::error!(
"Executor {} failed to receive, sender disconnected",
self.id
);
}
// received nothing
Err(_) => {}
}
}
}
async fn pull_task(&mut self) -> Result<StagedTask> {
let request = PullTaskRequest {
executor_id: self.id.to_string(),
};
let response = self.scheduler_client.pull_task(request).await?.into_inner();
log::debug!("pull_stask response: {:?}", response);
let staged_task = StagedTask::from_slice(&response.staged_task)?;
Ok(staged_task)
}
async fn heartbeat(&mut self) -> Result<ExecutorCommand> {
let request = HeartbeatRequest::new(self.id, self.status);
let response = self.scheduler_client.heartbeat(request).await?.into_inner();
log::debug!("heartbeat_with_result response: {:?}", response);
response.command.try_into()
}
async fn update_task_result(
&mut self,
task_id: &Uuid,
task_result: Result<TaskOutputs>,
) -> Result<()> {
let request = UpdateTaskResultRequest::new(*task_id, task_result);
let _response = self.scheduler_client.update_task_result(request).await?;
Ok(())
}
async fn update_task_status(&mut self, task_id: &Uuid, task_status: TaskStatus) -> Result<()> {
let request = UpdateTaskStatusRequest::new(task_id.to_owned(), task_status);
let _response = self.scheduler_client.update_task_status(request).await?;
Ok(())
}
}
fn invoke_task(task: &StagedTask, fusion_base: &PathBuf) -> Result<TaskOutputs> {
let save_log = task
.function_arguments
.get("save_log")
.ok()
.and_then(|v| v.as_str().and_then(|s| s.parse().ok()))
.unwrap_or(false);
let log_arc = Arc::new(Mutex::new(Vec::<String>::new()));
if save_log {
let log_arc = Arc::into_raw(log_arc.clone());
log::info!(buffer = log_arc.expose_addr(); "");
}
let file_mgr = TaskFileManager::new(
WORKER_BASE_DIR,
fusion_base,
&task.task_id,
&task.input_data,
&task.output_data,
)?;
let invocation = prepare_task(task, &file_mgr)?;
log::debug!("Invoke function: {:?}", invocation);
let worker = Worker::default();
let summary = worker.invoke_function(invocation)?;
let outputs_tag = finalize_task(&file_mgr)?;
if save_log {
log::info!(buffer = 0; "");
}
let log = Arc::try_unwrap(log_arc)
.map_err(|_| anyhow::anyhow!("log buffer is referenced more than once"))?
.into_inner()?;
let task_outputs = TaskOutputs::new(summary.as_bytes(), outputs_tag, log);
Ok(task_outputs)
}
fn prepare_task(task: &StagedTask, file_mgr: &TaskFileManager) -> Result<StagedFunction> {
let input_files = file_mgr.prepare_staged_inputs()?;
let output_files = file_mgr.prepare_staged_outputs()?;
let staged_function = StagedFunctionBuilder::new()
.executor_type(task.executor_type)
.executor(task.executor)
.name(&task.function_name)
.arguments(task.function_arguments.clone())
.payload(task.function_payload.clone())
.input_files(input_files)
.output_files(output_files)
.runtime_name("default")
.build();
Ok(staged_function)
}
fn finalize_task(file_mgr: &TaskFileManager) -> Result<HashMap<String, FileAuthTag>> {
file_mgr.upload_outputs()
}
#[cfg(feature = "enclave_unit_test")]
pub mod tests {
use super::*;
use serde_json::json;
use std::format;
use teaclave_crypto::*;
use url::Url;
use uuid::Uuid;
pub fn test_invoke_echo() {
let task_id = Uuid::new_v4();
let function_arguments =
FunctionArguments::from_json(json!({"message": "Hello, Teaclave!"})).unwrap();
let staged_task = StagedTaskBuilder::new()
.task_id(task_id)
.executor(Executor::Builtin)
.function_name("builtin-echo")
.function_arguments(function_arguments)
.build();
let file_mgr = TaskFileManager::new(
WORKER_BASE_DIR,
"/tmp/fusion_base",
&staged_task.task_id,
&staged_task.input_data,
&staged_task.output_data,
)
.unwrap();
let invocation = prepare_task(&staged_task, &file_mgr).unwrap();
let worker = Worker::default();
let result = worker.invoke_function(invocation);
if result.is_ok() {
finalize_task(&file_mgr).unwrap();
}
assert_eq!(result.unwrap(), "Hello, Teaclave!");
}
pub fn test_invoke_gbdt_train() {
let task_id = Uuid::new_v4();
let function_arguments = FunctionArguments::from_json(json!({
"feature_size": 4,
"max_depth": 4,
"iterations": 100,
"shrinkage": 0.1,
"feature_sample_ratio": 1.0,
"data_sample_ratio": 1.0,
"min_leaf_size": 1,
"loss": "LAD",
"training_optimization_level": 2,
}))
.unwrap();
let fixture_dir = format!(
"file:///{}/fixtures/functions/gbdt_training",
env!("TEACLAVE_TEST_INSTALL_DIR")
);
let input_url = Url::parse(&format!("{}/train.enc", fixture_dir)).unwrap();
let output_url = Url::parse(&format!("{}/model-{}.enc.out", fixture_dir, task_id)).unwrap();
let crypto = TeaclaveFile128Key::new(&[0; 16]).unwrap();
let input_cmac = FileAuthTag::from_hex("860030495909b84864b991865e9ad94f").unwrap();
let training_input_data = FunctionInputFile::new(input_url, input_cmac, crypto);
let model_output_data = FunctionOutputFile::new(output_url, crypto);
let input_data = hashmap!("training_data" => training_input_data);
let output_data = hashmap!("trained_model" => model_output_data);
let staged_task = StagedTaskBuilder::new()
.task_id(task_id)
.executor(Executor::Builtin)
.function_name("builtin-gbdt-train")
.function_arguments(function_arguments)
.input_data(input_data)
.output_data(output_data)
.build();
let file_mgr = TaskFileManager::new(
WORKER_BASE_DIR,
"/tmp/fusion_base",
&staged_task.task_id,
&staged_task.input_data,
&staged_task.output_data,
)
.unwrap();
let invocation = prepare_task(&staged_task, &file_mgr).unwrap();
let worker = Worker::default();
let result = worker.invoke_function(invocation);
if result.is_ok() {
finalize_task(&file_mgr).unwrap();
}
log::debug!("summary: {:?}", result);
assert!(result.is_ok());
}
}