From ab063847e6e893740749029a04cce1f6b7345ed5 Mon Sep 17 00:00:00 2001 From: Mike Lockwood Date: Wed, 12 Nov 2014 14:20:06 -0800 Subject: MTP: add strict bounds checking for all incoming packets Previously we did not sanity check incoming MTP packets, which could result in crashes due to reading off the edge of a packet. Now all MTP packet getter functions return a boolean result (true for OK, false for reading off the edge of the packet) and we now return errors for malformed packets. Bug: 18113092 Change-Id: Ic7623ee96f00652bdfb4f66acb16a93db5a1c105 --- media/mtp/MtpDataPacket.cpp | 160 ++++++++++++++++++++++++++++++++------------ 1 file changed, 117 insertions(+), 43 deletions(-) (limited to 'media/mtp/MtpDataPacket.cpp') diff --git a/media/mtp/MtpDataPacket.cpp b/media/mtp/MtpDataPacket.cpp index e6e19e3..052b700 100644 --- a/media/mtp/MtpDataPacket.cpp +++ b/media/mtp/MtpDataPacket.cpp @@ -51,104 +51,178 @@ void MtpDataPacket::setTransactionID(MtpTransactionID id) { MtpPacket::putUInt32(MTP_CONTAINER_TRANSACTION_ID_OFFSET, id); } -uint16_t MtpDataPacket::getUInt16() { +bool MtpDataPacket::getUInt8(uint8_t& value) { + if (mPacketSize - mOffset < sizeof(value)) + return false; + value = mBuffer[mOffset++]; + return true; +} + +bool MtpDataPacket::getUInt16(uint16_t& value) { + if (mPacketSize - mOffset < sizeof(value)) + return false; int offset = mOffset; - uint16_t result = (uint16_t)mBuffer[offset] | ((uint16_t)mBuffer[offset + 1] << 8); - mOffset += 2; - return result; + value = (uint16_t)mBuffer[offset] | ((uint16_t)mBuffer[offset + 1] << 8); + mOffset += sizeof(value); + return true; } -uint32_t MtpDataPacket::getUInt32() { +bool MtpDataPacket::getUInt32(uint32_t& value) { + if (mPacketSize - mOffset < sizeof(value)) + return false; int offset = mOffset; - uint32_t result = (uint32_t)mBuffer[offset] | ((uint32_t)mBuffer[offset + 1] << 8) | + value = (uint32_t)mBuffer[offset] | ((uint32_t)mBuffer[offset + 1] << 8) | ((uint32_t)mBuffer[offset + 2] << 16) | ((uint32_t)mBuffer[offset + 3] << 24); - mOffset += 4; - return result; + mOffset += sizeof(value); + return true; } -uint64_t MtpDataPacket::getUInt64() { +bool MtpDataPacket::getUInt64(uint64_t& value) { + if (mPacketSize - mOffset < sizeof(value)) + return false; int offset = mOffset; - uint64_t result = (uint64_t)mBuffer[offset] | ((uint64_t)mBuffer[offset + 1] << 8) | + value = (uint64_t)mBuffer[offset] | ((uint64_t)mBuffer[offset + 1] << 8) | ((uint64_t)mBuffer[offset + 2] << 16) | ((uint64_t)mBuffer[offset + 3] << 24) | ((uint64_t)mBuffer[offset + 4] << 32) | ((uint64_t)mBuffer[offset + 5] << 40) | ((uint64_t)mBuffer[offset + 6] << 48) | ((uint64_t)mBuffer[offset + 7] << 56); - mOffset += 8; - return result; + mOffset += sizeof(value); + return true; } -void MtpDataPacket::getUInt128(uint128_t& value) { - value[0] = getUInt32(); - value[1] = getUInt32(); - value[2] = getUInt32(); - value[3] = getUInt32(); +bool MtpDataPacket::getUInt128(uint128_t& value) { + return getUInt32(value[0]) && getUInt32(value[1]) && getUInt32(value[2]) && getUInt32(value[3]); } -void MtpDataPacket::getString(MtpStringBuffer& string) +bool MtpDataPacket::getString(MtpStringBuffer& string) { - string.readFromPacket(this); + return string.readFromPacket(this); } Int8List* MtpDataPacket::getAInt8() { + uint32_t count; + if (!getUInt32(count)) + return NULL; Int8List* result = new Int8List; - int count = getUInt32(); - for (int i = 0; i < count; i++) - result->push(getInt8()); + for (uint32_t i = 0; i < count; i++) { + int8_t value; + if (!getInt8(value)) { + delete result; + return NULL; + } + result->push(value); + } return result; } UInt8List* MtpDataPacket::getAUInt8() { + uint32_t count; + if (!getUInt32(count)) + return NULL; UInt8List* result = new UInt8List; - int count = getUInt32(); - for (int i = 0; i < count; i++) - result->push(getUInt8()); + for (uint32_t i = 0; i < count; i++) { + uint8_t value; + if (!getUInt8(value)) { + delete result; + return NULL; + } + result->push(value); + } return result; } Int16List* MtpDataPacket::getAInt16() { + uint32_t count; + if (!getUInt32(count)) + return NULL; Int16List* result = new Int16List; - int count = getUInt32(); - for (int i = 0; i < count; i++) - result->push(getInt16()); + for (uint32_t i = 0; i < count; i++) { + int16_t value; + if (!getInt16(value)) { + delete result; + return NULL; + } + result->push(value); + } return result; } UInt16List* MtpDataPacket::getAUInt16() { + uint32_t count; + if (!getUInt32(count)) + return NULL; UInt16List* result = new UInt16List; - int count = getUInt32(); - for (int i = 0; i < count; i++) - result->push(getUInt16()); + for (uint32_t i = 0; i < count; i++) { + uint16_t value; + if (!getUInt16(value)) { + delete result; + return NULL; + } + result->push(value); + } return result; } Int32List* MtpDataPacket::getAInt32() { + uint32_t count; + if (!getUInt32(count)) + return NULL; Int32List* result = new Int32List; - int count = getUInt32(); - for (int i = 0; i < count; i++) - result->push(getInt32()); + for (uint32_t i = 0; i < count; i++) { + int32_t value; + if (!getInt32(value)) { + delete result; + return NULL; + } + result->push(value); + } return result; } UInt32List* MtpDataPacket::getAUInt32() { + uint32_t count; + if (!getUInt32(count)) + return NULL; UInt32List* result = new UInt32List; - int count = getUInt32(); - for (int i = 0; i < count; i++) - result->push(getUInt32()); + for (uint32_t i = 0; i < count; i++) { + uint32_t value; + if (!getUInt32(value)) { + delete result; + return NULL; + } + result->push(value); + } return result; } Int64List* MtpDataPacket::getAInt64() { + uint32_t count; + if (!getUInt32(count)) + return NULL; Int64List* result = new Int64List; - int count = getUInt32(); - for (int i = 0; i < count; i++) - result->push(getInt64()); + for (uint32_t i = 0; i < count; i++) { + int64_t value; + if (!getInt64(value)) { + delete result; + return NULL; + } + result->push(value); + } return result; } UInt64List* MtpDataPacket::getAUInt64() { + uint32_t count; + if (!getUInt32(count)) + return NULL; UInt64List* result = new UInt64List; - int count = getUInt32(); - for (int i = 0; i < count; i++) - result->push(getUInt64()); + for (uint32_t i = 0; i < count; i++) { + uint64_t value; + if (!getUInt64(value)) { + delete result; + return NULL; + } + result->push(value); + } return result; } -- cgit v1.1