| // Licensed to the Apache Software Foundation (ASF) under one |
| // or more contributor license agreements. See the NOTICE file |
| // distributed with this work for additional information |
| // regarding copyright ownership. The ASF licenses this file |
| // to you under the Apache License, Version 2.0 (the |
| // "License"); you may not use this file except in compliance |
| // with the License. You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, |
| // software distributed under the License is distributed on an |
| // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| // KIND, either express or implied. See the License for the |
| // specific language governing permissions and limitations |
| // under the License. |
| |
| use std::collections::HashMap; |
| use std::path::{Path, PathBuf}; |
| use std::sync::mpsc; |
| use std::sync::{Arc, Mutex}; |
| use std::thread; |
| |
| use crate::task_file_manager::TaskFileManager; |
| use anyhow::Result; |
| use teaclave_proto::teaclave_common::{ExecutorCommand, ExecutorStatus}; |
| use teaclave_proto::teaclave_scheduler_service::*; |
| use teaclave_rpc::transport::{channel::Endpoint, Channel}; |
| use teaclave_types::*; |
| use teaclave_worker::Worker; |
| use uuid::Uuid; |
| |
| static WORKER_BASE_DIR: &str = "/tmp/teaclave_agent/"; |
| |
| #[derive(Clone)] |
| pub(crate) struct TeaclaveExecutionService { |
| #[allow(dead_code)] |
| worker: Arc<Worker>, |
| scheduler_client: TeaclaveSchedulerClient<Channel>, |
| fusion_base: PathBuf, |
| id: Uuid, |
| status: ExecutorStatus, |
| } |
| |
| impl TeaclaveExecutionService { |
| pub(crate) async fn new( |
| scheduler_service_endpoint: Endpoint, |
| fusion_base: impl AsRef<Path>, |
| ) -> Result<Self> { |
| let channel = scheduler_service_endpoint.connect().await?; |
| let scheduler_client = TeaclaveSchedulerClient::new_with_builtin_config(channel); |
| |
| Ok(TeaclaveExecutionService { |
| worker: Arc::new(Worker::default()), |
| scheduler_client, |
| fusion_base: fusion_base.as_ref().to_owned(), |
| id: Uuid::new_v4(), |
| status: ExecutorStatus::Idle, |
| }) |
| } |
| |
| pub(crate) async fn start(&mut self) -> Result<()> { |
| let (tx, rx) = mpsc::channel(); |
| let mut current_task: Arc<Option<StagedTask>> = Arc::new(None); |
| let mut task_handle: Option<thread::JoinHandle<()>> = None; |
| |
| loop { |
| std::thread::sleep(std::time::Duration::from_secs(3)); |
| |
| match self.heartbeat().await { |
| Ok(ExecutorCommand::Stop) => { |
| log::info!("Executor {} is stopped", self.id); |
| return Err(anyhow::anyhow!("EnclaveForceTermination")); |
| } |
| Ok(ExecutorCommand::NewTask) if self.status == ExecutorStatus::Idle => { |
| match self.pull_task().await { |
| Ok(task) => { |
| self.status = ExecutorStatus::Executing; |
| self.update_task_status(&task.task_id, TaskStatus::Running) |
| .await?; |
| let tx_task = tx.clone(); |
| let fusion_base = self.fusion_base.clone(); |
| current_task = Arc::new(Some(task)); |
| let task_copy = current_task.clone(); |
| let handle = thread::spawn(move || { |
| let result = |
| invoke_task(task_copy.as_ref().as_ref().unwrap(), &fusion_base); |
| tx_task.send(result).unwrap(); |
| }); |
| task_handle = Some(handle); |
| } |
| Err(e) => { |
| log::error!("Executor {} failed to pull task: {}", self.id, e); |
| } |
| }; |
| } |
| Err(e) => { |
| log::error!("Executor {} failed to heartbeat: {}", self.id, e); |
| return Err(e); |
| } |
| _ => {} |
| } |
| |
| match rx.try_recv() { |
| Ok(result) => { |
| let task_unwrapped = current_task.as_ref().as_ref().unwrap(); |
| match result { |
| Ok(_) => log::debug!( |
| "InvokeTask: {:?}, {:?}, success", |
| task_unwrapped.task_id, |
| task_unwrapped.function_id |
| ), |
| Err(_) => log::debug!( |
| "InvokeTask: {:?}, {:?}, failure", |
| task_unwrapped.task_id, |
| task_unwrapped.function_id |
| ), |
| } |
| log::debug!("InvokeTask result: {:?}", result); |
| let task_copy = current_task.clone(); |
| match self |
| .update_task_result(&task_copy.as_ref().as_ref().unwrap().task_id, result) |
| .await |
| { |
| Ok(_) => (), |
| Err(e) => { |
| log::error!("UpdateResult Error: {:?}", e); |
| continue; |
| } |
| } |
| current_task = Arc::new(None); |
| task_handle.unwrap().join().unwrap(); |
| task_handle = None; |
| self.status = ExecutorStatus::Idle; |
| } |
| Err(mpsc::TryRecvError::Disconnected) => { |
| log::error!( |
| "Executor {} failed to receive, sender disconnected", |
| self.id |
| ); |
| } |
| // received nothing |
| Err(_) => {} |
| } |
| } |
| } |
| |
| async fn pull_task(&mut self) -> Result<StagedTask> { |
| let request = PullTaskRequest { |
| executor_id: self.id.to_string(), |
| }; |
| let response = self.scheduler_client.pull_task(request).await?.into_inner(); |
| |
| log::debug!("pull_stask response: {:?}", response); |
| let staged_task = StagedTask::from_slice(&response.staged_task)?; |
| Ok(staged_task) |
| } |
| |
| async fn heartbeat(&mut self) -> Result<ExecutorCommand> { |
| let request = HeartbeatRequest::new(self.id, self.status); |
| let response = self.scheduler_client.heartbeat(request).await?.into_inner(); |
| |
| log::debug!("heartbeat_with_result response: {:?}", response); |
| response.command.try_into() |
| } |
| |
| async fn update_task_result( |
| &mut self, |
| task_id: &Uuid, |
| task_result: Result<TaskOutputs>, |
| ) -> Result<()> { |
| let request = UpdateTaskResultRequest::new(*task_id, task_result); |
| |
| let _response = self.scheduler_client.update_task_result(request).await?; |
| |
| Ok(()) |
| } |
| |
| async fn update_task_status(&mut self, task_id: &Uuid, task_status: TaskStatus) -> Result<()> { |
| let request = UpdateTaskStatusRequest::new(task_id.to_owned(), task_status); |
| let _response = self.scheduler_client.update_task_status(request).await?; |
| |
| Ok(()) |
| } |
| } |
| |
| fn invoke_task(task: &StagedTask, fusion_base: &PathBuf) -> Result<TaskOutputs> { |
| let save_log = task |
| .function_arguments |
| .get("save_log") |
| .ok() |
| .and_then(|v| v.as_str().and_then(|s| s.parse().ok())) |
| .unwrap_or(false); |
| let log_arc = Arc::new(Mutex::new(Vec::<String>::new())); |
| |
| if save_log { |
| let log_arc = Arc::into_raw(log_arc.clone()); |
| log::info!(buffer = log_arc.expose_addr(); ""); |
| } |
| |
| let file_mgr = TaskFileManager::new( |
| WORKER_BASE_DIR, |
| fusion_base, |
| &task.task_id, |
| &task.input_data, |
| &task.output_data, |
| )?; |
| let invocation = prepare_task(task, &file_mgr)?; |
| |
| log::debug!("Invoke function: {:?}", invocation); |
| let worker = Worker::default(); |
| let summary = worker.invoke_function(invocation)?; |
| |
| let outputs_tag = finalize_task(&file_mgr)?; |
| if save_log { |
| log::info!(buffer = 0; ""); |
| } |
| |
| let log = Arc::try_unwrap(log_arc) |
| .map_err(|_| anyhow::anyhow!("log buffer is referenced more than once"))? |
| .into_inner()?; |
| let task_outputs = TaskOutputs::new(summary.as_bytes(), outputs_tag, log); |
| |
| Ok(task_outputs) |
| } |
| |
| fn prepare_task(task: &StagedTask, file_mgr: &TaskFileManager) -> Result<StagedFunction> { |
| let input_files = file_mgr.prepare_staged_inputs()?; |
| let output_files = file_mgr.prepare_staged_outputs()?; |
| |
| let staged_function = StagedFunctionBuilder::new() |
| .executor_type(task.executor_type) |
| .executor(task.executor) |
| .name(&task.function_name) |
| .arguments(task.function_arguments.clone()) |
| .payload(task.function_payload.clone()) |
| .input_files(input_files) |
| .output_files(output_files) |
| .runtime_name("default") |
| .build(); |
| Ok(staged_function) |
| } |
| |
| fn finalize_task(file_mgr: &TaskFileManager) -> Result<HashMap<String, FileAuthTag>> { |
| file_mgr.upload_outputs() |
| } |
| |
| #[cfg(feature = "enclave_unit_test")] |
| pub mod tests { |
| use super::*; |
| use serde_json::json; |
| use std::format; |
| use teaclave_crypto::*; |
| use url::Url; |
| use uuid::Uuid; |
| |
| pub fn test_invoke_echo() { |
| let task_id = Uuid::new_v4(); |
| let function_arguments = |
| FunctionArguments::from_json(json!({"message": "Hello, Teaclave!"})).unwrap(); |
| let staged_task = StagedTaskBuilder::new() |
| .task_id(task_id) |
| .executor(Executor::Builtin) |
| .function_name("builtin-echo") |
| .function_arguments(function_arguments) |
| .build(); |
| |
| let file_mgr = TaskFileManager::new( |
| WORKER_BASE_DIR, |
| "/tmp/fusion_base", |
| &staged_task.task_id, |
| &staged_task.input_data, |
| &staged_task.output_data, |
| ) |
| .unwrap(); |
| let invocation = prepare_task(&staged_task, &file_mgr).unwrap(); |
| |
| let worker = Worker::default(); |
| let result = worker.invoke_function(invocation); |
| if result.is_ok() { |
| finalize_task(&file_mgr).unwrap(); |
| } |
| assert_eq!(result.unwrap(), "Hello, Teaclave!"); |
| } |
| |
| pub fn test_invoke_gbdt_train() { |
| let task_id = Uuid::new_v4(); |
| let function_arguments = FunctionArguments::from_json(json!({ |
| "feature_size": 4, |
| "max_depth": 4, |
| "iterations": 100, |
| "shrinkage": 0.1, |
| "feature_sample_ratio": 1.0, |
| "data_sample_ratio": 1.0, |
| "min_leaf_size": 1, |
| "loss": "LAD", |
| "training_optimization_level": 2, |
| })) |
| .unwrap(); |
| let fixture_dir = format!( |
| "file:///{}/fixtures/functions/gbdt_training", |
| env!("TEACLAVE_TEST_INSTALL_DIR") |
| ); |
| let input_url = Url::parse(&format!("{}/train.enc", fixture_dir)).unwrap(); |
| let output_url = Url::parse(&format!("{}/model-{}.enc.out", fixture_dir, task_id)).unwrap(); |
| let crypto = TeaclaveFile128Key::new(&[0; 16]).unwrap(); |
| let input_cmac = FileAuthTag::from_hex("860030495909b84864b991865e9ad94f").unwrap(); |
| let training_input_data = FunctionInputFile::new(input_url, input_cmac, crypto); |
| let model_output_data = FunctionOutputFile::new(output_url, crypto); |
| |
| let input_data = hashmap!("training_data" => training_input_data); |
| let output_data = hashmap!("trained_model" => model_output_data); |
| |
| let staged_task = StagedTaskBuilder::new() |
| .task_id(task_id) |
| .executor(Executor::Builtin) |
| .function_name("builtin-gbdt-train") |
| .function_arguments(function_arguments) |
| .input_data(input_data) |
| .output_data(output_data) |
| .build(); |
| |
| let file_mgr = TaskFileManager::new( |
| WORKER_BASE_DIR, |
| "/tmp/fusion_base", |
| &staged_task.task_id, |
| &staged_task.input_data, |
| &staged_task.output_data, |
| ) |
| .unwrap(); |
| let invocation = prepare_task(&staged_task, &file_mgr).unwrap(); |
| |
| let worker = Worker::default(); |
| let result = worker.invoke_function(invocation); |
| if result.is_ok() { |
| finalize_task(&file_mgr).unwrap(); |
| } |
| log::debug!("summary: {:?}", result); |
| assert!(result.is_ok()); |
| } |
| } |