summaryrefslogtreecommitdiffstats
path: root/src/ssl/d1_both.c
diff options
context:
space:
mode:
Diffstat (limited to 'src/ssl/d1_both.c')
-rw-r--r--src/ssl/d1_both.c112
1 files changed, 50 insertions, 62 deletions
diff --git a/src/ssl/d1_both.c b/src/ssl/d1_both.c
index ac35a66..1acb3ce 100644
--- a/src/ssl/d1_both.c
+++ b/src/ssl/d1_both.c
@@ -111,6 +111,8 @@
* copied and put under another distribution licence
* [including the GNU Public Licence.] */
+#include <openssl/ssl.h>
+
#include <assert.h>
#include <limits.h>
#include <stdio.h>
@@ -147,52 +149,44 @@ static void dtls1_fix_message_header(SSL *s, unsigned long frag_off,
unsigned long frag_len);
static unsigned char *dtls1_write_message_header(SSL *s, unsigned char *p);
-static hm_fragment *dtls1_hm_fragment_new(unsigned long frag_len,
- int reassembly) {
- hm_fragment *frag = NULL;
- uint8_t *buf = NULL;
- uint8_t *bitmask = NULL;
-
- frag = (hm_fragment *)OPENSSL_malloc(sizeof(hm_fragment));
+static hm_fragment *dtls1_hm_fragment_new(size_t frag_len, int reassembly) {
+ hm_fragment *frag = OPENSSL_malloc(sizeof(hm_fragment));
if (frag == NULL) {
- OPENSSL_PUT_ERROR(SSL, dtls1_hm_fragment_new, ERR_R_MALLOC_FAILURE);
+ OPENSSL_PUT_ERROR(SSL, ERR_R_MALLOC_FAILURE);
return NULL;
}
+ memset(frag, 0, sizeof(hm_fragment));
- if (frag_len) {
- buf = (uint8_t *)OPENSSL_malloc(frag_len);
- if (buf == NULL) {
- OPENSSL_PUT_ERROR(SSL, dtls1_hm_fragment_new, ERR_R_MALLOC_FAILURE);
- OPENSSL_free(frag);
- return NULL;
+ /* If the handshake message is empty, |frag->fragment| and |frag->reassembly|
+ * are NULL. */
+ if (frag_len > 0) {
+ frag->fragment = OPENSSL_malloc(frag_len);
+ if (frag->fragment == NULL) {
+ OPENSSL_PUT_ERROR(SSL, ERR_R_MALLOC_FAILURE);
+ goto err;
}
- }
-
- /* zero length fragment gets zero frag->fragment */
- frag->fragment = buf;
- /* Initialize reassembly bitmask if necessary */
- if (reassembly && frag_len > 0) {
- if (frag_len + 7 < frag_len) {
- OPENSSL_PUT_ERROR(SSL, dtls1_hm_fragment_new, ERR_R_OVERFLOW);
- return NULL;
- }
- size_t bitmask_len = (frag_len + 7) / 8;
- bitmask = (uint8_t *)OPENSSL_malloc(bitmask_len);
- if (bitmask == NULL) {
- OPENSSL_PUT_ERROR(SSL, dtls1_hm_fragment_new, ERR_R_MALLOC_FAILURE);
- if (buf != NULL) {
- OPENSSL_free(buf);
+ if (reassembly) {
+ /* Initialize reassembly bitmask. */
+ if (frag_len + 7 < frag_len) {
+ OPENSSL_PUT_ERROR(SSL, ERR_R_OVERFLOW);
+ goto err;
}
- OPENSSL_free(frag);
- return NULL;
+ size_t bitmask_len = (frag_len + 7) / 8;
+ frag->reassembly = OPENSSL_malloc(bitmask_len);
+ if (frag->reassembly == NULL) {
+ OPENSSL_PUT_ERROR(SSL, ERR_R_MALLOC_FAILURE);
+ goto err;
+ }
+ memset(frag->reassembly, 0, bitmask_len);
}
- memset(bitmask, 0, bitmask_len);
}
- frag->reassembly = bitmask;
-
return frag;
+
+err:
+ dtls1_hm_fragment_free(frag);
+ return NULL;
}
void dtls1_hm_fragment_free(hm_fragment *frag) {
@@ -326,7 +320,7 @@ int dtls1_do_write(SSL *s, int type, enum dtls1_use_epoch_t use_epoch) {
if (curr_mtu <= DTLS1_HM_HEADER_LENGTH) {
/* To make forward progress, the MTU must, at minimum, fit the handshake
* header and one byte of handshake body. */
- OPENSSL_PUT_ERROR(SSL, dtls1_do_write, SSL_R_MTU_TOO_SMALL);
+ OPENSSL_PUT_ERROR(SSL, SSL_R_MTU_TOO_SMALL);
return -1;
}
@@ -344,7 +338,7 @@ int dtls1_do_write(SSL *s, int type, enum dtls1_use_epoch_t use_epoch) {
assert(type == SSL3_RT_CHANGE_CIPHER_SPEC);
/* ChangeCipherSpec cannot be fragmented. */
if (s->init_num > curr_mtu) {
- OPENSSL_PUT_ERROR(SSL, dtls1_do_write, SSL_R_MTU_TOO_SMALL);
+ OPENSSL_PUT_ERROR(SSL, SSL_R_MTU_TOO_SMALL);
return -1;
}
len = s->init_num;
@@ -450,8 +444,7 @@ static hm_fragment *dtls1_get_buffered_message(
frag->msg_header.msg_len != msg_hdr->msg_len) {
/* The new fragment must be compatible with the previous fragments from
* this message. */
- OPENSSL_PUT_ERROR(SSL, dtls1_get_buffered_message,
- SSL_R_FRAGMENT_MISMATCH);
+ OPENSSL_PUT_ERROR(SSL, SSL_R_FRAGMENT_MISMATCH);
ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_ILLEGAL_PARAMETER);
return NULL;
}
@@ -473,11 +466,7 @@ static size_t dtls1_max_handshake_message_len(const SSL *s) {
/* dtls1_process_fragment reads a handshake fragment and processes it. It
* returns one if a fragment was successfully processed and 0 or -1 on error. */
static int dtls1_process_fragment(SSL *s) {
- /* Read handshake message header.
- *
- * TODO(davidben): ssl_read_bytes allows splitting the fragment header and
- * body across two records. Change this interface to consume the fragment in
- * one pass. */
+ /* Read handshake message header. */
uint8_t header[DTLS1_HM_HEADER_LENGTH];
int ret = dtls1_read_bytes(s, SSL3_RT_HANDSHAKE, header,
DTLS1_HM_HEADER_LENGTH, 0);
@@ -485,7 +474,7 @@ static int dtls1_process_fragment(SSL *s) {
return ret;
}
if (ret != DTLS1_HM_HEADER_LENGTH) {
- OPENSSL_PUT_ERROR(SSL, dtls1_process_fragment, SSL_R_UNEXPECTED_MESSAGE);
+ OPENSSL_PUT_ERROR(SSL, SSL_R_UNEXPECTED_MESSAGE);
ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_UNEXPECTED_MESSAGE);
return -1;
}
@@ -494,14 +483,16 @@ static int dtls1_process_fragment(SSL *s) {
struct hm_header_st msg_hdr;
dtls1_get_message_header(header, &msg_hdr);
+ /* TODO(davidben): dtls1_read_bytes is the wrong abstraction for DTLS. There
+ * should be no need to reach into |s->s3->rrec.length|. */
const size_t frag_off = msg_hdr.frag_off;
const size_t frag_len = msg_hdr.frag_len;
const size_t msg_len = msg_hdr.msg_len;
if (frag_off > msg_len || frag_off + frag_len < frag_off ||
frag_off + frag_len > msg_len ||
- msg_len > dtls1_max_handshake_message_len(s)) {
- OPENSSL_PUT_ERROR(SSL, dtls1_process_fragment,
- SSL_R_EXCESSIVE_MESSAGE_SIZE);
+ msg_len > dtls1_max_handshake_message_len(s) ||
+ frag_len > s->s3->rrec.length) {
+ OPENSSL_PUT_ERROR(SSL, SSL_R_EXCESSIVE_MESSAGE_SIZE);
ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_ILLEGAL_PARAMETER);
return -1;
}
@@ -535,8 +526,8 @@ static int dtls1_process_fragment(SSL *s) {
ret = dtls1_read_bytes(s, SSL3_RT_HANDSHAKE, frag->fragment + frag_off,
frag_len, 0);
if (ret != frag_len) {
- OPENSSL_PUT_ERROR(SSL, dtls1_process_fragment, SSL_R_UNEXPECTED_MESSAGE);
- ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_UNEXPECTED_MESSAGE);
+ OPENSSL_PUT_ERROR(SSL, ERR_R_INTERNAL_ERROR);
+ ssl3_send_alert(s, SSL3_AL_FATAL, SSL_AD_INTERNAL_ERROR);
return -1;
}
dtls1_hm_fragment_mark(frag, frag_off, frag_off + frag_len);
@@ -563,7 +554,7 @@ long dtls1_get_message(SSL *s, int st1, int stn, int msg_type, long max,
s->s3->tmp.reuse_message = 0;
if (msg_type >= 0 && s->s3->tmp.message_type != msg_type) {
al = SSL_AD_UNEXPECTED_MESSAGE;
- OPENSSL_PUT_ERROR(SSL, dtls1_get_message, SSL_R_UNEXPECTED_MESSAGE);
+ OPENSSL_PUT_ERROR(SSL, SSL_R_UNEXPECTED_MESSAGE);
goto f_err;
}
*ok = 1;
@@ -589,22 +580,19 @@ long dtls1_get_message(SSL *s, int st1, int stn, int msg_type, long max,
assert(frag->reassembly == NULL);
if (frag->msg_header.msg_len > (size_t)max) {
- OPENSSL_PUT_ERROR(SSL, dtls1_get_message, SSL_R_EXCESSIVE_MESSAGE_SIZE);
+ OPENSSL_PUT_ERROR(SSL, SSL_R_EXCESSIVE_MESSAGE_SIZE);
goto err;
}
+ /* Reconstruct the assembled message. */
+ size_t len;
CBB cbb;
+ CBB_zero(&cbb);
if (!BUF_MEM_grow(s->init_buf,
(size_t)frag->msg_header.msg_len +
DTLS1_HM_HEADER_LENGTH) ||
- !CBB_init_fixed(&cbb, (uint8_t *)s->init_buf->data, s->init_buf->max)) {
- OPENSSL_PUT_ERROR(SSL, dtls1_get_message, ERR_R_MALLOC_FAILURE);
- goto err;
- }
-
- /* Reconstruct the assembled message. */
- size_t len;
- if (!CBB_add_u8(&cbb, frag->msg_header.type) ||
+ !CBB_init_fixed(&cbb, (uint8_t *)s->init_buf->data, s->init_buf->max) ||
+ !CBB_add_u8(&cbb, frag->msg_header.type) ||
!CBB_add_u24(&cbb, frag->msg_header.msg_len) ||
!CBB_add_u16(&cbb, frag->msg_header.seq) ||
!CBB_add_u24(&cbb, 0 /* frag_off */) ||
@@ -612,7 +600,7 @@ long dtls1_get_message(SSL *s, int st1, int stn, int msg_type, long max,
!CBB_add_bytes(&cbb, frag->fragment, frag->msg_header.msg_len) ||
!CBB_finish(&cbb, NULL, &len)) {
CBB_cleanup(&cbb);
- OPENSSL_PUT_ERROR(SSL, dtls1_get_message, ERR_R_INTERNAL_ERROR);
+ OPENSSL_PUT_ERROR(SSL, ERR_R_MALLOC_FAILURE);
goto err;
}
assert(len == (size_t)frag->msg_header.msg_len + DTLS1_HM_HEADER_LENGTH);
@@ -628,7 +616,7 @@ long dtls1_get_message(SSL *s, int st1, int stn, int msg_type, long max,
if (msg_type >= 0 && s->s3->tmp.message_type != msg_type) {
al = SSL_AD_UNEXPECTED_MESSAGE;
- OPENSSL_PUT_ERROR(SSL, dtls1_get_message, SSL_R_UNEXPECTED_MESSAGE);
+ OPENSSL_PUT_ERROR(SSL, SSL_R_UNEXPECTED_MESSAGE);
goto f_err;
}
if (hash_message == ssl_hash_message && !ssl3_hash_current_message(s)) {