blob: 32124111a7f9d9a3ff4a391a9f2bc2f68c441b3f [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
use anyhow::bail;
use std::cmp;
use std::convert::TryFrom;
use std::format;
use std::io::{self, BufRead, BufReader, Write};
use teaclave_types::{FunctionArguments, FunctionRuntime};
extern crate hex;
// Input data should be a list of sorted hash values.
const IN_DATA1: &str = "input_data1";
const IN_DATA2: &str = "input_data2";
const OUT_RESULT1: &str = "output_result1";
const OUT_RESULT2: &str = "output_result2";
#[derive(Default)]
pub struct OrderedSetIntersect;
#[derive(serde::Deserialize)]
pub struct OrderedSetIntersectArguments {
order: String,
}
impl TryFrom<FunctionArguments> for OrderedSetIntersectArguments {
type Error = anyhow::Error;
fn try_from(arguments: FunctionArguments) -> Result<Self, Self::Error> {
use anyhow::Context;
serde_json::from_str(&arguments.into_string()).context("Cannot deserialize arguments")
}
}
impl OrderedSetIntersect {
pub const NAME: &'static str = "builtin-ordered-set-intersect";
pub fn new() -> Self {
Default::default()
}
pub fn run(
&self,
arguments: FunctionArguments,
runtime: FunctionRuntime,
) -> anyhow::Result<String> {
let input1 = runtime.open_input(IN_DATA1)?;
let input2 = runtime.open_input(IN_DATA2)?;
let mut output1 = runtime.create_output(OUT_RESULT1)?;
let mut output2 = runtime.create_output(OUT_RESULT2)?;
let args = OrderedSetIntersectArguments::try_from(arguments)?;
let order = &args.order[..];
let ascending_order = match order {
"ascending" => true,
"desending" => false,
_ => bail!("Invalid order"),
};
let vec1 = parse_input_data(input1, ascending_order)?;
let vec2 = parse_input_data(input2, ascending_order)?;
let (result1, result2) = intersection_ordered_vec(&vec1, &vec2, ascending_order)?;
let mut common_sets = 0;
for item in result1 {
write!(&mut output1, "{}", item as u32)?;
if item {
common_sets += 1;
}
}
for item in result2 {
write!(&mut output2, "{}", item as u32)?;
}
log::trace!("{}", common_sets);
Ok(format!("{} common items", common_sets))
}
}
fn parse_input_data(input: impl io::Read, ascending_order: bool) -> anyhow::Result<Vec<Vec<u8>>> {
let mut samples: Vec<Vec<u8>> = Vec::new();
let reader = BufReader::new(input);
for (index, byte_result) in reader.lines().enumerate() {
let byte = byte_result?;
let result = hex::decode(byte)?;
if index > 0 {
// If vec has more than 2 elements, then verify the ordering
let last_element = &samples[index - 1];
if ascending_order && result < *last_element
|| !ascending_order && result > *last_element
{
bail!("Invalid ordering");
}
}
samples.push(result)
}
Ok(samples)
}
fn intersection_ordered_vec(
input1: &[Vec<u8>],
input2: &[Vec<u8>],
ascending_order: bool,
) -> anyhow::Result<(Vec<bool>, Vec<bool>)> {
let v1_len = input1.len();
let v2_len = input2.len();
let mut res1 = std::vec![false; v1_len];
let mut res2 = std::vec![false; v2_len];
let mut i = 0;
let mut j = 0;
while i < v1_len && j < v2_len {
let order = &input1[i].cmp(&input2[j]);
match order {
cmp::Ordering::Equal => {
res1[i] = true;
res2[j] = true;
i += 1;
j += 1;
}
cmp::Ordering::Less => {
if ascending_order {
i += 1;
} else {
j += 1;
}
}
cmp::Ordering::Greater => {
if ascending_order {
j += 1;
} else {
i += 1;
}
}
}
}
Ok((res1, res2))
}
#[cfg(feature = "enclave_unit_test")]
pub mod tests {
use super::*;
use serde_json::json;
use std::path::Path;
use std::untrusted::fs;
use teaclave_crypto::*;
use teaclave_runtime::*;
use teaclave_test_utils::*;
use teaclave_types::*;
pub fn run_tests() -> bool {
run_tests!(test_ordered_set_intersect)
}
fn test_ordered_set_intersect() {
let arguments = FunctionArguments::from_json(json!({
"order": "ascending"
}))
.unwrap();
let base = Path::new("fixtures/functions/ordered_set_intersect");
let user1_input = base.join("psi0.txt");
let user1_output = base.join("output_psi0.txt");
let user2_input = base.join("psi1.txt");
let user2_output = base.join("output_psi1.txt");
let input_files = StagedFiles::new(hashmap!(
IN_DATA1 =>
StagedFileInfo::new(&user1_input, TeaclaveFile128Key::random(), FileAuthTag::mock()),
IN_DATA2 =>
StagedFileInfo::new(&user2_input, TeaclaveFile128Key::random(), FileAuthTag::mock()),
));
let output_files = StagedFiles::new(hashmap!(
OUT_RESULT1 =>
StagedFileInfo::new(&user1_output, TeaclaveFile128Key::random(), FileAuthTag::mock()),
OUT_RESULT2 =>
StagedFileInfo::new(&user2_output, TeaclaveFile128Key::random(), FileAuthTag::mock()),
));
let runtime = Box::new(RawIoRuntime::new(input_files, output_files));
let summary = OrderedSetIntersect::new().run(arguments, runtime).unwrap();
let user1_result = fs::read_to_string(&user1_output).unwrap();
let user2_result = fs::read_to_string(&user2_output).unwrap();
assert_eq!(&user1_result[..], "0101010");
assert_eq!(&user2_result[..], "01101");
assert_eq!(summary, "3 common items");
}
}