diff options
4 files changed, 183 insertions, 50 deletions
diff --git a/java/src/test/java/com/google/protobuf/NanoTest.java b/java/src/test/java/com/google/protobuf/NanoTest.java index bb95a17..687bc16 100644 --- a/java/src/test/java/com/google/protobuf/NanoTest.java +++ b/java/src/test/java/com/google/protobuf/NanoTest.java @@ -2886,13 +2886,6 @@ public class NanoTest extends TestCase { TestAllTypesNano.BAR, TestAllTypesNano.BAZ }; - // We set the _nan fields to something other than nan, because equality - // is defined for nan such that Float.NaN != Float.NaN, which makes any - // instance of TestAllTypesNano unequal to any other instance unless - // these fields are set. This is also the behavior of the regular java - // generator when the value of a field is NaN. - message.defaultFloatNan = 1.0f; - message.defaultDoubleNan = 1.0; return message; } @@ -2915,7 +2908,6 @@ public class NanoTest extends TestCase { TestAllTypesNano.BAR, TestAllTypesNano.BAZ }; - message.defaultFloatNan = 1.0f; return message; } @@ -2924,8 +2916,7 @@ public class NanoTest extends TestCase { .setOptionalInt32(5) .setOptionalString("Hello") .setOptionalBytes(new byte[] {1, 2, 3}) - .setOptionalNestedEnum(TestNanoAccessors.BAR) - .setDefaultFloatNan(1.0f); + .setOptionalNestedEnum(TestNanoAccessors.BAR); message.optionalNestedMessage = new TestNanoAccessors.NestedMessage().setBb(27); message.repeatedInt32 = new int[] { 5, 6, 7, 8 }; message.repeatedString = new String[] { "One", "Two" }; @@ -2973,6 +2964,126 @@ public class NanoTest extends TestCase { return message; } + public void testEqualsWithSpecialFloatingPointValues() throws Exception { + // Checks that the nano implementation complies with Object.equals() when treating + // floating point numbers, i.e. NaN == NaN and +0.0 != -0.0. + // This test assumes that the generated equals() implementations are symmetric, so + // there will only be one direction for each equality check. + + TestAllTypesNano m1 = new TestAllTypesNano(); + m1.optionalFloat = Float.NaN; + m1.optionalDouble = Double.NaN; + TestAllTypesNano m2 = new TestAllTypesNano(); + m2.optionalFloat = Float.NaN; + m2.optionalDouble = Double.NaN; + assertTrue(m1.equals(m2)); + assertTrue(m1.equals( + MessageNano.mergeFrom(new TestAllTypesNano(), MessageNano.toByteArray(m1)))); + + m1.optionalFloat = +0f; + m2.optionalFloat = -0f; + assertFalse(m1.equals(m2)); + + m1.optionalFloat = -0f; + m1.optionalDouble = +0d; + m2.optionalDouble = -0d; + assertFalse(m1.equals(m2)); + + m1.optionalDouble = -0d; + assertTrue(m1.equals(m2)); + assertFalse(m1.equals(new TestAllTypesNano())); // -0 does not equals() the default +0 + assertTrue(m1.equals( + MessageNano.mergeFrom(new TestAllTypesNano(), MessageNano.toByteArray(m1)))); + + // ------- + + TestAllTypesNanoHas m3 = new TestAllTypesNanoHas(); + m3.optionalFloat = Float.NaN; + m3.hasOptionalFloat = true; + m3.optionalDouble = Double.NaN; + m3.hasOptionalDouble = true; + TestAllTypesNanoHas m4 = new TestAllTypesNanoHas(); + m4.optionalFloat = Float.NaN; + m4.hasOptionalFloat = true; + m4.optionalDouble = Double.NaN; + m4.hasOptionalDouble = true; + assertTrue(m3.equals(m4)); + assertTrue(m3.equals( + MessageNano.mergeFrom(new TestAllTypesNanoHas(), MessageNano.toByteArray(m3)))); + + m3.optionalFloat = +0f; + m4.optionalFloat = -0f; + assertFalse(m3.equals(m4)); + + m3.optionalFloat = -0f; + m3.optionalDouble = +0d; + m4.optionalDouble = -0d; + assertFalse(m3.equals(m4)); + + m3.optionalDouble = -0d; + m3.hasOptionalFloat = false; // -0 does not equals() the default +0, + m3.hasOptionalDouble = false; // so these incorrect 'has' flags should be disregarded. + assertTrue(m3.equals(m4)); // note: m4 has the 'has' flags set. + assertFalse(m3.equals(new TestAllTypesNanoHas())); // note: the new message has +0 defaults + assertTrue(m3.equals( + MessageNano.mergeFrom(new TestAllTypesNanoHas(), MessageNano.toByteArray(m3)))); + // note: the deserialized message has the 'has' flags set. + + // ------- + + TestNanoAccessors m5 = new TestNanoAccessors(); + m5.setOptionalFloat(Float.NaN); + m5.setOptionalDouble(Double.NaN); + TestNanoAccessors m6 = new TestNanoAccessors(); + m6.setOptionalFloat(Float.NaN); + m6.setOptionalDouble(Double.NaN); + assertTrue(m5.equals(m6)); + assertTrue(m5.equals( + MessageNano.mergeFrom(new TestNanoAccessors(), MessageNano.toByteArray(m6)))); + + m5.setOptionalFloat(+0f); + m6.setOptionalFloat(-0f); + assertFalse(m5.equals(m6)); + + m5.setOptionalFloat(-0f); + m5.setOptionalDouble(+0d); + m6.setOptionalDouble(-0d); + assertFalse(m5.equals(m6)); + + m5.setOptionalDouble(-0d); + assertTrue(m5.equals(m6)); + assertFalse(m5.equals(new TestNanoAccessors())); + assertTrue(m5.equals( + MessageNano.mergeFrom(new TestNanoAccessors(), MessageNano.toByteArray(m6)))); + + // ------- + + NanoReferenceTypes.TestAllTypesNano m7 = new NanoReferenceTypes.TestAllTypesNano(); + m7.optionalFloat = Float.NaN; + m7.optionalDouble = Double.NaN; + NanoReferenceTypes.TestAllTypesNano m8 = new NanoReferenceTypes.TestAllTypesNano(); + m8.optionalFloat = Float.NaN; + m8.optionalDouble = Double.NaN; + assertTrue(m7.equals(m8)); + assertTrue(m7.equals(MessageNano.mergeFrom( + new NanoReferenceTypes.TestAllTypesNano(), MessageNano.toByteArray(m7)))); + + m7.optionalFloat = +0f; + m8.optionalFloat = -0f; + assertFalse(m7.equals(m8)); + + m7.optionalFloat = -0f; + m7.optionalDouble = +0d; + m8.optionalDouble = -0d; + assertFalse(m7.equals(m8)); + + m7.optionalDouble = -0d; + assertTrue(m7.equals(m8)); + assertFalse(m7.equals(new NanoReferenceTypes.TestAllTypesNano())); + assertTrue(m7.equals(MessageNano.mergeFrom( + new NanoReferenceTypes.TestAllTypesNano(), MessageNano.toByteArray(m7)))); + } + public void testNullRepeatedFields() throws Exception { // Check that serialization after explicitly setting a repeated field // to null doesn't NPE. diff --git a/src/google/protobuf/compiler/javanano/javanano_primitive_field.cc b/src/google/protobuf/compiler/javanano/javanano_primitive_field.cc index b6c98b4..9898aaf 100644 --- a/src/google/protobuf/compiler/javanano/javanano_primitive_field.cc +++ b/src/google/protobuf/compiler/javanano/javanano_primitive_field.cc @@ -175,38 +175,6 @@ int FixedSize(FieldDescriptor::Type type) { return -1; } -// Returns true if the field has a default value equal to NaN. -bool IsDefaultNaN(const FieldDescriptor* field) { - switch (field->type()) { - case FieldDescriptor::TYPE_INT32 : return false; - case FieldDescriptor::TYPE_UINT32 : return false; - case FieldDescriptor::TYPE_SINT32 : return false; - case FieldDescriptor::TYPE_FIXED32 : return false; - case FieldDescriptor::TYPE_SFIXED32: return false; - case FieldDescriptor::TYPE_INT64 : return false; - case FieldDescriptor::TYPE_UINT64 : return false; - case FieldDescriptor::TYPE_SINT64 : return false; - case FieldDescriptor::TYPE_FIXED64 : return false; - case FieldDescriptor::TYPE_SFIXED64: return false; - case FieldDescriptor::TYPE_FLOAT : - return isnan(field->default_value_float()); - case FieldDescriptor::TYPE_DOUBLE : - return isnan(field->default_value_double()); - case FieldDescriptor::TYPE_BOOL : return false; - case FieldDescriptor::TYPE_STRING : return false; - case FieldDescriptor::TYPE_BYTES : return false; - case FieldDescriptor::TYPE_ENUM : return false; - case FieldDescriptor::TYPE_GROUP : return false; - case FieldDescriptor::TYPE_MESSAGE : return false; - - // No default because we want the compiler to complain if any new - // types are added. - } - - GOOGLE_LOG(FATAL) << "Can't get here."; - return false; -} - // Return true if the type is a that has variable length // for instance String's. bool IsVariableLenType(JavaType type) { @@ -384,15 +352,21 @@ GenerateSerializationConditional(io::Printer* printer) const { printer->Print(variables_, "if ("); } - if (IsArrayType(GetJavaType(descriptor_))) { + JavaType java_type = GetJavaType(descriptor_); + if (IsArrayType(java_type)) { printer->Print(variables_, "!java.util.Arrays.equals(this.$name$, $default$)) {\n"); - } else if (IsReferenceType(GetJavaType(descriptor_))) { + } else if (IsReferenceType(java_type)) { printer->Print(variables_, "!this.$name$.equals($default$)) {\n"); - } else if (IsDefaultNaN(descriptor_)) { + } else if (java_type == JAVATYPE_FLOAT) { printer->Print(variables_, - "!$capitalized_type$.isNaN(this.$name$)) {\n"); + "java.lang.Float.floatToIntBits(this.$name$)\n" + " != java.lang.Float.floatToIntBits($default$)) {\n"); + } else if (java_type == JAVATYPE_DOUBLE) { + printer->Print(variables_, + "java.lang.Double.doubleToLongBits(this.$name$)\n" + " != java.lang.Double.doubleToLongBits($default$)) {\n"); } else { printer->Print(variables_, "this.$name$ != $default$) {\n"); @@ -464,6 +438,36 @@ GenerateEqualsCode(io::Printer* printer) const { printer->Print(") {\n" " return false;\n" "}\n"); + } else if (java_type == JAVATYPE_FLOAT) { + printer->Print(variables_, + "{\n" + " int bits = java.lang.Float.floatToIntBits(this.$name$);\n" + " if (bits != java.lang.Float.floatToIntBits(other.$name$)"); + if (params_.generate_has()) { + printer->Print(variables_, + "\n" + " || (bits == java.lang.Float.floatToIntBits($default$)\n" + " && this.has$capitalized_name$ != other.has$capitalized_name$)"); + } + printer->Print(") {\n" + " return false;\n" + " }\n" + "}\n"); + } else if (java_type == JAVATYPE_DOUBLE) { + printer->Print(variables_, + "{\n" + " long bits = java.lang.Double.doubleToLongBits(this.$name$);\n" + " if (bits != java.lang.Double.doubleToLongBits(other.$name$)"); + if (params_.generate_has()) { + printer->Print(variables_, + "\n" + " || (bits == java.lang.Double.doubleToLongBits($default$)\n" + " && this.has$capitalized_name$ != other.has$capitalized_name$)"); + } + printer->Print(") {\n" + " return false;\n" + " }\n" + "}\n"); } else { printer->Print(variables_, "if (this.$name$ != other.$name$"); @@ -623,12 +627,26 @@ GenerateSerializedSizeCode(io::Printer* printer) const { void AccessorPrimitiveFieldGenerator:: GenerateEqualsCode(io::Printer* printer) const { switch (GetJavaType(descriptor_)) { - // For all Java primitive types below, the hash codes match the - // results of BoxedType.valueOf(primitiveValue).hashCode(). - case JAVATYPE_INT: - case JAVATYPE_LONG: + // For all Java primitive types below, the equality checks match the + // results of BoxedType.valueOf(primitiveValue).equals(otherValue). case JAVATYPE_FLOAT: + printer->Print(variables_, + "if ($different_has$\n" + " || java.lang.Float.floatToIntBits($name$_)\n" + " != java.lang.Float.floatToIntBits(other.$name$_)) {\n" + " return false;\n" + "}\n"); + break; case JAVATYPE_DOUBLE: + printer->Print(variables_, + "if ($different_has$\n" + " || java.lang.Double.doubleToLongBits($name$_)\n" + " != java.lang.Double.doubleToLongBits(other.$name$_)) {\n" + " return false;\n" + "}\n"); + break; + case JAVATYPE_INT: + case JAVATYPE_LONG: case JAVATYPE_BOOLEAN: printer->Print(variables_, "if ($different_has$\n" diff --git a/src/google/protobuf/unittest_accessors_nano.proto b/src/google/protobuf/unittest_accessors_nano.proto index 875af25..f1d4d34 100644 --- a/src/google/protobuf/unittest_accessors_nano.proto +++ b/src/google/protobuf/unittest_accessors_nano.proto @@ -49,6 +49,8 @@ message TestNanoAccessors { // Singular optional int32 optional_int32 = 1; + optional float optional_float = 11; + optional double optional_double = 12; optional string optional_string = 14; optional bytes optional_bytes = 15; diff --git a/src/google/protobuf/unittest_has_nano.proto b/src/google/protobuf/unittest_has_nano.proto index 61800f2..289d08a 100644 --- a/src/google/protobuf/unittest_has_nano.proto +++ b/src/google/protobuf/unittest_has_nano.proto @@ -49,6 +49,8 @@ message TestAllTypesNanoHas { // Singular optional int32 optional_int32 = 1; + optional float optional_float = 11; + optional double optional_double = 12; optional string optional_string = 14; optional bytes optional_bytes = 15; |