blob: ca87b3f6028f0fbd1333a464800fa6ad4118151d [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to you under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.calcite.linq4j.test;
import org.apache.calcite.linq4j.Enumerable;
import org.apache.calcite.linq4j.Enumerator;
import org.apache.calcite.linq4j.ExtendedEnumerable;
import org.apache.calcite.linq4j.JoinType;
import org.apache.calcite.linq4j.Linq4j;
import org.apache.calcite.linq4j.function.Function1;
import org.apache.calcite.linq4j.function.Function2;
import com.google.common.collect.ImmutableList;
import org.junit.jupiter.api.Test;
import java.util.ArrayList;
import java.util.List;
import static org.junit.jupiter.api.Assertions.assertArrayEquals;
/**
* Tests {@link ExtendedEnumerable#correlateJoin(JoinType, Function1, Function2)}.
*/
class CorrelateJoinTest {
static final Function2<Integer, Integer, Integer[]> SELECT_BOTH =
(v0, v1) -> new Integer[]{v0, v1};
@Test void testInner() {
testJoin(JoinType.INNER, new Integer[][]{
{2, 20},
{3, -30},
{3, -60},
{20, 200},
{30, -300},
{30, -600}});
}
@Test void testLeft() {
testJoin(JoinType.LEFT, new Integer[][]{
{1, null},
{2, 20},
{3, -30},
{3, -60},
{10, null},
{20, 200},
{30, -300},
{30, -600}});
}
@Test void testSemi() {
testJoin(JoinType.SEMI, new Integer[][]{
{2, null},
{3, null},
{20, null},
{30, null}});
}
@Test void testAnti() {
testJoin(JoinType.ANTI, new Integer[][]{
{1, null},
{10, null}});
}
public void testJoin(JoinType joinType, Integer[][] expected) {
Enumerable<Integer[]> join =
Linq4j.asEnumerable(ImmutableList.of(1, 2, 3, 10, 20, 30))
.correlateJoin(joinType, a0 -> {
if (a0 == 1 || a0 == 10) {
return Linq4j.emptyEnumerable();
}
if (a0 == 2 || a0 == 20) {
return Linq4j.singletonEnumerable(a0 * 10);
}
if (a0 == 3 || a0 == 30) {
return Linq4j.asEnumerable(
ImmutableList.of(-a0 * 10, -a0 * 20));
}
throw new IllegalArgumentException(
"Unexpected input " + a0);
}, SELECT_BOTH);
for (int i = 0; i < 2; i++) {
Enumerator<Integer[]> e = join.enumerator();
checkResults(e, expected);
e.close();
}
}
private void checkResults(Enumerator<Integer[]> e, Integer[][] expected) {
List<Integer[]> res = new ArrayList<>();
while (e.moveNext()) {
res.add(e.current());
}
Integer[][] actual = res.toArray(new Integer[res.size()][]);
assertArrayEquals(expected, actual);
}
}