blob: 50c8b37c7b4cb7067d726954177c1a9a919ce264 [file] [log] [blame]
/*
* Copyright 2009-2010 by The Regents of the University of California
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* you may obtain a copy of the License from
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package edu.uci.ics.asterix.runtime.pointables.cast;
import java.io.DataOutput;
import java.io.DataOutputStream;
import java.io.IOException;
import java.util.ArrayList;
import java.util.List;
import edu.uci.ics.asterix.builders.RecordBuilder;
import edu.uci.ics.asterix.common.exceptions.AsterixException;
import edu.uci.ics.asterix.dataflow.data.nontagged.AqlNullWriterFactory;
import edu.uci.ics.asterix.om.types.ARecordType;
import edu.uci.ics.asterix.om.types.ATypeTag;
import edu.uci.ics.asterix.om.types.AUnionType;
import edu.uci.ics.asterix.om.types.EnumDeserializer;
import edu.uci.ics.asterix.om.types.IAType;
import edu.uci.ics.asterix.om.util.NonTaggedFormatUtil;
import edu.uci.ics.asterix.runtime.pointables.ARecordPointable;
import edu.uci.ics.asterix.runtime.pointables.PointableAllocator;
import edu.uci.ics.asterix.runtime.pointables.base.DefaultOpenFieldType;
import edu.uci.ics.asterix.runtime.pointables.base.IVisitablePointable;
import edu.uci.ics.asterix.runtime.util.ResettableByteArrayOutputStream;
import edu.uci.ics.hyracks.algebricks.common.utils.Triple;
import edu.uci.ics.hyracks.api.dataflow.value.IBinaryComparator;
import edu.uci.ics.hyracks.api.dataflow.value.INullWriter;
import edu.uci.ics.hyracks.data.std.accessors.PointableBinaryComparatorFactory;
import edu.uci.ics.hyracks.data.std.api.IValueReference;
import edu.uci.ics.hyracks.data.std.primitive.UTF8StringPointable;
import edu.uci.ics.hyracks.dataflow.common.comm.io.ByteArrayAccessibleOutputStream;
/**
* This class is to do the runtime type cast for a record. It is ONLY visible to
* ACastVisitor.
*/
class ARecordCaster {
// pointable allocator
private final PointableAllocator allocator = new PointableAllocator();
private final List<IVisitablePointable> reqFieldNames = new ArrayList<IVisitablePointable>();
private final List<IVisitablePointable> reqFieldTypeTags = new ArrayList<IVisitablePointable>();
private ARecordType cachedReqType = null;
private final ResettableByteArrayOutputStream bos = new ResettableByteArrayOutputStream();
private final DataOutputStream dos = new DataOutputStream(bos);
private final RecordBuilder recBuilder = new RecordBuilder();
private final IVisitablePointable nullReference = allocator.allocateEmpty();
private final IVisitablePointable nullTypeTag = allocator.allocateEmpty();
private final IBinaryComparator fieldNameComparator = PointableBinaryComparatorFactory.of(
UTF8StringPointable.FACTORY).createBinaryComparator();
private final ByteArrayAccessibleOutputStream outputBos = new ByteArrayAccessibleOutputStream();
private final DataOutputStream outputDos = new DataOutputStream(outputBos);
private final IVisitablePointable fieldTempReference = allocator.allocateEmpty();
private final Triple<IVisitablePointable, IAType, Boolean> nestedVisitorArg = new Triple<IVisitablePointable, IAType, Boolean>(
fieldTempReference, null, null);
private int numInputFields = 0;
// describe closed fields in the required type
private int[] fieldPermutation;
private boolean[] optionalFields;
// describe fields (open or not) in the input records
private boolean[] openFields;
private int[] fieldNamesSortedIndex;
private int[] reqFieldNamesSortedIndex;
public ARecordCaster() {
try {
bos.reset();
int start = bos.size();
INullWriter nullWriter = AqlNullWriterFactory.INSTANCE.createNullWriter();
nullWriter.writeNull(dos);
int end = bos.size();
nullReference.set(bos.getByteArray(), start, end - start);
start = bos.size();
dos.write(ATypeTag.NULL.serialize());
end = bos.size();
nullTypeTag.set(bos.getByteArray(), start, end);
} catch (IOException e) {
throw new IllegalStateException(e);
}
}
public void castRecord(ARecordPointable recordAccessor, IVisitablePointable resultAccessor, ARecordType reqType,
ACastVisitor visitor) throws IOException, AsterixException {
List<IVisitablePointable> fieldNames = recordAccessor.getFieldNames();
List<IVisitablePointable> fieldTypeTags = recordAccessor.getFieldTypeTags();
List<IVisitablePointable> fieldValues = recordAccessor.getFieldValues();
numInputFields = fieldNames.size();
if (openFields == null || numInputFields > openFields.length) {
openFields = new boolean[numInputFields];
fieldNamesSortedIndex = new int[numInputFields];
}
if (cachedReqType == null || !reqType.equals(cachedReqType)) {
loadRequiredType(reqType);
}
// clear the previous states
reset();
matchClosedPart(fieldNames, fieldTypeTags, fieldValues);
writeOutput(fieldNames, fieldTypeTags, fieldValues, outputDos, visitor);
resultAccessor.set(outputBos.getByteArray(), 0, outputBos.size());
}
private void reset() {
for (int i = 0; i < numInputFields; i++)
openFields[i] = true;
for (int i = 0; i < fieldPermutation.length; i++)
fieldPermutation[i] = -1;
for (int i = 0; i < numInputFields; i++)
fieldNamesSortedIndex[i] = i;
outputBos.reset();
}
private void loadRequiredType(ARecordType reqType) throws IOException {
reqFieldNames.clear();
reqFieldTypeTags.clear();
cachedReqType = reqType;
int numSchemaFields = reqType.getFieldTypes().length;
IAType[] fieldTypes = reqType.getFieldTypes();
String[] fieldNames = reqType.getFieldNames();
fieldPermutation = new int[numSchemaFields];
optionalFields = new boolean[numSchemaFields];
for (int i = 0; i < optionalFields.length; i++)
optionalFields[i] = false;
bos.reset(nullReference.getStartOffset() + nullReference.getLength());
for (int i = 0; i < numSchemaFields; i++) {
ATypeTag ftypeTag = fieldTypes[i].getTypeTag();
String fname = fieldNames[i];
// add type tag pointable
if (fieldTypes[i].getTypeTag() == ATypeTag.UNION
&& NonTaggedFormatUtil.isOptionalField((AUnionType) fieldTypes[i])) {
// optional field: add the embedded non-null type tag
ftypeTag = ((AUnionType) fieldTypes[i]).getUnionList()
.get(NonTaggedFormatUtil.OPTIONAL_TYPE_INDEX_IN_UNION_LIST).getTypeTag();
optionalFields[i] = true;
}
int tagStart = bos.size();
dos.writeByte(ftypeTag.serialize());
int tagEnd = bos.size();
IVisitablePointable typeTagPointable = allocator.allocateEmpty();
typeTagPointable.set(bos.getByteArray(), tagStart, tagEnd - tagStart);
reqFieldTypeTags.add(typeTagPointable);
// add type name pointable (including a string type tag)
int nameStart = bos.size();
dos.write(ATypeTag.STRING.serialize());
dos.writeUTF(fname);
int nameEnd = bos.size();
IVisitablePointable typeNamePointable = allocator.allocateEmpty();
typeNamePointable.set(bos.getByteArray(), nameStart, nameEnd - nameStart);
reqFieldNames.add(typeNamePointable);
}
reqFieldNamesSortedIndex = new int[reqFieldNames.size()];
for (int i = 0; i < reqFieldNamesSortedIndex.length; i++)
reqFieldNamesSortedIndex[i] = i;
// sort the field name index
quickSort(reqFieldNamesSortedIndex, reqFieldNames, 0, reqFieldNamesSortedIndex.length - 1);
}
private void matchClosedPart(List<IVisitablePointable> fieldNames, List<IVisitablePointable> fieldTypeTags,
List<IVisitablePointable> fieldValues) {
// sort-merge based match
quickSort(fieldNamesSortedIndex, fieldNames, 0, numInputFields - 1);
int fnStart = 0;
int reqFnStart = 0;
while (fnStart < numInputFields && reqFnStart < reqFieldNames.size()) {
int fnPos = fieldNamesSortedIndex[fnStart];
int reqFnPos = reqFieldNamesSortedIndex[reqFnStart];
int c = compare(fieldNames.get(fnPos), reqFieldNames.get(reqFnPos));
if (c == 0) {
IVisitablePointable fieldTypeTag = fieldTypeTags.get(fnPos);
IVisitablePointable reqFieldTypeTag = reqFieldTypeTags.get(reqFnPos);
if (fieldTypeTag.equals(reqFieldTypeTag) || (
// match the null type of optional field
optionalFields[reqFnPos] && fieldTypeTag.equals(nullTypeTag))) {
fieldPermutation[reqFnPos] = fnPos;
openFields[fnPos] = false;
}
fnStart++;
reqFnStart++;
}
if (c > 0)
reqFnStart++;
if (c < 0)
fnStart++;
}
// check unmatched fields in the input type
for (int i = 0; i < openFields.length; i++) {
if (openFields[i] == true && !cachedReqType.isOpen())
throw new IllegalStateException("type mismatch: including extra closed fields");
}
// check unmatched fields in the required type
for (int i = 0; i < fieldPermutation.length; i++) {
if (fieldPermutation[i] < 0) {
IAType t = cachedReqType.getFieldTypes()[i];
if (!(t.getTypeTag() == ATypeTag.UNION && NonTaggedFormatUtil.isOptionalField((AUnionType) t))) {
// no matched field in the input for a required closed field
throw new IllegalStateException("type mismatch: miss a required closed field");
}
}
}
}
private void writeOutput(List<IVisitablePointable> fieldNames, List<IVisitablePointable> fieldTypeTags,
List<IVisitablePointable> fieldValues, DataOutput output, ACastVisitor visitor) throws IOException,
AsterixException {
// reset the states of the record builder
recBuilder.reset(cachedReqType);
recBuilder.init();
// write the closed part
for (int i = 0; i < fieldPermutation.length; i++) {
int pos = fieldPermutation[i];
IVisitablePointable field;
if (pos >= 0) {
field = fieldValues.get(pos);
} else {
field = nullReference;
}
IAType fType = cachedReqType.getFieldTypes()[i];
nestedVisitorArg.second = fType;
// recursively casting, the result of casting can always be thought
// as flat
if (optionalFields[i]) {
nestedVisitorArg.second = ((AUnionType) fType).getUnionList().get(
NonTaggedFormatUtil.OPTIONAL_TYPE_INDEX_IN_UNION_LIST);
}
field.accept(visitor, nestedVisitorArg);
recBuilder.addField(i, nestedVisitorArg.first);
}
// write the open part
for (int i = 0; i < numInputFields; i++) {
if (openFields[i]) {
IVisitablePointable name = fieldNames.get(i);
IVisitablePointable field = fieldValues.get(i);
IVisitablePointable fieldTypeTag = fieldTypeTags.get(i);
ATypeTag typeTag = EnumDeserializer.ATYPETAGDESERIALIZER
.deserialize(fieldTypeTag.getByteArray()[fieldTypeTag.getStartOffset()]);
nestedVisitorArg.second = DefaultOpenFieldType.getDefaultOpenFieldType(typeTag);
field.accept(visitor, nestedVisitorArg);
recBuilder.addField(name, nestedVisitorArg.first);
}
}
recBuilder.write(output, true);
}
private void quickSort(int[] index, List<IVisitablePointable> names, int start, int end) {
if (end <= start)
return;
int i = partition(index, names, start, end);
quickSort(index, names, start, i - 1);
quickSort(index, names, i + 1, end);
}
private int partition(int[] index, List<IVisitablePointable> names, int left, int right) {
int i = left - 1;
int j = right;
while (true) {
// grow from the left
while (compare(names.get(index[++i]), names.get(index[right])) < 0)
;
// lower from the right
while (compare(names.get(index[right]), names.get(index[--j])) < 0)
if (j == left)
break;
if (i >= j)
break;
// swap i and j
swap(index, i, j);
}
// swap i and right
swap(index, i, right); // swap with partition element
return i;
}
private void swap(int[] array, int i, int j) {
int temp = array[i];
array[i] = array[j];
array[j] = temp;
}
private int compare(IValueReference a, IValueReference b) {
// start+1 and len-1 due to the type tag
return fieldNameComparator.compare(a.getByteArray(), a.getStartOffset() + 1, a.getLength() - 1,
b.getByteArray(), b.getStartOffset() + 1, b.getLength() - 1);
}
}