blob: b7e248b9b992197f4a4c2ee33a60a1782652d14b [file] [log] [blame]
#!/usr/bin/env python3
# Licensed to the Apache Software Foundation (ASF) under one
# or more contributor license agreements. See the NOTICE file
# distributed with this work for additional information
# regarding copyright ownership. The ASF licenses this file
# to you under the Apache License, Version 2.0 (the
# "License"); you may not use this file except in compliance
# with the License. You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing,
# software distributed under the License is distributed on an
# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
import unittest
import socket
import struct
import ssl
import json
import base64
import toml
import os
from cryptography import x509
from cryptography.hazmat.backends import default_backend
from OpenSSL.crypto import load_certificate, FILETYPE_PEM, FILETYPE_ASN1
from OpenSSL.crypto import X509Store, X509StoreContext
from OpenSSL import crypto
HOSTNAME = 'localhost'
AUTHENTICATION_SERVICE_ADDRESS = (HOSTNAME, 7776)
CONTEXT = ssl._create_unverified_context()
if os.environ.get('DCAP'):
AS_ROOT_CERT_FILENAME = "dcap_root_ca_cert.pem"
else:
AS_ROOT_CERT_FILENAME = "ias_root_ca_cert.pem"
if os.environ.get('TEACLAVE_PROJECT_ROOT'):
AS_ROOT_CA_CERT_PATH = os.environ['TEACLAVE_PROJECT_ROOT'] + \
"/keys/" + AS_ROOT_CERT_FILENAME
ENCLAVE_INFO_PATH = os.environ['TEACLAVE_PROJECT_ROOT'] + \
"/release/tests/enclave_info.toml"
else:
AS_ROOT_CA_CERT_PATH = "../../keys/" + AS_ROOT_CERT_FILENAME
ENCLAVE_INFO_PATH = "../../release/tests/enclave_info.toml"
def write_message(sock, message):
message = json.dumps(message)
message = message.encode()
sock.write(struct.pack(">Q", len(message)))
sock.write(message)
def read_message(sock):
response_len = struct.unpack(">Q", sock.read(8))
response = sock.read(response_len[0])
return response
def verify_report(cert, endpoint_name):
def load_certificates(pem_bytes):
start_line = b'-----BEGIN CERTIFICATE-----'
result = []
cert_slots = pem_bytes.split(start_line)
for single_pem_cert in cert_slots[1:]:
cert = load_certificate(FILETYPE_ASN1,
start_line + single_pem_cert)
result.append(cert)
return result
if os.environ.get('SGX_MODE') == 'SW':
return
cert = x509.load_der_x509_certificate(cert, default_backend())
ext = json.loads(cert.extensions[0].value.value)
report = bytes(ext["report"])
signature = bytes(ext["signature"])
certs = [load_certificate(FILETYPE_ASN1, bytes(c)) for c in ext["certs"]]
# verify signing cert with AS root cert
with open(AS_ROOT_CA_CERT_PATH) as f:
as_root_ca_cert = f.read()
as_root_ca_cert = load_certificate(FILETYPE_PEM, as_root_ca_cert)
store = X509Store()
store.add_cert(as_root_ca_cert)
for c in certs:
store.add_cert(c)
store_ctx = X509StoreContext(store, as_root_ca_cert)
store_ctx.verify_certificate()
# verify report's signature
crypto.verify(certs[0], signature, bytes(ext["report"]), 'sha256')
report = json.loads(report)
quote = report['isvEnclaveQuoteBody']
quote = base64.b64decode(quote)
# get mr_enclave and mr_signer from the quote
mr_enclave = quote[112:112 + 32].hex()
mr_signer = quote[176:176 + 32].hex()
# get enclave_info
enclave_info = toml.load(ENCLAVE_INFO_PATH)
# verify mr_enclave and mr_signer
enclave_name = "teaclave_" + endpoint_name + "_service"
if mr_enclave != enclave_info[enclave_name]["mr_enclave"]:
raise Exception("mr_enclave error")
if mr_signer != enclave_info[enclave_name]["mr_signer"]:
raise Exception("mr_signer error")
class TestAuthenticationService(unittest.TestCase):
def setUp(self):
sock = socket.create_connection(AUTHENTICATION_SERVICE_ADDRESS)
self.socket = CONTEXT.wrap_socket(sock, server_hostname=HOSTNAME)
cert = self.socket.getpeercert(binary_form=True)
verify_report(cert, "authentication")
def tearDown(self):
self.socket.close()
def test_invalid_request(self):
user_id = "invalid_id"
user_password = "invalid_password"
message = {
"invalid_request": "user_login",
"id": user_id,
"password": user_password
}
write_message(self.socket, message)
response = read_message(self.socket)
self.assertEqual(
response, b'{"result":"err","request_error":"invalid request"}')
def test_login_permission_denied(self):
user_id = "invalid_id"
user_password = "invalid_password"
message = {
"message": {
"user_login": {
"id": user_id,
"password": user_password
}
}
}
write_message(self.socket, message)
response = read_message(self.socket)
self.assertEqual(
response,
b'{"result":"err","request_error":"authentication failed"}')
if __name__ == '__main__':
unittest.main()