diff options
Diffstat (limited to 'python/google/protobuf/reflection.py')
-rwxr-xr-x | python/google/protobuf/reflection.py | 1487 |
1 files changed, 487 insertions, 1000 deletions
diff --git a/python/google/protobuf/reflection.py b/python/google/protobuf/reflection.py index d65d8b6..5b23803 100755 --- a/python/google/protobuf/reflection.py +++ b/python/google/protobuf/reflection.py @@ -50,9 +50,13 @@ this file*. __author__ = 'robinson@google.com (Will Robinson)' -import heapq -import threading +try: + from cStringIO import StringIO +except ImportError: + from StringIO import StringIO +import struct import weakref + # We use "as" to avoid name collisions with variables. from google.protobuf.internal import containers from google.protobuf.internal import decoder @@ -139,14 +143,26 @@ class GeneratedProtocolMessageType(type): type. """ descriptor = dictionary[GeneratedProtocolMessageType._DESCRIPTOR_KEY] + + cls._decoders_by_tag = {} + cls._extensions_by_name = {} + cls._extensions_by_number = {} + if (descriptor.has_options and + descriptor.GetOptions().message_set_wire_format): + cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = ( + decoder.MessageSetItemDecoder(cls._extensions_by_number)) + # We act as a "friend" class of the descriptor, setting # its _concrete_class attribute the first time we use a # given descriptor to initialize a concrete protocol message - # class. + # class. We also attach stuff to each FieldDescriptor for quick + # lookup later on. concrete_class_attr_name = '_concrete_class' if not hasattr(descriptor, concrete_class_attr_name): setattr(descriptor, concrete_class_attr_name, cls) - cls._known_extensions = [] + for field in descriptor.fields: + _AttachFieldHelpers(cls, field) + _AddEnumValues(descriptor, cls) _AddInitMethod(descriptor, cls) _AddPropertiesForFields(descriptor, cls) @@ -184,30 +200,33 @@ def _PropertyName(proto_field_name): # return proto_field_name + "_" # return proto_field_name # """ + # Kenton says: The above is a BAD IDEA. People rely on being able to use + # getattr() and setattr() to reflectively manipulate field values. If we + # rename the properties, then every such user has to also make sure to apply + # the same transformation. Note that currently if you name a field "yield", + # you can still access it just fine using getattr/setattr -- it's not even + # that cumbersome to do so. + # TODO(kenton): Remove this method entirely if/when everyone agrees with my + # position. return proto_field_name -def _ValueFieldName(proto_field_name): - """Returns the name of the (internal) instance attribute which objects - should use to store the current value for a given protocol message field. - - Args: - proto_field_name: The protocol message field name, exactly - as it appears (or would appear) in a .proto file. - """ - return '_value_' + proto_field_name +def _VerifyExtensionHandle(message, extension_handle): + """Verify that the given extension handle is valid.""" + if not isinstance(extension_handle, _FieldDescriptor): + raise KeyError('HasExtension() expects an extension handle, got: %s' % + extension_handle) -def _HasFieldName(proto_field_name): - """Returns the name of the (internal) instance attribute which - objects should use to store a boolean telling whether this field - is explicitly set or not. + if not extension_handle.is_extension: + raise KeyError('"%s" is not an extension.' % extension_handle.full_name) - Args: - proto_field_name: The protocol message field name, exactly - as it appears (or would appear) in a .proto file. - """ - return '_has_' + proto_field_name + if extension_handle.containing_type is not message.DESCRIPTOR: + raise KeyError('Extension "%s" extends message type "%s", but this ' + 'message is of type "%s".' % + (extension_handle.full_name, + extension_handle.containing_type.full_name, + message.DESCRIPTOR.full_name)) def _AddSlots(message_descriptor, dictionary): @@ -218,16 +237,57 @@ def _AddSlots(message_descriptor, dictionary): message_descriptor: A Descriptor instance describing this message type. dictionary: Class dictionary to which we'll add a '__slots__' entry. """ - field_names = [_ValueFieldName(f.name) for f in message_descriptor.fields] - field_names.extend(_HasFieldName(f.name) for f in message_descriptor.fields - if f.label != _FieldDescriptor.LABEL_REPEATED) - field_names.extend(('Extensions', - '_cached_byte_size', - '_cached_byte_size_dirty', - '_called_transition_to_nonempty', - '_listener', - '_lock', '__weakref__')) - dictionary['__slots__'] = field_names + dictionary['__slots__'] = ['_cached_byte_size', + '_cached_byte_size_dirty', + '_fields', + '_is_present_in_parent', + '_listener', + '_listener_for_children', + '__weakref__'] + + +def _IsMessageSetExtension(field): + return (field.is_extension and + field.containing_type.has_options and + field.containing_type.GetOptions().message_set_wire_format and + field.type == _FieldDescriptor.TYPE_MESSAGE and + field.message_type == field.extension_scope and + field.label == _FieldDescriptor.LABEL_OPTIONAL) + + +def _AttachFieldHelpers(cls, field_descriptor): + is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED) + is_packed = (field_descriptor.has_options and + field_descriptor.GetOptions().packed) + + if _IsMessageSetExtension(field_descriptor): + field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number) + sizer = encoder.MessageSetItemSizer(field_descriptor.number) + else: + field_encoder = type_checkers.TYPE_TO_ENCODER[field_descriptor.type]( + field_descriptor.number, is_repeated, is_packed) + sizer = type_checkers.TYPE_TO_SIZER[field_descriptor.type]( + field_descriptor.number, is_repeated, is_packed) + + field_descriptor._encoder = field_encoder + field_descriptor._sizer = sizer + field_descriptor._default_constructor = _DefaultValueConstructorForField( + field_descriptor) + + def AddDecoder(wiretype, is_packed): + tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype) + cls._decoders_by_tag[tag_bytes] = ( + type_checkers.TYPE_TO_DECODER[field_descriptor.type]( + field_descriptor.number, is_repeated, is_packed, + field_descriptor, field_descriptor._default_constructor)) + + AddDecoder(type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type], + False) + + if is_repeated and wire_format.IsTypePackable(field_descriptor.type): + # To support wire compatibility of adding packed = true, add a decoder for + # packed values regardless of the field's options. + AddDecoder(wire_format.WIRETYPE_LENGTH_DELIMITED, True) def _AddClassAttributesForNestedExtensions(descriptor, dictionary): @@ -249,44 +309,51 @@ def _AddEnumValues(descriptor, cls): setattr(cls, enum_value.name, enum_value.number) -def _DefaultValueForField(message, field): - """Returns a default value for a field. +def _DefaultValueConstructorForField(field): + """Returns a function which returns a default value for a field. Args: + field: FieldDescriptor object for this field. + + The returned function has one argument: message: Message instance containing this field, or a weakref proxy of same. - field: FieldDescriptor object for this field. - Returns: A default value for this field. May refer back to |message| - via a weak reference. + That function in turn returns a default value for this field. The default + value may refer back to |message| via a weak reference. """ - # TODO(robinson): Only the repeated fields need a reference to 'message' (so - # that they can set the 'has' bit on the containing Message when someone - # append()s a value). We could special-case this, and avoid an extra - # function call on __init__() and Clear() for non-repeated fields. - - # TODO(robinson): Find a better place for the default value assertion in this - # function. No need to repeat them every time the client calls Clear('foo'). - # (We should probably just assert these things once and as early as possible, - # by tightening checking in the descriptor classes.) + if field.label == _FieldDescriptor.LABEL_REPEATED: if field.default_value != []: raise ValueError('Repeated field default value not empty list: %s' % ( field.default_value)) - listener = _Listener(message, None) if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: # We can't look at _concrete_class yet since it might not have # been set. (Depends on order in which we initialize the classes). - return containers.RepeatedCompositeFieldContainer( - listener, field.message_type) + message_type = field.message_type + def MakeRepeatedMessageDefault(message): + return containers.RepeatedCompositeFieldContainer( + message._listener_for_children, field.message_type) + return MakeRepeatedMessageDefault else: - return containers.RepeatedScalarFieldContainer( - listener, type_checkers.GetTypeChecker(field.cpp_type, field.type)) + type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type) + def MakeRepeatedScalarDefault(message): + return containers.RepeatedScalarFieldContainer( + message._listener_for_children, type_checker) + return MakeRepeatedScalarDefault if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: - assert field.default_value is None + # _concrete_class may not yet be initialized. + message_type = field.message_type + def MakeSubMessageDefault(message): + result = message_type._concrete_class() + result._SetListener(message._listener_for_children) + return result + return MakeSubMessageDefault - return field.default_value + def MakeScalarDefault(message): + return field.default_value + return MakeScalarDefault def _AddInitMethod(message_descriptor, cls): @@ -295,21 +362,29 @@ def _AddInitMethod(message_descriptor, cls): def init(self, **kwargs): self._cached_byte_size = 0 self._cached_byte_size_dirty = False + self._fields = {} + self._is_present_in_parent = False self._listener = message_listener_mod.NullMessageListener() - self._called_transition_to_nonempty = False - # TODO(robinson): We should only create a lock if we really need one - # in this class. - self._lock = threading.Lock() - for field in fields: - default_value = _DefaultValueForField(self, field) - python_field_name = _ValueFieldName(field.name) - setattr(self, python_field_name, default_value) - if field.label != _FieldDescriptor.LABEL_REPEATED: - setattr(self, _HasFieldName(field.name), False) - self.Extensions = _ExtensionDict(self, cls._known_extensions) + self._listener_for_children = _Listener(self) for field_name, field_value in kwargs.iteritems(): field = _GetFieldByName(message_descriptor, field_name) - _MergeFieldOrExtension(self, field, field_value) + if field is None: + raise TypeError("%s() got an unexpected keyword argument '%s'" % + (message_descriptor.name, field_name)) + if field.label == _FieldDescriptor.LABEL_REPEATED: + copy = field._default_constructor(self) + if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: # Composite + for val in field_value: + copy.add().MergeFrom(val) + else: # Scalar + copy.extend(field_value) + self._fields[field] = copy + elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + copy = field._default_constructor(self) + copy.MergeFrom(field_value) + self._fields[field] = copy + else: + self._fields[field] = field_value init.__module__ = None init.__doc__ = None @@ -336,6 +411,11 @@ def _AddPropertiesForFields(descriptor, cls): for field in descriptor.fields: _AddPropertiesForField(field, cls) + if descriptor.is_extendable: + # _ExtensionDict is just an adaptor with no state so we allocate a new one + # every time it is accessed. + cls.Extensions = property(lambda self: _ExtensionDict(self)) + def _AddPropertiesForField(field, cls): """Adds a public property for a protocol message field. @@ -377,11 +457,22 @@ def _AddPropertiesForRepeatedField(field, cls): cls: The class we're constructing. """ proto_field_name = field.name - python_field_name = _ValueFieldName(proto_field_name) property_name = _PropertyName(proto_field_name) def getter(self): - return getattr(self, python_field_name) + field_value = self._fields.get(field) + if field_value is None: + # Construct a new object to represent this field. + field_value = field._default_constructor(self) + + # Atomically check if another thread has preempted us and, if not, swap + # in the new object we just created. If someone has preempted us, we + # take that object and discard ours. + # WARNING: We are relying on setdefault() being atomic. This is true + # in CPython but we haven't investigated others. This warning appears + # in several other locations in this file. + field_value = self._fields.setdefault(field, field_value) + return field_value getter.__module__ = None getter.__doc__ = 'Getter for %s.' % proto_field_name @@ -407,21 +498,21 @@ def _AddPropertiesForNonRepeatedScalarField(field, cls): cls: The class we're constructing. """ proto_field_name = field.name - python_field_name = _ValueFieldName(proto_field_name) - has_field_name = _HasFieldName(proto_field_name) property_name = _PropertyName(proto_field_name) type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type) + default_value = field.default_value def getter(self): - return getattr(self, python_field_name) + return self._fields.get(field, default_value) getter.__module__ = None getter.__doc__ = 'Getter for %s.' % proto_field_name def setter(self, new_value): type_checker.CheckValue(new_value) - setattr(self, has_field_name, True) - self._MarkByteSizeDirty() - self._MaybeCallTransitionToNonemptyCallback() - setattr(self, python_field_name, new_value) + self._fields[field] = new_value + # Check _cached_byte_size_dirty inline to improve performance, since scalar + # setters are called frequently. + if not self._cached_byte_size_dirty: + self._Modified() setter.__module__ = None setter.__doc__ = 'Setter for %s.' % proto_field_name @@ -444,25 +535,23 @@ def _AddPropertiesForNonRepeatedCompositeField(field, cls): # TODO(robinson): Remove duplication with similar method # for non-repeated scalars. proto_field_name = field.name - python_field_name = _ValueFieldName(proto_field_name) - has_field_name = _HasFieldName(proto_field_name) property_name = _PropertyName(proto_field_name) message_type = field.message_type def getter(self): - # TODO(robinson): Appropriately scary note about double-checked locking. - field_value = getattr(self, python_field_name) + field_value = self._fields.get(field) if field_value is None: - self._lock.acquire() - try: - field_value = getattr(self, python_field_name) - if field_value is None: - field_class = message_type._concrete_class - field_value = field_class() - field_value._SetListener(_Listener(self, has_field_name)) - setattr(self, python_field_name, field_value) - finally: - self._lock.release() + # Construct a new object to represent this field. + field_value = message_type._concrete_class() + field_value._SetListener(self._listener_for_children) + + # Atomically check if another thread has preempted us and, if not, swap + # in the new object we just created. If someone has preempted us, we + # take that object and discard ours. + # WARNING: We are relying on setdefault() being atomic. This is true + # in CPython but we haven't investigated others. This warning appears + # in several other locations in this file. + field_value = self._fields.setdefault(field, field_value) return field_value getter.__module__ = None getter.__doc__ = 'Getter for %s.' % proto_field_name @@ -490,7 +579,27 @@ def _AddStaticMethods(cls): # TODO(robinson): This probably needs to be thread-safe(?) def RegisterExtension(extension_handle): extension_handle.containing_type = cls.DESCRIPTOR - cls._known_extensions.append(extension_handle) + _AttachFieldHelpers(cls, extension_handle) + + # Try to insert our extension, failing if an extension with the same number + # already exists. + actual_handle = cls._extensions_by_number.setdefault( + extension_handle.number, extension_handle) + if actual_handle is not extension_handle: + raise AssertionError( + 'Extensions "%s" and "%s" both try to extend message type "%s" with ' + 'field number %d.' % + (extension_handle.full_name, actual_handle.full_name, + cls.DESCRIPTOR.full_name, extension_handle.number)) + + cls._extensions_by_name[extension_handle.full_name] = extension_handle + + handle = extension_handle # avoid line wrapping + if _IsMessageSetExtension(handle): + # MessageSet extension. Also register under type name. + cls._extensions_by_name[ + extension_handle.message_type.full_name] = extension_handle + cls.RegisterExtension = staticmethod(RegisterExtension) def FromString(s): @@ -500,115 +609,107 @@ def _AddStaticMethods(cls): cls.FromString = staticmethod(FromString) +def _IsPresent(item): + """Given a (FieldDescriptor, value) tuple from _fields, return true if the + value should be included in the list returned by ListFields().""" + + if item[0].label == _FieldDescriptor.LABEL_REPEATED: + return bool(item[1]) + elif item[0].cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + return item[1]._is_present_in_parent + else: + return True + + def _AddListFieldsMethod(message_descriptor, cls): """Helper for _AddMessageMethods().""" - # Ensure that we always list in ascending field-number order. - # For non-extension fields, we can do the sort once, here, at import-time. - # For extensions, we sort on each ListFields() call, though - # we could do better if we have to. - fields = sorted(message_descriptor.fields, key=lambda f: f.number) - has_field_names = (_HasFieldName(f.name) for f in fields) - value_field_names = (_ValueFieldName(f.name) for f in fields) - triplets = zip(has_field_names, value_field_names, fields) - def ListFields(self): - # We need to list all extension and non-extension fields - # together, in sorted order by field number. - - # Step 0: Get an iterator over all "set" non-extension fields, - # sorted by field number. - # This iterator yields (field_number, field_descriptor, value) tuples. - def SortedSetFieldsIter(): - # Note that triplets is already sorted by field number. - for has_field_name, value_field_name, field_descriptor in triplets: - if field_descriptor.label == _FieldDescriptor.LABEL_REPEATED: - value = getattr(self, _ValueFieldName(field_descriptor.name)) - if len(value) > 0: - yield (field_descriptor.number, field_descriptor, value) - elif getattr(self, _HasFieldName(field_descriptor.name)): - value = getattr(self, _ValueFieldName(field_descriptor.name)) - yield (field_descriptor.number, field_descriptor, value) - sorted_fields = SortedSetFieldsIter() - - # Step 1: Get an iterator over all "set" extension fields, - # sorted by field number. - # This iterator ALSO yields (field_number, field_descriptor, value) tuples. - # TODO(robinson): It's not necessary to repeat this with each - # serialization call. We can do better. - sorted_extension_fields = sorted( - [(f.number, f, v) for f, v in self.Extensions._ListSetExtensions()]) - - # Step 2: Create a composite iterator that merges the extension- - # and non-extension fields, and that still yields fields in - # sorted order. - all_set_fields = _ImergeSorted(sorted_fields, sorted_extension_fields) - - # Step 3: Strip off the field numbers and return. - return [field[1:] for field in all_set_fields] + all_fields = [item for item in self._fields.iteritems() if _IsPresent(item)] + all_fields.sort(key = lambda item: item[0].number) + return all_fields cls.ListFields = ListFields -def _AddHasFieldMethod(cls): + +def _AddHasFieldMethod(message_descriptor, cls): """Helper for _AddMessageMethods().""" + + singular_fields = {} + for field in message_descriptor.fields: + if field.label != _FieldDescriptor.LABEL_REPEATED: + singular_fields[field.name] = field + def HasField(self, field_name): try: - return getattr(self, _HasFieldName(field_name)) - except AttributeError: - raise ValueError('Protocol message has no "%s" field.' % field_name) + field = singular_fields[field_name] + except KeyError: + raise ValueError( + 'Protocol message has no singular "%s" field.' % field_name) + + if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + value = self._fields.get(field) + return value is not None and value._is_present_in_parent + else: + return field in self._fields cls.HasField = HasField -def _AddClearFieldMethod(cls): +def _AddClearFieldMethod(message_descriptor, cls): """Helper for _AddMessageMethods().""" def ClearField(self, field_name): - field = _GetFieldByName(self.DESCRIPTOR, field_name) - proto_field_name = field.name - python_field_name = _ValueFieldName(proto_field_name) - has_field_name = _HasFieldName(proto_field_name) - default_value = _DefaultValueForField(self, field) - if field.label == _FieldDescriptor.LABEL_REPEATED: - self._MarkByteSizeDirty() - else: - if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: - old_field_value = getattr(self, python_field_name) - if old_field_value is not None: - # Snip the old object out of the object tree. - old_field_value._SetListener(None) - if getattr(self, has_field_name): - setattr(self, has_field_name, False) - # Set dirty bit on ourself and parents only if - # we're actually changing state. - self._MarkByteSizeDirty() - setattr(self, python_field_name, default_value) + try: + field = message_descriptor.fields_by_name[field_name] + except KeyError: + raise ValueError('Protocol message has no "%s" field.' % field_name) + + if field in self._fields: + # Note: If the field is a sub-message, its listener will still point + # at us. That's fine, because the worst than can happen is that it + # will call _Modified() and invalidate our byte size. Big deal. + del self._fields[field] + + # Always call _Modified() -- even if nothing was changed, this is + # a mutating method, and thus calling it should cause the field to become + # present in the parent message. + self._Modified() + cls.ClearField = ClearField def _AddClearExtensionMethod(cls): """Helper for _AddMessageMethods().""" def ClearExtension(self, extension_handle): - self.Extensions._ClearExtension(extension_handle) + _VerifyExtensionHandle(self, extension_handle) + + # Similar to ClearField(), above. + if extension_handle in self._fields: + del self._fields[extension_handle] + self._Modified() cls.ClearExtension = ClearExtension -def _AddClearMethod(cls): +def _AddClearMethod(message_descriptor, cls): """Helper for _AddMessageMethods().""" def Clear(self): # Clear fields. - fields = self.DESCRIPTOR.fields - for field in fields: - self.ClearField(field.name) - # Clear extensions. - extensions = self.Extensions._ListSetExtensions() - for extension in extensions: - self.ClearExtension(extension[0]) + self._fields = {} + self._Modified() cls.Clear = Clear def _AddHasExtensionMethod(cls): """Helper for _AddMessageMethods().""" def HasExtension(self, extension_handle): - return self.Extensions._HasExtension(extension_handle) + _VerifyExtensionHandle(self, extension_handle) + if extension_handle.label == _FieldDescriptor.LABEL_REPEATED: + raise KeyError('"%s" is repeated.' % extension_handle.full_name) + + if extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + value = self._fields.get(extension_handle) + return value is not None and value._is_present_in_parent + else: + return extension_handle in self._fields cls.HasExtension = HasExtension @@ -622,26 +723,8 @@ def _AddEqualsMethod(message_descriptor, cls): if self is other: return True - # Compare all fields contained directly in this message. - for field_descriptor in message_descriptor.fields: - label = field_descriptor.label - property_name = _PropertyName(field_descriptor.name) - # Non-repeated field equality requires matching "has" bits as well - # as having an equal value. - if label != _FieldDescriptor.LABEL_REPEATED: - self_has = self.HasField(property_name) - other_has = other.HasField(property_name) - if self_has != other_has: - return False - if not self_has: - # If the "has" bit for this field is False, we must stop here. - # Otherwise we will recurse forever on recursively-defined protos. - continue - if getattr(self, property_name) != getattr(other, property_name): - return False + return self.ListFields() == other.ListFields() - # Compare the extensions present in both messages. - return self.Extensions == other.Extensions cls.__eq__ = __eq__ @@ -685,618 +768,202 @@ def _BytesForNonRepeatedElement(value, field_number, field_type): def _AddByteSizeMethod(message_descriptor, cls): """Helper for _AddMessageMethods().""" - def BytesForField(message, field, value): - """Returns the number of bytes required to serialize a single field - in message. The field may be repeated or not, composite or not. - - Args: - message: The Message instance containing a field of the given type. - field: A FieldDescriptor describing the field of interest. - value: The value whose byte size we're interested in. - - Returns: The number of bytes required to serialize the current value - of "field" in "message", including space for tags and any other - necessary information. - """ - - if _MessageSetField(field): - return wire_format.MessageSetItemByteSize(field.number, value) - - field_number, field_type = field.number, field.type - - # Repeated fields. - if field.label == _FieldDescriptor.LABEL_REPEATED: - elements = value - else: - elements = [value] - - if field.GetOptions().packed: - content_size = _ContentBytesForPackedField(message, field, elements) - if content_size: - tag_size = wire_format.TagByteSize(field_number) - length_size = wire_format.Int32ByteSizeNoTag(content_size) - return tag_size + length_size + content_size - else: - return 0 - else: - return sum(_BytesForNonRepeatedElement(element, field_number, field_type) - for element in elements) - - def _ContentBytesForPackedField(self, field, value): - """Returns the number of bytes required to serialize the actual - content of a packed field (not including the tag or the encoding - of the length. - - Args: - self: The Message instance containing a field of the given type. - field: A FieldDescriptor describing the field of interest. - value: The value whose byte size we're interested in. - - Returns: The number of bytes required to serialize the current value - of the packed "field" in "message", excluding space for tags and the - length encoding. - """ - size = sum(_BytesForNonRepeatedElement(element, field.number, field.type) - for element in value) - # In the packed case, there are no per element tags. - return size - wire_format.TagByteSize(field.number) * len(value) - - fields = message_descriptor.fields - has_field_names = (_HasFieldName(f.name) for f in fields) - zipped = zip(has_field_names, fields) - def ByteSize(self): if not self._cached_byte_size_dirty: return self._cached_byte_size size = 0 - # Hardcoded fields first. - for has_field_name, field in zipped: - if (field.label == _FieldDescriptor.LABEL_REPEATED - or getattr(self, has_field_name)): - value = getattr(self, _ValueFieldName(field.name)) - size += BytesForField(self, field, value) - # Extensions next. - for field, value in self.Extensions._ListSetExtensions(): - size += BytesForField(self, field, value) + for field_descriptor, field_value in self.ListFields(): + size += field_descriptor._sizer(field_value) self._cached_byte_size = size self._cached_byte_size_dirty = False + self._listener_for_children.dirty = False return size - cls._ContentBytesForPackedField = _ContentBytesForPackedField cls.ByteSize = ByteSize -def _MessageSetField(field_descriptor): - """Checks if a field should be serialized using the message set wire format. - - Args: - field_descriptor: Descriptor of the field. - - Returns: - True if the field should be serialized using the message set wire format, - false otherwise. - """ - return (field_descriptor.is_extension and - field_descriptor.label != _FieldDescriptor.LABEL_REPEATED and - field_descriptor.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE and - field_descriptor.containing_type.GetOptions().message_set_wire_format) - - -def _SerializeValueToEncoder(value, field_number, field_descriptor, encoder): - """Appends the serialization of a single value to encoder. - - Args: - value: Value to serialize. - field_number: Field number of this value. - field_descriptor: Descriptor of the field to serialize. - encoder: encoder.Encoder object to which we should serialize this value. - """ - if _MessageSetField(field_descriptor): - encoder.AppendMessageSetItem(field_number, value) - return - - try: - method = type_checkers.TYPE_TO_SERIALIZE_METHOD[field_descriptor.type] - method(encoder, field_number, value) - except KeyError: - raise message_mod.EncodeError('Unrecognized field type: %d' % - field_descriptor.type) - - -def _ImergeSorted(*streams): - """Merges N sorted iterators into a single sorted iterator. - Each element in streams must be an iterable that yields - its elements in sorted order, and the elements contained - in each stream must all be comparable. - - There may be repeated elements in the component streams or - across the streams; the repeated elements will all be repeated - in the merged iterator as well. - - I believe that the heapq module at HEAD in the Python - sources has a method like this, but for now we roll our own. - """ - iters = [iter(stream) for stream in streams] - heap = [] - for index, it in enumerate(iters): - try: - heap.append((it.next(), index)) - except StopIteration: - pass - heapq.heapify(heap) - - while heap: - smallest_value, idx = heap[0] - yield smallest_value - try: - next_element = iters[idx].next() - heapq.heapreplace(heap, (next_element, idx)) - except StopIteration: - heapq.heappop(heap) - - def _AddSerializeToStringMethod(message_descriptor, cls): """Helper for _AddMessageMethods().""" def SerializeToString(self): # Check if the message has all of its required fields set. errors = [] - if not _InternalIsInitialized(self, errors): - raise message_mod.EncodeError('\n'.join(errors)) + if not self.IsInitialized(): + raise message_mod.EncodeError( + 'Message is missing required fields: ' + + ','.join(self.FindInitializationErrors())) return self.SerializePartialToString() cls.SerializeToString = SerializeToString def _AddSerializePartialToStringMethod(message_descriptor, cls): """Helper for _AddMessageMethods().""" - Encoder = encoder.Encoder def SerializePartialToString(self): - encoder = Encoder() - # We need to serialize all extension and non-extension fields - # together, in sorted order by field number. - for field_descriptor, field_value in self.ListFields(): - if field_descriptor.label == _FieldDescriptor.LABEL_REPEATED: - repeated_value = field_value - else: - repeated_value = [field_value] - if field_descriptor.GetOptions().packed: - # First, write the field number and WIRETYPE_LENGTH_DELIMITED. - field_number = field_descriptor.number - encoder.AppendTag(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) - # Next, write the number of bytes. - content_bytes = self._ContentBytesForPackedField( - field_descriptor, field_value) - encoder.AppendInt32NoTag(content_bytes) - # Finally, write the actual values. - try: - method = type_checkers.TYPE_TO_NOTAG_SERIALIZE_METHOD[ - field_descriptor.type] - for value in repeated_value: - method(encoder, value) - except KeyError: - raise message_mod.EncodeError('Unrecognized field type: %d' % - field_descriptor.type) - else: - for element in repeated_value: - _SerializeValueToEncoder(element, field_descriptor.number, - field_descriptor, encoder) - return encoder.ToString() - + out = StringIO() + self._InternalSerialize(out.write) + return out.getvalue() cls.SerializePartialToString = SerializePartialToString + def InternalSerialize(self, write_bytes): + for field_descriptor, field_value in self.ListFields(): + field_descriptor._encoder(write_bytes, field_value) + cls._InternalSerialize = InternalSerialize -def _WireTypeForFieldType(field_type): - """Given a field type, returns the expected wire type.""" - try: - return type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_type] - except KeyError: - raise message_mod.DecodeError('Unknown field type: %d' % field_type) - - -def _WireTypeForField(field_descriptor): - """Given a field descriptor, returns the expected wire type.""" - if field_descriptor.GetOptions().packed: - return wire_format.WIRETYPE_LENGTH_DELIMITED - else: - return _WireTypeForFieldType(field_descriptor.type) - - -def _RecursivelyMerge(field_number, field_type, decoder, message): - """Decodes a message from decoder into message. - message is either a group or a nested message within some containing - protocol message. If it's a group, we use the group protocol to - deserialize, and if it's a nested message, we use the nested-message - protocol. - - Args: - field_number: The field number of message in its enclosing protocol buffer. - field_type: The field type of message. Must be either TYPE_MESSAGE - or TYPE_GROUP. - decoder: Decoder to read from. - message: Message to deserialize into. - """ - if field_type == _FieldDescriptor.TYPE_MESSAGE: - decoder.ReadMessageInto(message) - elif field_type == _FieldDescriptor.TYPE_GROUP: - decoder.ReadGroupInto(field_number, message) - else: - raise message_mod.DecodeError('Unexpected field type: %d' % field_type) - - -def _DeserializeScalarFromDecoder(field_type, decoder): - """Deserializes a scalar of the requested type from decoder. field_type must - be a scalar (non-group, non-message) FieldDescriptor.FIELD_* constant. - """ - try: - method = type_checkers.TYPE_TO_DESERIALIZE_METHOD[field_type] - return method(decoder) - except KeyError: - raise message_mod.DecodeError('Unrecognized field type: %d' % field_type) - - -def _SkipField(field_number, wire_type, decoder): - """Skips a field with the specified wire type. - - Args: - field_number: Tag number of the field to skip. - wire_type: Wire type of the field to skip. - decoder: Decoder used to deserialize the messsage. It must be positioned - just after reading the the tag and wire type of the field. - """ - if wire_type == wire_format.WIRETYPE_VARINT: - decoder.ReadUInt64() - elif wire_type == wire_format.WIRETYPE_FIXED64: - decoder.ReadFixed64() - elif wire_type == wire_format.WIRETYPE_LENGTH_DELIMITED: - decoder.SkipBytes(decoder.ReadInt32()) - elif wire_type == wire_format.WIRETYPE_START_GROUP: - _SkipGroup(field_number, decoder) - elif wire_type == wire_format.WIRETYPE_END_GROUP: - pass - elif wire_type == wire_format.WIRETYPE_FIXED32: - decoder.ReadFixed32() - else: - raise message_mod.DecodeError('Unexpected wire type: %d' % wire_type) - - -def _SkipGroup(group_number, decoder): - """Skips a nested group from the decoder. - - Args: - group_number: Tag number of the group to skip. - decoder: Decoder used to deserialize the message. It must be positioned - exactly at the beginning of the message that should be skipped. - """ - while True: - field_number, wire_type = decoder.ReadFieldNumberAndWireType() - if (wire_type == wire_format.WIRETYPE_END_GROUP and - field_number == group_number): - return - _SkipField(field_number, wire_type, decoder) - - -def _DeserializeMessageSetItem(message, decoder): - """Deserializes a message using the message set wire format. - - Args: - message: Message to be parsed to. - decoder: The decoder to be used to deserialize encoded data. Note that the - decoder should be positioned just after reading the START_GROUP tag that - began the messageset item. - """ - field_number, wire_type = decoder.ReadFieldNumberAndWireType() - if wire_type != wire_format.WIRETYPE_VARINT or field_number != 2: - raise message_mod.DecodeError( - 'Incorrect message set wire format. ' - 'wire_type: %d, field_number: %d' % (wire_type, field_number)) - - type_id = decoder.ReadInt32() - field_number, wire_type = decoder.ReadFieldNumberAndWireType() - if wire_type != wire_format.WIRETYPE_LENGTH_DELIMITED or field_number != 3: - raise message_mod.DecodeError( - 'Incorrect message set wire format. ' - 'wire_type: %d, field_number: %d' % (wire_type, field_number)) - - extension_dict = message.Extensions - extensions_by_number = extension_dict._AllExtensionsByNumber() - if type_id not in extensions_by_number: - _SkipField(field_number, wire_type, decoder) - return - - field_descriptor = extensions_by_number[type_id] - value = extension_dict[field_descriptor] - decoder.ReadMessageInto(value) - # Read the END_GROUP tag. - field_number, wire_type = decoder.ReadFieldNumberAndWireType() - if wire_type != wire_format.WIRETYPE_END_GROUP or field_number != 1: - raise message_mod.DecodeError( - 'Incorrect message set wire format. ' - 'wire_type: %d, field_number: %d' % (wire_type, field_number)) - - -def _DeserializeOneEntity(message_descriptor, message, decoder): - """Deserializes the next wire entity from decoder into message. - - The next wire entity is either a scalar or a nested message, an - element in a repeated field (the wire encoding in this case is the - same), or a packed repeated field (in this case, the entire repeated - field is read by a single call to _DeserializeOneEntity). - - Args: - message_descriptor: A Descriptor instance describing all fields - in message. - message: The Message instance into which we're decoding our fields. - decoder: The Decoder we're using to deserialize encoded data. - - Returns: The number of bytes read from decoder during this method. - """ - initial_position = decoder.Position() - field_number, wire_type = decoder.ReadFieldNumberAndWireType() - extension_dict = message.Extensions - extensions_by_number = extension_dict._AllExtensionsByNumber() - if field_number in message_descriptor.fields_by_number: - # Non-extension field. - field_descriptor = message_descriptor.fields_by_number[field_number] - value = getattr(message, _PropertyName(field_descriptor.name)) - def nonextension_setter_fn(scalar): - setattr(message, _PropertyName(field_descriptor.name), scalar) - scalar_setter_fn = nonextension_setter_fn - elif field_number in extensions_by_number: - # Extension field. - field_descriptor = extensions_by_number[field_number] - value = extension_dict[field_descriptor] - def extension_setter_fn(scalar): - extension_dict[field_descriptor] = scalar - scalar_setter_fn = extension_setter_fn - elif wire_type == wire_format.WIRETYPE_END_GROUP: - # We assume we're being parsed as the group that's ended. - return 0 - elif (wire_type == wire_format.WIRETYPE_START_GROUP and - field_number == 1 and - message_descriptor.GetOptions().message_set_wire_format): - # A Message Set item. - _DeserializeMessageSetItem(message, decoder) - return decoder.Position() - initial_position - else: - _SkipField(field_number, wire_type, decoder) - return decoder.Position() - initial_position - - # If we reach this point, we've identified the field as either - # hardcoded or extension, and set |field_descriptor|, |scalar_setter_fn|, - # and |value| appropriately. Now actually deserialize the thing. - # - # field_descriptor: Describes the field we're deserializing. - # value: The value currently stored in the field to deserialize. - # Used only if the field is composite and/or repeated. - # scalar_setter_fn: A function F such that F(scalar) will - # set a nonrepeated scalar value for this field. Used only - # if this field is a nonrepeated scalar. - - field_number = field_descriptor.number - expected_wire_type = _WireTypeForField(field_descriptor) - if wire_type != expected_wire_type: - # Need to fill in uninterpreted_bytes. Work for the next CL. - raise RuntimeError('TODO(robinson): Wiretype mismatches not handled.') - - property_name = _PropertyName(field_descriptor.name) - label = field_descriptor.label - field_type = field_descriptor.type - cpp_type = field_descriptor.cpp_type - - # Nonrepeated scalar. Just set the field directly. - if (label != _FieldDescriptor.LABEL_REPEATED - and cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE): - scalar_setter_fn(_DeserializeScalarFromDecoder(field_type, decoder)) - return decoder.Position() - initial_position - - # Nonrepeated composite. Recursively deserialize. - if label != _FieldDescriptor.LABEL_REPEATED: - composite = value - _RecursivelyMerge(field_number, field_type, decoder, composite) - return decoder.Position() - initial_position - - # Now we know we're dealing with a repeated field of some kind. - element_list = value - - if cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE: - # Repeated scalar. - if not field_descriptor.GetOptions().packed: - element_list.append(_DeserializeScalarFromDecoder(field_type, decoder)) - return decoder.Position() - initial_position - else: - # Packed repeated field. - length = _DeserializeScalarFromDecoder( - _FieldDescriptor.TYPE_INT32, decoder) - content_start = decoder.Position() - while decoder.Position() - content_start < length: - element_list.append(_DeserializeScalarFromDecoder(field_type, decoder)) - return decoder.Position() - initial_position - else: - # Repeated composite. - composite = element_list.add() - _RecursivelyMerge(field_number, field_type, decoder, composite) - return decoder.Position() - initial_position - - -def _FieldOrExtensionValues(message, field_or_extension): - """Retrieves the list of values for the specified field or extension. - - The target field or extension can be optional, required or repeated, but it - must have value(s) set. The assumption is that the target field or extension - is set (e.g. _HasFieldOrExtension holds true). - Args: - message: Message which contains the target field or extension. - field_or_extension: Field or extension for which the list of values is - required. Must be an instance of FieldDescriptor. +def _AddMergeFromStringMethod(message_descriptor, cls): + """Helper for _AddMessageMethods().""" + def MergeFromString(self, serialized): + length = len(serialized) + try: + if self._InternalParse(serialized, 0, length) != length: + # The only reason _InternalParse would return early is if it + # encountered an end-group tag. + raise message_mod.DecodeError('Unexpected end-group tag.') + except IndexError: + raise message_mod.DecodeError('Truncated message.') + except struct.error, e: + raise message_mod.DecodeError(e) + return length # Return this for legacy reasons. + cls.MergeFromString = MergeFromString - Returns: - A list of values for the specified field or extension. This list will only - contain a single element if the field is non-repeated. - """ - if field_or_extension.is_extension: - value = message.Extensions[field_or_extension] - else: - value = getattr(message, _ValueFieldName(field_or_extension.name)) - if field_or_extension.label != _FieldDescriptor.LABEL_REPEATED: - return [value] - else: - # In this case value is a list or repeated values. - return value + local_ReadTag = decoder.ReadTag + local_SkipField = decoder.SkipField + decoders_by_tag = cls._decoders_by_tag + + def InternalParse(self, buffer, pos, end): + self._Modified() + field_dict = self._fields + while pos != end: + (tag_bytes, new_pos) = local_ReadTag(buffer, pos) + field_decoder = decoders_by_tag.get(tag_bytes) + if field_decoder is None: + new_pos = local_SkipField(buffer, new_pos, end, tag_bytes) + if new_pos == -1: + return pos + pos = new_pos + else: + pos = field_decoder(buffer, new_pos, end, self, field_dict) + return pos + cls._InternalParse = InternalParse -def _HasFieldOrExtension(message, field_or_extension): - """Checks if a message has the specified field or extension set. +def _AddIsInitializedMethod(message_descriptor, cls): + """Adds the IsInitialized and FindInitializationError methods to the + protocol message class.""" - The field or extension specified can be optional, required or repeated. If - it is repeated, this function returns True. Otherwise it checks the has bit - of the field or extension. + required_fields = [field for field in message_descriptor.fields + if field.label == _FieldDescriptor.LABEL_REQUIRED] - Args: - message: Message which contains the target field or extension. - field_or_extension: Field or extension to check. This must be a - FieldDescriptor instance. + def IsInitialized(self, errors=None): + """Checks if all required fields of a message are set. - Returns: - True if the message has a value set for the specified field or extension, - or if the field or extension is repeated. - """ - if field_or_extension.label == _FieldDescriptor.LABEL_REPEATED: - return True - if field_or_extension.is_extension: - return message.HasExtension(field_or_extension) - else: - return message.HasField(field_or_extension.name) + Args: + errors: A list which, if provided, will be populated with the field + paths of all missing required fields. + Returns: + True iff the specified message has all required fields set. + """ -def _IsFieldOrExtensionInitialized(message, field, errors=None): - """Checks if a message field or extension is initialized. + # Performance is critical so we avoid HasField() and ListFields(). - Args: - message: The message which contains the field or extension. - field: Field or extension to check. This must be a FieldDescriptor instance. - errors: Errors will be appended to it, if set to a meaningful value. + for field in required_fields: + if (field not in self._fields or + (field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE and + not self._fields[field]._is_present_in_parent)): + if errors is not None: + errors.extend(self.FindInitializationErrors()) + return False - Returns: - True if the field/extension can be considered initialized. - """ - # If the field is required and is not set, it isn't initialized. - if field.label == _FieldDescriptor.LABEL_REQUIRED: - if not _HasFieldOrExtension(message, field): - if errors is not None: - errors.append('Required field %s is not set.' % field.full_name) - return False + for field, value in self._fields.iteritems(): + if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + if field.label == _FieldDescriptor.LABEL_REPEATED: + for element in value: + if not element.IsInitialized(): + if errors is not None: + errors.extend(self.FindInitializationErrors()) + return False + elif value._is_present_in_parent and not value.IsInitialized(): + if errors is not None: + errors.extend(self.FindInitializationErrors()) + return False - # If the field is optional and is not set, or if it - # isn't a submessage then the field is initialized. - if field.label == _FieldDescriptor.LABEL_OPTIONAL: - if not _HasFieldOrExtension(message, field): - return True - if field.cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE: return True - # The field is set and is either a single or a repeated submessage. - messages = _FieldOrExtensionValues(message, field) - # If all submessages in this field are initialized, the field is - # considered initialized. - for message in messages: - if not _InternalIsInitialized(message, errors): - return False - return True - + cls.IsInitialized = IsInitialized -def _InternalIsInitialized(message, errors=None): - """Checks if all required fields of a message are set. - - Args: - message: The message to check. - errors: If set, initialization errors will be appended to it. + def FindInitializationErrors(self): + """Finds required fields which are not initialized. - Returns: - True iff the specified message has all required fields set. - """ - fields_and_extensions = [] - fields_and_extensions.extend(message.DESCRIPTOR.fields) - fields_and_extensions.extend( - [extension[0] for extension in message.Extensions._ListSetExtensions()]) - for field_or_extension in fields_and_extensions: - if not _IsFieldOrExtensionInitialized(message, field_or_extension, errors): - return False - return True - - -def _AddMergeFromStringMethod(message_descriptor, cls): - """Helper for _AddMessageMethods().""" - Decoder = decoder.Decoder - def MergeFromString(self, serialized): - decoder = Decoder(serialized) - byte_count = 0 - while not decoder.EndOfStream(): - bytes_read = _DeserializeOneEntity(message_descriptor, self, decoder) - if not bytes_read: - break - byte_count += bytes_read - return byte_count - cls.MergeFromString = MergeFromString - - -def _AddIsInitializedMethod(cls): - """Adds the IsInitialized method to the protocol message class.""" - cls.IsInitialized = _InternalIsInitialized + Returns: + A list of strings. Each string is a path to an uninitialized field from + the top-level message, e.g. "foo.bar[5].baz". + """ + errors = [] # simplify things -def _MergeFieldOrExtension(destination_msg, field, value): - """Merges a specified message field into another message.""" - property_name = _PropertyName(field.name) - is_extension = field.is_extension + for field in required_fields: + if not self.HasField(field.name): + errors.append(field.name) - if not is_extension: - destination = getattr(destination_msg, property_name) - elif (field.label == _FieldDescriptor.LABEL_REPEATED or - field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE): - destination = destination_msg.Extensions[field] + for field, value in self.ListFields(): + if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + if field.is_extension: + name = "(%s)" % field.full_name + else: + name = field.name - # Case 1 - a composite field. - if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: - if field.label == _FieldDescriptor.LABEL_REPEATED: - for v in value: - destination.add().MergeFrom(v) - else: - destination.MergeFrom(value) - return + if field.label == _FieldDescriptor.LABEL_REPEATED: + for i in xrange(len(value)): + element = value[i] + prefix = "%s[%d]." % (name, i) + sub_errors = element.FindInitializationErrors() + errors += [ prefix + error for error in sub_errors ] + else: + prefix = name + "." + sub_errors = value.FindInitializationErrors() + errors += [ prefix + error for error in sub_errors ] - # Case 2 - a repeated field. - if field.label == _FieldDescriptor.LABEL_REPEATED: - for v in value: - destination.append(v) - return + return errors - # Case 3 - a singular field. - if is_extension: - destination_msg.Extensions[field] = value - else: - setattr(destination_msg, property_name, value) + cls.FindInitializationErrors = FindInitializationErrors def _AddMergeFromMethod(cls): + LABEL_REPEATED = _FieldDescriptor.LABEL_REPEATED + CPPTYPE_MESSAGE = _FieldDescriptor.CPPTYPE_MESSAGE + def MergeFrom(self, msg): assert msg is not self - for field in msg.ListFields(): - _MergeFieldOrExtension(self, field[0], field[1]) + self._Modified() + + fields = self._fields + + for field, value in msg._fields.iteritems(): + if field.label == LABEL_REPEATED or field.cpp_type == CPPTYPE_MESSAGE: + field_value = fields.get(field) + if field_value is None: + # Construct a new object to represent this field. + field_value = field._default_constructor(self) + fields[field] = field_value + field_value.MergeFrom(value) + else: + self._fields[field] = value cls.MergeFrom = MergeFrom def _AddMessageMethods(message_descriptor, cls): """Adds implementations of all Message methods to cls.""" _AddListFieldsMethod(message_descriptor, cls) - _AddHasFieldMethod(cls) - _AddClearFieldMethod(cls) - _AddClearExtensionMethod(cls) - _AddClearMethod(cls) - _AddHasExtensionMethod(cls) + _AddHasFieldMethod(message_descriptor, cls) + _AddClearFieldMethod(message_descriptor, cls) + if message_descriptor.is_extendable: + _AddClearExtensionMethod(cls) + _AddHasExtensionMethod(cls) + _AddClearMethod(message_descriptor, cls) _AddEqualsMethod(message_descriptor, cls) _AddStrMethod(message_descriptor, cls) _AddSetListenerMethod(cls) @@ -1304,31 +971,30 @@ def _AddMessageMethods(message_descriptor, cls): _AddSerializeToStringMethod(message_descriptor, cls) _AddSerializePartialToStringMethod(message_descriptor, cls) _AddMergeFromStringMethod(message_descriptor, cls) - _AddIsInitializedMethod(cls) + _AddIsInitializedMethod(message_descriptor, cls) _AddMergeFromMethod(cls) def _AddPrivateHelperMethods(cls): """Adds implementation of private helper methods to cls.""" - def MaybeCallTransitionToNonemptyCallback(self): - """Calls self._listener.TransitionToNonempty() the first time this - method is called. On all subsequent calls, this is a no-op. - """ - if not self._called_transition_to_nonempty: - self._listener.TransitionToNonempty() - self._called_transition_to_nonempty = True - cls._MaybeCallTransitionToNonemptyCallback = ( - MaybeCallTransitionToNonemptyCallback) - - def MarkByteSizeDirty(self): + def Modified(self): """Sets the _cached_byte_size_dirty bit to true, and propagates this to our listener iff this was a state change. """ + + # Note: Some callers check _cached_byte_size_dirty before calling + # _Modified() as an extra optimization. So, if this method is ever + # changed such that it does stuff even when _cached_byte_size_dirty is + # already true, the callers need to be updated. if not self._cached_byte_size_dirty: self._cached_byte_size_dirty = True - self._listener.ByteSizeDirty() - cls._MarkByteSizeDirty = MarkByteSizeDirty + self._listener_for_children.dirty = True + self._is_present_in_parent = True + self._listener.Modified() + + cls._Modified = Modified + cls.SetInParent = Modified class _Listener(object): @@ -1338,22 +1004,17 @@ class _Listener(object): In order to support semantics like: - foo.bar.baz = 23 + foo.bar.baz.qux = 23 assert foo.HasField('bar') ...child objects must have back references to their parents. This helper class is at the heart of this support. """ - def __init__(self, parent_message, has_field_name): + def __init__(self, parent_message): """Args: - parent_message: The message whose _MaybeCallTransitionToNonemptyCallback() - and _MarkByteSizeDirty() methods we should call when we receive - TransitionToNonempty() and ByteSizeDirty() messages. - has_field_name: The name of the "has" field that we should set in - the parent message when we receive a TransitionToNonempty message, - or None if there's no "has" field to set. (This will be the case - for child objects in "repeated" fields). + parent_message: The message whose _Modified() method we should call when + we receive Modified() messages. """ # This listener establishes a back reference from a child (contained) object # to its parent (containing) object. We make this a weak reference to avoid @@ -1363,36 +1024,27 @@ class _Listener(object): self._parent_message_weakref = parent_message else: self._parent_message_weakref = weakref.proxy(parent_message) - self._has_field_name = has_field_name - def TransitionToNonempty(self): + # As an optimization, we also indicate directly on the listener whether + # or not the parent message is dirty. This way we can avoid traversing + # up the tree in the common case. + self.dirty = False + + def Modified(self): + if self.dirty: + return try: - if self._has_field_name is not None: - setattr(self._parent_message_weakref, self._has_field_name, True) # Propagate the signal to our parents iff this is the first field set. - self._parent_message_weakref._MaybeCallTransitionToNonemptyCallback() + self._parent_message_weakref._Modified() except ReferenceError: # We can get here if a client has kept a reference to a child object, # and is now setting a field on it, but the child's parent has been # garbage-collected. This is not an error. pass - def ByteSizeDirty(self): - try: - self._parent_message_weakref._MarkByteSizeDirty() - except ReferenceError: - # Same as above. - pass - # TODO(robinson): Move elsewhere? This file is getting pretty ridiculous... # TODO(robinson): Unify error handling of "unknown extension" crap. -# TODO(robinson): There's so much similarity between the way that -# extensions behave and the way that normal fields behave that it would -# be really nice to unify more code. It's not immediately obvious -# how to do this, though, and I'd rather get the full functionality -# implemented (and, crucially, get all the tests and specs fleshed out -# and passing), and then come back to this thorny unification problem. # TODO(robinson): Support iteritems()-style iteration over all # extensions with the "has" bits turned on? class _ExtensionDict(object): @@ -1404,250 +1056,85 @@ class _ExtensionDict(object): FieldDescriptors. """ - class _ExtensionListener(object): + def __init__(self, extended_message): + """extended_message: Message instance for which we are the Extensions dict. + """ - """Adapts an _ExtensionDict to behave as a MessageListener.""" + self._extended_message = extended_message - def __init__(self, extension_dict, handle_id): - self._extension_dict = extension_dict - self._handle_id = handle_id + def __getitem__(self, extension_handle): + """Returns the current value of the given extension handle.""" - def TransitionToNonempty(self): - self._extension_dict._SubmessageTransitionedToNonempty(self._handle_id) + _VerifyExtensionHandle(self._extended_message, extension_handle) - def ByteSizeDirty(self): - self._extension_dict._SubmessageByteSizeBecameDirty() + result = self._extended_message._fields.get(extension_handle) + if result is not None: + return result - # TODO(robinson): Somewhere, we need to blow up if people - # try to register two extensions with the same field number. - # (And we need a test for this of course). + if extension_handle.label == _FieldDescriptor.LABEL_REPEATED: + result = extension_handle._default_constructor(self._extended_message) + elif extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + result = extension_handle.message_type._concrete_class() + try: + result._SetListener(self._extended_message._listener_for_children) + except ReferenceError: + pass + else: + # Singular scalar -- just return the default without inserting into the + # dict. + return extension_handle.default_value - def __init__(self, extended_message, known_extensions): - """extended_message: Message instance for which we are the Extensions dict. - known_extensions: Iterable of known extension handles. - These must be FieldDescriptors. - """ - # We keep a weak reference to extended_message, since - # it has a reference to this instance in turn. - self._extended_message = weakref.proxy(extended_message) - # We make a deep copy of known_extensions to avoid any - # thread-safety concerns, since the argument passed in - # is the global (class-level) dict of known extensions for - # this type of message, which could be modified at any time - # via a RegisterExtension() call. - # - # This dict maps from handle id to handle (a FieldDescriptor). - # - # XXX - # TODO(robinson): This isn't good enough. The client could - # instantiate an object in module A, then afterward import - # module B and pass the instance to B.Foo(). If B imports - # an extender of this proto and then tries to use it, B - # will get a KeyError, even though the extension *is* registered - # at the time of use. - # XXX - self._known_extensions = dict((id(e), e) for e in known_extensions) - # Read lock around self._values, which may be modified by multiple - # concurrent readers in the conceptually "const" __getitem__ method. - # So, we grab this lock in every "read-only" method to ensure - # that concurrent read access is safe without external locking. - self._lock = threading.Lock() - # Maps from extension handle ID to current value of that extension. - self._values = {} - # Maps from extension handle ID to a boolean "has" bit, but only - # for non-repeated extension fields. - keys = (id for id, extension in self._known_extensions.iteritems() - if extension.label != _FieldDescriptor.LABEL_REPEATED) - self._has_bits = dict.fromkeys(keys, False) - - self._extensions_by_number = dict( - (f.number, f) for f in self._known_extensions.itervalues()) - - self._extensions_by_name = {} - for extension in self._known_extensions.itervalues(): - if (extension.containing_type.GetOptions().message_set_wire_format and - extension.type == descriptor_mod.FieldDescriptor.TYPE_MESSAGE and - extension.message_type == extension.extension_scope and - extension.label == descriptor_mod.FieldDescriptor.LABEL_OPTIONAL): - extension_name = extension.message_type.full_name - else: - extension_name = extension.full_name - self._extensions_by_name[extension_name] = extension + # Atomically check if another thread has preempted us and, if not, swap + # in the new object we just created. If someone has preempted us, we + # take that object and discard ours. + # WARNING: We are relying on setdefault() being atomic. This is true + # in CPython but we haven't investigated others. This warning appears + # in several other locations in this file. + result = self._extended_message._fields.setdefault( + extension_handle, result) - def __getitem__(self, extension_handle): - """Returns the current value of the given extension handle.""" - # We don't care as much about keeping critical sections short in the - # extension support, since it's presumably much less of a common case. - self._lock.acquire() - try: - handle_id = id(extension_handle) - if handle_id not in self._known_extensions: - raise KeyError('Extension not known to this class') - if handle_id not in self._values: - self._AddMissingHandle(extension_handle, handle_id) - return self._values[handle_id] - finally: - self._lock.release() + return result def __eq__(self, other): - # We have to grab read locks since we're accessing _values - # in a "const" method. See the comment in the constructor. - if self is other: - return True - self._lock.acquire() - try: - other._lock.acquire() - try: - if self._has_bits != other._has_bits: - return False - # If there's a "has" bit, then only compare values where it is true. - for k, v in self._values.iteritems(): - if self._has_bits.get(k, False) and v != other._values[k]: - return False - return True - finally: - other._lock.release() - finally: - self._lock.release() + if not isinstance(other, self.__class__): + return False + + my_fields = self._extended_message.ListFields() + other_fields = other._extended_message.ListFields() + + # Get rid of non-extension fields. + my_fields = [ field for field in my_fields if field.is_extension ] + other_fields = [ field for field in other_fields if field.is_extension ] + + return my_fields == other_fields def __ne__(self, other): return not self == other # Note that this is only meaningful for non-repeated, scalar extension - # fields. Note also that we may have to call - # MaybeCallTransitionToNonemptyCallback() when we do successfully set a field - # this way, to set any necssary "has" bits in the ancestors of the extended - # message. + # fields. Note also that we may have to call _Modified() when we do + # successfully set a field this way, to set any necssary "has" bits in the + # ancestors of the extended message. def __setitem__(self, extension_handle, value): """If extension_handle specifies a non-repeated, scalar extension field, sets the value of that field. """ - handle_id = id(extension_handle) - if handle_id not in self._known_extensions: - raise KeyError('Extension not known to this class') - field = extension_handle # Just shorten the name. - if (field.label == _FieldDescriptor.LABEL_OPTIONAL - and field.cpp_type != _FieldDescriptor.CPPTYPE_MESSAGE): - # It's slightly wasteful to lookup the type checker each time, - # but we expect this to be a vanishingly uncommon case anyway. - type_checker = type_checkers.GetTypeChecker(field.cpp_type, field.type) - type_checker.CheckValue(value) - self._values[handle_id] = value - self._has_bits[handle_id] = True - self._extended_message._MarkByteSizeDirty() - self._extended_message._MaybeCallTransitionToNonemptyCallback() - else: - raise TypeError('Extension is repeated and/or a composite type.') - - def _AddMissingHandle(self, extension_handle, handle_id): - """Helper internal to ExtensionDict.""" - # Special handling for non-repeated message extensions, which (like - # normal fields of this kind) are initialized lazily. - # REQUIRES: _lock already held. - cpp_type = extension_handle.cpp_type - label = extension_handle.label - if (cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE - and label != _FieldDescriptor.LABEL_REPEATED): - self._AddMissingNonRepeatedCompositeHandle(extension_handle, handle_id) - else: - self._values[handle_id] = _DefaultValueForField( - self._extended_message, extension_handle) - - def _AddMissingNonRepeatedCompositeHandle(self, extension_handle, handle_id): - """Helper internal to ExtensionDict.""" - # REQUIRES: _lock already held. - value = extension_handle.message_type._concrete_class() - value._SetListener(_ExtensionDict._ExtensionListener(self, handle_id)) - self._values[handle_id] = value - - def _SubmessageTransitionedToNonempty(self, handle_id): - """Called when a submessage with a given handle id first transitions to - being nonempty. Called by _ExtensionListener. - """ - assert handle_id in self._has_bits - self._has_bits[handle_id] = True - self._extended_message._MaybeCallTransitionToNonemptyCallback() - def _SubmessageByteSizeBecameDirty(self): - """Called whenever a submessage's cached byte size becomes invalid - (goes from being "clean" to being "dirty"). Called by _ExtensionListener. - """ - self._extended_message._MarkByteSizeDirty() - - # We may wish to widen the public interface of Message.Extensions - # to expose some of this private functionality in the future. - # For now, we make all this functionality module-private and just - # implement what we need for serialization/deserialization, - # HasField()/ClearField(), etc. - - def _HasExtension(self, extension_handle): - """Method for internal use by this module. - Returns true iff we "have" this extension in the sense of the - "has" bit being set. - """ - handle_id = id(extension_handle) - # Note that this is different from the other checks. - if handle_id not in self._has_bits: - raise KeyError('Extension not known to this class, or is repeated field.') - return self._has_bits[handle_id] - - # Intentionally pretty similar to ClearField() above. - def _ClearExtension(self, extension_handle): - """Method for internal use by this module. - Clears the specified extension, unsetting its "has" bit. - """ - handle_id = id(extension_handle) - if handle_id not in self._known_extensions: - raise KeyError('Extension not known to this class') - default_value = _DefaultValueForField(self._extended_message, - extension_handle) - if extension_handle.label == _FieldDescriptor.LABEL_REPEATED: - self._extended_message._MarkByteSizeDirty() - else: - cpp_type = extension_handle.cpp_type - if cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: - if handle_id in self._values: - # Future modifications to this object shouldn't set any - # "has" bits here. - self._values[handle_id]._SetListener(None) - if self._has_bits[handle_id]: - self._has_bits[handle_id] = False - self._extended_message._MarkByteSizeDirty() - if handle_id in self._values: - del self._values[handle_id] - - def _ListSetExtensions(self): - """Method for internal use by this module. - - Returns an sequence of all extensions that are currently "set" - in this extension dict. A "set" extension is a repeated extension, - or a non-repeated extension with its "has" bit set. - - The returned sequence contains (field_descriptor, value) pairs, - where value is the current value of the extension with the given - field descriptor. - - The sequence values are in arbitrary order. - """ - self._lock.acquire() # Read-only methods must lock around self._values. - try: - set_extensions = [] - for handle_id, value in self._values.iteritems(): - handle = self._known_extensions[handle_id] - if (handle.label == _FieldDescriptor.LABEL_REPEATED - or self._has_bits[handle_id]): - set_extensions.append((handle, value)) - return set_extensions - finally: - self._lock.release() - - def _AllExtensionsByNumber(self): - """Method for internal use by this module. - - Returns: A dict mapping field_number to (handle, field_descriptor), - for *all* registered extensions for this dict. - """ - return self._extensions_by_number + _VerifyExtensionHandle(self._extended_message, extension_handle) + + if (extension_handle.label == _FieldDescriptor.LABEL_REPEATED or + extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE): + raise TypeError( + 'Cannot assign to extension "%s" because it is a repeated or ' + 'composite type.' % extension_handle.full_name) + + # It's slightly wasteful to lookup the type checker each time, + # but we expect this to be a vanishingly uncommon case anyway. + type_checker = type_checkers.GetTypeChecker( + extension_handle.cpp_type, extension_handle.type) + type_checker.CheckValue(value) + self._extended_message._fields[extension_handle] = value + self._extended_message._Modified() def _FindExtensionByName(self, name): """Tries to find a known extension with the specified name. @@ -1658,4 +1145,4 @@ class _ExtensionDict(object): Returns: Extension field descriptor. """ - return self._extensions_by_name.get(name, None) + return self._extended_message._extensions_by_name.get(name, None) |