| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| package org.apache.lucene.util.fst; |
| |
| import java.io.BufferedReader; |
| import java.io.IOException; |
| import java.io.InputStream; |
| import java.nio.charset.StandardCharsets; |
| import java.nio.file.Files; |
| import java.nio.file.Path; |
| import java.nio.file.Paths; |
| import java.util.ArrayList; |
| import java.util.Arrays; |
| import java.util.Collections; |
| import java.util.HashSet; |
| import java.util.List; |
| import java.util.Locale; |
| import java.util.Set; |
| import java.util.zip.GZIPInputStream; |
| import org.apache.lucene.store.ByteArrayDataInput; |
| import org.apache.lucene.store.DataInput; |
| import org.apache.lucene.store.InputStreamDataInput; |
| import org.apache.lucene.util.BytesRef; |
| import org.apache.lucene.util.CharsRef; |
| import org.apache.lucene.util.IntsRefBuilder; |
| import org.apache.lucene.util.LuceneTestCase; |
| |
| public class TestFSTDirectAddressing extends LuceneTestCase { |
| |
| public void testDenseWithGap() throws Exception { |
| List<String> words = Arrays.asList("ah", "bi", "cj", "dk", "fl", "gm"); |
| List<BytesRef> entries = new ArrayList<>(); |
| for (String word : words) { |
| entries.add(new BytesRef(word.getBytes(StandardCharsets.US_ASCII))); |
| } |
| final BytesRefFSTEnum<Object> fstEnum = new BytesRefFSTEnum<>(buildFST(entries)); |
| for (BytesRef entry : entries) { |
| assertNotNull(entry.utf8ToString() + " not found", fstEnum.seekExact(entry)); |
| } |
| } |
| |
| public void testDeDupTails() throws Exception { |
| List<BytesRef> entries = new ArrayList<>(); |
| for (int i = 0; i < 1000000; i += 4) { |
| byte[] b = new byte[3]; |
| int val = i; |
| for (int j = b.length - 1; j >= 0; --j) { |
| b[j] = (byte) (val & 0xff); |
| val >>= 8; |
| } |
| entries.add(new BytesRef(b)); |
| } |
| long size = buildFST(entries).ramBytesUsed(); |
| // Size is 1648 when we use only list-encoding. We were previously failing to ever de-dup |
| // direct addressing, which led this case to blow up. |
| // This test will fail if there is more than 1% size increase with direct addressing. |
| assertTrue("FST size = " + size + " B", size <= 1648 * 1.01d); |
| } |
| |
| @Nightly |
| public void testWorstCaseForDirectAddressing() throws Exception { |
| // This test will fail if there is more than 1% memory increase with direct addressing in this |
| // worst case. |
| final double MEMORY_INCREASE_LIMIT_PERCENT = 1d; |
| final int NUM_WORDS = 1000000; |
| |
| // Generate words with specially crafted bytes. |
| Set<BytesRef> wordSet = new HashSet<>(); |
| for (int i = 0; i < NUM_WORDS; ++i) { |
| byte[] b = new byte[5]; |
| random().nextBytes(b); |
| for (int j = 0; j < b.length; ++j) { |
| b[j] &= 0xfc; // Make this byte a multiple of 4. |
| } |
| wordSet.add(new BytesRef(b)); |
| } |
| List<BytesRef> wordList = new ArrayList<>(wordSet); |
| Collections.sort(wordList); |
| |
| // Disable direct addressing and measure the FST size. |
| FSTCompiler<Object> fstCompiler = createFSTCompiler(-1f); |
| FST<Object> fst = buildFST(wordList, fstCompiler); |
| long ramBytesUsedNoDirectAddressing = fst.ramBytesUsed(); |
| |
| // Enable direct addressing and measure the FST size. |
| fstCompiler = createFSTCompiler(FSTCompiler.DIRECT_ADDRESSING_MAX_OVERSIZING_FACTOR); |
| fst = buildFST(wordList, fstCompiler); |
| long ramBytesUsed = fst.ramBytesUsed(); |
| |
| // Compute the size increase in percents. |
| double directAddressingMemoryIncreasePercent = |
| ((double) ramBytesUsed / ramBytesUsedNoDirectAddressing - 1) * 100; |
| |
| // printStats(builder, ramBytesUsed, directAddressingMemoryIncreasePercent); |
| |
| // Verify the FST size does not exceed the limit. |
| assertTrue( |
| "FST size exceeds limit, size = " |
| + ramBytesUsed |
| + ", increase = " |
| + directAddressingMemoryIncreasePercent |
| + " %" |
| + ", limit = " |
| + MEMORY_INCREASE_LIMIT_PERCENT |
| + " %", |
| directAddressingMemoryIncreasePercent < MEMORY_INCREASE_LIMIT_PERCENT); |
| } |
| |
| private static void printStats( |
| FSTCompiler<Object> fstCompiler, |
| long ramBytesUsed, |
| double directAddressingMemoryIncreasePercent) { |
| System.out.println( |
| "directAddressingMaxOversizingFactor = " |
| + fstCompiler.getDirectAddressingMaxOversizingFactor()); |
| System.out.println( |
| "ramBytesUsed = " |
| + String.format(Locale.ENGLISH, "%.2f MB", ramBytesUsed / 1024d / 1024d) |
| + String.format( |
| Locale.ENGLISH, |
| " (%.2f %% increase with direct addressing)", |
| directAddressingMemoryIncreasePercent)); |
| System.out.println("num nodes = " + fstCompiler.nodeCount); |
| long fixedLengthArcNodeCount = |
| fstCompiler.directAddressingNodeCount + fstCompiler.binarySearchNodeCount; |
| System.out.println( |
| "num fixed-length-arc nodes = " |
| + fixedLengthArcNodeCount |
| + String.format( |
| Locale.ENGLISH, |
| " (%.2f %% of all nodes)", |
| ((double) fixedLengthArcNodeCount / fstCompiler.nodeCount * 100))); |
| System.out.println( |
| "num binary-search nodes = " |
| + (fstCompiler.binarySearchNodeCount) |
| + String.format( |
| Locale.ENGLISH, |
| " (%.2f %% of fixed-length-arc nodes)", |
| ((double) (fstCompiler.binarySearchNodeCount) / fixedLengthArcNodeCount * 100))); |
| System.out.println( |
| "num direct-addressing nodes = " |
| + (fstCompiler.directAddressingNodeCount) |
| + String.format( |
| Locale.ENGLISH, |
| " (%.2f %% of fixed-length-arc nodes)", |
| ((double) (fstCompiler.directAddressingNodeCount) |
| / fixedLengthArcNodeCount |
| * 100))); |
| } |
| |
| private static FSTCompiler<Object> createFSTCompiler(float directAddressingMaxOversizingFactor) { |
| return new FSTCompiler.Builder<>(FST.INPUT_TYPE.BYTE1, NoOutputs.getSingleton()) |
| .directAddressingMaxOversizingFactor(directAddressingMaxOversizingFactor) |
| .build(); |
| } |
| |
| private FST<Object> buildFST(List<BytesRef> entries) throws Exception { |
| return buildFST( |
| entries, createFSTCompiler(FSTCompiler.DIRECT_ADDRESSING_MAX_OVERSIZING_FACTOR)); |
| } |
| |
| private static FST<Object> buildFST(List<BytesRef> entries, FSTCompiler<Object> fstCompiler) |
| throws Exception { |
| BytesRef last = null; |
| for (BytesRef entry : entries) { |
| if (entry.equals(last) == false) { |
| fstCompiler.add( |
| Util.toIntsRef(entry, new IntsRefBuilder()), NoOutputs.getSingleton().getNoOutput()); |
| } |
| last = entry; |
| } |
| return fstCompiler.compile(); |
| } |
| |
| public static void main(String... args) throws Exception { |
| if (args.length < 2) { |
| throw new IllegalArgumentException("Missing argument"); |
| } |
| switch (args[0]) { |
| case "-countFSTArcs": |
| countFSTArcs(args[1]); |
| break; |
| case "-measureFSTOversizing": |
| measureFSTOversizing(args[1]); |
| break; |
| case "-recompileAndWalk": |
| recompileAndWalk(args[1]); |
| break; |
| default: |
| throw new IllegalArgumentException("Invalid argument " + args[0]); |
| } |
| } |
| |
| private static void countFSTArcs(String fstFilePath) throws IOException { |
| byte[] buf = Files.readAllBytes(Paths.get(fstFilePath)); |
| DataInput in = new ByteArrayDataInput(buf); |
| FST<BytesRef> fst = new FST<>(in, in, ByteSequenceOutputs.getSingleton()); |
| BytesRefFSTEnum<BytesRef> fstEnum = new BytesRefFSTEnum<>(fst); |
| int binarySearchArcCount = 0, directAddressingArcCount = 0, listArcCount = 0; |
| while (fstEnum.next() != null) { |
| if (fstEnum.arcs[fstEnum.upto].bytesPerArc() == 0) { |
| listArcCount++; |
| } else if (fstEnum.arcs[fstEnum.upto].nodeFlags() == FST.ARCS_FOR_DIRECT_ADDRESSING) { |
| directAddressingArcCount++; |
| } else { |
| binarySearchArcCount++; |
| } |
| } |
| System.out.println( |
| "direct addressing arcs = " |
| + directAddressingArcCount |
| + ", binary search arcs = " |
| + binarySearchArcCount |
| + " list arcs = " |
| + listArcCount); |
| } |
| |
| private static void measureFSTOversizing(String wordsFilePath) throws Exception { |
| final int MAX_NUM_WORDS = 1000000; |
| |
| // Read real english words. |
| List<BytesRef> wordList = new ArrayList<>(); |
| try (BufferedReader reader = Files.newBufferedReader(Paths.get(wordsFilePath))) { |
| while (wordList.size() < MAX_NUM_WORDS) { |
| String word = reader.readLine(); |
| if (word == null) { |
| break; |
| } |
| wordList.add(new BytesRef(word)); |
| } |
| } |
| Collections.sort(wordList); |
| |
| // Disable direct addressing and measure the FST size. |
| FSTCompiler<Object> fstCompiler = createFSTCompiler(-1f); |
| FST<Object> fst = buildFST(wordList, fstCompiler); |
| long ramBytesUsedNoDirectAddressing = fst.ramBytesUsed(); |
| |
| // Enable direct addressing and measure the FST size. |
| fstCompiler = createFSTCompiler(FSTCompiler.DIRECT_ADDRESSING_MAX_OVERSIZING_FACTOR); |
| fst = buildFST(wordList, fstCompiler); |
| long ramBytesUsed = fst.ramBytesUsed(); |
| |
| // Compute the size increase in percents. |
| double directAddressingMemoryIncreasePercent = |
| ((double) ramBytesUsed / ramBytesUsedNoDirectAddressing - 1) * 100; |
| |
| printStats(fstCompiler, ramBytesUsed, directAddressingMemoryIncreasePercent); |
| } |
| |
| private static void recompileAndWalk(String fstFilePath) throws IOException { |
| try (InputStreamDataInput in = |
| new InputStreamDataInput(newInputStream(Paths.get(fstFilePath)))) { |
| |
| System.out.println("Reading FST"); |
| long startTimeMs = System.currentTimeMillis(); |
| FST<CharsRef> originalFst = new FST<>(in, in, CharSequenceOutputs.getSingleton()); |
| long endTimeMs = System.currentTimeMillis(); |
| System.out.println("time = " + (endTimeMs - startTimeMs) + " ms"); |
| |
| for (float oversizingFactor : List.of(0f, 0f, 0f, 1f, 1f, 1f)) { |
| System.out.println("\nFST construction (oversizingFactor=" + oversizingFactor + ")"); |
| startTimeMs = System.currentTimeMillis(); |
| FST<CharsRef> fst = recompile(originalFst, oversizingFactor); |
| endTimeMs = System.currentTimeMillis(); |
| System.out.println("time = " + (endTimeMs - startTimeMs) + " ms"); |
| System.out.println("FST RAM = " + fst.ramBytesUsed() + " B"); |
| |
| System.out.println("FST enum"); |
| startTimeMs = System.currentTimeMillis(); |
| walk(fst); |
| endTimeMs = System.currentTimeMillis(); |
| System.out.println("time = " + (endTimeMs - startTimeMs) + " ms"); |
| } |
| } |
| } |
| |
| private static InputStream newInputStream(Path path) throws IOException { |
| InputStream in = Files.newInputStream(path); |
| String fileName = path.getFileName().toString(); |
| if (fileName.endsWith("gz") || fileName.endsWith("zip")) { |
| in = new GZIPInputStream(in); |
| } |
| return in; |
| } |
| |
| private static FST<CharsRef> recompile(FST<CharsRef> fst, float oversizingFactor) |
| throws IOException { |
| FSTCompiler<CharsRef> fstCompiler = |
| new FSTCompiler.Builder<>(FST.INPUT_TYPE.BYTE4, CharSequenceOutputs.getSingleton()) |
| .directAddressingMaxOversizingFactor(oversizingFactor) |
| .build(); |
| IntsRefFSTEnum<CharsRef> fstEnum = new IntsRefFSTEnum<>(fst); |
| IntsRefFSTEnum.InputOutput<CharsRef> inputOutput; |
| while ((inputOutput = fstEnum.next()) != null) { |
| fstCompiler.add(inputOutput.input, CharsRef.deepCopyOf(inputOutput.output)); |
| } |
| return fstCompiler.compile(); |
| } |
| |
| private static int walk(FST<CharsRef> read) throws IOException { |
| IntsRefFSTEnum<CharsRef> fstEnum = new IntsRefFSTEnum<>(read); |
| IntsRefFSTEnum.InputOutput<CharsRef> inputOutput; |
| int terms = 0; |
| while ((inputOutput = fstEnum.next()) != null) { |
| terms += inputOutput.input.length; |
| terms += inputOutput.output.length; |
| } |
| return terms; |
| } |
| } |