diff options
Diffstat (limited to 'python/google/protobuf/internal')
38 files changed, 6946 insertions, 400 deletions
diff --git a/python/google/protobuf/internal/api_implementation.cc b/python/google/protobuf/internal/api_implementation.cc new file mode 100644 index 0000000..83db40b --- /dev/null +++ b/python/google/protobuf/internal/api_implementation.cc @@ -0,0 +1,139 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +#include <Python.h> + +namespace google { +namespace protobuf { +namespace python { + +// Version constant. +// This is either 0 for python, 1 for CPP V1, 2 for CPP V2. +// +// 0 is default and is equivalent to +// PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=python +// +// 1 is set with -DPYTHON_PROTO2_CPP_IMPL_V1 and is equivalent to +// PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=cpp +// and +// PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION=1 +// +// 2 is set with -DPYTHON_PROTO2_CPP_IMPL_V2 and is equivalent to +// PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION=cpp +// and +// PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION=2 +#ifdef PYTHON_PROTO2_CPP_IMPL_V1 +#if PY_MAJOR_VERSION >= 3 +#error "PYTHON_PROTO2_CPP_IMPL_V1 is not supported under Python 3." +#endif +static int kImplVersion = 1; +#else +#ifdef PYTHON_PROTO2_CPP_IMPL_V2 +static int kImplVersion = 2; +#else +#ifdef PYTHON_PROTO2_PYTHON_IMPL +static int kImplVersion = 0; +#else + +// The defaults are set here. Python 3 uses the fast C++ APIv2 by default. +// Python 2 still uses the Python version by default until some compatibility +// issues can be worked around. +#if PY_MAJOR_VERSION >= 3 +static int kImplVersion = 2; +#else +static int kImplVersion = 0; +#endif + +#endif // PYTHON_PROTO2_PYTHON_IMPL +#endif // PYTHON_PROTO2_CPP_IMPL_V2 +#endif // PYTHON_PROTO2_CPP_IMPL_V1 + +static const char* kImplVersionName = "api_version"; + +static const char* kModuleName = "_api_implementation"; +static const char kModuleDocstring[] = +"_api_implementation is a module that exposes compile-time constants that\n" +"determine the default API implementation to use for Python proto2.\n" +"\n" +"It complements api_implementation.py by setting defaults using compile-time\n" +"constants defined in C, such that one can set defaults at compilation\n" +"(e.g. with blaze flag --copt=-DPYTHON_PROTO2_CPP_IMPL_V2)."; + +#if PY_MAJOR_VERSION >= 3 +static struct PyModuleDef _module = { + PyModuleDef_HEAD_INIT, + kModuleName, + kModuleDocstring, + -1, + NULL, + NULL, + NULL, + NULL, + NULL +}; +#define INITFUNC PyInit__api_implementation +#define INITFUNC_ERRORVAL NULL +#else +#define INITFUNC init_api_implementation +#define INITFUNC_ERRORVAL +#endif + +extern "C" { + PyMODINIT_FUNC INITFUNC() { +#if PY_MAJOR_VERSION >= 3 + PyObject *module = PyModule_Create(&_module); +#else + PyObject *module = Py_InitModule3( + const_cast<char*>(kModuleName), + NULL, + const_cast<char*>(kModuleDocstring)); +#endif + if (module == NULL) { + return INITFUNC_ERRORVAL; + } + + // Adds the module variable "api_version". + if (PyModule_AddIntConstant( + module, + const_cast<char*>(kImplVersionName), + kImplVersion)) +#if PY_MAJOR_VERSION < 3 + return; +#else + { Py_DECREF(module); return NULL; } + + return module; +#endif + } +} + +} // namespace python +} // namespace protobuf +} // namespace google diff --git a/python/google/protobuf/internal/api_implementation.py b/python/google/protobuf/internal/api_implementation.py new file mode 100755 index 0000000..f7926c1 --- /dev/null +++ b/python/google/protobuf/internal/api_implementation.py @@ -0,0 +1,89 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# https://developers.google.com/protocol-buffers/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Determine which implementation of the protobuf API is used in this process. +""" + +import os +import sys + +try: + # pylint: disable=g-import-not-at-top + from google.protobuf.internal import _api_implementation + # The compile-time constants in the _api_implementation module can be used to + # switch to a certain implementation of the Python API at build time. + _api_version = _api_implementation.api_version + del _api_implementation +except ImportError: + _api_version = 0 + +_default_implementation_type = ( + 'python' if _api_version == 0 else 'cpp') +_default_version_str = ( + '1' if _api_version <= 1 else '2') + +# This environment variable can be used to switch to a certain implementation +# of the Python API, overriding the compile-time constants in the +# _api_implementation module. Right now only 'python' and 'cpp' are valid +# values. Any other value will be ignored. +_implementation_type = os.getenv('PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION', + _default_implementation_type) + +if _implementation_type != 'python': + _implementation_type = 'cpp' + +# This environment variable can be used to switch between the two +# 'cpp' implementations, overriding the compile-time constants in the +# _api_implementation module. Right now only 1 and 2 are valid values. Any other +# value will be ignored. +_implementation_version_str = os.getenv( + 'PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION', + _default_version_str) + +if _implementation_version_str not in ('1', '2'): + raise ValueError( + "unsupported PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION: '" + + _implementation_version_str + "' (supported versions: 1, 2)" + ) + +_implementation_version = int(_implementation_version_str) + + +# Usage of this function is discouraged. Clients shouldn't care which +# implementation of the API is in use. Note that there is no guarantee +# that differences between APIs will be maintained. +# Please don't use this function if possible. +def Type(): + return _implementation_type + + +# See comment on 'Type' above. +def Version(): + return _implementation_version diff --git a/python/google/protobuf/internal/api_implementation_default_test.py b/python/google/protobuf/internal/api_implementation_default_test.py new file mode 100644 index 0000000..78d5cf2 --- /dev/null +++ b/python/google/protobuf/internal/api_implementation_default_test.py @@ -0,0 +1,63 @@ +#! /usr/bin/python +# +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# https://developers.google.com/protocol-buffers/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Test that the api_implementation defaults are what we expect.""" + +import os +import sys +# Clear environment implementation settings before the google3 imports. +os.environ.pop('PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION', None) +os.environ.pop('PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION_VERSION', None) + +# pylint: disable=g-import-not-at-top +from google.apputils import basetest +from google.protobuf.internal import api_implementation + + +class ApiImplementationDefaultTest(basetest.TestCase): + + if sys.version_info.major <= 2: + + def testThatPythonIsTheDefault(self): + """If -DPYTHON_PROTO_*IMPL* was given at build time, this may fail.""" + self.assertEqual('python', api_implementation.Type()) + + else: + + def testThatCppApiV2IsTheDefault(self): + """If -DPYTHON_PROTO_*IMPL* was given at build time, this may fail.""" + self.assertEqual('cpp', api_implementation.Type()) + self.assertEqual(2, api_implementation.Version()) + + +if __name__ == '__main__': + basetest.main() diff --git a/python/google/protobuf/internal/containers.py b/python/google/protobuf/internal/containers.py index 5cc7d6d..20bfa85 100755 --- a/python/google/protobuf/internal/containers.py +++ b/python/google/protobuf/internal/containers.py @@ -1,6 +1,6 @@ # Protocol Buffers - Google's data interchange format # Copyright 2008 Google Inc. All rights reserved. -# http://code.google.com/p/protobuf/ +# https://developers.google.com/protocol-buffers/ # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -72,9 +72,20 @@ class BaseContainer(object): # The concrete classes should define __eq__. return not self == other + def __hash__(self): + raise TypeError('unhashable object') + def __repr__(self): return repr(self._values) + def sort(self, *args, **kwargs): + # Continue to support the old sort_function keyword argument. + # This is expected to be a rare occurrence, so use LBYL to avoid + # the overhead of actually catching KeyError. + if 'sort_function' in kwargs: + kwargs['cmp'] = kwargs.pop('sort_function') + self._values.sort(*args, **kwargs) + class RepeatedScalarFieldContainer(BaseContainer): @@ -97,15 +108,13 @@ class RepeatedScalarFieldContainer(BaseContainer): def append(self, value): """Appends an item to the list. Similar to list.append().""" - self._type_checker.CheckValue(value) - self._values.append(value) + self._values.append(self._type_checker.CheckValue(value)) if not self._message_listener.dirty: self._message_listener.Modified() def insert(self, key, value): """Inserts the item at the specified position. Similar to list.insert().""" - self._type_checker.CheckValue(value) - self._values.insert(key, value) + self._values.insert(key, self._type_checker.CheckValue(value)) if not self._message_listener.dirty: self._message_listener.Modified() @@ -116,8 +125,7 @@ class RepeatedScalarFieldContainer(BaseContainer): new_values = [] for elem in elem_seq: - self._type_checker.CheckValue(elem) - new_values.append(elem) + new_values.append(self._type_checker.CheckValue(elem)) self._values.extend(new_values) self._message_listener.Modified() @@ -135,9 +143,13 @@ class RepeatedScalarFieldContainer(BaseContainer): def __setitem__(self, key, value): """Sets the item on the specified position.""" - self._type_checker.CheckValue(value) - self._values[key] = value - self._message_listener.Modified() + if isinstance(key, slice): # PY3 + if key.step is not None: + raise ValueError('Extended slices not supported') + self.__setslice__(key.start, key.stop, value) + else: + self._values[key] = self._type_checker.CheckValue(value) + self._message_listener.Modified() def __getslice__(self, start, stop): """Retrieves the subset of items from between the specified indices.""" @@ -147,8 +159,7 @@ class RepeatedScalarFieldContainer(BaseContainer): """Sets the subset of items from between the specified indices.""" new_values = [] for value in values: - self._type_checker.CheckValue(value) - new_values.append(value) + new_values.append(self._type_checker.CheckValue(value)) self._values[start:stop] = new_values self._message_listener.Modified() @@ -198,28 +209,42 @@ class RepeatedCompositeFieldContainer(BaseContainer): super(RepeatedCompositeFieldContainer, self).__init__(message_listener) self._message_descriptor = message_descriptor - def add(self): - new_element = self._message_descriptor._concrete_class() + def add(self, **kwargs): + """Adds a new element at the end of the list and returns it. Keyword + arguments may be used to initialize the element. + """ + new_element = self._message_descriptor._concrete_class(**kwargs) new_element._SetListener(self._message_listener) self._values.append(new_element) if not self._message_listener.dirty: self._message_listener.Modified() return new_element - def MergeFrom(self, other): - """Appends the contents of another repeated field of the same type to this - one, copying each individual message. + def extend(self, elem_seq): + """Extends by appending the given sequence of elements of the same type + as this one, copying each individual message. """ message_class = self._message_descriptor._concrete_class listener = self._message_listener values = self._values - for message in other._values: + for message in elem_seq: new_element = message_class() new_element._SetListener(listener) new_element.MergeFrom(message) values.append(new_element) listener.Modified() + def MergeFrom(self, other): + """Appends the contents of another repeated field of the same type to this + one, copying each individual message. + """ + self.extend(other._values) + + def remove(self, elem): + """Removes an item from the list. Similar to list.remove().""" + self._values.remove(elem) + self._message_listener.Modified() + def __getslice__(self, start, stop): """Retrieves the subset of items from between the specified indices.""" return self._values[start:stop] diff --git a/python/google/protobuf/internal/cpp_message.py b/python/google/protobuf/internal/cpp_message.py new file mode 100755 index 0000000..0313cb0 --- /dev/null +++ b/python/google/protobuf/internal/cpp_message.py @@ -0,0 +1,663 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# https://developers.google.com/protocol-buffers/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Contains helper functions used to create protocol message classes from +Descriptor objects at runtime backed by the protocol buffer C++ API. +""" + +__author__ = 'petar@google.com (Petar Petrov)' + +import copy_reg +import operator +from google.protobuf.internal import _net_proto2___python +from google.protobuf.internal import enum_type_wrapper +from google.protobuf import message + + +_LABEL_REPEATED = _net_proto2___python.LABEL_REPEATED +_LABEL_OPTIONAL = _net_proto2___python.LABEL_OPTIONAL +_CPPTYPE_MESSAGE = _net_proto2___python.CPPTYPE_MESSAGE +_TYPE_MESSAGE = _net_proto2___python.TYPE_MESSAGE + + +def GetDescriptorPool(): + """Creates a new DescriptorPool C++ object.""" + return _net_proto2___python.NewCDescriptorPool() + + +_pool = GetDescriptorPool() + + +def GetFieldDescriptor(full_field_name): + """Searches for a field descriptor given a full field name.""" + return _pool.FindFieldByName(full_field_name) + + +def BuildFile(content): + """Registers a new proto file in the underlying C++ descriptor pool.""" + _net_proto2___python.BuildFile(content) + + +def GetExtensionDescriptor(full_extension_name): + """Searches for extension descriptor given a full field name.""" + return _pool.FindExtensionByName(full_extension_name) + + +def NewCMessage(full_message_name): + """Creates a new C++ protocol message by its name.""" + return _net_proto2___python.NewCMessage(full_message_name) + + +def ScalarProperty(cdescriptor): + """Returns a scalar property for the given descriptor.""" + + def Getter(self): + return self._cmsg.GetScalar(cdescriptor) + + def Setter(self, value): + self._cmsg.SetScalar(cdescriptor, value) + + return property(Getter, Setter) + + +def CompositeProperty(cdescriptor, message_type): + """Returns a Python property the given composite field.""" + + def Getter(self): + sub_message = self._composite_fields.get(cdescriptor.name, None) + if sub_message is None: + cmessage = self._cmsg.NewSubMessage(cdescriptor) + sub_message = message_type._concrete_class(__cmessage=cmessage) + self._composite_fields[cdescriptor.name] = sub_message + return sub_message + + return property(Getter) + + +class RepeatedScalarContainer(object): + """Container for repeated scalar fields.""" + + __slots__ = ['_message', '_cfield_descriptor', '_cmsg'] + + def __init__(self, msg, cfield_descriptor): + self._message = msg + self._cmsg = msg._cmsg + self._cfield_descriptor = cfield_descriptor + + def append(self, value): + self._cmsg.AddRepeatedScalar( + self._cfield_descriptor, value) + + def extend(self, sequence): + for element in sequence: + self.append(element) + + def insert(self, key, value): + values = self[slice(None, None, None)] + values.insert(key, value) + self._cmsg.AssignRepeatedScalar(self._cfield_descriptor, values) + + def remove(self, value): + values = self[slice(None, None, None)] + values.remove(value) + self._cmsg.AssignRepeatedScalar(self._cfield_descriptor, values) + + def __setitem__(self, key, value): + values = self[slice(None, None, None)] + values[key] = value + self._cmsg.AssignRepeatedScalar(self._cfield_descriptor, values) + + def __getitem__(self, key): + return self._cmsg.GetRepeatedScalar(self._cfield_descriptor, key) + + def __delitem__(self, key): + self._cmsg.DeleteRepeatedField(self._cfield_descriptor, key) + + def __len__(self): + return len(self[slice(None, None, None)]) + + def __eq__(self, other): + if self is other: + return True + if not operator.isSequenceType(other): + raise TypeError( + 'Can only compare repeated scalar fields against sequences.') + # We are presumably comparing against some other sequence type. + return other == self[slice(None, None, None)] + + def __ne__(self, other): + return not self == other + + def __hash__(self): + raise TypeError('unhashable object') + + def sort(self, *args, **kwargs): + # Maintain compatibility with the previous interface. + if 'sort_function' in kwargs: + kwargs['cmp'] = kwargs.pop('sort_function') + self._cmsg.AssignRepeatedScalar(self._cfield_descriptor, + sorted(self, *args, **kwargs)) + + +def RepeatedScalarProperty(cdescriptor): + """Returns a Python property the given repeated scalar field.""" + + def Getter(self): + container = self._composite_fields.get(cdescriptor.name, None) + if container is None: + container = RepeatedScalarContainer(self, cdescriptor) + self._composite_fields[cdescriptor.name] = container + return container + + def Setter(self, new_value): + raise AttributeError('Assignment not allowed to repeated field ' + '"%s" in protocol message object.' % cdescriptor.name) + + doc = 'Magic attribute generated for "%s" proto field.' % cdescriptor.name + return property(Getter, Setter, doc=doc) + + +class RepeatedCompositeContainer(object): + """Container for repeated composite fields.""" + + __slots__ = ['_message', '_subclass', '_cfield_descriptor', '_cmsg'] + + def __init__(self, msg, cfield_descriptor, subclass): + self._message = msg + self._cmsg = msg._cmsg + self._subclass = subclass + self._cfield_descriptor = cfield_descriptor + + def add(self, **kwargs): + cmessage = self._cmsg.AddMessage(self._cfield_descriptor) + return self._subclass(__cmessage=cmessage, __owner=self._message, **kwargs) + + def extend(self, elem_seq): + """Extends by appending the given sequence of elements of the same type + as this one, copying each individual message. + """ + for message in elem_seq: + self.add().MergeFrom(message) + + def remove(self, value): + # TODO(protocol-devel): This is inefficient as it needs to generate a + # message pointer for each message only to do index(). Move this to a C++ + # extension function. + self.__delitem__(self[slice(None, None, None)].index(value)) + + def MergeFrom(self, other): + for message in other[:]: + self.add().MergeFrom(message) + + def __getitem__(self, key): + cmessages = self._cmsg.GetRepeatedMessage( + self._cfield_descriptor, key) + subclass = self._subclass + if not isinstance(cmessages, list): + return subclass(__cmessage=cmessages, __owner=self._message) + + return [subclass(__cmessage=m, __owner=self._message) for m in cmessages] + + def __delitem__(self, key): + self._cmsg.DeleteRepeatedField( + self._cfield_descriptor, key) + + def __len__(self): + return self._cmsg.FieldLength(self._cfield_descriptor) + + def __eq__(self, other): + """Compares the current instance with another one.""" + if self is other: + return True + if not isinstance(other, self.__class__): + raise TypeError('Can only compare repeated composite fields against ' + 'other repeated composite fields.') + messages = self[slice(None, None, None)] + other_messages = other[slice(None, None, None)] + return messages == other_messages + + def __hash__(self): + raise TypeError('unhashable object') + + def sort(self, cmp=None, key=None, reverse=False, **kwargs): + # Maintain compatibility with the old interface. + if cmp is None and 'sort_function' in kwargs: + cmp = kwargs.pop('sort_function') + + # The cmp function, if provided, is passed the results of the key function, + # so we only need to wrap one of them. + if key is None: + index_key = self.__getitem__ + else: + index_key = lambda i: key(self[i]) + + # Sort the list of current indexes by the underlying object. + indexes = range(len(self)) + indexes.sort(cmp=cmp, key=index_key, reverse=reverse) + + # Apply the transposition. + for dest, src in enumerate(indexes): + if dest == src: + continue + self._cmsg.SwapRepeatedFieldElements(self._cfield_descriptor, dest, src) + # Don't swap the same value twice. + indexes[src] = src + + +def RepeatedCompositeProperty(cdescriptor, message_type): + """Returns a Python property for the given repeated composite field.""" + + def Getter(self): + container = self._composite_fields.get(cdescriptor.name, None) + if container is None: + container = RepeatedCompositeContainer( + self, cdescriptor, message_type._concrete_class) + self._composite_fields[cdescriptor.name] = container + return container + + def Setter(self, new_value): + raise AttributeError('Assignment not allowed to repeated field ' + '"%s" in protocol message object.' % cdescriptor.name) + + doc = 'Magic attribute generated for "%s" proto field.' % cdescriptor.name + return property(Getter, Setter, doc=doc) + + +class ExtensionDict(object): + """Extension dictionary added to each protocol message.""" + + def __init__(self, msg): + self._message = msg + self._cmsg = msg._cmsg + self._values = {} + + def __setitem__(self, extension, value): + from google.protobuf import descriptor + if not isinstance(extension, descriptor.FieldDescriptor): + raise KeyError('Bad extension %r.' % (extension,)) + cdescriptor = extension._cdescriptor + if (cdescriptor.label != _LABEL_OPTIONAL or + cdescriptor.cpp_type == _CPPTYPE_MESSAGE): + raise TypeError('Extension %r is repeated and/or a composite type.' % ( + extension.full_name,)) + self._cmsg.SetScalar(cdescriptor, value) + self._values[extension] = value + + def __getitem__(self, extension): + from google.protobuf import descriptor + if not isinstance(extension, descriptor.FieldDescriptor): + raise KeyError('Bad extension %r.' % (extension,)) + + cdescriptor = extension._cdescriptor + if (cdescriptor.label != _LABEL_REPEATED and + cdescriptor.cpp_type != _CPPTYPE_MESSAGE): + return self._cmsg.GetScalar(cdescriptor) + + ext = self._values.get(extension, None) + if ext is not None: + return ext + + ext = self._CreateNewHandle(extension) + self._values[extension] = ext + return ext + + def ClearExtension(self, extension): + from google.protobuf import descriptor + if not isinstance(extension, descriptor.FieldDescriptor): + raise KeyError('Bad extension %r.' % (extension,)) + self._cmsg.ClearFieldByDescriptor(extension._cdescriptor) + if extension in self._values: + del self._values[extension] + + def HasExtension(self, extension): + from google.protobuf import descriptor + if not isinstance(extension, descriptor.FieldDescriptor): + raise KeyError('Bad extension %r.' % (extension,)) + return self._cmsg.HasFieldByDescriptor(extension._cdescriptor) + + def _FindExtensionByName(self, name): + """Tries to find a known extension with the specified name. + + Args: + name: Extension full name. + + Returns: + Extension field descriptor. + """ + return self._message._extensions_by_name.get(name, None) + + def _CreateNewHandle(self, extension): + cdescriptor = extension._cdescriptor + if (cdescriptor.label != _LABEL_REPEATED and + cdescriptor.cpp_type == _CPPTYPE_MESSAGE): + cmessage = self._cmsg.NewSubMessage(cdescriptor) + return extension.message_type._concrete_class(__cmessage=cmessage) + + if cdescriptor.label == _LABEL_REPEATED: + if cdescriptor.cpp_type == _CPPTYPE_MESSAGE: + return RepeatedCompositeContainer( + self._message, cdescriptor, extension.message_type._concrete_class) + else: + return RepeatedScalarContainer(self._message, cdescriptor) + # This shouldn't happen! + assert False + return None + + +def NewMessage(bases, message_descriptor, dictionary): + """Creates a new protocol message *class*.""" + _AddClassAttributesForNestedExtensions(message_descriptor, dictionary) + _AddEnumValues(message_descriptor, dictionary) + _AddDescriptors(message_descriptor, dictionary) + return bases + + +def InitMessage(message_descriptor, cls): + """Constructs a new message instance (called before instance's __init__).""" + cls._extensions_by_name = {} + _AddInitMethod(message_descriptor, cls) + _AddMessageMethods(message_descriptor, cls) + _AddPropertiesForExtensions(message_descriptor, cls) + copy_reg.pickle(cls, lambda obj: (cls, (), obj.__getstate__())) + + +def _AddDescriptors(message_descriptor, dictionary): + """Sets up a new protocol message class dictionary. + + Args: + message_descriptor: A Descriptor instance describing this message type. + dictionary: Class dictionary to which we'll add a '__slots__' entry. + """ + dictionary['__descriptors'] = {} + for field in message_descriptor.fields: + dictionary['__descriptors'][field.name] = GetFieldDescriptor( + field.full_name) + + dictionary['__slots__'] = list(dictionary['__descriptors'].iterkeys()) + [ + '_cmsg', '_owner', '_composite_fields', 'Extensions', '_HACK_REFCOUNTS'] + + +def _AddEnumValues(message_descriptor, dictionary): + """Sets class-level attributes for all enum fields defined in this message. + + Args: + message_descriptor: Descriptor object for this message type. + dictionary: Class dictionary that should be populated. + """ + for enum_type in message_descriptor.enum_types: + dictionary[enum_type.name] = enum_type_wrapper.EnumTypeWrapper(enum_type) + for enum_value in enum_type.values: + dictionary[enum_value.name] = enum_value.number + + +def _AddClassAttributesForNestedExtensions(message_descriptor, dictionary): + """Adds class attributes for the nested extensions.""" + extension_dict = message_descriptor.extensions_by_name + for extension_name, extension_field in extension_dict.iteritems(): + assert extension_name not in dictionary + dictionary[extension_name] = extension_field + + +def _AddInitMethod(message_descriptor, cls): + """Adds an __init__ method to cls.""" + + # Create and attach message field properties to the message class. + # This can be done just once per message class, since property setters and + # getters are passed the message instance. + # This makes message instantiation extremely fast, and at the same time it + # doesn't require the creation of property objects for each message instance, + # which saves a lot of memory. + for field in message_descriptor.fields: + field_cdescriptor = cls.__descriptors[field.name] + if field.label == _LABEL_REPEATED: + if field.cpp_type == _CPPTYPE_MESSAGE: + value = RepeatedCompositeProperty(field_cdescriptor, field.message_type) + else: + value = RepeatedScalarProperty(field_cdescriptor) + elif field.cpp_type == _CPPTYPE_MESSAGE: + value = CompositeProperty(field_cdescriptor, field.message_type) + else: + value = ScalarProperty(field_cdescriptor) + setattr(cls, field.name, value) + + # Attach a constant with the field number. + constant_name = field.name.upper() + '_FIELD_NUMBER' + setattr(cls, constant_name, field.number) + + def Init(self, **kwargs): + """Message constructor.""" + cmessage = kwargs.pop('__cmessage', None) + if cmessage: + self._cmsg = cmessage + else: + self._cmsg = NewCMessage(message_descriptor.full_name) + + # Keep a reference to the owner, as the owner keeps a reference to the + # underlying protocol buffer message. + owner = kwargs.pop('__owner', None) + if owner: + self._owner = owner + + if message_descriptor.is_extendable: + self.Extensions = ExtensionDict(self) + else: + # Reference counting in the C++ code is broken and depends on + # the Extensions reference to keep this object alive during unit + # tests (see b/4856052). Remove this once b/4945904 is fixed. + self._HACK_REFCOUNTS = self + self._composite_fields = {} + + for field_name, field_value in kwargs.iteritems(): + field_cdescriptor = self.__descriptors.get(field_name, None) + if not field_cdescriptor: + raise ValueError('Protocol message has no "%s" field.' % field_name) + if field_cdescriptor.label == _LABEL_REPEATED: + if field_cdescriptor.cpp_type == _CPPTYPE_MESSAGE: + field_name = getattr(self, field_name) + for val in field_value: + field_name.add().MergeFrom(val) + else: + getattr(self, field_name).extend(field_value) + elif field_cdescriptor.cpp_type == _CPPTYPE_MESSAGE: + getattr(self, field_name).MergeFrom(field_value) + else: + setattr(self, field_name, field_value) + + Init.__module__ = None + Init.__doc__ = None + cls.__init__ = Init + + +def _IsMessageSetExtension(field): + """Checks if a field is a message set extension.""" + return (field.is_extension and + field.containing_type.has_options and + field.containing_type.GetOptions().message_set_wire_format and + field.type == _TYPE_MESSAGE and + field.message_type == field.extension_scope and + field.label == _LABEL_OPTIONAL) + + +def _AddMessageMethods(message_descriptor, cls): + """Adds the methods to a protocol message class.""" + if message_descriptor.is_extendable: + + def ClearExtension(self, extension): + self.Extensions.ClearExtension(extension) + + def HasExtension(self, extension): + return self.Extensions.HasExtension(extension) + + def HasField(self, field_name): + return self._cmsg.HasField(field_name) + + def ClearField(self, field_name): + child_cmessage = None + if field_name in self._composite_fields: + child_field = self._composite_fields[field_name] + del self._composite_fields[field_name] + + child_cdescriptor = self.__descriptors[field_name] + # TODO(anuraag): Support clearing repeated message fields as well. + if (child_cdescriptor.label != _LABEL_REPEATED and + child_cdescriptor.cpp_type == _CPPTYPE_MESSAGE): + child_field._owner = None + child_cmessage = child_field._cmsg + + if child_cmessage is not None: + self._cmsg.ClearField(field_name, child_cmessage) + else: + self._cmsg.ClearField(field_name) + + def Clear(self): + cmessages_to_release = [] + for field_name, child_field in self._composite_fields.iteritems(): + child_cdescriptor = self.__descriptors[field_name] + # TODO(anuraag): Support clearing repeated message fields as well. + if (child_cdescriptor.label != _LABEL_REPEATED and + child_cdescriptor.cpp_type == _CPPTYPE_MESSAGE): + child_field._owner = None + cmessages_to_release.append((child_cdescriptor, child_field._cmsg)) + self._composite_fields.clear() + self._cmsg.Clear(cmessages_to_release) + + def IsInitialized(self, errors=None): + if self._cmsg.IsInitialized(): + return True + if errors is not None: + errors.extend(self.FindInitializationErrors()); + return False + + def SerializeToString(self): + if not self.IsInitialized(): + raise message.EncodeError( + 'Message %s is missing required fields: %s' % ( + self._cmsg.full_name, ','.join(self.FindInitializationErrors()))) + return self._cmsg.SerializeToString() + + def SerializePartialToString(self): + return self._cmsg.SerializePartialToString() + + def ParseFromString(self, serialized): + self.Clear() + self.MergeFromString(serialized) + + def MergeFromString(self, serialized): + byte_size = self._cmsg.MergeFromString(serialized) + if byte_size < 0: + raise message.DecodeError('Unable to merge from string.') + return byte_size + + def MergeFrom(self, msg): + if not isinstance(msg, cls): + raise TypeError( + "Parameter to MergeFrom() must be instance of same class: " + "expected %s got %s." % (cls.__name__, type(msg).__name__)) + self._cmsg.MergeFrom(msg._cmsg) + + def CopyFrom(self, msg): + self._cmsg.CopyFrom(msg._cmsg) + + def ByteSize(self): + return self._cmsg.ByteSize() + + def SetInParent(self): + return self._cmsg.SetInParent() + + def ListFields(self): + all_fields = [] + field_list = self._cmsg.ListFields() + fields_by_name = cls.DESCRIPTOR.fields_by_name + for is_extension, field_name in field_list: + if is_extension: + extension = cls._extensions_by_name[field_name] + all_fields.append((extension, self.Extensions[extension])) + else: + field_descriptor = fields_by_name[field_name] + all_fields.append( + (field_descriptor, getattr(self, field_name))) + all_fields.sort(key=lambda item: item[0].number) + return all_fields + + def FindInitializationErrors(self): + return self._cmsg.FindInitializationErrors() + + def __str__(self): + return str(self._cmsg) + + def __eq__(self, other): + if self is other: + return True + if not isinstance(other, self.__class__): + return False + return self.ListFields() == other.ListFields() + + def __ne__(self, other): + return not self == other + + def __hash__(self): + raise TypeError('unhashable object') + + def __unicode__(self): + # Lazy import to prevent circular import when text_format imports this file. + from google.protobuf import text_format + return text_format.MessageToString(self, as_utf8=True).decode('utf-8') + + # Attach the local methods to the message class. + for key, value in locals().copy().iteritems(): + if key not in ('key', 'value', '__builtins__', '__name__', '__doc__'): + setattr(cls, key, value) + + # Static methods: + + def RegisterExtension(extension_handle): + extension_handle.containing_type = cls.DESCRIPTOR + cls._extensions_by_name[extension_handle.full_name] = extension_handle + + if _IsMessageSetExtension(extension_handle): + # MessageSet extension. Also register under type name. + cls._extensions_by_name[ + extension_handle.message_type.full_name] = extension_handle + cls.RegisterExtension = staticmethod(RegisterExtension) + + def FromString(string): + msg = cls() + msg.MergeFromString(string) + return msg + cls.FromString = staticmethod(FromString) + + + +def _AddPropertiesForExtensions(message_descriptor, cls): + """Adds properties for all fields in this protocol message type.""" + extension_dict = message_descriptor.extensions_by_name + for extension_name, extension_field in extension_dict.iteritems(): + constant_name = extension_name.upper() + '_FIELD_NUMBER' + setattr(cls, constant_name, extension_field.number) diff --git a/python/google/protobuf/internal/decoder.py b/python/google/protobuf/internal/decoder.py index 461a30c..a4b9060 100755 --- a/python/google/protobuf/internal/decoder.py +++ b/python/google/protobuf/internal/decoder.py @@ -1,6 +1,6 @@ # Protocol Buffers - Google's data interchange format # Copyright 2008 Google Inc. All rights reserved. -# http://code.google.com/p/protobuf/ +# https://developers.google.com/protocol-buffers/ # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -28,6 +28,10 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#PY25 compatible for GAE. +# +# Copyright 2009 Google Inc. All Rights Reserved. + """Code for decoding protocol buffer primitives. This code is very similar to encoder.py -- read the docs for that module first. @@ -81,17 +85,26 @@ we repeatedly read a tag, look up the corresponding decoder, and invoke it. __author__ = 'kenton@google.com (Kenton Varda)' import struct +import sys ##PY25 +_PY2 = sys.version_info[0] < 3 ##PY25 from google.protobuf.internal import encoder from google.protobuf.internal import wire_format from google.protobuf import message +# This will overflow and thus become IEEE-754 "infinity". We would use +# "float('inf')" but it doesn't work on Windows pre-Python-2.6. +_POS_INF = 1e10000 +_NEG_INF = -_POS_INF +_NAN = _POS_INF * 0 + + # This is not for optimization, but rather to avoid conflicts with local # variables named "message". _DecodeError = message.DecodeError -def _VarintDecoder(mask): +def _VarintDecoder(mask, result_type): """Return an encoder for a basic varint value (does not include tag). Decoded values will be bitwise-anded with the given mask before being @@ -102,15 +115,18 @@ def _VarintDecoder(mask): """ local_ord = ord + py2 = _PY2 ##PY25 +##!PY25 py2 = str is bytes def DecodeVarint(buffer, pos): result = 0 shift = 0 while 1: - b = local_ord(buffer[pos]) + b = local_ord(buffer[pos]) if py2 else buffer[pos] result |= ((b & 0x7f) << shift) pos += 1 if not (b & 0x80): result &= mask + result = result_type(result) return (result, pos) shift += 7 if shift >= 64: @@ -118,15 +134,17 @@ def _VarintDecoder(mask): return DecodeVarint -def _SignedVarintDecoder(mask): +def _SignedVarintDecoder(mask, result_type): """Like _VarintDecoder() but decodes signed values.""" local_ord = ord + py2 = _PY2 ##PY25 +##!PY25 py2 = str is bytes def DecodeVarint(buffer, pos): result = 0 shift = 0 while 1: - b = local_ord(buffer[pos]) + b = local_ord(buffer[pos]) if py2 else buffer[pos] result |= ((b & 0x7f) << shift) pos += 1 if not (b & 0x80): @@ -135,19 +153,23 @@ def _SignedVarintDecoder(mask): result |= ~mask else: result &= mask + result = result_type(result) return (result, pos) shift += 7 if shift >= 64: raise _DecodeError('Too many bytes when decoding varint.') return DecodeVarint +# We force 32-bit values to int and 64-bit values to long to make +# alternate implementations where the distinction is more significant +# (e.g. the C++ implementation) simpler. -_DecodeVarint = _VarintDecoder((1 << 64) - 1) -_DecodeSignedVarint = _SignedVarintDecoder((1 << 64) - 1) +_DecodeVarint = _VarintDecoder((1 << 64) - 1, long) +_DecodeSignedVarint = _SignedVarintDecoder((1 << 64) - 1, long) # Use these versions for values which must be limited to 32 bits. -_DecodeVarint32 = _VarintDecoder((1 << 32) - 1) -_DecodeSignedVarint32 = _SignedVarintDecoder((1 << 32) - 1) +_DecodeVarint32 = _VarintDecoder((1 << 32) - 1, int) +_DecodeSignedVarint32 = _SignedVarintDecoder((1 << 32) - 1, int) def ReadTag(buffer, pos): @@ -161,8 +183,10 @@ def ReadTag(buffer, pos): use that, but not in Python. """ + py2 = _PY2 ##PY25 +##!PY25 py2 = str is bytes start = pos - while ord(buffer[pos]) & 0x80: + while (ord(buffer[pos]) if py2 else buffer[pos]) & 0x80: pos += 1 pos += 1 return (buffer[start:pos], pos) @@ -269,10 +293,161 @@ def _StructPackDecoder(wire_type, format): return _SimpleDecoder(wire_type, InnerDecode) +def _FloatDecoder(): + """Returns a decoder for a float field. + + This code works around a bug in struct.unpack for non-finite 32-bit + floating-point values. + """ + + local_unpack = struct.unpack + b = (lambda x:x) if _PY2 else lambda x:x.encode('latin1') ##PY25 + + def InnerDecode(buffer, pos): + # We expect a 32-bit value in little-endian byte order. Bit 1 is the sign + # bit, bits 2-9 represent the exponent, and bits 10-32 are the significand. + new_pos = pos + 4 + float_bytes = buffer[pos:new_pos] + + # If this value has all its exponent bits set, then it's non-finite. + # In Python 2.4, struct.unpack will convert it to a finite 64-bit value. + # To avoid that, we parse it specially. + if ((float_bytes[3:4] in b('\x7F\xFF')) ##PY25 +##!PY25 if ((float_bytes[3:4] in b'\x7F\xFF') + and (float_bytes[2:3] >= b('\x80'))): ##PY25 +##!PY25 and (float_bytes[2:3] >= b'\x80')): + # If at least one significand bit is set... + if float_bytes[0:3] != b('\x00\x00\x80'): ##PY25 +##!PY25 if float_bytes[0:3] != b'\x00\x00\x80': + return (_NAN, new_pos) + # If sign bit is set... + if float_bytes[3:4] == b('\xFF'): ##PY25 +##!PY25 if float_bytes[3:4] == b'\xFF': + return (_NEG_INF, new_pos) + return (_POS_INF, new_pos) + + # Note that we expect someone up-stack to catch struct.error and convert + # it to _DecodeError -- this way we don't have to set up exception- + # handling blocks every time we parse one value. + result = local_unpack('<f', float_bytes)[0] + return (result, new_pos) + return _SimpleDecoder(wire_format.WIRETYPE_FIXED32, InnerDecode) + + +def _DoubleDecoder(): + """Returns a decoder for a double field. + + This code works around a bug in struct.unpack for not-a-number. + """ + + local_unpack = struct.unpack + b = (lambda x:x) if _PY2 else lambda x:x.encode('latin1') ##PY25 + + def InnerDecode(buffer, pos): + # We expect a 64-bit value in little-endian byte order. Bit 1 is the sign + # bit, bits 2-12 represent the exponent, and bits 13-64 are the significand. + new_pos = pos + 8 + double_bytes = buffer[pos:new_pos] + + # If this value has all its exponent bits set and at least one significand + # bit set, it's not a number. In Python 2.4, struct.unpack will treat it + # as inf or -inf. To avoid that, we treat it specially. +##!PY25 if ((double_bytes[7:8] in b'\x7F\xFF') +##!PY25 and (double_bytes[6:7] >= b'\xF0') +##!PY25 and (double_bytes[0:7] != b'\x00\x00\x00\x00\x00\x00\xF0')): + if ((double_bytes[7:8] in b('\x7F\xFF')) ##PY25 + and (double_bytes[6:7] >= b('\xF0')) ##PY25 + and (double_bytes[0:7] != b('\x00\x00\x00\x00\x00\x00\xF0'))): ##PY25 + return (_NAN, new_pos) + + # Note that we expect someone up-stack to catch struct.error and convert + # it to _DecodeError -- this way we don't have to set up exception- + # handling blocks every time we parse one value. + result = local_unpack('<d', double_bytes)[0] + return (result, new_pos) + return _SimpleDecoder(wire_format.WIRETYPE_FIXED64, InnerDecode) + + +def EnumDecoder(field_number, is_repeated, is_packed, key, new_default): + enum_type = key.enum_type + if is_packed: + local_DecodeVarint = _DecodeVarint + def DecodePackedField(buffer, pos, end, message, field_dict): + value = field_dict.get(key) + if value is None: + value = field_dict.setdefault(key, new_default(message)) + (endpoint, pos) = local_DecodeVarint(buffer, pos) + endpoint += pos + if endpoint > end: + raise _DecodeError('Truncated message.') + while pos < endpoint: + value_start_pos = pos + (element, pos) = _DecodeSignedVarint32(buffer, pos) + if element in enum_type.values_by_number: + value.append(element) + else: + if not message._unknown_fields: + message._unknown_fields = [] + tag_bytes = encoder.TagBytes(field_number, + wire_format.WIRETYPE_VARINT) + message._unknown_fields.append( + (tag_bytes, buffer[value_start_pos:pos])) + if pos > endpoint: + if element in enum_type.values_by_number: + del value[-1] # Discard corrupt value. + else: + del message._unknown_fields[-1] + raise _DecodeError('Packed element was truncated.') + return pos + return DecodePackedField + elif is_repeated: + tag_bytes = encoder.TagBytes(field_number, wire_format.WIRETYPE_VARINT) + tag_len = len(tag_bytes) + def DecodeRepeatedField(buffer, pos, end, message, field_dict): + value = field_dict.get(key) + if value is None: + value = field_dict.setdefault(key, new_default(message)) + while 1: + (element, new_pos) = _DecodeSignedVarint32(buffer, pos) + if element in enum_type.values_by_number: + value.append(element) + else: + if not message._unknown_fields: + message._unknown_fields = [] + message._unknown_fields.append( + (tag_bytes, buffer[pos:new_pos])) + # Predict that the next tag is another copy of the same repeated + # field. + pos = new_pos + tag_len + if buffer[new_pos:pos] != tag_bytes or new_pos >= end: + # Prediction failed. Return. + if new_pos > end: + raise _DecodeError('Truncated message.') + return new_pos + return DecodeRepeatedField + else: + def DecodeField(buffer, pos, end, message, field_dict): + value_start_pos = pos + (enum_value, pos) = _DecodeSignedVarint32(buffer, pos) + if pos > end: + raise _DecodeError('Truncated message.') + if enum_value in enum_type.values_by_number: + field_dict[key] = enum_value + else: + if not message._unknown_fields: + message._unknown_fields = [] + tag_bytes = encoder.TagBytes(field_number, + wire_format.WIRETYPE_VARINT) + message._unknown_fields.append( + (tag_bytes, buffer[value_start_pos:pos])) + return pos + return DecodeField + + # -------------------------------------------------------------------- -Int32Decoder = EnumDecoder = _SimpleDecoder( +Int32Decoder = _SimpleDecoder( wire_format.WIRETYPE_VARINT, _DecodeSignedVarint32) Int64Decoder = _SimpleDecoder( @@ -294,8 +469,8 @@ Fixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<I') Fixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<Q') SFixed32Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<i') SFixed64Decoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<q') -FloatDecoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED32, '<f') -DoubleDecoder = _StructPackDecoder(wire_format.WIRETYPE_FIXED64, '<d') +FloatDecoder = _FloatDecoder() +DoubleDecoder = _DoubleDecoder() BoolDecoder = _ModifiedDecoder( wire_format.WIRETYPE_VARINT, _DecodeVarint, bool) @@ -307,6 +482,14 @@ def StringDecoder(field_number, is_repeated, is_packed, key, new_default): local_DecodeVarint = _DecodeVarint local_unicode = unicode + def _ConvertToUnicode(byte_str): + try: + return local_unicode(byte_str, 'utf-8') + except UnicodeDecodeError, e: + # add more information to the error message and re-raise it. + e.reason = '%s in field: %s' % (e, key.full_name) + raise + assert not is_packed if is_repeated: tag_bytes = encoder.TagBytes(field_number, @@ -321,7 +504,7 @@ def StringDecoder(field_number, is_repeated, is_packed, key, new_default): new_pos = pos + size if new_pos > end: raise _DecodeError('Truncated string.') - value.append(local_unicode(buffer[pos:new_pos], 'utf-8')) + value.append(_ConvertToUnicode(buffer[pos:new_pos])) # Predict that the next tag is another copy of the same repeated field. pos = new_pos + tag_len if buffer[new_pos:pos] != tag_bytes or new_pos == end: @@ -334,7 +517,7 @@ def StringDecoder(field_number, is_repeated, is_packed, key, new_default): new_pos = pos + size if new_pos > end: raise _DecodeError('Truncated string.') - field_dict[key] = local_unicode(buffer[pos:new_pos], 'utf-8') + field_dict[key] = _ConvertToUnicode(buffer[pos:new_pos]) return new_pos return DecodeField @@ -503,6 +686,7 @@ def MessageSetItemDecoder(extensions_by_number): local_SkipField = SkipField def DecodeItem(buffer, pos, end, message, field_dict): + message_set_item_start = pos type_id = -1 message_start = -1 message_end = -1 @@ -541,6 +725,11 @@ def MessageSetItemDecoder(extensions_by_number): # The only reason _InternalParse would return early is if it encountered # an end-group tag. raise _DecodeError('Unexpected end-group tag.') + else: + if not message._unknown_fields: + message._unknown_fields = [] + message._unknown_fields.append((MESSAGE_SET_ITEM_TAG, + buffer[message_set_item_start:pos])) return pos @@ -552,8 +741,10 @@ def MessageSetItemDecoder(extensions_by_number): def _SkipVarint(buffer, pos, end): """Skip a varint value. Returns the new position.""" - - while ord(buffer[pos]) & 0x80: + # Previously ord(buffer[pos]) raised IndexError when pos is out of range. + # With this code, ord(b'') raises TypeError. Both are handled in + # python_message.py to generate a 'Truncated message' error. + while ord(buffer[pos:pos+1]) & 0x80: pos += 1 pos += 1 if pos > end: @@ -620,7 +811,6 @@ def _FieldSkipper(): ] wiretype_mask = wire_format.TAG_TYPE_MASK - local_ord = ord def SkipField(buffer, pos, end, tag_bytes): """Skips a field with the specified tag. @@ -633,7 +823,7 @@ def _FieldSkipper(): """ # The wire type is always in the first byte since varints are little-endian. - wire_type = local_ord(tag_bytes[0]) & wiretype_mask + wire_type = ord(tag_bytes[0:1]) & wiretype_mask return WIRETYPE_TO_SKIPPER[wire_type](buffer, pos, end) return SkipField diff --git a/python/google/protobuf/internal/descriptor_database_test.py b/python/google/protobuf/internal/descriptor_database_test.py new file mode 100644 index 0000000..fc65b69 --- /dev/null +++ b/python/google/protobuf/internal/descriptor_database_test.py @@ -0,0 +1,63 @@ +#! /usr/bin/python +# +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# https://developers.google.com/protocol-buffers/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Tests for google.protobuf.descriptor_database.""" + +__author__ = 'matthewtoia@google.com (Matt Toia)' + +from google.apputils import basetest +from google.protobuf import descriptor_pb2 +from google.protobuf.internal import factory_test2_pb2 +from google.protobuf import descriptor_database + + +class DescriptorDatabaseTest(basetest.TestCase): + + def testAdd(self): + db = descriptor_database.DescriptorDatabase() + file_desc_proto = descriptor_pb2.FileDescriptorProto.FromString( + factory_test2_pb2.DESCRIPTOR.serialized_pb) + db.Add(file_desc_proto) + + self.assertEquals(file_desc_proto, db.FindFileByName( + 'google/protobuf/internal/factory_test2.proto')) + self.assertEquals(file_desc_proto, db.FindFileContainingSymbol( + 'google.protobuf.python.internal.Factory2Message')) + self.assertEquals(file_desc_proto, db.FindFileContainingSymbol( + 'google.protobuf.python.internal.Factory2Message.NestedFactory2Message')) + self.assertEquals(file_desc_proto, db.FindFileContainingSymbol( + 'google.protobuf.python.internal.Factory2Enum')) + self.assertEquals(file_desc_proto, db.FindFileContainingSymbol( + 'google.protobuf.python.internal.Factory2Message.NestedFactory2Enum')) + +if __name__ == '__main__': + basetest.main() diff --git a/python/google/protobuf/internal/descriptor_pool_test.py b/python/google/protobuf/internal/descriptor_pool_test.py new file mode 100644 index 0000000..d2f8557 --- /dev/null +++ b/python/google/protobuf/internal/descriptor_pool_test.py @@ -0,0 +1,564 @@ +#! /usr/bin/python +# +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# https://developers.google.com/protocol-buffers/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Tests for google.protobuf.descriptor_pool.""" + +__author__ = 'matthewtoia@google.com (Matt Toia)' + +import os +import unittest + +from google.apputils import basetest +from google.protobuf import unittest_pb2 +from google.protobuf import descriptor_pb2 +from google.protobuf.internal import api_implementation +from google.protobuf.internal import descriptor_pool_test1_pb2 +from google.protobuf.internal import descriptor_pool_test2_pb2 +from google.protobuf.internal import factory_test1_pb2 +from google.protobuf.internal import factory_test2_pb2 +from google.protobuf import descriptor +from google.protobuf import descriptor_database +from google.protobuf import descriptor_pool + + +class DescriptorPoolTest(basetest.TestCase): + + def setUp(self): + self.pool = descriptor_pool.DescriptorPool() + self.factory_test1_fd = descriptor_pb2.FileDescriptorProto.FromString( + factory_test1_pb2.DESCRIPTOR.serialized_pb) + self.factory_test2_fd = descriptor_pb2.FileDescriptorProto.FromString( + factory_test2_pb2.DESCRIPTOR.serialized_pb) + self.pool.Add(self.factory_test1_fd) + self.pool.Add(self.factory_test2_fd) + + def testFindFileByName(self): + name1 = 'google/protobuf/internal/factory_test1.proto' + file_desc1 = self.pool.FindFileByName(name1) + self.assertIsInstance(file_desc1, descriptor.FileDescriptor) + self.assertEquals(name1, file_desc1.name) + self.assertEquals('google.protobuf.python.internal', file_desc1.package) + self.assertIn('Factory1Message', file_desc1.message_types_by_name) + + name2 = 'google/protobuf/internal/factory_test2.proto' + file_desc2 = self.pool.FindFileByName(name2) + self.assertIsInstance(file_desc2, descriptor.FileDescriptor) + self.assertEquals(name2, file_desc2.name) + self.assertEquals('google.protobuf.python.internal', file_desc2.package) + self.assertIn('Factory2Message', file_desc2.message_types_by_name) + + def testFindFileByNameFailure(self): + with self.assertRaises(KeyError): + self.pool.FindFileByName('Does not exist') + + def testFindFileContainingSymbol(self): + file_desc1 = self.pool.FindFileContainingSymbol( + 'google.protobuf.python.internal.Factory1Message') + self.assertIsInstance(file_desc1, descriptor.FileDescriptor) + self.assertEquals('google/protobuf/internal/factory_test1.proto', + file_desc1.name) + self.assertEquals('google.protobuf.python.internal', file_desc1.package) + self.assertIn('Factory1Message', file_desc1.message_types_by_name) + + file_desc2 = self.pool.FindFileContainingSymbol( + 'google.protobuf.python.internal.Factory2Message') + self.assertIsInstance(file_desc2, descriptor.FileDescriptor) + self.assertEquals('google/protobuf/internal/factory_test2.proto', + file_desc2.name) + self.assertEquals('google.protobuf.python.internal', file_desc2.package) + self.assertIn('Factory2Message', file_desc2.message_types_by_name) + + def testFindFileContainingSymbolFailure(self): + with self.assertRaises(KeyError): + self.pool.FindFileContainingSymbol('Does not exist') + + def testFindMessageTypeByName(self): + msg1 = self.pool.FindMessageTypeByName( + 'google.protobuf.python.internal.Factory1Message') + self.assertIsInstance(msg1, descriptor.Descriptor) + self.assertEquals('Factory1Message', msg1.name) + self.assertEquals('google.protobuf.python.internal.Factory1Message', + msg1.full_name) + self.assertEquals(None, msg1.containing_type) + + nested_msg1 = msg1.nested_types[0] + self.assertEquals('NestedFactory1Message', nested_msg1.name) + self.assertEquals(msg1, nested_msg1.containing_type) + + nested_enum1 = msg1.enum_types[0] + self.assertEquals('NestedFactory1Enum', nested_enum1.name) + self.assertEquals(msg1, nested_enum1.containing_type) + + self.assertEquals(nested_msg1, msg1.fields_by_name[ + 'nested_factory_1_message'].message_type) + self.assertEquals(nested_enum1, msg1.fields_by_name[ + 'nested_factory_1_enum'].enum_type) + + msg2 = self.pool.FindMessageTypeByName( + 'google.protobuf.python.internal.Factory2Message') + self.assertIsInstance(msg2, descriptor.Descriptor) + self.assertEquals('Factory2Message', msg2.name) + self.assertEquals('google.protobuf.python.internal.Factory2Message', + msg2.full_name) + self.assertIsNone(msg2.containing_type) + + nested_msg2 = msg2.nested_types[0] + self.assertEquals('NestedFactory2Message', nested_msg2.name) + self.assertEquals(msg2, nested_msg2.containing_type) + + nested_enum2 = msg2.enum_types[0] + self.assertEquals('NestedFactory2Enum', nested_enum2.name) + self.assertEquals(msg2, nested_enum2.containing_type) + + self.assertEquals(nested_msg2, msg2.fields_by_name[ + 'nested_factory_2_message'].message_type) + self.assertEquals(nested_enum2, msg2.fields_by_name[ + 'nested_factory_2_enum'].enum_type) + + self.assertTrue(msg2.fields_by_name['int_with_default'].has_default_value) + self.assertEquals( + 1776, msg2.fields_by_name['int_with_default'].default_value) + + self.assertTrue( + msg2.fields_by_name['double_with_default'].has_default_value) + self.assertEquals( + 9.99, msg2.fields_by_name['double_with_default'].default_value) + + self.assertTrue( + msg2.fields_by_name['string_with_default'].has_default_value) + self.assertEquals( + 'hello world', msg2.fields_by_name['string_with_default'].default_value) + + self.assertTrue(msg2.fields_by_name['bool_with_default'].has_default_value) + self.assertFalse(msg2.fields_by_name['bool_with_default'].default_value) + + self.assertTrue(msg2.fields_by_name['enum_with_default'].has_default_value) + self.assertEquals( + 1, msg2.fields_by_name['enum_with_default'].default_value) + + msg3 = self.pool.FindMessageTypeByName( + 'google.protobuf.python.internal.Factory2Message.NestedFactory2Message') + self.assertEquals(nested_msg2, msg3) + + self.assertTrue(msg2.fields_by_name['bytes_with_default'].has_default_value) + self.assertEquals( + b'a\xfb\x00c', + msg2.fields_by_name['bytes_with_default'].default_value) + + self.assertEqual(1, len(msg2.oneofs)) + self.assertEqual(1, len(msg2.oneofs_by_name)) + self.assertEqual(2, len(msg2.oneofs[0].fields)) + for name in ['oneof_int', 'oneof_string']: + self.assertEqual(msg2.oneofs[0], + msg2.fields_by_name[name].containing_oneof) + self.assertIn(msg2.fields_by_name[name], msg2.oneofs[0].fields) + + def testFindMessageTypeByNameFailure(self): + with self.assertRaises(KeyError): + self.pool.FindMessageTypeByName('Does not exist') + + def testFindEnumTypeByName(self): + enum1 = self.pool.FindEnumTypeByName( + 'google.protobuf.python.internal.Factory1Enum') + self.assertIsInstance(enum1, descriptor.EnumDescriptor) + self.assertEquals(0, enum1.values_by_name['FACTORY_1_VALUE_0'].number) + self.assertEquals(1, enum1.values_by_name['FACTORY_1_VALUE_1'].number) + + nested_enum1 = self.pool.FindEnumTypeByName( + 'google.protobuf.python.internal.Factory1Message.NestedFactory1Enum') + self.assertIsInstance(nested_enum1, descriptor.EnumDescriptor) + self.assertEquals( + 0, nested_enum1.values_by_name['NESTED_FACTORY_1_VALUE_0'].number) + self.assertEquals( + 1, nested_enum1.values_by_name['NESTED_FACTORY_1_VALUE_1'].number) + + enum2 = self.pool.FindEnumTypeByName( + 'google.protobuf.python.internal.Factory2Enum') + self.assertIsInstance(enum2, descriptor.EnumDescriptor) + self.assertEquals(0, enum2.values_by_name['FACTORY_2_VALUE_0'].number) + self.assertEquals(1, enum2.values_by_name['FACTORY_2_VALUE_1'].number) + + nested_enum2 = self.pool.FindEnumTypeByName( + 'google.protobuf.python.internal.Factory2Message.NestedFactory2Enum') + self.assertIsInstance(nested_enum2, descriptor.EnumDescriptor) + self.assertEquals( + 0, nested_enum2.values_by_name['NESTED_FACTORY_2_VALUE_0'].number) + self.assertEquals( + 1, nested_enum2.values_by_name['NESTED_FACTORY_2_VALUE_1'].number) + + def testFindEnumTypeByNameFailure(self): + with self.assertRaises(KeyError): + self.pool.FindEnumTypeByName('Does not exist') + + def testUserDefinedDB(self): + db = descriptor_database.DescriptorDatabase() + self.pool = descriptor_pool.DescriptorPool(db) + db.Add(self.factory_test1_fd) + db.Add(self.factory_test2_fd) + self.testFindMessageTypeByName() + + def testComplexNesting(self): + test1_desc = descriptor_pb2.FileDescriptorProto.FromString( + descriptor_pool_test1_pb2.DESCRIPTOR.serialized_pb) + test2_desc = descriptor_pb2.FileDescriptorProto.FromString( + descriptor_pool_test2_pb2.DESCRIPTOR.serialized_pb) + self.pool.Add(test1_desc) + self.pool.Add(test2_desc) + TEST1_FILE.CheckFile(self, self.pool) + TEST2_FILE.CheckFile(self, self.pool) + + + +class ProtoFile(object): + + def __init__(self, name, package, messages, dependencies=None): + self.name = name + self.package = package + self.messages = messages + self.dependencies = dependencies or [] + + def CheckFile(self, test, pool): + file_desc = pool.FindFileByName(self.name) + test.assertEquals(self.name, file_desc.name) + test.assertEquals(self.package, file_desc.package) + dependencies_names = [f.name for f in file_desc.dependencies] + test.assertEqual(self.dependencies, dependencies_names) + for name, msg_type in self.messages.items(): + msg_type.CheckType(test, None, name, file_desc) + + +class EnumType(object): + + def __init__(self, values): + self.values = values + + def CheckType(self, test, msg_desc, name, file_desc): + enum_desc = msg_desc.enum_types_by_name[name] + test.assertEqual(name, enum_desc.name) + expected_enum_full_name = '.'.join([msg_desc.full_name, name]) + test.assertEqual(expected_enum_full_name, enum_desc.full_name) + test.assertEqual(msg_desc, enum_desc.containing_type) + test.assertEqual(file_desc, enum_desc.file) + for index, (value, number) in enumerate(self.values): + value_desc = enum_desc.values_by_name[value] + test.assertEqual(value, value_desc.name) + test.assertEqual(index, value_desc.index) + test.assertEqual(number, value_desc.number) + test.assertEqual(enum_desc, value_desc.type) + test.assertIn(value, msg_desc.enum_values_by_name) + + +class MessageType(object): + + def __init__(self, type_dict, field_list, is_extendable=False, + extensions=None): + self.type_dict = type_dict + self.field_list = field_list + self.is_extendable = is_extendable + self.extensions = extensions or [] + + def CheckType(self, test, containing_type_desc, name, file_desc): + if containing_type_desc is None: + desc = file_desc.message_types_by_name[name] + expected_full_name = '.'.join([file_desc.package, name]) + else: + desc = containing_type_desc.nested_types_by_name[name] + expected_full_name = '.'.join([containing_type_desc.full_name, name]) + + test.assertEqual(name, desc.name) + test.assertEqual(expected_full_name, desc.full_name) + test.assertEqual(containing_type_desc, desc.containing_type) + test.assertEqual(desc.file, file_desc) + test.assertEqual(self.is_extendable, desc.is_extendable) + for name, subtype in self.type_dict.items(): + subtype.CheckType(test, desc, name, file_desc) + + for index, (name, field) in enumerate(self.field_list): + field.CheckField(test, desc, name, index) + + for index, (name, field) in enumerate(self.extensions): + field.CheckField(test, desc, name, index) + + +class EnumField(object): + + def __init__(self, number, type_name, default_value): + self.number = number + self.type_name = type_name + self.default_value = default_value + + def CheckField(self, test, msg_desc, name, index): + field_desc = msg_desc.fields_by_name[name] + enum_desc = msg_desc.enum_types_by_name[self.type_name] + test.assertEqual(name, field_desc.name) + expected_field_full_name = '.'.join([msg_desc.full_name, name]) + test.assertEqual(expected_field_full_name, field_desc.full_name) + test.assertEqual(index, field_desc.index) + test.assertEqual(self.number, field_desc.number) + test.assertEqual(descriptor.FieldDescriptor.TYPE_ENUM, field_desc.type) + test.assertEqual(descriptor.FieldDescriptor.CPPTYPE_ENUM, + field_desc.cpp_type) + test.assertTrue(field_desc.has_default_value) + test.assertEqual(enum_desc.values_by_name[self.default_value].index, + field_desc.default_value) + test.assertEqual(msg_desc, field_desc.containing_type) + test.assertEqual(enum_desc, field_desc.enum_type) + + +class MessageField(object): + + def __init__(self, number, type_name): + self.number = number + self.type_name = type_name + + def CheckField(self, test, msg_desc, name, index): + field_desc = msg_desc.fields_by_name[name] + field_type_desc = msg_desc.nested_types_by_name[self.type_name] + test.assertEqual(name, field_desc.name) + expected_field_full_name = '.'.join([msg_desc.full_name, name]) + test.assertEqual(expected_field_full_name, field_desc.full_name) + test.assertEqual(index, field_desc.index) + test.assertEqual(self.number, field_desc.number) + test.assertEqual(descriptor.FieldDescriptor.TYPE_MESSAGE, field_desc.type) + test.assertEqual(descriptor.FieldDescriptor.CPPTYPE_MESSAGE, + field_desc.cpp_type) + test.assertFalse(field_desc.has_default_value) + test.assertEqual(msg_desc, field_desc.containing_type) + test.assertEqual(field_type_desc, field_desc.message_type) + + +class StringField(object): + + def __init__(self, number, default_value): + self.number = number + self.default_value = default_value + + def CheckField(self, test, msg_desc, name, index): + field_desc = msg_desc.fields_by_name[name] + test.assertEqual(name, field_desc.name) + expected_field_full_name = '.'.join([msg_desc.full_name, name]) + test.assertEqual(expected_field_full_name, field_desc.full_name) + test.assertEqual(index, field_desc.index) + test.assertEqual(self.number, field_desc.number) + test.assertEqual(descriptor.FieldDescriptor.TYPE_STRING, field_desc.type) + test.assertEqual(descriptor.FieldDescriptor.CPPTYPE_STRING, + field_desc.cpp_type) + test.assertTrue(field_desc.has_default_value) + test.assertEqual(self.default_value, field_desc.default_value) + + +class ExtensionField(object): + + def __init__(self, number, extended_type): + self.number = number + self.extended_type = extended_type + + def CheckField(self, test, msg_desc, name, index): + field_desc = msg_desc.extensions_by_name[name] + test.assertEqual(name, field_desc.name) + expected_field_full_name = '.'.join([msg_desc.full_name, name]) + test.assertEqual(expected_field_full_name, field_desc.full_name) + test.assertEqual(self.number, field_desc.number) + test.assertEqual(index, field_desc.index) + test.assertEqual(descriptor.FieldDescriptor.TYPE_MESSAGE, field_desc.type) + test.assertEqual(descriptor.FieldDescriptor.CPPTYPE_MESSAGE, + field_desc.cpp_type) + test.assertFalse(field_desc.has_default_value) + test.assertTrue(field_desc.is_extension) + test.assertEqual(msg_desc, field_desc.extension_scope) + test.assertEqual(msg_desc, field_desc.message_type) + test.assertEqual(self.extended_type, field_desc.containing_type.name) + + +class AddDescriptorTest(basetest.TestCase): + + def _TestMessage(self, prefix): + pool = descriptor_pool.DescriptorPool() + pool.AddDescriptor(unittest_pb2.TestAllTypes.DESCRIPTOR) + self.assertEquals( + 'protobuf_unittest.TestAllTypes', + pool.FindMessageTypeByName( + prefix + 'protobuf_unittest.TestAllTypes').full_name) + + # AddDescriptor is not recursive. + with self.assertRaises(KeyError): + pool.FindMessageTypeByName( + prefix + 'protobuf_unittest.TestAllTypes.NestedMessage') + + pool.AddDescriptor(unittest_pb2.TestAllTypes.NestedMessage.DESCRIPTOR) + self.assertEquals( + 'protobuf_unittest.TestAllTypes.NestedMessage', + pool.FindMessageTypeByName( + prefix + 'protobuf_unittest.TestAllTypes.NestedMessage').full_name) + + # Files are implicitly also indexed when messages are added. + self.assertEquals( + 'google/protobuf/unittest.proto', + pool.FindFileByName( + 'google/protobuf/unittest.proto').name) + + self.assertEquals( + 'google/protobuf/unittest.proto', + pool.FindFileContainingSymbol( + prefix + 'protobuf_unittest.TestAllTypes.NestedMessage').name) + + def testMessage(self): + self._TestMessage('') + self._TestMessage('.') + + def _TestEnum(self, prefix): + pool = descriptor_pool.DescriptorPool() + pool.AddEnumDescriptor(unittest_pb2.ForeignEnum.DESCRIPTOR) + self.assertEquals( + 'protobuf_unittest.ForeignEnum', + pool.FindEnumTypeByName( + prefix + 'protobuf_unittest.ForeignEnum').full_name) + + # AddEnumDescriptor is not recursive. + with self.assertRaises(KeyError): + pool.FindEnumTypeByName( + prefix + 'protobuf_unittest.ForeignEnum.NestedEnum') + + pool.AddEnumDescriptor(unittest_pb2.TestAllTypes.NestedEnum.DESCRIPTOR) + self.assertEquals( + 'protobuf_unittest.TestAllTypes.NestedEnum', + pool.FindEnumTypeByName( + prefix + 'protobuf_unittest.TestAllTypes.NestedEnum').full_name) + + # Files are implicitly also indexed when enums are added. + self.assertEquals( + 'google/protobuf/unittest.proto', + pool.FindFileByName( + 'google/protobuf/unittest.proto').name) + + self.assertEquals( + 'google/protobuf/unittest.proto', + pool.FindFileContainingSymbol( + prefix + 'protobuf_unittest.TestAllTypes.NestedEnum').name) + + def testEnum(self): + self._TestEnum('') + self._TestEnum('.') + + def testFile(self): + pool = descriptor_pool.DescriptorPool() + pool.AddFileDescriptor(unittest_pb2.DESCRIPTOR) + self.assertEquals( + 'google/protobuf/unittest.proto', + pool.FindFileByName( + 'google/protobuf/unittest.proto').name) + + # AddFileDescriptor is not recursive; messages and enums within files must + # be explicitly registered. + with self.assertRaises(KeyError): + pool.FindFileContainingSymbol( + 'protobuf_unittest.TestAllTypes') + + +TEST1_FILE = ProtoFile( + 'google/protobuf/internal/descriptor_pool_test1.proto', + 'google.protobuf.python.internal', + { + 'DescriptorPoolTest1': MessageType({ + 'NestedEnum': EnumType([('ALPHA', 1), ('BETA', 2)]), + 'NestedMessage': MessageType({ + 'NestedEnum': EnumType([('EPSILON', 5), ('ZETA', 6)]), + 'DeepNestedMessage': MessageType({ + 'NestedEnum': EnumType([('ETA', 7), ('THETA', 8)]), + }, [ + ('nested_enum', EnumField(1, 'NestedEnum', 'ETA')), + ('nested_field', StringField(2, 'theta')), + ]), + }, [ + ('nested_enum', EnumField(1, 'NestedEnum', 'ZETA')), + ('nested_field', StringField(2, 'beta')), + ('deep_nested_message', MessageField(3, 'DeepNestedMessage')), + ]) + }, [ + ('nested_enum', EnumField(1, 'NestedEnum', 'BETA')), + ('nested_message', MessageField(2, 'NestedMessage')), + ], is_extendable=True), + + 'DescriptorPoolTest2': MessageType({ + 'NestedEnum': EnumType([('GAMMA', 3), ('DELTA', 4)]), + 'NestedMessage': MessageType({ + 'NestedEnum': EnumType([('IOTA', 9), ('KAPPA', 10)]), + 'DeepNestedMessage': MessageType({ + 'NestedEnum': EnumType([('LAMBDA', 11), ('MU', 12)]), + }, [ + ('nested_enum', EnumField(1, 'NestedEnum', 'MU')), + ('nested_field', StringField(2, 'lambda')), + ]), + }, [ + ('nested_enum', EnumField(1, 'NestedEnum', 'IOTA')), + ('nested_field', StringField(2, 'delta')), + ('deep_nested_message', MessageField(3, 'DeepNestedMessage')), + ]) + }, [ + ('nested_enum', EnumField(1, 'NestedEnum', 'GAMMA')), + ('nested_message', MessageField(2, 'NestedMessage')), + ]), + }) + + +TEST2_FILE = ProtoFile( + 'google/protobuf/internal/descriptor_pool_test2.proto', + 'google.protobuf.python.internal', + { + 'DescriptorPoolTest3': MessageType({ + 'NestedEnum': EnumType([('NU', 13), ('XI', 14)]), + 'NestedMessage': MessageType({ + 'NestedEnum': EnumType([('OMICRON', 15), ('PI', 16)]), + 'DeepNestedMessage': MessageType({ + 'NestedEnum': EnumType([('RHO', 17), ('SIGMA', 18)]), + }, [ + ('nested_enum', EnumField(1, 'NestedEnum', 'RHO')), + ('nested_field', StringField(2, 'sigma')), + ]), + }, [ + ('nested_enum', EnumField(1, 'NestedEnum', 'PI')), + ('nested_field', StringField(2, 'nu')), + ('deep_nested_message', MessageField(3, 'DeepNestedMessage')), + ]) + }, [ + ('nested_enum', EnumField(1, 'NestedEnum', 'XI')), + ('nested_message', MessageField(2, 'NestedMessage')), + ], extensions=[ + ('descriptor_pool_test', + ExtensionField(1001, 'DescriptorPoolTest1')), + ]), + }, + dependencies=['google/protobuf/internal/descriptor_pool_test1.proto']) + + +if __name__ == '__main__': + basetest.main() diff --git a/python/google/protobuf/internal/descriptor_pool_test1.proto b/python/google/protobuf/internal/descriptor_pool_test1.proto new file mode 100644 index 0000000..6dfe4ef --- /dev/null +++ b/python/google/protobuf/internal/descriptor_pool_test1.proto @@ -0,0 +1,94 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package google.protobuf.python.internal; + + +message DescriptorPoolTest1 { + extensions 1000 to max; + + enum NestedEnum { + ALPHA = 1; + BETA = 2; + } + + optional NestedEnum nested_enum = 1 [default = BETA]; + + message NestedMessage { + enum NestedEnum { + EPSILON = 5; + ZETA = 6; + } + optional NestedEnum nested_enum = 1 [default = ZETA]; + optional string nested_field = 2 [default = "beta"]; + optional DeepNestedMessage deep_nested_message = 3; + + message DeepNestedMessage { + enum NestedEnum { + ETA = 7; + THETA = 8; + } + optional NestedEnum nested_enum = 1 [default = ETA]; + optional string nested_field = 2 [default = "theta"]; + } + } + + optional NestedMessage nested_message = 2; +} + +message DescriptorPoolTest2 { + enum NestedEnum { + GAMMA = 3; + DELTA = 4; + } + + optional NestedEnum nested_enum = 1 [default = GAMMA]; + + message NestedMessage { + enum NestedEnum { + IOTA = 9; + KAPPA = 10; + } + optional NestedEnum nested_enum = 1 [default = IOTA]; + optional string nested_field = 2 [default = "delta"]; + optional DeepNestedMessage deep_nested_message = 3; + + message DeepNestedMessage { + enum NestedEnum { + LAMBDA = 11; + MU = 12; + } + optional NestedEnum nested_enum = 1 [default = MU]; + optional string nested_field = 2 [default = "lambda"]; + } + } + + optional NestedMessage nested_message = 2; +} diff --git a/python/google/protobuf/internal/descriptor_pool_test2.proto b/python/google/protobuf/internal/descriptor_pool_test2.proto new file mode 100644 index 0000000..fbc8438 --- /dev/null +++ b/python/google/protobuf/internal/descriptor_pool_test2.proto @@ -0,0 +1,70 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package google.protobuf.python.internal; + +import "google/protobuf/internal/descriptor_pool_test1.proto"; + + +message DescriptorPoolTest3 { + + extend DescriptorPoolTest1 { + optional DescriptorPoolTest3 descriptor_pool_test = 1001; + } + + enum NestedEnum { + NU = 13; + XI = 14; + } + + optional NestedEnum nested_enum = 1 [default = XI]; + + message NestedMessage { + enum NestedEnum { + OMICRON = 15; + PI = 16; + } + optional NestedEnum nested_enum = 1 [default = PI]; + optional string nested_field = 2 [default = "nu"]; + optional DeepNestedMessage deep_nested_message = 3; + + message DeepNestedMessage { + enum NestedEnum { + RHO = 17; + SIGMA = 18; + } + optional NestedEnum nested_enum = 1 [default = RHO]; + optional string nested_field = 2 [default = "sigma"]; + } + } + + optional NestedMessage nested_message = 2; +} + diff --git a/python/google/protobuf/internal/descriptor_python_test.py b/python/google/protobuf/internal/descriptor_python_test.py new file mode 100644 index 0000000..5471ae0 --- /dev/null +++ b/python/google/protobuf/internal/descriptor_python_test.py @@ -0,0 +1,54 @@ +#! /usr/bin/python +# +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# https://developers.google.com/protocol-buffers/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Unittest for descriptor.py for the pure Python implementation.""" + +import os +os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python' + +# We must set the implementation version above before the google3 imports. +# pylint: disable=g-import-not-at-top +from google.apputils import basetest +from google.protobuf.internal import api_implementation +# Run all tests from the original module by putting them in our namespace. +# pylint: disable=wildcard-import +from google.protobuf.internal.descriptor_test import * + + +class ConfirmPurePythonTest(basetest.TestCase): + + def testImplementationSetting(self): + self.assertEqual('python', api_implementation.Type()) + + +if __name__ == '__main__': + basetest.main() diff --git a/python/google/protobuf/internal/descriptor_test.py b/python/google/protobuf/internal/descriptor_test.py index 05c2745..b3777e3 100755 --- a/python/google/protobuf/internal/descriptor_test.py +++ b/python/google/protobuf/internal/descriptor_test.py @@ -2,7 +2,7 @@ # # Protocol Buffers - Google's data interchange format # Copyright 2008 Google Inc. All rights reserved. -# http://code.google.com/p/protobuf/ +# https://developers.google.com/protocol-buffers/ # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -34,7 +34,8 @@ __author__ = 'robinson@google.com (Will Robinson)' -import unittest +from google.apputils import basetest +from google.protobuf import unittest_custom_options_pb2 from google.protobuf import unittest_import_pb2 from google.protobuf import unittest_pb2 from google.protobuf import descriptor_pb2 @@ -47,7 +48,7 @@ name: 'TestEmptyMessage' """ -class DescriptorTest(unittest.TestCase): +class DescriptorTest(basetest.TestCase): def setUp(self): self.my_file = descriptor.FileDescriptor( @@ -101,6 +102,15 @@ class DescriptorTest(unittest.TestCase): self.my_method ]) + def testEnumValueName(self): + self.assertEqual(self.my_message.EnumValueName('ForeignEnum', 4), + 'FOREIGN_FOO') + + self.assertEqual( + self.my_message.enum_types_by_name[ + 'ForeignEnum'].values_by_number[4].name, + self.my_message.EnumValueName('ForeignEnum', 4)) + def testEnumFixups(self): self.assertEqual(self.my_enum, self.my_enum.values[0].type) @@ -125,6 +135,257 @@ class DescriptorTest(unittest.TestCase): self.assertEqual(self.my_service.GetOptions(), descriptor_pb2.ServiceOptions()) + def testSimpleCustomOptions(self): + file_descriptor = unittest_custom_options_pb2.DESCRIPTOR + message_descriptor =\ + unittest_custom_options_pb2.TestMessageWithCustomOptions.DESCRIPTOR + field_descriptor = message_descriptor.fields_by_name["field1"] + enum_descriptor = message_descriptor.enum_types_by_name["AnEnum"] + enum_value_descriptor =\ + message_descriptor.enum_values_by_name["ANENUM_VAL2"] + service_descriptor =\ + unittest_custom_options_pb2.TestServiceWithCustomOptions.DESCRIPTOR + method_descriptor = service_descriptor.FindMethodByName("Foo") + + file_options = file_descriptor.GetOptions() + file_opt1 = unittest_custom_options_pb2.file_opt1 + self.assertEqual(9876543210, file_options.Extensions[file_opt1]) + message_options = message_descriptor.GetOptions() + message_opt1 = unittest_custom_options_pb2.message_opt1 + self.assertEqual(-56, message_options.Extensions[message_opt1]) + field_options = field_descriptor.GetOptions() + field_opt1 = unittest_custom_options_pb2.field_opt1 + self.assertEqual(8765432109, field_options.Extensions[field_opt1]) + field_opt2 = unittest_custom_options_pb2.field_opt2 + self.assertEqual(42, field_options.Extensions[field_opt2]) + enum_options = enum_descriptor.GetOptions() + enum_opt1 = unittest_custom_options_pb2.enum_opt1 + self.assertEqual(-789, enum_options.Extensions[enum_opt1]) + enum_value_options = enum_value_descriptor.GetOptions() + enum_value_opt1 = unittest_custom_options_pb2.enum_value_opt1 + self.assertEqual(123, enum_value_options.Extensions[enum_value_opt1]) + + service_options = service_descriptor.GetOptions() + service_opt1 = unittest_custom_options_pb2.service_opt1 + self.assertEqual(-9876543210, service_options.Extensions[service_opt1]) + method_options = method_descriptor.GetOptions() + method_opt1 = unittest_custom_options_pb2.method_opt1 + self.assertEqual(unittest_custom_options_pb2.METHODOPT1_VAL2, + method_options.Extensions[method_opt1]) + + def testDifferentCustomOptionTypes(self): + kint32min = -2**31 + kint64min = -2**63 + kint32max = 2**31 - 1 + kint64max = 2**63 - 1 + kuint32max = 2**32 - 1 + kuint64max = 2**64 - 1 + + message_descriptor =\ + unittest_custom_options_pb2.CustomOptionMinIntegerValues.DESCRIPTOR + message_options = message_descriptor.GetOptions() + self.assertEqual(False, message_options.Extensions[ + unittest_custom_options_pb2.bool_opt]) + self.assertEqual(kint32min, message_options.Extensions[ + unittest_custom_options_pb2.int32_opt]) + self.assertEqual(kint64min, message_options.Extensions[ + unittest_custom_options_pb2.int64_opt]) + self.assertEqual(0, message_options.Extensions[ + unittest_custom_options_pb2.uint32_opt]) + self.assertEqual(0, message_options.Extensions[ + unittest_custom_options_pb2.uint64_opt]) + self.assertEqual(kint32min, message_options.Extensions[ + unittest_custom_options_pb2.sint32_opt]) + self.assertEqual(kint64min, message_options.Extensions[ + unittest_custom_options_pb2.sint64_opt]) + self.assertEqual(0, message_options.Extensions[ + unittest_custom_options_pb2.fixed32_opt]) + self.assertEqual(0, message_options.Extensions[ + unittest_custom_options_pb2.fixed64_opt]) + self.assertEqual(kint32min, message_options.Extensions[ + unittest_custom_options_pb2.sfixed32_opt]) + self.assertEqual(kint64min, message_options.Extensions[ + unittest_custom_options_pb2.sfixed64_opt]) + + message_descriptor =\ + unittest_custom_options_pb2.CustomOptionMaxIntegerValues.DESCRIPTOR + message_options = message_descriptor.GetOptions() + self.assertEqual(True, message_options.Extensions[ + unittest_custom_options_pb2.bool_opt]) + self.assertEqual(kint32max, message_options.Extensions[ + unittest_custom_options_pb2.int32_opt]) + self.assertEqual(kint64max, message_options.Extensions[ + unittest_custom_options_pb2.int64_opt]) + self.assertEqual(kuint32max, message_options.Extensions[ + unittest_custom_options_pb2.uint32_opt]) + self.assertEqual(kuint64max, message_options.Extensions[ + unittest_custom_options_pb2.uint64_opt]) + self.assertEqual(kint32max, message_options.Extensions[ + unittest_custom_options_pb2.sint32_opt]) + self.assertEqual(kint64max, message_options.Extensions[ + unittest_custom_options_pb2.sint64_opt]) + self.assertEqual(kuint32max, message_options.Extensions[ + unittest_custom_options_pb2.fixed32_opt]) + self.assertEqual(kuint64max, message_options.Extensions[ + unittest_custom_options_pb2.fixed64_opt]) + self.assertEqual(kint32max, message_options.Extensions[ + unittest_custom_options_pb2.sfixed32_opt]) + self.assertEqual(kint64max, message_options.Extensions[ + unittest_custom_options_pb2.sfixed64_opt]) + + message_descriptor =\ + unittest_custom_options_pb2.CustomOptionOtherValues.DESCRIPTOR + message_options = message_descriptor.GetOptions() + self.assertEqual(-100, message_options.Extensions[ + unittest_custom_options_pb2.int32_opt]) + self.assertAlmostEqual(12.3456789, message_options.Extensions[ + unittest_custom_options_pb2.float_opt], 6) + self.assertAlmostEqual(1.234567890123456789, message_options.Extensions[ + unittest_custom_options_pb2.double_opt]) + self.assertEqual("Hello, \"World\"", message_options.Extensions[ + unittest_custom_options_pb2.string_opt]) + self.assertEqual(b"Hello\0World", message_options.Extensions[ + unittest_custom_options_pb2.bytes_opt]) + dummy_enum = unittest_custom_options_pb2.DummyMessageContainingEnum + self.assertEqual( + dummy_enum.TEST_OPTION_ENUM_TYPE2, + message_options.Extensions[unittest_custom_options_pb2.enum_opt]) + + message_descriptor =\ + unittest_custom_options_pb2.SettingRealsFromPositiveInts.DESCRIPTOR + message_options = message_descriptor.GetOptions() + self.assertAlmostEqual(12, message_options.Extensions[ + unittest_custom_options_pb2.float_opt], 6) + self.assertAlmostEqual(154, message_options.Extensions[ + unittest_custom_options_pb2.double_opt]) + + message_descriptor =\ + unittest_custom_options_pb2.SettingRealsFromNegativeInts.DESCRIPTOR + message_options = message_descriptor.GetOptions() + self.assertAlmostEqual(-12, message_options.Extensions[ + unittest_custom_options_pb2.float_opt], 6) + self.assertAlmostEqual(-154, message_options.Extensions[ + unittest_custom_options_pb2.double_opt]) + + def testComplexExtensionOptions(self): + descriptor =\ + unittest_custom_options_pb2.VariousComplexOptions.DESCRIPTOR + options = descriptor.GetOptions() + self.assertEqual(42, options.Extensions[ + unittest_custom_options_pb2.complex_opt1].foo) + self.assertEqual(324, options.Extensions[ + unittest_custom_options_pb2.complex_opt1].Extensions[ + unittest_custom_options_pb2.quux]) + self.assertEqual(876, options.Extensions[ + unittest_custom_options_pb2.complex_opt1].Extensions[ + unittest_custom_options_pb2.corge].qux) + self.assertEqual(987, options.Extensions[ + unittest_custom_options_pb2.complex_opt2].baz) + self.assertEqual(654, options.Extensions[ + unittest_custom_options_pb2.complex_opt2].Extensions[ + unittest_custom_options_pb2.grault]) + self.assertEqual(743, options.Extensions[ + unittest_custom_options_pb2.complex_opt2].bar.foo) + self.assertEqual(1999, options.Extensions[ + unittest_custom_options_pb2.complex_opt2].bar.Extensions[ + unittest_custom_options_pb2.quux]) + self.assertEqual(2008, options.Extensions[ + unittest_custom_options_pb2.complex_opt2].bar.Extensions[ + unittest_custom_options_pb2.corge].qux) + self.assertEqual(741, options.Extensions[ + unittest_custom_options_pb2.complex_opt2].Extensions[ + unittest_custom_options_pb2.garply].foo) + self.assertEqual(1998, options.Extensions[ + unittest_custom_options_pb2.complex_opt2].Extensions[ + unittest_custom_options_pb2.garply].Extensions[ + unittest_custom_options_pb2.quux]) + self.assertEqual(2121, options.Extensions[ + unittest_custom_options_pb2.complex_opt2].Extensions[ + unittest_custom_options_pb2.garply].Extensions[ + unittest_custom_options_pb2.corge].qux) + self.assertEqual(1971, options.Extensions[ + unittest_custom_options_pb2.ComplexOptionType2 + .ComplexOptionType4.complex_opt4].waldo) + self.assertEqual(321, options.Extensions[ + unittest_custom_options_pb2.complex_opt2].fred.waldo) + self.assertEqual(9, options.Extensions[ + unittest_custom_options_pb2.complex_opt3].qux) + self.assertEqual(22, options.Extensions[ + unittest_custom_options_pb2.complex_opt3].complexoptiontype5.plugh) + self.assertEqual(24, options.Extensions[ + unittest_custom_options_pb2.complexopt6].xyzzy) + + # Check that aggregate options were parsed and saved correctly in + # the appropriate descriptors. + def testAggregateOptions(self): + file_descriptor = unittest_custom_options_pb2.DESCRIPTOR + message_descriptor =\ + unittest_custom_options_pb2.AggregateMessage.DESCRIPTOR + field_descriptor = message_descriptor.fields_by_name["fieldname"] + enum_descriptor = unittest_custom_options_pb2.AggregateEnum.DESCRIPTOR + enum_value_descriptor = enum_descriptor.values_by_name["VALUE"] + service_descriptor =\ + unittest_custom_options_pb2.AggregateService.DESCRIPTOR + method_descriptor = service_descriptor.FindMethodByName("Method") + + # Tests for the different types of data embedded in fileopt + file_options = file_descriptor.GetOptions().Extensions[ + unittest_custom_options_pb2.fileopt] + self.assertEqual(100, file_options.i) + self.assertEqual("FileAnnotation", file_options.s) + self.assertEqual("NestedFileAnnotation", file_options.sub.s) + self.assertEqual("FileExtensionAnnotation", file_options.file.Extensions[ + unittest_custom_options_pb2.fileopt].s) + self.assertEqual("EmbeddedMessageSetElement", file_options.mset.Extensions[ + unittest_custom_options_pb2.AggregateMessageSetElement + .message_set_extension].s) + + # Simple tests for all the other types of annotations + self.assertEqual( + "MessageAnnotation", + message_descriptor.GetOptions().Extensions[ + unittest_custom_options_pb2.msgopt].s) + self.assertEqual( + "FieldAnnotation", + field_descriptor.GetOptions().Extensions[ + unittest_custom_options_pb2.fieldopt].s) + self.assertEqual( + "EnumAnnotation", + enum_descriptor.GetOptions().Extensions[ + unittest_custom_options_pb2.enumopt].s) + self.assertEqual( + "EnumValueAnnotation", + enum_value_descriptor.GetOptions().Extensions[ + unittest_custom_options_pb2.enumvalopt].s) + self.assertEqual( + "ServiceAnnotation", + service_descriptor.GetOptions().Extensions[ + unittest_custom_options_pb2.serviceopt].s) + self.assertEqual( + "MethodAnnotation", + method_descriptor.GetOptions().Extensions[ + unittest_custom_options_pb2.methodopt].s) + + def testNestedOptions(self): + nested_message =\ + unittest_custom_options_pb2.NestedOptionType.NestedMessage.DESCRIPTOR + self.assertEqual(1001, nested_message.GetOptions().Extensions[ + unittest_custom_options_pb2.message_opt1]) + nested_field = nested_message.fields_by_name["nested_field"] + self.assertEqual(1002, nested_field.GetOptions().Extensions[ + unittest_custom_options_pb2.field_opt1]) + outer_message =\ + unittest_custom_options_pb2.NestedOptionType.DESCRIPTOR + nested_enum = outer_message.enum_types_by_name["NestedEnum"] + self.assertEqual(1003, nested_enum.GetOptions().Extensions[ + unittest_custom_options_pb2.enum_opt1]) + nested_enum_value = outer_message.enum_values_by_name["NESTED_ENUM_VALUE"] + self.assertEqual(1004, nested_enum_value.GetOptions().Extensions[ + unittest_custom_options_pb2.enum_value_opt1]) + nested_extension = outer_message.extensions_by_name["nested_extension"] + self.assertEqual(1005, nested_extension.GetOptions().Extensions[ + unittest_custom_options_pb2.field_opt2]) + def testFileDescriptorReferences(self): self.assertEqual(self.my_enum.file, self.my_file) self.assertEqual(self.my_message.file, self.my_file) @@ -134,7 +395,7 @@ class DescriptorTest(unittest.TestCase): self.assertEqual(self.my_file.package, 'protobuf_unittest') -class DescriptorCopyToProtoTest(unittest.TestCase): +class DescriptorCopyToProtoTest(basetest.TestCase): """Tests for CopyTo functions of Descriptor.""" def _AssertProtoEqual(self, actual_proto, expected_class, expected_ascii): @@ -269,45 +530,49 @@ class DescriptorCopyToProtoTest(unittest.TestCase): descriptor_pb2.DescriptorProto, TEST_MESSAGE_WITH_SEVERAL_EXTENSIONS_ASCII) - def testCopyToProto_FileDescriptor(self): - UNITTEST_IMPORT_FILE_DESCRIPTOR_ASCII = (""" - name: 'google/protobuf/unittest_import.proto' - package: 'protobuf_unittest_import' - message_type: < - name: 'ImportMessage' - field: < - name: 'd' - number: 1 - label: 1 # Optional - type: 5 # TYPE_INT32 - > - > - """ + - """enum_type: < - name: 'ImportEnum' - value: < - name: 'IMPORT_FOO' - number: 7 - > - value: < - name: 'IMPORT_BAR' - number: 8 - > - value: < - name: 'IMPORT_BAZ' - number: 9 - > - > - options: < - java_package: 'com.google.protobuf.test' - optimize_for: 1 # SPEED - > - """) - - self._InternalTestCopyToProto( - unittest_import_pb2.DESCRIPTOR, - descriptor_pb2.FileDescriptorProto, - UNITTEST_IMPORT_FILE_DESCRIPTOR_ASCII) + # Disable this test so we can make changes to the proto file. + # TODO(xiaofeng): Enable this test after cl/55530659 is submitted. + # + # def testCopyToProto_FileDescriptor(self): + # UNITTEST_IMPORT_FILE_DESCRIPTOR_ASCII = (""" + # name: 'google/protobuf/unittest_import.proto' + # package: 'protobuf_unittest_import' + # dependency: 'google/protobuf/unittest_import_public.proto' + # message_type: < + # name: 'ImportMessage' + # field: < + # name: 'd' + # number: 1 + # label: 1 # Optional + # type: 5 # TYPE_INT32 + # > + # > + # """ + + # """enum_type: < + # name: 'ImportEnum' + # value: < + # name: 'IMPORT_FOO' + # number: 7 + # > + # value: < + # name: 'IMPORT_BAR' + # number: 8 + # > + # value: < + # name: 'IMPORT_BAZ' + # number: 9 + # > + # > + # options: < + # java_package: 'com.google.protobuf.test' + # optimize_for: 1 # SPEED + # > + # public_dependency: 0 + # """) + # self._InternalTestCopyToProto( + # unittest_import_pb2.DESCRIPTOR, + # descriptor_pb2.FileDescriptorProto, + # UNITTEST_IMPORT_FILE_DESCRIPTOR_ASCII) def testCopyToProto_ServiceDescriptor(self): TEST_SERVICE_ASCII = """ @@ -323,12 +588,82 @@ class DescriptorCopyToProtoTest(unittest.TestCase): output_type: '.protobuf_unittest.BarResponse' > """ - self._InternalTestCopyToProto( unittest_pb2.TestService.DESCRIPTOR, descriptor_pb2.ServiceDescriptorProto, TEST_SERVICE_ASCII) +class MakeDescriptorTest(basetest.TestCase): + + def testMakeDescriptorWithNestedFields(self): + file_descriptor_proto = descriptor_pb2.FileDescriptorProto() + file_descriptor_proto.name = 'Foo2' + message_type = file_descriptor_proto.message_type.add() + message_type.name = file_descriptor_proto.name + nested_type = message_type.nested_type.add() + nested_type.name = 'Sub' + enum_type = nested_type.enum_type.add() + enum_type.name = 'FOO' + enum_type_val = enum_type.value.add() + enum_type_val.name = 'BAR' + enum_type_val.number = 3 + field = message_type.field.add() + field.number = 1 + field.name = 'uint64_field' + field.label = descriptor.FieldDescriptor.LABEL_REQUIRED + field.type = descriptor.FieldDescriptor.TYPE_UINT64 + field = message_type.field.add() + field.number = 2 + field.name = 'nested_message_field' + field.label = descriptor.FieldDescriptor.LABEL_REQUIRED + field.type = descriptor.FieldDescriptor.TYPE_MESSAGE + field.type_name = 'Sub' + enum_field = nested_type.field.add() + enum_field.number = 2 + enum_field.name = 'bar_field' + enum_field.label = descriptor.FieldDescriptor.LABEL_REQUIRED + enum_field.type = descriptor.FieldDescriptor.TYPE_ENUM + enum_field.type_name = 'Foo2.Sub.FOO' + + result = descriptor.MakeDescriptor(message_type) + self.assertEqual(result.fields[0].cpp_type, + descriptor.FieldDescriptor.CPPTYPE_UINT64) + self.assertEqual(result.fields[1].cpp_type, + descriptor.FieldDescriptor.CPPTYPE_MESSAGE) + self.assertEqual(result.fields[1].message_type.containing_type, + result) + self.assertEqual(result.nested_types[0].fields[0].full_name, + 'Foo2.Sub.bar_field') + self.assertEqual(result.nested_types[0].fields[0].enum_type, + result.nested_types[0].enum_types[0]) + + def testMakeDescriptorWithUnsignedIntField(self): + file_descriptor_proto = descriptor_pb2.FileDescriptorProto() + file_descriptor_proto.name = 'Foo' + message_type = file_descriptor_proto.message_type.add() + message_type.name = file_descriptor_proto.name + enum_type = message_type.enum_type.add() + enum_type.name = 'FOO' + enum_type_val = enum_type.value.add() + enum_type_val.name = 'BAR' + enum_type_val.number = 3 + field = message_type.field.add() + field.number = 1 + field.name = 'uint64_field' + field.label = descriptor.FieldDescriptor.LABEL_REQUIRED + field.type = descriptor.FieldDescriptor.TYPE_UINT64 + enum_field = message_type.field.add() + enum_field.number = 2 + enum_field.name = 'bar_field' + enum_field.label = descriptor.FieldDescriptor.LABEL_REQUIRED + enum_field.type = descriptor.FieldDescriptor.TYPE_ENUM + enum_field.type_name = 'Foo.FOO' + + result = descriptor.MakeDescriptor(message_type) + self.assertEqual(result.fields[0].cpp_type, + descriptor.FieldDescriptor.CPPTYPE_UINT64) + + if __name__ == '__main__': - unittest.main() + basetest.main() diff --git a/python/google/protobuf/internal/encoder.py b/python/google/protobuf/internal/encoder.py index aa05d5b..38a5138 100755 --- a/python/google/protobuf/internal/encoder.py +++ b/python/google/protobuf/internal/encoder.py @@ -1,6 +1,6 @@ # Protocol Buffers - Google's data interchange format # Copyright 2008 Google Inc. All rights reserved. -# http://code.google.com/p/protobuf/ +# https://developers.google.com/protocol-buffers/ # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -28,6 +28,10 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#PY25 compatible for GAE. +# +# Copyright 2009 Google Inc. All Rights Reserved. + """Code for encoding protocol message primitives. Contains the logic for encoding every logical protocol field type @@ -67,9 +71,17 @@ sizer rather than when calling them. In particular: __author__ = 'kenton@google.com (Kenton Varda)' import struct +import sys ##PY25 +_PY2 = sys.version_info[0] < 3 ##PY25 from google.protobuf.internal import wire_format +# This will overflow and thus become IEEE-754 "infinity". We would use +# "float('inf')" but it doesn't work on Windows pre-Python-2.6. +_POS_INF = 1e10000 +_NEG_INF = -_POS_INF + + def _VarintSize(value): """Compute the size of a varint value.""" if value <= 0x7f: return 1 @@ -334,7 +346,8 @@ def MessageSetItemSizer(field_number): def _VarintEncoder(): """Return an encoder for a basic varint value (does not include tag).""" - local_chr = chr + local_chr = _PY2 and chr or (lambda x: bytes((x,))) ##PY25 +##!PY25 local_chr = chr if bytes is str else lambda x: bytes((x,)) def EncodeVarint(write, value): bits = value & 0x7f value >>= 7 @@ -351,7 +364,8 @@ def _SignedVarintEncoder(): """Return an encoder for a basic signed varint value (does not include tag).""" - local_chr = chr + local_chr = _PY2 and chr or (lambda x: bytes((x,))) ##PY25 +##!PY25 local_chr = chr if bytes is str else lambda x: bytes((x,)) def EncodeSignedVarint(write, value): if value < 0: value += (1 << 64) @@ -376,7 +390,8 @@ def _VarintBytes(value): pieces = [] _EncodeVarint(pieces.append, value) - return "".join(pieces) + return "".encode("latin1").join(pieces) ##PY25 +##!PY25 return b"".join(pieces) def TagBytes(field_number, wire_type): @@ -502,6 +517,90 @@ def _StructPackEncoder(wire_type, format): return SpecificEncoder +def _FloatingPointEncoder(wire_type, format): + """Return a constructor for an encoder for float fields. + + This is like StructPackEncoder, but catches errors that may be due to + passing non-finite floating-point values to struct.pack, and makes a + second attempt to encode those values. + + Args: + wire_type: The field's wire type, for encoding tags. + format: The format string to pass to struct.pack(). + """ + + b = _PY2 and (lambda x:x) or (lambda x:x.encode('latin1')) ##PY25 + value_size = struct.calcsize(format) + if value_size == 4: + def EncodeNonFiniteOrRaise(write, value): + # Remember that the serialized form uses little-endian byte order. + if value == _POS_INF: + write(b('\x00\x00\x80\x7F')) ##PY25 +##!PY25 write(b'\x00\x00\x80\x7F') + elif value == _NEG_INF: + write(b('\x00\x00\x80\xFF')) ##PY25 +##!PY25 write(b'\x00\x00\x80\xFF') + elif value != value: # NaN + write(b('\x00\x00\xC0\x7F')) ##PY25 +##!PY25 write(b'\x00\x00\xC0\x7F') + else: + raise + elif value_size == 8: + def EncodeNonFiniteOrRaise(write, value): + if value == _POS_INF: + write(b('\x00\x00\x00\x00\x00\x00\xF0\x7F')) ##PY25 +##!PY25 write(b'\x00\x00\x00\x00\x00\x00\xF0\x7F') + elif value == _NEG_INF: + write(b('\x00\x00\x00\x00\x00\x00\xF0\xFF')) ##PY25 +##!PY25 write(b'\x00\x00\x00\x00\x00\x00\xF0\xFF') + elif value != value: # NaN + write(b('\x00\x00\x00\x00\x00\x00\xF8\x7F')) ##PY25 +##!PY25 write(b'\x00\x00\x00\x00\x00\x00\xF8\x7F') + else: + raise + else: + raise ValueError('Can\'t encode floating-point values that are ' + '%d bytes long (only 4 or 8)' % value_size) + + def SpecificEncoder(field_number, is_repeated, is_packed): + local_struct_pack = struct.pack + if is_packed: + tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) + local_EncodeVarint = _EncodeVarint + def EncodePackedField(write, value): + write(tag_bytes) + local_EncodeVarint(write, len(value) * value_size) + for element in value: + # This try/except block is going to be faster than any code that + # we could write to check whether element is finite. + try: + write(local_struct_pack(format, element)) + except SystemError: + EncodeNonFiniteOrRaise(write, element) + return EncodePackedField + elif is_repeated: + tag_bytes = TagBytes(field_number, wire_type) + def EncodeRepeatedField(write, value): + for element in value: + write(tag_bytes) + try: + write(local_struct_pack(format, element)) + except SystemError: + EncodeNonFiniteOrRaise(write, element) + return EncodeRepeatedField + else: + tag_bytes = TagBytes(field_number, wire_type) + def EncodeField(write, value): + write(tag_bytes) + try: + write(local_struct_pack(format, value)) + except SystemError: + EncodeNonFiniteOrRaise(write, value) + return EncodeField + + return SpecificEncoder + + # ==================================================================== # Here we declare an encoder constructor for each field type. These work # very similarly to sizer constructors, described earlier. @@ -525,15 +624,17 @@ Fixed32Encoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED32, '<I') Fixed64Encoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED64, '<Q') SFixed32Encoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED32, '<i') SFixed64Encoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED64, '<q') -FloatEncoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED32, '<f') -DoubleEncoder = _StructPackEncoder(wire_format.WIRETYPE_FIXED64, '<d') +FloatEncoder = _FloatingPointEncoder(wire_format.WIRETYPE_FIXED32, '<f') +DoubleEncoder = _FloatingPointEncoder(wire_format.WIRETYPE_FIXED64, '<d') def BoolEncoder(field_number, is_repeated, is_packed): """Returns an encoder for a boolean field.""" - false_byte = chr(0) - true_byte = chr(1) +##!PY25 false_byte = b'\x00' +##!PY25 true_byte = b'\x01' + false_byte = '\x00'.encode('latin1') ##PY25 + true_byte = '\x01'.encode('latin1') ##PY25 if is_packed: tag_bytes = TagBytes(field_number, wire_format.WIRETYPE_LENGTH_DELIMITED) local_EncodeVarint = _EncodeVarint @@ -669,7 +770,8 @@ def MessageSetItemEncoder(field_number): } } """ - start_bytes = "".join([ + start_bytes = "".encode("latin1").join([ ##PY25 +##!PY25 start_bytes = b"".join([ TagBytes(1, wire_format.WIRETYPE_START_GROUP), TagBytes(2, wire_format.WIRETYPE_VARINT), _VarintBytes(field_number), diff --git a/python/google/protobuf/internal/enum_type_wrapper.py b/python/google/protobuf/internal/enum_type_wrapper.py new file mode 100644 index 0000000..1cffe35 --- /dev/null +++ b/python/google/protobuf/internal/enum_type_wrapper.py @@ -0,0 +1,89 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# https://developers.google.com/protocol-buffers/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""A simple wrapper around enum types to expose utility functions. + +Instances are created as properties with the same name as the enum they wrap +on proto classes. For usage, see: + reflection_test.py +""" + +__author__ = 'rabsatt@google.com (Kevin Rabsatt)' + + +class EnumTypeWrapper(object): + """A utility for finding the names of enum values.""" + + DESCRIPTOR = None + + def __init__(self, enum_type): + """Inits EnumTypeWrapper with an EnumDescriptor.""" + self._enum_type = enum_type + self.DESCRIPTOR = enum_type; + + def Name(self, number): + """Returns a string containing the name of an enum value.""" + if number in self._enum_type.values_by_number: + return self._enum_type.values_by_number[number].name + raise ValueError('Enum %s has no name defined for value %d' % ( + self._enum_type.name, number)) + + def Value(self, name): + """Returns the value coresponding to the given enum name.""" + if name in self._enum_type.values_by_name: + return self._enum_type.values_by_name[name].number + raise ValueError('Enum %s has no value defined for name %s' % ( + self._enum_type.name, name)) + + def keys(self): + """Return a list of the string names in the enum. + + These are returned in the order they were defined in the .proto file. + """ + + return [value_descriptor.name + for value_descriptor in self._enum_type.values] + + def values(self): + """Return a list of the integer values in the enum. + + These are returned in the order they were defined in the .proto file. + """ + + return [value_descriptor.number + for value_descriptor in self._enum_type.values] + + def items(self): + """Return a list of the (name, value) pairs of the enum. + + These are returned in the order they were defined in the .proto file. + """ + return [(value_descriptor.name, value_descriptor.number) + for value_descriptor in self._enum_type.values] diff --git a/python/google/protobuf/internal/factory_test1.proto b/python/google/protobuf/internal/factory_test1.proto new file mode 100644 index 0000000..9f5a391 --- /dev/null +++ b/python/google/protobuf/internal/factory_test1.proto @@ -0,0 +1,57 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: matthewtoia@google.com (Matt Toia) + + +package google.protobuf.python.internal; + + +enum Factory1Enum { + FACTORY_1_VALUE_0 = 0; + FACTORY_1_VALUE_1 = 1; +} + +message Factory1Message { + optional Factory1Enum factory_1_enum = 1; + enum NestedFactory1Enum { + NESTED_FACTORY_1_VALUE_0 = 0; + NESTED_FACTORY_1_VALUE_1 = 1; + } + optional NestedFactory1Enum nested_factory_1_enum = 2; + message NestedFactory1Message { + optional string value = 1; + } + optional NestedFactory1Message nested_factory_1_message = 3; + optional int32 scalar_value = 4; + repeated string list_value = 5; + + extensions 1000 to max; +} diff --git a/python/google/protobuf/internal/factory_test2.proto b/python/google/protobuf/internal/factory_test2.proto new file mode 100644 index 0000000..27feb6c --- /dev/null +++ b/python/google/protobuf/internal/factory_test2.proto @@ -0,0 +1,92 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: matthewtoia@google.com (Matt Toia) + + +package google.protobuf.python.internal; + +import "google/protobuf/internal/factory_test1.proto"; + + +enum Factory2Enum { + FACTORY_2_VALUE_0 = 0; + FACTORY_2_VALUE_1 = 1; +} + +message Factory2Message { + required int32 mandatory = 1; + optional Factory2Enum factory_2_enum = 2; + enum NestedFactory2Enum { + NESTED_FACTORY_2_VALUE_0 = 0; + NESTED_FACTORY_2_VALUE_1 = 1; + } + optional NestedFactory2Enum nested_factory_2_enum = 3; + message NestedFactory2Message { + optional string value = 1; + } + optional NestedFactory2Message nested_factory_2_message = 4; + optional Factory1Message factory_1_message = 5; + optional Factory1Enum factory_1_enum = 6; + optional Factory1Message.NestedFactory1Enum nested_factory_1_enum = 7; + optional Factory1Message.NestedFactory1Message nested_factory_1_message = 8; + optional Factory2Message circular_message = 9; + optional string scalar_value = 10; + repeated string list_value = 11; + repeated group Grouped = 12 { + optional string part_1 = 13; + optional string part_2 = 14; + } + optional LoopMessage loop = 15; + optional int32 int_with_default = 16 [default = 1776]; + optional double double_with_default = 17 [default = 9.99]; + optional string string_with_default = 18 [default = "hello world"]; + optional bool bool_with_default = 19 [default = false]; + optional Factory2Enum enum_with_default = 20 [default = FACTORY_2_VALUE_1]; + optional bytes bytes_with_default = 21 [default = "a\373\000c"]; + + + extend Factory1Message { + optional string one_more_field = 1001; + } + + oneof oneof_field { + int32 oneof_int = 22; + string oneof_string = 23; + } +} + +message LoopMessage { + optional Factory2Message loop = 1; +} + +extend Factory1Message { + optional string another_field = 1002; +} diff --git a/python/google/protobuf/internal/generator_test.py b/python/google/protobuf/internal/generator_test.py index 78360b5..422fa9a 100755 --- a/python/google/protobuf/internal/generator_test.py +++ b/python/google/protobuf/internal/generator_test.py @@ -2,7 +2,7 @@ # # Protocol Buffers - Google's data interchange format # Copyright 2008 Google Inc. All rights reserved. -# http://code.google.com/p/protobuf/ +# https://developers.google.com/protocol-buffers/ # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -41,17 +41,21 @@ further ensures that we can use Python protocol message objects as we expect. __author__ = 'robinson@google.com (Will Robinson)' -import unittest +from google.apputils import basetest +from google.protobuf.internal import test_bad_identifiers_pb2 +from google.protobuf import unittest_custom_options_pb2 from google.protobuf import unittest_import_pb2 +from google.protobuf import unittest_import_public_pb2 from google.protobuf import unittest_mset_pb2 -from google.protobuf import unittest_pb2 from google.protobuf import unittest_no_generic_services_pb2 - +from google.protobuf import unittest_pb2 +from google.protobuf import service +from google.protobuf import symbol_database MAX_EXTENSION = 536870912 -class GeneratorTest(unittest.TestCase): +class GeneratorTest(basetest.TestCase): def testNestedMessageDescriptor(self): field_name = 'optional_nested_message' @@ -99,6 +103,7 @@ class GeneratorTest(unittest.TestCase): self.assertTrue(isinf(message.neg_inf_float)) self.assertTrue(message.neg_inf_float < 0) self.assertTrue(isnan(message.nan_float)) + self.assertEqual("? ? ?? ?? ??? ??/ ??-", message.cpp_trigraph) def testHasDefaultValues(self): desc = unittest_pb2.TestAllTypes.DESCRIPTOR @@ -140,6 +145,13 @@ class GeneratorTest(unittest.TestCase): proto = unittest_mset_pb2.TestMessageSet() self.assertTrue(proto.DESCRIPTOR.GetOptions().message_set_wire_format) + def testMessageWithCustomOptions(self): + proto = unittest_custom_options_pb2.TestMessageWithCustomOptions() + enum_options = proto.DESCRIPTOR.enum_types_by_name['AnEnum'].GetOptions() + self.assertTrue(enum_options is not None) + # TODO(gps): We really should test for the presense of the enum_opt1 + # extension and for its value to be set to -789. + def testNestedTypes(self): self.assertEquals( set(unittest_pb2.TestAllTypes.DESCRIPTOR.nested_types), @@ -206,15 +218,126 @@ class GeneratorTest(unittest.TestCase): 'google/protobuf/unittest.proto') self.assertEqual(unittest_pb2.DESCRIPTOR.package, 'protobuf_unittest') self.assertFalse(unittest_pb2.DESCRIPTOR.serialized_pb is None) + self.assertEqual(unittest_pb2.DESCRIPTOR.dependencies, + [unittest_import_pb2.DESCRIPTOR]) + self.assertEqual(unittest_import_pb2.DESCRIPTOR.dependencies, + [unittest_import_public_pb2.DESCRIPTOR]) def testNoGenericServices(self): - # unittest_no_generic_services.proto should contain defs for everything - # except services. self.assertTrue(hasattr(unittest_no_generic_services_pb2, "TestMessage")) self.assertTrue(hasattr(unittest_no_generic_services_pb2, "FOO")) self.assertTrue(hasattr(unittest_no_generic_services_pb2, "test_extension")) - self.assertFalse(hasattr(unittest_no_generic_services_pb2, "TestService")) + # Make sure unittest_no_generic_services_pb2 has no services subclassing + # Proto2 Service class. + if hasattr(unittest_no_generic_services_pb2, "TestService"): + self.assertFalse(issubclass(unittest_no_generic_services_pb2.TestService, + service.Service)) + + def testMessageTypesByName(self): + file_type = unittest_pb2.DESCRIPTOR + self.assertEqual( + unittest_pb2._TESTALLTYPES, + file_type.message_types_by_name[unittest_pb2._TESTALLTYPES.name]) + + # Nested messages shouldn't be included in the message_types_by_name + # dictionary (like in the C++ API). + self.assertFalse( + unittest_pb2._TESTALLTYPES_NESTEDMESSAGE.name in + file_type.message_types_by_name) + + def testEnumTypesByName(self): + file_type = unittest_pb2.DESCRIPTOR + self.assertEqual( + unittest_pb2._FOREIGNENUM, + file_type.enum_types_by_name[unittest_pb2._FOREIGNENUM.name]) + + def testExtensionsByName(self): + file_type = unittest_pb2.DESCRIPTOR + self.assertEqual( + unittest_pb2.my_extension_string, + file_type.extensions_by_name[unittest_pb2.my_extension_string.name]) + + def testPublicImports(self): + # Test public imports as embedded message. + all_type_proto = unittest_pb2.TestAllTypes() + self.assertEqual(0, all_type_proto.optional_public_import_message.e) + + # PublicImportMessage is actually defined in unittest_import_public_pb2 + # module, and is public imported by unittest_import_pb2 module. + public_import_proto = unittest_import_pb2.PublicImportMessage() + self.assertEqual(0, public_import_proto.e) + self.assertTrue(unittest_import_public_pb2.PublicImportMessage is + unittest_import_pb2.PublicImportMessage) + + def testBadIdentifiers(self): + # We're just testing that the code was imported without problems. + message = test_bad_identifiers_pb2.TestBadIdentifiers() + self.assertEqual(message.Extensions[test_bad_identifiers_pb2.message], + "foo") + self.assertEqual(message.Extensions[test_bad_identifiers_pb2.descriptor], + "bar") + self.assertEqual(message.Extensions[test_bad_identifiers_pb2.reflection], + "baz") + self.assertEqual(message.Extensions[test_bad_identifiers_pb2.service], + "qux") + + def testOneof(self): + desc = unittest_pb2.TestAllTypes.DESCRIPTOR + self.assertEqual(1, len(desc.oneofs)) + self.assertEqual('oneof_field', desc.oneofs[0].name) + self.assertEqual(0, desc.oneofs[0].index) + self.assertIs(desc, desc.oneofs[0].containing_type) + self.assertIs(desc.oneofs[0], desc.oneofs_by_name['oneof_field']) + nested_names = set(['oneof_uint32', 'oneof_nested_message', + 'oneof_string', 'oneof_bytes']) + self.assertSameElements( + nested_names, + [field.name for field in desc.oneofs[0].fields]) + for field_name, field_desc in desc.fields_by_name.iteritems(): + if field_name in nested_names: + self.assertIs(desc.oneofs[0], field_desc.containing_oneof) + else: + self.assertIsNone(field_desc.containing_oneof) + + +class SymbolDatabaseRegistrationTest(basetest.TestCase): + """Checks that messages, enums and files are correctly registered.""" + + def testGetSymbol(self): + self.assertEquals( + unittest_pb2.TestAllTypes, symbol_database.Default().GetSymbol( + 'protobuf_unittest.TestAllTypes')) + self.assertEquals( + unittest_pb2.TestAllTypes.NestedMessage, + symbol_database.Default().GetSymbol( + 'protobuf_unittest.TestAllTypes.NestedMessage')) + with self.assertRaises(KeyError): + symbol_database.Default().GetSymbol('protobuf_unittest.NestedMessage') + self.assertEquals( + unittest_pb2.TestAllTypes.OptionalGroup, + symbol_database.Default().GetSymbol( + 'protobuf_unittest.TestAllTypes.OptionalGroup')) + self.assertEquals( + unittest_pb2.TestAllTypes.RepeatedGroup, + symbol_database.Default().GetSymbol( + 'protobuf_unittest.TestAllTypes.RepeatedGroup')) + + def testEnums(self): + self.assertEquals( + 'protobuf_unittest.ForeignEnum', + symbol_database.Default().pool.FindEnumTypeByName( + 'protobuf_unittest.ForeignEnum').full_name) + self.assertEquals( + 'protobuf_unittest.TestAllTypes.NestedEnum', + symbol_database.Default().pool.FindEnumTypeByName( + 'protobuf_unittest.TestAllTypes.NestedEnum').full_name) + + def testFindFileByName(self): + self.assertEquals( + 'google/protobuf/unittest.proto', + symbol_database.Default().pool.FindFileByName( + 'google/protobuf/unittest.proto').name) if __name__ == '__main__': - unittest.main() + basetest.main() diff --git a/python/google/protobuf/internal/message_factory_python_test.py b/python/google/protobuf/internal/message_factory_python_test.py new file mode 100644 index 0000000..85e02b2 --- /dev/null +++ b/python/google/protobuf/internal/message_factory_python_test.py @@ -0,0 +1,54 @@ +#! /usr/bin/python +# +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# https://developers.google.com/protocol-buffers/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Tests for ..public.message_factory for the pure Python implementation.""" + +import os +os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python' + +# We must set the implementation version above before the google3 imports. +# pylint: disable=g-import-not-at-top +from google.apputils import basetest +from google.protobuf.internal import api_implementation +# Run all tests from the original module by putting them in our namespace. +# pylint: disable=wildcard-import +from google.protobuf.internal.message_factory_test import * + + +class ConfirmPurePythonTest(basetest.TestCase): + + def testImplementationSetting(self): + self.assertEqual('python', api_implementation.Type()) + + +if __name__ == '__main__': + basetest.main() diff --git a/python/google/protobuf/internal/message_factory_test.py b/python/google/protobuf/internal/message_factory_test.py new file mode 100644 index 0000000..fcf1341 --- /dev/null +++ b/python/google/protobuf/internal/message_factory_test.py @@ -0,0 +1,131 @@ +#! /usr/bin/python +# +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# https://developers.google.com/protocol-buffers/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Tests for google.protobuf.message_factory.""" + +__author__ = 'matthewtoia@google.com (Matt Toia)' + +from google.apputils import basetest +from google.protobuf import descriptor_pb2 +from google.protobuf.internal import factory_test1_pb2 +from google.protobuf.internal import factory_test2_pb2 +from google.protobuf import descriptor_database +from google.protobuf import descriptor_pool +from google.protobuf import message_factory + + +class MessageFactoryTest(basetest.TestCase): + + def setUp(self): + self.factory_test1_fd = descriptor_pb2.FileDescriptorProto.FromString( + factory_test1_pb2.DESCRIPTOR.serialized_pb) + self.factory_test2_fd = descriptor_pb2.FileDescriptorProto.FromString( + factory_test2_pb2.DESCRIPTOR.serialized_pb) + + def _ExerciseDynamicClass(self, cls): + msg = cls() + msg.mandatory = 42 + msg.nested_factory_2_enum = 0 + msg.nested_factory_2_message.value = 'nested message value' + msg.factory_1_message.factory_1_enum = 1 + msg.factory_1_message.nested_factory_1_enum = 0 + msg.factory_1_message.nested_factory_1_message.value = ( + 'nested message value') + msg.factory_1_message.scalar_value = 22 + msg.factory_1_message.list_value.extend([u'one', u'two', u'three']) + msg.factory_1_message.list_value.append(u'four') + msg.factory_1_enum = 1 + msg.nested_factory_1_enum = 0 + msg.nested_factory_1_message.value = 'nested message value' + msg.circular_message.mandatory = 1 + msg.circular_message.circular_message.mandatory = 2 + msg.circular_message.scalar_value = 'one deep' + msg.scalar_value = 'zero deep' + msg.list_value.extend([u'four', u'three', u'two']) + msg.list_value.append(u'one') + msg.grouped.add() + msg.grouped[0].part_1 = 'hello' + msg.grouped[0].part_2 = 'world' + msg.grouped.add(part_1='testing', part_2='123') + msg.loop.loop.mandatory = 2 + msg.loop.loop.loop.loop.mandatory = 4 + serialized = msg.SerializeToString() + converted = factory_test2_pb2.Factory2Message.FromString(serialized) + reserialized = converted.SerializeToString() + self.assertEquals(serialized, reserialized) + result = cls.FromString(reserialized) + self.assertEquals(msg, result) + + def testGetPrototype(self): + db = descriptor_database.DescriptorDatabase() + pool = descriptor_pool.DescriptorPool(db) + db.Add(self.factory_test1_fd) + db.Add(self.factory_test2_fd) + factory = message_factory.MessageFactory() + cls = factory.GetPrototype(pool.FindMessageTypeByName( + 'google.protobuf.python.internal.Factory2Message')) + self.assertIsNot(cls, factory_test2_pb2.Factory2Message) + self._ExerciseDynamicClass(cls) + cls2 = factory.GetPrototype(pool.FindMessageTypeByName( + 'google.protobuf.python.internal.Factory2Message')) + self.assertIs(cls, cls2) + + def testGetMessages(self): + # performed twice because multiple calls with the same input must be allowed + for _ in range(2): + messages = message_factory.GetMessages([self.factory_test2_fd, + self.factory_test1_fd]) + self.assertContainsSubset( + ['google.protobuf.python.internal.Factory2Message', + 'google.protobuf.python.internal.Factory1Message'], + messages.keys()) + self._ExerciseDynamicClass( + messages['google.protobuf.python.internal.Factory2Message']) + self.assertContainsSubset( + ['google.protobuf.python.internal.Factory2Message.one_more_field', + 'google.protobuf.python.internal.another_field'], + (messages['google.protobuf.python.internal.Factory1Message'] + ._extensions_by_name.keys())) + factory_msg1 = messages['google.protobuf.python.internal.Factory1Message'] + msg1 = messages['google.protobuf.python.internal.Factory1Message']() + ext1 = factory_msg1._extensions_by_name[ + 'google.protobuf.python.internal.Factory2Message.one_more_field'] + ext2 = factory_msg1._extensions_by_name[ + 'google.protobuf.python.internal.another_field'] + msg1.Extensions[ext1] = 'test1' + msg1.Extensions[ext2] = 'test2' + self.assertEquals('test1', msg1.Extensions[ext1]) + self.assertEquals('test2', msg1.Extensions[ext2]) + + +if __name__ == '__main__': + basetest.main() diff --git a/python/google/protobuf/internal/message_listener.py b/python/google/protobuf/internal/message_listener.py index 1080234..0fc255a 100755 --- a/python/google/protobuf/internal/message_listener.py +++ b/python/google/protobuf/internal/message_listener.py @@ -1,6 +1,6 @@ # Protocol Buffers - Google's data interchange format # Copyright 2008 Google Inc. All rights reserved. -# http://code.google.com/p/protobuf/ +# https://developers.google.com/protocol-buffers/ # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are diff --git a/python/google/protobuf/internal/message_python_test.py b/python/google/protobuf/internal/message_python_test.py new file mode 100644 index 0000000..c40623a --- /dev/null +++ b/python/google/protobuf/internal/message_python_test.py @@ -0,0 +1,54 @@ +#! /usr/bin/python +# +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# https://developers.google.com/protocol-buffers/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Tests for ..public.message for the pure Python implementation.""" + +import os +os.environ['PROTOCOL_BUFFERS_PYTHON_IMPLEMENTATION'] = 'python' + +# We must set the implementation version above before the google3 imports. +# pylint: disable=g-import-not-at-top +from google.apputils import basetest +from google.protobuf.internal import api_implementation +# Run all tests from the original module by putting them in our namespace. +# pylint: disable=wildcard-import +from google.protobuf.internal.message_test import * + + +class ConfirmPurePythonTest(basetest.TestCase): + + def testImplementationSetting(self): + self.assertEqual('python', api_implementation.Type()) + + +if __name__ == '__main__': + basetest.main() diff --git a/python/google/protobuf/internal/message_test.py b/python/google/protobuf/internal/message_test.py index 73a9a3a..48b7ffd 100755 --- a/python/google/protobuf/internal/message_test.py +++ b/python/google/protobuf/internal/message_test.py @@ -2,7 +2,7 @@ # # Protocol Buffers - Google's data interchange format # Copyright 2008 Google Inc. All rights reserved. -# http://code.google.com/p/protobuf/ +# https://developers.google.com/protocol-buffers/ # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -43,47 +43,639 @@ abstract interface. __author__ = 'gps@google.com (Gregory P. Smith)' -import unittest -from google.protobuf import unittest_import_pb2 +import copy +import math +import operator +import pickle +import sys + +from google.apputils import basetest from google.protobuf import unittest_pb2 +from google.protobuf.internal import api_implementation from google.protobuf.internal import test_util +from google.protobuf import message + +# Python pre-2.6 does not have isinf() or isnan() functions, so we have +# to provide our own. +def isnan(val): + # NaN is never equal to itself. + return val != val +def isinf(val): + # Infinity times zero equals NaN. + return not isnan(val) and isnan(val * 0) +def IsPosInf(val): + return isinf(val) and (val > 0) +def IsNegInf(val): + return isinf(val) and (val < 0) -class MessageTest(unittest.TestCase): +class MessageTest(basetest.TestCase): + + def testBadUtf8String(self): + if api_implementation.Type() != 'python': + self.skipTest("Skipping testBadUtf8String, currently only the python " + "api implementation raises UnicodeDecodeError when a " + "string field contains bad utf-8.") + bad_utf8_data = test_util.GoldenFileData('bad_utf8_string') + with self.assertRaises(UnicodeDecodeError) as context: + unittest_pb2.TestAllTypes.FromString(bad_utf8_data) + self.assertIn('field: protobuf_unittest.TestAllTypes.optional_string', + str(context.exception)) def testGoldenMessage(self): - golden_data = test_util.GoldenFile('golden_message').read() + golden_data = test_util.GoldenFileData( + 'golden_message_oneof_implemented') golden_message = unittest_pb2.TestAllTypes() golden_message.ParseFromString(golden_data) test_util.ExpectAllFieldsSet(self, golden_message) - self.assertTrue(golden_message.SerializeToString() == golden_data) + self.assertEqual(golden_data, golden_message.SerializeToString()) + golden_copy = copy.deepcopy(golden_message) + self.assertEqual(golden_data, golden_copy.SerializeToString()) def testGoldenExtensions(self): - golden_data = test_util.GoldenFile('golden_message').read() + golden_data = test_util.GoldenFileData('golden_message') golden_message = unittest_pb2.TestAllExtensions() golden_message.ParseFromString(golden_data) all_set = unittest_pb2.TestAllExtensions() test_util.SetAllExtensions(all_set) self.assertEquals(all_set, golden_message) - self.assertTrue(golden_message.SerializeToString() == golden_data) + self.assertEqual(golden_data, golden_message.SerializeToString()) + golden_copy = copy.deepcopy(golden_message) + self.assertEqual(golden_data, golden_copy.SerializeToString()) def testGoldenPackedMessage(self): - golden_data = test_util.GoldenFile('golden_packed_fields_message').read() + golden_data = test_util.GoldenFileData('golden_packed_fields_message') golden_message = unittest_pb2.TestPackedTypes() golden_message.ParseFromString(golden_data) all_set = unittest_pb2.TestPackedTypes() test_util.SetAllPackedFields(all_set) self.assertEquals(all_set, golden_message) - self.assertTrue(all_set.SerializeToString() == golden_data) + self.assertEqual(golden_data, all_set.SerializeToString()) + golden_copy = copy.deepcopy(golden_message) + self.assertEqual(golden_data, golden_copy.SerializeToString()) def testGoldenPackedExtensions(self): - golden_data = test_util.GoldenFile('golden_packed_fields_message').read() + golden_data = test_util.GoldenFileData('golden_packed_fields_message') golden_message = unittest_pb2.TestPackedExtensions() golden_message.ParseFromString(golden_data) all_set = unittest_pb2.TestPackedExtensions() test_util.SetAllPackedExtensions(all_set) self.assertEquals(all_set, golden_message) - self.assertTrue(all_set.SerializeToString() == golden_data) + self.assertEqual(golden_data, all_set.SerializeToString()) + golden_copy = copy.deepcopy(golden_message) + self.assertEqual(golden_data, golden_copy.SerializeToString()) + + def testPickleSupport(self): + golden_data = test_util.GoldenFileData('golden_message') + golden_message = unittest_pb2.TestAllTypes() + golden_message.ParseFromString(golden_data) + pickled_message = pickle.dumps(golden_message) + + unpickled_message = pickle.loads(pickled_message) + self.assertEquals(unpickled_message, golden_message) + + + def testPickleIncompleteProto(self): + golden_message = unittest_pb2.TestRequired(a=1) + pickled_message = pickle.dumps(golden_message) + + unpickled_message = pickle.loads(pickled_message) + self.assertEquals(unpickled_message, golden_message) + self.assertEquals(unpickled_message.a, 1) + # This is still an incomplete proto - so serializing should fail + self.assertRaises(message.EncodeError, unpickled_message.SerializeToString) + + def testPositiveInfinity(self): + golden_data = (b'\x5D\x00\x00\x80\x7F' + b'\x61\x00\x00\x00\x00\x00\x00\xF0\x7F' + b'\xCD\x02\x00\x00\x80\x7F' + b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\x7F') + golden_message = unittest_pb2.TestAllTypes() + golden_message.ParseFromString(golden_data) + self.assertTrue(IsPosInf(golden_message.optional_float)) + self.assertTrue(IsPosInf(golden_message.optional_double)) + self.assertTrue(IsPosInf(golden_message.repeated_float[0])) + self.assertTrue(IsPosInf(golden_message.repeated_double[0])) + self.assertEqual(golden_data, golden_message.SerializeToString()) + + def testNegativeInfinity(self): + golden_data = (b'\x5D\x00\x00\x80\xFF' + b'\x61\x00\x00\x00\x00\x00\x00\xF0\xFF' + b'\xCD\x02\x00\x00\x80\xFF' + b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF0\xFF') + golden_message = unittest_pb2.TestAllTypes() + golden_message.ParseFromString(golden_data) + self.assertTrue(IsNegInf(golden_message.optional_float)) + self.assertTrue(IsNegInf(golden_message.optional_double)) + self.assertTrue(IsNegInf(golden_message.repeated_float[0])) + self.assertTrue(IsNegInf(golden_message.repeated_double[0])) + self.assertEqual(golden_data, golden_message.SerializeToString()) + + def testNotANumber(self): + golden_data = (b'\x5D\x00\x00\xC0\x7F' + b'\x61\x00\x00\x00\x00\x00\x00\xF8\x7F' + b'\xCD\x02\x00\x00\xC0\x7F' + b'\xD1\x02\x00\x00\x00\x00\x00\x00\xF8\x7F') + golden_message = unittest_pb2.TestAllTypes() + golden_message.ParseFromString(golden_data) + self.assertTrue(isnan(golden_message.optional_float)) + self.assertTrue(isnan(golden_message.optional_double)) + self.assertTrue(isnan(golden_message.repeated_float[0])) + self.assertTrue(isnan(golden_message.repeated_double[0])) + + # The protocol buffer may serialize to any one of multiple different + # representations of a NaN. Rather than verify a specific representation, + # verify the serialized string can be converted into a correctly + # behaving protocol buffer. + serialized = golden_message.SerializeToString() + message = unittest_pb2.TestAllTypes() + message.ParseFromString(serialized) + self.assertTrue(isnan(message.optional_float)) + self.assertTrue(isnan(message.optional_double)) + self.assertTrue(isnan(message.repeated_float[0])) + self.assertTrue(isnan(message.repeated_double[0])) + + def testPositiveInfinityPacked(self): + golden_data = (b'\xA2\x06\x04\x00\x00\x80\x7F' + b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF0\x7F') + golden_message = unittest_pb2.TestPackedTypes() + golden_message.ParseFromString(golden_data) + self.assertTrue(IsPosInf(golden_message.packed_float[0])) + self.assertTrue(IsPosInf(golden_message.packed_double[0])) + self.assertEqual(golden_data, golden_message.SerializeToString()) + + def testNegativeInfinityPacked(self): + golden_data = (b'\xA2\x06\x04\x00\x00\x80\xFF' + b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF0\xFF') + golden_message = unittest_pb2.TestPackedTypes() + golden_message.ParseFromString(golden_data) + self.assertTrue(IsNegInf(golden_message.packed_float[0])) + self.assertTrue(IsNegInf(golden_message.packed_double[0])) + self.assertEqual(golden_data, golden_message.SerializeToString()) + + def testNotANumberPacked(self): + golden_data = (b'\xA2\x06\x04\x00\x00\xC0\x7F' + b'\xAA\x06\x08\x00\x00\x00\x00\x00\x00\xF8\x7F') + golden_message = unittest_pb2.TestPackedTypes() + golden_message.ParseFromString(golden_data) + self.assertTrue(isnan(golden_message.packed_float[0])) + self.assertTrue(isnan(golden_message.packed_double[0])) + + serialized = golden_message.SerializeToString() + message = unittest_pb2.TestPackedTypes() + message.ParseFromString(serialized) + self.assertTrue(isnan(message.packed_float[0])) + self.assertTrue(isnan(message.packed_double[0])) + + def testExtremeFloatValues(self): + message = unittest_pb2.TestAllTypes() + + # Most positive exponent, no significand bits set. + kMostPosExponentNoSigBits = math.pow(2, 127) + message.optional_float = kMostPosExponentNoSigBits + message.ParseFromString(message.SerializeToString()) + self.assertTrue(message.optional_float == kMostPosExponentNoSigBits) + + # Most positive exponent, one significand bit set. + kMostPosExponentOneSigBit = 1.5 * math.pow(2, 127) + message.optional_float = kMostPosExponentOneSigBit + message.ParseFromString(message.SerializeToString()) + self.assertTrue(message.optional_float == kMostPosExponentOneSigBit) + + # Repeat last two cases with values of same magnitude, but negative. + message.optional_float = -kMostPosExponentNoSigBits + message.ParseFromString(message.SerializeToString()) + self.assertTrue(message.optional_float == -kMostPosExponentNoSigBits) + + message.optional_float = -kMostPosExponentOneSigBit + message.ParseFromString(message.SerializeToString()) + self.assertTrue(message.optional_float == -kMostPosExponentOneSigBit) + + # Most negative exponent, no significand bits set. + kMostNegExponentNoSigBits = math.pow(2, -127) + message.optional_float = kMostNegExponentNoSigBits + message.ParseFromString(message.SerializeToString()) + self.assertTrue(message.optional_float == kMostNegExponentNoSigBits) + + # Most negative exponent, one significand bit set. + kMostNegExponentOneSigBit = 1.5 * math.pow(2, -127) + message.optional_float = kMostNegExponentOneSigBit + message.ParseFromString(message.SerializeToString()) + self.assertTrue(message.optional_float == kMostNegExponentOneSigBit) + + # Repeat last two cases with values of the same magnitude, but negative. + message.optional_float = -kMostNegExponentNoSigBits + message.ParseFromString(message.SerializeToString()) + self.assertTrue(message.optional_float == -kMostNegExponentNoSigBits) + + message.optional_float = -kMostNegExponentOneSigBit + message.ParseFromString(message.SerializeToString()) + self.assertTrue(message.optional_float == -kMostNegExponentOneSigBit) + + def testExtremeDoubleValues(self): + message = unittest_pb2.TestAllTypes() + + # Most positive exponent, no significand bits set. + kMostPosExponentNoSigBits = math.pow(2, 1023) + message.optional_double = kMostPosExponentNoSigBits + message.ParseFromString(message.SerializeToString()) + self.assertTrue(message.optional_double == kMostPosExponentNoSigBits) + + # Most positive exponent, one significand bit set. + kMostPosExponentOneSigBit = 1.5 * math.pow(2, 1023) + message.optional_double = kMostPosExponentOneSigBit + message.ParseFromString(message.SerializeToString()) + self.assertTrue(message.optional_double == kMostPosExponentOneSigBit) + + # Repeat last two cases with values of same magnitude, but negative. + message.optional_double = -kMostPosExponentNoSigBits + message.ParseFromString(message.SerializeToString()) + self.assertTrue(message.optional_double == -kMostPosExponentNoSigBits) + + message.optional_double = -kMostPosExponentOneSigBit + message.ParseFromString(message.SerializeToString()) + self.assertTrue(message.optional_double == -kMostPosExponentOneSigBit) + + # Most negative exponent, no significand bits set. + kMostNegExponentNoSigBits = math.pow(2, -1023) + message.optional_double = kMostNegExponentNoSigBits + message.ParseFromString(message.SerializeToString()) + self.assertTrue(message.optional_double == kMostNegExponentNoSigBits) + + # Most negative exponent, one significand bit set. + kMostNegExponentOneSigBit = 1.5 * math.pow(2, -1023) + message.optional_double = kMostNegExponentOneSigBit + message.ParseFromString(message.SerializeToString()) + self.assertTrue(message.optional_double == kMostNegExponentOneSigBit) + + # Repeat last two cases with values of the same magnitude, but negative. + message.optional_double = -kMostNegExponentNoSigBits + message.ParseFromString(message.SerializeToString()) + self.assertTrue(message.optional_double == -kMostNegExponentNoSigBits) + + message.optional_double = -kMostNegExponentOneSigBit + message.ParseFromString(message.SerializeToString()) + self.assertTrue(message.optional_double == -kMostNegExponentOneSigBit) + + def testFloatPrinting(self): + message = unittest_pb2.TestAllTypes() + message.optional_float = 2.0 + self.assertEqual(str(message), 'optional_float: 2.0\n') + + def testHighPrecisionFloatPrinting(self): + message = unittest_pb2.TestAllTypes() + message.optional_double = 0.12345678912345678 + if sys.version_info.major >= 3: + self.assertEqual(str(message), 'optional_double: 0.12345678912345678\n') + else: + self.assertEqual(str(message), 'optional_double: 0.123456789123\n') + + def testUnknownFieldPrinting(self): + populated = unittest_pb2.TestAllTypes() + test_util.SetAllNonLazyFields(populated) + empty = unittest_pb2.TestEmptyMessage() + empty.ParseFromString(populated.SerializeToString()) + self.assertEqual(str(empty), '') + + def testSortingRepeatedScalarFieldsDefaultComparator(self): + """Check some different types with the default comparator.""" + message = unittest_pb2.TestAllTypes() + + # TODO(mattp): would testing more scalar types strengthen test? + message.repeated_int32.append(1) + message.repeated_int32.append(3) + message.repeated_int32.append(2) + message.repeated_int32.sort() + self.assertEqual(message.repeated_int32[0], 1) + self.assertEqual(message.repeated_int32[1], 2) + self.assertEqual(message.repeated_int32[2], 3) + + message.repeated_float.append(1.1) + message.repeated_float.append(1.3) + message.repeated_float.append(1.2) + message.repeated_float.sort() + self.assertAlmostEqual(message.repeated_float[0], 1.1) + self.assertAlmostEqual(message.repeated_float[1], 1.2) + self.assertAlmostEqual(message.repeated_float[2], 1.3) + + message.repeated_string.append('a') + message.repeated_string.append('c') + message.repeated_string.append('b') + message.repeated_string.sort() + self.assertEqual(message.repeated_string[0], 'a') + self.assertEqual(message.repeated_string[1], 'b') + self.assertEqual(message.repeated_string[2], 'c') + + message.repeated_bytes.append(b'a') + message.repeated_bytes.append(b'c') + message.repeated_bytes.append(b'b') + message.repeated_bytes.sort() + self.assertEqual(message.repeated_bytes[0], b'a') + self.assertEqual(message.repeated_bytes[1], b'b') + self.assertEqual(message.repeated_bytes[2], b'c') + + def testSortingRepeatedScalarFieldsCustomComparator(self): + """Check some different types with custom comparator.""" + message = unittest_pb2.TestAllTypes() + + message.repeated_int32.append(-3) + message.repeated_int32.append(-2) + message.repeated_int32.append(-1) + message.repeated_int32.sort(key=abs) + self.assertEqual(message.repeated_int32[0], -1) + self.assertEqual(message.repeated_int32[1], -2) + self.assertEqual(message.repeated_int32[2], -3) + + message.repeated_string.append('aaa') + message.repeated_string.append('bb') + message.repeated_string.append('c') + message.repeated_string.sort(key=len) + self.assertEqual(message.repeated_string[0], 'c') + self.assertEqual(message.repeated_string[1], 'bb') + self.assertEqual(message.repeated_string[2], 'aaa') + + def testSortingRepeatedCompositeFieldsCustomComparator(self): + """Check passing a custom comparator to sort a repeated composite field.""" + message = unittest_pb2.TestAllTypes() + + message.repeated_nested_message.add().bb = 1 + message.repeated_nested_message.add().bb = 3 + message.repeated_nested_message.add().bb = 2 + message.repeated_nested_message.add().bb = 6 + message.repeated_nested_message.add().bb = 5 + message.repeated_nested_message.add().bb = 4 + message.repeated_nested_message.sort(key=operator.attrgetter('bb')) + self.assertEqual(message.repeated_nested_message[0].bb, 1) + self.assertEqual(message.repeated_nested_message[1].bb, 2) + self.assertEqual(message.repeated_nested_message[2].bb, 3) + self.assertEqual(message.repeated_nested_message[3].bb, 4) + self.assertEqual(message.repeated_nested_message[4].bb, 5) + self.assertEqual(message.repeated_nested_message[5].bb, 6) + + def testRepeatedCompositeFieldSortArguments(self): + """Check sorting a repeated composite field using list.sort() arguments.""" + message = unittest_pb2.TestAllTypes() + + get_bb = operator.attrgetter('bb') + cmp_bb = lambda a, b: cmp(a.bb, b.bb) + message.repeated_nested_message.add().bb = 1 + message.repeated_nested_message.add().bb = 3 + message.repeated_nested_message.add().bb = 2 + message.repeated_nested_message.add().bb = 6 + message.repeated_nested_message.add().bb = 5 + message.repeated_nested_message.add().bb = 4 + message.repeated_nested_message.sort(key=get_bb) + self.assertEqual([k.bb for k in message.repeated_nested_message], + [1, 2, 3, 4, 5, 6]) + message.repeated_nested_message.sort(key=get_bb, reverse=True) + self.assertEqual([k.bb for k in message.repeated_nested_message], + [6, 5, 4, 3, 2, 1]) + if sys.version_info.major >= 3: return # No cmp sorting in PY3. + message.repeated_nested_message.sort(sort_function=cmp_bb) + self.assertEqual([k.bb for k in message.repeated_nested_message], + [1, 2, 3, 4, 5, 6]) + message.repeated_nested_message.sort(cmp=cmp_bb, reverse=True) + self.assertEqual([k.bb for k in message.repeated_nested_message], + [6, 5, 4, 3, 2, 1]) + + def testRepeatedScalarFieldSortArguments(self): + """Check sorting a scalar field using list.sort() arguments.""" + message = unittest_pb2.TestAllTypes() + + message.repeated_int32.append(-3) + message.repeated_int32.append(-2) + message.repeated_int32.append(-1) + message.repeated_int32.sort(key=abs) + self.assertEqual(list(message.repeated_int32), [-1, -2, -3]) + message.repeated_int32.sort(key=abs, reverse=True) + self.assertEqual(list(message.repeated_int32), [-3, -2, -1]) + if sys.version_info.major < 3: # No cmp sorting in PY3. + abs_cmp = lambda a, b: cmp(abs(a), abs(b)) + message.repeated_int32.sort(sort_function=abs_cmp) + self.assertEqual(list(message.repeated_int32), [-1, -2, -3]) + message.repeated_int32.sort(cmp=abs_cmp, reverse=True) + self.assertEqual(list(message.repeated_int32), [-3, -2, -1]) + + message.repeated_string.append('aaa') + message.repeated_string.append('bb') + message.repeated_string.append('c') + message.repeated_string.sort(key=len) + self.assertEqual(list(message.repeated_string), ['c', 'bb', 'aaa']) + message.repeated_string.sort(key=len, reverse=True) + self.assertEqual(list(message.repeated_string), ['aaa', 'bb', 'c']) + if sys.version_info.major < 3: # No cmp sorting in PY3. + len_cmp = lambda a, b: cmp(len(a), len(b)) + message.repeated_string.sort(sort_function=len_cmp) + self.assertEqual(list(message.repeated_string), ['c', 'bb', 'aaa']) + message.repeated_string.sort(cmp=len_cmp, reverse=True) + self.assertEqual(list(message.repeated_string), ['aaa', 'bb', 'c']) + + def testRepeatedFieldsComparable(self): + m1 = unittest_pb2.TestAllTypes() + m2 = unittest_pb2.TestAllTypes() + m1.repeated_int32.append(0) + m1.repeated_int32.append(1) + m1.repeated_int32.append(2) + m2.repeated_int32.append(0) + m2.repeated_int32.append(1) + m2.repeated_int32.append(2) + m1.repeated_nested_message.add().bb = 1 + m1.repeated_nested_message.add().bb = 2 + m1.repeated_nested_message.add().bb = 3 + m2.repeated_nested_message.add().bb = 1 + m2.repeated_nested_message.add().bb = 2 + m2.repeated_nested_message.add().bb = 3 + + if sys.version_info.major >= 3: return # No cmp() in PY3. + + # These comparisons should not raise errors. + _ = m1 < m2 + _ = m1.repeated_nested_message < m2.repeated_nested_message + + # Make sure cmp always works. If it wasn't defined, these would be + # id() comparisons and would all fail. + self.assertEqual(cmp(m1, m2), 0) + self.assertEqual(cmp(m1.repeated_int32, m2.repeated_int32), 0) + self.assertEqual(cmp(m1.repeated_int32, [0, 1, 2]), 0) + self.assertEqual(cmp(m1.repeated_nested_message, + m2.repeated_nested_message), 0) + with self.assertRaises(TypeError): + # Can't compare repeated composite containers to lists. + cmp(m1.repeated_nested_message, m2.repeated_nested_message[:]) + + # TODO(anuraag): Implement extensiondict comparison in C++ and then add test + + def testParsingMerge(self): + """Check the merge behavior when a required or optional field appears + multiple times in the input.""" + messages = [ + unittest_pb2.TestAllTypes(), + unittest_pb2.TestAllTypes(), + unittest_pb2.TestAllTypes() ] + messages[0].optional_int32 = 1 + messages[1].optional_int64 = 2 + messages[2].optional_int32 = 3 + messages[2].optional_string = 'hello' + + merged_message = unittest_pb2.TestAllTypes() + merged_message.optional_int32 = 3 + merged_message.optional_int64 = 2 + merged_message.optional_string = 'hello' + + generator = unittest_pb2.TestParsingMerge.RepeatedFieldsGenerator() + generator.field1.extend(messages) + generator.field2.extend(messages) + generator.field3.extend(messages) + generator.ext1.extend(messages) + generator.ext2.extend(messages) + generator.group1.add().field1.MergeFrom(messages[0]) + generator.group1.add().field1.MergeFrom(messages[1]) + generator.group1.add().field1.MergeFrom(messages[2]) + generator.group2.add().field1.MergeFrom(messages[0]) + generator.group2.add().field1.MergeFrom(messages[1]) + generator.group2.add().field1.MergeFrom(messages[2]) + + data = generator.SerializeToString() + parsing_merge = unittest_pb2.TestParsingMerge() + parsing_merge.ParseFromString(data) + + # Required and optional fields should be merged. + self.assertEqual(parsing_merge.required_all_types, merged_message) + self.assertEqual(parsing_merge.optional_all_types, merged_message) + self.assertEqual(parsing_merge.optionalgroup.optional_group_all_types, + merged_message) + self.assertEqual(parsing_merge.Extensions[ + unittest_pb2.TestParsingMerge.optional_ext], + merged_message) + + # Repeated fields should not be merged. + self.assertEqual(len(parsing_merge.repeated_all_types), 3) + self.assertEqual(len(parsing_merge.repeatedgroup), 3) + self.assertEqual(len(parsing_merge.Extensions[ + unittest_pb2.TestParsingMerge.repeated_ext]), 3) + + def ensureNestedMessageExists(self, msg, attribute): + """Make sure that a nested message object exists. + + As soon as a nested message attribute is accessed, it will be present in the + _fields dict, without being marked as actually being set. + """ + getattr(msg, attribute) + self.assertFalse(msg.HasField(attribute)) + + def testOneofGetCaseNonexistingField(self): + m = unittest_pb2.TestAllTypes() + self.assertRaises(ValueError, m.WhichOneof, 'no_such_oneof_field') + + def testOneofSemantics(self): + m = unittest_pb2.TestAllTypes() + self.assertIs(None, m.WhichOneof('oneof_field')) + + m.oneof_uint32 = 11 + self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field')) + self.assertTrue(m.HasField('oneof_uint32')) + + m.oneof_string = u'foo' + self.assertEqual('oneof_string', m.WhichOneof('oneof_field')) + self.assertFalse(m.HasField('oneof_uint32')) + self.assertTrue(m.HasField('oneof_string')) + + m.oneof_nested_message.bb = 11 + self.assertEqual('oneof_nested_message', m.WhichOneof('oneof_field')) + self.assertFalse(m.HasField('oneof_string')) + self.assertTrue(m.HasField('oneof_nested_message')) + + m.oneof_bytes = b'bb' + self.assertEqual('oneof_bytes', m.WhichOneof('oneof_field')) + self.assertFalse(m.HasField('oneof_nested_message')) + self.assertTrue(m.HasField('oneof_bytes')) + + def testOneofCompositeFieldReadAccess(self): + m = unittest_pb2.TestAllTypes() + m.oneof_uint32 = 11 + + self.ensureNestedMessageExists(m, 'oneof_nested_message') + self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field')) + self.assertEqual(11, m.oneof_uint32) + + def testOneofHasField(self): + m = unittest_pb2.TestAllTypes() + self.assertFalse(m.HasField('oneof_field')) + m.oneof_uint32 = 11 + self.assertTrue(m.HasField('oneof_field')) + m.oneof_bytes = b'bb' + self.assertTrue(m.HasField('oneof_field')) + m.ClearField('oneof_bytes') + self.assertFalse(m.HasField('oneof_field')) + + def testOneofClearField(self): + m = unittest_pb2.TestAllTypes() + m.oneof_uint32 = 11 + m.ClearField('oneof_field') + self.assertFalse(m.HasField('oneof_field')) + self.assertFalse(m.HasField('oneof_uint32')) + self.assertIs(None, m.WhichOneof('oneof_field')) + + def testOneofClearSetField(self): + m = unittest_pb2.TestAllTypes() + m.oneof_uint32 = 11 + m.ClearField('oneof_uint32') + self.assertFalse(m.HasField('oneof_field')) + self.assertFalse(m.HasField('oneof_uint32')) + self.assertIs(None, m.WhichOneof('oneof_field')) + + def testOneofClearUnsetField(self): + m = unittest_pb2.TestAllTypes() + m.oneof_uint32 = 11 + self.ensureNestedMessageExists(m, 'oneof_nested_message') + m.ClearField('oneof_nested_message') + self.assertEqual(11, m.oneof_uint32) + self.assertTrue(m.HasField('oneof_field')) + self.assertTrue(m.HasField('oneof_uint32')) + self.assertEqual('oneof_uint32', m.WhichOneof('oneof_field')) + + def testOneofDeserialize(self): + m = unittest_pb2.TestAllTypes() + m.oneof_uint32 = 11 + m2 = unittest_pb2.TestAllTypes() + m2.ParseFromString(m.SerializeToString()) + self.assertEqual('oneof_uint32', m2.WhichOneof('oneof_field')) + + def testSortEmptyRepeatedCompositeContainer(self): + """Exercise a scenario that has led to segfaults in the past. + """ + m = unittest_pb2.TestAllTypes() + m.repeated_nested_message.sort() + + def testHasFieldOnRepeatedField(self): + """Using HasField on a repeated field should raise an exception. + """ + m = unittest_pb2.TestAllTypes() + with self.assertRaises(ValueError) as _: + m.HasField('repeated_int32') + + +class ValidTypeNamesTest(basetest.TestCase): + + def assertImportFromName(self, msg, base_name): + # Parse <type 'module.class_name'> to extra 'some.name' as a string. + tp_name = str(type(msg)).split("'")[1] + valid_names = ('Repeated%sContainer' % base_name, + 'Repeated%sFieldContainer' % base_name) + self.assertTrue(any(tp_name.endswith(v) for v in valid_names), + '%r does end with any of %r' % (tp_name, valid_names)) + + parts = tp_name.split('.') + class_name = parts[-1] + module_name = '.'.join(parts[:-1]) + __import__(module_name, fromlist=[class_name]) + + def testTypeNamesCanBeImported(self): + # If import doesn't work, pickling won't work either. + pb = unittest_pb2.TestAllTypes() + self.assertImportFromName(pb.repeated_int32, 'Scalar') + self.assertImportFromName(pb.repeated_nested_message, 'Composite') + if __name__ == '__main__': - unittest.main() + basetest.main() diff --git a/python/google/protobuf/internal/missing_enum_values.proto b/python/google/protobuf/internal/missing_enum_values.proto new file mode 100644 index 0000000..e90f0cd --- /dev/null +++ b/python/google/protobuf/internal/missing_enum_values.proto @@ -0,0 +1,50 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package google.protobuf.python.internal; + +message TestEnumValues { + enum NestedEnum { + ZERO = 0; + ONE = 1; + } + optional NestedEnum optional_nested_enum = 1; + repeated NestedEnum repeated_nested_enum = 2; + repeated NestedEnum packed_nested_enum = 3 [packed = true]; +} + +message TestMissingEnumValues { + enum NestedEnum { + TWO = 2; + } + optional NestedEnum optional_nested_enum = 1; + repeated NestedEnum repeated_nested_enum = 2; + repeated NestedEnum packed_nested_enum = 3 [packed = true]; +} diff --git a/python/google/protobuf/internal/more_extensions.proto b/python/google/protobuf/internal/more_extensions.proto index e2d9701..c04e597 100644 --- a/python/google/protobuf/internal/more_extensions.proto +++ b/python/google/protobuf/internal/more_extensions.proto @@ -1,6 +1,6 @@ // Protocol Buffers - Google's data interchange format // Copyright 2008 Google Inc. All rights reserved. -// http://code.google.com/p/protobuf/ +// https://developers.google.com/protocol-buffers/ // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions are diff --git a/python/google/protobuf/internal/more_extensions_dynamic.proto b/python/google/protobuf/internal/more_extensions_dynamic.proto new file mode 100644 index 0000000..88bd9c1 --- /dev/null +++ b/python/google/protobuf/internal/more_extensions_dynamic.proto @@ -0,0 +1,49 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: jasonh@google.com (Jason Hsueh) +// +// This file is used to test a corner case in the CPP implementation where the +// generated C++ type is available for the extendee, but the extension is +// defined in a file whose C++ type is not in the binary. + + +import "google/protobuf/internal/more_extensions.proto"; + +package google.protobuf.internal; + +message DynamicMessageType { + optional int32 a = 1; +} + +extend ExtendedMessage { + optional int32 dynamic_int32_extension = 100; + optional DynamicMessageType dynamic_message_extension = 101; +} diff --git a/python/google/protobuf/internal/more_messages.proto b/python/google/protobuf/internal/more_messages.proto index c701b44..61db66c 100644 --- a/python/google/protobuf/internal/more_messages.proto +++ b/python/google/protobuf/internal/more_messages.proto @@ -1,6 +1,6 @@ // Protocol Buffers - Google's data interchange format // Copyright 2008 Google Inc. All rights reserved. -// http://code.google.com/p/protobuf/ +// https://developers.google.com/protocol-buffers/ // // Redistribution and use in source and binary forms, with or without // modification, are permitted provided that the following conditions are diff --git a/python/google/protobuf/internal/python_message.py b/python/google/protobuf/internal/python_message.py new file mode 100755 index 0000000..a5c26f4 --- /dev/null +++ b/python/google/protobuf/internal/python_message.py @@ -0,0 +1,1251 @@ +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# https://developers.google.com/protocol-buffers/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# Keep it Python2.5 compatible for GAE. +# +# Copyright 2007 Google Inc. All Rights Reserved. +# +# This code is meant to work on Python 2.4 and above only. +# +# TODO(robinson): Helpers for verbose, common checks like seeing if a +# descriptor's cpp_type is CPPTYPE_MESSAGE. + +"""Contains a metaclass and helper functions used to create +protocol message classes from Descriptor objects at runtime. + +Recall that a metaclass is the "type" of a class. +(A class is to a metaclass what an instance is to a class.) + +In this case, we use the GeneratedProtocolMessageType metaclass +to inject all the useful functionality into the classes +output by the protocol compiler at compile-time. + +The upshot of all this is that the real implementation +details for ALL pure-Python protocol buffers are *here in +this file*. +""" + +__author__ = 'robinson@google.com (Will Robinson)' + +import sys +if sys.version_info[0] < 3: + try: + from cStringIO import StringIO as BytesIO + except ImportError: + from StringIO import StringIO as BytesIO + import copy_reg as copyreg +else: + from io import BytesIO + import copyreg +import struct +import weakref + +# We use "as" to avoid name collisions with variables. +from google.protobuf.internal import containers +from google.protobuf.internal import decoder +from google.protobuf.internal import encoder +from google.protobuf.internal import enum_type_wrapper +from google.protobuf.internal import message_listener as message_listener_mod +from google.protobuf.internal import type_checkers +from google.protobuf.internal import wire_format +from google.protobuf import descriptor as descriptor_mod +from google.protobuf import message as message_mod +from google.protobuf import text_format + +_FieldDescriptor = descriptor_mod.FieldDescriptor + + +def NewMessage(bases, descriptor, dictionary): + _AddClassAttributesForNestedExtensions(descriptor, dictionary) + _AddSlots(descriptor, dictionary) + return bases + + +def InitMessage(descriptor, cls): + cls._decoders_by_tag = {} + cls._extensions_by_name = {} + cls._extensions_by_number = {} + if (descriptor.has_options and + descriptor.GetOptions().message_set_wire_format): + cls._decoders_by_tag[decoder.MESSAGE_SET_ITEM_TAG] = ( + decoder.MessageSetItemDecoder(cls._extensions_by_number), None) + + # Attach stuff to each FieldDescriptor for quick lookup later on. + for field in descriptor.fields: + _AttachFieldHelpers(cls, field) + + _AddEnumValues(descriptor, cls) + _AddInitMethod(descriptor, cls) + _AddPropertiesForFields(descriptor, cls) + _AddPropertiesForExtensions(descriptor, cls) + _AddStaticMethods(cls) + _AddMessageMethods(descriptor, cls) + _AddPrivateHelperMethods(descriptor, cls) + copyreg.pickle(cls, lambda obj: (cls, (), obj.__getstate__())) + + +# Stateless helpers for GeneratedProtocolMessageType below. +# Outside clients should not access these directly. +# +# I opted not to make any of these methods on the metaclass, to make it more +# clear that I'm not really using any state there and to keep clients from +# thinking that they have direct access to these construction helpers. + + +def _PropertyName(proto_field_name): + """Returns the name of the public property attribute which + clients can use to get and (in some cases) set the value + of a protocol message field. + + Args: + proto_field_name: The protocol message field name, exactly + as it appears (or would appear) in a .proto file. + """ + # TODO(robinson): Escape Python keywords (e.g., yield), and test this support. + # nnorwitz makes my day by writing: + # """ + # FYI. See the keyword module in the stdlib. This could be as simple as: + # + # if keyword.iskeyword(proto_field_name): + # return proto_field_name + "_" + # return proto_field_name + # """ + # Kenton says: The above is a BAD IDEA. People rely on being able to use + # getattr() and setattr() to reflectively manipulate field values. If we + # rename the properties, then every such user has to also make sure to apply + # the same transformation. Note that currently if you name a field "yield", + # you can still access it just fine using getattr/setattr -- it's not even + # that cumbersome to do so. + # TODO(kenton): Remove this method entirely if/when everyone agrees with my + # position. + return proto_field_name + + +def _VerifyExtensionHandle(message, extension_handle): + """Verify that the given extension handle is valid.""" + + if not isinstance(extension_handle, _FieldDescriptor): + raise KeyError('HasExtension() expects an extension handle, got: %s' % + extension_handle) + + if not extension_handle.is_extension: + raise KeyError('"%s" is not an extension.' % extension_handle.full_name) + + if not extension_handle.containing_type: + raise KeyError('"%s" is missing a containing_type.' + % extension_handle.full_name) + + if extension_handle.containing_type is not message.DESCRIPTOR: + raise KeyError('Extension "%s" extends message type "%s", but this ' + 'message is of type "%s".' % + (extension_handle.full_name, + extension_handle.containing_type.full_name, + message.DESCRIPTOR.full_name)) + + +def _AddSlots(message_descriptor, dictionary): + """Adds a __slots__ entry to dictionary, containing the names of all valid + attributes for this message type. + + Args: + message_descriptor: A Descriptor instance describing this message type. + dictionary: Class dictionary to which we'll add a '__slots__' entry. + """ + dictionary['__slots__'] = ['_cached_byte_size', + '_cached_byte_size_dirty', + '_fields', + '_unknown_fields', + '_is_present_in_parent', + '_listener', + '_listener_for_children', + '__weakref__', + '_oneofs'] + + +def _IsMessageSetExtension(field): + return (field.is_extension and + field.containing_type.has_options and + field.containing_type.GetOptions().message_set_wire_format and + field.type == _FieldDescriptor.TYPE_MESSAGE and + field.message_type == field.extension_scope and + field.label == _FieldDescriptor.LABEL_OPTIONAL) + + +def _AttachFieldHelpers(cls, field_descriptor): + is_repeated = (field_descriptor.label == _FieldDescriptor.LABEL_REPEATED) + is_packed = (field_descriptor.has_options and + field_descriptor.GetOptions().packed) + + if _IsMessageSetExtension(field_descriptor): + field_encoder = encoder.MessageSetItemEncoder(field_descriptor.number) + sizer = encoder.MessageSetItemSizer(field_descriptor.number) + else: + field_encoder = type_checkers.TYPE_TO_ENCODER[field_descriptor.type]( + field_descriptor.number, is_repeated, is_packed) + sizer = type_checkers.TYPE_TO_SIZER[field_descriptor.type]( + field_descriptor.number, is_repeated, is_packed) + + field_descriptor._encoder = field_encoder + field_descriptor._sizer = sizer + field_descriptor._default_constructor = _DefaultValueConstructorForField( + field_descriptor) + + def AddDecoder(wiretype, is_packed): + tag_bytes = encoder.TagBytes(field_descriptor.number, wiretype) + cls._decoders_by_tag[tag_bytes] = ( + type_checkers.TYPE_TO_DECODER[field_descriptor.type]( + field_descriptor.number, is_repeated, is_packed, + field_descriptor, field_descriptor._default_constructor), + field_descriptor if field_descriptor.containing_oneof is not None + else None) + + AddDecoder(type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type], + False) + + if is_repeated and wire_format.IsTypePackable(field_descriptor.type): + # To support wire compatibility of adding packed = true, add a decoder for + # packed values regardless of the field's options. + AddDecoder(wire_format.WIRETYPE_LENGTH_DELIMITED, True) + + +def _AddClassAttributesForNestedExtensions(descriptor, dictionary): + extension_dict = descriptor.extensions_by_name + for extension_name, extension_field in extension_dict.iteritems(): + assert extension_name not in dictionary + dictionary[extension_name] = extension_field + + +def _AddEnumValues(descriptor, cls): + """Sets class-level attributes for all enum fields defined in this message. + + Also exporting a class-level object that can name enum values. + + Args: + descriptor: Descriptor object for this message type. + cls: Class we're constructing for this message type. + """ + for enum_type in descriptor.enum_types: + setattr(cls, enum_type.name, enum_type_wrapper.EnumTypeWrapper(enum_type)) + for enum_value in enum_type.values: + setattr(cls, enum_value.name, enum_value.number) + + +def _DefaultValueConstructorForField(field): + """Returns a function which returns a default value for a field. + + Args: + field: FieldDescriptor object for this field. + + The returned function has one argument: + message: Message instance containing this field, or a weakref proxy + of same. + + That function in turn returns a default value for this field. The default + value may refer back to |message| via a weak reference. + """ + + if field.label == _FieldDescriptor.LABEL_REPEATED: + if field.has_default_value and field.default_value != []: + raise ValueError('Repeated field default value not empty list: %s' % ( + field.default_value)) + if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + # We can't look at _concrete_class yet since it might not have + # been set. (Depends on order in which we initialize the classes). + message_type = field.message_type + def MakeRepeatedMessageDefault(message): + return containers.RepeatedCompositeFieldContainer( + message._listener_for_children, field.message_type) + return MakeRepeatedMessageDefault + else: + type_checker = type_checkers.GetTypeChecker(field) + def MakeRepeatedScalarDefault(message): + return containers.RepeatedScalarFieldContainer( + message._listener_for_children, type_checker) + return MakeRepeatedScalarDefault + + if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + # _concrete_class may not yet be initialized. + message_type = field.message_type + def MakeSubMessageDefault(message): + result = message_type._concrete_class() + result._SetListener(message._listener_for_children) + return result + return MakeSubMessageDefault + + def MakeScalarDefault(message): + # TODO(protobuf-team): This may be broken since there may not be + # default_value. Combine with has_default_value somehow. + return field.default_value + return MakeScalarDefault + + +def _AddInitMethod(message_descriptor, cls): + """Adds an __init__ method to cls.""" + fields = message_descriptor.fields + def init(self, **kwargs): + self._cached_byte_size = 0 + self._cached_byte_size_dirty = len(kwargs) > 0 + self._fields = {} + # Contains a mapping from oneof field descriptors to the descriptor + # of the currently set field in that oneof field. + self._oneofs = {} + + # _unknown_fields is () when empty for efficiency, and will be turned into + # a list if fields are added. + self._unknown_fields = () + self._is_present_in_parent = False + self._listener = message_listener_mod.NullMessageListener() + self._listener_for_children = _Listener(self) + for field_name, field_value in kwargs.iteritems(): + field = _GetFieldByName(message_descriptor, field_name) + if field is None: + raise TypeError("%s() got an unexpected keyword argument '%s'" % + (message_descriptor.name, field_name)) + if field.label == _FieldDescriptor.LABEL_REPEATED: + copy = field._default_constructor(self) + if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: # Composite + for val in field_value: + copy.add().MergeFrom(val) + else: # Scalar + copy.extend(field_value) + self._fields[field] = copy + elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + copy = field._default_constructor(self) + copy.MergeFrom(field_value) + self._fields[field] = copy + else: + setattr(self, field_name, field_value) + + init.__module__ = None + init.__doc__ = None + cls.__init__ = init + + +def _GetFieldByName(message_descriptor, field_name): + """Returns a field descriptor by field name. + + Args: + message_descriptor: A Descriptor describing all fields in message. + field_name: The name of the field to retrieve. + Returns: + The field descriptor associated with the field name. + """ + try: + return message_descriptor.fields_by_name[field_name] + except KeyError: + raise ValueError('Protocol message has no "%s" field.' % field_name) + + +def _AddPropertiesForFields(descriptor, cls): + """Adds properties for all fields in this protocol message type.""" + for field in descriptor.fields: + _AddPropertiesForField(field, cls) + + if descriptor.is_extendable: + # _ExtensionDict is just an adaptor with no state so we allocate a new one + # every time it is accessed. + cls.Extensions = property(lambda self: _ExtensionDict(self)) + + +def _AddPropertiesForField(field, cls): + """Adds a public property for a protocol message field. + Clients can use this property to get and (in the case + of non-repeated scalar fields) directly set the value + of a protocol message field. + + Args: + field: A FieldDescriptor for this field. + cls: The class we're constructing. + """ + # Catch it if we add other types that we should + # handle specially here. + assert _FieldDescriptor.MAX_CPPTYPE == 10 + + constant_name = field.name.upper() + "_FIELD_NUMBER" + setattr(cls, constant_name, field.number) + + if field.label == _FieldDescriptor.LABEL_REPEATED: + _AddPropertiesForRepeatedField(field, cls) + elif field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + _AddPropertiesForNonRepeatedCompositeField(field, cls) + else: + _AddPropertiesForNonRepeatedScalarField(field, cls) + + +def _AddPropertiesForRepeatedField(field, cls): + """Adds a public property for a "repeated" protocol message field. Clients + can use this property to get the value of the field, which will be either a + _RepeatedScalarFieldContainer or _RepeatedCompositeFieldContainer (see + below). + + Note that when clients add values to these containers, we perform + type-checking in the case of repeated scalar fields, and we also set any + necessary "has" bits as a side-effect. + + Args: + field: A FieldDescriptor for this field. + cls: The class we're constructing. + """ + proto_field_name = field.name + property_name = _PropertyName(proto_field_name) + + def getter(self): + field_value = self._fields.get(field) + if field_value is None: + # Construct a new object to represent this field. + field_value = field._default_constructor(self) + + # Atomically check if another thread has preempted us and, if not, swap + # in the new object we just created. If someone has preempted us, we + # take that object and discard ours. + # WARNING: We are relying on setdefault() being atomic. This is true + # in CPython but we haven't investigated others. This warning appears + # in several other locations in this file. + field_value = self._fields.setdefault(field, field_value) + return field_value + getter.__module__ = None + getter.__doc__ = 'Getter for %s.' % proto_field_name + + # We define a setter just so we can throw an exception with a more + # helpful error message. + def setter(self, new_value): + raise AttributeError('Assignment not allowed to repeated field ' + '"%s" in protocol message object.' % proto_field_name) + + doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name + setattr(cls, property_name, property(getter, setter, doc=doc)) + + +def _AddPropertiesForNonRepeatedScalarField(field, cls): + """Adds a public property for a nonrepeated, scalar protocol message field. + Clients can use this property to get and directly set the value of the field. + Note that when the client sets the value of a field by using this property, + all necessary "has" bits are set as a side-effect, and we also perform + type-checking. + + Args: + field: A FieldDescriptor for this field. + cls: The class we're constructing. + """ + proto_field_name = field.name + property_name = _PropertyName(proto_field_name) + type_checker = type_checkers.GetTypeChecker(field) + default_value = field.default_value + valid_values = set() + + def getter(self): + # TODO(protobuf-team): This may be broken since there may not be + # default_value. Combine with has_default_value somehow. + return self._fields.get(field, default_value) + getter.__module__ = None + getter.__doc__ = 'Getter for %s.' % proto_field_name + def field_setter(self, new_value): + # pylint: disable=protected-access + self._fields[field] = type_checker.CheckValue(new_value) + # Check _cached_byte_size_dirty inline to improve performance, since scalar + # setters are called frequently. + if not self._cached_byte_size_dirty: + self._Modified() + + if field.containing_oneof is not None: + def setter(self, new_value): + field_setter(self, new_value) + self._UpdateOneofState(field) + else: + setter = field_setter + + setter.__module__ = None + setter.__doc__ = 'Setter for %s.' % proto_field_name + + # Add a property to encapsulate the getter/setter. + doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name + setattr(cls, property_name, property(getter, setter, doc=doc)) + + +def _AddPropertiesForNonRepeatedCompositeField(field, cls): + """Adds a public property for a nonrepeated, composite protocol message field. + A composite field is a "group" or "message" field. + + Clients can use this property to get the value of the field, but cannot + assign to the property directly. + + Args: + field: A FieldDescriptor for this field. + cls: The class we're constructing. + """ + # TODO(robinson): Remove duplication with similar method + # for non-repeated scalars. + proto_field_name = field.name + property_name = _PropertyName(proto_field_name) + + # TODO(komarek): Can anyone explain to me why we cache the message_type this + # way, instead of referring to field.message_type inside of getter(self)? + # What if someone sets message_type later on (which makes for simpler + # dyanmic proto descriptor and class creation code). + message_type = field.message_type + + def getter(self): + field_value = self._fields.get(field) + if field_value is None: + # Construct a new object to represent this field. + field_value = message_type._concrete_class() # use field.message_type? + field_value._SetListener( + _OneofListener(self, field) + if field.containing_oneof is not None + else self._listener_for_children) + + # Atomically check if another thread has preempted us and, if not, swap + # in the new object we just created. If someone has preempted us, we + # take that object and discard ours. + # WARNING: We are relying on setdefault() being atomic. This is true + # in CPython but we haven't investigated others. This warning appears + # in several other locations in this file. + field_value = self._fields.setdefault(field, field_value) + return field_value + getter.__module__ = None + getter.__doc__ = 'Getter for %s.' % proto_field_name + + # We define a setter just so we can throw an exception with a more + # helpful error message. + def setter(self, new_value): + raise AttributeError('Assignment not allowed to composite field ' + '"%s" in protocol message object.' % proto_field_name) + + # Add a property to encapsulate the getter. + doc = 'Magic attribute generated for "%s" proto field.' % proto_field_name + setattr(cls, property_name, property(getter, setter, doc=doc)) + + +def _AddPropertiesForExtensions(descriptor, cls): + """Adds properties for all fields in this protocol message type.""" + extension_dict = descriptor.extensions_by_name + for extension_name, extension_field in extension_dict.iteritems(): + constant_name = extension_name.upper() + "_FIELD_NUMBER" + setattr(cls, constant_name, extension_field.number) + + +def _AddStaticMethods(cls): + # TODO(robinson): This probably needs to be thread-safe(?) + def RegisterExtension(extension_handle): + extension_handle.containing_type = cls.DESCRIPTOR + _AttachFieldHelpers(cls, extension_handle) + + # Try to insert our extension, failing if an extension with the same number + # already exists. + actual_handle = cls._extensions_by_number.setdefault( + extension_handle.number, extension_handle) + if actual_handle is not extension_handle: + raise AssertionError( + 'Extensions "%s" and "%s" both try to extend message type "%s" with ' + 'field number %d.' % + (extension_handle.full_name, actual_handle.full_name, + cls.DESCRIPTOR.full_name, extension_handle.number)) + + cls._extensions_by_name[extension_handle.full_name] = extension_handle + + handle = extension_handle # avoid line wrapping + if _IsMessageSetExtension(handle): + # MessageSet extension. Also register under type name. + cls._extensions_by_name[ + extension_handle.message_type.full_name] = extension_handle + + cls.RegisterExtension = staticmethod(RegisterExtension) + + def FromString(s): + message = cls() + message.MergeFromString(s) + return message + cls.FromString = staticmethod(FromString) + + +def _IsPresent(item): + """Given a (FieldDescriptor, value) tuple from _fields, return true if the + value should be included in the list returned by ListFields().""" + + if item[0].label == _FieldDescriptor.LABEL_REPEATED: + return bool(item[1]) + elif item[0].cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + return item[1]._is_present_in_parent + else: + return True + + +def _AddListFieldsMethod(message_descriptor, cls): + """Helper for _AddMessageMethods().""" + + def ListFields(self): + all_fields = [item for item in self._fields.iteritems() if _IsPresent(item)] + all_fields.sort(key = lambda item: item[0].number) + return all_fields + + cls.ListFields = ListFields + + +def _AddHasFieldMethod(message_descriptor, cls): + """Helper for _AddMessageMethods().""" + + singular_fields = {} + for field in message_descriptor.fields: + if field.label != _FieldDescriptor.LABEL_REPEATED: + singular_fields[field.name] = field + # Fields inside oneofs are never repeated (enforced by the compiler). + for field in message_descriptor.oneofs: + singular_fields[field.name] = field + + def HasField(self, field_name): + try: + field = singular_fields[field_name] + except KeyError: + raise ValueError( + 'Protocol message has no singular "%s" field.' % field_name) + + if isinstance(field, descriptor_mod.OneofDescriptor): + try: + return HasField(self, self._oneofs[field].name) + except KeyError: + return False + else: + if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + value = self._fields.get(field) + return value is not None and value._is_present_in_parent + else: + return field in self._fields + + cls.HasField = HasField + + +def _AddClearFieldMethod(message_descriptor, cls): + """Helper for _AddMessageMethods().""" + def ClearField(self, field_name): + try: + field = message_descriptor.fields_by_name[field_name] + except KeyError: + try: + field = message_descriptor.oneofs_by_name[field_name] + if field in self._oneofs: + field = self._oneofs[field] + else: + return + except KeyError: + raise ValueError('Protocol message has no "%s" field.' % field_name) + + if field in self._fields: + # Note: If the field is a sub-message, its listener will still point + # at us. That's fine, because the worst than can happen is that it + # will call _Modified() and invalidate our byte size. Big deal. + del self._fields[field] + + if self._oneofs.get(field.containing_oneof, None) is field: + del self._oneofs[field.containing_oneof] + + # Always call _Modified() -- even if nothing was changed, this is + # a mutating method, and thus calling it should cause the field to become + # present in the parent message. + self._Modified() + + cls.ClearField = ClearField + + +def _AddClearExtensionMethod(cls): + """Helper for _AddMessageMethods().""" + def ClearExtension(self, extension_handle): + _VerifyExtensionHandle(self, extension_handle) + + # Similar to ClearField(), above. + if extension_handle in self._fields: + del self._fields[extension_handle] + self._Modified() + cls.ClearExtension = ClearExtension + + +def _AddClearMethod(message_descriptor, cls): + """Helper for _AddMessageMethods().""" + def Clear(self): + # Clear fields. + self._fields = {} + self._unknown_fields = () + self._Modified() + cls.Clear = Clear + + +def _AddHasExtensionMethod(cls): + """Helper for _AddMessageMethods().""" + def HasExtension(self, extension_handle): + _VerifyExtensionHandle(self, extension_handle) + if extension_handle.label == _FieldDescriptor.LABEL_REPEATED: + raise KeyError('"%s" is repeated.' % extension_handle.full_name) + + if extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + value = self._fields.get(extension_handle) + return value is not None and value._is_present_in_parent + else: + return extension_handle in self._fields + cls.HasExtension = HasExtension + + +def _AddEqualsMethod(message_descriptor, cls): + """Helper for _AddMessageMethods().""" + def __eq__(self, other): + if (not isinstance(other, message_mod.Message) or + other.DESCRIPTOR != self.DESCRIPTOR): + return False + + if self is other: + return True + + if not self.ListFields() == other.ListFields(): + return False + + # Sort unknown fields because their order shouldn't affect equality test. + unknown_fields = list(self._unknown_fields) + unknown_fields.sort() + other_unknown_fields = list(other._unknown_fields) + other_unknown_fields.sort() + + return unknown_fields == other_unknown_fields + + cls.__eq__ = __eq__ + + +def _AddStrMethod(message_descriptor, cls): + """Helper for _AddMessageMethods().""" + def __str__(self): + return text_format.MessageToString(self) + cls.__str__ = __str__ + + +def _AddUnicodeMethod(unused_message_descriptor, cls): + """Helper for _AddMessageMethods().""" + + def __unicode__(self): + return text_format.MessageToString(self, as_utf8=True).decode('utf-8') + cls.__unicode__ = __unicode__ + + +def _AddSetListenerMethod(cls): + """Helper for _AddMessageMethods().""" + def SetListener(self, listener): + if listener is None: + self._listener = message_listener_mod.NullMessageListener() + else: + self._listener = listener + cls._SetListener = SetListener + + +def _BytesForNonRepeatedElement(value, field_number, field_type): + """Returns the number of bytes needed to serialize a non-repeated element. + The returned byte count includes space for tag information and any + other additional space associated with serializing value. + + Args: + value: Value we're serializing. + field_number: Field number of this value. (Since the field number + is stored as part of a varint-encoded tag, this has an impact + on the total bytes required to serialize the value). + field_type: The type of the field. One of the TYPE_* constants + within FieldDescriptor. + """ + try: + fn = type_checkers.TYPE_TO_BYTE_SIZE_FN[field_type] + return fn(field_number, value) + except KeyError: + raise message_mod.EncodeError('Unrecognized field type: %d' % field_type) + + +def _AddByteSizeMethod(message_descriptor, cls): + """Helper for _AddMessageMethods().""" + + def ByteSize(self): + if not self._cached_byte_size_dirty: + return self._cached_byte_size + + size = 0 + for field_descriptor, field_value in self.ListFields(): + size += field_descriptor._sizer(field_value) + + for tag_bytes, value_bytes in self._unknown_fields: + size += len(tag_bytes) + len(value_bytes) + + self._cached_byte_size = size + self._cached_byte_size_dirty = False + self._listener_for_children.dirty = False + return size + + cls.ByteSize = ByteSize + + +def _AddSerializeToStringMethod(message_descriptor, cls): + """Helper for _AddMessageMethods().""" + + def SerializeToString(self): + # Check if the message has all of its required fields set. + errors = [] + if not self.IsInitialized(): + raise message_mod.EncodeError( + 'Message %s is missing required fields: %s' % ( + self.DESCRIPTOR.full_name, ','.join(self.FindInitializationErrors()))) + return self.SerializePartialToString() + cls.SerializeToString = SerializeToString + + +def _AddSerializePartialToStringMethod(message_descriptor, cls): + """Helper for _AddMessageMethods().""" + + def SerializePartialToString(self): + out = BytesIO() + self._InternalSerialize(out.write) + return out.getvalue() + cls.SerializePartialToString = SerializePartialToString + + def InternalSerialize(self, write_bytes): + for field_descriptor, field_value in self.ListFields(): + field_descriptor._encoder(write_bytes, field_value) + for tag_bytes, value_bytes in self._unknown_fields: + write_bytes(tag_bytes) + write_bytes(value_bytes) + cls._InternalSerialize = InternalSerialize + + +def _AddMergeFromStringMethod(message_descriptor, cls): + """Helper for _AddMessageMethods().""" + def MergeFromString(self, serialized): + length = len(serialized) + try: + if self._InternalParse(serialized, 0, length) != length: + # The only reason _InternalParse would return early is if it + # encountered an end-group tag. + raise message_mod.DecodeError('Unexpected end-group tag.') + except (IndexError, TypeError): + # Now ord(buf[p:p+1]) == ord('') gets TypeError. + raise message_mod.DecodeError('Truncated message.') + except struct.error, e: + raise message_mod.DecodeError(e) + return length # Return this for legacy reasons. + cls.MergeFromString = MergeFromString + + local_ReadTag = decoder.ReadTag + local_SkipField = decoder.SkipField + decoders_by_tag = cls._decoders_by_tag + + def InternalParse(self, buffer, pos, end): + self._Modified() + field_dict = self._fields + unknown_field_list = self._unknown_fields + while pos != end: + (tag_bytes, new_pos) = local_ReadTag(buffer, pos) + field_decoder, field_desc = decoders_by_tag.get(tag_bytes, (None, None)) + if field_decoder is None: + value_start_pos = new_pos + new_pos = local_SkipField(buffer, new_pos, end, tag_bytes) + if new_pos == -1: + return pos + if not unknown_field_list: + unknown_field_list = self._unknown_fields = [] + unknown_field_list.append((tag_bytes, buffer[value_start_pos:new_pos])) + pos = new_pos + else: + pos = field_decoder(buffer, new_pos, end, self, field_dict) + if field_desc: + self._UpdateOneofState(field_desc) + return pos + cls._InternalParse = InternalParse + + +def _AddIsInitializedMethod(message_descriptor, cls): + """Adds the IsInitialized and FindInitializationError methods to the + protocol message class.""" + + required_fields = [field for field in message_descriptor.fields + if field.label == _FieldDescriptor.LABEL_REQUIRED] + + def IsInitialized(self, errors=None): + """Checks if all required fields of a message are set. + + Args: + errors: A list which, if provided, will be populated with the field + paths of all missing required fields. + + Returns: + True iff the specified message has all required fields set. + """ + + # Performance is critical so we avoid HasField() and ListFields(). + + for field in required_fields: + if (field not in self._fields or + (field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE and + not self._fields[field]._is_present_in_parent)): + if errors is not None: + errors.extend(self.FindInitializationErrors()) + return False + + for field, value in list(self._fields.items()): # dict can change size! + if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + if field.label == _FieldDescriptor.LABEL_REPEATED: + for element in value: + if not element.IsInitialized(): + if errors is not None: + errors.extend(self.FindInitializationErrors()) + return False + elif value._is_present_in_parent and not value.IsInitialized(): + if errors is not None: + errors.extend(self.FindInitializationErrors()) + return False + + return True + + cls.IsInitialized = IsInitialized + + def FindInitializationErrors(self): + """Finds required fields which are not initialized. + + Returns: + A list of strings. Each string is a path to an uninitialized field from + the top-level message, e.g. "foo.bar[5].baz". + """ + + errors = [] # simplify things + + for field in required_fields: + if not self.HasField(field.name): + errors.append(field.name) + + for field, value in self.ListFields(): + if field.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + if field.is_extension: + name = "(%s)" % field.full_name + else: + name = field.name + + if field.label == _FieldDescriptor.LABEL_REPEATED: + for i in xrange(len(value)): + element = value[i] + prefix = "%s[%d]." % (name, i) + sub_errors = element.FindInitializationErrors() + errors += [ prefix + error for error in sub_errors ] + else: + prefix = name + "." + sub_errors = value.FindInitializationErrors() + errors += [ prefix + error for error in sub_errors ] + + return errors + + cls.FindInitializationErrors = FindInitializationErrors + + +def _AddMergeFromMethod(cls): + LABEL_REPEATED = _FieldDescriptor.LABEL_REPEATED + CPPTYPE_MESSAGE = _FieldDescriptor.CPPTYPE_MESSAGE + + def MergeFrom(self, msg): + if not isinstance(msg, cls): + raise TypeError( + "Parameter to MergeFrom() must be instance of same class: " + "expected %s got %s." % (cls.__name__, type(msg).__name__)) + + assert msg is not self + self._Modified() + + fields = self._fields + + for field, value in msg._fields.iteritems(): + if field.label == LABEL_REPEATED: + field_value = fields.get(field) + if field_value is None: + # Construct a new object to represent this field. + field_value = field._default_constructor(self) + fields[field] = field_value + field_value.MergeFrom(value) + elif field.cpp_type == CPPTYPE_MESSAGE: + if value._is_present_in_parent: + field_value = fields.get(field) + if field_value is None: + # Construct a new object to represent this field. + field_value = field._default_constructor(self) + fields[field] = field_value + field_value.MergeFrom(value) + else: + self._fields[field] = value + + if msg._unknown_fields: + if not self._unknown_fields: + self._unknown_fields = [] + self._unknown_fields.extend(msg._unknown_fields) + + cls.MergeFrom = MergeFrom + + +def _AddWhichOneofMethod(message_descriptor, cls): + def WhichOneof(self, oneof_name): + """Returns the name of the currently set field inside a oneof, or None.""" + try: + field = message_descriptor.oneofs_by_name[oneof_name] + except KeyError: + raise ValueError( + 'Protocol message has no oneof "%s" field.' % oneof_name) + + nested_field = self._oneofs.get(field, None) + if nested_field is not None and self.HasField(nested_field.name): + return nested_field.name + else: + return None + + cls.WhichOneof = WhichOneof + + +def _AddMessageMethods(message_descriptor, cls): + """Adds implementations of all Message methods to cls.""" + _AddListFieldsMethod(message_descriptor, cls) + _AddHasFieldMethod(message_descriptor, cls) + _AddClearFieldMethod(message_descriptor, cls) + if message_descriptor.is_extendable: + _AddClearExtensionMethod(cls) + _AddHasExtensionMethod(cls) + _AddClearMethod(message_descriptor, cls) + _AddEqualsMethod(message_descriptor, cls) + _AddStrMethod(message_descriptor, cls) + _AddUnicodeMethod(message_descriptor, cls) + _AddSetListenerMethod(cls) + _AddByteSizeMethod(message_descriptor, cls) + _AddSerializeToStringMethod(message_descriptor, cls) + _AddSerializePartialToStringMethod(message_descriptor, cls) + _AddMergeFromStringMethod(message_descriptor, cls) + _AddIsInitializedMethod(message_descriptor, cls) + _AddMergeFromMethod(cls) + _AddWhichOneofMethod(message_descriptor, cls) + +def _AddPrivateHelperMethods(message_descriptor, cls): + """Adds implementation of private helper methods to cls.""" + + def Modified(self): + """Sets the _cached_byte_size_dirty bit to true, + and propagates this to our listener iff this was a state change. + """ + + # Note: Some callers check _cached_byte_size_dirty before calling + # _Modified() as an extra optimization. So, if this method is ever + # changed such that it does stuff even when _cached_byte_size_dirty is + # already true, the callers need to be updated. + if not self._cached_byte_size_dirty: + self._cached_byte_size_dirty = True + self._listener_for_children.dirty = True + self._is_present_in_parent = True + self._listener.Modified() + + def _UpdateOneofState(self, field): + """Sets field as the active field in its containing oneof. + + Will also delete currently active field in the oneof, if it is different + from the argument. Does not mark the message as modified. + """ + other_field = self._oneofs.setdefault(field.containing_oneof, field) + if other_field is not field: + del self._fields[other_field] + self._oneofs[field.containing_oneof] = field + + cls._Modified = Modified + cls.SetInParent = Modified + cls._UpdateOneofState = _UpdateOneofState + + +class _Listener(object): + + """MessageListener implementation that a parent message registers with its + child message. + + In order to support semantics like: + + foo.bar.baz.qux = 23 + assert foo.HasField('bar') + + ...child objects must have back references to their parents. + This helper class is at the heart of this support. + """ + + def __init__(self, parent_message): + """Args: + parent_message: The message whose _Modified() method we should call when + we receive Modified() messages. + """ + # This listener establishes a back reference from a child (contained) object + # to its parent (containing) object. We make this a weak reference to avoid + # creating cyclic garbage when the client finishes with the 'parent' object + # in the tree. + if isinstance(parent_message, weakref.ProxyType): + self._parent_message_weakref = parent_message + else: + self._parent_message_weakref = weakref.proxy(parent_message) + + # As an optimization, we also indicate directly on the listener whether + # or not the parent message is dirty. This way we can avoid traversing + # up the tree in the common case. + self.dirty = False + + def Modified(self): + if self.dirty: + return + try: + # Propagate the signal to our parents iff this is the first field set. + self._parent_message_weakref._Modified() + except ReferenceError: + # We can get here if a client has kept a reference to a child object, + # and is now setting a field on it, but the child's parent has been + # garbage-collected. This is not an error. + pass + + +class _OneofListener(_Listener): + """Special listener implementation for setting composite oneof fields.""" + + def __init__(self, parent_message, field): + """Args: + parent_message: The message whose _Modified() method we should call when + we receive Modified() messages. + field: The descriptor of the field being set in the parent message. + """ + super(_OneofListener, self).__init__(parent_message) + self._field = field + + def Modified(self): + """Also updates the state of the containing oneof in the parent message.""" + try: + self._parent_message_weakref._UpdateOneofState(self._field) + super(_OneofListener, self).Modified() + except ReferenceError: + pass + + +# TODO(robinson): Move elsewhere? This file is getting pretty ridiculous... +# TODO(robinson): Unify error handling of "unknown extension" crap. +# TODO(robinson): Support iteritems()-style iteration over all +# extensions with the "has" bits turned on? +class _ExtensionDict(object): + + """Dict-like container for supporting an indexable "Extensions" + field on proto instances. + + Note that in all cases we expect extension handles to be + FieldDescriptors. + """ + + def __init__(self, extended_message): + """extended_message: Message instance for which we are the Extensions dict. + """ + + self._extended_message = extended_message + + def __getitem__(self, extension_handle): + """Returns the current value of the given extension handle.""" + + _VerifyExtensionHandle(self._extended_message, extension_handle) + + result = self._extended_message._fields.get(extension_handle) + if result is not None: + return result + + if extension_handle.label == _FieldDescriptor.LABEL_REPEATED: + result = extension_handle._default_constructor(self._extended_message) + elif extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE: + result = extension_handle.message_type._concrete_class() + try: + result._SetListener(self._extended_message._listener_for_children) + except ReferenceError: + pass + else: + # Singular scalar -- just return the default without inserting into the + # dict. + return extension_handle.default_value + + # Atomically check if another thread has preempted us and, if not, swap + # in the new object we just created. If someone has preempted us, we + # take that object and discard ours. + # WARNING: We are relying on setdefault() being atomic. This is true + # in CPython but we haven't investigated others. This warning appears + # in several other locations in this file. + result = self._extended_message._fields.setdefault( + extension_handle, result) + + return result + + def __eq__(self, other): + if not isinstance(other, self.__class__): + return False + + my_fields = self._extended_message.ListFields() + other_fields = other._extended_message.ListFields() + + # Get rid of non-extension fields. + my_fields = [ field for field in my_fields if field.is_extension ] + other_fields = [ field for field in other_fields if field.is_extension ] + + return my_fields == other_fields + + def __ne__(self, other): + return not self == other + + def __hash__(self): + raise TypeError('unhashable object') + + # Note that this is only meaningful for non-repeated, scalar extension + # fields. Note also that we may have to call _Modified() when we do + # successfully set a field this way, to set any necssary "has" bits in the + # ancestors of the extended message. + def __setitem__(self, extension_handle, value): + """If extension_handle specifies a non-repeated, scalar extension + field, sets the value of that field. + """ + + _VerifyExtensionHandle(self._extended_message, extension_handle) + + if (extension_handle.label == _FieldDescriptor.LABEL_REPEATED or + extension_handle.cpp_type == _FieldDescriptor.CPPTYPE_MESSAGE): + raise TypeError( + 'Cannot assign to extension "%s" because it is a repeated or ' + 'composite type.' % extension_handle.full_name) + + # It's slightly wasteful to lookup the type checker each time, + # but we expect this to be a vanishingly uncommon case anyway. + type_checker = type_checkers.GetTypeChecker( + extension_handle) + # pylint: disable=protected-access + self._extended_message._fields[extension_handle] = ( + type_checker.CheckValue(value)) + self._extended_message._Modified() + + def _FindExtensionByName(self, name): + """Tries to find a known extension with the specified name. + + Args: + name: Extension full name. + + Returns: + Extension field descriptor. + """ + return self._extended_message._extensions_by_name.get(name, None) diff --git a/python/google/protobuf/internal/reflection_test.py b/python/google/protobuf/internal/reflection_test.py index 2c9fa30..d59815d 100755 --- a/python/google/protobuf/internal/reflection_test.py +++ b/python/google/protobuf/internal/reflection_test.py @@ -3,7 +3,7 @@ # # Protocol Buffers - Google's data interchange format # Copyright 2008 Google Inc. All rights reserved. -# http://code.google.com/p/protobuf/ +# https://developers.google.com/protocol-buffers/ # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -37,12 +37,12 @@ pure-Python protocol compiler. __author__ = 'robinson@google.com (Will Robinson)' +import copy +import gc import operator import struct -import unittest -# TODO(robinson): When we split this test in two, only some of these imports -# will be necessary in each test. +from google.apputils import basetest from google.protobuf import unittest_import_pb2 from google.protobuf import unittest_mset_pb2 from google.protobuf import unittest_pb2 @@ -50,6 +50,8 @@ from google.protobuf import descriptor_pb2 from google.protobuf import descriptor from google.protobuf import message from google.protobuf import reflection +from google.protobuf import text_format +from google.protobuf.internal import api_implementation from google.protobuf.internal import more_extensions_pb2 from google.protobuf.internal import more_messages_pb2 from google.protobuf.internal import wire_format @@ -102,12 +104,12 @@ class _MiniDecoder(object): return self._pos == len(self._bytes) -class ReflectionTest(unittest.TestCase): +class ReflectionTest(basetest.TestCase): - def assertIs(self, values, others): + def assertListsEqual(self, values, others): self.assertEqual(len(values), len(others)) for i in range(len(values)): - self.assertTrue(values[i] is others[i]) + self.assertEqual(values[i], others[i]) def testScalarConstructor(self): # Constructor with only scalar types should succeed. @@ -200,6 +202,41 @@ class ReflectionTest(unittest.TestCase): unittest_pb2.ForeignMessage(c=12)], list(proto.repeated_foreign_message)) + def testConstructorTypeError(self): + self.assertRaises( + TypeError, unittest_pb2.TestAllTypes, optional_int32="foo") + self.assertRaises( + TypeError, unittest_pb2.TestAllTypes, optional_string=1234) + self.assertRaises( + TypeError, unittest_pb2.TestAllTypes, optional_nested_message=1234) + self.assertRaises( + TypeError, unittest_pb2.TestAllTypes, repeated_int32=1234) + self.assertRaises( + TypeError, unittest_pb2.TestAllTypes, repeated_int32=["foo"]) + self.assertRaises( + TypeError, unittest_pb2.TestAllTypes, repeated_string=1234) + self.assertRaises( + TypeError, unittest_pb2.TestAllTypes, repeated_string=[1234]) + self.assertRaises( + TypeError, unittest_pb2.TestAllTypes, repeated_nested_message=1234) + self.assertRaises( + TypeError, unittest_pb2.TestAllTypes, repeated_nested_message=[1234]) + + def testConstructorInvalidatesCachedByteSize(self): + message = unittest_pb2.TestAllTypes(optional_int32 = 12) + self.assertEquals(2, message.ByteSize()) + + message = unittest_pb2.TestAllTypes( + optional_nested_message = unittest_pb2.TestAllTypes.NestedMessage()) + self.assertEquals(3, message.ByteSize()) + + message = unittest_pb2.TestAllTypes(repeated_int32 = [12]) + self.assertEquals(3, message.ByteSize()) + + message = unittest_pb2.TestAllTypes( + repeated_nested_message = [unittest_pb2.TestAllTypes.NestedMessage()]) + self.assertEquals(3, message.ByteSize()) + def testSimpleHasBits(self): # Test a scalar. proto = unittest_pb2.TestAllTypes() @@ -284,12 +321,6 @@ class ReflectionTest(unittest.TestCase): # ...and ensure that the scalar field has returned to its default. self.assertEqual(0, getattr(composite_field, scalar_field_name)) - # Finally, ensure that modifications to the old composite field object - # don't have any effect on the parent. - # - # (NOTE that when we clear the composite field in the parent, we actually - # don't recursively clear down the tree. Instead, we just disconnect the - # cleared composite from the tree.) self.assertTrue(old_composite_field is not composite_field) setattr(old_composite_field, scalar_field_name, new_val) self.assertTrue(not composite_field.HasField(scalar_field_name)) @@ -319,6 +350,64 @@ class ReflectionTest(unittest.TestCase): self.assertTrue(not proto.HasField('optional_nested_message')) self.assertEqual(0, proto.optional_nested_message.bb) + def testGetDefaultMessageAfterDisconnectingDefaultMessage(self): + proto = unittest_pb2.TestAllTypes() + nested = proto.optional_nested_message + proto.ClearField('optional_nested_message') + del proto + del nested + # Force a garbage collect so that the underlying CMessages are freed along + # with the Messages they point to. This is to make sure we're not deleting + # default message instances. + gc.collect() + proto = unittest_pb2.TestAllTypes() + nested = proto.optional_nested_message + + def testDisconnectingNestedMessageAfterSettingField(self): + proto = unittest_pb2.TestAllTypes() + nested = proto.optional_nested_message + nested.bb = 5 + self.assertTrue(proto.HasField('optional_nested_message')) + proto.ClearField('optional_nested_message') # Should disconnect from parent + self.assertEqual(5, nested.bb) + self.assertEqual(0, proto.optional_nested_message.bb) + self.assertTrue(nested is not proto.optional_nested_message) + nested.bb = 23 + self.assertTrue(not proto.HasField('optional_nested_message')) + self.assertEqual(0, proto.optional_nested_message.bb) + + def testDisconnectingNestedMessageBeforeGettingField(self): + proto = unittest_pb2.TestAllTypes() + self.assertTrue(not proto.HasField('optional_nested_message')) + proto.ClearField('optional_nested_message') + self.assertTrue(not proto.HasField('optional_nested_message')) + + def testDisconnectingNestedMessageAfterMerge(self): + # This test exercises the code path that does not use ReleaseMessage(). + # The underlying fear is that if we use ReleaseMessage() incorrectly, + # we will have memory leaks. It's hard to check that that doesn't happen, + # but at least we can exercise that code path to make sure it works. + proto1 = unittest_pb2.TestAllTypes() + proto2 = unittest_pb2.TestAllTypes() + proto2.optional_nested_message.bb = 5 + proto1.MergeFrom(proto2) + self.assertTrue(proto1.HasField('optional_nested_message')) + proto1.ClearField('optional_nested_message') + self.assertTrue(not proto1.HasField('optional_nested_message')) + + def testDisconnectingLazyNestedMessage(self): + # This test exercises releasing a nested message that is lazy. This test + # only exercises real code in the C++ implementation as Python does not + # support lazy parsing, but the current C++ implementation results in + # memory corruption and a crash. + if api_implementation.Type() != 'python': + return + proto = unittest_pb2.TestAllTypes() + proto.optional_lazy_message.bb = 5 + proto.ClearField('optional_lazy_message') + del proto + gc.collect() + def testHasBitsWhenModifyingRepeatedFields(self): # Test nesting when we add an element to a repeated field in a submessage. proto = unittest_pb2.TestNestedMessageHasBits() @@ -446,7 +535,7 @@ class ReflectionTest(unittest.TestCase): self.assertEqual(0.0, proto.optional_double) self.assertEqual(False, proto.optional_bool) self.assertEqual('', proto.optional_string) - self.assertEqual('', proto.optional_bytes) + self.assertEqual(b'', proto.optional_bytes) self.assertEqual(41, proto.default_int32) self.assertEqual(42, proto.default_int64) @@ -462,7 +551,7 @@ class ReflectionTest(unittest.TestCase): self.assertEqual(52e3, proto.default_double) self.assertEqual(True, proto.default_bool) self.assertEqual('hello', proto.default_string) - self.assertEqual('world', proto.default_bytes) + self.assertEqual(b'world', proto.default_bytes) self.assertEqual(unittest_pb2.TestAllTypes.BAR, proto.default_nested_enum) self.assertEqual(unittest_pb2.FOREIGN_BAR, proto.default_foreign_enum) self.assertEqual(unittest_import_pb2.IMPORT_BAR, @@ -479,6 +568,17 @@ class ReflectionTest(unittest.TestCase): proto = unittest_pb2.TestAllTypes() self.assertRaises(ValueError, proto.ClearField, 'nonexistent_field') + def testClearRemovesChildren(self): + # Make sure there aren't any implementation bugs that are only partially + # clearing the message (which can happen in the more complex C++ + # implementation which has parallel message lists). + proto = unittest_pb2.TestRequiredForeign() + for i in range(10): + proto.repeated_message.add() + proto2 = unittest_pb2.TestRequiredForeign() + proto.CopyFrom(proto2) + self.assertRaises(IndexError, lambda: proto.repeated_message[5]) + def testDisallowedAssignments(self): # It's illegal to assign values directly to repeated fields # or to nonrepeated composite fields. Ensure that this fails. @@ -500,7 +600,6 @@ class ReflectionTest(unittest.TestCase): # proto.nonexistent_field = 23 should fail as well. self.assertRaises(AttributeError, setattr, proto, 'nonexistent_field', 23) - # TODO(robinson): Add type-safety check for enums. def testSingleScalarTypeSafety(self): proto = unittest_pb2.TestAllTypes() self.assertRaises(TypeError, setattr, proto, 'optional_int32', 1.1) @@ -508,11 +607,37 @@ class ReflectionTest(unittest.TestCase): self.assertRaises(TypeError, setattr, proto, 'optional_string', 10) self.assertRaises(TypeError, setattr, proto, 'optional_bytes', 10) + def testIntegerTypes(self): + def TestGetAndDeserialize(field_name, value, expected_type): + proto = unittest_pb2.TestAllTypes() + setattr(proto, field_name, value) + self.assertTrue(isinstance(getattr(proto, field_name), expected_type)) + proto2 = unittest_pb2.TestAllTypes() + proto2.ParseFromString(proto.SerializeToString()) + self.assertTrue(isinstance(getattr(proto2, field_name), expected_type)) + + TestGetAndDeserialize('optional_int32', 1, int) + TestGetAndDeserialize('optional_int32', 1 << 30, int) + TestGetAndDeserialize('optional_uint32', 1 << 30, int) + if struct.calcsize('L') == 4: + # Python only has signed ints, so 32-bit python can't fit an uint32 + # in an int. + TestGetAndDeserialize('optional_uint32', 1 << 31, long) + else: + # 64-bit python can fit uint32 inside an int + TestGetAndDeserialize('optional_uint32', 1 << 31, int) + TestGetAndDeserialize('optional_int64', 1 << 30, long) + TestGetAndDeserialize('optional_int64', 1 << 60, long) + TestGetAndDeserialize('optional_uint64', 1 << 30, long) + TestGetAndDeserialize('optional_uint64', 1 << 60, long) + def testSingleScalarBoundsChecking(self): def TestMinAndMaxIntegers(field_name, expected_min, expected_max): pb = unittest_pb2.TestAllTypes() setattr(pb, field_name, expected_min) + self.assertEqual(expected_min, getattr(pb, field_name)) setattr(pb, field_name, expected_max) + self.assertEqual(expected_max, getattr(pb, field_name)) self.assertRaises(ValueError, setattr, pb, field_name, expected_min - 1) self.assertRaises(ValueError, setattr, pb, field_name, expected_max + 1) @@ -520,7 +645,10 @@ class ReflectionTest(unittest.TestCase): TestMinAndMaxIntegers('optional_uint32', 0, 0xffffffff) TestMinAndMaxIntegers('optional_int64', -(1 << 63), (1 << 63) - 1) TestMinAndMaxIntegers('optional_uint64', 0, 0xffffffffffffffff) - TestMinAndMaxIntegers('optional_nested_enum', -(1 << 31), (1 << 31) - 1) + + pb = unittest_pb2.TestAllTypes() + pb.optional_nested_enum = 1 + self.assertEqual(1, pb.optional_nested_enum) def testRepeatedScalarTypeSafety(self): proto = unittest_pb2.TestAllTypes() @@ -534,11 +662,19 @@ class ReflectionTest(unittest.TestCase): self.assertRaises(IndexError, proto.repeated_int32.__setitem__, 500, 23) self.assertRaises(TypeError, proto.repeated_int32.__setitem__, 0, 'abc') + # Repeated enums tests. + #proto.repeated_nested_enum.append(0) + def testSingleScalarGettersAndSetters(self): proto = unittest_pb2.TestAllTypes() self.assertEqual(0, proto.optional_int32) proto.optional_int32 = 1 self.assertEqual(1, proto.optional_int32) + + proto.optional_uint64 = 0xffffffffffff + self.assertEqual(0xffffffffffff, proto.optional_uint64) + proto.optional_uint64 = 0xffffffffffffffff + self.assertEqual(0xffffffffffffffff, proto.optional_uint64) # TODO(robinson): Test all other scalar field types. def testSingleScalarClearField(self): @@ -561,6 +697,77 @@ class ReflectionTest(unittest.TestCase): self.assertEqual(3, proto.BAZ) self.assertEqual(3, unittest_pb2.TestAllTypes.BAZ) + def testEnum_Name(self): + self.assertEqual('FOREIGN_FOO', + unittest_pb2.ForeignEnum.Name(unittest_pb2.FOREIGN_FOO)) + self.assertEqual('FOREIGN_BAR', + unittest_pb2.ForeignEnum.Name(unittest_pb2.FOREIGN_BAR)) + self.assertEqual('FOREIGN_BAZ', + unittest_pb2.ForeignEnum.Name(unittest_pb2.FOREIGN_BAZ)) + self.assertRaises(ValueError, + unittest_pb2.ForeignEnum.Name, 11312) + + proto = unittest_pb2.TestAllTypes() + self.assertEqual('FOO', + proto.NestedEnum.Name(proto.FOO)) + self.assertEqual('FOO', + unittest_pb2.TestAllTypes.NestedEnum.Name(proto.FOO)) + self.assertEqual('BAR', + proto.NestedEnum.Name(proto.BAR)) + self.assertEqual('BAR', + unittest_pb2.TestAllTypes.NestedEnum.Name(proto.BAR)) + self.assertEqual('BAZ', + proto.NestedEnum.Name(proto.BAZ)) + self.assertEqual('BAZ', + unittest_pb2.TestAllTypes.NestedEnum.Name(proto.BAZ)) + self.assertRaises(ValueError, + proto.NestedEnum.Name, 11312) + self.assertRaises(ValueError, + unittest_pb2.TestAllTypes.NestedEnum.Name, 11312) + + def testEnum_Value(self): + self.assertEqual(unittest_pb2.FOREIGN_FOO, + unittest_pb2.ForeignEnum.Value('FOREIGN_FOO')) + self.assertEqual(unittest_pb2.FOREIGN_BAR, + unittest_pb2.ForeignEnum.Value('FOREIGN_BAR')) + self.assertEqual(unittest_pb2.FOREIGN_BAZ, + unittest_pb2.ForeignEnum.Value('FOREIGN_BAZ')) + self.assertRaises(ValueError, + unittest_pb2.ForeignEnum.Value, 'FO') + + proto = unittest_pb2.TestAllTypes() + self.assertEqual(proto.FOO, + proto.NestedEnum.Value('FOO')) + self.assertEqual(proto.FOO, + unittest_pb2.TestAllTypes.NestedEnum.Value('FOO')) + self.assertEqual(proto.BAR, + proto.NestedEnum.Value('BAR')) + self.assertEqual(proto.BAR, + unittest_pb2.TestAllTypes.NestedEnum.Value('BAR')) + self.assertEqual(proto.BAZ, + proto.NestedEnum.Value('BAZ')) + self.assertEqual(proto.BAZ, + unittest_pb2.TestAllTypes.NestedEnum.Value('BAZ')) + self.assertRaises(ValueError, + proto.NestedEnum.Value, 'Foo') + self.assertRaises(ValueError, + unittest_pb2.TestAllTypes.NestedEnum.Value, 'Foo') + + def testEnum_KeysAndValues(self): + self.assertEqual(['FOREIGN_FOO', 'FOREIGN_BAR', 'FOREIGN_BAZ'], + unittest_pb2.ForeignEnum.keys()) + self.assertEqual([4, 5, 6], + unittest_pb2.ForeignEnum.values()) + self.assertEqual([('FOREIGN_FOO', 4), ('FOREIGN_BAR', 5), + ('FOREIGN_BAZ', 6)], + unittest_pb2.ForeignEnum.items()) + + proto = unittest_pb2.TestAllTypes() + self.assertEqual(['FOO', 'BAR', 'BAZ', 'NEG'], proto.NestedEnum.keys()) + self.assertEqual([1, 2, 3, -1], proto.NestedEnum.values()) + self.assertEqual([('FOO', 1), ('BAR', 2), ('BAZ', 3), ('NEG', -1)], + proto.NestedEnum.items()) + def testRepeatedScalars(self): proto = unittest_pb2.TestAllTypes() @@ -619,11 +826,38 @@ class ReflectionTest(unittest.TestCase): del proto.repeated_int32[2:] self.assertEqual([5, 35], proto.repeated_int32) + # Test extending. + proto.repeated_int32.extend([3, 13]) + self.assertEqual([5, 35, 3, 13], proto.repeated_int32) + # Test clearing. proto.ClearField('repeated_int32') self.assertTrue(not proto.repeated_int32) self.assertEqual(0, len(proto.repeated_int32)) + proto.repeated_int32.append(1) + self.assertEqual(1, proto.repeated_int32[-1]) + # Test assignment to a negative index. + proto.repeated_int32[-1] = 2 + self.assertEqual(2, proto.repeated_int32[-1]) + + # Test deletion at negative indices. + proto.repeated_int32[:] = [0, 1, 2, 3] + del proto.repeated_int32[-1] + self.assertEqual([0, 1, 2], proto.repeated_int32) + + del proto.repeated_int32[-2] + self.assertEqual([0, 2], proto.repeated_int32) + + self.assertRaises(IndexError, proto.repeated_int32.__delitem__, -3) + self.assertRaises(IndexError, proto.repeated_int32.__delitem__, 300) + + del proto.repeated_int32[-2:-1] + self.assertEqual([2], proto.repeated_int32) + + del proto.repeated_int32[100:10000] + self.assertEqual([2], proto.repeated_int32) + def testRepeatedScalarsRemove(self): proto = unittest_pb2.TestAllTypes() @@ -661,7 +895,7 @@ class ReflectionTest(unittest.TestCase): m1 = proto.repeated_nested_message.add() self.assertTrue(proto.repeated_nested_message) self.assertEqual(2, len(proto.repeated_nested_message)) - self.assertIs([m0, m1], proto.repeated_nested_message) + self.assertListsEqual([m0, m1], proto.repeated_nested_message) self.assertTrue(isinstance(m0, unittest_pb2.TestAllTypes.NestedMessage)) # Test out-of-bounds indices. @@ -680,32 +914,86 @@ class ReflectionTest(unittest.TestCase): m2 = proto.repeated_nested_message.add() m3 = proto.repeated_nested_message.add() m4 = proto.repeated_nested_message.add() - self.assertIs([m1, m2, m3], proto.repeated_nested_message[1:4]) - self.assertIs([m0, m1, m2, m3, m4], proto.repeated_nested_message[:]) + self.assertListsEqual( + [m1, m2, m3], proto.repeated_nested_message[1:4]) + self.assertListsEqual( + [m0, m1, m2, m3, m4], proto.repeated_nested_message[:]) + self.assertListsEqual( + [m0, m1], proto.repeated_nested_message[:2]) + self.assertListsEqual( + [m2, m3, m4], proto.repeated_nested_message[2:]) + self.assertEqual( + m0, proto.repeated_nested_message[0]) + self.assertListsEqual( + [m0], proto.repeated_nested_message[:1]) # Test that we can use the field as an iterator. result = [] for i in proto.repeated_nested_message: result.append(i) - self.assertIs([m0, m1, m2, m3, m4], result) + self.assertListsEqual([m0, m1, m2, m3, m4], result) # Test single deletion. del proto.repeated_nested_message[2] - self.assertIs([m0, m1, m3, m4], proto.repeated_nested_message) + self.assertListsEqual([m0, m1, m3, m4], proto.repeated_nested_message) # Test slice deletion. del proto.repeated_nested_message[2:] - self.assertIs([m0, m1], proto.repeated_nested_message) + self.assertListsEqual([m0, m1], proto.repeated_nested_message) + + # Test extending. + n1 = unittest_pb2.TestAllTypes.NestedMessage(bb=1) + n2 = unittest_pb2.TestAllTypes.NestedMessage(bb=2) + proto.repeated_nested_message.extend([n1,n2]) + self.assertEqual(4, len(proto.repeated_nested_message)) + self.assertEqual(n1, proto.repeated_nested_message[2]) + self.assertEqual(n2, proto.repeated_nested_message[3]) # Test clearing. proto.ClearField('repeated_nested_message') self.assertTrue(not proto.repeated_nested_message) self.assertEqual(0, len(proto.repeated_nested_message)) + # Test constructing an element while adding it. + proto.repeated_nested_message.add(bb=23) + self.assertEqual(1, len(proto.repeated_nested_message)) + self.assertEqual(23, proto.repeated_nested_message[0].bb) + + def testRepeatedCompositeRemove(self): + proto = unittest_pb2.TestAllTypes() + + self.assertEqual(0, len(proto.repeated_nested_message)) + m0 = proto.repeated_nested_message.add() + # Need to set some differentiating variable so m0 != m1 != m2: + m0.bb = len(proto.repeated_nested_message) + m1 = proto.repeated_nested_message.add() + m1.bb = len(proto.repeated_nested_message) + self.assertTrue(m0 != m1) + m2 = proto.repeated_nested_message.add() + m2.bb = len(proto.repeated_nested_message) + self.assertListsEqual([m0, m1, m2], proto.repeated_nested_message) + + self.assertEqual(3, len(proto.repeated_nested_message)) + proto.repeated_nested_message.remove(m0) + self.assertEqual(2, len(proto.repeated_nested_message)) + self.assertEqual(m1, proto.repeated_nested_message[0]) + self.assertEqual(m2, proto.repeated_nested_message[1]) + + # Removing m0 again or removing None should raise error + self.assertRaises(ValueError, proto.repeated_nested_message.remove, m0) + self.assertRaises(ValueError, proto.repeated_nested_message.remove, None) + self.assertEqual(2, len(proto.repeated_nested_message)) + + proto.repeated_nested_message.remove(m2) + self.assertEqual(1, len(proto.repeated_nested_message)) + self.assertEqual(m1, proto.repeated_nested_message[0]) + def testHandWrittenReflection(self): - # TODO(robinson): We probably need a better way to specify - # protocol types by hand. But then again, this isn't something - # we expect many people to do. Hmm. + # Hand written extensions are only supported by the pure-Python + # implementation of the API. + if api_implementation.Type() != 'python': + return + FieldDescriptor = descriptor.FieldDescriptor foo_field_descriptor = FieldDescriptor( name='foo_field', full_name='MyProto.foo_field', @@ -730,6 +1018,68 @@ class ReflectionTest(unittest.TestCase): self.assertEqual(23, myproto_instance.foo_field) self.assertTrue(myproto_instance.HasField('foo_field')) + def testDescriptorProtoSupport(self): + # Hand written descriptors/reflection are only supported by the pure-Python + # implementation of the API. + if api_implementation.Type() != 'python': + return + + def AddDescriptorField(proto, field_name, field_type): + AddDescriptorField.field_index += 1 + new_field = proto.field.add() + new_field.name = field_name + new_field.type = field_type + new_field.number = AddDescriptorField.field_index + new_field.label = descriptor_pb2.FieldDescriptorProto.LABEL_OPTIONAL + + AddDescriptorField.field_index = 0 + + desc_proto = descriptor_pb2.DescriptorProto() + desc_proto.name = 'Car' + fdp = descriptor_pb2.FieldDescriptorProto + AddDescriptorField(desc_proto, 'name', fdp.TYPE_STRING) + AddDescriptorField(desc_proto, 'year', fdp.TYPE_INT64) + AddDescriptorField(desc_proto, 'automatic', fdp.TYPE_BOOL) + AddDescriptorField(desc_proto, 'price', fdp.TYPE_DOUBLE) + # Add a repeated field + AddDescriptorField.field_index += 1 + new_field = desc_proto.field.add() + new_field.name = 'owners' + new_field.type = fdp.TYPE_STRING + new_field.number = AddDescriptorField.field_index + new_field.label = descriptor_pb2.FieldDescriptorProto.LABEL_REPEATED + + desc = descriptor.MakeDescriptor(desc_proto) + self.assertTrue(desc.fields_by_name.has_key('name')) + self.assertTrue(desc.fields_by_name.has_key('year')) + self.assertTrue(desc.fields_by_name.has_key('automatic')) + self.assertTrue(desc.fields_by_name.has_key('price')) + self.assertTrue(desc.fields_by_name.has_key('owners')) + + class CarMessage(message.Message): + __metaclass__ = reflection.GeneratedProtocolMessageType + DESCRIPTOR = desc + + prius = CarMessage() + prius.name = 'prius' + prius.year = 2010 + prius.automatic = True + prius.price = 25134.75 + prius.owners.extend(['bob', 'susan']) + + serialized_prius = prius.SerializeToString() + new_prius = reflection.ParseMessage(desc, serialized_prius) + self.assertTrue(new_prius is not prius) + self.assertEqual(prius, new_prius) + + # these are unnecessary assuming message equality works as advertised but + # explicitly check to be safe since we're mucking about in metaclass foo + self.assertEqual(prius.name, new_prius.name) + self.assertEqual(prius.year, new_prius.year) + self.assertEqual(prius.automatic, new_prius.automatic) + self.assertEqual(prius.price, new_prius.price) + self.assertEqual(prius.owners, new_prius.owners) + def testTopLevelExtensionsForOptionalScalar(self): extendee_proto = unittest_pb2.TestAllExtensions() extension = unittest_pb2.optional_int32_extension @@ -819,6 +1169,14 @@ class ReflectionTest(unittest.TestCase): self.assertTrue(required is not extendee_proto.Extensions[extension]) self.assertTrue(not extendee_proto.HasExtension(extension)) + def testRegisteredExtensions(self): + self.assertTrue('protobuf_unittest.optional_int32_extension' in + unittest_pb2.TestAllExtensions._extensions_by_name) + self.assertTrue(1 in unittest_pb2.TestAllExtensions._extensions_by_number) + # Make sure extensions haven't been registered into types that shouldn't + # have any. + self.assertEquals(0, len(unittest_pb2.TestAllTypes._extensions_by_name)) + # If message A directly contains message B, and # a.HasField('b') is currently False, then mutating any # extension in B should change a.HasField('b') to True @@ -868,7 +1226,7 @@ class ReflectionTest(unittest.TestCase): self.assertTrue(not toplevel.HasField('submessage')) foreign = toplevel.submessage.Extensions[ more_extensions_pb2.repeated_message_extension].add() - self.assertTrue(foreign is toplevel.submessage.Extensions[ + self.assertEqual(foreign, toplevel.submessage.Extensions[ more_extensions_pb2.repeated_message_extension][0]) self.assertTrue(toplevel.HasField('submessage')) @@ -971,6 +1329,12 @@ class ReflectionTest(unittest.TestCase): self.assertEqual(123, proto2.repeated_nested_message[1].bb) self.assertEqual(321, proto2.repeated_nested_message[2].bb) + proto3 = unittest_pb2.TestAllTypes() + proto3.repeated_nested_message.MergeFrom(proto2.repeated_nested_message) + self.assertEqual(999, proto3.repeated_nested_message[0].bb) + self.assertEqual(123, proto3.repeated_nested_message[1].bb) + self.assertEqual(321, proto3.repeated_nested_message[2].bb) + def testMergeFromAllFields(self): # With all fields set. proto1 = unittest_pb2.TestAllTypes() @@ -1035,6 +1399,19 @@ class ReflectionTest(unittest.TestCase): self.assertEqual(222, ext2[1].bb) self.assertEqual(333, ext2[2].bb) + def testMergeFromBug(self): + message1 = unittest_pb2.TestAllTypes() + message2 = unittest_pb2.TestAllTypes() + + # Cause optional_nested_message to be instantiated within message1, even + # though it is not considered to be "present". + message1.optional_nested_message + self.assertFalse(message1.HasField('optional_nested_message')) + + # Merge into message2. This should not instantiate the field is message2. + message2.MergeFrom(message1) + self.assertFalse(message2.HasField('optional_nested_message')) + def testCopyFromSingularField(self): # Test copy with just a singular field. proto1 = unittest_pb2.TestAllTypes() @@ -1087,9 +1464,36 @@ class ReflectionTest(unittest.TestCase): self.assertEqual(2, proto1.optional_int32) self.assertEqual('important-text', proto1.optional_string) + def testCopyFromBadType(self): + # The python implementation doesn't raise an exception in this + # case. In theory it should. + if api_implementation.Type() == 'python': + return + proto1 = unittest_pb2.TestAllTypes() + proto2 = unittest_pb2.TestAllExtensions() + self.assertRaises(TypeError, proto1.CopyFrom, proto2) + + def testDeepCopy(self): + proto1 = unittest_pb2.TestAllTypes() + proto1.optional_int32 = 1 + proto2 = copy.deepcopy(proto1) + self.assertEqual(1, proto2.optional_int32) + + proto1.repeated_int32.append(2) + proto1.repeated_int32.append(3) + container = copy.deepcopy(proto1.repeated_int32) + self.assertEqual([2, 3], container) + + # TODO(anuraag): Implement deepcopy for repeated composite / extension dict + def testClear(self): proto = unittest_pb2.TestAllTypes() - test_util.SetAllFields(proto) + # C++ implementation does not support lazy fields right now so leave it + # out for now. + if api_implementation.Type() == 'python': + test_util.SetAllFields(proto) + else: + test_util.SetAllNonLazyFields(proto) # Clear the message. proto.Clear() self.assertEquals(proto.ByteSize(), 0) @@ -1105,6 +1509,45 @@ class ReflectionTest(unittest.TestCase): empty_proto = unittest_pb2.TestAllExtensions() self.assertEquals(proto, empty_proto) + def testDisconnectingBeforeClear(self): + proto = unittest_pb2.TestAllTypes() + nested = proto.optional_nested_message + proto.Clear() + self.assertTrue(nested is not proto.optional_nested_message) + nested.bb = 23 + self.assertTrue(not proto.HasField('optional_nested_message')) + self.assertEqual(0, proto.optional_nested_message.bb) + + proto = unittest_pb2.TestAllTypes() + nested = proto.optional_nested_message + nested.bb = 5 + foreign = proto.optional_foreign_message + foreign.c = 6 + + proto.Clear() + self.assertTrue(nested is not proto.optional_nested_message) + self.assertTrue(foreign is not proto.optional_foreign_message) + self.assertEqual(5, nested.bb) + self.assertEqual(6, foreign.c) + nested.bb = 15 + foreign.c = 16 + self.assertFalse(proto.HasField('optional_nested_message')) + self.assertEqual(0, proto.optional_nested_message.bb) + self.assertFalse(proto.HasField('optional_foreign_message')) + self.assertEqual(0, proto.optional_foreign_message.c) + + def testOneOf(self): + proto = unittest_pb2.TestAllTypes() + proto.oneof_uint32 = 10 + proto.oneof_nested_message.bb = 11 + self.assertEqual(11, proto.oneof_nested_message.bb) + self.assertFalse(proto.HasField('oneof_uint32')) + nested = proto.oneof_nested_message + proto.oneof_string = 'abc' + self.assertEqual('abc', proto.oneof_string) + self.assertEqual(11, nested.bb) + self.assertFalse(proto.HasField('oneof_nested_message')) + def assertInitialized(self, proto): self.assertTrue(proto.IsInitialized()) # Neither method should raise an exception. @@ -1175,6 +1618,40 @@ class ReflectionTest(unittest.TestCase): self.assertFalse(proto.IsInitialized(errors)) self.assertEqual(errors, ['a', 'b', 'c']) + @basetest.unittest.skipIf( + api_implementation.Type() != 'cpp' or api_implementation.Version() != 2, + 'Errors are only available from the most recent C++ implementation.') + def testFileDescriptorErrors(self): + file_name = 'test_file_descriptor_errors.proto' + package_name = 'test_file_descriptor_errors.proto' + file_descriptor_proto = descriptor_pb2.FileDescriptorProto() + file_descriptor_proto.name = file_name + file_descriptor_proto.package = package_name + m1 = file_descriptor_proto.message_type.add() + m1.name = 'msg1' + # Compiles the proto into the C++ descriptor pool + descriptor.FileDescriptor( + file_name, + package_name, + serialized_pb=file_descriptor_proto.SerializeToString()) + # Add a FileDescriptorProto that has duplicate symbols + another_file_name = 'another_test_file_descriptor_errors.proto' + file_descriptor_proto.name = another_file_name + m2 = file_descriptor_proto.message_type.add() + m2.name = 'msg2' + with self.assertRaises(TypeError) as cm: + descriptor.FileDescriptor( + another_file_name, + package_name, + serialized_pb=file_descriptor_proto.SerializeToString()) + self.assertTrue(hasattr(cm, 'exception'), '%s not raised' % + getattr(cm.expected, '__name__', cm.expected)) + self.assertIn('test_file_descriptor_errors.proto', str(cm.exception)) + # Error message will say something about this definition being a + # duplicate, though we don't check the message exactly to avoid a + # dependency on the C++ logging code. + self.assertIn('test_file_descriptor_errors.msg1', str(cm.exception)) + def testStringUTF8Encoding(self): proto = unittest_pb2.TestAllTypes() @@ -1192,16 +1669,15 @@ class ReflectionTest(unittest.TestCase): proto.optional_string = str('Testing') self.assertEqual(proto.optional_string, unicode('Testing')) - # Values of type 'str' are also accepted as long as they can be encoded in - # UTF-8. - self.assertEqual(type(proto.optional_string), str) - # Try to assign a 'str' value which contains bytes that aren't 7-bit ASCII. self.assertRaises(ValueError, - setattr, proto, 'optional_string', str('a\x80a')) - # Assign a 'str' object which contains a UTF-8 encoded string. - self.assertRaises(ValueError, - setattr, proto, 'optional_string', 'Тест') + setattr, proto, 'optional_string', b'a\x80a') + if str is bytes: # PY2 + # Assign a 'str' object which contains a UTF-8 encoded string. + self.assertRaises(ValueError, + setattr, proto, 'optional_string', 'Тест') + else: + proto.optional_string = 'Тест' # No exception thrown. proto.optional_string = 'abc' @@ -1224,7 +1700,8 @@ class ReflectionTest(unittest.TestCase): self.assertEqual(proto.ByteSize(), len(serialized)) raw = unittest_mset_pb2.RawMessageSet() - raw.MergeFromString(serialized) + bytes_read = raw.MergeFromString(serialized) + self.assertEqual(len(serialized), bytes_read) message2 = unittest_mset_pb2.TestMessageSetExtension2() @@ -1232,18 +1709,37 @@ class ReflectionTest(unittest.TestCase): # Check that the type_id is the same as the tag ID in the .proto file. self.assertEqual(raw.item[0].type_id, 1547769) - # Check the actually bytes on the wire. + # Check the actual bytes on the wire. self.assertTrue( raw.item[0].message.endswith(test_utf8_bytes)) - message2.MergeFromString(raw.item[0].message) + bytes_read = message2.MergeFromString(raw.item[0].message) + self.assertEqual(len(raw.item[0].message), bytes_read) self.assertEqual(type(message2.str), unicode) self.assertEqual(message2.str, test_utf8) - # How about if the bytes on the wire aren't a valid UTF-8 encoded string. - bytes = raw.item[0].message.replace( - test_utf8_bytes, len(test_utf8_bytes) * '\xff') - self.assertRaises(UnicodeDecodeError, message2.MergeFromString, bytes) + # The pure Python API throws an exception on MergeFromString(), + # if any of the string fields of the message can't be UTF-8 decoded. + # The C++ implementation of the API has no way to check that on + # MergeFromString and thus has no way to throw the exception. + # + # The pure Python API always returns objects of type 'unicode' (UTF-8 + # encoded), or 'bytes' (in 7 bit ASCII). + badbytes = raw.item[0].message.replace( + test_utf8_bytes, len(test_utf8_bytes) * b'\xff') + + unicode_decode_failed = False + try: + message2.MergeFromString(badbytes) + except UnicodeDecodeError: + unicode_decode_failed = True + string_field = message2.str + self.assertTrue(unicode_decode_failed or type(string_field) is bytes) + + def testBytesInTextFormat(self): + proto = unittest_pb2.TestAllTypes(optional_bytes=b'\x00\x7f\x80\xff') + self.assertEqual(u'optional_bytes: "\\000\\177\\200\\377"\n', + unicode(proto)) def testEmptyNestedMessage(self): proto = unittest_pb2.TestAllTypes() @@ -1257,16 +1753,19 @@ class ReflectionTest(unittest.TestCase): self.assertTrue(proto.HasField('optional_nested_message')) proto = unittest_pb2.TestAllTypes() - proto.optional_nested_message.MergeFromString('') + bytes_read = proto.optional_nested_message.MergeFromString(b'') + self.assertEqual(0, bytes_read) self.assertTrue(proto.HasField('optional_nested_message')) proto = unittest_pb2.TestAllTypes() - proto.optional_nested_message.ParseFromString('') + proto.optional_nested_message.ParseFromString(b'') self.assertTrue(proto.HasField('optional_nested_message')) serialized = proto.SerializeToString() proto2 = unittest_pb2.TestAllTypes() - proto2.MergeFromString(serialized) + self.assertEqual( + len(serialized), + proto2.MergeFromString(serialized)) self.assertTrue(proto2.HasField('optional_nested_message')) def testSetInParent(self): @@ -1280,12 +1779,15 @@ class ReflectionTest(unittest.TestCase): # into separate TestCase classes. -class TestAllTypesEqualityTest(unittest.TestCase): +class TestAllTypesEqualityTest(basetest.TestCase): def setUp(self): self.first_proto = unittest_pb2.TestAllTypes() self.second_proto = unittest_pb2.TestAllTypes() + def testNotHashable(self): + self.assertRaises(TypeError, hash, self.first_proto) + def testSelfEquality(self): self.assertEqual(self.first_proto, self.first_proto) @@ -1293,7 +1795,7 @@ class TestAllTypesEqualityTest(unittest.TestCase): self.assertEqual(self.first_proto, self.second_proto) -class FullProtosEqualityTest(unittest.TestCase): +class FullProtosEqualityTest(basetest.TestCase): """Equality tests using completely-full protos as a starting point.""" @@ -1303,6 +1805,9 @@ class FullProtosEqualityTest(unittest.TestCase): test_util.SetAllFields(self.first_proto) test_util.SetAllFields(self.second_proto) + def testNotHashable(self): + self.assertRaises(TypeError, hash, self.first_proto) + def testNoneNotEqual(self): self.assertNotEqual(self.first_proto, None) self.assertNotEqual(None, self.second_proto) @@ -1371,15 +1876,12 @@ class FullProtosEqualityTest(unittest.TestCase): self.first_proto.ClearField('optional_nested_message') self.second_proto.optional_nested_message.ClearField('bb') self.assertNotEqual(self.first_proto, self.second_proto) - # TODO(robinson): Replace next two lines with method - # to set the "has" bit without changing the value, - # if/when such a method exists. self.first_proto.optional_nested_message.bb = 0 self.first_proto.optional_nested_message.ClearField('bb') self.assertEqual(self.first_proto, self.second_proto) -class ExtensionEqualityTest(unittest.TestCase): +class ExtensionEqualityTest(basetest.TestCase): def testExtensionEquality(self): first_proto = unittest_pb2.TestAllExtensions() @@ -1412,7 +1914,7 @@ class ExtensionEqualityTest(unittest.TestCase): self.assertEqual(first_proto, second_proto) -class MutualRecursionEqualityTest(unittest.TestCase): +class MutualRecursionEqualityTest(basetest.TestCase): def testEqualityWithMutualRecursion(self): first_proto = unittest_pb2.TestMutualRecursionA() @@ -1424,7 +1926,7 @@ class MutualRecursionEqualityTest(unittest.TestCase): self.assertEqual(first_proto, second_proto) -class ByteSizeTest(unittest.TestCase): +class ByteSizeTest(basetest.TestCase): def setUp(self): self.proto = unittest_pb2.TestAllTypes() @@ -1438,6 +1940,14 @@ class ByteSizeTest(unittest.TestCase): def testEmptyMessage(self): self.assertEqual(0, self.proto.ByteSize()) + def testSizedOnKwargs(self): + # Use a separate message to ensure testing right after creation. + proto = unittest_pb2.TestAllTypes() + self.assertEqual(0, proto.ByteSize()) + proto_kwargs = unittest_pb2.TestAllTypes(optional_int64 = 1) + # One byte for the tag, one to encode varint 1. + self.assertEqual(2, proto_kwargs.ByteSize()) + def testVarints(self): def Test(i, expected_varint_size): self.proto.Clear() @@ -1629,10 +2139,13 @@ class ByteSizeTest(unittest.TestCase): self.assertEqual(3, self.proto.ByteSize()) self.proto.ClearField('optional_foreign_message') self.assertEqual(0, self.proto.ByteSize()) - child = self.proto.optional_foreign_message - self.proto.ClearField('optional_foreign_message') - child.c = 128 - self.assertEqual(0, self.proto.ByteSize()) + + if api_implementation.Type() == 'python': + # This is only possible in pure-Python implementation of the API. + child = self.proto.optional_foreign_message + self.proto.ClearField('optional_foreign_message') + child.c = 128 + self.assertEqual(0, self.proto.ByteSize()) # Test within extension. extension = more_extensions_pb2.optional_message_extension @@ -1698,7 +2211,6 @@ class ByteSizeTest(unittest.TestCase): self.assertEqual(19, self.packed_extended_proto.ByteSize()) -# TODO(robinson): We need cross-language serialization consistency tests. # Issues to be sure to cover include: # * Handling of unrecognized tags ("uninterpreted_bytes"). # * Handling of MessageSets. @@ -1710,14 +2222,16 @@ class ByteSizeTest(unittest.TestCase): # * Handling of empty submessages (with and without "has" # bits set). -class SerializationTest(unittest.TestCase): +class SerializationTest(basetest.TestCase): def testSerializeEmtpyMessage(self): first_proto = unittest_pb2.TestAllTypes() second_proto = unittest_pb2.TestAllTypes() serialized = first_proto.SerializeToString() self.assertEqual(first_proto.ByteSize(), len(serialized)) - second_proto.MergeFromString(serialized) + self.assertEqual( + len(serialized), + second_proto.MergeFromString(serialized)) self.assertEqual(first_proto, second_proto) def testSerializeAllFields(self): @@ -1726,7 +2240,9 @@ class SerializationTest(unittest.TestCase): test_util.SetAllFields(first_proto) serialized = first_proto.SerializeToString() self.assertEqual(first_proto.ByteSize(), len(serialized)) - second_proto.MergeFromString(serialized) + self.assertEqual( + len(serialized), + second_proto.MergeFromString(serialized)) self.assertEqual(first_proto, second_proto) def testSerializeAllExtensions(self): @@ -1734,7 +2250,19 @@ class SerializationTest(unittest.TestCase): second_proto = unittest_pb2.TestAllExtensions() test_util.SetAllExtensions(first_proto) serialized = first_proto.SerializeToString() - second_proto.MergeFromString(serialized) + self.assertEqual( + len(serialized), + second_proto.MergeFromString(serialized)) + self.assertEqual(first_proto, second_proto) + + def testSerializeWithOptionalGroup(self): + first_proto = unittest_pb2.TestAllTypes() + second_proto = unittest_pb2.TestAllTypes() + first_proto.optionalgroup.a = 242 + serialized = first_proto.SerializeToString() + self.assertEqual( + len(serialized), + second_proto.MergeFromString(serialized)) self.assertEqual(first_proto, second_proto) def testSerializeNegativeValues(self): @@ -1753,6 +2281,10 @@ class SerializationTest(unittest.TestCase): self.assertEqual(first_proto, second_proto) def testParseTruncated(self): + # This test is only applicable for the Python implementation of the API. + if api_implementation.Type() != 'python': + return + first_proto = unittest_pb2.TestAllTypes() test_util.SetAllFields(first_proto) serialized = first_proto.SerializeToString() @@ -1822,7 +2354,9 @@ class SerializationTest(unittest.TestCase): second_proto.optional_int32 = 100 second_proto.optional_nested_message.bb = 999 - second_proto.MergeFromString(serialized) + bytes_parsed = second_proto.MergeFromString(serialized) + self.assertEqual(len(serialized), bytes_parsed) + # Ensure that we append to repeated fields. self.assertEqual(['baz', 'foobar'], list(second_proto.repeated_string)) # Ensure that we overwrite nonrepeatd scalars. @@ -1847,20 +2381,28 @@ class SerializationTest(unittest.TestCase): raw = unittest_mset_pb2.RawMessageSet() self.assertEqual(False, raw.DESCRIPTOR.GetOptions().message_set_wire_format) - raw.MergeFromString(serialized) + self.assertEqual( + len(serialized), + raw.MergeFromString(serialized)) self.assertEqual(2, len(raw.item)) message1 = unittest_mset_pb2.TestMessageSetExtension1() - message1.MergeFromString(raw.item[0].message) + self.assertEqual( + len(raw.item[0].message), + message1.MergeFromString(raw.item[0].message)) self.assertEqual(123, message1.i) message2 = unittest_mset_pb2.TestMessageSetExtension2() - message2.MergeFromString(raw.item[1].message) + self.assertEqual( + len(raw.item[1].message), + message2.MergeFromString(raw.item[1].message)) self.assertEqual('foo', message2.str) # Deserialize using the MessageSet wire format. proto2 = unittest_mset_pb2.TestMessageSet() - proto2.MergeFromString(serialized) + self.assertEqual( + len(serialized), + proto2.MergeFromString(serialized)) self.assertEqual(123, proto2.Extensions[extension1].i) self.assertEqual('foo', proto2.Extensions[extension2].str) @@ -1900,7 +2442,9 @@ class SerializationTest(unittest.TestCase): # Parse message using the message set wire format. proto = unittest_mset_pb2.TestMessageSet() - proto.MergeFromString(serialized) + self.assertEqual( + len(serialized), + proto.MergeFromString(serialized)) # Check that the message parsed well. extension_message1 = unittest_mset_pb2.TestMessageSetExtension1 @@ -1918,7 +2462,9 @@ class SerializationTest(unittest.TestCase): proto2 = unittest_pb2.TestEmptyMessage() # Parsing this message should succeed. - proto2.MergeFromString(serialized) + self.assertEqual( + len(serialized), + proto2.MergeFromString(serialized)) # Now test with a int64 field set. proto = unittest_pb2.TestAllTypes() @@ -1928,13 +2474,15 @@ class SerializationTest(unittest.TestCase): # unknown. proto2 = unittest_pb2.TestEmptyMessage() # Parsing this message should succeed. - proto2.MergeFromString(serialized) + self.assertEqual( + len(serialized), + proto2.MergeFromString(serialized)) def _CheckRaises(self, exc_class, callable_obj, exception): """This method checks if the excpetion type and message are as expected.""" try: callable_obj() - except exc_class, ex: + except exc_class as ex: # Check if the exception message is the right one. self.assertEqual(exception, str(ex)) return @@ -1946,15 +2494,22 @@ class SerializationTest(unittest.TestCase): self._CheckRaises( message.EncodeError, proto.SerializeToString, - 'Message is missing required fields: a,b,c') + 'Message protobuf_unittest.TestRequired is missing required fields: ' + 'a,b,c') # Shouldn't raise exceptions. partial = proto.SerializePartialToString() + proto2 = unittest_pb2.TestRequired() + self.assertFalse(proto2.HasField('a')) + # proto2 ParseFromString does not check that required fields are set. + proto2.ParseFromString(partial) + self.assertFalse(proto2.HasField('a')) + proto.a = 1 self._CheckRaises( message.EncodeError, proto.SerializeToString, - 'Message is missing required fields: b,c') + 'Message protobuf_unittest.TestRequired is missing required fields: b,c') # Shouldn't raise exceptions. partial = proto.SerializePartialToString() @@ -1962,7 +2517,7 @@ class SerializationTest(unittest.TestCase): self._CheckRaises( message.EncodeError, proto.SerializeToString, - 'Message is missing required fields: c') + 'Message protobuf_unittest.TestRequired is missing required fields: c') # Shouldn't raise exceptions. partial = proto.SerializePartialToString() @@ -1972,11 +2527,15 @@ class SerializationTest(unittest.TestCase): partial = proto.SerializePartialToString() proto2 = unittest_pb2.TestRequired() - proto2.MergeFromString(serialized) + self.assertEqual( + len(serialized), + proto2.MergeFromString(serialized)) self.assertEqual(1, proto2.a) self.assertEqual(2, proto2.b) self.assertEqual(3, proto2.c) - proto2.ParseFromString(partial) + self.assertEqual( + len(partial), + proto2.MergeFromString(partial)) self.assertEqual(1, proto2.a) self.assertEqual(2, proto2.b) self.assertEqual(3, proto2.c) @@ -1991,7 +2550,8 @@ class SerializationTest(unittest.TestCase): self._CheckRaises( message.EncodeError, proto.SerializeToString, - 'Message is missing required fields: ' + 'Message protobuf_unittest.TestRequiredForeign ' + 'is missing required fields: ' 'optional_message.b,optional_message.c') proto.optional_message.b = 2 @@ -2003,7 +2563,7 @@ class SerializationTest(unittest.TestCase): self._CheckRaises( message.EncodeError, proto.SerializeToString, - 'Message is missing required fields: ' + 'Message protobuf_unittest.TestRequiredForeign is missing required fields: ' 'repeated_message[0].b,repeated_message[0].c,' 'repeated_message[1].a,repeated_message[1].c') @@ -2043,7 +2603,9 @@ class SerializationTest(unittest.TestCase): second_proto.packed_double.extend([1.0, 2.0]) second_proto.packed_sint32.append(4) - second_proto.MergeFromString(serialized) + self.assertEqual( + len(serialized), + second_proto.MergeFromString(serialized)) self.assertEqual([3, 1, 2], second_proto.packed_int32) self.assertEqual([1.0, 2.0, 3.0], second_proto.packed_double) self.assertEqual([4], second_proto.packed_sint32) @@ -2076,7 +2638,10 @@ class SerializationTest(unittest.TestCase): unpacked = unittest_pb2.TestUnpackedTypes() test_util.SetAllUnpackedFields(unpacked) packed = unittest_pb2.TestPackedTypes() - packed.MergeFromString(unpacked.SerializeToString()) + serialized = unpacked.SerializeToString() + self.assertEqual( + len(serialized), + packed.MergeFromString(serialized)) expected = unittest_pb2.TestPackedTypes() test_util.SetAllPackedFields(expected) self.assertEqual(expected, packed) @@ -2085,7 +2650,10 @@ class SerializationTest(unittest.TestCase): packed = unittest_pb2.TestPackedTypes() test_util.SetAllPackedFields(packed) unpacked = unittest_pb2.TestUnpackedTypes() - unpacked.MergeFromString(packed.SerializeToString()) + serialized = packed.SerializeToString() + self.assertEqual( + len(serialized), + unpacked.MergeFromString(serialized)) expected = unittest_pb2.TestUnpackedTypes() test_util.SetAllUnpackedFields(expected) self.assertEqual(expected, unpacked) @@ -2137,7 +2705,7 @@ class SerializationTest(unittest.TestCase): optional_int32=1, optional_string='foo', optional_bool=True, - optional_bytes='bar', + optional_bytes=b'bar', optional_nested_message=unittest_pb2.TestAllTypes.NestedMessage(bb=1), optional_foreign_message=unittest_pb2.ForeignMessage(c=1), optional_nested_enum=unittest_pb2.TestAllTypes.FOO, @@ -2155,7 +2723,7 @@ class SerializationTest(unittest.TestCase): self.assertEqual(1, proto.optional_int32) self.assertEqual('foo', proto.optional_string) self.assertEqual(True, proto.optional_bool) - self.assertEqual('bar', proto.optional_bytes) + self.assertEqual(b'bar', proto.optional_bytes) self.assertEqual(1, proto.optional_nested_message.bb) self.assertEqual(1, proto.optional_foreign_message.c) self.assertEqual(unittest_pb2.TestAllTypes.FOO, @@ -2205,7 +2773,7 @@ class SerializationTest(unittest.TestCase): self.assertEqual(3, proto.repeated_int32[2]) -class OptionsTest(unittest.TestCase): +class OptionsTest(basetest.TestCase): def testMessageOptions(self): proto = unittest_mset_pb2.TestMessageSet() @@ -2232,5 +2800,135 @@ class OptionsTest(unittest.TestCase): +class ClassAPITest(basetest.TestCase): + + def testMakeClassWithNestedDescriptor(self): + leaf_desc = descriptor.Descriptor('leaf', 'package.parent.child.leaf', '', + containing_type=None, fields=[], + nested_types=[], enum_types=[], + extensions=[]) + child_desc = descriptor.Descriptor('child', 'package.parent.child', '', + containing_type=None, fields=[], + nested_types=[leaf_desc], enum_types=[], + extensions=[]) + sibling_desc = descriptor.Descriptor('sibling', 'package.parent.sibling', + '', containing_type=None, fields=[], + nested_types=[], enum_types=[], + extensions=[]) + parent_desc = descriptor.Descriptor('parent', 'package.parent', '', + containing_type=None, fields=[], + nested_types=[child_desc, sibling_desc], + enum_types=[], extensions=[]) + message_class = reflection.MakeClass(parent_desc) + self.assertIn('child', message_class.__dict__) + self.assertIn('sibling', message_class.__dict__) + self.assertIn('leaf', message_class.child.__dict__) + + def _GetSerializedFileDescriptor(self, name): + """Get a serialized representation of a test FileDescriptorProto. + + Args: + name: All calls to this must use a unique message name, to avoid + collisions in the cpp descriptor pool. + Returns: + A string containing the serialized form of a test FileDescriptorProto. + """ + file_descriptor_str = ( + 'message_type {' + ' name: "' + name + '"' + ' field {' + ' name: "flat"' + ' number: 1' + ' label: LABEL_REPEATED' + ' type: TYPE_UINT32' + ' }' + ' field {' + ' name: "bar"' + ' number: 2' + ' label: LABEL_OPTIONAL' + ' type: TYPE_MESSAGE' + ' type_name: "Bar"' + ' }' + ' nested_type {' + ' name: "Bar"' + ' field {' + ' name: "baz"' + ' number: 3' + ' label: LABEL_OPTIONAL' + ' type: TYPE_MESSAGE' + ' type_name: "Baz"' + ' }' + ' nested_type {' + ' name: "Baz"' + ' enum_type {' + ' name: "deep_enum"' + ' value {' + ' name: "VALUE_A"' + ' number: 0' + ' }' + ' }' + ' field {' + ' name: "deep"' + ' number: 4' + ' label: LABEL_OPTIONAL' + ' type: TYPE_UINT32' + ' }' + ' }' + ' }' + '}') + file_descriptor = descriptor_pb2.FileDescriptorProto() + text_format.Merge(file_descriptor_str, file_descriptor) + return file_descriptor.SerializeToString() + + def testParsingFlatClassWithExplicitClassDeclaration(self): + """Test that the generated class can parse a flat message.""" + file_descriptor = descriptor_pb2.FileDescriptorProto() + file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('A')) + msg_descriptor = descriptor.MakeDescriptor( + file_descriptor.message_type[0]) + + class MessageClass(message.Message): + __metaclass__ = reflection.GeneratedProtocolMessageType + DESCRIPTOR = msg_descriptor + msg = MessageClass() + msg_str = ( + 'flat: 0 ' + 'flat: 1 ' + 'flat: 2 ') + text_format.Merge(msg_str, msg) + self.assertEqual(msg.flat, [0, 1, 2]) + + def testParsingFlatClass(self): + """Test that the generated class can parse a flat message.""" + file_descriptor = descriptor_pb2.FileDescriptorProto() + file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('B')) + msg_descriptor = descriptor.MakeDescriptor( + file_descriptor.message_type[0]) + msg_class = reflection.MakeClass(msg_descriptor) + msg = msg_class() + msg_str = ( + 'flat: 0 ' + 'flat: 1 ' + 'flat: 2 ') + text_format.Merge(msg_str, msg) + self.assertEqual(msg.flat, [0, 1, 2]) + + def testParsingNestedClass(self): + """Test that the generated class can parse a nested message.""" + file_descriptor = descriptor_pb2.FileDescriptorProto() + file_descriptor.ParseFromString(self._GetSerializedFileDescriptor('C')) + msg_descriptor = descriptor.MakeDescriptor( + file_descriptor.message_type[0]) + msg_class = reflection.MakeClass(msg_descriptor) + msg = msg_class() + msg_str = ( + 'bar {' + ' baz {' + ' deep: 4' + ' }' + '}') + text_format.Merge(msg_str, msg) + self.assertEqual(msg.bar.baz.deep, 4) + if __name__ == '__main__': - unittest.main() + basetest.main() diff --git a/python/google/protobuf/internal/service_reflection_test.py b/python/google/protobuf/internal/service_reflection_test.py index e04f825..07dcf44 100755 --- a/python/google/protobuf/internal/service_reflection_test.py +++ b/python/google/protobuf/internal/service_reflection_test.py @@ -2,7 +2,7 @@ # # Protocol Buffers - Google's data interchange format # Copyright 2008 Google Inc. All rights reserved. -# http://code.google.com/p/protobuf/ +# https://developers.google.com/protocol-buffers/ # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -34,13 +34,13 @@ __author__ = 'petar@google.com (Petar Petrov)' -import unittest +from google.apputils import basetest from google.protobuf import unittest_pb2 from google.protobuf import service_reflection from google.protobuf import service -class FooUnitTest(unittest.TestCase): +class FooUnitTest(basetest.TestCase): def testService(self): class MockRpcChannel(service.RpcChannel): @@ -133,4 +133,4 @@ class FooUnitTest(unittest.TestCase): if __name__ == '__main__': - unittest.main() + basetest.main() diff --git a/python/google/protobuf/internal/symbol_database_test.py b/python/google/protobuf/internal/symbol_database_test.py new file mode 100644 index 0000000..47572d5 --- /dev/null +++ b/python/google/protobuf/internal/symbol_database_test.py @@ -0,0 +1,120 @@ +#! /usr/bin/python +# +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# https://developers.google.com/protocol-buffers/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Tests for google.protobuf.symbol_database.""" + +from google.apputils import basetest +from google.protobuf import unittest_pb2 +from google.protobuf import symbol_database + + +class SymbolDatabaseTest(basetest.TestCase): + + def _Database(self): + db = symbol_database.SymbolDatabase() + # Register representative types from unittest_pb2. + db.RegisterFileDescriptor(unittest_pb2.DESCRIPTOR) + db.RegisterMessage(unittest_pb2.TestAllTypes) + db.RegisterMessage(unittest_pb2.TestAllTypes.NestedMessage) + db.RegisterMessage(unittest_pb2.TestAllTypes.OptionalGroup) + db.RegisterMessage(unittest_pb2.TestAllTypes.RepeatedGroup) + db.RegisterEnumDescriptor(unittest_pb2.ForeignEnum.DESCRIPTOR) + db.RegisterEnumDescriptor(unittest_pb2.TestAllTypes.NestedEnum.DESCRIPTOR) + return db + + def testGetPrototype(self): + instance = self._Database().GetPrototype( + unittest_pb2.TestAllTypes.DESCRIPTOR) + self.assertTrue(instance is unittest_pb2.TestAllTypes) + + def testGetMessages(self): + messages = self._Database().GetMessages( + ['google/protobuf/unittest.proto']) + self.assertTrue( + unittest_pb2.TestAllTypes is + messages['protobuf_unittest.TestAllTypes']) + + def testGetSymbol(self): + self.assertEquals( + unittest_pb2.TestAllTypes, self._Database().GetSymbol( + 'protobuf_unittest.TestAllTypes')) + self.assertEquals( + unittest_pb2.TestAllTypes.NestedMessage, self._Database().GetSymbol( + 'protobuf_unittest.TestAllTypes.NestedMessage')) + self.assertEquals( + unittest_pb2.TestAllTypes.OptionalGroup, self._Database().GetSymbol( + 'protobuf_unittest.TestAllTypes.OptionalGroup')) + self.assertEquals( + unittest_pb2.TestAllTypes.RepeatedGroup, self._Database().GetSymbol( + 'protobuf_unittest.TestAllTypes.RepeatedGroup')) + + def testEnums(self): + # Check registration of types in the pool. + self.assertEquals( + 'protobuf_unittest.ForeignEnum', + self._Database().pool.FindEnumTypeByName( + 'protobuf_unittest.ForeignEnum').full_name) + self.assertEquals( + 'protobuf_unittest.TestAllTypes.NestedEnum', + self._Database().pool.FindEnumTypeByName( + 'protobuf_unittest.TestAllTypes.NestedEnum').full_name) + + def testFindMessageTypeByName(self): + self.assertEquals( + 'protobuf_unittest.TestAllTypes', + self._Database().pool.FindMessageTypeByName( + 'protobuf_unittest.TestAllTypes').full_name) + self.assertEquals( + 'protobuf_unittest.TestAllTypes.NestedMessage', + self._Database().pool.FindMessageTypeByName( + 'protobuf_unittest.TestAllTypes.NestedMessage').full_name) + + def testFindFindContainingSymbol(self): + # Lookup based on either enum or message. + self.assertEquals( + 'google/protobuf/unittest.proto', + self._Database().pool.FindFileContainingSymbol( + 'protobuf_unittest.TestAllTypes.NestedEnum').name) + self.assertEquals( + 'google/protobuf/unittest.proto', + self._Database().pool.FindFileContainingSymbol( + 'protobuf_unittest.TestAllTypes').name) + + def testFindFileByName(self): + self.assertEquals( + 'google/protobuf/unittest.proto', + self._Database().pool.FindFileByName( + 'google/protobuf/unittest.proto').name) + + +if __name__ == '__main__': + basetest.main() diff --git a/python/google/protobuf/internal/test_bad_identifiers.proto b/python/google/protobuf/internal/test_bad_identifiers.proto new file mode 100644 index 0000000..9eb18cb --- /dev/null +++ b/python/google/protobuf/internal/test_bad_identifiers.proto @@ -0,0 +1,52 @@ +// Protocol Buffers - Google's data interchange format +// Copyright 2008 Google Inc. All rights reserved. +// https://developers.google.com/protocol-buffers/ +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are +// met: +// +// * Redistributions of source code must retain the above copyright +// notice, this list of conditions and the following disclaimer. +// * Redistributions in binary form must reproduce the above +// copyright notice, this list of conditions and the following disclaimer +// in the documentation and/or other materials provided with the +// distribution. +// * Neither the name of Google Inc. nor the names of its +// contributors may be used to endorse or promote products derived from +// this software without specific prior written permission. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +// "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +// LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +// A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +// OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +// SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +// LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +// DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +// THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +// OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Author: kenton@google.com (Kenton Varda) + + +package protobuf_unittest; + +option py_generic_services = true; + +message TestBadIdentifiers { + extensions 100 to max; +} + +// Make sure these reasonable extension names don't conflict with internal +// variables. +extend TestBadIdentifiers { + optional string message = 100 [default="foo"]; + optional string descriptor = 101 [default="bar"]; + optional string reflection = 102 [default="baz"]; + optional string service = 103 [default="qux"]; +} + +message AnotherMessage {} +service AnotherService {} diff --git a/python/google/protobuf/internal/test_util.py b/python/google/protobuf/internal/test_util.py index 1df1619..787f465 100755 --- a/python/google/protobuf/internal/test_util.py +++ b/python/google/protobuf/internal/test_util.py @@ -1,6 +1,6 @@ # Protocol Buffers - Google's data interchange format # Copyright 2008 Google Inc. All rights reserved. -# http://code.google.com/p/protobuf/ +# https://developers.google.com/protocol-buffers/ # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -42,8 +42,8 @@ from google.protobuf import unittest_import_pb2 from google.protobuf import unittest_pb2 -def SetAllFields(message): - """Sets every field in the message to a unique value. +def SetAllNonLazyFields(message): + """Sets every non-lazy field in the message to a unique value. Args: message: A unittest_pb2.TestAllTypes instance. @@ -66,26 +66,21 @@ def SetAllFields(message): message.optional_float = 111 message.optional_double = 112 message.optional_bool = True - # TODO(robinson): Firmly spec out and test how - # protos interact with unicode. One specific example: - # what happens if we change the literal below to - # u'115'? What *should* happen? Still some discussion - # to finish with Kenton about bytes vs. strings - # and forcing everything to be utf8. :-/ - message.optional_string = '115' - message.optional_bytes = '116' + message.optional_string = u'115' + message.optional_bytes = b'116' message.optionalgroup.a = 117 message.optional_nested_message.bb = 118 message.optional_foreign_message.c = 119 message.optional_import_message.d = 120 + message.optional_public_import_message.e = 126 message.optional_nested_enum = unittest_pb2.TestAllTypes.BAZ message.optional_foreign_enum = unittest_pb2.FOREIGN_BAZ message.optional_import_enum = unittest_import_pb2.IMPORT_BAZ - message.optional_string_piece = '124' - message.optional_cord = '125' + message.optional_string_piece = u'124' + message.optional_cord = u'125' # # Repeated fields. @@ -104,20 +99,21 @@ def SetAllFields(message): message.repeated_float.append(211) message.repeated_double.append(212) message.repeated_bool.append(True) - message.repeated_string.append('215') - message.repeated_bytes.append('216') + message.repeated_string.append(u'215') + message.repeated_bytes.append(b'216') message.repeatedgroup.add().a = 217 message.repeated_nested_message.add().bb = 218 message.repeated_foreign_message.add().c = 219 message.repeated_import_message.add().d = 220 + message.repeated_lazy_message.add().bb = 227 message.repeated_nested_enum.append(unittest_pb2.TestAllTypes.BAR) message.repeated_foreign_enum.append(unittest_pb2.FOREIGN_BAR) message.repeated_import_enum.append(unittest_import_pb2.IMPORT_BAR) - message.repeated_string_piece.append('224') - message.repeated_cord.append('225') + message.repeated_string_piece.append(u'224') + message.repeated_cord.append(u'225') # Add a second one of each field. message.repeated_int32.append(301) @@ -133,20 +129,21 @@ def SetAllFields(message): message.repeated_float.append(311) message.repeated_double.append(312) message.repeated_bool.append(False) - message.repeated_string.append('315') - message.repeated_bytes.append('316') + message.repeated_string.append(u'315') + message.repeated_bytes.append(b'316') message.repeatedgroup.add().a = 317 message.repeated_nested_message.add().bb = 318 message.repeated_foreign_message.add().c = 319 message.repeated_import_message.add().d = 320 + message.repeated_lazy_message.add().bb = 327 message.repeated_nested_enum.append(unittest_pb2.TestAllTypes.BAZ) message.repeated_foreign_enum.append(unittest_pb2.FOREIGN_BAZ) message.repeated_import_enum.append(unittest_import_pb2.IMPORT_BAZ) - message.repeated_string_piece.append('324') - message.repeated_cord.append('325') + message.repeated_string_piece.append(u'324') + message.repeated_cord.append(u'325') # # Fields that have defaults. @@ -166,7 +163,7 @@ def SetAllFields(message): message.default_double = 412 message.default_bool = False message.default_string = '415' - message.default_bytes = '416' + message.default_bytes = b'416' message.default_nested_enum = unittest_pb2.TestAllTypes.FOO message.default_foreign_enum = unittest_pb2.FOREIGN_FOO @@ -175,6 +172,16 @@ def SetAllFields(message): message.default_string_piece = '424' message.default_cord = '425' + message.oneof_uint32 = 601 + message.oneof_nested_message.bb = 602 + message.oneof_string = '603' + message.oneof_bytes = b'604' + + +def SetAllFields(message): + SetAllNonLazyFields(message) + message.optional_lazy_message.bb = 127 + def SetAllExtensions(message): """Sets every extension in the message to a unique value. @@ -204,21 +211,23 @@ def SetAllExtensions(message): extensions[pb2.optional_float_extension] = 111 extensions[pb2.optional_double_extension] = 112 extensions[pb2.optional_bool_extension] = True - extensions[pb2.optional_string_extension] = '115' - extensions[pb2.optional_bytes_extension] = '116' + extensions[pb2.optional_string_extension] = u'115' + extensions[pb2.optional_bytes_extension] = b'116' extensions[pb2.optionalgroup_extension].a = 117 extensions[pb2.optional_nested_message_extension].bb = 118 extensions[pb2.optional_foreign_message_extension].c = 119 extensions[pb2.optional_import_message_extension].d = 120 + extensions[pb2.optional_public_import_message_extension].e = 126 + extensions[pb2.optional_lazy_message_extension].bb = 127 extensions[pb2.optional_nested_enum_extension] = pb2.TestAllTypes.BAZ extensions[pb2.optional_nested_enum_extension] = pb2.TestAllTypes.BAZ extensions[pb2.optional_foreign_enum_extension] = pb2.FOREIGN_BAZ extensions[pb2.optional_import_enum_extension] = import_pb2.IMPORT_BAZ - extensions[pb2.optional_string_piece_extension] = '124' - extensions[pb2.optional_cord_extension] = '125' + extensions[pb2.optional_string_piece_extension] = u'124' + extensions[pb2.optional_cord_extension] = u'125' # # Repeated fields. @@ -237,20 +246,21 @@ def SetAllExtensions(message): extensions[pb2.repeated_float_extension].append(211) extensions[pb2.repeated_double_extension].append(212) extensions[pb2.repeated_bool_extension].append(True) - extensions[pb2.repeated_string_extension].append('215') - extensions[pb2.repeated_bytes_extension].append('216') + extensions[pb2.repeated_string_extension].append(u'215') + extensions[pb2.repeated_bytes_extension].append(b'216') extensions[pb2.repeatedgroup_extension].add().a = 217 extensions[pb2.repeated_nested_message_extension].add().bb = 218 extensions[pb2.repeated_foreign_message_extension].add().c = 219 extensions[pb2.repeated_import_message_extension].add().d = 220 + extensions[pb2.repeated_lazy_message_extension].add().bb = 227 extensions[pb2.repeated_nested_enum_extension].append(pb2.TestAllTypes.BAR) extensions[pb2.repeated_foreign_enum_extension].append(pb2.FOREIGN_BAR) extensions[pb2.repeated_import_enum_extension].append(import_pb2.IMPORT_BAR) - extensions[pb2.repeated_string_piece_extension].append('224') - extensions[pb2.repeated_cord_extension].append('225') + extensions[pb2.repeated_string_piece_extension].append(u'224') + extensions[pb2.repeated_cord_extension].append(u'225') # Append a second one of each field. extensions[pb2.repeated_int32_extension].append(301) @@ -266,20 +276,21 @@ def SetAllExtensions(message): extensions[pb2.repeated_float_extension].append(311) extensions[pb2.repeated_double_extension].append(312) extensions[pb2.repeated_bool_extension].append(False) - extensions[pb2.repeated_string_extension].append('315') - extensions[pb2.repeated_bytes_extension].append('316') + extensions[pb2.repeated_string_extension].append(u'315') + extensions[pb2.repeated_bytes_extension].append(b'316') extensions[pb2.repeatedgroup_extension].add().a = 317 extensions[pb2.repeated_nested_message_extension].add().bb = 318 extensions[pb2.repeated_foreign_message_extension].add().c = 319 extensions[pb2.repeated_import_message_extension].add().d = 320 + extensions[pb2.repeated_lazy_message_extension].add().bb = 327 extensions[pb2.repeated_nested_enum_extension].append(pb2.TestAllTypes.BAZ) extensions[pb2.repeated_foreign_enum_extension].append(pb2.FOREIGN_BAZ) extensions[pb2.repeated_import_enum_extension].append(import_pb2.IMPORT_BAZ) - extensions[pb2.repeated_string_piece_extension].append('324') - extensions[pb2.repeated_cord_extension].append('325') + extensions[pb2.repeated_string_piece_extension].append(u'324') + extensions[pb2.repeated_cord_extension].append(u'325') # # Fields with defaults. @@ -298,16 +309,21 @@ def SetAllExtensions(message): extensions[pb2.default_float_extension] = 411 extensions[pb2.default_double_extension] = 412 extensions[pb2.default_bool_extension] = False - extensions[pb2.default_string_extension] = '415' - extensions[pb2.default_bytes_extension] = '416' + extensions[pb2.default_string_extension] = u'415' + extensions[pb2.default_bytes_extension] = b'416' extensions[pb2.default_nested_enum_extension] = pb2.TestAllTypes.FOO extensions[pb2.default_foreign_enum_extension] = pb2.FOREIGN_FOO extensions[pb2.default_import_enum_extension] = import_pb2.IMPORT_FOO - extensions[pb2.default_string_piece_extension] = '424' + extensions[pb2.default_string_piece_extension] = u'424' extensions[pb2.default_cord_extension] = '425' + extensions[pb2.oneof_uint32_extension] = 601 + extensions[pb2.oneof_nested_message_extension].bb = 602 + extensions[pb2.oneof_string_extension] = u'603' + extensions[pb2.oneof_bytes_extension] = b'604' + def SetAllFieldsAndExtensions(message): """Sets every field and extension in the message to a unique value. @@ -346,7 +362,7 @@ def ExpectAllFieldsAndExtensionsInOrder(serialized): message.my_float = 1.0 expected_strings.append(message.SerializeToString()) message.Clear() - expected = ''.join(expected_strings) + expected = b''.join(expected_strings) if expected != serialized: raise ValueError('Expected %r, found %r' % (expected, serialized)) @@ -401,12 +417,14 @@ def ExpectAllFieldsSet(test_case, message): test_case.assertEqual(112, message.optional_double) test_case.assertEqual(True, message.optional_bool) test_case.assertEqual('115', message.optional_string) - test_case.assertEqual('116', message.optional_bytes) + test_case.assertEqual(b'116', message.optional_bytes) test_case.assertEqual(117, message.optionalgroup.a) test_case.assertEqual(118, message.optional_nested_message.bb) test_case.assertEqual(119, message.optional_foreign_message.c) test_case.assertEqual(120, message.optional_import_message.d) + test_case.assertEqual(126, message.optional_public_import_message.e) + test_case.assertEqual(127, message.optional_lazy_message.bb) test_case.assertEqual(unittest_pb2.TestAllTypes.BAZ, message.optional_nested_enum) @@ -458,12 +476,13 @@ def ExpectAllFieldsSet(test_case, message): test_case.assertEqual(212, message.repeated_double[0]) test_case.assertEqual(True, message.repeated_bool[0]) test_case.assertEqual('215', message.repeated_string[0]) - test_case.assertEqual('216', message.repeated_bytes[0]) + test_case.assertEqual(b'216', message.repeated_bytes[0]) test_case.assertEqual(217, message.repeatedgroup[0].a) test_case.assertEqual(218, message.repeated_nested_message[0].bb) test_case.assertEqual(219, message.repeated_foreign_message[0].c) test_case.assertEqual(220, message.repeated_import_message[0].d) + test_case.assertEqual(227, message.repeated_lazy_message[0].bb) test_case.assertEqual(unittest_pb2.TestAllTypes.BAR, message.repeated_nested_enum[0]) @@ -486,12 +505,13 @@ def ExpectAllFieldsSet(test_case, message): test_case.assertEqual(312, message.repeated_double[1]) test_case.assertEqual(False, message.repeated_bool[1]) test_case.assertEqual('315', message.repeated_string[1]) - test_case.assertEqual('316', message.repeated_bytes[1]) + test_case.assertEqual(b'316', message.repeated_bytes[1]) test_case.assertEqual(317, message.repeatedgroup[1].a) test_case.assertEqual(318, message.repeated_nested_message[1].bb) test_case.assertEqual(319, message.repeated_foreign_message[1].c) test_case.assertEqual(320, message.repeated_import_message[1].d) + test_case.assertEqual(327, message.repeated_lazy_message[1].bb) test_case.assertEqual(unittest_pb2.TestAllTypes.BAZ, message.repeated_nested_enum[1]) @@ -536,7 +556,7 @@ def ExpectAllFieldsSet(test_case, message): test_case.assertEqual(412, message.default_double) test_case.assertEqual(False, message.default_bool) test_case.assertEqual('415', message.default_string) - test_case.assertEqual('416', message.default_bytes) + test_case.assertEqual(b'416', message.default_bytes) test_case.assertEqual(unittest_pb2.TestAllTypes.FOO, message.default_nested_enum) @@ -545,6 +565,7 @@ def ExpectAllFieldsSet(test_case, message): test_case.assertEqual(unittest_import_pb2.IMPORT_FOO, message.default_import_enum) + def GoldenFile(filename): """Finds the given golden file and returns a file object representing it.""" @@ -558,9 +579,15 @@ def GoldenFile(filename): path = os.path.join(path, '..') raise RuntimeError( - 'Could not find golden files. This test must be run from within the ' - 'protobuf source package so that it can read test data files from the ' - 'C++ source tree.') + 'Could not find golden files. This test must be run from within the ' + 'protobuf source package so that it can read test data files from the ' + 'C++ source tree.') + + +def GoldenFileData(filename): + """Finds the given golden file and returns its contents.""" + with GoldenFile(filename) as f: + return f.read() def SetAllPackedFields(message): diff --git a/python/google/protobuf/internal/text_encoding_test.py b/python/google/protobuf/internal/text_encoding_test.py new file mode 100755 index 0000000..db0222b --- /dev/null +++ b/python/google/protobuf/internal/text_encoding_test.py @@ -0,0 +1,68 @@ +#! /usr/bin/python +# +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# https://developers.google.com/protocol-buffers/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Tests for google.protobuf.text_encoding.""" + +from google.apputils import basetest +from google.protobuf import text_encoding + +TEST_VALUES = [ + ("foo\\rbar\\nbaz\\t", + "foo\\rbar\\nbaz\\t", + b"foo\rbar\nbaz\t"), + ("\\'full of \\\"sound\\\" and \\\"fury\\\"\\'", + "\\'full of \\\"sound\\\" and \\\"fury\\\"\\'", + b"'full of \"sound\" and \"fury\"'"), + ("signi\\\\fying\\\\ nothing\\\\", + "signi\\\\fying\\\\ nothing\\\\", + b"signi\\fying\\ nothing\\"), + ("\\010\\t\\n\\013\\014\\r", + "\x08\\t\\n\x0b\x0c\\r", + b"\010\011\012\013\014\015")] + + +class TextEncodingTestCase(basetest.TestCase): + def testCEscape(self): + for escaped, escaped_utf8, unescaped in TEST_VALUES: + self.assertEquals(escaped, + text_encoding.CEscape(unescaped, as_utf8=False)) + self.assertEquals(escaped_utf8, + text_encoding.CEscape(unescaped, as_utf8=True)) + + def testCUnescape(self): + for escaped, escaped_utf8, unescaped in TEST_VALUES: + self.assertEquals(unescaped, text_encoding.CUnescape(escaped)) + self.assertEquals(unescaped, text_encoding.CUnescape(escaped_utf8)) + + +if __name__ == "__main__": + basetest.main() diff --git a/python/google/protobuf/internal/text_format_test.py b/python/google/protobuf/internal/text_format_test.py index e0991cb..b0a3a5f 100755 --- a/python/google/protobuf/internal/text_format_test.py +++ b/python/google/protobuf/internal/text_format_test.py @@ -2,7 +2,7 @@ # # Protocol Buffers - Google's data interchange format # Copyright 2008 Google Inc. All rights reserved. -# http://code.google.com/p/protobuf/ +# https://developers.google.com/protocol-buffers/ # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -34,48 +34,71 @@ __author__ = 'kenton@google.com (Kenton Varda)' -import difflib +import re -import unittest +from google.apputils import basetest from google.protobuf import text_format +from google.protobuf.internal import api_implementation from google.protobuf.internal import test_util from google.protobuf import unittest_pb2 from google.protobuf import unittest_mset_pb2 +class TextFormatTest(basetest.TestCase): -class TextFormatTest(unittest.TestCase): def ReadGolden(self, golden_filename): - f = test_util.GoldenFile(golden_filename) - golden_lines = f.readlines() - f.close() - return golden_lines + with test_util.GoldenFile(golden_filename) as f: + return (f.readlines() if str is bytes else # PY3 + [golden_line.decode('utf-8') for golden_line in f]) def CompareToGoldenFile(self, text, golden_filename): golden_lines = self.ReadGolden(golden_filename) - self.CompareToGoldenLines(text, golden_lines) + self.assertMultiLineEqual(text, ''.join(golden_lines)) def CompareToGoldenText(self, text, golden_text): - self.CompareToGoldenLines(text, golden_text.splitlines(1)) - - def CompareToGoldenLines(self, text, golden_lines): - actual_lines = text.splitlines(1) - self.assertEqual(golden_lines, actual_lines, - "Text doesn't match golden. Diff:\n" + - ''.join(difflib.ndiff(golden_lines, actual_lines))) + self.assertMultiLineEqual(text, golden_text) def testPrintAllFields(self): message = unittest_pb2.TestAllTypes() test_util.SetAllFields(message) self.CompareToGoldenFile( - self.RemoveRedundantZeros(text_format.MessageToString(message)), - 'text_format_unittest_data.txt') + self.RemoveRedundantZeros(text_format.MessageToString(message)), + 'text_format_unittest_data_oneof_implemented.txt') + + def testPrintInIndexOrder(self): + message = unittest_pb2.TestFieldOrderings() + message.my_string = '115' + message.my_int = 101 + message.my_float = 111 + self.CompareToGoldenText( + self.RemoveRedundantZeros(text_format.MessageToString( + message, use_index_order=True)), + 'my_string: \"115\"\nmy_int: 101\nmy_float: 111\n') + self.CompareToGoldenText( + self.RemoveRedundantZeros(text_format.MessageToString( + message)), 'my_int: 101\nmy_string: \"115\"\nmy_float: 111\n') def testPrintAllExtensions(self): message = unittest_pb2.TestAllExtensions() test_util.SetAllExtensions(message) self.CompareToGoldenFile( - self.RemoveRedundantZeros(text_format.MessageToString(message)), - 'text_format_unittest_extensions_data.txt') + self.RemoveRedundantZeros(text_format.MessageToString(message)), + 'text_format_unittest_extensions_data.txt') + + def testPrintAllFieldsPointy(self): + message = unittest_pb2.TestAllTypes() + test_util.SetAllFields(message) + self.CompareToGoldenFile( + self.RemoveRedundantZeros( + text_format.MessageToString(message, pointy_brackets=True)), + 'text_format_unittest_data_pointy_oneof.txt') + + def testPrintAllExtensionsPointy(self): + message = unittest_pb2.TestAllExtensions() + test_util.SetAllExtensions(message) + self.CompareToGoldenFile( + self.RemoveRedundantZeros(text_format.MessageToString( + message, pointy_brackets=True)), + 'text_format_unittest_extensions_data_pointy.txt') def testPrintMessageSet(self): message = unittest_mset_pb2.TestMessageSetContainer() @@ -83,33 +106,179 @@ class TextFormatTest(unittest.TestCase): ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension message.message_set.Extensions[ext1].i = 23 message.message_set.Extensions[ext2].str = 'foo' - self.CompareToGoldenText(text_format.MessageToString(message), - 'message_set {\n' - ' [protobuf_unittest.TestMessageSetExtension1] {\n' - ' i: 23\n' - ' }\n' - ' [protobuf_unittest.TestMessageSetExtension2] {\n' - ' str: \"foo\"\n' - ' }\n' - '}\n') + self.CompareToGoldenText( + text_format.MessageToString(message), + 'message_set {\n' + ' [protobuf_unittest.TestMessageSetExtension1] {\n' + ' i: 23\n' + ' }\n' + ' [protobuf_unittest.TestMessageSetExtension2] {\n' + ' str: \"foo\"\n' + ' }\n' + '}\n') def testPrintExotic(self): message = unittest_pb2.TestAllTypes() - message.repeated_int64.append(-9223372036854775808); - message.repeated_uint64.append(18446744073709551615); - message.repeated_double.append(123.456); - message.repeated_double.append(1.23e22); - message.repeated_double.append(1.23e-18); - message.repeated_string.append('\000\001\a\b\f\n\r\t\v\\\'\"'); + message.repeated_int64.append(-9223372036854775808) + message.repeated_uint64.append(18446744073709551615) + message.repeated_double.append(123.456) + message.repeated_double.append(1.23e22) + message.repeated_double.append(1.23e-18) + message.repeated_string.append('\000\001\a\b\f\n\r\t\v\\\'"') + message.repeated_string.append(u'\u00fc\ua71f') + self.CompareToGoldenText( + self.RemoveRedundantZeros(text_format.MessageToString(message)), + 'repeated_int64: -9223372036854775808\n' + 'repeated_uint64: 18446744073709551615\n' + 'repeated_double: 123.456\n' + 'repeated_double: 1.23e+22\n' + 'repeated_double: 1.23e-18\n' + 'repeated_string:' + ' "\\000\\001\\007\\010\\014\\n\\r\\t\\013\\\\\\\'\\""\n' + 'repeated_string: "\\303\\274\\352\\234\\237"\n') + + def testPrintExoticUnicodeSubclass(self): + class UnicodeSub(unicode): + pass + message = unittest_pb2.TestAllTypes() + message.repeated_string.append(UnicodeSub(u'\u00fc\ua71f')) + self.CompareToGoldenText( + text_format.MessageToString(message), + 'repeated_string: "\\303\\274\\352\\234\\237"\n') + + def testPrintNestedMessageAsOneLine(self): + message = unittest_pb2.TestAllTypes() + msg = message.repeated_nested_message.add() + msg.bb = 42 + self.CompareToGoldenText( + text_format.MessageToString(message, as_one_line=True), + 'repeated_nested_message { bb: 42 }') + + def testPrintRepeatedFieldsAsOneLine(self): + message = unittest_pb2.TestAllTypes() + message.repeated_int32.append(1) + message.repeated_int32.append(1) + message.repeated_int32.append(3) + message.repeated_string.append("Google") + message.repeated_string.append("Zurich") + self.CompareToGoldenText( + text_format.MessageToString(message, as_one_line=True), + 'repeated_int32: 1 repeated_int32: 1 repeated_int32: 3 ' + 'repeated_string: "Google" repeated_string: "Zurich"') + + def testPrintNestedNewLineInStringAsOneLine(self): + message = unittest_pb2.TestAllTypes() + message.optional_string = "a\nnew\nline" + self.CompareToGoldenText( + text_format.MessageToString(message, as_one_line=True), + 'optional_string: "a\\nnew\\nline"') + + def testPrintMessageSetAsOneLine(self): + message = unittest_mset_pb2.TestMessageSetContainer() + ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension + ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension + message.message_set.Extensions[ext1].i = 23 + message.message_set.Extensions[ext2].str = 'foo' + self.CompareToGoldenText( + text_format.MessageToString(message, as_one_line=True), + 'message_set {' + ' [protobuf_unittest.TestMessageSetExtension1] {' + ' i: 23' + ' }' + ' [protobuf_unittest.TestMessageSetExtension2] {' + ' str: \"foo\"' + ' }' + ' }') + + def testPrintExoticAsOneLine(self): + message = unittest_pb2.TestAllTypes() + message.repeated_int64.append(-9223372036854775808) + message.repeated_uint64.append(18446744073709551615) + message.repeated_double.append(123.456) + message.repeated_double.append(1.23e22) + message.repeated_double.append(1.23e-18) + message.repeated_string.append('\000\001\a\b\f\n\r\t\v\\\'"') + message.repeated_string.append(u'\u00fc\ua71f') + self.CompareToGoldenText( + self.RemoveRedundantZeros( + text_format.MessageToString(message, as_one_line=True)), + 'repeated_int64: -9223372036854775808' + ' repeated_uint64: 18446744073709551615' + ' repeated_double: 123.456' + ' repeated_double: 1.23e+22' + ' repeated_double: 1.23e-18' + ' repeated_string: ' + '"\\000\\001\\007\\010\\014\\n\\r\\t\\013\\\\\\\'\\""' + ' repeated_string: "\\303\\274\\352\\234\\237"') + + def testRoundTripExoticAsOneLine(self): + message = unittest_pb2.TestAllTypes() + message.repeated_int64.append(-9223372036854775808) + message.repeated_uint64.append(18446744073709551615) + message.repeated_double.append(123.456) + message.repeated_double.append(1.23e22) + message.repeated_double.append(1.23e-18) + message.repeated_string.append('\000\001\a\b\f\n\r\t\v\\\'"') + message.repeated_string.append(u'\u00fc\ua71f') + + # Test as_utf8 = False. + wire_text = text_format.MessageToString( + message, as_one_line=True, as_utf8=False) + parsed_message = unittest_pb2.TestAllTypes() + r = text_format.Parse(wire_text, parsed_message) + self.assertIs(r, parsed_message) + self.assertEquals(message, parsed_message) + + # Test as_utf8 = True. + wire_text = text_format.MessageToString( + message, as_one_line=True, as_utf8=True) + parsed_message = unittest_pb2.TestAllTypes() + r = text_format.Parse(wire_text, parsed_message) + self.assertIs(r, parsed_message) + self.assertEquals(message, parsed_message, + '\n%s != %s' % (message, parsed_message)) + + def testPrintRawUtf8String(self): + message = unittest_pb2.TestAllTypes() + message.repeated_string.append(u'\u00fc\ua71f') + text = text_format.MessageToString(message, as_utf8=True) + self.CompareToGoldenText(text, 'repeated_string: "\303\274\352\234\237"\n') + parsed_message = unittest_pb2.TestAllTypes() + text_format.Parse(text, parsed_message) + self.assertEquals(message, parsed_message, + '\n%s != %s' % (message, parsed_message)) + + def testPrintFloatFormat(self): + # Check that float_format argument is passed to sub-message formatting. + message = unittest_pb2.NestedTestAllTypes() + # We use 1.25 as it is a round number in binary. The proto 32-bit float + # will not gain additional imprecise digits as a 64-bit Python float and + # show up in its str. 32-bit 1.2 is noisy when extended to 64-bit: + # >>> struct.unpack('f', struct.pack('f', 1.2))[0] + # 1.2000000476837158 + # >>> struct.unpack('f', struct.pack('f', 1.25))[0] + # 1.25 + message.payload.optional_float = 1.25 + # Check rounding at 15 significant digits + message.payload.optional_double = -.000003456789012345678 + # Check no decimal point. + message.payload.repeated_float.append(-5642) + # Check no trailing zeros. + message.payload.repeated_double.append(.000078900) + formatted_fields = ['optional_float: 1.25', + 'optional_double: -3.45678901234568e-6', + 'repeated_float: -5642', + 'repeated_double: 7.89e-5'] + text_message = text_format.MessageToString(message, float_format='.15g') self.CompareToGoldenText( - self.RemoveRedundantZeros(text_format.MessageToString(message)), - 'repeated_int64: -9223372036854775808\n' - 'repeated_uint64: 18446744073709551615\n' - 'repeated_double: 123.456\n' - 'repeated_double: 1.23e+22\n' - 'repeated_double: 1.23e-18\n' - 'repeated_string: ' - '\"\\000\\001\\007\\010\\014\\n\\r\\t\\013\\\\\\\'\\\"\"\n') + self.RemoveRedundantZeros(text_message), + 'payload {{\n {}\n {}\n {}\n {}\n}}\n'.format(*formatted_fields)) + # as_one_line=True is a separate code branch where float_format is passed. + text_message = text_format.MessageToString(message, as_one_line=True, + float_format='.15g') + self.CompareToGoldenText( + self.RemoveRedundantZeros(text_message), + 'payload {{ {} {} {} {} }}'.format(*formatted_fields)) def testMessageToString(self): message = unittest_pb2.ForeignMessage() @@ -119,52 +288,57 @@ class TextFormatTest(unittest.TestCase): def RemoveRedundantZeros(self, text): # Some platforms print 1e+5 as 1e+005. This is fine, but we need to remove # these zeros in order to match the golden file. - return text.replace('e+0','e+').replace('e+0','e+') \ + text = text.replace('e+0','e+').replace('e+0','e+') \ .replace('e-0','e-').replace('e-0','e-') + # Floating point fields are printed with .0 suffix even if they are + # actualy integer numbers. + text = re.compile('\.0$', re.MULTILINE).sub('', text) + return text - def testMergeGolden(self): + def testParseGolden(self): golden_text = '\n'.join(self.ReadGolden('text_format_unittest_data.txt')) parsed_message = unittest_pb2.TestAllTypes() - text_format.Merge(golden_text, parsed_message) + r = text_format.Parse(golden_text, parsed_message) + self.assertIs(r, parsed_message) message = unittest_pb2.TestAllTypes() test_util.SetAllFields(message) self.assertEquals(message, parsed_message) - def testMergeGoldenExtensions(self): + def testParseGoldenExtensions(self): golden_text = '\n'.join(self.ReadGolden( 'text_format_unittest_extensions_data.txt')) parsed_message = unittest_pb2.TestAllExtensions() - text_format.Merge(golden_text, parsed_message) + text_format.Parse(golden_text, parsed_message) message = unittest_pb2.TestAllExtensions() test_util.SetAllExtensions(message) self.assertEquals(message, parsed_message) - def testMergeAllFields(self): + def testParseAllFields(self): message = unittest_pb2.TestAllTypes() test_util.SetAllFields(message) ascii_text = text_format.MessageToString(message) parsed_message = unittest_pb2.TestAllTypes() - text_format.Merge(ascii_text, parsed_message) + text_format.Parse(ascii_text, parsed_message) self.assertEqual(message, parsed_message) test_util.ExpectAllFieldsSet(self, message) - def testMergeAllExtensions(self): + def testParseAllExtensions(self): message = unittest_pb2.TestAllExtensions() test_util.SetAllExtensions(message) ascii_text = text_format.MessageToString(message) parsed_message = unittest_pb2.TestAllExtensions() - text_format.Merge(ascii_text, parsed_message) + text_format.Parse(ascii_text, parsed_message) self.assertEqual(message, parsed_message) - def testMergeMessageSet(self): + def testParseMessageSet(self): message = unittest_pb2.TestAllTypes() text = ('repeated_uint64: 1\n' 'repeated_uint64: 2\n') - text_format.Merge(text, message) + text_format.Parse(text, message) self.assertEqual(1, message.repeated_uint64[0]) self.assertEqual(2, message.repeated_uint64[1]) @@ -177,13 +351,13 @@ class TextFormatTest(unittest.TestCase): ' str: \"foo\"\n' ' }\n' '}\n') - text_format.Merge(text, message) + text_format.Parse(text, message) ext1 = unittest_mset_pb2.TestMessageSetExtension1.message_set_extension ext2 = unittest_mset_pb2.TestMessageSetExtension2.message_set_extension self.assertEquals(23, message.message_set.Extensions[ext1].i) self.assertEquals('foo', message.message_set.Extensions[ext2].str) - def testMergeExotic(self): + def testParseExotic(self): message = unittest_pb2.TestAllTypes() text = ('repeated_int64: -9223372036854775808\n' 'repeated_uint64: 18446744073709551615\n' @@ -191,9 +365,12 @@ class TextFormatTest(unittest.TestCase): 'repeated_double: 1.23e+22\n' 'repeated_double: 1.23e-18\n' 'repeated_string: \n' - '\"\\000\\001\\007\\010\\014\\n\\r\\t\\013\\\\\\\'\\\"\"\n' - 'repeated_string: "foo" \'corge\' "grault"') - text_format.Merge(text, message) + '"\\000\\001\\007\\010\\014\\n\\r\\t\\013\\\\\\\'\\""\n' + 'repeated_string: "foo" \'corge\' "grault"\n' + 'repeated_string: "\\303\\274\\352\\234\\237"\n' + 'repeated_string: "\\xc3\\xbc"\n' + 'repeated_string: "\xc3\xbc"\n') + text_format.Parse(text, message) self.assertEqual(-9223372036854775808, message.repeated_int64[0]) self.assertEqual(18446744073709551615, message.repeated_uint64[0]) @@ -201,95 +378,224 @@ class TextFormatTest(unittest.TestCase): self.assertEqual(1.23e22, message.repeated_double[1]) self.assertEqual(1.23e-18, message.repeated_double[2]) self.assertEqual( - '\000\001\a\b\f\n\r\t\v\\\'\"', message.repeated_string[0]) + '\000\001\a\b\f\n\r\t\v\\\'"', message.repeated_string[0]) self.assertEqual('foocorgegrault', message.repeated_string[1]) + self.assertEqual(u'\u00fc\ua71f', message.repeated_string[2]) + self.assertEqual(u'\u00fc', message.repeated_string[3]) - def testMergeUnknownField(self): + def testParseTrailingCommas(self): + message = unittest_pb2.TestAllTypes() + text = ('repeated_int64: 100;\n' + 'repeated_int64: 200;\n' + 'repeated_int64: 300,\n' + 'repeated_string: "one",\n' + 'repeated_string: "two";\n') + text_format.Parse(text, message) + + self.assertEqual(100, message.repeated_int64[0]) + self.assertEqual(200, message.repeated_int64[1]) + self.assertEqual(300, message.repeated_int64[2]) + self.assertEqual(u'one', message.repeated_string[0]) + self.assertEqual(u'two', message.repeated_string[1]) + + def testParseEmptyText(self): + message = unittest_pb2.TestAllTypes() + text = '' + text_format.Parse(text, message) + self.assertEquals(unittest_pb2.TestAllTypes(), message) + + def testParseInvalidUtf8(self): + message = unittest_pb2.TestAllTypes() + text = 'repeated_string: "\\xc3\\xc3"' + self.assertRaises(text_format.ParseError, text_format.Parse, text, message) + + def testParseSingleWord(self): + message = unittest_pb2.TestAllTypes() + text = 'foo' + self.assertRaisesWithLiteralMatch( + text_format.ParseError, + ('1:1 : Message type "protobuf_unittest.TestAllTypes" has no field named ' + '"foo".'), + text_format.Parse, text, message) + + def testParseUnknownField(self): message = unittest_pb2.TestAllTypes() text = 'unknown_field: 8\n' - self.assertRaisesWithMessage( + self.assertRaisesWithLiteralMatch( text_format.ParseError, ('1:1 : Message type "protobuf_unittest.TestAllTypes" has no field named ' '"unknown_field".'), - text_format.Merge, text, message) + text_format.Parse, text, message) - def testMergeBadExtension(self): + def testParseBadExtension(self): message = unittest_pb2.TestAllExtensions() text = '[unknown_extension]: 8\n' - self.assertRaisesWithMessage( + self.assertRaisesWithLiteralMatch( text_format.ParseError, '1:2 : Extension "unknown_extension" not registered.', - text_format.Merge, text, message) + text_format.Parse, text, message) message = unittest_pb2.TestAllTypes() - self.assertRaisesWithMessage( + self.assertRaisesWithLiteralMatch( text_format.ParseError, ('1:2 : Message type "protobuf_unittest.TestAllTypes" does not have ' 'extensions.'), - text_format.Merge, text, message) + text_format.Parse, text, message) - def testMergeGroupNotClosed(self): + def testParseGroupNotClosed(self): message = unittest_pb2.TestAllTypes() text = 'RepeatedGroup: <' - self.assertRaisesWithMessage( + self.assertRaisesWithLiteralMatch( text_format.ParseError, '1:16 : Expected ">".', - text_format.Merge, text, message) + text_format.Parse, text, message) text = 'RepeatedGroup: {' - self.assertRaisesWithMessage( + self.assertRaisesWithLiteralMatch( text_format.ParseError, '1:16 : Expected "}".', - text_format.Merge, text, message) + text_format.Parse, text, message) - def testMergeEmptyGroup(self): + def testParseEmptyGroup(self): message = unittest_pb2.TestAllTypes() text = 'OptionalGroup: {}' - text_format.Merge(text, message) + text_format.Parse(text, message) self.assertTrue(message.HasField('optionalgroup')) message.Clear() message = unittest_pb2.TestAllTypes() text = 'OptionalGroup: <>' - text_format.Merge(text, message) + text_format.Parse(text, message) self.assertTrue(message.HasField('optionalgroup')) - def testMergeBadEnumValue(self): + def testParseBadEnumValue(self): message = unittest_pb2.TestAllTypes() text = 'optional_nested_enum: BARR' - self.assertRaisesWithMessage( + self.assertRaisesWithLiteralMatch( text_format.ParseError, ('1:23 : Enum type "protobuf_unittest.TestAllTypes.NestedEnum" ' 'has no value named BARR.'), - text_format.Merge, text, message) + text_format.Parse, text, message) message = unittest_pb2.TestAllTypes() text = 'optional_nested_enum: 100' - self.assertRaisesWithMessage( + self.assertRaisesWithLiteralMatch( text_format.ParseError, ('1:23 : Enum type "protobuf_unittest.TestAllTypes.NestedEnum" ' 'has no value with number 100.'), - text_format.Merge, text, message) - - def assertRaisesWithMessage(self, e_class, e, func, *args, **kwargs): - """Same as assertRaises, but also compares the exception message.""" - if hasattr(e_class, '__name__'): - exc_name = e_class.__name__ - else: - exc_name = str(e_class) - - try: - func(*args, **kwargs) - except e_class, expr: - if str(expr) != e: - msg = '%s raised, but with wrong message: "%s" instead of "%s"' - raise self.failureException(msg % (exc_name, - str(expr).encode('string_escape'), - e.encode('string_escape'))) - return - else: - raise self.failureException('%s not raised' % exc_name) - - -class TokenizerTest(unittest.TestCase): + text_format.Parse, text, message) + + def testParseBadIntValue(self): + message = unittest_pb2.TestAllTypes() + text = 'optional_int32: bork' + self.assertRaisesWithLiteralMatch( + text_format.ParseError, + ('1:17 : Couldn\'t parse integer: bork'), + text_format.Parse, text, message) + + def testParseStringFieldUnescape(self): + message = unittest_pb2.TestAllTypes() + text = r'''repeated_string: "\xf\x62" + repeated_string: "\\xf\\x62" + repeated_string: "\\\xf\\\x62" + repeated_string: "\\\\xf\\\\x62" + repeated_string: "\\\\\xf\\\\\x62" + repeated_string: "\x5cx20"''' + text_format.Parse(text, message) + + SLASH = '\\' + self.assertEqual('\x0fb', message.repeated_string[0]) + self.assertEqual(SLASH + 'xf' + SLASH + 'x62', message.repeated_string[1]) + self.assertEqual(SLASH + '\x0f' + SLASH + 'b', message.repeated_string[2]) + self.assertEqual(SLASH + SLASH + 'xf' + SLASH + SLASH + 'x62', + message.repeated_string[3]) + self.assertEqual(SLASH + SLASH + '\x0f' + SLASH + SLASH + 'b', + message.repeated_string[4]) + self.assertEqual(SLASH + 'x20', message.repeated_string[5]) + + def testMergeRepeatedScalars(self): + message = unittest_pb2.TestAllTypes() + text = ('optional_int32: 42 ' + 'optional_int32: 67') + r = text_format.Merge(text, message) + self.assertIs(r, message) + self.assertEqual(67, message.optional_int32) + + def testParseRepeatedScalars(self): + message = unittest_pb2.TestAllTypes() + text = ('optional_int32: 42 ' + 'optional_int32: 67') + self.assertRaisesWithLiteralMatch( + text_format.ParseError, + ('1:36 : Message type "protobuf_unittest.TestAllTypes" should not ' + 'have multiple "optional_int32" fields.'), + text_format.Parse, text, message) + + def testMergeRepeatedNestedMessageScalars(self): + message = unittest_pb2.TestAllTypes() + text = ('optional_nested_message { bb: 1 } ' + 'optional_nested_message { bb: 2 }') + r = text_format.Merge(text, message) + self.assertTrue(r is message) + self.assertEqual(2, message.optional_nested_message.bb) + + def testParseRepeatedNestedMessageScalars(self): + message = unittest_pb2.TestAllTypes() + text = ('optional_nested_message { bb: 1 } ' + 'optional_nested_message { bb: 2 }') + self.assertRaisesWithLiteralMatch( + text_format.ParseError, + ('1:65 : Message type "protobuf_unittest.TestAllTypes.NestedMessage" ' + 'should not have multiple "bb" fields.'), + text_format.Parse, text, message) + + def testMergeRepeatedExtensionScalars(self): + message = unittest_pb2.TestAllExtensions() + text = ('[protobuf_unittest.optional_int32_extension]: 42 ' + '[protobuf_unittest.optional_int32_extension]: 67') + text_format.Merge(text, message) + self.assertEqual( + 67, + message.Extensions[unittest_pb2.optional_int32_extension]) + + def testParseRepeatedExtensionScalars(self): + message = unittest_pb2.TestAllExtensions() + text = ('[protobuf_unittest.optional_int32_extension]: 42 ' + '[protobuf_unittest.optional_int32_extension]: 67') + self.assertRaisesWithLiteralMatch( + text_format.ParseError, + ('1:96 : Message type "protobuf_unittest.TestAllExtensions" ' + 'should not have multiple ' + '"protobuf_unittest.optional_int32_extension" extensions.'), + text_format.Parse, text, message) + + def testParseLinesGolden(self): + opened = self.ReadGolden('text_format_unittest_data.txt') + parsed_message = unittest_pb2.TestAllTypes() + r = text_format.ParseLines(opened, parsed_message) + self.assertIs(r, parsed_message) + + message = unittest_pb2.TestAllTypes() + test_util.SetAllFields(message) + self.assertEquals(message, parsed_message) + + def testMergeLinesGolden(self): + opened = self.ReadGolden('text_format_unittest_data.txt') + parsed_message = unittest_pb2.TestAllTypes() + r = text_format.MergeLines(opened, parsed_message) + self.assertIs(r, parsed_message) + + message = unittest_pb2.TestAllTypes() + test_util.SetAllFields(message) + self.assertEqual(message, parsed_message) + + def testParseOneof(self): + m = unittest_pb2.TestAllTypes() + m.oneof_uint32 = 11 + m2 = unittest_pb2.TestAllTypes() + text_format.Parse(text_format.MessageToString(m), m2) + self.assertEqual('oneof_uint32', m2.WhichOneof('oneof_field')) + + +class TokenizerTest(basetest.TestCase): def testSimpleTokenCases(self): text = ('identifier1:"string1"\n \n\n' @@ -297,8 +603,9 @@ class TokenizerTest(unittest.TestCase): 'identifiER_4 : 1.1e+2 ID5:-0.23 ID6:\'aaaa\\\'bbbb\'\n' 'ID7 : "aa\\"bb"\n\n\n\n ID8: {A:inf B:-inf C:true D:false}\n' 'ID9: 22 ID10: -111111111111111111 ID11: -22\n' - 'ID12: 2222222222222222222') - tokenizer = text_format._Tokenizer(text) + 'ID12: 2222222222222222222 ID13: 1.23456f ID14: 1.2e+2f ' + 'false_bool: 0 true_BOOL:t \n true_bool1: 1 false_BOOL1:f ') + tokenizer = text_format._Tokenizer(text.splitlines()) methods = [(tokenizer.ConsumeIdentifier, 'identifier1'), ':', (tokenizer.ConsumeString, 'string1'), @@ -325,10 +632,10 @@ class TokenizerTest(unittest.TestCase): '{', (tokenizer.ConsumeIdentifier, 'A'), ':', - (tokenizer.ConsumeFloat, text_format._INFINITY), + (tokenizer.ConsumeFloat, float('inf')), (tokenizer.ConsumeIdentifier, 'B'), ':', - (tokenizer.ConsumeFloat, -text_format._INFINITY), + (tokenizer.ConsumeFloat, -float('inf')), (tokenizer.ConsumeIdentifier, 'C'), ':', (tokenizer.ConsumeBool, True), @@ -347,7 +654,25 @@ class TokenizerTest(unittest.TestCase): (tokenizer.ConsumeInt32, -22), (tokenizer.ConsumeIdentifier, 'ID12'), ':', - (tokenizer.ConsumeUint64, 2222222222222222222)] + (tokenizer.ConsumeUint64, 2222222222222222222), + (tokenizer.ConsumeIdentifier, 'ID13'), + ':', + (tokenizer.ConsumeFloat, 1.23456), + (tokenizer.ConsumeIdentifier, 'ID14'), + ':', + (tokenizer.ConsumeFloat, 1.2e+2), + (tokenizer.ConsumeIdentifier, 'false_bool'), + ':', + (tokenizer.ConsumeBool, False), + (tokenizer.ConsumeIdentifier, 'true_BOOL'), + ':', + (tokenizer.ConsumeBool, True), + (tokenizer.ConsumeIdentifier, 'true_bool1'), + ':', + (tokenizer.ConsumeBool, True), + (tokenizer.ConsumeIdentifier, 'false_BOOL1'), + ':', + (tokenizer.ConsumeBool, False)] i = 0 while not tokenizer.AtEnd(): @@ -366,7 +691,7 @@ class TokenizerTest(unittest.TestCase): int64_max = (1 << 63) - 1 uint32_max = (1 << 32) - 1 text = '-1 %d %d' % (uint32_max + 1, int64_max + 1) - tokenizer = text_format._Tokenizer(text) + tokenizer = text_format._Tokenizer(text.splitlines()) self.assertRaises(text_format.ParseError, tokenizer.ConsumeUint32) self.assertRaises(text_format.ParseError, tokenizer.ConsumeUint64) self.assertEqual(-1, tokenizer.ConsumeInt32()) @@ -380,7 +705,7 @@ class TokenizerTest(unittest.TestCase): self.assertTrue(tokenizer.AtEnd()) text = '-0 -0 0 0' - tokenizer = text_format._Tokenizer(text) + tokenizer = text_format._Tokenizer(text.splitlines()) self.assertEqual(0, tokenizer.ConsumeUint32()) self.assertEqual(0, tokenizer.ConsumeUint64()) self.assertEqual(0, tokenizer.ConsumeUint32()) @@ -389,40 +714,30 @@ class TokenizerTest(unittest.TestCase): def testConsumeByteString(self): text = '"string1\'' - tokenizer = text_format._Tokenizer(text) + tokenizer = text_format._Tokenizer(text.splitlines()) self.assertRaises(text_format.ParseError, tokenizer.ConsumeByteString) text = 'string1"' - tokenizer = text_format._Tokenizer(text) + tokenizer = text_format._Tokenizer(text.splitlines()) self.assertRaises(text_format.ParseError, tokenizer.ConsumeByteString) text = '\n"\\xt"' - tokenizer = text_format._Tokenizer(text) + tokenizer = text_format._Tokenizer(text.splitlines()) self.assertRaises(text_format.ParseError, tokenizer.ConsumeByteString) text = '\n"\\"' - tokenizer = text_format._Tokenizer(text) + tokenizer = text_format._Tokenizer(text.splitlines()) self.assertRaises(text_format.ParseError, tokenizer.ConsumeByteString) text = '\n"\\x"' - tokenizer = text_format._Tokenizer(text) + tokenizer = text_format._Tokenizer(text.splitlines()) self.assertRaises(text_format.ParseError, tokenizer.ConsumeByteString) def testConsumeBool(self): text = 'not-a-bool' - tokenizer = text_format._Tokenizer(text) + tokenizer = text_format._Tokenizer(text.splitlines()) self.assertRaises(text_format.ParseError, tokenizer.ConsumeBool) - def testInfNan(self): - # Make sure our infinity and NaN definitions are sound. - self.assertEquals(float, type(text_format._INFINITY)) - self.assertEquals(float, type(text_format._NAN)) - self.assertTrue(text_format._NAN != text_format._NAN) - - inf_times_zero = text_format._INFINITY * 0 - self.assertTrue(inf_times_zero != inf_times_zero) - self.assertTrue(text_format._INFINITY > 0) - if __name__ == '__main__': - unittest.main() + basetest.main() diff --git a/python/google/protobuf/internal/type_checkers.py b/python/google/protobuf/internal/type_checkers.py index 2b3cd4d..56d2646 100755 --- a/python/google/protobuf/internal/type_checkers.py +++ b/python/google/protobuf/internal/type_checkers.py @@ -1,6 +1,6 @@ # Protocol Buffers - Google's data interchange format # Copyright 2008 Google Inc. All rights reserved. -# http://code.google.com/p/protobuf/ +# https://developers.google.com/protocol-buffers/ # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -28,6 +28,10 @@ # (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +#PY25 compatible for GAE. +# +# Copyright 2008 Google Inc. All Rights Reserved. + """Provides type checking routines. This module defines type checking utilities in the forms of dictionaries: @@ -45,6 +49,9 @@ TYPE_TO_DESERIALIZE_METHOD: A dictionary with field types and deserialization __author__ = 'robinson@google.com (Will Robinson)' +import sys ##PY25 +if sys.version < '2.6': bytes = str ##PY25 +from google.protobuf.internal import api_implementation from google.protobuf.internal import decoder from google.protobuf.internal import encoder from google.protobuf.internal import wire_format @@ -53,21 +60,22 @@ from google.protobuf import descriptor _FieldDescriptor = descriptor.FieldDescriptor -def GetTypeChecker(cpp_type, field_type): +def GetTypeChecker(field): """Returns a type checker for a message field of the specified types. Args: - cpp_type: C++ type of the field (see descriptor.py). - field_type: Protocol message field type (see descriptor.py). + field: FieldDescriptor object for this field. Returns: An instance of TypeChecker which can be used to verify the types of values assigned to a field of the specified type. """ - if (cpp_type == _FieldDescriptor.CPPTYPE_STRING and - field_type == _FieldDescriptor.TYPE_STRING): + if (field.cpp_type == _FieldDescriptor.CPPTYPE_STRING and + field.type == _FieldDescriptor.TYPE_STRING): return UnicodeValueChecker() - return _VALUE_CHECKERS[cpp_type] + if field.cpp_type == _FieldDescriptor.CPPTYPE_ENUM: + return EnumValueChecker(field.enum_type) + return _VALUE_CHECKERS[field.cpp_type] # None of the typecheckers below make any attempt to guard against people @@ -85,10 +93,15 @@ class TypeChecker(object): self._acceptable_types = acceptable_types def CheckValue(self, proposed_value): + """Type check the provided value and return it. + + The returned value might have been normalized to another type. + """ if not isinstance(proposed_value, self._acceptable_types): message = ('%.1024r has type %s, but expected one of: %s' % (proposed_value, type(proposed_value), self._acceptable_types)) raise TypeError(message) + return proposed_value # IntValueChecker and its subclasses perform integer type-checks @@ -104,28 +117,54 @@ class IntValueChecker(object): raise TypeError(message) if not self._MIN <= proposed_value <= self._MAX: raise ValueError('Value out of range: %d' % proposed_value) + # We force 32-bit values to int and 64-bit values to long to make + # alternate implementations where the distinction is more significant + # (e.g. the C++ implementation) simpler. + proposed_value = self._TYPE(proposed_value) + return proposed_value + + +class EnumValueChecker(object): + + """Checker used for enum fields. Performs type-check and range check.""" + + def __init__(self, enum_type): + self._enum_type = enum_type + + def CheckValue(self, proposed_value): + if not isinstance(proposed_value, (int, long)): + message = ('%.1024r has type %s, but expected one of: %s' % + (proposed_value, type(proposed_value), (int, long))) + raise TypeError(message) + if proposed_value not in self._enum_type.values_by_number: + raise ValueError('Unknown enum value: %d' % proposed_value) + return proposed_value class UnicodeValueChecker(object): - """Checker used for string fields.""" + """Checker used for string fields. + + Always returns a unicode value, even if the input is of type str. + """ def CheckValue(self, proposed_value): - if not isinstance(proposed_value, (str, unicode)): + if not isinstance(proposed_value, (bytes, unicode)): message = ('%.1024r has type %s, but expected one of: %s' % - (proposed_value, type(proposed_value), (str, unicode))) + (proposed_value, type(proposed_value), (bytes, unicode))) raise TypeError(message) - # If the value is of type 'str' make sure that it is in 7-bit ASCII + # If the value is of type 'bytes' make sure that it is in 7-bit ASCII # encoding. - if isinstance(proposed_value, str): + if isinstance(proposed_value, bytes): try: - unicode(proposed_value, 'ascii') + proposed_value = proposed_value.decode('ascii') except UnicodeDecodeError: - raise ValueError('%.1024r has type str, but isn\'t in 7-bit ASCII ' + raise ValueError('%.1024r has type bytes, but isn\'t in 7-bit ASCII ' 'encoding. Non-ASCII strings must be converted to ' 'unicode objects before being added.' % (proposed_value)) + return proposed_value class Int32ValueChecker(IntValueChecker): @@ -133,21 +172,25 @@ class Int32ValueChecker(IntValueChecker): # efficient. _MIN = -2147483648 _MAX = 2147483647 + _TYPE = int class Uint32ValueChecker(IntValueChecker): _MIN = 0 _MAX = (1 << 32) - 1 + _TYPE = int class Int64ValueChecker(IntValueChecker): _MIN = -(1 << 63) _MAX = (1 << 63) - 1 + _TYPE = long class Uint64ValueChecker(IntValueChecker): _MIN = 0 _MAX = (1 << 64) - 1 + _TYPE = long # Type-checkers for all scalar CPPTYPEs. @@ -161,8 +204,7 @@ _VALUE_CHECKERS = { _FieldDescriptor.CPPTYPE_FLOAT: TypeChecker( float, int, long), _FieldDescriptor.CPPTYPE_BOOL: TypeChecker(bool, int), - _FieldDescriptor.CPPTYPE_ENUM: Int32ValueChecker(), - _FieldDescriptor.CPPTYPE_STRING: TypeChecker(str), + _FieldDescriptor.CPPTYPE_STRING: TypeChecker(bytes), } diff --git a/python/google/protobuf/internal/unknown_fields_test.py b/python/google/protobuf/internal/unknown_fields_test.py new file mode 100755 index 0000000..7177560 --- /dev/null +++ b/python/google/protobuf/internal/unknown_fields_test.py @@ -0,0 +1,231 @@ +#! /usr/bin/python +# -*- coding: utf-8 -*- +# +# Protocol Buffers - Google's data interchange format +# Copyright 2008 Google Inc. All rights reserved. +# https://developers.google.com/protocol-buffers/ +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are +# met: +# +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above +# copyright notice, this list of conditions and the following disclaimer +# in the documentation and/or other materials provided with the +# distribution. +# * Neither the name of Google Inc. nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS +# "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT +# LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR +# A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT +# OWNER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, +# SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT +# LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, +# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY +# THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +"""Test for preservation of unknown fields in the pure Python implementation.""" + +__author__ = 'bohdank@google.com (Bohdan Koval)' + +from google.apputils import basetest +from google.protobuf import unittest_mset_pb2 +from google.protobuf import unittest_pb2 +from google.protobuf.internal import encoder +from google.protobuf.internal import missing_enum_values_pb2 +from google.protobuf.internal import test_util +from google.protobuf.internal import type_checkers + + +class UnknownFieldsTest(basetest.TestCase): + + def setUp(self): + self.descriptor = unittest_pb2.TestAllTypes.DESCRIPTOR + self.all_fields = unittest_pb2.TestAllTypes() + test_util.SetAllFields(self.all_fields) + self.all_fields_data = self.all_fields.SerializeToString() + self.empty_message = unittest_pb2.TestEmptyMessage() + self.empty_message.ParseFromString(self.all_fields_data) + self.unknown_fields = self.empty_message._unknown_fields + + def GetField(self, name): + field_descriptor = self.descriptor.fields_by_name[name] + wire_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type] + field_tag = encoder.TagBytes(field_descriptor.number, wire_type) + result_dict = {} + for tag_bytes, value in self.unknown_fields: + if tag_bytes == field_tag: + decoder = unittest_pb2.TestAllTypes._decoders_by_tag[tag_bytes][0] + decoder(value, 0, len(value), self.all_fields, result_dict) + return result_dict[field_descriptor] + + def testEnum(self): + value = self.GetField('optional_nested_enum') + self.assertEqual(self.all_fields.optional_nested_enum, value) + + def testRepeatedEnum(self): + value = self.GetField('repeated_nested_enum') + self.assertEqual(self.all_fields.repeated_nested_enum, value) + + def testVarint(self): + value = self.GetField('optional_int32') + self.assertEqual(self.all_fields.optional_int32, value) + + def testFixed32(self): + value = self.GetField('optional_fixed32') + self.assertEqual(self.all_fields.optional_fixed32, value) + + def testFixed64(self): + value = self.GetField('optional_fixed64') + self.assertEqual(self.all_fields.optional_fixed64, value) + + def testLengthDelimited(self): + value = self.GetField('optional_string') + self.assertEqual(self.all_fields.optional_string, value) + + def testGroup(self): + value = self.GetField('optionalgroup') + self.assertEqual(self.all_fields.optionalgroup, value) + + def testSerialize(self): + data = self.empty_message.SerializeToString() + + # Don't use assertEqual because we don't want to dump raw binary data to + # stdout. + self.assertTrue(data == self.all_fields_data) + + def testCopyFrom(self): + message = unittest_pb2.TestEmptyMessage() + message.CopyFrom(self.empty_message) + self.assertEqual(self.unknown_fields, message._unknown_fields) + + def testMergeFrom(self): + message = unittest_pb2.TestAllTypes() + message.optional_int32 = 1 + message.optional_uint32 = 2 + source = unittest_pb2.TestEmptyMessage() + source.ParseFromString(message.SerializeToString()) + + message.ClearField('optional_int32') + message.optional_int64 = 3 + message.optional_uint32 = 4 + destination = unittest_pb2.TestEmptyMessage() + destination.ParseFromString(message.SerializeToString()) + unknown_fields = destination._unknown_fields[:] + + destination.MergeFrom(source) + self.assertEqual(unknown_fields + source._unknown_fields, + destination._unknown_fields) + + def testClear(self): + self.empty_message.Clear() + self.assertEqual(0, len(self.empty_message._unknown_fields)) + + def testByteSize(self): + self.assertEqual(self.all_fields.ByteSize(), self.empty_message.ByteSize()) + + def testUnknownExtensions(self): + message = unittest_pb2.TestEmptyMessageWithExtensions() + message.ParseFromString(self.all_fields_data) + self.assertEqual(self.empty_message._unknown_fields, + message._unknown_fields) + + def testListFields(self): + # Make sure ListFields doesn't return unknown fields. + self.assertEqual(0, len(self.empty_message.ListFields())) + + def testSerializeMessageSetWireFormatUnknownExtension(self): + # Create a message using the message set wire format with an unknown + # message. + raw = unittest_mset_pb2.RawMessageSet() + + # Add an unknown extension. + item = raw.item.add() + item.type_id = 1545009 + message1 = unittest_mset_pb2.TestMessageSetExtension1() + message1.i = 12345 + item.message = message1.SerializeToString() + + serialized = raw.SerializeToString() + + # Parse message using the message set wire format. + proto = unittest_mset_pb2.TestMessageSet() + proto.MergeFromString(serialized) + + # Verify that the unknown extension is serialized unchanged + reserialized = proto.SerializeToString() + new_raw = unittest_mset_pb2.RawMessageSet() + new_raw.MergeFromString(reserialized) + self.assertEqual(raw, new_raw) + + def testEquals(self): + message = unittest_pb2.TestEmptyMessage() + message.ParseFromString(self.all_fields_data) + self.assertEqual(self.empty_message, message) + + self.all_fields.ClearField('optional_string') + message.ParseFromString(self.all_fields.SerializeToString()) + self.assertNotEqual(self.empty_message, message) + + +class UnknownFieldsTest(basetest.TestCase): + + def setUp(self): + self.descriptor = missing_enum_values_pb2.TestEnumValues.DESCRIPTOR + + self.message = missing_enum_values_pb2.TestEnumValues() + self.message.optional_nested_enum = ( + missing_enum_values_pb2.TestEnumValues.ZERO) + self.message.repeated_nested_enum.extend([ + missing_enum_values_pb2.TestEnumValues.ZERO, + missing_enum_values_pb2.TestEnumValues.ONE, + ]) + self.message.packed_nested_enum.extend([ + missing_enum_values_pb2.TestEnumValues.ZERO, + missing_enum_values_pb2.TestEnumValues.ONE, + ]) + self.message_data = self.message.SerializeToString() + self.missing_message = missing_enum_values_pb2.TestMissingEnumValues() + self.missing_message.ParseFromString(self.message_data) + self.unknown_fields = self.missing_message._unknown_fields + + def GetField(self, name): + field_descriptor = self.descriptor.fields_by_name[name] + wire_type = type_checkers.FIELD_TYPE_TO_WIRE_TYPE[field_descriptor.type] + field_tag = encoder.TagBytes(field_descriptor.number, wire_type) + result_dict = {} + for tag_bytes, value in self.unknown_fields: + if tag_bytes == field_tag: + decoder = missing_enum_values_pb2.TestEnumValues._decoders_by_tag[ + tag_bytes][0] + decoder(value, 0, len(value), self.message, result_dict) + return result_dict[field_descriptor] + + def testUnknownEnumValue(self): + self.assertFalse(self.missing_message.HasField('optional_nested_enum')) + value = self.GetField('optional_nested_enum') + self.assertEqual(self.message.optional_nested_enum, value) + + def testUnknownRepeatedEnumValue(self): + value = self.GetField('repeated_nested_enum') + self.assertEqual(self.message.repeated_nested_enum, value) + + def testUnknownPackedEnumValue(self): + value = self.GetField('packed_nested_enum') + self.assertEqual(self.message.packed_nested_enum, value) + + def testRoundTrip(self): + new_message = missing_enum_values_pb2.TestEnumValues() + new_message.ParseFromString(self.missing_message.SerializeToString()) + self.assertEqual(self.message, new_message) + + +if __name__ == '__main__': + basetest.main() diff --git a/python/google/protobuf/internal/wire_format.py b/python/google/protobuf/internal/wire_format.py index c941fe1..883f525 100755 --- a/python/google/protobuf/internal/wire_format.py +++ b/python/google/protobuf/internal/wire_format.py @@ -1,6 +1,6 @@ # Protocol Buffers - Google's data interchange format # Copyright 2008 Google Inc. All rights reserved. -# http://code.google.com/p/protobuf/ +# https://developers.google.com/protocol-buffers/ # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are diff --git a/python/google/protobuf/internal/wire_format_test.py b/python/google/protobuf/internal/wire_format_test.py index 7600778..f39035c 100755 --- a/python/google/protobuf/internal/wire_format_test.py +++ b/python/google/protobuf/internal/wire_format_test.py @@ -2,7 +2,7 @@ # # Protocol Buffers - Google's data interchange format # Copyright 2008 Google Inc. All rights reserved. -# http://code.google.com/p/protobuf/ +# https://developers.google.com/protocol-buffers/ # # Redistribution and use in source and binary forms, with or without # modification, are permitted provided that the following conditions are @@ -34,12 +34,12 @@ __author__ = 'robinson@google.com (Will Robinson)' -import unittest +from google.apputils import basetest from google.protobuf import message from google.protobuf.internal import wire_format -class WireFormatTest(unittest.TestCase): +class WireFormatTest(basetest.TestCase): def testPackTag(self): field_number = 0xabc @@ -195,7 +195,7 @@ class WireFormatTest(unittest.TestCase): # Test UTF-8 string byte size calculation. # 1 byte for tag, 1 byte for length, 8 bytes for content. self.assertEqual(10, wire_format.StringByteSize( - 5, unicode('\xd0\xa2\xd0\xb5\xd1\x81\xd1\x82', 'utf-8'))) + 5, b'\xd0\xa2\xd0\xb5\xd1\x81\xd1\x82'.decode('utf-8'))) class MockMessage(object): def __init__(self, byte_size): @@ -250,4 +250,4 @@ class WireFormatTest(unittest.TestCase): if __name__ == '__main__': - unittest.main() + basetest.main() |