[examples] Add builtin_gbdt_train python example (#319)
diff --git a/cmake/scripts/test.sh b/cmake/scripts/test.sh
index 653abfe..7e6119c 100755
--- a/cmake/scripts/test.sh
+++ b/cmake/scripts/test.sh
@@ -155,9 +155,12 @@
sleep 3 # wait for execution services
popd
+ pushd ${TEACLAVE_PROJECT_ROOT}/examples/python
export PYTHONPATH=${TEACLAVE_PROJECT_ROOT}/sdk/python
- python3 ${TEACLAVE_PROJECT_ROOT}/examples/python/builtin_echo.py
- python3 ${TEACLAVE_PROJECT_ROOT}/examples/python/mesapy_echo.py
+ python3 builtin_echo.py
+ python3 mesapy_echo.py
+ python3 builtin_gbdt_train.py
+ popd
# kill all background services
cleanup
diff --git a/examples/python/builtin_gbdt_train.py b/examples/python/builtin_gbdt_train.py
new file mode 100644
index 0000000..d2e03e3
--- /dev/null
+++ b/examples/python/builtin_gbdt_train.py
@@ -0,0 +1,137 @@
+#!/usr/bin/env python3
+
+import sys
+
+from teaclave import (AuthenticationService, FrontendService,
+ AuthenticationClient, FrontendClient)
+from utils import (AUTHENTICATION_SERVICE_ADDRESS, FRONTEND_SERVICE_ADDRESS,
+ AS_ROOT_CA_CERT_PATH, ENCLAVE_INFO_PATH, USER_ID,
+ USER_PASSWORD)
+
+
+class FunctionInput:
+ def __init__(self, name, description):
+ self.name = name
+ self.description = description
+
+
+class FunctionOutput:
+ def __init__(self, name, description):
+ self.name = name
+ self.description = description
+
+
+class OwnerList:
+ def __init__(self, data_name, uids):
+ self.data_name = data_name
+ self.uids = uids
+
+
+class DataList:
+ def __init__(self, data_name, data_id):
+ self.data_name = data_name
+ self.data_id = data_id
+
+
+class BuiltinGbdtExample:
+ def __init__(self, user_id, user_password):
+ self.user_id = user_id
+ self.user_password = user_password
+
+ def gbdt(self):
+ channel = AuthenticationService(AUTHENTICATION_SERVICE_ADDRESS,
+ AS_ROOT_CA_CERT_PATH,
+ ENCLAVE_INFO_PATH).connect()
+ client = AuthenticationClient(channel)
+
+ print("[+] registering user")
+ client.user_register(self.user_id, self.user_password)
+
+ print("[+] login")
+ token = client.user_login(self.user_id, self.user_password)
+
+ channel = FrontendService(FRONTEND_SERVICE_ADDRESS,
+ AS_ROOT_CA_CERT_PATH,
+ ENCLAVE_INFO_PATH).connect()
+ metadata = {"id": self.user_id, "token": token}
+ client = FrontendClient(channel, metadata)
+
+ print("[+] registering function")
+ function_id = client.register_function(
+ name="builtin-gbdt-train",
+ description="Native Gbdt Training Function",
+ executor_type="builtin",
+ arguments=[
+ "feature_size", "max_depth", "iterations", "shrinkage",
+ "feature_sample_ratio", "data_sample_ratio", "min_leaf_size",
+ "loss", "training_optimization_level"
+ ],
+ inputs=[
+ FunctionInput("training_data", "Input traning data file.")
+ ],
+ outputs=[FunctionOutput("trained_model", "Output trained model.")])
+
+ print("[+] registering input file")
+ url = "http://localhost:6789/fixtures/functions/gbdt_training/train.enc"
+ cmac = "881adca6b0524472da0a9d0bb02b9af9"
+ schema = "teaclave-file-128"
+ key = [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]
+ iv = []
+ training_data_id = client.register_input_file(url, schema, key, iv,
+ cmac)
+
+ print("[+] registering output file")
+ url = "http://localhost:6789/fixtures/functions/gbdt_training/e2e_output_model.enc"
+ schema = "teaclave-file-128"
+ key = [
+ 63, 195, 250, 208, 252, 127, 203, 27, 247, 168, 71, 77, 27, 47,
+ 254, 240
+ ]
+ iv = []
+ output_model_id = client.register_output_file(url, schema, key, iv)
+
+ print("[+] creating task")
+ task_id = client.create_task(
+ function_id=function_id,
+ function_arguments=({
+ "feature_size": 4,
+ "max_depth": 4,
+ "iterations": 100,
+ "shrinkage": 0.1,
+ "feature_sample_ratio": 1.0,
+ "data_sample_ratio": 1.0,
+ "min_leaf_size": 1,
+ "loss": "LAD",
+ "training_optimization_level": 2
+ }),
+ executor="builtin",
+ inputs_ownership=[OwnerList("training_data", [self.user_id])],
+ outputs_ownership=[OwnerList("trained_model", [self.user_id])])
+
+ print("[+] assigning data to task")
+ client.assign_data_to_task(
+ task_id, [DataList("training_data", training_data_id)],
+ [DataList("trained_model", output_model_id)])
+
+ print("[+] approving task")
+ client.approve_task(task_id)
+
+ print("[+] invoking task")
+ client.invoke_task(task_id)
+
+ print("[+] getting result")
+ result = client.get_task_result(task_id)
+ print("[+] done")
+
+ return bytes(result)
+
+
+def main():
+ example = BuiltinGbdtExample(USER_ID, USER_PASSWORD)
+ rt = example.gbdt()
+
+ print("[+] function return: ", rt)
+
+
+if __name__ == '__main__':
+ main()
diff --git a/sdk/python/teaclave.py b/sdk/python/teaclave.py
index 1819ca0..f9cd8e3 100644
--- a/sdk/python/teaclave.py
+++ b/sdk/python/teaclave.py
@@ -105,6 +105,23 @@
self.outputs = outputs
+class RegisterInputFileRequest:
+ def __init__(self, metadata, url, cmac, crypto_info):
+ self.request = "register_input_file"
+ self.metadata = metadata
+ self.url = url
+ self.cmac = cmac
+ self.crypto_info = crypto_info
+
+
+class RegisterOutputFileRequest:
+ def __init__(self, metadata, url, crypto_info):
+ self.request = "register_output_file"
+ self.metadata = metadata
+ self.url = url
+ self.crypto_info = crypto_info
+
+
class CreateTaskRequest:
def __init__(self, metadata, function_id, function_arguments, executor,
inputs_ownership, outputs_ownership):
@@ -117,6 +134,22 @@
self.outputs_ownership = outputs_ownership
+class AssignDataRequest:
+ def __init__(self, metadata, task_id, inputs, outputs):
+ self.request = "assign_data"
+ self.metadata = metadata
+ self.task_id = task_id
+ self.inputs = inputs
+ self.outputs = outputs
+
+
+class ApproveTaskRequest:
+ def __init__(self, metadata, task_id):
+ self.request = "approve_task"
+ self.metadata = metadata
+ self.task_id = task_id
+
+
class InvokeTaskRequest:
def __init__(self, metadata, task_id):
self.request = "invoke_task"
@@ -131,6 +164,13 @@
self.task_id = task_id
+class TeaclaveFile128Key:
+ def __init__(self, schema, key, iv):
+ self.schema = schema
+ self.key = key
+ self.iv = iv
+
+
class FrontendClient:
def __init__(self, channel, metadata):
self.channel = channel
@@ -152,6 +192,20 @@
response = read_message(self.channel)
return response["content"]["function_id"]
+ def register_input_file(self, url, schema, key, iv, cmac):
+ request = RegisterInputFileRequest(self.metadata, url, cmac,
+ TeaclaveFile128Key(schema, key, iv))
+ write_message(self.channel, request)
+ response = read_message(self.channel)
+ return response["content"]["data_id"]
+
+ def register_output_file(self, url, schema, key, iv):
+ request = RegisterOutputFileRequest(
+ self.metadata, url, TeaclaveFile128Key(schema, key, iv))
+ write_message(self.channel, request)
+ response = read_message(self.channel)
+ return response["content"]["data_id"]
+
def create_task(self,
function_id,
function_arguments,
@@ -166,6 +220,18 @@
response = read_message(self.channel)
return response["content"]["task_id"]
+ def assign_data_to_task(self, task_id, inputs, outputs):
+ request = AssignDataRequest(self.metadata, task_id, inputs, outputs)
+ write_message(self.channel, request)
+ response = read_message(self.channel)
+ return
+
+ def approve_task(self, task_id):
+ request = ApproveTaskRequest(self.metadata, task_id)
+ write_message(self.channel, request)
+ response = read_message(self.channel)
+ return
+
def invoke_task(self, task_id):
request = InvokeTaskRequest(self.metadata, task_id)
write_message(self.channel, request)