blob: a63f0ecb5a4d6129e550ac5fb7912ae391b138c1 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.druid.math.expr;
import com.google.common.collect.ImmutableMap;
import org.apache.druid.common.config.NullHandling;
import org.apache.druid.testing.InitializedNullHandlingTest;
import org.junit.Assert;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
public class ApplyFunctionTest extends InitializedNullHandlingTest
{
private Expr.ObjectBinding bindings;
@Rule
public ExpectedException expectedException = ExpectedException.none();
@Before
public void setup()
{
ImmutableMap.Builder<String, Object> builder = ImmutableMap.builder();
builder.put("x", "foo");
builder.put("y", 2);
builder.put("z", 3.1);
builder.put("a", new String[] {"foo", "bar", "baz", "foobar"});
builder.put("b", new Long[] {1L, 2L, 3L, 4L, 5L});
builder.put("c", new Double[] {3.1, 4.2, 5.3});
builder.put("d", new String[] {null});
builder.put("e", new String[] {null, "foo", "bar"});
builder.put("f", new String[0]);
bindings = InputBindings.withMap(builder.build());
}
@Test
public void testMap()
{
assertExpr("map((x) -> concat(x, 'foo'), ['foo', 'bar', 'baz', 'foobar'])", new String[] {"foofoo", "barfoo", "bazfoo", "foobarfoo"});
assertExpr("map((x) -> concat(x, 'foo'), a)", new String[] {"foofoo", "barfoo", "bazfoo", "foobarfoo"});
assertExpr("map((x) -> x + 1, [1, 2, 3, 4, 5])", new Long[] {2L, 3L, 4L, 5L, 6L});
assertExpr("map((x) -> x + 1, b)", new Long[] {2L, 3L, 4L, 5L, 6L});
assertExpr("map((c) -> c + z, [3.1, 4.2, 5.3])", new Double[]{6.2, 7.3, 8.4});
assertExpr("map((c) -> c + z, c)", new Double[]{6.2, 7.3, 8.4});
assertExpr("map((x) -> x + 1, map((x) -> x + 1, [1, 2, 3, 4, 5]))", new Long[] {3L, 4L, 5L, 6L, 7L});
assertExpr("map((x) -> x + 1, map((x) -> x + 1, b))", new Long[] {3L, 4L, 5L, 6L, 7L});
assertExpr("map(() -> 1, [1, 2, 3, 4, 5])", new Long[] {1L, 1L, 1L, 1L, 1L});
}
@Test
public void testCartesianMap()
{
assertExpr("cartesian_map((x, y) -> concat(x, y), ['foo', 'bar', 'baz', 'foobar'], ['bar', 'baz'])", new String[] {"foobar", "foobaz", "barbar", "barbaz", "bazbar", "bazbaz", "foobarbar", "foobarbaz"});
assertExpr("cartesian_map((x, y, z) -> concat(concat(x, y), z), ['foo', 'bar', 'baz', 'foobar'], ['bar', 'baz'], ['omg'])", new String[] {"foobaromg", "foobazomg", "barbaromg", "barbazomg", "bazbaromg", "bazbazomg", "foobarbaromg", "foobarbazomg"});
assertExpr("cartesian_map(() -> 1, [1, 2], [1, 2, 3])", new Long[] {1L, 1L, 1L, 1L, 1L, 1L});
assertExpr("cartesian_map((x, y) -> concat(x, y), d, d)", new String[] {null});
assertExpr("cartesian_map((x, y) -> concat(x, y), d, f)", new String[0]);
if (NullHandling.replaceWithDefault()) {
assertExpr("cartesian_map((x, y) -> concat(x, y), d, e)", new String[]{null, "foo", "bar"});
assertExpr("cartesian_map((x, y) -> concat(x, y), e, e)", new String[] {null, "foo", "bar", "foo", "foofoo", "foobar", "bar", "barfoo", "barbar"});
} else {
assertExpr("cartesian_map((x, y) -> concat(x, y), d, e)", new String[]{null, null, null});
assertExpr("cartesian_map((x, y) -> concat(x, y), e, e)", new String[] {null, null, null, null, "foofoo", "foobar", null, "barfoo", "barbar"});
}
}
@Test
public void testFilter()
{
assertExpr("filter((x) -> strlen(x) > 3, ['foo', 'bar', 'baz', 'foobar'])", new String[] {"foobar"});
assertExpr("filter((x) -> strlen(x) > 3, a)", new String[] {"foobar"});
assertExpr("filter((x) -> x > 2, [1, 2, 3, 4, 5])", new Long[] {3L, 4L, 5L});
assertExpr("filter((x) -> x > 2, b)", new Long[] {3L, 4L, 5L});
}
@Test
public void testFold()
{
assertExpr("fold((x, y) -> x + y, [1, 1, 1, 1, 1], 0)", 5L);
assertExpr("fold((b, acc) -> b * acc, map((b) -> b * 2, filter(b -> b > 3, b)), 1)", 80L);
assertExpr("fold((a, acc) -> concat(a, acc), a, '')", "foobarbazbarfoo");
assertExpr("fold((a, acc) -> array_append(acc, a), a, [])", new String[]{"foo", "bar", "baz", "foobar"});
assertExpr("fold((a, acc) -> array_append(acc, a), b, <LONG>[])", new Long[]{1L, 2L, 3L, 4L, 5L});
}
@Test
public void testCartesianFold()
{
assertExpr("cartesian_fold((x, y, acc) -> x + y + acc, [1, 1, 1, 1, 1], [1, 1], 0)", 20L);
}
@Test
public void testAnyMatch()
{
assertExpr("any(x -> x > 3, [1, 2, 3, 4])", 1L);
assertExpr("any(x -> x > 3, [1, 2, 3])", 0L);
assertExpr("any(x -> x, map(x -> x > 3, [1, 2, 3, 4]))", 1L);
assertExpr("any(x -> x, map(x -> x > 3, [1, 2, 3]))", 0L);
}
@Test
public void testAllMatch()
{
assertExpr("all(x -> x > 0, [1, 2, 3, 4])", 1L);
assertExpr("all(x -> x > 1, [1, 2, 3, 4])", 0L);
assertExpr("all(x -> x, map(x -> x > 0, [1, 2, 3, 4]))", 1L);
assertExpr("all(x -> x, map(x -> x > 1, [1, 2, 3, 4]))", 0L);
}
@Test
public void testScoping()
{
assertExpr("map(b -> b + 1, b)", new Long[]{2L, 3L, 4L, 5L, 6L});
assertExpr("fold((b, acc) -> acc + b, map(b -> b + 1, b), 0)", 20L);
assertExpr("fold((b, acc) -> acc + b, map(b -> b + 1, b), fold((b, acc) -> acc + b, map(b -> b + 1, b), 0))", 40L);
assertExpr("fold((b, acc) -> acc + b, map(b -> b + 1, b), 0) + fold((b, acc) -> acc + b, map(b -> b + 1, b), 0)", 40L);
assertExpr("fold((b, acc) -> acc + b, map(b -> b + 1, b), fold((b, acc) -> acc + b, map(b -> b + 1, b), 0) + fold((b, acc) -> acc + b, map(b -> b + 1, b), 0))", 60L);
}
@Test
public void testInvalidArgCount()
{
expectedException.expect(IllegalArgumentException.class);
expectedException.expectMessage("lambda expression argument count does not match fold argument count");
assertExpr("fold(() -> 1, [1, 1, 1, 1, 1], 0)", null);
expectedException.expectMessage("lambda expression argument count does not match cartesian_fold argument count");
assertExpr("cartesian_fold(() -> 1, [1, 1, 1, 1, 1], [1, 1], 0)", null);
expectedException.expectMessage("lambda expression argument count does not match any argument count");
assertExpr("any(() -> 1, [1, 2, 3, 4])", null);
expectedException.expectMessage("lambda expression argument count does not match all argument count");
assertExpr("all(() -> 0, [1, 2, 3, 4])", null);
}
private void assertExpr(final String expression, final Object expectedResult)
{
final Expr expr = Parser.parse(expression, ExprMacroTable.nil());
Assert.assertEquals(expression, expectedResult, expr.eval(bindings).value());
final Expr exprNoFlatten = Parser.parse(expression, ExprMacroTable.nil(), false);
final Expr roundTrip = Parser.parse(exprNoFlatten.stringify(), ExprMacroTable.nil());
Assert.assertEquals(expr.stringify(), expectedResult, roundTrip.eval(bindings).value());
final Expr roundTripFlatten = Parser.parse(expr.stringify(), ExprMacroTable.nil());
Assert.assertEquals(expr.stringify(), expectedResult, roundTripFlatten.eval(bindings).value());
Assert.assertEquals(expr.stringify(), roundTrip.stringify());
Assert.assertEquals(expr.stringify(), roundTripFlatten.stringify());
Assert.assertArrayEquals(expr.getCacheKey(), roundTrip.getCacheKey());
Assert.assertArrayEquals(expr.getCacheKey(), roundTripFlatten.getCacheKey());
}
private void assertExpr(final String expression, final Object[] expectedResult)
{
final Expr expr = Parser.parse(expression, ExprMacroTable.nil());
final Object[] result = expr.eval(bindings).asArray();
if (expectedResult.length != 0 || result == null || result.length != 0) {
Assert.assertArrayEquals(expression, expectedResult, result);
}
final Expr exprNoFlatten = Parser.parse(expression, ExprMacroTable.nil(), false);
final Expr roundTrip = Parser.parse(exprNoFlatten.stringify(), ExprMacroTable.nil());
final Object[] resultRoundTrip = roundTrip.eval(bindings).asArray();
if (expectedResult.length != 0 || resultRoundTrip == null || resultRoundTrip.length != 0) {
Assert.assertArrayEquals(expr.stringify(), expectedResult, resultRoundTrip);
}
final Expr roundTripFlatten = Parser.parse(expr.stringify(), ExprMacroTable.nil());
final Object[] resultRoundTripFlatten = roundTripFlatten.eval(bindings).asArray();
if (expectedResult.length != 0 || resultRoundTripFlatten == null || resultRoundTripFlatten.length != 0) {
Assert.assertArrayEquals(expr.stringify(), expectedResult, resultRoundTripFlatten);
}
Assert.assertEquals(expr.stringify(), roundTrip.stringify());
Assert.assertEquals(expr.stringify(), roundTripFlatten.stringify());
Assert.assertArrayEquals(expr.getCacheKey(), roundTrip.getCacheKey());
Assert.assertArrayEquals(expr.getCacheKey(), roundTripFlatten.getCacheKey());
}
private void assertExpr(final String expression, final Double[] expectedResult)
{
final Expr expr = Parser.parse(expression, ExprMacroTable.nil());
Double[] result = (Double[]) expr.eval(bindings).value();
Assert.assertEquals(expectedResult.length, result.length);
for (int i = 0; i < result.length; i++) {
Assert.assertEquals(expression, expectedResult[i], result[i], 0.00001); // something is lame somewhere..
}
final Expr exprNoFlatten = Parser.parse(expression, ExprMacroTable.nil(), false);
final Expr roundTrip = Parser.parse(exprNoFlatten.stringify(), ExprMacroTable.nil());
Double[] resultRoundTrip = (Double[]) roundTrip.eval(bindings).value();
Assert.assertEquals(expectedResult.length, resultRoundTrip.length);
for (int i = 0; i < resultRoundTrip.length; i++) {
Assert.assertEquals(expression, expectedResult[i], resultRoundTrip[i], 0.00001);
}
final Expr roundTripFlatten = Parser.parse(expr.stringify(), ExprMacroTable.nil());
Double[] resultRoundTripFlatten = (Double[]) roundTripFlatten.eval(bindings).value();
Assert.assertEquals(expectedResult.length, resultRoundTripFlatten.length);
for (int i = 0; i < resultRoundTripFlatten.length; i++) {
Assert.assertEquals(expression, expectedResult[i], resultRoundTripFlatten[i], 0.00001);
}
Assert.assertEquals(expr.stringify(), roundTrip.stringify());
Assert.assertEquals(expr.stringify(), roundTripFlatten.stringify());
Assert.assertArrayEquals(expr.getCacheKey(), roundTrip.getCacheKey());
Assert.assertArrayEquals(expr.getCacheKey(), roundTripFlatten.getCacheKey());
}
}