blob: b887427fd71118400bc21e42cfa640b2fcd5fe78 [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#[cfg(feature = "mesalock_sgx")]
use std::prelude::v1::*;
use std::format;
use std::io::{self, BufRead, BufReader, Write};
use std::convert::TryFrom;
use teaclave_types::{FunctionArguments, FunctionRuntime};
use gbdt::config::Config;
use gbdt::decision_tree::Data;
use gbdt::gradient_boost::GBDT;
const IN_DATA: &str = "training_data";
const OUT_MODEL: &str = "trained_model";
#[derive(Default)]
pub struct GbdtTrain;
struct GbdtTrainArguments {
feature_size: usize,
max_depth: u32,
iterations: usize,
shrinkage: f32,
feature_sample_ratio: f64,
data_sample_ratio: f64,
min_leaf_size: usize,
loss: String,
training_optimization_level: u8,
}
impl TryFrom<FunctionArguments> for GbdtTrainArguments {
type Error = anyhow::Error;
fn try_from(arguments: FunctionArguments) -> Result<Self, Self::Error> {
let feature_size = arguments.get("feature_size")?.as_usize()?;
let max_depth = arguments.get("max_depth")?.as_u32()?;
let iterations = arguments.get("iterations")?.as_usize()?;
let shrinkage = arguments.get("shrinkage")?.as_f32()?;
let feature_sample_ratio = arguments.get("feature_sample_ratio")?.as_f64()?;
let data_sample_ratio = arguments.get("data_sample_ratio")?.as_f64()?;
let min_leaf_size = arguments.get("min_leaf_size")?.as_usize()?;
let loss = arguments.get("loss")?.as_str().to_owned();
let training_optimization_level = arguments.get("training_optimization_level")?.as_u8()?;
Ok(Self {
feature_size,
max_depth,
iterations,
shrinkage,
feature_sample_ratio,
data_sample_ratio,
min_leaf_size,
loss,
training_optimization_level,
})
}
}
impl GbdtTrain {
pub const NAME: &'static str = "builtin-gbdt-train";
pub fn new() -> Self {
Default::default()
}
pub fn run(
&self,
arguments: FunctionArguments,
runtime: FunctionRuntime,
) -> anyhow::Result<String> {
log::debug!("start traning...");
let args = GbdtTrainArguments::try_from(arguments)?;
log::debug!("open input...");
// read input
let training_file = runtime.open_input(IN_DATA)?;
let mut train_dv = parse_training_data(training_file, args.feature_size)?;
let data_size = train_dv.len();
// init gbdt config
let mut cfg = Config::new();
cfg.set_debug(false);
cfg.set_feature_size(args.feature_size);
cfg.set_max_depth(args.max_depth);
cfg.set_iterations(args.iterations);
cfg.set_shrinkage(args.shrinkage);
cfg.set_loss(&args.loss);
cfg.set_min_leaf_size(args.min_leaf_size);
cfg.set_data_sample_ratio(args.data_sample_ratio);
cfg.set_feature_sample_ratio(args.feature_sample_ratio);
cfg.set_training_optimization_level(args.training_optimization_level);
// start training
let mut gbdt_train_mod = GBDT::new(&cfg);
gbdt_train_mod.fit(&mut train_dv);
let model_json = serde_json::to_string(&gbdt_train_mod)?;
// save the model to output
let mut model_file = runtime.create_output(OUT_MODEL)?;
model_file.write_all(model_json.as_bytes())?;
let summary = format!("Trained {} lines of data.", data_size);
Ok(summary)
}
}
fn parse_data_line(line: &str, feature_size: usize) -> anyhow::Result<Data> {
let trimed_line = line.trim();
anyhow::ensure!(!trimed_line.is_empty(), "Empty line");
let mut v: Vec<f32> = trimed_line
.split(',')
.map(|x| x.parse::<f32>())
.collect::<std::result::Result<_, _>>()?;
anyhow::ensure!(
v.len() == feature_size + 1,
"Data format error: column len = {}, expected = {}",
v.len(),
feature_size + 1
);
// Last column is the label
Ok(Data {
label: v.swap_remove(feature_size),
feature: v,
target: 0.0,
weight: 1.0,
residual: 0.0,
initial_guess: 0.0,
})
}
fn parse_training_data(input: impl io::Read, feature_size: usize) -> anyhow::Result<Vec<Data>> {
let mut samples: Vec<Data> = Vec::new();
let reader = BufReader::new(input);
for line_result in reader.lines() {
let line = line_result?;
let data = parse_data_line(&line, feature_size)?;
samples.push(data);
}
Ok(samples)
}
#[cfg(feature = "enclave_unit_test")]
pub mod tests {
use super::*;
use std::untrusted::fs;
use teaclave_crypto::*;
use teaclave_runtime::*;
use teaclave_test_utils::*;
use teaclave_types::*;
pub fn run_tests() -> bool {
run_tests!(test_gbdt_train, test_gbdt_parse_training_data,)
}
fn test_gbdt_train() {
let arguments = FunctionArguments::new(hashmap!(
"feature_size" => "4",
"max_depth" => "4",
"iterations" => "100",
"shrinkage" => "0.1",
"feature_sample_ratio" => "1.0",
"data_sample_ratio" => "1.0",
"min_leaf_size" => "1",
"loss" => "LAD",
"training_optimization_level" => "2"
));
let plain_input = "fixtures/functions/gbdt_training/train.txt";
let plain_output = "fixtures/functions/gbdt_training/training_model.txt.out";
let expected_output = "fixtures/functions/gbdt_training/expected_model.txt";
let input_files = StagedFiles::new(hashmap!(
IN_DATA =>
StagedFileInfo::new(plain_input, TeaclaveFile128Key::random(), FileAuthTag::mock())
));
let output_files = StagedFiles::new(hashmap!(
OUT_MODEL =>
StagedFileInfo::new(plain_output, TeaclaveFile128Key::random(), FileAuthTag::mock())
));
let runtime = Box::new(RawIoRuntime::new(input_files, output_files));
let summary = GbdtTrain::new().run(arguments, runtime).unwrap();
assert_eq!(summary, "Trained 120 lines of data.");
let result = fs::read_to_string(&plain_output).unwrap();
let expected = fs::read_to_string(&expected_output).unwrap();
assert_eq!(&result[..], &expected[..]);
}
fn test_gbdt_parse_training_data() {
let line = "4.8,3.0,1.4,0.3,3.0";
let result = parse_data_line(&line, 4);
assert!(result.is_ok());
let result = parse_data_line(&line, 3);
assert!(result.is_err());
}
}