| # SPDX-License-Identifier: Apache-2.0 |
| # |
| # Modifications by Apache Solr contributors; see git log for details. |
| # Licensed under the Apache License, Version 2.0. |
| # |
| # The OpenSearch Contributors require contributions made to |
| # this file be licensed under the Apache-2.0 license or a |
| # compatible open source license. |
| # Modifications Copyright OpenSearch Contributors. See |
| # GitHub history for details. |
| # Licensed to Elasticsearch B.V. under one or more contributor |
| # license agreements. See the NOTICE file distributed with |
| # this work for additional information regarding copyright |
| # ownership. Elasticsearch B.V. licenses this file to you under |
| # the Apache License, Version 2.0 (the "License"); you may |
| # not use this file except in compliance with the License. |
| # You may obtain a copy of the License at |
| # |
| # http://www.apache.org/licenses/LICENSE-2.0 |
| # |
| # Unless required by applicable law or agreed to in writing, |
| # software distributed under the License is distributed on an |
| # "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| # KIND, either express or implied. See the License for the |
| # specific language governing permissions and limitations |
| # under the License. |
| |
| import bz2 |
| import gzip |
| import logging |
| import os |
| import mmap |
| import shutil |
| import subprocess |
| import tarfile |
| import zipfile |
| import urllib.error |
| from contextlib import suppress |
| import zstandard as zstd |
| |
| from solrorbit import exceptions |
| from solrorbit.utils import console |
| |
| |
| class FileSource: |
| """ |
| FileSource is a wrapper around a plain file which simplifies testing of file I/O calls. |
| """ |
| def __init__(self, file_name, mode, encoding="utf-8"): |
| self.file_name = file_name |
| self.mode = mode |
| self.encoding = encoding |
| self.f = None |
| |
| def open(self): |
| self.f = open(self.file_name, mode=self.mode, encoding=self.encoding) |
| # allow for chaining |
| return self |
| |
| def seek(self, offset): |
| self.f.seek(offset) |
| |
| def read(self): |
| return self.f.read() |
| |
| def readline(self): |
| return self.f.readline() |
| |
| def readlines(self, num_lines): |
| lines = [] |
| f = self.f |
| for _ in range(num_lines): |
| line = f.readline() |
| if len(line) == 0: |
| break |
| lines.append(line) |
| return lines |
| |
| def close(self): |
| self.f.close() |
| self.f = None |
| |
| def __enter__(self): |
| self.open() |
| return self |
| |
| def __exit__(self, exc_type, exc_val, exc_tb): |
| self.close() |
| return False |
| |
| def __str__(self, *args, **kwargs): |
| return self.file_name |
| |
| |
| class MmapSource: |
| """ |
| MmapSource is a wrapper around a memory-mapped file which simplifies testing of file I/O calls. |
| """ |
| def __init__(self, file_name, mode, encoding="utf-8"): |
| self.file_name = file_name |
| self.mode = mode |
| self.encoding = encoding |
| self.f = None |
| self.mm = None |
| |
| def open(self): |
| self.f = open(self.file_name, mode="r+b") |
| self.mm = mmap.mmap(self.f.fileno(), 0, access=mmap.ACCESS_READ) |
| # madvise is available in Python 3.8+ |
| with suppress(AttributeError): |
| self.mm.madvise(mmap.MADV_SEQUENTIAL) |
| |
| # allow for chaining |
| return self |
| |
| def seek(self, offset): |
| self.mm.seek(offset) |
| |
| def read(self): |
| return self.mm.read() |
| |
| def readline(self): |
| return self.mm.readline() |
| |
| def readlines(self, num_lines): |
| lines = [] |
| mm = self.mm |
| for _ in range(num_lines): |
| line = mm.readline() |
| if line == b"": |
| break |
| lines.append(line) |
| return lines |
| |
| def close(self): |
| self.mm.close() |
| self.mm = None |
| self.f.close() |
| self.f = None |
| |
| def __enter__(self): |
| self.open() |
| return self |
| |
| def __exit__(self, exc_type, exc_val, exc_tb): |
| self.close() |
| return False |
| |
| def __str__(self, *args, **kwargs): |
| return self.file_name |
| |
| |
| class DictStringFileSourceFactory: |
| """ |
| Factory that can create `StringAsFileSource` for tests. Based on the provided dict, it will create a proper `StringAsFileSource`. |
| |
| It is intended for scenarios where multiple files may be read by client code. |
| """ |
| def __init__(self, name_to_contents): |
| self.name_to_contents = name_to_contents |
| |
| def __call__(self, name, mode, encoding="utf-8"): |
| return StringAsFileSource(self.name_to_contents[name], mode, encoding) |
| |
| |
| class StringAsFileSource: |
| """ |
| Implementation of ``FileSource`` intended for tests. It's kept close to ``FileSource`` to simplify maintenance but it is not meant to |
| be used in production code. |
| """ |
| def __init__(self, contents, mode, encoding="utf-8"): |
| """ |
| :param contents: The file contents as an array of strings. Each item in the array should correspond to one line. |
| :param mode: The file mode. It is ignored in this implementation but kept to implement the same interface as ``FileSource``. |
| :param encoding: The file encoding. It is ignored in this implementation but kept to implement the same interface as ``FileSource``. |
| """ |
| self.contents = contents |
| self.current_index = 0 |
| self.opened = False |
| |
| def open(self): |
| self.opened = True |
| return self |
| |
| def seek(self, offset): |
| self._assert_opened() |
| if offset != 0: |
| raise AssertionError("StringAsFileSource does not support random seeks") |
| |
| def read(self): |
| self._assert_opened() |
| return "\n".join(self.contents) |
| |
| def readline(self): |
| self._assert_opened() |
| if self.current_index >= len(self.contents): |
| return "" |
| line = self.contents[self.current_index] |
| self.current_index += 1 |
| return line |
| |
| def readlines(self, num_lines): |
| lines = [] |
| for _ in range(num_lines): |
| line = self.readline() |
| if len(line) == 0: |
| break |
| lines.append(line) |
| return lines |
| |
| def close(self): |
| self._assert_opened() |
| self.contents = None |
| self.opened = False |
| |
| def _assert_opened(self): |
| assert self.opened |
| |
| def __enter__(self): |
| self.open() |
| return self |
| |
| def __exit__(self, exc_type, exc_val, exc_tb): |
| self.close() |
| return False |
| |
| def __str__(self, *args, **kwargs): |
| return "StringAsFileSource" |
| |
| |
| def ensure_dir(directory, mode=0o777): |
| """ |
| Ensure that the provided directory and all of its parent directories exist. |
| This function is safe to execute on existing directories (no op). |
| |
| :param directory: The directory to create (if it does not exist). |
| :param mode: The permission flags to use (if it does not exist). |
| """ |
| if directory: |
| os.makedirs(directory, mode, exist_ok=True) |
| |
| def ensure_symlink(source, link_name): |
| """ |
| Ensure that a symlink exists from link_name to source. |
| If link_name already exists, it will be updated or replaced as necessary. |
| |
| :param source: The target of the symlink |
| :param link_name: The path where the symlink should be created |
| """ |
| logger = logging.getLogger(__name__) |
| if os.path.exists(link_name): |
| if os.path.islink(link_name): |
| if os.readlink(link_name) != source: |
| os.remove(link_name) |
| os.symlink(source, link_name) |
| logger.info("Updated symlink: %s -> %s", link_name, source) |
| elif os.path.isdir(link_name): |
| shutil.rmtree(link_name) |
| os.symlink(source, link_name) |
| logger.info("Replaced directory with symlink: %s -> %s", link_name, source) |
| else: |
| os.remove(link_name) |
| os.symlink(source, link_name) |
| logger.info("Replaced file with symlink: %s -> %s", link_name, source) |
| else: |
| os.symlink(source, link_name) |
| logger.info("Created symlink: %s -> %s", link_name, source) |
| |
| def _zipdir(source_directory, archive): |
| for root, _, files in os.walk(source_directory): |
| for file in files: |
| archive.write( |
| filename=os.path.join(root, file), |
| arcname=os.path.relpath(os.path.join(root, file), os.path.join(source_directory, ".."))) |
| |
| |
| def is_archive(name): |
| """ |
| :param name: File name to check. Can be either just the file name or optionally also an absolute path. |
| :return: True iff the given file name is an archive that is also recognized for decompression by Solr Orbit. |
| """ |
| _, ext = splitext(name) |
| return ext in [".zip", ".bz2", ".gz", ".tar", ".tar.gz", ".tgz", ".tar.bz2", ".zst"] |
| |
| |
| def is_executable(name): |
| """ |
| :param name: File name to check. |
| :return: True iff given file name is executable and in PATH, all other cases False. |
| """ |
| |
| return shutil.which(name) is not None |
| |
| |
| def compress(source_directory, archive_name): |
| """ |
| Compress a directory tree. |
| |
| :param source_directory: The source directory to compress. Must be readable. |
| :param archive_name: The absolute path including the file name of the archive. Must have the extension .zip. |
| """ |
| archive = zipfile.ZipFile(archive_name, "w", zipfile.ZIP_DEFLATED) |
| _zipdir(source_directory, archive) |
| |
| |
| def compress_zstd(source_directory, archive_name): |
| """ |
| Compress a directory tree using Zstandard compression. |
| :param source_directory: The source directory to compress. Must be readable. |
| :param archive_name: The absolute path including the file name of the archive. Must have the extension .zst. |
| """ |
| zstc = zstd.ZstdCompressor() |
| |
| with open(archive_name, "wb") as archive_file: |
| with zstc.stream_writer(archive_file) as compressor: |
| for root, _, files in os.walk(source_directory): |
| for file in files: |
| file_path = os.path.join(root, file) |
| rel_path = os.path.relpath(file_path, source_directory) |
| # Write the file path (relative) to the archive to recreate the directory structure |
| compressor.write(rel_path.encode("utf-8")) |
| with open(file_path, "rb") as source_file: |
| # Write the content of the file to the archive |
| for chunk in source_file: |
| compressor.write(chunk) |
| |
| |
| def decompress(zip_name, target_directory): |
| """ |
| Decompresses the provided archive to the target directory. The following file extensions are supported: |
| |
| * zip |
| * bz2 |
| * gz |
| * tar |
| * tar.gz |
| * tgz |
| * tar.bz2 |
| * zst |
| |
| The decompression method is chosen based on the file extension. |
| |
| :param zip_name: The full path name to the file that should be decompressed. |
| :param target_directory: The directory to which files should be decompressed. May or may not exist prior to calling |
| this function. |
| """ |
| _, extension = splitext(zip_name) |
| if extension == ".zip": |
| _do_decompress(target_directory, zipfile.ZipFile(zip_name)) |
| elif extension == ".bz2": |
| decompressor_args = ["pbzip2", "-d", "-k", "-m10000", "-c"] |
| decompressor_lib = bz2.open |
| _do_decompress_manually(target_directory, zip_name, decompressor_args, decompressor_lib) |
| elif extension == ".gz": |
| decompressor_args = ["pigz", "-d", "-k", "-c"] |
| decompressor_lib = gzip.open |
| _do_decompress_manually(target_directory, zip_name, decompressor_args, decompressor_lib) |
| elif extension in [".tar", ".tar.gz", ".tgz", ".tar.bz2"]: |
| _do_decompress(target_directory, tarfile.open(zip_name)) |
| elif extension == ".zst": |
| _do_decompress_zstd(target_directory, zip_name) |
| else: |
| raise RuntimeError("Unsupported file extension [%s]. Cannot decompress [%s]" % (extension, zip_name)) |
| |
| |
| def _do_decompress_manually(target_directory, filename, decompressor_args, decompressor_lib): |
| decompressor_bin = decompressor_args[0] |
| base_path_without_extension = basename(splitext(filename)[0]) |
| |
| if is_executable(decompressor_bin): |
| if _do_decompress_manually_external(target_directory, filename, base_path_without_extension, decompressor_args): |
| return |
| else: |
| logging.getLogger(__name__).warning("%s not found in PATH. Using standard library, decompression will take longer.", |
| decompressor_bin) |
| |
| _do_decompress_manually_with_lib(target_directory, filename, decompressor_lib(filename)) |
| |
| |
| def _do_decompress_manually_external(target_directory, filename, base_path_without_extension, decompressor_args): |
| with open(os.path.join(target_directory, base_path_without_extension), "wb") as new_file: |
| try: |
| subprocess.run(decompressor_args + [filename], stdout=new_file, stderr=subprocess.PIPE, check=True) |
| except subprocess.CalledProcessError as err: |
| logging.getLogger(__name__).warning("Failed to decompress [%s] with [%s]. Error [%s]. Falling back to standard library.", |
| filename, err.cmd, err.stderr) |
| return False |
| return True |
| |
| |
| def _do_decompress_manually_with_lib(target_directory, filename, compressed_file): |
| path_without_extension = basename(splitext(filename)[0]) |
| |
| ensure_dir(target_directory) |
| try: |
| with open(os.path.join(target_directory, path_without_extension), "wb") as new_file: |
| for data in iter(lambda: compressed_file.read(100 * 1024), b""): |
| new_file.write(data) |
| finally: |
| compressed_file.close() |
| |
| |
| def _do_decompress_zstd(target_directory, filename): |
| path_without_extension = basename(splitext(filename)[0]) |
| try: |
| with open(filename, 'rb') as compressed_file: |
| zstd_decompressor = zstd.ZstdDecompressor() |
| with open(os.path.join(target_directory, path_without_extension), "wb") as new_file: |
| for chunk in zstd_decompressor.read_to_iter(compressed_file): |
| new_file.write(chunk) |
| finally: |
| compressed_file.close() |
| |
| |
| def _do_decompress(target_directory, compressed_file): |
| try: |
| compressed_file.extractall(path=target_directory) |
| except BaseException: |
| raise RuntimeError("Could not decompress provided archive [%s]" % compressed_file.filename) |
| finally: |
| compressed_file.close() |
| |
| |
| # just in a dedicated method to ease mocking |
| def dirname(path): |
| return os.path.dirname(path) |
| |
| |
| def basename(path): |
| return os.path.basename(path) |
| |
| |
| def exists(path): |
| return os.path.exists(path) |
| |
| |
| def normalize_path(path, cwd="."): |
| """ |
| Normalizes a path by removing redundant "../" and also expanding the "~" character to the user home directory. |
| :param path: A possibly non-normalized path. |
| :param cwd: The current working directory. "." by default. |
| :return: A normalized path. |
| """ |
| normalized = os.path.normpath(os.path.expanduser(path)) |
| # user specified only a file name? -> treat as relative to the current directory |
| if dirname(normalized) == "": |
| return os.path.join(cwd, normalized) |
| else: |
| return normalized |
| |
| |
| def escape_path(path): |
| """ |
| Escapes any characters that might be problematic in shell interactions. |
| |
| :param path: The original path. |
| :return: A potentially modified version of the path with all problematic characters escaped. |
| """ |
| return path.replace("\\", "\\\\") |
| |
| |
| def splitext(file_name): |
| if file_name.endswith(".tar.gz"): |
| return file_name[0:-7], file_name[-7:] |
| elif file_name.endswith(".tar.bz2"): |
| return file_name[0:-8], file_name[-8:] |
| else: |
| return os.path.splitext(file_name) |
| |
| |
| def is_plain_text(file): |
| _, ext = splitext(file) |
| return ext in [".ini", ".txt", ".json", ".yml", ".yaml", ".options", ".properties"] |
| |
| |
| def has_extension(file_name, extension): |
| """ |
| Checks whether the given file name has the given extension. |
| |
| :param file_name: A file name to check (either just the name or an absolute path name). |
| :param extension: The extension including the leading dot (i.e. it is ".txt", not "txt"). |
| :return: True iff the given ``file_name`` has the given ``extension``. |
| """ |
| _, ext = splitext(file_name) |
| return ext == extension |
| |
| |
| class FileOffsetTable: |
| """ |
| The FileOffsetTable represents a persistent mapping from lines in a data file to their offset in bytes in the |
| data file. This helps bulk-indexing clients to advance quickly to a certain position in a large data file. |
| """ |
| def __init__(self, data_file_path, offset_table_path, mode): |
| """ |
| Creates a new FileOffsetTable instance. The constructor should not be called directly but instead the |
| respective factory methods should be used. |
| |
| :param data_file_path: The absolute path to the data file. This file is assumed to exist at this point. |
| :param offset_table_path: The absolute path to the corresponding offset table file. Only required to exist |
| for read operations on the data file. |
| :param mode: The mode in which the file offset table should be opened. |
| """ |
| self.data_file_path = data_file_path |
| self.offset_table_path = offset_table_path |
| self.mode = mode |
| self.offset_file = None |
| |
| def exists(self): |
| """ |
| :return: True iff the file offset table already exists. |
| """ |
| return os.path.exists(self.offset_table_path) |
| |
| def is_valid(self): |
| """ |
| :return: True iff the file offset table exists and it is up-to-date. |
| """ |
| return self.exists() and os.path.getmtime(self.offset_table_path) >= os.path.getmtime(self.data_file_path) |
| |
| def __enter__(self): |
| self.offset_file = open(self.offset_table_path, self.mode) |
| return self |
| |
| def add_offset(self, line_number, offset): |
| """ |
| Adds a new offset mapping to the file offset table. This method has to be called inside a context-manager block. |
| |
| :param line_number: A line number to add. |
| :param offset: The corresponding offset in bytes. |
| """ |
| print(f"{line_number};{offset}", file=self.offset_file) |
| |
| def find_closest_offset(self, target_line_number): |
| """ |
| Determines the offset in bytes for the line L in the corresponding data file with the following properties: |
| |
| * L <= target_line_number |
| * For any line M, where M != L and M <= target_line_number: M > L (i.e. L is the closest match) |
| |
| :param target_line_number: A positive number representing a line number in the data file. |
| :return: A tuple of file offset in bytes to the line with the closest match and the number of lines that |
| still need to be skipped. |
| """ |
| prior_offset = 0 |
| prior_remaining_lines = target_line_number |
| |
| for line in self.offset_file: |
| line_number, offset_in_bytes = [int(i) for i in line.strip().split(";")] |
| if line_number <= target_line_number: |
| prior_offset = offset_in_bytes |
| prior_remaining_lines = target_line_number - line_number |
| else: |
| break |
| |
| return prior_offset, prior_remaining_lines |
| |
| def __exit__(self, exc_type, exc_val, exc_tb): |
| self.offset_file.close() |
| self.offset_file = None |
| return False |
| |
| @classmethod |
| def create_for_data_file(cls, data_file_path): |
| """ |
| Factory method to create a new file offset table. |
| |
| :param data_file_path: The absolute path to the data file for which a file offset table should be created. |
| """ |
| return cls(data_file_path, f"{data_file_path}.offset", "wt") |
| |
| @classmethod |
| def read_for_data_file(cls, data_file_path): |
| """ |
| |
| Factory method to read from an existing file offset table. |
| |
| :param data_file_path: The absolute path to the data file for which the file offset table should be read. |
| """ |
| return cls(data_file_path, f"{data_file_path}.offset", "rt") |
| |
| @staticmethod |
| def remove(data_file_path): |
| """ |
| Removes a file offset table for the provided data path. |
| |
| :param data_file_path: The absolute path to the data file for which the file offset table should be deleted. |
| """ |
| os.remove(f"{data_file_path}.offset") |
| |
| |
| def prepare_file_offset_table(data_file_path, base_url, source_url, downloader): |
| """ |
| Creates a file that contains a mapping from line numbers to file offsets for the provided path. This file is used internally by |
| #skip_lines(data_file_path, data_file) to speed up line skipping. |
| |
| :param data_file_path: The path to a text file that is readable by this process. |
| :return The number of lines read or ``None`` if it did not have to build the file offset table. |
| """ |
| file_offset_table = FileOffsetTable.create_for_data_file(data_file_path) |
| if not file_offset_table.is_valid(): |
| if not source_url: |
| try: |
| downloader.download(base_url, None, data_file_path + '.offset', None) |
| except exceptions.DataError as e: |
| if isinstance(e.cause, urllib.error.HTTPError) and (e.cause.code == 403 or e.cause.code == 404): |
| logging.getLogger(__name__).info("Pre-generated offset file not found, will generate from corpus data") |
| |
| if not file_offset_table.is_valid(): |
| console.info("Preparing file offset table for [%s] ... " % data_file_path, end="", flush=True) |
| line_number = 0 |
| with file_offset_table: |
| with open(data_file_path, mode="rt", encoding="utf-8") as data_file: |
| while True: |
| line = data_file.readline() |
| if len(line) == 0: |
| break |
| line_number += 1 |
| if line_number % 50000 == 0: |
| file_offset_table.add_offset(line_number, data_file.tell()) |
| console.println("[OK]") |
| return line_number |
| else: |
| return None |
| |
| |
| def remove_file_offset_table(data_file_path): |
| """ |
| |
| Attempts to remove the file offset table for the provided data path. |
| |
| :param data_file_path: The path to a text file that is readable by this process. |
| """ |
| FileOffsetTable.remove(data_file_path) |
| |
| |
| def skip_lines(data_file_path, data_file, number_of_lines_to_skip): |
| """ |
| Skips the first `number_of_lines_to_skip` lines in `data_file` as a side effect. |
| |
| :param data_file_path: The full path to the data file. |
| :param data_file: The data file. It is assumed that this file is already open for reading and its file pointer is at position zero. |
| :param number_of_lines_to_skip: A non-negative number of lines that should be skipped. |
| """ |
| if number_of_lines_to_skip == 0: |
| return |
| |
| file_offset_table = FileOffsetTable.read_for_data_file(data_file_path) |
| # can we fast forward? |
| if file_offset_table.exists(): |
| with file_offset_table: |
| offset, remaining_lines = file_offset_table.find_closest_offset(number_of_lines_to_skip) |
| else: |
| offset = 0 |
| remaining_lines = number_of_lines_to_skip |
| |
| # fast forward to the last known file offset |
| data_file.seek(offset) |
| # forward the last remaining lines if needed |
| if remaining_lines > 0: |
| for _ in range(remaining_lines): |
| data_file.readline() |
| |
| |
| def get_size(start_path="."): |
| total_size = 0 |
| for dirpath, _, filenames in os.walk(start_path): |
| for f in filenames: |
| fp = os.path.join(dirpath, f) |
| total_size += os.path.getsize(fp) |
| return total_size |