blob: 5df3bbedb0423bd48c2ab2afa61a9f4062106ef0 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#[cfg(feature = "mesalock_sgx")]
use std::prelude::v1::*;
use std::convert::TryFrom;
use std::format;
use std::io::{self, BufRead, BufReader, Write};
use teaclave_types::{FunctionArguments, FunctionRuntime};
use rusty_machine::learning::logistic_reg::LogisticRegressor;
use rusty_machine::learning::optim::grad_desc::GradientDesc;
use rusty_machine::learning::SupModel;
use rusty_machine::linalg;
const TRAINING_DATA: &str = "training_data";
const OUT_MODEL_FILE: &str = "model_file";
#[derive(Default)]
pub struct LogisticRegressionTrain;
struct LogisticRegressionTrainArguments {
alg_alpha: f64,
alg_iters: usize,
feature_size: usize,
}
impl TryFrom<FunctionArguments> for LogisticRegressionTrainArguments {
type Error = anyhow::Error;
fn try_from(arguments: FunctionArguments) -> Result<Self, Self::Error> {
let alg_alpha = arguments.get("alg_alpha")?.as_f64()?;
let alg_iters = arguments.get("alg_iters")?.as_usize()?;
let feature_size = arguments.get("feature_size")?.as_usize()?;
Ok(Self {
alg_alpha,
alg_iters,
feature_size,
})
}
}
impl LogisticRegressionTrain {
pub const NAME: &'static str = "builtin-logistic-regression-train";
pub fn new() -> Self {
Default::default()
}
pub fn run(
&self,
arguments: FunctionArguments,
runtime: FunctionRuntime,
) -> anyhow::Result<String> {
let args = LogisticRegressionTrainArguments::try_from(arguments)?;
let input = runtime.open_input(TRAINING_DATA)?;
let (flattend_features, targets) = parse_training_data(input, args.feature_size)?;
let data_size = targets.len();
let data_matrix = linalg::Matrix::new(data_size, args.feature_size, flattend_features);
let targets = linalg::Vector::new(targets);
let gd = GradientDesc::new(args.alg_alpha, args.alg_iters);
let mut lr = LogisticRegressor::new(gd);
lr.train(&data_matrix, &targets)?;
let model_json = serde_json::to_string(&lr)?;
let mut model_file = runtime.create_output(OUT_MODEL_FILE)?;
model_file.write_all(model_json.as_bytes())?;
Ok(format!("Trained {} lines of data.", data_size))
}
}
fn parse_training_data(
input: impl io::Read,
feature_size: usize,
) -> anyhow::Result<(Vec<f64>, Vec<f64>)> {
let reader = BufReader::new(input);
let mut targets = Vec::<f64>::new();
let mut features = Vec::new();
for line_result in reader.lines() {
let line = line_result?;
let trimed_line = line.trim();
anyhow::ensure!(!trimed_line.is_empty(), "Empty line");
let mut v: Vec<f64> = trimed_line
.split(',')
.map(|x| x.parse::<f64>())
.collect::<std::result::Result<_, _>>()?;
anyhow::ensure!(
v.len() == feature_size + 1,
"Data format error: column len = {}, expected = {}",
v.len(),
feature_size + 1
);
let label = v.swap_remove(feature_size);
targets.push(label);
features.extend(v);
}
Ok((features, targets))
}
#[cfg(feature = "enclave_unit_test")]
pub mod tests {
use super::*;
use std::path::Path;
use std::untrusted::fs;
use teaclave_crypto::*;
use teaclave_runtime::*;
use teaclave_test_utils::*;
use teaclave_types::*;
pub fn run_tests() -> bool {
run_tests!(test_logistic_regression_train)
}
fn test_logistic_regression_train() {
let arguments = FunctionArguments::new(hashmap! {
"alg_alpha" => "0.3",
"alg_iters" => "100",
"feature_size" => "30"
});
let base = Path::new("fixtures/functions/logistic_regression_training");
let training_data = base.join("train.txt");
let plain_output = base.join("model.txt.out");
let expected_output = base.join("expected_model.txt");
let input_files = StagedFiles::new(hashmap!(
TRAINING_DATA =>
StagedFileInfo::new(&training_data, TeaclaveFile128Key::random(), FileAuthTag::mock()),
));
let output_files = StagedFiles::new(hashmap!(
OUT_MODEL_FILE =>
StagedFileInfo::new(&plain_output, TeaclaveFile128Key::random(), FileAuthTag::mock())
));
let runtime = Box::new(RawIoRuntime::new(input_files, output_files));
let summary = LogisticRegressionTrain::new()
.run(arguments, runtime)
.unwrap();
assert_eq!(summary, "Trained 100 lines of data.");
let _result = fs::read_to_string(&plain_output).unwrap();
let _expected = fs::read_to_string(&expected_output).unwrap();
// assert_eq!(&result[..], &expected[..]);
}
}