blob: 5519f16188bf46866da6b45ac680450a0b55071e [file] [log] [blame]
import os
import re
from datetime import timedelta, datetime
from typing import List, Any, Optional
from cryptography import x509
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric import ec, rsa
from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePrivateKey
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey
from cryptography.hazmat.primitives.serialization import Encoding, PrivateFormat, NoEncryption, load_pem_private_key
from cryptography.x509 import ExtendedKeyUsageOID, NameOID
EC_SUPPORTED = {}
EC_SUPPORTED.update([(curve.name.upper(), curve) for curve in [
ec.SECP192R1,
ec.SECP224R1,
ec.SECP256R1,
ec.SECP384R1,
]])
def _private_key(key_type):
if isinstance(key_type, str):
key_type = key_type.upper()
m = re.match(r'^(RSA)?(\d+)$', key_type)
if m:
key_type = int(m.group(2))
if isinstance(key_type, int):
return rsa.generate_private_key(
public_exponent=65537,
key_size=key_type,
backend=default_backend()
)
if not isinstance(key_type, ec.EllipticCurve) and key_type in EC_SUPPORTED:
key_type = EC_SUPPORTED[key_type]
return ec.generate_private_key(
curve=key_type,
backend=default_backend()
)
class CertificateSpec:
def __init__(self, name: str = None, domains: List[str] = None,
email: str = None,
key_type: str = None, single_file: bool = False,
valid_from: timedelta = timedelta(days=-1),
valid_to: timedelta = timedelta(days=89),
client: bool = False,
sub_specs: List['CertificateSpec'] = None):
self._name = name
self.domains = domains
self.client = client
self.email = email
self.key_type = key_type
self.single_file = single_file
self.valid_from = valid_from
self.valid_to = valid_to
self.sub_specs = sub_specs
@property
def name(self) -> Optional[str]:
if self._name:
return self._name
elif self.domains:
return self.domains[0]
return None
@property
def type(self) -> Optional[str]:
if self.domains and len(self.domains):
return "server"
elif self.client:
return "client"
elif self.name:
return "ca"
return None
class Credentials:
def __init__(self, name: str, cert: Any, pkey: Any, issuer: 'Credentials' = None):
self._name = name
self._cert = cert
self._pkey = pkey
self._issuer = issuer
self._cert_file = None
self._pkey_file = None
self._store = None
@property
def name(self) -> str:
return self._name
@property
def subject(self) -> x509.Name:
return self._cert.subject
@property
def key_type(self):
if isinstance(self._pkey, RSAPrivateKey):
return f"rsa{self._pkey.key_size}"
elif isinstance(self._pkey, EllipticCurvePrivateKey):
return f"{self._pkey.curve.name}"
else:
raise Exception(f"unknown key type: {self._pkey}")
@property
def private_key(self) -> Any:
return self._pkey
@property
def certificate(self) -> Any:
return self._cert
@property
def cert_pem(self) -> bytes:
return self._cert.public_bytes(Encoding.PEM)
@property
def pkey_pem(self) -> bytes:
return self._pkey.private_bytes(
Encoding.PEM,
PrivateFormat.TraditionalOpenSSL if self.key_type.startswith('rsa') else PrivateFormat.PKCS8,
NoEncryption())
@property
def issuer(self) -> Optional['Credentials']:
return self._issuer
def set_store(self, store: 'CertStore'):
self._store = store
def set_files(self, cert_file: str, pkey_file: str = None):
self._cert_file = cert_file
self._pkey_file = pkey_file
@property
def cert_file(self) -> str:
return self._cert_file
@property
def pkey_file(self) -> Optional[str]:
return self._pkey_file
def get_first(self, name) -> Optional['Credentials']:
creds = self._store.get_credentials_for_name(name) if self._store else []
return creds[0] if len(creds) else None
def get_credentials_for_name(self, name) -> List['Credentials']:
return self._store.get_credentials_for_name(name) if self._store else []
def issue_certs(self, specs: List[CertificateSpec],
chain: List['Credentials'] = None) -> List['Credentials']:
return [self.issue_cert(spec=spec, chain=chain) for spec in specs]
def issue_cert(self, spec: CertificateSpec, chain: List['Credentials'] = None) -> 'Credentials':
key_type = spec.key_type if spec.key_type else self.key_type
creds = None
if self._store:
creds = self._store.load_credentials(
name=spec.name, key_type=key_type, single_file=spec.single_file, issuer=self)
if creds is None:
creds = HttpdTestCA.create_credentials(spec=spec, issuer=self, key_type=key_type,
valid_from=spec.valid_from, valid_to=spec.valid_to)
if self._store:
self._store.save(creds, single_file=spec.single_file)
if spec.type == "ca":
self._store.save_chain(creds, "ca", with_root=True)
if spec.sub_specs:
if self._store:
sub_store = CertStore(fpath=os.path.join(self._store.path, creds.name))
creds.set_store(sub_store)
subchain = chain.copy() if chain else []
subchain.append(self)
creds.issue_certs(spec.sub_specs, chain=subchain)
return creds
class CertStore:
def __init__(self, fpath: str):
self._store_dir = fpath
if not os.path.exists(self._store_dir):
os.makedirs(self._store_dir)
self._creds_by_name = {}
@property
def path(self) -> str:
return self._store_dir
def save(self, creds: Credentials, name: str = None,
chain: List[Credentials] = None,
single_file: bool = False) -> None:
name = name if name is not None else creds.name
cert_file = self.get_cert_file(name=name, key_type=creds.key_type)
pkey_file = self.get_pkey_file(name=name, key_type=creds.key_type)
if single_file:
pkey_file = None
with open(cert_file, "wb") as fd:
fd.write(creds.cert_pem)
if chain:
for c in chain:
fd.write(c.cert_pem)
if pkey_file is None:
fd.write(creds.pkey_pem)
if pkey_file is not None:
with open(pkey_file, "wb") as fd:
fd.write(creds.pkey_pem)
creds.set_files(cert_file, pkey_file)
self._add_credentials(name, creds)
def save_chain(self, creds: Credentials, infix: str, with_root=False):
name = creds.name
chain = [creds]
while creds.issuer is not None:
creds = creds.issuer
chain.append(creds)
if not with_root and len(chain) > 1:
chain = chain[:-1]
chain_file = os.path.join(self._store_dir, f'{name}-{infix}.pem')
with open(chain_file, "wb") as fd:
for c in chain:
fd.write(c.cert_pem)
def _add_credentials(self, name: str, creds: Credentials):
if name not in self._creds_by_name:
self._creds_by_name[name] = []
self._creds_by_name[name].append(creds)
def get_credentials_for_name(self, name) -> List[Credentials]:
return self._creds_by_name[name] if name in self._creds_by_name else []
def get_cert_file(self, name: str, key_type=None) -> str:
key_infix = ".{0}".format(key_type) if key_type is not None else ""
return os.path.join(self._store_dir, f'{name}{key_infix}.cert.pem')
def get_pkey_file(self, name: str, key_type=None) -> str:
key_infix = ".{0}".format(key_type) if key_type is not None else ""
return os.path.join(self._store_dir, f'{name}{key_infix}.pkey.pem')
def load_pem_cert(self, fpath: str) -> x509.Certificate:
with open(fpath) as fd:
return x509.load_pem_x509_certificate("".join(fd.readlines()).encode())
def load_pem_pkey(self, fpath: str):
with open(fpath) as fd:
return load_pem_private_key("".join(fd.readlines()).encode(), password=None)
def load_credentials(self, name: str, key_type=None, single_file: bool = False, issuer: Credentials = None):
cert_file = self.get_cert_file(name=name, key_type=key_type)
pkey_file = cert_file if single_file else self.get_pkey_file(name=name, key_type=key_type)
if os.path.isfile(cert_file) and os.path.isfile(pkey_file):
cert = self.load_pem_cert(cert_file)
pkey = self.load_pem_pkey(pkey_file)
creds = Credentials(name=name, cert=cert, pkey=pkey, issuer=issuer)
creds.set_store(self)
creds.set_files(cert_file, pkey_file)
self._add_credentials(name, creds)
return creds
return None
class HttpdTestCA:
@classmethod
def create_root(cls, name: str, store_dir: str, key_type: str = "rsa2048") -> Credentials:
store = CertStore(fpath=store_dir)
creds = store.load_credentials(name="ca", key_type=key_type, issuer=None)
if creds is None:
creds = HttpdTestCA._make_ca_credentials(name=name, key_type=key_type)
store.save(creds, name="ca")
creds.set_store(store)
return creds
@staticmethod
def create_credentials(spec: CertificateSpec, issuer: Credentials, key_type: Any,
valid_from: timedelta = timedelta(days=-1),
valid_to: timedelta = timedelta(days=89),
) -> Credentials:
"""Create a certificate signed by this CA for the given domains.
:returns: the certificate and private key PEM file paths
"""
if spec.domains and len(spec.domains):
creds = HttpdTestCA._make_server_credentials(name=spec.name, domains=spec.domains,
issuer=issuer, valid_from=valid_from,
valid_to=valid_to, key_type=key_type)
elif spec.client:
creds = HttpdTestCA._make_client_credentials(name=spec.name, issuer=issuer,
email=spec.email, valid_from=valid_from,
valid_to=valid_to, key_type=key_type)
elif spec.name:
creds = HttpdTestCA._make_ca_credentials(name=spec.name, issuer=issuer,
valid_from=valid_from, valid_to=valid_to,
key_type=key_type)
else:
raise Exception(f"unrecognized certificate specification: {spec}")
return creds
@staticmethod
def _make_x509_name(org_name: str = None, common_name: str = None, parent: x509.Name = None) -> x509.Name:
name_pieces = []
if org_name:
oid = NameOID.ORGANIZATIONAL_UNIT_NAME if parent else NameOID.ORGANIZATION_NAME
name_pieces.append(x509.NameAttribute(oid, org_name))
elif common_name:
name_pieces.append(x509.NameAttribute(NameOID.COMMON_NAME, common_name))
if parent:
name_pieces.extend([rdn for rdn in parent])
return x509.Name(name_pieces)
@staticmethod
def _make_csr(
subject: x509.Name,
pkey: Any,
issuer_subject: Optional[Credentials],
valid_from_delta: timedelta = None,
valid_until_delta: timedelta = None
):
pubkey = pkey.public_key()
issuer_subject = issuer_subject if issuer_subject is not None else subject
valid_from = datetime.now()
if valid_until_delta is not None:
valid_from += valid_from_delta
valid_until = datetime.now()
if valid_until_delta is not None:
valid_until += valid_until_delta
return (
x509.CertificateBuilder()
.subject_name(subject)
.issuer_name(issuer_subject)
.public_key(pubkey)
.not_valid_before(valid_from)
.not_valid_after(valid_until)
.serial_number(x509.random_serial_number())
.add_extension(
x509.SubjectKeyIdentifier.from_public_key(pubkey),
critical=False,
)
)
@staticmethod
def _add_ca_usages(csr: Any) -> Any:
return csr.add_extension(
x509.BasicConstraints(ca=True, path_length=9),
critical=True,
).add_extension(
x509.KeyUsage(
digital_signature=True,
content_commitment=False,
key_encipherment=False,
data_encipherment=False,
key_agreement=False,
key_cert_sign=True,
crl_sign=True,
encipher_only=False,
decipher_only=False),
critical=True
).add_extension(
x509.ExtendedKeyUsage([
ExtendedKeyUsageOID.CLIENT_AUTH,
ExtendedKeyUsageOID.SERVER_AUTH,
ExtendedKeyUsageOID.CODE_SIGNING,
]),
critical=True
)
@staticmethod
def _add_leaf_usages(csr: Any, domains: List[str], issuer: Credentials) -> Any:
return csr.add_extension(
x509.BasicConstraints(ca=False, path_length=None),
critical=True,
).add_extension(
x509.AuthorityKeyIdentifier.from_issuer_subject_key_identifier(
issuer.certificate.extensions.get_extension_for_class(
x509.SubjectKeyIdentifier).value),
critical=False
).add_extension(
x509.SubjectAlternativeName([x509.DNSName(domain) for domain in domains]),
critical=True,
).add_extension(
x509.ExtendedKeyUsage([
ExtendedKeyUsageOID.SERVER_AUTH,
]),
critical=True
)
@staticmethod
def _add_client_usages(csr: Any, issuer: Credentials, rfc82name: str = None) -> Any:
cert = csr.add_extension(
x509.BasicConstraints(ca=False, path_length=None),
critical=True,
).add_extension(
x509.AuthorityKeyIdentifier.from_issuer_subject_key_identifier(
issuer.certificate.extensions.get_extension_for_class(
x509.SubjectKeyIdentifier).value),
critical=False
)
if rfc82name:
cert.add_extension(
x509.SubjectAlternativeName([x509.RFC822Name(rfc82name)]),
critical=True,
)
cert.add_extension(
x509.ExtendedKeyUsage([
ExtendedKeyUsageOID.CLIENT_AUTH,
]),
critical=True
)
return cert
@staticmethod
def _make_ca_credentials(name, key_type: Any,
issuer: Credentials = None,
valid_from: timedelta = timedelta(days=-1),
valid_to: timedelta = timedelta(days=89),
) -> Credentials:
pkey = _private_key(key_type=key_type)
if issuer is not None:
issuer_subject = issuer.certificate.subject
issuer_key = issuer.private_key
else:
issuer_subject = None
issuer_key = pkey
subject = HttpdTestCA._make_x509_name(org_name=name, parent=issuer.subject if issuer else None)
csr = HttpdTestCA._make_csr(subject=subject,
issuer_subject=issuer_subject, pkey=pkey,
valid_from_delta=valid_from, valid_until_delta=valid_to)
csr = HttpdTestCA._add_ca_usages(csr)
cert = csr.sign(private_key=issuer_key,
algorithm=hashes.SHA256(),
backend=default_backend())
return Credentials(name=name, cert=cert, pkey=pkey, issuer=issuer)
@staticmethod
def _make_server_credentials(name: str, domains: List[str], issuer: Credentials,
key_type: Any,
valid_from: timedelta = timedelta(days=-1),
valid_to: timedelta = timedelta(days=89),
) -> Credentials:
name = name
pkey = _private_key(key_type=key_type)
subject = HttpdTestCA._make_x509_name(common_name=name, parent=issuer.subject)
csr = HttpdTestCA._make_csr(subject=subject,
issuer_subject=issuer.certificate.subject, pkey=pkey,
valid_from_delta=valid_from, valid_until_delta=valid_to)
csr = HttpdTestCA._add_leaf_usages(csr, domains=domains, issuer=issuer)
cert = csr.sign(private_key=issuer.private_key,
algorithm=hashes.SHA256(),
backend=default_backend())
return Credentials(name=name, cert=cert, pkey=pkey, issuer=issuer)
@staticmethod
def _make_client_credentials(name: str,
issuer: Credentials, email: Optional[str],
key_type: Any,
valid_from: timedelta = timedelta(days=-1),
valid_to: timedelta = timedelta(days=89),
) -> Credentials:
pkey = _private_key(key_type=key_type)
subject = HttpdTestCA._make_x509_name(common_name=name, parent=issuer.subject)
csr = HttpdTestCA._make_csr(subject=subject,
issuer_subject=issuer.certificate.subject, pkey=pkey,
valid_from_delta=valid_from, valid_until_delta=valid_to)
csr = HttpdTestCA._add_client_usages(csr, issuer=issuer, rfc82name=email)
cert = csr.sign(private_key=issuer.private_key,
algorithm=hashes.SHA256(),
backend=default_backend())
return Credentials(name=name, cert=cert, pkey=pkey, issuer=issuer)