From ab063847e6e893740749029a04cce1f6b7345ed5 Mon Sep 17 00:00:00 2001 From: Mike Lockwood <lockwood@google.com> Date: Wed, 12 Nov 2014 14:20:06 -0800 Subject: MTP: add strict bounds checking for all incoming packets Previously we did not sanity check incoming MTP packets, which could result in crashes due to reading off the edge of a packet. Now all MTP packet getter functions return a boolean result (true for OK, false for reading off the edge of the packet) and we now return errors for malformed packets. Bug: 18113092 Change-Id: Ic7623ee96f00652bdfb4f66acb16a93db5a1c105 --- media/mtp/MtpDataPacket.cpp | 160 ++++++++++++++++++++++++++++++----------- media/mtp/MtpDataPacket.h | 25 +++---- media/mtp/MtpDevice.cpp | 32 ++++++--- media/mtp/MtpDevice.h | 2 +- media/mtp/MtpDeviceInfo.cpp | 33 +++++---- media/mtp/MtpDeviceInfo.h | 4 +- media/mtp/MtpObjectInfo.cpp | 42 +++++------ media/mtp/MtpObjectInfo.h | 2 +- media/mtp/MtpPacket.cpp | 2 +- media/mtp/MtpPacket.h | 8 +-- media/mtp/MtpProperty.cpp | 93 ++++++++++++++---------- media/mtp/MtpProperty.h | 14 ++-- media/mtp/MtpRequestPacket.cpp | 20 ++++-- media/mtp/MtpRequestPacket.h | 4 ++ media/mtp/MtpServer.cpp | 97 ++++++++++++++++++++----- media/mtp/MtpStorageInfo.cpp | 20 +++--- media/mtp/MtpStorageInfo.h | 2 +- media/mtp/MtpStringBuffer.cpp | 13 +++- media/mtp/MtpStringBuffer.h | 2 +- 19 files changed, 386 insertions(+), 189 deletions(-) (limited to 'media/mtp') diff --git a/media/mtp/MtpDataPacket.cpp b/media/mtp/MtpDataPacket.cpp index e6e19e3..052b700 100644 --- a/media/mtp/MtpDataPacket.cpp +++ b/media/mtp/MtpDataPacket.cpp @@ -51,104 +51,178 @@ void MtpDataPacket::setTransactionID(MtpTransactionID id) { MtpPacket::putUInt32(MTP_CONTAINER_TRANSACTION_ID_OFFSET, id); } -uint16_t MtpDataPacket::getUInt16() { +bool MtpDataPacket::getUInt8(uint8_t& value) { + if (mPacketSize - mOffset < sizeof(value)) + return false; + value = mBuffer[mOffset++]; + return true; +} + +bool MtpDataPacket::getUInt16(uint16_t& value) { + if (mPacketSize - mOffset < sizeof(value)) + return false; int offset = mOffset; - uint16_t result = (uint16_t)mBuffer[offset] | ((uint16_t)mBuffer[offset + 1] << 8); - mOffset += 2; - return result; + value = (uint16_t)mBuffer[offset] | ((uint16_t)mBuffer[offset + 1] << 8); + mOffset += sizeof(value); + return true; } -uint32_t MtpDataPacket::getUInt32() { +bool MtpDataPacket::getUInt32(uint32_t& value) { + if (mPacketSize - mOffset < sizeof(value)) + return false; int offset = mOffset; - uint32_t result = (uint32_t)mBuffer[offset] | ((uint32_t)mBuffer[offset + 1] << 8) | + value = (uint32_t)mBuffer[offset] | ((uint32_t)mBuffer[offset + 1] << 8) | ((uint32_t)mBuffer[offset + 2] << 16) | ((uint32_t)mBuffer[offset + 3] << 24); - mOffset += 4; - return result; + mOffset += sizeof(value); + return true; } -uint64_t MtpDataPacket::getUInt64() { +bool MtpDataPacket::getUInt64(uint64_t& value) { + if (mPacketSize - mOffset < sizeof(value)) + return false; int offset = mOffset; - uint64_t result = (uint64_t)mBuffer[offset] | ((uint64_t)mBuffer[offset + 1] << 8) | + value = (uint64_t)mBuffer[offset] | ((uint64_t)mBuffer[offset + 1] << 8) | ((uint64_t)mBuffer[offset + 2] << 16) | ((uint64_t)mBuffer[offset + 3] << 24) | ((uint64_t)mBuffer[offset + 4] << 32) | ((uint64_t)mBuffer[offset + 5] << 40) | ((uint64_t)mBuffer[offset + 6] << 48) | ((uint64_t)mBuffer[offset + 7] << 56); - mOffset += 8; - return result; + mOffset += sizeof(value); + return true; } -void MtpDataPacket::getUInt128(uint128_t& value) { - value[0] = getUInt32(); - value[1] = getUInt32(); - value[2] = getUInt32(); - value[3] = getUInt32(); +bool MtpDataPacket::getUInt128(uint128_t& value) { + return getUInt32(value[0]) && getUInt32(value[1]) && getUInt32(value[2]) && getUInt32(value[3]); } -void MtpDataPacket::getString(MtpStringBuffer& string) +bool MtpDataPacket::getString(MtpStringBuffer& string) { - string.readFromPacket(this); + return string.readFromPacket(this); } Int8List* MtpDataPacket::getAInt8() { + uint32_t count; + if (!getUInt32(count)) + return NULL; Int8List* result = new Int8List; - int count = getUInt32(); - for (int i = 0; i < count; i++) - result->push(getInt8()); + for (uint32_t i = 0; i < count; i++) { + int8_t value; + if (!getInt8(value)) { + delete result; + return NULL; + } + result->push(value); + } return result; } UInt8List* MtpDataPacket::getAUInt8() { + uint32_t count; + if (!getUInt32(count)) + return NULL; UInt8List* result = new UInt8List; - int count = getUInt32(); - for (int i = 0; i < count; i++) - result->push(getUInt8()); + for (uint32_t i = 0; i < count; i++) { + uint8_t value; + if (!getUInt8(value)) { + delete result; + return NULL; + } + result->push(value); + } return result; } Int16List* MtpDataPacket::getAInt16() { + uint32_t count; + if (!getUInt32(count)) + return NULL; Int16List* result = new Int16List; - int count = getUInt32(); - for (int i = 0; i < count; i++) - result->push(getInt16()); + for (uint32_t i = 0; i < count; i++) { + int16_t value; + if (!getInt16(value)) { + delete result; + return NULL; + } + result->push(value); + } return result; } UInt16List* MtpDataPacket::getAUInt16() { + uint32_t count; + if (!getUInt32(count)) + return NULL; UInt16List* result = new UInt16List; - int count = getUInt32(); - for (int i = 0; i < count; i++) - result->push(getUInt16()); + for (uint32_t i = 0; i < count; i++) { + uint16_t value; + if (!getUInt16(value)) { + delete result; + return NULL; + } + result->push(value); + } return result; } Int32List* MtpDataPacket::getAInt32() { + uint32_t count; + if (!getUInt32(count)) + return NULL; Int32List* result = new Int32List; - int count = getUInt32(); - for (int i = 0; i < count; i++) - result->push(getInt32()); + for (uint32_t i = 0; i < count; i++) { + int32_t value; + if (!getInt32(value)) { + delete result; + return NULL; + } + result->push(value); + } return result; } UInt32List* MtpDataPacket::getAUInt32() { + uint32_t count; + if (!getUInt32(count)) + return NULL; UInt32List* result = new UInt32List; - int count = getUInt32(); - for (int i = 0; i < count; i++) - result->push(getUInt32()); + for (uint32_t i = 0; i < count; i++) { + uint32_t value; + if (!getUInt32(value)) { + delete result; + return NULL; + } + result->push(value); + } return result; } Int64List* MtpDataPacket::getAInt64() { + uint32_t count; + if (!getUInt32(count)) + return NULL; Int64List* result = new Int64List; - int count = getUInt32(); - for (int i = 0; i < count; i++) - result->push(getInt64()); + for (uint32_t i = 0; i < count; i++) { + int64_t value; + if (!getInt64(value)) { + delete result; + return NULL; + } + result->push(value); + } return result; } UInt64List* MtpDataPacket::getAUInt64() { + uint32_t count; + if (!getUInt32(count)) + return NULL; UInt64List* result = new UInt64List; - int count = getUInt32(); - for (int i = 0; i < count; i++) - result->push(getUInt64()); + for (uint32_t i = 0; i < count; i++) { + uint64_t value; + if (!getUInt64(value)) { + delete result; + return NULL; + } + result->push(value); + } return result; } diff --git a/media/mtp/MtpDataPacket.h b/media/mtp/MtpDataPacket.h index 2b81063..13d3bd9 100644 --- a/media/mtp/MtpDataPacket.h +++ b/media/mtp/MtpDataPacket.h @@ -30,7 +30,7 @@ class MtpStringBuffer; class MtpDataPacket : public MtpPacket { private: // current offset for get/put methods - int mOffset; + size_t mOffset; public: MtpDataPacket(); @@ -42,17 +42,18 @@ public: void setTransactionID(MtpTransactionID id); inline const uint8_t* getData() const { return mBuffer + MTP_CONTAINER_HEADER_SIZE; } - inline uint8_t getUInt8() { return (uint8_t)mBuffer[mOffset++]; } - inline int8_t getInt8() { return (int8_t)mBuffer[mOffset++]; } - uint16_t getUInt16(); - inline int16_t getInt16() { return (int16_t)getUInt16(); } - uint32_t getUInt32(); - inline int32_t getInt32() { return (int32_t)getUInt32(); } - uint64_t getUInt64(); - inline int64_t getInt64() { return (int64_t)getUInt64(); } - void getUInt128(uint128_t& value); - inline void getInt128(int128_t& value) { getUInt128((uint128_t&)value); } - void getString(MtpStringBuffer& string); + + bool getUInt8(uint8_t& value); + inline bool getInt8(int8_t& value) { return getUInt8((uint8_t&)value); } + bool getUInt16(uint16_t& value); + inline bool getInt16(int16_t& value) { return getUInt16((uint16_t&)value); } + bool getUInt32(uint32_t& value); + inline bool getInt32(int32_t& value) { return getUInt32((uint32_t&)value); } + bool getUInt64(uint64_t& value); + inline bool getInt64(int64_t& value) { return getUInt64((uint64_t&)value); } + bool getUInt128(uint128_t& value); + inline bool getInt128(int128_t& value) { return getUInt128((uint128_t&)value); } + bool getString(MtpStringBuffer& string); Int8List* getAInt8(); UInt8List* getAUInt8(); diff --git a/media/mtp/MtpDevice.cpp b/media/mtp/MtpDevice.cpp index d6d5dd5..e0d679d 100644 --- a/media/mtp/MtpDevice.cpp +++ b/media/mtp/MtpDevice.cpp @@ -313,8 +313,10 @@ MtpDeviceInfo* MtpDevice::getDeviceInfo() { MtpResponseCode ret = readResponse(); if (ret == MTP_RESPONSE_OK) { MtpDeviceInfo* info = new MtpDeviceInfo; - info->read(mData); - return info; + if (info->read(mData)) + return info; + else + delete info; } return NULL; } @@ -346,8 +348,10 @@ MtpStorageInfo* MtpDevice::getStorageInfo(MtpStorageID storageID) { MtpResponseCode ret = readResponse(); if (ret == MTP_RESPONSE_OK) { MtpStorageInfo* info = new MtpStorageInfo(storageID); - info->read(mData); - return info; + if (info->read(mData)) + return info; + else + delete info; } return NULL; } @@ -385,8 +389,10 @@ MtpObjectInfo* MtpDevice::getObjectInfo(MtpObjectHandle handle) { MtpResponseCode ret = readResponse(); if (ret == MTP_RESPONSE_OK) { MtpObjectInfo* info = new MtpObjectInfo(handle); - info->read(mData); - return info; + if (info->read(mData)) + return info; + else + delete info; } return NULL; } @@ -547,8 +553,10 @@ MtpProperty* MtpDevice::getDevicePropDesc(MtpDeviceProperty code) { MtpResponseCode ret = readResponse(); if (ret == MTP_RESPONSE_OK) { MtpProperty* property = new MtpProperty; - property->read(mData); - return property; + if (property->read(mData)) + return property; + else + delete property; } return NULL; } @@ -566,15 +574,17 @@ MtpProperty* MtpDevice::getObjectPropDesc(MtpObjectProperty code, MtpObjectForma MtpResponseCode ret = readResponse(); if (ret == MTP_RESPONSE_OK) { MtpProperty* property = new MtpProperty; - property->read(mData); - return property; + if (property->read(mData)) + return property; + else + delete property; } return NULL; } bool MtpDevice::readObject(MtpObjectHandle handle, bool (* callback)(void* data, int offset, int length, void* clientData), - int objectSize, void* clientData) { + size_t objectSize, void* clientData) { Mutex::Autolock autoLock(mMutex); bool result = false; diff --git a/media/mtp/MtpDevice.h b/media/mtp/MtpDevice.h index b69203e..9b0acbf 100644 --- a/media/mtp/MtpDevice.h +++ b/media/mtp/MtpDevice.h @@ -98,7 +98,7 @@ public: bool readObject(MtpObjectHandle handle, bool (* callback)(void* data, int offset, int length, void* clientData), - int objectSize, void* clientData); + size_t objectSize, void* clientData); bool readObject(MtpObjectHandle handle, const char* destPath, int group, int perm); diff --git a/media/mtp/MtpDeviceInfo.cpp b/media/mtp/MtpDeviceInfo.cpp index 108e2b8..3e1dff7 100644 --- a/media/mtp/MtpDeviceInfo.cpp +++ b/media/mtp/MtpDeviceInfo.cpp @@ -28,7 +28,7 @@ MtpDeviceInfo::MtpDeviceInfo() mVendorExtensionID(0), mVendorExtensionVersion(0), mVendorExtensionDesc(NULL), - mFunctionalCode(0), + mFunctionalMode(0), mOperations(NULL), mEvents(NULL), mDeviceProperties(NULL), @@ -59,39 +59,46 @@ MtpDeviceInfo::~MtpDeviceInfo() { free(mSerial); } -void MtpDeviceInfo::read(MtpDataPacket& packet) { +bool MtpDeviceInfo::read(MtpDataPacket& packet) { MtpStringBuffer string; // read the device info - mStandardVersion = packet.getUInt16(); - mVendorExtensionID = packet.getUInt32(); - mVendorExtensionVersion = packet.getUInt16(); + if (!packet.getUInt16(mStandardVersion)) return false; + if (!packet.getUInt32(mVendorExtensionID)) return false; + if (!packet.getUInt16(mVendorExtensionVersion)) return false; - packet.getString(string); + if (!packet.getString(string)) return false; mVendorExtensionDesc = strdup((const char *)string); - mFunctionalCode = packet.getUInt16(); + if (!packet.getUInt16(mFunctionalMode)) return false; mOperations = packet.getAUInt16(); + if (!mOperations) return false; mEvents = packet.getAUInt16(); + if (!mEvents) return false; mDeviceProperties = packet.getAUInt16(); + if (!mDeviceProperties) return false; mCaptureFormats = packet.getAUInt16(); + if (!mCaptureFormats) return false; mPlaybackFormats = packet.getAUInt16(); + if (!mCaptureFormats) return false; - packet.getString(string); + if (!packet.getString(string)) return false; mManufacturer = strdup((const char *)string); - packet.getString(string); + if (!packet.getString(string)) return false; mModel = strdup((const char *)string); - packet.getString(string); + if (!packet.getString(string)) return false; mVersion = strdup((const char *)string); - packet.getString(string); + if (!packet.getString(string)) return false; mSerial = strdup((const char *)string); + + return true; } void MtpDeviceInfo::print() { ALOGV("Device Info:\n\tmStandardVersion: %d\n\tmVendorExtensionID: %d\n\tmVendorExtensionVersiony: %d\n", mStandardVersion, mVendorExtensionID, mVendorExtensionVersion); - ALOGV("\tmVendorExtensionDesc: %s\n\tmFunctionalCode: %d\n\tmManufacturer: %s\n\tmModel: %s\n\tmVersion: %s\n\tmSerial: %s\n", - mVendorExtensionDesc, mFunctionalCode, mManufacturer, mModel, mVersion, mSerial); + ALOGV("\tmVendorExtensionDesc: %s\n\tmFunctionalMode: %d\n\tmManufacturer: %s\n\tmModel: %s\n\tmVersion: %s\n\tmSerial: %s\n", + mVendorExtensionDesc, mFunctionalMode, mManufacturer, mModel, mVersion, mSerial); } } // namespace android diff --git a/media/mtp/MtpDeviceInfo.h b/media/mtp/MtpDeviceInfo.h index 2abaa10..bcda9a5 100644 --- a/media/mtp/MtpDeviceInfo.h +++ b/media/mtp/MtpDeviceInfo.h @@ -29,7 +29,7 @@ public: uint32_t mVendorExtensionID; uint16_t mVendorExtensionVersion; char* mVendorExtensionDesc; - uint16_t mFunctionalCode; + uint16_t mFunctionalMode; UInt16List* mOperations; UInt16List* mEvents; MtpDevicePropertyList* mDeviceProperties; @@ -44,7 +44,7 @@ public: MtpDeviceInfo(); virtual ~MtpDeviceInfo(); - void read(MtpDataPacket& packet); + bool read(MtpDataPacket& packet); void print(); }; diff --git a/media/mtp/MtpObjectInfo.cpp b/media/mtp/MtpObjectInfo.cpp index cd15343..0573104 100644 --- a/media/mtp/MtpObjectInfo.cpp +++ b/media/mtp/MtpObjectInfo.cpp @@ -55,39 +55,41 @@ MtpObjectInfo::~MtpObjectInfo() { free(mKeywords); } -void MtpObjectInfo::read(MtpDataPacket& packet) { +bool MtpObjectInfo::read(MtpDataPacket& packet) { MtpStringBuffer string; time_t time; - mStorageID = packet.getUInt32(); - mFormat = packet.getUInt16(); - mProtectionStatus = packet.getUInt16(); - mCompressedSize = packet.getUInt32(); - mThumbFormat = packet.getUInt16(); - mThumbCompressedSize = packet.getUInt32(); - mThumbPixWidth = packet.getUInt32(); - mThumbPixHeight = packet.getUInt32(); - mImagePixWidth = packet.getUInt32(); - mImagePixHeight = packet.getUInt32(); - mImagePixDepth = packet.getUInt32(); - mParent = packet.getUInt32(); - mAssociationType = packet.getUInt16(); - mAssociationDesc = packet.getUInt32(); - mSequenceNumber = packet.getUInt32(); + if (!packet.getUInt32(mStorageID)) return false; + if (!packet.getUInt16(mFormat)) return false; + if (!packet.getUInt16(mProtectionStatus)) return false; + if (!packet.getUInt32(mCompressedSize)) return false; + if (!packet.getUInt16(mThumbFormat)) return false; + if (!packet.getUInt32(mThumbCompressedSize)) return false; + if (!packet.getUInt32(mThumbPixWidth)) return false; + if (!packet.getUInt32(mThumbPixHeight)) return false; + if (!packet.getUInt32(mImagePixWidth)) return false; + if (!packet.getUInt32(mImagePixHeight)) return false; + if (!packet.getUInt32(mImagePixDepth)) return false; + if (!packet.getUInt32(mParent)) return false; + if (!packet.getUInt16(mAssociationType)) return false; + if (!packet.getUInt32(mAssociationDesc)) return false; + if (!packet.getUInt32(mSequenceNumber)) return false; - packet.getString(string); + if (!packet.getString(string)) return false; mName = strdup((const char *)string); - packet.getString(string); + if (!packet.getString(string)) return false; if (parseDateTime((const char*)string, time)) mDateCreated = time; - packet.getString(string); + if (!packet.getString(string)) return false; if (parseDateTime((const char*)string, time)) mDateModified = time; - packet.getString(string); + if (!packet.getString(string)) return false; mKeywords = strdup((const char *)string); + + return true; } void MtpObjectInfo::print() { diff --git a/media/mtp/MtpObjectInfo.h b/media/mtp/MtpObjectInfo.h index c7a449c..86780f1 100644 --- a/media/mtp/MtpObjectInfo.h +++ b/media/mtp/MtpObjectInfo.h @@ -50,7 +50,7 @@ public: MtpObjectInfo(MtpObjectHandle handle); virtual ~MtpObjectInfo(); - void read(MtpDataPacket& packet); + bool read(MtpDataPacket& packet); void print(); }; diff --git a/media/mtp/MtpPacket.cpp b/media/mtp/MtpPacket.cpp index dd07843..bab1335 100644 --- a/media/mtp/MtpPacket.cpp +++ b/media/mtp/MtpPacket.cpp @@ -52,7 +52,7 @@ void MtpPacket::reset() { memset(mBuffer, 0, mBufferSize); } -void MtpPacket::allocate(int length) { +void MtpPacket::allocate(size_t length) { if (length > mBufferSize) { int newLength = length + mAllocationIncrement; mBuffer = (uint8_t *)realloc(mBuffer, newLength); diff --git a/media/mtp/MtpPacket.h b/media/mtp/MtpPacket.h index 0ffb1d3..037722a 100644 --- a/media/mtp/MtpPacket.h +++ b/media/mtp/MtpPacket.h @@ -28,11 +28,11 @@ class MtpPacket { protected: uint8_t* mBuffer; // current size of the buffer - int mBufferSize; + size_t mBufferSize; // number of bytes to add when resizing the buffer - int mAllocationIncrement; + size_t mAllocationIncrement; // size of the data in the packet - int mPacketSize; + size_t mPacketSize; public: MtpPacket(int bufferSize); @@ -41,7 +41,7 @@ public: // sets packet size to the default container size and sets buffer to zero virtual void reset(); - void allocate(int length); + void allocate(size_t length); void dump(); void copyFrom(const MtpPacket& src); diff --git a/media/mtp/MtpProperty.cpp b/media/mtp/MtpProperty.cpp index c500901..d58e2a4 100644 --- a/media/mtp/MtpProperty.cpp +++ b/media/mtp/MtpProperty.cpp @@ -106,15 +106,15 @@ MtpProperty::~MtpProperty() { free(mMinimumValue.str); free(mMaximumValue.str); if (mDefaultArrayValues) { - for (int i = 0; i < mDefaultArrayLength; i++) + for (uint32_t i = 0; i < mDefaultArrayLength; i++) free(mDefaultArrayValues[i].str); } if (mCurrentArrayValues) { - for (int i = 0; i < mCurrentArrayLength; i++) + for (uint32_t i = 0; i < mCurrentArrayLength; i++) free(mCurrentArrayValues[i].str); } if (mEnumValues) { - for (int i = 0; i < mEnumLength; i++) + for (uint16_t i = 0; i < mEnumLength; i++) free(mEnumValues[i].str); } } @@ -123,11 +123,14 @@ MtpProperty::~MtpProperty() { delete[] mEnumValues; } -void MtpProperty::read(MtpDataPacket& packet) { - mCode = packet.getUInt16(); +bool MtpProperty::read(MtpDataPacket& packet) { + uint8_t temp8; + + if (!packet.getUInt16(mCode)) return false; bool deviceProp = isDeviceProperty(); - mType = packet.getUInt16(); - mWriteable = (packet.getUInt8() == 1); + if (!packet.getUInt16(mType)) return false; + if (!packet.getUInt8(temp8)) return false; + mWriteable = (temp8 == 1); switch (mType) { case MTP_TYPE_AINT8: case MTP_TYPE_AUINT8: @@ -140,28 +143,36 @@ void MtpProperty::read(MtpDataPacket& packet) { case MTP_TYPE_AINT128: case MTP_TYPE_AUINT128: mDefaultArrayValues = readArrayValues(packet, mDefaultArrayLength); - if (deviceProp) + if (!mDefaultArrayValues) return false; + if (deviceProp) { mCurrentArrayValues = readArrayValues(packet, mCurrentArrayLength); + if (!mCurrentArrayValues) return false; + } break; default: - readValue(packet, mDefaultValue); - if (deviceProp) - readValue(packet, mCurrentValue); + if (!readValue(packet, mDefaultValue)) return false; + if (deviceProp) { + if (!readValue(packet, mCurrentValue)) return false; + } } - if (!deviceProp) - mGroupCode = packet.getUInt32(); - mFormFlag = packet.getUInt8(); + if (!deviceProp) { + if (!packet.getUInt32(mGroupCode)) return false; + } + if (!packet.getUInt8(mFormFlag)) return false; if (mFormFlag == kFormRange) { - readValue(packet, mMinimumValue); - readValue(packet, mMaximumValue); - readValue(packet, mStepSize); + if (!readValue(packet, mMinimumValue)) return false; + if (!readValue(packet, mMaximumValue)) return false; + if (!readValue(packet, mStepSize)) return false; } else if (mFormFlag == kFormEnum) { - mEnumLength = packet.getUInt16(); + if (!packet.getUInt16(mEnumLength)) return false; mEnumValues = new MtpPropertyValue[mEnumLength]; - for (int i = 0; i < mEnumLength; i++) - readValue(packet, mEnumValues[i]); + for (int i = 0; i < mEnumLength; i++) { + if (!readValue(packet, mEnumValues[i])) return false; + } } + + return true; } void MtpProperty::write(MtpDataPacket& packet) { @@ -409,57 +420,59 @@ void MtpProperty::print(MtpPropertyValue& value, MtpString& buffer) { } } -void MtpProperty::readValue(MtpDataPacket& packet, MtpPropertyValue& value) { +bool MtpProperty::readValue(MtpDataPacket& packet, MtpPropertyValue& value) { MtpStringBuffer stringBuffer; switch (mType) { case MTP_TYPE_INT8: case MTP_TYPE_AINT8: - value.u.i8 = packet.getInt8(); + if (!packet.getInt8(value.u.i8)) return false; break; case MTP_TYPE_UINT8: case MTP_TYPE_AUINT8: - value.u.u8 = packet.getUInt8(); + if (!packet.getUInt8(value.u.u8)) return false; break; case MTP_TYPE_INT16: case MTP_TYPE_AINT16: - value.u.i16 = packet.getInt16(); + if (!packet.getInt16(value.u.i16)) return false; break; case MTP_TYPE_UINT16: case MTP_TYPE_AUINT16: - value.u.u16 = packet.getUInt16(); + if (!packet.getUInt16(value.u.u16)) return false; break; case MTP_TYPE_INT32: case MTP_TYPE_AINT32: - value.u.i32 = packet.getInt32(); + if (!packet.getInt32(value.u.i32)) return false; break; case MTP_TYPE_UINT32: case MTP_TYPE_AUINT32: - value.u.u32 = packet.getUInt32(); + if (!packet.getUInt32(value.u.u32)) return false; break; case MTP_TYPE_INT64: case MTP_TYPE_AINT64: - value.u.i64 = packet.getInt64(); + if (!packet.getInt64(value.u.i64)) return false; break; case MTP_TYPE_UINT64: case MTP_TYPE_AUINT64: - value.u.u64 = packet.getUInt64(); + if (!packet.getUInt64(value.u.u64)) return false; break; case MTP_TYPE_INT128: case MTP_TYPE_AINT128: - packet.getInt128(value.u.i128); + if (!packet.getInt128(value.u.i128)) return false; break; case MTP_TYPE_UINT128: case MTP_TYPE_AUINT128: - packet.getUInt128(value.u.u128); + if (!packet.getUInt128(value.u.u128)) return false; break; case MTP_TYPE_STR: - packet.getString(stringBuffer); + if (!packet.getString(stringBuffer)) return false; value.str = strdup(stringBuffer); break; default: ALOGE("unknown type %04X in MtpProperty::readValue", mType); + return false; } + return true; } void MtpProperty::writeValue(MtpDataPacket& packet, MtpPropertyValue& value) { @@ -517,8 +530,9 @@ void MtpProperty::writeValue(MtpDataPacket& packet, MtpPropertyValue& value) { } } -MtpPropertyValue* MtpProperty::readArrayValues(MtpDataPacket& packet, int& length) { - length = packet.getUInt32(); +MtpPropertyValue* MtpProperty::readArrayValues(MtpDataPacket& packet, uint32_t& length) { + if (!packet.getUInt32(length)) return NULL; + // Fail if resulting array is over 2GB. This is because the maximum array // size may be less than SIZE_MAX on some platforms. if ( CC_UNLIKELY( @@ -528,14 +542,17 @@ MtpPropertyValue* MtpProperty::readArrayValues(MtpDataPacket& packet, int& lengt return NULL; } MtpPropertyValue* result = new MtpPropertyValue[length]; - for (int i = 0; i < length; i++) - readValue(packet, result[i]); + for (uint32_t i = 0; i < length; i++) + if (!readValue(packet, result[i])) { + delete result; + return NULL; + } return result; } -void MtpProperty::writeArrayValues(MtpDataPacket& packet, MtpPropertyValue* values, int length) { +void MtpProperty::writeArrayValues(MtpDataPacket& packet, MtpPropertyValue* values, uint32_t length) { packet.putUInt32(length); - for (int i = 0; i < length; i++) + for (uint32_t i = 0; i < length; i++) writeValue(packet, values[i]); } diff --git a/media/mtp/MtpProperty.h b/media/mtp/MtpProperty.h index 06ca56e..2e2ead1 100644 --- a/media/mtp/MtpProperty.h +++ b/media/mtp/MtpProperty.h @@ -49,9 +49,9 @@ public: MtpPropertyValue mCurrentValue; // for array types - int mDefaultArrayLength; + uint32_t mDefaultArrayLength; MtpPropertyValue* mDefaultArrayValues; - int mCurrentArrayLength; + uint32_t mCurrentArrayLength; MtpPropertyValue* mCurrentArrayValues; enum { @@ -70,7 +70,7 @@ public: MtpPropertyValue mStepSize; // for enum form - int mEnumLength; + uint16_t mEnumLength; MtpPropertyValue* mEnumValues; public: @@ -83,7 +83,7 @@ public: inline MtpPropertyCode getPropertyCode() const { return mCode; } - void read(MtpDataPacket& packet); + bool read(MtpDataPacket& packet); void write(MtpDataPacket& packet); void setDefaultValue(const uint16_t* string); @@ -102,11 +102,11 @@ public: } private: - void readValue(MtpDataPacket& packet, MtpPropertyValue& value); + bool readValue(MtpDataPacket& packet, MtpPropertyValue& value); void writeValue(MtpDataPacket& packet, MtpPropertyValue& value); - MtpPropertyValue* readArrayValues(MtpDataPacket& packet, int& length); + MtpPropertyValue* readArrayValues(MtpDataPacket& packet, uint32_t& length); void writeArrayValues(MtpDataPacket& packet, - MtpPropertyValue* values, int length); + MtpPropertyValue* values, uint32_t length); }; }; // namespace android diff --git a/media/mtp/MtpRequestPacket.cpp b/media/mtp/MtpRequestPacket.cpp index 0e58e01..40b11b0 100644 --- a/media/mtp/MtpRequestPacket.cpp +++ b/media/mtp/MtpRequestPacket.cpp @@ -27,7 +27,8 @@ namespace android { MtpRequestPacket::MtpRequestPacket() - : MtpPacket(512) + : MtpPacket(512), + mParameterCount(0) { } @@ -37,10 +38,21 @@ MtpRequestPacket::~MtpRequestPacket() { #ifdef MTP_DEVICE int MtpRequestPacket::read(int fd) { int ret = ::read(fd, mBuffer, mBufferSize); - if (ret >= 0) + if (ret < 0) { + // file read error + return ret; + } + + // request packet should have 12 byte header followed by 0 to 5 32-bit arguments + if (ret >= MTP_CONTAINER_HEADER_SIZE + && ret <= MTP_CONTAINER_HEADER_SIZE + 5 * sizeof(uint32_t) + && ((ret - MTP_CONTAINER_HEADER_SIZE) & 3) == 0) { mPacketSize = ret; - else - mPacketSize = 0; + mParameterCount = (ret - MTP_CONTAINER_HEADER_SIZE) / sizeof(uint32_t); + } else { + ALOGE("Malformed MTP request packet"); + ret = -1; + } return ret; } #endif diff --git a/media/mtp/MtpRequestPacket.h b/media/mtp/MtpRequestPacket.h index 1201f11..79b798d 100644 --- a/media/mtp/MtpRequestPacket.h +++ b/media/mtp/MtpRequestPacket.h @@ -43,6 +43,10 @@ public: inline MtpOperationCode getOperationCode() const { return getContainerCode(); } inline void setOperationCode(MtpOperationCode code) { return setContainerCode(code); } + inline int getParameterCount() const { return mParameterCount; } + +private: + int mParameterCount; }; }; // namespace android diff --git a/media/mtp/MtpServer.cpp b/media/mtp/MtpServer.cpp index aa43967..931a09d 100644 --- a/media/mtp/MtpServer.cpp +++ b/media/mtp/MtpServer.cpp @@ -495,6 +495,9 @@ MtpResponseCode MtpServer::doOpenSession() { mResponse.setParameter(1, mSessionID); return MTP_RESPONSE_SESSION_ALREADY_OPEN; } + if (mRequest.getParameterCount() < 1) + return MTP_RESPONSE_INVALID_PARAMETER; + mSessionID = mRequest.getParameter(1); mSessionOpen = true; @@ -529,6 +532,9 @@ MtpResponseCode MtpServer::doGetStorageInfo() { if (!mSessionOpen) return MTP_RESPONSE_SESSION_NOT_OPEN; + if (mRequest.getParameterCount() < 1) + return MTP_RESPONSE_INVALID_PARAMETER; + MtpStorageID id = mRequest.getParameter(1); MtpStorage* storage = getStorage(id); if (!storage) @@ -550,6 +556,8 @@ MtpResponseCode MtpServer::doGetStorageInfo() { MtpResponseCode MtpServer::doGetObjectPropsSupported() { if (!mSessionOpen) return MTP_RESPONSE_SESSION_NOT_OPEN; + if (mRequest.getParameterCount() < 1) + return MTP_RESPONSE_INVALID_PARAMETER; MtpObjectFormat format = mRequest.getParameter(1); MtpObjectPropertyList* properties = mDatabase->getSupportedObjectProperties(format); mData.putAUInt16(properties); @@ -560,6 +568,8 @@ MtpResponseCode MtpServer::doGetObjectPropsSupported() { MtpResponseCode MtpServer::doGetObjectHandles() { if (!mSessionOpen) return MTP_RESPONSE_SESSION_NOT_OPEN; + if (mRequest.getParameterCount() < 3) + return MTP_RESPONSE_INVALID_PARAMETER; MtpStorageID storageID = mRequest.getParameter(1); // 0xFFFFFFFF for all storage MtpObjectFormat format = mRequest.getParameter(2); // 0 for all formats MtpObjectHandle parent = mRequest.getParameter(3); // 0xFFFFFFFF for objects with no parent @@ -577,6 +587,8 @@ MtpResponseCode MtpServer::doGetObjectHandles() { MtpResponseCode MtpServer::doGetNumObjects() { if (!mSessionOpen) return MTP_RESPONSE_SESSION_NOT_OPEN; + if (mRequest.getParameterCount() < 3) + return MTP_RESPONSE_INVALID_PARAMETER; MtpStorageID storageID = mRequest.getParameter(1); // 0xFFFFFFFF for all storage MtpObjectFormat format = mRequest.getParameter(2); // 0 for all formats MtpObjectHandle parent = mRequest.getParameter(3); // 0xFFFFFFFF for objects with no parent @@ -599,6 +611,8 @@ MtpResponseCode MtpServer::doGetObjectReferences() { return MTP_RESPONSE_SESSION_NOT_OPEN; if (!hasStorage()) return MTP_RESPONSE_INVALID_OBJECT_HANDLE; + if (mRequest.getParameterCount() < 1) + return MTP_RESPONSE_INVALID_PARAMETER; MtpObjectHandle handle = mRequest.getParameter(1); // FIXME - check for invalid object handle @@ -617,9 +631,13 @@ MtpResponseCode MtpServer::doSetObjectReferences() { return MTP_RESPONSE_SESSION_NOT_OPEN; if (!hasStorage()) return MTP_RESPONSE_INVALID_OBJECT_HANDLE; + if (mRequest.getParameterCount() < 1) + return MTP_RESPONSE_INVALID_PARAMETER; MtpStorageID handle = mRequest.getParameter(1); MtpObjectHandleList* references = mData.getAUInt32(); + if (!references) + return MTP_RESPONSE_INVALID_PARAMETER; MtpResponseCode result = mDatabase->setObjectReferences(handle, references); delete references; return result; @@ -628,6 +646,8 @@ MtpResponseCode MtpServer::doSetObjectReferences() { MtpResponseCode MtpServer::doGetObjectPropValue() { if (!hasStorage()) return MTP_RESPONSE_INVALID_OBJECT_HANDLE; + if (mRequest.getParameterCount() < 2) + return MTP_RESPONSE_INVALID_PARAMETER; MtpObjectHandle handle = mRequest.getParameter(1); MtpObjectProperty property = mRequest.getParameter(2); ALOGV("GetObjectPropValue %d %s\n", handle, @@ -639,6 +659,8 @@ MtpResponseCode MtpServer::doGetObjectPropValue() { MtpResponseCode MtpServer::doSetObjectPropValue() { if (!hasStorage()) return MTP_RESPONSE_INVALID_OBJECT_HANDLE; + if (mRequest.getParameterCount() < 2) + return MTP_RESPONSE_INVALID_PARAMETER; MtpObjectHandle handle = mRequest.getParameter(1); MtpObjectProperty property = mRequest.getParameter(2); ALOGV("SetObjectPropValue %d %s\n", handle, @@ -648,6 +670,8 @@ MtpResponseCode MtpServer::doSetObjectPropValue() { } MtpResponseCode MtpServer::doGetDevicePropValue() { + if (mRequest.getParameterCount() < 1) + return MTP_RESPONSE_INVALID_PARAMETER; MtpDeviceProperty property = mRequest.getParameter(1); ALOGV("GetDevicePropValue %s\n", MtpDebug::getDevicePropCodeName(property)); @@ -656,6 +680,8 @@ MtpResponseCode MtpServer::doGetDevicePropValue() { } MtpResponseCode MtpServer::doSetDevicePropValue() { + if (mRequest.getParameterCount() < 1) + return MTP_RESPONSE_INVALID_PARAMETER; MtpDeviceProperty property = mRequest.getParameter(1); ALOGV("SetDevicePropValue %s\n", MtpDebug::getDevicePropCodeName(property)); @@ -664,6 +690,8 @@ MtpResponseCode MtpServer::doSetDevicePropValue() { } MtpResponseCode MtpServer::doResetDevicePropValue() { + if (mRequest.getParameterCount() < 1) + return MTP_RESPONSE_INVALID_PARAMETER; MtpDeviceProperty property = mRequest.getParameter(1); ALOGV("ResetDevicePropValue %s\n", MtpDebug::getDevicePropCodeName(property)); @@ -674,6 +702,8 @@ MtpResponseCode MtpServer::doResetDevicePropValue() { MtpResponseCode MtpServer::doGetObjectPropList() { if (!hasStorage()) return MTP_RESPONSE_INVALID_OBJECT_HANDLE; + if (mRequest.getParameterCount() < 5) + return MTP_RESPONSE_INVALID_PARAMETER; MtpObjectHandle handle = mRequest.getParameter(1); // use uint32_t so we can support 0xFFFFFFFF @@ -691,6 +721,8 @@ MtpResponseCode MtpServer::doGetObjectPropList() { MtpResponseCode MtpServer::doGetObjectInfo() { if (!hasStorage()) return MTP_RESPONSE_INVALID_OBJECT_HANDLE; + if (mRequest.getParameterCount() < 1) + return MTP_RESPONSE_INVALID_PARAMETER; MtpObjectHandle handle = mRequest.getParameter(1); MtpObjectInfo info(handle); MtpResponseCode result = mDatabase->getObjectInfo(handle, info); @@ -732,6 +764,8 @@ MtpResponseCode MtpServer::doGetObjectInfo() { MtpResponseCode MtpServer::doGetObject() { if (!hasStorage()) return MTP_RESPONSE_INVALID_OBJECT_HANDLE; + if (mRequest.getParameterCount() < 1) + return MTP_RESPONSE_INVALID_PARAMETER; MtpObjectHandle handle = mRequest.getParameter(1); MtpString pathBuf; int64_t fileLength; @@ -765,6 +799,8 @@ MtpResponseCode MtpServer::doGetObject() { } MtpResponseCode MtpServer::doGetThumb() { + if (mRequest.getParameterCount() < 1) + return MTP_RESPONSE_INVALID_PARAMETER; MtpObjectHandle handle = mRequest.getParameter(1); size_t thumbSize; void* thumb = mDatabase->getThumbnail(handle, thumbSize); @@ -783,6 +819,8 @@ MtpResponseCode MtpServer::doGetThumb() { MtpResponseCode MtpServer::doGetPartialObject(MtpOperationCode operation) { if (!hasStorage()) return MTP_RESPONSE_INVALID_OBJECT_HANDLE; + if (mRequest.getParameterCount() < 4) + return MTP_RESPONSE_INVALID_PARAMETER; MtpObjectHandle handle = mRequest.getParameter(1); uint64_t offset; uint32_t length; @@ -832,6 +870,11 @@ MtpResponseCode MtpServer::doGetPartialObject(MtpOperationCode operation) { MtpResponseCode MtpServer::doSendObjectInfo() { MtpString path; + uint16_t temp16; + uint32_t temp32; + + if (mRequest.getParameterCount() < 2) + return MTP_RESPONSE_INVALID_PARAMETER; MtpStorageID storageID = mRequest.getParameter(1); MtpStorage* storage = getStorage(storageID); MtpObjectHandle parent = mRequest.getParameter(2); @@ -853,25 +896,29 @@ MtpResponseCode MtpServer::doSendObjectInfo() { } // read only the fields we need - mData.getUInt32(); // storage ID - MtpObjectFormat format = mData.getUInt16(); - mData.getUInt16(); // protection status - mSendObjectFileSize = mData.getUInt32(); - mData.getUInt16(); // thumb format - mData.getUInt32(); // thumb compressed size - mData.getUInt32(); // thumb pix width - mData.getUInt32(); // thumb pix height - mData.getUInt32(); // image pix width - mData.getUInt32(); // image pix height - mData.getUInt32(); // image bit depth - mData.getUInt32(); // parent - uint16_t associationType = mData.getUInt16(); - uint32_t associationDesc = mData.getUInt32(); // association desc - mData.getUInt32(); // sequence number + if (!mData.getUInt32(temp32)) return MTP_RESPONSE_INVALID_PARAMETER; // storage ID + if (!mData.getUInt16(temp16)) return MTP_RESPONSE_INVALID_PARAMETER; + MtpObjectFormat format = temp16; + if (!mData.getUInt16(temp16)) return MTP_RESPONSE_INVALID_PARAMETER; // protection status + if (!mData.getUInt32(temp32)) return MTP_RESPONSE_INVALID_PARAMETER; + mSendObjectFileSize = temp32; + if (!mData.getUInt16(temp16)) return MTP_RESPONSE_INVALID_PARAMETER; // thumb format + if (!mData.getUInt32(temp32)) return MTP_RESPONSE_INVALID_PARAMETER; // thumb compressed size + if (!mData.getUInt32(temp32)) return MTP_RESPONSE_INVALID_PARAMETER; // thumb pix width + if (!mData.getUInt32(temp32)) return MTP_RESPONSE_INVALID_PARAMETER; // thumb pix height + if (!mData.getUInt32(temp32)) return MTP_RESPONSE_INVALID_PARAMETER; // image pix width + if (!mData.getUInt32(temp32)) return MTP_RESPONSE_INVALID_PARAMETER; // image pix height + if (!mData.getUInt32(temp32)) return MTP_RESPONSE_INVALID_PARAMETER; // image bit depth + if (!mData.getUInt32(temp32)) return MTP_RESPONSE_INVALID_PARAMETER; // parent + if (!mData.getUInt16(temp16)) return MTP_RESPONSE_INVALID_PARAMETER; + uint16_t associationType = temp16; + if (!mData.getUInt32(temp32)) return MTP_RESPONSE_INVALID_PARAMETER; + uint32_t associationDesc = temp32; // association desc + if (!mData.getUInt32(temp32)) return MTP_RESPONSE_INVALID_PARAMETER; // sequence number MtpStringBuffer name, created, modified; - mData.getString(name); // file name - mData.getString(created); // date created - mData.getString(modified); // date modified + if (!mData.getString(name)) return MTP_RESPONSE_INVALID_PARAMETER; // file name + if (!mData.getString(created)) return MTP_RESPONSE_INVALID_PARAMETER; // date created + if (!mData.getString(modified)) return MTP_RESPONSE_INVALID_PARAMETER; // date modified // keywords follow ALOGV("name: %s format: %04X\n", (const char *)name, format); @@ -1066,6 +1113,8 @@ static void deletePath(const char* path) { MtpResponseCode MtpServer::doDeleteObject() { if (!hasStorage()) return MTP_RESPONSE_INVALID_OBJECT_HANDLE; + if (mRequest.getParameterCount() < 2) + return MTP_RESPONSE_INVALID_PARAMETER; MtpObjectHandle handle = mRequest.getParameter(1); MtpObjectFormat format = mRequest.getParameter(2); // FIXME - support deleting all objects if handle is 0xFFFFFFFF @@ -1087,6 +1136,8 @@ MtpResponseCode MtpServer::doDeleteObject() { } MtpResponseCode MtpServer::doGetObjectPropDesc() { + if (mRequest.getParameterCount() < 2) + return MTP_RESPONSE_INVALID_PARAMETER; MtpObjectProperty propCode = mRequest.getParameter(1); MtpObjectFormat format = mRequest.getParameter(2); ALOGV("GetObjectPropDesc %s %s\n", MtpDebug::getObjectPropCodeName(propCode), @@ -1100,6 +1151,8 @@ MtpResponseCode MtpServer::doGetObjectPropDesc() { } MtpResponseCode MtpServer::doGetDevicePropDesc() { + if (mRequest.getParameterCount() < 1) + return MTP_RESPONSE_INVALID_PARAMETER; MtpDeviceProperty propCode = mRequest.getParameter(1); ALOGV("GetDevicePropDesc %s\n", MtpDebug::getDevicePropCodeName(propCode)); MtpProperty* property = mDatabase->getDevicePropertyDesc(propCode); @@ -1113,6 +1166,8 @@ MtpResponseCode MtpServer::doGetDevicePropDesc() { MtpResponseCode MtpServer::doSendPartialObject() { if (!hasStorage()) return MTP_RESPONSE_INVALID_OBJECT_HANDLE; + if (mRequest.getParameterCount() < 4) + return MTP_RESPONSE_INVALID_PARAMETER; MtpObjectHandle handle = mRequest.getParameter(1); uint64_t offset = mRequest.getParameter(2); uint64_t offset2 = mRequest.getParameter(3); @@ -1180,6 +1235,8 @@ MtpResponseCode MtpServer::doSendPartialObject() { } MtpResponseCode MtpServer::doTruncateObject() { + if (mRequest.getParameterCount() < 3) + return MTP_RESPONSE_INVALID_PARAMETER; MtpObjectHandle handle = mRequest.getParameter(1); ObjectEdit* edit = getEditObject(handle); if (!edit) { @@ -1199,6 +1256,8 @@ MtpResponseCode MtpServer::doTruncateObject() { } MtpResponseCode MtpServer::doBeginEditObject() { + if (mRequest.getParameterCount() < 1) + return MTP_RESPONSE_INVALID_PARAMETER; MtpObjectHandle handle = mRequest.getParameter(1); if (getEditObject(handle)) { ALOGE("object already open for edit in doBeginEditObject"); @@ -1223,6 +1282,8 @@ MtpResponseCode MtpServer::doBeginEditObject() { } MtpResponseCode MtpServer::doEndEditObject() { + if (mRequest.getParameterCount() < 1) + return MTP_RESPONSE_INVALID_PARAMETER; MtpObjectHandle handle = mRequest.getParameter(1); ObjectEdit* edit = getEditObject(handle); if (!edit) { diff --git a/media/mtp/MtpStorageInfo.cpp b/media/mtp/MtpStorageInfo.cpp index 2b1a9ae..5d4ebbf 100644 --- a/media/mtp/MtpStorageInfo.cpp +++ b/media/mtp/MtpStorageInfo.cpp @@ -45,21 +45,23 @@ MtpStorageInfo::~MtpStorageInfo() { free(mVolumeIdentifier); } -void MtpStorageInfo::read(MtpDataPacket& packet) { +bool MtpStorageInfo::read(MtpDataPacket& packet) { MtpStringBuffer string; // read the device info - mStorageType = packet.getUInt16(); - mFileSystemType = packet.getUInt16(); - mAccessCapability = packet.getUInt16(); - mMaxCapacity = packet.getUInt64(); - mFreeSpaceBytes = packet.getUInt64(); - mFreeSpaceObjects = packet.getUInt32(); + if (!packet.getUInt16(mStorageType)) return false; + if (!packet.getUInt16(mFileSystemType)) return false; + if (!packet.getUInt16(mAccessCapability)) return false; + if (!packet.getUInt64(mMaxCapacity)) return false; + if (!packet.getUInt64(mFreeSpaceBytes)) return false; + if (!packet.getUInt32(mFreeSpaceObjects)) return false; - packet.getString(string); + if (!packet.getString(string)) return false; mStorageDescription = strdup((const char *)string); - packet.getString(string); + if (!packet.getString(string)) return false; mVolumeIdentifier = strdup((const char *)string); + + return true; } void MtpStorageInfo::print() { diff --git a/media/mtp/MtpStorageInfo.h b/media/mtp/MtpStorageInfo.h index 2cb626e..35a8189 100644 --- a/media/mtp/MtpStorageInfo.h +++ b/media/mtp/MtpStorageInfo.h @@ -39,7 +39,7 @@ public: MtpStorageInfo(MtpStorageID id); virtual ~MtpStorageInfo(); - void read(MtpDataPacket& packet); + bool read(MtpDataPacket& packet); void print(); }; diff --git a/media/mtp/MtpStringBuffer.cpp b/media/mtp/MtpStringBuffer.cpp index f3420a4..df04694 100644 --- a/media/mtp/MtpStringBuffer.cpp +++ b/media/mtp/MtpStringBuffer.cpp @@ -123,11 +123,17 @@ void MtpStringBuffer::set(const uint16_t* src) { mByteCount = dest - mBuffer; } -void MtpStringBuffer::readFromPacket(MtpDataPacket* packet) { - int count = packet->getUInt8(); +bool MtpStringBuffer::readFromPacket(MtpDataPacket* packet) { + uint8_t count; + if (!packet->getUInt8(count)) + return false; + uint8_t* dest = mBuffer; for (int i = 0; i < count; i++) { - uint16_t ch = packet->getUInt16(); + uint16_t ch; + + if (!packet->getUInt16(ch)) + return false; if (ch >= 0x0800) { *dest++ = (uint8_t)(0xE0 | (ch >> 12)); *dest++ = (uint8_t)(0x80 | ((ch >> 6) & 0x3F)); @@ -142,6 +148,7 @@ void MtpStringBuffer::readFromPacket(MtpDataPacket* packet) { *dest++ = 0; mCharCount = count; mByteCount = dest - mBuffer; + return true; } void MtpStringBuffer::writeToPacket(MtpDataPacket* packet) const { diff --git a/media/mtp/MtpStringBuffer.h b/media/mtp/MtpStringBuffer.h index e5150df..85d91e8 100644 --- a/media/mtp/MtpStringBuffer.h +++ b/media/mtp/MtpStringBuffer.h @@ -46,7 +46,7 @@ public: void set(const char* src); void set(const uint16_t* src); - void readFromPacket(MtpDataPacket* packet); + bool readFromPacket(MtpDataPacket* packet); void writeToPacket(MtpDataPacket* packet) const; inline int getCharCount() const { return mCharCount; } -- cgit v1.1