blob: 8ec5158561d97e3e42d78b78091040fafaa08989 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
use std::convert::TryFrom;
use std::format;
use std::io::{self, BufRead, BufReader, Write};
use teaclave_types::{FunctionArguments, FunctionRuntime};
use rusty_machine::learning::pca::PCA;
use rusty_machine::learning::UnSupModel;
use rusty_machine::linalg;
use rusty_machine::linalg::BaseMatrix;
const IN_DATA: &str = "input_data";
const OUT_RESULT: &str = "output_data";
#[derive(Default)]
pub struct PrincipalComponentsAnalysis;
#[derive(serde::Deserialize)]
struct PrincipalComponentsAnalysisArguments {
n: usize,
center: bool,
feature_size: usize,
}
impl TryFrom<FunctionArguments> for PrincipalComponentsAnalysisArguments {
type Error = anyhow::Error;
fn try_from(arguments: FunctionArguments) -> Result<Self, Self::Error> {
use anyhow::Context;
serde_json::from_str(&arguments.into_string()).context("Cannot deserialize arguments")
}
}
impl PrincipalComponentsAnalysis {
pub const NAME: &'static str = "builtin_principal_components_analysis";
pub fn new() -> Self {
Default::default()
}
pub fn run(
&self,
arguments: FunctionArguments,
runtime: FunctionRuntime,
) -> anyhow::Result<String> {
let args = PrincipalComponentsAnalysisArguments::try_from(arguments)?;
let input = runtime.open_input(IN_DATA)?;
let (flattend_features, targets) = parse_input_data(input, args.feature_size)?;
let data_size = targets.len();
let input_features = linalg::Matrix::new(data_size, args.feature_size, flattend_features);
let mut model = PCA::new(args.n, args.center);
model.train(&input_features)?;
let predict_result = model.predict(&input_features)?;
let mut output = runtime.create_output(OUT_RESULT)?;
for i in 0..predict_result.rows() {
for j in 0..predict_result.cols() {
if j == predict_result.cols() - 1 {
write!(&mut output, "{:?}", predict_result[[i, j]])?;
} else {
write!(&mut output, "{:?},", predict_result[[i, j]])?;
}
}
writeln!(&mut output)?;
}
Ok(format!(
"transform {} rows * {} cols lines of data.",
predict_result.rows(),
predict_result.cols()
))
}
}
fn parse_input_data(
input: impl io::Read,
feature_size: usize,
) -> anyhow::Result<(Vec<f64>, Vec<f64>)> {
let reader = BufReader::new(input);
let mut targets = Vec::<f64>::new();
let mut features = Vec::new();
for line_result in reader.lines() {
let line = line_result?;
let trimed_line = line.trim();
anyhow::ensure!(!trimed_line.is_empty(), "Empty line");
let mut v: Vec<f64> = trimed_line
.split(',')
.map(|x| x.parse::<f64>())
.collect::<std::result::Result<_, _>>()?;
anyhow::ensure!(
v.len() == feature_size + 1,
"Data format error: column len = {}, expected = {}",
v.len(),
feature_size + 1
);
let label = v.swap_remove(feature_size);
targets.push(label);
features.extend(v);
}
Ok((features, targets))
}
#[cfg(feature = "enclave_unit_test")]
pub mod tests {
use super::*;
use serde_json::json;
use std::path::Path;
use std::untrusted::fs;
use teaclave_crypto::*;
use teaclave_runtime::*;
use teaclave_test_utils::*;
use teaclave_types::*;
pub fn run_tests() -> bool {
run_tests!(test_pca_predict)
}
fn test_pca_predict() {
let args = FunctionArguments::from_json(json!({
"n": 2,
"feature_size": 4,
"center":true
}))
.unwrap();
let base = Path::new("fixtures/functions/princopal_components_analysis");
let input_data_file = base.join("input.txt");
let output_data_file = base.join("result.txt");
let expected_output = base.join("expected_result.txt");
let input_files = StagedFiles::new(hashmap!(
IN_DATA =>
StagedFileInfo::new(&input_data_file, TeaclaveFile128Key::random(), FileAuthTag::mock()),
));
let output_files = StagedFiles::new(hashmap!(
OUT_RESULT =>
StagedFileInfo::new(&output_data_file, TeaclaveFile128Key::random(), FileAuthTag::mock()),
));
let runtime = Box::new(RawIoRuntime::new(input_files, output_files));
let summary = PrincipalComponentsAnalysis::new()
.run(args, runtime)
.unwrap();
assert_eq!(summary, "transform 90 rows * 2 cols lines of data.");
let result = fs::read_to_string(&output_data_file).unwrap();
let expected = fs::read_to_string(&expected_output).unwrap();
assert_eq!(&result[..], &expected[..]);
}
}