| // Licensed to the Apache Software Foundation (ASF) under one |
| // or more contributor license agreements. See the NOTICE file |
| // distributed with this work for additional information |
| // regarding copyright ownership. The ASF licenses this file |
| // to you under the Apache License, Version 2.0 (the |
| // "License"); you may not use this file except in compliance |
| // with the License. You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, |
| // software distributed under the License is distributed on an |
| // "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| // KIND, either express or implied. See the License for the |
| // specific language governing permissions and limitations |
| // under the License. |
| |
| #[cfg(feature = "mesalock_sgx")] |
| use std::prelude::v1::*; |
| |
| use std::format; |
| use std::io::{self, BufRead, BufReader, Write}; |
| |
| use teaclave_types::{FunctionArguments, FunctionRuntime}; |
| |
| use rusty_machine::learning::logistic_reg::LogisticRegressor; |
| use rusty_machine::learning::optim::grad_desc::GradientDesc; |
| use rusty_machine::learning::SupModel; |
| use rusty_machine::linalg; |
| |
| const MODEL_FILE: &str = "model_file"; |
| const INPUT_DATA: &str = "data_file"; |
| const RESULT: &str = "result_file"; |
| |
| #[derive(Default)] |
| pub struct LogisticRegressionPredict; |
| |
| impl LogisticRegressionPredict { |
| pub const NAME: &'static str = "builtin-logistic-regression-predict"; |
| |
| pub fn new() -> Self { |
| Default::default() |
| } |
| |
| pub fn run( |
| &self, |
| _arguments: FunctionArguments, |
| runtime: FunctionRuntime, |
| ) -> anyhow::Result<String> { |
| let mut model_json = String::new(); |
| let mut f = runtime.open_input(MODEL_FILE)?; |
| f.read_to_string(&mut model_json)?; |
| |
| let lr: LogisticRegressor<GradientDesc> = serde_json::from_str(&model_json)?; |
| let feature_size = lr |
| .parameters() |
| .ok_or_else(|| anyhow::anyhow!("Model parameter is None"))? |
| .size() |
| - 1; |
| |
| let input = runtime.open_input(INPUT_DATA)?; |
| let data_matrix = parse_input_data(input, feature_size)?; |
| |
| let result = lr.predict(&data_matrix)?; |
| |
| let mut output = runtime.create_output(RESULT)?; |
| let result_cnt = result.data().len(); |
| for c in result.data().iter() { |
| writeln!(&mut output, "{:.4}", c)?; |
| } |
| Ok(format!("Predicted {} lines of data.", result_cnt)) |
| } |
| } |
| |
| fn parse_input_data( |
| input: impl io::Read, |
| feature_size: usize, |
| ) -> anyhow::Result<linalg::Matrix<f64>> { |
| let mut flattened_data = Vec::new(); |
| let mut count = 0; |
| |
| let reader = BufReader::new(input); |
| for line_result in reader.lines() { |
| let line = line_result?; |
| let trimed_line = line.trim(); |
| anyhow::ensure!(!trimed_line.is_empty(), "Empty line"); |
| |
| let v: Vec<f64> = trimed_line |
| .split(',') |
| .map(|x| x.parse::<f64>()) |
| .collect::<std::result::Result<_, _>>()?; |
| |
| anyhow::ensure!( |
| v.len() == feature_size, |
| "Data format error: column len = {}, expected = {}", |
| v.len(), |
| feature_size |
| ); |
| |
| flattened_data.extend(v); |
| count += 1; |
| } |
| |
| Ok(linalg::Matrix::new(count, feature_size, flattened_data)) |
| } |
| |
| #[cfg(feature = "enclave_unit_test")] |
| pub mod tests { |
| use super::*; |
| use std::path::Path; |
| use std::untrusted::fs; |
| use teaclave_crypto::*; |
| use teaclave_runtime::*; |
| use teaclave_test_utils::*; |
| use teaclave_types::*; |
| |
| pub fn run_tests() -> bool { |
| run_tests!(test_logistic_regression_predict) |
| } |
| |
| fn test_logistic_regression_predict() { |
| let arguments = FunctionArguments::default(); |
| |
| let base = Path::new("fixtures/functions/logistic_regression_prediction"); |
| let model = base.join("model.txt"); |
| let plain_input = base.join("predict_input.txt"); |
| let plain_output = base.join("predict_result.txt.out"); |
| let expected_output = base.join("expected_result.txt"); |
| |
| let input_files = StagedFiles::new(hashmap!( |
| MODEL_FILE => |
| StagedFileInfo::new(&model, TeaclaveFile128Key::random(), FileAuthTag::mock()), |
| INPUT_DATA => |
| StagedFileInfo::new(&plain_input, TeaclaveFile128Key::random(), FileAuthTag::mock()), |
| )); |
| |
| let output_files = StagedFiles::new(hashmap!( |
| RESULT => |
| StagedFileInfo::new(&plain_output, TeaclaveFile128Key::random(), FileAuthTag::mock()) |
| )); |
| |
| let runtime = Box::new(RawIoRuntime::new(input_files, output_files)); |
| |
| let summary = LogisticRegressionPredict::new() |
| .run(arguments, runtime) |
| .unwrap(); |
| assert_eq!(summary, "Predicted 5 lines of data."); |
| |
| let result = fs::read_to_string(&plain_output).unwrap(); |
| let expected = fs::read_to_string(&expected_output).unwrap(); |
| assert_eq!(&result[..], &expected[..]); |
| } |
| } |