// Copyright 2014 The Go Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. // DTLS implementation. // // NOTE: This is a not even a remotely production-quality DTLS // implementation. It is the bare minimum necessary to be able to // achieve coverage on BoringSSL's implementation. Of note is that // this implementation assumes the underlying net.PacketConn is not // only reliable but also ordered. BoringSSL will be expected to deal // with simulated loss, but there is no point in forcing the test // driver to. package main import ( "bytes" "errors" "fmt" "io" "math/rand" "net" ) func versionToWire(vers uint16, isDTLS bool) uint16 { if isDTLS { return ^(vers - 0x0201) } return vers } func wireToVersion(vers uint16, isDTLS bool) uint16 { if isDTLS { return ^vers + 0x0201 } return vers } func (c *Conn) dtlsDoReadRecord(want recordType) (recordType, *block, error) { recordHeaderLen := dtlsRecordHeaderLen if c.rawInput == nil { c.rawInput = c.in.newBlock() } b := c.rawInput // Read a new packet only if the current one is empty. if len(b.data) == 0 { // Pick some absurdly large buffer size. b.resize(maxCiphertext + recordHeaderLen) n, err := c.conn.Read(c.rawInput.data) if err != nil { return 0, nil, err } if c.config.Bugs.MaxPacketLength != 0 && n > c.config.Bugs.MaxPacketLength { return 0, nil, fmt.Errorf("dtls: exceeded maximum packet length") } c.rawInput.resize(n) } // Read out one record. // // A real DTLS implementation should be tolerant of errors, // but this is test code. We should not be tolerant of our // peer sending garbage. if len(b.data) < recordHeaderLen { return 0, nil, errors.New("dtls: failed to read record header") } typ := recordType(b.data[0]) vers := wireToVersion(uint16(b.data[1])<<8|uint16(b.data[2]), c.isDTLS) if c.haveVers { if vers != c.vers { c.sendAlert(alertProtocolVersion) return 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: received record with version %x when expecting version %x", vers, c.vers)) } } else { if expect := c.config.Bugs.ExpectInitialRecordVersion; expect != 0 && vers != expect { c.sendAlert(alertProtocolVersion) return 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: received record with version %x when expecting version %x", vers, expect)) } } seq := b.data[3:11] // For test purposes, we assume a reliable channel. Require // that the explicit sequence number matches the incrementing // one we maintain. A real implementation would maintain a // replay window and such. if !bytes.Equal(seq, c.in.seq[:]) { c.sendAlert(alertIllegalParameter) return 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: bad sequence number")) } n := int(b.data[11])<<8 | int(b.data[12]) if n > maxCiphertext || len(b.data) < recordHeaderLen+n { c.sendAlert(alertRecordOverflow) return 0, nil, c.in.setErrorLocked(fmt.Errorf("dtls: oversized record received with length %d", n)) } // Process message. b, c.rawInput = c.in.splitBlock(b, recordHeaderLen+n) ok, off, err := c.in.decrypt(b) if !ok { c.in.setErrorLocked(c.sendAlert(err)) } b.off = off return typ, b, nil } func (c *Conn) makeFragment(header, data []byte, fragOffset, fragLen int) []byte { fragment := make([]byte, 0, 12+fragLen) fragment = append(fragment, header...) fragment = append(fragment, byte(c.sendHandshakeSeq>>8), byte(c.sendHandshakeSeq)) fragment = append(fragment, byte(fragOffset>>16), byte(fragOffset>>8), byte(fragOffset)) fragment = append(fragment, byte(fragLen>>16), byte(fragLen>>8), byte(fragLen)) fragment = append(fragment, data[fragOffset:fragOffset+fragLen]...) return fragment } func (c *Conn) dtlsWriteRecord(typ recordType, data []byte) (n int, err error) { if typ != recordTypeHandshake { // Only handshake messages are fragmented. return c.dtlsWriteRawRecord(typ, data) } maxLen := c.config.Bugs.MaxHandshakeRecordLength if maxLen <= 0 { maxLen = 1024 } // Handshake messages have to be modified to include fragment // offset and length and with the header replicated. Save the // TLS header here. // // TODO(davidben): This assumes that data contains exactly one // handshake message. This is incompatible with // FragmentAcrossChangeCipherSpec. (Which is unfortunate // because OpenSSL's DTLS implementation will probably accept // such fragmentation and could do with a fix + tests.) header := data[:4] data = data[4:] isFinished := header[0] == typeFinished if c.config.Bugs.SendEmptyFragments { fragment := c.makeFragment(header, data, 0, 0) c.pendingFragments = append(c.pendingFragments, fragment) } firstRun := true fragOffset := 0 for firstRun || fragOffset < len(data) { firstRun = false fragLen := len(data) - fragOffset if fragLen > maxLen { fragLen = maxLen } fragment := c.makeFragment(header, data, fragOffset, fragLen) if c.config.Bugs.FragmentMessageTypeMismatch && fragOffset > 0 { fragment[0]++ } if c.config.Bugs.FragmentMessageLengthMismatch && fragOffset > 0 { fragment[3]++ } // Buffer the fragment for later. They will be sent (and // reordered) on flush. c.pendingFragments = append(c.pendingFragments, fragment) if c.config.Bugs.ReorderHandshakeFragments { // Don't duplicate Finished to avoid the peer // interpreting it as a retransmit request. if !isFinished { c.pendingFragments = append(c.pendingFragments, fragment) } if fragLen > (maxLen+1)/2 { // Overlap each fragment by half. fragLen = (maxLen + 1) / 2 } } fragOffset += fragLen n += fragLen } if !isFinished && c.config.Bugs.MixCompleteMessageWithFragments { fragment := c.makeFragment(header, data, 0, len(data)) c.pendingFragments = append(c.pendingFragments, fragment) } // Increment the handshake sequence number for the next // handshake message. c.sendHandshakeSeq++ return } func (c *Conn) dtlsFlushHandshake() error { if !c.isDTLS { return nil } // This is a test-only DTLS implementation, so there is no need to // retain |c.pendingFragments| for a future retransmit. var fragments [][]byte fragments, c.pendingFragments = c.pendingFragments, fragments if c.config.Bugs.ReorderHandshakeFragments { perm := rand.New(rand.NewSource(0)).Perm(len(fragments)) tmp := make([][]byte, len(fragments)) for i := range tmp { tmp[i] = fragments[perm[i]] } fragments = tmp } maxRecordLen := c.config.Bugs.PackHandshakeFragments maxPacketLen := c.config.Bugs.PackHandshakeRecords // Pack handshake fragments into records. var records [][]byte for _, fragment := range fragments { if c.config.Bugs.SplitFragmentHeader { records = append(records, fragment[:2]) records = append(records, fragment[2:]) } else if c.config.Bugs.SplitFragmentBody { if len(fragment) > 12 { records = append(records, fragment[:13]) records = append(records, fragment[13:]) } else { records = append(records, fragment) } } else if i := len(records) - 1; len(records) > 0 && len(records[i])+len(fragment) <= maxRecordLen { records[i] = append(records[i], fragment...) } else { // The fragment will be appended to, so copy it. records = append(records, append([]byte{}, fragment...)) } } // Format them into packets. var packets [][]byte for _, record := range records { b, err := c.dtlsSealRecord(recordTypeHandshake, record) if err != nil { return err } if i := len(packets) - 1; len(packets) > 0 && len(packets[i])+len(b.data) <= maxPacketLen { packets[i] = append(packets[i], b.data...) } else { // The sealed record will be appended to and reused by // |c.out|, so copy it. packets = append(packets, append([]byte{}, b.data...)) } c.out.freeBlock(b) } // Send all the packets. for _, packet := range packets { if _, err := c.conn.Write(packet); err != nil { return err } } return nil } // dtlsSealRecord seals a record into a block from |c.out|'s pool. func (c *Conn) dtlsSealRecord(typ recordType, data []byte) (b *block, err error) { recordHeaderLen := dtlsRecordHeaderLen maxLen := c.config.Bugs.MaxHandshakeRecordLength if maxLen <= 0 { maxLen = 1024 } b = c.out.newBlock() explicitIVLen := 0 explicitIVIsSeq := false if cbc, ok := c.out.cipher.(cbcMode); ok { // Block cipher modes have an explicit IV. explicitIVLen = cbc.BlockSize() } else if aead, ok := c.out.cipher.(*tlsAead); ok { if aead.explicitNonce { explicitIVLen = 8 // The AES-GCM construction in TLS has an explicit nonce so that // the nonce can be random. However, the nonce is only 8 bytes // which is too small for a secure, random nonce. Therefore we // use the sequence number as the nonce. explicitIVIsSeq = true } } else if c.out.cipher != nil { panic("Unknown cipher") } b.resize(recordHeaderLen + explicitIVLen + len(data)) b.data[0] = byte(typ) vers := c.vers if vers == 0 { // Some TLS servers fail if the record version is greater than // TLS 1.0 for the initial ClientHello. vers = VersionTLS10 } vers = versionToWire(vers, c.isDTLS) b.data[1] = byte(vers >> 8) b.data[2] = byte(vers) // DTLS records include an explicit sequence number. copy(b.data[3:11], c.out.seq[0:]) b.data[11] = byte(len(data) >> 8) b.data[12] = byte(len(data)) if explicitIVLen > 0 { explicitIV := b.data[recordHeaderLen : recordHeaderLen+explicitIVLen] if explicitIVIsSeq { copy(explicitIV, c.out.seq[:]) } else { if _, err = io.ReadFull(c.config.rand(), explicitIV); err != nil { return } } } copy(b.data[recordHeaderLen+explicitIVLen:], data) c.out.encrypt(b, explicitIVLen) return } func (c *Conn) dtlsWriteRawRecord(typ recordType, data []byte) (n int, err error) { b, err := c.dtlsSealRecord(typ, data) if err != nil { return } _, err = c.conn.Write(b.data) if err != nil { return } n = len(data) c.out.freeBlock(b) if typ == recordTypeChangeCipherSpec { err = c.out.changeCipherSpec(c.config) if err != nil { // Cannot call sendAlert directly, // because we already hold c.out.Mutex. c.tmp[0] = alertLevelError c.tmp[1] = byte(err.(alert)) c.writeRecord(recordTypeAlert, c.tmp[0:2]) return n, c.out.setErrorLocked(&net.OpError{Op: "local error", Err: err}) } } return } func (c *Conn) dtlsDoReadHandshake() ([]byte, error) { // Assemble a full handshake message. For test purposes, this // implementation assumes fragments arrive in order. It may // need to be cleverer if we ever test BoringSSL's retransmit // behavior. for len(c.handMsg) < 4+c.handMsgLen { // Get a new handshake record if the previous has been // exhausted. if c.hand.Len() == 0 { if err := c.in.err; err != nil { return nil, err } if err := c.readRecord(recordTypeHandshake); err != nil { return nil, err } } // Read the next fragment. It must fit entirely within // the record. if c.hand.Len() < 12 { return nil, errors.New("dtls: bad handshake record") } header := c.hand.Next(12) fragN := int(header[1])<<16 | int(header[2])<<8 | int(header[3]) fragSeq := uint16(header[4])<<8 | uint16(header[5]) fragOff := int(header[6])<<16 | int(header[7])<<8 | int(header[8]) fragLen := int(header[9])<<16 | int(header[10])<<8 | int(header[11]) if c.hand.Len() < fragLen { return nil, errors.New("dtls: fragment length too long") } fragment := c.hand.Next(fragLen) // Check it's a fragment for the right message. if fragSeq != c.recvHandshakeSeq { return nil, errors.New("dtls: bad handshake sequence number") } // Check that the length is consistent. if c.handMsg == nil { c.handMsgLen = fragN if c.handMsgLen > maxHandshake { return nil, c.in.setErrorLocked(c.sendAlert(alertInternalError)) } // Start with the TLS handshake header, // without the DTLS bits. c.handMsg = append([]byte{}, header[:4]...) } else if fragN != c.handMsgLen { return nil, errors.New("dtls: bad handshake length") } // Add the fragment to the pending message. if 4+fragOff != len(c.handMsg) { return nil, errors.New("dtls: bad fragment offset") } if fragOff+fragLen > c.handMsgLen { return nil, errors.New("dtls: bad fragment length") } c.handMsg = append(c.handMsg, fragment...) } c.recvHandshakeSeq++ ret := c.handMsg c.handMsg, c.handMsgLen = nil, 0 return ret, nil } // DTLSServer returns a new DTLS server side connection // using conn as the underlying transport. // The configuration config must be non-nil and must have // at least one certificate. func DTLSServer(conn net.Conn, config *Config) *Conn { c := &Conn{config: config, isDTLS: true, conn: conn} c.init() return c } // DTLSClient returns a new DTLS client side connection // using conn as the underlying transport. // The config cannot be nil: users must set either ServerHostname or // InsecureSkipVerify in the config. func DTLSClient(conn net.Conn, config *Config) *Conn { c := &Conn{config: config, isClient: true, isDTLS: true, conn: conn} c.init() return c }