summaryrefslogtreecommitdiffstats
path: root/src/ssl/d1_both.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/ssl/d1_both.c')
-rw-r--r--src/ssl/d1_both.c112
1 files changed, 62 insertions, 50 deletions
diff --git a/src/ssl/d1_both.c b/src/ssl/d1_both.c
index 1acb3ce..ac35a66 100644
--- a/src/ssl/d1_both.c
+++ b/src/ssl/d1_both.c
@@ -111,8 +111,6 @@
* copied and put under another distribution licence
* [including the GNU Public Licence.] */
-#include <openssl/ssl.h>
-
#include <assert.h>
#include <limits.h>
#include <stdio.h>
@@ -149,44 +147,52 @@ static void dtls1_fix_message_header(SSL *s, unsigned long frag_off,
unsigned long frag_len);
static unsigned char *dtls1_write_message_header(SSL *s, unsigned char *p);
-static hm_fragment *dtls1_hm_fragment_new(size_t frag_len, int reassembly) {
- hm_fragment *frag = OPENSSL_malloc(sizeof(hm_fragment));
+static hm_fragment *dtls1_hm_fragment_new(unsigned long frag_len,
+ int reassembly) {
+ hm_fragment *frag = NULL;
+ uint8_t *buf = NULL;
+ uint8_t *bitmask = NULL;
+
+ frag = (hm_fragment *)OPENSSL_malloc(sizeof(hm_fragment));
if (frag == NULL) {
- OPENSSL_PUT_ERROR(SSL, ERR_R_MALLOC_FAILURE);
+ OPENSSL_PUT_ERROR(SSL, dtls1_hm_fragment_new, ERR_R_MALLOC_FAILURE);
return NULL;
}
- memset(frag, 0, sizeof(hm_fragment));
- /* If the handshake message is empty, |frag->fragment| and |frag->reassembly|
- * are NULL. */
- if (frag_len > 0) {
- frag->fragment = OPENSSL_malloc(frag_len);
- if (frag->fragment == NULL) {
- OPENSSL_PUT_ERROR(SSL, ERR_R_MALLOC_FAILURE);
- goto err;
+ if (frag_len) {
+ buf = (uint8_t *)OPENSSL_malloc(frag_len);
+ if (buf == NULL) {
+ OPENSSL_PUT_ERROR(SSL, dtls1_hm_fragment_new, ERR_R_MALLOC_FAILURE);
+ OPENSSL_free(frag);
+ return NULL;
}
+ }
- if (reassembly) {
- /* Initialize reassembly bitmask. */
- if (frag_len + 7 < frag_len) {
- OPENSSL_PUT_ERROR(SSL, ERR_R_OVERFLOW);
- goto err;
- }
- size_t bitmask_len = (frag_len + 7) / 8;
- frag->reassembly = OPENSSL_malloc(bitmask_len);
- if (frag->reassembly == NULL) {
- OPENSSL_PUT_ERROR(SSL, ERR_R_MALLOC_FAILURE);
- goto err;
+ /* zero length fragment gets zero frag->fragment */
+ frag->fragment = buf;
+
+ /* Initialize reassembly bitmask if necessary */
+ if (reassembly && frag_len > 0) {
+ if (frag_len + 7 < frag_len) {
+ OPENSSL_PUT_ERROR(SSL, dtls1_hm_fragment_new, ERR_R_OVERFLOW);
+ return NULL;
+ }
+ size_t bitmask_len = (frag_len + 7) / 8;
+ bitmask = (uint8_t *)OPENSSL_malloc(bitmask_len);
+ if (bitmask == NULL) {
+ OPENSSL_PUT_ERROR(SSL, dtls1_hm_fragment_new, ERR_R_MALLOC_FAILURE);
+ if (buf != NULL) {
+ OPENSSL_free(buf);
}
- memset(frag->reassembly, 0, bitmask_len);
+ OPENSSL_free(frag);
+ return NULL;
}
+ memset(bitmask, 0, bitmask_len);
}
- return frag;
+ frag->reassembly = bitmask;
-err:
- dtls1_hm_fragment_free(frag);
- return NULL;
+ return frag;
}
void dtls1_hm_fragment_free(hm_fragment *frag) {
@@ -320,7 +326,7 @@ int dtls1_do_write(SSL *s, int type, enum dtls1_use_epoch_t use_epoch) {
if (curr_mtu <= DTLS1_HM_HEADER_LENGTH) {
/* To make forward progress, the MTU must, at minimum, fit the handshake
* header and one byte of handshake body. */
- OPENSSL_PUT_ERROR(SSL, SSL_R_MTU_TOO_SMALL);
+ OPENSSL_PUT_ERROR(SSL, dtls1_do_write, SSL_R_MTU_TOO_SMALL);
return -1;
}
@@ -338,7 +344,7 @@ int dtls1_do_write(SSL *s, int type, enum dtls1_use_epoch_t use_epoch) {
assert(type == SSL3_RT_CHANGE_CIPHER_SPEC);
/* ChangeCipherSpec cannot be fragmented. */
if (s->init_num > curr_mtu) {
- OPENSSL_PUT_ERROR(SSL, SSL_R_MTU_TOO_SMALL);
+ OPENSSL_PUT_ERROR(SSL, dtls1_do_write, SSL_R_MTU_TOO_SMALL);
return -1;
}
len = s->init_num;
@@ -444,7 +450,8 @@ static hm_fragment *dtls1_get_buffered_message(
frag->msg_header.msg_len != msg_hdr->msg_len) {
/* The new fragment must be compatible with the previous fragments from
* this message. */
- OPENSSL_PUT_ERROR(SSL, SSL_R_FRAGMENT_MISMATCH);
+ OPENSSL_PUT_ERROR(SSL, dtls1_get_buffered_message,
+ SSL_R_FRAGMENT_MISMATCH);
ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_ILLEGAL_PARAMETER);
return NULL;
}
@@ -466,7 +473,11 @@ static size_t dtls1_max_handshake_message_len(const SSL *s) {
/* dtls1_process_fragment reads a handshake fragment and processes it. It
* returns one if a fragment was successfully processed and 0 or -1 on error. */
static int dtls1_process_fragment(SSL *s) {
- /* Read handshake message header. */
+ /* Read handshake message header.
+ *
+ * TODO(davidben): ssl_read_bytes allows splitting the fragment header and
+ * body across two records. Change this interface to consume the fragment in
+ * one pass. */
uint8_t header[DTLS1_HM_HEADER_LENGTH];
int ret = dtls1_read_bytes(s, SSL3_RT_HANDSHAKE, header,
DTLS1_HM_HEADER_LENGTH, 0);
@@ -474,7 +485,7 @@ static int dtls1_process_fragment(SSL *s) {
return ret;
}
if (ret != DTLS1_HM_HEADER_LENGTH) {
- OPENSSL_PUT_ERROR(SSL, SSL_R_UNEXPECTED_MESSAGE);
+ OPENSSL_PUT_ERROR(SSL, dtls1_process_fragment, SSL_R_UNEXPECTED_MESSAGE);
ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_UNEXPECTED_MESSAGE);
return -1;
}
@@ -483,16 +494,14 @@ static int dtls1_process_fragment(SSL *s) {
struct hm_header_st msg_hdr;
dtls1_get_message_header(header, &msg_hdr);
- /* TODO(davidben): dtls1_read_bytes is the wrong abstraction for DTLS. There
- * should be no need to reach into |s->s3->rrec.length|. */
const size_t frag_off = msg_hdr.frag_off;
const size_t frag_len = msg_hdr.frag_len;
const size_t msg_len = msg_hdr.msg_len;
if (frag_off > msg_len || frag_off + frag_len < frag_off ||
frag_off + frag_len > msg_len ||
- msg_len > dtls1_max_handshake_message_len(s) ||
- frag_len > s->s3->rrec.length) {
- OPENSSL_PUT_ERROR(SSL, SSL_R_EXCESSIVE_MESSAGE_SIZE);
+ msg_len > dtls1_max_handshake_message_len(s)) {
+ OPENSSL_PUT_ERROR(SSL, dtls1_process_fragment,
+ SSL_R_EXCESSIVE_MESSAGE_SIZE);
ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_ILLEGAL_PARAMETER);
return -1;
}
@@ -526,8 +535,8 @@ static int dtls1_process_fragment(SSL *s) {
ret = dtls1_read_bytes(s, SSL3_RT_HANDSHAKE, frag->fragment + frag_off,
frag_len, 0);
if (ret != frag_len) {
- OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
- ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_INTERNAL_ERROR);
+ OPENSSL_PUT_ERROR(SSL, dtls1_process_fragment, SSL_R_UNEXPECTED_MESSAGE);
+ ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_UNEXPECTED_MESSAGE);
return -1;
}
dtls1_hm_fragment_mark(frag, frag_off, frag_off + frag_len);
@@ -554,7 +563,7 @@ long dtls1_get_message(SSL *s, int st1, int stn, int msg_type, long max,
s->s3->tmp.reuse_message = 0;
if (msg_type >= 0 && s->s3->tmp.message_type != msg_type) {
al = SSL_AD_UNEXPECTED_MESSAGE;
- OPENSSL_PUT_ERROR(SSL, SSL_R_UNEXPECTED_MESSAGE);
+ OPENSSL_PUT_ERROR(SSL, dtls1_get_message, SSL_R_UNEXPECTED_MESSAGE);
goto f_err;
}
*ok = 1;
@@ -580,19 +589,22 @@ long dtls1_get_message(SSL *s, int st1, int stn, int msg_type, long max,
assert(frag->reassembly == NULL);
if (frag->msg_header.msg_len > (size_t)max) {
- OPENSSL_PUT_ERROR(SSL, SSL_R_EXCESSIVE_MESSAGE_SIZE);
+ OPENSSL_PUT_ERROR(SSL, dtls1_get_message, SSL_R_EXCESSIVE_MESSAGE_SIZE);
goto err;
}
- /* Reconstruct the assembled message. */
- size_t len;
CBB cbb;
- CBB_zero(&cbb);
if (!BUF_MEM_grow(s->init_buf,
(size_t)frag->msg_header.msg_len +
DTLS1_HM_HEADER_LENGTH) ||
- !CBB_init_fixed(&cbb, (uint8_t *)s->init_buf->data, s->init_buf->max) ||
- !CBB_add_u8(&cbb, frag->msg_header.type) ||
+ !CBB_init_fixed(&cbb, (uint8_t *)s->init_buf->data, s->init_buf->max)) {
+ OPENSSL_PUT_ERROR(SSL, dtls1_get_message, ERR_R_MALLOC_FAILURE);
+ goto err;
+ }
+
+ /* Reconstruct the assembled message. */
+ size_t len;
+ if (!CBB_add_u8(&cbb, frag->msg_header.type) ||
!CBB_add_u24(&cbb, frag->msg_header.msg_len) ||
!CBB_add_u16(&cbb, frag->msg_header.seq) ||
!CBB_add_u24(&cbb, 0 /* frag_off */) ||
@@ -600,7 +612,7 @@ long dtls1_get_message(SSL *s, int st1, int stn, int msg_type, long max,
!CBB_add_bytes(&cbb, frag->fragment, frag->msg_header.msg_len) ||
!CBB_finish(&cbb, NULL, &len)) {
CBB_cleanup(&cbb);
- OPENSSL_PUT_ERROR(SSL, ERR_R_MALLOC_FAILURE);
+ OPENSSL_PUT_ERROR(SSL, dtls1_get_message, ERR_R_INTERNAL_ERROR);
goto err;
}
assert(len == (size_t)frag->msg_header.msg_len + DTLS1_HM_HEADER_LENGTH);
@@ -616,7 +628,7 @@ long dtls1_get_message(SSL *s, int st1, int stn, int msg_type, long max,
if (msg_type >= 0 && s->s3->tmp.message_type != msg_type) {
al = SSL_AD_UNEXPECTED_MESSAGE;
- OPENSSL_PUT_ERROR(SSL, SSL_R_UNEXPECTED_MESSAGE);
+ OPENSSL_PUT_ERROR(SSL, dtls1_get_message, SSL_R_UNEXPECTED_MESSAGE);
goto f_err;
}
if (hash_message == ssl_hash_message && !ssl3_hash_current_message(s)) {