blob: d80c8da789fa3168a8c3c95e974c2e645f297aea [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.runtime.join.batch;
import org.apache.flink.api.common.typeutils.TypeSerializer;
import org.apache.flink.configuration.Configuration;
import org.apache.flink.core.memory.MemorySegment;
import org.apache.flink.metrics.Gauge;
import org.apache.flink.runtime.io.disk.iomanager.IOManager;
import org.apache.flink.runtime.memory.MemoryAllocationException;
import org.apache.flink.runtime.memory.MemoryManager;
import org.apache.flink.streaming.api.operators.TwoInputSelection;
import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.table.codegen.GeneratedJoinConditionFunction;
import org.apache.flink.table.codegen.GeneratedProjection;
import org.apache.flink.table.codegen.GeneratedSorter;
import org.apache.flink.table.codegen.JoinConditionFunction;
import org.apache.flink.table.codegen.Projection;
import org.apache.flink.table.dataformat.BaseRow;
import org.apache.flink.table.dataformat.GenericRow;
import org.apache.flink.table.dataformat.JoinedRow;
import org.apache.flink.table.plan.FlinkJoinRelType;
import org.apache.flink.table.runtime.AbstractStreamOperatorWithMetrics;
import org.apache.flink.table.runtime.sort.BinaryExternalSorter;
import org.apache.flink.table.runtime.sort.NormalizedKeyComputer;
import org.apache.flink.table.runtime.sort.RecordComparator;
import org.apache.flink.table.runtime.util.ResettableExternalBuffer;
import org.apache.flink.table.runtime.util.StreamRecordCollector;
import org.apache.flink.table.typeutils.AbstractRowSerializer;
import org.apache.flink.table.typeutils.BinaryRowSerializer;
import org.apache.flink.util.Collector;
import org.apache.flink.util.MutableObjectIterator;
import org.codehaus.commons.compiler.CompileException;
import java.util.List;
import static org.apache.flink.table.codegen.CodeGenUtils.compile;
import static org.apache.flink.util.Preconditions.checkNotNull;
/**
* An implementation that realizes the joining through a sort-merge join strategy.
*/
public class SortMergeJoinOperator extends AbstractStreamOperatorWithMetrics<BaseRow>
implements TwoInputStreamOperator<BaseRow, BaseRow, BaseRow> {
private final long reservedSortMemory1;
private final long maxSortMemory1;
private final long reservedSortMemory2;
private final long maxSortMemory2;
private final long perRequestMemory;
private final long externalBufferMemory;
private final FlinkJoinRelType type;
private final boolean leftIsSmaller;
// generated code to cook
private GeneratedJoinConditionFunction condFuncCode;
private GeneratedProjection projectionCode1;
private GeneratedProjection projectionCode2;
private GeneratedSorter gSorter1;
private GeneratedSorter gSorter2;
private GeneratedSorter keyGSorter;
private final boolean[] filterNulls;
private transient CookedClasses classes;
private transient Configuration conf;
private transient MemoryManager memManager;
private transient IOManager ioManager;
private transient TypeSerializer<BaseRow> inputSerializer1;
private transient TypeSerializer<BaseRow> inputSerializer2;
private transient BinaryRowSerializer serializer1;
private transient BinaryRowSerializer serializer2;
private transient BinaryExternalSorter sorter1;
private transient BinaryExternalSorter sorter2;
private transient SortMergeJoinIterator joinIterator1;
private transient SortMergeJoinIterator joinIterator2;
private transient SortMergeFullOuterJoinIterator fullOuterJoinIterator;
private transient Collector<BaseRow> collector;
private transient boolean[] isFinished;
private transient JoinConditionFunction condFunc;
private transient RecordComparator keyComparator;
private transient BaseRow leftNullRow;
private transient BaseRow rightNullRow;
private transient SortMergeJoinHelper helper;
public SortMergeJoinOperator(
long reservedSortMemory, long maxSortMemory,
long perRequestMemory, long externalBufferMemory, FlinkJoinRelType type, boolean leftIsSmaller,
GeneratedJoinConditionFunction condFuncCode,
GeneratedProjection projectionCode1, GeneratedProjection projectionCode2,
GeneratedSorter gSorter1, GeneratedSorter gSorter2, GeneratedSorter keyGSorter,
boolean[] filterNulls) {
this(reservedSortMemory, maxSortMemory, reservedSortMemory, maxSortMemory, perRequestMemory,
externalBufferMemory, type, leftIsSmaller, condFuncCode, projectionCode1, projectionCode2, gSorter1,
gSorter2, keyGSorter, filterNulls);
}
public SortMergeJoinOperator(
long reservedSortMemory1, long maxSortMemory1,
long reservedSortMemory2, long maxSortMemory2,
long perRequestMemory, long externalBufferMemory, FlinkJoinRelType type, boolean leftIsSmaller,
GeneratedJoinConditionFunction condFuncCode,
GeneratedProjection projectionCode1, GeneratedProjection projectionCode2,
GeneratedSorter gSorter1, GeneratedSorter gSorter2, GeneratedSorter keyGSorter,
boolean[] filterNulls) {
this.reservedSortMemory1 = reservedSortMemory1;
this.maxSortMemory1 = maxSortMemory1;
this.reservedSortMemory2 = reservedSortMemory2;
this.maxSortMemory2 = maxSortMemory2;
this.perRequestMemory = perRequestMemory;
this.externalBufferMemory = externalBufferMemory;
this.type = type;
this.leftIsSmaller = leftIsSmaller;
this.condFuncCode = condFuncCode;
this.projectionCode1 = projectionCode1;
this.projectionCode2 = projectionCode2;
this.gSorter1 = checkNotNull(gSorter1);
this.gSorter2 = checkNotNull(gSorter2);
this.keyGSorter = checkNotNull(keyGSorter);
this.filterNulls = filterNulls;
}
@Override
public void open() throws Exception {
super.open();
conf = getSqlConf();
// code gen classes.
this.classes = cookGeneratedClasses(getContainingTask().getUserCodeClassLoader());
isFinished = new boolean[2];
isFinished[0] = false;
isFinished[1] = false;
collector = new StreamRecordCollector<>(output);
this.inputSerializer1 = getOperatorConfig().getTypeSerializerIn1(getUserCodeClassloader());
this.serializer1 =
new BinaryRowSerializer(((AbstractRowSerializer) inputSerializer1).getTypes());
this.inputSerializer2 = getOperatorConfig().getTypeSerializerIn2(getUserCodeClassloader());
this.serializer2 =
new BinaryRowSerializer(((AbstractRowSerializer) inputSerializer2).getTypes());
this.memManager = this.getContainingTask().getEnvironment().getMemoryManager();
this.ioManager = this.getContainingTask().getEnvironment().getIOManager();
initSorter();
initKeyComparator();
this.condFunc = classes.condFuncClass.newInstance();
this.leftNullRow = new GenericRow(serializer1.getNumFields());
this.rightNullRow = new GenericRow(serializer2.getNumFields());
JoinedRow joinedRow = new JoinedRow();
this.helper = new SortMergeJoinHelper(collector, condFunc, leftNullRow, rightNullRow, joinedRow);
condFuncCode = null;
keyGSorter = null;
projectionCode1 = null;
projectionCode2 = null;
gSorter1 = null;
gSorter2 = null;
getMetricGroup().gauge("memoryUsedSizeInBytes",
(Gauge<Long>) () -> sorter1.getUsedMemoryInBytes() + sorter2.getUsedMemoryInBytes());
getMetricGroup().gauge("numSpillFiles",
(Gauge<Long>) () -> sorter1.getNumSpillFiles() + sorter2.getNumSpillFiles());
getMetricGroup().gauge("spillInBytes",
(Gauge<Long>) () -> sorter1.getSpillInBytes() + sorter2.getSpillInBytes());
}
private void initSorter() throws Exception {
// sorter1
NormalizedKeyComputer computer1 = classes.computerClass1.newInstance();
RecordComparator comparator1 = classes.comparatorClass1.newInstance();
computer1.init(gSorter1.serializers(), gSorter1.comparators());
comparator1.init(gSorter1.serializers(), gSorter1.comparators());
this.sorter1 = new BinaryExternalSorter(this.getContainingTask(),
memManager, reservedSortMemory1, maxSortMemory1, perRequestMemory,
ioManager, inputSerializer1, serializer1, computer1, comparator1, conf);
this.sorter1.startThreads();
// sorter2
NormalizedKeyComputer computer2 = classes.computerClass2.newInstance();
RecordComparator comparator2 = classes.comparatorClass2.newInstance();
computer2.init(gSorter2.serializers(), gSorter2.comparators());
comparator2.init(gSorter2.serializers(), gSorter2.comparators());
this.sorter2 = new BinaryExternalSorter(this.getContainingTask(),
memManager, reservedSortMemory2, maxSortMemory2, perRequestMemory,
ioManager, inputSerializer2, serializer2, computer2, comparator2, conf);
this.sorter2.startThreads();
}
private void initKeyComparator() throws Exception {
keyComparator = classes.keyComparatorClass.newInstance();
keyComparator.init(keyGSorter.serializers(), keyGSorter.comparators());
}
protected CookedClasses cookGeneratedClasses(ClassLoader cl) throws CompileException {
return new CookedClasses(
compile(cl, condFuncCode.name(), condFuncCode.code()),
compile(cl, keyGSorter.comparator().name(), keyGSorter.comparator().code()),
compile(cl, projectionCode1.name(), projectionCode1.code()),
compile(cl, projectionCode2.name(), projectionCode2.code()),
compile(cl, gSorter1.computer().name(), gSorter1.computer().code()),
compile(cl, gSorter2.computer().name(), gSorter2.computer().code()),
compile(cl, gSorter1.comparator().name(), gSorter1.comparator().code()),
compile(cl, gSorter2.comparator().name(), gSorter2.comparator().code())
);
}
@Override
public TwoInputSelection firstInputSelection() {
return TwoInputSelection.ANY;
}
@Override
public TwoInputSelection processElement1(StreamRecord<BaseRow> element) throws Exception {
this.sorter1.write(element.getValue());
return TwoInputSelection.ANY;
}
@Override
public TwoInputSelection processElement2(StreamRecord<BaseRow> element) throws Exception {
this.sorter2.write(element.getValue());
return TwoInputSelection.ANY;
}
@Override
public void endInput1() throws Exception {
isFinished[0] = true;
if (isAllFinished()) {
doSortMergeJoin();
}
}
@Override
public void endInput2() throws Exception {
isFinished[1] = true;
if (isAllFinished()) {
doSortMergeJoin();
}
}
private void doSortMergeJoin() throws Exception {
Projection projection1 = classes.projectionClass1.newInstance();
Projection projection2 = classes.projectionClass2.newInstance();
MutableObjectIterator iterator1 = sorter1.getIterator();
MutableObjectIterator iterator2 = sorter2.getIterator();
if (type.equals(FlinkJoinRelType.INNER)) {
if (!leftIsSmaller) {
joinIterator2 = new SortMergeInnerJoinIterator(
serializer1, serializer2, projection1, projection2,
keyComparator, iterator1, iterator2, newBuffer(serializer2), filterNulls);
helper.innerJoin((SortMergeInnerJoinIterator) joinIterator2, false);
} else {
joinIterator1 = new SortMergeInnerJoinIterator(
serializer2, serializer1, projection2, projection1,
keyComparator, iterator2, iterator1, newBuffer(serializer1), filterNulls);
helper.innerJoin((SortMergeInnerJoinIterator) joinIterator1, true);
}
} else if (type.equals(FlinkJoinRelType.LEFT)) {
joinIterator2 = new SortMergeOneSideOuterJoinIterator(
serializer1, serializer2, projection1, projection2,
keyComparator, iterator1, iterator2, newBuffer(serializer2), filterNulls);
helper.oneSideOuterJoin((SortMergeOneSideOuterJoinIterator) joinIterator2, false, rightNullRow);
} else if (type.equals(FlinkJoinRelType.RIGHT)) {
joinIterator1 = new SortMergeOneSideOuterJoinIterator(
serializer2, serializer1, projection2, projection1,
keyComparator, iterator2, iterator1, newBuffer(serializer1), filterNulls);
helper.oneSideOuterJoin((SortMergeOneSideOuterJoinIterator) joinIterator1, true, leftNullRow);
} else if (type.equals(FlinkJoinRelType.FULL)) {
fullOuterJoinIterator = new SortMergeFullOuterJoinIterator(
serializer1, serializer2, projection1, projection2,
keyComparator, iterator1, iterator2,
newBuffer(serializer1), newBuffer(serializer2), filterNulls);
helper.fullOuterJoin(fullOuterJoinIterator);
} else if (type.equals(FlinkJoinRelType.SEMI)) {
joinIterator2 = new SortMergeInnerJoinIterator(
serializer1, serializer2, projection1, projection2,
keyComparator, iterator1, iterator2, newBuffer(serializer2), filterNulls);
while (((SortMergeInnerJoinIterator) joinIterator2).nextInnerJoin()) {
BaseRow probeRow = joinIterator2.getProbeRow();
boolean matched = false;
try (ResettableExternalBuffer.BufferIterator iter = joinIterator2.getMatchBuffer().newIterator()) {
while (iter.advanceNext()) {
BaseRow row = iter.getRow();
if (condFunc.apply(probeRow, row)) {
matched = true;
break;
}
}
}
if (matched) {
collector.collect(probeRow);
}
}
} else if (type.equals(FlinkJoinRelType.ANTI)) {
joinIterator2 = new SortMergeOneSideOuterJoinIterator(
serializer1, serializer2, projection1, projection2,
keyComparator, iterator1, iterator2, newBuffer(serializer2), filterNulls);
while (((SortMergeOneSideOuterJoinIterator) joinIterator2).nextOuterJoin()) {
BaseRow probeRow = joinIterator2.getProbeRow();
ResettableExternalBuffer matchBuffer = joinIterator2.getMatchBuffer();
boolean matched = false;
if (matchBuffer != null) {
try (ResettableExternalBuffer.BufferIterator iter = matchBuffer.newIterator()) {
while (iter.advanceNext()) {
BaseRow row = iter.getRow();
if (condFunc.apply(probeRow, row)) {
matched = true;
break;
}
}
}
}
if (!matched) {
collector.collect(probeRow);
}
}
} else {
throw new RuntimeException("Not support yet!");
}
}
private ResettableExternalBuffer newBuffer(BinaryRowSerializer serializer) throws MemoryAllocationException {
List<MemorySegment> externalBufferSegments = memManager.allocatePages(
this.getContainingTask(), (int) (externalBufferMemory / memManager.getPageSize()));
return new ResettableExternalBuffer(memManager, ioManager, externalBufferSegments, serializer);
}
private boolean isAllFinished() {
return isFinished[0] && isFinished[1];
}
@Override
public void close() throws Exception {
super.close();
if (this.sorter1 != null) {
this.sorter1.close();
}
if (this.sorter2 != null) {
this.sorter2.close();
}
if (this.joinIterator1 != null) {
this.joinIterator1.close();
}
if (this.joinIterator2 != null) {
this.joinIterator2.close();
}
if (this.fullOuterJoinIterator != null) {
this.fullOuterJoinIterator.close();
}
}
/**
* Generated classes.
*/
protected static class CookedClasses {
protected CookedClasses(
Class<JoinConditionFunction> condFuncClass,
Class<RecordComparator> keyComparatorClass,
Class<Projection> projectionClass1,
Class<Projection> projectionClass2,
Class<NormalizedKeyComputer> computerClass1,
Class<NormalizedKeyComputer> computerClass2,
Class<RecordComparator> comparatorClass1,
Class<RecordComparator> comparatorClass2) {
this.condFuncClass = condFuncClass;
this.keyComparatorClass = keyComparatorClass;
this.projectionClass1 = projectionClass1;
this.projectionClass2 = projectionClass2;
this.computerClass1 = computerClass1;
this.computerClass2 = computerClass2;
this.comparatorClass1 = comparatorClass1;
this.comparatorClass2 = comparatorClass2;
}
protected final Class<JoinConditionFunction> condFuncClass;
protected final Class<RecordComparator> keyComparatorClass;
protected final Class<Projection> projectionClass1;
protected final Class<Projection> projectionClass2;
protected final Class<NormalizedKeyComputer> computerClass1;
protected final Class<NormalizedKeyComputer> computerClass2;
protected final Class<RecordComparator> comparatorClass1;
protected final Class<RecordComparator> comparatorClass2;
}
}