blob: 2e1bf2e9883140940e2bd9f2489002056e97e283 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
use futures::future::join_all;
use futures::TryFutureExt;
use tokio::io::AsyncWriteExt;
use tokio_util::codec;
use url::Url;
use std::path::{Component, Path, PathBuf};
use teaclave_types::{FileAgentRequest, HandleFileCommand, HandleFileInfo};
async fn download_remote_input_to_file(
presigned_url: Url,
dest: impl AsRef<std::path::Path>,
) -> anyhow::Result<()> {
let mut download = reqwest::get(presigned_url.as_str())
.await?
.error_for_status()?;
let mut outfile = tokio::fs::File::create(dest).await?;
while let Some(chunk) = download.chunk().await? {
outfile.write_all(&chunk).await?;
}
// Must flush tokio::io::BufWriter manually.
// It will *not* flush itself automatically when dropped.
outfile.flush().await?;
Ok(())
}
async fn copy_file(
src: impl AsRef<std::path::Path>,
dst: impl AsRef<std::path::Path>,
) -> anyhow::Result<()> {
tokio::fs::copy(src, dst).await?;
Ok(())
}
async fn upload_output_file_to_remote(
src: impl AsRef<std::path::Path>,
presigned_url: Url,
) -> anyhow::Result<()> {
let metadata = std::fs::metadata(&src)?;
let file_len = metadata.len();
let stream = tokio::fs::File::open(src.as_ref().to_path_buf())
.map_ok(|file| codec::FramedRead::new(file, codec::BytesCodec::new()))
.try_flatten_stream();
let body = reqwest::Body::wrap_stream(stream);
let client = reqwest::Client::new();
let res = client
.put(presigned_url.as_str())
.header(reqwest::header::CONTENT_TYPE, "application/x-binary")
.header(reqwest::header::CONTENT_LENGTH, file_len.to_string())
.body(body)
.send()
.await?;
match res.status() {
http::StatusCode::OK => Ok(()),
status => anyhow::bail!("{}", status),
}
}
async fn handle_download(
info: HandleFileInfo,
fusion_base: impl AsRef<Path>,
) -> anyhow::Result<()> {
anyhow::ensure!(
!info.local.exists(),
"[Download] Dest local file: {:?} already exists.",
info.local
);
let dst = info.local;
let remote = info.remote;
match remote.scheme() {
"https" | "http" => {
download_remote_input_to_file(remote, dst).await?;
}
"file" => {
let src = remote
.to_file_path()
.map_err(|e| anyhow::anyhow!("Cannot convert file:// to path: {:?}", e))?;
anyhow::ensure!(
src.exists(),
"[Download] Src local file: {:?} doesn't exist.",
src
);
copy_file(src, dst).await?;
}
"fusion" => {
let path = remote
.to_file_path()
.map_err(|e| anyhow::anyhow!("Cannot convert fusion:// to path: {:?}", e))?;
let components = path.components().collect::<Vec<_>>();
anyhow::ensure!(
(components[0] == Component::RootDir)
&& (components[1] == Component::Normal("TEACLAVE_FUSION_BASE".as_ref())),
"[Download] Fusion data format error: {:?}",
components
);
let relative_path: PathBuf = components[2..].iter().collect();
let src: PathBuf = fusion_base.as_ref().join(relative_path);
anyhow::ensure!(
src.exists(),
"[Download] Src local file: {:?} doesn't exist.",
src
);
copy_file(src, dst).await?;
}
"data" => {
let data = remote.path().split(',').collect::<Vec<&str>>();
if data.len() == 2 && data[0] == "text/plain;base64" {
let bytes = base64::decode(data[1])?;
tokio::fs::write(dst, bytes).await?;
} else {
anyhow::bail!("Scheme format not supported")
}
}
_ => anyhow::bail!("Scheme not supported"),
}
Ok(())
}
async fn handle_upload(info: HandleFileInfo, fusion_base: impl AsRef<Path>) -> anyhow::Result<()> {
anyhow::ensure!(
info.local.exists(),
"[Upload] Src local file: {:?} doesn't exist.",
info.local
);
let src = info.local;
match info.remote.scheme() {
"https" | "http" => {
upload_output_file_to_remote(src, info.remote).await?;
}
"file" => {
let dst = info
.remote
.to_file_path()
.map_err(|e| anyhow::anyhow!("Cannot convert to path: {:?}", e))?;
anyhow::ensure!(!dst.exists(), "[Upload] Dest local file: {:?} exist.", dst);
copy_file(src, dst).await?;
}
"fusion" => {
let path = info
.remote
.to_file_path()
.map_err(|e| anyhow::anyhow!("Cannot convert fusion:// to path: {:?}", e))?;
let components = path.components().collect::<Vec<_>>();
anyhow::ensure!(
(components[0] == Component::RootDir)
&& (components[1] == Component::Normal("TEACLAVE_FUSION_BASE".as_ref())),
"[Upload] Fusion data format error: {:?}",
components
);
let relative_path: PathBuf = components[2..].iter().collect();
let dst: PathBuf = fusion_base.as_ref().join(relative_path);
anyhow::ensure!(
!dst.exists(),
"[Upload] Dest fusion file: {:?} exists.",
dst
);
copy_file(src, dst).await?;
}
_ => anyhow::bail!("Scheme not supported"),
}
Ok(())
}
fn handle_file_request(bytes: &[u8]) -> anyhow::Result<()> {
let req: FileAgentRequest = serde_json::from_slice(bytes)?;
let results = tokio::runtime::Builder::new()
.threaded_scheduler()
.enable_all()
.build()?
.block_on(async {
let fusion_base = req.fusion_base.clone();
match req.cmd {
HandleFileCommand::Download => {
let futures: Vec<_> = req
.info
.into_iter()
.map(|info| {
let fusion_base = fusion_base.clone();
tokio::spawn(async { handle_download(info, fusion_base).await })
})
.collect();
join_all(futures).await
}
HandleFileCommand::Upload => {
let futures: Vec<_> = req
.info
.into_iter()
.map(|info| {
let fusion_base = fusion_base.clone();
tokio::spawn(async { handle_upload(info, fusion_base).await })
})
.collect();
join_all(futures).await
}
}
});
let (task_results, errs): (Vec<_>, Vec<_>) = results.into_iter().partition(Result::is_ok);
debug!("{:?}, errs: {:?}", task_results, errs);
if !errs.is_empty() {
anyhow::bail!("Spawned task join error!");
}
anyhow::ensure!(
task_results.iter().all(|x| x.as_ref().unwrap().is_ok()),
format!("Some handle file task failed {:?}", task_results)
);
Ok(())
}
#[no_mangle]
#[allow(clippy::not_unsafe_ptr_arg_deref)]
pub extern "C" fn ocall_handle_file_request(in_buf: *const u8, in_len: u32) -> u32 {
let input_buf: &[u8] = unsafe { std::slice::from_raw_parts(in_buf, in_len as usize) };
match handle_file_request(input_buf) {
Ok(_) => 0,
Err(e) => {
log::error!("error: {:?}. in_buf: {:?}", e, in_buf);
1
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::io::Write;
use std::path::PathBuf;
use url::Url;
#[test]
fn test_file_url() {
let url = Url::parse("file:///tmp/abc.txt").unwrap();
assert_eq!(url.scheme(), "file");
assert_eq!(url.host(), None);
assert_eq!(url.path(), "/tmp/abc.txt");
let url = Url::parse("file:///countries/việt nam").unwrap();
assert_eq!(url.path(), "/countries/vi%E1%BB%87t%20nam");
let file_path = url.to_file_path().unwrap();
assert_eq!(file_path, PathBuf::from("/countries/việt nam"));
}
#[test]
fn test_get_single_file() {
let s = "http://localhost:6789/fixtures/functions/mesapy/input.txt";
let url = Url::parse(s).unwrap();
let dest = PathBuf::from("/tmp/input_test_get_single_file.txt");
let info = HandleFileInfo::new(&dest, &url);
let req = FileAgentRequest::new(HandleFileCommand::Download, vec![info], "");
let bytes = serde_json::to_vec(&req).unwrap();
handle_file_request(&bytes).unwrap();
std::fs::remove_file(&dest).unwrap();
}
#[test]
fn test_put_single_file() {
let src = PathBuf::from("/tmp/output_single_test.txt");
{
let mut file = std::fs::File::create(&src).unwrap();
file.write_all(b"Hello Teaclave Results!").unwrap();
}
let s = "http://localhost:6789/fixtures/functions/mesapy/result.txt";
let url = Url::parse(s).unwrap();
let info = HandleFileInfo::new(&src, &url);
let req = FileAgentRequest::new(HandleFileCommand::Upload, vec![info], "");
let bytes = serde_json::to_vec(&req).unwrap();
handle_file_request(&bytes).unwrap();
std::fs::remove_file(&src).unwrap();
}
#[test]
fn test_get_multiple_files() {
let s = "http://localhost:6789/fixtures/functions/gbdt_training/train.txt";
let url = Url::parse(s).unwrap();
let base = PathBuf::from("/tmp/file_agent_test_base");
let fnames = vec!["a.txt", "b.txt", "c.txt", "d.txt"];
std::fs::create_dir_all(&base).unwrap();
let info_list: Vec<_> = fnames
.iter()
.map(|fname| HandleFileInfo::new(base.join(fname), &url))
.collect();
let req = FileAgentRequest::new(HandleFileCommand::Download, info_list, "");
let bytes = serde_json::to_vec(&req).unwrap();
handle_file_request(&bytes).unwrap();
std::fs::remove_dir_all(&base).unwrap();
}
#[test]
fn test_put_multiple_files() {
let src = PathBuf::from("/tmp/output_multiple_test.txt");
{
let mut file = std::fs::File::create(&src).unwrap();
file.write_all(b"Hello Teaclave Results!").unwrap();
}
let s = "http://localhost:6789/fixtures/functions/gbdt_training";
let url = Url::parse(s).unwrap();
let fnames = vec!["a.txt", "b.txt", "c.txt", "d.txt"];
let info_list: Vec<_> = fnames
.iter()
.map(|fname| {
let mut url = url.clone();
url.path_segments_mut().unwrap().push(fname);
HandleFileInfo::new(&src, &url)
})
.collect();
let req = FileAgentRequest::new(HandleFileCommand::Upload, info_list, "");
let bytes = serde_json::to_vec(&req).unwrap();
handle_file_request(&bytes).unwrap();
std::fs::remove_file(&src).unwrap();
}
#[test]
fn test_local_copy_file() {
let base_str = "/tmp/file_agent_local_copy";
let base = PathBuf::from(&base_str);
std::fs::create_dir_all(&base).unwrap();
let src = base.join("src.txt");
{
let mut file = std::fs::File::create(&src).unwrap();
file.write_all(b"Hello Teaclave Results!").unwrap();
}
// test local upload
let s = format!("file://{}/d1.txt", base_str);
let url = Url::parse(&s).unwrap();
let info = HandleFileInfo::new(&src, &url);
let req = FileAgentRequest::new(HandleFileCommand::Upload, vec![info], "");
let bytes = serde_json::to_vec(&req).unwrap();
handle_file_request(&bytes).unwrap();
// test local download
let dest = base.join("d2.txt");
let info = HandleFileInfo::new(&dest, &url);
let req = FileAgentRequest::new(HandleFileCommand::Download, vec![info], "");
let bytes = serde_json::to_vec(&req).unwrap();
handle_file_request(&bytes).unwrap();
std::fs::remove_dir_all(&base).unwrap();
}
#[test]
fn test_data_scheme() {
let url = Url::parse("data:text/plain;base64,SGVsbG8sIFdvcmxkIQ==").unwrap();
let dest = PathBuf::from("/tmp/input_test_data_scheme.txt");
let info = HandleFileInfo::new(&dest, &url);
let req = FileAgentRequest::new(HandleFileCommand::Download, vec![info], "");
let bytes = serde_json::to_vec(&req).unwrap();
handle_file_request(&bytes).unwrap();
assert_eq!(std::fs::read_to_string(&dest).unwrap(), "Hello, World!");
std::fs::remove_file(&dest).unwrap();
}
}