blob: 480c625d515e6e81bff136d9d12f8a2e3410b5b1 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.groovy.ginq.provider.collection.runtime;
import groovy.lang.Tuple2;
import org.apache.groovy.util.ReversedList;
import java.math.BigDecimal;
import java.math.RoundingMode;
import java.util.Comparator;
import java.util.List;
import java.util.function.Function;
import java.util.stream.Collectors;
import static java.util.Collections.binarySearch;
import static org.codehaus.groovy.runtime.dgmimpl.NumberNumberPlus.plus;
import static org.codehaus.groovy.runtime.typehandling.NumberMath.toBigDecimal;
/**
* Represents window which stores elements used by window functions
*
* @param <T> the type of {@link Queryable} element
* @param <U> the type of field to sort
* @since 4.0.0
*/
class WindowImpl<T, U extends Comparable<? super U>> extends QueryableCollection<T> implements Window<T> {
WindowImpl(Tuple2<T, Long> currentRecord, int index, U value, List<T> list, Order<? super T, ? extends U> order) {
super(list);
this.currentRecord = currentRecord;
this.order = order;
this.comparator = null == order ? null : makeComparator(order);
this.index = index;
this.value = value;
this.list = list;
}
@Override
public long rowNumber() {
return index;
}
@Override
public <V> V lead(Function<? super T, ? extends V> extractor, long lead, V def) {
V field;
if (0 == lead) {
field = extractor.apply(currentRecord.getV1());
} else if (0 <= index + lead && index + lead < list.size()) {
field = extractor.apply(list.get(index + (int) lead));
} else {
field = def;
}
return field;
}
@Override
public <V> V lag(Function<? super T, ? extends V> extractor, long lag, V def) {
return lead(extractor, -lag, def);
}
@Override
public <V> V firstValue(Function<? super T, ? extends V> extractor) {
return nthValue(extractor, 0);
}
@Override
public <V> V lastValue(Function<? super T, ? extends V> extractor) {
final int size = list.size();
if (0 == size) {
return null;
}
return nthValue(extractor, size - 1);
}
@Override
public <V> V nthValue(Function<? super T, ? extends V> extractor, long index) {
if (index < 0) {
throw new IllegalArgumentException("index cannot be less than 0: " + index);
}
if (index > list.size() - 1) {
return null;
}
return extractor.apply(list.get((int) index));
}
@Override
public Long rank() {
if (null == order) {
return null;
}
long result = 1L;
for (T t : list) {
if (comparator.compare(currentRecord.getV1(), t) > 0) {
result++;
}
}
return result;
}
@Override
public Long denseRank() {
if (null == order) {
return null;
}
long result = 1L;
T latest = null;
for (T t : list) {
if (comparator.compare(currentRecord.getV1(), t) > 0 && comparator.compare(latest, t) != 0) {
result++;
}
latest = t;
}
return result;
}
@Override
public BigDecimal percentRank() {
if (null == order) {
return null;
}
final int size = list.size();
if (1 == size) {
return BigDecimal.ONE;
}
Long r = rank();
if (null == r) {
return null;
}
return toBigDecimal(r - 1).divide(toBigDecimal(size - 1), 16, RoundingMode.HALF_UP);
}
@Override
public BigDecimal cumeDist() {
if (null == order) {
return null;
}
long cnt = list.stream()
.filter(e -> comparator.compare(currentRecord.getV1(), e) >= 0)
.count();
return toBigDecimal(cnt).divide(toBigDecimal(list.size()), 16, RoundingMode.HALF_UP);
}
@Override
public long ntile(long bucketCnt) {
return bucketCnt * rowNumber() / list.size();
}
private static <T, U extends Comparable<? super U>> long getFirstIndex(WindowDefinition<T, U> windowDefinition, int index) {
RowBound rowBound = windowDefinition.rows();
final Long lower = rowBound.getLower();
return null == lower || Long.MIN_VALUE == lower ? 0 : index + lower;
}
private static <T, U extends Comparable<? super U>> long getLastIndex(WindowDefinition<T, U> windowDefinition, int index, long size) {
RowBound rowBound = windowDefinition.rows();
final Long upper = rowBound.getUpper();
return null == upper || Long.MAX_VALUE == upper ? size - 1 : index + upper;
}
static <T, U extends Comparable<? super U>> RowBound getValidRowBound(WindowDefinition<T, U> windowDefinition, int index, U value, List<Tuple2<T, Long>> listWithIndex) {
int size = listWithIndex.size();
long firstIndex = 0;
long lastIndex = size - 1;
if (null != windowDefinition.rows() && RowBound.DEFAULT != windowDefinition.rows()) {
firstIndex = getFirstIndex(windowDefinition, index);
lastIndex = getLastIndex(windowDefinition, index, size);
} else if (null != windowDefinition.range() && null != windowDefinition.orderBy()) {
ValueBound<? extends U> valueBound = windowDefinition.range();
U lower = valueBound.getLower();
U upper = valueBound.getUpper();
if (value instanceof Number && (lower instanceof Number || null == lower) && (upper instanceof Number || null == upper)) {
final List<Order<? super T, ? extends U>> orderList = windowDefinition.orderBy();
if (orderList.size() == 1) {
Order<? super T, ? extends U> order = orderList.get(0);
if (listWithIndex.isEmpty()) {
return null;
}
int flag = order.isAsc() ? 1 : -1;
BigDecimal firstElement = toBigDecimal((Number) order.getKeyExtractor().apply(listWithIndex.get(0).getV1()));
BigDecimal lastElement = toBigDecimal((Number) order.getKeyExtractor().apply(listWithIndex.get(size - 1).getV1()));
BigDecimal lowerValue = null == lower ? MIN_VALUE : toBigDecimal(plus((Number) value, (Number) lower));
BigDecimal upperValue = null == upper ? MAX_VALUE : toBigDecimal(plus((Number) value, (Number) upper));
if ((flag * lowerValue.compareTo(firstElement) < 0 && flag * upperValue.compareTo(firstElement) < 0)
|| (flag * lowerValue.compareTo(lastElement) > 0 && flag * upperValue.compareTo(lastElement) > 0)) {
return null;
}
List<U> list =
listWithIndex.stream()
.map(e -> order.getKeyExtractor().apply(e.getV1()))
.collect(Collectors.toList());
if (order.isAsc()) {
firstIndex = getIndexByValue(lowerValue, true, list);
lastIndex = getIndexByValue(upperValue, false, list);
} else {
final List<U> reversedList = new ReversedList<>(list);
lastIndex = size - 1 - getIndexByValue(lowerValue, true, reversedList);
firstIndex = size - 1 - getIndexByValue(upperValue, false, reversedList);
}
}
}
}
if ((firstIndex < 0 && lastIndex < 0) || (firstIndex >= size && lastIndex >= size)) {
return null;
}
return new RowBound(Math.max(firstIndex, 0), Math.min(lastIndex, size - 1));
}
private static <T, U extends Comparable<? super U>> long getIndexByValue(BigDecimal value, boolean isLower, List<U> list) {
int tmpIndex = binarySearch(list, value, Comparator.comparing(u -> toBigDecimal((Number) u)));
int valueIndex;
if (tmpIndex >= 0) {
valueIndex = tmpIndex;
} else {
valueIndex = -tmpIndex - 1;
if (!isLower) {
valueIndex = valueIndex - 1;
if (valueIndex < 0) {
valueIndex = 0;
}
}
}
if (isLower) {
int i = valueIndex - 1;
for (; i >= 0; i--) {
if (!value.equals(toBigDecimal((Number) list.get(i)))) {
break;
}
}
valueIndex = i + 1;
} else {
int i = valueIndex + 1;
for (int n = list.size(); i < n; i++) {
if (!value.equals(toBigDecimal((Number) list.get(i)))) {
break;
}
}
valueIndex = i - 1;
}
return valueIndex;
}
static <T, U extends Comparable<? super U>> List<Order<Tuple2<T, Long>, U>> composeOrders(List<Queryable.Order<? super T, ? extends U>> orderList) {
return orderList.stream()
.map(order -> new Order<Tuple2<T, Long>, U>(t -> order.getKeyExtractor().apply(t.getV1()), order.isAsc(), order.isNullsLast()))
.collect(Collectors.toList());
}
static <T, U extends Comparable<? super U>> List<Order<Tuple2<T, Long>, U>> composeOrders(WindowDefinition<T, U> windowDefinition) {
return composeOrders(windowDefinition.orderBy());
}
private final Tuple2<T, Long> currentRecord;
private final Order<? super T, ? extends U> order;
private final Comparator<? super T> comparator;
private final int index;
private final U value;
private final List<T> list;
private static final BigDecimal MIN_VALUE = toBigDecimal(Long.MIN_VALUE);
private static final BigDecimal MAX_VALUE = toBigDecimal(Long.MAX_VALUE);
private static final long serialVersionUID = -3458969297047398621L;
}