| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file exceBinaryRow in compliance |
| * with the License. You may oBinaryRowain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHBinaryRow WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package org.apache.flink.table.runtime.join.batch.hashtable.longtable; |
| |
| import org.apache.flink.core.memory.MemorySegment; |
| import org.apache.flink.core.memory.MemorySegmentFactory; |
| import org.apache.flink.core.memory.SeekableDataInputView; |
| import org.apache.flink.runtime.io.disk.iomanager.BlockChannelWriter; |
| import org.apache.flink.runtime.io.disk.iomanager.BulkBlockChannelReader; |
| import org.apache.flink.runtime.io.disk.iomanager.ChannelReaderInputView; |
| import org.apache.flink.runtime.io.disk.iomanager.FileIOChannel; |
| import org.apache.flink.runtime.io.disk.iomanager.IOManager; |
| import org.apache.flink.runtime.memory.AbstractPagedInputView; |
| import org.apache.flink.runtime.memory.AbstractPagedOutputView; |
| import org.apache.flink.table.dataformat.BinaryRow; |
| import org.apache.flink.table.runtime.util.AbstractChannelWriterOutputView; |
| import org.apache.flink.table.runtime.util.FileChannelUtil; |
| import org.apache.flink.table.runtime.util.RowIterator; |
| import org.apache.flink.table.typeutils.BinaryRowSerializer; |
| import org.apache.flink.util.MathUtils; |
| import org.apache.flink.util.Preconditions; |
| |
| import org.slf4j.Logger; |
| import org.slf4j.LoggerFactory; |
| |
| import java.io.EOFException; |
| import java.io.IOException; |
| import java.util.ArrayList; |
| import java.util.Arrays; |
| import java.util.Collections; |
| import java.util.List; |
| import java.util.concurrent.LinkedBlockingQueue; |
| |
| import static org.apache.flink.table.runtime.join.batch.hashtable.longtable.LongHybridHashTable.hashLong; |
| import static org.apache.flink.util.Preconditions.checkArgument; |
| |
| /** |
| * Partition for {@link LongHybridHashTable}. |
| * |
| * <p>The layout of the buckets inside a memory segment is as follows:</p> |
| * |
| * <p>Hash mode: |
| * +----------------------------- Bucket area ---------------------------- |
| * | long key (8 bytes) | address (8 bytes) | |
| * | long key (8 bytes) | address (8 bytes) | |
| * | long key (8 bytes) | address (8 bytes) | |
| * | ... |
| * +----------------------------- Data area -------------------------- |
| * | size & address of next row with the same key (8bytes) | binary row | |
| * | size & address of next row with the same key (8bytes) | binary row | |
| * | size & address of next row with the same key (8bytes) | binary row | |
| * | ... |
| * |
| * <p>Dense mode: |
| * +----------------------------- Bucket area ---------------------------- |
| * | address1 (8 bytes) | address2 (8 bytes) | address3 (8 bytes) | ... |
| * Directly addressed by the index of the corresponding array of key values. |
| */ |
| public class LongHashPartition extends AbstractPagedInputView implements SeekableDataInputView { |
| |
| private static final Logger LOG = LoggerFactory.getLogger(LongHashPartition.class); |
| |
| // The number of bits for size in address |
| private static final int SIZE_BITS = 28; |
| private static final int SIZE_MASK = 0xfffffff; |
| static final long INVALID_ADDRESS = 0x00000FFFFFFFFFL; |
| |
| private final LongHashContext context; |
| |
| // the number of bits in the mem segment size; |
| private final int segmentSizeBits; |
| |
| private final int segmentSizeMask; |
| |
| // the size of the memory segments being used |
| private final int segmentSize; |
| private int partitionNum; |
| private final BinaryRowSerializer buildSideSerializer; |
| private final BinaryRow buildReuseRow; |
| private int recursionLevel; |
| private int numBuckets; |
| |
| // The minimum key |
| private long minKey = Long.MAX_VALUE; |
| |
| // The maximum key |
| private long maxKey = Long.MIN_VALUE; |
| |
| // The array to store the key and offset of BinaryRow in the page. |
| // |
| // Sparse mode: [offset1 | size1] [key1] [offset2 | size2] [offset3] ... |
| // Dense mode: [offset1 | size1] [offset2 | size2] |
| private MemorySegment[] buckets; |
| private int numBucketsMask; |
| |
| // The pages to store all bytes of BinaryRow and the pointer to next rows. |
| // [row1][pointer1] [row2][pointer2] |
| private MemorySegment[] partitionBuffers; |
| private int finalBufferLimit; |
| private int currentBufferNum; |
| private BuildSideBuffer buildSideWriteBuffer; |
| AbstractChannelWriterOutputView probeSideBuffer; |
| long probeSideRecordCounter; // number of probe-side records in this partition |
| |
| // The number of unique keys. |
| private long numKeys; |
| |
| private final MatchIterator iterator; |
| |
| // the channel writer for the build side, if partition is spilled |
| private BlockChannelWriter<MemorySegment> buildSideChannel; |
| |
| // number of build-side records in this partition |
| private long buildSideRecordCounter; |
| |
| int probeNumBytesInLastSeg; |
| |
| /** |
| * Entrance 1: Init LongHashPartition for new insert and search. |
| */ |
| LongHashPartition( |
| LongHashContext context, int partitionNum, BinaryRowSerializer buildSideSerializer, |
| double estimatedRowCount, int maxSegs, int recursionLevel) { |
| this(context, partitionNum, buildSideSerializer, |
| getBucketBuffersByRowCount((long) estimatedRowCount, maxSegs, context.pageSize()), |
| recursionLevel, null, 0); |
| this.buildSideWriteBuffer = new BuildSideBuffer(context.nextSegment()); |
| } |
| |
| /** |
| * Entrance 2: build table from spilled partition when the partition fits entirely into main memory. |
| */ |
| LongHashPartition( |
| LongHashContext context, int partitionNum, BinaryRowSerializer buildSideSerializer, |
| int bucketNumSegs, int recursionLevel, List<MemorySegment> buffers, |
| int lastSegmentLimit) { |
| this(context, buildSideSerializer, listToArray(buffers)); |
| this.partitionNum = partitionNum; |
| this.recursionLevel = recursionLevel; |
| |
| int numBuckets = MathUtils.roundDownToPowerOf2(bucketNumSegs * segmentSize / 16); |
| MemorySegment[] buckets = new MemorySegment[bucketNumSegs]; |
| for (int i = 0; i < bucketNumSegs; i++) { |
| buckets[i] = context.nextSegment(); |
| } |
| setNewBuckets(buckets, numBuckets); |
| this.finalBufferLimit = lastSegmentLimit; |
| } |
| |
| /** |
| * Entrance 3: dense mode for just data search (bucket in LongHybridHashTable of dense mode). |
| */ |
| LongHashPartition(LongHashContext context, BinaryRowSerializer buildSideSerializer, |
| MemorySegment[] partitionBuffers) { |
| super(0); |
| this.context = context; |
| this.buildSideSerializer = buildSideSerializer; |
| this.buildReuseRow = buildSideSerializer.createInstance(); |
| this.segmentSize = context.pageSize(); |
| Preconditions.checkArgument(segmentSize % 16 == 0); |
| this.partitionBuffers = partitionBuffers; |
| this.segmentSizeBits = MathUtils.log2strict(segmentSize); |
| this.segmentSizeMask = segmentSize - 1; |
| this.finalBufferLimit = segmentSize; |
| this.iterator = new MatchIterator(); |
| } |
| |
| private static MemorySegment[] listToArray(List<MemorySegment> list) { |
| if (list != null) { |
| return list.toArray(new MemorySegment[list.size()]); |
| } |
| return null; |
| } |
| |
| private static int getBucketBuffersByRowCount(long rowCount, int maxSegs, int segmentSize) { |
| int minNumBuckets = (int) Math.ceil((rowCount / 0.5)); |
| Preconditions.checkArgument(segmentSize % 16 == 0); |
| return MathUtils.roundDownToPowerOf2((int) Math.max(1, |
| Math.min(maxSegs, Math.ceil(((double) minNumBuckets) * 16 / segmentSize)))); |
| } |
| |
| private void setNewBuckets(MemorySegment[] buckets, int numBuckets) { |
| for (MemorySegment segment : buckets) { |
| for (int i = 0; i < segmentSize; i += 16) { |
| // Maybe we don't need init key, cause always verify address |
| segment.putLong(i, 0); |
| segment.putLong(i + 8, INVALID_ADDRESS); |
| } |
| } |
| this.buckets = buckets; |
| checkArgument(MathUtils.isPowerOf2(numBuckets)); |
| this.numBuckets = numBuckets; |
| this.numBucketsMask = numBuckets - 1; |
| this.numKeys = 0; |
| } |
| |
| static long toAddrAndLen(long address, int size) { |
| return (address << SIZE_BITS) | size; |
| } |
| |
| static long toAddress(long addrAndLen) { |
| return addrAndLen >>> SIZE_BITS; |
| } |
| |
| static int toLength(long addrAndLen) { |
| return (int) (addrAndLen & SIZE_MASK); |
| } |
| |
| /** |
| * Returns an iterator of BinaryRow for multiple linked values. |
| */ |
| MatchIterator valueIter(long address) { |
| iterator.set(address); |
| return iterator; |
| } |
| |
| public MatchIterator get(long key) { |
| return get(key, hashLong(key, recursionLevel)); |
| } |
| |
| /** |
| * Returns an iterator for all the values for the given key, or null if no value found. |
| */ |
| public MatchIterator get(long key, int hashCode) { |
| int bucket = hashCode & numBucketsMask; |
| |
| int bucketOffset = bucket << 4; |
| MemorySegment segment = buckets[bucketOffset >>> segmentSizeBits]; |
| int segOffset = bucketOffset & segmentSizeMask; |
| |
| while (true) { |
| long address = segment.getLong(segOffset + 8); |
| if (address != INVALID_ADDRESS) { |
| if (segment.getLong(segOffset) == key) { |
| return valueIter(address); |
| } else { |
| bucket = (bucket + 1) & numBucketsMask; |
| if (segOffset + 16 < segmentSize) { |
| segOffset += 16; |
| } else { |
| bucketOffset = bucket << 4; |
| segOffset = bucketOffset & segmentSizeMask; |
| segment = buckets[bucketOffset >>> segmentSizeBits]; |
| } |
| } |
| } else { |
| return valueIter(INVALID_ADDRESS); |
| } |
| } |
| } |
| |
| /** |
| * Update the address in array for given key. |
| */ |
| private void updateIndex(long key, int hashCode, long address, int size, |
| MemorySegment dataSegment, int currentPositionInSegment) throws IOException { |
| assert(numKeys <= numBuckets / 2); |
| |
| int bucket = hashCode & numBucketsMask; |
| |
| int bucketOffset = bucket << 4; |
| MemorySegment segment = buckets[bucketOffset >>> segmentSizeBits]; |
| int segOffset = bucketOffset & segmentSizeMask; |
| long currAddress; |
| |
| while (true) { |
| currAddress = segment.getLong(segOffset + 8); |
| if (segment.getLong(segOffset) != key && currAddress != INVALID_ADDRESS) { |
| |
| // TODO test Conflict resolution: |
| // now: +1 +1 +1... cache friendly but more conflict, so we set factor to 0.5 |
| // other1: +1 +2 +3... less conflict, factor can be 0.75 |
| // other2: Secondary hashCode... less and less conflict, but need compute hash again |
| bucket = (bucket + 1) & numBucketsMask; |
| if (segOffset + 16 < segmentSize) { |
| segOffset += 16; |
| } else { |
| bucketOffset = bucket << 4; |
| segment = buckets[bucketOffset >>> segmentSizeBits]; |
| segOffset = bucketOffset & segmentSizeMask; |
| } |
| } else { |
| break; |
| } |
| } |
| if (currAddress == INVALID_ADDRESS) { |
| // this is the first value for this key, put the address in array. |
| segment.putLong(segOffset, key); |
| segment.putLong(segOffset + 8, address); |
| numKeys += 1; |
| if (dataSegment != null) { |
| dataSegment.putLong(currentPositionInSegment, toAddrAndLen(INVALID_ADDRESS, size)); |
| } |
| if (numKeys * 2 > numBuckets) { |
| resize(); |
| } |
| } else { |
| // there are some values for this key, put the address in the front of them. |
| dataSegment.putLong(currentPositionInSegment, toAddrAndLen(currAddress, size)); |
| segment.putLong(segOffset + 8, address); |
| } |
| } |
| |
| private void resize() throws IOException { |
| MemorySegment[] oldBuckets = this.buckets; |
| int oldNumBuckets = numBuckets; |
| int newNumSegs = oldBuckets.length * 2; |
| int newNumBuckets = MathUtils.roundDownToPowerOf2(newNumSegs * segmentSize / 16); |
| |
| // request new buckets. |
| MemorySegment[] newBuckets = new MemorySegment[newNumSegs]; |
| for (int i = 0; i < newNumSegs; i++) { |
| MemorySegment seg = context.getNextBuffer(); |
| if (seg == null) { |
| final int spilledPart = context.spillPartition(); |
| if (spilledPart == partitionNum) { |
| // this bucket is no longer in-memory |
| // free new segments. |
| context.returnAll(Arrays.asList(newBuckets)); |
| return; |
| } |
| seg = context.getNextBuffer(); |
| if (seg == null) { |
| throw new RuntimeException( |
| "Bug in HybridHashJoin: No memory became available after spilling a partition."); |
| } |
| } |
| newBuckets[i] = seg; |
| } |
| |
| setNewBuckets(newBuckets, newNumBuckets); |
| reHash(oldBuckets, oldNumBuckets); |
| } |
| |
| private void reHash(MemorySegment[] oldBuckets, int oldNumBuckets) throws IOException { |
| long reHashStartTime = System.currentTimeMillis(); |
| int bucketOffset = 0; |
| MemorySegment segment = oldBuckets[bucketOffset]; |
| int segOffset = 0; |
| for (int i = 0; i < oldNumBuckets; i++) { |
| long address = segment.getLong(segOffset + 8); |
| if (address != INVALID_ADDRESS) { |
| long key = segment.getLong(segOffset); |
| // size/dataSegment/currentPositionInSegment should never be used. |
| updateIndex(key, hashLong(key, recursionLevel), address, 0, null, 0); |
| } |
| |
| // not last bucket, move to next. |
| if (i != oldNumBuckets - 1) { |
| if (segOffset + 16 < segmentSize) { |
| segOffset += 16; |
| } else { |
| segment = oldBuckets[++bucketOffset]; |
| segOffset = 0; |
| } |
| } |
| } |
| |
| context.returnAll(Arrays.asList(oldBuckets)); |
| LOG.info("The rehash take {} ms for {} segments", (System.currentTimeMillis() - reHashStartTime), numBuckets); |
| } |
| |
| public MemorySegment[] getBuckets() { |
| return buckets; |
| } |
| |
| int getBuildSideBlockCount() { |
| return this.partitionBuffers == null ? this.buildSideWriteBuffer.getBlockCount() |
| : this.partitionBuffers.length; |
| } |
| |
| int getProbeSideBlockCount() { |
| return this.probeSideBuffer == null ? -1 : this.probeSideBuffer.getBlockCount(); |
| } |
| |
| BlockChannelWriter<MemorySegment> getBuildSideChannel() { |
| return this.buildSideChannel; |
| } |
| |
| FileIOChannel.ID getProbeSideChannelID() { |
| return probeSideBuffer.getChannelID(); |
| } |
| |
| int getPartitionNumber() { |
| return this.partitionNum; |
| } |
| |
| MemorySegment[] getPartitionBuffers() { |
| return partitionBuffers; |
| } |
| |
| int getRecursionLevel() { |
| return this.recursionLevel; |
| } |
| |
| int getNumOccupiedMemorySegments() { |
| // either the number of memory segments, or one for spilling |
| final int numPartitionBuffers = this.partitionBuffers != null ? |
| this.partitionBuffers.length |
| : this.buildSideWriteBuffer.getNumOccupiedMemorySegments(); |
| return numPartitionBuffers + buckets.length; |
| } |
| |
| int spillPartition(IOManager ioAccess, FileIOChannel.ID targetChannel, |
| LinkedBlockingQueue<MemorySegment> bufferReturnQueue) throws IOException { |
| // sanity checks |
| if (!isInMemory()) { |
| throw new RuntimeException("Bug in Hybrid Hash Join: " + |
| "Request to spill a partition that has already been spilled."); |
| } |
| if (getNumOccupiedMemorySegments() < 2) { |
| throw new RuntimeException("Bug in Hybrid Hash Join: " + |
| "Request to spill a partition with less than two buffers."); |
| } |
| |
| // create the channel block writer and spill the current buffers |
| // that keep the build side buffers current block, as it is most likely not full, yet |
| // we return the number of blocks that become available |
| this.buildSideChannel = FileChannelUtil.createBlockChannelWriter(ioAccess, targetChannel, bufferReturnQueue, |
| context.compressionEnable(), context.compressionCodecFactory(), context.compressionBlockSize(), segmentSize); |
| return this.buildSideWriteBuffer.spill(this.buildSideChannel); |
| } |
| |
| /** |
| * After build phase. |
| * @return build spill return buffer, if have spilled, it returns the current write buffer, |
| * because it was used all the time in build phase, so it can only be returned at this time. |
| */ |
| int finalizeBuildPhase(IOManager ioAccess, FileIOChannel.Enumerator probeChannelEnumerator) throws IOException { |
| this.finalBufferLimit = this.buildSideWriteBuffer.getCurrentPositionInSegment(); |
| this.partitionBuffers = this.buildSideWriteBuffer.close(); |
| |
| if (!isInMemory()) { |
| // close the channel. |
| this.buildSideChannel.close(); |
| |
| this.probeSideBuffer = FileChannelUtil.createOutputView(ioAccess, probeChannelEnumerator.next(), |
| context.compressionEnable(), context.compressionCodecFactory(), |
| context.compressionBlockSize(), segmentSize); |
| return 1; |
| } else { |
| return 0; |
| } |
| } |
| |
| void finalizeProbePhase(List<LongHashPartition> spilledPartitions) throws IOException { |
| if (isInMemory()) { |
| releaseBuckets(); |
| context.returnAll(partitionBuffers); |
| this.partitionBuffers = null; |
| } else { |
| if (this.probeSideRecordCounter == 0) { |
| // delete the spill files |
| this.probeSideBuffer.close(); |
| this.buildSideChannel.deleteChannel(); |
| this.probeSideBuffer.deleteChannel(); |
| } else { |
| // flush the last probe side buffer and register this partition as pending |
| probeNumBytesInLastSeg = this.probeSideBuffer.close(); |
| spilledPartitions.add(this); |
| } |
| } |
| } |
| |
| final PartitionIterator newPartitionIterator() { |
| return new PartitionIterator(); |
| } |
| |
| final int getLastSegmentLimit() { |
| return this.finalBufferLimit; |
| } |
| |
| // ------------------ PagedInputView for read -------------------- |
| |
| @Override |
| public void setReadPosition(long pointer) { |
| final int bufferNum = (int) (pointer >>> this.segmentSizeBits); |
| final int offset = (int) (pointer & segmentSizeMask); |
| |
| this.currentBufferNum = bufferNum; |
| |
| seekInput(this.partitionBuffers[bufferNum], offset, |
| bufferNum < partitionBuffers.length - 1 ? segmentSize : finalBufferLimit); |
| } |
| |
| @Override |
| protected MemorySegment nextSegment(MemorySegment current) throws IOException { |
| this.currentBufferNum++; |
| if (this.currentBufferNum < this.partitionBuffers.length) { |
| return this.partitionBuffers[this.currentBufferNum]; |
| } else { |
| throw new EOFException(); |
| } |
| } |
| |
| @Override |
| protected int getLimitForSegment(MemorySegment segment) { |
| return segment == partitionBuffers[partitionBuffers.length - 1] ? finalBufferLimit : segmentSize; |
| } |
| |
| boolean isInMemory() { |
| return buildSideChannel == null; |
| } |
| |
| final void insertIntoProbeBuffer(BinaryRowSerializer probeSer, BinaryRow record) throws IOException { |
| probeSer.serialize(record, this.probeSideBuffer); |
| this.probeSideRecordCounter++; |
| } |
| |
| long getBuildSideRecordCount() { |
| return buildSideRecordCounter; |
| } |
| |
| long getMinKey() { |
| return minKey; |
| } |
| |
| long getMaxKey() { |
| return maxKey; |
| } |
| |
| private void updateMinMax(long key) { |
| if (key < minKey) { |
| minKey = key; |
| } |
| if (key > maxKey) { |
| maxKey = key; |
| } |
| } |
| |
| void insertIntoBucket(long key, int hashCode, int size, long address) throws IOException { |
| this.buildSideRecordCounter++; |
| updateMinMax(key); |
| |
| final int bufferNum = (int) (address >>> this.segmentSizeBits); |
| final int offset = (int) (address & (this.segmentSize - 1)); |
| updateIndex(key, hashCode, address, size, partitionBuffers[bufferNum], offset); |
| } |
| |
| void insertIntoTable(long key, int hashCode, BinaryRow row) throws IOException { |
| this.buildSideRecordCounter++; |
| updateMinMax(key); |
| int sizeInBytes = row.getSizeInBytes(); |
| if (sizeInBytes >= (1 << SIZE_BITS)) { |
| throw new UnsupportedOperationException("Does not support row that is larger than 256M"); |
| } |
| if (isInMemory()) { |
| checkWriteAdvance(); |
| |
| if (isInMemory()) { |
| updateIndex(key, hashCode, buildSideWriteBuffer.getCurrentPointer(), sizeInBytes, |
| buildSideWriteBuffer.getCurrentSegment(), buildSideWriteBuffer.getCurrentPositionInSegment()); |
| } else { |
| buildSideWriteBuffer.getCurrentSegment().putLong( |
| buildSideWriteBuffer.getCurrentPositionInSegment(), |
| toAddrAndLen(INVALID_ADDRESS, sizeInBytes)); |
| } |
| |
| buildSideWriteBuffer.skipBytesToWrite(8); |
| |
| if (row.getAllSegments().length == 1) { |
| buildSideWriteBuffer.write(row.getMemorySegment(), row.getBaseOffset(), sizeInBytes); |
| } else { |
| buildSideSerializer.serializeRowToPagesSlow(row, buildSideWriteBuffer); |
| } |
| } else { |
| serializeToPages(row); |
| } |
| } |
| |
| public void serializeToPages(BinaryRow row) throws IOException { |
| |
| int sizeInBytes = row.getSizeInBytes(); |
| checkWriteAdvance(); |
| |
| buildSideWriteBuffer.getCurrentSegment().putLong( |
| buildSideWriteBuffer.getCurrentPositionInSegment(), |
| toAddrAndLen(INVALID_ADDRESS, row.getSizeInBytes())); |
| buildSideWriteBuffer.skipBytesToWrite(8); |
| |
| if (row.getAllSegments().length == 1) { |
| buildSideWriteBuffer.write(row.getMemorySegment(), row.getBaseOffset(), sizeInBytes); |
| } else { |
| buildSideSerializer.serializeRowToPagesSlow(row, buildSideWriteBuffer); |
| } |
| } |
| |
| void releaseBuckets() { |
| if (buckets != null) { |
| context.returnAll(buckets); |
| buckets = null; |
| } |
| } |
| |
| public void append(long key, BinaryRow row) throws IOException { |
| insertIntoTable(key, hashLong(key, recursionLevel), row); |
| } |
| |
| // ------------------ PagedInputView for read end -------------------- |
| |
| /** |
| * Write Buffer. |
| */ |
| private class BuildSideBuffer extends AbstractPagedOutputView { |
| |
| private final ArrayList<MemorySegment> targetList; |
| private int currentBlockNumber; |
| private BlockChannelWriter<MemorySegment> writer; |
| |
| private BuildSideBuffer(MemorySegment segment) { |
| super(segment, segment.size(), 0); |
| this.targetList = new ArrayList<>(); |
| } |
| |
| @Override |
| protected MemorySegment nextSegment(MemorySegment current, |
| int positionInCurrent) throws IOException { |
| final MemorySegment next; |
| if (this.writer == null) { |
| // Must first add current segment: |
| // This may happen when you need to spill: |
| // A partition called nextSegment, can not get memory, need to spill, the result |
| // give itself to the spill, Since it is switching currentSeg, it is necessary |
| // to give the previous currSeg to spill. |
| this.targetList.add(current); |
| next = context.nextSegment(); |
| } else { |
| this.writer.writeBlock(current); |
| try { |
| next = this.writer.getReturnQueue().take(); |
| } catch (InterruptedException iex) { |
| throw new IOException("Hash Join Partition was interrupted while " + |
| "grabbing a new write-behind buffer."); |
| } |
| } |
| |
| this.currentBlockNumber++; |
| return next; |
| } |
| |
| long getCurrentPointer() { |
| return (((long) this.currentBlockNumber) << segmentSizeBits) |
| + getCurrentPositionInSegment(); |
| } |
| |
| int getBlockCount() { |
| return this.currentBlockNumber + 1; |
| } |
| |
| int getNumOccupiedMemorySegments() { |
| // return the current segment + all filled segments |
| return this.targetList.size() + 1; |
| } |
| |
| int spill(BlockChannelWriter<MemorySegment> writer) throws IOException { |
| this.writer = writer; |
| final int numSegments = this.targetList.size(); |
| for (MemorySegment segment : this.targetList) { |
| this.writer.writeBlock(segment); |
| } |
| this.targetList.clear(); |
| return numSegments; |
| } |
| |
| MemorySegment[] close() throws IOException { |
| final MemorySegment current = getCurrentSegment(); |
| if (current == null) { |
| throw new IllegalStateException("Illegal State in LongHashTable: " + |
| "No current buffer when finalizing build side."); |
| } |
| clear(); |
| |
| if (this.writer == null) { |
| this.targetList.add(current); |
| MemorySegment[] buffers = |
| this.targetList.toArray(new MemorySegment[this.targetList.size()]); |
| this.targetList.clear(); |
| return buffers; |
| } else { |
| writer.writeBlock(current); |
| return null; |
| } |
| } |
| } |
| |
| /** |
| * Iterator for probe match. |
| */ |
| public class MatchIterator implements RowIterator<BinaryRow> { |
| private long address; |
| |
| public void set(long address) { |
| this.address = address; |
| } |
| |
| @Override |
| public boolean advanceNext() { |
| if (address != INVALID_ADDRESS) { |
| setReadPosition(address); |
| long addrAndLen = getCurrentSegment().getLong(getCurrentPositionInSegment()); |
| this.address = toAddress(addrAndLen); |
| int size = toLength(addrAndLen); |
| try { |
| skipBytesToRead(8); |
| buildSideSerializer.pointTo(size, buildReuseRow, LongHashPartition.this); |
| } catch (IOException e) { |
| throw new RuntimeException(e); |
| } |
| return true; |
| } |
| return false; |
| } |
| |
| @Override |
| public BinaryRow getRow() { |
| return buildReuseRow; |
| } |
| } |
| |
| void clearAllMemory(List<MemorySegment> target) { |
| // return current buffers from build side and probe side |
| if (this.buildSideWriteBuffer != null) { |
| if (this.buildSideWriteBuffer.getCurrentSegment() != null) { |
| target.add(this.buildSideWriteBuffer.getCurrentSegment()); |
| } |
| target.addAll(this.buildSideWriteBuffer.targetList); |
| this.buildSideWriteBuffer.targetList.clear(); |
| this.buildSideWriteBuffer = null; |
| } |
| releaseBuckets(); |
| |
| // return the partition buffers |
| if (this.partitionBuffers != null) { |
| Collections.addAll(target, this.partitionBuffers); |
| this.partitionBuffers = null; |
| } |
| |
| // clear the channels |
| try { |
| if (this.buildSideChannel != null) { |
| this.buildSideChannel.close(); |
| this.buildSideChannel.deleteChannel(); |
| } |
| if (this.probeSideBuffer != null) { |
| this.probeSideBuffer.closeAndDelete(); |
| this.probeSideBuffer = null; |
| } |
| } catch (IOException ioex) { |
| throw new RuntimeException("Error deleting the partition files. " + |
| "Some temporary files might not be removed.", ioex); |
| } |
| } |
| |
| /** |
| * For spilled partition to rebuild index and hashcode when memory can |
| * store all the build side data. |
| * (After bulk load to memory, see {@link BulkBlockChannelReader}). |
| */ |
| final class PartitionIterator implements RowIterator<BinaryRow> { |
| |
| private long currentPointer; |
| |
| private BinaryRow reuse; |
| |
| private PartitionIterator() { |
| this.reuse = buildSideSerializer.createInstance(); |
| setReadPosition(0); |
| } |
| |
| @Override |
| public boolean advanceNext() { |
| try { |
| checkReadAdvance(); |
| |
| int pos = getCurrentPositionInSegment(); |
| this.currentPointer = (((long) currentBufferNum) << segmentSizeBits) + pos; |
| |
| long addrAndLen = getCurrentSegment().getLong(pos); |
| skipBytesToRead(8); |
| buildSideSerializer.pointTo(toLength(addrAndLen), reuse, LongHashPartition.this); |
| return true; |
| } catch (EOFException e) { |
| return false; |
| } catch (IOException e) { |
| throw new RuntimeException(e); |
| } |
| } |
| |
| final long getPointer() { |
| return this.currentPointer; |
| } |
| |
| @Override |
| public BinaryRow getRow() { |
| return this.reuse; |
| } |
| } |
| |
| private void checkWriteAdvance() throws IOException { |
| if (shouldAdvance( |
| buildSideWriteBuffer.getSegmentSize() - buildSideWriteBuffer.getCurrentPositionInSegment(), |
| buildSideSerializer)) { |
| buildSideWriteBuffer.advance(); |
| } |
| } |
| |
| private void checkReadAdvance() throws IOException { |
| if (shouldAdvance(getCurrentSegmentLimit() - getCurrentPositionInSegment(), |
| buildSideSerializer)) { |
| advance(); |
| } |
| } |
| |
| private static boolean shouldAdvance(int available, BinaryRowSerializer serializer) { |
| return available < 8 + serializer.getFixedLengthPartSize(); |
| } |
| |
| static void deserializeFromPages(BinaryRow reuse, ChannelReaderInputView inView, |
| BinaryRowSerializer buildSideSerializer) throws IOException { |
| if (shouldAdvance(inView.getCurrentSegmentLimit() - inView.getCurrentPositionInSegment(), |
| buildSideSerializer)) { |
| inView.advance(); |
| } |
| MemorySegment segment = reuse.getMemorySegment(); |
| |
| int length = toLength(inView.getCurrentSegment().getLong( |
| inView.getCurrentPositionInSegment())); |
| inView.skipBytesToRead(8); |
| |
| if (segment == null || segment.size() < length) { |
| segment = MemorySegmentFactory.wrap(new byte[length]); |
| } |
| inView.readFully(segment.getHeapMemory(), 0, length); |
| reuse.pointTo(segment, 0, length); |
| } |
| |
| void iteratorToDenseBucket(MemorySegment[] denseBuckets, long addressOffset, long globalMinKey) { |
| int bucketOffset = 0; |
| MemorySegment segment = buckets[bucketOffset]; |
| int segOffset = 0; |
| for (int i = 0; i < numBuckets; i++) { |
| long address = segment.getLong(segOffset + 8); |
| if (address != INVALID_ADDRESS) { |
| long key = segment.getLong(segOffset); |
| long denseBucket = key - globalMinKey; |
| long denseBucketOffset = denseBucket << 3; |
| int denseSegIndex = (int) (denseBucketOffset >>> segmentSizeBits); |
| int denseSegOffset = (int) (denseBucketOffset & segmentSizeMask); |
| denseBuckets[denseSegIndex].putLong(denseSegOffset, address + addressOffset); |
| } |
| |
| // not last bucket, move to next. |
| if (i != numBuckets - 1) { |
| if (segOffset + 16 < segmentSize) { |
| segOffset += 16; |
| } else { |
| segment = buckets[++bucketOffset]; |
| segOffset = 0; |
| } |
| } |
| } |
| } |
| |
| void updateDenseAddressOffset(long addressOffset) { |
| if (addressOffset != 0) { |
| setReadPosition(0); |
| while (true) { |
| try { |
| checkReadAdvance(); |
| long addrAndLen = getCurrentSegment().getLong(getCurrentPositionInSegment()); |
| long address = LongHashPartition.toAddress(addrAndLen); |
| int len = LongHashPartition.toLength(addrAndLen); |
| if (address != INVALID_ADDRESS) { |
| getCurrentSegment().putLong(getCurrentPositionInSegment(), |
| LongHashPartition.toAddrAndLen(address + addressOffset, len)); |
| } |
| skipBytesToRead(8 + len); |
| } catch (EOFException e) { |
| break; |
| } catch (IOException e) { |
| throw new RuntimeException(e); |
| } |
| } |
| } |
| } |
| } |