aboutsummaryrefslogtreecommitdiffstats
path: root/src/google/protobuf/extension_set_heavy.cc
diff options
context:
space:
mode:
Diffstat (limited to 'src/google/protobuf/extension_set_heavy.cc')
-rw-r--r--src/google/protobuf/extension_set_heavy.cc301
1 files changed, 289 insertions, 12 deletions
diff --git a/src/google/protobuf/extension_set_heavy.cc b/src/google/protobuf/extension_set_heavy.cc
index 2721f15..1d971fd 100644
--- a/src/google/protobuf/extension_set_heavy.cc
+++ b/src/google/protobuf/extension_set_heavy.cc
@@ -35,17 +35,43 @@
// Contains methods defined in extension_set.h which cannot be part of the
// lite library because they use descriptors or reflection.
-#include <google/protobuf/extension_set.h>
+#include <google/protobuf/io/zero_copy_stream_impl_lite.h>
#include <google/protobuf/descriptor.h>
+#include <google/protobuf/extension_set.h>
#include <google/protobuf/message.h>
#include <google/protobuf/repeated_field.h>
#include <google/protobuf/wire_format.h>
#include <google/protobuf/wire_format_lite_inl.h>
namespace google {
+
namespace protobuf {
namespace internal {
+// A FieldSkipper used to store unknown MessageSet fields into UnknownFieldSet.
+class MessageSetFieldSkipper
+ : public UnknownFieldSetFieldSkipper {
+ public:
+ explicit MessageSetFieldSkipper(UnknownFieldSet* unknown_fields)
+ : UnknownFieldSetFieldSkipper(unknown_fields) {}
+ virtual ~MessageSetFieldSkipper() {}
+
+ virtual bool SkipMessageSetField(io::CodedInputStream* input,
+ int field_number);
+};
+bool MessageSetFieldSkipper::SkipMessageSetField(
+ io::CodedInputStream* input, int field_number) {
+ uint32 length;
+ if (!input->ReadVarint32(&length)) return false;
+ if (unknown_fields_ == NULL) {
+ return input->Skip(length);
+ } else {
+ return input->ReadString(
+ unknown_fields_->AddLengthDelimited(field_number), length);
+ }
+}
+
+
// Implementation of ExtensionFinder which finds extensions in a given
// DescriptorPool, using the given MessageFactory to construct sub-objects.
// This class is implemented in extension_set_heavy.cc.
@@ -67,7 +93,7 @@ class DescriptorPoolExtensionFinder : public ExtensionFinder {
void ExtensionSet::AppendToList(const Descriptor* containing_type,
const DescriptorPool* pool,
- vector<const FieldDescriptor*>* output) const {
+ std::vector<const FieldDescriptor*>* output) const {
for (map<int, Extension>::const_iterator iter = extensions_.begin();
iter != extensions_.end(); ++iter) {
bool has = false;
@@ -103,6 +129,11 @@ inline FieldDescriptor::CppType cpp_type(FieldType type) {
static_cast<FieldDescriptor::Type>(type));
}
+inline WireFormatLite::FieldType field_type(FieldType type) {
+ GOOGLE_DCHECK(type > 0 && type <= WireFormatLite::MAX_FIELD_TYPE);
+ return static_cast<WireFormatLite::FieldType>(type);
+}
+
#define GOOGLE_DCHECK_TYPE(EXTENSION, LABEL, CPPTYPE) \
GOOGLE_DCHECK_EQ((EXTENSION).is_repeated ? FieldDescriptor::LABEL_REPEATED \
: FieldDescriptor::LABEL_OPTIONAL, \
@@ -118,7 +149,12 @@ const MessageLite& ExtensionSet::GetMessage(int number,
return *factory->GetPrototype(message_type);
} else {
GOOGLE_DCHECK_TYPE(iter->second, OPTIONAL, MESSAGE);
- return *iter->second.message_value;
+ if (iter->second.is_lazy) {
+ return iter->second.lazymessage_value->GetMessage(
+ *factory->GetPrototype(message_type));
+ } else {
+ return *iter->second.message_value;
+ }
}
}
@@ -132,13 +168,41 @@ MessageLite* ExtensionSet::MutableMessage(const FieldDescriptor* descriptor,
extension->is_packed = false;
const MessageLite* prototype =
factory->GetPrototype(descriptor->message_type());
- GOOGLE_CHECK(prototype != NULL);
+ extension->is_lazy = false;
extension->message_value = prototype->New();
+ extension->is_cleared = false;
+ return extension->message_value;
} else {
GOOGLE_DCHECK_TYPE(*extension, OPTIONAL, MESSAGE);
+ extension->is_cleared = false;
+ if (extension->is_lazy) {
+ return extension->lazymessage_value->MutableMessage(
+ *factory->GetPrototype(descriptor->message_type()));
+ } else {
+ return extension->message_value;
+ }
+ }
+}
+
+MessageLite* ExtensionSet::ReleaseMessage(const FieldDescriptor* descriptor,
+ MessageFactory* factory) {
+ map<int, Extension>::iterator iter = extensions_.find(descriptor->number());
+ if (iter == extensions_.end()) {
+ // Not present. Return NULL.
+ return NULL;
+ } else {
+ GOOGLE_DCHECK_TYPE(iter->second, OPTIONAL, MESSAGE);
+ MessageLite* ret = NULL;
+ if (iter->second.is_lazy) {
+ ret = iter->second.lazymessage_value->ReleaseMessage(
+ *factory->GetPrototype(descriptor->message_type()));
+ delete iter->second.lazymessage_value;
+ } else {
+ ret = iter->second.message_value;
+ }
+ extensions_.erase(descriptor->number());
+ return ret;
}
- extension->is_cleared = false;
- return extension->message_value;
}
MessageLite* ExtensionSet::AddMessage(const FieldDescriptor* descriptor,
@@ -157,7 +221,7 @@ MessageLite* ExtensionSet::AddMessage(const FieldDescriptor* descriptor,
// RepeatedPtrField<Message> does not know how to Add() since it cannot
// allocate an abstract object, so we have to be tricky.
MessageLite* result = extension->repeated_message_value
- ->AddFromCleared<internal::GenericTypeHandler<MessageLite> >();
+ ->AddFromCleared<GenericTypeHandler<MessageLite> >();
if (result == NULL) {
const MessageLite* prototype;
if (extension->repeated_message_value->size() == 0) {
@@ -220,7 +284,7 @@ bool ExtensionSet::ParseField(uint32 tag, io::CodedInputStream* input,
bool ExtensionSet::ParseMessageSet(io::CodedInputStream* input,
const Message* containing_type,
UnknownFieldSet* unknown_fields) {
- UnknownFieldSetFieldSkipper skipper(unknown_fields);
+ MessageSetFieldSkipper skipper(unknown_fields);
if (input->GetExtensionPool() == NULL) {
GeneratedExtensionFinder finder(containing_type);
return ParseMessageSet(input, &finder, &skipper);
@@ -286,7 +350,11 @@ int ExtensionSet::Extension::SpaceUsedExcludingSelf() const {
StringSpaceUsedExcludingSelf(*string_value);
break;
case FieldDescriptor::CPPTYPE_MESSAGE:
- total_size += down_cast<Message*>(message_value)->SpaceUsed();
+ if (is_lazy) {
+ total_size += lazymessage_value->SpaceUsed();
+ } else {
+ total_size += down_cast<Message*>(message_value)->SpaceUsed();
+ }
break;
default:
// No extra storage costs for primitive types.
@@ -419,8 +487,15 @@ uint8* ExtensionSet::Extension::SerializeFieldWithCachedSizesToArray(
HANDLE_TYPE( BYTES, Bytes, *string_value);
HANDLE_TYPE( ENUM, Enum, enum_value);
HANDLE_TYPE( GROUP, Group, *message_value);
- HANDLE_TYPE( MESSAGE, Message, *message_value);
#undef HANDLE_TYPE
+ case FieldDescriptor::TYPE_MESSAGE:
+ if (is_lazy) {
+ target = lazymessage_value->WriteMessageToArray(number, target);
+ } else {
+ target = WireFormatLite::WriteMessageToArray(
+ number, *message_value, target);
+ }
+ break;
}
}
return target;
@@ -444,14 +519,216 @@ uint8* ExtensionSet::Extension::SerializeMessageSetItemWithCachedSizesToArray(
target = WireFormatLite::WriteUInt32ToArray(
WireFormatLite::kMessageSetTypeIdNumber, number, target);
// Write message.
- target = WireFormatLite::WriteMessageToArray(
- WireFormatLite::kMessageSetMessageNumber, *message_value, target);
+ if (is_lazy) {
+ target = lazymessage_value->WriteMessageToArray(
+ WireFormatLite::kMessageSetMessageNumber, target);
+ } else {
+ target = WireFormatLite::WriteMessageToArray(
+ WireFormatLite::kMessageSetMessageNumber, *message_value, target);
+ }
// End group.
target = io::CodedOutputStream::WriteTagToArray(
WireFormatLite::kMessageSetItemEndTag, target);
return target;
}
+
+bool ExtensionSet::ParseFieldMaybeLazily(
+ int wire_type, int field_number, io::CodedInputStream* input,
+ ExtensionFinder* extension_finder,
+ MessageSetFieldSkipper* field_skipper) {
+ return ParseField(WireFormatLite::MakeTag(
+ field_number, static_cast<WireFormatLite::WireType>(wire_type)),
+ input, extension_finder, field_skipper);
+}
+
+bool ExtensionSet::ParseMessageSet(io::CodedInputStream* input,
+ ExtensionFinder* extension_finder,
+ MessageSetFieldSkipper* field_skipper) {
+ while (true) {
+ const uint32 tag = input->ReadTag();
+ switch (tag) {
+ case 0:
+ return true;
+ case WireFormatLite::kMessageSetItemStartTag:
+ if (!ParseMessageSetItem(input, extension_finder, field_skipper)) {
+ return false;
+ }
+ break;
+ default:
+ if (!ParseField(tag, input, extension_finder, field_skipper)) {
+ return false;
+ }
+ break;
+ }
+ }
+}
+
+bool ExtensionSet::ParseMessageSet(io::CodedInputStream* input,
+ const MessageLite* containing_type) {
+ MessageSetFieldSkipper skipper(NULL);
+ GeneratedExtensionFinder finder(containing_type);
+ return ParseMessageSet(input, &finder, &skipper);
+}
+
+bool ExtensionSet::ParseMessageSetItem(io::CodedInputStream* input,
+ ExtensionFinder* extension_finder,
+ MessageSetFieldSkipper* field_skipper) {
+ // TODO(kenton): It would be nice to share code between this and
+ // WireFormatLite::ParseAndMergeMessageSetItem(), but I think the
+ // differences would be hard to factor out.
+
+ // This method parses a group which should contain two fields:
+ // required int32 type_id = 2;
+ // required data message = 3;
+
+ uint32 last_type_id = 0;
+
+ // If we see message data before the type_id, we'll append it to this so
+ // we can parse it later.
+ string message_data;
+
+ while (true) {
+ const uint32 tag = input->ReadTag();
+ if (tag == 0) return false;
+
+ switch (tag) {
+ case WireFormatLite::kMessageSetTypeIdTag: {
+ uint32 type_id;
+ if (!input->ReadVarint32(&type_id)) return false;
+ last_type_id = type_id;
+
+ if (!message_data.empty()) {
+ // We saw some message data before the type_id. Have to parse it
+ // now.
+ io::CodedInputStream sub_input(
+ reinterpret_cast<const uint8*>(message_data.data()),
+ message_data.size());
+ if (!ParseFieldMaybeLazily(WireFormatLite::WIRETYPE_LENGTH_DELIMITED,
+ last_type_id, &sub_input,
+ extension_finder, field_skipper)) {
+ return false;
+ }
+ message_data.clear();
+ }
+
+ break;
+ }
+
+ case WireFormatLite::kMessageSetMessageTag: {
+ if (last_type_id == 0) {
+ // We haven't seen a type_id yet. Append this data to message_data.
+ string temp;
+ uint32 length;
+ if (!input->ReadVarint32(&length)) return false;
+ if (!input->ReadString(&temp, length)) return false;
+ io::StringOutputStream output_stream(&message_data);
+ io::CodedOutputStream coded_output(&output_stream);
+ coded_output.WriteVarint32(length);
+ coded_output.WriteString(temp);
+ } else {
+ // Already saw type_id, so we can parse this directly.
+ if (!ParseFieldMaybeLazily(WireFormatLite::WIRETYPE_LENGTH_DELIMITED,
+ last_type_id, input,
+ extension_finder, field_skipper)) {
+ return false;
+ }
+ }
+
+ break;
+ }
+
+ case WireFormatLite::kMessageSetItemEndTag: {
+ return true;
+ }
+
+ default: {
+ if (!field_skipper->SkipField(input, tag)) return false;
+ }
+ }
+ }
+}
+
+void ExtensionSet::Extension::SerializeMessageSetItemWithCachedSizes(
+ int number,
+ io::CodedOutputStream* output) const {
+ if (type != WireFormatLite::TYPE_MESSAGE || is_repeated) {
+ // Not a valid MessageSet extension, but serialize it the normal way.
+ SerializeFieldWithCachedSizes(number, output);
+ return;
+ }
+
+ if (is_cleared) return;
+
+ // Start group.
+ output->WriteTag(WireFormatLite::kMessageSetItemStartTag);
+
+ // Write type ID.
+ WireFormatLite::WriteUInt32(WireFormatLite::kMessageSetTypeIdNumber,
+ number,
+ output);
+ // Write message.
+ if (is_lazy) {
+ lazymessage_value->WriteMessage(
+ WireFormatLite::kMessageSetMessageNumber, output);
+ } else {
+ WireFormatLite::WriteMessageMaybeToArray(
+ WireFormatLite::kMessageSetMessageNumber,
+ *message_value,
+ output);
+ }
+
+ // End group.
+ output->WriteTag(WireFormatLite::kMessageSetItemEndTag);
+}
+
+int ExtensionSet::Extension::MessageSetItemByteSize(int number) const {
+ if (type != WireFormatLite::TYPE_MESSAGE || is_repeated) {
+ // Not a valid MessageSet extension, but compute the byte size for it the
+ // normal way.
+ return ByteSize(number);
+ }
+
+ if (is_cleared) return 0;
+
+ int our_size = WireFormatLite::kMessageSetItemTagsSize;
+
+ // type_id
+ our_size += io::CodedOutputStream::VarintSize32(number);
+
+ // message
+ int message_size = 0;
+ if (is_lazy) {
+ message_size = lazymessage_value->ByteSize();
+ } else {
+ message_size = message_value->ByteSize();
+ }
+
+ our_size += io::CodedOutputStream::VarintSize32(message_size);
+ our_size += message_size;
+
+ return our_size;
+}
+
+void ExtensionSet::SerializeMessageSetWithCachedSizes(
+ io::CodedOutputStream* output) const {
+ for (map<int, Extension>::const_iterator iter = extensions_.begin();
+ iter != extensions_.end(); ++iter) {
+ iter->second.SerializeMessageSetItemWithCachedSizes(iter->first, output);
+ }
+}
+
+int ExtensionSet::MessageSetByteSize() const {
+ int total_size = 0;
+
+ for (map<int, Extension>::const_iterator iter = extensions_.begin();
+ iter != extensions_.end(); ++iter) {
+ total_size += iter->second.MessageSetItemByteSize(iter->first);
+ }
+
+ return total_size;
+}
+
} // namespace internal
} // namespace protobuf
} // namespace google