blob: c834f551be6933407bf1f25b689e0b854eb48362 [file] [log] [blame]
package deliveryservice
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
import (
"bytes"
"context"
"crypto/ecdsa"
"crypto/rsa"
"crypto/x509"
"database/sql"
"encoding/base64"
"encoding/pem"
"errors"
"fmt"
"net/http"
"strconv"
"strings"
"time"
"github.com/apache/trafficcontrol/lib/go-tc"
"github.com/apache/trafficcontrol/lib/go-util"
"github.com/apache/trafficcontrol/traffic_ops/traffic_ops_golang/api"
"github.com/apache/trafficcontrol/traffic_ops/traffic_ops_golang/dbhelpers"
"github.com/apache/trafficcontrol/traffic_ops/traffic_ops_golang/tenant"
)
const (
PemCertEndMarker = "-----END CERTIFICATE-----"
)
// AddSSLKeys adds the given ssl keys to the given delivery service.
func AddSSLKeys(w http.ResponseWriter, r *http.Request) {
inf, userErr, sysErr, errCode := api.NewInfo(r, nil, nil)
if userErr != nil || sysErr != nil {
api.HandleErr(w, r, inf.Tx.Tx, errCode, userErr, sysErr)
return
}
defer inf.Close()
if !inf.Config.TrafficVaultEnabled {
api.HandleErr(w, r, inf.Tx.Tx, http.StatusInternalServerError, nil, errors.New("adding SSL keys to Traffic Vault for delivery service: Traffic Vault is not configured"))
return
}
req := tc.DeliveryServiceAddSSLKeysReq{}
if err := api.Parse(r.Body, inf.Tx.Tx, &req); err != nil {
api.HandleErr(w, r, inf.Tx.Tx, http.StatusBadRequest, errors.New("parsing request: "+err.Error()), nil)
return
}
if userErr, sysErr, errCode := tenant.Check(inf.User, *req.DeliveryService, inf.Tx.Tx); userErr != nil || sysErr != nil {
api.HandleErr(w, r, inf.Tx.Tx, errCode, userErr, sysErr)
return
}
dsID, cdnID, ok, err := getDSIDAndCDNIDFromName(inf.Tx.Tx, *req.DeliveryService)
if err != nil {
api.HandleErr(w, r, inf.Tx.Tx, http.StatusInternalServerError, nil, errors.New("deliveryservice.AddSSLKeys: getting DS ID and CDN ID from name "+err.Error()))
return
} else if !ok {
api.HandleErr(w, r, inf.Tx.Tx, http.StatusNotFound, errors.New("no DS with name "+*req.DeliveryService), nil)
return
}
userErr, sysErr, statusCode := dbhelpers.CheckIfCurrentUserCanModifyCDNWithID(inf.Tx.Tx, int64(cdnID), inf.User.UserName)
if userErr != nil || sysErr != nil {
api.HandleErr(w, r, inf.Tx.Tx, statusCode, userErr, sysErr)
return
}
// ECDSA keys support is only permitted for DNS delivery services
// Traffic Router (HTTP* delivery service types) do not support ECDSA keys
dsType, dsFound, err := getDSType(inf.Tx.Tx, *req.Key)
allowEC := false
if err == nil && dsFound && dsType.IsDNS() {
allowEC = true
}
certChain, certPrivateKey, isUnknownAuth, isVerifiedChainNotEqual, err := verifyCertKeyPair(req.Certificate.Crt, req.Certificate.Key, "", allowEC)
if err != nil {
api.HandleErr(w, r, inf.Tx.Tx, http.StatusBadRequest, errors.New("verifying certificate: "+err.Error()), nil)
return
}
req.Certificate.Crt = certChain
req.Certificate.Key = certPrivateKey
base64EncodeCertificate(req.Certificate)
authType := ""
if req.AuthType != nil {
authType = *req.AuthType
}
dsSSLKeys := tc.DeliveryServiceSSLKeys{
CDN: *req.CDN,
DeliveryService: *req.DeliveryService,
Hostname: *req.HostName,
Key: *req.Key,
Version: *req.Version,
Certificate: *req.Certificate,
AuthType: authType,
}
if err := inf.Vault.PutDeliveryServiceSSLKeys(dsSSLKeys, inf.Tx.Tx, r.Context()); err != nil {
api.HandleErr(w, r, inf.Tx.Tx, http.StatusInternalServerError, nil, errors.New("putting SSL keys in Traffic Vault for delivery service '"+*req.DeliveryService+"': "+err.Error()))
return
}
if err := updateSSLKeyVersion(*req.DeliveryService, req.Version.ToInt64(), inf.Tx.Tx); err != nil {
api.HandleErr(w, r, inf.Tx.Tx, http.StatusInternalServerError, nil, errors.New("adding SSL keys to delivery service '"+*req.DeliveryService+"': "+err.Error()))
return
}
api.CreateChangeLogRawTx(api.ApiChange, "DS: "+*req.DeliveryService+", ID: "+strconv.Itoa(dsID)+", ACTION: Added/Updated SSL keys", inf.User, inf.Tx.Tx)
if isUnknownAuth {
api.WriteRespAlert(w, r, tc.WarnLevel, "WARNING: SSL keys were successfully added for '"+*req.DeliveryService+"', but the input certificate may be invalid (certificate is signed by an unknown authority)")
return
}
if isVerifiedChainNotEqual {
api.WriteRespAlert(w, r, tc.WarnLevel, "WARNING: SSL keys were successfully added for '"+*req.DeliveryService+"', but the input certificate may be invalid (certificate verification produced a different chain)")
return
}
api.WriteResp(w, r, "Successfully added ssl keys for "+*req.DeliveryService)
}
// GetSSlKeyExpirationInformation gets expiration information for all SSL certificates.
func GetSSlKeyExpirationInformation(w http.ResponseWriter, r *http.Request) {
inf, userErr, sysErr, errCode := api.NewInfo(r, nil, []string{"days"})
if userErr != nil || sysErr != nil {
api.HandleErr(w, r, inf.Tx.Tx, errCode, userErr, sysErr)
return
}
defer inf.Close()
if !inf.Config.TrafficVaultEnabled {
api.HandleErr(w, r, inf.Tx.Tx, http.StatusInternalServerError, nil, errors.New("getting SSL keys expiration information from Traffic Vault: Traffic Vault is not configured"))
return
}
daysParam := 0
if days, ok := inf.IntParams["days"]; ok {
daysParam = days
}
expirationInfos, err := inf.Vault.GetExpirationInformation(inf.Tx.Tx, r.Context(), daysParam)
if err != nil {
api.HandleErr(w, r, inf.Tx.Tx, http.StatusInternalServerError, nil, errors.New("getting SSL keys expiration information from Traffic Vault: "+err.Error()))
return
}
api.WriteResp(w, r, expirationInfos)
}
// GetSSLKeysByXMLID fetches the deliveryservice ssl keys by the specified xmlID. V15 includes expiration date.
func GetSSLKeysByXMLID(w http.ResponseWriter, r *http.Request) {
inf, userErr, sysErr, errCode := api.NewInfo(r, []string{"xmlid"}, nil)
if userErr != nil || sysErr != nil {
api.HandleErr(w, r, inf.Tx.Tx, errCode, userErr, sysErr)
return
}
defer inf.Close()
if !inf.Config.TrafficVaultEnabled {
api.HandleErr(w, r, inf.Tx.Tx, http.StatusInternalServerError, nil, errors.New("getting SSL keys from Traffic Vault by xml id: Traffic Vault is not configured"))
return
}
xmlID := inf.Params["xmlid"]
alerts := tc.Alerts{}
if userErr, sysErr, errCode := tenant.Check(inf.User, xmlID, inf.Tx.Tx); userErr != nil || sysErr != nil {
userErr = api.LogErr(r, errCode, userErr, sysErr)
alerts.AddNewAlert(tc.ErrorLevel, userErr.Error())
api.WriteAlerts(w, r, errCode, alerts)
return
}
keyObjV4, err := getSslKeys(inf, r.Context())
if err != nil {
userErr := api.LogErr(r, http.StatusInternalServerError, nil, err)
alerts.AddNewAlert(tc.ErrorLevel, userErr.Error())
api.WriteAlerts(w, r, http.StatusInternalServerError, alerts)
return
}
var keyObj interface{}
if inf.Version.Major < 4 {
keyObj = keyObjV4.DeliveryServiceSSLKeysV15
} else {
keyObj = keyObjV4
}
if len(alerts.Alerts) == 0 {
api.WriteResp(w, r, keyObj)
} else {
api.WriteAlertsObj(w, r, http.StatusOK, alerts, keyObj)
}
}
func getSslKeys(inf *api.APIInfo, ctx context.Context) (tc.DeliveryServiceSSLKeysV4, error) {
xmlID := inf.Params["xmlid"]
version := inf.Params["version"]
decode := inf.Params["decode"]
keyObjFromTv, ok, err := inf.Vault.GetDeliveryServiceSSLKeys(xmlID, version, inf.Tx.Tx, ctx)
if err != nil {
return tc.DeliveryServiceSSLKeysV4{}, errors.New("getting ssl keys: " + err.Error())
}
keyObj := tc.DeliveryServiceSSLKeysV4{}
if ok {
keyObj.DeliveryServiceSSLKeysV15 = keyObjFromTv
parsedCert := keyObj.Certificate
err = Base64DecodeCertificate(&parsedCert)
if err != nil {
return tc.DeliveryServiceSSLKeysV4{}, errors.New("getting SSL keys for XMLID '" + xmlID + "': " + err.Error())
}
if decode != "" && decode != "0" { // the Perl version checked the decode string as: if ( $decode )
keyObj.Certificate = parsedCert
}
if keyObj.Certificate.Crt != "" && keyObj.Expiration.IsZero() {
exp, sans, err := ParseExpirationAndSansFromCert([]byte(parsedCert.Crt), keyObj.Hostname)
if err != nil {
return tc.DeliveryServiceSSLKeysV4{}, errors.New(xmlID + ": " + err.Error())
}
keyObj.Expiration = exp
keyObj.Sans = sans
}
}
return keyObj, nil
}
// ParseExpirationAndSansFromCert returns the expiration and SANs from a certificate.
func ParseExpirationAndSansFromCert(cert []byte, commonName string) (time.Time, []string, error) {
block, _ := pem.Decode(cert)
if block == nil {
return time.Time{}, []string{}, errors.New("Error decoding cert to parse expiration")
}
x509cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
return time.Time{}, []string{}, errors.New("Error parsing cert to get expiration - " + err.Error())
}
dnsNames := util.RemoveStrFromArray(x509cert.DNSNames, commonName)
return x509cert.NotAfter, dnsNames, nil
}
func Base64DecodeCertificate(cert *tc.DeliveryServiceSSLKeysCertificate) error {
csrDec, err := base64.StdEncoding.DecodeString(cert.CSR)
if err != nil {
return errors.New("base64 decoding csr: " + err.Error())
}
cert.CSR = string(csrDec)
crtDec, err := base64.StdEncoding.DecodeString(cert.Crt)
if err != nil {
return errors.New("base64 decoding crt: " + err.Error())
}
cert.Crt = string(crtDec)
keyDec, err := base64.StdEncoding.DecodeString(cert.Key)
if err != nil {
return errors.New("base64 decoding key: " + err.Error())
}
cert.Key = string(keyDec)
return nil
}
func base64EncodeCertificate(cert *tc.DeliveryServiceSSLKeysCertificate) {
cert.CSR = base64.StdEncoding.EncodeToString([]byte(cert.CSR))
cert.Crt = base64.StdEncoding.EncodeToString([]byte(cert.Crt))
cert.Key = base64.StdEncoding.EncodeToString([]byte(cert.Key))
}
// DeleteSSLKeys deletes a Delivery Service's sslkeys via a DELETE method
func DeleteSSLKeys(w http.ResponseWriter, r *http.Request) {
inf, userErr, sysErr, errCode := api.NewInfo(r, []string{"xmlid"}, nil)
if userErr != nil || sysErr != nil {
api.HandleErr(w, r, inf.Tx.Tx, errCode, userErr, sysErr)
return
}
defer inf.Close()
if !inf.Config.TrafficVaultEnabled {
api.HandleErr(w, r, inf.Tx.Tx, http.StatusInternalServerError, userErr, errors.New("deliveryservice.DeleteSSLKeys: Traffic Vault is not configured"))
return
}
xmlID := inf.Params["xmlid"]
dsID, cdnID, ok, err := getDSIDAndCDNIDFromName(inf.Tx.Tx, xmlID)
if err != nil {
api.HandleErr(w, r, inf.Tx.Tx, http.StatusInternalServerError, nil, errors.New("deliveryservice.DeleteSSLKeys: getting DS ID and CDN ID from name "+err.Error()))
return
} else if !ok {
api.HandleErr(w, r, inf.Tx.Tx, http.StatusNotFound, errors.New("no DS with name "+xmlID), nil)
return
}
userErr, sysErr, statusCode := dbhelpers.CheckIfCurrentUserCanModifyCDNWithID(inf.Tx.Tx, int64(cdnID), inf.User.UserName)
if userErr != nil || sysErr != nil {
api.HandleErr(w, r, inf.Tx.Tx, statusCode, userErr, sysErr)
return
}
if userErr, sysErr, errCode := tenant.Check(inf.User, xmlID, inf.Tx.Tx); userErr != nil || sysErr != nil {
api.HandleErr(w, r, inf.Tx.Tx, errCode, userErr, sysErr)
return
}
if err := inf.Vault.DeleteDeliveryServiceSSLKeys(xmlID, inf.Params["version"], inf.Tx.Tx, r.Context()); err != nil {
api.HandleErr(w, r, inf.Tx.Tx, http.StatusInternalServerError, userErr, errors.New("deliveryservice.DeleteSSLKeys: deleting SSL keys: "+err.Error()))
return
}
api.CreateChangeLogRawTx(api.ApiChange, "DS: "+xmlID+", ID: "+strconv.Itoa(dsID)+", ACTION: Deleted SSL keys", inf.User, inf.Tx.Tx)
api.WriteResp(w, r, "Successfully deleted ssl keys for "+xmlID)
}
func updateSSLKeyVersion(xmlID string, version int64, tx *sql.Tx) error {
q := `UPDATE deliveryservice SET ssl_key_version = $1 WHERE xml_id = $2`
if _, err := tx.Exec(q, version, xmlID); err != nil {
return errors.New("updating delivery service ssl_key_version: " + err.Error())
}
return nil
}
// returns the cdn_id found by domainname.
func getCDNIDByDomainname(domainName string, tx *sql.Tx) (int64, bool, error) {
cdnID := int64(0)
if err := tx.QueryRow(`SELECT id from cdn WHERE domain_name = $1`, domainName).Scan(&cdnID); err != nil {
if err == sql.ErrNoRows {
return 0, false, nil
}
return 0, false, err
}
return cdnID, true, nil
}
// getDSIDAndCDNIDFromName loads the DeliveryService's ID and CDN ID from the database, from the xml_id. Returns whether the delivery service was found, and any error.
func getDSIDAndCDNIDFromName(tx *sql.Tx, xmlID string) (int, int, bool, error) {
id := 0
cdnID := 0
if err := tx.QueryRow(`SELECT id, cdn_id FROM deliveryservice WHERE xml_id = $1`, xmlID).Scan(&id, &cdnID); err != nil {
if err == sql.ErrNoRows {
return id, cdnID, false, nil
}
return id, cdnID, false, fmt.Errorf("querying ID for delivery service ID '%v': %v", xmlID, err)
}
return id, cdnID, true, nil
}
// verify the server certificate chain and return the
// certificate and its chain in the proper order. Returns a verified
// and ordered certificate and CA chain.
// If the cert verification returns UnknownAuthorityError, return true to
// indicate that the certs are signed by an unknown authority (e.g. self-signed). Otherwise, return false.
// If the chain returned from Certificate.Verify() does not match the input chain,
// return true. Otherwise, return false.
func verifyCertKeyPair(pemCertificate string, pemPrivateKey string, rootCA string, allowEC bool) (string, string, bool, bool, error) {
// decode, verify, and order certs for storage
cleanPemPrivateKey := ""
certs := strings.SplitAfter(pemCertificate, PemCertEndMarker)
if len(certs) <= 1 {
return "", "", false, false, errors.New("no certificate chain to verify")
}
// decode and verify the server certificate
block, _ := pem.Decode([]byte(certs[0]))
if block == nil {
return "", "", false, false, errors.New("could not decode pem-encoded server certificate")
}
cert, err := x509.ParseCertificate(block.Bytes)
if err != nil {
return "", "", false, false, errors.New("could not parse the server certificate: " + err.Error())
}
// Common x509 certificate validation
err = commonX509CertificateValidation(cert)
if err != nil {
return "", "", false, false, err
}
switch cert.PublicKeyAlgorithm {
case x509.RSA:
var rsaPrivateKey *rsa.PrivateKey
// RSA is both a digital signature and encryption algorithm, hence the key encipherment
// usage must be indicated in the certificate.
// The keyUsage and extended Key Usage does not exist in version 1 of the x509 specificication.
if cert.Version > 1 && !(cert.KeyUsage&x509.KeyUsageKeyEncipherment > 0) {
return "", "", false, false, errors.New("cert/key (rsa) validation: no keyEncipherment keyUsage extension present in x509v3 server certificate")
}
// Extract the RSA public key from the x509 certificate
certPublicKey, ok := cert.PublicKey.(*rsa.PublicKey)
if !ok || certPublicKey == nil {
return "", "", false, false, errors.New("cert/key (rsa) validation error: could not extract public RSA key from certificate")
}
// Attempt to decode the RSA private key
rsaPrivateKey, cleanPemPrivateKey, err = decodeRSAPrivateKey(pemPrivateKey)
if err != nil {
return "", "", false, false, err
}
// Check RSA private key modulus against the x509 RSA public key modulus
if rsaPrivateKey != nil && certPublicKey != nil && !bytes.Equal(rsaPrivateKey.N.Bytes(), certPublicKey.N.Bytes()) {
return "", "", false, false, errors.New("cert/key (rsa) mismatch error: RSA public N modulus value mismatch")
}
case x509.ECDSA:
var ecdsaPrivateKey *ecdsa.PrivateKey
// Only permit ECDSA support for DNS* DSTypes until the Traffic Router can support it
if !allowEC {
return "", "", false, false, errors.New("cert/key validation error: ECDSA public key algorithm unsupported for non-DNS delivery service type")
}
// DSA and ECDSA is not an encryption algorithm and only a signing algorithm, hence the
// certificate only needs to have the DigitalSignature KeyUsage indicated.
if cert.Version > 1 && !(cert.KeyUsage&x509.KeyUsageDigitalSignature > 0) {
return "", "", false, false, errors.New("cert/key (ecdsa) validation error: no digitalSignature keyUsage extension present in x509v3 server certificate")
}
// Attempt to decode the ECDSA private key
ecdsaPrivateKey, cleanPemPrivateKey, err = decodeECDSAPrivateKey(pemPrivateKey)
if err != nil {
return "", "", false, false, err
}
// Extract the ECDSA public key from the x509 certificate
certPublicKey, ok := cert.PublicKey.(*ecdsa.PublicKey)
if !ok || certPublicKey == nil {
return "", "", false, false, errors.New("cert/key (ecdsa) validation error: could not get extract public ECDSA key from certificate")
}
// Compare the ECDSA curve name contained within the x509.PublicKey against the curve name indicated in the private key
if certPublicKey.Params().Name != ecdsaPrivateKey.Params().Name {
return "", "", false, false, errors.New("cert/key (ecdsa) mismatch error: ECDSA curve name in cert does not match curve name in private key")
}
// Verify that ECDSA public value X matches in both the cert.PublicKey and the private key.
if !bytes.Equal(certPublicKey.X.Bytes(), ecdsaPrivateKey.X.Bytes()) {
return "", "", false, false, errors.New("cert/key (ecdsa) mismatch error: ECDSA public X value mismatch")
}
// Verify that ECDSA public value Y matches in both the cert.PublicKey and the private key.
if !bytes.Equal(certPublicKey.Y.Bytes(), ecdsaPrivateKey.Y.Bytes()) {
return "", "", false, false, errors.New("cert/key (ecdsa) mismatch error: ECDSA public Y value mismatch")
}
case x509.DSA:
return "", "", false, false, errors.New("cert/key validation error: DSA public key algorithm unsupported")
case x509.UnknownPublicKeyAlgorithm:
fallthrough
default:
return "", "", false, false, errors.New("cert/key validation error: Unknown public key algorithm")
}
bundle := ""
for i := 0; i < len(certs)-1; i++ {
bundle += certs[i]
}
intermediatePool := x509.NewCertPool()
if !intermediatePool.AppendCertsFromPEM([]byte(bundle)) {
return "", "", false, false, errors.New("certificate CA bundle is empty")
}
opts := x509.VerifyOptions{
Intermediates: intermediatePool,
}
if rootCA != "" {
// verify the certificate chain.
rootPool := x509.NewCertPool()
if !rootPool.AppendCertsFromPEM([]byte(rootCA)) {
return "", "", false, false, errors.New("unable to parse root CA certificate")
}
opts.Roots = rootPool
}
chain, err := cert.Verify(opts)
if err != nil {
if _, ok := err.(x509.UnknownAuthorityError); ok {
return pemCertificate, cleanPemPrivateKey, true, false, nil
}
return "", "", false, false, errors.New("could not verify the certificate chain: " + err.Error())
}
if len(chain) < 1 {
return "", "", false, false, errors.New("can't find valid chain for cert in file in request")
}
pemEncodedChain := ""
for _, link := range chain[0] {
// Include all certificates in the chain, since verification was successful.
block := &pem.Block{Type: "CERTIFICATE", Bytes: link.Raw}
pemEncodedChain += string(pem.EncodeToMemory(block))
}
if len(pemEncodedChain) < 1 {
return "", "", false, false, errors.New("invalid empty certificate chain in request")
}
if pemEncodedChain != pemCertificate {
return pemCertificate, cleanPemPrivateKey, false, true, nil
}
return pemCertificate, cleanPemPrivateKey, false, false, nil
}
func commonX509CertificateValidation(cert *x509.Certificate) error {
// validate certificate is a server auth certificate if the extension is present
if cert.Version > 1 {
serverAuthExtKeyUsageFound := false
for _, certExtKeyUsage := range cert.ExtKeyUsage {
if certExtKeyUsage == x509.ExtKeyUsageServerAuth {
serverAuthExtKeyUsageFound = true
break
}
}
if !serverAuthExtKeyUsageFound {
return errors.New("certificate (x509v3) validation error: server certificate missing 'serverAuth' extended key usage")
}
}
// ensure that the certificate uses a supported PKI algorithm and a public key is present.
if cert.PublicKey == nil {
return errors.New("certificate validation error: no PKI public key found")
}
if cert.PublicKeyAlgorithm == x509.UnknownPublicKeyAlgorithm {
return errors.New("certificate validation error: unknown PKI algorithm")
}
// ensure that the certificate is signed with supported algorithm
if len(cert.Signature) == 0 {
return errors.New("certificate validation error: no signature found")
}
if cert.SignatureAlgorithm == x509.UnknownSignatureAlgorithm {
return errors.New("certificate validation error: unknown signature algorithm")
}
return nil
}
// Common privateKey validation logic.
// Reject unsupported encrypted private keys
func commonPrivateKeyValidation(block *pem.Block) error {
if block == nil {
return errors.New("private key validation error: could not decode pem-encoded private key")
}
// Check for encrypted keys or other unsupported key types
if strings.Contains(block.Type, "ENCRYPTED") {
return errors.New("private key validation error: encrypted private key not supported - block type: " + block.Type)
}
// Check block headers for encryption.
for _, value := range block.Headers {
if strings.Contains(value, "ENCRYPTED") {
return errors.New("private key validation error: encrypted private key not supported - header: " + value)
}
}
return nil
}
// decode the private key
// check for proper algorithm.
// check for correct number of keys
// return private key object, cleaned private key PEM, or any errors.
func decodeRSAPrivateKey(pemPrivateKey string) (*rsa.PrivateKey, string, error) {
// Remove any white space before decoding
var trimmedPrivateKey = strings.TrimSpace(pemPrivateKey)
// Capture all key decode errors and collapse them at the end
var decodeErrors = make([]error, 0)
// RSA Private Key
var rsaPrivateKey *rsa.PrivateKey = nil
// Check for proper key count before attempting to decode.
blockCount := strings.Count(trimmedPrivateKey, "\n-----END")
if blockCount < 1 {
return nil, "", errors.New("private key validation error: no RSA private key PEM blocks found")
}
if blockCount > 1 {
return nil, "", errors.New("private key validation error: multiple private key PEM blocks found")
}
// Attempt to decode pem encoded text into PEM block.
block, _ := pem.Decode([]byte(trimmedPrivateKey))
// Check that the key was decoded and validate key isn't encrypted and
// other common validation shared between PKI algorithms
err := commonPrivateKeyValidation(block)
if err != nil {
return nil, "", err
}
// Decode PKCS#8 - RSA Private Key
privateKey, err := x509.ParsePKCS8PrivateKey(block.Bytes)
if err != nil {
decodeErrors = append(decodeErrors, errors.New("private key validation error: parse pkcs#8 error: "+err.Error()))
}
// Determine if the privateKey is of the correct type
rsaPrivateKey, ok := privateKey.(*rsa.PrivateKey)
if !ok || rsaPrivateKey == nil {
decodeErrors = append(decodeErrors, fmt.Errorf("private key validation error: incorrect private key type: %T", privateKey))
} else {
return rsaPrivateKey, trimmedPrivateKey, nil
}
// Decode PKCS#1 - RSA Private Key
rsaPrivateKey, err = x509.ParsePKCS1PrivateKey(block.Bytes)
if err != nil || rsaPrivateKey == nil {
decodeErrors = append(decodeErrors, errors.New("private key validation error: parse pkcs#1 error: "+err.Error()))
return nil, "", util.JoinErrsSep(decodeErrors, ", ")
}
return rsaPrivateKey, trimmedPrivateKey, nil
}
// decode the private key
// check for proper algorithm.
// check for correct number of keys
// return private key object, cleaned private key PEM, or any errors.
func decodeECDSAPrivateKey(pemPrivateKey string) (*ecdsa.PrivateKey, string, error) {
var ecdsaPrivateKey *ecdsa.PrivateKey = nil
// Remove any white space before decoding
var trimmedPrivateKey = strings.TrimSpace(pemPrivateKey)
// Capture all key decode errors and collapse them at the end
var decodeErrors = make([]error, 0)
// Check for proper key count before attempting to decode.
// ECDSA keys can have 1 or 2 PEM blocks if the 'EC PARAM' block is included.
var blockCount = strings.Count(trimmedPrivateKey, "\n-----END")
if blockCount < 1 {
return nil, "", errors.New("private key validation error: no EC private key PEM blocks found")
}
if blockCount > 2 {
return nil, "", errors.New("private key validation error: too many EC related PEM blocks found")
}
// Attempt to decode pem encoded text into PEM block.
var pemData = []byte(trimmedPrivateKey)
for len(pemData) > 0 {
var block *pem.Block = nil
// Check for at least one END marker
if strings.Count(string(pemData), "\n-----END") == 0 {
break
}
// Attempt to decode the first PEM Block
block, pemData = pem.Decode(pemData)
if block == nil {
return nil, "", errors.New("private key validation error: could not decode pem-encoded block")
}
// Check that the key was decoded and validate key isn't encrypted and
// other common validation shared between PKI algorithms
err := commonPrivateKeyValidation(block)
if err != nil {
return nil, "", err
}
// Check if this pem block has 'KEY' contained in the type and try to decode it.
if !strings.Contains(block.Type, "KEY") {
continue
}
// First try to parse an EC key the normal way, before attempting PKCS8
ecdsaPrivateKey, err = x509.ParseECPrivateKey(block.Bytes)
if ecdsaPrivateKey == nil || err != nil {
decodeErrors = append(decodeErrors, errors.New("private key validation error: failed to parse EC ANSI X9.62: "+err.Error()))
} else {
return ecdsaPrivateKey, trimmedPrivateKey, nil
}
// Second, try to parse PEM block as a PKCS#8 formatted RSA Private Key.
privateKey, err := x509.ParsePKCS8PrivateKey(block.Bytes)
if err != nil {
decodeErrors = append(decodeErrors, errors.New("private key validation error: parse pkcs#8 error: %s"+err.Error()))
return nil, "", util.JoinErrsSep(decodeErrors, ", ")
}
// Make sure the privateKey is of the correct type (ecdsa.PrivateKey)
ecdsaPrivateKey, ok := privateKey.(*ecdsa.PrivateKey)
if !ok || ecdsaPrivateKey == nil {
decodeErrors = append(decodeErrors, fmt.Errorf("private key validation error: incorrect private key type: %T", privateKey))
return nil, "", util.JoinErrsSep(decodeErrors, ", ")
}
return ecdsaPrivateKey, trimmedPrivateKey, nil
}
return nil, "", errors.New("private key validation error: no ECDSA private keys found")
}