| /** |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| package org.apache.pinot.core.query.optimizer; |
| |
| import java.util.ArrayList; |
| import java.util.Arrays; |
| import java.util.HashSet; |
| import java.util.Iterator; |
| import java.util.List; |
| import java.util.Set; |
| import org.apache.pinot.common.request.Expression; |
| import org.apache.pinot.common.request.Function; |
| import org.apache.pinot.common.request.PinotQuery; |
| import org.apache.pinot.common.utils.request.RequestUtils; |
| import org.apache.pinot.spi.data.FieldSpec.DataType; |
| import org.apache.pinot.spi.data.Schema; |
| import org.apache.pinot.spi.utils.CommonConstants.Query.Range; |
| import org.apache.pinot.sql.FilterKind; |
| import org.apache.pinot.sql.parsers.CalciteSqlParser; |
| import org.testng.annotations.Test; |
| |
| import static org.testng.Assert.*; |
| |
| |
| public class QueryOptimizerTest { |
| private static final QueryOptimizer OPTIMIZER = new QueryOptimizer(); |
| private static final Schema SCHEMA = |
| new Schema.SchemaBuilder().setSchemaName("testTable").addSingleValueDimension("int", DataType.INT) |
| .addSingleValueDimension("long", DataType.LONG).addSingleValueDimension("float", DataType.FLOAT) |
| .addSingleValueDimension("double", DataType.DOUBLE).addSingleValueDimension("string", DataType.STRING) |
| .addSingleValueDimension("bytes", DataType.BYTES).addMultiValueDimension("mvInt", DataType.INT).build(); |
| |
| @Test |
| public void testNoFilter() { |
| String query = "SELECT * FROM testTable"; |
| PinotQuery pinotQuery = CalciteSqlParser.compileToPinotQuery(query); |
| OPTIMIZER.optimize(pinotQuery, SCHEMA); |
| assertNull(pinotQuery.getFilterExpression()); |
| } |
| |
| @Test |
| public void testFlattenAndOrFilter() { |
| String query = |
| "SELECT * FROM testTable WHERE ((int = 4 OR (long = 5 AND (float = 9 AND double = 7.5))) OR string = 'foo') " |
| + "OR bytes = 'abc'"; |
| PinotQuery pinotQuery = CalciteSqlParser.compileToPinotQuery(query); |
| OPTIMIZER.optimize(pinotQuery, SCHEMA); |
| Function filterFunction = pinotQuery.getFilterExpression().getFunctionCall(); |
| assertEquals(filterFunction.getOperator(), FilterKind.OR.name()); |
| List<Expression> children = filterFunction.getOperands(); |
| assertEquals(children.size(), 4); |
| assertEquals(children.get(0), getEqFilterExpression("int", 4)); |
| assertEquals(children.get(2), getEqFilterExpression("string", "foo")); |
| assertEquals(children.get(3), getEqFilterExpression("bytes", "abc")); |
| |
| Function secondChildFunction = children.get(1).getFunctionCall(); |
| assertEquals(secondChildFunction.getOperator(), FilterKind.AND.name()); |
| List<Expression> secondChildChildren = secondChildFunction.getOperands(); |
| assertEquals(secondChildChildren.size(), 3); |
| assertEquals(secondChildChildren.get(0), getEqFilterExpression("long", 5L)); |
| assertEquals(secondChildChildren.get(1), getEqFilterExpression("float", 9f)); |
| assertEquals(secondChildChildren.get(2), getEqFilterExpression("double", 7.5)); |
| } |
| |
| private static Expression getEqFilterExpression(String column, Object value) { |
| Expression eqFilterExpression = RequestUtils.getFunctionExpression(FilterKind.EQUALS.name()); |
| eqFilterExpression.getFunctionCall().setOperands( |
| Arrays.asList(RequestUtils.getIdentifierExpression(column), RequestUtils.getLiteralExpression(value))); |
| return eqFilterExpression; |
| } |
| |
| @Test |
| public void testMergeEqInFilter() { |
| String query = |
| "SELECT * FROM testTable WHERE int IN (1, 1) AND (long IN (2, 3) OR long IN (3, 4) OR long = 2) AND (float = " |
| + "3.5 OR double IN (1.1, 1.2) OR float = 4.5 OR float > 5.5 OR double = 1.3)"; |
| PinotQuery pinotQuery = CalciteSqlParser.compileToPinotQuery(query); |
| OPTIMIZER.optimize(pinotQuery, SCHEMA); |
| Function filterFunction = pinotQuery.getFilterExpression().getFunctionCall(); |
| assertEquals(filterFunction.getOperator(), FilterKind.AND.name()); |
| List<Expression> children = filterFunction.getOperands(); |
| assertEquals(children.size(), 3); |
| assertEquals(children.get(0), getEqFilterExpression("int", 1)); |
| checkInFilterFunction(children.get(1).getFunctionCall(), "long", Arrays.asList(2L, 3L, 4L)); |
| |
| Function thirdChildFunction = children.get(2).getFunctionCall(); |
| assertEquals(thirdChildFunction.getOperator(), FilterKind.OR.name()); |
| List<Expression> thirdChildChildren = thirdChildFunction.getOperands(); |
| assertEquals(thirdChildChildren.size(), 3); |
| assertEquals(thirdChildChildren.get(0).getFunctionCall().getOperator(), FilterKind.GREATER_THAN.name()); |
| |
| // Order of second and third child is not deterministic |
| Function secondGrandChildFunction = thirdChildChildren.get(1).getFunctionCall(); |
| assertEquals(secondGrandChildFunction.getOperator(), FilterKind.IN.name()); |
| Function thirdGrandChildFunction = thirdChildChildren.get(2).getFunctionCall(); |
| assertEquals(thirdGrandChildFunction.getOperator(), FilterKind.IN.name()); |
| if (secondGrandChildFunction.getOperands().get(0).getIdentifier().getName().equals("float")) { |
| checkInFilterFunction(secondGrandChildFunction, "float", Arrays.asList(3.5, 4.5)); |
| checkInFilterFunction(thirdGrandChildFunction, "double", Arrays.asList(1.1, 1.2, 1.3)); |
| } else { |
| checkInFilterFunction(secondGrandChildFunction, "double", Arrays.asList(1.1, 1.2, 1.3)); |
| checkInFilterFunction(thirdGrandChildFunction, "float", Arrays.asList(3.5, 4.5)); |
| } |
| } |
| |
| private static void checkInFilterFunction(Function inFilterFunction, String column, List<Object> values) { |
| assertEquals(inFilterFunction.getOperator(), FilterKind.IN.name()); |
| List<Expression> operands = inFilterFunction.getOperands(); |
| int numOperands = operands.size(); |
| assertEquals(numOperands, values.size() + 1); |
| assertEquals(operands.get(0).getIdentifier().getName(), column); |
| Set<Expression> valueExpressions = new HashSet<>(); |
| for (Object value : values) { |
| valueExpressions.add(RequestUtils.getLiteralExpression(value)); |
| } |
| for (int i = 1; i < numOperands; i++) { |
| assertTrue(valueExpressions.contains(operands.get(i))); |
| } |
| } |
| |
| @Test |
| public void testMergeRangeFilter() { |
| String query = |
| "SELECT * FROM testTable WHERE (int > 10 AND int <= 100 AND int BETWEEN 10 AND 20) OR (float BETWEEN 5.5 AND " |
| + "7.5 AND float = 6 AND float < 6.5 AND float BETWEEN 6 AND 8) OR (string > '123' AND string > '23') OR " |
| + "(mvInt > 5 AND mvInt < 0)"; |
| PinotQuery pinotQuery = CalciteSqlParser.compileToPinotQuery(query); |
| OPTIMIZER.optimize(pinotQuery, SCHEMA); |
| Function filterFunction = pinotQuery.getFilterExpression().getFunctionCall(); |
| assertEquals(filterFunction.getOperator(), FilterKind.OR.name()); |
| List<Expression> operands = filterFunction.getOperands(); |
| assertEquals(operands.size(), 4); |
| assertEquals(operands.get(0), getRangeFilterExpression("int", "(10\00020]")); |
| // Alphabetical order for STRING column ('23' > '123') |
| assertEquals(operands.get(2), getRangeFilterExpression("string", "(23\000*)")); |
| |
| Function secondChildFunction = operands.get(1).getFunctionCall(); |
| assertEquals(secondChildFunction.getOperator(), FilterKind.AND.name()); |
| List<Expression> secondChildChildren = secondChildFunction.getOperands(); |
| assertEquals(secondChildChildren.size(), 2); |
| assertEquals(secondChildChildren.get(0), getEqFilterExpression("float", 6f)); |
| assertEquals(secondChildChildren.get(1), getRangeFilterExpression("float", "[6.0\0006.5)")); |
| |
| // Range filter on multi-value column should not be merged ([-5, 10] can match this filter) |
| Function fourthChildFunction = operands.get(3).getFunctionCall(); |
| assertEquals(fourthChildFunction.getOperator(), FilterKind.AND.name()); |
| List<Expression> fourthChildChildren = fourthChildFunction.getOperands(); |
| assertEquals(fourthChildChildren.size(), 2); |
| assertEquals(fourthChildChildren.get(0).getFunctionCall().getOperator(), FilterKind.GREATER_THAN.name()); |
| assertEquals(fourthChildChildren.get(1).getFunctionCall().getOperator(), FilterKind.LESS_THAN.name()); |
| } |
| |
| private static Expression getRangeFilterExpression(String column, String rangeString) { |
| Expression rangeFilterExpression = RequestUtils.getFunctionExpression(FilterKind.RANGE.name()); |
| rangeFilterExpression.getFunctionCall().setOperands( |
| Arrays.asList(RequestUtils.getIdentifierExpression(column), RequestUtils.getLiteralExpression(rangeString))); |
| return rangeFilterExpression; |
| } |
| |
| @Test |
| public void testQueries() { |
| // MergeEqInFilter |
| testQuery("SELECT * FROM testTable WHERE int = 1 OR int = 2 OR int = 3", |
| "SELECT * FROM testTable WHERE int IN (1, 2, 3)"); |
| testQuery("SELECT * FROM testTable WHERE int = 1 OR int = 2 OR int = 3 AND long = 4", |
| "SELECT * FROM testTable WHERE int IN (1, 2) OR (int = 3 AND long = 4)"); |
| testQuery("SELECT * FROM testTable WHERE int = 1 OR int = 2 OR int = 3 OR long = 4 OR long = 5 OR long = 6", |
| "SELECT * FROM testTable WHERE int IN (1, 2, 3) OR long IN (4, 5, 6)"); |
| testQuery("SELECT * FROM testTable WHERE int = 1 OR long = 4 OR int = 2 OR long = 5 OR int = 3 OR long = 6", |
| "SELECT * FROM testTable WHERE int IN (1, 2, 3) OR long IN (4, 5, 6)"); |
| testQuery("SELECT * FROM testTable WHERE int = 1 OR int = 1", "SELECT * FROM testTable WHERE int = 1"); |
| testQuery("SELECT * FROM testTable WHERE (int = 1 OR int = 1) AND long = 2", |
| "SELECT * FROM testTable WHERE int = 1 AND long = 2"); |
| testQuery("SELECT * FROM testTable WHERE int = 1 OR int IN (2, 3, 4, 5)", |
| "SELECT * FROM testTable WHERE int IN (1, 2, 3, 4, 5)"); |
| testQuery("SELECT * FROM testTable WHERE int IN (1, 1) OR int = 1", "SELECT * FROM testTable WHERE int = 1"); |
| testQuery("SELECT * FROM testTable WHERE string = 'foo' OR string = 'bar' OR string = 'foobar'", |
| "SELECT * FROM testTable WHERE string IN ('foo', 'bar', 'foobar')"); |
| testQuery("SELECT * FROM testTable WHERE bytes = 'dead' OR bytes = 'beef' OR bytes = 'deadbeef'", |
| "SELECT * FROM testTable WHERE bytes IN ('dead', 'beef', 'deadbeef')"); |
| |
| // MergeRangeFilter |
| testQuery("SELECT * FROM testTable WHERE int >= 10 AND int <= 20", |
| "SELECT * FROM testTable WHERE int BETWEEN 10 AND 20"); |
| testQuery("SELECT * FROM testTable WHERE int BETWEEN 10 AND 20 AND int > 7 AND int <= 17 OR int > 20", |
| "SELECT * FROM testTable WHERE int BETWEEN 10 AND 17 OR int > 20"); |
| testQuery("SELECT * FROM testTable WHERE long BETWEEN 10 AND 20 AND long > 7 AND long <= 17 OR long > 20", |
| "SELECT * FROM testTable WHERE long BETWEEN 10 AND 17 OR long > 20"); |
| testQuery("SELECT * FROM testTable WHERE float BETWEEN 10.5 AND 20 AND float > 7 AND float <= 17.5 OR float > 20", |
| "SELECT * FROM testTable WHERE float BETWEEN 10.5 AND 17.5 OR float > 20"); |
| testQuery( |
| "SELECT * FROM testTable WHERE double BETWEEN 10.5 AND 20 AND double > 7 AND double <= 17.5 OR double > 20", |
| "SELECT * FROM testTable WHERE double BETWEEN 10.5 AND 17.5 OR double > 20"); |
| testQuery( |
| "SELECT * FROM testTable WHERE string BETWEEN '10' AND '20' AND string > '7' AND string <= '17' OR string > " |
| + "'20'", "SELECT * FROM testTable WHERE string > '7' AND string <= '17' OR string > '20'"); |
| testQuery( |
| "SELECT * FROM testTable WHERE bytes BETWEEN '10' AND '20' AND bytes > '07' AND bytes <= '17' OR bytes > '20'", |
| "SELECT * FROM testTable WHERE bytes BETWEEN '10' AND '17' OR bytes > '20'"); |
| testQuery( |
| "SELECT * FROM testTable WHERE int > 10 AND long > 20 AND int <= 30 AND long <= 40 AND int >= 15 AND long >= " |
| + "25", "SELECT * FROM testTable WHERE int BETWEEN 15 AND 30 AND long BETWEEN 25 AND 40"); |
| testQuery("SELECT * FROM testTable WHERE int > 10 AND int > 20 OR int < 30 AND int < 40", |
| "SELECT * FROM testTable WHERE int > 20 OR int < 30"); |
| testQuery("SELECT * FROM testTable WHERE int > 10 AND int > 20 OR long < 30 AND long < 40", |
| "SELECT * FROM testTable WHERE int > 20 OR long < 30"); |
| |
| // Mixed |
| testQuery( |
| "SELECT * FROM testTable WHERE int >= 20 AND (int > 10 AND (int IN (1, 2) OR (int = 2 OR int = 3)) AND int <=" |
| + " 30)", "SELECT * FROM testTable WHERE int BETWEEN 20 AND 30 AND int IN (1, 2, 3)"); |
| } |
| |
| private static void testQuery(String actual, String expected) { |
| PinotQuery actualPinotQuery = CalciteSqlParser.compileToPinotQuery(actual); |
| OPTIMIZER.optimize(actualPinotQuery, SCHEMA); |
| // Also optimize the expected query because the expected range can only be generate via optimizer |
| PinotQuery expectedPinotQuery = CalciteSqlParser.compileToPinotQuery(expected); |
| OPTIMIZER.optimize(expectedPinotQuery, SCHEMA); |
| comparePinotQuery(actualPinotQuery, expectedPinotQuery); |
| } |
| |
| private static void comparePinotQuery(PinotQuery actual, PinotQuery expected) { |
| if (expected.getFilterExpression() == null) { |
| assertNull(actual.getFilterExpression()); |
| return; |
| } |
| compareFilterExpression(actual.getFilterExpression(), expected.getFilterExpression()); |
| } |
| |
| private static void compareFilterExpression(Expression actual, Expression expected) { |
| Function actualFilterFunction = actual.getFunctionCall(); |
| Function expectedFilterFunction = expected.getFunctionCall(); |
| FilterKind actualFilterKind = FilterKind.valueOf(actualFilterFunction.getOperator()); |
| FilterKind expectedFilterKind = FilterKind.valueOf(expectedFilterFunction.getOperator()); |
| List<Expression> actualOperands = actualFilterFunction.getOperands(); |
| List<Expression> expectedOperands = expectedFilterFunction.getOperands(); |
| if (!actualFilterKind.isRange()) { |
| assertEquals(actualFilterKind, expectedFilterKind); |
| assertEquals(actualOperands.size(), expectedOperands.size()); |
| if (actualFilterKind == FilterKind.AND || actualFilterKind == FilterKind.OR) { |
| compareFilterExpressionChildren(actualOperands, expectedOperands); |
| } else { |
| assertEquals(actualOperands.get(0), expectedOperands.get(0)); |
| if (actualFilterKind == FilterKind.IN || actualFilterKind == FilterKind.NOT_IN) { |
| // Handle different order of values |
| assertEqualsNoOrder(actualOperands.toArray(), expectedOperands.toArray()); |
| } else { |
| assertEquals(actualOperands, expectedOperands); |
| } |
| } |
| } else { |
| assertTrue(expectedFilterKind.isRange()); |
| assertEquals(getRangeString(actualFilterKind, actualOperands), |
| getRangeString(expectedFilterKind, expectedOperands)); |
| } |
| } |
| |
| /** |
| * Handles different order of children under AND/OR filter. |
| */ |
| private static void compareFilterExpressionChildren(List<Expression> actual, List<Expression> expected) { |
| assertEquals(actual.size(), expected.size()); |
| List<Expression> unmatchedExpectedChildren = new ArrayList<>(expected); |
| for (Expression actualChild : actual) { |
| Iterator<Expression> iterator = unmatchedExpectedChildren.iterator(); |
| boolean findMatchingChild = false; |
| while (iterator.hasNext()) { |
| try { |
| compareFilterExpression(actualChild, iterator.next()); |
| iterator.remove(); |
| findMatchingChild = true; |
| break; |
| } catch (AssertionError e) { |
| // Ignore |
| } |
| } |
| if (!findMatchingChild) { |
| fail("Failed to find matching child"); |
| } |
| } |
| } |
| |
| private static String getRangeString(FilterKind filterKind, List<Expression> operands) { |
| switch (filterKind) { |
| case GREATER_THAN: |
| return Range.LOWER_EXCLUSIVE + operands.get(1).getLiteral().getFieldValue().toString() + Range.UPPER_UNBOUNDED; |
| case GREATER_THAN_OR_EQUAL: |
| return Range.LOWER_INCLUSIVE + operands.get(1).getLiteral().getFieldValue().toString() + Range.UPPER_UNBOUNDED; |
| case LESS_THAN: |
| return Range.LOWER_UNBOUNDED + operands.get(1).getLiteral().getFieldValue().toString() + Range.UPPER_EXCLUSIVE; |
| case LESS_THAN_OR_EQUAL: |
| return Range.LOWER_UNBOUNDED + operands.get(1).getLiteral().getFieldValue().toString() + Range.UPPER_INCLUSIVE; |
| case BETWEEN: |
| return Range.LOWER_INCLUSIVE + operands.get(1).getLiteral().getFieldValue().toString() + Range.DELIMITER |
| + operands.get(2).getLiteral().getFieldValue().toString() + Range.UPPER_INCLUSIVE; |
| case RANGE: |
| return operands.get(1).getLiteral().getStringValue(); |
| default: |
| throw new IllegalStateException(); |
| } |
| } |
| } |