diff options
Diffstat (limited to 'src/google/protobuf/extension_set_heavy.cc')
-rw-r--r-- | src/google/protobuf/extension_set_heavy.cc | 276 |
1 files changed, 256 insertions, 20 deletions
diff --git a/src/google/protobuf/extension_set_heavy.cc b/src/google/protobuf/extension_set_heavy.cc index 8555f6f..2721f15 100644 --- a/src/google/protobuf/extension_set_heavy.cc +++ b/src/google/protobuf/extension_set_heavy.cc @@ -40,11 +40,31 @@ #include <google/protobuf/message.h> #include <google/protobuf/repeated_field.h> #include <google/protobuf/wire_format.h> +#include <google/protobuf/wire_format_lite_inl.h> namespace google { namespace protobuf { namespace internal { +// Implementation of ExtensionFinder which finds extensions in a given +// DescriptorPool, using the given MessageFactory to construct sub-objects. +// This class is implemented in extension_set_heavy.cc. +class DescriptorPoolExtensionFinder : public ExtensionFinder { + public: + DescriptorPoolExtensionFinder(const DescriptorPool* pool, + MessageFactory* factory, + const Descriptor* containing_type) + : pool_(pool), factory_(factory), containing_type_(containing_type) {} + virtual ~DescriptorPoolExtensionFinder() {} + + virtual bool Find(int number, ExtensionInfo* output); + + private: + const DescriptorPool* pool_; + MessageFactory* factory_; + const Descriptor* containing_type_; +}; + void ExtensionSet::AppendToList(const Descriptor* containing_type, const DescriptorPool* pool, vector<const FieldDescriptor*>* output) const { @@ -58,12 +78,26 @@ void ExtensionSet::AppendToList(const Descriptor* containing_type, } if (has) { - output->push_back( - pool->FindExtensionByNumber(containing_type, iter->first)); + // TODO(kenton): Looking up each field by number is somewhat unfortunate. + // Is there a better way? The problem is that descriptors are lazily- + // initialized, so they might not even be constructed until + // AppendToList() is called. + + if (iter->second.descriptor == NULL) { + output->push_back(pool->FindExtensionByNumber( + containing_type, iter->first)); + } else { + output->push_back(iter->second.descriptor); + } } } } +inline FieldDescriptor::Type real_type(FieldType type) { + GOOGLE_DCHECK(type > 0 && type <= FieldDescriptor::MAX_TYPE); + return static_cast<FieldDescriptor::Type>(type); +} + inline FieldDescriptor::CppType cpp_type(FieldType type) { return FieldDescriptor::TypeToCppType( static_cast<FieldDescriptor::Type>(type)); @@ -88,16 +122,16 @@ const MessageLite& ExtensionSet::GetMessage(int number, } } -MessageLite* ExtensionSet::MutableMessage(int number, FieldType type, - const Descriptor* message_type, +MessageLite* ExtensionSet::MutableMessage(const FieldDescriptor* descriptor, MessageFactory* factory) { Extension* extension; - if (MaybeNewExtension(number, &extension)) { - extension->type = type; + if (MaybeNewExtension(descriptor->number(), descriptor, &extension)) { + extension->type = descriptor->type(); GOOGLE_DCHECK_EQ(cpp_type(extension->type), FieldDescriptor::CPPTYPE_MESSAGE); extension->is_repeated = false; extension->is_packed = false; - const MessageLite* prototype = factory->GetPrototype(message_type); + const MessageLite* prototype = + factory->GetPrototype(descriptor->message_type()); GOOGLE_CHECK(prototype != NULL); extension->message_value = prototype->New(); } else { @@ -107,12 +141,11 @@ MessageLite* ExtensionSet::MutableMessage(int number, FieldType type, return extension->message_value; } -MessageLite* ExtensionSet::AddMessage(int number, FieldType type, - const Descriptor* message_type, +MessageLite* ExtensionSet::AddMessage(const FieldDescriptor* descriptor, MessageFactory* factory) { Extension* extension; - if (MaybeNewExtension(number, &extension)) { - extension->type = type; + if (MaybeNewExtension(descriptor->number(), descriptor, &extension)) { + extension->type = descriptor->type(); GOOGLE_DCHECK_EQ(cpp_type(extension->type), FieldDescriptor::CPPTYPE_MESSAGE); extension->is_repeated = true; extension->repeated_message_value = @@ -128,7 +161,7 @@ MessageLite* ExtensionSet::AddMessage(int number, FieldType type, if (result == NULL) { const MessageLite* prototype; if (extension->repeated_message_value->size() == 0) { - prototype = factory->GetPrototype(message_type); + prototype = factory->GetPrototype(descriptor->message_type()); GOOGLE_CHECK(prototype != NULL); } else { prototype = &extension->repeated_message_value->Get(0); @@ -139,18 +172,64 @@ MessageLite* ExtensionSet::AddMessage(int number, FieldType type, return result; } +static bool ValidateEnumUsingDescriptor(const void* arg, int number) { + return reinterpret_cast<const EnumDescriptor*>(arg) + ->FindValueByNumber(number) != NULL; +} + +bool DescriptorPoolExtensionFinder::Find(int number, ExtensionInfo* output) { + const FieldDescriptor* extension = + pool_->FindExtensionByNumber(containing_type_, number); + if (extension == NULL) { + return false; + } else { + output->type = extension->type(); + output->is_repeated = extension->is_repeated(); + output->is_packed = extension->options().packed(); + output->descriptor = extension; + if (extension->cpp_type() == FieldDescriptor::CPPTYPE_MESSAGE) { + output->message_prototype = + factory_->GetPrototype(extension->message_type()); + GOOGLE_CHECK(output->message_prototype != NULL) + << "Extension factory's GetPrototype() returned NULL for extension: " + << extension->full_name(); + } else if (extension->cpp_type() == FieldDescriptor::CPPTYPE_ENUM) { + output->enum_validity_check.func = ValidateEnumUsingDescriptor; + output->enum_validity_check.arg = extension->enum_type(); + } + + return true; + } +} + bool ExtensionSet::ParseField(uint32 tag, io::CodedInputStream* input, - const MessageLite* containing_type, + const Message* containing_type, UnknownFieldSet* unknown_fields) { UnknownFieldSetFieldSkipper skipper(unknown_fields); - return ParseField(tag, input, containing_type, &skipper); + if (input->GetExtensionPool() == NULL) { + GeneratedExtensionFinder finder(containing_type); + return ParseField(tag, input, &finder, &skipper); + } else { + DescriptorPoolExtensionFinder finder(input->GetExtensionPool(), + input->GetExtensionFactory(), + containing_type->GetDescriptor()); + return ParseField(tag, input, &finder, &skipper); + } } bool ExtensionSet::ParseMessageSet(io::CodedInputStream* input, - const MessageLite* containing_type, + const Message* containing_type, UnknownFieldSet* unknown_fields) { UnknownFieldSetFieldSkipper skipper(unknown_fields); - return ParseMessageSet(input, containing_type, &skipper); + if (input->GetExtensionPool() == NULL) { + GeneratedExtensionFinder finder(containing_type); + return ParseMessageSet(input, &finder, &skipper); + } else { + DescriptorPoolExtensionFinder finder(input->GetExtensionPool(), + input->GetExtensionFactory(), + containing_type->GetDescriptor()); + return ParseMessageSet(input, &finder, &skipper); + } } int ExtensionSet::SpaceUsedExcludingSelf() const { @@ -175,7 +254,7 @@ int ExtensionSet::Extension::SpaceUsedExcludingSelf() const { if (is_repeated) { switch (cpp_type(type)) { #define HANDLE_TYPE(UPPERCASE, LOWERCASE) \ - case WireFormatLite::CPPTYPE_##UPPERCASE: \ + case FieldDescriptor::CPPTYPE_##UPPERCASE: \ total_size += sizeof(*repeated_##LOWERCASE##_value) + \ repeated_##LOWERCASE##_value->SpaceUsedExcludingSelf();\ break @@ -189,8 +268,9 @@ int ExtensionSet::Extension::SpaceUsedExcludingSelf() const { HANDLE_TYPE( BOOL, bool); HANDLE_TYPE( ENUM, enum); HANDLE_TYPE( STRING, string); +#undef HANDLE_TYPE - case WireFormatLite::CPPTYPE_MESSAGE: + case FieldDescriptor::CPPTYPE_MESSAGE: // repeated_message_value is actually a RepeatedPtrField<MessageLite>, // but MessageLite has no SpaceUsed(), so we must directly call // RepeatedPtrFieldBase::SpaceUsedExcludingSelf() with a different type @@ -201,11 +281,11 @@ int ExtensionSet::Extension::SpaceUsedExcludingSelf() const { } } else { switch (cpp_type(type)) { - case WireFormatLite::CPPTYPE_STRING: + case FieldDescriptor::CPPTYPE_STRING: total_size += sizeof(*string_value) + StringSpaceUsedExcludingSelf(*string_value); break; - case WireFormatLite::CPPTYPE_MESSAGE: + case FieldDescriptor::CPPTYPE_MESSAGE: total_size += down_cast<Message*>(message_value)->SpaceUsed(); break; default: @@ -216,6 +296,162 @@ int ExtensionSet::Extension::SpaceUsedExcludingSelf() const { return total_size; } +// The Serialize*ToArray methods are only needed in the heavy library, as +// the lite library only generates SerializeWithCachedSizes. +uint8* ExtensionSet::SerializeWithCachedSizesToArray( + int start_field_number, int end_field_number, + uint8* target) const { + map<int, Extension>::const_iterator iter; + for (iter = extensions_.lower_bound(start_field_number); + iter != extensions_.end() && iter->first < end_field_number; + ++iter) { + target = iter->second.SerializeFieldWithCachedSizesToArray(iter->first, + target); + } + return target; +} + +uint8* ExtensionSet::SerializeMessageSetWithCachedSizesToArray( + uint8* target) const { + map<int, Extension>::const_iterator iter; + for (iter = extensions_.begin(); iter != extensions_.end(); ++iter) { + target = iter->second.SerializeMessageSetItemWithCachedSizesToArray( + iter->first, target); + } + return target; +} + +uint8* ExtensionSet::Extension::SerializeFieldWithCachedSizesToArray( + int number, uint8* target) const { + if (is_repeated) { + if (is_packed) { + if (cached_size == 0) return target; + + target = WireFormatLite::WriteTagToArray(number, + WireFormatLite::WIRETYPE_LENGTH_DELIMITED, target); + target = WireFormatLite::WriteInt32NoTagToArray(cached_size, target); + + switch (real_type(type)) { +#define HANDLE_TYPE(UPPERCASE, CAMELCASE, LOWERCASE) \ + case FieldDescriptor::TYPE_##UPPERCASE: \ + for (int i = 0; i < repeated_##LOWERCASE##_value->size(); i++) { \ + target = WireFormatLite::Write##CAMELCASE##NoTagToArray( \ + repeated_##LOWERCASE##_value->Get(i), target); \ + } \ + break + + HANDLE_TYPE( INT32, Int32, int32); + HANDLE_TYPE( INT64, Int64, int64); + HANDLE_TYPE( UINT32, UInt32, uint32); + HANDLE_TYPE( UINT64, UInt64, uint64); + HANDLE_TYPE( SINT32, SInt32, int32); + HANDLE_TYPE( SINT64, SInt64, int64); + HANDLE_TYPE( FIXED32, Fixed32, uint32); + HANDLE_TYPE( FIXED64, Fixed64, uint64); + HANDLE_TYPE(SFIXED32, SFixed32, int32); + HANDLE_TYPE(SFIXED64, SFixed64, int64); + HANDLE_TYPE( FLOAT, Float, float); + HANDLE_TYPE( DOUBLE, Double, double); + HANDLE_TYPE( BOOL, Bool, bool); + HANDLE_TYPE( ENUM, Enum, enum); +#undef HANDLE_TYPE + + case WireFormatLite::TYPE_STRING: + case WireFormatLite::TYPE_BYTES: + case WireFormatLite::TYPE_GROUP: + case WireFormatLite::TYPE_MESSAGE: + GOOGLE_LOG(FATAL) << "Non-primitive types can't be packed."; + break; + } + } else { + switch (real_type(type)) { +#define HANDLE_TYPE(UPPERCASE, CAMELCASE, LOWERCASE) \ + case FieldDescriptor::TYPE_##UPPERCASE: \ + for (int i = 0; i < repeated_##LOWERCASE##_value->size(); i++) { \ + target = WireFormatLite::Write##CAMELCASE##ToArray(number, \ + repeated_##LOWERCASE##_value->Get(i), target); \ + } \ + break + + HANDLE_TYPE( INT32, Int32, int32); + HANDLE_TYPE( INT64, Int64, int64); + HANDLE_TYPE( UINT32, UInt32, uint32); + HANDLE_TYPE( UINT64, UInt64, uint64); + HANDLE_TYPE( SINT32, SInt32, int32); + HANDLE_TYPE( SINT64, SInt64, int64); + HANDLE_TYPE( FIXED32, Fixed32, uint32); + HANDLE_TYPE( FIXED64, Fixed64, uint64); + HANDLE_TYPE(SFIXED32, SFixed32, int32); + HANDLE_TYPE(SFIXED64, SFixed64, int64); + HANDLE_TYPE( FLOAT, Float, float); + HANDLE_TYPE( DOUBLE, Double, double); + HANDLE_TYPE( BOOL, Bool, bool); + HANDLE_TYPE( STRING, String, string); + HANDLE_TYPE( BYTES, Bytes, string); + HANDLE_TYPE( ENUM, Enum, enum); + HANDLE_TYPE( GROUP, Group, message); + HANDLE_TYPE( MESSAGE, Message, message); +#undef HANDLE_TYPE + } + } + } else if (!is_cleared) { + switch (real_type(type)) { +#define HANDLE_TYPE(UPPERCASE, CAMELCASE, VALUE) \ + case FieldDescriptor::TYPE_##UPPERCASE: \ + target = WireFormatLite::Write##CAMELCASE##ToArray( \ + number, VALUE, target); \ + break + + HANDLE_TYPE( INT32, Int32, int32_value); + HANDLE_TYPE( INT64, Int64, int64_value); + HANDLE_TYPE( UINT32, UInt32, uint32_value); + HANDLE_TYPE( UINT64, UInt64, uint64_value); + HANDLE_TYPE( SINT32, SInt32, int32_value); + HANDLE_TYPE( SINT64, SInt64, int64_value); + HANDLE_TYPE( FIXED32, Fixed32, uint32_value); + HANDLE_TYPE( FIXED64, Fixed64, uint64_value); + HANDLE_TYPE(SFIXED32, SFixed32, int32_value); + HANDLE_TYPE(SFIXED64, SFixed64, int64_value); + HANDLE_TYPE( FLOAT, Float, float_value); + HANDLE_TYPE( DOUBLE, Double, double_value); + HANDLE_TYPE( BOOL, Bool, bool_value); + HANDLE_TYPE( STRING, String, *string_value); + HANDLE_TYPE( BYTES, Bytes, *string_value); + HANDLE_TYPE( ENUM, Enum, enum_value); + HANDLE_TYPE( GROUP, Group, *message_value); + HANDLE_TYPE( MESSAGE, Message, *message_value); +#undef HANDLE_TYPE + } + } + return target; +} + +uint8* ExtensionSet::Extension::SerializeMessageSetItemWithCachedSizesToArray( + int number, + uint8* target) const { + if (type != WireFormatLite::TYPE_MESSAGE || is_repeated) { + // Not a valid MessageSet extension, but serialize it the normal way. + GOOGLE_LOG(WARNING) << "Invalid message set extension."; + return SerializeFieldWithCachedSizesToArray(number, target); + } + + if (is_cleared) return target; + + // Start group. + target = io::CodedOutputStream::WriteTagToArray( + WireFormatLite::kMessageSetItemStartTag, target); + // Write type ID. + target = WireFormatLite::WriteUInt32ToArray( + WireFormatLite::kMessageSetTypeIdNumber, number, target); + // Write message. + target = WireFormatLite::WriteMessageToArray( + WireFormatLite::kMessageSetMessageNumber, *message_value, target); + // End group. + target = io::CodedOutputStream::WriteTagToArray( + WireFormatLite::kMessageSetItemEndTag, target); + return target; +} + } // namespace internal } // namespace protobuf } // namespace google |