blob: 363052c9c8b65be208eb50d5203ef37c0faff433 [file] [log] [blame]
/**
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.hadoop.io.serial.lib.protobuf;
import java.io.IOException;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import org.apache.hadoop.io.serial.RawComparator;
import com.google.protobuf.ByteString;
import com.google.protobuf.CodedInputStream;
import com.google.protobuf.Descriptors.Descriptor;
import com.google.protobuf.Descriptors.EnumValueDescriptor;
import com.google.protobuf.Descriptors.FieldDescriptor;
import com.google.protobuf.Descriptors.FieldDescriptor.Type;
import com.google.protobuf.Message;
public class ProtoBufComparator implements RawComparator {
static final int WIRETYPE_VARINT = 0;
static final int WIRETYPE_FIXED64 = 1;
static final int WIRETYPE_LENGTH_DELIMITED = 2;
static final int WIRETYPE_START_GROUP = 3;
static final int WIRETYPE_END_GROUP = 4;
static final int WIRETYPE_FIXED32 = 5;
static final int TAG_TYPE_BITS = 3;
static final int TAG_TYPE_MASK = (1 << TAG_TYPE_BITS) - 1;
private final Map<Descriptor, List<FieldDescriptor>> fieldCache =
new HashMap<Descriptor, List<FieldDescriptor>>();
private final List<FieldDescriptor> topFields;
/**
* Create a comparator that will compare serialized messages of a particular
* class.
* @param cls the specific class to compare
*/
public ProtoBufComparator(Class<? extends Message> cls) {
if (!Message.class.isAssignableFrom(cls)) {
throw new IllegalArgumentException("Type " + cls +
"must be a generated protobuf class");
}
try {
Method getDescriptor = cls.getDeclaredMethod("getDescriptor");
topFields = addToCache((Descriptor) getDescriptor.invoke(null));
} catch (Exception e) {
throw new IllegalArgumentException("Can't get descriptors for " + cls, e);
}
}
/**
* Define a comparator so that we can sort the fields by their field ids.
*/
private static class FieldIdComparator implements Comparator<FieldDescriptor>{
@Override
public int compare(FieldDescriptor left, FieldDescriptor right) {
int leftId = left.getNumber();
int rightId = right.getNumber();
if (leftId == rightId) {
return 0;
} else {
return leftId < rightId ? -1 : 1;
}
}
}
private static final FieldIdComparator FIELD_COMPARE= new FieldIdComparator();
/**
* Add all of the types that are recursively nested under the given one
* to the cache. The fields are sorted by field id.
* @param type the type to add
* @return the list of sorted fields for the given type
*/
private List<FieldDescriptor> addToCache(Descriptor type) {
List<FieldDescriptor> fields =
new ArrayList<FieldDescriptor>(type.getFields());
Collections.sort(fields, FIELD_COMPARE);
fieldCache.put(type, fields);
for(Descriptor subMessage: type.getNestedTypes()) {
if (!fieldCache.containsKey(subMessage)) {
addToCache(subMessage);
}
}
return fields;
}
/**
* Compare two serialized protocol buffers using the natural sort order.
* @param b1 the left serialized value
* @param s1 the first byte index in b1 to compare
* @param l1 the number of bytes in b1 to compare
* @param b2 the right serialized value
* @param s2 the first byte index in b2 to compare
* @param l2 the number of bytes in b2 to compare
*/
@Override
public int compare(byte[] b1, int s1, int l1, byte[] b2, int s2, int l2) {
CodedInputStream left = CodedInputStream.newInstance(b1,s1,l1);
CodedInputStream right = CodedInputStream.newInstance(b2, s2, l2);
try {
return compareMessage(left, right, topFields);
} catch (IOException ie) {
throw new IllegalArgumentException("problem running compare", ie);
}
}
/**
* Advance the stream to the given fieldId or one that is larger.
* @param stream the stream to read
* @param currentTag the last tag that was read from this stream
* @param fieldId the id of the field we are looking for
* @return the last tag that was read or 0 for end of stream
* @throws IOException
*/
private int advanceTo(CodedInputStream stream,
int currentTag,
int fieldId) throws IOException {
int goal = fieldId << TAG_TYPE_BITS;
// if we've already seen the right tag, return it
if (currentTag > goal) {
return currentTag;
}
while (!stream.isAtEnd()) {
currentTag = stream.readTag();
if (currentTag > goal) {
return currentTag;
} else {
stream.skipField(currentTag);
}
}
return 0;
}
/**
* Check compatibility between the logical type in the schema and the
* wiretype. Incompatible fields are ignored.
* @param tag the tag (id and wiretype) of the field
* @param type the logical type of the field
* @return true if we should use this field for comparing
*/
private boolean isCompatible(int tag, Type type) {
int wiretype = tag & TAG_TYPE_MASK;
switch (type) {
case BOOL:
case ENUM:
case INT32:
case INT64:
case SINT32:
case SINT64:
case UINT32:
case UINT64:
return wiretype == WIRETYPE_VARINT ||
wiretype == WIRETYPE_LENGTH_DELIMITED;
case BYTES:
case MESSAGE:
case STRING:
return wiretype == WIRETYPE_LENGTH_DELIMITED;
case FLOAT:
case FIXED32:
case SFIXED32:
return wiretype == WIRETYPE_LENGTH_DELIMITED ||
wiretype == WIRETYPE_FIXED32;
case DOUBLE:
case SFIXED64:
case FIXED64:
return wiretype == WIRETYPE_LENGTH_DELIMITED ||
wiretype == WIRETYPE_FIXED64;
case GROUP:
// don't bother dealing with groups, since they aren't used outside of
// the protobuf mothership.
return false;
default:
throw new IllegalArgumentException("Unknown field type " + type);
}
}
/**
* Compare two serialized messages of the same type.
* @param left the left message
* @param right the right message
* @param fields the fields of the message
* @return -1, 0, or 1 if left is less, equal or greater than right
* @throws IOException
*/
private int compareMessage(CodedInputStream left, CodedInputStream right,
List<FieldDescriptor> fields
) throws IOException {
int leftTag = 0;
int rightTag = 0;
for(FieldDescriptor field: fields) {
int fieldId = field.getNumber();
Type fieldType = field.getType();
int wireFormat = 0;
leftTag = advanceTo(left, leftTag, fieldId);
rightTag = advanceTo(right, rightTag, fieldId);
boolean leftDefault = (leftTag >>> TAG_TYPE_BITS) != fieldId;
boolean rightDefault = (rightTag >>> TAG_TYPE_BITS) != fieldId;
// if the fieldType and wiretypes aren't compatible, skip field
if (!leftDefault && !isCompatible(leftTag, fieldType)) {
leftDefault = true;
left.skipField(leftTag);
}
if (!rightDefault && !isCompatible(rightTag, fieldType)) {
rightDefault = true;
right.skipField(leftTag);
}
if (!leftDefault) {
wireFormat = leftTag & TAG_TYPE_MASK;
// ensure both sides use the same representation
if (!rightDefault && leftTag != rightTag) {
return leftTag < rightTag ? -1 : 1;
}
} else if (rightDefault) {
continue;
}
int result;
switch (wireFormat) {
case WIRETYPE_LENGTH_DELIMITED:
switch (fieldType) {
case STRING:
String leftStr =
leftDefault ? (String) field.getDefaultValue() : left.readString();
String rightStr =
rightDefault ? (String) field.getDefaultValue() :right.readString();
result = leftStr.compareTo(rightStr);
if (result != 0) {
return result;
}
break;
case BYTES:
result = compareBytes(leftDefault ?
(ByteString) field.getDefaultValue() :
left.readBytes(),
rightDefault ?
(ByteString) field.getDefaultValue() :
right.readBytes());
if (result != 0) {
return result;
}
break;
default:
// handle nested messages and packed fields
if (leftDefault) {
return -1;
} else if (rightDefault) {
return 1;
}
int leftLimit = left.readRawVarint32();
int rightLimit = right.readRawVarint32();
int oldLeft = left.pushLimit(leftLimit);
int oldRight = right.pushLimit(rightLimit);
while (left.getBytesUntilLimit() > 0 &&
right.getBytesUntilLimit() > 0) {
result = compareField(field, left, right, false, false);
if (result != 0) {
return result;
}
}
if (right.getBytesUntilLimit() > 0) {
return -1;
} else if (left.getBytesUntilLimit() > 0) {
return 1;
}
left.popLimit(oldLeft);
right.popLimit(oldRight);
}
break;
case WIRETYPE_VARINT:
case WIRETYPE_FIXED32:
case WIRETYPE_FIXED64:
result = compareField(field, left, right, leftDefault, rightDefault);
if (result != 0) {
return result;
}
break;
default:
throw new IllegalArgumentException("Unknown field encoding " +
wireFormat);
}
}
return 0;
}
/**
* Compare a single field inside of a message. This is used for both packed
* and unpacked fields. It assumes the wire type has already been checked.
* @param field the type of the field that we are comparing
* @param left the left value
* @param right the right value
* @param leftDefault use the default value instead of the left value
* @param rightDefault use the default value instead of the right value
* @return -1, 0, 1 depending on whether left is less, equal or greater than
* right
* @throws IOException
*/
private int compareField(FieldDescriptor field,
CodedInputStream left,
CodedInputStream right,
boolean leftDefault,
boolean rightDefault) throws IOException {
switch (field.getType()) {
case BOOL:
boolean leftBool = leftDefault ?
(Boolean) field.getDefaultValue() : left.readBool();
boolean rightBool = rightDefault ?
(Boolean) field.getDefaultValue() : right.readBool();
if (leftBool == rightBool) {
return 0;
} else {
return rightBool ? -1 : 1;
}
case DOUBLE:
return
Double.compare(leftDefault ?
(Double) field.getDefaultValue(): left.readDouble(),
rightDefault ?
(Double) field.getDefaultValue() :right.readDouble());
case ENUM:
return compareInt32(leftDefault ? intDefault(field) : left.readEnum(),
rightDefault ? intDefault(field) : right.readEnum());
case FIXED32:
return compareUInt32(leftDefault ? intDefault(field) : left.readFixed32(),
rightDefault?intDefault(field):right.readFixed32());
case FIXED64:
return compareUInt64(leftDefault? longDefault(field) : left.readFixed64(),
rightDefault?longDefault(field):right.readFixed64());
case FLOAT:
return Float.compare(leftDefault ?
(Float) field.getDefaultValue() : left.readFloat(),
rightDefault ?
(Float) field.getDefaultValue():right.readFloat());
case INT32:
return compareInt32(leftDefault?intDefault(field):left.readInt32(),
rightDefault?intDefault(field):right.readInt32());
case INT64:
return compareInt64(leftDefault?longDefault(field):left.readInt64(),
rightDefault?longDefault(field):right.readInt64());
case MESSAGE:
return compareMessage(left, right,
fieldCache.get(field.getMessageType()));
case SFIXED32:
return compareInt32(leftDefault ?intDefault(field):left.readSFixed32(),
rightDefault ?intDefault(field):right.readSFixed32());
case SFIXED64:
return compareInt64(leftDefault ? longDefault(field) :left.readSFixed64(),
rightDefault?longDefault(field):right.readSFixed64());
case SINT32:
return compareInt32(leftDefault?intDefault(field):left.readSInt32(),
rightDefault?intDefault(field):right.readSInt32());
case SINT64:
return compareInt64(leftDefault ? longDefault(field) : left.readSInt64(),
rightDefault?longDefault(field):right.readSInt64());
case UINT32:
return compareUInt32(leftDefault ? intDefault(field) :left.readUInt32(),
rightDefault ? intDefault(field):right.readUInt32());
case UINT64:
return compareUInt64(leftDefault ? longDefault(field) :left.readUInt64(),
rightDefault?longDefault(field) :right.readUInt64());
default:
throw new IllegalArgumentException("unknown field type " + field);
}
}
/**
* Compare 32 bit signed integers.
* @param left
* @param right
* @return -1, 0 ,1 if left is less, equal, or greater to right
*/
private static int compareInt32(int left, int right) {
if (left == right) {
return 0;
} else {
return left < right ? -1 : 1;
}
}
/**
* Compare 64 bit signed integers.
* @param left
* @param right
* @return -1, 0 ,1 if left is less, equal, or greater to right
*/
private static int compareInt64(long left, long right) {
if (left == right) {
return 0;
} else {
return left < right ? -1 : 1;
}
}
/**
* Compare 32 bit logically unsigned integers.
* @param left
* @param right
* @return -1, 0 ,1 if left is less, equal, or greater to right
*/
private static int compareUInt32(int left, int right) {
if (left == right) {
return 0;
} else {
return left + Integer.MIN_VALUE < right + Integer.MIN_VALUE ? -1 : 1;
}
}
/**
* Compare two byte strings using memcmp semantics
* @param left
* @param right
* @return -1, 0, 1 if left is less, equal, or greater than right
*/
private static int compareBytes(ByteString left, ByteString right) {
int size = Math.min(left.size(), right.size());
for(int i = 0; i < size; ++i) {
int leftByte = left.byteAt(i) & 0xff;
int rightByte = right.byteAt(i) & 0xff;
if (leftByte != rightByte) {
return leftByte < rightByte ? -1 : 1;
}
}
if (left.size() != right.size()) {
return left.size() < right.size() ? -1 : 1;
}
return 0;
}
/**
* Compare 32 bit logically unsigned integers.
* @param left
* @param right
* @return -1, 0 ,1 if left is less, equal, or greater to right
*/
private static int compareUInt64(long left, long right) {
if (left == right) {
return 0;
} else {
return left + Long.MIN_VALUE < right + Long.MIN_VALUE ? -1 : 1;
}
}
/**
* Get the integer default, including dereferencing enum values.
* @param field the field
* @return the integer default value
*/
private static int intDefault(FieldDescriptor field) {
if (field.getType() == Type.ENUM) {
return ((EnumValueDescriptor) field.getDefaultValue()).getNumber();
} else {
return (Integer) field.getDefaultValue();
}
}
/**
* Get the long default value for the given field.
* @param field the field
* @return the default value
*/
private static long longDefault(FieldDescriptor field) {
return (Long) field.getDefaultValue();
}
}