Update python sdk and examples
diff --git a/.gitignore b/.gitignore
index 4190671..ba070e5 100644
--- a/.gitignore
+++ b/.gitignore
@@ -32,3 +32,6 @@
examples/c/builtin_echo
examples/c/builtin_ordered_set_intersect
examples/python/out.jpg
+# ignore grpc files during building and testing
+sdk/python/*_pb2.py
+sdk/python/*_grpc.py
diff --git a/cmake/scripts/test.sh b/cmake/scripts/test.sh
index e61e531..0d2474d 100755
--- a/cmake/scripts/test.sh
+++ b/cmake/scripts/test.sh
@@ -79,6 +79,14 @@
done
}
+generate_python_grpc_stubs() {
+ python3 -m grpc_tools.protoc \
+ --proto_path=${TEACLAVE_PROJECT_ROOT}/services/proto/src/proto \
+ --python_out=${TEACLAVE_PROJECT_ROOT}/sdk/python \
+ --grpclib_python_out=${TEACLAVE_PROJECT_ROOT}/sdk/python \
+ ${TEACLAVE_PROJECT_ROOT}/services/proto/src/proto/*.proto
+}
+
run_integration_tests() {
trap cleanup INT TERM ERR
@@ -152,7 +160,9 @@
./teaclave_functional_tests -t end_to_end
- # Run script tests
+ generate_python_grpc_stubs
+
+ export PYTHONPATH=${TEACLAVE_PROJECT_ROOT}/sdk/python
./scripts/functional_tests.py -v
popd
@@ -283,6 +293,8 @@
sleep 3 # wait for execution services
popd
+ generate_python_grpc_stubs
+
# run builtin examples
builtin_examples
@@ -328,6 +340,8 @@
sleep 3 # wait for execution services
popd
+ generate_python_grpc_stubs
+
# run builtin examples
builtin_examples
@@ -374,6 +388,8 @@
echo "executor 1 pid: $exe_pid1"
echo "executor 2 pid: $exe_pid2"
+ generate_python_grpc_stubs
+
pushd ${TEACLAVE_PROJECT_ROOT}/examples/python
export PYTHONPATH=${TEACLAVE_PROJECT_ROOT}/sdk/python
python3 mesapy_deadloop_cancel.py
diff --git a/examples/README.md b/examples/README.md
index 7cba2cf..8eeda5d 100644
--- a/examples/README.md
+++ b/examples/README.md
@@ -10,8 +10,14 @@
Before trying these examples, please make sure all services in the Teaclave
platform has been properly launched. Also, for examples implemented in Python,
-don't forget to set the `PYTHONPATH` to the `sdk` path so that the scripts can
-successfully import the `teaclave` module.
+don't forget to generate protocol stub files and set the `PYTHONPATH` to the
+`sdk` path so that the scripts can successfully import the `teaclave` module.
+
+Generate stub files by grpcio-tools and grpclib.
+
+```
+python3 -m grpc_tools.protoc --proto_path=../../services/proto/src/proto --python_out=. --grpclib_python_out=. ../../services/proto/src/proto/{teaclave_authentication_service.proto,teaclave_frontend_service.proto,teaclave_common.proto}
+```
For instance, use the following command to invoke an echo function in Teaclave:
diff --git a/examples/python/builtin_face_detection.py b/examples/python/builtin_face_detection.py
index a074373..570f9f9 100644
--- a/examples/python/builtin_face_detection.py
+++ b/examples/python/builtin_face_detection.py
@@ -17,13 +17,9 @@
# specific language governing permissions and limitations
# under the License.
-import os
-import sys
import json
from PIL import Image, ImageDraw
-import requests
-
from utils import USER_ID, USER_PASSWORD, connect_authentication_service, connect_frontend_service
from teaclave import FunctionArgument
diff --git a/examples/python/builtin_gbdt_train.py b/examples/python/builtin_gbdt_train.py
index 9b12510..bc93210 100644
--- a/examples/python/builtin_gbdt_train.py
+++ b/examples/python/builtin_gbdt_train.py
@@ -17,8 +17,6 @@
# specific language governing permissions and limitations
# under the License.
-import sys
-
from teaclave import FunctionInput, FunctionOutput, FunctionArgument, OwnerList, DataMap
from utils import USER_ID, USER_PASSWORD, connect_authentication_service, connect_frontend_service
diff --git a/examples/python/builtin_online_decrypt.py b/examples/python/builtin_online_decrypt.py
index 9a30b3b..302857d 100644
--- a/examples/python/builtin_online_decrypt.py
+++ b/examples/python/builtin_online_decrypt.py
@@ -17,7 +17,6 @@
# specific language governing permissions and limitations
# under the License.
-import sys
import base64
from utils import USER_ID, USER_PASSWORD, connect_authentication_service, connect_frontend_service
diff --git a/examples/python/builtin_ordered_set_intersect.py b/examples/python/builtin_ordered_set_intersect.py
index 277cef7..d084bf2 100644
--- a/examples/python/builtin_ordered_set_intersect.py
+++ b/examples/python/builtin_ordered_set_intersect.py
@@ -17,10 +17,8 @@
# specific language governing permissions and limitations
# under the License.
-import sys
-
from teaclave import FunctionInput, FunctionOutput, FunctionArgument, OwnerList, DataMap
-from utils import USER_ID, USER_PASSWORD, connect_authentication_service, connect_frontend_service, PlatformAdmin
+from utils import connect_authentication_service, connect_frontend_service, PlatformAdmin
# In the example, user 0 creates the task and user 0, 1, upload their private data.
# Then user 0 invokes the task and user 0, 1 get the result.
@@ -63,13 +61,6 @@
], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
-class DataList:
-
- def __init__(self, data_name, data_id):
- self.data_name = data_name
- self.data_id = data_id
-
-
class Client:
def __init__(self, user_id, user_password):
@@ -143,8 +134,8 @@
output_id = client.register_output_file(url, schema, key, iv)
print(f"[+] {self.user_id} assigning data to task")
- client.assign_data_to_task(task_id, [DataList(input_label, input_id)],
- [DataList(output_label, output_id)])
+ client.assign_data_to_task(task_id, [DataMap(input_label, input_id)],
+ [DataMap(output_label, output_id)])
def approve_task(self, task_id):
client = self.client
diff --git a/examples/python/builtin_password_check.py b/examples/python/builtin_password_check.py
index 1715be0..dac6864 100644
--- a/examples/python/builtin_password_check.py
+++ b/examples/python/builtin_password_check.py
@@ -17,10 +17,8 @@
# specific language governing permissions and limitations
# under the License.
-import sys
-
-from teaclave import FunctionInput, FunctionOutput, FunctionArgument, OwnerList, DataMap
-from utils import USER_ID, USER_PASSWORD, connect_authentication_service, connect_frontend_service, PlatformAdmin
+from teaclave import FunctionInput, OwnerList, DataMap
+from utils import connect_authentication_service, connect_frontend_service, PlatformAdmin
# In the example, user 0 creates the task and user 0, 1, upload their private data.
# Then user 0 invokes the task and user 0, 1 get the result.
@@ -69,13 +67,6 @@
], [], [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
-class DataList:
-
- def __init__(self, data_name, data_id):
- self.data_name = data_name
- self.data_id = data_id
-
-
class Client:
def __init__(self, user_id, user_password):
@@ -134,7 +125,7 @@
input_id = client.register_input_file(url, schema, key, iv, cmac)
print(f"[+] {self.user_id} assigning data to task")
- client.assign_data_to_task(task_id, [DataList(input_label, input_id)],
+ client.assign_data_to_task(task_id, [DataMap(input_label, input_id)],
[])
def approve_task(self, task_id):
diff --git a/examples/python/builtin_private_join_and_compute.py b/examples/python/builtin_private_join_and_compute.py
index 0b7c89b..a5e77c9 100644
--- a/examples/python/builtin_private_join_and_compute.py
+++ b/examples/python/builtin_private_join_and_compute.py
@@ -17,10 +17,8 @@
# specific language governing permissions and limitations
# under the License.
-import sys
-
from teaclave import FunctionInput, FunctionOutput, FunctionArgument, OwnerList, DataMap
-from utils import USER_ID, USER_PASSWORD, connect_authentication_service, connect_frontend_service, PlatformAdmin
+from utils import connect_authentication_service, connect_frontend_service, PlatformAdmin
# In the example, user 3 creates the task and user 0, 1, 2 upload their private data.
# Then user 3 invokes the task and user 0, 1, 2 get the result.
@@ -70,13 +68,6 @@
USER_DATA_3 = UserData("user3", "password")
-class DataList:
-
- def __init__(self, data_name, data_id):
- self.data_name = data_name
- self.data_id = data_id
-
-
class ConfigClient:
def __init__(self, user_id, user_password):
@@ -173,8 +164,8 @@
output_id = client.register_output_file(url, schema, key, iv)
print(f"[+] {self.user_id} assigning data to task")
- client.assign_data_to_task(task_id, [DataList(input_label, input_id)],
- [DataList(output_label, output_id)])
+ client.assign_data_to_task(task_id, [DataMap(input_label, input_id)],
+ [DataMap(output_label, output_id)])
def approve_task(self, task_id):
client = self.client
diff --git a/examples/python/builtin_rsa_sign.py b/examples/python/builtin_rsa_sign.py
index 3ba494a..63128d6 100644
--- a/examples/python/builtin_rsa_sign.py
+++ b/examples/python/builtin_rsa_sign.py
@@ -17,10 +17,8 @@
# specific language governing permissions and limitations
# under the License.
-import sys
-
-from teaclave import FunctionInput, FunctionOutput, FunctionArgument, OwnerList, DataMap
-from utils import USER_ID, USER_PASSWORD, connect_authentication_service, connect_frontend_service, PlatformAdmin
+from teaclave import FunctionInput, FunctionArgument, OwnerList, DataMap
+from utils import connect_authentication_service, connect_frontend_service, PlatformAdmin
def get_client(user_id, user_password):
diff --git a/examples/python/mesapy_deadloop_cancel.py b/examples/python/mesapy_deadloop_cancel.py
index a4b0eaa..c4311da 100644
--- a/examples/python/mesapy_deadloop_cancel.py
+++ b/examples/python/mesapy_deadloop_cancel.py
@@ -17,10 +17,9 @@
# specific language governing permissions and limitations
# under the License.
-import sys
import time
-from teaclave import FunctionInput, FunctionOutput, OwnerList, DataMap, TaskStatus
+from teaclave import TaskStatus
from utils import USER_ID, USER_PASSWORD, connect_authentication_service, connect_frontend_service, PlatformAdmin
diff --git a/examples/python/mesapy_echo.py b/examples/python/mesapy_echo.py
index b60df00..7109374 100644
--- a/examples/python/mesapy_echo.py
+++ b/examples/python/mesapy_echo.py
@@ -19,7 +19,7 @@
import sys
-from teaclave import FunctionInput, FunctionOutput, FunctionArgument, OwnerList, DataMap
+from teaclave import FunctionArgument
from utils import USER_ID, USER_PASSWORD, connect_authentication_service, connect_frontend_service, PlatformAdmin
diff --git a/examples/python/mesapy_logistic_reg.py b/examples/python/mesapy_logistic_reg.py
index 2b8c8b8..409275e 100644
--- a/examples/python/mesapy_logistic_reg.py
+++ b/examples/python/mesapy_logistic_reg.py
@@ -20,11 +20,9 @@
An example about Logistic Regression in MesaPy.
"""
-import sys
-import binascii
from typing import List
from teaclave import FunctionInput, FunctionOutput, FunctionArgument, OwnerList, DataMap
-from utils import USER_ID, USER_PASSWORD, connect_authentication_service, connect_frontend_service, PlatformAdmin
+from utils import connect_authentication_service, connect_frontend_service
from enum import Enum
@@ -79,13 +77,6 @@
self.label = label
-class DataList:
-
- def __init__(self, data_name, data_id):
- self.data_name = data_name
- self.data_id = data_id
-
-
class ConfigClient:
def __init__(self, user_id, user_password):
@@ -161,7 +152,7 @@
key = da.file_key
iv = da.iv
input_id = client.register_input_file(url, schema, key, iv, cmac)
- input_data_list.append(DataList(da.label, input_id))
+ input_data_list.append(DataMap(da.label, input_id))
print(f"[+] {self.user_id} registering output file")
output_data_list = []
for out_data in outputs:
@@ -170,7 +161,7 @@
key = out_data.file_key
iv = out_data.iv
output_id = client.register_output_file(out_url, schema, key, iv)
- output_data_list.append(DataList(out_data.label, output_id))
+ output_data_list.append(DataMap(out_data.label, output_id))
print(f"[+] {self.user_id} assigning data to task")
client.assign_data_to_task(task_id, input_data_list, output_data_list)
diff --git a/examples/python/requirements.txt b/examples/python/requirements.txt
index 6f3b748..82d36f2 100644
--- a/examples/python/requirements.txt
+++ b/examples/python/requirements.txt
@@ -3,3 +3,6 @@
cryptography
requests
Pillow
+grpclib
+grpcio
+grpcio-tools
diff --git a/examples/python/test_disable_function.py b/examples/python/test_disable_function.py
index dbea609..9923f6e 100644
--- a/examples/python/test_disable_function.py
+++ b/examples/python/test_disable_function.py
@@ -17,10 +17,8 @@
# specific language governing permissions and limitations
# under the License.
-import sys
-
-from teaclave import FunctionInput, FunctionOutput, FunctionArgument, OwnerList, DataMap
-from utils import USER_ID, USER_PASSWORD, connect_authentication_service, connect_frontend_service, PlatformAdmin
+from teaclave import FunctionArgument
+from utils import connect_authentication_service, connect_frontend_service, PlatformAdmin
class UserData:
diff --git a/examples/python/utils.py b/examples/python/utils.py
index 41299fa..7011564 100644
--- a/examples/python/utils.py
+++ b/examples/python/utils.py
@@ -45,7 +45,7 @@
def __init__(self, user_id: str, user_password: str):
self.client = AuthenticationService(AUTHENTICATION_SERVICE_ADDRESS,
AS_ROOT_CA_CERT_PATH,
- ENCLAVE_INFO_PATH).connect()
+ ENCLAVE_INFO_PATH)
token = self.client.user_login(user_id, user_password)
self.client.metadata = {"id": user_id, "token": token}
@@ -55,10 +55,9 @@
def connect_authentication_service():
return AuthenticationService(AUTHENTICATION_SERVICE_ADDRESS,
- AS_ROOT_CA_CERT_PATH,
- ENCLAVE_INFO_PATH).connect()
+ AS_ROOT_CA_CERT_PATH, ENCLAVE_INFO_PATH)
def connect_frontend_service():
return FrontendService(FRONTEND_SERVICE_ADDRESS, AS_ROOT_CA_CERT_PATH,
- ENCLAVE_INFO_PATH).connect()
+ ENCLAVE_INFO_PATH)
diff --git a/examples/python/wasm_c_simple_add.py b/examples/python/wasm_c_simple_add.py
index 867cba3..ee6762b 100644
--- a/examples/python/wasm_c_simple_add.py
+++ b/examples/python/wasm_c_simple_add.py
@@ -19,7 +19,7 @@
import sys
-from teaclave import FunctionInput, FunctionOutput, FunctionArgument, OwnerList, DataMap
+from teaclave import FunctionArgument
from utils import USER_ID, USER_PASSWORD, connect_authentication_service, connect_frontend_service, PlatformAdmin
diff --git a/examples/python/wasm_rust_psi.py b/examples/python/wasm_rust_psi.py
index 61fc1c5..60176e8 100644
--- a/examples/python/wasm_rust_psi.py
+++ b/examples/python/wasm_rust_psi.py
@@ -17,10 +17,8 @@
# specific language governing permissions and limitations
# under the License.
-import sys
-
from teaclave import FunctionInput, FunctionOutput, FunctionArgument, OwnerList, DataMap
-from utils import USER_ID, USER_PASSWORD, connect_authentication_service, connect_frontend_service, PlatformAdmin
+from utils import connect_authentication_service, connect_frontend_service, PlatformAdmin
class UserData:
@@ -66,13 +64,6 @@
[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0])
-class DataList:
-
- def __init__(self, data_name, data_id):
- self.data_name = data_name
- self.data_id = data_id
-
-
class Client:
def __init__(self, user_id, user_password):
@@ -168,8 +159,8 @@
output_id = client.register_output_file(url, schema, key, iv)
print(f"[+] {self.user_id} assigning data to task")
- client.assign_data_to_task(task_id, [DataList(input_label, input_id)],
- [DataList(output_label, output_id)])
+ client.assign_data_to_task(task_id, [DataMap(input_label, input_id)],
+ [DataMap(output_label, output_id)])
def approve_task(self, task_id):
client = self.client
diff --git a/sdk/python/teaclave.py b/sdk/python/teaclave.py
index 2fec9cd..c19d325 100644
--- a/sdk/python/teaclave.py
+++ b/sdk/python/teaclave.py
@@ -23,26 +23,33 @@
authentication service and frontend service) through RPC protocols.
"""
-import struct
import json
import base64
import toml
-import os
import time
+import os
import ssl
-import socket
-
-from typing import Tuple, Dict, List, Any
-from enum import IntEnum
import cryptography
from cryptography import x509
from cryptography.hazmat.backends import default_backend
+from google.protobuf.json_format import MessageToDict
+from grpclib.client import Channel, _ChannelState
+from grpclib.protocol import H2Protocol
+
from OpenSSL.crypto import load_certificate, FILETYPE_PEM, FILETYPE_ASN1
from OpenSSL.crypto import X509Store, X509StoreContext
from OpenSSL import crypto
+import teaclave_authentication_service_pb2 as auth
+import teaclave_frontend_service_pb2 as fe
+from teaclave_authentication_service_grpc import TeaclaveAuthenticationApiStub
+from teaclave_frontend_service_grpc import TeaclaveFrontendStub
+from teaclave_common_pb2 import TaskStatus, FileCryptoInfo
+
+from typing import Tuple, Dict, List, Any
+
__all__ = [
'FrontendService', 'AuthenticationService', 'FunctionArgument',
'FunctionInput', 'FunctionOutput', 'OwnerList', 'DataMap'
@@ -51,19 +58,13 @@
Metadata = Dict[str, str]
-class TaskStatus(IntEnum):
- Created = 0
- DataAssigned = 1
- Approved = 2
- Staged = 3
- Running = 4
- Finished = 10
- Canceled = 20
- Failed = 99
-
-
class Request:
- pass
+ message = None
+
+ def __init__(self, method, response, metadata=dict()):
+ self.method = method
+ self.metadata = metadata
+ self.response = response
class TeaclaveException(Exception):
@@ -71,8 +72,8 @@
class TeaclaveService:
- channel = None
metadata = None
+ stub = None
def __init__(self,
name: str,
@@ -80,38 +81,79 @@
as_root_ca_cert_path: str,
enclave_info_path: str,
dump_report=False):
- self._context = ssl._create_unverified_context()
self._name = name
self._address = address
self._as_root_ca_cert_path = as_root_ca_cert_path
self._enclave_info_path = enclave_info_path
- self._closed = False
self._dump_report = dump_report
+ self._channel = TeaclaveChannel(self._name, self._address,
+ self._as_root_ca_cert_path,
+ self._enclave_info_path)
+ self._loop = self._channel._loop
+
+ def call_method(self, request):
+ return self._loop.run_until_complete(
+ getattr(self.stub, request.method)(request.message,
+ metadata=request.metadata))
+
def __enter__(self):
return self
def __exit__(self, *exc):
- if not self._closed:
- self.close()
+ self.close()
def close(self):
- self._closed = True
- if self.channel: self.channel.close()
+ if self._channel: self._channel.close()
- def check_channel(self):
- if not self.channel: raise TeaclaveException("Channel is None")
+ def __del__(self) -> None:
+ self.close()
def check_metadata(self):
if not self.metadata: raise TeaclaveException("Metadata is None")
- def connect(self):
- """Establish trusted connection and verify remote attestation report.
- """
- sock = socket.create_connection(self._address)
- channel = self._context.wrap_socket(sock,
- server_hostname=self._address[0])
- cert = channel.getpeercert(binary_form=True)
+ def check_channel(self):
+ self._channel.check_channel()
+
+ def get_metadata(self):
+ return self.metadata
+
+
+def create_context() -> ssl.SSLContext:
+ ctx = ssl._create_unverified_context()
+ ctx.options |= ssl.OP_NO_TLSv1 | ssl.OP_NO_TLSv1_1
+ ctx.set_ciphers('ECDHE+AESGCM:ECDHE+CHACHA20:DHE+AESGCM:DHE+CHACHA20')
+ ctx.set_alpn_protocols(['h2'])
+ try:
+ ctx.set_npn_protocols(['h2'])
+ except NotImplementedError:
+ pass
+ return ctx
+
+
+class TeaclaveChannel(Channel):
+
+ def __init__(self,
+ name: str,
+ address: Tuple[str, int],
+ as_root_ca_cert_path: str,
+ enclave_info_path: str,
+ dump_report=False):
+ context = create_context()
+ super().__init__(host=address[0], port=address[1], ssl=context)
+ self._name = name
+ self._as_root_ca_cert_path = as_root_ca_cert_path
+ self._enclave_info_path = enclave_info_path
+ self._dump_report = dump_report
+
+ def check_channel(self):
+ if self._state == _ChannelState.TRANSIENT_FAILURE:
+ raise TeaclaveException("Channel is None")
+
+ async def __connect__(self) -> H2Protocol:
+ protocol = await super().__connect__()
+ sslobj = protocol.connection._transport.get_extra_info('ssl_object')
+ cert = sslobj.getpeercert(binary_form=True)
if not cert: raise TeaclaveException("Peer cert is None")
try:
self._verify_report(self._as_root_ca_cert_path,
@@ -119,9 +161,7 @@
except Exception as e:
raise TeaclaveException(
f"Failed to verify attestation report: {e}")
- self.channel = channel
-
- return self
+ return protocol
def _verify_report(self, as_root_ca_cert_path: str, enclave_info_path: str,
cert: Dict[str, Any], endpoint_name: str):
@@ -244,9 +284,9 @@
"""
def __init__(self, name: str, description: str, optional=False):
- self.name = name
- self.description = description
- self.optional = optional
+ self.message = fe.FunctionInput(name=name,
+ description=description,
+ optional=optional)
class FunctionOutput:
@@ -260,9 +300,9 @@
"""
def __init__(self, name: str, description: str, optional=False):
- self.name = name
- self.description = description
- self.optional = optional
+ self.message = fe.FunctionOutput(name=name,
+ description=description,
+ optional=optional)
class FunctionArgument:
@@ -280,9 +320,9 @@
key: str,
default_value: str = "",
allow_overwrite=True):
- self.key = key
- self.default_value = default_value
- self.allow_overwrite = allow_overwrite
+ self.message = fe.FunctionArgument(key=key,
+ default_value=default_value,
+ allow_overwrite=allow_overwrite)
class OwnerList:
@@ -295,8 +335,7 @@
"""
def __init__(self, data_name: str, uids: List[str]):
- self.data_name = data_name
- self.uids = uids
+ self.message = fe.OwnerList(data_name=data_name, uids=uids)
class DataMap:
@@ -309,8 +348,7 @@
"""
def __init__(self, data_name, data_id):
- self.data_name = data_name
- self.data_id = data_id
+ self.message = fe.DataMap(data_name=data_name, data_id=data_id)
class CryptoInfo:
@@ -324,73 +362,70 @@
"""
def __init__(self, schema: str, key: List[int], iv: List[int]):
- self.schema = schema
- self.key = key
- self.iv = iv
+
+ self.message = FileCryptoInfo(schema=schema,
+ key=bytes(key),
+ iv=bytes(iv))
class UserRegisterRequest(Request):
def __init__(self, metadata: Metadata, user_id: str, user_password: str,
role: str, attribute: str):
- self.request = "user_register"
- self.metadata = metadata
- self.id = user_id
- self.password = user_password
- self.role = role
- self.attribute = attribute
+ super().__init__("UserRegister", auth.UserRegisterResponse, metadata)
+ self.message = auth.UserRegisterRequest(id=user_id,
+ password=user_password,
+ role=role,
+ attribute=attribute)
class UserUpdateRequest(Request):
def __init__(self, metadata: Metadata, user_id: str, user_password: str,
role: str, attribute: str):
- self.request = "user_update"
- self.metadata = metadata
- self.id = user_id
- self.password = user_password
- self.role = role
- self.attribute = attribute
+ super().__init__("UserUpdate", auth.UserUpdateResponse)
+ self.message = auth.UserUpdateRequest(id=user_id,
+ password=user_password,
+ role=role,
+ attribute=attribute)
class UserLoginRequest(Request):
def __init__(self, user_id: str, user_password: str):
- self.request = "user_login"
- self.id = user_id
- self.password = user_password
+ super().__init__("UserLogin", auth.UserLoginResponse)
+ self.message = auth.UserLoginRequest(id=user_id,
+ password=user_password)
class UserChangePasswordRequest(Request):
def __init__(self, metadata: Metadata, password: str):
- self.request = "user_change_password"
- self.metadata = metadata
- self.password = password
+ super().__init__("UserChangePassword", auth.UserChangePasswordResponse,
+ metadata)
+ self.message = auth.UserChangePasswordRequest(password=password)
class ResetUserPasswordRequest(Request):
def __init__(self, metadata: Metadata, user_id: str):
- self.request = "reset_user_password"
- self.metadata = metadata
- self.id = user_id
+ super().__init__("ResetUserPassword", auth.ResetUserPasswordResponse,
+ metadata)
+ self.message = auth.ResetUserPasswordRequest(id=user_id)
class DeleteUserRequest(Request):
def __init__(self, metadata: Metadata, user_id: str):
- self.request = "delete_user"
- self.metadata = metadata
- self.id = user_id
+ super().__init__("DeleteUser", auth.DeleteUserResponse, metadata)
+ self.message = auth.DeleteUserRequest(id=user_id)
class ListUsersRequest(Request):
def __init__(self, metadata: Metadata, user_id: str):
- self.request = "list_users"
- self.metadata = metadata
- self.id = user_id
+ super().__init__("ListUsers", auth.ListUsersResponse, metadata)
+ self.message = auth.ListUsersRequest(id=user_id)
class AuthenticationService(TeaclaveService):
@@ -414,9 +449,13 @@
dump_report=False):
super().__init__("authentication", address, as_root_ca_cert_path,
enclave_info_path, dump_report)
+ self.stub = TeaclaveAuthenticationApiStub(self._channel)
- def user_register(self, user_id: str, user_password: str, role: str,
- attribute: str):
+ def user_register(self,
+ user_id: str,
+ user_password: str,
+ role="",
+ attribute=""):
"""Register a new user.
Args:
@@ -430,18 +469,17 @@
self.check_metadata()
request = UserRegisterRequest(self.metadata, user_id, user_password,
role, attribute)
- _write_message(self.channel, request)
- response = _read_message(self.channel)
- if response["result"] == "ok":
- pass
- else:
- reason = "unknown"
- if "request_error" in response:
- reason = response["request_error"]
- raise TeaclaveException(f"Failed to register user ({reason})")
+ try:
+ response = self.call_method(request)
+ return response
+ except Exception as e:
+ raise TeaclaveException(f"Failed to register user {str(e)}")
- def user_update(self, user_id: str, user_password: str, role: str,
- attribute: str):
+ def user_update(self,
+ user_id: str,
+ user_password: str,
+ role: str,
+ attribute=""):
"""Update an existing user.
Args:
@@ -455,15 +493,11 @@
self.check_metadata()
request = UserUpdateRequest(self.metadata, user_id, user_password,
role, attribute)
- _write_message(self.channel, request)
- response = _read_message(self.channel)
- if response["result"] == "ok":
- pass
- else:
- reason = "unknown"
- if "request_error" in response:
- reason = response["request_error"]
- raise TeaclaveException(f"Failed to update user ({reason})")
+ try:
+ response = self.call_method(request)
+ return response
+ except Exception as e:
+ raise TeaclaveException(f"Failed to update user {str(e)}")
def user_login(self, user_id: str, user_password: str) -> str:
"""Login and get a session token.
@@ -477,17 +511,14 @@
str: User login token.
"""
- self.check_channel()
+ self._channel.check_channel()
request = UserLoginRequest(user_id, user_password)
- _write_message(self.channel, request)
- response = _read_message(self.channel)
- if response["result"] == "ok":
- return response["content"]["token"]
- else:
- reason = "unknown"
- if "request_error" in response:
- reason = response["request_error"]
- raise TeaclaveException(f"Failed to login user ({reason})")
+ try:
+ response = self.call_method(request)
+ self.metadata = {"id": user_id, "token": response.token}
+ return response.token
+ except Exception as e:
+ raise TeaclaveException(f"Failed to login user {str(e)}")
def user_change_password(self, user_password: str):
"""Change password.
@@ -499,15 +530,11 @@
self.check_channel()
self.check_metadata()
request = UserChangePasswordRequest(self.metadata, user_password)
- _write_message(self.channel, request)
- response = _read_message(self.channel)
- if response["result"] == "ok":
- pass
- else:
- reason = "unknown"
- if "request_error" in response:
- reason = response["request_error"]
- raise TeaclaveException(f"Failed to change password ({reason})")
+ try:
+ response = self.call_method(request)
+ return response
+ except Exception as e:
+ raise TeaclaveException(f"Failed to change password {str(e)}")
def reset_user_password(self, user_id: str) -> str:
"""Reset password of a managed user.
@@ -523,15 +550,12 @@
self.check_channel()
self.check_metadata()
request = ResetUserPasswordRequest(self.metadata, user_id)
- _write_message(self.channel, request)
- response = _read_message(self.channel)
- if response["result"] == "ok":
- return response["content"]["password"]
- else:
- reason = "unknown"
- if "request_error" in response:
- reason = response["request_error"]
- raise TeaclaveException(f"Failed to reset password ({reason})")
+ try:
+ response = self.call_method(request)
+ return response
+ except Exception as e:
+ reason = str(e)
+ raise TeaclaveException(f"Failed to reset password {reason}")
def delete_user(self, user_id: str) -> str:
"""Delete a user.
@@ -543,14 +567,11 @@
self.check_channel()
self.check_metadata()
request = DeleteUserRequest(self.metadata, user_id)
- _write_message(self.channel, request)
- response = _read_message(self.channel)
- if response["result"] == "ok":
- pass
- else:
- reason = "unknown"
- if "request_error" in response:
- reason = response["request_error"]
+ try:
+ response = self.call_method(request)
+ return response
+ except Exception as e:
+ reason = str(e)
raise TeaclaveException(f"Failed to delete user ({reason})")
def list_users(self, user_id: str) -> str:
@@ -565,15 +586,13 @@
str: User list
"""
self.check_channel()
+ self.check_metadata()
request = ListUsersRequest(self.metadata, user_id)
- _write_message(self.channel, request)
- response = _read_message(self.channel)
- if response["result"] == "ok":
- return response["content"]["ids"]
- else:
- reason = "unknown"
- if "request_error" in response:
- reason = response["request_error"]
+ try:
+ response = self.call_method(request)
+ return response
+ except Exception as e:
+ reason = str(e)
raise TeaclaveException(f"Failed to list user ({reason})")
@@ -584,18 +603,23 @@
arguments: List[FunctionArgument],
inputs: List[FunctionInput], outputs: List[FunctionOutput],
user_allowlist: List[str], usage_quota: int):
- self.request = "register_function"
- self.metadata = metadata
- self.name = name
- self.description = description
- self.executor_type = executor_type
- self.public = public
- self.payload = payload
- self.arguments = arguments
- self.inputs = inputs
- self.outputs = outputs
- self.user_allowlist = user_allowlist
- self.usage_quota = usage_quota
+ super().__init__("RegisterFunction", fe.RegisterFunctionResponse,
+ metadata)
+ arguments = [x.message for x in arguments]
+ inputs = [x.message for x in inputs]
+ outputs = [x.message for x in outputs]
+
+ self.message = fe.RegisterFunctionRequest(
+ name=name,
+ description=description,
+ executor_type=executor_type,
+ public=public,
+ payload=bytes(payload),
+ arguments=arguments,
+ inputs=inputs,
+ outputs=outputs,
+ user_allowlist=user_allowlist,
+ usage_quota=usage_quota)
class UpdateFunctionRequest(Request):
@@ -605,97 +629,87 @@
payload: List[int], arguments: List[FunctionArgument],
inputs: List[FunctionInput], outputs: List[FunctionOutput],
user_allowlist: List[str], usage_quota: int):
- self.request = "update_function"
- self.metadata = metadata
- self.function_id = function_id
- self.name = name
- self.description = description
- self.executor_type = executor_type
- self.public = public
- self.payload = payload
- self.arguments = arguments
- self.inputs = inputs
- self.outputs = outputs
- self.user_allowlist = user_allowlist
- self.usage_quota = usage_quota
+ super().__init__("UpdateFunction", fe.UpdateFunctionResponse, metadata)
+ arguments = [x.message for x in arguments]
+ inputs = [x.message for x in inputs]
+ outputs = [x.message for x in outputs]
+
+ self.message = fe.UpdateFunctionRequest(function_id, name, description,
+ executor_type, public, payload,
+ arguments, inputs, outputs,
+ user_allowlist, usage_quota)
class ListFunctionsRequest(Request):
def __init__(self, metadata: Metadata, user_id: str):
- self.request = "list_functions"
- self.metadata = metadata
- self.user_id = user_id
+ super().__init__("ListFunctions", fe.ListFunctionsResponse, metadata)
+ self.message = fe.ListFunctionsRequest(user_id=user_id)
class DeleteFunctionRequest(Request):
def __init__(self, metadata: Metadata, function_id: str):
- self.request = "delete_function"
- self.metadata = metadata
- self.function_id = function_id
+ super().__init__("ListFunctions", fe.DeleteFunctionResponse, metadata)
+ self.message = fe.DeleteFunctionRequest(function_id=function_id)
class DisableFunctionRequest(Request):
def __init__(self, metadata: Metadata, function_id: str):
- self.request = "disable_function"
- self.metadata = metadata
- self.function_id = function_id
+ super().__init__("DisableFunction", fe.DisableFunctionResponse,
+ metadata)
+ self.message = fe.DisableFunctionRequest(function_id=function_id)
class GetFunctionRequest(Request):
def __init__(self, metadata: Metadata, function_id: str):
- self.request = "get_function"
- self.metadata = metadata
- self.function_id = function_id
+ super().__init__("GetFunction", fe.GetFunctionResponse, metadata)
+ self.message = fe.GetFunctionRequest(function_id=function_id)
class GetFunctionUsageStatsRequest(Request):
def __init__(self, metadata: Metadata, function_id: str):
- self.request = "get_function_usage_stats"
- self.metadata = metadata
- self.function_id = function_id
+ super().__init__("GetFunctionUsageStats",
+ fe.GetFunctionUsageStatsResponse, metadata)
+ self.message = fe.GetFunctionUsageStatsRequest(function_id=function_id)
class RegisterInputFileRequest(Request):
def __init__(self, metadata: Metadata, url: str, cmac: List[int],
crypto_info: CryptoInfo):
- self.request = "register_input_file"
- self.metadata = metadata
- self.url = url
- self.cmac = cmac
- self.crypto_info = crypto_info
+ super().__init__("RegisterInputFile", fe.RegisterInputFileResponse,
+ metadata)
+ self.message = fe.RegisterInputFileRequest(
+ url=url, cmac=bytes(cmac), crypto_info=crypto_info.message)
class RegisterOutputFileRequest(Request):
def __init__(self, metadata: Metadata, url: str, crypto_info: CryptoInfo):
- self.request = "register_output_file"
- self.metadata = metadata
- self.url = url
- self.crypto_info = crypto_info
+ super().__init__("RegisterOutputFile", fe.RegisterOutputFileResponse,
+ metadata)
+ self.message = fe.RegisterOutputFileRequest(
+ url=url, crypto_info=crypto_info.message)
class UpdateInputFileRequest(Request):
def __init__(self, metadata: Metadata, data_id: str, url: str):
- self.request = "update_input_file"
- self.metadata = metadata
- self.data_id = data_id
- self.url = url
+ super().__init__("UpdateInputFile", fe.UpdateInputFileResponse,
+ metadata)
+ self.message = fe.UpdateInputFileRequest(data_id=data_id, url=url)
class UpdateOutputFileRequest(Request):
def __init__(self, metadata: Metadata, data_id: str, url: str):
- self.request = "update_output_file"
- self.metadata = metadata
- self.data_id = data_id
- self.url = url
+ super().__init__("UpdateInputFile", fe.UpdateOutputFileResponse,
+ metadata)
+ self.message = fe.UpdateOutputFileRequest(data_id=data_id, url=url)
class CreateTaskRequest(Request):
@@ -704,56 +718,56 @@
function_arguments: Dict[str, Any], executor: str,
inputs_ownership: List[OwnerList],
outputs_ownership: List[OwnerList]):
- self.request = "create_task"
- self.metadata = metadata
- self.function_id = function_id
- self.function_arguments = function_arguments
- self.executor = executor
- self.inputs_ownership = inputs_ownership
- self.outputs_ownership = outputs_ownership
+ super().__init__("CreateTask", fe.CreateTaskResponse, metadata)
+ inputs_ownership = [x.message for x in inputs_ownership]
+ outputs_ownership = [x.message for x in outputs_ownership]
+
+ self.message = fe.CreateTaskRequest(
+ function_id=function_id,
+ function_arguments=function_arguments,
+ executor=executor,
+ inputs_ownership=inputs_ownership,
+ outputs_ownership=outputs_ownership)
class AssignDataRequest(Request):
def __init__(self, metadata: Metadata, task_id: str, inputs: List[DataMap],
outputs: List[DataMap]):
- self.request = "assign_data"
- self.metadata = metadata
- self.task_id = task_id
- self.inputs = inputs
- self.outputs = outputs
+ super().__init__("AssignData", fe.AssignDataResponse, metadata)
+ inputs = [x.message for x in inputs]
+ outputs = [x.message for x in outputs]
+ self.message = fe.AssignDataRequest(task_id=task_id,
+ inputs=inputs,
+ outputs=outputs)
class ApproveTaskRequest(Request):
def __init__(self, metadata: Metadata, task_id: str):
- self.request = "approve_task"
- self.metadata = metadata
- self.task_id = task_id
+ super().__init__("ApproveTask", fe.ApproveTaskResponse, metadata)
+ self.message = fe.ApproveTaskRequest(task_id=task_id)
class InvokeTaskRequest(Request):
def __init__(self, metadata: Metadata, task_id: str):
- self.request = "invoke_task"
- self.metadata = metadata
- self.task_id = task_id
+ super().__init__("InvokeTask", fe.InvokeTaskResponse, metadata)
+ self.message = fe.InvokeTaskRequest(task_id=task_id)
class CancelTaskRequest(Request):
def __init__(self, metadata: Metadata, task_id: str):
- self.request = "cancel_task"
- self.metadata = metadata
- self.task_id = task_id
+ super().__init__("CancelTask", fe.CancelTaskResponse, metadata)
+ self.message = fe.CancelTaskRequest(task_id=task_id)
class GetTaskRequest(Request):
def __init__(self, metadata: Metadata, task_id: str):
- self.request = "get_task"
- self.metadata = metadata
- self.task_id = task_id
+ super().__init__("GetTask", fe.GetTaskResponse, metadata)
+ self.message = fe.GetTaskRequest(task_id=task_id)
class FrontendService(TeaclaveService):
@@ -776,6 +790,7 @@
dump_report=False):
super().__init__("frontend", address, as_root_ca_cert_path,
enclave_info_path, dump_report)
+ self.stub = TeaclaveFrontendStub(self._channel)
def register_function(
self,
@@ -796,14 +811,11 @@
executor_type, public, payload,
arguments, inputs, outputs,
user_allowlist, usage_quota)
- _write_message(self.channel, request)
- response = _read_message(self.channel)
- if response["result"] == "ok":
- return response["content"]["function_id"]
- else:
- reason = "unknown"
- if "request_error" in response:
- reason = response["request_error"]
+ try:
+ response = self.call_method(request)
+ return response.function_id
+ except Exception as e:
+ reason = str(e)
raise TeaclaveException(f"Failed to register function ({reason})")
def update_function(
@@ -826,53 +838,45 @@
description, executor_type, public,
payload, arguments, inputs, outputs,
user_allowlist, usage_quota)
- _write_message(self.channel, request)
- response = _read_message(self.channel)
- if response["result"] == "ok":
- return response["content"]["function_id"]
- else:
- reason = "unknown"
- if "request_error" in response:
- reason = response["request_error"]
- raise TeaclaveException(f"Failed to update function ({reason})")
+ try:
+ response = self.call_method(request)
+ return response.function_id
+ except Exception as e:
+ reason = str(e)
+ raise TeaclaveException(f"Failed to register function ({reason})")
def list_functions(self, user_id: str):
self.check_metadata()
self.check_channel()
request = ListFunctionsRequest(self.metadata, user_id)
- _write_message(self.channel, request)
- response = _read_message(self.channel)
- if response["result"] == "ok":
- return response["content"]
- else:
- raise TeaclaveException("Failed to list functions")
+ try:
+ response = self.call_method(request)
+ except Exception as e:
+ raise TeaclaveException(f"Failed to list functions ({str(e)})")
+ return MessageToDict(response,
+ preserving_proto_field_name=True,
+ use_integers_for_enums=True)
def get_function(self, function_id: str):
self.check_metadata()
self.check_channel()
request = GetFunctionRequest(self.metadata, function_id)
- _write_message(self.channel, request)
- response = _read_message(self.channel)
- if response["result"] == "ok":
- return response["content"]
- else:
- reason = "unknown"
- if "request_error" in response:
- reason = response["request_error"]
+ try:
+ response = self.call_method(request)
+ return response
+ except Exception as e:
+ reason = str(e)
raise TeaclaveException(f"Failed to get function ({reason})")
- def get_function_usage_stats(self, user_id: str, function_id: str):
+ def get_function_usage_stats(self, function_id: str):
self.check_metadata()
self.check_channel()
request = GetFunctionUsageStatsRequest(self.metadata, function_id)
- _write_message(self.channel, request)
- response = _read_message(self.channel)
- if response["result"] == "ok":
- return response["content"]
- else:
- reason = "unknown"
- if "request_error" in response:
- reason = response["request_error"]
+ try:
+ response = self.call_method(request)
+ return response
+ except Exception as e:
+ reason = str(e)
raise TeaclaveException(
f"Failed to get function usage statistics ({reason})")
@@ -880,23 +884,23 @@
self.check_metadata()
self.check_channel()
request = DeleteFunctionRequest(self.metadata, function_id)
- _write_message(self.channel, request)
- response = _read_message(self.channel)
- if response["result"] == "ok":
- return response["content"]
- else:
- raise TeaclaveException("Failed to delete function")
+ try:
+ response = self.call_method(request)
+ return response
+ except Exception as e:
+ reason = str(e)
+ raise TeaclaveException(f"Failed to delete function ({reason})")
def disable_function(self, function_id: str):
self.check_metadata()
self.check_channel()
request = DisableFunctionRequest(self.metadata, function_id)
- _write_message(self.channel, request)
- response = _read_message(self.channel)
- if response["result"] == "ok":
- return response["content"]
- else:
- raise TeaclaveException("Failed to disable function")
+ try:
+ response = self.call_method(request)
+ return response
+ except Exception as e:
+ reason = str(e)
+ raise TeaclaveException(f"Failed to disable function ({reason})")
def register_input_file(self, url: str, schema: str, key: List[int],
iv: List[int], cmac: List[int]):
@@ -904,14 +908,11 @@
self.check_channel()
request = RegisterInputFileRequest(self.metadata, url, cmac,
CryptoInfo(schema, key, iv))
- _write_message(self.channel, request)
- response = _read_message(self.channel)
- if response["result"] == "ok":
- return response["content"]["data_id"]
- else:
- reason = "unknown"
- if "request_error" in response:
- reason = response["request_error"]
+ try:
+ response = self.call_method(request)
+ return response.data_id
+ except Exception as e:
+ reason = str(e)
raise TeaclaveException(
f"Failed to register input file ({reason})")
@@ -921,14 +922,11 @@
self.check_channel()
request = RegisterOutputFileRequest(self.metadata, url,
CryptoInfo(schema, key, iv))
- _write_message(self.channel, request)
- response = _read_message(self.channel)
- if response["result"] == "ok":
- return response["content"]["data_id"]
- else:
- reason = "unknown"
- if "request_error" in response:
- reason = response["request_error"]
+ try:
+ response = self.call_method(request)
+ return response.data_id
+ except Exception as e:
+ reason = str(e)
raise TeaclaveException(
f"Failed to register output file ({reason})")
@@ -944,14 +942,11 @@
request = CreateTaskRequest(self.metadata, function_id,
function_arguments, executor,
inputs_ownership, outputs_ownership)
- _write_message(self.channel, request)
- response = _read_message(self.channel)
- if response["result"] == "ok":
- return response["content"]["task_id"]
- else:
- reason = "unknown"
- if "request_error" in response:
- reason = response["request_error"]
+ try:
+ response = self.call_method(request)
+ return response.task_id
+ except Exception as e:
+ reason = str(e)
raise TeaclaveException(f"Failed to create task ({reason})")
def assign_data_to_task(self, task_id: str, inputs: List[DataMap],
@@ -959,14 +954,10 @@
self.check_metadata()
self.check_channel()
request = AssignDataRequest(self.metadata, task_id, inputs, outputs)
- _write_message(self.channel, request)
- response = _read_message(self.channel)
- if response["result"] == "ok":
- pass
- else:
- reason = "unknown"
- if "request_error" in response:
- reason = response["request_error"]
+ try:
+ self.call_method(request)
+ except Exception as e:
+ reason = str(e)
raise TeaclaveException(
f"Failed to assign data to task ({reason})")
@@ -974,137 +965,83 @@
self.check_metadata()
self.check_channel()
request = ApproveTaskRequest(self.metadata, task_id)
- _write_message(self.channel, request)
- response = _read_message(self.channel)
- if response["result"] == "ok":
- pass
- else:
- reason = "unknown"
- if "request_error" in response:
- reason = response["request_error"]
+ try:
+ self.call_method(request)
+ except Exception as e:
+ reason = str(e)
raise TeaclaveException(f"Failed to approve task ({reason})")
def invoke_task(self, task_id: str):
self.check_metadata()
self.check_channel()
request = InvokeTaskRequest(self.metadata, task_id)
- _write_message(self.channel, request)
- response = _read_message(self.channel)
- if response["result"] == "ok":
- pass
- else:
- reason = "unknown"
- if "request_error" in response:
- reason = response["request_error"]
+ try:
+ self.call_method(request)
+ except Exception as e:
+ reason = str(e)
raise TeaclaveException(f"Failed to invoke task ({reason})")
def cancel_task(self, task_id: str):
self.check_metadata()
self.check_channel()
request = CancelTaskRequest(self.metadata, task_id)
- _write_message(self.channel, request)
- response = _read_message(self.channel)
- if response["result"] == "ok":
- pass
- else:
- reason = "unknown"
- if "request_error" in response:
- reason = response["request_error"]
+ try:
+ self.call_method(request)
+ except Exception as e:
+ reason = str(e)
raise TeaclaveException(f"Failed to cancel task ({reason})")
- def get_task(self, task_id: str) -> dict:
+ def get_task(self, task_id: str):
self.check_metadata()
self.check_channel()
request = GetTaskRequest(self.metadata, task_id)
-
- _write_message(self.channel, request)
- response = _read_message(self.channel)
- if response["result"] != "ok":
- reason = "unknown"
- if "request_error" in response:
- reason = response["request_error"]
+ try:
+ response = self.call_method(request)
+ return MessageToDict(response,
+ preserving_proto_field_name=True,
+ use_integers_for_enums=True)
+ except Exception as e:
+ reason = str(e)
raise TeaclaveException(f"Failed to get task result ({reason})")
- return response["content"]
def get_task_result(self, task_id: str):
self.check_metadata()
self.check_channel()
request = GetTaskRequest(self.metadata, task_id)
-
while True:
- _write_message(self.channel, request)
- response = _read_message(self.channel)
- if response["result"] != "ok":
- reason = "unknown"
- if "request_error" in response:
- reason = response["request_error"]
+ try:
+ time.sleep(1)
+ response = self.call_method(request)
+ if response.status == TaskStatus.Finished:
+ break
+ elif response.status == TaskStatus.Canceled:
+ raise TeaclaveException("Task Canceled, Error: " +
+ response.result.Err.reason)
+ elif response.status == TaskStatus.Failed:
+ raise TeaclaveException("Task Failed, Error: " +
+ response.result.Err.reason)
+ except Exception as e:
+ reason = str(e)
raise TeaclaveException(
f"Failed to get task result ({reason})")
- time.sleep(1)
- if response["content"]["status"] == TaskStatus.Finished:
- break
- elif response["content"]["status"] == TaskStatus.Canceled:
- raise TeaclaveException(
- "Task Canceled, Error: " +
- response["content"]["result"]["result"]["Err"]["reason"])
- elif response["content"]["status"] == TaskStatus.Failed:
- raise TeaclaveException(
- "Task Failed, Error: " +
- response["content"]["result"]["result"]["Err"]["reason"])
- return response["content"]["result"]["result"]["Ok"]["return_value"]
+ return response.result.Ok.return_value
def get_output_cmac_by_tag(self, task_id: str, tag: str):
self.check_metadata()
self.check_channel()
request = GetTaskRequest(self.metadata, task_id)
while True:
- _write_message(self.channel, request)
- response = _read_message(self.channel)
- if response["result"] != "ok":
- reason = "unknown"
- if "request_error" in response:
- reason = response["request_error"]
+ try:
+ time.sleep(1)
+ response = self.call_method(request)
+ if response.status == TaskStatus.Finished:
+ break
+ except Exception as e:
+ reason = str(e)
raise TeaclaveException(
- f"Failed to get output cmac by tag ({reason})")
- time.sleep(1)
- if response["content"]["status"] == TaskStatus.Finished:
- break
-
- return response["content"]["result"]["result"]["Ok"]["tags_map"][tag]
-
-
-def _write_message(sock: ssl.SSLSocket, message: Any):
-
- class RequestEncoder(json.JSONEncoder):
-
- def default(self, o):
- if isinstance(o, Request):
- request = o.__dict__["request"]
- j = {}
- j["message"] = {}
- j["message"][request] = {}
- for k, v in o.__dict__.items():
- if k == "metadata": j[k] = v
- elif k == "request": continue
- else: j["message"][request][k] = v
- return j
- else:
- return o.__dict__
-
- message = json.dumps(message, cls=RequestEncoder,
- separators=(',', ':')).encode()
- sock.sendall(struct.pack(">Q", len(message)))
- sock.sendall(message)
-
-
-def _read_message(sock: ssl.SSLSocket):
- response_len = struct.unpack(">Q", sock.read(8))
- raw = bytearray()
- total_recv = 0
- while total_recv < response_len[0]:
- data = sock.recv()
- total_recv += len(data)
- raw += data
- response = json.loads(raw)
- return response
+ f"Failed to get task result ({reason})")
+ response = MessageToDict(response,
+ preserving_proto_field_name=True,
+ use_integers_for_enums=True)
+ return base64.b64decode(response["result"]["Ok"]["tags_map"][tag])
diff --git a/tests/scripts/functional_tests.py b/tests/scripts/functional_tests.py
index b7e248b..04bb223 100755
--- a/tests/scripts/functional_tests.py
+++ b/tests/scripts/functional_tests.py
@@ -33,6 +33,14 @@
from OpenSSL.crypto import X509Store, X509StoreContext
from OpenSSL import crypto
+import h2.connection
+import h2.events
+
+from io import BytesIO
+from h2.config import H2Configuration
+from urllib.parse import unquote
+from teaclave_authentication_service_pb2 import UserLoginRequest, UserLoginResponse
+
HOSTNAME = 'localhost'
AUTHENTICATION_SERVICE_ADDRESS = (HOSTNAME, 7776)
CONTEXT = ssl._create_unverified_context()
@@ -52,19 +60,6 @@
ENCLAVE_INFO_PATH = "../../release/tests/enclave_info.toml"
-def write_message(sock, message):
- message = json.dumps(message)
- message = message.encode()
- sock.write(struct.pack(">Q", len(message)))
- sock.write(message)
-
-
-def read_message(sock):
- response_len = struct.unpack(">Q", sock.read(8))
- response = sock.read(response_len[0])
- return response
-
-
def verify_report(cert, endpoint_name):
def load_certificates(pem_bytes):
@@ -121,50 +116,114 @@
raise Exception("mr_signer error")
+def encode_message(message):
+ message_bin = message.SerializeToString()
+ header = struct.pack('?', False) + struct.pack('>I', len(message_bin))
+ return header + message_bin
+
+
+def decode_message(message_bin, message_type):
+ f = BytesIO(message_bin)
+ meta = f.read(5)
+ message_len = struct.unpack('>I', meta[1:])[0]
+ message_body = f.read(message_len)
+ message = message_type.FromString(message_body)
+ return message
+
+
class TestAuthenticationService(unittest.TestCase):
def setUp(self):
sock = socket.create_connection(AUTHENTICATION_SERVICE_ADDRESS)
+ CONTEXT.set_alpn_protocols(['h2'])
self.socket = CONTEXT.wrap_socket(sock, server_hostname=HOSTNAME)
cert = self.socket.getpeercert(binary_form=True)
verify_report(cert, "authentication")
+ config = H2Configuration(client_side=True, header_encoding='ascii')
+ self.connection = h2.connection.H2Connection(config)
+ self.connection.initiate_connection()
+ self.socket.sendall(self.connection.data_to_send())
+ self.stream_id = 1
+
+ def set_headers(self, method_path):
+ headers = [(':method', 'POST'), (':path', method_path),
+ (':authority', HOSTNAME), (':scheme', 'https'),
+ ('content-type', 'application/grpc')]
+ return headers
+
+ def send_message(self, message, method_path):
+ headers = self.set_headers(method_path)
+ self.connection.send_headers(self.stream_id, headers)
+ message_data = encode_message(message)
+ self.connection.send_data(self.stream_id,
+ message_data,
+ end_stream=True)
+ self.socket.sendall(self.connection.data_to_send())
+
+ def recv_message(self):
+ body = None
+ headers = None
+ response_stream_ended = False
+ max_frame_size = self.connection.max_outbound_frame_size
+ print(max_frame_size)
+ while not response_stream_ended:
+ # read raw data from the socket
+ data = self.socket.recv(max_frame_size)
+ if not data:
+ break
+
+ # feed raw data into h2, and process resulting events
+ events = self.connection.receive_data(data)
+ for event in events:
+ if isinstance(event, h2.events.ResponseReceived):
+ headers = dict(event.headers)
+ if isinstance(event, h2.events.DataReceived):
+ # update flow control so the server doesn't starve us
+ self.connection.acknowledge_received_data(
+ event.flow_controlled_length, event.stream_id)
+ # more response body data received
+ body += event.data
+ if isinstance(event, h2.events.StreamEnded):
+ # response body completed, let's exit the loop
+ response_stream_ended = True
+ break
+ # send any pending data to the server
+ self.socket.sendall(self.connection.data_to_send())
+ return (headers, body)
def tearDown(self):
+ self.connection.close_connection()
+ self.socket.sendall(self.connection.data_to_send())
self.socket.close()
def test_invalid_request(self):
+ path = '/teaclave_authentication_service_proto.TeaclaveAuthenticationApi/InvalidRequest'
user_id = "invalid_id"
user_password = "invalid_password"
- message = {
- "invalid_request": "user_login",
- "id": user_id,
- "password": user_password
- }
- write_message(self.socket, message)
+ message = UserLoginRequest(id=user_id, password=user_password)
+ self.send_message(message, path)
- response = read_message(self.socket)
- self.assertEqual(
- response, b'{"result":"err","request_error":"invalid request"}')
+ (headers, response) = self.recv_message()
+ self.assertEqual(response, None)
+ # https://grpc.github.io/grpc/core/md_doc_statuscodes.html
+ # grpc status UNIMPLEMENTED: 12
+ self.assertEqual(headers['grpc-status'], '12')
def test_login_permission_denied(self):
+ path = '/teaclave_authentication_service_proto.TeaclaveAuthenticationApi/UserLogin'
user_id = "invalid_id"
user_password = "invalid_password"
- message = {
- "message": {
- "user_login": {
- "id": user_id,
- "password": user_password
- }
- }
- }
- write_message(self.socket, message)
-
- response = read_message(self.socket)
- self.assertEqual(
- response,
- b'{"result":"err","request_error":"authentication failed"}')
+ message = UserLoginRequest(id=user_id, password=user_password)
+ self.send_message(message, path)
+ (headers, body) = self.recv_message()
+ self.assertEqual(body, None)
+ self.assertEqual(headers['grpc-status'], '16')
+ message = unquote(headers['grpc-message'],
+ encoding='utf-8',
+ errors='replace')
+ self.assertEqual(message, 'authentication failed')
if __name__ == '__main__':