aboutsummaryrefslogtreecommitdiffstats
path: root/python/google/protobuf/internal/decoder.py
diff options
context:
space:
mode:
Diffstat (limited to 'python/google/protobuf/internal/decoder.py')
-rwxr-xr-xpython/google/protobuf/internal/decoder.py228
1 files changed, 209 insertions, 19 deletions
diff --git a/python/google/protobuf/internal/decoder.py b/python/google/protobuf/internal/decoder.py
index 461a30c..a4b9060 100755
--- a/python/google/protobuf/internal/decoder.py
+++ b/python/google/protobuf/internal/decoder.py
@@ -1,6 +1,6 @@
# Protocol Buffers - Google's data interchange format
# Copyright 2008 Google Inc. All rights reserved.
-# http://code.google.com/p/protobuf/
+# https://developers.google.com/protocol-buffers/
#
# Redistribution and use in source and binary forms, with or without
# modification, are permitted provided that the following conditions are
@@ -28,6 +28,10 @@
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
+#PY25 compatible for GAE.
+#
+# Copyright 2009 Google Inc. All Rights Reserved.
+
"""Code for decoding protocol buffer primitives.
This code is very similar to encoder.py -- read the docs for that module first.
@@ -81,17 +85,26 @@ we repeatedly read a tag, look up the corresponding decoder, and invoke it.
__author__ = 'kenton@google.com (Kenton Varda)'
import struct
+import sys ##PY25
+_PY2 = sys.version_info[0] < 3 ##PY25
from google.protobuf.internal import encoder
from google.protobuf.internal import wire_format
from google.protobuf import message
+# This will overflow and thus become IEEE-754 "infinity". We would use
+# "float('inf')" but it doesn't work on Windows pre-Python-2.6.
+_POS_INF = 1e10000
+_NEG_INF = -_POS_INF
+_NAN = _POS_INF * 0
+
+
# This is not for optimization, but rather to avoid conflicts with local
# variables named "message".
_DecodeError = message.DecodeError
-def _VarintDecoder(mask):
+def _VarintDecoder(mask, result_type):
"""Return an encoder for a basic varint value (does not include tag).
Decoded values will be bitwise-anded with the given mask before being
@@ -102,15 +115,18 @@ def _VarintDecoder(mask):
"""
local_ord = ord
+ py2 = _PY2 ##PY25
+##!PY25 py2 = str is bytes
def DecodeVarint(buffer, pos):
result = 0
shift = 0
while 1:
- b = local_ord(buffer[pos])
+ b = local_ord(buffer[pos]) if py2 else buffer[pos]
result |= ((b & 0x7f) << shift)
pos += 1
if not (b & 0x80):
result &= mask
+ result = result_type(result)
return (result, pos)
shift += 7
if shift >= 64:
@@ -118,15 +134,17 @@ def _VarintDecoder(mask):
return DecodeVarint
-def _SignedVarintDecoder(mask):
+def _SignedVarintDecoder(mask, result_type):
"""Like _VarintDecoder() but decodes signed values."""
local_ord = ord
+ py2 = _PY2 ##PY25
+##!PY25 py2 = str is bytes
def DecodeVarint(buffer, pos):
result = 0
shift = 0
while 1:
- b = local_ord(buffer[pos])
+ b = local_ord(buffer[pos]) if py2 else buffer[pos]
result |= ((b & 0x7f) << shift)
pos += 1
if not (b & 0x80):
@@ -135,19 +153,23 @@ def _SignedVarintDecoder(mask):
result |= ~mask
else:
result &= mask
+ result = result_type(result)
return (result, pos)
shift += 7
if shift >= 64:
raise _DecodeError('Too many bytes when decoding varint.')
return DecodeVarint
+# We force 32-bit values to int and 64-bit values to long to make
+# alternate implementations where the distinction is more significant
+# (e.g. the C++ implementation) simpler.
-_DecodeVarint = _VarintDecoder((1 << 64) - 1)
-_DecodeSignedVarint = _SignedVarintDecoder((1 << 64) - 1)
+_DecodeVarint = _VarintDecoder((1 << 64) - 1, long)
+_DecodeSignedVarint = _SignedVarintDecoder((1 << 64) - 1, long)
# Use these versions for values which must be limited to 32 bits.
-_DecodeVarint32 = _VarintDecoder((1 << 32) - 1)
-_DecodeSignedVarint32 = _SignedVarintDecoder((1 << 32) - 1)
+_DecodeVarint32 = _VarintDecoder((1 << 32) - 1, int)
+_DecodeSignedVarint32 = _SignedVarintDecoder((1 << 32) - 1, int)
def ReadTag(buffer, pos):
@@ -161,8 +183,10 @@ def ReadTag(buffer, pos):
use that, but not in Python.
"""
+ py2 = _PY2 ##PY25
+##!PY25 py2 = str is bytes
start = pos
- while ord(buffer[pos]) & 0x80:
+ while (ord(buffer[pos]) if py2 else buffer[pos]) & 0x80:
pos += 1
pos += 1
return (buffer[start:pos], pos)
@@ -269,10 +293,161 @@ def _StructPackDecoder(wire_type, format):
return _SimpleDecoder(wire_type, InnerDecode)
+def _FloatDecoder():
+ """Returns a decoder for a float field.
+
+ This code works around a bug in struct.unpack for non-finite 32-bit
+ floating-point values.
+ """
+
+ local_unpack = struct.unpack
+ b = (lambda x:x) if _PY2 else lambda x:x.encode('latin1') ##PY25
+
+ def InnerDecode(buffer, pos):
+ # We expect a 32-bit value in little-endian byte order. Bit 1 is the sign
+ # bit, bits 2-9 represent the exponent, and bits 10-32 are the significand.
+ new_pos = pos + 4
+ float_bytes = buffer[pos:new_pos]
+
+ # If this value has all its exponent bits set, then it's non-finite.
+ # In Python 2.4, struct.unpack will convert it to a finite 64-bit value.
+ # To avoid that, we parse it specially.
+ if ((float_bytes[3:4] in b('\x7F\xFF')) ##PY25
+##!PY25 if ((float_bytes[3:4] in b'\x7F\xFF')
+ and (float_bytes[2:3] >= b('\x80'))): ##PY25
+##!PY25 and (float_bytes[2:3] >= b'\x80')):
+ # If at least one significand bit is set...
+ if float_bytes[0:3] != b('\x00\x00\x80'): ##PY25
+##!PY25 if float_bytes[0:3] != b'\x00\x00\x80':
+ return (_NAN, new_pos)
+ # If sign bit is set...
+ if float_bytes[3:4] == b('\xFF'): ##PY25
+##!PY25 if float_bytes[3:4] == b'\xFF':
+ return (_NEG_INF, new_pos)
+ return (_POS_INF, new_pos)
+
+ # Note that we expect someone up-stack to catch struct.error and convert
+ # it to _DecodeError -- this way we don't have to set up exception-
+ # handling blocks every time we parse one value.
+ result = local_unpack('<f', float_bytes)[0]
+ return (result, new_pos)
+ return _SimpleDecoder(wire_format.WIRETYPE_FIXED32, InnerDecode)
+
+
+def _DoubleDecoder():
+ """Returns a decoder for a double field.
+
+ This code works around a bug in struct.unpack for not-a-number.
+ """
+
+ local_unpack = struct.unpack
+ b = (lambda x:x) if _PY2 else lambda x:x.encode('latin1') ##PY25
+
+ def InnerDecode(buffer, pos):
+ # We expect a 64-bit value in little-endian byte order. Bit 1 is the sign
+ # bit, bits 2-12 represent the exponent, and bits 13-64 are the significand.
+ new_pos = pos + 8
+ double_bytes = buffer[pos:new_pos]
+
+ # If this value has all its exponent bits set and at least one significand
+ # bit set, it's not a number. In Python 2.4, struct.unpack will treat it
+ # as inf or -inf. To avoid that, we treat it specially.
+##!PY25 if ((double_bytes[7:8] in b'\x7F\xFF')
+##!PY25 and (double_bytes[6:7] >= b'\xF0')
+##!PY25 and (double_bytes[0:7] != b'\x00\x00\x00\x00\x00\x00\xF0')):
+ if ((double_bytes[7:8] in b('\x7F\xFF')) ##PY25
+ and (double_bytes[6:7] >= b('\xF0')) ##PY25
+ and (double_bytes[0:7] != b('\x00\x00\x00\x00\x00\x00\xF0'))): ##PY25
+ return (_NAN, new_pos)
+
+ # Note that we expect someone up-stack to catch struct.error and convert
+ # it to _DecodeError -- this way we don't have to set up exception-
+ # handling blocks every time we parse one value.
+ result = local_unpack('<d', double_bytes)[0]
+ return (result, new_pos)
+ return _SimpleDecoder(wire_format.WIRETYPE_FIXED64, InnerDecode)
+
+
+def EnumDecoder(field_number, is_repeated, is_packed, key, new_default):
+ enum_type = key.enum_type
+ if is_packed:
+ local_DecodeVarint = _DecodeVarint
+ def DecodePackedField(buffer, pos, end, message, field_dict):
+ value = field_dict.get(key)
+ if value is None:
+ value = field_dict.setdefault(key, new_default(message))
+ (endpoint, pos) = local_DecodeVarint(buffer, pos)
+ endpoint += pos
+ if endpoint > end:
+ raise _DecodeError('Truncated message.')
+ while pos < endpoint:
+ value_start_pos = pos
+ (element, pos) = _DecodeSignedVarint32(buffer, pos)
+ if element in enum_type.values_by_number:
+ value.append(element)
+ else:
+ if not message._unknown_fields:
+ message._unknown_fields = []
+ tag_bytes = encoder.TagBytes(field_number,
+ wire_format.WIRETYPE_VARINT)
+ message._unknown_fields.append(
+ (tag_bytes, buffer[value_start_pos:pos]))
+ if pos > endpoint:
+ if element in enum_type.values_by_number:
+ del value[-1] # Discard corrupt value.
+ else:
+ del message._unknown_fields[-1]
+ raise _DecodeError('Packed element was truncated.')
+ return pos
+ return DecodePackedField
+ elif is_repeated:
+ tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT)
+ tag_len = len(tag_bytes)
+ def DecodeRepeatedField(buffer, pos, end, message, field_dict):
+ value = field_dict.get(key)
+ if value is None:
+ value = field_dict.setdefault(key, new_default(message))
+ while 1:
+ (element, new_pos) = _DecodeSignedVarint32(buffer, pos)
+ if element in enum_type.values_by_number:
+ value.append(element)
+ else:
+ if not message._unknown_fields:
+ message._unknown_fields = []
+ message._unknown_fields.append(
+ (tag_bytes, buffer[pos:new_pos]))
+ # Predict that the next tag is another copy of the same repeated
+ # field.
+ pos = new_pos + tag_len
+ if buffer[new_pos:pos] != tag_bytes or new_pos >= end:
+ # Prediction failed. Return.
+ if new_pos > end:
+ raise _DecodeError('Truncated message.')
+ return new_pos
+ return DecodeRepeatedField
+ else:
+ def DecodeField(buffer, pos, end, message, field_dict):
+ value_start_pos = pos
+ (enum_value, pos) = _DecodeSignedVarint32(buffer, pos)
+ if pos > end:
+ raise _DecodeError('Truncated message.')
+ if enum_value in enum_type.values_by_number:
+ field_dict[key] = enum_value
+ else:
+ if not message._unknown_fields:
+ message._unknown_fields = []
+ tag_bytes = encoder.TagBytes(field_number,
+ wire_format.WIRETYPE_VARINT)
+ message._unknown_fields.append(
+ (tag_bytes, buffer[value_start_pos:pos]))
+ return pos
+ return DecodeField
+
+
# --------------------------------------------------------------------
-Int32Decoder = EnumDecoder = _SimpleDecoder(
+Int32Decoder = _SimpleDecoder(
wire_format.WIRETYPE_VARINT, _DecodeSignedVarint32)
Int64Decoder = _SimpleDecoder(
@@ -294,8 +469,8 @@ Fixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<I')
Fixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<Q')
SFixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<i')
SFixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<q')
-FloatDecoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<f')
-DoubleDecoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<d')
+FloatDecoder = _FloatDecoder()
+DoubleDecoder = _DoubleDecoder()
BoolDecoder = _ModifiedDecoder(
wire_format.WIRETYPE_VARINT, _DecodeVarint, bool)
@@ -307,6 +482,14 @@ def StringDecoder(field_number, is_repeated, is_packed, key, new_default):
local_DecodeVarint = _DecodeVarint
local_unicode = unicode
+ def _ConvertToUnicode(byte_str):
+ try:
+ return local_unicode(byte_str, 'utf-8')
+ except UnicodeDecodeError, e:
+ # add more information to the error message and re-raise it.
+ e.reason = '%s in field: %s' % (e, key.full_name)
+ raise
+
assert not is_packed
if is_repeated:
tag_bytes = encoder.TagBytes(field_number,
@@ -321,7 +504,7 @@ def StringDecoder(field_number, is_repeated, is_packed, key, new_default):
new_pos = pos + size
if new_pos > end:
raise _DecodeError('Truncated string.')
- value.append(local_unicode(buffer[pos:new_pos], 'utf-8'))
+ value.append(_ConvertToUnicode(buffer[pos:new_pos]))
# Predict that the next tag is another copy of the same repeated field.
pos = new_pos + tag_len
if buffer[new_pos:pos] != tag_bytes or new_pos == end:
@@ -334,7 +517,7 @@ def StringDecoder(field_number, is_repeated, is_packed, key, new_default):
new_pos = pos + size
if new_pos > end:
raise _DecodeError('Truncated string.')
- field_dict[key] = local_unicode(buffer[pos:new_pos], 'utf-8')
+ field_dict[key] = _ConvertToUnicode(buffer[pos:new_pos])
return new_pos
return DecodeField
@@ -503,6 +686,7 @@ def MessageSetItemDecoder(extensions_by_number):
local_SkipField = SkipField
def DecodeItem(buffer, pos, end, message, field_dict):
+ message_set_item_start = pos
type_id = -1
message_start = -1
message_end = -1
@@ -541,6 +725,11 @@ def MessageSetItemDecoder(extensions_by_number):
# The only reason _InternalParse would return early is if it encountered
# an end-group tag.
raise _DecodeError('Unexpected end-group tag.')
+ else:
+ if not message._unknown_fields:
+ message._unknown_fields = []
+ message._unknown_fields.append((MESSAGE_SET_ITEM_TAG,
+ buffer[message_set_item_start:pos]))
return pos
@@ -552,8 +741,10 @@ def MessageSetItemDecoder(extensions_by_number):
def _SkipVarint(buffer, pos, end):
"""Skip a varint value. Returns the new position."""
-
- while ord(buffer[pos]) & 0x80:
+ # Previously ord(buffer[pos]) raised IndexError when pos is out of range.
+ # With this code, ord(b'') raises TypeError. Both are handled in
+ # python_message.py to generate a 'Truncated message' error.
+ while ord(buffer[pos:pos+1]) & 0x80:
pos += 1
pos += 1
if pos > end:
@@ -620,7 +811,6 @@ def _FieldSkipper():
]
wiretype_mask = wire_format.TAG_TYPE_MASK
- local_ord = ord
def SkipField(buffer, pos, end, tag_bytes):
"""Skips a field with the specified tag.
@@ -633,7 +823,7 @@ def _FieldSkipper():
"""
# The wire type is always in the first byte since varints are little-endian.
- wire_type = local_ord(tag_bytes[0]) & wiretype_mask
+ wire_type = ord(tag_bytes[0:1]) & wiretype_mask
return WIRETYPE_TO_SKIPPER[wire_type](buffer, pos, end)
return SkipField