blob: 1283be68ee99984b8eba3c688e48a490fd9104c7 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file exceBinaryRow in compliance
* with the License. You may oBinaryRowain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHBinaryRow WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.flink.table.runtime.join.batch;
import org.apache.flink.configuration.ConfigConstants;
import org.apache.flink.runtime.io.disk.iomanager.IOManager;
import org.apache.flink.streaming.api.operators.TwoInputSelection;
import org.apache.flink.streaming.api.operators.TwoInputStreamOperator;
import org.apache.flink.streaming.runtime.streamrecord.StreamRecord;
import org.apache.flink.table.api.types.RowType;
import org.apache.flink.table.codegen.CodeGenUtils;
import org.apache.flink.table.codegen.GeneratedJoinConditionFunction;
import org.apache.flink.table.codegen.GeneratedProjection;
import org.apache.flink.table.codegen.JoinConditionFunction;
import org.apache.flink.table.codegen.Projection;
import org.apache.flink.table.dataformat.BaseRow;
import org.apache.flink.table.dataformat.BinaryRow;
import org.apache.flink.table.dataformat.GenericRow;
import org.apache.flink.table.dataformat.JoinedRow;
import org.apache.flink.table.runtime.AbstractStreamOperatorWithMetrics;
import org.apache.flink.table.runtime.join.batch.hashtable.BinaryHashTable;
import org.apache.flink.table.runtime.util.RowIterator;
import org.apache.flink.table.runtime.util.StreamRecordCollector;
import org.apache.flink.table.typeutils.AbstractRowSerializer;
import org.apache.flink.util.Collector;
import org.codehaus.commons.compiler.CompileException;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import java.io.Serializable;
import static org.apache.flink.util.Preconditions.checkNotNull;
/**
* Shuffle hash join base operator.
* The join operator implements the logic of a join operator at runtime. It uses a
* hybrid-hash-join internally to match the records with equal key. The build side
* of the hash is the first input of the match.
*/
public abstract class HashJoinOperator extends AbstractStreamOperatorWithMetrics<BaseRow>
implements TwoInputStreamOperator<BaseRow, BaseRow, BaseRow> {
private static final Logger LOG = LoggerFactory.getLogger(HashJoinOperator.class);
private final HashJoinParameter parameter;
private final boolean reverseJoinFunction;
final HashJoinType type;
// cooked classes
transient Class<JoinConditionFunction> condFuncClass;
transient Class<Projection<BaseRow, BinaryRow>> buildProjectionClass;
transient Class<Projection<BaseRow, BinaryRow>> probeProjectionClass;
private transient BinaryHashTable table;
transient Collector<BaseRow> collector;
transient BaseRow buildSideNullRow;
transient BaseRow probeSideNullRow;
private transient JoinedRow joinedRow;
HashJoinOperator(HashJoinParameter parameter) {
this.parameter = parameter;
this.type = parameter.type;
this.reverseJoinFunction = parameter.reverseJoinFunction;
}
@Override
public void open() throws Exception {
super.open();
// code gen function and projection classes.
cookGeneratedClasses(getContainingTask().getUserCodeClassLoader());
IOManager ioManager = getContainingTask().getEnvironment().getIOManager();
final AbstractRowSerializer buildSerializer = (AbstractRowSerializer) getOperatorConfig()
.getTypeSerializerIn1(getUserCodeClassloader());
final AbstractRowSerializer probeSerializer = (AbstractRowSerializer) getOperatorConfig()
.getTypeSerializerIn2(getUserCodeClassloader());
boolean hashJoinUseBitMaps =
getContainingTask().getEnvironment().getTaskConfiguration().getBoolean(
ConfigConstants.RUNTIME_HASH_JOIN_BLOOM_FILTERS_KEY,
ConfigConstants.DEFAULT_RUNTIME_HASH_JOIN_BLOOM_FILTERS);
int parallel = getRuntimeContext().getNumberOfParallelSubtasks();
this.table = new BinaryHashTable(
getSqlConf(),
getContainingTask(),
buildSerializer, probeSerializer,
buildProjectionClass.newInstance(), probeProjectionClass.newInstance(),
getContainingTask().getEnvironment().getMemoryManager(),
parameter.reservedMemorySize,
parameter.maxMemorySize,
parameter.perRequestMemorySize,
ioManager, parameter.buildRowSize, parameter.buildRowCount / parallel,
hashJoinUseBitMaps, type,
condFuncClass.newInstance(), reverseJoinFunction, parameter.filterNullKeys,
parameter.tryDistinctBuildRow);
this.collector = new StreamRecordCollector<>(output);
this.buildSideNullRow = new GenericRow(buildSerializer.getNumFields());
this.probeSideNullRow = new GenericRow(probeSerializer.getNumFields());
this.joinedRow = new JoinedRow();
getMetricGroup().gauge("memoryUsedSizeInBytes", table::getUsedMemoryInBytes);
getMetricGroup().gauge("numSpillFiles", table::getNumSpillFiles);
getMetricGroup().gauge("spillInBytes", table::getSpillInBytes);
}
protected void cookGeneratedClasses(ClassLoader cl) throws CompileException {
long startTime = System.currentTimeMillis();
condFuncClass = CodeGenUtils.compile(
cl, parameter.condFuncCode.name(), parameter.condFuncCode.code());
buildProjectionClass = CodeGenUtils.compile(
cl, parameter.buildProjectionCode.name(), parameter.buildProjectionCode.code());
probeProjectionClass = CodeGenUtils.compile(
cl, parameter.probeProjectionCode.name(), parameter.probeProjectionCode.code());
parameter.condFuncCode = null;
parameter.buildProjectionCode = null;
parameter.probeProjectionCode = null;
long endTime = System.currentTimeMillis();
LOG.info("Compiling generated codes, used time: " + (endTime - startTime) + "ms.");
}
@Override
public TwoInputSelection firstInputSelection() {
return TwoInputSelection.FIRST;
}
@Override
public TwoInputSelection processElement1(StreamRecord<BaseRow> element) throws Exception {
this.table.putBuildRow(element.getValue());
return TwoInputSelection.FIRST;
}
@Override
public TwoInputSelection processElement2(StreamRecord<BaseRow> element) throws Exception {
if (this.table.tryProbe(element.getValue())) {
joinWithNextKey();
}
return TwoInputSelection.SECOND;
}
@Override
public void endInput1() throws Exception {
LOG.info("Finish build phase.");
this.table.endBuild();
}
@Override
public void endInput2() throws Exception {
LOG.info("Finish probe phase.");
while (this.table.nextMatching()) {
joinWithNextKey();
}
LOG.info("Finish rebuild phase.");
}
private void joinWithNextKey() throws Exception {
// we have a next record, get the iterators to the probe and build side values
join(table.getBuildSideIterator(), table.getCurrentProbeRow());
}
public abstract void join(RowIterator<BinaryRow> buildIter, BaseRow probeRow) throws Exception;
void innerJoin(RowIterator<BinaryRow> buildIter, BaseRow probeRow) throws Exception {
collect(buildIter.getRow(), probeRow);
while (buildIter.advanceNext()) {
collect(buildIter.getRow(), probeRow);
}
}
void buildOuterJoin(RowIterator<BinaryRow> buildIter) throws Exception {
collect(buildIter.getRow(), probeSideNullRow);
while (buildIter.advanceNext()) {
collect(buildIter.getRow(), probeSideNullRow);
}
}
void collect(BaseRow row1, BaseRow row2) throws Exception {
if (reverseJoinFunction) {
collector.collect(joinedRow.replace(row2, row1));
} else {
collector.collect(joinedRow.replace(row1, row2));
}
}
@Override
public void close() throws Exception {
super.close();
if (this.table != null) {
this.table.close();
this.table.free();
this.table = null;
}
}
public static HashJoinOperator newHashJoinOperator(
long minMemorySize,
long maxMemorySize,
long eachRequestMemorySize,
HashJoinType type,
GeneratedJoinConditionFunction condFuncCode,
boolean reverseJoinFunction,
boolean[] filterNullKeys,
GeneratedProjection buildProjectionCode,
GeneratedProjection probeProjectionCode,
boolean tryDistinctBuildRow,
int buildRowSize,
long buildRowCount,
long probeRowCount,
RowType keyType) {
HashJoinParameter parameter = new HashJoinParameter(minMemorySize, maxMemorySize, eachRequestMemorySize,
type, condFuncCode, reverseJoinFunction, filterNullKeys, buildProjectionCode, probeProjectionCode,
tryDistinctBuildRow, buildRowSize, buildRowCount, probeRowCount, keyType);
switch (type) {
case INNER:
return new InnerHashJoinOperator(parameter);
case BUILD_OUTER:
return new BuildOuterHashJoinOperator(parameter);
case PROBE_OUTER:
return new ProbeOuterHashJoinOperator(parameter);
case FULL_OUTER:
return new FullOuterHashJoinOperator(parameter);
case SEMI:
return new SemiHashJoinOperator(parameter);
case ANTI:
return new AntiHashJoinOperator(parameter);
case BUILD_LEFT_SEMI:
case BUILD_LEFT_ANTI:
return new BuildLeftSemiOrAntiHashJoinOperator(parameter);
default:
throw new IllegalArgumentException("invalid: " + type);
}
}
static class HashJoinParameter implements Serializable {
long reservedMemorySize;
long maxMemorySize;
long perRequestMemorySize;
HashJoinType type;
GeneratedJoinConditionFunction condFuncCode;
boolean reverseJoinFunction;
boolean[] filterNullKeys;
GeneratedProjection buildProjectionCode;
GeneratedProjection probeProjectionCode;
boolean tryDistinctBuildRow;
int buildRowSize;
long buildRowCount;
long probeRowCount;
RowType keyType;
HashJoinParameter(
long reservedMemorySize, long maxMemorySize, long perRequestMemorySize, HashJoinType type,
GeneratedJoinConditionFunction condFuncCode, boolean reverseJoinFunction,
boolean[] filterNullKeys,
GeneratedProjection buildProjectionCode,
GeneratedProjection probeProjectionCode, boolean tryDistinctBuildRow,
int buildRowSize, long buildRowCount, long probeRowCount, RowType keyType) {
this.reservedMemorySize = reservedMemorySize;
this.maxMemorySize = maxMemorySize;
this.perRequestMemorySize = perRequestMemorySize;
this.type = type;
this.condFuncCode = condFuncCode;
this.reverseJoinFunction = reverseJoinFunction;
this.filterNullKeys = filterNullKeys;
this.buildProjectionCode = buildProjectionCode;
this.probeProjectionCode = probeProjectionCode;
this.tryDistinctBuildRow = tryDistinctBuildRow;
this.buildRowSize = buildRowSize;
this.buildRowCount = buildRowCount;
this.probeRowCount = probeRowCount;
this.keyType = keyType;
}
}
/**
* Inner join.
*/
private static class InnerHashJoinOperator extends HashJoinOperator {
InnerHashJoinOperator(HashJoinParameter parameter) {
super(parameter);
}
@Override
public void join(RowIterator<BinaryRow> buildIter, BaseRow probeRow) throws Exception {
if (buildIter.advanceNext()) {
if (probeRow != null) {
innerJoin(buildIter, probeRow);
}
}
}
}
/**
* BuildOuter join.
*/
private static class BuildOuterHashJoinOperator extends HashJoinOperator {
BuildOuterHashJoinOperator(HashJoinParameter parameter) {
super(parameter);
}
@Override
public void join(RowIterator<BinaryRow> buildIter, BaseRow probeRow) throws Exception {
if (buildIter.advanceNext()) {
if (probeRow != null) {
innerJoin(buildIter, probeRow);
} else {
buildOuterJoin(buildIter);
}
}
}
}
/**
* ProbeOuter join.
*/
private static class ProbeOuterHashJoinOperator extends HashJoinOperator {
ProbeOuterHashJoinOperator(HashJoinParameter parameter) {
super(parameter);
}
@Override
public void join(RowIterator<BinaryRow> buildIter, BaseRow probeRow) throws Exception {
if (buildIter.advanceNext()) {
if (probeRow != null) {
innerJoin(buildIter, probeRow);
}
} else if (probeRow != null) {
collect(buildSideNullRow, probeRow);
}
}
}
/**
* FullOuter join.
*/
private static class FullOuterHashJoinOperator extends HashJoinOperator {
FullOuterHashJoinOperator(HashJoinParameter parameter) {
super(parameter);
}
@Override
public void join(RowIterator<BinaryRow> buildIter, BaseRow probeRow) throws Exception {
if (buildIter.advanceNext()) {
if (probeRow != null) {
innerJoin(buildIter, probeRow);
} else {
buildOuterJoin(buildIter);
}
} else if (probeRow != null) {
collect(buildSideNullRow, probeRow);
}
}
}
/**
* Semi join.
*/
private static class SemiHashJoinOperator extends HashJoinOperator {
SemiHashJoinOperator(HashJoinParameter parameter) {
super(parameter);
}
@Override
public void join(RowIterator<BinaryRow> buildIter, BaseRow probeRow) throws Exception {
checkNotNull(probeRow);
if (buildIter.advanceNext()) {
collector.collect(probeRow);
}
}
}
/**
* Anti join.
*/
private static class AntiHashJoinOperator extends HashJoinOperator {
AntiHashJoinOperator(HashJoinParameter parameter) {
super(parameter);
}
@Override
public void join(RowIterator<BinaryRow> buildIter, BaseRow probeRow) throws Exception {
checkNotNull(probeRow);
if (!buildIter.advanceNext()) {
collector.collect(probeRow);
}
}
}
/**
* BuildLeftSemiOrAnti join.
*/
private static class BuildLeftSemiOrAntiHashJoinOperator extends HashJoinOperator {
BuildLeftSemiOrAntiHashJoinOperator(HashJoinParameter parameter) {
super(parameter);
}
@Override
public void join(RowIterator<BinaryRow> buildIter, BaseRow probeRow) throws Exception {
if (buildIter.advanceNext()) {
if (probeRow != null) { //Probe phase
// we must iterator to set probedSet.
while (buildIter.advanceNext()) {}
} else { //End Probe phase, iterator build side elements.
collector.collect(buildIter.getRow());
while (buildIter.advanceNext()) {
collector.collect(buildIter.getRow());
}
}
}
}
}
}