| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package org.apache.nlpcraft.probe.mgrs.sentence; |
| |
| import java.util.ArrayList; |
| import java.util.Collection; |
| import java.util.Collections; |
| import java.util.Comparator; |
| import java.util.List; |
| import java.util.Set; |
| import java.util.concurrent.ForkJoinPool; |
| import java.util.concurrent.RecursiveTask; |
| |
| import static java.util.stream.Collectors.toList; |
| |
| /** |
| * It is not converted to scala because scala and java long values implicit conversion performance problems. |
| */ |
| class NCSentenceHelper extends RecursiveTask<List<Long>> { |
| private static final long THRESHOLD = (long)Math.pow(2, 20); |
| |
| private final long lo; |
| private final long hi; |
| private final long[] wordBits; |
| private final int[] wordCounts; |
| |
| private NCSentenceHelper(long lo, long hi, long[] wordBits, int[] wordCounts) { |
| this.lo = lo; |
| this.hi = hi; |
| this.wordBits = wordBits; |
| this.wordCounts = wordCounts; |
| } |
| |
| private List<Long> computeLocal() { |
| List<Long> res = new ArrayList<>(); |
| |
| for (long comboBits = lo; comboBits < hi; comboBits++) { |
| boolean match = true; |
| |
| // For each input row we check if subtracting the current combination of words |
| // from the input row would give us the expected result. |
| for (int j = 0; j < wordBits.length; j++) { |
| // Get bitmask of how many words can be subtracted from the row. |
| // Check if there is more than 1 word remaining after subtraction. |
| if (wordCounts[j] - Long.bitCount(wordBits[j] & comboBits) > 1) { |
| // Skip this combination. |
| match = false; |
| |
| break; |
| } |
| } |
| |
| if (match && excludes(comboBits, res)) |
| res.add(comboBits); |
| } |
| |
| return res; |
| } |
| |
| private List<Long> forkJoin() { |
| long mid = lo + hi >>> 1L; |
| |
| NCSentenceHelper t1 = new NCSentenceHelper(lo, mid, wordBits, wordCounts); |
| NCSentenceHelper t2 = new NCSentenceHelper(mid, hi, wordBits, wordCounts); |
| |
| t2.fork(); |
| |
| return merge(t1.compute(), t2.join()); |
| } |
| |
| private static List<Long> merge(List<Long> l1, List<Long> l2) { |
| if (l1.isEmpty()) |
| return l2; |
| else if (l2.isEmpty()) |
| return l1; |
| |
| int size1 = l1.size(); |
| int size2 = l2.size(); |
| |
| if (size1 == 1 && size2 > 1 || size2 == 1 && size1 > 1) { |
| // Minor optimization in case if one of the lists has only one element. |
| List<Long> res = size1 == 1 ? l2 : l1; |
| Long val = size1 == 1 ? l1.get(0) : l2.get(0); |
| |
| if (excludes(val, res)) |
| res.add(val); |
| |
| return res; |
| } |
| |
| List<Long> res = new ArrayList<>(size1 + size2); |
| |
| for (int i = 0, max = Math.max(size1, size2); i < max; i++) { |
| Long v1 = i < size1 ? l1.get(i) : null; |
| Long v2 = i < size2 ? l2.get(i) : null; |
| |
| if (v1 != null && v2 != null) { |
| if (containsAllBits(v1, v2)) |
| v1 = null; |
| else if (containsAllBits(v2, v1)) |
| v2 = null; |
| } |
| |
| if (v1 != null && excludes(v1, res)) |
| res.add(v1); |
| |
| if (v2 != null && excludes(v2, res)) |
| res.add(v2); |
| } |
| |
| return res; |
| } |
| |
| private static boolean excludes(long bits, List<Long> allBits) { |
| for (Long allBit : allBits) |
| if (containsAllBits(bits, allBit)) |
| return false; |
| |
| return true; |
| } |
| |
| private static boolean containsAllBits(long bitSet1, long bitSet2) { |
| return (bitSet1 & bitSet2) == bitSet2; |
| } |
| |
| private static <T> long wordsToBits(Set<T> words, List<T> dict) { |
| long bits = 0; |
| |
| for (int i = 0, n = dict.size(); i < n; i++) |
| if (words.contains(dict.get(i))) |
| bits |= 1L << i; |
| |
| return bits; |
| } |
| |
| private static <T> List<T> bitsToWords(long bits, List<T> dict) { |
| List<T> words = new ArrayList<>(Long.bitCount(bits)); |
| |
| for (int i = 0, n = dict.size(); i < n; i++) |
| if ((bits & 1L << i) != 0) |
| words.add(dict.get(i)); |
| |
| return words; |
| } |
| |
| @Override |
| protected List<Long> compute() { |
| return hi - lo <= THRESHOLD ? computeLocal() : forkJoin(); |
| } |
| |
| /** |
| * |
| * @param words |
| * @param pool |
| * @param <T> |
| * @return |
| */ |
| static <T> List<List<T>> findCombinations(List<Set<T>> words, ForkJoinPool pool) { |
| assert words != null && !words.isEmpty(); |
| assert pool != null; |
| |
| if (words.stream().allMatch(p -> p.size() == 1)) |
| return Collections.singletonList(Collections.emptyList()); |
| |
| // Build dictionary of unique words. |
| List<T> dict = words.stream().flatMap(Collection::stream).distinct().collect(toList()); |
| |
| if (dict.size() > Long.SIZE) |
| // Note: Power set of 64 words results in 9223372036854775807 combinations. |
| throw new IllegalArgumentException("Dictionary is too long: " + dict.size()); |
| |
| // Convert words to bitmasks (each bit corresponds to an index in the dictionary). |
| long[] wordBits = words.stream().sorted(Comparator.comparingInt(Set::size)).mapToLong(row -> wordsToBits(row, dict)).toArray(); |
| // Cache words count per row. |
| int[] wordCounts = words.stream().sorted(Comparator.comparingInt(Set::size)).mapToInt(Set::size).toArray(); |
| |
| // Prepare Fork/Join task to iterate over the power set of all combinations. |
| return |
| pool.invoke(new NCSentenceHelper(1, (long)Math.pow(2, dict.size()), wordBits, wordCounts)). |
| stream().map(bits -> bitsToWords(bits, dict)).collect(toList()); |
| } |
| } |