#
# Licensed to the Apache Software Foundation (ASF) under one or more
# contributor license agreements.  See the NOTICE file distributed with
# this work for additional information regarding copyright ownership.
# The ASF licenses this file to You under the Apache License, Version 2.0
# (the "License"); you may not use this file except in compliance with
# the License.  You may obtain a copy of the License at
#
#    http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#

"""File-based sink."""

# pytype: skip-file

import logging
import os
import re
import time
import uuid

from apache_beam.internal import util
from apache_beam.io import iobase
from apache_beam.io.filesystem import BeamIOError
from apache_beam.io.filesystem import CompressionTypes
from apache_beam.io.filesystems import FileSystems
from apache_beam.options.value_provider import StaticValueProvider
from apache_beam.options.value_provider import ValueProvider
from apache_beam.options.value_provider import check_accessible
from apache_beam.transforms import window
from apache_beam.transforms.display import DisplayDataItem

DEFAULT_SHARD_NAME_TEMPLATE = '-SSSSS-of-NNNNN'
DEFAULT_WINDOW_SHARD_NAME_TEMPLATE = '-W-SSSSS-of-NNNNN'
DEFAULT_TRIGGERING_FREQUENCY = 0

__all__ = ['FileBasedSink']

_LOGGER = logging.getLogger(__name__)


