blob: 672fb399d3afff6f0de105aa7632ebe3443db8fa [file] [log] [blame]
package edu.uci.ics.asterix.runtime.evaluators.common;
import java.io.DataOutput;
import java.io.IOException;
import java.util.Arrays;
import edu.uci.ics.asterix.formats.nontagged.AqlBinaryComparatorFactoryProvider;
import edu.uci.ics.asterix.formats.nontagged.AqlBinaryHashFunctionFactoryProvider;
import edu.uci.ics.asterix.formats.nontagged.AqlSerializerDeserializerProvider;
import edu.uci.ics.asterix.om.base.AFloat;
import edu.uci.ics.asterix.om.base.AMutableFloat;
import edu.uci.ics.asterix.om.types.ATypeTag;
import edu.uci.ics.asterix.om.types.BuiltinType;
import edu.uci.ics.asterix.om.types.EnumDeserializer;
import edu.uci.ics.asterix.runtime.evaluators.functions.BinaryHashMap;
import edu.uci.ics.asterix.runtime.evaluators.functions.BinaryHashMap.BinaryEntry;
import edu.uci.ics.hyracks.algebricks.common.exceptions.AlgebricksException;
import edu.uci.ics.hyracks.algebricks.runtime.base.ICopyEvaluator;
import edu.uci.ics.hyracks.algebricks.runtime.base.ICopyEvaluatorFactory;
import edu.uci.ics.hyracks.api.dataflow.value.IBinaryComparator;
import edu.uci.ics.hyracks.api.dataflow.value.IBinaryHashFunction;
import edu.uci.ics.hyracks.api.dataflow.value.ISerializerDeserializer;
import edu.uci.ics.hyracks.data.std.primitive.IntegerPointable;
import edu.uci.ics.hyracks.data.std.primitive.UTF8StringPointable;
import edu.uci.ics.hyracks.dataflow.common.data.accessors.ArrayBackedValueStorage;
import edu.uci.ics.hyracks.dataflow.common.data.accessors.IDataOutputProvider;
import edu.uci.ics.hyracks.dataflow.common.data.accessors.IFrameTupleReference;
public class SimilarityJaccardEvaluator implements ICopyEvaluator {
// Parameters for hash table.
protected final int TABLE_SIZE = 100;
protected final int TABLE_FRAME_SIZE = 32768;
// Assuming type indicator in serde format.
protected final int TYPE_INDICATOR_SIZE = 1;
protected final DataOutput out;
protected final ArrayBackedValueStorage argOut = new ArrayBackedValueStorage();
protected final ICopyEvaluator firstOrdListEval;
protected final ICopyEvaluator secondOrdListEval;
protected final AsterixOrderedListIterator fstOrdListIter = new AsterixOrderedListIterator();
protected final AsterixOrderedListIterator sndOrdListIter = new AsterixOrderedListIterator();
protected final AsterixUnorderedListIterator fstUnordListIter = new AsterixUnorderedListIterator();
protected final AsterixUnorderedListIterator sndUnordListIter = new AsterixUnorderedListIterator();
protected AbstractAsterixListIterator firstListIter;
protected AbstractAsterixListIterator secondListIter;
protected final AMutableFloat aFloat = new AMutableFloat(0);
@SuppressWarnings("unchecked")
protected final ISerializerDeserializer<AFloat> floatSerde = AqlSerializerDeserializerProvider.INSTANCE
.getSerializerDeserializer(BuiltinType.AFLOAT);
protected ATypeTag firstTypeTag;
protected ATypeTag secondTypeTag;
protected int firstStart = -1;
protected int secondStart = -1;
protected float jaccSim = 0.0f;
protected ATypeTag itemTypeTag;
protected BinaryHashMap hashMap;
protected BinaryEntry keyEntry = new BinaryEntry();
protected BinaryEntry valEntry = new BinaryEntry();
// Ignore case for strings. Defaults to true.
protected final boolean ignoreCase = true;
public SimilarityJaccardEvaluator(ICopyEvaluatorFactory[] args, IDataOutputProvider output)
throws AlgebricksException {
out = output.getDataOutput();
firstOrdListEval = args[0].createEvaluator(argOut);
secondOrdListEval = args[1].createEvaluator(argOut);
byte[] emptyValBuf = new byte[8];
Arrays.fill(emptyValBuf, (byte) 0);
valEntry.set(emptyValBuf, 0, 8);
}
@Override
public void evaluate(IFrameTupleReference tuple) throws AlgebricksException {
runArgEvals(tuple);
if (!checkArgTypes(firstTypeTag, secondTypeTag)) {
return;
}
if (prepareLists(argOut.getByteArray(), firstStart, secondStart, firstTypeTag)) {
jaccSim = computeResult(argOut.getByteArray(), firstStart, secondStart, firstTypeTag);
} else {
jaccSim = 0.0f;
}
try {
writeResult(jaccSim);
} catch (IOException e) {
throw new AlgebricksException(e);
}
}
protected void runArgEvals(IFrameTupleReference tuple) throws AlgebricksException {
argOut.reset();
firstStart = argOut.getLength();
firstOrdListEval.evaluate(tuple);
secondStart = argOut.getLength();
secondOrdListEval.evaluate(tuple);
firstTypeTag = EnumDeserializer.ATYPETAGDESERIALIZER.deserialize(argOut.getByteArray()[firstStart]);
secondTypeTag = EnumDeserializer.ATYPETAGDESERIALIZER.deserialize(argOut.getByteArray()[secondStart]);
}
protected boolean prepareLists(byte[] bytes, int firstStart, int secondStart, ATypeTag argType)
throws AlgebricksException {
firstListIter.reset(bytes, firstStart);
secondListIter.reset(bytes, secondStart);
// Check for special case where one of the lists is empty, since list
// types won't match.
if (firstListIter.size() == 0 || secondListIter.size() == 0) {
return false;
}
if (firstTypeTag == ATypeTag.ANY || secondTypeTag == ATypeTag.ANY) {
throw new AlgebricksException("\n Jaccard can only be called on homogenous lists");
}
// TODO: Check item types are compatible.
itemTypeTag = EnumDeserializer.ATYPETAGDESERIALIZER.deserialize(bytes[firstStart + 1]);
return true;
}
protected float computeResult(byte[] bytes, int firstStart, int secondStart, ATypeTag argType)
throws AlgebricksException {
setHashMap(bytes, firstStart, secondStart);
// We will subtract the intersection size later to get the real union size.
int firstListSize = firstListIter.size();
int secondListSize = secondListIter.size();
int unionSize = firstListSize + secondListSize;
// Choose smaller list as build, and larger one as probe.
AbstractAsterixListIterator buildList = (firstListSize < secondListSize) ? firstListIter : secondListIter;
AbstractAsterixListIterator probeList = (buildList == firstListIter) ? secondListIter : firstListIter;
int buildListSize = (buildList == firstListIter) ? firstListSize : secondListSize;
int probeListSize = (probeList == firstListIter) ? firstListSize : secondListSize;
buildHashMap(buildList);
int intersectionSize = probeHashMap(probeList, buildListSize, probeListSize);
// Special indicator for the "check" version of jaccard.
if (intersectionSize < 0) {
return -1;
}
unionSize -= intersectionSize;
return (float) intersectionSize / (float) unionSize;
}
protected void buildHashMap(AbstractAsterixListIterator buildIter) {
// Build phase: Add items into hash map, starting with first list.
// Value in map is a pair of integers. Set first integer to 1.
IntegerPointable.setInteger(valEntry.buf, 0, 1);
while (buildIter.hasNext()) {
byte[] buf = buildIter.getData();
int off = buildIter.getPos();
int len = getItemLen(buf, off);
keyEntry.set(buf, off, len);
BinaryEntry entry = hashMap.put(keyEntry, valEntry);
if (entry != null) {
// Increment value.
int firstValInt = IntegerPointable.getInteger(entry.buf, entry.off);
IntegerPointable.setInteger(entry.buf, entry.off, firstValInt + 1);
}
buildIter.next();
}
}
protected int probeHashMap(AbstractAsterixListIterator probeIter, int probeListSize, int buildListSize) {
// Probe phase: Probe items from second list, and compute intersection size.
int intersectionSize = 0;
while (probeIter.hasNext()) {
byte[] buf = probeIter.getData();
int off = probeIter.getPos();
int len = getItemLen(buf, off);
keyEntry.set(buf, off, len);
BinaryEntry entry = hashMap.get(keyEntry);
if (entry != null) {
// Increment second value.
int firstValInt = IntegerPointable.getInteger(entry.buf, entry.off);
// Irrelevant for the intersection size.
if (firstValInt == 0) {
continue;
}
int secondValInt = IntegerPointable.getInteger(entry.buf, entry.off + 4);
// Subtract old min value.
intersectionSize -= (firstValInt < secondValInt) ? firstValInt : secondValInt;
secondValInt++;
// Add new min value.
intersectionSize += (firstValInt < secondValInt) ? firstValInt : secondValInt;
IntegerPointable.setInteger(entry.buf, entry.off + 4, secondValInt);
}
probeIter.next();
}
return intersectionSize;
}
protected void setHashMap(byte[] bytes, int firstStart, int secondStart) {
if (hashMap != null) {
hashMap.clear();
return;
}
IBinaryHashFunction hashFunc = null;
IBinaryComparator cmp = null;
switch (itemTypeTag) {
case INT32: {
hashFunc = AqlBinaryHashFunctionFactoryProvider.INTEGER_POINTABLE_INSTANCE.createBinaryHashFunction();
cmp = AqlBinaryComparatorFactoryProvider.INTEGER_POINTABLE_INSTANCE.createBinaryComparator();
break;
}
case FLOAT: {
hashFunc = AqlBinaryHashFunctionFactoryProvider.FLOAT_POINTABLE_INSTANCE.createBinaryHashFunction();
cmp = AqlBinaryComparatorFactoryProvider.FLOAT_POINTABLE_INSTANCE.createBinaryComparator();
break;
}
case DOUBLE: {
hashFunc = AqlBinaryHashFunctionFactoryProvider.DOUBLE_POINTABLE_INSTANCE.createBinaryHashFunction();
cmp = AqlBinaryComparatorFactoryProvider.DOUBLE_POINTABLE_INSTANCE.createBinaryComparator();
break;
}
case STRING: {
if (ignoreCase) {
// Ignore case in comparisons and hashing.
hashFunc = AqlBinaryHashFunctionFactoryProvider.UTF8STRING_LOWERCASE_POINTABLE_INSTANCE
.createBinaryHashFunction();
cmp = AqlBinaryComparatorFactoryProvider.UTF8STRING_LOWERCASE_POINTABLE_INSTANCE
.createBinaryComparator();
} else {
hashFunc = AqlBinaryHashFunctionFactoryProvider.UTF8STRING_POINTABLE_INSTANCE
.createBinaryHashFunction();
cmp = AqlBinaryComparatorFactoryProvider.UTF8STRING_POINTABLE_INSTANCE.createBinaryComparator();
}
break;
}
default: {
break;
}
}
hashMap = new BinaryHashMap(TABLE_SIZE, TABLE_FRAME_SIZE, hashFunc, cmp);
}
protected int getItemLen(byte[] bytes, int itemOff) {
switch (itemTypeTag) {
case INT32: {
return 4;
}
case FLOAT: {
return 4;
}
case DOUBLE: {
return 8;
}
case STRING: {
// 2 bytes for the UTF8 len, plus the string data.
return 2 + UTF8StringPointable.getUTFLength(bytes, itemOff);
}
default: {
return -1;
}
}
}
protected boolean checkArgTypes(ATypeTag typeTag1, ATypeTag typeTag2) throws AlgebricksException {
// Jaccard between null and anything else is 0
if (typeTag1 == ATypeTag.NULL || typeTag2 == ATypeTag.NULL) {
try {
writeResult(0.0f);
} catch (IOException e) {
throw new AlgebricksException(e);
}
return false;
}
switch (typeTag1) {
case ORDEREDLIST: {
firstListIter = fstOrdListIter;
break;
}
case UNORDEREDLIST: {
firstListIter = fstUnordListIter;
break;
}
default: {
throw new AlgebricksException("Invalid types " + typeTag1 + " given as arguments to jaccard.");
}
}
switch (typeTag2) {
case ORDEREDLIST: {
secondListIter = sndOrdListIter;
break;
}
case UNORDEREDLIST: {
secondListIter = sndUnordListIter;
break;
}
default: {
throw new AlgebricksException("Invalid types " + typeTag2 + " given as arguments to jaccard.");
}
}
return true;
}
protected void writeResult(float jacc) throws IOException {
aFloat.setValue(jacc);
floatSerde.serialize(aFloat, out);
}
}