/**
 *
 * Licensed to the Apache Software Foundation (ASF) under one or more
 * contributor license agreements.  See the NOTICE file distributed with
 * this work for additional information regarding copyright ownership.
 * The ASF licenses this file to You under the Apache License, Version 2.0
 * (the "License"); you may not use this file except in compliance with
 * the License.  You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */
#include "S3Wrapper.h"

#include <memory>
#include <utility>
#include <vector>
#include <algorithm>

#include "S3ClientRequestSender.h"
#include "utils/ArrayUtils.h"
#include "utils/StringUtils.h"
#include "utils/file/FileUtils.h"
#include "minifi-cpp/utils/gsl.h"
#include "utils/RegexUtils.h"
#include "aws/core/utils/HashingUtils.h"
#include "range/v3/algorithm/any_of.hpp"
#include "utils/GeneralUtils.h"

namespace org::apache::nifi::minifi::aws::s3 {

void HeadObjectResult::setFilePaths(const std::string& key) {
  absolute_path = std::filesystem::path(key, std::filesystem::path::format::generic_format);
  path = absolute_path.parent_path();
  filename = absolute_path.filename();
}

S3Wrapper::S3Wrapper(const Aws::Auth::AWSCredentials& credentials, const Aws::Client::ClientConfiguration& client_config, bool use_virtual_addressing)
    : request_sender_(std::make_unique<S3ClientRequestSender>(credentials, client_config, use_virtual_addressing)) {
}

S3Wrapper::S3Wrapper(std::unique_ptr<S3RequestSender>&& request_sender) : request_sender_(std::move(request_sender)) {
}

Expiration S3Wrapper::getExpiration(const std::string& expiration) {
  minifi::utils::Regex expr("expiry-date=\"(.*)\", rule-id=\"(.*)\"");
  minifi::utils::SMatch matches;
  const bool matched = minifi::utils::regexMatch(expiration, matches, expr);
  if (!matched || matches.size() < 3)
    return Expiration{};
  return Expiration{matches[1], matches[2]};
}

std::string S3Wrapper::getEncryptionString(Aws::S3Crt::Model::ServerSideEncryption encryption) {
  if (encryption == Aws::S3Crt::Model::ServerSideEncryption::NOT_SET) {
    return "";
  }

  const auto it = ranges::find(SERVER_SIDE_ENCRYPTION_MAP, encryption, [](const auto& kv) { return kv.second; });
  if (it != SERVER_SIDE_ENCRYPTION_MAP.end()) {
    return std::string{it->first};
  }
  return "";
}

std::shared_ptr<Aws::StringStream> S3Wrapper::readFlowFileStream(const std::shared_ptr<io::InputStream>& stream, uint64_t read_limit, uint64_t& read_size_out) {
  std::array<std::byte, BUFFER_SIZE> buffer{};
  auto data_stream = std::make_shared<Aws::StringStream>();
  uint64_t read_size = 0;
  while (read_size < read_limit) {
    const auto next_read_size = (std::min)(read_limit - read_size, uint64_t{BUFFER_SIZE});
    const auto read_ret = stream->read(std::span(buffer).subspan(0, next_read_size));
    if (io::isError(read_ret)) {
      throw StreamReadException("Reading flow file inputstream failed!");
    }
    if (read_ret > 0) {
      data_stream->write(reinterpret_cast<char*>(buffer.data()), gsl::narrow<std::streamsize>(read_ret));
      read_size += read_ret;
    } else {
      break;
    }
  }
  read_size_out = read_size;
  return data_stream;
}

std::optional<PutObjectResult> S3Wrapper::putObject(const PutObjectRequestParameters& put_object_params, const std::shared_ptr<io::InputStream>& stream, uint64_t flow_size) {
  uint64_t read_size{};
  auto data_stream = readFlowFileStream(stream, flow_size, read_size);
  auto request = createPutObjectRequest<Aws::S3Crt::Model::PutObjectRequest>(put_object_params);
  request.SetBody(data_stream);

  auto aws_result = request_sender_->sendPutObjectRequest(request);
  if (!aws_result) {
    return std::nullopt;
  }

  return createPutObjectResult(*aws_result);
}

std::optional<S3Wrapper::UploadPartsResult> S3Wrapper::uploadParts(const PutObjectRequestParameters& put_object_params, const std::shared_ptr<io::InputStream>& stream,
    MultipartUploadState upload_state) {
  gsl_Expects(multipart_upload_storage_);
  stream->seek(upload_state.uploaded_size);
  S3Wrapper::UploadPartsResult result;
  result.upload_id = upload_state.upload_id;
  result.part_etags = upload_state.uploaded_etags;
  const auto flow_size = upload_state.full_size - upload_state.uploaded_size;
  if (upload_state.part_size == 0) {
    logger_->log_error("Invalid upload part size 0 was set for S3 object with key '{}' in bucket '{}'", put_object_params.object_key, put_object_params.bucket);
    return std::nullopt;
  }
  const size_t part_count = minifi::utils::intdiv_ceil(flow_size, upload_state.part_size);
  size_t total_read = 0;
  const size_t start_part = upload_state.uploaded_parts + 1;
  const size_t last_part = start_part + part_count - 1;
  for (size_t part_number = start_part; part_number <= last_part; ++part_number) {
    uint64_t read_size{};
    const auto remaining = flow_size - total_read;
    const auto next_read_size = std::min(remaining, upload_state.part_size);
    auto stream_ptr = readFlowFileStream(stream, next_read_size, read_size);
    total_read += read_size;

    auto upload_part_request = Aws::S3Crt::Model::UploadPartRequest{}
      .WithBucket(put_object_params.bucket)
      .WithKey(put_object_params.object_key)
      .WithPartNumber(gsl::narrow<int>(part_number))
      .WithUploadId(upload_state.upload_id)
      .WithChecksumAlgorithm(put_object_params.checksum_algorithm);
    upload_part_request.SetBody(stream_ptr);

    Aws::Utils::ByteBuffer part_md5(Aws::Utils::HashingUtils::CalculateMD5(*stream_ptr));
    upload_part_request.SetContentMD5(Aws::Utils::HashingUtils::Base64Encode(part_md5));

    auto upload_part_result = request_sender_->sendUploadPartRequest(upload_part_request);
    if (!upload_part_result) {
      logger_->log_error("Failed to upload part {} of {} of S3 object with key '{}'", part_number, last_part, put_object_params.object_key);
      return std::nullopt;
    }
    result.part_etags.push_back(upload_part_result->GetETag());
    upload_state.uploaded_etags.push_back(upload_part_result->GetETag());
    upload_state.uploaded_parts += 1;
    upload_state.uploaded_size += read_size;
    multipart_upload_storage_->storeState(put_object_params.bucket, put_object_params.object_key, upload_state);
    logger_->log_info("Uploaded part {} of {} S3 object with key '{}'", part_number, last_part, put_object_params.object_key);
  }

  multipart_upload_storage_->removeState(put_object_params.bucket, put_object_params.object_key);
  return result;
}

std::optional<Aws::S3Crt::Model::CompleteMultipartUploadResult> S3Wrapper::completeMultipartUpload(const PutObjectRequestParameters& put_object_params,
    const S3Wrapper::UploadPartsResult& upload_parts_result) {
  auto complete_multipart_upload_request = Aws::S3Crt::Model::CompleteMultipartUploadRequest{}
    .WithBucket(put_object_params.bucket)
    .WithKey(put_object_params.object_key)
    .WithUploadId(upload_parts_result.upload_id);

  Aws::S3Crt::Model::CompletedMultipartUpload completed_multipart_upload;
  for (size_t i = 0; i < upload_parts_result.part_etags.size(); ++i) {
    auto part = Aws::S3Crt::Model::CompletedPart{}
      .WithETag(upload_parts_result.part_etags[i])
      .WithPartNumber(gsl::narrow<int>(i + 1));
    completed_multipart_upload.AddParts(part);
  }

  complete_multipart_upload_request.SetMultipartUpload(completed_multipart_upload);

  return request_sender_->sendCompleteMultipartUploadRequest(complete_multipart_upload_request);
}

bool S3Wrapper::multipartUploadExistsInS3(const PutObjectRequestParameters& put_object_params) {
  ListMultipartUploadsRequestParameters params;
  params.bucket = put_object_params.bucket;
  auto pending_uploads = listMultipartUploads(params);
  if (!pending_uploads) {
    return false;
  }

  return ranges::any_of(*pending_uploads, [&](const auto& upload) { return upload.key == put_object_params.object_key; });
}

std::optional<MultipartUploadState> S3Wrapper::getMultipartUploadState(const PutObjectRequestParameters& put_object_params) {
  gsl_Expects(multipart_upload_storage_);
  auto upload_state = multipart_upload_storage_->getState(put_object_params.bucket, put_object_params.object_key);
  if (!upload_state) {
    return std::nullopt;
  }
  if (!multipartUploadExistsInS3(put_object_params)) {
    logger_->log_info("Local upload state for object '{}' in bucket '{}' not found in S3, removing it from local cache.", put_object_params.object_key, put_object_params.bucket);
    multipart_upload_storage_->removeState(put_object_params.bucket, put_object_params.object_key);
    return std::nullopt;
  }
  return upload_state;
}

std::optional<PutObjectResult> S3Wrapper::putObjectMultipart(const PutObjectRequestParameters& put_object_params, const std::shared_ptr<io::InputStream>& stream,
    uint64_t flow_size, uint64_t multipart_size) {
  gsl_Expects(multipart_upload_storage_);
  if (auto upload_state = getMultipartUploadState(put_object_params)) {
    logger_->log_info("Found previous multipart upload state for {} in bucket {}, continuing upload", put_object_params.object_key, put_object_params.bucket);
    return uploadParts(put_object_params, stream, std::move(*upload_state))
      | minifi::utils::andThen([&, this](const auto& upload_parts_result) { return completeMultipartUpload(put_object_params, upload_parts_result); })
      | minifi::utils::transform([this](const auto& complete_multipart_upload_result) { return createPutObjectResult(complete_multipart_upload_result); });
  } else {
    logger_->log_debug("No previous multipart upload state was found for {} in bucket {}", put_object_params.object_key, put_object_params.bucket);
    auto request = createPutObjectRequest<Aws::S3Crt::Model::CreateMultipartUploadRequest>(put_object_params);
    return request_sender_->sendCreateMultipartUploadRequest(request)
      | minifi::utils::andThen([&, this](const auto& create_multipart_result) { return uploadParts(put_object_params, stream,
          MultipartUploadState{create_multipart_result.GetUploadId(), multipart_size, flow_size, Aws::Utils::DateTime::Now()}); })
      | minifi::utils::andThen([&, this](const auto& upload_parts_result) { return completeMultipartUpload(put_object_params, upload_parts_result); })
      | minifi::utils::transform([this](const auto& complete_multipart_upload_result) { return createPutObjectResult(complete_multipart_upload_result); });
  }
}

bool S3Wrapper::deleteObject(const DeleteObjectRequestParameters& params) {
  auto request = Aws::S3Crt::Model::DeleteObjectRequest{}
    .WithBucket(params.bucket)
    .WithKey(params.object_key);
  if (!params.version.empty()) {
    request.SetVersionId(params.version);
  }
  return request_sender_->sendDeleteObjectRequest(request);
}

int64_t S3Wrapper::writeFetchedBody(Aws::IOStream& source, const int64_t data_size, io::OutputStream& output) {
  std::vector<uint8_t> buffer(BUFFER_SIZE);
  size_t write_size = 0;
  if (data_size < 0) return 0;
  while (write_size < gsl::narrow<uint64_t>(data_size)) {
    const auto next_write_size = (std::min)(gsl::narrow<size_t>(data_size) - write_size, BUFFER_SIZE);
    if (!source.read(reinterpret_cast<char*>(buffer.data()), gsl::narrow<std::streamsize>(next_write_size))) {
      return -1;
    }
    const auto ret = output.write(buffer.data(), next_write_size);
    if (io::isError(ret)) {
      return -1;
    }
    write_size += next_write_size;
  }
  return gsl::narrow<int64_t>(write_size);
}

std::optional<GetObjectResult> S3Wrapper::getObject(const GetObjectRequestParameters& get_object_params, io::OutputStream& out_body) {
  auto request = createFetchObjectRequest<Aws::S3Crt::Model::GetObjectRequest>(get_object_params);
  auto aws_result = request_sender_->sendGetObjectRequest(request);
  if (!aws_result) {
    return std::nullopt;
  }
  auto result = fillFetchObjectResult<Aws::S3Crt::Model::GetObjectResult, GetObjectResult>(get_object_params, *aws_result);
  result.write_size = writeFetchedBody(aws_result->GetBody(), aws_result->GetContentLength(), out_body);
  return result;
}

void S3Wrapper::addListResults(const Aws::Vector<Aws::S3Crt::Model::ObjectVersion>& content, const uint64_t min_object_age, std::vector<ListedObjectAttributes>& listed_objects) {
  for (const auto& version : content) {
    if (last_bucket_list_timestamp_ - min_object_age < gsl::narrow<uint64_t>(version.GetLastModified().Millis())) {
      logger_->log_debug("Object version '{}' of key '{}' skipped due to minimum object age filter", version.GetVersionId(), version.GetKey());
      continue;
    }

    ListedObjectAttributes attributes;
    attributes.etag = minifi::utils::string::removeFramingCharacters(version.GetETag(), '"');
    attributes.filename = version.GetKey();
    attributes.is_latest = version.GetIsLatest();
    attributes.last_modified = version.GetLastModified().UnderlyingTimestamp();
    attributes.length = version.GetSize();
    try {
      attributes.store_class = minifi::utils::at(VERSION_STORAGE_CLASS_MAP, version.GetStorageClass());
    } catch (const std::out_of_range&) {
      logger_->log_warn("Unknown storage class '{}' for object version with key '{}'", magic_enum::enum_name(version.GetStorageClass()), version.GetKey());
      attributes.store_class = "Unknown";
    }
    attributes.version = version.GetVersionId();
    listed_objects.push_back(attributes);
  }
}

void S3Wrapper::addListResults(const Aws::Vector<Aws::S3Crt::Model::Object>& content, const uint64_t min_object_age, std::vector<ListedObjectAttributes>& listed_objects) {
  for (const auto& object : content) {
    if (last_bucket_list_timestamp_ - min_object_age < gsl::narrow<uint64_t>(object.GetLastModified().Millis())) {
      logger_->log_debug("Object with key '{}' skipped due to minimum object age filter", object.GetKey());
      continue;
    }

    ListedObjectAttributes attributes;
    attributes.etag = minifi::utils::string::removeFramingCharacters(object.GetETag(), '"');
    attributes.filename = object.GetKey();
    attributes.is_latest = true;
    attributes.last_modified = object.GetLastModified().UnderlyingTimestamp();
    attributes.length = object.GetSize();
    try {
      attributes.store_class = minifi::utils::at(OBJECT_STORAGE_CLASS_MAP, object.GetStorageClass());
    } catch (const std::out_of_range&) {
      logger_->log_warn("Unknown storage class '{}' for object with key '{}'", magic_enum::enum_name(object.GetStorageClass()), object.GetKey());
      attributes.store_class = "Unknown";
    }
    listed_objects.push_back(attributes);
  }
}

std::optional<std::vector<ListedObjectAttributes>> S3Wrapper::listVersions(const ListRequestParameters& params) {
  auto request = createListRequest<Aws::S3Crt::Model::ListObjectVersionsRequest>(params);
  std::vector<ListedObjectAttributes> attribute_list;
  std::optional<Aws::S3Crt::Model::ListObjectVersionsResult> aws_result;
  do {
    aws_result = request_sender_->sendListVersionsRequest(request);
    if (!aws_result) {
      return std::nullopt;
    }
    const auto& versions = aws_result->GetVersions();
    logger_->log_debug("AWS S3 List operation returned {} versions. This result is{} truncated.", versions.size(), aws_result->GetIsTruncated() ? "" : " not");
    addListResults(versions, params.min_object_age, attribute_list);
    if (aws_result->GetIsTruncated()) {
      request.WithKeyMarker(aws_result->GetNextKeyMarker())
        .WithVersionIdMarker(aws_result->GetNextVersionIdMarker());
    }
  } while (aws_result->GetIsTruncated());

  return attribute_list;
}

std::optional<std::vector<ListedObjectAttributes>> S3Wrapper::listObjects(const ListRequestParameters& params) {
  auto request = createListRequest<Aws::S3Crt::Model::ListObjectsV2Request>(params);
  std::vector<ListedObjectAttributes> attribute_list;
  std::optional<Aws::S3Crt::Model::ListObjectsV2Result> aws_result;
  do {
    aws_result = request_sender_->sendListObjectsRequest(request);
    if (!aws_result) {
      return std::nullopt;
    }
    const auto& objects = aws_result->GetContents();
    logger_->log_debug("AWS S3 List operation returned {} objects. This result is{} truncated.", objects.size(), aws_result->GetIsTruncated() ? "" : " not");
    addListResults(objects, params.min_object_age, attribute_list);
    if (aws_result->GetIsTruncated()) {
      request.SetContinuationToken(aws_result->GetNextContinuationToken());
    }
  } while (aws_result->GetIsTruncated());

  return attribute_list;
}

std::optional<std::vector<ListedObjectAttributes>> S3Wrapper::listBucket(const ListRequestParameters& params) {
  last_bucket_list_timestamp_ = gsl::narrow<uint64_t>(Aws::Utils::DateTime::CurrentTimeMillis());
  if (params.use_versions) {
    return listVersions(params);
  }
  return listObjects(params);
}

std::optional<std::map<std::string, std::string>> S3Wrapper::getObjectTags(const GetObjectTagsParameters& params) {
  auto request = Aws::S3Crt::Model::GetObjectTaggingRequest{}
    .WithBucket(params.bucket)
    .WithKey(params.object_key);
  if (!params.version.empty()) {
    request.SetVersionId(params.version);
  }
  auto aws_result = request_sender_->sendGetObjectTaggingRequest(request);
  if (!aws_result) {
    return std::nullopt;
  }
  std::map<std::string, std::string> tags;
  for (const auto& tag : aws_result->GetTagSet()) {
    tags.emplace(tag.GetKey(), tag.GetValue());
  }
  return tags;
}

std::optional<HeadObjectResult> S3Wrapper::headObject(const HeadObjectRequestParameters& head_object_params) {
  auto request = createFetchObjectRequest<Aws::S3Crt::Model::HeadObjectRequest>(head_object_params);
  auto aws_result = request_sender_->sendHeadObjectRequest(request);
  if (!aws_result) {
    return std::nullopt;
  }
  return fillFetchObjectResult<Aws::S3Crt::Model::HeadObjectResult, HeadObjectResult>(head_object_params, aws_result.value());
}

template<typename ListRequest>
ListRequest S3Wrapper::createListRequest(const ListRequestParameters& params) {
  auto request = ListRequest{}.WithBucket(params.bucket);
  if (!params.delimiter.empty()) {
    request.SetDelimiter(params.delimiter);
  }
  if (!params.prefix.empty()) {
    request.SetPrefix(params.prefix);
  }
  return request;
}

template<typename FetchObjectRequest>
FetchObjectRequest S3Wrapper::createFetchObjectRequest(const GetObjectRequestParameters& get_object_params) {
  auto request = FetchObjectRequest{}
    .WithBucket(get_object_params.bucket)
    .WithKey(get_object_params.object_key);
  if (!get_object_params.version.empty()) {
    request.SetVersionId(get_object_params.version);
  }
  if (get_object_params.requester_pays) {
    request.SetRequestPayer(Aws::S3Crt::Model::RequestPayer::requester);
  }
  return request;
}

template<typename AwsResult, typename FetchObjectResult>
FetchObjectResult S3Wrapper::fillFetchObjectResult(const GetObjectRequestParameters& get_object_params, const AwsResult& fetch_object_result) {
  FetchObjectResult result;
  result.setFilePaths(get_object_params.object_key);
  result.mime_type = fetch_object_result.GetContentType();
  result.etag = minifi::utils::string::removeFramingCharacters(fetch_object_result.GetETag(), '"');
  result.expiration = getExpiration(fetch_object_result.GetExpiration());
  result.ssealgorithm = getEncryptionString(fetch_object_result.GetServerSideEncryption());
  result.version = fetch_object_result.GetVersionId();
  for (const auto& metadata : fetch_object_result.GetMetadata()) {
    result.user_metadata_map.emplace(metadata.first, metadata.second);
  }
  return result;
}

void S3Wrapper::addListMultipartUploadResults(const Aws::Vector<Aws::S3Crt::Model::MultipartUpload>& uploads, std::optional<std::chrono::milliseconds> age_off_limit,
    std::vector<MultipartUpload>& filtered_uploads) {
  const auto now = Aws::Utils::DateTime::Now();
  for (const auto& upload : uploads) {
    // if age_off_limit is set only list the aged off uploads
    if (age_off_limit && now - upload.GetInitiated() <= *age_off_limit) {
      logger_->log_debug("Multipart upload with key '{}' and upload id '{}' has not aged off yet", upload.GetKey(), upload.GetUploadId());
      continue;
    }

    filtered_uploads.push_back({.key = upload.GetKey(), .upload_id = upload.GetUploadId()});
  }
}

std::optional<std::vector<MultipartUpload>> S3Wrapper::listMultipartUploads(const ListMultipartUploadsRequestParameters& params) {
  std::vector<MultipartUpload> result;
  std::optional<Aws::S3Crt::Model::ListMultipartUploadsResult> aws_result;
  Aws::S3Crt::Model::ListMultipartUploadsRequest request;
  request.SetBucket(params.bucket);
  do {
    aws_result = request_sender_->sendListMultipartUploadsRequest(request);
    if (!aws_result) {
      return std::nullopt;
    }
    const auto& uploads = aws_result->GetUploads();
    logger_->log_debug("AWS S3 List operation returned {} multipart uploads. This result is{} truncated.", uploads.size(), aws_result->GetIsTruncated() ? "" : " not");
    addListMultipartUploadResults(uploads, params.age_off_limit, result);
    if (aws_result->GetIsTruncated()) {
      request.SetKeyMarker(aws_result->GetNextKeyMarker());
    }
  } while (aws_result->GetIsTruncated());

  return result;
}

bool S3Wrapper::abortMultipartUpload(const AbortMultipartUploadRequestParameters& params) {
  auto request = Aws::S3Crt::Model::AbortMultipartUploadRequest{}
    .WithBucket(params.bucket)
    .WithKey(params.key)
    .WithUploadId(params.upload_id);
  return request_sender_->sendAbortMultipartUploadRequest(request);
}

void S3Wrapper::ageOffLocalS3MultipartUploadStates(std::chrono::milliseconds multipart_upload_max_age_threshold) {
  gsl_Expects(multipart_upload_storage_);
  multipart_upload_storage_->removeAgedStates(multipart_upload_max_age_threshold);
}

void S3Wrapper::initializeMultipartUploadStateStorage(core::StateManager& state_manager) {
  multipart_upload_storage_ = std::make_unique<MultipartUploadStateStorage>(state_manager);
}

}  // namespace org::apache::nifi::minifi::aws::s3