class FileBasedSink(iobase.Sink):
  """A sink to a GCS or local files.

  To implement a file-based sink, extend this class and override
  either :meth:`.write_record()` or :meth:`.write_encoded_record()`.

  If needed, also overwrite :meth:`.open()` and/or :meth:`.close()` to customize
  the file handling or write headers and footers.

  The output of this write is a :class:`~apache_beam.pvalue.PCollection` of
  all written shards.
  """

  # Max number of threads to be used for renaming.
  _MAX_RENAME_THREADS = 64
  __hash__ = None  # type: ignore[assignment]

  def __init__(
      self,
      file_path_prefix,
      coder,
      file_name_suffix='',
      num_shards=0,
      shard_name_template=None,
      mime_type='application/octet-stream',
      compression_type=CompressionTypes.AUTO,
      *,
      max_records_per_shard=None,
      max_bytes_per_shard=None,
      skip_if_empty=False,
      convert_fn=None,
      triggering_frequency=None):
    """
     Raises:
      TypeError: if file path parameters are not a :class:`str` or
        :class:`~apache_beam.options.value_provider.ValueProvider`, or if
        **compression_type** is not member of
        :class:`~apache_beam.io.filesystem.CompressionTypes`.
      ValueError: if **shard_name_template** is not of expected
        format.
    """
    if not isinstance(file_path_prefix, (str, ValueProvider)):
      raise TypeError(
          'file_path_prefix must be a string or ValueProvider;'
          'got %r instead' % file_path_prefix)
    if not isinstance(file_name_suffix, (str, ValueProvider)):
      raise TypeError(
          'file_name_suffix must be a string or ValueProvider;'
          'got %r instead' % file_name_suffix)

    if not CompressionTypes.is_valid_compression_type(compression_type):
      raise TypeError(
          'compression_type must be CompressionType object but '
          'was %s' % type(compression_type))
    if shard_name_template is None:
      shard_name_template = DEFAULT_SHARD_NAME_TEMPLATE
    elif shard_name_template == '':
      num_shards = 1
    if triggering_frequency is None:
      triggering_frequency = DEFAULT_TRIGGERING_FREQUENCY
    if isinstance(file_path_prefix, str):
      file_path_prefix = StaticValueProvider(str, file_path_prefix)
    if isinstance(file_name_suffix, str):
      file_name_suffix = StaticValueProvider(str, file_name_suffix)
    self.file_path_prefix = file_path_prefix
    self.file_name_suffix = file_name_suffix
    self.num_shards = num_shards
    self.coder = coder
    self.shard_name_template = shard_name_template
    self.shard_name_format = self._template_to_format(shard_name_template)
    self.shard_name_glob_format = self._template_to_glob_format(
        shard_name_template)
    self.compression_type = compression_type
    self.mime_type = mime_type
    self.max_records_per_shard = max_records_per_shard
    self.max_bytes_per_shard = max_bytes_per_shard
    self.skip_if_empty = skip_if_empty
    self.convert_fn = convert_fn
    self.triggering_frequency = triggering_frequency

  def display_data(self):
    return {
        'shards': DisplayDataItem(self.num_shards,
                                  label='Number of Shards').drop_if_default(0),
        'compression': DisplayDataItem(str(self.compression_type)),
        'file_pattern': DisplayDataItem(
            '{}{}{}'.format(
                self.file_path_prefix,
                self.shard_name_format,
                self.file_name_suffix),
            label='File Pattern')
    }

  @check_accessible(['file_path_prefix'])
  def open(self, temp_path):
    """Opens ``temp_path``, returning an opaque file handle object.

    The returned file handle is passed to ``write_[encoded_]record`` and
    ``close``.
    """
    writer = FileSystems.create(
        temp_path, self.mime_type, self.compression_type)
    if self.max_bytes_per_shard:
      self.byte_counter = _ByteCountingWriter(writer)
      return self.byte_counter
    else:
      return writer

  def write_record(self, file_handle, value):
    """Writes a single record go the file handle returned by ``open()``.

    By default, calls ``write_encoded_record`` after encoding the record with
    this sink's Coder.
    """
    self.write_encoded_record(file_handle, self.coder.encode(value))

  def write_encoded_record(self, file_handle, encoded_value):
    """Writes a single encoded record to the file handle returned by ``open()``.
    """
    raise NotImplementedError

  def close(self, file_handle):
    """Finalize and close the file handle returned from ``open()``.

    Called after all records are written.

    By default, calls ``file_handle.close()`` iff it is not None.
    """
    if file_handle is not None:
      file_handle.close()

  @check_accessible(['file_path_prefix', 'file_name_suffix'])
  def initialize_write(self):
    file_path_prefix = self.file_path_prefix.get()

    tmp_dir = self._create_temp_dir(file_path_prefix)
    FileSystems.mkdirs(tmp_dir)
    return tmp_dir

  def _create_temp_dir(self, file_path_prefix):
    base_path, last_component = FileSystems.split(file_path_prefix)
    if not last_component:
      # Trying to re-split the base_path to check if it's a root.
      new_base_path, _ = FileSystems.split(base_path)
      if base_path == new_base_path:
        raise ValueError(
            'Cannot create a temporary directory for root path '
            'prefix %s. Please specify a file path prefix with '
            'at least two components.' % file_path_prefix)
    path_components = [
        base_path, 'beam-temp-' + last_component + '-' + uuid.uuid1().hex
    ]
    return FileSystems.join(*path_components)

  @check_accessible(['file_path_prefix', 'file_name_suffix'])
  def open_writer(self, init_result, uid):
    # A proper suffix is needed for AUTO compression detection.
    # We also ensure there will be no collisions with uid and a
    # (possibly unsharded) file_path_prefix and a (possibly empty)
    # file_name_suffix.
    from apache_beam.pvalue import EmptySideInput

    # Handle case where init_result is EmptySideInput (empty collection)
    # TODO: https://github.com/apache/beam/issues/36563 for Prism
    if isinstance(init_result, EmptySideInput):
      # Fall back to creating a temporary directory based on file_path_prefix
      _LOGGER.warning(
          'Initialization result collection was empty, falling back to '
          'creating temporary directory. This may indicate an issue with '
          'the pipeline initialization phase.')
      file_path_prefix = self.file_path_prefix.get()
      init_result = self._create_temp_dir(file_path_prefix)
      FileSystems.mkdirs(init_result)

    file_path_prefix = self.file_path_prefix.get()
    file_name_suffix = self.file_name_suffix.get()
    suffix = ('.' + os.path.basename(file_path_prefix) + file_name_suffix)
    writer_path = FileSystems.join(init_result, uid) + suffix
    return FileBasedSinkWriter(self, writer_path)

  @check_accessible(['file_path_prefix', 'file_name_suffix'])
  def _get_final_name(self, shard_num, num_shards, w=None):
    if w is None or isinstance(w, window.GlobalWindow):
      window_utc = None
    else:
      window_utc = (
          '[' + w.start.to_utc_datetime().strftime("%Y-%m-%dT%H-%M-%S") + ', ' +
          w.end.to_utc_datetime().strftime("%Y-%m-%dT%H-%M-%S") + ')')
    return ''.join([
        self.file_path_prefix.get(),
        self.shard_name_format % dict(
            shard_num=shard_num,
            num_shards=num_shards,
            uuid=(uuid.uuid4()),
            window=w,
            window_utc=window_utc),
        self.file_name_suffix.get()
    ])

  @check_accessible(['file_path_prefix', 'file_name_suffix'])
  def _get_final_name_glob(self, num_shards, w=None):
    if w is None or isinstance(w, window.GlobalWindow):
      window_utc = None
    else:
      window_utc = (
          '[' + w.start.to_utc_datetime().strftime("%Y-%m-%dT%H-%M-%S") + ', ' +
          w.end.to_utc_datetime().strftime("%Y-%m-%dT%H-%M-%S") + ')')
    return ''.join([
        self.file_path_prefix.get(),
        self.shard_name_glob_format % dict(
            num_shards=num_shards,
            uuid=(uuid.uuid4()),
            window=w,
            window_utc=window_utc),
        self.file_name_suffix.get()
    ])

  def pre_finalize(self, init_result, writer_results):
    num_shards = len(list(writer_results))
    dst_glob = self._get_final_name_glob(num_shards)
    dst_glob_files = [
        file_metadata.path for mr in FileSystems.match([dst_glob])
        for file_metadata in mr.metadata_list
    ]

    if dst_glob_files:
      _LOGGER.warning(
          'Deleting %d existing files in target path matching: %s',
          len(dst_glob_files),
          self.shard_name_glob_format)
      FileSystems.delete(dst_glob_files)

  def pre_finalize_windowed(self, init_result, writer_results, window=None):
    num_shards = len(list(writer_results))
    dst_glob = self._get_final_name_glob(num_shards, window)
    dst_glob_files = [
        file_metadata.path for mr in FileSystems.match([dst_glob])
        for file_metadata in mr.metadata_list
    ]

    if dst_glob_files:
      _LOGGER.warning(
          'Deleting %d existing files in target path matching: %s',
          len(dst_glob_files),
          self.shard_name_glob_format)
      FileSystems.delete(dst_glob_files)

  def _check_state_for_finalize_write(
      self, writer_results, num_shards, window=None):
    """Checks writer output files' states.

    Returns:
      src_files, dst_files: Lists of files to rename. For each i, finalize_write
        should rename(src_files[i], dst_files[i]).
      delete_files: Src files to delete. These could be leftovers from an
        incomplete (non-atomic) rename operation.
      num_skipped: Tally of writer results files already renamed, such as from
        a previous run of finalize_write().
    """
    if not writer_results:
      return [], [], [], 0

    src_glob = FileSystems.join(FileSystems.split(writer_results[0])[0], '*')
    dst_glob = self._get_final_name_glob(num_shards, window)
    src_glob_files = set(
        file_metadata.path for mr in FileSystems.match([src_glob])
        for file_metadata in mr.metadata_list)
    dst_glob_files = set(
        file_metadata.path for mr in FileSystems.match([dst_glob])
        for file_metadata in mr.metadata_list)

    src_files = []
    dst_files = []
    delete_files = []
    num_skipped = 0
    for shard_num, src in enumerate(writer_results):
      final_name = self._get_final_name(shard_num, num_shards, window)
      dst = final_name
      src_exists = src in src_glob_files
      dst_exists = dst in dst_glob_files
      if not src_exists and not dst_exists:
        raise BeamIOError(
            'src and dst files do not exist. src: %s, dst: %s' % (src, dst))
      if not src_exists and dst_exists:
        _LOGGER.debug('src: %s -> dst: %s already renamed, skipping', src, dst)
        num_skipped += 1
        continue
      if (src_exists and dst_exists and
          FileSystems.checksum(src) == FileSystems.checksum(dst)):
        _LOGGER.debug('src: %s == dst: %s, deleting src', src, dst)
        delete_files.append(src)
        continue

      src_files.append(src)
      dst_files.append(dst)

    self._report_sink_lineage(dst_glob, dst_files)
    return src_files, dst_files, delete_files, num_skipped

  def _report_sink_lineage(self, dst_glob, dst_files):
    """
    Report sink Lineage. Report every file if number of files no more than 10,
    otherwise only report glob.
    """
    # There is rollup at the higher level, but this loses glob information.
    # Better to report multiple globs than just the parent directory.
    if len(dst_files) <= 10:
      for dst in dst_files:
        FileSystems.report_sink_lineage(dst)
    else:
      FileSystems.report_sink_lineage(dst_glob)

  @check_accessible(['file_path_prefix'])
  def finalize_write(
      self, init_result, writer_results, unused_pre_finalize_results):
    #Legacy finalize_write now has shares the implementation with
    #finalize_windowed_write when window is None.
    return self.finalize_windowed_write(
        init_result, writer_results, unused_pre_finalize_results, None)

  @check_accessible(['file_path_prefix'])
  def finalize_windowed_write(
      self, init_result, writer_results, unused_pre_finalize_results, w=None):
    writer_results = sorted(writer_results)
    num_shards = len(writer_results)

    src_files, dst_files, delete_files, num_skipped = (
        self._check_state_for_finalize_write(writer_results, num_shards, w))
    num_skipped += len(delete_files)
    FileSystems.delete(delete_files)
    num_shards_to_finalize = len(src_files)
    min_threads = min(num_shards_to_finalize, FileBasedSink._MAX_RENAME_THREADS)
    num_threads = max(1, min_threads)

    chunk_size = FileSystems.get_chunk_size(self.file_path_prefix.get())
    source_file_batch = [
        src_files[i:i + chunk_size]
        for i in range(0, len(src_files), chunk_size)
    ]
    destination_file_batch = [
        dst_files[i:i + chunk_size]
        for i in range(0, len(dst_files), chunk_size)
    ]

    if num_shards_to_finalize:
      start_time = time.time()

      def _rename_batch(batch):
        """_rename_batch executes batch rename operations."""
        source_files, destination_files = batch
        exceptions = []
        try:
          FileSystems.rename(source_files, destination_files)
          return exceptions
        except BeamIOError as exp:
          if exp.exception_details is None:
            raise
          for (src, dst), exception in exp.exception_details.items():
            if exception:
              _LOGGER.error(
                  ('Exception in _rename_batch. src: %s, '
                   'dst: %s, err: %s'),
                  src,
                  dst,
                  exception)
              exceptions.append(exception)
            else:
              _LOGGER.debug('Rename successful: %s -> %s', src, dst)
          return exceptions

      if w is None or isinstance(w, window.GlobalWindow):
        # bounded input was handled by finalize_write legacy method
        # the implementation here should be called by finalize_write
        # Use a thread pool for renaming operations.
        exception_batches = util.run_using_threadpool(
            _rename_batch,
            list(zip(source_file_batch, destination_file_batch)),
            num_threads)

        all_exceptions = [
            e for exception_batch in exception_batches for e in exception_batch
        ]
        if all_exceptions:
          raise Exception(
              'Encountered exceptions in finalize_write: %s' % all_exceptions)

        yield from dst_files
      else:
        # unbounded input
        batch = list([src_files, dst_files])
        exception_batches = _rename_batch(batch)

        all_exceptions = [
            e for exception_batch in exception_batches for e in exception_batch
        ]
        if all_exceptions:
          raise Exception(
              'Encountered exceptions in finalize_write: %s' % all_exceptions)

        yield from dst_files

      _LOGGER.info(
          'Renamed %d shards in %.2f seconds.',
          num_shards_to_finalize,
          time.time() - start_time)
    else:
      _LOGGER.warning(
          'No shards found to finalize. num_shards: %d, skipped: %d',
          num_shards,
          num_skipped)

    try:
      FileSystems.delete([init_result])
    except IOError:
      # This error is not serious, we simply log it.
      _LOGGER.info('Unable to delete file: %s', init_result)

  @staticmethod
  def _template_replace_window(shard_name_template):
    match = re.search('W+', shard_name_template)
    if match:
      shard_name_template = shard_name_template.replace(
          match.group(0), '%%(window)0%ds' % len(match.group(0)))
    match = re.search('V+', shard_name_template)
    if match:
      shard_name_template = shard_name_template.replace(
          match.group(0), '%%(window_utc)0%ds' % len(match.group(0)))
    return shard_name_template

  @staticmethod
  def _template_replace_uuid(shard_name_template):
    match = re.search('U+', shard_name_template)
    if match:
      shard_name_template = shard_name_template.replace(
          match.group(0), '%%(uuid)0%dd' % len(match.group(0)))
    return shard_name_template

  @staticmethod
  def _template_replace_num_shards(shard_name_template):
    match = re.search('N+', shard_name_template)
    if match:
      shard_name_template = shard_name_template.replace(
          match.group(0), '%%(num_shards)0%dd' % len(match.group(0)))
    return shard_name_template

  @staticmethod
  def _template_replace_shard_num(shard_name_template):
    match = re.search('S+', shard_name_template)
    if match is None:
      # shard name is required in the template.
      raise ValueError(
          "Shard number pattern S+ not found in shard_name_template: %s" %
          shard_name_template)
    return shard_name_template.replace(
        match.group(0), '%%(shard_num)0%dd' % len(match.group(0)))

  @staticmethod
  def _template_to_format(shard_name_template):
    if not shard_name_template:
      return ''
    # shard_num is required in the template, while others are optional.
    replace_funcs = [
        FileBasedSink._template_replace_shard_num,
        FileBasedSink._template_replace_num_shards,
        FileBasedSink._template_replace_uuid,
        FileBasedSink._template_replace_window
    ]
    for func in replace_funcs:
      shard_name_template = func(shard_name_template)
    return shard_name_template

  @staticmethod
  def _template_to_glob_format(shard_name_template):
    if not shard_name_template:
      return ''
    match = re.search('S+', shard_name_template)
    if match is None:
      raise ValueError(
          "Shard number pattern S+ not found in shard_name_template: %s" %
          shard_name_template)
    shard_name_format = shard_name_template.replace(match.group(0), '*')
    return FileBasedSink._template_replace_num_shards(shard_name_format)

  def __eq__(self, other):
    # TODO: Clean up workitem_test which uses this.
    # pylint: disable=unidiomatic-typecheck
    return type(self) == type(other) and self.__dict__ == other.__dict__


class FileBasedSinkWriter(iobase.Writer):
  """The writer for FileBasedSink.
  """
  def __init__(self, sink, temp_shard_path):
    self.sink = sink
    self.temp_shard_path = temp_shard_path
    self.temp_handle = self.sink.open(temp_shard_path)
    self.num_records_written = 0

  def write(self, value):
    self.num_records_written += 1
    self.sink.write_record(self.temp_handle, value)

  def at_capacity(self):
    return (
        self.sink.max_records_per_shard and
        self.num_records_written >= self.sink.max_records_per_shard) or (
            self.sink.max_bytes_per_shard and
            self.sink.byte_counter.bytes_written
            >= self.sink.max_bytes_per_shard)

  def close(self):
    self.sink.close(self.temp_handle)
    return self.temp_shard_path


class _ByteCountingWriter:
  def __init__(self, writer):
    self.writer = writer
    self.bytes_written = 0

  def write(self, bs):
    self.bytes_written += len(bs)
    self.writer.write(bs)

  def flush(self):
    self.writer.flush()

  def close(self):
    self.writer.close()
