blob: 0a0ece0ed288713775b42f236f15cb0a2f58ccad [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.shardingsphere.sharding.merge.dql.groupby;
import com.google.common.collect.ImmutableMap;
import org.apache.shardingsphere.infra.database.type.DatabaseTypeRegistry;
import org.apache.shardingsphere.infra.executor.sql.execute.result.query.QueryResult;
import org.apache.shardingsphere.infra.merge.result.MergedResult;
import org.apache.shardingsphere.sharding.merge.dql.ShardingDQLResultMerger;
import org.apache.shardingsphere.infra.metadata.schema.model.ColumnMetaData;
import org.apache.shardingsphere.infra.metadata.schema.ShardingSphereSchema;
import org.apache.shardingsphere.infra.metadata.schema.model.TableMetaData;
import org.apache.shardingsphere.infra.binder.segment.select.groupby.GroupByContext;
import org.apache.shardingsphere.infra.binder.segment.select.orderby.OrderByContext;
import org.apache.shardingsphere.infra.binder.segment.select.orderby.OrderByItem;
import org.apache.shardingsphere.infra.binder.segment.select.pagination.PaginationContext;
import org.apache.shardingsphere.infra.binder.segment.select.projection.ProjectionsContext;
import org.apache.shardingsphere.infra.binder.segment.select.projection.impl.AggregationProjection;
import org.apache.shardingsphere.infra.binder.statement.dml.SelectStatementContext;
import org.apache.shardingsphere.sql.parser.sql.common.constant.AggregationType;
import org.apache.shardingsphere.sql.parser.sql.common.constant.OrderDirection;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.item.ProjectionsSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.dml.order.item.IndexOrderByItemSegment;
import org.apache.shardingsphere.sql.parser.sql.common.segment.generic.table.SimpleTableSegment;
import org.apache.shardingsphere.sql.parser.sql.dialect.statement.mysql.dml.MySQLSelectStatement;
import org.apache.shardingsphere.sql.parser.sql.common.value.identifier.IdentifierValue;
import org.junit.Test;
import java.math.BigDecimal;
import java.sql.SQLException;
import java.util.Arrays;
import java.util.Calendar;
import java.util.Collections;
import java.util.Date;
import static org.hamcrest.CoreMatchers.is;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.RETURNS_DEEP_STUBS;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
public final class GroupByStreamMergedResultTest {
@Test
public void assertNextForResultSetsAllEmpty() throws SQLException {
ShardingDQLResultMerger resultMerger = new ShardingDQLResultMerger(DatabaseTypeRegistry.getActualDatabaseType("MySQL"));
MergedResult actual = resultMerger.merge(Arrays.asList(mockQueryResult(), mockQueryResult(), mockQueryResult()), createSelectStatementContext(), buildSchema());
assertFalse(actual.next());
}
@Test
public void assertNextForSomeResultSetsEmpty() throws SQLException {
QueryResult queryResult1 = mockQueryResult();
when(queryResult1.next()).thenReturn(true, false);
when(queryResult1.getValue(1, Object.class)).thenReturn(20);
when(queryResult1.getValue(2, Object.class)).thenReturn(0);
when(queryResult1.getValue(3, Object.class)).thenReturn(2);
when(queryResult1.getValue(4, Object.class)).thenReturn(new Date(0L));
when(queryResult1.getValue(5, Object.class)).thenReturn(2);
when(queryResult1.getValue(6, Object.class)).thenReturn(20);
QueryResult queryResult2 = mockQueryResult();
QueryResult queryResult3 = mockQueryResult();
when(queryResult3.next()).thenReturn(true, true, false);
when(queryResult3.getValue(1, Object.class)).thenReturn(20, 30);
when(queryResult3.getValue(2, Object.class)).thenReturn(0);
when(queryResult3.getValue(3, Object.class)).thenReturn(2, 2, 3);
when(queryResult3.getValue(4, Object.class)).thenReturn(new Date(0L));
when(queryResult3.getValue(5, Object.class)).thenReturn(2, 2, 3);
when(queryResult3.getValue(6, Object.class)).thenReturn(20, 20, 30);
ShardingDQLResultMerger resultMerger = new ShardingDQLResultMerger(DatabaseTypeRegistry.getActualDatabaseType("MySQL"));
MergedResult actual = resultMerger.merge(Arrays.asList(queryResult1, queryResult2, queryResult3), createSelectStatementContext(), buildSchema());
assertTrue(actual.next());
assertThat(actual.getValue(1, Object.class), is(new BigDecimal(40)));
assertThat(((BigDecimal) actual.getValue(2, Object.class)).intValue(), is(10));
assertThat(actual.getValue(3, Object.class), is(2));
assertThat(actual.getCalendarValue(4, Date.class, Calendar.getInstance()), is(new Date(0L)));
assertThat(actual.getValue(5, Object.class), is(new BigDecimal(4)));
assertThat(actual.getValue(6, Object.class), is(new BigDecimal(40)));
assertTrue(actual.next());
assertThat(actual.getValue(1, Object.class), is(new BigDecimal(30)));
assertThat(((BigDecimal) actual.getValue(2, Object.class)).intValue(), is(10));
assertThat(actual.getValue(3, Object.class), is(3));
assertThat(actual.getCalendarValue(4, Date.class, Calendar.getInstance()), is(new Date(0L)));
assertThat(actual.getValue(5, Object.class), is(new BigDecimal(3)));
assertThat(actual.getValue(6, Object.class), is(new BigDecimal(30)));
assertFalse(actual.next());
}
@Test
public void assertNextForMix() throws SQLException {
QueryResult queryResult1 = mockQueryResult();
when(queryResult1.next()).thenReturn(true, false);
when(queryResult1.getValue(1, Object.class)).thenReturn(20);
when(queryResult1.getValue(2, Object.class)).thenReturn(0);
when(queryResult1.getValue(3, Object.class)).thenReturn(2);
when(queryResult1.getValue(5, Object.class)).thenReturn(2);
when(queryResult1.getValue(6, Object.class)).thenReturn(20);
QueryResult queryResult2 = mockQueryResult();
when(queryResult2.next()).thenReturn(true, true, true, false);
when(queryResult2.getValue(1, Object.class)).thenReturn(20, 30, 40);
when(queryResult2.getValue(2, Object.class)).thenReturn(0);
when(queryResult2.getValue(3, Object.class)).thenReturn(2, 2, 3, 3, 3, 4);
when(queryResult2.getValue(5, Object.class)).thenReturn(2, 2, 3, 3, 3, 4);
when(queryResult2.getValue(6, Object.class)).thenReturn(20, 20, 30, 30, 30, 40);
QueryResult queryResult3 = mockQueryResult();
when(queryResult3.next()).thenReturn(true, true, false);
when(queryResult3.getValue(1, Object.class)).thenReturn(10, 30);
when(queryResult3.getValue(2, Object.class)).thenReturn(10);
when(queryResult3.getValue(3, Object.class)).thenReturn(1, 1, 1, 1, 3);
when(queryResult3.getValue(5, Object.class)).thenReturn(1, 1, 3);
when(queryResult3.getValue(6, Object.class)).thenReturn(10, 10, 30);
ShardingDQLResultMerger resultMerger = new ShardingDQLResultMerger(DatabaseTypeRegistry.getActualDatabaseType("MySQL"));
MergedResult actual = resultMerger.merge(Arrays.asList(queryResult1, queryResult2, queryResult3), createSelectStatementContext(), buildSchema());
assertTrue(actual.next());
assertThat(actual.getValue(1, Object.class), is(new BigDecimal(10)));
assertThat(((BigDecimal) actual.getValue(2, Object.class)).intValue(), is(10));
assertThat(actual.getValue(3, Object.class), is(1));
assertThat(actual.getValue(5, Object.class), is(new BigDecimal(1)));
assertThat(actual.getValue(6, Object.class), is(new BigDecimal(10)));
assertTrue(actual.next());
assertThat(actual.getValue(1, Object.class), is(new BigDecimal(40)));
assertThat(((BigDecimal) actual.getValue(2, Object.class)).intValue(), is(10));
assertThat(actual.getValue(3, Object.class), is(2));
assertThat(actual.getValue(5, Object.class), is(new BigDecimal(4)));
assertThat(actual.getValue(6, Object.class), is(new BigDecimal(40)));
assertTrue(actual.next());
assertThat(actual.getValue(1, Object.class), is(new BigDecimal(60)));
assertThat(((BigDecimal) actual.getValue(2, Object.class)).intValue(), is(10));
assertThat(actual.getValue(3, Object.class), is(3));
assertThat(actual.getValue(5, Object.class), is(new BigDecimal(6)));
assertThat(actual.getValue(6, Object.class), is(new BigDecimal(60)));
assertTrue(actual.next());
assertThat(actual.getValue(1, Object.class), is(new BigDecimal(40)));
assertThat(((BigDecimal) actual.getValue(2, Object.class)).intValue(), is(10));
assertThat(actual.getValue(3, Object.class), is(4));
assertThat(actual.getValue(5, Object.class), is(new BigDecimal(4)));
assertThat(actual.getValue(6, Object.class), is(new BigDecimal(40)));
assertFalse(actual.next());
}
private SelectStatementContext createSelectStatementContext() {
AggregationProjection aggregationProjection1 = new AggregationProjection(AggregationType.COUNT, "(*)", null);
aggregationProjection1.setIndex(1);
AggregationProjection aggregationProjection2 = new AggregationProjection(AggregationType.AVG, "(num)", null);
aggregationProjection2.setIndex(2);
AggregationProjection derivedAggregationProjection1 = new AggregationProjection(AggregationType.COUNT, "(num)", "AVG_DERIVED_COUNT_0");
aggregationProjection2.setIndex(5);
aggregationProjection2.getDerivedAggregationProjections().add(derivedAggregationProjection1);
AggregationProjection derivedAggregationProjection2 = new AggregationProjection(AggregationType.SUM, "(num)", "AVG_DERIVED_SUM_0");
aggregationProjection2.setIndex(6);
aggregationProjection2.getDerivedAggregationProjections().add(derivedAggregationProjection2);
SimpleTableSegment tableSegment = new SimpleTableSegment(10, 13, new IdentifierValue("tbl"));
MySQLSelectStatement selectStatement = new MySQLSelectStatement();
selectStatement.setFrom(tableSegment);
ProjectionsContext projectionsContext = new ProjectionsContext(0, 0, false, Arrays.asList(aggregationProjection1, aggregationProjection2));
ProjectionsSegment projectionsSegment = new ProjectionsSegment(0, 0);
selectStatement.setProjections(projectionsSegment);
return new SelectStatementContext(selectStatement,
new GroupByContext(Collections.singletonList(new OrderByItem(new IndexOrderByItemSegment(0, 0, 3, OrderDirection.ASC, OrderDirection.ASC))), 0),
new OrderByContext(Collections.singletonList(new OrderByItem(new IndexOrderByItemSegment(0, 0, 3, OrderDirection.ASC, OrderDirection.ASC))), false),
projectionsContext, new PaginationContext(null, null, Collections.emptyList()));
}
private ShardingSphereSchema buildSchema() {
ColumnMetaData columnMetaData1 = new ColumnMetaData("col1", 0, false, false, false);
ColumnMetaData columnMetaData2 = new ColumnMetaData("col2", 0, false, false, false);
ColumnMetaData columnMetaData3 = new ColumnMetaData("col3", 0, false, false, false);
TableMetaData tableMetaData = new TableMetaData(Arrays.asList(columnMetaData1, columnMetaData2, columnMetaData3), Collections.emptyList());
return new ShardingSphereSchema(ImmutableMap.of("tbl", tableMetaData));
}
private QueryResult mockQueryResult() throws SQLException {
QueryResult result = mock(QueryResult.class, RETURNS_DEEP_STUBS);
when(result.getMetaData().getColumnCount()).thenReturn(6);
when(result.getMetaData().getColumnLabel(1)).thenReturn("COUNT(*)");
when(result.getMetaData().getColumnLabel(2)).thenReturn("AVG(num)");
when(result.getMetaData().getColumnLabel(3)).thenReturn("id");
when(result.getMetaData().getColumnLabel(4)).thenReturn("date");
when(result.getMetaData().getColumnLabel(5)).thenReturn("AVG_DERIVED_COUNT_0");
when(result.getMetaData().getColumnLabel(6)).thenReturn("AVG_DERIVED_SUM_0");
when(result.getMetaData().getColumnName(1)).thenReturn("col1");
when(result.getMetaData().getColumnName(2)).thenReturn("col2");
when(result.getMetaData().getColumnName(3)).thenReturn("col3");
return result;
}
}