diff options
Diffstat (limited to 'src/ssl/test/runner')
-rw-r--r-- | src/ssl/test/runner/alert.go | 77 | ||||
-rw-r--r-- | src/ssl/test/runner/cert.pem | 15 | ||||
-rw-r--r-- | src/ssl/test/runner/channel_id_key.pem | 5 | ||||
-rw-r--r-- | src/ssl/test/runner/cipher_suites.go | 395 | ||||
-rw-r--r-- | src/ssl/test/runner/common.go | 953 | ||||
-rw-r--r-- | src/ssl/test/runner/conn.go | 1229 | ||||
-rw-r--r-- | src/ssl/test/runner/dtls.go | 342 | ||||
-rw-r--r-- | src/ssl/test/runner/ecdsa_cert.pem | 12 | ||||
-rw-r--r-- | src/ssl/test/runner/ecdsa_key.pem | 8 | ||||
-rw-r--r-- | src/ssl/test/runner/handshake_client.go | 910 | ||||
-rw-r--r-- | src/ssl/test/runner/handshake_messages.go | 1875 | ||||
-rw-r--r-- | src/ssl/test/runner/handshake_server.go | 964 | ||||
-rw-r--r-- | src/ssl/test/runner/key.pem | 15 | ||||
-rw-r--r-- | src/ssl/test/runner/key_agreement.go | 776 | ||||
-rw-r--r-- | src/ssl/test/runner/packet_adapter.go | 101 | ||||
-rw-r--r-- | src/ssl/test/runner/prf.go | 388 | ||||
-rw-r--r-- | src/ssl/test/runner/recordingconn.go | 130 | ||||
-rw-r--r-- | src/ssl/test/runner/runner.go | 2649 | ||||
-rw-r--r-- | src/ssl/test/runner/ticket.go | 221 | ||||
-rw-r--r-- | src/ssl/test/runner/tls.go | 279 |
20 files changed, 11344 insertions, 0 deletions
diff --git a/src/ssl/test/runner/alert.go b/src/ssl/test/runner/alert.go new file mode 100644 index 0000000..b48ab2a --- /dev/null +++ b/src/ssl/test/runner/alert.go @@ -0,0 +1,77 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import "strconv" + +type alert uint8 + +const ( + // alert level + alertLevelWarning = 1 + alertLevelError = 2 +) + +const ( + alertCloseNotify alert = 0 + alertUnexpectedMessage alert = 10 + alertBadRecordMAC alert = 20 + alertDecryptionFailed alert = 21 + alertRecordOverflow alert = 22 + alertDecompressionFailure alert = 30 + alertHandshakeFailure alert = 40 + alertBadCertificate alert = 42 + alertUnsupportedCertificate alert = 43 + alertCertificateRevoked alert = 44 + alertCertificateExpired alert = 45 + alertCertificateUnknown alert = 46 + alertIllegalParameter alert = 47 + alertUnknownCA alert = 48 + alertAccessDenied alert = 49 + alertDecodeError alert = 50 + alertDecryptError alert = 51 + alertProtocolVersion alert = 70 + alertInsufficientSecurity alert = 71 + alertInternalError alert = 80 + alertUserCanceled alert = 90 + alertNoRenegotiation alert = 100 +) + +var alertText = map[alert]string{ + alertCloseNotify: "close notify", + alertUnexpectedMessage: "unexpected message", + alertBadRecordMAC: "bad record MAC", + alertDecryptionFailed: "decryption failed", + alertRecordOverflow: "record overflow", + alertDecompressionFailure: "decompression failure", + alertHandshakeFailure: "handshake failure", + alertBadCertificate: "bad certificate", + alertUnsupportedCertificate: "unsupported certificate", + alertCertificateRevoked: "revoked certificate", + alertCertificateExpired: "expired certificate", + alertCertificateUnknown: "unknown certificate", + alertIllegalParameter: "illegal parameter", + alertUnknownCA: "unknown certificate authority", + alertAccessDenied: "access denied", + alertDecodeError: "error decoding message", + alertDecryptError: "error decrypting message", + alertProtocolVersion: "protocol version not supported", + alertInsufficientSecurity: "insufficient security level", + alertInternalError: "internal error", + alertUserCanceled: "user canceled", + alertNoRenegotiation: "no renegotiation", +} + +func (e alert) String() string { + s, ok := alertText[e] + if ok { + return s + } + return "alert(" + strconv.Itoa(int(e)) + ")" +} + +func (e alert) Error() string { + return e.String() +} diff --git a/src/ssl/test/runner/cert.pem b/src/ssl/test/runner/cert.pem new file mode 100644 index 0000000..4de4f49 --- /dev/null +++ b/src/ssl/test/runner/cert.pem @@ -0,0 +1,15 @@ +-----BEGIN CERTIFICATE----- +MIICWDCCAcGgAwIBAgIJAPuwTC6rEJsMMA0GCSqGSIb3DQEBBQUAMEUxCzAJBgNV +BAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRlcm5ldCBX +aWRnaXRzIFB0eSBMdGQwHhcNMTQwNDIzMjA1MDQwWhcNMTcwNDIyMjA1MDQwWjBF +MQswCQYDVQQGEwJBVTETMBEGA1UECAwKU29tZS1TdGF0ZTEhMB8GA1UECgwYSW50 +ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMIGfMA0GCSqGSIb3DQEBAQUAA4GNADCBiQKB +gQDYK8imMuRi/03z0K1Zi0WnvfFHvwlYeyK9Na6XJYaUoIDAtB92kWdGMdAQhLci +HnAjkXLI6W15OoV3gA/ElRZ1xUpxTMhjP6PyY5wqT5r6y8FxbiiFKKAnHmUcrgfV +W28tQ+0rkLGMryRtrukXOgXBv7gcrmU7G1jC2a7WqmeI8QIDAQABo1AwTjAdBgNV +HQ4EFgQUi3XVrMsIvg4fZbf6Vr5sp3Xaha8wHwYDVR0jBBgwFoAUi3XVrMsIvg4f +Zbf6Vr5sp3Xaha8wDAYDVR0TBAUwAwEB/zANBgkqhkiG9w0BAQUFAAOBgQA76Hht +ldY9avcTGSwbwoiuIqv0jTL1fHFnzy3RHMLDh+Lpvolc5DSrSJHCP5WuK0eeJXhr +T5oQpHL9z/cCDLAKCKRa4uV0fhEdOWBqyR9p8y5jJtye72t6CuFUV5iqcpF4BH4f +j2VNHwsSrJwkD4QUGlUtH7vwnQmyCFxZMmWAJg== +-----END CERTIFICATE----- diff --git a/src/ssl/test/runner/channel_id_key.pem b/src/ssl/test/runner/channel_id_key.pem new file mode 100644 index 0000000..604752b --- /dev/null +++ b/src/ssl/test/runner/channel_id_key.pem @@ -0,0 +1,5 @@ +-----BEGIN EC PRIVATE KEY----- +MHcCAQEEIPwxu50c7LEhVNRYJFRWBUnoaz7JSos96T5hBp4rjyptoAoGCCqGSM49 +AwEHoUQDQgAEzFSVTE5guxJRQ0VbZ8dicPs5e/DT7xpW7Yc9hq0VOchv7cbXuI/T +CwadDjGWX/oaz0ftFqrVmfkwZu+C58ioWg== +-----END EC PRIVATE KEY----- diff --git a/src/ssl/test/runner/cipher_suites.go b/src/ssl/test/runner/cipher_suites.go new file mode 100644 index 0000000..89e75c8 --- /dev/null +++ b/src/ssl/test/runner/cipher_suites.go @@ -0,0 +1,395 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "crypto/aes" + "crypto/cipher" + "crypto/des" + "crypto/hmac" + "crypto/md5" + "crypto/rc4" + "crypto/sha1" + "crypto/sha256" + "crypto/sha512" + "crypto/x509" + "hash" +) + +// a keyAgreement implements the client and server side of a TLS key agreement +// protocol by generating and processing key exchange messages. +type keyAgreement interface { + // On the server side, the first two methods are called in order. + + // In the case that the key agreement protocol doesn't use a + // ServerKeyExchange message, generateServerKeyExchange can return nil, + // nil. + generateServerKeyExchange(*Config, *Certificate, *clientHelloMsg, *serverHelloMsg) (*serverKeyExchangeMsg, error) + processClientKeyExchange(*Config, *Certificate, *clientKeyExchangeMsg, uint16) ([]byte, error) + + // On the client side, the next two methods are called in order. + + // This method may not be called if the server doesn't send a + // ServerKeyExchange message. + processServerKeyExchange(*Config, *clientHelloMsg, *serverHelloMsg, *x509.Certificate, *serverKeyExchangeMsg) error + generateClientKeyExchange(*Config, *clientHelloMsg, *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) +} + +const ( + // suiteECDH indicates that the cipher suite involves elliptic curve + // Diffie-Hellman. This means that it should only be selected when the + // client indicates that it supports ECC with a curve and point format + // that we're happy with. + suiteECDHE = 1 << iota + // suiteECDSA indicates that the cipher suite involves an ECDSA + // signature and therefore may only be selected when the server's + // certificate is ECDSA. If this is not set then the cipher suite is + // RSA based. + suiteECDSA + // suiteTLS12 indicates that the cipher suite should only be advertised + // and accepted when using TLS 1.2. + suiteTLS12 + // suiteSHA384 indicates that the cipher suite uses SHA384 as the + // handshake hash. + suiteSHA384 + // suiteNoDTLS indicates that the cipher suite cannot be used + // in DTLS. + suiteNoDTLS + // suitePSK indicates that the cipher suite authenticates with + // a pre-shared key rather than a server private key. + suitePSK +) + +// A cipherSuite is a specific combination of key agreement, cipher and MAC +// function. All cipher suites currently assume RSA key agreement. +type cipherSuite struct { + id uint16 + // the lengths, in bytes, of the key material needed for each component. + keyLen int + macLen int + ivLen int + ka func(version uint16) keyAgreement + // flags is a bitmask of the suite* values, above. + flags int + cipher func(key, iv []byte, isRead bool) interface{} + mac func(version uint16, macKey []byte) macFunction + aead func(key, fixedNonce []byte) cipher.AEAD +} + +var cipherSuites = []*cipherSuite{ + // Ciphersuite order is chosen so that ECDHE comes before plain RSA + // and RC4 comes before AES (because of the Lucky13 attack). + {TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, 16, 0, 4, ecdheRSAKA, suiteECDHE | suiteTLS12, nil, nil, aeadAESGCM}, + {TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, 16, 0, 4, ecdheECDSAKA, suiteECDHE | suiteECDSA | suiteTLS12, nil, nil, aeadAESGCM}, + {TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384, 32, 0, 4, ecdheRSAKA, suiteECDHE | suiteTLS12 | suiteSHA384, nil, nil, aeadAESGCM}, + {TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384, 32, 0, 4, ecdheECDSAKA, suiteECDHE | suiteECDSA | suiteTLS12 | suiteSHA384, nil, nil, aeadAESGCM}, + {TLS_ECDHE_RSA_WITH_RC4_128_SHA, 16, 20, 0, ecdheRSAKA, suiteECDHE | suiteNoDTLS, cipherRC4, macSHA1, nil}, + {TLS_ECDHE_ECDSA_WITH_RC4_128_SHA, 16, 20, 0, ecdheECDSAKA, suiteECDHE | suiteECDSA | suiteNoDTLS, cipherRC4, macSHA1, nil}, + {TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256, 16, 32, 16, ecdheRSAKA, suiteECDHE | suiteTLS12, cipherAES, macSHA256, nil}, + {TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256, 16, 32, 16, ecdheECDSAKA, suiteECDHE | suiteECDSA | suiteTLS12, cipherAES, macSHA256, nil}, + {TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA, 16, 20, 16, ecdheRSAKA, suiteECDHE, cipherAES, macSHA1, nil}, + {TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA, 16, 20, 16, ecdheECDSAKA, suiteECDHE | suiteECDSA, cipherAES, macSHA1, nil}, + {TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384, 32, 48, 16, ecdheRSAKA, suiteECDHE | suiteTLS12 | suiteSHA384, cipherAES, macSHA384, nil}, + {TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384, 32, 48, 16, ecdheECDSAKA, suiteECDHE | suiteECDSA | suiteTLS12 | suiteSHA384, cipherAES, macSHA384, nil}, + {TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA, 32, 20, 16, ecdheRSAKA, suiteECDHE, cipherAES, macSHA1, nil}, + {TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA, 32, 20, 16, ecdheECDSAKA, suiteECDHE | suiteECDSA, cipherAES, macSHA1, nil}, + {TLS_DHE_RSA_WITH_AES_128_GCM_SHA256, 16, 0, 4, dheRSAKA, suiteTLS12, nil, nil, aeadAESGCM}, + {TLS_DHE_RSA_WITH_AES_256_GCM_SHA384, 32, 0, 4, dheRSAKA, suiteTLS12 | suiteSHA384, nil, nil, aeadAESGCM}, + {TLS_DHE_RSA_WITH_AES_128_CBC_SHA256, 16, 32, 16, dheRSAKA, suiteTLS12, cipherAES, macSHA256, nil}, + {TLS_DHE_RSA_WITH_AES_256_CBC_SHA256, 32, 32, 16, dheRSAKA, suiteTLS12, cipherAES, macSHA256, nil}, + {TLS_DHE_RSA_WITH_AES_128_CBC_SHA, 16, 20, 16, dheRSAKA, 0, cipherAES, macSHA1, nil}, + {TLS_DHE_RSA_WITH_AES_256_CBC_SHA, 32, 20, 16, dheRSAKA, 0, cipherAES, macSHA1, nil}, + {TLS_RSA_WITH_AES_128_GCM_SHA256, 16, 0, 4, rsaKA, suiteTLS12, nil, nil, aeadAESGCM}, + {TLS_RSA_WITH_AES_256_GCM_SHA384, 32, 0, 4, rsaKA, suiteTLS12 | suiteSHA384, nil, nil, aeadAESGCM}, + {TLS_RSA_WITH_RC4_128_SHA, 16, 20, 0, rsaKA, suiteNoDTLS, cipherRC4, macSHA1, nil}, + {TLS_RSA_WITH_RC4_128_MD5, 16, 16, 0, rsaKA, suiteNoDTLS, cipherRC4, macMD5, nil}, + {TLS_RSA_WITH_AES_128_CBC_SHA256, 16, 32, 16, rsaKA, suiteTLS12, cipherAES, macSHA256, nil}, + {TLS_RSA_WITH_AES_256_CBC_SHA256, 32, 32, 16, rsaKA, suiteTLS12, cipherAES, macSHA256, nil}, + {TLS_RSA_WITH_AES_128_CBC_SHA, 16, 20, 16, rsaKA, 0, cipherAES, macSHA1, nil}, + {TLS_RSA_WITH_AES_256_CBC_SHA, 32, 20, 16, rsaKA, 0, cipherAES, macSHA1, nil}, + {TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA, 24, 20, 8, ecdheRSAKA, suiteECDHE, cipher3DES, macSHA1, nil}, + {TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA, 24, 20, 8, dheRSAKA, 0, cipher3DES, macSHA1, nil}, + {TLS_RSA_WITH_3DES_EDE_CBC_SHA, 24, 20, 8, rsaKA, 0, cipher3DES, macSHA1, nil}, + {TLS_ECDHE_PSK_WITH_AES_128_GCM_SHA256, 16, 0, 4, ecdhePSKKA, suiteECDHE | suiteTLS12 | suitePSK, nil, nil, aeadAESGCM}, + {TLS_PSK_WITH_RC4_128_SHA, 16, 20, 0, pskKA, suiteNoDTLS | suitePSK, cipherRC4, macSHA1, nil}, + {TLS_PSK_WITH_AES_128_CBC_SHA, 16, 20, 16, pskKA, suitePSK, cipherAES, macSHA1, nil}, + {TLS_PSK_WITH_AES_256_CBC_SHA, 32, 20, 16, pskKA, suitePSK, cipherAES, macSHA1, nil}, +} + +func cipherRC4(key, iv []byte, isRead bool) interface{} { + cipher, _ := rc4.NewCipher(key) + return cipher +} + +func cipher3DES(key, iv []byte, isRead bool) interface{} { + block, _ := des.NewTripleDESCipher(key) + if isRead { + return cipher.NewCBCDecrypter(block, iv) + } + return cipher.NewCBCEncrypter(block, iv) +} + +func cipherAES(key, iv []byte, isRead bool) interface{} { + block, _ := aes.NewCipher(key) + if isRead { + return cipher.NewCBCDecrypter(block, iv) + } + return cipher.NewCBCEncrypter(block, iv) +} + +// macSHA1 returns a macFunction for the given protocol version. +func macSHA1(version uint16, key []byte) macFunction { + if version == VersionSSL30 { + mac := ssl30MAC{ + h: sha1.New(), + key: make([]byte, len(key)), + } + copy(mac.key, key) + return mac + } + return tls10MAC{hmac.New(sha1.New, key)} +} + +func macMD5(version uint16, key []byte) macFunction { + if version == VersionSSL30 { + mac := ssl30MAC{ + h: md5.New(), + key: make([]byte, len(key)), + } + copy(mac.key, key) + return mac + } + return tls10MAC{hmac.New(md5.New, key)} +} + +func macSHA256(version uint16, key []byte) macFunction { + if version == VersionSSL30 { + mac := ssl30MAC{ + h: sha256.New(), + key: make([]byte, len(key)), + } + copy(mac.key, key) + return mac + } + return tls10MAC{hmac.New(sha256.New, key)} +} + +func macSHA384(version uint16, key []byte) macFunction { + if version == VersionSSL30 { + mac := ssl30MAC{ + h: sha512.New384(), + key: make([]byte, len(key)), + } + copy(mac.key, key) + return mac + } + return tls10MAC{hmac.New(sha512.New384, key)} +} + +type macFunction interface { + Size() int + MAC(digestBuf, seq, header, length, data []byte) []byte +} + +// fixedNonceAEAD wraps an AEAD and prefixes a fixed portion of the nonce to +// each call. +type fixedNonceAEAD struct { + // sealNonce and openNonce are buffers where the larger nonce will be + // constructed. Since a seal and open operation may be running + // concurrently, there is a separate buffer for each. + sealNonce, openNonce []byte + aead cipher.AEAD +} + +func (f *fixedNonceAEAD) NonceSize() int { return 8 } +func (f *fixedNonceAEAD) Overhead() int { return f.aead.Overhead() } + +func (f *fixedNonceAEAD) Seal(out, nonce, plaintext, additionalData []byte) []byte { + copy(f.sealNonce[len(f.sealNonce)-8:], nonce) + return f.aead.Seal(out, f.sealNonce, plaintext, additionalData) +} + +func (f *fixedNonceAEAD) Open(out, nonce, plaintext, additionalData []byte) ([]byte, error) { + copy(f.openNonce[len(f.openNonce)-8:], nonce) + return f.aead.Open(out, f.openNonce, plaintext, additionalData) +} + +func aeadAESGCM(key, fixedNonce []byte) cipher.AEAD { + aes, err := aes.NewCipher(key) + if err != nil { + panic(err) + } + aead, err := cipher.NewGCM(aes) + if err != nil { + panic(err) + } + + nonce1, nonce2 := make([]byte, 12), make([]byte, 12) + copy(nonce1, fixedNonce) + copy(nonce2, fixedNonce) + + return &fixedNonceAEAD{nonce1, nonce2, aead} +} + +// ssl30MAC implements the SSLv3 MAC function, as defined in +// www.mozilla.org/projects/security/pki/nss/ssl/draft302.txt section 5.2.3.1 +type ssl30MAC struct { + h hash.Hash + key []byte +} + +func (s ssl30MAC) Size() int { + return s.h.Size() +} + +var ssl30Pad1 = [48]byte{0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36, 0x36} + +var ssl30Pad2 = [48]byte{0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c, 0x5c} + +func (s ssl30MAC) MAC(digestBuf, seq, header, length, data []byte) []byte { + padLength := 48 + if s.h.Size() == 20 { + padLength = 40 + } + + s.h.Reset() + s.h.Write(s.key) + s.h.Write(ssl30Pad1[:padLength]) + s.h.Write(seq) + s.h.Write(header[:1]) + s.h.Write(length) + s.h.Write(data) + digestBuf = s.h.Sum(digestBuf[:0]) + + s.h.Reset() + s.h.Write(s.key) + s.h.Write(ssl30Pad2[:padLength]) + s.h.Write(digestBuf) + return s.h.Sum(digestBuf[:0]) +} + +// tls10MAC implements the TLS 1.0 MAC function. RFC 2246, section 6.2.3. +type tls10MAC struct { + h hash.Hash +} + +func (s tls10MAC) Size() int { + return s.h.Size() +} + +func (s tls10MAC) MAC(digestBuf, seq, header, length, data []byte) []byte { + s.h.Reset() + s.h.Write(seq) + s.h.Write(header) + s.h.Write(length) + s.h.Write(data) + return s.h.Sum(digestBuf[:0]) +} + +func rsaKA(version uint16) keyAgreement { + return &rsaKeyAgreement{} +} + +func ecdheECDSAKA(version uint16) keyAgreement { + return &ecdheKeyAgreement{ + auth: &signedKeyAgreement{ + sigType: signatureECDSA, + version: version, + }, + } +} + +func ecdheRSAKA(version uint16) keyAgreement { + return &ecdheKeyAgreement{ + auth: &signedKeyAgreement{ + sigType: signatureRSA, + version: version, + }, + } +} + +func dheRSAKA(version uint16) keyAgreement { + return &dheKeyAgreement{ + auth: &signedKeyAgreement{ + sigType: signatureRSA, + version: version, + }, + } +} + +func pskKA(version uint16) keyAgreement { + return &pskKeyAgreement{ + base: &nilKeyAgreement{}, + } +} + +func ecdhePSKKA(version uint16) keyAgreement { + return &pskKeyAgreement{ + base: &ecdheKeyAgreement{ + auth: &nilKeyAgreementAuthentication{}, + }, + } +} + +// mutualCipherSuite returns a cipherSuite given a list of supported +// ciphersuites and the id requested by the peer. +func mutualCipherSuite(have []uint16, want uint16) *cipherSuite { + for _, id := range have { + if id == want { + for _, suite := range cipherSuites { + if suite.id == want { + return suite + } + } + return nil + } + } + return nil +} + +// A list of the possible cipher suite ids. Taken from +// http://www.iana.org/assignments/tls-parameters/tls-parameters.xml +const ( + TLS_RSA_WITH_RC4_128_MD5 uint16 = 0x0004 + TLS_RSA_WITH_RC4_128_SHA uint16 = 0x0005 + TLS_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0x000a + TLS_DHE_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0x0016 + TLS_RSA_WITH_AES_128_CBC_SHA uint16 = 0x002f + TLS_DHE_RSA_WITH_AES_128_CBC_SHA uint16 = 0x0033 + TLS_RSA_WITH_AES_256_CBC_SHA uint16 = 0x0035 + TLS_DHE_RSA_WITH_AES_256_CBC_SHA uint16 = 0x0039 + TLS_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0x003c + TLS_RSA_WITH_AES_256_CBC_SHA256 uint16 = 0x003d + TLS_DHE_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0x0067 + TLS_DHE_RSA_WITH_AES_256_CBC_SHA256 uint16 = 0x006b + TLS_PSK_WITH_RC4_128_SHA uint16 = 0x008a + TLS_PSK_WITH_AES_128_CBC_SHA uint16 = 0x008c + TLS_PSK_WITH_AES_256_CBC_SHA uint16 = 0x008d + TLS_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0x009c + TLS_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0x009d + TLS_DHE_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0x009e + TLS_DHE_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0x009f + TLS_ECDHE_ECDSA_WITH_RC4_128_SHA uint16 = 0xc007 + TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA uint16 = 0xc009 + TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA uint16 = 0xc00a + TLS_ECDHE_RSA_WITH_RC4_128_SHA uint16 = 0xc011 + TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA uint16 = 0xc012 + TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA uint16 = 0xc013 + TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA uint16 = 0xc014 + TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256 uint16 = 0xc023 + TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384 uint16 = 0xc024 + TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256 uint16 = 0xc027 + TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384 uint16 = 0xc028 + TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256 uint16 = 0xc02b + TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384 uint16 = 0xc02c + TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256 uint16 = 0xc02f + TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384 uint16 = 0xc030 + fallbackSCSV uint16 = 0x5600 +) + +// Additional cipher suite IDs, not IANA-assigned. +const ( + TLS_ECDHE_PSK_WITH_AES_128_GCM_SHA256 uint16 = 0xcafe +) diff --git a/src/ssl/test/runner/common.go b/src/ssl/test/runner/common.go new file mode 100644 index 0000000..7aaf9a2 --- /dev/null +++ b/src/ssl/test/runner/common.go @@ -0,0 +1,953 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "container/list" + "crypto" + "crypto/ecdsa" + "crypto/rand" + "crypto/x509" + "fmt" + "io" + "math/big" + "strings" + "sync" + "time" +) + +const ( + VersionSSL30 = 0x0300 + VersionTLS10 = 0x0301 + VersionTLS11 = 0x0302 + VersionTLS12 = 0x0303 +) + +const ( + maxPlaintext = 16384 // maximum plaintext payload length + maxCiphertext = 16384 + 2048 // maximum ciphertext payload length + tlsRecordHeaderLen = 5 // record header length + dtlsRecordHeaderLen = 13 + maxHandshake = 65536 // maximum handshake we support (protocol max is 16 MB) + + minVersion = VersionSSL30 + maxVersion = VersionTLS12 +) + +// TLS record types. +type recordType uint8 + +const ( + recordTypeChangeCipherSpec recordType = 20 + recordTypeAlert recordType = 21 + recordTypeHandshake recordType = 22 + recordTypeApplicationData recordType = 23 +) + +// TLS handshake message types. +const ( + typeHelloRequest uint8 = 0 + typeClientHello uint8 = 1 + typeServerHello uint8 = 2 + typeHelloVerifyRequest uint8 = 3 + typeNewSessionTicket uint8 = 4 + typeCertificate uint8 = 11 + typeServerKeyExchange uint8 = 12 + typeCertificateRequest uint8 = 13 + typeServerHelloDone uint8 = 14 + typeCertificateVerify uint8 = 15 + typeClientKeyExchange uint8 = 16 + typeFinished uint8 = 20 + typeCertificateStatus uint8 = 22 + typeNextProtocol uint8 = 67 // Not IANA assigned + typeEncryptedExtensions uint8 = 203 // Not IANA assigned +) + +// TLS compression types. +const ( + compressionNone uint8 = 0 +) + +// TLS extension numbers +const ( + extensionServerName uint16 = 0 + extensionStatusRequest uint16 = 5 + extensionSupportedCurves uint16 = 10 + extensionSupportedPoints uint16 = 11 + extensionSignatureAlgorithms uint16 = 13 + extensionUseSRTP uint16 = 14 + extensionALPN uint16 = 16 + extensionSignedCertificateTimestamp uint16 = 18 + extensionExtendedMasterSecret uint16 = 23 + extensionSessionTicket uint16 = 35 + extensionNextProtoNeg uint16 = 13172 // not IANA assigned + extensionRenegotiationInfo uint16 = 0xff01 + extensionChannelID uint16 = 30032 // not IANA assigned +) + +// TLS signaling cipher suite values +const ( + scsvRenegotiation uint16 = 0x00ff +) + +// CurveID is the type of a TLS identifier for an elliptic curve. See +// http://www.iana.org/assignments/tls-parameters/tls-parameters.xml#tls-parameters-8 +type CurveID uint16 + +const ( + CurveP256 CurveID = 23 + CurveP384 CurveID = 24 + CurveP521 CurveID = 25 +) + +// TLS Elliptic Curve Point Formats +// http://www.iana.org/assignments/tls-parameters/tls-parameters.xml#tls-parameters-9 +const ( + pointFormatUncompressed uint8 = 0 +) + +// TLS CertificateStatusType (RFC 3546) +const ( + statusTypeOCSP uint8 = 1 +) + +// Certificate types (for certificateRequestMsg) +const ( + CertTypeRSASign = 1 // A certificate containing an RSA key + CertTypeDSSSign = 2 // A certificate containing a DSA key + CertTypeRSAFixedDH = 3 // A certificate containing a static DH key + CertTypeDSSFixedDH = 4 // A certificate containing a static DH key + + // See RFC4492 sections 3 and 5.5. + CertTypeECDSASign = 64 // A certificate containing an ECDSA-capable public key, signed with ECDSA. + CertTypeRSAFixedECDH = 65 // A certificate containing an ECDH-capable public key, signed with RSA. + CertTypeECDSAFixedECDH = 66 // A certificate containing an ECDH-capable public key, signed with ECDSA. + + // Rest of these are reserved by the TLS spec +) + +// Hash functions for TLS 1.2 (See RFC 5246, section A.4.1) +const ( + hashMD5 uint8 = 1 + hashSHA1 uint8 = 2 + hashSHA224 uint8 = 3 + hashSHA256 uint8 = 4 + hashSHA384 uint8 = 5 + hashSHA512 uint8 = 6 +) + +// Signature algorithms for TLS 1.2 (See RFC 5246, section A.4.1) +const ( + signatureRSA uint8 = 1 + signatureECDSA uint8 = 3 +) + +// signatureAndHash mirrors the TLS 1.2, SignatureAndHashAlgorithm struct. See +// RFC 5246, section A.4.1. +type signatureAndHash struct { + signature, hash uint8 +} + +// supportedSKXSignatureAlgorithms contains the signature and hash algorithms +// that the code advertises as supported in a TLS 1.2 ClientHello. +var supportedSKXSignatureAlgorithms = []signatureAndHash{ + {signatureRSA, hashSHA256}, + {signatureECDSA, hashSHA256}, + {signatureRSA, hashSHA1}, + {signatureECDSA, hashSHA1}, +} + +// supportedClientCertSignatureAlgorithms contains the signature and hash +// algorithms that the code advertises as supported in a TLS 1.2 +// CertificateRequest. +var supportedClientCertSignatureAlgorithms = []signatureAndHash{ + {signatureRSA, hashSHA256}, + {signatureECDSA, hashSHA256}, +} + +// SRTP protection profiles (See RFC 5764, section 4.1.2) +const ( + SRTP_AES128_CM_HMAC_SHA1_80 uint16 = 0x0001 + SRTP_AES128_CM_HMAC_SHA1_32 = 0x0002 +) + +// ConnectionState records basic TLS details about the connection. +type ConnectionState struct { + Version uint16 // TLS version used by the connection (e.g. VersionTLS12) + HandshakeComplete bool // TLS handshake is complete + DidResume bool // connection resumes a previous TLS connection + CipherSuite uint16 // cipher suite in use (TLS_RSA_WITH_RC4_128_SHA, ...) + NegotiatedProtocol string // negotiated next protocol (from Config.NextProtos) + NegotiatedProtocolIsMutual bool // negotiated protocol was advertised by server + NegotiatedProtocolFromALPN bool // protocol negotiated with ALPN + ServerName string // server name requested by client, if any (server side only) + PeerCertificates []*x509.Certificate // certificate chain presented by remote peer + VerifiedChains [][]*x509.Certificate // verified chains built from PeerCertificates + ChannelID *ecdsa.PublicKey // the channel ID for this connection + SRTPProtectionProfile uint16 // the negotiated DTLS-SRTP protection profile +} + +// ClientAuthType declares the policy the server will follow for +// TLS Client Authentication. +type ClientAuthType int + +const ( + NoClientCert ClientAuthType = iota + RequestClientCert + RequireAnyClientCert + VerifyClientCertIfGiven + RequireAndVerifyClientCert +) + +// ClientSessionState contains the state needed by clients to resume TLS +// sessions. +type ClientSessionState struct { + sessionId []uint8 // Session ID supplied by the server. nil if the session has a ticket. + sessionTicket []uint8 // Encrypted ticket used for session resumption with server + vers uint16 // SSL/TLS version negotiated for the session + cipherSuite uint16 // Ciphersuite negotiated for the session + masterSecret []byte // MasterSecret generated by client on a full handshake + handshakeHash []byte // Handshake hash for Channel ID purposes. + serverCertificates []*x509.Certificate // Certificate chain presented by the server + extendedMasterSecret bool // Whether an extended master secret was used to generate the session +} + +// ClientSessionCache is a cache of ClientSessionState objects that can be used +// by a client to resume a TLS session with a given server. ClientSessionCache +// implementations should expect to be called concurrently from different +// goroutines. +type ClientSessionCache interface { + // Get searches for a ClientSessionState associated with the given key. + // On return, ok is true if one was found. + Get(sessionKey string) (session *ClientSessionState, ok bool) + + // Put adds the ClientSessionState to the cache with the given key. + Put(sessionKey string, cs *ClientSessionState) +} + +// ServerSessionCache is a cache of sessionState objects that can be used by a +// client to resume a TLS session with a given server. ServerSessionCache +// implementations should expect to be called concurrently from different +// goroutines. +type ServerSessionCache interface { + // Get searches for a sessionState associated with the given session + // ID. On return, ok is true if one was found. + Get(sessionId string) (session *sessionState, ok bool) + + // Put adds the sessionState to the cache with the given session ID. + Put(sessionId string, session *sessionState) +} + +// A Config structure is used to configure a TLS client or server. +// After one has been passed to a TLS function it must not be +// modified. A Config may be reused; the tls package will also not +// modify it. +type Config struct { + // Rand provides the source of entropy for nonces and RSA blinding. + // If Rand is nil, TLS uses the cryptographic random reader in package + // crypto/rand. + // The Reader must be safe for use by multiple goroutines. + Rand io.Reader + + // Time returns the current time as the number of seconds since the epoch. + // If Time is nil, TLS uses time.Now. + Time func() time.Time + + // Certificates contains one or more certificate chains + // to present to the other side of the connection. + // Server configurations must include at least one certificate. + Certificates []Certificate + + // NameToCertificate maps from a certificate name to an element of + // Certificates. Note that a certificate name can be of the form + // '*.example.com' and so doesn't have to be a domain name as such. + // See Config.BuildNameToCertificate + // The nil value causes the first element of Certificates to be used + // for all connections. + NameToCertificate map[string]*Certificate + + // RootCAs defines the set of root certificate authorities + // that clients use when verifying server certificates. + // If RootCAs is nil, TLS uses the host's root CA set. + RootCAs *x509.CertPool + + // NextProtos is a list of supported, application level protocols. + NextProtos []string + + // ServerName is used to verify the hostname on the returned + // certificates unless InsecureSkipVerify is given. It is also included + // in the client's handshake to support virtual hosting. + ServerName string + + // ClientAuth determines the server's policy for + // TLS Client Authentication. The default is NoClientCert. + ClientAuth ClientAuthType + + // ClientCAs defines the set of root certificate authorities + // that servers use if required to verify a client certificate + // by the policy in ClientAuth. + ClientCAs *x509.CertPool + + // ClientCertificateTypes defines the set of allowed client certificate + // types. The default is CertTypeRSASign and CertTypeECDSASign. + ClientCertificateTypes []byte + + // InsecureSkipVerify controls whether a client verifies the + // server's certificate chain and host name. + // If InsecureSkipVerify is true, TLS accepts any certificate + // presented by the server and any host name in that certificate. + // In this mode, TLS is susceptible to man-in-the-middle attacks. + // This should be used only for testing. + InsecureSkipVerify bool + + // CipherSuites is a list of supported cipher suites. If CipherSuites + // is nil, TLS uses a list of suites supported by the implementation. + CipherSuites []uint16 + + // PreferServerCipherSuites controls whether the server selects the + // client's most preferred ciphersuite, or the server's most preferred + // ciphersuite. If true then the server's preference, as expressed in + // the order of elements in CipherSuites, is used. + PreferServerCipherSuites bool + + // SessionTicketsDisabled may be set to true to disable session ticket + // (resumption) support. + SessionTicketsDisabled bool + + // SessionTicketKey is used by TLS servers to provide session + // resumption. See RFC 5077. If zero, it will be filled with + // random data before the first server handshake. + // + // If multiple servers are terminating connections for the same host + // they should all have the same SessionTicketKey. If the + // SessionTicketKey leaks, previously recorded and future TLS + // connections using that key are compromised. + SessionTicketKey [32]byte + + // ClientSessionCache is a cache of ClientSessionState entries + // for TLS session resumption. + ClientSessionCache ClientSessionCache + + // ServerSessionCache is a cache of sessionState entries for TLS session + // resumption. + ServerSessionCache ServerSessionCache + + // MinVersion contains the minimum SSL/TLS version that is acceptable. + // If zero, then SSLv3 is taken as the minimum. + MinVersion uint16 + + // MaxVersion contains the maximum SSL/TLS version that is acceptable. + // If zero, then the maximum version supported by this package is used, + // which is currently TLS 1.2. + MaxVersion uint16 + + // CurvePreferences contains the elliptic curves that will be used in + // an ECDHE handshake, in preference order. If empty, the default will + // be used. + CurvePreferences []CurveID + + // ChannelID contains the ECDSA key for the client to use as + // its TLS Channel ID. + ChannelID *ecdsa.PrivateKey + + // RequestChannelID controls whether the server requests a TLS + // Channel ID. If negotiated, the client's public key is + // returned in the ConnectionState. + RequestChannelID bool + + // PreSharedKey, if not nil, is the pre-shared key to use with + // the PSK cipher suites. + PreSharedKey []byte + + // PreSharedKeyIdentity, if not empty, is the identity to use + // with the PSK cipher suites. + PreSharedKeyIdentity string + + // SRTPProtectionProfiles, if not nil, is the list of SRTP + // protection profiles to offer in DTLS-SRTP. + SRTPProtectionProfiles []uint16 + + // SignatureAndHashes, if not nil, overrides the default set of + // supported signature and hash algorithms to advertise in + // CertificateRequest. + SignatureAndHashes []signatureAndHash + + // Bugs specifies optional misbehaviour to be used for testing other + // implementations. + Bugs ProtocolBugs + + serverInitOnce sync.Once // guards calling (*Config).serverInit +} + +type BadValue int + +const ( + BadValueNone BadValue = iota + BadValueNegative + BadValueZero + BadValueLimit + BadValueLarge + NumBadValues +) + +type ProtocolBugs struct { + // InvalidSKXSignature specifies that the signature in a + // ServerKeyExchange message should be invalid. + InvalidSKXSignature bool + + // InvalidSKXCurve causes the curve ID in the ServerKeyExchange message + // to be wrong. + InvalidSKXCurve bool + + // BadECDSAR controls ways in which the 'r' value of an ECDSA signature + // can be invalid. + BadECDSAR BadValue + BadECDSAS BadValue + + // MaxPadding causes CBC records to have the maximum possible padding. + MaxPadding bool + // PaddingFirstByteBad causes the first byte of the padding to be + // incorrect. + PaddingFirstByteBad bool + // PaddingFirstByteBadIf255 causes the first byte of padding to be + // incorrect if there's a maximum amount of padding (i.e. 255 bytes). + PaddingFirstByteBadIf255 bool + + // FailIfNotFallbackSCSV causes a server handshake to fail if the + // client doesn't send the fallback SCSV value. + FailIfNotFallbackSCSV bool + + // DuplicateExtension causes an extra empty extension of bogus type to + // be emitted in either the ClientHello or the ServerHello. + DuplicateExtension bool + + // UnauthenticatedECDH causes the server to pretend ECDHE_RSA + // and ECDHE_ECDSA cipher suites are actually ECDH_anon. No + // Certificate message is sent and no signature is added to + // ServerKeyExchange. + UnauthenticatedECDH bool + + // SkipServerKeyExchange causes the server to skip sending + // ServerKeyExchange messages. + SkipServerKeyExchange bool + + // SkipChangeCipherSpec causes the implementation to skip + // sending the ChangeCipherSpec message (and adjusting cipher + // state accordingly for the Finished message). + SkipChangeCipherSpec bool + + // EarlyChangeCipherSpec causes the client to send an early + // ChangeCipherSpec message before the ClientKeyExchange. A value of + // zero disables this behavior. One and two configure variants for 0.9.8 + // and 1.0.1 modes, respectively. + EarlyChangeCipherSpec int + + // FragmentAcrossChangeCipherSpec causes the implementation to fragment + // the Finished (or NextProto) message around the ChangeCipherSpec + // messages. + FragmentAcrossChangeCipherSpec bool + + // SkipNewSessionTicket causes the server to skip sending the + // NewSessionTicket message despite promising to in ServerHello. + SkipNewSessionTicket bool + + // SendV2ClientHello causes the client to send a V2ClientHello + // instead of a normal ClientHello. + SendV2ClientHello bool + + // SendFallbackSCSV causes the client to include + // TLS_FALLBACK_SCSV in the ClientHello. + SendFallbackSCSV bool + + // MaxHandshakeRecordLength, if non-zero, is the maximum size of a + // handshake record. Handshake messages will be split into multiple + // records at the specified size, except that the client_version will + // never be fragmented. + MaxHandshakeRecordLength int + + // FragmentClientVersion will allow MaxHandshakeRecordLength to apply to + // the first 6 bytes of the ClientHello. + FragmentClientVersion bool + + // FragmentAlert will cause all alerts to be fragmented across + // two records. + FragmentAlert bool + + // SendSpuriousAlert will cause an spurious, unwanted alert to be sent. + SendSpuriousAlert bool + + // RsaClientKeyExchangeVersion, if non-zero, causes the client to send a + // ClientKeyExchange with the specified version rather than the + // client_version when performing the RSA key exchange. + RsaClientKeyExchangeVersion uint16 + + // RenewTicketOnResume causes the server to renew the session ticket and + // send a NewSessionTicket message during an abbreviated handshake. + RenewTicketOnResume bool + + // SendClientVersion, if non-zero, causes the client to send a different + // TLS version in the ClientHello than the maximum supported version. + SendClientVersion uint16 + + // SkipHelloVerifyRequest causes a DTLS server to skip the + // HelloVerifyRequest message. + SkipHelloVerifyRequest bool + + // ExpectFalseStart causes the server to, on full handshakes, + // expect the peer to False Start; the server Finished message + // isn't sent until we receive an application data record + // from the peer. + ExpectFalseStart bool + + // SSL3RSAKeyExchange causes the client to always send an RSA + // ClientKeyExchange message without the two-byte length + // prefix, as if it were SSL3. + SSL3RSAKeyExchange bool + + // SkipCipherVersionCheck causes the server to negotiate + // TLS 1.2 ciphers in earlier versions of TLS. + SkipCipherVersionCheck bool + + // ExpectServerName, if not empty, is the hostname the client + // must specify in the server_name extension. + ExpectServerName string + + // SwapNPNAndALPN switches the relative order between NPN and + // ALPN on the server. This is to test that server preference + // of ALPN works regardless of their relative order. + SwapNPNAndALPN bool + + // AllowSessionVersionMismatch causes the server to resume sessions + // regardless of the version associated with the session. + AllowSessionVersionMismatch bool + + // CorruptTicket causes a client to corrupt a session ticket before + // sending it in a resume handshake. + CorruptTicket bool + + // OversizedSessionId causes the session id that is sent with a ticket + // resumption attempt to be too large (33 bytes). + OversizedSessionId bool + + // RequireExtendedMasterSecret, if true, requires that the peer support + // the extended master secret option. + RequireExtendedMasterSecret bool + + // NoExtendedMasterSecret causes the client and server to behave as if + // they didn't support an extended master secret. + NoExtendedMasterSecret bool + + // EmptyRenegotiationInfo causes the renegotiation extension to be + // empty in a renegotiation handshake. + EmptyRenegotiationInfo bool + + // BadRenegotiationInfo causes the renegotiation extension value in a + // renegotiation handshake to be incorrect. + BadRenegotiationInfo bool + + // NoRenegotiationInfo causes the client to behave as if it + // didn't support the renegotiation info extension. + NoRenegotiationInfo bool + + // SequenceNumberIncrement, if non-zero, causes outgoing sequence + // numbers in DTLS to increment by that value rather by 1. This is to + // stress the replay bitmap window by simulating extreme packet loss and + // retransmit at the record layer. + SequenceNumberIncrement uint64 + + // RSAServerKeyExchange, if true, causes the server to send a + // ServerKeyExchange message in the plain RSA key exchange. + RSAServerKeyExchange bool + + // SRTPMasterKeyIdentifer, if not empty, is the SRTP MKI value that the + // client offers when negotiating SRTP. MKI support is still missing so + // the peer must still send none. + SRTPMasterKeyIdentifer string + + // SendSRTPProtectionProfile, if non-zero, is the SRTP profile that the + // server sends in the ServerHello instead of the negotiated one. + SendSRTPProtectionProfile uint16 + + // NoSignatureAndHashes, if true, causes the client to omit the + // signature and hashes extension. + // + // For a server, it will cause an empty list to be sent in the + // CertificateRequest message. None the less, the configured set will + // still be enforced. + NoSignatureAndHashes bool + + // RequireSameRenegoClientVersion, if true, causes the server + // to require that all ClientHellos match in offered version + // across a renego. + RequireSameRenegoClientVersion bool + + // RequireFastradioPadding, if true, requires that ClientHello messages + // be at least 1000 bytes long. + RequireFastradioPadding bool + + // ExpectInitialRecordVersion, if non-zero, is the expected + // version of the records before the version is determined. + ExpectInitialRecordVersion uint16 + + // MaxPacketLength, if non-zero, is the maximum acceptable size for a + // packet. + MaxPacketLength int + + // SendCipherSuite, if non-zero, is the cipher suite value that the + // server will send in the ServerHello. This does not affect the cipher + // the server believes it has actually negotiated. + SendCipherSuite uint16 + + // AppDataAfterChangeCipherSpec, if not null, causes application data to + // be sent immediately after ChangeCipherSpec. + AppDataAfterChangeCipherSpec []byte +} + +func (c *Config) serverInit() { + if c.SessionTicketsDisabled { + return + } + + // If the key has already been set then we have nothing to do. + for _, b := range c.SessionTicketKey { + if b != 0 { + return + } + } + + if _, err := io.ReadFull(c.rand(), c.SessionTicketKey[:]); err != nil { + c.SessionTicketsDisabled = true + } +} + +func (c *Config) rand() io.Reader { + r := c.Rand + if r == nil { + return rand.Reader + } + return r +} + +func (c *Config) time() time.Time { + t := c.Time + if t == nil { + t = time.Now + } + return t() +} + +func (c *Config) cipherSuites() []uint16 { + s := c.CipherSuites + if s == nil { + s = defaultCipherSuites() + } + return s +} + +func (c *Config) minVersion() uint16 { + if c == nil || c.MinVersion == 0 { + return minVersion + } + return c.MinVersion +} + +func (c *Config) maxVersion() uint16 { + if c == nil || c.MaxVersion == 0 { + return maxVersion + } + return c.MaxVersion +} + +var defaultCurvePreferences = []CurveID{CurveP256, CurveP384, CurveP521} + +func (c *Config) curvePreferences() []CurveID { + if c == nil || len(c.CurvePreferences) == 0 { + return defaultCurvePreferences + } + return c.CurvePreferences +} + +// mutualVersion returns the protocol version to use given the advertised +// version of the peer. +func (c *Config) mutualVersion(vers uint16) (uint16, bool) { + minVersion := c.minVersion() + maxVersion := c.maxVersion() + + if vers < minVersion { + return 0, false + } + if vers > maxVersion { + vers = maxVersion + } + return vers, true +} + +// getCertificateForName returns the best certificate for the given name, +// defaulting to the first element of c.Certificates if there are no good +// options. +func (c *Config) getCertificateForName(name string) *Certificate { + if len(c.Certificates) == 1 || c.NameToCertificate == nil { + // There's only one choice, so no point doing any work. + return &c.Certificates[0] + } + + name = strings.ToLower(name) + for len(name) > 0 && name[len(name)-1] == '.' { + name = name[:len(name)-1] + } + + if cert, ok := c.NameToCertificate[name]; ok { + return cert + } + + // try replacing labels in the name with wildcards until we get a + // match. + labels := strings.Split(name, ".") + for i := range labels { + labels[i] = "*" + candidate := strings.Join(labels, ".") + if cert, ok := c.NameToCertificate[candidate]; ok { + return cert + } + } + + // If nothing matches, return the first certificate. + return &c.Certificates[0] +} + +func (c *Config) signatureAndHashesForServer() []signatureAndHash { + if c != nil && c.SignatureAndHashes != nil { + return c.SignatureAndHashes + } + return supportedClientCertSignatureAlgorithms +} + +func (c *Config) signatureAndHashesForClient() []signatureAndHash { + if c != nil && c.SignatureAndHashes != nil { + return c.SignatureAndHashes + } + return supportedSKXSignatureAlgorithms +} + +// BuildNameToCertificate parses c.Certificates and builds c.NameToCertificate +// from the CommonName and SubjectAlternateName fields of each of the leaf +// certificates. +func (c *Config) BuildNameToCertificate() { + c.NameToCertificate = make(map[string]*Certificate) + for i := range c.Certificates { + cert := &c.Certificates[i] + x509Cert, err := x509.ParseCertificate(cert.Certificate[0]) + if err != nil { + continue + } + if len(x509Cert.Subject.CommonName) > 0 { + c.NameToCertificate[x509Cert.Subject.CommonName] = cert + } + for _, san := range x509Cert.DNSNames { + c.NameToCertificate[san] = cert + } + } +} + +// A Certificate is a chain of one or more certificates, leaf first. +type Certificate struct { + Certificate [][]byte + PrivateKey crypto.PrivateKey // supported types: *rsa.PrivateKey, *ecdsa.PrivateKey + // OCSPStaple contains an optional OCSP response which will be served + // to clients that request it. + OCSPStaple []byte + // SignedCertificateTimestampList contains an optional encoded + // SignedCertificateTimestampList structure which will be + // served to clients that request it. + SignedCertificateTimestampList []byte + // Leaf is the parsed form of the leaf certificate, which may be + // initialized using x509.ParseCertificate to reduce per-handshake + // processing for TLS clients doing client authentication. If nil, the + // leaf certificate will be parsed as needed. + Leaf *x509.Certificate +} + +// A TLS record. +type record struct { + contentType recordType + major, minor uint8 + payload []byte +} + +type handshakeMessage interface { + marshal() []byte + unmarshal([]byte) bool +} + +// lruSessionCache is a client or server session cache implementation +// that uses an LRU caching strategy. +type lruSessionCache struct { + sync.Mutex + + m map[string]*list.Element + q *list.List + capacity int +} + +type lruSessionCacheEntry struct { + sessionKey string + state interface{} +} + +// Put adds the provided (sessionKey, cs) pair to the cache. +func (c *lruSessionCache) Put(sessionKey string, cs interface{}) { + c.Lock() + defer c.Unlock() + + if elem, ok := c.m[sessionKey]; ok { + entry := elem.Value.(*lruSessionCacheEntry) + entry.state = cs + c.q.MoveToFront(elem) + return + } + + if c.q.Len() < c.capacity { + entry := &lruSessionCacheEntry{sessionKey, cs} + c.m[sessionKey] = c.q.PushFront(entry) + return + } + + elem := c.q.Back() + entry := elem.Value.(*lruSessionCacheEntry) + delete(c.m, entry.sessionKey) + entry.sessionKey = sessionKey + entry.state = cs + c.q.MoveToFront(elem) + c.m[sessionKey] = elem +} + +// Get returns the value associated with a given key. It returns (nil, +// false) if no value is found. +func (c *lruSessionCache) Get(sessionKey string) (interface{}, bool) { + c.Lock() + defer c.Unlock() + + if elem, ok := c.m[sessionKey]; ok { + c.q.MoveToFront(elem) + return elem.Value.(*lruSessionCacheEntry).state, true + } + return nil, false +} + +// lruClientSessionCache is a ClientSessionCache implementation that +// uses an LRU caching strategy. +type lruClientSessionCache struct { + lruSessionCache +} + +func (c *lruClientSessionCache) Put(sessionKey string, cs *ClientSessionState) { + c.lruSessionCache.Put(sessionKey, cs) +} + +func (c *lruClientSessionCache) Get(sessionKey string) (*ClientSessionState, bool) { + cs, ok := c.lruSessionCache.Get(sessionKey) + if !ok { + return nil, false + } + return cs.(*ClientSessionState), true +} + +// lruServerSessionCache is a ServerSessionCache implementation that +// uses an LRU caching strategy. +type lruServerSessionCache struct { + lruSessionCache +} + +func (c *lruServerSessionCache) Put(sessionId string, session *sessionState) { + c.lruSessionCache.Put(sessionId, session) +} + +func (c *lruServerSessionCache) Get(sessionId string) (*sessionState, bool) { + cs, ok := c.lruSessionCache.Get(sessionId) + if !ok { + return nil, false + } + return cs.(*sessionState), true +} + +// NewLRUClientSessionCache returns a ClientSessionCache with the given +// capacity that uses an LRU strategy. If capacity is < 1, a default capacity +// is used instead. +func NewLRUClientSessionCache(capacity int) ClientSessionCache { + const defaultSessionCacheCapacity = 64 + + if capacity < 1 { + capacity = defaultSessionCacheCapacity + } + return &lruClientSessionCache{ + lruSessionCache{ + m: make(map[string]*list.Element), + q: list.New(), + capacity: capacity, + }, + } +} + +// NewLRUServerSessionCache returns a ServerSessionCache with the given +// capacity that uses an LRU strategy. If capacity is < 1, a default capacity +// is used instead. +func NewLRUServerSessionCache(capacity int) ServerSessionCache { + const defaultSessionCacheCapacity = 64 + + if capacity < 1 { + capacity = defaultSessionCacheCapacity + } + return &lruServerSessionCache{ + lruSessionCache{ + m: make(map[string]*list.Element), + q: list.New(), + capacity: capacity, + }, + } +} + +// TODO(jsing): Make these available to both crypto/x509 and crypto/tls. +type dsaSignature struct { + R, S *big.Int +} + +type ecdsaSignature dsaSignature + +var emptyConfig Config + +func defaultConfig() *Config { + return &emptyConfig +} + +var ( + once sync.Once + varDefaultCipherSuites []uint16 +) + +func defaultCipherSuites() []uint16 { + once.Do(initDefaultCipherSuites) + return varDefaultCipherSuites +} + +func initDefaultCipherSuites() { + for _, suite := range cipherSuites { + if suite.flags&suitePSK == 0 { + varDefaultCipherSuites = append(varDefaultCipherSuites, suite.id) + } + } +} + +func unexpectedMessageError(wanted, got interface{}) error { + return fmt.Errorf("tls: received unexpected handshake message of type %T when waiting for %T", got, wanted) +} + +func isSupportedSignatureAndHash(sigHash signatureAndHash, sigHashes []signatureAndHash) bool { + for _, s := range sigHashes { + if s == sigHash { + return true + } + } + return false +} diff --git a/src/ssl/test/runner/conn.go b/src/ssl/test/runner/conn.go new file mode 100644 index 0000000..d4a6817 --- /dev/null +++ b/src/ssl/test/runner/conn.go @@ -0,0 +1,1229 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// TLS low level connection and record layer + +package main + +import ( + "bytes" + "crypto/cipher" + "crypto/ecdsa" + "crypto/subtle" + "crypto/x509" + "errors" + "fmt" + "io" + "net" + "sync" + "time" +) + +// A Conn represents a secured connection. +// It implements the net.Conn interface. +type Conn struct { + // constant + conn net.Conn + isDTLS bool + isClient bool + + // constant after handshake; protected by handshakeMutex + handshakeMutex sync.Mutex // handshakeMutex < in.Mutex, out.Mutex, errMutex + handshakeErr error // error resulting from handshake + vers uint16 // TLS version + haveVers bool // version has been negotiated + config *Config // configuration passed to constructor + handshakeComplete bool + didResume bool // whether this connection was a session resumption + extendedMasterSecret bool // whether this session used an extended master secret + cipherSuite uint16 + ocspResponse []byte // stapled OCSP response + peerCertificates []*x509.Certificate + // verifiedChains contains the certificate chains that we built, as + // opposed to the ones presented by the server. + verifiedChains [][]*x509.Certificate + // serverName contains the server name indicated by the client, if any. + serverName string + + clientProtocol string + clientProtocolFallback bool + usedALPN bool + + // verify_data values for the renegotiation extension. + clientVerify []byte + serverVerify []byte + + channelID *ecdsa.PublicKey + + srtpProtectionProfile uint16 + + clientVersion uint16 + + // input/output + in, out halfConn // in.Mutex < out.Mutex + rawInput *block // raw input, right off the wire + input *block // application record waiting to be read + hand bytes.Buffer // handshake record waiting to be read + + // DTLS state + sendHandshakeSeq uint16 + recvHandshakeSeq uint16 + handMsg []byte // pending assembled handshake message + handMsgLen int // handshake message length, not including the header + + tmp [16]byte +} + +func (c *Conn) init() { + c.in.isDTLS = c.isDTLS + c.out.isDTLS = c.isDTLS + c.in.config = c.config + c.out.config = c.config +} + +// Access to net.Conn methods. +// Cannot just embed net.Conn because that would +// export the struct field too. + +// LocalAddr returns the local network address. +func (c *Conn) LocalAddr() net.Addr { + return c.conn.LocalAddr() +} + +// RemoteAddr returns the remote network address. +func (c *Conn) RemoteAddr() net.Addr { + return c.conn.RemoteAddr() +} + +// SetDeadline sets the read and write deadlines associated with the connection. +// A zero value for t means Read and Write will not time out. +// After a Write has timed out, the TLS state is corrupt and all future writes will return the same error. +func (c *Conn) SetDeadline(t time.Time) error { + return c.conn.SetDeadline(t) +} + +// SetReadDeadline sets the read deadline on the underlying connection. +// A zero value for t means Read will not time out. +func (c *Conn) SetReadDeadline(t time.Time) error { + return c.conn.SetReadDeadline(t) +} + +// SetWriteDeadline sets the write deadline on the underlying conneciton. +// A zero value for t means Write will not time out. +// After a Write has timed out, the TLS state is corrupt and all future writes will return the same error. +func (c *Conn) SetWriteDeadline(t time.Time) error { + return c.conn.SetWriteDeadline(t) +} + +// A halfConn represents one direction of the record layer +// connection, either sending or receiving. +type halfConn struct { + sync.Mutex + + err error // first permanent error + version uint16 // protocol version + isDTLS bool + cipher interface{} // cipher algorithm + mac macFunction + seq [8]byte // 64-bit sequence number + bfree *block // list of free blocks + + nextCipher interface{} // next encryption state + nextMac macFunction // next MAC algorithm + + // used to save allocating a new buffer for each MAC. + inDigestBuf, outDigestBuf []byte + + config *Config +} + +func (hc *halfConn) setErrorLocked(err error) error { + hc.err = err + return err +} + +func (hc *halfConn) error() error { + // This should be locked, but I've removed it for the renegotiation + // tests since we don't concurrently read and write the same tls.Conn + // in any case during testing. + err := hc.err + return err +} + +// prepareCipherSpec sets the encryption and MAC states +// that a subsequent changeCipherSpec will use. +func (hc *halfConn) prepareCipherSpec(version uint16, cipher interface{}, mac macFunction) { + hc.version = version + hc.nextCipher = cipher + hc.nextMac = mac +} + +// changeCipherSpec changes the encryption and MAC states +// to the ones previously passed to prepareCipherSpec. +func (hc *halfConn) changeCipherSpec(config *Config) error { + if hc.nextCipher == nil { + return alertInternalError + } + hc.cipher = hc.nextCipher + hc.mac = hc.nextMac + hc.nextCipher = nil + hc.nextMac = nil + hc.config = config + hc.incEpoch() + return nil +} + +// incSeq increments the sequence number. +func (hc *halfConn) incSeq(isOutgoing bool) { + limit := 0 + increment := uint64(1) + if hc.isDTLS { + // Increment up to the epoch in DTLS. + limit = 2 + + if isOutgoing && hc.config.Bugs.SequenceNumberIncrement != 0 { + increment = hc.config.Bugs.SequenceNumberIncrement + } + } + for i := 7; i >= limit; i-- { + increment += uint64(hc.seq[i]) + hc.seq[i] = byte(increment) + increment >>= 8 + } + + // Not allowed to let sequence number wrap. + // Instead, must renegotiate before it does. + // Not likely enough to bother. + if increment != 0 { + panic("TLS: sequence number wraparound") + } +} + +// incEpoch resets the sequence number. In DTLS, it increments the +// epoch half of the sequence number. +func (hc *halfConn) incEpoch() { + limit := 0 + if hc.isDTLS { + for i := 1; i >= 0; i-- { + hc.seq[i]++ + if hc.seq[i] != 0 { + break + } + if i == 0 { + panic("TLS: epoch number wraparound") + } + } + limit = 2 + } + seq := hc.seq[limit:] + for i := range seq { + seq[i] = 0 + } +} + +func (hc *halfConn) recordHeaderLen() int { + if hc.isDTLS { + return dtlsRecordHeaderLen + } + return tlsRecordHeaderLen +} + +// removePadding returns an unpadded slice, in constant time, which is a prefix +// of the input. It also returns a byte which is equal to 255 if the padding +// was valid and 0 otherwise. See RFC 2246, section 6.2.3.2 +func removePadding(payload []byte) ([]byte, byte) { + if len(payload) < 1 { + return payload, 0 + } + + paddingLen := payload[len(payload)-1] + t := uint(len(payload)-1) - uint(paddingLen) + // if len(payload) >= (paddingLen - 1) then the MSB of t is zero + good := byte(int32(^t) >> 31) + + toCheck := 255 // the maximum possible padding length + // The length of the padded data is public, so we can use an if here + if toCheck+1 > len(payload) { + toCheck = len(payload) - 1 + } + + for i := 0; i < toCheck; i++ { + t := uint(paddingLen) - uint(i) + // if i <= paddingLen then the MSB of t is zero + mask := byte(int32(^t) >> 31) + b := payload[len(payload)-1-i] + good &^= mask&paddingLen ^ mask&b + } + + // We AND together the bits of good and replicate the result across + // all the bits. + good &= good << 4 + good &= good << 2 + good &= good << 1 + good = uint8(int8(good) >> 7) + + toRemove := good&paddingLen + 1 + return payload[:len(payload)-int(toRemove)], good +} + +// removePaddingSSL30 is a replacement for removePadding in the case that the +// protocol version is SSLv3. In this version, the contents of the padding +// are random and cannot be checked. +func removePaddingSSL30(payload []byte) ([]byte, byte) { + if len(payload) < 1 { + return payload, 0 + } + + paddingLen := int(payload[len(payload)-1]) + 1 + if paddingLen > len(payload) { + return payload, 0 + } + + return payload[:len(payload)-paddingLen], 255 +} + +func roundUp(a, b int) int { + return a + (b-a%b)%b +} + +// cbcMode is an interface for block ciphers using cipher block chaining. +type cbcMode interface { + cipher.BlockMode + SetIV([]byte) +} + +// decrypt checks and strips the mac and decrypts the data in b. Returns a +// success boolean, the number of bytes to skip from the start of the record in +// order to get the application payload, and an optional alert value. +func (hc *halfConn) decrypt(b *block) (ok bool, prefixLen int, alertValue alert) { + recordHeaderLen := hc.recordHeaderLen() + + // pull out payload + payload := b.data[recordHeaderLen:] + + macSize := 0 + if hc.mac != nil { + macSize = hc.mac.Size() + } + + paddingGood := byte(255) + explicitIVLen := 0 + + seq := hc.seq[:] + if hc.isDTLS { + // DTLS sequence numbers are explicit. + seq = b.data[3:11] + } + + // decrypt + if hc.cipher != nil { + switch c := hc.cipher.(type) { + case cipher.Stream: + c.XORKeyStream(payload, payload) + case cipher.AEAD: + explicitIVLen = 8 + if len(payload) < explicitIVLen { + return false, 0, alertBadRecordMAC + } + nonce := payload[:8] + payload = payload[8:] + + var additionalData [13]byte + copy(additionalData[:], seq) + copy(additionalData[8:], b.data[:3]) + n := len(payload) - c.Overhead() + additionalData[11] = byte(n >> 8) + additionalData[12] = byte(n) + var err error + payload, err = c.Open(payload[:0], nonce, payload, additionalData[:]) + if err != nil { + return false, 0, alertBadRecordMAC + } + b.resize(recordHeaderLen + explicitIVLen + len(payload)) + case cbcMode: + blockSize := c.BlockSize() + if hc.version >= VersionTLS11 || hc.isDTLS { + explicitIVLen = blockSize + } + + if len(payload)%blockSize != 0 || len(payload) < roundUp(explicitIVLen+macSize+1, blockSize) { + return false, 0, alertBadRecordMAC + } + + if explicitIVLen > 0 { + c.SetIV(payload[:explicitIVLen]) + payload = payload[explicitIVLen:] + } + c.CryptBlocks(payload, payload) + if hc.version == VersionSSL30 { + payload, paddingGood = removePaddingSSL30(payload) + } else { + payload, paddingGood = removePadding(payload) + } + b.resize(recordHeaderLen + explicitIVLen + len(payload)) + + // note that we still have a timing side-channel in the + // MAC check, below. An attacker can align the record + // so that a correct padding will cause one less hash + // block to be calculated. Then they can iteratively + // decrypt a record by breaking each byte. See + // "Password Interception in a SSL/TLS Channel", Brice + // Canvel et al. + // + // However, our behavior matches OpenSSL, so we leak + // only as much as they do. + default: + panic("unknown cipher type") + } + } + + // check, strip mac + if hc.mac != nil { + if len(payload) < macSize { + return false, 0, alertBadRecordMAC + } + + // strip mac off payload, b.data + n := len(payload) - macSize + b.data[recordHeaderLen-2] = byte(n >> 8) + b.data[recordHeaderLen-1] = byte(n) + b.resize(recordHeaderLen + explicitIVLen + n) + remoteMAC := payload[n:] + localMAC := hc.mac.MAC(hc.inDigestBuf, seq, b.data[:3], b.data[recordHeaderLen-2:recordHeaderLen], payload[:n]) + + if subtle.ConstantTimeCompare(localMAC, remoteMAC) != 1 || paddingGood != 255 { + return false, 0, alertBadRecordMAC + } + hc.inDigestBuf = localMAC + } + hc.incSeq(false) + + return true, recordHeaderLen + explicitIVLen, 0 +} + +// padToBlockSize calculates the needed padding block, if any, for a payload. +// On exit, prefix aliases payload and extends to the end of the last full +// block of payload. finalBlock is a fresh slice which contains the contents of +// any suffix of payload as well as the needed padding to make finalBlock a +// full block. +func padToBlockSize(payload []byte, blockSize int, config *Config) (prefix, finalBlock []byte) { + overrun := len(payload) % blockSize + prefix = payload[:len(payload)-overrun] + + paddingLen := blockSize - overrun + finalSize := blockSize + if config.Bugs.MaxPadding { + for paddingLen+blockSize <= 256 { + paddingLen += blockSize + } + finalSize = 256 + } + finalBlock = make([]byte, finalSize) + for i := range finalBlock { + finalBlock[i] = byte(paddingLen - 1) + } + if config.Bugs.PaddingFirstByteBad || config.Bugs.PaddingFirstByteBadIf255 && paddingLen == 256 { + finalBlock[overrun] ^= 0xff + } + copy(finalBlock, payload[len(payload)-overrun:]) + return +} + +// encrypt encrypts and macs the data in b. +func (hc *halfConn) encrypt(b *block, explicitIVLen int) (bool, alert) { + recordHeaderLen := hc.recordHeaderLen() + + // mac + if hc.mac != nil { + mac := hc.mac.MAC(hc.outDigestBuf, hc.seq[0:], b.data[:3], b.data[recordHeaderLen-2:recordHeaderLen], b.data[recordHeaderLen+explicitIVLen:]) + + n := len(b.data) + b.resize(n + len(mac)) + copy(b.data[n:], mac) + hc.outDigestBuf = mac + } + + payload := b.data[recordHeaderLen:] + + // encrypt + if hc.cipher != nil { + switch c := hc.cipher.(type) { + case cipher.Stream: + c.XORKeyStream(payload, payload) + case cipher.AEAD: + payloadLen := len(b.data) - recordHeaderLen - explicitIVLen + b.resize(len(b.data) + c.Overhead()) + nonce := b.data[recordHeaderLen : recordHeaderLen+explicitIVLen] + payload := b.data[recordHeaderLen+explicitIVLen:] + payload = payload[:payloadLen] + + var additionalData [13]byte + copy(additionalData[:], hc.seq[:]) + copy(additionalData[8:], b.data[:3]) + additionalData[11] = byte(payloadLen >> 8) + additionalData[12] = byte(payloadLen) + + c.Seal(payload[:0], nonce, payload, additionalData[:]) + case cbcMode: + blockSize := c.BlockSize() + if explicitIVLen > 0 { + c.SetIV(payload[:explicitIVLen]) + payload = payload[explicitIVLen:] + } + prefix, finalBlock := padToBlockSize(payload, blockSize, hc.config) + b.resize(recordHeaderLen + explicitIVLen + len(prefix) + len(finalBlock)) + c.CryptBlocks(b.data[recordHeaderLen+explicitIVLen:], prefix) + c.CryptBlocks(b.data[recordHeaderLen+explicitIVLen+len(prefix):], finalBlock) + default: + panic("unknown cipher type") + } + } + + // update length to include MAC and any block padding needed. + n := len(b.data) - recordHeaderLen + b.data[recordHeaderLen-2] = byte(n >> 8) + b.data[recordHeaderLen-1] = byte(n) + hc.incSeq(true) + + return true, 0 +} + +// A block is a simple data buffer. +type block struct { + data []byte + off int // index for Read + link *block +} + +// resize resizes block to be n bytes, growing if necessary. +func (b *block) resize(n int) { + if n > cap(b.data) { + b.reserve(n) + } + b.data = b.data[0:n] +} + +// reserve makes sure that block contains a capacity of at least n bytes. +func (b *block) reserve(n int) { + if cap(b.data) >= n { + return + } + m := cap(b.data) + if m == 0 { + m = 1024 + } + for m < n { + m *= 2 + } + data := make([]byte, len(b.data), m) + copy(data, b.data) + b.data = data +} + +// readFromUntil reads from r into b until b contains at least n bytes +// or else returns an error. +func (b *block) readFromUntil(r io.Reader, n int) error { + // quick case + if len(b.data) >= n { + return nil + } + + // read until have enough. + b.reserve(n) + for { + m, err := r.Read(b.data[len(b.data):cap(b.data)]) + b.data = b.data[0 : len(b.data)+m] + if len(b.data) >= n { + // TODO(bradfitz,agl): slightly suspicious + // that we're throwing away r.Read's err here. + break + } + if err != nil { + return err + } + } + return nil +} + +func (b *block) Read(p []byte) (n int, err error) { + n = copy(p, b.data[b.off:]) + b.off += n + return +} + +// newBlock allocates a new block, from hc's free list if possible. +func (hc *halfConn) newBlock() *block { + b := hc.bfree + if b == nil { + return new(block) + } + hc.bfree = b.link + b.link = nil + b.resize(0) + return b +} + +// freeBlock returns a block to hc's free list. +// The protocol is such that each side only has a block or two on +// its free list at a time, so there's no need to worry about +// trimming the list, etc. +func (hc *halfConn) freeBlock(b *block) { + b.link = hc.bfree + hc.bfree = b +} + +// splitBlock splits a block after the first n bytes, +// returning a block with those n bytes and a +// block with the remainder. the latter may be nil. +func (hc *halfConn) splitBlock(b *block, n int) (*block, *block) { + if len(b.data) <= n { + return b, nil + } + bb := hc.newBlock() + bb.resize(len(b.data) - n) + copy(bb.data, b.data[n:]) + b.data = b.data[0:n] + return b, bb +} + +func (c *Conn) doReadRecord(want recordType) (recordType, *block, error) { + if c.isDTLS { + return c.dtlsDoReadRecord(want) + } + + recordHeaderLen := tlsRecordHeaderLen + + if c.rawInput == nil { + c.rawInput = c.in.newBlock() + } + b := c.rawInput + + // Read header, payload. + if err := b.readFromUntil(c.conn, recordHeaderLen); err != nil { + // RFC suggests that EOF without an alertCloseNotify is + // an error, but popular web sites seem to do this, + // so we can't make it an error. + // if err == io.EOF { + // err = io.ErrUnexpectedEOF + // } + if e, ok := err.(net.Error); !ok || !e.Temporary() { + c.in.setErrorLocked(err) + } + return 0, nil, err + } + typ := recordType(b.data[0]) + + // No valid TLS record has a type of 0x80, however SSLv2 handshakes + // start with a uint16 length where the MSB is set and the first record + // is always < 256 bytes long. Therefore typ == 0x80 strongly suggests + // an SSLv2 client. + if want == recordTypeHandshake && typ == 0x80 { + c.sendAlert(alertProtocolVersion) + return 0, nil, c.in.setErrorLocked(errors.New("tls: unsupported SSLv2 handshake received")) + } + + vers := uint16(b.data[1])<<8 | uint16(b.data[2]) + n := int(b.data[3])<<8 | int(b.data[4]) + if c.haveVers { + if vers != c.vers { + c.sendAlert(alertProtocolVersion) + return 0, nil, c.in.setErrorLocked(fmt.Errorf("tls: received record with version %x when expecting version %x", vers, c.vers)) + } + } else { + if expect := c.config.Bugs.ExpectInitialRecordVersion; expect != 0 && vers != expect { + c.sendAlert(alertProtocolVersion) + return 0, nil, c.in.setErrorLocked(fmt.Errorf("tls: received record with version %x when expecting version %x", vers, expect)) + } + } + if n > maxCiphertext { + c.sendAlert(alertRecordOverflow) + return 0, nil, c.in.setErrorLocked(fmt.Errorf("tls: oversized record received with length %d", n)) + } + if !c.haveVers { + // First message, be extra suspicious: + // this might not be a TLS client. + // Bail out before reading a full 'body', if possible. + // The current max version is 3.1. + // If the version is >= 16.0, it's probably not real. + // Similarly, a clientHello message encodes in + // well under a kilobyte. If the length is >= 12 kB, + // it's probably not real. + if (typ != recordTypeAlert && typ != want) || vers >= 0x1000 || n >= 0x3000 { + c.sendAlert(alertUnexpectedMessage) + return 0, nil, c.in.setErrorLocked(fmt.Errorf("tls: first record does not look like a TLS handshake")) + } + } + if err := b.readFromUntil(c.conn, recordHeaderLen+n); err != nil { + if err == io.EOF { + err = io.ErrUnexpectedEOF + } + if e, ok := err.(net.Error); !ok || !e.Temporary() { + c.in.setErrorLocked(err) + } + return 0, nil, err + } + + // Process message. + b, c.rawInput = c.in.splitBlock(b, recordHeaderLen+n) + ok, off, err := c.in.decrypt(b) + if !ok { + c.in.setErrorLocked(c.sendAlert(err)) + } + b.off = off + return typ, b, nil +} + +// readRecord reads the next TLS record from the connection +// and updates the record layer state. +// c.in.Mutex <= L; c.input == nil. +func (c *Conn) readRecord(want recordType) error { + // Caller must be in sync with connection: + // handshake data if handshake not yet completed, + // else application data. + switch want { + default: + c.sendAlert(alertInternalError) + return c.in.setErrorLocked(errors.New("tls: unknown record type requested")) + case recordTypeHandshake, recordTypeChangeCipherSpec: + if c.handshakeComplete { + c.sendAlert(alertInternalError) + return c.in.setErrorLocked(errors.New("tls: handshake or ChangeCipherSpec requested after handshake complete")) + } + case recordTypeApplicationData: + if !c.handshakeComplete && !c.config.Bugs.ExpectFalseStart { + c.sendAlert(alertInternalError) + return c.in.setErrorLocked(errors.New("tls: application data record requested before handshake complete")) + } + } + +Again: + typ, b, err := c.doReadRecord(want) + if err != nil { + return err + } + data := b.data[b.off:] + if len(data) > maxPlaintext { + err := c.sendAlert(alertRecordOverflow) + c.in.freeBlock(b) + return c.in.setErrorLocked(err) + } + + switch typ { + default: + c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + + case recordTypeAlert: + if len(data) != 2 { + c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + break + } + if alert(data[1]) == alertCloseNotify { + c.in.setErrorLocked(io.EOF) + break + } + switch data[0] { + case alertLevelWarning: + // drop on the floor + c.in.freeBlock(b) + goto Again + case alertLevelError: + c.in.setErrorLocked(&net.OpError{Op: "remote error", Err: alert(data[1])}) + default: + c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + + case recordTypeChangeCipherSpec: + if typ != want || len(data) != 1 || data[0] != 1 { + c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + break + } + err := c.in.changeCipherSpec(c.config) + if err != nil { + c.in.setErrorLocked(c.sendAlert(err.(alert))) + } + + case recordTypeApplicationData: + if typ != want { + c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + break + } + c.input = b + b = nil + + case recordTypeHandshake: + // TODO(rsc): Should at least pick off connection close. + if typ != want { + // A client might need to process a HelloRequest from + // the server, thus receiving a handshake message when + // application data is expected is ok. Moreover, a DTLS + // peer who sends Finished second may retransmit the + // final leg. BoringSSL retrainsmits on an internal + // timer, so this may also occur in test code. + if !c.isClient && !c.isDTLS { + return c.in.setErrorLocked(c.sendAlert(alertNoRenegotiation)) + } + } + c.hand.Write(data) + } + + if b != nil { + c.in.freeBlock(b) + } + return c.in.err +} + +// sendAlert sends a TLS alert message. +// c.out.Mutex <= L. +func (c *Conn) sendAlertLocked(err alert) error { + switch err { + case alertNoRenegotiation, alertCloseNotify: + c.tmp[0] = alertLevelWarning + default: + c.tmp[0] = alertLevelError + } + c.tmp[1] = byte(err) + if c.config.Bugs.FragmentAlert { + c.writeRecord(recordTypeAlert, c.tmp[0:1]) + c.writeRecord(recordTypeAlert, c.tmp[1:2]) + } else { + c.writeRecord(recordTypeAlert, c.tmp[0:2]) + } + // closeNotify is a special case in that it isn't an error: + if err != alertCloseNotify { + return c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err}) + } + return nil +} + +// sendAlert sends a TLS alert message. +// L < c.out.Mutex. +func (c *Conn) sendAlert(err alert) error { + c.out.Lock() + defer c.out.Unlock() + return c.sendAlertLocked(err) +} + +// writeV2Record writes a record for a V2ClientHello. +func (c *Conn) writeV2Record(data []byte) (n int, err error) { + record := make([]byte, 2+len(data)) + record[0] = uint8(len(data)>>8) | 0x80 + record[1] = uint8(len(data)) + copy(record[2:], data) + return c.conn.Write(record) +} + +// writeRecord writes a TLS record with the given type and payload +// to the connection and updates the record layer state. +// c.out.Mutex <= L. +func (c *Conn) writeRecord(typ recordType, data []byte) (n int, err error) { + if c.isDTLS { + return c.dtlsWriteRecord(typ, data) + } + + recordHeaderLen := tlsRecordHeaderLen + b := c.out.newBlock() + first := true + isClientHello := typ == recordTypeHandshake && len(data) > 0 && data[0] == typeClientHello + for len(data) > 0 { + m := len(data) + if m > maxPlaintext { + m = maxPlaintext + } + if typ == recordTypeHandshake && c.config.Bugs.MaxHandshakeRecordLength > 0 && m > c.config.Bugs.MaxHandshakeRecordLength { + m = c.config.Bugs.MaxHandshakeRecordLength + // By default, do not fragment the client_version or + // server_version, which are located in the first 6 + // bytes. + if first && isClientHello && !c.config.Bugs.FragmentClientVersion && m < 6 { + m = 6 + } + } + explicitIVLen := 0 + explicitIVIsSeq := false + first = false + + var cbc cbcMode + if c.out.version >= VersionTLS11 { + var ok bool + if cbc, ok = c.out.cipher.(cbcMode); ok { + explicitIVLen = cbc.BlockSize() + } + } + if explicitIVLen == 0 { + if _, ok := c.out.cipher.(cipher.AEAD); ok { + explicitIVLen = 8 + // The AES-GCM construction in TLS has an + // explicit nonce so that the nonce can be + // random. However, the nonce is only 8 bytes + // which is too small for a secure, random + // nonce. Therefore we use the sequence number + // as the nonce. + explicitIVIsSeq = true + } + } + b.resize(recordHeaderLen + explicitIVLen + m) + b.data[0] = byte(typ) + vers := c.vers + if vers == 0 { + // Some TLS servers fail if the record version is + // greater than TLS 1.0 for the initial ClientHello. + vers = VersionTLS10 + } + b.data[1] = byte(vers >> 8) + b.data[2] = byte(vers) + b.data[3] = byte(m >> 8) + b.data[4] = byte(m) + if explicitIVLen > 0 { + explicitIV := b.data[recordHeaderLen : recordHeaderLen+explicitIVLen] + if explicitIVIsSeq { + copy(explicitIV, c.out.seq[:]) + } else { + if _, err = io.ReadFull(c.config.rand(), explicitIV); err != nil { + break + } + } + } + copy(b.data[recordHeaderLen+explicitIVLen:], data) + c.out.encrypt(b, explicitIVLen) + _, err = c.conn.Write(b.data) + if err != nil { + break + } + n += m + data = data[m:] + } + c.out.freeBlock(b) + + if typ == recordTypeChangeCipherSpec { + err = c.out.changeCipherSpec(c.config) + if err != nil { + // Cannot call sendAlert directly, + // because we already hold c.out.Mutex. + c.tmp[0] = alertLevelError + c.tmp[1] = byte(err.(alert)) + c.writeRecord(recordTypeAlert, c.tmp[0:2]) + return n, c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err}) + } + } + return +} + +func (c *Conn) doReadHandshake() ([]byte, error) { + if c.isDTLS { + return c.dtlsDoReadHandshake() + } + + for c.hand.Len() < 4 { + if err := c.in.err; err != nil { + return nil, err + } + if err := c.readRecord(recordTypeHandshake); err != nil { + return nil, err + } + } + + data := c.hand.Bytes() + n := int(data[1])<<16 | int(data[2])<<8 | int(data[3]) + if n > maxHandshake { + return nil, c.in.setErrorLocked(c.sendAlert(alertInternalError)) + } + for c.hand.Len() < 4+n { + if err := c.in.err; err != nil { + return nil, err + } + if err := c.readRecord(recordTypeHandshake); err != nil { + return nil, err + } + } + return c.hand.Next(4 + n), nil +} + +// readHandshake reads the next handshake message from +// the record layer. +// c.in.Mutex < L; c.out.Mutex < L. +func (c *Conn) readHandshake() (interface{}, error) { + data, err := c.doReadHandshake() + if err != nil { + return nil, err + } + + var m handshakeMessage + switch data[0] { + case typeHelloRequest: + m = new(helloRequestMsg) + case typeClientHello: + m = &clientHelloMsg{ + isDTLS: c.isDTLS, + } + case typeServerHello: + m = &serverHelloMsg{ + isDTLS: c.isDTLS, + } + case typeNewSessionTicket: + m = new(newSessionTicketMsg) + case typeCertificate: + m = new(certificateMsg) + case typeCertificateRequest: + m = &certificateRequestMsg{ + hasSignatureAndHash: c.vers >= VersionTLS12, + } + case typeCertificateStatus: + m = new(certificateStatusMsg) + case typeServerKeyExchange: + m = new(serverKeyExchangeMsg) + case typeServerHelloDone: + m = new(serverHelloDoneMsg) + case typeClientKeyExchange: + m = new(clientKeyExchangeMsg) + case typeCertificateVerify: + m = &certificateVerifyMsg{ + hasSignatureAndHash: c.vers >= VersionTLS12, + } + case typeNextProtocol: + m = new(nextProtoMsg) + case typeFinished: + m = new(finishedMsg) + case typeHelloVerifyRequest: + m = new(helloVerifyRequestMsg) + case typeEncryptedExtensions: + m = new(encryptedExtensionsMsg) + default: + return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + + // The handshake message unmarshallers + // expect to be able to keep references to data, + // so pass in a fresh copy that won't be overwritten. + data = append([]byte(nil), data...) + + if !m.unmarshal(data) { + return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage)) + } + return m, nil +} + +// Write writes data to the connection. +func (c *Conn) Write(b []byte) (int, error) { + if err := c.Handshake(); err != nil { + return 0, err + } + + c.out.Lock() + defer c.out.Unlock() + + if err := c.out.err; err != nil { + return 0, err + } + + if !c.handshakeComplete { + return 0, alertInternalError + } + + if c.config.Bugs.SendSpuriousAlert { + c.sendAlertLocked(alertRecordOverflow) + } + + // SSL 3.0 and TLS 1.0 are susceptible to a chosen-plaintext + // attack when using block mode ciphers due to predictable IVs. + // This can be prevented by splitting each Application Data + // record into two records, effectively randomizing the IV. + // + // http://www.openssl.org/~bodo/tls-cbc.txt + // https://bugzilla.mozilla.org/show_bug.cgi?id=665814 + // http://www.imperialviolet.org/2012/01/15/beastfollowup.html + + var m int + if len(b) > 1 && c.vers <= VersionTLS10 && !c.isDTLS { + if _, ok := c.out.cipher.(cipher.BlockMode); ok { + n, err := c.writeRecord(recordTypeApplicationData, b[:1]) + if err != nil { + return n, c.out.setErrorLocked(err) + } + m, b = 1, b[1:] + } + } + + n, err := c.writeRecord(recordTypeApplicationData, b) + return n + m, c.out.setErrorLocked(err) +} + +func (c *Conn) handleRenegotiation() error { + c.handshakeComplete = false + if !c.isClient { + panic("renegotiation should only happen for a client") + } + + msg, err := c.readHandshake() + if err != nil { + return err + } + _, ok := msg.(*helloRequestMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return alertUnexpectedMessage + } + + return c.Handshake() +} + +func (c *Conn) Renegotiate() error { + if !c.isClient { + helloReq := new(helloRequestMsg) + c.writeRecord(recordTypeHandshake, helloReq.marshal()) + } + + c.handshakeComplete = false + return c.Handshake() +} + +// Read can be made to time out and return a net.Error with Timeout() == true +// after a fixed time limit; see SetDeadline and SetReadDeadline. +func (c *Conn) Read(b []byte) (n int, err error) { + if err = c.Handshake(); err != nil { + return + } + + c.in.Lock() + defer c.in.Unlock() + + // Some OpenSSL servers send empty records in order to randomize the + // CBC IV. So this loop ignores a limited number of empty records. + const maxConsecutiveEmptyRecords = 100 + for emptyRecordCount := 0; emptyRecordCount <= maxConsecutiveEmptyRecords; emptyRecordCount++ { + for c.input == nil && c.in.err == nil { + if err := c.readRecord(recordTypeApplicationData); err != nil { + // Soft error, like EAGAIN + return 0, err + } + if c.hand.Len() > 0 && !c.isDTLS { + // We received handshake bytes, indicating the + // start of a renegotiation or a DTLS retransmit. + if err := c.handleRenegotiation(); err != nil { + return 0, err + } + continue + } + } + if err := c.in.err; err != nil { + return 0, err + } + + n, err = c.input.Read(b) + if c.input.off >= len(c.input.data) || c.isDTLS { + c.in.freeBlock(c.input) + c.input = nil + } + + // If a close-notify alert is waiting, read it so that + // we can return (n, EOF) instead of (n, nil), to signal + // to the HTTP response reading goroutine that the + // connection is now closed. This eliminates a race + // where the HTTP response reading goroutine would + // otherwise not observe the EOF until its next read, + // by which time a client goroutine might have already + // tried to reuse the HTTP connection for a new + // request. + // See https://codereview.appspot.com/76400046 + // and http://golang.org/issue/3514 + if ri := c.rawInput; ri != nil && + n != 0 && err == nil && + c.input == nil && len(ri.data) > 0 && recordType(ri.data[0]) == recordTypeAlert { + if recErr := c.readRecord(recordTypeApplicationData); recErr != nil { + err = recErr // will be io.EOF on closeNotify + } + } + + if n != 0 || err != nil { + return n, err + } + } + + return 0, io.ErrNoProgress +} + +// Close closes the connection. +func (c *Conn) Close() error { + var alertErr error + + c.handshakeMutex.Lock() + defer c.handshakeMutex.Unlock() + if c.handshakeComplete { + alertErr = c.sendAlert(alertCloseNotify) + } + + if err := c.conn.Close(); err != nil { + return err + } + return alertErr +} + +// Handshake runs the client or server handshake +// protocol if it has not yet been run. +// Most uses of this package need not call Handshake +// explicitly: the first Read or Write will call it automatically. +func (c *Conn) Handshake() error { + c.handshakeMutex.Lock() + defer c.handshakeMutex.Unlock() + if err := c.handshakeErr; err != nil { + return err + } + if c.handshakeComplete { + return nil + } + + if c.isClient { + c.handshakeErr = c.clientHandshake() + } else { + c.handshakeErr = c.serverHandshake() + } + return c.handshakeErr +} + +// ConnectionState returns basic TLS details about the connection. +func (c *Conn) ConnectionState() ConnectionState { + c.handshakeMutex.Lock() + defer c.handshakeMutex.Unlock() + + var state ConnectionState + state.HandshakeComplete = c.handshakeComplete + if c.handshakeComplete { + state.Version = c.vers + state.NegotiatedProtocol = c.clientProtocol + state.DidResume = c.didResume + state.NegotiatedProtocolIsMutual = !c.clientProtocolFallback + state.NegotiatedProtocolFromALPN = c.usedALPN + state.CipherSuite = c.cipherSuite + state.PeerCertificates = c.peerCertificates + state.VerifiedChains = c.verifiedChains + state.ServerName = c.serverName + state.ChannelID = c.channelID + state.SRTPProtectionProfile = c.srtpProtectionProfile + } + + return state +} + +// OCSPResponse returns the stapled OCSP response from the TLS server, if +// any. (Only valid for client connections.) +func (c *Conn) OCSPResponse() []byte { + c.handshakeMutex.Lock() + defer c.handshakeMutex.Unlock() + + return c.ocspResponse +} + +// VerifyHostname checks that the peer certificate chain is valid for +// connecting to host. If so, it returns nil; if not, it returns an error +// describing the problem. +func (c *Conn) VerifyHostname(host string) error { + c.handshakeMutex.Lock() + defer c.handshakeMutex.Unlock() + if !c.isClient { + return errors.New("tls: VerifyHostname called on TLS server connection") + } + if !c.handshakeComplete { + return errors.New("tls: handshake has not yet been performed") + } + return c.peerCertificates[0].VerifyHostname(host) +} diff --git a/src/ssl/test/runner/dtls.go b/src/ssl/test/runner/dtls.go new file mode 100644 index 0000000..a395980 --- /dev/null +++ b/src/ssl/test/runner/dtls.go @@ -0,0 +1,342 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// DTLS implementation. +// +// NOTE: This is a not even a remotely production-quality DTLS +// implementation. It is the bare minimum necessary to be able to +// achieve coverage on BoringSSL's implementation. Of note is that +// this implementation assumes the underlying net.PacketConn is not +// only reliable but also ordered. BoringSSL will be expected to deal +// with simulated loss, but there is no point in forcing the test +// driver to. + +package main + +import ( + "bytes" + "crypto/cipher" + "errors" + "fmt" + "io" + "net" +) + +func versionToWire(vers uint16, isDTLS bool) uint16 { + if isDTLS { + return ^(vers - 0x0201) + } + return vers +} + +func wireToVersion(vers uint16, isDTLS bool) uint16 { + if isDTLS { + return ^vers + 0x0201 + } + return vers +} + +func (c *Conn) dtlsDoReadRecord(want recordType) (recordType, *block, error) { +Again: + recordHeaderLen := dtlsRecordHeaderLen + + if c.rawInput == nil { + c.rawInput = c.in.newBlock() + } + b := c.rawInput + + // Read a new packet only if the current one is empty. + if len(b.data) == 0 { + // Pick some absurdly large buffer size. + b.resize(maxCiphertext + recordHeaderLen) + n, err := c.conn.Read(c.rawInput.data) + if err != nil { + return 0, nil, err + } + if c.config.Bugs.MaxPacketLength != 0 && n > c.config.Bugs.MaxPacketLength { + return 0, nil, fmt.Errorf("dtls: exceeded maximum packet length") + } + c.rawInput.resize(n) + } + + // Read out one record. + // + // A real DTLS implementation should be tolerant of errors, + // but this is test code. We should not be tolerant of our + // peer sending garbage. + if len(b.data) < recordHeaderLen { + return 0, nil, errors.New("dtls: failed to read record header") + } + typ := recordType(b.data[0]) + vers := wireToVersion(uint16(b.data[1])<<8|uint16(b.data[2]), c.isDTLS) + if c.haveVers { + if vers != c.vers { + c.sendAlert(alertProtocolVersion) + return 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: received record with version %x when expecting version %x", vers, c.vers)) + } + } else { + if expect := c.config.Bugs.ExpectInitialRecordVersion; expect != 0 && vers != expect { + c.sendAlert(alertProtocolVersion) + return 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: received record with version %x when expecting version %x", vers, expect)) + } + } + seq := b.data[3:11] + if !bytes.Equal(seq[:2], c.in.seq[:2]) { + // If the epoch didn't match, silently drop the record. + // BoringSSL retransmits on an internal timer, so it may flakily + // revisit the previous epoch if retransmiting ChangeCipherSpec + // and Finished. + goto Again + } + // For test purposes, we assume a reliable channel. Require + // that the explicit sequence number matches the incrementing + // one we maintain. A real implementation would maintain a + // replay window and such. + if !bytes.Equal(seq, c.in.seq[:]) { + c.sendAlert(alertIllegalParameter) + return 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: bad sequence number")) + } + n := int(b.data[11])<<8 | int(b.data[12]) + if n > maxCiphertext || len(b.data) < recordHeaderLen+n { + c.sendAlert(alertRecordOverflow) + return 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: oversized record received with length %d", n)) + } + + // Process message. + b, c.rawInput = c.in.splitBlock(b, recordHeaderLen+n) + ok, off, err := c.in.decrypt(b) + if !ok { + c.in.setErrorLocked(c.sendAlert(err)) + } + b.off = off + return typ, b, nil +} + +func (c *Conn) dtlsWriteRecord(typ recordType, data []byte) (n int, err error) { + recordHeaderLen := dtlsRecordHeaderLen + maxLen := c.config.Bugs.MaxHandshakeRecordLength + if maxLen <= 0 { + maxLen = 1024 + } + + b := c.out.newBlock() + + var header []byte + if typ == recordTypeHandshake { + // Handshake messages have to be modified to include + // fragment offset and length and with the header + // replicated. Save the header here. + // + // TODO(davidben): This assumes that data contains + // exactly one handshake message. This is incompatible + // with FragmentAcrossChangeCipherSpec. (Which is + // unfortunate because OpenSSL's DTLS implementation + // will probably accept such fragmentation and could + // do with a fix + tests.) + if len(data) < 4 { + // This should not happen. + panic(data) + } + header = data[:4] + data = data[4:] + } + + firstRun := true + for firstRun || len(data) > 0 { + firstRun = false + m := len(data) + var fragment []byte + // Handshake messages get fragmented. Other records we + // pass-through as is. DTLS should be a packet + // interface. + if typ == recordTypeHandshake { + if m > maxLen { + m = maxLen + } + + // Standard handshake header. + fragment = make([]byte, 0, 12+m) + fragment = append(fragment, header...) + // message_seq + fragment = append(fragment, byte(c.sendHandshakeSeq>>8), byte(c.sendHandshakeSeq)) + // fragment_offset + fragment = append(fragment, byte(n>>16), byte(n>>8), byte(n)) + // fragment_length + fragment = append(fragment, byte(m>>16), byte(m>>8), byte(m)) + fragment = append(fragment, data[:m]...) + } else { + fragment = data[:m] + } + + // Send the fragment. + explicitIVLen := 0 + explicitIVIsSeq := false + + if cbc, ok := c.out.cipher.(cbcMode); ok { + // Block cipher modes have an explicit IV. + explicitIVLen = cbc.BlockSize() + } else if _, ok := c.out.cipher.(cipher.AEAD); ok { + explicitIVLen = 8 + // The AES-GCM construction in TLS has an + // explicit nonce so that the nonce can be + // random. However, the nonce is only 8 bytes + // which is too small for a secure, random + // nonce. Therefore we use the sequence number + // as the nonce. + explicitIVIsSeq = true + } else if c.out.cipher != nil { + panic("Unknown cipher") + } + b.resize(recordHeaderLen + explicitIVLen + len(fragment)) + b.data[0] = byte(typ) + vers := c.vers + if vers == 0 { + // Some TLS servers fail if the record version is + // greater than TLS 1.0 for the initial ClientHello. + vers = VersionTLS10 + } + vers = versionToWire(vers, c.isDTLS) + b.data[1] = byte(vers >> 8) + b.data[2] = byte(vers) + // DTLS records include an explicit sequence number. + copy(b.data[3:11], c.out.seq[0:]) + b.data[11] = byte(len(fragment) >> 8) + b.data[12] = byte(len(fragment)) + if explicitIVLen > 0 { + explicitIV := b.data[recordHeaderLen : recordHeaderLen+explicitIVLen] + if explicitIVIsSeq { + copy(explicitIV, c.out.seq[:]) + } else { + if _, err = io.ReadFull(c.config.rand(), explicitIV); err != nil { + break + } + } + } + copy(b.data[recordHeaderLen+explicitIVLen:], fragment) + c.out.encrypt(b, explicitIVLen) + + // TODO(davidben): A real DTLS implementation needs to + // retransmit handshake messages. For testing + // purposes, we don't actually care. + _, err = c.conn.Write(b.data) + if err != nil { + break + } + n += m + data = data[m:] + } + c.out.freeBlock(b) + + // Increment the handshake sequence number for the next + // handshake message. + if typ == recordTypeHandshake { + c.sendHandshakeSeq++ + } + + if typ == recordTypeChangeCipherSpec { + err = c.out.changeCipherSpec(c.config) + if err != nil { + // Cannot call sendAlert directly, + // because we already hold c.out.Mutex. + c.tmp[0] = alertLevelError + c.tmp[1] = byte(err.(alert)) + c.writeRecord(recordTypeAlert, c.tmp[0:2]) + return n, c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err}) + } + } + return +} + +func (c *Conn) dtlsDoReadHandshake() ([]byte, error) { + // Assemble a full handshake message. For test purposes, this + // implementation assumes fragments arrive in order, but tolerates + // retransmits. It may need to be cleverer if we ever test BoringSSL's + // retransmit behavior. + for len(c.handMsg) < 4+c.handMsgLen { + // Get a new handshake record if the previous has been + // exhausted. + if c.hand.Len() == 0 { + if err := c.in.err; err != nil { + return nil, err + } + if err := c.readRecord(recordTypeHandshake); err != nil { + return nil, err + } + } + + // Read the next fragment. It must fit entirely within + // the record. + if c.hand.Len() < 12 { + return nil, errors.New("dtls: bad handshake record") + } + header := c.hand.Next(12) + fragN := int(header[1])<<16 | int(header[2])<<8 | int(header[3]) + fragSeq := uint16(header[4])<<8 | uint16(header[5]) + fragOff := int(header[6])<<16 | int(header[7])<<8 | int(header[8]) + fragLen := int(header[9])<<16 | int(header[10])<<8 | int(header[11]) + + if c.hand.Len() < fragLen { + return nil, errors.New("dtls: fragment length too long") + } + fragment := c.hand.Next(fragLen) + + if fragSeq < c.recvHandshakeSeq { + // BoringSSL retransmits based on an internal timer, so + // it may flakily retransmit part of a handshake + // message. Ignore those fragments. + // + // TODO(davidben): Revise this if BoringSSL's retransmit + // logic is made more deterministic. + continue + } else if fragSeq > c.recvHandshakeSeq { + return nil, errors.New("dtls: handshake messages sent out of order") + } + + // Check that the length is consistent. + if c.handMsg == nil { + c.handMsgLen = fragN + if c.handMsgLen > maxHandshake { + return nil, c.in.setErrorLocked(c.sendAlert(alertInternalError)) + } + // Start with the TLS handshake header, + // without the DTLS bits. + c.handMsg = append([]byte{}, header[:4]...) + } else if fragN != c.handMsgLen { + return nil, errors.New("dtls: bad handshake length") + } + + // Add the fragment to the pending message. + if 4+fragOff != len(c.handMsg) { + return nil, errors.New("dtls: bad fragment offset") + } + if fragOff+fragLen > c.handMsgLen { + return nil, errors.New("dtls: bad fragment length") + } + c.handMsg = append(c.handMsg, fragment...) + } + c.recvHandshakeSeq++ + ret := c.handMsg + c.handMsg, c.handMsgLen = nil, 0 + return ret, nil +} + +// DTLSServer returns a new DTLS server side connection +// using conn as the underlying transport. +// The configuration config must be non-nil and must have +// at least one certificate. +func DTLSServer(conn net.Conn, config *Config) *Conn { + c := &Conn{config: config, isDTLS: true, conn: conn} + c.init() + return c +} + +// DTLSClient returns a new DTLS client side connection +// using conn as the underlying transport. +// The config cannot be nil: users must set either ServerHostname or +// InsecureSkipVerify in the config. +func DTLSClient(conn net.Conn, config *Config) *Conn { + c := &Conn{config: config, isClient: true, isDTLS: true, conn: conn} + c.init() + return c +} diff --git a/src/ssl/test/runner/ecdsa_cert.pem b/src/ssl/test/runner/ecdsa_cert.pem new file mode 100644 index 0000000..50bcbf5 --- /dev/null +++ b/src/ssl/test/runner/ecdsa_cert.pem @@ -0,0 +1,12 @@ +-----BEGIN CERTIFICATE----- +MIIBzzCCAXagAwIBAgIJANlMBNpJfb/rMAkGByqGSM49BAEwRTELMAkGA1UEBhMC +QVUxEzARBgNVBAgMClNvbWUtU3RhdGUxITAfBgNVBAoMGEludGVybmV0IFdpZGdp +dHMgUHR5IEx0ZDAeFw0xNDA0MjMyMzIxNTdaFw0xNDA1MjMyMzIxNTdaMEUxCzAJ +BgNVBAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRlcm5l +dCBXaWRnaXRzIFB0eSBMdGQwWTATBgcqhkjOPQIBBggqhkjOPQMBBwNCAATmK2ni +v2Wfl74vHg2UikzVl2u3qR4NRvvdqakendy6WgHn1peoChj5w8SjHlbifINI2xYa +HPUdfvGULUvPciLBo1AwTjAdBgNVHQ4EFgQUq4TSrKuV8IJOFngHVVdf5CaNgtEw +HwYDVR0jBBgwFoAUq4TSrKuV8IJOFngHVVdf5CaNgtEwDAYDVR0TBAUwAwEB/zAJ +BgcqhkjOPQQBA0gAMEUCIQDyoDVeUTo2w4J5m+4nUIWOcAZ0lVfSKXQA9L4Vh13E +BwIgfB55FGohg/B6dGh5XxSZmmi08cueFV7mHzJSYV51yRQ= +-----END CERTIFICATE----- diff --git a/src/ssl/test/runner/ecdsa_key.pem b/src/ssl/test/runner/ecdsa_key.pem new file mode 100644 index 0000000..b9116f0 --- /dev/null +++ b/src/ssl/test/runner/ecdsa_key.pem @@ -0,0 +1,8 @@ +-----BEGIN EC PARAMETERS----- +BggqhkjOPQMBBw== +-----END EC PARAMETERS----- +-----BEGIN EC PRIVATE KEY----- +MHcCAQEEIAcPCHJ61KBKnN1ZyU2JaHcItW/JXTB3DujRyc4Ki7RqoAoGCCqGSM49 +AwEHoUQDQgAE5itp4r9ln5e+Lx4NlIpM1Zdrt6keDUb73ampHp3culoB59aXqAoY ++cPEox5W4nyDSNsWGhz1HX7xlC1Lz3IiwQ== +-----END EC PRIVATE KEY----- diff --git a/src/ssl/test/runner/handshake_client.go b/src/ssl/test/runner/handshake_client.go new file mode 100644 index 0000000..f297fc1 --- /dev/null +++ b/src/ssl/test/runner/handshake_client.go @@ -0,0 +1,910 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "bytes" + "crypto" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rsa" + "crypto/subtle" + "crypto/x509" + "encoding/asn1" + "errors" + "fmt" + "io" + "math/big" + "net" + "strconv" +) + +type clientHandshakeState struct { + c *Conn + serverHello *serverHelloMsg + hello *clientHelloMsg + suite *cipherSuite + finishedHash finishedHash + masterSecret []byte + session *ClientSessionState +} + +func (c *Conn) clientHandshake() error { + if c.config == nil { + c.config = defaultConfig() + } + + if len(c.config.ServerName) == 0 && !c.config.InsecureSkipVerify { + return errors.New("tls: either ServerName or InsecureSkipVerify must be specified in the tls.Config") + } + + c.sendHandshakeSeq = 0 + c.recvHandshakeSeq = 0 + + nextProtosLength := 0 + for _, proto := range c.config.NextProtos { + if l := len(proto); l == 0 || l > 255 { + return errors.New("tls: invalid NextProtos value") + } else { + nextProtosLength += 1 + l + } + } + if nextProtosLength > 0xffff { + return errors.New("tls: NextProtos values too large") + } + + hello := &clientHelloMsg{ + isDTLS: c.isDTLS, + vers: c.config.maxVersion(), + compressionMethods: []uint8{compressionNone}, + random: make([]byte, 32), + ocspStapling: true, + serverName: c.config.ServerName, + supportedCurves: c.config.curvePreferences(), + supportedPoints: []uint8{pointFormatUncompressed}, + nextProtoNeg: len(c.config.NextProtos) > 0, + secureRenegotiation: []byte{}, + alpnProtocols: c.config.NextProtos, + duplicateExtension: c.config.Bugs.DuplicateExtension, + channelIDSupported: c.config.ChannelID != nil, + npnLast: c.config.Bugs.SwapNPNAndALPN, + extendedMasterSecret: c.config.maxVersion() >= VersionTLS10, + srtpProtectionProfiles: c.config.SRTPProtectionProfiles, + srtpMasterKeyIdentifier: c.config.Bugs.SRTPMasterKeyIdentifer, + } + + if c.config.Bugs.SendClientVersion != 0 { + hello.vers = c.config.Bugs.SendClientVersion + } + + if c.config.Bugs.NoExtendedMasterSecret { + hello.extendedMasterSecret = false + } + + if len(c.clientVerify) > 0 && !c.config.Bugs.EmptyRenegotiationInfo { + if c.config.Bugs.BadRenegotiationInfo { + hello.secureRenegotiation = append(hello.secureRenegotiation, c.clientVerify...) + hello.secureRenegotiation[0] ^= 0x80 + } else { + hello.secureRenegotiation = c.clientVerify + } + } + + if c.config.Bugs.NoRenegotiationInfo { + hello.secureRenegotiation = nil + } + + possibleCipherSuites := c.config.cipherSuites() + hello.cipherSuites = make([]uint16, 0, len(possibleCipherSuites)) + +NextCipherSuite: + for _, suiteId := range possibleCipherSuites { + for _, suite := range cipherSuites { + if suite.id != suiteId { + continue + } + // Don't advertise TLS 1.2-only cipher suites unless + // we're attempting TLS 1.2. + if hello.vers < VersionTLS12 && suite.flags&suiteTLS12 != 0 { + continue + } + // Don't advertise non-DTLS cipher suites on DTLS. + if c.isDTLS && suite.flags&suiteNoDTLS != 0 { + continue + } + hello.cipherSuites = append(hello.cipherSuites, suiteId) + continue NextCipherSuite + } + } + + if c.config.Bugs.SendFallbackSCSV { + hello.cipherSuites = append(hello.cipherSuites, fallbackSCSV) + } + + _, err := io.ReadFull(c.config.rand(), hello.random) + if err != nil { + c.sendAlert(alertInternalError) + return errors.New("tls: short read from Rand: " + err.Error()) + } + + if hello.vers >= VersionTLS12 && !c.config.Bugs.NoSignatureAndHashes { + hello.signatureAndHashes = c.config.signatureAndHashesForClient() + } + + var session *ClientSessionState + var cacheKey string + sessionCache := c.config.ClientSessionCache + + if sessionCache != nil { + hello.ticketSupported = !c.config.SessionTicketsDisabled + + // Try to resume a previously negotiated TLS session, if + // available. + cacheKey = clientSessionCacheKey(c.conn.RemoteAddr(), c.config) + candidateSession, ok := sessionCache.Get(cacheKey) + if ok { + ticketOk := !c.config.SessionTicketsDisabled || candidateSession.sessionTicket == nil + + // Check that the ciphersuite/version used for the + // previous session are still valid. + cipherSuiteOk := false + for _, id := range hello.cipherSuites { + if id == candidateSession.cipherSuite { + cipherSuiteOk = true + break + } + } + + versOk := candidateSession.vers >= c.config.minVersion() && + candidateSession.vers <= c.config.maxVersion() + if ticketOk && versOk && cipherSuiteOk { + session = candidateSession + } + } + } + + if session != nil { + if session.sessionTicket != nil { + hello.sessionTicket = session.sessionTicket + if c.config.Bugs.CorruptTicket { + hello.sessionTicket = make([]byte, len(session.sessionTicket)) + copy(hello.sessionTicket, session.sessionTicket) + if len(hello.sessionTicket) > 0 { + offset := 40 + if offset > len(hello.sessionTicket) { + offset = len(hello.sessionTicket) - 1 + } + hello.sessionTicket[offset] ^= 0x40 + } + } + // A random session ID is used to detect when the + // server accepted the ticket and is resuming a session + // (see RFC 5077). + sessionIdLen := 16 + if c.config.Bugs.OversizedSessionId { + sessionIdLen = 33 + } + hello.sessionId = make([]byte, sessionIdLen) + if _, err := io.ReadFull(c.config.rand(), hello.sessionId); err != nil { + c.sendAlert(alertInternalError) + return errors.New("tls: short read from Rand: " + err.Error()) + } + } else { + hello.sessionId = session.sessionId + } + } + + var helloBytes []byte + if c.config.Bugs.SendV2ClientHello { + // Test that the peer left-pads random. + hello.random[0] = 0 + v2Hello := &v2ClientHelloMsg{ + vers: hello.vers, + cipherSuites: hello.cipherSuites, + // No session resumption for V2ClientHello. + sessionId: nil, + challenge: hello.random[1:], + } + helloBytes = v2Hello.marshal() + c.writeV2Record(helloBytes) + } else { + helloBytes = hello.marshal() + c.writeRecord(recordTypeHandshake, helloBytes) + } + + msg, err := c.readHandshake() + if err != nil { + return err + } + + if c.isDTLS { + helloVerifyRequest, ok := msg.(*helloVerifyRequestMsg) + if ok { + if helloVerifyRequest.vers != VersionTLS10 { + // Per RFC 6347, the version field in + // HelloVerifyRequest SHOULD be always DTLS + // 1.0. Enforce this for testing purposes. + return errors.New("dtls: bad HelloVerifyRequest version") + } + + hello.raw = nil + hello.cookie = helloVerifyRequest.cookie + helloBytes = hello.marshal() + c.writeRecord(recordTypeHandshake, helloBytes) + + msg, err = c.readHandshake() + if err != nil { + return err + } + } + } + + serverHello, ok := msg.(*serverHelloMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(serverHello, msg) + } + + c.vers, ok = c.config.mutualVersion(serverHello.vers) + if !ok { + c.sendAlert(alertProtocolVersion) + return fmt.Errorf("tls: server selected unsupported protocol version %x", serverHello.vers) + } + c.haveVers = true + + suite := mutualCipherSuite(c.config.cipherSuites(), serverHello.cipherSuite) + if suite == nil { + c.sendAlert(alertHandshakeFailure) + return fmt.Errorf("tls: server selected an unsupported cipher suite") + } + + if len(c.clientVerify) > 0 && !c.config.Bugs.NoRenegotiationInfo { + var expectedRenegInfo []byte + expectedRenegInfo = append(expectedRenegInfo, c.clientVerify...) + expectedRenegInfo = append(expectedRenegInfo, c.serverVerify...) + if !bytes.Equal(serverHello.secureRenegotiation, expectedRenegInfo) { + c.sendAlert(alertHandshakeFailure) + return fmt.Errorf("tls: renegotiation mismatch") + } + } + + hs := &clientHandshakeState{ + c: c, + serverHello: serverHello, + hello: hello, + suite: suite, + finishedHash: newFinishedHash(c.vers, suite), + session: session, + } + + hs.writeHash(helloBytes, hs.c.sendHandshakeSeq-1) + hs.writeServerHash(hs.serverHello.marshal()) + + if c.config.Bugs.EarlyChangeCipherSpec > 0 { + hs.establishKeys() + c.writeRecord(recordTypeChangeCipherSpec, []byte{1}) + } + + isResume, err := hs.processServerHello() + if err != nil { + return err + } + + if isResume { + if c.config.Bugs.EarlyChangeCipherSpec == 0 { + if err := hs.establishKeys(); err != nil { + return err + } + } + if err := hs.readSessionTicket(); err != nil { + return err + } + if err := hs.readFinished(); err != nil { + return err + } + if err := hs.sendFinished(isResume); err != nil { + return err + } + } else { + if err := hs.doFullHandshake(); err != nil { + return err + } + if err := hs.establishKeys(); err != nil { + return err + } + if err := hs.sendFinished(isResume); err != nil { + return err + } + if err := hs.readSessionTicket(); err != nil { + return err + } + if err := hs.readFinished(); err != nil { + return err + } + } + + if sessionCache != nil && hs.session != nil && session != hs.session { + sessionCache.Put(cacheKey, hs.session) + } + + c.didResume = isResume + c.handshakeComplete = true + c.cipherSuite = suite.id + return nil +} + +func (hs *clientHandshakeState) doFullHandshake() error { + c := hs.c + + var leaf *x509.Certificate + if hs.suite.flags&suitePSK == 0 { + msg, err := c.readHandshake() + if err != nil { + return err + } + + certMsg, ok := msg.(*certificateMsg) + if !ok || len(certMsg.certificates) == 0 { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(certMsg, msg) + } + hs.writeServerHash(certMsg.marshal()) + + certs := make([]*x509.Certificate, len(certMsg.certificates)) + for i, asn1Data := range certMsg.certificates { + cert, err := x509.ParseCertificate(asn1Data) + if err != nil { + c.sendAlert(alertBadCertificate) + return errors.New("tls: failed to parse certificate from server: " + err.Error()) + } + certs[i] = cert + } + leaf = certs[0] + + if !c.config.InsecureSkipVerify { + opts := x509.VerifyOptions{ + Roots: c.config.RootCAs, + CurrentTime: c.config.time(), + DNSName: c.config.ServerName, + Intermediates: x509.NewCertPool(), + } + + for i, cert := range certs { + if i == 0 { + continue + } + opts.Intermediates.AddCert(cert) + } + c.verifiedChains, err = leaf.Verify(opts) + if err != nil { + c.sendAlert(alertBadCertificate) + return err + } + } + + switch leaf.PublicKey.(type) { + case *rsa.PublicKey, *ecdsa.PublicKey: + break + default: + c.sendAlert(alertUnsupportedCertificate) + return fmt.Errorf("tls: server's certificate contains an unsupported type of public key: %T", leaf.PublicKey) + } + + c.peerCertificates = certs + } + + if hs.serverHello.ocspStapling { + msg, err := c.readHandshake() + if err != nil { + return err + } + cs, ok := msg.(*certificateStatusMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(cs, msg) + } + hs.writeServerHash(cs.marshal()) + + if cs.statusType == statusTypeOCSP { + c.ocspResponse = cs.response + } + } + + msg, err := c.readHandshake() + if err != nil { + return err + } + + keyAgreement := hs.suite.ka(c.vers) + + skx, ok := msg.(*serverKeyExchangeMsg) + if ok { + hs.writeServerHash(skx.marshal()) + err = keyAgreement.processServerKeyExchange(c.config, hs.hello, hs.serverHello, leaf, skx) + if err != nil { + c.sendAlert(alertUnexpectedMessage) + return err + } + + msg, err = c.readHandshake() + if err != nil { + return err + } + } + + var chainToSend *Certificate + var certRequested bool + certReq, ok := msg.(*certificateRequestMsg) + if ok { + certRequested = true + + // RFC 4346 on the certificateAuthorities field: + // A list of the distinguished names of acceptable certificate + // authorities. These distinguished names may specify a desired + // distinguished name for a root CA or for a subordinate CA; + // thus, this message can be used to describe both known roots + // and a desired authorization space. If the + // certificate_authorities list is empty then the client MAY + // send any certificate of the appropriate + // ClientCertificateType, unless there is some external + // arrangement to the contrary. + + hs.writeServerHash(certReq.marshal()) + + var rsaAvail, ecdsaAvail bool + for _, certType := range certReq.certificateTypes { + switch certType { + case CertTypeRSASign: + rsaAvail = true + case CertTypeECDSASign: + ecdsaAvail = true + } + } + + // We need to search our list of client certs for one + // where SignatureAlgorithm is RSA and the Issuer is in + // certReq.certificateAuthorities + findCert: + for i, chain := range c.config.Certificates { + if !rsaAvail && !ecdsaAvail { + continue + } + + for j, cert := range chain.Certificate { + x509Cert := chain.Leaf + // parse the certificate if this isn't the leaf + // node, or if chain.Leaf was nil + if j != 0 || x509Cert == nil { + if x509Cert, err = x509.ParseCertificate(cert); err != nil { + c.sendAlert(alertInternalError) + return errors.New("tls: failed to parse client certificate #" + strconv.Itoa(i) + ": " + err.Error()) + } + } + + switch { + case rsaAvail && x509Cert.PublicKeyAlgorithm == x509.RSA: + case ecdsaAvail && x509Cert.PublicKeyAlgorithm == x509.ECDSA: + default: + continue findCert + } + + if len(certReq.certificateAuthorities) == 0 { + // they gave us an empty list, so just take the + // first RSA cert from c.config.Certificates + chainToSend = &chain + break findCert + } + + for _, ca := range certReq.certificateAuthorities { + if bytes.Equal(x509Cert.RawIssuer, ca) { + chainToSend = &chain + break findCert + } + } + } + } + + msg, err = c.readHandshake() + if err != nil { + return err + } + } + + shd, ok := msg.(*serverHelloDoneMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(shd, msg) + } + hs.writeServerHash(shd.marshal()) + + // If the server requested a certificate then we have to send a + // Certificate message, even if it's empty because we don't have a + // certificate to send. + if certRequested { + certMsg := new(certificateMsg) + if chainToSend != nil { + certMsg.certificates = chainToSend.Certificate + } + hs.writeClientHash(certMsg.marshal()) + c.writeRecord(recordTypeHandshake, certMsg.marshal()) + } + + preMasterSecret, ckx, err := keyAgreement.generateClientKeyExchange(c.config, hs.hello, leaf) + if err != nil { + c.sendAlert(alertInternalError) + return err + } + if ckx != nil { + if c.config.Bugs.EarlyChangeCipherSpec < 2 { + hs.writeClientHash(ckx.marshal()) + } + c.writeRecord(recordTypeHandshake, ckx.marshal()) + } + + if hs.serverHello.extendedMasterSecret && c.vers >= VersionTLS10 { + hs.masterSecret = extendedMasterFromPreMasterSecret(c.vers, hs.suite, preMasterSecret, hs.finishedHash) + c.extendedMasterSecret = true + } else { + if c.config.Bugs.RequireExtendedMasterSecret { + return errors.New("tls: extended master secret required but not supported by peer") + } + hs.masterSecret = masterFromPreMasterSecret(c.vers, hs.suite, preMasterSecret, hs.hello.random, hs.serverHello.random) + } + + if chainToSend != nil { + var signed []byte + certVerify := &certificateVerifyMsg{ + hasSignatureAndHash: c.vers >= VersionTLS12, + } + + switch key := c.config.Certificates[0].PrivateKey.(type) { + case *ecdsa.PrivateKey: + certVerify.signatureAndHash, err = hs.finishedHash.selectClientCertSignatureAlgorithm(certReq.signatureAndHashes, signatureECDSA) + if err != nil { + break + } + var digest []byte + digest, _, err = hs.finishedHash.hashForClientCertificate(certVerify.signatureAndHash, hs.masterSecret) + if err != nil { + break + } + var r, s *big.Int + r, s, err = ecdsa.Sign(c.config.rand(), key, digest) + if err == nil { + signed, err = asn1.Marshal(ecdsaSignature{r, s}) + } + case *rsa.PrivateKey: + certVerify.signatureAndHash, err = hs.finishedHash.selectClientCertSignatureAlgorithm(certReq.signatureAndHashes, signatureRSA) + if err != nil { + break + } + var digest []byte + var hashFunc crypto.Hash + digest, hashFunc, err = hs.finishedHash.hashForClientCertificate(certVerify.signatureAndHash, hs.masterSecret) + if err != nil { + break + } + signed, err = rsa.SignPKCS1v15(c.config.rand(), key, hashFunc, digest) + default: + err = errors.New("unknown private key type") + } + if err != nil { + c.sendAlert(alertInternalError) + return errors.New("tls: failed to sign handshake with client certificate: " + err.Error()) + } + certVerify.signature = signed + + hs.writeClientHash(certVerify.marshal()) + c.writeRecord(recordTypeHandshake, certVerify.marshal()) + } + + hs.finishedHash.discardHandshakeBuffer() + + return nil +} + +func (hs *clientHandshakeState) establishKeys() error { + c := hs.c + + clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV := + keysFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.hello.random, hs.serverHello.random, hs.suite.macLen, hs.suite.keyLen, hs.suite.ivLen) + var clientCipher, serverCipher interface{} + var clientHash, serverHash macFunction + if hs.suite.cipher != nil { + clientCipher = hs.suite.cipher(clientKey, clientIV, false /* not for reading */) + clientHash = hs.suite.mac(c.vers, clientMAC) + serverCipher = hs.suite.cipher(serverKey, serverIV, true /* for reading */) + serverHash = hs.suite.mac(c.vers, serverMAC) + } else { + clientCipher = hs.suite.aead(clientKey, clientIV) + serverCipher = hs.suite.aead(serverKey, serverIV) + } + + c.in.prepareCipherSpec(c.vers, serverCipher, serverHash) + c.out.prepareCipherSpec(c.vers, clientCipher, clientHash) + return nil +} + +func (hs *clientHandshakeState) serverResumedSession() bool { + // If the server responded with the same sessionId then it means the + // sessionTicket is being used to resume a TLS session. + return hs.session != nil && hs.hello.sessionId != nil && + bytes.Equal(hs.serverHello.sessionId, hs.hello.sessionId) +} + +func (hs *clientHandshakeState) processServerHello() (bool, error) { + c := hs.c + + if hs.serverHello.compressionMethod != compressionNone { + c.sendAlert(alertUnexpectedMessage) + return false, errors.New("tls: server selected unsupported compression format") + } + + clientDidNPN := hs.hello.nextProtoNeg + clientDidALPN := len(hs.hello.alpnProtocols) > 0 + serverHasNPN := hs.serverHello.nextProtoNeg + serverHasALPN := len(hs.serverHello.alpnProtocol) > 0 + + if !clientDidNPN && serverHasNPN { + c.sendAlert(alertHandshakeFailure) + return false, errors.New("server advertised unrequested NPN extension") + } + + if !clientDidALPN && serverHasALPN { + c.sendAlert(alertHandshakeFailure) + return false, errors.New("server advertised unrequested ALPN extension") + } + + if serverHasNPN && serverHasALPN { + c.sendAlert(alertHandshakeFailure) + return false, errors.New("server advertised both NPN and ALPN extensions") + } + + if serverHasALPN { + c.clientProtocol = hs.serverHello.alpnProtocol + c.clientProtocolFallback = false + c.usedALPN = true + } + + if !hs.hello.channelIDSupported && hs.serverHello.channelIDRequested { + c.sendAlert(alertHandshakeFailure) + return false, errors.New("server advertised unrequested Channel ID extension") + } + + if hs.serverHello.srtpProtectionProfile != 0 { + if hs.serverHello.srtpMasterKeyIdentifier != "" { + return false, errors.New("tls: server selected SRTP MKI value") + } + + found := false + for _, p := range c.config.SRTPProtectionProfiles { + if p == hs.serverHello.srtpProtectionProfile { + found = true + break + } + } + if !found { + return false, errors.New("tls: server advertised unsupported SRTP profile") + } + + c.srtpProtectionProfile = hs.serverHello.srtpProtectionProfile + } + + if hs.serverResumedSession() { + // Restore masterSecret and peerCerts from previous state + hs.masterSecret = hs.session.masterSecret + c.peerCertificates = hs.session.serverCertificates + c.extendedMasterSecret = hs.session.extendedMasterSecret + hs.finishedHash.discardHandshakeBuffer() + return true, nil + } + return false, nil +} + +func (hs *clientHandshakeState) readFinished() error { + c := hs.c + + c.readRecord(recordTypeChangeCipherSpec) + if err := c.in.error(); err != nil { + return err + } + + msg, err := c.readHandshake() + if err != nil { + return err + } + serverFinished, ok := msg.(*finishedMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(serverFinished, msg) + } + + if c.config.Bugs.EarlyChangeCipherSpec == 0 { + verify := hs.finishedHash.serverSum(hs.masterSecret) + if len(verify) != len(serverFinished.verifyData) || + subtle.ConstantTimeCompare(verify, serverFinished.verifyData) != 1 { + c.sendAlert(alertHandshakeFailure) + return errors.New("tls: server's Finished message was incorrect") + } + } + c.serverVerify = append(c.serverVerify[:0], serverFinished.verifyData...) + hs.writeServerHash(serverFinished.marshal()) + return nil +} + +func (hs *clientHandshakeState) readSessionTicket() error { + c := hs.c + + // Create a session with no server identifier. Either a + // session ID or session ticket will be attached. + session := &ClientSessionState{ + vers: c.vers, + cipherSuite: hs.suite.id, + masterSecret: hs.masterSecret, + handshakeHash: hs.finishedHash.server.Sum(nil), + serverCertificates: c.peerCertificates, + } + + if !hs.serverHello.ticketSupported { + if hs.session == nil && len(hs.serverHello.sessionId) > 0 { + session.sessionId = hs.serverHello.sessionId + hs.session = session + } + return nil + } + + msg, err := c.readHandshake() + if err != nil { + return err + } + sessionTicketMsg, ok := msg.(*newSessionTicketMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(sessionTicketMsg, msg) + } + + session.sessionTicket = sessionTicketMsg.ticket + hs.session = session + + hs.writeServerHash(sessionTicketMsg.marshal()) + + return nil +} + +func (hs *clientHandshakeState) sendFinished(isResume bool) error { + c := hs.c + + var postCCSBytes []byte + seqno := hs.c.sendHandshakeSeq + if hs.serverHello.nextProtoNeg { + nextProto := new(nextProtoMsg) + proto, fallback := mutualProtocol(c.config.NextProtos, hs.serverHello.nextProtos) + nextProto.proto = proto + c.clientProtocol = proto + c.clientProtocolFallback = fallback + + nextProtoBytes := nextProto.marshal() + hs.writeHash(nextProtoBytes, seqno) + seqno++ + postCCSBytes = append(postCCSBytes, nextProtoBytes...) + } + + if hs.serverHello.channelIDRequested { + encryptedExtensions := new(encryptedExtensionsMsg) + if c.config.ChannelID.Curve != elliptic.P256() { + return fmt.Errorf("tls: Channel ID is not on P-256.") + } + var resumeHash []byte + if isResume { + resumeHash = hs.session.handshakeHash + } + r, s, err := ecdsa.Sign(c.config.rand(), c.config.ChannelID, hs.finishedHash.hashForChannelID(resumeHash)) + if err != nil { + return err + } + channelID := make([]byte, 128) + writeIntPadded(channelID[0:32], c.config.ChannelID.X) + writeIntPadded(channelID[32:64], c.config.ChannelID.Y) + writeIntPadded(channelID[64:96], r) + writeIntPadded(channelID[96:128], s) + encryptedExtensions.channelID = channelID + + c.channelID = &c.config.ChannelID.PublicKey + + encryptedExtensionsBytes := encryptedExtensions.marshal() + hs.writeHash(encryptedExtensionsBytes, seqno) + seqno++ + postCCSBytes = append(postCCSBytes, encryptedExtensionsBytes...) + } + + finished := new(finishedMsg) + if c.config.Bugs.EarlyChangeCipherSpec == 2 { + finished.verifyData = hs.finishedHash.clientSum(nil) + } else { + finished.verifyData = hs.finishedHash.clientSum(hs.masterSecret) + } + c.clientVerify = append(c.clientVerify[:0], finished.verifyData...) + finishedBytes := finished.marshal() + hs.writeHash(finishedBytes, seqno) + postCCSBytes = append(postCCSBytes, finishedBytes...) + + if c.config.Bugs.FragmentAcrossChangeCipherSpec { + c.writeRecord(recordTypeHandshake, postCCSBytes[:5]) + postCCSBytes = postCCSBytes[5:] + } + + if !c.config.Bugs.SkipChangeCipherSpec && + c.config.Bugs.EarlyChangeCipherSpec == 0 { + c.writeRecord(recordTypeChangeCipherSpec, []byte{1}) + } + + if c.config.Bugs.AppDataAfterChangeCipherSpec != nil { + c.writeRecord(recordTypeApplicationData, c.config.Bugs.AppDataAfterChangeCipherSpec) + } + + c.writeRecord(recordTypeHandshake, postCCSBytes) + return nil +} + +func (hs *clientHandshakeState) writeClientHash(msg []byte) { + // writeClientHash is called before writeRecord. + hs.writeHash(msg, hs.c.sendHandshakeSeq) +} + +func (hs *clientHandshakeState) writeServerHash(msg []byte) { + // writeServerHash is called after readHandshake. + hs.writeHash(msg, hs.c.recvHandshakeSeq-1) +} + +func (hs *clientHandshakeState) writeHash(msg []byte, seqno uint16) { + if hs.c.isDTLS { + // This is somewhat hacky. DTLS hashes a slightly different format. + // First, the TLS header. + hs.finishedHash.Write(msg[:4]) + // Then the sequence number and reassembled fragment offset (always 0). + hs.finishedHash.Write([]byte{byte(seqno >> 8), byte(seqno), 0, 0, 0}) + // Then the reassembled fragment (always equal to the message length). + hs.finishedHash.Write(msg[1:4]) + // And then the message body. + hs.finishedHash.Write(msg[4:]) + } else { + hs.finishedHash.Write(msg) + } +} + +// clientSessionCacheKey returns a key used to cache sessionTickets that could +// be used to resume previously negotiated TLS sessions with a server. +func clientSessionCacheKey(serverAddr net.Addr, config *Config) string { + if len(config.ServerName) > 0 { + return config.ServerName + } + return serverAddr.String() +} + +// mutualProtocol finds the mutual Next Protocol Negotiation or ALPN protocol +// given list of possible protocols and a list of the preference order. The +// first list must not be empty. It returns the resulting protocol and flag +// indicating if the fallback case was reached. +func mutualProtocol(protos, preferenceProtos []string) (string, bool) { + for _, s := range preferenceProtos { + for _, c := range protos { + if s == c { + return s, false + } + } + } + + return protos[0], true +} + +// writeIntPadded writes x into b, padded up with leading zeros as +// needed. +func writeIntPadded(b []byte, x *big.Int) { + for i := range b { + b[i] = 0 + } + xb := x.Bytes() + copy(b[len(b)-len(xb):], xb) +} diff --git a/src/ssl/test/runner/handshake_messages.go b/src/ssl/test/runner/handshake_messages.go new file mode 100644 index 0000000..ce214fd --- /dev/null +++ b/src/ssl/test/runner/handshake_messages.go @@ -0,0 +1,1875 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import "bytes" + +type clientHelloMsg struct { + raw []byte + isDTLS bool + vers uint16 + random []byte + sessionId []byte + cookie []byte + cipherSuites []uint16 + compressionMethods []uint8 + nextProtoNeg bool + serverName string + ocspStapling bool + supportedCurves []CurveID + supportedPoints []uint8 + ticketSupported bool + sessionTicket []uint8 + signatureAndHashes []signatureAndHash + secureRenegotiation []byte + alpnProtocols []string + duplicateExtension bool + channelIDSupported bool + npnLast bool + extendedMasterSecret bool + srtpProtectionProfiles []uint16 + srtpMasterKeyIdentifier string + sctListSupported bool +} + +func (m *clientHelloMsg) equal(i interface{}) bool { + m1, ok := i.(*clientHelloMsg) + if !ok { + return false + } + + return bytes.Equal(m.raw, m1.raw) && + m.isDTLS == m1.isDTLS && + m.vers == m1.vers && + bytes.Equal(m.random, m1.random) && + bytes.Equal(m.sessionId, m1.sessionId) && + bytes.Equal(m.cookie, m1.cookie) && + eqUint16s(m.cipherSuites, m1.cipherSuites) && + bytes.Equal(m.compressionMethods, m1.compressionMethods) && + m.nextProtoNeg == m1.nextProtoNeg && + m.serverName == m1.serverName && + m.ocspStapling == m1.ocspStapling && + eqCurveIDs(m.supportedCurves, m1.supportedCurves) && + bytes.Equal(m.supportedPoints, m1.supportedPoints) && + m.ticketSupported == m1.ticketSupported && + bytes.Equal(m.sessionTicket, m1.sessionTicket) && + eqSignatureAndHashes(m.signatureAndHashes, m1.signatureAndHashes) && + bytes.Equal(m.secureRenegotiation, m1.secureRenegotiation) && + (m.secureRenegotiation == nil) == (m1.secureRenegotiation == nil) && + eqStrings(m.alpnProtocols, m1.alpnProtocols) && + m.duplicateExtension == m1.duplicateExtension && + m.channelIDSupported == m1.channelIDSupported && + m.npnLast == m1.npnLast && + m.extendedMasterSecret == m1.extendedMasterSecret && + eqUint16s(m.srtpProtectionProfiles, m1.srtpProtectionProfiles) && + m.srtpMasterKeyIdentifier == m1.srtpMasterKeyIdentifier && + m.sctListSupported == m1.sctListSupported +} + +func (m *clientHelloMsg) marshal() []byte { + if m.raw != nil { + return m.raw + } + + length := 2 + 32 + 1 + len(m.sessionId) + 2 + len(m.cipherSuites)*2 + 1 + len(m.compressionMethods) + if m.isDTLS { + length += 1 + len(m.cookie) + } + numExtensions := 0 + extensionsLength := 0 + if m.nextProtoNeg { + numExtensions++ + } + if m.ocspStapling { + extensionsLength += 1 + 2 + 2 + numExtensions++ + } + if len(m.serverName) > 0 { + extensionsLength += 5 + len(m.serverName) + numExtensions++ + } + if len(m.supportedCurves) > 0 { + extensionsLength += 2 + 2*len(m.supportedCurves) + numExtensions++ + } + if len(m.supportedPoints) > 0 { + extensionsLength += 1 + len(m.supportedPoints) + numExtensions++ + } + if m.ticketSupported { + extensionsLength += len(m.sessionTicket) + numExtensions++ + } + if len(m.signatureAndHashes) > 0 { + extensionsLength += 2 + 2*len(m.signatureAndHashes) + numExtensions++ + } + if m.secureRenegotiation != nil { + extensionsLength += 1 + len(m.secureRenegotiation) + numExtensions++ + } + if m.duplicateExtension { + numExtensions += 2 + } + if m.channelIDSupported { + numExtensions++ + } + if len(m.alpnProtocols) > 0 { + extensionsLength += 2 + for _, s := range m.alpnProtocols { + if l := len(s); l == 0 || l > 255 { + panic("invalid ALPN protocol") + } + extensionsLength++ + extensionsLength += len(s) + } + numExtensions++ + } + if m.extendedMasterSecret { + numExtensions++ + } + if len(m.srtpProtectionProfiles) > 0 { + extensionsLength += 2 + 2*len(m.srtpProtectionProfiles) + extensionsLength += 1 + len(m.srtpMasterKeyIdentifier) + numExtensions++ + } + if m.sctListSupported { + numExtensions++ + } + if numExtensions > 0 { + extensionsLength += 4 * numExtensions + length += 2 + extensionsLength + } + + x := make([]byte, 4+length) + x[0] = typeClientHello + x[1] = uint8(length >> 16) + x[2] = uint8(length >> 8) + x[3] = uint8(length) + vers := versionToWire(m.vers, m.isDTLS) + x[4] = uint8(vers >> 8) + x[5] = uint8(vers) + copy(x[6:38], m.random) + x[38] = uint8(len(m.sessionId)) + copy(x[39:39+len(m.sessionId)], m.sessionId) + y := x[39+len(m.sessionId):] + if m.isDTLS { + y[0] = uint8(len(m.cookie)) + copy(y[1:], m.cookie) + y = y[1+len(m.cookie):] + } + y[0] = uint8(len(m.cipherSuites) >> 7) + y[1] = uint8(len(m.cipherSuites) << 1) + for i, suite := range m.cipherSuites { + y[2+i*2] = uint8(suite >> 8) + y[3+i*2] = uint8(suite) + } + z := y[2+len(m.cipherSuites)*2:] + z[0] = uint8(len(m.compressionMethods)) + copy(z[1:], m.compressionMethods) + + z = z[1+len(m.compressionMethods):] + if numExtensions > 0 { + z[0] = byte(extensionsLength >> 8) + z[1] = byte(extensionsLength) + z = z[2:] + } + if m.duplicateExtension { + // Add a duplicate bogus extension at the beginning and end. + z[0] = 0xff + z[1] = 0xff + z = z[4:] + } + if m.nextProtoNeg && !m.npnLast { + z[0] = byte(extensionNextProtoNeg >> 8) + z[1] = byte(extensionNextProtoNeg & 0xff) + // The length is always 0 + z = z[4:] + } + if len(m.serverName) > 0 { + z[0] = byte(extensionServerName >> 8) + z[1] = byte(extensionServerName & 0xff) + l := len(m.serverName) + 5 + z[2] = byte(l >> 8) + z[3] = byte(l) + z = z[4:] + + // RFC 3546, section 3.1 + // + // struct { + // NameType name_type; + // select (name_type) { + // case host_name: HostName; + // } name; + // } ServerName; + // + // enum { + // host_name(0), (255) + // } NameType; + // + // opaque HostName<1..2^16-1>; + // + // struct { + // ServerName server_name_list<1..2^16-1> + // } ServerNameList; + + z[0] = byte((len(m.serverName) + 3) >> 8) + z[1] = byte(len(m.serverName) + 3) + z[3] = byte(len(m.serverName) >> 8) + z[4] = byte(len(m.serverName)) + copy(z[5:], []byte(m.serverName)) + z = z[l:] + } + if m.ocspStapling { + // RFC 4366, section 3.6 + z[0] = byte(extensionStatusRequest >> 8) + z[1] = byte(extensionStatusRequest) + z[2] = 0 + z[3] = 5 + z[4] = 1 // OCSP type + // Two zero valued uint16s for the two lengths. + z = z[9:] + } + if len(m.supportedCurves) > 0 { + // http://tools.ietf.org/html/rfc4492#section-5.5.1 + z[0] = byte(extensionSupportedCurves >> 8) + z[1] = byte(extensionSupportedCurves) + l := 2 + 2*len(m.supportedCurves) + z[2] = byte(l >> 8) + z[3] = byte(l) + l -= 2 + z[4] = byte(l >> 8) + z[5] = byte(l) + z = z[6:] + for _, curve := range m.supportedCurves { + z[0] = byte(curve >> 8) + z[1] = byte(curve) + z = z[2:] + } + } + if len(m.supportedPoints) > 0 { + // http://tools.ietf.org/html/rfc4492#section-5.5.2 + z[0] = byte(extensionSupportedPoints >> 8) + z[1] = byte(extensionSupportedPoints) + l := 1 + len(m.supportedPoints) + z[2] = byte(l >> 8) + z[3] = byte(l) + l-- + z[4] = byte(l) + z = z[5:] + for _, pointFormat := range m.supportedPoints { + z[0] = byte(pointFormat) + z = z[1:] + } + } + if m.ticketSupported { + // http://tools.ietf.org/html/rfc5077#section-3.2 + z[0] = byte(extensionSessionTicket >> 8) + z[1] = byte(extensionSessionTicket) + l := len(m.sessionTicket) + z[2] = byte(l >> 8) + z[3] = byte(l) + z = z[4:] + copy(z, m.sessionTicket) + z = z[len(m.sessionTicket):] + } + if len(m.signatureAndHashes) > 0 { + // https://tools.ietf.org/html/rfc5246#section-7.4.1.4.1 + z[0] = byte(extensionSignatureAlgorithms >> 8) + z[1] = byte(extensionSignatureAlgorithms) + l := 2 + 2*len(m.signatureAndHashes) + z[2] = byte(l >> 8) + z[3] = byte(l) + z = z[4:] + + l -= 2 + z[0] = byte(l >> 8) + z[1] = byte(l) + z = z[2:] + for _, sigAndHash := range m.signatureAndHashes { + z[0] = sigAndHash.hash + z[1] = sigAndHash.signature + z = z[2:] + } + } + if m.secureRenegotiation != nil { + z[0] = byte(extensionRenegotiationInfo >> 8) + z[1] = byte(extensionRenegotiationInfo & 0xff) + z[2] = 0 + z[3] = byte(1 + len(m.secureRenegotiation)) + z[4] = byte(len(m.secureRenegotiation)) + z = z[5:] + copy(z, m.secureRenegotiation) + z = z[len(m.secureRenegotiation):] + } + if len(m.alpnProtocols) > 0 { + z[0] = byte(extensionALPN >> 8) + z[1] = byte(extensionALPN & 0xff) + lengths := z[2:] + z = z[6:] + + stringsLength := 0 + for _, s := range m.alpnProtocols { + l := len(s) + z[0] = byte(l) + copy(z[1:], s) + z = z[1+l:] + stringsLength += 1 + l + } + + lengths[2] = byte(stringsLength >> 8) + lengths[3] = byte(stringsLength) + stringsLength += 2 + lengths[0] = byte(stringsLength >> 8) + lengths[1] = byte(stringsLength) + } + if m.channelIDSupported { + z[0] = byte(extensionChannelID >> 8) + z[1] = byte(extensionChannelID & 0xff) + z = z[4:] + } + if m.nextProtoNeg && m.npnLast { + z[0] = byte(extensionNextProtoNeg >> 8) + z[1] = byte(extensionNextProtoNeg & 0xff) + // The length is always 0 + z = z[4:] + } + if m.duplicateExtension { + // Add a duplicate bogus extension at the beginning and end. + z[0] = 0xff + z[1] = 0xff + z = z[4:] + } + if m.extendedMasterSecret { + // https://tools.ietf.org/html/draft-ietf-tls-session-hash-01 + z[0] = byte(extensionExtendedMasterSecret >> 8) + z[1] = byte(extensionExtendedMasterSecret & 0xff) + z = z[4:] + } + if len(m.srtpProtectionProfiles) > 0 { + z[0] = byte(extensionUseSRTP >> 8) + z[1] = byte(extensionUseSRTP & 0xff) + + profilesLen := 2 * len(m.srtpProtectionProfiles) + mkiLen := len(m.srtpMasterKeyIdentifier) + l := 2 + profilesLen + 1 + mkiLen + z[2] = byte(l >> 8) + z[3] = byte(l & 0xff) + + z[4] = byte(profilesLen >> 8) + z[5] = byte(profilesLen & 0xff) + z = z[6:] + for _, p := range m.srtpProtectionProfiles { + z[0] = byte(p >> 8) + z[1] = byte(p & 0xff) + z = z[2:] + } + + z[0] = byte(mkiLen) + copy(z[1:], []byte(m.srtpMasterKeyIdentifier)) + z = z[1+mkiLen:] + } + if m.sctListSupported { + z[0] = byte(extensionSignedCertificateTimestamp >> 8) + z[1] = byte(extensionSignedCertificateTimestamp & 0xff) + z = z[4:] + } + + m.raw = x + + return x +} + +func (m *clientHelloMsg) unmarshal(data []byte) bool { + if len(data) < 42 { + return false + } + m.raw = data + m.vers = wireToVersion(uint16(data[4])<<8|uint16(data[5]), m.isDTLS) + m.random = data[6:38] + sessionIdLen := int(data[38]) + if sessionIdLen > 32 || len(data) < 39+sessionIdLen { + return false + } + m.sessionId = data[39 : 39+sessionIdLen] + data = data[39+sessionIdLen:] + if m.isDTLS { + if len(data) < 1 { + return false + } + cookieLen := int(data[0]) + if cookieLen > 32 || len(data) < 1+cookieLen { + return false + } + m.cookie = data[1 : 1+cookieLen] + data = data[1+cookieLen:] + } + if len(data) < 2 { + return false + } + // cipherSuiteLen is the number of bytes of cipher suite numbers. Since + // they are uint16s, the number must be even. + cipherSuiteLen := int(data[0])<<8 | int(data[1]) + if cipherSuiteLen%2 == 1 || len(data) < 2+cipherSuiteLen { + return false + } + numCipherSuites := cipherSuiteLen / 2 + m.cipherSuites = make([]uint16, numCipherSuites) + for i := 0; i < numCipherSuites; i++ { + m.cipherSuites[i] = uint16(data[2+2*i])<<8 | uint16(data[3+2*i]) + if m.cipherSuites[i] == scsvRenegotiation { + m.secureRenegotiation = []byte{} + } + } + data = data[2+cipherSuiteLen:] + if len(data) < 1 { + return false + } + compressionMethodsLen := int(data[0]) + if len(data) < 1+compressionMethodsLen { + return false + } + m.compressionMethods = data[1 : 1+compressionMethodsLen] + + data = data[1+compressionMethodsLen:] + + m.nextProtoNeg = false + m.serverName = "" + m.ocspStapling = false + m.ticketSupported = false + m.sessionTicket = nil + m.signatureAndHashes = nil + m.alpnProtocols = nil + m.extendedMasterSecret = false + + if len(data) == 0 { + // ClientHello is optionally followed by extension data + return true + } + if len(data) < 2 { + return false + } + + extensionsLength := int(data[0])<<8 | int(data[1]) + data = data[2:] + if extensionsLength != len(data) { + return false + } + + for len(data) != 0 { + if len(data) < 4 { + return false + } + extension := uint16(data[0])<<8 | uint16(data[1]) + length := int(data[2])<<8 | int(data[3]) + data = data[4:] + if len(data) < length { + return false + } + + switch extension { + case extensionServerName: + if length < 2 { + return false + } + numNames := int(data[0])<<8 | int(data[1]) + d := data[2:] + for i := 0; i < numNames; i++ { + if len(d) < 3 { + return false + } + nameType := d[0] + nameLen := int(d[1])<<8 | int(d[2]) + d = d[3:] + if len(d) < nameLen { + return false + } + if nameType == 0 { + m.serverName = string(d[0:nameLen]) + break + } + d = d[nameLen:] + } + case extensionNextProtoNeg: + if length > 0 { + return false + } + m.nextProtoNeg = true + case extensionStatusRequest: + m.ocspStapling = length > 0 && data[0] == statusTypeOCSP + case extensionSupportedCurves: + // http://tools.ietf.org/html/rfc4492#section-5.5.1 + if length < 2 { + return false + } + l := int(data[0])<<8 | int(data[1]) + if l%2 == 1 || length != l+2 { + return false + } + numCurves := l / 2 + m.supportedCurves = make([]CurveID, numCurves) + d := data[2:] + for i := 0; i < numCurves; i++ { + m.supportedCurves[i] = CurveID(d[0])<<8 | CurveID(d[1]) + d = d[2:] + } + case extensionSupportedPoints: + // http://tools.ietf.org/html/rfc4492#section-5.5.2 + if length < 1 { + return false + } + l := int(data[0]) + if length != l+1 { + return false + } + m.supportedPoints = make([]uint8, l) + copy(m.supportedPoints, data[1:]) + case extensionSessionTicket: + // http://tools.ietf.org/html/rfc5077#section-3.2 + m.ticketSupported = true + m.sessionTicket = data[:length] + case extensionSignatureAlgorithms: + // https://tools.ietf.org/html/rfc5246#section-7.4.1.4.1 + if length < 2 || length&1 != 0 { + return false + } + l := int(data[0])<<8 | int(data[1]) + if l != length-2 { + return false + } + n := l / 2 + d := data[2:] + m.signatureAndHashes = make([]signatureAndHash, n) + for i := range m.signatureAndHashes { + m.signatureAndHashes[i].hash = d[0] + m.signatureAndHashes[i].signature = d[1] + d = d[2:] + } + case extensionRenegotiationInfo: + if length < 1 || length != int(data[0])+1 { + return false + } + m.secureRenegotiation = data[1:length] + case extensionALPN: + if length < 2 { + return false + } + l := int(data[0])<<8 | int(data[1]) + if l != length-2 { + return false + } + d := data[2:length] + for len(d) != 0 { + stringLen := int(d[0]) + d = d[1:] + if stringLen == 0 || stringLen > len(d) { + return false + } + m.alpnProtocols = append(m.alpnProtocols, string(d[:stringLen])) + d = d[stringLen:] + } + case extensionChannelID: + if length > 0 { + return false + } + m.channelIDSupported = true + case extensionExtendedMasterSecret: + if length != 0 { + return false + } + m.extendedMasterSecret = true + case extensionUseSRTP: + if length < 2 { + return false + } + l := int(data[0])<<8 | int(data[1]) + if l > length-2 || l%2 != 0 { + return false + } + n := l / 2 + m.srtpProtectionProfiles = make([]uint16, n) + d := data[2:length] + for i := 0; i < n; i++ { + m.srtpProtectionProfiles[i] = uint16(d[0])<<8 | uint16(d[1]) + d = d[2:] + } + if len(d) < 1 || int(d[0]) != len(d)-1 { + return false + } + m.srtpMasterKeyIdentifier = string(d[1:]) + case extensionSignedCertificateTimestamp: + if length != 0 { + return false + } + m.sctListSupported = true + } + data = data[length:] + } + + return true +} + +type serverHelloMsg struct { + raw []byte + isDTLS bool + vers uint16 + random []byte + sessionId []byte + cipherSuite uint16 + compressionMethod uint8 + nextProtoNeg bool + nextProtos []string + ocspStapling bool + ticketSupported bool + secureRenegotiation []byte + alpnProtocol string + duplicateExtension bool + channelIDRequested bool + extendedMasterSecret bool + srtpProtectionProfile uint16 + srtpMasterKeyIdentifier string + sctList []byte +} + +func (m *serverHelloMsg) equal(i interface{}) bool { + m1, ok := i.(*serverHelloMsg) + if !ok { + return false + } + + return bytes.Equal(m.raw, m1.raw) && + m.isDTLS == m1.isDTLS && + m.vers == m1.vers && + bytes.Equal(m.random, m1.random) && + bytes.Equal(m.sessionId, m1.sessionId) && + m.cipherSuite == m1.cipherSuite && + m.compressionMethod == m1.compressionMethod && + m.nextProtoNeg == m1.nextProtoNeg && + eqStrings(m.nextProtos, m1.nextProtos) && + m.ocspStapling == m1.ocspStapling && + m.ticketSupported == m1.ticketSupported && + bytes.Equal(m.secureRenegotiation, m1.secureRenegotiation) && + (m.secureRenegotiation == nil) == (m1.secureRenegotiation == nil) && + m.alpnProtocol == m1.alpnProtocol && + m.duplicateExtension == m1.duplicateExtension && + m.channelIDRequested == m1.channelIDRequested && + m.extendedMasterSecret == m1.extendedMasterSecret && + m.srtpProtectionProfile == m1.srtpProtectionProfile && + m.srtpMasterKeyIdentifier == m1.srtpMasterKeyIdentifier && + bytes.Equal(m.sctList, m1.sctList) +} + +func (m *serverHelloMsg) marshal() []byte { + if m.raw != nil { + return m.raw + } + + length := 38 + len(m.sessionId) + numExtensions := 0 + extensionsLength := 0 + + nextProtoLen := 0 + if m.nextProtoNeg { + numExtensions++ + for _, v := range m.nextProtos { + nextProtoLen += len(v) + } + nextProtoLen += len(m.nextProtos) + extensionsLength += nextProtoLen + } + if m.ocspStapling { + numExtensions++ + } + if m.ticketSupported { + numExtensions++ + } + if m.secureRenegotiation != nil { + extensionsLength += 1 + len(m.secureRenegotiation) + numExtensions++ + } + if m.duplicateExtension { + numExtensions += 2 + } + if m.channelIDRequested { + numExtensions++ + } + if alpnLen := len(m.alpnProtocol); alpnLen > 0 { + if alpnLen >= 256 { + panic("invalid ALPN protocol") + } + extensionsLength += 2 + 1 + alpnLen + numExtensions++ + } + if m.extendedMasterSecret { + numExtensions++ + } + if m.srtpProtectionProfile != 0 { + extensionsLength += 2 + 2 + 1 + len(m.srtpMasterKeyIdentifier) + numExtensions++ + } + if m.sctList != nil { + extensionsLength += len(m.sctList) + numExtensions++ + } + + if numExtensions > 0 { + extensionsLength += 4 * numExtensions + length += 2 + extensionsLength + } + + x := make([]byte, 4+length) + x[0] = typeServerHello + x[1] = uint8(length >> 16) + x[2] = uint8(length >> 8) + x[3] = uint8(length) + vers := versionToWire(m.vers, m.isDTLS) + x[4] = uint8(vers >> 8) + x[5] = uint8(vers) + copy(x[6:38], m.random) + x[38] = uint8(len(m.sessionId)) + copy(x[39:39+len(m.sessionId)], m.sessionId) + z := x[39+len(m.sessionId):] + z[0] = uint8(m.cipherSuite >> 8) + z[1] = uint8(m.cipherSuite) + z[2] = uint8(m.compressionMethod) + + z = z[3:] + if numExtensions > 0 { + z[0] = byte(extensionsLength >> 8) + z[1] = byte(extensionsLength) + z = z[2:] + } + if m.duplicateExtension { + // Add a duplicate bogus extension at the beginning and end. + z[0] = 0xff + z[1] = 0xff + z = z[4:] + } + if m.nextProtoNeg { + z[0] = byte(extensionNextProtoNeg >> 8) + z[1] = byte(extensionNextProtoNeg & 0xff) + z[2] = byte(nextProtoLen >> 8) + z[3] = byte(nextProtoLen) + z = z[4:] + + for _, v := range m.nextProtos { + l := len(v) + if l > 255 { + l = 255 + } + z[0] = byte(l) + copy(z[1:], []byte(v[0:l])) + z = z[1+l:] + } + } + if m.ocspStapling { + z[0] = byte(extensionStatusRequest >> 8) + z[1] = byte(extensionStatusRequest) + z = z[4:] + } + if m.ticketSupported { + z[0] = byte(extensionSessionTicket >> 8) + z[1] = byte(extensionSessionTicket) + z = z[4:] + } + if m.secureRenegotiation != nil { + z[0] = byte(extensionRenegotiationInfo >> 8) + z[1] = byte(extensionRenegotiationInfo & 0xff) + z[2] = 0 + z[3] = byte(1 + len(m.secureRenegotiation)) + z[4] = byte(len(m.secureRenegotiation)) + z = z[5:] + copy(z, m.secureRenegotiation) + z = z[len(m.secureRenegotiation):] + } + if alpnLen := len(m.alpnProtocol); alpnLen > 0 { + z[0] = byte(extensionALPN >> 8) + z[1] = byte(extensionALPN & 0xff) + l := 2 + 1 + alpnLen + z[2] = byte(l >> 8) + z[3] = byte(l) + l -= 2 + z[4] = byte(l >> 8) + z[5] = byte(l) + l -= 1 + z[6] = byte(l) + copy(z[7:], []byte(m.alpnProtocol)) + z = z[7+alpnLen:] + } + if m.channelIDRequested { + z[0] = byte(extensionChannelID >> 8) + z[1] = byte(extensionChannelID & 0xff) + z = z[4:] + } + if m.duplicateExtension { + // Add a duplicate bogus extension at the beginning and end. + z[0] = 0xff + z[1] = 0xff + z = z[4:] + } + if m.extendedMasterSecret { + z[0] = byte(extensionExtendedMasterSecret >> 8) + z[1] = byte(extensionExtendedMasterSecret & 0xff) + z = z[4:] + } + if m.srtpProtectionProfile != 0 { + z[0] = byte(extensionUseSRTP >> 8) + z[1] = byte(extensionUseSRTP & 0xff) + l := 2 + 2 + 1 + len(m.srtpMasterKeyIdentifier) + z[2] = byte(l >> 8) + z[3] = byte(l & 0xff) + z[4] = 0 + z[5] = 2 + z[6] = byte(m.srtpProtectionProfile >> 8) + z[7] = byte(m.srtpProtectionProfile & 0xff) + l = len(m.srtpMasterKeyIdentifier) + z[8] = byte(l) + copy(z[9:], []byte(m.srtpMasterKeyIdentifier)) + z = z[9+l:] + } + if m.sctList != nil { + z[0] = byte(extensionSignedCertificateTimestamp >> 8) + z[1] = byte(extensionSignedCertificateTimestamp & 0xff) + l := len(m.sctList) + z[2] = byte(l >> 8) + z[3] = byte(l & 0xff) + copy(z[4:], m.sctList) + z = z[4+l:] + } + + m.raw = x + + return x +} + +func (m *serverHelloMsg) unmarshal(data []byte) bool { + if len(data) < 42 { + return false + } + m.raw = data + m.vers = wireToVersion(uint16(data[4])<<8|uint16(data[5]), m.isDTLS) + m.random = data[6:38] + sessionIdLen := int(data[38]) + if sessionIdLen > 32 || len(data) < 39+sessionIdLen { + return false + } + m.sessionId = data[39 : 39+sessionIdLen] + data = data[39+sessionIdLen:] + if len(data) < 3 { + return false + } + m.cipherSuite = uint16(data[0])<<8 | uint16(data[1]) + m.compressionMethod = data[2] + data = data[3:] + + m.nextProtoNeg = false + m.nextProtos = nil + m.ocspStapling = false + m.ticketSupported = false + m.alpnProtocol = "" + m.extendedMasterSecret = false + + if len(data) == 0 { + // ServerHello is optionally followed by extension data + return true + } + if len(data) < 2 { + return false + } + + extensionsLength := int(data[0])<<8 | int(data[1]) + data = data[2:] + if len(data) != extensionsLength { + return false + } + + for len(data) != 0 { + if len(data) < 4 { + return false + } + extension := uint16(data[0])<<8 | uint16(data[1]) + length := int(data[2])<<8 | int(data[3]) + data = data[4:] + if len(data) < length { + return false + } + + switch extension { + case extensionNextProtoNeg: + m.nextProtoNeg = true + d := data[:length] + for len(d) > 0 { + l := int(d[0]) + d = d[1:] + if l == 0 || l > len(d) { + return false + } + m.nextProtos = append(m.nextProtos, string(d[:l])) + d = d[l:] + } + case extensionStatusRequest: + if length > 0 { + return false + } + m.ocspStapling = true + case extensionSessionTicket: + if length > 0 { + return false + } + m.ticketSupported = true + case extensionRenegotiationInfo: + if length < 1 || length != int(data[0])+1 { + return false + } + m.secureRenegotiation = data[1:length] + case extensionALPN: + d := data[:length] + if len(d) < 3 { + return false + } + l := int(d[0])<<8 | int(d[1]) + if l != len(d)-2 { + return false + } + d = d[2:] + l = int(d[0]) + if l != len(d)-1 { + return false + } + d = d[1:] + m.alpnProtocol = string(d) + case extensionChannelID: + if length > 0 { + return false + } + m.channelIDRequested = true + case extensionExtendedMasterSecret: + if length != 0 { + return false + } + m.extendedMasterSecret = true + case extensionUseSRTP: + if length < 2+2+1 { + return false + } + if data[0] != 0 || data[1] != 2 { + return false + } + m.srtpProtectionProfile = uint16(data[2])<<8 | uint16(data[3]) + d := data[4:length] + l := int(d[0]) + if l != len(d)-1 { + return false + } + m.srtpMasterKeyIdentifier = string(d[1:]) + case extensionSignedCertificateTimestamp: + if length < 2 { + return false + } + l := int(data[0])<<8 | int(data[1]) + if l != len(data)-2 { + return false + } + m.sctList = data[2:length] + } + data = data[length:] + } + + return true +} + +type certificateMsg struct { + raw []byte + certificates [][]byte +} + +func (m *certificateMsg) equal(i interface{}) bool { + m1, ok := i.(*certificateMsg) + if !ok { + return false + } + + return bytes.Equal(m.raw, m1.raw) && + eqByteSlices(m.certificates, m1.certificates) +} + +func (m *certificateMsg) marshal() (x []byte) { + if m.raw != nil { + return m.raw + } + + var i int + for _, slice := range m.certificates { + i += len(slice) + } + + length := 3 + 3*len(m.certificates) + i + x = make([]byte, 4+length) + x[0] = typeCertificate + x[1] = uint8(length >> 16) + x[2] = uint8(length >> 8) + x[3] = uint8(length) + + certificateOctets := length - 3 + x[4] = uint8(certificateOctets >> 16) + x[5] = uint8(certificateOctets >> 8) + x[6] = uint8(certificateOctets) + + y := x[7:] + for _, slice := range m.certificates { + y[0] = uint8(len(slice) >> 16) + y[1] = uint8(len(slice) >> 8) + y[2] = uint8(len(slice)) + copy(y[3:], slice) + y = y[3+len(slice):] + } + + m.raw = x + return +} + +func (m *certificateMsg) unmarshal(data []byte) bool { + if len(data) < 7 { + return false + } + + m.raw = data + certsLen := uint32(data[4])<<16 | uint32(data[5])<<8 | uint32(data[6]) + if uint32(len(data)) != certsLen+7 { + return false + } + + numCerts := 0 + d := data[7:] + for certsLen > 0 { + if len(d) < 4 { + return false + } + certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2]) + if uint32(len(d)) < 3+certLen { + return false + } + d = d[3+certLen:] + certsLen -= 3 + certLen + numCerts++ + } + + m.certificates = make([][]byte, numCerts) + d = data[7:] + for i := 0; i < numCerts; i++ { + certLen := uint32(d[0])<<16 | uint32(d[1])<<8 | uint32(d[2]) + m.certificates[i] = d[3 : 3+certLen] + d = d[3+certLen:] + } + + return true +} + +type serverKeyExchangeMsg struct { + raw []byte + key []byte +} + +func (m *serverKeyExchangeMsg) equal(i interface{}) bool { + m1, ok := i.(*serverKeyExchangeMsg) + if !ok { + return false + } + + return bytes.Equal(m.raw, m1.raw) && + bytes.Equal(m.key, m1.key) +} + +func (m *serverKeyExchangeMsg) marshal() []byte { + if m.raw != nil { + return m.raw + } + length := len(m.key) + x := make([]byte, length+4) + x[0] = typeServerKeyExchange + x[1] = uint8(length >> 16) + x[2] = uint8(length >> 8) + x[3] = uint8(length) + copy(x[4:], m.key) + + m.raw = x + return x +} + +func (m *serverKeyExchangeMsg) unmarshal(data []byte) bool { + m.raw = data + if len(data) < 4 { + return false + } + m.key = data[4:] + return true +} + +type certificateStatusMsg struct { + raw []byte + statusType uint8 + response []byte +} + +func (m *certificateStatusMsg) equal(i interface{}) bool { + m1, ok := i.(*certificateStatusMsg) + if !ok { + return false + } + + return bytes.Equal(m.raw, m1.raw) && + m.statusType == m1.statusType && + bytes.Equal(m.response, m1.response) +} + +func (m *certificateStatusMsg) marshal() []byte { + if m.raw != nil { + return m.raw + } + + var x []byte + if m.statusType == statusTypeOCSP { + x = make([]byte, 4+4+len(m.response)) + x[0] = typeCertificateStatus + l := len(m.response) + 4 + x[1] = byte(l >> 16) + x[2] = byte(l >> 8) + x[3] = byte(l) + x[4] = statusTypeOCSP + + l -= 4 + x[5] = byte(l >> 16) + x[6] = byte(l >> 8) + x[7] = byte(l) + copy(x[8:], m.response) + } else { + x = []byte{typeCertificateStatus, 0, 0, 1, m.statusType} + } + + m.raw = x + return x +} + +func (m *certificateStatusMsg) unmarshal(data []byte) bool { + m.raw = data + if len(data) < 5 { + return false + } + m.statusType = data[4] + + m.response = nil + if m.statusType == statusTypeOCSP { + if len(data) < 8 { + return false + } + respLen := uint32(data[5])<<16 | uint32(data[6])<<8 | uint32(data[7]) + if uint32(len(data)) != 4+4+respLen { + return false + } + m.response = data[8:] + } + return true +} + +type serverHelloDoneMsg struct{} + +func (m *serverHelloDoneMsg) equal(i interface{}) bool { + _, ok := i.(*serverHelloDoneMsg) + return ok +} + +func (m *serverHelloDoneMsg) marshal() []byte { + x := make([]byte, 4) + x[0] = typeServerHelloDone + return x +} + +func (m *serverHelloDoneMsg) unmarshal(data []byte) bool { + return len(data) == 4 +} + +type clientKeyExchangeMsg struct { + raw []byte + ciphertext []byte +} + +func (m *clientKeyExchangeMsg) equal(i interface{}) bool { + m1, ok := i.(*clientKeyExchangeMsg) + if !ok { + return false + } + + return bytes.Equal(m.raw, m1.raw) && + bytes.Equal(m.ciphertext, m1.ciphertext) +} + +func (m *clientKeyExchangeMsg) marshal() []byte { + if m.raw != nil { + return m.raw + } + length := len(m.ciphertext) + x := make([]byte, length+4) + x[0] = typeClientKeyExchange + x[1] = uint8(length >> 16) + x[2] = uint8(length >> 8) + x[3] = uint8(length) + copy(x[4:], m.ciphertext) + + m.raw = x + return x +} + +func (m *clientKeyExchangeMsg) unmarshal(data []byte) bool { + m.raw = data + if len(data) < 4 { + return false + } + l := int(data[1])<<16 | int(data[2])<<8 | int(data[3]) + if l != len(data)-4 { + return false + } + m.ciphertext = data[4:] + return true +} + +type finishedMsg struct { + raw []byte + verifyData []byte +} + +func (m *finishedMsg) equal(i interface{}) bool { + m1, ok := i.(*finishedMsg) + if !ok { + return false + } + + return bytes.Equal(m.raw, m1.raw) && + bytes.Equal(m.verifyData, m1.verifyData) +} + +func (m *finishedMsg) marshal() (x []byte) { + if m.raw != nil { + return m.raw + } + + x = make([]byte, 4+len(m.verifyData)) + x[0] = typeFinished + x[3] = byte(len(m.verifyData)) + copy(x[4:], m.verifyData) + m.raw = x + return +} + +func (m *finishedMsg) unmarshal(data []byte) bool { + m.raw = data + if len(data) < 4 { + return false + } + m.verifyData = data[4:] + return true +} + +type nextProtoMsg struct { + raw []byte + proto string +} + +func (m *nextProtoMsg) equal(i interface{}) bool { + m1, ok := i.(*nextProtoMsg) + if !ok { + return false + } + + return bytes.Equal(m.raw, m1.raw) && + m.proto == m1.proto +} + +func (m *nextProtoMsg) marshal() []byte { + if m.raw != nil { + return m.raw + } + l := len(m.proto) + if l > 255 { + l = 255 + } + + padding := 32 - (l+2)%32 + length := l + padding + 2 + x := make([]byte, length+4) + x[0] = typeNextProtocol + x[1] = uint8(length >> 16) + x[2] = uint8(length >> 8) + x[3] = uint8(length) + + y := x[4:] + y[0] = byte(l) + copy(y[1:], []byte(m.proto[0:l])) + y = y[1+l:] + y[0] = byte(padding) + + m.raw = x + + return x +} + +func (m *nextProtoMsg) unmarshal(data []byte) bool { + m.raw = data + + if len(data) < 5 { + return false + } + data = data[4:] + protoLen := int(data[0]) + data = data[1:] + if len(data) < protoLen { + return false + } + m.proto = string(data[0:protoLen]) + data = data[protoLen:] + + if len(data) < 1 { + return false + } + paddingLen := int(data[0]) + data = data[1:] + if len(data) != paddingLen { + return false + } + + return true +} + +type certificateRequestMsg struct { + raw []byte + // hasSignatureAndHash indicates whether this message includes a list + // of signature and hash functions. This change was introduced with TLS + // 1.2. + hasSignatureAndHash bool + + certificateTypes []byte + signatureAndHashes []signatureAndHash + certificateAuthorities [][]byte +} + +func (m *certificateRequestMsg) equal(i interface{}) bool { + m1, ok := i.(*certificateRequestMsg) + if !ok { + return false + } + + return bytes.Equal(m.raw, m1.raw) && + bytes.Equal(m.certificateTypes, m1.certificateTypes) && + eqByteSlices(m.certificateAuthorities, m1.certificateAuthorities) && + eqSignatureAndHashes(m.signatureAndHashes, m1.signatureAndHashes) +} + +func (m *certificateRequestMsg) marshal() (x []byte) { + if m.raw != nil { + return m.raw + } + + // See http://tools.ietf.org/html/rfc4346#section-7.4.4 + length := 1 + len(m.certificateTypes) + 2 + casLength := 0 + for _, ca := range m.certificateAuthorities { + casLength += 2 + len(ca) + } + length += casLength + + if m.hasSignatureAndHash { + length += 2 + 2*len(m.signatureAndHashes) + } + + x = make([]byte, 4+length) + x[0] = typeCertificateRequest + x[1] = uint8(length >> 16) + x[2] = uint8(length >> 8) + x[3] = uint8(length) + + x[4] = uint8(len(m.certificateTypes)) + + copy(x[5:], m.certificateTypes) + y := x[5+len(m.certificateTypes):] + + if m.hasSignatureAndHash { + n := len(m.signatureAndHashes) * 2 + y[0] = uint8(n >> 8) + y[1] = uint8(n) + y = y[2:] + for _, sigAndHash := range m.signatureAndHashes { + y[0] = sigAndHash.hash + y[1] = sigAndHash.signature + y = y[2:] + } + } + + y[0] = uint8(casLength >> 8) + y[1] = uint8(casLength) + y = y[2:] + for _, ca := range m.certificateAuthorities { + y[0] = uint8(len(ca) >> 8) + y[1] = uint8(len(ca)) + y = y[2:] + copy(y, ca) + y = y[len(ca):] + } + + m.raw = x + return +} + +func (m *certificateRequestMsg) unmarshal(data []byte) bool { + m.raw = data + + if len(data) < 5 { + return false + } + + length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3]) + if uint32(len(data))-4 != length { + return false + } + + numCertTypes := int(data[4]) + data = data[5:] + if numCertTypes == 0 || len(data) <= numCertTypes { + return false + } + + m.certificateTypes = make([]byte, numCertTypes) + if copy(m.certificateTypes, data) != numCertTypes { + return false + } + + data = data[numCertTypes:] + + if m.hasSignatureAndHash { + if len(data) < 2 { + return false + } + sigAndHashLen := uint16(data[0])<<8 | uint16(data[1]) + data = data[2:] + if sigAndHashLen&1 != 0 { + return false + } + if len(data) < int(sigAndHashLen) { + return false + } + numSigAndHash := sigAndHashLen / 2 + m.signatureAndHashes = make([]signatureAndHash, numSigAndHash) + for i := range m.signatureAndHashes { + m.signatureAndHashes[i].hash = data[0] + m.signatureAndHashes[i].signature = data[1] + data = data[2:] + } + } + + if len(data) < 2 { + return false + } + casLength := uint16(data[0])<<8 | uint16(data[1]) + data = data[2:] + if len(data) < int(casLength) { + return false + } + cas := make([]byte, casLength) + copy(cas, data) + data = data[casLength:] + + m.certificateAuthorities = nil + for len(cas) > 0 { + if len(cas) < 2 { + return false + } + caLen := uint16(cas[0])<<8 | uint16(cas[1]) + cas = cas[2:] + + if len(cas) < int(caLen) { + return false + } + + m.certificateAuthorities = append(m.certificateAuthorities, cas[:caLen]) + cas = cas[caLen:] + } + if len(data) > 0 { + return false + } + + return true +} + +type certificateVerifyMsg struct { + raw []byte + hasSignatureAndHash bool + signatureAndHash signatureAndHash + signature []byte +} + +func (m *certificateVerifyMsg) equal(i interface{}) bool { + m1, ok := i.(*certificateVerifyMsg) + if !ok { + return false + } + + return bytes.Equal(m.raw, m1.raw) && + m.hasSignatureAndHash == m1.hasSignatureAndHash && + m.signatureAndHash.hash == m1.signatureAndHash.hash && + m.signatureAndHash.signature == m1.signatureAndHash.signature && + bytes.Equal(m.signature, m1.signature) +} + +func (m *certificateVerifyMsg) marshal() (x []byte) { + if m.raw != nil { + return m.raw + } + + // See http://tools.ietf.org/html/rfc4346#section-7.4.8 + siglength := len(m.signature) + length := 2 + siglength + if m.hasSignatureAndHash { + length += 2 + } + x = make([]byte, 4+length) + x[0] = typeCertificateVerify + x[1] = uint8(length >> 16) + x[2] = uint8(length >> 8) + x[3] = uint8(length) + y := x[4:] + if m.hasSignatureAndHash { + y[0] = m.signatureAndHash.hash + y[1] = m.signatureAndHash.signature + y = y[2:] + } + y[0] = uint8(siglength >> 8) + y[1] = uint8(siglength) + copy(y[2:], m.signature) + + m.raw = x + + return +} + +func (m *certificateVerifyMsg) unmarshal(data []byte) bool { + m.raw = data + + if len(data) < 6 { + return false + } + + length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3]) + if uint32(len(data))-4 != length { + return false + } + + data = data[4:] + if m.hasSignatureAndHash { + m.signatureAndHash.hash = data[0] + m.signatureAndHash.signature = data[1] + data = data[2:] + } + + if len(data) < 2 { + return false + } + siglength := int(data[0])<<8 + int(data[1]) + data = data[2:] + if len(data) != siglength { + return false + } + + m.signature = data + + return true +} + +type newSessionTicketMsg struct { + raw []byte + ticket []byte +} + +func (m *newSessionTicketMsg) equal(i interface{}) bool { + m1, ok := i.(*newSessionTicketMsg) + if !ok { + return false + } + + return bytes.Equal(m.raw, m1.raw) && + bytes.Equal(m.ticket, m1.ticket) +} + +func (m *newSessionTicketMsg) marshal() (x []byte) { + if m.raw != nil { + return m.raw + } + + // See http://tools.ietf.org/html/rfc5077#section-3.3 + ticketLen := len(m.ticket) + length := 2 + 4 + ticketLen + x = make([]byte, 4+length) + x[0] = typeNewSessionTicket + x[1] = uint8(length >> 16) + x[2] = uint8(length >> 8) + x[3] = uint8(length) + x[8] = uint8(ticketLen >> 8) + x[9] = uint8(ticketLen) + copy(x[10:], m.ticket) + + m.raw = x + + return +} + +func (m *newSessionTicketMsg) unmarshal(data []byte) bool { + m.raw = data + + if len(data) < 10 { + return false + } + + length := uint32(data[1])<<16 | uint32(data[2])<<8 | uint32(data[3]) + if uint32(len(data))-4 != length { + return false + } + + ticketLen := int(data[8])<<8 + int(data[9]) + if len(data)-10 != ticketLen { + return false + } + + m.ticket = data[10:] + + return true +} + +type v2ClientHelloMsg struct { + raw []byte + vers uint16 + cipherSuites []uint16 + sessionId []byte + challenge []byte +} + +func (m *v2ClientHelloMsg) equal(i interface{}) bool { + m1, ok := i.(*v2ClientHelloMsg) + if !ok { + return false + } + + return bytes.Equal(m.raw, m1.raw) && + m.vers == m1.vers && + eqUint16s(m.cipherSuites, m1.cipherSuites) && + bytes.Equal(m.sessionId, m1.sessionId) && + bytes.Equal(m.challenge, m1.challenge) +} + +func (m *v2ClientHelloMsg) marshal() []byte { + if m.raw != nil { + return m.raw + } + + length := 1 + 2 + 2 + 2 + 2 + len(m.cipherSuites)*3 + len(m.sessionId) + len(m.challenge) + + x := make([]byte, length) + x[0] = 1 + x[1] = uint8(m.vers >> 8) + x[2] = uint8(m.vers) + x[3] = uint8((len(m.cipherSuites) * 3) >> 8) + x[4] = uint8(len(m.cipherSuites) * 3) + x[5] = uint8(len(m.sessionId) >> 8) + x[6] = uint8(len(m.sessionId)) + x[7] = uint8(len(m.challenge) >> 8) + x[8] = uint8(len(m.challenge)) + y := x[9:] + for i, spec := range m.cipherSuites { + y[i*3] = 0 + y[i*3+1] = uint8(spec >> 8) + y[i*3+2] = uint8(spec) + } + y = y[len(m.cipherSuites)*3:] + copy(y, m.sessionId) + y = y[len(m.sessionId):] + copy(y, m.challenge) + + m.raw = x + + return x +} + +type helloVerifyRequestMsg struct { + raw []byte + vers uint16 + cookie []byte +} + +func (m *helloVerifyRequestMsg) equal(i interface{}) bool { + m1, ok := i.(*helloVerifyRequestMsg) + if !ok { + return false + } + + return bytes.Equal(m.raw, m1.raw) && + m.vers == m1.vers && + bytes.Equal(m.cookie, m1.cookie) +} + +func (m *helloVerifyRequestMsg) marshal() []byte { + if m.raw != nil { + return m.raw + } + + length := 2 + 1 + len(m.cookie) + + x := make([]byte, 4+length) + x[0] = typeHelloVerifyRequest + x[1] = uint8(length >> 16) + x[2] = uint8(length >> 8) + x[3] = uint8(length) + vers := versionToWire(m.vers, true) + x[4] = uint8(vers >> 8) + x[5] = uint8(vers) + x[6] = uint8(len(m.cookie)) + copy(x[7:7+len(m.cookie)], m.cookie) + + return x +} + +func (m *helloVerifyRequestMsg) unmarshal(data []byte) bool { + if len(data) < 4+2+1 { + return false + } + m.raw = data + m.vers = wireToVersion(uint16(data[4])<<8|uint16(data[5]), true) + cookieLen := int(data[6]) + if cookieLen > 32 || len(data) != 7+cookieLen { + return false + } + m.cookie = data[7 : 7+cookieLen] + + return true +} + +type encryptedExtensionsMsg struct { + raw []byte + channelID []byte +} + +func (m *encryptedExtensionsMsg) equal(i interface{}) bool { + m1, ok := i.(*encryptedExtensionsMsg) + if !ok { + return false + } + + return bytes.Equal(m.raw, m1.raw) && + bytes.Equal(m.channelID, m1.channelID) +} + +func (m *encryptedExtensionsMsg) marshal() []byte { + if m.raw != nil { + return m.raw + } + + length := 2 + 2 + len(m.channelID) + + x := make([]byte, 4+length) + x[0] = typeEncryptedExtensions + x[1] = uint8(length >> 16) + x[2] = uint8(length >> 8) + x[3] = uint8(length) + x[4] = uint8(extensionChannelID >> 8) + x[5] = uint8(extensionChannelID & 0xff) + x[6] = uint8(len(m.channelID) >> 8) + x[7] = uint8(len(m.channelID) & 0xff) + copy(x[8:], m.channelID) + + return x +} + +func (m *encryptedExtensionsMsg) unmarshal(data []byte) bool { + if len(data) != 4+2+2+128 { + return false + } + m.raw = data + if (uint16(data[4])<<8)|uint16(data[5]) != extensionChannelID { + return false + } + if int(data[6])<<8|int(data[7]) != 128 { + return false + } + m.channelID = data[4+2+2:] + + return true +} + +type helloRequestMsg struct { +} + +func (*helloRequestMsg) marshal() []byte { + return []byte{typeHelloRequest, 0, 0, 0} +} + +func (*helloRequestMsg) unmarshal(data []byte) bool { + return len(data) == 4 +} + +func eqUint16s(x, y []uint16) bool { + if len(x) != len(y) { + return false + } + for i, v := range x { + if y[i] != v { + return false + } + } + return true +} + +func eqCurveIDs(x, y []CurveID) bool { + if len(x) != len(y) { + return false + } + for i, v := range x { + if y[i] != v { + return false + } + } + return true +} + +func eqStrings(x, y []string) bool { + if len(x) != len(y) { + return false + } + for i, v := range x { + if y[i] != v { + return false + } + } + return true +} + +func eqByteSlices(x, y [][]byte) bool { + if len(x) != len(y) { + return false + } + for i, v := range x { + if !bytes.Equal(v, y[i]) { + return false + } + } + return true +} + +func eqSignatureAndHashes(x, y []signatureAndHash) bool { + if len(x) != len(y) { + return false + } + for i, v := range x { + v2 := y[i] + if v.hash != v2.hash || v.signature != v2.signature { + return false + } + } + return true +} diff --git a/src/ssl/test/runner/handshake_server.go b/src/ssl/test/runner/handshake_server.go new file mode 100644 index 0000000..1234a57 --- /dev/null +++ b/src/ssl/test/runner/handshake_server.go @@ -0,0 +1,964 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "bytes" + "crypto" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/rsa" + "crypto/subtle" + "crypto/x509" + "encoding/asn1" + "errors" + "fmt" + "io" + "math/big" +) + +// serverHandshakeState contains details of a server handshake in progress. +// It's discarded once the handshake has completed. +type serverHandshakeState struct { + c *Conn + clientHello *clientHelloMsg + hello *serverHelloMsg + suite *cipherSuite + ellipticOk bool + ecdsaOk bool + sessionState *sessionState + finishedHash finishedHash + masterSecret []byte + certsFromClient [][]byte + cert *Certificate +} + +// serverHandshake performs a TLS handshake as a server. +func (c *Conn) serverHandshake() error { + config := c.config + + // If this is the first server handshake, we generate a random key to + // encrypt the tickets with. + config.serverInitOnce.Do(config.serverInit) + + c.sendHandshakeSeq = 0 + c.recvHandshakeSeq = 0 + + hs := serverHandshakeState{ + c: c, + } + isResume, err := hs.readClientHello() + if err != nil { + return err + } + + // For an overview of TLS handshaking, see https://tools.ietf.org/html/rfc5246#section-7.3 + if isResume { + // The client has included a session ticket and so we do an abbreviated handshake. + if err := hs.doResumeHandshake(); err != nil { + return err + } + if err := hs.establishKeys(); err != nil { + return err + } + if c.config.Bugs.RenewTicketOnResume { + if err := hs.sendSessionTicket(); err != nil { + return err + } + } + if err := hs.sendFinished(); err != nil { + return err + } + if err := hs.readFinished(isResume); err != nil { + return err + } + c.didResume = true + } else { + // The client didn't include a session ticket, or it wasn't + // valid so we do a full handshake. + if err := hs.doFullHandshake(); err != nil { + return err + } + if err := hs.establishKeys(); err != nil { + return err + } + if err := hs.readFinished(isResume); err != nil { + return err + } + if c.config.Bugs.ExpectFalseStart { + if err := c.readRecord(recordTypeApplicationData); err != nil { + return err + } + } + if err := hs.sendSessionTicket(); err != nil { + return err + } + if err := hs.sendFinished(); err != nil { + return err + } + } + c.handshakeComplete = true + + return nil +} + +// readClientHello reads a ClientHello message from the client and decides +// whether we will perform session resumption. +func (hs *serverHandshakeState) readClientHello() (isResume bool, err error) { + config := hs.c.config + c := hs.c + + msg, err := c.readHandshake() + if err != nil { + return false, err + } + var ok bool + hs.clientHello, ok = msg.(*clientHelloMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return false, unexpectedMessageError(hs.clientHello, msg) + } + if config.Bugs.RequireFastradioPadding && len(hs.clientHello.raw) < 1000 { + return false, errors.New("tls: ClientHello record size should be larger than 1000 bytes when padding enabled.") + } + + if c.isDTLS && !config.Bugs.SkipHelloVerifyRequest { + // Per RFC 6347, the version field in HelloVerifyRequest SHOULD + // be always DTLS 1.0 + helloVerifyRequest := &helloVerifyRequestMsg{ + vers: VersionTLS10, + cookie: make([]byte, 32), + } + if _, err := io.ReadFull(c.config.rand(), helloVerifyRequest.cookie); err != nil { + c.sendAlert(alertInternalError) + return false, errors.New("dtls: short read from Rand: " + err.Error()) + } + c.writeRecord(recordTypeHandshake, helloVerifyRequest.marshal()) + + msg, err := c.readHandshake() + if err != nil { + return false, err + } + newClientHello, ok := msg.(*clientHelloMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return false, unexpectedMessageError(hs.clientHello, msg) + } + if !bytes.Equal(newClientHello.cookie, helloVerifyRequest.cookie) { + return false, errors.New("dtls: invalid cookie") + } + + // Apart from the cookie, the two ClientHellos must + // match. Note that clientHello.equal compares the + // serialization, so we make a copy. + oldClientHelloCopy := *hs.clientHello + oldClientHelloCopy.raw = nil + oldClientHelloCopy.cookie = nil + newClientHelloCopy := *newClientHello + newClientHelloCopy.raw = nil + newClientHelloCopy.cookie = nil + if !oldClientHelloCopy.equal(&newClientHelloCopy) { + return false, errors.New("dtls: retransmitted ClientHello does not match") + } + hs.clientHello = newClientHello + } + + if config.Bugs.RequireSameRenegoClientVersion && c.clientVersion != 0 { + if c.clientVersion != hs.clientHello.vers { + return false, fmt.Errorf("tls: client offered different version on renego") + } + } + c.clientVersion = hs.clientHello.vers + + // Reject < 1.2 ClientHellos with signature_algorithms. + if c.clientVersion < VersionTLS12 && len(hs.clientHello.signatureAndHashes) > 0 { + return false, fmt.Errorf("tls: client included signature_algorithms before TLS 1.2") + } + + c.vers, ok = config.mutualVersion(hs.clientHello.vers) + if !ok { + c.sendAlert(alertProtocolVersion) + return false, fmt.Errorf("tls: client offered an unsupported, maximum protocol version of %x", hs.clientHello.vers) + } + c.haveVers = true + + hs.hello = new(serverHelloMsg) + hs.hello.isDTLS = c.isDTLS + + supportedCurve := false + preferredCurves := config.curvePreferences() +Curves: + for _, curve := range hs.clientHello.supportedCurves { + for _, supported := range preferredCurves { + if supported == curve { + supportedCurve = true + break Curves + } + } + } + + supportedPointFormat := false + for _, pointFormat := range hs.clientHello.supportedPoints { + if pointFormat == pointFormatUncompressed { + supportedPointFormat = true + break + } + } + hs.ellipticOk = supportedCurve && supportedPointFormat + + foundCompression := false + // We only support null compression, so check that the client offered it. + for _, compression := range hs.clientHello.compressionMethods { + if compression == compressionNone { + foundCompression = true + break + } + } + + if !foundCompression { + c.sendAlert(alertHandshakeFailure) + return false, errors.New("tls: client does not support uncompressed connections") + } + + hs.hello.vers = c.vers + hs.hello.random = make([]byte, 32) + _, err = io.ReadFull(config.rand(), hs.hello.random) + if err != nil { + c.sendAlert(alertInternalError) + return false, err + } + + if !bytes.Equal(c.clientVerify, hs.clientHello.secureRenegotiation) { + c.sendAlert(alertHandshakeFailure) + return false, errors.New("tls: renegotiation mismatch") + } + + if len(c.clientVerify) > 0 && !c.config.Bugs.EmptyRenegotiationInfo { + hs.hello.secureRenegotiation = append(hs.hello.secureRenegotiation, c.clientVerify...) + hs.hello.secureRenegotiation = append(hs.hello.secureRenegotiation, c.serverVerify...) + if c.config.Bugs.BadRenegotiationInfo { + hs.hello.secureRenegotiation[0] ^= 0x80 + } + } else { + hs.hello.secureRenegotiation = hs.clientHello.secureRenegotiation + } + + hs.hello.compressionMethod = compressionNone + hs.hello.duplicateExtension = c.config.Bugs.DuplicateExtension + if len(hs.clientHello.serverName) > 0 { + c.serverName = hs.clientHello.serverName + } + + if len(hs.clientHello.alpnProtocols) > 0 { + if selectedProto, fallback := mutualProtocol(hs.clientHello.alpnProtocols, c.config.NextProtos); !fallback { + hs.hello.alpnProtocol = selectedProto + c.clientProtocol = selectedProto + c.usedALPN = true + } + } else { + // Although sending an empty NPN extension is reasonable, Firefox has + // had a bug around this. Best to send nothing at all if + // config.NextProtos is empty. See + // https://code.google.com/p/go/issues/detail?id=5445. + if hs.clientHello.nextProtoNeg && len(config.NextProtos) > 0 { + hs.hello.nextProtoNeg = true + hs.hello.nextProtos = config.NextProtos + } + } + hs.hello.extendedMasterSecret = c.vers >= VersionTLS10 && hs.clientHello.extendedMasterSecret && !c.config.Bugs.NoExtendedMasterSecret + + if len(config.Certificates) == 0 { + c.sendAlert(alertInternalError) + return false, errors.New("tls: no certificates configured") + } + hs.cert = &config.Certificates[0] + if len(hs.clientHello.serverName) > 0 { + hs.cert = config.getCertificateForName(hs.clientHello.serverName) + } + if expected := c.config.Bugs.ExpectServerName; expected != "" && expected != hs.clientHello.serverName { + return false, errors.New("tls: unexpected server name") + } + + if hs.clientHello.channelIDSupported && config.RequestChannelID { + hs.hello.channelIDRequested = true + } + + if hs.clientHello.srtpProtectionProfiles != nil { + SRTPLoop: + for _, p1 := range c.config.SRTPProtectionProfiles { + for _, p2 := range hs.clientHello.srtpProtectionProfiles { + if p1 == p2 { + hs.hello.srtpProtectionProfile = p1 + c.srtpProtectionProfile = p1 + break SRTPLoop + } + } + } + } + + if c.config.Bugs.SendSRTPProtectionProfile != 0 { + hs.hello.srtpProtectionProfile = c.config.Bugs.SendSRTPProtectionProfile + } + + _, hs.ecdsaOk = hs.cert.PrivateKey.(*ecdsa.PrivateKey) + + if hs.checkForResumption() { + return true, nil + } + + var scsvFound bool + + for _, cipherSuite := range hs.clientHello.cipherSuites { + if cipherSuite == fallbackSCSV { + scsvFound = true + break + } + } + + if !scsvFound && config.Bugs.FailIfNotFallbackSCSV { + return false, errors.New("tls: no fallback SCSV found when expected") + } else if scsvFound && !config.Bugs.FailIfNotFallbackSCSV { + return false, errors.New("tls: fallback SCSV found when not expected") + } + + var preferenceList, supportedList []uint16 + if c.config.PreferServerCipherSuites { + preferenceList = c.config.cipherSuites() + supportedList = hs.clientHello.cipherSuites + } else { + preferenceList = hs.clientHello.cipherSuites + supportedList = c.config.cipherSuites() + } + + for _, id := range preferenceList { + if hs.suite = c.tryCipherSuite(id, supportedList, c.vers, hs.ellipticOk, hs.ecdsaOk); hs.suite != nil { + break + } + } + + if hs.suite == nil { + c.sendAlert(alertHandshakeFailure) + return false, errors.New("tls: no cipher suite supported by both client and server") + } + + return false, nil +} + +// checkForResumption returns true if we should perform resumption on this connection. +func (hs *serverHandshakeState) checkForResumption() bool { + c := hs.c + + if len(hs.clientHello.sessionTicket) > 0 { + if c.config.SessionTicketsDisabled { + return false + } + + var ok bool + if hs.sessionState, ok = c.decryptTicket(hs.clientHello.sessionTicket); !ok { + return false + } + } else { + if c.config.ServerSessionCache == nil { + return false + } + + var ok bool + sessionId := string(hs.clientHello.sessionId) + if hs.sessionState, ok = c.config.ServerSessionCache.Get(sessionId); !ok { + return false + } + } + + // Never resume a session for a different SSL version. + if !c.config.Bugs.AllowSessionVersionMismatch && c.vers != hs.sessionState.vers { + return false + } + + cipherSuiteOk := false + // Check that the client is still offering the ciphersuite in the session. + for _, id := range hs.clientHello.cipherSuites { + if id == hs.sessionState.cipherSuite { + cipherSuiteOk = true + break + } + } + if !cipherSuiteOk { + return false + } + + // Check that we also support the ciphersuite from the session. + hs.suite = c.tryCipherSuite(hs.sessionState.cipherSuite, c.config.cipherSuites(), hs.sessionState.vers, hs.ellipticOk, hs.ecdsaOk) + if hs.suite == nil { + return false + } + + sessionHasClientCerts := len(hs.sessionState.certificates) != 0 + needClientCerts := c.config.ClientAuth == RequireAnyClientCert || c.config.ClientAuth == RequireAndVerifyClientCert + if needClientCerts && !sessionHasClientCerts { + return false + } + if sessionHasClientCerts && c.config.ClientAuth == NoClientCert { + return false + } + + return true +} + +func (hs *serverHandshakeState) doResumeHandshake() error { + c := hs.c + + hs.hello.cipherSuite = hs.suite.id + // We echo the client's session ID in the ServerHello to let it know + // that we're doing a resumption. + hs.hello.sessionId = hs.clientHello.sessionId + hs.hello.ticketSupported = c.config.Bugs.RenewTicketOnResume + + hs.finishedHash = newFinishedHash(c.vers, hs.suite) + hs.finishedHash.discardHandshakeBuffer() + hs.writeClientHash(hs.clientHello.marshal()) + hs.writeServerHash(hs.hello.marshal()) + + c.writeRecord(recordTypeHandshake, hs.hello.marshal()) + + if len(hs.sessionState.certificates) > 0 { + if _, err := hs.processCertsFromClient(hs.sessionState.certificates); err != nil { + return err + } + } + + hs.masterSecret = hs.sessionState.masterSecret + c.extendedMasterSecret = hs.sessionState.extendedMasterSecret + + return nil +} + +func (hs *serverHandshakeState) doFullHandshake() error { + config := hs.c.config + c := hs.c + + isPSK := hs.suite.flags&suitePSK != 0 + if !isPSK && hs.clientHello.ocspStapling && len(hs.cert.OCSPStaple) > 0 { + hs.hello.ocspStapling = true + } + + if hs.clientHello.sctListSupported && len(hs.cert.SignedCertificateTimestampList) > 0 { + hs.hello.sctList = hs.cert.SignedCertificateTimestampList + } + + hs.hello.ticketSupported = hs.clientHello.ticketSupported && !config.SessionTicketsDisabled && c.vers > VersionSSL30 + hs.hello.cipherSuite = hs.suite.id + if config.Bugs.SendCipherSuite != 0 { + hs.hello.cipherSuite = config.Bugs.SendCipherSuite + } + c.extendedMasterSecret = hs.hello.extendedMasterSecret + + // Generate a session ID if we're to save the session. + if !hs.hello.ticketSupported && config.ServerSessionCache != nil { + hs.hello.sessionId = make([]byte, 32) + if _, err := io.ReadFull(config.rand(), hs.hello.sessionId); err != nil { + c.sendAlert(alertInternalError) + return errors.New("tls: short read from Rand: " + err.Error()) + } + } + + hs.finishedHash = newFinishedHash(c.vers, hs.suite) + hs.writeClientHash(hs.clientHello.marshal()) + hs.writeServerHash(hs.hello.marshal()) + + c.writeRecord(recordTypeHandshake, hs.hello.marshal()) + + if !isPSK { + certMsg := new(certificateMsg) + certMsg.certificates = hs.cert.Certificate + if !config.Bugs.UnauthenticatedECDH { + hs.writeServerHash(certMsg.marshal()) + c.writeRecord(recordTypeHandshake, certMsg.marshal()) + } + } + + if hs.hello.ocspStapling { + certStatus := new(certificateStatusMsg) + certStatus.statusType = statusTypeOCSP + certStatus.response = hs.cert.OCSPStaple + hs.writeServerHash(certStatus.marshal()) + c.writeRecord(recordTypeHandshake, certStatus.marshal()) + } + + keyAgreement := hs.suite.ka(c.vers) + skx, err := keyAgreement.generateServerKeyExchange(config, hs.cert, hs.clientHello, hs.hello) + if err != nil { + c.sendAlert(alertHandshakeFailure) + return err + } + if skx != nil && !config.Bugs.SkipServerKeyExchange { + hs.writeServerHash(skx.marshal()) + c.writeRecord(recordTypeHandshake, skx.marshal()) + } + + if config.ClientAuth >= RequestClientCert { + // Request a client certificate + certReq := &certificateRequestMsg{ + certificateTypes: config.ClientCertificateTypes, + } + if certReq.certificateTypes == nil { + certReq.certificateTypes = []byte{ + byte(CertTypeRSASign), + byte(CertTypeECDSASign), + } + } + if c.vers >= VersionTLS12 { + certReq.hasSignatureAndHash = true + if !config.Bugs.NoSignatureAndHashes { + certReq.signatureAndHashes = config.signatureAndHashesForServer() + } + } + + // An empty list of certificateAuthorities signals to + // the client that it may send any certificate in response + // to our request. When we know the CAs we trust, then + // we can send them down, so that the client can choose + // an appropriate certificate to give to us. + if config.ClientCAs != nil { + certReq.certificateAuthorities = config.ClientCAs.Subjects() + } + hs.writeServerHash(certReq.marshal()) + c.writeRecord(recordTypeHandshake, certReq.marshal()) + } + + helloDone := new(serverHelloDoneMsg) + hs.writeServerHash(helloDone.marshal()) + c.writeRecord(recordTypeHandshake, helloDone.marshal()) + + var pub crypto.PublicKey // public key for client auth, if any + + msg, err := c.readHandshake() + if err != nil { + return err + } + + var ok bool + // If we requested a client certificate, then the client must send a + // certificate message, even if it's empty. + if config.ClientAuth >= RequestClientCert { + var certMsg *certificateMsg + if certMsg, ok = msg.(*certificateMsg); !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(certMsg, msg) + } + hs.writeClientHash(certMsg.marshal()) + + if len(certMsg.certificates) == 0 { + // The client didn't actually send a certificate + switch config.ClientAuth { + case RequireAnyClientCert, RequireAndVerifyClientCert: + c.sendAlert(alertBadCertificate) + return errors.New("tls: client didn't provide a certificate") + } + } + + pub, err = hs.processCertsFromClient(certMsg.certificates) + if err != nil { + return err + } + + msg, err = c.readHandshake() + if err != nil { + return err + } + } + + // Get client key exchange + ckx, ok := msg.(*clientKeyExchangeMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(ckx, msg) + } + hs.writeClientHash(ckx.marshal()) + + preMasterSecret, err := keyAgreement.processClientKeyExchange(config, hs.cert, ckx, c.vers) + if err != nil { + c.sendAlert(alertHandshakeFailure) + return err + } + if c.extendedMasterSecret { + hs.masterSecret = extendedMasterFromPreMasterSecret(c.vers, hs.suite, preMasterSecret, hs.finishedHash) + } else { + if c.config.Bugs.RequireExtendedMasterSecret { + return errors.New("tls: extended master secret required but not supported by peer") + } + hs.masterSecret = masterFromPreMasterSecret(c.vers, hs.suite, preMasterSecret, hs.clientHello.random, hs.hello.random) + } + + // If we received a client cert in response to our certificate request message, + // the client will send us a certificateVerifyMsg immediately after the + // clientKeyExchangeMsg. This message is a digest of all preceding + // handshake-layer messages that is signed using the private key corresponding + // to the client's certificate. This allows us to verify that the client is in + // possession of the private key of the certificate. + if len(c.peerCertificates) > 0 { + msg, err = c.readHandshake() + if err != nil { + return err + } + certVerify, ok := msg.(*certificateVerifyMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(certVerify, msg) + } + + // Determine the signature type. + var signatureAndHash signatureAndHash + if certVerify.hasSignatureAndHash { + signatureAndHash = certVerify.signatureAndHash + if !isSupportedSignatureAndHash(signatureAndHash, config.signatureAndHashesForServer()) { + return errors.New("tls: unsupported hash function for client certificate") + } + } else { + // Before TLS 1.2 the signature algorithm was implicit + // from the key type, and only one hash per signature + // algorithm was possible. Leave the hash as zero. + switch pub.(type) { + case *ecdsa.PublicKey: + signatureAndHash.signature = signatureECDSA + case *rsa.PublicKey: + signatureAndHash.signature = signatureRSA + } + } + + switch key := pub.(type) { + case *ecdsa.PublicKey: + if signatureAndHash.signature != signatureECDSA { + err = errors.New("tls: bad signature type for client's ECDSA certificate") + break + } + ecdsaSig := new(ecdsaSignature) + if _, err = asn1.Unmarshal(certVerify.signature, ecdsaSig); err != nil { + break + } + if ecdsaSig.R.Sign() <= 0 || ecdsaSig.S.Sign() <= 0 { + err = errors.New("ECDSA signature contained zero or negative values") + break + } + var digest []byte + digest, _, err = hs.finishedHash.hashForClientCertificate(signatureAndHash, hs.masterSecret) + if err != nil { + break + } + if !ecdsa.Verify(key, digest, ecdsaSig.R, ecdsaSig.S) { + err = errors.New("ECDSA verification failure") + break + } + case *rsa.PublicKey: + if signatureAndHash.signature != signatureRSA { + err = errors.New("tls: bad signature type for client's RSA certificate") + break + } + var digest []byte + var hashFunc crypto.Hash + digest, hashFunc, err = hs.finishedHash.hashForClientCertificate(signatureAndHash, hs.masterSecret) + if err != nil { + break + } + err = rsa.VerifyPKCS1v15(key, hashFunc, digest, certVerify.signature) + } + if err != nil { + c.sendAlert(alertBadCertificate) + return errors.New("could not validate signature of connection nonces: " + err.Error()) + } + + hs.writeClientHash(certVerify.marshal()) + } + + hs.finishedHash.discardHandshakeBuffer() + + return nil +} + +func (hs *serverHandshakeState) establishKeys() error { + c := hs.c + + clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV := + keysFromMasterSecret(c.vers, hs.suite, hs.masterSecret, hs.clientHello.random, hs.hello.random, hs.suite.macLen, hs.suite.keyLen, hs.suite.ivLen) + + var clientCipher, serverCipher interface{} + var clientHash, serverHash macFunction + + if hs.suite.aead == nil { + clientCipher = hs.suite.cipher(clientKey, clientIV, true /* for reading */) + clientHash = hs.suite.mac(c.vers, clientMAC) + serverCipher = hs.suite.cipher(serverKey, serverIV, false /* not for reading */) + serverHash = hs.suite.mac(c.vers, serverMAC) + } else { + clientCipher = hs.suite.aead(clientKey, clientIV) + serverCipher = hs.suite.aead(serverKey, serverIV) + } + + c.in.prepareCipherSpec(c.vers, clientCipher, clientHash) + c.out.prepareCipherSpec(c.vers, serverCipher, serverHash) + + return nil +} + +func (hs *serverHandshakeState) readFinished(isResume bool) error { + c := hs.c + + c.readRecord(recordTypeChangeCipherSpec) + if err := c.in.error(); err != nil { + return err + } + + if hs.hello.nextProtoNeg { + msg, err := c.readHandshake() + if err != nil { + return err + } + nextProto, ok := msg.(*nextProtoMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(nextProto, msg) + } + hs.writeClientHash(nextProto.marshal()) + c.clientProtocol = nextProto.proto + } + + if hs.hello.channelIDRequested { + msg, err := c.readHandshake() + if err != nil { + return err + } + encryptedExtensions, ok := msg.(*encryptedExtensionsMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(encryptedExtensions, msg) + } + x := new(big.Int).SetBytes(encryptedExtensions.channelID[0:32]) + y := new(big.Int).SetBytes(encryptedExtensions.channelID[32:64]) + r := new(big.Int).SetBytes(encryptedExtensions.channelID[64:96]) + s := new(big.Int).SetBytes(encryptedExtensions.channelID[96:128]) + if !elliptic.P256().IsOnCurve(x, y) { + return errors.New("tls: invalid channel ID public key") + } + channelID := &ecdsa.PublicKey{elliptic.P256(), x, y} + var resumeHash []byte + if isResume { + resumeHash = hs.sessionState.handshakeHash + } + if !ecdsa.Verify(channelID, hs.finishedHash.hashForChannelID(resumeHash), r, s) { + return errors.New("tls: invalid channel ID signature") + } + c.channelID = channelID + + hs.writeClientHash(encryptedExtensions.marshal()) + } + + msg, err := c.readHandshake() + if err != nil { + return err + } + clientFinished, ok := msg.(*finishedMsg) + if !ok { + c.sendAlert(alertUnexpectedMessage) + return unexpectedMessageError(clientFinished, msg) + } + + verify := hs.finishedHash.clientSum(hs.masterSecret) + if len(verify) != len(clientFinished.verifyData) || + subtle.ConstantTimeCompare(verify, clientFinished.verifyData) != 1 { + c.sendAlert(alertHandshakeFailure) + return errors.New("tls: client's Finished message is incorrect") + } + c.clientVerify = append(c.clientVerify[:0], clientFinished.verifyData...) + + hs.writeClientHash(clientFinished.marshal()) + return nil +} + +func (hs *serverHandshakeState) sendSessionTicket() error { + c := hs.c + state := sessionState{ + vers: c.vers, + cipherSuite: hs.suite.id, + masterSecret: hs.masterSecret, + certificates: hs.certsFromClient, + handshakeHash: hs.finishedHash.server.Sum(nil), + } + + if !hs.hello.ticketSupported || hs.c.config.Bugs.SkipNewSessionTicket { + if c.config.ServerSessionCache != nil && len(hs.hello.sessionId) != 0 { + c.config.ServerSessionCache.Put(string(hs.hello.sessionId), &state) + } + return nil + } + + m := new(newSessionTicketMsg) + + var err error + m.ticket, err = c.encryptTicket(&state) + if err != nil { + return err + } + + hs.writeServerHash(m.marshal()) + c.writeRecord(recordTypeHandshake, m.marshal()) + + return nil +} + +func (hs *serverHandshakeState) sendFinished() error { + c := hs.c + + finished := new(finishedMsg) + finished.verifyData = hs.finishedHash.serverSum(hs.masterSecret) + c.serverVerify = append(c.serverVerify[:0], finished.verifyData...) + postCCSBytes := finished.marshal() + hs.writeServerHash(postCCSBytes) + + if c.config.Bugs.FragmentAcrossChangeCipherSpec { + c.writeRecord(recordTypeHandshake, postCCSBytes[:5]) + postCCSBytes = postCCSBytes[5:] + } + + if !c.config.Bugs.SkipChangeCipherSpec { + c.writeRecord(recordTypeChangeCipherSpec, []byte{1}) + } + + if c.config.Bugs.AppDataAfterChangeCipherSpec != nil { + c.writeRecord(recordTypeApplicationData, c.config.Bugs.AppDataAfterChangeCipherSpec) + } + + c.writeRecord(recordTypeHandshake, postCCSBytes) + + c.cipherSuite = hs.suite.id + + return nil +} + +// processCertsFromClient takes a chain of client certificates either from a +// Certificates message or from a sessionState and verifies them. It returns +// the public key of the leaf certificate. +func (hs *serverHandshakeState) processCertsFromClient(certificates [][]byte) (crypto.PublicKey, error) { + c := hs.c + + hs.certsFromClient = certificates + certs := make([]*x509.Certificate, len(certificates)) + var err error + for i, asn1Data := range certificates { + if certs[i], err = x509.ParseCertificate(asn1Data); err != nil { + c.sendAlert(alertBadCertificate) + return nil, errors.New("tls: failed to parse client certificate: " + err.Error()) + } + } + + if c.config.ClientAuth >= VerifyClientCertIfGiven && len(certs) > 0 { + opts := x509.VerifyOptions{ + Roots: c.config.ClientCAs, + CurrentTime: c.config.time(), + Intermediates: x509.NewCertPool(), + KeyUsages: []x509.ExtKeyUsage{x509.ExtKeyUsageClientAuth}, + } + + for _, cert := range certs[1:] { + opts.Intermediates.AddCert(cert) + } + + chains, err := certs[0].Verify(opts) + if err != nil { + c.sendAlert(alertBadCertificate) + return nil, errors.New("tls: failed to verify client's certificate: " + err.Error()) + } + + ok := false + for _, ku := range certs[0].ExtKeyUsage { + if ku == x509.ExtKeyUsageClientAuth { + ok = true + break + } + } + if !ok { + c.sendAlert(alertHandshakeFailure) + return nil, errors.New("tls: client's certificate's extended key usage doesn't permit it to be used for client authentication") + } + + c.verifiedChains = chains + } + + if len(certs) > 0 { + var pub crypto.PublicKey + switch key := certs[0].PublicKey.(type) { + case *ecdsa.PublicKey, *rsa.PublicKey: + pub = key + default: + c.sendAlert(alertUnsupportedCertificate) + return nil, fmt.Errorf("tls: client's certificate contains an unsupported public key of type %T", certs[0].PublicKey) + } + c.peerCertificates = certs + return pub, nil + } + + return nil, nil +} + +func (hs *serverHandshakeState) writeServerHash(msg []byte) { + // writeServerHash is called before writeRecord. + hs.writeHash(msg, hs.c.sendHandshakeSeq) +} + +func (hs *serverHandshakeState) writeClientHash(msg []byte) { + // writeClientHash is called after readHandshake. + hs.writeHash(msg, hs.c.recvHandshakeSeq-1) +} + +func (hs *serverHandshakeState) writeHash(msg []byte, seqno uint16) { + if hs.c.isDTLS { + // This is somewhat hacky. DTLS hashes a slightly different format. + // First, the TLS header. + hs.finishedHash.Write(msg[:4]) + // Then the sequence number and reassembled fragment offset (always 0). + hs.finishedHash.Write([]byte{byte(seqno >> 8), byte(seqno), 0, 0, 0}) + // Then the reassembled fragment (always equal to the message length). + hs.finishedHash.Write(msg[1:4]) + // And then the message body. + hs.finishedHash.Write(msg[4:]) + } else { + hs.finishedHash.Write(msg) + } +} + +// tryCipherSuite returns a cipherSuite with the given id if that cipher suite +// is acceptable to use. +func (c *Conn) tryCipherSuite(id uint16, supportedCipherSuites []uint16, version uint16, ellipticOk, ecdsaOk bool) *cipherSuite { + for _, supported := range supportedCipherSuites { + if id == supported { + var candidate *cipherSuite + + for _, s := range cipherSuites { + if s.id == id { + candidate = s + break + } + } + if candidate == nil { + continue + } + // Don't select a ciphersuite which we can't + // support for this client. + if (candidate.flags&suiteECDHE != 0) && !ellipticOk { + continue + } + if (candidate.flags&suiteECDSA != 0) != ecdsaOk { + continue + } + if !c.config.Bugs.SkipCipherVersionCheck && version < VersionTLS12 && candidate.flags&suiteTLS12 != 0 { + continue + } + if c.isDTLS && candidate.flags&suiteNoDTLS != 0 { + continue + } + return candidate + } + } + + return nil +} diff --git a/src/ssl/test/runner/key.pem b/src/ssl/test/runner/key.pem new file mode 100644 index 0000000..e9107bf --- /dev/null +++ b/src/ssl/test/runner/key.pem @@ -0,0 +1,15 @@ +-----BEGIN RSA PRIVATE KEY----- +MIICXgIBAAKBgQDYK8imMuRi/03z0K1Zi0WnvfFHvwlYeyK9Na6XJYaUoIDAtB92 +kWdGMdAQhLciHnAjkXLI6W15OoV3gA/ElRZ1xUpxTMhjP6PyY5wqT5r6y8FxbiiF +KKAnHmUcrgfVW28tQ+0rkLGMryRtrukXOgXBv7gcrmU7G1jC2a7WqmeI8QIDAQAB +AoGBAIBy09Fd4DOq/Ijp8HeKuCMKTHqTW1xGHshLQ6jwVV2vWZIn9aIgmDsvkjCe +i6ssZvnbjVcwzSoByhjN8ZCf/i15HECWDFFh6gt0P5z0MnChwzZmvatV/FXCT0j+ +WmGNB/gkehKjGXLLcjTb6dRYVJSCZhVuOLLcbWIV10gggJQBAkEA8S8sGe4ezyyZ +m4e9r95g6s43kPqtj5rewTsUxt+2n4eVodD+ZUlCULWVNAFLkYRTBCASlSrm9Xhj +QpmWAHJUkQJBAOVzQdFUaewLtdOJoPCtpYoY1zd22eae8TQEmpGOR11L6kbxLQsk +aMly/DOnOaa82tqAGTdqDEZgSNmCeKKknmECQAvpnY8GUOVAubGR6c+W90iBuQLj +LtFp/9ihd2w/PoDwrHZaoUYVcT4VSfJQog/k7kjE4MYXYWL8eEKg3WTWQNECQQDk +104Wi91Umd1PzF0ijd2jXOERJU1wEKe6XLkYYNHWQAe5l4J4MWj9OdxFXAxIuuR/ +tfDwbqkta4xcux67//khAkEAvvRXLHTaa6VFzTaiiO8SaFsHV3lQyXOtMrBpB5jd +moZWgjHvB2W9Ckn7sDqsPB+U2tyX0joDdQEyuiMECDY8oQ== +-----END RSA PRIVATE KEY----- diff --git a/src/ssl/test/runner/key_agreement.go b/src/ssl/test/runner/key_agreement.go new file mode 100644 index 0000000..116dfd8 --- /dev/null +++ b/src/ssl/test/runner/key_agreement.go @@ -0,0 +1,776 @@ +// Copyright 2010 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "crypto" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/md5" + "crypto/rand" + "crypto/rsa" + "crypto/sha1" + "crypto/x509" + "encoding/asn1" + "errors" + "io" + "math/big" +) + +var errClientKeyExchange = errors.New("tls: invalid ClientKeyExchange message") +var errServerKeyExchange = errors.New("tls: invalid ServerKeyExchange message") + +// rsaKeyAgreement implements the standard TLS key agreement where the client +// encrypts the pre-master secret to the server's public key. +type rsaKeyAgreement struct { + clientVersion uint16 +} + +func (ka *rsaKeyAgreement) generateServerKeyExchange(config *Config, cert *Certificate, clientHello *clientHelloMsg, hello *serverHelloMsg) (*serverKeyExchangeMsg, error) { + // Save the client version for comparison later. + ka.clientVersion = versionToWire(clientHello.vers, clientHello.isDTLS) + + if config.Bugs.RSAServerKeyExchange { + // Send an empty ServerKeyExchange message. + return &serverKeyExchangeMsg{}, nil + } + + return nil, nil +} + +func (ka *rsaKeyAgreement) processClientKeyExchange(config *Config, cert *Certificate, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) { + preMasterSecret := make([]byte, 48) + _, err := io.ReadFull(config.rand(), preMasterSecret[2:]) + if err != nil { + return nil, err + } + + if len(ckx.ciphertext) < 2 { + return nil, errClientKeyExchange + } + + ciphertext := ckx.ciphertext + if version != VersionSSL30 { + ciphertextLen := int(ckx.ciphertext[0])<<8 | int(ckx.ciphertext[1]) + if ciphertextLen != len(ckx.ciphertext)-2 { + return nil, errClientKeyExchange + } + ciphertext = ckx.ciphertext[2:] + } + + err = rsa.DecryptPKCS1v15SessionKey(config.rand(), cert.PrivateKey.(*rsa.PrivateKey), ciphertext, preMasterSecret) + if err != nil { + return nil, err + } + // This check should be done in constant-time, but this is a testing + // implementation. See the discussion at the end of section 7.4.7.1 of + // RFC 4346. + vers := uint16(preMasterSecret[0])<<8 | uint16(preMasterSecret[1]) + if ka.clientVersion != vers { + return nil, errors.New("tls: invalid version in RSA premaster") + } + return preMasterSecret, nil +} + +func (ka *rsaKeyAgreement) processServerKeyExchange(config *Config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, cert *x509.Certificate, skx *serverKeyExchangeMsg) error { + return errors.New("tls: unexpected ServerKeyExchange") +} + +func (ka *rsaKeyAgreement) generateClientKeyExchange(config *Config, clientHello *clientHelloMsg, cert *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) { + preMasterSecret := make([]byte, 48) + vers := clientHello.vers + if config.Bugs.RsaClientKeyExchangeVersion != 0 { + vers = config.Bugs.RsaClientKeyExchangeVersion + } + vers = versionToWire(vers, clientHello.isDTLS) + preMasterSecret[0] = byte(vers >> 8) + preMasterSecret[1] = byte(vers) + _, err := io.ReadFull(config.rand(), preMasterSecret[2:]) + if err != nil { + return nil, nil, err + } + + encrypted, err := rsa.EncryptPKCS1v15(config.rand(), cert.PublicKey.(*rsa.PublicKey), preMasterSecret) + if err != nil { + return nil, nil, err + } + ckx := new(clientKeyExchangeMsg) + if clientHello.vers != VersionSSL30 && !config.Bugs.SSL3RSAKeyExchange { + ckx.ciphertext = make([]byte, len(encrypted)+2) + ckx.ciphertext[0] = byte(len(encrypted) >> 8) + ckx.ciphertext[1] = byte(len(encrypted)) + copy(ckx.ciphertext[2:], encrypted) + } else { + ckx.ciphertext = encrypted + } + return preMasterSecret, ckx, nil +} + +// sha1Hash calculates a SHA1 hash over the given byte slices. +func sha1Hash(slices [][]byte) []byte { + hsha1 := sha1.New() + for _, slice := range slices { + hsha1.Write(slice) + } + return hsha1.Sum(nil) +} + +// md5SHA1Hash implements TLS 1.0's hybrid hash function which consists of the +// concatenation of an MD5 and SHA1 hash. +func md5SHA1Hash(slices [][]byte) []byte { + md5sha1 := make([]byte, md5.Size+sha1.Size) + hmd5 := md5.New() + for _, slice := range slices { + hmd5.Write(slice) + } + copy(md5sha1, hmd5.Sum(nil)) + copy(md5sha1[md5.Size:], sha1Hash(slices)) + return md5sha1 +} + +// hashForServerKeyExchange hashes the given slices and returns their digest +// and the identifier of the hash function used. The hashFunc argument is only +// used for >= TLS 1.2 and precisely identifies the hash function to use. +func hashForServerKeyExchange(sigType, hashFunc uint8, version uint16, slices ...[]byte) ([]byte, crypto.Hash, error) { + if version >= VersionTLS12 { + hash, err := lookupTLSHash(hashFunc) + if err != nil { + return nil, 0, err + } + h := hash.New() + for _, slice := range slices { + h.Write(slice) + } + return h.Sum(nil), hash, nil + } + if sigType == signatureECDSA { + return sha1Hash(slices), crypto.SHA1, nil + } + return md5SHA1Hash(slices), crypto.MD5SHA1, nil +} + +// pickTLS12HashForSignature returns a TLS 1.2 hash identifier for signing a +// ServerKeyExchange given the signature type being used and the client's +// advertized list of supported signature and hash combinations. +func pickTLS12HashForSignature(sigType uint8, clientSignatureAndHashes []signatureAndHash) (uint8, error) { + if len(clientSignatureAndHashes) == 0 { + // If the client didn't specify any signature_algorithms + // extension then we can assume that it supports SHA1. See + // http://tools.ietf.org/html/rfc5246#section-7.4.1.4.1 + return hashSHA1, nil + } + + for _, sigAndHash := range clientSignatureAndHashes { + if sigAndHash.signature != sigType { + continue + } + switch sigAndHash.hash { + case hashSHA1, hashSHA256: + return sigAndHash.hash, nil + } + } + + return 0, errors.New("tls: client doesn't support any common hash functions") +} + +func curveForCurveID(id CurveID) (elliptic.Curve, bool) { + switch id { + case CurveP256: + return elliptic.P256(), true + case CurveP384: + return elliptic.P384(), true + case CurveP521: + return elliptic.P521(), true + default: + return nil, false + } + +} + +// keyAgreementAuthentication is a helper interface that specifies how +// to authenticate the ServerKeyExchange parameters. +type keyAgreementAuthentication interface { + signParameters(config *Config, cert *Certificate, clientHello *clientHelloMsg, hello *serverHelloMsg, params []byte) (*serverKeyExchangeMsg, error) + verifyParameters(config *Config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, cert *x509.Certificate, params []byte, sig []byte) error +} + +// nilKeyAgreementAuthentication does not authenticate the key +// agreement parameters. +type nilKeyAgreementAuthentication struct{} + +func (ka *nilKeyAgreementAuthentication) signParameters(config *Config, cert *Certificate, clientHello *clientHelloMsg, hello *serverHelloMsg, params []byte) (*serverKeyExchangeMsg, error) { + skx := new(serverKeyExchangeMsg) + skx.key = params + return skx, nil +} + +func (ka *nilKeyAgreementAuthentication) verifyParameters(config *Config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, cert *x509.Certificate, params []byte, sig []byte) error { + return nil +} + +// signedKeyAgreement signs the ServerKeyExchange parameters with the +// server's private key. +type signedKeyAgreement struct { + version uint16 + sigType uint8 +} + +func (ka *signedKeyAgreement) signParameters(config *Config, cert *Certificate, clientHello *clientHelloMsg, hello *serverHelloMsg, params []byte) (*serverKeyExchangeMsg, error) { + var tls12HashId uint8 + var err error + if ka.version >= VersionTLS12 { + if tls12HashId, err = pickTLS12HashForSignature(ka.sigType, clientHello.signatureAndHashes); err != nil { + return nil, err + } + } + + digest, hashFunc, err := hashForServerKeyExchange(ka.sigType, tls12HashId, ka.version, clientHello.random, hello.random, params) + if err != nil { + return nil, err + } + + if config.Bugs.InvalidSKXSignature { + digest[0] ^= 0x80 + } + + var sig []byte + switch ka.sigType { + case signatureECDSA: + privKey, ok := cert.PrivateKey.(*ecdsa.PrivateKey) + if !ok { + return nil, errors.New("ECDHE ECDSA requires an ECDSA server private key") + } + r, s, err := ecdsa.Sign(config.rand(), privKey, digest) + if err != nil { + return nil, errors.New("failed to sign ECDHE parameters: " + err.Error()) + } + order := privKey.Curve.Params().N + r = maybeCorruptECDSAValue(r, config.Bugs.BadECDSAR, order) + s = maybeCorruptECDSAValue(s, config.Bugs.BadECDSAS, order) + sig, err = asn1.Marshal(ecdsaSignature{r, s}) + case signatureRSA: + privKey, ok := cert.PrivateKey.(*rsa.PrivateKey) + if !ok { + return nil, errors.New("ECDHE RSA requires a RSA server private key") + } + sig, err = rsa.SignPKCS1v15(config.rand(), privKey, hashFunc, digest) + if err != nil { + return nil, errors.New("failed to sign ECDHE parameters: " + err.Error()) + } + default: + return nil, errors.New("unknown ECDHE signature algorithm") + } + + skx := new(serverKeyExchangeMsg) + if config.Bugs.UnauthenticatedECDH { + skx.key = params + } else { + sigAndHashLen := 0 + if ka.version >= VersionTLS12 { + sigAndHashLen = 2 + } + skx.key = make([]byte, len(params)+sigAndHashLen+2+len(sig)) + copy(skx.key, params) + k := skx.key[len(params):] + if ka.version >= VersionTLS12 { + k[0] = tls12HashId + k[1] = ka.sigType + k = k[2:] + } + k[0] = byte(len(sig) >> 8) + k[1] = byte(len(sig)) + copy(k[2:], sig) + } + + return skx, nil +} + +func (ka *signedKeyAgreement) verifyParameters(config *Config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, cert *x509.Certificate, params []byte, sig []byte) error { + if len(sig) < 2 { + return errServerKeyExchange + } + + var tls12HashId uint8 + if ka.version >= VersionTLS12 { + // handle SignatureAndHashAlgorithm + var sigAndHash []uint8 + sigAndHash, sig = sig[:2], sig[2:] + if sigAndHash[1] != ka.sigType { + return errServerKeyExchange + } + tls12HashId = sigAndHash[0] + if len(sig) < 2 { + return errServerKeyExchange + } + + if !isSupportedSignatureAndHash(signatureAndHash{ka.sigType, tls12HashId}, config.signatureAndHashesForClient()) { + return errors.New("tls: unsupported hash function for ServerKeyExchange") + } + } + sigLen := int(sig[0])<<8 | int(sig[1]) + if sigLen+2 != len(sig) { + return errServerKeyExchange + } + sig = sig[2:] + + digest, hashFunc, err := hashForServerKeyExchange(ka.sigType, tls12HashId, ka.version, clientHello.random, serverHello.random, params) + if err != nil { + return err + } + switch ka.sigType { + case signatureECDSA: + pubKey, ok := cert.PublicKey.(*ecdsa.PublicKey) + if !ok { + return errors.New("ECDHE ECDSA requires a ECDSA server public key") + } + ecdsaSig := new(ecdsaSignature) + if _, err := asn1.Unmarshal(sig, ecdsaSig); err != nil { + return err + } + if ecdsaSig.R.Sign() <= 0 || ecdsaSig.S.Sign() <= 0 { + return errors.New("ECDSA signature contained zero or negative values") + } + if !ecdsa.Verify(pubKey, digest, ecdsaSig.R, ecdsaSig.S) { + return errors.New("ECDSA verification failure") + } + case signatureRSA: + pubKey, ok := cert.PublicKey.(*rsa.PublicKey) + if !ok { + return errors.New("ECDHE RSA requires a RSA server public key") + } + if err := rsa.VerifyPKCS1v15(pubKey, hashFunc, digest, sig); err != nil { + return err + } + default: + return errors.New("unknown ECDHE signature algorithm") + } + + return nil +} + +// ecdheRSAKeyAgreement implements a TLS key agreement where the server +// generates a ephemeral EC public/private key pair and signs it. The +// pre-master secret is then calculated using ECDH. The signature may +// either be ECDSA or RSA. +type ecdheKeyAgreement struct { + auth keyAgreementAuthentication + privateKey []byte + curve elliptic.Curve + x, y *big.Int +} + +func maybeCorruptECDSAValue(n *big.Int, typeOfCorruption BadValue, limit *big.Int) *big.Int { + switch typeOfCorruption { + case BadValueNone: + return n + case BadValueNegative: + return new(big.Int).Neg(n) + case BadValueZero: + return big.NewInt(0) + case BadValueLimit: + return limit + case BadValueLarge: + bad := new(big.Int).Set(limit) + return bad.Lsh(bad, 20) + default: + panic("unknown BadValue type") + } +} + +func (ka *ecdheKeyAgreement) generateServerKeyExchange(config *Config, cert *Certificate, clientHello *clientHelloMsg, hello *serverHelloMsg) (*serverKeyExchangeMsg, error) { + var curveid CurveID + preferredCurves := config.curvePreferences() + +NextCandidate: + for _, candidate := range preferredCurves { + for _, c := range clientHello.supportedCurves { + if candidate == c { + curveid = c + break NextCandidate + } + } + } + + if curveid == 0 { + return nil, errors.New("tls: no supported elliptic curves offered") + } + + var ok bool + if ka.curve, ok = curveForCurveID(curveid); !ok { + return nil, errors.New("tls: preferredCurves includes unsupported curve") + } + + var x, y *big.Int + var err error + ka.privateKey, x, y, err = elliptic.GenerateKey(ka.curve, config.rand()) + if err != nil { + return nil, err + } + ecdhePublic := elliptic.Marshal(ka.curve, x, y) + + // http://tools.ietf.org/html/rfc4492#section-5.4 + serverECDHParams := make([]byte, 1+2+1+len(ecdhePublic)) + serverECDHParams[0] = 3 // named curve + serverECDHParams[1] = byte(curveid >> 8) + serverECDHParams[2] = byte(curveid) + if config.Bugs.InvalidSKXCurve { + serverECDHParams[2] ^= 0xff + } + serverECDHParams[3] = byte(len(ecdhePublic)) + copy(serverECDHParams[4:], ecdhePublic) + + return ka.auth.signParameters(config, cert, clientHello, hello, serverECDHParams) +} + +func (ka *ecdheKeyAgreement) processClientKeyExchange(config *Config, cert *Certificate, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) { + if len(ckx.ciphertext) == 0 || int(ckx.ciphertext[0]) != len(ckx.ciphertext)-1 { + return nil, errClientKeyExchange + } + x, y := elliptic.Unmarshal(ka.curve, ckx.ciphertext[1:]) + if x == nil { + return nil, errClientKeyExchange + } + x, _ = ka.curve.ScalarMult(x, y, ka.privateKey) + preMasterSecret := make([]byte, (ka.curve.Params().BitSize+7)>>3) + xBytes := x.Bytes() + copy(preMasterSecret[len(preMasterSecret)-len(xBytes):], xBytes) + + return preMasterSecret, nil +} + +func (ka *ecdheKeyAgreement) processServerKeyExchange(config *Config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, cert *x509.Certificate, skx *serverKeyExchangeMsg) error { + if len(skx.key) < 4 { + return errServerKeyExchange + } + if skx.key[0] != 3 { // named curve + return errors.New("tls: server selected unsupported curve") + } + curveid := CurveID(skx.key[1])<<8 | CurveID(skx.key[2]) + + var ok bool + if ka.curve, ok = curveForCurveID(curveid); !ok { + return errors.New("tls: server selected unsupported curve") + } + + publicLen := int(skx.key[3]) + if publicLen+4 > len(skx.key) { + return errServerKeyExchange + } + ka.x, ka.y = elliptic.Unmarshal(ka.curve, skx.key[4:4+publicLen]) + if ka.x == nil { + return errServerKeyExchange + } + serverECDHParams := skx.key[:4+publicLen] + sig := skx.key[4+publicLen:] + + return ka.auth.verifyParameters(config, clientHello, serverHello, cert, serverECDHParams, sig) +} + +func (ka *ecdheKeyAgreement) generateClientKeyExchange(config *Config, clientHello *clientHelloMsg, cert *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) { + if ka.curve == nil { + return nil, nil, errors.New("missing ServerKeyExchange message") + } + priv, mx, my, err := elliptic.GenerateKey(ka.curve, config.rand()) + if err != nil { + return nil, nil, err + } + x, _ := ka.curve.ScalarMult(ka.x, ka.y, priv) + preMasterSecret := make([]byte, (ka.curve.Params().BitSize+7)>>3) + xBytes := x.Bytes() + copy(preMasterSecret[len(preMasterSecret)-len(xBytes):], xBytes) + + serialized := elliptic.Marshal(ka.curve, mx, my) + + ckx := new(clientKeyExchangeMsg) + ckx.ciphertext = make([]byte, 1+len(serialized)) + ckx.ciphertext[0] = byte(len(serialized)) + copy(ckx.ciphertext[1:], serialized) + + return preMasterSecret, ckx, nil +} + +// dheRSAKeyAgreement implements a TLS key agreement where the server generates +// an ephemeral Diffie-Hellman public/private key pair and signs it. The +// pre-master secret is then calculated using Diffie-Hellman. +type dheKeyAgreement struct { + auth keyAgreementAuthentication + p, g *big.Int + yTheirs *big.Int + xOurs *big.Int +} + +func (ka *dheKeyAgreement) generateServerKeyExchange(config *Config, cert *Certificate, clientHello *clientHelloMsg, hello *serverHelloMsg) (*serverKeyExchangeMsg, error) { + // 2048-bit MODP Group with 256-bit Prime Order Subgroup (RFC + // 5114, Section 2.3) + ka.p, _ = new(big.Int).SetString("87A8E61DB4B6663CFFBBD19C651959998CEEF608660DD0F25D2CEED4435E3B00E00DF8F1D61957D4FAF7DF4561B2AA3016C3D91134096FAA3BF4296D830E9A7C209E0C6497517ABD5A8A9D306BCF67ED91F9E6725B4758C022E0B1EF4275BF7B6C5BFC11D45F9088B941F54EB1E59BB8BC39A0BF12307F5C4FDB70C581B23F76B63ACAE1CAA6B7902D52526735488A0EF13C6D9A51BFA4AB3AD8347796524D8EF6A167B5A41825D967E144E5140564251CCACB83E6B486F6B3CA3F7971506026C0B857F689962856DED4010ABD0BE621C3A3960A54E710C375F26375D7014103A4B54330C198AF126116D2276E11715F693877FAD7EF09CADB094AE91E1A1597", 16) + ka.g, _ = new(big.Int).SetString("3FB32C9B73134D0B2E77506660EDBD484CA7B18F21EF205407F4793A1A0BA12510DBC15077BE463FFF4FED4AAC0BB555BE3A6C1B0C6B47B1BC3773BF7E8C6F62901228F8C28CBB18A55AE31341000A650196F931C77A57F2DDF463E5E9EC144B777DE62AAAB8A8628AC376D282D6ED3864E67982428EBC831D14348F6F2F9193B5045AF2767164E1DFC967C1FB3F2E55A4BD1BFFE83B9C80D052B985D182EA0ADB2A3B7313D3FE14C8484B1E052588B9B7D2BBD2DF016199ECD06E1557CD0915B3353BBB64E0EC377FD028370DF92B52C7891428CDC67EB6184B523D1DB246C32F63078490F00EF8D647D148D47954515E2327CFEF98C582664B4C0F6CC41659", 16) + q, _ := new(big.Int).SetString("8CF83642A709A097B447997640129DA299B1A47D1EB3750BA308B0FE64F5FBD3", 16) + + var err error + ka.xOurs, err = rand.Int(config.rand(), q) + if err != nil { + return nil, err + } + yOurs := new(big.Int).Exp(ka.g, ka.xOurs, ka.p) + + // http://tools.ietf.org/html/rfc5246#section-7.4.3 + pBytes := ka.p.Bytes() + gBytes := ka.g.Bytes() + yBytes := yOurs.Bytes() + serverDHParams := make([]byte, 0, 2+len(pBytes)+2+len(gBytes)+2+len(yBytes)) + serverDHParams = append(serverDHParams, byte(len(pBytes)>>8), byte(len(pBytes))) + serverDHParams = append(serverDHParams, pBytes...) + serverDHParams = append(serverDHParams, byte(len(gBytes)>>8), byte(len(gBytes))) + serverDHParams = append(serverDHParams, gBytes...) + serverDHParams = append(serverDHParams, byte(len(yBytes)>>8), byte(len(yBytes))) + serverDHParams = append(serverDHParams, yBytes...) + + return ka.auth.signParameters(config, cert, clientHello, hello, serverDHParams) +} + +func (ka *dheKeyAgreement) processClientKeyExchange(config *Config, cert *Certificate, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) { + if len(ckx.ciphertext) < 2 { + return nil, errClientKeyExchange + } + yLen := (int(ckx.ciphertext[0]) << 8) | int(ckx.ciphertext[1]) + if yLen != len(ckx.ciphertext)-2 { + return nil, errClientKeyExchange + } + yTheirs := new(big.Int).SetBytes(ckx.ciphertext[2:]) + if yTheirs.Sign() <= 0 || yTheirs.Cmp(ka.p) >= 0 { + return nil, errClientKeyExchange + } + return new(big.Int).Exp(yTheirs, ka.xOurs, ka.p).Bytes(), nil +} + +func (ka *dheKeyAgreement) processServerKeyExchange(config *Config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, cert *x509.Certificate, skx *serverKeyExchangeMsg) error { + // Read dh_p + k := skx.key + if len(k) < 2 { + return errServerKeyExchange + } + pLen := (int(k[0]) << 8) | int(k[1]) + k = k[2:] + if len(k) < pLen { + return errServerKeyExchange + } + ka.p = new(big.Int).SetBytes(k[:pLen]) + k = k[pLen:] + + // Read dh_g + if len(k) < 2 { + return errServerKeyExchange + } + gLen := (int(k[0]) << 8) | int(k[1]) + k = k[2:] + if len(k) < gLen { + return errServerKeyExchange + } + ka.g = new(big.Int).SetBytes(k[:gLen]) + k = k[gLen:] + + // Read dh_Ys + if len(k) < 2 { + return errServerKeyExchange + } + yLen := (int(k[0]) << 8) | int(k[1]) + k = k[2:] + if len(k) < yLen { + return errServerKeyExchange + } + ka.yTheirs = new(big.Int).SetBytes(k[:yLen]) + k = k[yLen:] + if ka.yTheirs.Sign() <= 0 || ka.yTheirs.Cmp(ka.p) >= 0 { + return errServerKeyExchange + } + + sig := k + serverDHParams := skx.key[:len(skx.key)-len(sig)] + + return ka.auth.verifyParameters(config, clientHello, serverHello, cert, serverDHParams, sig) +} + +func (ka *dheKeyAgreement) generateClientKeyExchange(config *Config, clientHello *clientHelloMsg, cert *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) { + if ka.p == nil || ka.g == nil || ka.yTheirs == nil { + return nil, nil, errors.New("missing ServerKeyExchange message") + } + + xOurs, err := rand.Int(config.rand(), ka.p) + if err != nil { + return nil, nil, err + } + preMasterSecret := new(big.Int).Exp(ka.yTheirs, xOurs, ka.p).Bytes() + + yOurs := new(big.Int).Exp(ka.g, xOurs, ka.p) + yBytes := yOurs.Bytes() + ckx := new(clientKeyExchangeMsg) + ckx.ciphertext = make([]byte, 2+len(yBytes)) + ckx.ciphertext[0] = byte(len(yBytes) >> 8) + ckx.ciphertext[1] = byte(len(yBytes)) + copy(ckx.ciphertext[2:], yBytes) + + return preMasterSecret, ckx, nil +} + +// nilKeyAgreement is a fake key agreement used to implement the plain PSK key +// exchange. +type nilKeyAgreement struct{} + +func (ka *nilKeyAgreement) generateServerKeyExchange(config *Config, cert *Certificate, clientHello *clientHelloMsg, hello *serverHelloMsg) (*serverKeyExchangeMsg, error) { + return nil, nil +} + +func (ka *nilKeyAgreement) processClientKeyExchange(config *Config, cert *Certificate, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) { + if len(ckx.ciphertext) != 0 { + return nil, errClientKeyExchange + } + + // Although in plain PSK, otherSecret is all zeros, the base key + // agreement does not access to the length of the pre-shared + // key. pskKeyAgreement instead interprets nil to mean to use all zeros + // of the appropriate length. + return nil, nil +} + +func (ka *nilKeyAgreement) processServerKeyExchange(config *Config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, cert *x509.Certificate, skx *serverKeyExchangeMsg) error { + if len(skx.key) != 0 { + return errServerKeyExchange + } + return nil +} + +func (ka *nilKeyAgreement) generateClientKeyExchange(config *Config, clientHello *clientHelloMsg, cert *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) { + // Although in plain PSK, otherSecret is all zeros, the base key + // agreement does not access to the length of the pre-shared + // key. pskKeyAgreement instead interprets nil to mean to use all zeros + // of the appropriate length. + return nil, &clientKeyExchangeMsg{}, nil +} + +// makePSKPremaster formats a PSK pre-master secret based on otherSecret from +// the base key exchange and psk. +func makePSKPremaster(otherSecret, psk []byte) []byte { + out := make([]byte, 0, 2+len(otherSecret)+2+len(psk)) + out = append(out, byte(len(otherSecret)>>8), byte(len(otherSecret))) + out = append(out, otherSecret...) + out = append(out, byte(len(psk)>>8), byte(len(psk))) + out = append(out, psk...) + return out +} + +// pskKeyAgreement implements the PSK key agreement. +type pskKeyAgreement struct { + base keyAgreement + identityHint string +} + +func (ka *pskKeyAgreement) generateServerKeyExchange(config *Config, cert *Certificate, clientHello *clientHelloMsg, hello *serverHelloMsg) (*serverKeyExchangeMsg, error) { + // Assemble the identity hint. + bytes := make([]byte, 2+len(config.PreSharedKeyIdentity)) + bytes[0] = byte(len(config.PreSharedKeyIdentity) >> 8) + bytes[1] = byte(len(config.PreSharedKeyIdentity)) + copy(bytes[2:], []byte(config.PreSharedKeyIdentity)) + + // If there is one, append the base key agreement's + // ServerKeyExchange. + baseSkx, err := ka.base.generateServerKeyExchange(config, cert, clientHello, hello) + if err != nil { + return nil, err + } + + if baseSkx != nil { + bytes = append(bytes, baseSkx.key...) + } else if config.PreSharedKeyIdentity == "" { + // ServerKeyExchange is optional if the identity hint is empty + // and there would otherwise be no ServerKeyExchange. + return nil, nil + } + + skx := new(serverKeyExchangeMsg) + skx.key = bytes + return skx, nil +} + +func (ka *pskKeyAgreement) processClientKeyExchange(config *Config, cert *Certificate, ckx *clientKeyExchangeMsg, version uint16) ([]byte, error) { + // First, process the PSK identity. + if len(ckx.ciphertext) < 2 { + return nil, errClientKeyExchange + } + identityLen := (int(ckx.ciphertext[0]) << 8) | int(ckx.ciphertext[1]) + if 2+identityLen > len(ckx.ciphertext) { + return nil, errClientKeyExchange + } + identity := string(ckx.ciphertext[2 : 2+identityLen]) + + if identity != config.PreSharedKeyIdentity { + return nil, errors.New("tls: unexpected identity") + } + + if config.PreSharedKey == nil { + return nil, errors.New("tls: pre-shared key not configured") + } + + // Process the remainder of the ClientKeyExchange to compute the base + // pre-master secret. + newCkx := new(clientKeyExchangeMsg) + newCkx.ciphertext = ckx.ciphertext[2+identityLen:] + otherSecret, err := ka.base.processClientKeyExchange(config, cert, newCkx, version) + if err != nil { + return nil, err + } + + if otherSecret == nil { + // Special-case for the plain PSK key exchanges. + otherSecret = make([]byte, len(config.PreSharedKey)) + } + return makePSKPremaster(otherSecret, config.PreSharedKey), nil +} + +func (ka *pskKeyAgreement) processServerKeyExchange(config *Config, clientHello *clientHelloMsg, serverHello *serverHelloMsg, cert *x509.Certificate, skx *serverKeyExchangeMsg) error { + if len(skx.key) < 2 { + return errServerKeyExchange + } + identityLen := (int(skx.key[0]) << 8) | int(skx.key[1]) + if 2+identityLen > len(skx.key) { + return errServerKeyExchange + } + ka.identityHint = string(skx.key[2 : 2+identityLen]) + + // Process the remainder of the ServerKeyExchange. + newSkx := new(serverKeyExchangeMsg) + newSkx.key = skx.key[2+identityLen:] + return ka.base.processServerKeyExchange(config, clientHello, serverHello, cert, newSkx) +} + +func (ka *pskKeyAgreement) generateClientKeyExchange(config *Config, clientHello *clientHelloMsg, cert *x509.Certificate) ([]byte, *clientKeyExchangeMsg, error) { + // The server only sends an identity hint but, for purposes of + // test code, the server always sends the hint and it is + // required to match. + if ka.identityHint != config.PreSharedKeyIdentity { + return nil, nil, errors.New("tls: unexpected identity") + } + + // Serialize the identity. + bytes := make([]byte, 2+len(config.PreSharedKeyIdentity)) + bytes[0] = byte(len(config.PreSharedKeyIdentity) >> 8) + bytes[1] = byte(len(config.PreSharedKeyIdentity)) + copy(bytes[2:], []byte(config.PreSharedKeyIdentity)) + + // Append the base key exchange's ClientKeyExchange. + otherSecret, baseCkx, err := ka.base.generateClientKeyExchange(config, clientHello, cert) + if err != nil { + return nil, nil, err + } + ckx := new(clientKeyExchangeMsg) + ckx.ciphertext = append(bytes, baseCkx.ciphertext...) + + if config.PreSharedKey == nil { + return nil, nil, errors.New("tls: pre-shared key not configured") + } + if otherSecret == nil { + otherSecret = make([]byte, len(config.PreSharedKey)) + } + return makePSKPremaster(otherSecret, config.PreSharedKey), ckx, nil +} diff --git a/src/ssl/test/runner/packet_adapter.go b/src/ssl/test/runner/packet_adapter.go new file mode 100644 index 0000000..671b413 --- /dev/null +++ b/src/ssl/test/runner/packet_adapter.go @@ -0,0 +1,101 @@ +// Copyright 2014 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "encoding/binary" + "errors" + "net" +) + +type packetAdaptor struct { + net.Conn +} + +// newPacketAdaptor wraps a reliable streaming net.Conn into a +// reliable packet-based net.Conn. Every packet is encoded with a +// 32-bit length prefix as a framing layer. +func newPacketAdaptor(conn net.Conn) net.Conn { + return &packetAdaptor{conn} +} + +func (p *packetAdaptor) Read(b []byte) (int, error) { + var length uint32 + if err := binary.Read(p.Conn, binary.BigEndian, &length); err != nil { + return 0, err + } + out := make([]byte, length) + n, err := p.Conn.Read(out) + if err != nil { + return 0, err + } + if n != int(length) { + return 0, errors.New("internal error: length mismatch!") + } + return copy(b, out), nil +} + +func (p *packetAdaptor) Write(b []byte) (int, error) { + length := uint32(len(b)) + if err := binary.Write(p.Conn, binary.BigEndian, length); err != nil { + return 0, err + } + n, err := p.Conn.Write(b) + if err != nil { + return 0, err + } + if n != len(b) { + return 0, errors.New("internal error: length mismatch!") + } + return len(b), nil +} + +type replayAdaptor struct { + net.Conn + prevWrite []byte +} + +// newReplayAdaptor wraps a packeted net.Conn. It transforms it into +// one which, after writing a packet, always replays the previous +// write. +func newReplayAdaptor(conn net.Conn) net.Conn { + return &replayAdaptor{Conn: conn} +} + +func (r *replayAdaptor) Write(b []byte) (int, error) { + n, err := r.Conn.Write(b) + + // Replay the previous packet and save the current one to + // replay next. + if r.prevWrite != nil { + r.Conn.Write(r.prevWrite) + } + r.prevWrite = append(r.prevWrite[:0], b...) + + return n, err +} + +type damageAdaptor struct { + net.Conn + damage bool +} + +// newDamageAdaptor wraps a packeted net.Conn. It transforms it into one which +// optionally damages the final byte of every Write() call. +func newDamageAdaptor(conn net.Conn) *damageAdaptor { + return &damageAdaptor{Conn: conn} +} + +func (d *damageAdaptor) setDamage(damage bool) { + d.damage = damage +} + +func (d *damageAdaptor) Write(b []byte) (int, error) { + if d.damage && len(b) > 0 { + b = append([]byte{}, b...) + b[len(b)-1]++ + } + return d.Conn.Write(b) +} diff --git a/src/ssl/test/runner/prf.go b/src/ssl/test/runner/prf.go new file mode 100644 index 0000000..75a8933 --- /dev/null +++ b/src/ssl/test/runner/prf.go @@ -0,0 +1,388 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "crypto" + "crypto/hmac" + "crypto/md5" + "crypto/sha1" + "crypto/sha256" + "crypto/sha512" + "errors" + "hash" +) + +// Split a premaster secret in two as specified in RFC 4346, section 5. +func splitPreMasterSecret(secret []byte) (s1, s2 []byte) { + s1 = secret[0 : (len(secret)+1)/2] + s2 = secret[len(secret)/2:] + return +} + +// pHash implements the P_hash function, as defined in RFC 4346, section 5. +func pHash(result, secret, seed []byte, hash func() hash.Hash) { + h := hmac.New(hash, secret) + h.Write(seed) + a := h.Sum(nil) + + j := 0 + for j < len(result) { + h.Reset() + h.Write(a) + h.Write(seed) + b := h.Sum(nil) + todo := len(b) + if j+todo > len(result) { + todo = len(result) - j + } + copy(result[j:j+todo], b) + j += todo + + h.Reset() + h.Write(a) + a = h.Sum(nil) + } +} + +// prf10 implements the TLS 1.0 pseudo-random function, as defined in RFC 2246, section 5. +func prf10(result, secret, label, seed []byte) { + hashSHA1 := sha1.New + hashMD5 := md5.New + + labelAndSeed := make([]byte, len(label)+len(seed)) + copy(labelAndSeed, label) + copy(labelAndSeed[len(label):], seed) + + s1, s2 := splitPreMasterSecret(secret) + pHash(result, s1, labelAndSeed, hashMD5) + result2 := make([]byte, len(result)) + pHash(result2, s2, labelAndSeed, hashSHA1) + + for i, b := range result2 { + result[i] ^= b + } +} + +// prf12 implements the TLS 1.2 pseudo-random function, as defined in RFC 5246, section 5. +func prf12(hashFunc func() hash.Hash) func(result, secret, label, seed []byte) { + return func(result, secret, label, seed []byte) { + labelAndSeed := make([]byte, len(label)+len(seed)) + copy(labelAndSeed, label) + copy(labelAndSeed[len(label):], seed) + + pHash(result, secret, labelAndSeed, hashFunc) + } +} + +// prf30 implements the SSL 3.0 pseudo-random function, as defined in +// www.mozilla.org/projects/security/pki/nss/ssl/draft302.txt section 6. +func prf30(result, secret, label, seed []byte) { + hashSHA1 := sha1.New() + hashMD5 := md5.New() + + done := 0 + i := 0 + // RFC5246 section 6.3 says that the largest PRF output needed is 128 + // bytes. Since no more ciphersuites will be added to SSLv3, this will + // remain true. Each iteration gives us 16 bytes so 10 iterations will + // be sufficient. + var b [11]byte + for done < len(result) { + for j := 0; j <= i; j++ { + b[j] = 'A' + byte(i) + } + + hashSHA1.Reset() + hashSHA1.Write(b[:i+1]) + hashSHA1.Write(secret) + hashSHA1.Write(seed) + digest := hashSHA1.Sum(nil) + + hashMD5.Reset() + hashMD5.Write(secret) + hashMD5.Write(digest) + + done += copy(result[done:], hashMD5.Sum(nil)) + i++ + } +} + +const ( + tlsRandomLength = 32 // Length of a random nonce in TLS 1.1. + masterSecretLength = 48 // Length of a master secret in TLS 1.1. + finishedVerifyLength = 12 // Length of verify_data in a Finished message. +) + +var masterSecretLabel = []byte("master secret") +var extendedMasterSecretLabel = []byte("extended master secret") +var keyExpansionLabel = []byte("key expansion") +var clientFinishedLabel = []byte("client finished") +var serverFinishedLabel = []byte("server finished") +var channelIDLabel = []byte("TLS Channel ID signature\x00") +var channelIDResumeLabel = []byte("Resumption\x00") + +func prfForVersion(version uint16, suite *cipherSuite) func(result, secret, label, seed []byte) { + switch version { + case VersionSSL30: + return prf30 + case VersionTLS10, VersionTLS11: + return prf10 + case VersionTLS12: + if suite.flags&suiteSHA384 != 0 { + return prf12(sha512.New384) + } + return prf12(sha256.New) + default: + panic("unknown version") + } +} + +// masterFromPreMasterSecret generates the master secret from the pre-master +// secret. See http://tools.ietf.org/html/rfc5246#section-8.1 +func masterFromPreMasterSecret(version uint16, suite *cipherSuite, preMasterSecret, clientRandom, serverRandom []byte) []byte { + var seed [tlsRandomLength * 2]byte + copy(seed[0:len(clientRandom)], clientRandom) + copy(seed[len(clientRandom):], serverRandom) + masterSecret := make([]byte, masterSecretLength) + prfForVersion(version, suite)(masterSecret, preMasterSecret, masterSecretLabel, seed[0:]) + return masterSecret +} + +// extendedMasterFromPreMasterSecret generates the master secret from the +// pre-master secret when the Triple Handshake fix is in effect. See +// https://tools.ietf.org/html/draft-ietf-tls-session-hash-01 +func extendedMasterFromPreMasterSecret(version uint16, suite *cipherSuite, preMasterSecret []byte, h finishedHash) []byte { + masterSecret := make([]byte, masterSecretLength) + prfForVersion(version, suite)(masterSecret, preMasterSecret, extendedMasterSecretLabel, h.Sum()) + return masterSecret +} + +// keysFromMasterSecret generates the connection keys from the master +// secret, given the lengths of the MAC key, cipher key and IV, as defined in +// RFC 2246, section 6.3. +func keysFromMasterSecret(version uint16, suite *cipherSuite, masterSecret, clientRandom, serverRandom []byte, macLen, keyLen, ivLen int) (clientMAC, serverMAC, clientKey, serverKey, clientIV, serverIV []byte) { + var seed [tlsRandomLength * 2]byte + copy(seed[0:len(clientRandom)], serverRandom) + copy(seed[len(serverRandom):], clientRandom) + + n := 2*macLen + 2*keyLen + 2*ivLen + keyMaterial := make([]byte, n) + prfForVersion(version, suite)(keyMaterial, masterSecret, keyExpansionLabel, seed[0:]) + clientMAC = keyMaterial[:macLen] + keyMaterial = keyMaterial[macLen:] + serverMAC = keyMaterial[:macLen] + keyMaterial = keyMaterial[macLen:] + clientKey = keyMaterial[:keyLen] + keyMaterial = keyMaterial[keyLen:] + serverKey = keyMaterial[:keyLen] + keyMaterial = keyMaterial[keyLen:] + clientIV = keyMaterial[:ivLen] + keyMaterial = keyMaterial[ivLen:] + serverIV = keyMaterial[:ivLen] + return +} + +// lookupTLSHash looks up the corresponding crypto.Hash for a given +// TLS hash identifier. +func lookupTLSHash(hash uint8) (crypto.Hash, error) { + switch hash { + case hashMD5: + return crypto.MD5, nil + case hashSHA1: + return crypto.SHA1, nil + case hashSHA224: + return crypto.SHA224, nil + case hashSHA256: + return crypto.SHA256, nil + case hashSHA384: + return crypto.SHA384, nil + case hashSHA512: + return crypto.SHA512, nil + default: + return 0, errors.New("tls: unsupported hash algorithm") + } +} + +func newFinishedHash(version uint16, cipherSuite *cipherSuite) finishedHash { + if version >= VersionTLS12 { + newHash := sha256.New + if cipherSuite.flags&suiteSHA384 != 0 { + newHash = sha512.New384 + } + + return finishedHash{newHash(), newHash(), nil, nil, []byte{}, version, prf12(newHash)} + } + return finishedHash{sha1.New(), sha1.New(), md5.New(), md5.New(), []byte{}, version, prf10} +} + +// A finishedHash calculates the hash of a set of handshake messages suitable +// for including in a Finished message. +type finishedHash struct { + client hash.Hash + server hash.Hash + + // Prior to TLS 1.2, an additional MD5 hash is required. + clientMD5 hash.Hash + serverMD5 hash.Hash + + // In TLS 1.2 (and SSL 3 for implementation convenience), a + // full buffer is required. + buffer []byte + + version uint16 + prf func(result, secret, label, seed []byte) +} + +func (h *finishedHash) Write(msg []byte) (n int, err error) { + h.client.Write(msg) + h.server.Write(msg) + + if h.version < VersionTLS12 { + h.clientMD5.Write(msg) + h.serverMD5.Write(msg) + } + + if h.buffer != nil { + h.buffer = append(h.buffer, msg...) + } + + return len(msg), nil +} + +func (h finishedHash) Sum() []byte { + if h.version >= VersionTLS12 { + return h.client.Sum(nil) + } + + out := make([]byte, 0, md5.Size+sha1.Size) + out = h.clientMD5.Sum(out) + return h.client.Sum(out) +} + +// finishedSum30 calculates the contents of the verify_data member of a SSLv3 +// Finished message given the MD5 and SHA1 hashes of a set of handshake +// messages. +func finishedSum30(md5, sha1 hash.Hash, masterSecret []byte, magic []byte) []byte { + md5.Write(magic) + md5.Write(masterSecret) + md5.Write(ssl30Pad1[:]) + md5Digest := md5.Sum(nil) + + md5.Reset() + md5.Write(masterSecret) + md5.Write(ssl30Pad2[:]) + md5.Write(md5Digest) + md5Digest = md5.Sum(nil) + + sha1.Write(magic) + sha1.Write(masterSecret) + sha1.Write(ssl30Pad1[:40]) + sha1Digest := sha1.Sum(nil) + + sha1.Reset() + sha1.Write(masterSecret) + sha1.Write(ssl30Pad2[:40]) + sha1.Write(sha1Digest) + sha1Digest = sha1.Sum(nil) + + ret := make([]byte, len(md5Digest)+len(sha1Digest)) + copy(ret, md5Digest) + copy(ret[len(md5Digest):], sha1Digest) + return ret +} + +var ssl3ClientFinishedMagic = [4]byte{0x43, 0x4c, 0x4e, 0x54} +var ssl3ServerFinishedMagic = [4]byte{0x53, 0x52, 0x56, 0x52} + +// clientSum returns the contents of the verify_data member of a client's +// Finished message. +func (h finishedHash) clientSum(masterSecret []byte) []byte { + if h.version == VersionSSL30 { + return finishedSum30(h.clientMD5, h.client, masterSecret, ssl3ClientFinishedMagic[:]) + } + + out := make([]byte, finishedVerifyLength) + h.prf(out, masterSecret, clientFinishedLabel, h.Sum()) + return out +} + +// serverSum returns the contents of the verify_data member of a server's +// Finished message. +func (h finishedHash) serverSum(masterSecret []byte) []byte { + if h.version == VersionSSL30 { + return finishedSum30(h.serverMD5, h.server, masterSecret, ssl3ServerFinishedMagic[:]) + } + + out := make([]byte, finishedVerifyLength) + h.prf(out, masterSecret, serverFinishedLabel, h.Sum()) + return out +} + +// selectClientCertSignatureAlgorithm returns a signatureAndHash to sign a +// client's CertificateVerify with, or an error if none can be found. +func (h finishedHash) selectClientCertSignatureAlgorithm(serverList []signatureAndHash, sigType uint8) (signatureAndHash, error) { + if h.version < VersionTLS12 { + // Nothing to negotiate before TLS 1.2. + return signatureAndHash{signature: sigType}, nil + } + + for _, v := range serverList { + if v.signature == sigType && v.hash == hashSHA256 { + return v, nil + } + } + return signatureAndHash{}, errors.New("tls: no supported signature algorithm found for signing client certificate") +} + +// hashForClientCertificate returns a digest, hash function, and TLS 1.2 hash +// id suitable for signing by a TLS client certificate. +func (h finishedHash) hashForClientCertificate(signatureAndHash signatureAndHash, masterSecret []byte) ([]byte, crypto.Hash, error) { + if h.version == VersionSSL30 { + if signatureAndHash.signature != signatureRSA { + return nil, 0, errors.New("tls: unsupported signature type for client certificate") + } + + md5Hash := md5.New() + md5Hash.Write(h.buffer) + sha1Hash := sha1.New() + sha1Hash.Write(h.buffer) + return finishedSum30(md5Hash, sha1Hash, masterSecret, nil), crypto.MD5SHA1, nil + } + if h.version >= VersionTLS12 { + hashAlg, err := lookupTLSHash(signatureAndHash.hash) + if err != nil { + return nil, 0, err + } + hash := hashAlg.New() + hash.Write(h.buffer) + return hash.Sum(nil), hashAlg, nil + } + if signatureAndHash.signature == signatureECDSA { + return h.server.Sum(nil), crypto.SHA1, nil + } + + return h.Sum(), crypto.MD5SHA1, nil +} + +// hashForChannelID returns the hash to be signed for TLS Channel +// ID. If a resumption, resumeHash has the previous handshake +// hash. Otherwise, it is nil. +func (h finishedHash) hashForChannelID(resumeHash []byte) []byte { + hash := sha256.New() + hash.Write(channelIDLabel) + if resumeHash != nil { + hash.Write(channelIDResumeLabel) + hash.Write(resumeHash) + } + hash.Write(h.server.Sum(nil)) + return hash.Sum(nil) +} + +// discardHandshakeBuffer is called when there is no more need to +// buffer the entirety of the handshake messages. +func (h *finishedHash) discardHandshakeBuffer() { + h.buffer = nil +} diff --git a/src/ssl/test/runner/recordingconn.go b/src/ssl/test/runner/recordingconn.go new file mode 100644 index 0000000..a67fa48 --- /dev/null +++ b/src/ssl/test/runner/recordingconn.go @@ -0,0 +1,130 @@ +package main + +import ( + "bufio" + "encoding/hex" + "errors" + "fmt" + "io" + "net" + "strconv" + "strings" + "sync" +) + +// recordingConn is a net.Conn that records the traffic that passes through it. +// WriteTo can be used to produce output that can be later be loaded with +// ParseTestData. +type recordingConn struct { + net.Conn + sync.Mutex + flows [][]byte + reading bool +} + +func (r *recordingConn) Read(b []byte) (n int, err error) { + if n, err = r.Conn.Read(b); n == 0 { + return + } + b = b[:n] + + r.Lock() + defer r.Unlock() + + if l := len(r.flows); l == 0 || !r.reading { + buf := make([]byte, len(b)) + copy(buf, b) + r.flows = append(r.flows, buf) + } else { + r.flows[l-1] = append(r.flows[l-1], b[:n]...) + } + r.reading = true + return +} + +func (r *recordingConn) Write(b []byte) (n int, err error) { + if n, err = r.Conn.Write(b); n == 0 { + return + } + b = b[:n] + + r.Lock() + defer r.Unlock() + + if l := len(r.flows); l == 0 || r.reading { + buf := make([]byte, len(b)) + copy(buf, b) + r.flows = append(r.flows, buf) + } else { + r.flows[l-1] = append(r.flows[l-1], b[:n]...) + } + r.reading = false + return +} + +// WriteTo writes hex dumps to w that contains the recorded traffic. +func (r *recordingConn) WriteTo(w io.Writer) { + // TLS always starts with a client to server flow. + clientToServer := true + + for i, flow := range r.flows { + source, dest := "client", "server" + if !clientToServer { + source, dest = dest, source + } + fmt.Fprintf(w, ">>> Flow %d (%s to %s)\n", i+1, source, dest) + dumper := hex.Dumper(w) + dumper.Write(flow) + dumper.Close() + clientToServer = !clientToServer + } +} + +func parseTestData(r io.Reader) (flows [][]byte, err error) { + var currentFlow []byte + + scanner := bufio.NewScanner(r) + for scanner.Scan() { + line := scanner.Text() + // If the line starts with ">>> " then it marks the beginning + // of a new flow. + if strings.HasPrefix(line, ">>> ") { + if len(currentFlow) > 0 || len(flows) > 0 { + flows = append(flows, currentFlow) + currentFlow = nil + } + continue + } + + // Otherwise the line is a line of hex dump that looks like: + // 00000170 fc f5 06 bf (...) |.....X{&?......!| + // (Some bytes have been omitted from the middle section.) + + if i := strings.IndexByte(line, ' '); i >= 0 { + line = line[i:] + } else { + return nil, errors.New("invalid test data") + } + + if i := strings.IndexByte(line, '|'); i >= 0 { + line = line[:i] + } else { + return nil, errors.New("invalid test data") + } + + hexBytes := strings.Fields(line) + for _, hexByte := range hexBytes { + val, err := strconv.ParseUint(hexByte, 16, 8) + if err != nil { + return nil, errors.New("invalid hex byte in test data: " + err.Error()) + } + currentFlow = append(currentFlow, byte(val)) + } + } + + if len(currentFlow) > 0 { + flows = append(flows, currentFlow) + } + + return flows, nil +} diff --git a/src/ssl/test/runner/runner.go b/src/ssl/test/runner/runner.go new file mode 100644 index 0000000..137a87c --- /dev/null +++ b/src/ssl/test/runner/runner.go @@ -0,0 +1,2649 @@ +package main + +import ( + "bytes" + "crypto/ecdsa" + "crypto/elliptic" + "crypto/x509" + "encoding/base64" + "encoding/pem" + "flag" + "fmt" + "io" + "io/ioutil" + "net" + "os" + "os/exec" + "path" + "runtime" + "strconv" + "strings" + "sync" + "syscall" +) + +var ( + useValgrind = flag.Bool("valgrind", false, "If true, run code under valgrind") + useGDB = flag.Bool("gdb", false, "If true, run BoringSSL code under gdb") + flagDebug *bool = flag.Bool("debug", false, "Hexdump the contents of the connection") + mallocTest *int64 = flag.Int64("malloc-test", -1, "If non-negative, run each test with each malloc in turn failing from the given number onwards.") + mallocTestDebug *bool = flag.Bool("malloc-test-debug", false, "If true, ask bssl_shim to abort rather than fail a malloc. This can be used with a specific value for --malloc-test to identity the malloc failing that is causing problems.") +) + +const ( + rsaCertificateFile = "cert.pem" + ecdsaCertificateFile = "ecdsa_cert.pem" +) + +const ( + rsaKeyFile = "key.pem" + ecdsaKeyFile = "ecdsa_key.pem" + channelIDKeyFile = "channel_id_key.pem" +) + +var rsaCertificate, ecdsaCertificate Certificate +var channelIDKey *ecdsa.PrivateKey +var channelIDBytes []byte + +var testOCSPResponse = []byte{1, 2, 3, 4} +var testSCTList = []byte{5, 6, 7, 8} + +func initCertificates() { + var err error + rsaCertificate, err = LoadX509KeyPair(rsaCertificateFile, rsaKeyFile) + if err != nil { + panic(err) + } + rsaCertificate.OCSPStaple = testOCSPResponse + rsaCertificate.SignedCertificateTimestampList = testSCTList + + ecdsaCertificate, err = LoadX509KeyPair(ecdsaCertificateFile, ecdsaKeyFile) + if err != nil { + panic(err) + } + ecdsaCertificate.OCSPStaple = testOCSPResponse + ecdsaCertificate.SignedCertificateTimestampList = testSCTList + + channelIDPEMBlock, err := ioutil.ReadFile(channelIDKeyFile) + if err != nil { + panic(err) + } + channelIDDERBlock, _ := pem.Decode(channelIDPEMBlock) + if channelIDDERBlock.Type != "EC PRIVATE KEY" { + panic("bad key type") + } + channelIDKey, err = x509.ParseECPrivateKey(channelIDDERBlock.Bytes) + if err != nil { + panic(err) + } + if channelIDKey.Curve != elliptic.P256() { + panic("bad curve") + } + + channelIDBytes = make([]byte, 64) + writeIntPadded(channelIDBytes[:32], channelIDKey.X) + writeIntPadded(channelIDBytes[32:], channelIDKey.Y) +} + +var certificateOnce sync.Once + +func getRSACertificate() Certificate { + certificateOnce.Do(initCertificates) + return rsaCertificate +} + +func getECDSACertificate() Certificate { + certificateOnce.Do(initCertificates) + return ecdsaCertificate +} + +type testType int + +const ( + clientTest testType = iota + serverTest +) + +type protocol int + +const ( + tls protocol = iota + dtls +) + +const ( + alpn = 1 + npn = 2 +) + +type testCase struct { + testType testType + protocol protocol + name string + config Config + shouldFail bool + expectedError string + // expectedLocalError, if not empty, contains a substring that must be + // found in the local error. + expectedLocalError string + // expectedVersion, if non-zero, specifies the TLS version that must be + // negotiated. + expectedVersion uint16 + // expectedResumeVersion, if non-zero, specifies the TLS version that + // must be negotiated on resumption. If zero, expectedVersion is used. + expectedResumeVersion uint16 + // expectChannelID controls whether the connection should have + // negotiated a Channel ID with channelIDKey. + expectChannelID bool + // expectedNextProto controls whether the connection should + // negotiate a next protocol via NPN or ALPN. + expectedNextProto string + // expectedNextProtoType, if non-zero, is the expected next + // protocol negotiation mechanism. + expectedNextProtoType int + // expectedSRTPProtectionProfile is the DTLS-SRTP profile that + // should be negotiated. If zero, none should be negotiated. + expectedSRTPProtectionProfile uint16 + // messageLen is the length, in bytes, of the test message that will be + // sent. + messageLen int + // certFile is the path to the certificate to use for the server. + certFile string + // keyFile is the path to the private key to use for the server. + keyFile string + // resumeSession controls whether a second connection should be tested + // which attempts to resume the first session. + resumeSession bool + // resumeConfig, if not nil, points to a Config to be used on + // resumption. Unless newSessionsOnResume is set, + // SessionTicketKey, ServerSessionCache, and + // ClientSessionCache are copied from the initial connection's + // config. If nil, the initial connection's config is used. + resumeConfig *Config + // newSessionsOnResume, if true, will cause resumeConfig to + // use a different session resumption context. + newSessionsOnResume bool + // sendPrefix sends a prefix on the socket before actually performing a + // handshake. + sendPrefix string + // shimWritesFirst controls whether the shim sends an initial "hello" + // message before doing a roundtrip with the runner. + shimWritesFirst bool + // renegotiate indicates the the connection should be renegotiated + // during the exchange. + renegotiate bool + // renegotiateCiphers is a list of ciphersuite ids that will be + // switched in just before renegotiation. + renegotiateCiphers []uint16 + // replayWrites, if true, configures the underlying transport + // to replay every write it makes in DTLS tests. + replayWrites bool + // damageFirstWrite, if true, configures the underlying transport to + // damage the final byte of the first application data write. + damageFirstWrite bool + // flags, if not empty, contains a list of command-line flags that will + // be passed to the shim program. + flags []string +} + +var testCases = []testCase{ + { + name: "BadRSASignature", + config: Config{ + CipherSuites: []uint16{TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256}, + Bugs: ProtocolBugs{ + InvalidSKXSignature: true, + }, + }, + shouldFail: true, + expectedError: ":BAD_SIGNATURE:", + }, + { + name: "BadECDSASignature", + config: Config{ + CipherSuites: []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, + Bugs: ProtocolBugs{ + InvalidSKXSignature: true, + }, + Certificates: []Certificate{getECDSACertificate()}, + }, + shouldFail: true, + expectedError: ":BAD_SIGNATURE:", + }, + { + name: "BadECDSACurve", + config: Config{ + CipherSuites: []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, + Bugs: ProtocolBugs{ + InvalidSKXCurve: true, + }, + Certificates: []Certificate{getECDSACertificate()}, + }, + shouldFail: true, + expectedError: ":WRONG_CURVE:", + }, + { + testType: serverTest, + name: "BadRSAVersion", + config: Config{ + CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA}, + Bugs: ProtocolBugs{ + RsaClientKeyExchangeVersion: VersionTLS11, + }, + }, + shouldFail: true, + expectedError: ":DECRYPTION_FAILED_OR_BAD_RECORD_MAC:", + }, + { + name: "NoFallbackSCSV", + config: Config{ + Bugs: ProtocolBugs{ + FailIfNotFallbackSCSV: true, + }, + }, + shouldFail: true, + expectedLocalError: "no fallback SCSV found", + }, + { + name: "SendFallbackSCSV", + config: Config{ + Bugs: ProtocolBugs{ + FailIfNotFallbackSCSV: true, + }, + }, + flags: []string{"-fallback-scsv"}, + }, + { + name: "ClientCertificateTypes", + config: Config{ + ClientAuth: RequestClientCert, + ClientCertificateTypes: []byte{ + CertTypeDSSSign, + CertTypeRSASign, + CertTypeECDSASign, + }, + }, + flags: []string{ + "-expect-certificate-types", + base64.StdEncoding.EncodeToString([]byte{ + CertTypeDSSSign, + CertTypeRSASign, + CertTypeECDSASign, + }), + }, + }, + { + name: "NoClientCertificate", + config: Config{ + ClientAuth: RequireAnyClientCert, + }, + shouldFail: true, + expectedLocalError: "client didn't provide a certificate", + }, + { + name: "UnauthenticatedECDH", + config: Config{ + CipherSuites: []uint16{TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256}, + Bugs: ProtocolBugs{ + UnauthenticatedECDH: true, + }, + }, + shouldFail: true, + expectedError: ":UNEXPECTED_MESSAGE:", + }, + { + name: "SkipServerKeyExchange", + config: Config{ + CipherSuites: []uint16{TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256}, + Bugs: ProtocolBugs{ + SkipServerKeyExchange: true, + }, + }, + shouldFail: true, + expectedError: ":UNEXPECTED_MESSAGE:", + }, + { + name: "SkipChangeCipherSpec-Client", + config: Config{ + Bugs: ProtocolBugs{ + SkipChangeCipherSpec: true, + }, + }, + shouldFail: true, + expectedError: ":HANDSHAKE_RECORD_BEFORE_CCS:", + }, + { + testType: serverTest, + name: "SkipChangeCipherSpec-Server", + config: Config{ + Bugs: ProtocolBugs{ + SkipChangeCipherSpec: true, + }, + }, + shouldFail: true, + expectedError: ":HANDSHAKE_RECORD_BEFORE_CCS:", + }, + { + testType: serverTest, + name: "SkipChangeCipherSpec-Server-NPN", + config: Config{ + NextProtos: []string{"bar"}, + Bugs: ProtocolBugs{ + SkipChangeCipherSpec: true, + }, + }, + flags: []string{ + "-advertise-npn", "\x03foo\x03bar\x03baz", + }, + shouldFail: true, + expectedError: ":HANDSHAKE_RECORD_BEFORE_CCS:", + }, + { + name: "FragmentAcrossChangeCipherSpec-Client", + config: Config{ + Bugs: ProtocolBugs{ + FragmentAcrossChangeCipherSpec: true, + }, + }, + shouldFail: true, + expectedError: ":HANDSHAKE_RECORD_BEFORE_CCS:", + }, + { + testType: serverTest, + name: "FragmentAcrossChangeCipherSpec-Server", + config: Config{ + Bugs: ProtocolBugs{ + FragmentAcrossChangeCipherSpec: true, + }, + }, + shouldFail: true, + expectedError: ":HANDSHAKE_RECORD_BEFORE_CCS:", + }, + { + testType: serverTest, + name: "FragmentAcrossChangeCipherSpec-Server-NPN", + config: Config{ + NextProtos: []string{"bar"}, + Bugs: ProtocolBugs{ + FragmentAcrossChangeCipherSpec: true, + }, + }, + flags: []string{ + "-advertise-npn", "\x03foo\x03bar\x03baz", + }, + shouldFail: true, + expectedError: ":HANDSHAKE_RECORD_BEFORE_CCS:", + }, + { + testType: serverTest, + name: "FragmentAlert", + config: Config{ + Bugs: ProtocolBugs{ + FragmentAlert: true, + SendSpuriousAlert: true, + }, + }, + shouldFail: true, + expectedError: ":BAD_ALERT:", + }, + { + testType: serverTest, + name: "EarlyChangeCipherSpec-server-1", + config: Config{ + Bugs: ProtocolBugs{ + EarlyChangeCipherSpec: 1, + }, + }, + shouldFail: true, + expectedError: ":CCS_RECEIVED_EARLY:", + }, + { + testType: serverTest, + name: "EarlyChangeCipherSpec-server-2", + config: Config{ + Bugs: ProtocolBugs{ + EarlyChangeCipherSpec: 2, + }, + }, + shouldFail: true, + expectedError: ":CCS_RECEIVED_EARLY:", + }, + { + name: "SkipNewSessionTicket", + config: Config{ + Bugs: ProtocolBugs{ + SkipNewSessionTicket: true, + }, + }, + shouldFail: true, + expectedError: ":CCS_RECEIVED_EARLY:", + }, + { + testType: serverTest, + name: "FallbackSCSV", + config: Config{ + MaxVersion: VersionTLS11, + Bugs: ProtocolBugs{ + SendFallbackSCSV: true, + }, + }, + shouldFail: true, + expectedError: ":INAPPROPRIATE_FALLBACK:", + }, + { + testType: serverTest, + name: "FallbackSCSV-VersionMatch", + config: Config{ + Bugs: ProtocolBugs{ + SendFallbackSCSV: true, + }, + }, + }, + { + testType: serverTest, + name: "FragmentedClientVersion", + config: Config{ + Bugs: ProtocolBugs{ + MaxHandshakeRecordLength: 1, + FragmentClientVersion: true, + }, + }, + expectedVersion: VersionTLS12, + }, + { + testType: serverTest, + name: "MinorVersionTolerance", + config: Config{ + Bugs: ProtocolBugs{ + SendClientVersion: 0x03ff, + }, + }, + expectedVersion: VersionTLS12, + }, + { + testType: serverTest, + name: "MajorVersionTolerance", + config: Config{ + Bugs: ProtocolBugs{ + SendClientVersion: 0x0400, + }, + }, + expectedVersion: VersionTLS12, + }, + { + testType: serverTest, + name: "VersionTooLow", + config: Config{ + Bugs: ProtocolBugs{ + SendClientVersion: 0x0200, + }, + }, + shouldFail: true, + expectedError: ":UNSUPPORTED_PROTOCOL:", + }, + { + testType: serverTest, + name: "HttpGET", + sendPrefix: "GET / HTTP/1.0\n", + shouldFail: true, + expectedError: ":HTTP_REQUEST:", + }, + { + testType: serverTest, + name: "HttpPOST", + sendPrefix: "POST / HTTP/1.0\n", + shouldFail: true, + expectedError: ":HTTP_REQUEST:", + }, + { + testType: serverTest, + name: "HttpHEAD", + sendPrefix: "HEAD / HTTP/1.0\n", + shouldFail: true, + expectedError: ":HTTP_REQUEST:", + }, + { + testType: serverTest, + name: "HttpPUT", + sendPrefix: "PUT / HTTP/1.0\n", + shouldFail: true, + expectedError: ":HTTP_REQUEST:", + }, + { + testType: serverTest, + name: "HttpCONNECT", + sendPrefix: "CONNECT www.google.com:443 HTTP/1.0\n", + shouldFail: true, + expectedError: ":HTTPS_PROXY_REQUEST:", + }, + { + testType: serverTest, + name: "Garbage", + sendPrefix: "blah", + shouldFail: true, + expectedError: ":UNKNOWN_PROTOCOL:", + }, + { + name: "SkipCipherVersionCheck", + config: Config{ + CipherSuites: []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256}, + MaxVersion: VersionTLS11, + Bugs: ProtocolBugs{ + SkipCipherVersionCheck: true, + }, + }, + shouldFail: true, + expectedError: ":WRONG_CIPHER_RETURNED:", + }, + { + name: "RSAServerKeyExchange", + config: Config{ + CipherSuites: []uint16{TLS_RSA_WITH_AES_128_CBC_SHA}, + Bugs: ProtocolBugs{ + RSAServerKeyExchange: true, + }, + }, + shouldFail: true, + expectedError: ":UNEXPECTED_MESSAGE:", + }, + { + name: "DisableEverything", + flags: []string{"-no-tls12", "-no-tls11", "-no-tls1", "-no-ssl3"}, + shouldFail: true, + expectedError: ":WRONG_SSL_VERSION:", + }, + { + protocol: dtls, + name: "DisableEverything-DTLS", + flags: []string{"-no-tls12", "-no-tls1"}, + shouldFail: true, + expectedError: ":WRONG_SSL_VERSION:", + }, + { + name: "NoSharedCipher", + config: Config{ + CipherSuites: []uint16{}, + }, + shouldFail: true, + expectedError: ":HANDSHAKE_FAILURE_ON_CLIENT_HELLO:", + }, + { + protocol: dtls, + testType: serverTest, + name: "MTU", + config: Config{ + Bugs: ProtocolBugs{ + MaxPacketLength: 256, + }, + }, + flags: []string{"-mtu", "256"}, + }, + { + protocol: dtls, + testType: serverTest, + name: "MTUExceeded", + config: Config{ + Bugs: ProtocolBugs{ + MaxPacketLength: 255, + }, + }, + flags: []string{"-mtu", "256"}, + shouldFail: true, + expectedLocalError: "dtls: exceeded maximum packet length", + }, + { + name: "CertMismatchRSA", + config: Config{ + CipherSuites: []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, + Certificates: []Certificate{getECDSACertificate()}, + Bugs: ProtocolBugs{ + SendCipherSuite: TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256, + }, + }, + shouldFail: true, + expectedError: ":WRONG_CERTIFICATE_TYPE:", + }, + { + name: "CertMismatchECDSA", + config: Config{ + CipherSuites: []uint16{TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256}, + Certificates: []Certificate{getRSACertificate()}, + Bugs: ProtocolBugs{ + SendCipherSuite: TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256, + }, + }, + shouldFail: true, + expectedError: ":WRONG_CERTIFICATE_TYPE:", + }, + { + name: "TLSFatalBadPackets", + damageFirstWrite: true, + shouldFail: true, + expectedError: ":DECRYPTION_FAILED_OR_BAD_RECORD_MAC:", + }, + { + protocol: dtls, + name: "DTLSIgnoreBadPackets", + damageFirstWrite: true, + }, + { + protocol: dtls, + name: "DTLSIgnoreBadPackets-Async", + damageFirstWrite: true, + flags: []string{"-async"}, + }, + { + name: "AppDataAfterChangeCipherSpec", + config: Config{ + Bugs: ProtocolBugs{ + AppDataAfterChangeCipherSpec: []byte("TEST MESSAGE"), + }, + }, + shouldFail: true, + expectedError: ":DATA_BETWEEN_CCS_AND_FINISHED:", + }, + { + protocol: dtls, + name: "AppDataAfterChangeCipherSpec-DTLS", + config: Config{ + Bugs: ProtocolBugs{ + AppDataAfterChangeCipherSpec: []byte("TEST MESSAGE"), + }, + }, + }, +} + +func doExchange(test *testCase, config *Config, conn net.Conn, messageLen int, isResume bool) error { + var connDebug *recordingConn + var connDamage *damageAdaptor + if *flagDebug { + connDebug = &recordingConn{Conn: conn} + conn = connDebug + defer func() { + connDebug.WriteTo(os.Stdout) + }() + } + + if test.protocol == dtls { + conn = newPacketAdaptor(conn) + if test.replayWrites { + conn = newReplayAdaptor(conn) + } + } + + if test.damageFirstWrite { + connDamage = newDamageAdaptor(conn) + conn = connDamage + } + + if test.sendPrefix != "" { + if _, err := conn.Write([]byte(test.sendPrefix)); err != nil { + return err + } + } + + var tlsConn *Conn + if test.testType == clientTest { + if test.protocol == dtls { + tlsConn = DTLSServer(conn, config) + } else { + tlsConn = Server(conn, config) + } + } else { + config.InsecureSkipVerify = true + if test.protocol == dtls { + tlsConn = DTLSClient(conn, config) + } else { + tlsConn = Client(conn, config) + } + } + + if err := tlsConn.Handshake(); err != nil { + return err + } + + // TODO(davidben): move all per-connection expectations into a dedicated + // expectations struct that can be specified separately for the two + // legs. + expectedVersion := test.expectedVersion + if isResume && test.expectedResumeVersion != 0 { + expectedVersion = test.expectedResumeVersion + } + if vers := tlsConn.ConnectionState().Version; expectedVersion != 0 && vers != expectedVersion { + return fmt.Errorf("got version %x, expected %x", vers, expectedVersion) + } + + if test.expectChannelID { + channelID := tlsConn.ConnectionState().ChannelID + if channelID == nil { + return fmt.Errorf("no channel ID negotiated") + } + if channelID.Curve != channelIDKey.Curve || + channelIDKey.X.Cmp(channelIDKey.X) != 0 || + channelIDKey.Y.Cmp(channelIDKey.Y) != 0 { + return fmt.Errorf("incorrect channel ID") + } + } + + if expected := test.expectedNextProto; expected != "" { + if actual := tlsConn.ConnectionState().NegotiatedProtocol; actual != expected { + return fmt.Errorf("next proto mismatch: got %s, wanted %s", actual, expected) + } + } + + if test.expectedNextProtoType != 0 { + if (test.expectedNextProtoType == alpn) != tlsConn.ConnectionState().NegotiatedProtocolFromALPN { + return fmt.Errorf("next proto type mismatch") + } + } + + if p := tlsConn.ConnectionState().SRTPProtectionProfile; p != test.expectedSRTPProtectionProfile { + return fmt.Errorf("SRTP profile mismatch: got %d, wanted %d", p, test.expectedSRTPProtectionProfile) + } + + if test.shimWritesFirst { + var buf [5]byte + _, err := io.ReadFull(tlsConn, buf[:]) + if err != nil { + return err + } + if string(buf[:]) != "hello" { + return fmt.Errorf("bad initial message") + } + } + + if test.renegotiate { + if test.renegotiateCiphers != nil { + config.CipherSuites = test.renegotiateCiphers + } + if err := tlsConn.Renegotiate(); err != nil { + return err + } + } else if test.renegotiateCiphers != nil { + panic("renegotiateCiphers without renegotiate") + } + + if test.damageFirstWrite { + connDamage.setDamage(true) + tlsConn.Write([]byte("DAMAGED WRITE")) + connDamage.setDamage(false) + } + + if messageLen < 0 { + if test.protocol == dtls { + return fmt.Errorf("messageLen < 0 not supported for DTLS tests") + } + // Read until EOF. + _, err := io.Copy(ioutil.Discard, tlsConn) + return err + } + + var testMessage []byte + if config.Bugs.AppDataAfterChangeCipherSpec != nil { + // We've already sent a message. Expect the shim to echo it + // back. + testMessage = config.Bugs.AppDataAfterChangeCipherSpec + } else { + if messageLen == 0 { + messageLen = 32 + } + testMessage = make([]byte, messageLen) + for i := range testMessage { + testMessage[i] = 0x42 + } + tlsConn.Write(testMessage) + } + + buf := make([]byte, len(testMessage)) + if test.protocol == dtls { + bufTmp := make([]byte, len(buf)+1) + n, err := tlsConn.Read(bufTmp) + if err != nil { + return err + } + if n != len(buf) { + return fmt.Errorf("bad reply; length mismatch (%d vs %d)", n, len(buf)) + } + copy(buf, bufTmp) + } else { + _, err := io.ReadFull(tlsConn, buf) + if err != nil { + return err + } + } + + for i, v := range buf { + if v != testMessage[i]^0xff { + return fmt.Errorf("bad reply contents at byte %d", i) + } + } + + return nil +} + +func valgrindOf(dbAttach bool, path string, args ...string) *exec.Cmd { + valgrindArgs := []string{"--error-exitcode=99", "--track-origins=yes", "--leak-check=full"} + if dbAttach { + valgrindArgs = append(valgrindArgs, "--db-attach=yes", "--db-command=xterm -e gdb -nw %f %p") + } + valgrindArgs = append(valgrindArgs, path) + valgrindArgs = append(valgrindArgs, args...) + + return exec.Command("valgrind", valgrindArgs...) +} + +func gdbOf(path string, args ...string) *exec.Cmd { + xtermArgs := []string{"-e", "gdb", "--args"} + xtermArgs = append(xtermArgs, path) + xtermArgs = append(xtermArgs, args...) + + return exec.Command("xterm", xtermArgs...) +} + +func openSocketPair() (shimEnd *os.File, conn net.Conn) { + socks, err := syscall.Socketpair(syscall.AF_UNIX, syscall.SOCK_STREAM, 0) + if err != nil { + panic(err) + } + + syscall.CloseOnExec(socks[0]) + syscall.CloseOnExec(socks[1]) + shimEnd = os.NewFile(uintptr(socks[0]), "shim end") + connFile := os.NewFile(uintptr(socks[1]), "our end") + conn, err = net.FileConn(connFile) + if err != nil { + panic(err) + } + connFile.Close() + if err != nil { + panic(err) + } + return shimEnd, conn +} + +type moreMallocsError struct{} + +func (moreMallocsError) Error() string { + return "child process did not exhaust all allocation calls" +} + +var errMoreMallocs = moreMallocsError{} + +func runTest(test *testCase, buildDir string, mallocNumToFail int64) error { + if !test.shouldFail && (len(test.expectedError) > 0 || len(test.expectedLocalError) > 0) { + panic("Error expected without shouldFail in " + test.name) + } + + shimEnd, conn := openSocketPair() + shimEndResume, connResume := openSocketPair() + + shim_path := path.Join(buildDir, "ssl/test/bssl_shim") + var flags []string + if test.testType == serverTest { + flags = append(flags, "-server") + + flags = append(flags, "-key-file") + if test.keyFile == "" { + flags = append(flags, rsaKeyFile) + } else { + flags = append(flags, test.keyFile) + } + + flags = append(flags, "-cert-file") + if test.certFile == "" { + flags = append(flags, rsaCertificateFile) + } else { + flags = append(flags, test.certFile) + } + } + + if test.protocol == dtls { + flags = append(flags, "-dtls") + } + + if test.resumeSession { + flags = append(flags, "-resume") + } + + if test.shimWritesFirst { + flags = append(flags, "-shim-writes-first") + } + + flags = append(flags, test.flags...) + + var shim *exec.Cmd + if *useValgrind { + shim = valgrindOf(false, shim_path, flags...) + } else if *useGDB { + shim = gdbOf(shim_path, flags...) + } else { + shim = exec.Command(shim_path, flags...) + } + shim.ExtraFiles = []*os.File{shimEnd, shimEndResume} + shim.Stdin = os.Stdin + var stdoutBuf, stderrBuf bytes.Buffer + shim.Stdout = &stdoutBuf + shim.Stderr = &stderrBuf + if mallocNumToFail >= 0 { + shim.Env = []string{"MALLOC_NUMBER_TO_FAIL=" + strconv.FormatInt(mallocNumToFail, 10)} + if *mallocTestDebug { + shim.Env = append(shim.Env, "MALLOC_ABORT_ON_FAIL=1") + } + shim.Env = append(shim.Env, "_MALLOC_CHECK=1") + } + + if err := shim.Start(); err != nil { + panic(err) + } + shimEnd.Close() + shimEndResume.Close() + + config := test.config + config.ClientSessionCache = NewLRUClientSessionCache(1) + config.ServerSessionCache = NewLRUServerSessionCache(1) + if test.testType == clientTest { + if len(config.Certificates) == 0 { + config.Certificates = []Certificate{getRSACertificate()} + } + } + + err := doExchange(test, &config, conn, test.messageLen, + false /* not a resumption */) + conn.Close() + + if err == nil && test.resumeSession { + var resumeConfig Config + if test.resumeConfig != nil { + resumeConfig = *test.resumeConfig + if len(resumeConfig.Certificates) == 0 { + resumeConfig.Certificates = []Certificate{getRSACertificate()} + } + if !test.newSessionsOnResume { + resumeConfig.SessionTicketKey = config.SessionTicketKey + resumeConfig.ClientSessionCache = config.ClientSessionCache + resumeConfig.ServerSessionCache = config.ServerSessionCache + } + } else { + resumeConfig = config + } + err = doExchange(test, &resumeConfig, connResume, test.messageLen, + true /* resumption */) + } + connResume.Close() + + childErr := shim.Wait() + if exitError, ok := childErr.(*exec.ExitError); ok { + if exitError.Sys().(syscall.WaitStatus).ExitStatus() == 88 { + return errMoreMallocs + } + } + + stdout := string(stdoutBuf.Bytes()) + stderr := string(stderrBuf.Bytes()) + failed := err != nil || childErr != nil + correctFailure := len(test.expectedError) == 0 || strings.Contains(stdout, test.expectedError) + localError := "none" + if err != nil { + localError = err.Error() + } + if len(test.expectedLocalError) != 0 { + correctFailure = correctFailure && strings.Contains(localError, test.expectedLocalError) + } + + if failed != test.shouldFail || failed && !correctFailure { + childError := "none" + if childErr != nil { + childError = childErr.Error() + } + + var msg string + switch { + case failed && !test.shouldFail: + msg = "unexpected failure" + case !failed && test.shouldFail: + msg = "unexpected success" + case failed && !correctFailure: + msg = "bad error (wanted '" + test.expectedError + "' / '" + test.expectedLocalError + "')" + default: + panic("internal error") + } + + return fmt.Errorf("%s: local error '%s', child error '%s', stdout:\n%s\nstderr:\n%s", msg, localError, childError, string(stdoutBuf.Bytes()), stderr) + } + + if !*useValgrind && len(stderr) > 0 { + println(stderr) + } + + return nil +} + +var tlsVersions = []struct { + name string + version uint16 + flag string + hasDTLS bool +}{ + {"SSL3", VersionSSL30, "-no-ssl3", false}, + {"TLS1", VersionTLS10, "-no-tls1", true}, + {"TLS11", VersionTLS11, "-no-tls11", false}, + {"TLS12", VersionTLS12, "-no-tls12", true}, +} + +var testCipherSuites = []struct { + name string + id uint16 +}{ + {"3DES-SHA", TLS_RSA_WITH_3DES_EDE_CBC_SHA}, + {"AES128-GCM", TLS_RSA_WITH_AES_128_GCM_SHA256}, + {"AES128-SHA", TLS_RSA_WITH_AES_128_CBC_SHA}, + {"AES128-SHA256", TLS_RSA_WITH_AES_128_CBC_SHA256}, + {"AES256-GCM", TLS_RSA_WITH_AES_256_GCM_SHA384}, + {"AES256-SHA", TLS_RSA_WITH_AES_256_CBC_SHA}, + {"AES256-SHA256", TLS_RSA_WITH_AES_256_CBC_SHA256}, + {"DHE-RSA-AES128-GCM", TLS_DHE_RSA_WITH_AES_128_GCM_SHA256}, + {"DHE-RSA-AES128-SHA", TLS_DHE_RSA_WITH_AES_128_CBC_SHA}, + {"DHE-RSA-AES128-SHA256", TLS_DHE_RSA_WITH_AES_128_CBC_SHA256}, + {"DHE-RSA-AES256-GCM", TLS_DHE_RSA_WITH_AES_256_GCM_SHA384}, + {"DHE-RSA-AES256-SHA", TLS_DHE_RSA_WITH_AES_256_CBC_SHA}, + {"DHE-RSA-AES256-SHA256", TLS_DHE_RSA_WITH_AES_256_CBC_SHA256}, + {"ECDHE-ECDSA-AES128-GCM", TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, + {"ECDHE-ECDSA-AES128-SHA", TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA}, + {"ECDHE-ECDSA-AES128-SHA256", TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256}, + {"ECDHE-ECDSA-AES256-GCM", TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384}, + {"ECDHE-ECDSA-AES256-SHA", TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA}, + {"ECDHE-ECDSA-AES256-SHA384", TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384}, + {"ECDHE-ECDSA-RC4-SHA", TLS_ECDHE_ECDSA_WITH_RC4_128_SHA}, + {"ECDHE-PSK-WITH-AES-128-GCM-SHA256", TLS_ECDHE_PSK_WITH_AES_128_GCM_SHA256}, + {"ECDHE-RSA-AES128-GCM", TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256}, + {"ECDHE-RSA-AES128-SHA", TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA}, + {"ECDHE-RSA-AES128-SHA256", TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256}, + {"ECDHE-RSA-AES256-GCM", TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384}, + {"ECDHE-RSA-AES256-SHA", TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA}, + {"ECDHE-RSA-AES256-SHA384", TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384}, + {"ECDHE-RSA-RC4-SHA", TLS_ECDHE_RSA_WITH_RC4_128_SHA}, + {"PSK-AES128-CBC-SHA", TLS_PSK_WITH_AES_128_CBC_SHA}, + {"PSK-AES256-CBC-SHA", TLS_PSK_WITH_AES_256_CBC_SHA}, + {"PSK-RC4-SHA", TLS_PSK_WITH_RC4_128_SHA}, + {"RC4-MD5", TLS_RSA_WITH_RC4_128_MD5}, + {"RC4-SHA", TLS_RSA_WITH_RC4_128_SHA}, +} + +func hasComponent(suiteName, component string) bool { + return strings.Contains("-"+suiteName+"-", "-"+component+"-") +} + +func isTLS12Only(suiteName string) bool { + return hasComponent(suiteName, "GCM") || + hasComponent(suiteName, "SHA256") || + hasComponent(suiteName, "SHA384") +} + +func isDTLSCipher(suiteName string) bool { + return !hasComponent(suiteName, "RC4") +} + +func addCipherSuiteTests() { + for _, suite := range testCipherSuites { + const psk = "12345" + const pskIdentity = "luggage combo" + + var cert Certificate + var certFile string + var keyFile string + if hasComponent(suite.name, "ECDSA") { + cert = getECDSACertificate() + certFile = ecdsaCertificateFile + keyFile = ecdsaKeyFile + } else { + cert = getRSACertificate() + certFile = rsaCertificateFile + keyFile = rsaKeyFile + } + + var flags []string + if hasComponent(suite.name, "PSK") { + flags = append(flags, + "-psk", psk, + "-psk-identity", pskIdentity) + } + + for _, ver := range tlsVersions { + if ver.version < VersionTLS12 && isTLS12Only(suite.name) { + continue + } + + testCases = append(testCases, testCase{ + testType: clientTest, + name: ver.name + "-" + suite.name + "-client", + config: Config{ + MinVersion: ver.version, + MaxVersion: ver.version, + CipherSuites: []uint16{suite.id}, + Certificates: []Certificate{cert}, + PreSharedKey: []byte(psk), + PreSharedKeyIdentity: pskIdentity, + }, + flags: flags, + resumeSession: true, + }) + + testCases = append(testCases, testCase{ + testType: serverTest, + name: ver.name + "-" + suite.name + "-server", + config: Config{ + MinVersion: ver.version, + MaxVersion: ver.version, + CipherSuites: []uint16{suite.id}, + Certificates: []Certificate{cert}, + PreSharedKey: []byte(psk), + PreSharedKeyIdentity: pskIdentity, + }, + certFile: certFile, + keyFile: keyFile, + flags: flags, + resumeSession: true, + }) + + if ver.hasDTLS && isDTLSCipher(suite.name) { + testCases = append(testCases, testCase{ + testType: clientTest, + protocol: dtls, + name: "D" + ver.name + "-" + suite.name + "-client", + config: Config{ + MinVersion: ver.version, + MaxVersion: ver.version, + CipherSuites: []uint16{suite.id}, + Certificates: []Certificate{cert}, + PreSharedKey: []byte(psk), + PreSharedKeyIdentity: pskIdentity, + }, + flags: flags, + resumeSession: true, + }) + testCases = append(testCases, testCase{ + testType: serverTest, + protocol: dtls, + name: "D" + ver.name + "-" + suite.name + "-server", + config: Config{ + MinVersion: ver.version, + MaxVersion: ver.version, + CipherSuites: []uint16{suite.id}, + Certificates: []Certificate{cert}, + PreSharedKey: []byte(psk), + PreSharedKeyIdentity: pskIdentity, + }, + certFile: certFile, + keyFile: keyFile, + flags: flags, + resumeSession: true, + }) + } + } + } +} + +func addBadECDSASignatureTests() { + for badR := BadValue(1); badR < NumBadValues; badR++ { + for badS := BadValue(1); badS < NumBadValues; badS++ { + testCases = append(testCases, testCase{ + name: fmt.Sprintf("BadECDSA-%d-%d", badR, badS), + config: Config{ + CipherSuites: []uint16{TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256}, + Certificates: []Certificate{getECDSACertificate()}, + Bugs: ProtocolBugs{ + BadECDSAR: badR, + BadECDSAS: badS, + }, + }, + shouldFail: true, + expectedError: "SIGNATURE", + }) + } + } +} + +func addCBCPaddingTests() { + testCases = append(testCases, testCase{ + name: "MaxCBCPadding", + config: Config{ + CipherSuites: []uint16{TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA}, + Bugs: ProtocolBugs{ + MaxPadding: true, + }, + }, + messageLen: 12, // 20 bytes of SHA-1 + 12 == 0 % block size + }) + testCases = append(testCases, testCase{ + name: "BadCBCPadding", + config: Config{ + CipherSuites: []uint16{TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA}, + Bugs: ProtocolBugs{ + PaddingFirstByteBad: true, + }, + }, + shouldFail: true, + expectedError: "DECRYPTION_FAILED_OR_BAD_RECORD_MAC", + }) + // OpenSSL previously had an issue where the first byte of padding in + // 255 bytes of padding wasn't checked. + testCases = append(testCases, testCase{ + name: "BadCBCPadding255", + config: Config{ + CipherSuites: []uint16{TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA}, + Bugs: ProtocolBugs{ + MaxPadding: true, + PaddingFirstByteBadIf255: true, + }, + }, + messageLen: 12, // 20 bytes of SHA-1 + 12 == 0 % block size + shouldFail: true, + expectedError: "DECRYPTION_FAILED_OR_BAD_RECORD_MAC", + }) +} + +func addCBCSplittingTests() { + testCases = append(testCases, testCase{ + name: "CBCRecordSplitting", + config: Config{ + MaxVersion: VersionTLS10, + MinVersion: VersionTLS10, + CipherSuites: []uint16{TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA}, + }, + messageLen: -1, // read until EOF + flags: []string{ + "-async", + "-write-different-record-sizes", + "-cbc-record-splitting", + }, + }) + testCases = append(testCases, testCase{ + name: "CBCRecordSplittingPartialWrite", + config: Config{ + MaxVersion: VersionTLS10, + MinVersion: VersionTLS10, + CipherSuites: []uint16{TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA}, + }, + messageLen: -1, // read until EOF + flags: []string{ + "-async", + "-write-different-record-sizes", + "-cbc-record-splitting", + "-partial-write", + }, + }) +} + +func addClientAuthTests() { + // Add a dummy cert pool to stress certificate authority parsing. + // TODO(davidben): Add tests that those values parse out correctly. + certPool := x509.NewCertPool() + cert, err := x509.ParseCertificate(rsaCertificate.Certificate[0]) + if err != nil { + panic(err) + } + certPool.AddCert(cert) + + for _, ver := range tlsVersions { + testCases = append(testCases, testCase{ + testType: clientTest, + name: ver.name + "-Client-ClientAuth-RSA", + config: Config{ + MinVersion: ver.version, + MaxVersion: ver.version, + ClientAuth: RequireAnyClientCert, + ClientCAs: certPool, + }, + flags: []string{ + "-cert-file", rsaCertificateFile, + "-key-file", rsaKeyFile, + }, + }) + testCases = append(testCases, testCase{ + testType: serverTest, + name: ver.name + "-Server-ClientAuth-RSA", + config: Config{ + MinVersion: ver.version, + MaxVersion: ver.version, + Certificates: []Certificate{rsaCertificate}, + }, + flags: []string{"-require-any-client-certificate"}, + }) + if ver.version != VersionSSL30 { + testCases = append(testCases, testCase{ + testType: serverTest, + name: ver.name + "-Server-ClientAuth-ECDSA", + config: Config{ + MinVersion: ver.version, + MaxVersion: ver.version, + Certificates: []Certificate{ecdsaCertificate}, + }, + flags: []string{"-require-any-client-certificate"}, + }) + testCases = append(testCases, testCase{ + testType: clientTest, + name: ver.name + "-Client-ClientAuth-ECDSA", + config: Config{ + MinVersion: ver.version, + MaxVersion: ver.version, + ClientAuth: RequireAnyClientCert, + ClientCAs: certPool, + }, + flags: []string{ + "-cert-file", ecdsaCertificateFile, + "-key-file", ecdsaKeyFile, + }, + }) + } + } +} + +func addExtendedMasterSecretTests() { + const expectEMSFlag = "-expect-extended-master-secret" + + for _, with := range []bool{false, true} { + prefix := "No" + var flags []string + if with { + prefix = "" + flags = []string{expectEMSFlag} + } + + for _, isClient := range []bool{false, true} { + suffix := "-Server" + testType := serverTest + if isClient { + suffix = "-Client" + testType = clientTest + } + + for _, ver := range tlsVersions { + test := testCase{ + testType: testType, + name: prefix + "ExtendedMasterSecret-" + ver.name + suffix, + config: Config{ + MinVersion: ver.version, + MaxVersion: ver.version, + Bugs: ProtocolBugs{ + NoExtendedMasterSecret: !with, + RequireExtendedMasterSecret: with, + }, + }, + flags: flags, + shouldFail: ver.version == VersionSSL30 && with, + } + if test.shouldFail { + test.expectedLocalError = "extended master secret required but not supported by peer" + } + testCases = append(testCases, test) + } + } + } + + // When a session is resumed, it should still be aware that its master + // secret was generated via EMS and thus it's safe to use tls-unique. + testCases = append(testCases, testCase{ + name: "ExtendedMasterSecret-Resume", + config: Config{ + Bugs: ProtocolBugs{ + RequireExtendedMasterSecret: true, + }, + }, + flags: []string{expectEMSFlag}, + resumeSession: true, + }) +} + +// Adds tests that try to cover the range of the handshake state machine, under +// various conditions. Some of these are redundant with other tests, but they +// only cover the synchronous case. +func addStateMachineCoverageTests(async, splitHandshake bool, protocol protocol) { + var suffix string + var flags []string + var maxHandshakeRecordLength int + if protocol == dtls { + suffix = "-DTLS" + } + if async { + suffix += "-Async" + flags = append(flags, "-async") + } else { + suffix += "-Sync" + } + if splitHandshake { + suffix += "-SplitHandshakeRecords" + maxHandshakeRecordLength = 1 + } + + // Basic handshake, with resumption. Client and server, + // session ID and session ticket. + testCases = append(testCases, testCase{ + protocol: protocol, + name: "Basic-Client" + suffix, + config: Config{ + Bugs: ProtocolBugs{ + MaxHandshakeRecordLength: maxHandshakeRecordLength, + }, + }, + flags: flags, + resumeSession: true, + }) + testCases = append(testCases, testCase{ + protocol: protocol, + name: "Basic-Client-RenewTicket" + suffix, + config: Config{ + Bugs: ProtocolBugs{ + MaxHandshakeRecordLength: maxHandshakeRecordLength, + RenewTicketOnResume: true, + }, + }, + flags: flags, + resumeSession: true, + }) + testCases = append(testCases, testCase{ + protocol: protocol, + name: "Basic-Client-NoTicket" + suffix, + config: Config{ + SessionTicketsDisabled: true, + Bugs: ProtocolBugs{ + MaxHandshakeRecordLength: maxHandshakeRecordLength, + }, + }, + flags: flags, + resumeSession: true, + }) + testCases = append(testCases, testCase{ + protocol: protocol, + testType: serverTest, + name: "Basic-Server" + suffix, + config: Config{ + Bugs: ProtocolBugs{ + MaxHandshakeRecordLength: maxHandshakeRecordLength, + }, + }, + flags: flags, + resumeSession: true, + }) + testCases = append(testCases, testCase{ + protocol: protocol, + testType: serverTest, + name: "Basic-Server-NoTickets" + suffix, + config: Config{ + SessionTicketsDisabled: true, + Bugs: ProtocolBugs{ + MaxHandshakeRecordLength: maxHandshakeRecordLength, + }, + }, + flags: flags, + resumeSession: true, + }) + + // TLS client auth. + testCases = append(testCases, testCase{ + protocol: protocol, + testType: clientTest, + name: "ClientAuth-Client" + suffix, + config: Config{ + ClientAuth: RequireAnyClientCert, + Bugs: ProtocolBugs{ + MaxHandshakeRecordLength: maxHandshakeRecordLength, + }, + }, + flags: append(flags, + "-cert-file", rsaCertificateFile, + "-key-file", rsaKeyFile), + }) + testCases = append(testCases, testCase{ + protocol: protocol, + testType: serverTest, + name: "ClientAuth-Server" + suffix, + config: Config{ + Certificates: []Certificate{rsaCertificate}, + }, + flags: append(flags, "-require-any-client-certificate"), + }) + + // No session ticket support; server doesn't send NewSessionTicket. + testCases = append(testCases, testCase{ + protocol: protocol, + name: "SessionTicketsDisabled-Client" + suffix, + config: Config{ + SessionTicketsDisabled: true, + Bugs: ProtocolBugs{ + MaxHandshakeRecordLength: maxHandshakeRecordLength, + }, + }, + flags: flags, + }) + testCases = append(testCases, testCase{ + protocol: protocol, + testType: serverTest, + name: "SessionTicketsDisabled-Server" + suffix, + config: Config{ + SessionTicketsDisabled: true, + Bugs: ProtocolBugs{ + MaxHandshakeRecordLength: maxHandshakeRecordLength, + }, + }, + flags: flags, + }) + + // Skip ServerKeyExchange in PSK key exchange if there's no + // identity hint. + testCases = append(testCases, testCase{ + protocol: protocol, + name: "EmptyPSKHint-Client" + suffix, + config: Config{ + CipherSuites: []uint16{TLS_PSK_WITH_AES_128_CBC_SHA}, + PreSharedKey: []byte("secret"), + Bugs: ProtocolBugs{ + MaxHandshakeRecordLength: maxHandshakeRecordLength, + }, + }, + flags: append(flags, "-psk", "secret"), + }) + testCases = append(testCases, testCase{ + protocol: protocol, + testType: serverTest, + name: "EmptyPSKHint-Server" + suffix, + config: Config{ + CipherSuites: []uint16{TLS_PSK_WITH_AES_128_CBC_SHA}, + PreSharedKey: []byte("secret"), + Bugs: ProtocolBugs{ + MaxHandshakeRecordLength: maxHandshakeRecordLength, + }, + }, + flags: append(flags, "-psk", "secret"), + }) + + if protocol == tls { + // NPN on client and server; results in post-handshake message. + testCases = append(testCases, testCase{ + protocol: protocol, + name: "NPN-Client" + suffix, + config: Config{ + NextProtos: []string{"foo"}, + Bugs: ProtocolBugs{ + MaxHandshakeRecordLength: maxHandshakeRecordLength, + }, + }, + flags: append(flags, "-select-next-proto", "foo"), + expectedNextProto: "foo", + expectedNextProtoType: npn, + }) + testCases = append(testCases, testCase{ + protocol: protocol, + testType: serverTest, + name: "NPN-Server" + suffix, + config: Config{ + NextProtos: []string{"bar"}, + Bugs: ProtocolBugs{ + MaxHandshakeRecordLength: maxHandshakeRecordLength, + }, + }, + flags: append(flags, + "-advertise-npn", "\x03foo\x03bar\x03baz", + "-expect-next-proto", "bar"), + expectedNextProto: "bar", + expectedNextProtoType: npn, + }) + + // Client does False Start and negotiates NPN. + testCases = append(testCases, testCase{ + protocol: protocol, + name: "FalseStart" + suffix, + config: Config{ + CipherSuites: []uint16{TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256}, + NextProtos: []string{"foo"}, + Bugs: ProtocolBugs{ + ExpectFalseStart: true, + MaxHandshakeRecordLength: maxHandshakeRecordLength, + }, + }, + flags: append(flags, + "-false-start", + "-select-next-proto", "foo"), + shimWritesFirst: true, + resumeSession: true, + }) + + // Client does False Start and negotiates ALPN. + testCases = append(testCases, testCase{ + protocol: protocol, + name: "FalseStart-ALPN" + suffix, + config: Config{ + CipherSuites: []uint16{TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256}, + NextProtos: []string{"foo"}, + Bugs: ProtocolBugs{ + ExpectFalseStart: true, + MaxHandshakeRecordLength: maxHandshakeRecordLength, + }, + }, + flags: append(flags, + "-false-start", + "-advertise-alpn", "\x03foo"), + shimWritesFirst: true, + resumeSession: true, + }) + + // False Start without session tickets. + testCases = append(testCases, testCase{ + name: "FalseStart-SessionTicketsDisabled", + config: Config{ + CipherSuites: []uint16{TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256}, + NextProtos: []string{"foo"}, + SessionTicketsDisabled: true, + Bugs: ProtocolBugs{ + ExpectFalseStart: true, + MaxHandshakeRecordLength: maxHandshakeRecordLength, + }, + }, + flags: append(flags, + "-false-start", + "-select-next-proto", "foo", + ), + shimWritesFirst: true, + }) + + // Server parses a V2ClientHello. + testCases = append(testCases, testCase{ + protocol: protocol, + testType: serverTest, + name: "SendV2ClientHello" + suffix, + config: Config{ + // Choose a cipher suite that does not involve + // elliptic curves, so no extensions are + // involved. + CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA}, + Bugs: ProtocolBugs{ + MaxHandshakeRecordLength: maxHandshakeRecordLength, + SendV2ClientHello: true, + }, + }, + flags: flags, + }) + + // Client sends a Channel ID. + testCases = append(testCases, testCase{ + protocol: protocol, + name: "ChannelID-Client" + suffix, + config: Config{ + RequestChannelID: true, + Bugs: ProtocolBugs{ + MaxHandshakeRecordLength: maxHandshakeRecordLength, + }, + }, + flags: append(flags, + "-send-channel-id", channelIDKeyFile, + ), + resumeSession: true, + expectChannelID: true, + }) + + // Server accepts a Channel ID. + testCases = append(testCases, testCase{ + protocol: protocol, + testType: serverTest, + name: "ChannelID-Server" + suffix, + config: Config{ + ChannelID: channelIDKey, + Bugs: ProtocolBugs{ + MaxHandshakeRecordLength: maxHandshakeRecordLength, + }, + }, + flags: append(flags, + "-expect-channel-id", + base64.StdEncoding.EncodeToString(channelIDBytes), + ), + resumeSession: true, + expectChannelID: true, + }) + } else { + testCases = append(testCases, testCase{ + protocol: protocol, + name: "SkipHelloVerifyRequest" + suffix, + config: Config{ + Bugs: ProtocolBugs{ + MaxHandshakeRecordLength: maxHandshakeRecordLength, + SkipHelloVerifyRequest: true, + }, + }, + flags: flags, + }) + + testCases = append(testCases, testCase{ + testType: serverTest, + protocol: protocol, + name: "CookieExchange" + suffix, + config: Config{ + Bugs: ProtocolBugs{ + MaxHandshakeRecordLength: maxHandshakeRecordLength, + }, + }, + flags: append(flags, "-cookie-exchange"), + }) + } +} + +func addVersionNegotiationTests() { + for i, shimVers := range tlsVersions { + // Assemble flags to disable all newer versions on the shim. + var flags []string + for _, vers := range tlsVersions[i+1:] { + flags = append(flags, vers.flag) + } + + for _, runnerVers := range tlsVersions { + protocols := []protocol{tls} + if runnerVers.hasDTLS && shimVers.hasDTLS { + protocols = append(protocols, dtls) + } + for _, protocol := range protocols { + expectedVersion := shimVers.version + if runnerVers.version < shimVers.version { + expectedVersion = runnerVers.version + } + + suffix := shimVers.name + "-" + runnerVers.name + if protocol == dtls { + suffix += "-DTLS" + } + + shimVersFlag := strconv.Itoa(int(versionToWire(shimVers.version, protocol == dtls))) + + clientVers := shimVers.version + if clientVers > VersionTLS10 { + clientVers = VersionTLS10 + } + testCases = append(testCases, testCase{ + protocol: protocol, + testType: clientTest, + name: "VersionNegotiation-Client-" + suffix, + config: Config{ + MaxVersion: runnerVers.version, + Bugs: ProtocolBugs{ + ExpectInitialRecordVersion: clientVers, + }, + }, + flags: flags, + expectedVersion: expectedVersion, + }) + testCases = append(testCases, testCase{ + protocol: protocol, + testType: clientTest, + name: "VersionNegotiation-Client2-" + suffix, + config: Config{ + MaxVersion: runnerVers.version, + Bugs: ProtocolBugs{ + ExpectInitialRecordVersion: clientVers, + }, + }, + flags: []string{"-max-version", shimVersFlag}, + expectedVersion: expectedVersion, + }) + + testCases = append(testCases, testCase{ + protocol: protocol, + testType: serverTest, + name: "VersionNegotiation-Server-" + suffix, + config: Config{ + MaxVersion: runnerVers.version, + Bugs: ProtocolBugs{ + ExpectInitialRecordVersion: expectedVersion, + }, + }, + flags: flags, + expectedVersion: expectedVersion, + }) + testCases = append(testCases, testCase{ + protocol: protocol, + testType: serverTest, + name: "VersionNegotiation-Server2-" + suffix, + config: Config{ + MaxVersion: runnerVers.version, + Bugs: ProtocolBugs{ + ExpectInitialRecordVersion: expectedVersion, + }, + }, + flags: []string{"-max-version", shimVersFlag}, + expectedVersion: expectedVersion, + }) + } + } + } +} + +func addMinimumVersionTests() { + for i, shimVers := range tlsVersions { + // Assemble flags to disable all older versions on the shim. + var flags []string + for _, vers := range tlsVersions[:i] { + flags = append(flags, vers.flag) + } + + for _, runnerVers := range tlsVersions { + protocols := []protocol{tls} + if runnerVers.hasDTLS && shimVers.hasDTLS { + protocols = append(protocols, dtls) + } + for _, protocol := range protocols { + suffix := shimVers.name + "-" + runnerVers.name + if protocol == dtls { + suffix += "-DTLS" + } + shimVersFlag := strconv.Itoa(int(versionToWire(shimVers.version, protocol == dtls))) + + var expectedVersion uint16 + var shouldFail bool + var expectedError string + var expectedLocalError string + if runnerVers.version >= shimVers.version { + expectedVersion = runnerVers.version + } else { + shouldFail = true + expectedError = ":UNSUPPORTED_PROTOCOL:" + if runnerVers.version > VersionSSL30 { + expectedLocalError = "remote error: protocol version not supported" + } else { + expectedLocalError = "remote error: handshake failure" + } + } + + testCases = append(testCases, testCase{ + protocol: protocol, + testType: clientTest, + name: "MinimumVersion-Client-" + suffix, + config: Config{ + MaxVersion: runnerVers.version, + }, + flags: flags, + expectedVersion: expectedVersion, + shouldFail: shouldFail, + expectedError: expectedError, + expectedLocalError: expectedLocalError, + }) + testCases = append(testCases, testCase{ + protocol: protocol, + testType: clientTest, + name: "MinimumVersion-Client2-" + suffix, + config: Config{ + MaxVersion: runnerVers.version, + }, + flags: []string{"-min-version", shimVersFlag}, + expectedVersion: expectedVersion, + shouldFail: shouldFail, + expectedError: expectedError, + expectedLocalError: expectedLocalError, + }) + + testCases = append(testCases, testCase{ + protocol: protocol, + testType: serverTest, + name: "MinimumVersion-Server-" + suffix, + config: Config{ + MaxVersion: runnerVers.version, + }, + flags: flags, + expectedVersion: expectedVersion, + shouldFail: shouldFail, + expectedError: expectedError, + expectedLocalError: expectedLocalError, + }) + testCases = append(testCases, testCase{ + protocol: protocol, + testType: serverTest, + name: "MinimumVersion-Server2-" + suffix, + config: Config{ + MaxVersion: runnerVers.version, + }, + flags: []string{"-min-version", shimVersFlag}, + expectedVersion: expectedVersion, + shouldFail: shouldFail, + expectedError: expectedError, + expectedLocalError: expectedLocalError, + }) + } + } + } +} + +func addD5BugTests() { + testCases = append(testCases, testCase{ + testType: serverTest, + name: "D5Bug-NoQuirk-Reject", + config: Config{ + CipherSuites: []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256}, + Bugs: ProtocolBugs{ + SSL3RSAKeyExchange: true, + }, + }, + shouldFail: true, + expectedError: ":TLS_RSA_ENCRYPTED_VALUE_LENGTH_IS_WRONG:", + }) + testCases = append(testCases, testCase{ + testType: serverTest, + name: "D5Bug-Quirk-Normal", + config: Config{ + CipherSuites: []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256}, + }, + flags: []string{"-tls-d5-bug"}, + }) + testCases = append(testCases, testCase{ + testType: serverTest, + name: "D5Bug-Quirk-Bug", + config: Config{ + CipherSuites: []uint16{TLS_RSA_WITH_AES_128_GCM_SHA256}, + Bugs: ProtocolBugs{ + SSL3RSAKeyExchange: true, + }, + }, + flags: []string{"-tls-d5-bug"}, + }) +} + +func addExtensionTests() { + testCases = append(testCases, testCase{ + testType: clientTest, + name: "DuplicateExtensionClient", + config: Config{ + Bugs: ProtocolBugs{ + DuplicateExtension: true, + }, + }, + shouldFail: true, + expectedLocalError: "remote error: error decoding message", + }) + testCases = append(testCases, testCase{ + testType: serverTest, + name: "DuplicateExtensionServer", + config: Config{ + Bugs: ProtocolBugs{ + DuplicateExtension: true, + }, + }, + shouldFail: true, + expectedLocalError: "remote error: error decoding message", + }) + testCases = append(testCases, testCase{ + testType: clientTest, + name: "ServerNameExtensionClient", + config: Config{ + Bugs: ProtocolBugs{ + ExpectServerName: "example.com", + }, + }, + flags: []string{"-host-name", "example.com"}, + }) + testCases = append(testCases, testCase{ + testType: clientTest, + name: "ServerNameExtensionClient", + config: Config{ + Bugs: ProtocolBugs{ + ExpectServerName: "mismatch.com", + }, + }, + flags: []string{"-host-name", "example.com"}, + shouldFail: true, + expectedLocalError: "tls: unexpected server name", + }) + testCases = append(testCases, testCase{ + testType: clientTest, + name: "ServerNameExtensionClient", + config: Config{ + Bugs: ProtocolBugs{ + ExpectServerName: "missing.com", + }, + }, + shouldFail: true, + expectedLocalError: "tls: unexpected server name", + }) + testCases = append(testCases, testCase{ + testType: serverTest, + name: "ServerNameExtensionServer", + config: Config{ + ServerName: "example.com", + }, + flags: []string{"-expect-server-name", "example.com"}, + resumeSession: true, + }) + testCases = append(testCases, testCase{ + testType: clientTest, + name: "ALPNClient", + config: Config{ + NextProtos: []string{"foo"}, + }, + flags: []string{ + "-advertise-alpn", "\x03foo\x03bar\x03baz", + "-expect-alpn", "foo", + }, + expectedNextProto: "foo", + expectedNextProtoType: alpn, + resumeSession: true, + }) + testCases = append(testCases, testCase{ + testType: serverTest, + name: "ALPNServer", + config: Config{ + NextProtos: []string{"foo", "bar", "baz"}, + }, + flags: []string{ + "-expect-advertised-alpn", "\x03foo\x03bar\x03baz", + "-select-alpn", "foo", + }, + expectedNextProto: "foo", + expectedNextProtoType: alpn, + resumeSession: true, + }) + // Test that the server prefers ALPN over NPN. + testCases = append(testCases, testCase{ + testType: serverTest, + name: "ALPNServer-Preferred", + config: Config{ + NextProtos: []string{"foo", "bar", "baz"}, + }, + flags: []string{ + "-expect-advertised-alpn", "\x03foo\x03bar\x03baz", + "-select-alpn", "foo", + "-advertise-npn", "\x03foo\x03bar\x03baz", + }, + expectedNextProto: "foo", + expectedNextProtoType: alpn, + resumeSession: true, + }) + testCases = append(testCases, testCase{ + testType: serverTest, + name: "ALPNServer-Preferred-Swapped", + config: Config{ + NextProtos: []string{"foo", "bar", "baz"}, + Bugs: ProtocolBugs{ + SwapNPNAndALPN: true, + }, + }, + flags: []string{ + "-expect-advertised-alpn", "\x03foo\x03bar\x03baz", + "-select-alpn", "foo", + "-advertise-npn", "\x03foo\x03bar\x03baz", + }, + expectedNextProto: "foo", + expectedNextProtoType: alpn, + resumeSession: true, + }) + // Resume with a corrupt ticket. + testCases = append(testCases, testCase{ + testType: serverTest, + name: "CorruptTicket", + config: Config{ + Bugs: ProtocolBugs{ + CorruptTicket: true, + }, + }, + resumeSession: true, + flags: []string{"-expect-session-miss"}, + }) + // Resume with an oversized session id. + testCases = append(testCases, testCase{ + testType: serverTest, + name: "OversizedSessionId", + config: Config{ + Bugs: ProtocolBugs{ + OversizedSessionId: true, + }, + }, + resumeSession: true, + shouldFail: true, + expectedError: ":DECODE_ERROR:", + }) + // Basic DTLS-SRTP tests. Include fake profiles to ensure they + // are ignored. + testCases = append(testCases, testCase{ + protocol: dtls, + name: "SRTP-Client", + config: Config{ + SRTPProtectionProfiles: []uint16{40, SRTP_AES128_CM_HMAC_SHA1_80, 42}, + }, + flags: []string{ + "-srtp-profiles", + "SRTP_AES128_CM_SHA1_80:SRTP_AES128_CM_SHA1_32", + }, + expectedSRTPProtectionProfile: SRTP_AES128_CM_HMAC_SHA1_80, + }) + testCases = append(testCases, testCase{ + protocol: dtls, + testType: serverTest, + name: "SRTP-Server", + config: Config{ + SRTPProtectionProfiles: []uint16{40, SRTP_AES128_CM_HMAC_SHA1_80, 42}, + }, + flags: []string{ + "-srtp-profiles", + "SRTP_AES128_CM_SHA1_80:SRTP_AES128_CM_SHA1_32", + }, + expectedSRTPProtectionProfile: SRTP_AES128_CM_HMAC_SHA1_80, + }) + // Test that the MKI is ignored. + testCases = append(testCases, testCase{ + protocol: dtls, + testType: serverTest, + name: "SRTP-Server-IgnoreMKI", + config: Config{ + SRTPProtectionProfiles: []uint16{SRTP_AES128_CM_HMAC_SHA1_80}, + Bugs: ProtocolBugs{ + SRTPMasterKeyIdentifer: "bogus", + }, + }, + flags: []string{ + "-srtp-profiles", + "SRTP_AES128_CM_SHA1_80:SRTP_AES128_CM_SHA1_32", + }, + expectedSRTPProtectionProfile: SRTP_AES128_CM_HMAC_SHA1_80, + }) + // Test that SRTP isn't negotiated on the server if there were + // no matching profiles. + testCases = append(testCases, testCase{ + protocol: dtls, + testType: serverTest, + name: "SRTP-Server-NoMatch", + config: Config{ + SRTPProtectionProfiles: []uint16{100, 101, 102}, + }, + flags: []string{ + "-srtp-profiles", + "SRTP_AES128_CM_SHA1_80:SRTP_AES128_CM_SHA1_32", + }, + expectedSRTPProtectionProfile: 0, + }) + // Test that the server returning an invalid SRTP profile is + // flagged as an error by the client. + testCases = append(testCases, testCase{ + protocol: dtls, + name: "SRTP-Client-NoMatch", + config: Config{ + Bugs: ProtocolBugs{ + SendSRTPProtectionProfile: SRTP_AES128_CM_HMAC_SHA1_32, + }, + }, + flags: []string{ + "-srtp-profiles", + "SRTP_AES128_CM_SHA1_80", + }, + shouldFail: true, + expectedError: ":BAD_SRTP_PROTECTION_PROFILE_LIST:", + }) + // Test OCSP stapling and SCT list. + testCases = append(testCases, testCase{ + name: "OCSPStapling", + flags: []string{ + "-enable-ocsp-stapling", + "-expect-ocsp-response", + base64.StdEncoding.EncodeToString(testOCSPResponse), + }, + }) + testCases = append(testCases, testCase{ + name: "SignedCertificateTimestampList", + flags: []string{ + "-enable-signed-cert-timestamps", + "-expect-signed-cert-timestamps", + base64.StdEncoding.EncodeToString(testSCTList), + }, + }) +} + +func addResumptionVersionTests() { + for _, sessionVers := range tlsVersions { + for _, resumeVers := range tlsVersions { + protocols := []protocol{tls} + if sessionVers.hasDTLS && resumeVers.hasDTLS { + protocols = append(protocols, dtls) + } + for _, protocol := range protocols { + suffix := "-" + sessionVers.name + "-" + resumeVers.name + if protocol == dtls { + suffix += "-DTLS" + } + + testCases = append(testCases, testCase{ + protocol: protocol, + name: "Resume-Client" + suffix, + resumeSession: true, + config: Config{ + MaxVersion: sessionVers.version, + CipherSuites: []uint16{TLS_RSA_WITH_AES_128_CBC_SHA}, + Bugs: ProtocolBugs{ + AllowSessionVersionMismatch: true, + }, + }, + expectedVersion: sessionVers.version, + resumeConfig: &Config{ + MaxVersion: resumeVers.version, + CipherSuites: []uint16{TLS_RSA_WITH_AES_128_CBC_SHA}, + Bugs: ProtocolBugs{ + AllowSessionVersionMismatch: true, + }, + }, + expectedResumeVersion: resumeVers.version, + }) + + testCases = append(testCases, testCase{ + protocol: protocol, + name: "Resume-Client-NoResume" + suffix, + flags: []string{"-expect-session-miss"}, + resumeSession: true, + config: Config{ + MaxVersion: sessionVers.version, + CipherSuites: []uint16{TLS_RSA_WITH_AES_128_CBC_SHA}, + }, + expectedVersion: sessionVers.version, + resumeConfig: &Config{ + MaxVersion: resumeVers.version, + CipherSuites: []uint16{TLS_RSA_WITH_AES_128_CBC_SHA}, + }, + newSessionsOnResume: true, + expectedResumeVersion: resumeVers.version, + }) + + var flags []string + if sessionVers.version != resumeVers.version { + flags = append(flags, "-expect-session-miss") + } + testCases = append(testCases, testCase{ + protocol: protocol, + testType: serverTest, + name: "Resume-Server" + suffix, + flags: flags, + resumeSession: true, + config: Config{ + MaxVersion: sessionVers.version, + CipherSuites: []uint16{TLS_RSA_WITH_AES_128_CBC_SHA}, + }, + expectedVersion: sessionVers.version, + resumeConfig: &Config{ + MaxVersion: resumeVers.version, + CipherSuites: []uint16{TLS_RSA_WITH_AES_128_CBC_SHA}, + }, + expectedResumeVersion: resumeVers.version, + }) + } + } + } +} + +func addRenegotiationTests() { + testCases = append(testCases, testCase{ + testType: serverTest, + name: "Renegotiate-Server", + flags: []string{"-renegotiate"}, + shimWritesFirst: true, + }) + testCases = append(testCases, testCase{ + testType: serverTest, + name: "Renegotiate-Server-EmptyExt", + config: Config{ + Bugs: ProtocolBugs{ + EmptyRenegotiationInfo: true, + }, + }, + flags: []string{"-renegotiate"}, + shimWritesFirst: true, + shouldFail: true, + expectedError: ":RENEGOTIATION_MISMATCH:", + }) + testCases = append(testCases, testCase{ + testType: serverTest, + name: "Renegotiate-Server-BadExt", + config: Config{ + Bugs: ProtocolBugs{ + BadRenegotiationInfo: true, + }, + }, + flags: []string{"-renegotiate"}, + shimWritesFirst: true, + shouldFail: true, + expectedError: ":RENEGOTIATION_MISMATCH:", + }) + testCases = append(testCases, testCase{ + testType: serverTest, + name: "Renegotiate-Server-ClientInitiated", + renegotiate: true, + }) + testCases = append(testCases, testCase{ + testType: serverTest, + name: "Renegotiate-Server-ClientInitiated-NoExt", + renegotiate: true, + config: Config{ + Bugs: ProtocolBugs{ + NoRenegotiationInfo: true, + }, + }, + shouldFail: true, + expectedError: ":UNSAFE_LEGACY_RENEGOTIATION_DISABLED:", + }) + testCases = append(testCases, testCase{ + testType: serverTest, + name: "Renegotiate-Server-ClientInitiated-NoExt-Allowed", + renegotiate: true, + config: Config{ + Bugs: ProtocolBugs{ + NoRenegotiationInfo: true, + }, + }, + flags: []string{"-allow-unsafe-legacy-renegotiation"}, + }) + // TODO(agl): test the renegotiation info SCSV. + testCases = append(testCases, testCase{ + name: "Renegotiate-Client", + renegotiate: true, + }) + testCases = append(testCases, testCase{ + name: "Renegotiate-Client-EmptyExt", + renegotiate: true, + config: Config{ + Bugs: ProtocolBugs{ + EmptyRenegotiationInfo: true, + }, + }, + shouldFail: true, + expectedError: ":RENEGOTIATION_MISMATCH:", + }) + testCases = append(testCases, testCase{ + name: "Renegotiate-Client-BadExt", + renegotiate: true, + config: Config{ + Bugs: ProtocolBugs{ + BadRenegotiationInfo: true, + }, + }, + shouldFail: true, + expectedError: ":RENEGOTIATION_MISMATCH:", + }) + testCases = append(testCases, testCase{ + name: "Renegotiate-Client-SwitchCiphers", + renegotiate: true, + config: Config{ + CipherSuites: []uint16{TLS_RSA_WITH_RC4_128_SHA}, + }, + renegotiateCiphers: []uint16{TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256}, + }) + testCases = append(testCases, testCase{ + name: "Renegotiate-Client-SwitchCiphers2", + renegotiate: true, + config: Config{ + CipherSuites: []uint16{TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256}, + }, + renegotiateCiphers: []uint16{TLS_RSA_WITH_RC4_128_SHA}, + }) + testCases = append(testCases, testCase{ + name: "Renegotiate-SameClientVersion", + renegotiate: true, + config: Config{ + MaxVersion: VersionTLS10, + Bugs: ProtocolBugs{ + RequireSameRenegoClientVersion: true, + }, + }, + }) +} + +func addDTLSReplayTests() { + // Test that sequence number replays are detected. + testCases = append(testCases, testCase{ + protocol: dtls, + name: "DTLS-Replay", + replayWrites: true, + }) + + // Test the outgoing sequence number skipping by values larger + // than the retransmit window. + testCases = append(testCases, testCase{ + protocol: dtls, + name: "DTLS-Replay-LargeGaps", + config: Config{ + Bugs: ProtocolBugs{ + SequenceNumberIncrement: 127, + }, + }, + replayWrites: true, + }) +} + +func addFastRadioPaddingTests() { + testCases = append(testCases, testCase{ + protocol: tls, + name: "FastRadio-Padding", + config: Config{ + Bugs: ProtocolBugs{ + RequireFastradioPadding: true, + }, + }, + flags: []string{"-fastradio-padding"}, + }) + testCases = append(testCases, testCase{ + protocol: dtls, + name: "FastRadio-Padding", + config: Config{ + Bugs: ProtocolBugs{ + RequireFastradioPadding: true, + }, + }, + flags: []string{"-fastradio-padding"}, + }) +} + +var testHashes = []struct { + name string + id uint8 +}{ + {"SHA1", hashSHA1}, + {"SHA224", hashSHA224}, + {"SHA256", hashSHA256}, + {"SHA384", hashSHA384}, + {"SHA512", hashSHA512}, +} + +func addSigningHashTests() { + // Make sure each hash works. Include some fake hashes in the list and + // ensure they're ignored. + for _, hash := range testHashes { + testCases = append(testCases, testCase{ + name: "SigningHash-ClientAuth-" + hash.name, + config: Config{ + ClientAuth: RequireAnyClientCert, + SignatureAndHashes: []signatureAndHash{ + {signatureRSA, 42}, + {signatureRSA, hash.id}, + {signatureRSA, 255}, + }, + }, + flags: []string{ + "-cert-file", rsaCertificateFile, + "-key-file", rsaKeyFile, + }, + }) + + testCases = append(testCases, testCase{ + testType: serverTest, + name: "SigningHash-ServerKeyExchange-Sign-" + hash.name, + config: Config{ + CipherSuites: []uint16{TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256}, + SignatureAndHashes: []signatureAndHash{ + {signatureRSA, 42}, + {signatureRSA, hash.id}, + {signatureRSA, 255}, + }, + }, + }) + } + + // Test that hash resolution takes the signature type into account. + testCases = append(testCases, testCase{ + name: "SigningHash-ClientAuth-SignatureType", + config: Config{ + ClientAuth: RequireAnyClientCert, + SignatureAndHashes: []signatureAndHash{ + {signatureECDSA, hashSHA512}, + {signatureRSA, hashSHA384}, + {signatureECDSA, hashSHA1}, + }, + }, + flags: []string{ + "-cert-file", rsaCertificateFile, + "-key-file", rsaKeyFile, + }, + }) + + testCases = append(testCases, testCase{ + testType: serverTest, + name: "SigningHash-ServerKeyExchange-SignatureType", + config: Config{ + CipherSuites: []uint16{TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256}, + SignatureAndHashes: []signatureAndHash{ + {signatureECDSA, hashSHA512}, + {signatureRSA, hashSHA384}, + {signatureECDSA, hashSHA1}, + }, + }, + }) + + // Test that, if the list is missing, the peer falls back to SHA-1. + testCases = append(testCases, testCase{ + name: "SigningHash-ClientAuth-Fallback", + config: Config{ + ClientAuth: RequireAnyClientCert, + SignatureAndHashes: []signatureAndHash{ + {signatureRSA, hashSHA1}, + }, + Bugs: ProtocolBugs{ + NoSignatureAndHashes: true, + }, + }, + flags: []string{ + "-cert-file", rsaCertificateFile, + "-key-file", rsaKeyFile, + }, + }) + + testCases = append(testCases, testCase{ + testType: serverTest, + name: "SigningHash-ServerKeyExchange-Fallback", + config: Config{ + CipherSuites: []uint16{TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256}, + SignatureAndHashes: []signatureAndHash{ + {signatureRSA, hashSHA1}, + }, + Bugs: ProtocolBugs{ + NoSignatureAndHashes: true, + }, + }, + }) +} + +func worker(statusChan chan statusMsg, c chan *testCase, buildDir string, wg *sync.WaitGroup) { + defer wg.Done() + + for test := range c { + var err error + + if *mallocTest < 0 { + statusChan <- statusMsg{test: test, started: true} + err = runTest(test, buildDir, -1) + } else { + for mallocNumToFail := int64(*mallocTest); ; mallocNumToFail++ { + statusChan <- statusMsg{test: test, started: true} + if err = runTest(test, buildDir, mallocNumToFail); err != errMoreMallocs { + if err != nil { + fmt.Printf("\n\nmalloc test failed at %d: %s\n", mallocNumToFail, err) + } + break + } + } + } + statusChan <- statusMsg{test: test, err: err} + } +} + +type statusMsg struct { + test *testCase + started bool + err error +} + +func statusPrinter(doneChan chan struct{}, statusChan chan statusMsg, total int) { + var started, done, failed, lineLen int + defer close(doneChan) + + for msg := range statusChan { + if msg.started { + started++ + } else { + done++ + } + + fmt.Printf("\x1b[%dD\x1b[K", lineLen) + + if msg.err != nil { + fmt.Printf("FAILED (%s)\n%s\n", msg.test.name, msg.err) + failed++ + } + line := fmt.Sprintf("%d/%d/%d/%d", failed, done, started, total) + lineLen = len(line) + os.Stdout.WriteString(line) + } +} + +func main() { + var flagTest *string = flag.String("test", "", "The name of a test to run, or empty to run all tests") + var flagNumWorkers *int = flag.Int("num-workers", runtime.NumCPU(), "The number of workers to run in parallel.") + var flagBuildDir *string = flag.String("build-dir", "../../../build", "The build directory to run the shim from.") + + flag.Parse() + + addCipherSuiteTests() + addBadECDSASignatureTests() + addCBCPaddingTests() + addCBCSplittingTests() + addClientAuthTests() + addVersionNegotiationTests() + addMinimumVersionTests() + addD5BugTests() + addExtensionTests() + addResumptionVersionTests() + addExtendedMasterSecretTests() + addRenegotiationTests() + addDTLSReplayTests() + addSigningHashTests() + addFastRadioPaddingTests() + for _, async := range []bool{false, true} { + for _, splitHandshake := range []bool{false, true} { + for _, protocol := range []protocol{tls, dtls} { + addStateMachineCoverageTests(async, splitHandshake, protocol) + } + } + } + + var wg sync.WaitGroup + + numWorkers := *flagNumWorkers + + statusChan := make(chan statusMsg, numWorkers) + testChan := make(chan *testCase, numWorkers) + doneChan := make(chan struct{}) + + go statusPrinter(doneChan, statusChan, len(testCases)) + + for i := 0; i < numWorkers; i++ { + wg.Add(1) + go worker(statusChan, testChan, *flagBuildDir, &wg) + } + + for i := range testCases { + if len(*flagTest) == 0 || *flagTest == testCases[i].name { + testChan <- &testCases[i] + } + } + + close(testChan) + wg.Wait() + close(statusChan) + <-doneChan + + fmt.Printf("\n") +} diff --git a/src/ssl/test/runner/ticket.go b/src/ssl/test/runner/ticket.go new file mode 100644 index 0000000..8355822 --- /dev/null +++ b/src/ssl/test/runner/ticket.go @@ -0,0 +1,221 @@ +// Copyright 2012 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "bytes" + "crypto/aes" + "crypto/cipher" + "crypto/hmac" + "crypto/sha256" + "crypto/subtle" + "errors" + "io" +) + +// sessionState contains the information that is serialized into a session +// ticket in order to later resume a connection. +type sessionState struct { + vers uint16 + cipherSuite uint16 + masterSecret []byte + handshakeHash []byte + certificates [][]byte + extendedMasterSecret bool +} + +func (s *sessionState) equal(i interface{}) bool { + s1, ok := i.(*sessionState) + if !ok { + return false + } + + if s.vers != s1.vers || + s.cipherSuite != s1.cipherSuite || + !bytes.Equal(s.masterSecret, s1.masterSecret) || + !bytes.Equal(s.handshakeHash, s1.handshakeHash) || + s.extendedMasterSecret != s1.extendedMasterSecret { + return false + } + + if len(s.certificates) != len(s1.certificates) { + return false + } + + for i := range s.certificates { + if !bytes.Equal(s.certificates[i], s1.certificates[i]) { + return false + } + } + + return true +} + +func (s *sessionState) marshal() []byte { + length := 2 + 2 + 2 + len(s.masterSecret) + 2 + len(s.handshakeHash) + 2 + for _, cert := range s.certificates { + length += 4 + len(cert) + } + length++ + + ret := make([]byte, length) + x := ret + x[0] = byte(s.vers >> 8) + x[1] = byte(s.vers) + x[2] = byte(s.cipherSuite >> 8) + x[3] = byte(s.cipherSuite) + x[4] = byte(len(s.masterSecret) >> 8) + x[5] = byte(len(s.masterSecret)) + x = x[6:] + copy(x, s.masterSecret) + x = x[len(s.masterSecret):] + + x[0] = byte(len(s.handshakeHash) >> 8) + x[1] = byte(len(s.handshakeHash)) + x = x[2:] + copy(x, s.handshakeHash) + x = x[len(s.handshakeHash):] + + x[0] = byte(len(s.certificates) >> 8) + x[1] = byte(len(s.certificates)) + x = x[2:] + + for _, cert := range s.certificates { + x[0] = byte(len(cert) >> 24) + x[1] = byte(len(cert) >> 16) + x[2] = byte(len(cert) >> 8) + x[3] = byte(len(cert)) + copy(x[4:], cert) + x = x[4+len(cert):] + } + + if s.extendedMasterSecret { + x[0] = 1 + } + x = x[1:] + + return ret +} + +func (s *sessionState) unmarshal(data []byte) bool { + if len(data) < 8 { + return false + } + + s.vers = uint16(data[0])<<8 | uint16(data[1]) + s.cipherSuite = uint16(data[2])<<8 | uint16(data[3]) + masterSecretLen := int(data[4])<<8 | int(data[5]) + data = data[6:] + if len(data) < masterSecretLen { + return false + } + + s.masterSecret = data[:masterSecretLen] + data = data[masterSecretLen:] + + if len(data) < 2 { + return false + } + + handshakeHashLen := int(data[0])<<8 | int(data[1]) + data = data[2:] + if len(data) < handshakeHashLen { + return false + } + + s.handshakeHash = data[:handshakeHashLen] + data = data[handshakeHashLen:] + + if len(data) < 2 { + return false + } + + numCerts := int(data[0])<<8 | int(data[1]) + data = data[2:] + + s.certificates = make([][]byte, numCerts) + for i := range s.certificates { + if len(data) < 4 { + return false + } + certLen := int(data[0])<<24 | int(data[1])<<16 | int(data[2])<<8 | int(data[3]) + data = data[4:] + if certLen < 0 { + return false + } + if len(data) < certLen { + return false + } + s.certificates[i] = data[:certLen] + data = data[certLen:] + } + + if len(data) < 1 { + return false + } + + s.extendedMasterSecret = false + if data[0] == 1 { + s.extendedMasterSecret = true + } + data = data[1:] + + if len(data) > 0 { + return false + } + + return true +} + +func (c *Conn) encryptTicket(state *sessionState) ([]byte, error) { + serialized := state.marshal() + encrypted := make([]byte, aes.BlockSize+len(serialized)+sha256.Size) + iv := encrypted[:aes.BlockSize] + macBytes := encrypted[len(encrypted)-sha256.Size:] + + if _, err := io.ReadFull(c.config.rand(), iv); err != nil { + return nil, err + } + block, err := aes.NewCipher(c.config.SessionTicketKey[:16]) + if err != nil { + return nil, errors.New("tls: failed to create cipher while encrypting ticket: " + err.Error()) + } + cipher.NewCTR(block, iv).XORKeyStream(encrypted[aes.BlockSize:], serialized) + + mac := hmac.New(sha256.New, c.config.SessionTicketKey[16:32]) + mac.Write(encrypted[:len(encrypted)-sha256.Size]) + mac.Sum(macBytes[:0]) + + return encrypted, nil +} + +func (c *Conn) decryptTicket(encrypted []byte) (*sessionState, bool) { + if len(encrypted) < aes.BlockSize+sha256.Size { + return nil, false + } + + iv := encrypted[:aes.BlockSize] + macBytes := encrypted[len(encrypted)-sha256.Size:] + + mac := hmac.New(sha256.New, c.config.SessionTicketKey[16:32]) + mac.Write(encrypted[:len(encrypted)-sha256.Size]) + expected := mac.Sum(nil) + + if subtle.ConstantTimeCompare(macBytes, expected) != 1 { + return nil, false + } + + block, err := aes.NewCipher(c.config.SessionTicketKey[:16]) + if err != nil { + return nil, false + } + ciphertext := encrypted[aes.BlockSize : len(encrypted)-sha256.Size] + plaintext := make([]byte, len(ciphertext)) + cipher.NewCTR(block, iv).XORKeyStream(plaintext, ciphertext) + + state := new(sessionState) + ok := state.unmarshal(plaintext) + return state, ok +} diff --git a/src/ssl/test/runner/tls.go b/src/ssl/test/runner/tls.go new file mode 100644 index 0000000..6b637c8 --- /dev/null +++ b/src/ssl/test/runner/tls.go @@ -0,0 +1,279 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// Package tls partially implements TLS 1.2, as specified in RFC 5246. +package main + +import ( + "crypto" + "crypto/ecdsa" + "crypto/rsa" + "crypto/x509" + "encoding/pem" + "errors" + "io/ioutil" + "net" + "strings" + "time" +) + +// Server returns a new TLS server side connection +// using conn as the underlying transport. +// The configuration config must be non-nil and must have +// at least one certificate. +func Server(conn net.Conn, config *Config) *Conn { + c := &Conn{conn: conn, config: config} + c.init() + return c +} + +// Client returns a new TLS client side connection +// using conn as the underlying transport. +// The config cannot be nil: users must set either ServerHostname or +// InsecureSkipVerify in the config. +func Client(conn net.Conn, config *Config) *Conn { + c := &Conn{conn: conn, config: config, isClient: true} + c.init() + return c +} + +// A listener implements a network listener (net.Listener) for TLS connections. +type listener struct { + net.Listener + config *Config +} + +// Accept waits for and returns the next incoming TLS connection. +// The returned connection c is a *tls.Conn. +func (l *listener) Accept() (c net.Conn, err error) { + c, err = l.Listener.Accept() + if err != nil { + return + } + c = Server(c, l.config) + return +} + +// NewListener creates a Listener which accepts connections from an inner +// Listener and wraps each connection with Server. +// The configuration config must be non-nil and must have +// at least one certificate. +func NewListener(inner net.Listener, config *Config) net.Listener { + l := new(listener) + l.Listener = inner + l.config = config + return l +} + +// Listen creates a TLS listener accepting connections on the +// given network address using net.Listen. +// The configuration config must be non-nil and must have +// at least one certificate. +func Listen(network, laddr string, config *Config) (net.Listener, error) { + if config == nil || len(config.Certificates) == 0 { + return nil, errors.New("tls.Listen: no certificates in configuration") + } + l, err := net.Listen(network, laddr) + if err != nil { + return nil, err + } + return NewListener(l, config), nil +} + +type timeoutError struct{} + +func (timeoutError) Error() string { return "tls: DialWithDialer timed out" } +func (timeoutError) Timeout() bool { return true } +func (timeoutError) Temporary() bool { return true } + +// DialWithDialer connects to the given network address using dialer.Dial and +// then initiates a TLS handshake, returning the resulting TLS connection. Any +// timeout or deadline given in the dialer apply to connection and TLS +// handshake as a whole. +// +// DialWithDialer interprets a nil configuration as equivalent to the zero +// configuration; see the documentation of Config for the defaults. +func DialWithDialer(dialer *net.Dialer, network, addr string, config *Config) (*Conn, error) { + // We want the Timeout and Deadline values from dialer to cover the + // whole process: TCP connection and TLS handshake. This means that we + // also need to start our own timers now. + timeout := dialer.Timeout + + if !dialer.Deadline.IsZero() { + deadlineTimeout := dialer.Deadline.Sub(time.Now()) + if timeout == 0 || deadlineTimeout < timeout { + timeout = deadlineTimeout + } + } + + var errChannel chan error + + if timeout != 0 { + errChannel = make(chan error, 2) + time.AfterFunc(timeout, func() { + errChannel <- timeoutError{} + }) + } + + rawConn, err := dialer.Dial(network, addr) + if err != nil { + return nil, err + } + + colonPos := strings.LastIndex(addr, ":") + if colonPos == -1 { + colonPos = len(addr) + } + hostname := addr[:colonPos] + + if config == nil { + config = defaultConfig() + } + // If no ServerName is set, infer the ServerName + // from the hostname we're connecting to. + if config.ServerName == "" { + // Make a copy to avoid polluting argument or default. + c := *config + c.ServerName = hostname + config = &c + } + + conn := Client(rawConn, config) + + if timeout == 0 { + err = conn.Handshake() + } else { + go func() { + errChannel <- conn.Handshake() + }() + + err = <-errChannel + } + + if err != nil { + rawConn.Close() + return nil, err + } + + return conn, nil +} + +// Dial connects to the given network address using net.Dial +// and then initiates a TLS handshake, returning the resulting +// TLS connection. +// Dial interprets a nil configuration as equivalent to +// the zero configuration; see the documentation of Config +// for the defaults. +func Dial(network, addr string, config *Config) (*Conn, error) { + return DialWithDialer(new(net.Dialer), network, addr, config) +} + +// LoadX509KeyPair reads and parses a public/private key pair from a pair of +// files. The files must contain PEM encoded data. +func LoadX509KeyPair(certFile, keyFile string) (cert Certificate, err error) { + certPEMBlock, err := ioutil.ReadFile(certFile) + if err != nil { + return + } + keyPEMBlock, err := ioutil.ReadFile(keyFile) + if err != nil { + return + } + return X509KeyPair(certPEMBlock, keyPEMBlock) +} + +// X509KeyPair parses a public/private key pair from a pair of +// PEM encoded data. +func X509KeyPair(certPEMBlock, keyPEMBlock []byte) (cert Certificate, err error) { + var certDERBlock *pem.Block + for { + certDERBlock, certPEMBlock = pem.Decode(certPEMBlock) + if certDERBlock == nil { + break + } + if certDERBlock.Type == "CERTIFICATE" { + cert.Certificate = append(cert.Certificate, certDERBlock.Bytes) + } + } + + if len(cert.Certificate) == 0 { + err = errors.New("crypto/tls: failed to parse certificate PEM data") + return + } + + var keyDERBlock *pem.Block + for { + keyDERBlock, keyPEMBlock = pem.Decode(keyPEMBlock) + if keyDERBlock == nil { + err = errors.New("crypto/tls: failed to parse key PEM data") + return + } + if keyDERBlock.Type == "PRIVATE KEY" || strings.HasSuffix(keyDERBlock.Type, " PRIVATE KEY") { + break + } + } + + cert.PrivateKey, err = parsePrivateKey(keyDERBlock.Bytes) + if err != nil { + return + } + + // We don't need to parse the public key for TLS, but we so do anyway + // to check that it looks sane and matches the private key. + x509Cert, err := x509.ParseCertificate(cert.Certificate[0]) + if err != nil { + return + } + + switch pub := x509Cert.PublicKey.(type) { + case *rsa.PublicKey: + priv, ok := cert.PrivateKey.(*rsa.PrivateKey) + if !ok { + err = errors.New("crypto/tls: private key type does not match public key type") + return + } + if pub.N.Cmp(priv.N) != 0 { + err = errors.New("crypto/tls: private key does not match public key") + return + } + case *ecdsa.PublicKey: + priv, ok := cert.PrivateKey.(*ecdsa.PrivateKey) + if !ok { + err = errors.New("crypto/tls: private key type does not match public key type") + return + + } + if pub.X.Cmp(priv.X) != 0 || pub.Y.Cmp(priv.Y) != 0 { + err = errors.New("crypto/tls: private key does not match public key") + return + } + default: + err = errors.New("crypto/tls: unknown public key algorithm") + return + } + + return +} + +// Attempt to parse the given private key DER block. OpenSSL 0.9.8 generates +// PKCS#1 private keys by default, while OpenSSL 1.0.0 generates PKCS#8 keys. +// OpenSSL ecparam generates SEC1 EC private keys for ECDSA. We try all three. +func parsePrivateKey(der []byte) (crypto.PrivateKey, error) { + if key, err := x509.ParsePKCS1PrivateKey(der); err == nil { + return key, nil + } + if key, err := x509.ParsePKCS8PrivateKey(der); err == nil { + switch key := key.(type) { + case *rsa.PrivateKey, *ecdsa.PrivateKey: + return key, nil + default: + return nil, errors.New("crypto/tls: found unknown private key type in PKCS#8 wrapping") + } + } + if key, err := x509.ParseECPrivateKey(der); err == nil { + return key, nil + } + + return nil, errors.New("crypto/tls: failed to parse private key") +} |