blob: d0f86ab027de6465630d8d833efbd22abdeb36a6 [file] [log] [blame]
From 441b804732baca7f0915c914d66710d4f5c4b67d Mon Sep 17 00:00:00 2001
From: Dean Gurvitz <deansg@gmail.com>
Date: Mon, 7 Mar 2016 20:25:33 +0200
Subject: [PATCH 2/2] Enable custom grouping logic in Lucene
---
.../AbstractFirstPassGroupingCollector.java | 48 +++++++++++++--
.../search/grouping/EqualsGroupComparer.java | 38 ++++++++++++
.../lucene/search/grouping/GroupComparer.java | 34 +++++++++++
.../grouping/LevinshtienDistanceGroupComparer.java | 70 ++++++++++++++++++++++
.../term/TermSecondPassGroupingCollector.java | 17 +++++-
5 files changed, 201 insertions(+), 6 deletions(-)
create mode 100644 lucene/grouping/src/java/org/apache/lucene/search/grouping/EqualsGroupComparer.java
create mode 100644 lucene/grouping/src/java/org/apache/lucene/search/grouping/GroupComparer.java
create mode 100644 lucene/grouping/src/java/org/apache/lucene/search/grouping/LevinshtienDistanceGroupComparer.java
diff --git a/lucene/grouping/src/java/org/apache/lucene/search/grouping/AbstractFirstPassGroupingCollector.java b/lucene/grouping/src/java/org/apache/lucene/search/grouping/AbstractFirstPassGroupingCollector.java
index 4c386b6..1986446 100644
--- a/lucene/grouping/src/java/org/apache/lucene/search/grouping/AbstractFirstPassGroupingCollector.java
+++ b/lucene/grouping/src/java/org/apache/lucene/search/grouping/AbstractFirstPassGroupingCollector.java
@@ -18,6 +18,7 @@ package org.apache.lucene.search.grouping;
import org.apache.lucene.index.LeafReaderContext;
import org.apache.lucene.search.*;
+import org.apache.lucene.search.grouping.EqualsGroupComparer;
import java.io.IOException;
import java.util.*;
@@ -47,6 +48,8 @@ abstract public class AbstractFirstPassGroupingCollector<GROUP_VALUE_TYPE> exten
protected TreeSet<CollectedSearchGroup<GROUP_VALUE_TYPE>> orderedGroups;
private int docBase;
private int spareSlot;
+
+ protected GroupComparer<GROUP_VALUE_TYPE> groupComparer;
/**
* Create the first pass collector.
@@ -60,8 +63,13 @@ abstract public class AbstractFirstPassGroupingCollector<GROUP_VALUE_TYPE> exten
* @param topNGroups How many top groups to keep.
* @throws IOException If I/O related errors occur
*/
- @SuppressWarnings({"unchecked", "rawtypes"})
public AbstractFirstPassGroupingCollector(Sort groupSort, int topNGroups) throws IOException {
+ this(groupSort, topNGroups, new EqualsGroupComparer());
+ }
+
+ @SuppressWarnings({"unchecked", "rawtypes"})
+ public AbstractFirstPassGroupingCollector(Sort groupSort, int topNGroups,
+ GroupComparer groupComparer) throws IOException {
if (topNGroups < 1) {
throw new IllegalArgumentException("topNGroups must be >= 1 (got " + topNGroups + ")");
}
@@ -86,7 +94,10 @@ abstract public class AbstractFirstPassGroupingCollector<GROUP_VALUE_TYPE> exten
spareSlot = topNGroups;
groupMap = new HashMap<>(topNGroups);
+
+ this.groupComparer = groupComparer;
}
+
@Override
public boolean needsScores() {
@@ -183,9 +194,12 @@ abstract public class AbstractFirstPassGroupingCollector<GROUP_VALUE_TYPE> exten
// under null group)?
final GROUP_VALUE_TYPE groupValue = getDocGroupValue(doc);
- final CollectedSearchGroup<GROUP_VALUE_TYPE> group = groupMap.get(groupValue);
+ // final CollectedSearchGroup<GROUP_VALUE_TYPE> group = groupMap.get(groupValue);
- if (group == null) {
+ final Collection<CollectedSearchGroup<GROUP_VALUE_TYPE>> groups = groupComparer
+ .getSuitableGroups(groupValue, groupMap.values());
+
+ if (groups.isEmpty()) {
// First time we are seeing this group, or, we've seen
// it before but it fell out of the top N and is now
@@ -243,6 +257,31 @@ abstract public class AbstractFirstPassGroupingCollector<GROUP_VALUE_TYPE> exten
return;
}
+
+ Iterator<CollectedSearchGroup<GROUP_VALUE_TYPE>> groupsIterator = groups.iterator();
+
+ CollectedSearchGroup<GROUP_VALUE_TYPE> weakestGroup = groupsIterator.next();
+ while(groupsIterator.hasNext()) {
+ CollectedSearchGroup<GROUP_VALUE_TYPE> cGroup = groupsIterator.next();
+
+ for(int compIDX = 0;; compIDX++) {
+ final FieldComparator<?> fc = comparators[compIDX];
+ final int c = reversed[compIDX]
+ * fc.compare(cGroup.comparatorSlot, weakestGroup.comparatorSlot);
+
+ if(c < 0) {
+ // weaker group
+ weakestGroup = cGroup;
+ continue;
+ }
+ }
+ }
+
+ updateExistingGroup(doc, weakestGroup);
+ }
+
+ private void updateExistingGroup(int doc,
+ CollectedSearchGroup<GROUP_VALUE_TYPE> group) throws IOException {
// Update existing group:
for (int compIDX = 0;; compIDX++) {
@@ -297,8 +336,9 @@ abstract public class AbstractFirstPassGroupingCollector<GROUP_VALUE_TYPE> exten
}
}
}
- }
+ }
+
private void buildSortedSet() {
final Comparator<CollectedSearchGroup<?>> comparator = new Comparator<CollectedSearchGroup<?>>() {
@Override
diff --git a/lucene/grouping/src/java/org/apache/lucene/search/grouping/EqualsGroupComparer.java b/lucene/grouping/src/java/org/apache/lucene/search/grouping/EqualsGroupComparer.java
new file mode 100644
index 0000000..07d74db
--- /dev/null
+++ b/lucene/grouping/src/java/org/apache/lucene/search/grouping/EqualsGroupComparer.java
@@ -0,0 +1,38 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.search.grouping;
+
+import java.util.Collection;
+import java.util.stream.Collectors;
+
+import org.apache.lucene.index.SortedDocValues;
+import org.apache.lucene.util.BytesRef;
+
+public class EqualsGroupComparer extends GroupComparer<BytesRef> {
+ @Override
+ public Collection<CollectedSearchGroup<BytesRef>> getSuitableGroups(BytesRef groupValue,
+ Collection<CollectedSearchGroup<BytesRef>> groups) {
+ return groups.stream()
+ .filter(group -> group.groupValue.equals(groupValue))
+ .collect(Collectors.toSet());
+ }
+
+ @Override
+ public int lookupTerm(SortedDocValues sdv, BytesRef key) {
+ return sdv.lookupTerm(key);
+ }
+}
diff --git a/lucene/grouping/src/java/org/apache/lucene/search/grouping/GroupComparer.java b/lucene/grouping/src/java/org/apache/lucene/search/grouping/GroupComparer.java
new file mode 100644
index 0000000..178fb52
--- /dev/null
+++ b/lucene/grouping/src/java/org/apache/lucene/search/grouping/GroupComparer.java
@@ -0,0 +1,34 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.search.grouping;
+
+import java.util.Collection;
+import java.util.Map;
+
+import org.apache.lucene.index.SortedDocValues;
+
+public abstract class GroupComparer<GROUP_VALUE_TYPE> {
+
+ public abstract Collection<CollectedSearchGroup<GROUP_VALUE_TYPE>> getSuitableGroups(
+ GROUP_VALUE_TYPE groupValue,
+ Collection<CollectedSearchGroup<GROUP_VALUE_TYPE>> groups);
+
+ public abstract int lookupTerm(SortedDocValues sdv, GROUP_VALUE_TYPE key);
+
+ public void init(Map params) { }
+
+}
\ No newline at end of file
diff --git a/lucene/grouping/src/java/org/apache/lucene/search/grouping/LevinshtienDistanceGroupComparer.java b/lucene/grouping/src/java/org/apache/lucene/search/grouping/LevinshtienDistanceGroupComparer.java
new file mode 100644
index 0000000..658d9d3
--- /dev/null
+++ b/lucene/grouping/src/java/org/apache/lucene/search/grouping/LevinshtienDistanceGroupComparer.java
@@ -0,0 +1,70 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements. See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.lucene.search.grouping;
+
+import java.util.Collection;
+import java.util.Map;
+import java.util.stream.Collectors;
+
+import org.apache.lucene.index.SortedDocValues;
+import org.apache.lucene.search.spell.LuceneLevenshteinDistance;
+import org.apache.lucene.util.BytesRef;
+
+public class LevinshtienDistanceGroupComparer extends GroupComparer<BytesRef> {
+
+ final static String DISTANCE_PARAM = "dist";
+ final LuceneLevenshteinDistance lld = new LuceneLevenshteinDistance();
+ private int maxDistance = 2; // default
+
+ @Override
+ public Collection<CollectedSearchGroup<BytesRef>> getSuitableGroups(BytesRef groupValue,
+ Collection<CollectedSearchGroup<BytesRef>> groups) {
+ return groups.stream()
+ .filter(
+ group -> lld.getDistance(new String(groupValue.bytes), new String(group.groupValue.bytes)) <= maxDistance)
+ .collect(Collectors.toSet());
+ }
+
+ @Override
+ public int lookupTerm(SortedDocValues sdv, BytesRef key) {
+ int low = 0;
+ int high = sdv.getValueCount() - 1;
+
+ while (low <= high) {
+ int mid = (low + high) >>> 1;
+ final BytesRef term = sdv.lookupOrd(mid);
+ float cmp = lld.getDistance(new String(key.bytes), new String(term.bytes));
+
+ if (cmp <= maxDistance) {
+ return mid;
+ } else {
+ low = mid + 1;
+ }
+ }
+
+ return -(low + 1); // key not found.
+ }
+
+ @Override
+ public void init(Map params) {
+ Object distObj = params.get(DISTANCE_PARAM);
+ if (distObj instanceof Integer) {
+ maxDistance = (Integer) distObj;
+ }
+ }
+
+}
diff --git a/lucene/grouping/src/java/org/apache/lucene/search/grouping/term/TermSecondPassGroupingCollector.java b/lucene/grouping/src/java/org/apache/lucene/search/grouping/term/TermSecondPassGroupingCollector.java
index 9292856..f9f81c3 100644
--- a/lucene/grouping/src/java/org/apache/lucene/search/grouping/term/TermSecondPassGroupingCollector.java
+++ b/lucene/grouping/src/java/org/apache/lucene/search/grouping/term/TermSecondPassGroupingCollector.java
@@ -24,6 +24,8 @@ import org.apache.lucene.index.DocValues;
import org.apache.lucene.index.SortedDocValues;
import org.apache.lucene.search.Sort;
import org.apache.lucene.search.grouping.AbstractSecondPassGroupingCollector;
+import org.apache.lucene.search.grouping.EqualsGroupComparer;
+import org.apache.lucene.search.grouping.GroupComparer;
import org.apache.lucene.search.grouping.SearchGroup;
import org.apache.lucene.util.BytesRef;
import org.apache.lucene.util.SentinelIntSet;
@@ -42,14 +44,24 @@ public class TermSecondPassGroupingCollector extends AbstractSecondPassGroupingC
private SortedDocValues index;
+ private final GroupComparer<BytesRef> groupComparer;
+
@SuppressWarnings({"unchecked", "rawtypes"})
public TermSecondPassGroupingCollector(String groupField, Collection<SearchGroup<BytesRef>> groups, Sort groupSort, Sort withinGroupSort,
- int maxDocsPerGroup, boolean getScores, boolean getMaxScores, boolean fillSortFields)
+ int maxDocsPerGroup, boolean getScores, boolean getMaxScores, boolean fillSortFields) throws IOException {
+ this(groupField, groups, groupSort, withinGroupSort, maxDocsPerGroup, getScores, getMaxScores, fillSortFields, new EqualsGroupComparer());
+ }
+
+ @SuppressWarnings({"unchecked", "rawtypes"})
+ public TermSecondPassGroupingCollector(String groupField, Collection<SearchGroup<BytesRef>> groups, Sort groupSort, Sort withinGroupSort,
+ int maxDocsPerGroup, boolean getScores, boolean getMaxScores, boolean fillSortFields
+ , GroupComparer<BytesRef> groupComparer)
throws IOException {
super(groups, groupSort, withinGroupSort, maxDocsPerGroup, getScores, getMaxScores, fillSortFields);
this.groupField = groupField;
this.ordSet = new SentinelIntSet(groupMap.size(), -2);
super.groupDocs = (SearchGroupDocs<BytesRef>[]) new SearchGroupDocs[ordSet.keys.length];
+ this.groupComparer = groupComparer;
}
@Override
@@ -61,7 +73,8 @@ public class TermSecondPassGroupingCollector extends AbstractSecondPassGroupingC
ordSet.clear();
for (SearchGroupDocs<BytesRef> group : groupMap.values()) {
// System.out.println(" group=" + (group.groupValue == null ? "null" : group.groupValue.utf8ToString()));
- int ord = group.groupValue == null ? -1 : index.lookupTerm(group.groupValue);
+ // int ord = group.groupValue == null ? -1 : index.lookupTerm(group.groupValue);
+ int ord = group.groupValue == null ? -1 : groupComparer.lookupTerm(index, group.groupValue);
if (group.groupValue == null || ord >= 0) {
groupDocs[ordSet.put(ord)] = group;
}
--
2.7.2.windows.1