| // Copyright Istio Authors |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| package util |
| |
| import ( |
| "crypto/ecdsa" |
| "crypto/rsa" |
| "crypto/x509" |
| "encoding/pem" |
| "fmt" |
| "reflect" |
| "sort" |
| "strings" |
| "time" |
| ) |
| |
| // VerifyFields contains the certificate fields to verify in the test. |
| type VerifyFields struct { |
| NotBefore time.Time |
| TTL time.Duration // NotAfter - NotBefore |
| ExtKeyUsage []x509.ExtKeyUsage |
| KeyUsage x509.KeyUsage |
| IsCA bool |
| Org string |
| CommonName string |
| Host string |
| } |
| |
| // VerifyCertificate verifies a given PEM encoded certificate by |
| // - building one or more chains from the certificate to a root certificate; |
| // - checking fields are set as expected. |
| func VerifyCertificate(privPem []byte, certChainPem []byte, rootCertPem []byte, expectedFields *VerifyFields) error { |
| roots := x509.NewCertPool() |
| if rootCertPem != nil { |
| if ok := roots.AppendCertsFromPEM(rootCertPem); !ok { |
| return fmt.Errorf("failed to parse root certificate") |
| } |
| } |
| |
| intermediates := x509.NewCertPool() |
| if ok := intermediates.AppendCertsFromPEM(certChainPem); !ok { |
| return fmt.Errorf("failed to parse certificate chain") |
| } |
| |
| cert, err := ParsePemEncodedCertificate(certChainPem) |
| if err != nil { |
| return err |
| } |
| |
| opts := x509.VerifyOptions{ |
| Intermediates: intermediates, |
| Roots: roots, |
| } |
| host := "" |
| if expectedFields != nil { |
| host = expectedFields.Host |
| san := host |
| // uri scheme is currently not supported in go VerifyOptions. We verify |
| // this uri at the end as a special case. |
| if strings.HasPrefix(host, "spiffe") { |
| san = "" |
| } |
| opts.DNSName = san |
| } |
| opts.KeyUsages = append(opts.KeyUsages, x509.ExtKeyUsageAny) |
| |
| if _, err = cert.Verify(opts); err != nil { |
| return fmt.Errorf("failed to verify certificate: " + err.Error()) |
| } |
| if privPem != nil { |
| priv, err := ParsePemEncodedKey(privPem) |
| if err != nil { |
| return err |
| } |
| |
| privRSAKey, privRSAOk := priv.(*rsa.PrivateKey) |
| pubRSAKey, pubRSAOk := cert.PublicKey.(*rsa.PublicKey) |
| |
| privECKey, privECOk := priv.(*ecdsa.PrivateKey) |
| pubECKey, pubECOk := cert.PublicKey.(*ecdsa.PublicKey) |
| |
| rsaMatch := privRSAOk && pubRSAOk |
| ecMatch := privECOk && pubECOk |
| |
| if rsaMatch { |
| if !reflect.DeepEqual(privRSAKey.PublicKey, *pubRSAKey) { |
| return fmt.Errorf("the generated private RSA key and cert doesn't match") |
| } |
| } else if ecMatch { |
| if !reflect.DeepEqual(privECKey.PublicKey, *pubECKey) { |
| return fmt.Errorf("the generated private EC key and cert doesn't match") |
| } |
| } else { |
| return fmt.Errorf("algorithms for private key and cert do not match") |
| } |
| } |
| if strings.HasPrefix(host, "spiffe") { |
| matchHost := false |
| ids, err := ExtractIDs(cert.Extensions) |
| if err != nil { |
| return err |
| } |
| for _, id := range ids { |
| if strings.HasSuffix(id, host) { |
| matchHost = true |
| break |
| } |
| } |
| if !matchHost { |
| return fmt.Errorf("the certificate doesn't have the expected SAN for: %s", host) |
| } |
| } |
| if expectedFields != nil { |
| if nb := expectedFields.NotBefore; !nb.IsZero() && !nb.Equal(cert.NotBefore) { |
| return fmt.Errorf("unexpected value for 'NotBefore' field: want %v but got %v", nb, cert.NotBefore) |
| } |
| |
| if ttl := expectedFields.TTL; ttl != 0 && ttl != (cert.NotAfter.Sub(cert.NotBefore)) { |
| return fmt.Errorf("unexpected value for 'NotAfter' - 'NotBefore': want %v but got %v", ttl, cert.NotAfter.Sub(cert.NotBefore)) |
| } |
| |
| if eku := sortExtKeyUsage(expectedFields.ExtKeyUsage); !reflect.DeepEqual(eku, sortExtKeyUsage(cert.ExtKeyUsage)) { |
| return fmt.Errorf("unexpected value for 'ExtKeyUsage' field: want %v but got %v", eku, cert.ExtKeyUsage) |
| } |
| |
| if ku := expectedFields.KeyUsage; ku != cert.KeyUsage { |
| return fmt.Errorf("unexpected value for 'KeyUsage' field: want %v but got %v", ku, cert.KeyUsage) |
| } |
| |
| if isCA := expectedFields.IsCA; isCA != cert.IsCA { |
| return fmt.Errorf("unexpected value for 'IsCA' field: want %t but got %t", isCA, cert.IsCA) |
| } |
| |
| if org := expectedFields.Org; org != "" && !reflect.DeepEqual([]string{org}, cert.Issuer.Organization) { |
| return fmt.Errorf("unexpected value for 'Organization' field: want %v but got %v", |
| []string{org}, cert.Issuer.Organization) |
| } |
| |
| if cn := expectedFields.CommonName; cn != cert.Subject.CommonName { |
| return fmt.Errorf("unexpected value for 'CommonName' field: want %v but got %v", |
| cn, cert.Subject.CommonName) |
| } |
| } |
| return nil |
| } |
| |
| func sortExtKeyUsage(extKeyUsage []x509.ExtKeyUsage) []int { |
| data := make([]int, len(extKeyUsage)) |
| for i := range extKeyUsage { |
| data[i] = int(extKeyUsage[i]) |
| } |
| sort.Ints(data) |
| return data |
| } |
| |
| // FindRootCertFromCertificateChainBytes find the root cert from cert chain |
| func FindRootCertFromCertificateChainBytes(certBytes []byte) ([]byte, error) { |
| var block *pem.Block |
| cert := []byte{} |
| for { |
| block, certBytes = pem.Decode(certBytes) |
| if len(certBytes) == 0 { |
| break |
| } |
| if block == nil { |
| return nil, fmt.Errorf("error decoding certificate") |
| } |
| _, err := x509.ParseCertificate(block.Bytes) |
| if err != nil { |
| return nil, fmt.Errorf("error parsing TLS certificate: %s", err.Error()) |
| } |
| cert = certBytes |
| } |
| rootBlock, _ := pem.Decode(cert) |
| if rootBlock == nil { |
| return nil, nil |
| } |
| rootCert, err := x509.ParseCertificate(rootBlock.Bytes) |
| if err != nil { |
| return nil, fmt.Errorf("error parsing root certificate: %s", err.Error()) |
| } |
| if !rootCert.IsCA { |
| return nil, fmt.Errorf("found root cert is not a ca type cert: %v", rootCert) |
| } |
| |
| return cert, nil |
| } |