| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| #pragma once |
| |
| #include <Columns/ColumnNullable.h> |
| #include <Core/Block.h> |
| #include <Core/Field.h> |
| #include <Interpreters/Context.h> |
| #include <Parsers/ASTFunction.h> |
| #include <Parsers/ASTIdentifier.h> |
| #include <Parsers/ASTLiteral.h> |
| #include <Processors/Chunk.h> |
| #include <Processors/Executors/PushingPipelineExecutor.h> |
| #include <Processors/ISimpleTransform.h> |
| #include <Processors/Sinks/SinkToStorage.h> |
| #include <Storages/NativeOutputWriter.h> |
| #include <Storages/Output/OutputFormatFile.h> |
| #include <Storages/PartitionedSink.h> |
| #include <base/types.h> |
| #include <Common/ArenaUtils.h> |
| #include <Common/BlockTypeUtils.h> |
| #include <Common/CHUtil.h> |
| #include <Common/FieldAccurateComparison.h> |
| |
| namespace local_engine |
| { |
| |
| class NormalFileWriter : public NativeOutputWriter |
| { |
| public: |
| static std::unique_ptr<NativeOutputWriter> create( |
| const DB::ContextPtr & context, const std::string & file_uri, const DB::Block & preferred_schema, const std::string & format_hint); |
| |
| NormalFileWriter(const OutputFormatFilePtr & file_, const DB::ContextPtr & context_); |
| ~NormalFileWriter() override = default; |
| |
| void write(const DB::Block & block) override; |
| void close() override; |
| |
| private: |
| DB::Block castBlock(const DB::Block & block) const; |
| |
| OutputFormatFilePtr file; |
| DB::ContextPtr context; |
| |
| OutputFormatFile::OutputFormatPtr output_format; |
| std::unique_ptr<DB::QueryPipeline> pipeline; |
| std::unique_ptr<DB::PushingPipelineExecutor> writer; |
| }; |
| |
| OutputFormatFilePtr createOutputFormatFile( |
| const DB::ContextPtr & context, const std::string & file_uri, const DB::Block & preferred_schema, const std::string & format_hint); |
| |
| struct DeltaStats |
| { |
| size_t row_count; |
| // TODO Support delta.dataSkippingStatsColumns, detail see https://docs.databricks.com/aws/en/delta/data-skipping |
| size_t n_stats_cols; |
| std::vector<DB::Field> min; |
| std::vector<DB::Field> max; |
| std::vector<Int64> null_count; |
| std::set<size_t> partition_index; |
| |
| static DeltaStats create(const DB::Block & output, const DB::Names & partition) |
| { |
| size_t size = output.columns() - partition.size(); |
| std::set<size_t> partition_index; |
| std::ranges::transform( |
| partition, |
| std::inserter(partition_index, partition_index.end()), |
| [&](const auto & name) { return output.getPositionByName(name); }); |
| assert(partition_index.size() == partition.size()); |
| return DeltaStats(size, partition_index); |
| } |
| static DB::Block statsHeader(const DB::Block & output, const DB::Names & partition, DB::ColumnsWithTypeAndName && statsHeaderBase) |
| { |
| std::set<std::string> partition_index; |
| std::ranges::transform(partition, std::inserter(partition_index, partition_index.end()), [&](const auto & name) { return name; }); |
| |
| assert(partition_index.size() == partition.size()); |
| |
| size_t num_stats_cols = numStatsCols(output.columns() - partition.size()); |
| auto appendBase = [&](const std::string & prefix) |
| { |
| for (size_t i = 0, n = 0; i < output.columns() && n < num_stats_cols; i++) |
| { |
| const auto & column = output.getByPosition(i); |
| if (!partition_index.contains(column.name)) |
| { |
| statsHeaderBase.emplace_back(wrapNullableType(column.type), prefix + column.name); |
| ++n; |
| } |
| } |
| }; |
| appendBase("min_"); |
| appendBase("max_"); |
| for (size_t i = 0, n = 0; i < output.columns() && n < num_stats_cols; i++) |
| { |
| const auto & column = output.getByPosition(i); |
| if (!partition_index.contains(column.name)) |
| { |
| statsHeaderBase.emplace_back(BIGINT(), "null_count_" + column.name); |
| ++n; |
| } |
| } |
| |
| return DB::Block{statsHeaderBase}; |
| } |
| |
| static size_t numStatsCols(size_t origin) |
| { |
| if (DB::CurrentThread::isInitialized()) |
| { |
| const DB::ContextPtr query_context = DB::CurrentThread::get().getQueryContext(); |
| if (query_context) |
| { |
| SparkSQLConfig config = SparkSQLConfig::loadFromContext(query_context); |
| return std::min(config.deltaDataSkippingNumIndexedCols, origin); |
| } |
| } |
| return origin; |
| } |
| |
| explicit DeltaStats(size_t size, const std::set<size_t> & partition_index_ = {}) |
| : row_count(0) |
| , n_stats_cols(numStatsCols(size)) |
| , min(n_stats_cols) |
| , max(n_stats_cols) |
| , null_count(n_stats_cols, 0) |
| , partition_index(partition_index_) |
| { |
| assert(size > 0); |
| } |
| |
| bool initialized() const { return row_count > 0; } |
| |
| void update(const DB::Chunk & chunk) |
| { |
| assert(chunk.getNumRows() > 0); |
| const auto & columns = chunk.getColumns(); |
| assert(columns.size() - partition_index.size() >= n_stats_cols); |
| for (size_t i = 0, col = 0; i < n_stats_cols && col < columns.size(); ++col) |
| { |
| if (partition_index.contains(col)) |
| continue; |
| |
| const auto & column = columns[col]; |
| Int64 null_count = 0; |
| if (const auto * nullable_column = typeid_cast<const DB::ColumnNullable *>(column.get())) |
| { |
| const auto & null_map = nullable_column->getNullMapData(); |
| null_count = std::ranges::count_if(null_map, [](UInt8 value) { return value != 0; }); |
| } |
| this->null_count[i] += null_count; |
| |
| DB::Field min_value, max_value; |
| column->getExtremes(min_value, max_value); |
| |
| assert(min[i].isNull() || min_value.getType() == min[i].getType()); |
| assert(max[i].isNull() || max_value.getType() == max[i].getType()); |
| |
| if (!initialized()) |
| { |
| min[i] = min_value; |
| max[i] = max_value; |
| } |
| else |
| { |
| min[i] = accurateLess(min[i], min_value) ? min[i] : min_value; |
| max[i] = accurateLess(max[i], max_value) ? max_value : max[i]; |
| } |
| ++i; |
| } |
| |
| row_count += chunk.getNumRows(); |
| } |
| |
| void merge(const DeltaStats & right) |
| { |
| assert(n_stats_cols == right.n_stats_cols); |
| assert(partition_index == right.partition_index); |
| |
| for (size_t i = 0; i < n_stats_cols; ++i) |
| { |
| null_count[i] += right.null_count[i]; |
| min[i] = std::min(min[i], right.min[i]); |
| max[i] = std::max(max[i], right.max[i]); |
| } |
| } |
| }; |
| |
| class WriteStatsBase : public DB::ISimpleTransform |
| { |
| public: |
| /// visible for UTs |
| static const std::string NO_PARTITION_ID; |
| |
| protected: |
| bool all_chunks_processed_ = false; /// flag to determine if we have already processed all chunks |
| virtual DB::Chunk final_result() = 0; |
| |
| public: |
| WriteStatsBase(const DB::SharedHeader & input_header_, const DB::SharedHeader & output_header_) |
| : ISimpleTransform(input_header_, output_header_, true) |
| { |
| } |
| |
| Status prepare() override |
| { |
| if (input.isFinished() && !output.isFinished() && !has_input && !all_chunks_processed_) |
| { |
| all_chunks_processed_ = true; |
| /// return Ready to call transform() for generating filling rows after latest chunk was processed |
| return Status::Ready; |
| } |
| |
| return ISimpleTransform::prepare(); |
| } |
| |
| void transform(DB::Chunk & chunk) override |
| { |
| if (all_chunks_processed_) |
| chunk = final_result(); |
| else |
| chunk = {}; |
| } |
| }; |
| |
| class WriteStats : public WriteStatsBase |
| { |
| DB::MutableColumns columns_; |
| |
| enum ColumnIndex |
| { |
| filename, |
| partition_id, |
| record_count, |
| stats_column_start = record_count + 1 |
| }; |
| static DB::ColumnsWithTypeAndName statsHeaderBase() |
| { |
| return {{STRING(), "filename"}, {STRING(), "partition_id"}, {BIGINT(), "record_count"}}; |
| } |
| |
| protected: |
| DB::Chunk final_result() override |
| { |
| size_t rows = columns_[filename]->size(); |
| return DB::Chunk(std::move(columns_), rows); |
| } |
| |
| public: |
| WriteStats(const DB::SharedHeader & input_header_, const DB::SharedHeader & output_header_) |
| : WriteStatsBase(input_header_, output_header_), columns_(output_header_->cloneEmptyColumns()) |
| { |
| } |
| static std::shared_ptr<WriteStats> create(const DB::SharedHeader & input, const DB::Names & partition) |
| { |
| return std::make_shared<WriteStats>(input, toShared(DeltaStats::statsHeader(*input, partition, statsHeaderBase()))); |
| } |
| |
| String getName() const override { return "WriteStats"; } |
| |
| void collectStats(const String & filename, const String & partition_dir, const DeltaStats & stats) const |
| { |
| const std::string & partition = partition_dir.empty() ? WriteStatsBase::NO_PARTITION_ID : partition_dir; |
| size_t columnSize = stats.n_stats_cols; |
| assert(columns_.size() == stats_column_start + columnSize * 3); |
| |
| columns_[ColumnIndex::filename]->insertData(filename.c_str(), filename.size()); |
| columns_[partition_id]->insertData(partition.c_str(), partition.size()); |
| auto & countColData = static_cast<DB::ColumnVector<Int64> &>(*columns_[record_count]).getData(); |
| countColData.emplace_back(stats.row_count); |
| |
| for (int i = 0; i < columnSize; ++i) |
| { |
| size_t offset = stats_column_start + i; |
| columns_[offset]->insert(stats.min[i]); |
| columns_[columnSize + offset]->insert(stats.max[i]); |
| auto & nullCountData = static_cast<DB::ColumnVector<Int64> &>(*columns_[(columnSize * 2) + offset]).getData(); |
| nullCountData.emplace_back(stats.null_count[i]); |
| } |
| } |
| }; |
| |
| struct FileNameGenerator |
| { |
| // Align with org.apache.spark.sql.execution.FileNamePlaceHolder |
| static const std::vector<std::string> SUPPORT_PLACEHOLDERS; |
| // Align with placeholders above |
| const std::vector<bool> need_to_replace; |
| const std::string file_pattern; |
| |
| FileNameGenerator(const std::string & file_pattern) : file_pattern(file_pattern), need_to_replace(compute_need_to_replace(file_pattern)) |
| { |
| } |
| |
| std::vector<bool> compute_need_to_replace(const std::string & file_pattern) |
| { |
| std::vector<bool> result; |
| for (const std::string & placeholder : SUPPORT_PLACEHOLDERS) |
| if (file_pattern.find(placeholder) != std::string::npos) |
| result.push_back(true); |
| else |
| result.push_back(false); |
| return result; |
| } |
| |
| std::string generate(const std::string & bucket = "") const |
| { |
| std::string result = file_pattern; |
| if (need_to_replace[0]) // {id} |
| result = pattern_format(SUPPORT_PLACEHOLDERS[0], toString(DB::UUIDHelpers::generateV4())); |
| if (need_to_replace[1]) // {bucket} |
| result = pattern_format(SUPPORT_PLACEHOLDERS[1], bucket); |
| return result; |
| } |
| |
| std::string pattern_format(const std::string & arg, const std::string & replacement) const |
| { |
| std::string format_str = file_pattern; |
| size_t pos = format_str.find(arg); |
| while (pos != std::string::npos) |
| { |
| format_str.replace(pos, arg.length(), replacement); |
| pos = format_str.find(arg, pos + arg.length()); |
| } |
| return format_str; |
| } |
| }; |
| |
| class SubstraitFileSink final : public DB::SinkToStorage |
| { |
| const std::string partition_id_; |
| const bool bucketed_write_; |
| const std::string relative_path_; |
| OutputFormatFilePtr format_file_; |
| OutputFormatFile::OutputFormatPtr output_format_; |
| std::shared_ptr<WriteStats> stats_; |
| DeltaStats delta_stats_; |
| |
| static std::string makeAbsoluteFilename(const std::string & base_path, const std::string & partition_id, const std::string & relative) |
| { |
| if (partition_id.empty()) |
| return fmt::format("{}/{}", base_path, relative); |
| return fmt::format("{}/{}/{}", base_path, partition_id, relative); |
| } |
| |
| public: |
| explicit SubstraitFileSink( |
| const DB::ContextPtr & context, |
| const std::string & base_path, |
| const std::string & partition_id, |
| const bool bucketed_write, |
| const std::string & relative, |
| const std::string & format_hint, |
| const DB::SharedHeader header, |
| const std::shared_ptr<WriteStatsBase> & stats, |
| const DeltaStats & delta_stats) |
| : SinkToStorage(header) |
| , partition_id_(partition_id) |
| , bucketed_write_(bucketed_write) |
| , relative_path_(relative) |
| , format_file_(createOutputFormatFile(context, makeAbsoluteFilename(base_path, partition_id, relative), *header, format_hint)) |
| , stats_(std::dynamic_pointer_cast<WriteStats>(stats)) |
| , delta_stats_(delta_stats) |
| { |
| } |
| |
| String getName() const override { return "SubstraitFileSink"; } |
| |
| protected: |
| void consume(DB::Chunk & chunk) override |
| { |
| delta_stats_.update(chunk); |
| if (!output_format_) [[unlikely]] |
| output_format_ = format_file_->createOutputFormat(); |
| |
| const DB::Block & input_header = getHeader(); |
| if (bucketed_write_) |
| { |
| chunk.erase(input_header.columns() - 1); |
| const DB::ColumnsWithTypeAndName & cols = input_header.getColumnsWithTypeAndName(); |
| DB::ColumnsWithTypeAndName without_bucket_cols(cols.begin(), cols.end() - 1); |
| DB::Block without_bucket_header = DB::Block(without_bucket_cols); |
| output_format_->output->write(materializeBlock(without_bucket_header.cloneWithColumns(chunk.detachColumns()))); |
| } |
| else |
| output_format_->output->write(materializeBlock(input_header.cloneWithColumns(chunk.detachColumns()))); |
| } |
| void onFinish() override |
| { |
| if (output_format_) |
| { |
| output_format_->finalizeOutput(); |
| /// We need close reset output_format_ here before return to spark, because the file is closed in ~WriteBufferFromHDFSImpl(). |
| /// So that Spark Commit protocol can move the file safely. |
| output_format_.reset(); |
| assert(delta_stats_.row_count > 0); |
| if (stats_) |
| stats_->collectStats(relative_path_, partition_id_, delta_stats_); |
| } |
| } |
| void onCancel() noexcept override |
| { |
| if (output_format_) |
| { |
| output_format_->cancel(); |
| output_format_.reset(); |
| } |
| } |
| }; |
| |
| class SparkPartitionedBaseSink : public DB::PartitionedSink |
| { |
| public: |
| static const std::string DEFAULT_PARTITION_NAME; |
| static const std::string BUCKET_COLUMN_NAME; |
| |
| static bool isBucketedWrite(const DB::Block & input_header) |
| { |
| return input_header.has(BUCKET_COLUMN_NAME) && input_header.getPositionByName(BUCKET_COLUMN_NAME) == input_header.columns() - 1; |
| } |
| |
| /// visible for UTs |
| static DB::ASTPtr make_partition_expression(const DB::Names & partition_columns, const DB::Block & input_header) |
| { |
| /// Parse the following expression into ASTs |
| /// cancat('/col_name=', 'toString(col_name)') |
| bool add_slash = false; |
| DB::ASTs arguments; |
| for (const auto & column : partition_columns) |
| { |
| // partition_column= |
| auto column_name = std::make_shared<DB::ASTLiteral>(column); |
| auto escaped_name = makeASTFunction("sparkPartitionEscape", DB::ASTs{column_name}); |
| if (add_slash) |
| arguments.emplace_back(std::make_shared<DB::ASTLiteral>("/")); |
| add_slash = true; |
| arguments.emplace_back(escaped_name); |
| arguments.emplace_back(std::make_shared<DB::ASTLiteral>("=")); |
| |
| // ifNull(toString(partition_column), DEFAULT_PARTITION_NAME) |
| // FIXME if toString(partition_column) is empty |
| auto column_ast = makeASTFunction("toString", DB::ASTs{std::make_shared<DB::ASTIdentifier>(column)}); |
| auto escaped_value = makeASTFunction("sparkPartitionEscape", DB::ASTs{column_ast}); |
| DB::ASTs if_null_args{ |
| makeASTFunction("toString", DB::ASTs{escaped_value}), std::make_shared<DB::ASTLiteral>(DEFAULT_PARTITION_NAME)}; |
| arguments.emplace_back(makeASTFunction("ifNull", std::move(if_null_args))); |
| } |
| |
| if (isBucketedWrite(input_header)) |
| { |
| DB::ASTs args{std::make_shared<DB::ASTLiteral>("%05d"), std::make_shared<DB::ASTIdentifier>(BUCKET_COLUMN_NAME)}; |
| arguments.emplace_back(DB::makeASTFunction("printf", std::move(args))); |
| } |
| assert(!arguments.empty()); |
| if (arguments.size() == 1) |
| return arguments[0]; |
| return DB::makeASTFunction("concat", std::move(arguments)); |
| } |
| |
| private: |
| static std::shared_ptr<DB::IPartitionStrategy> |
| make_partition_strategy(const DB::ContextPtr & context, const DB::Names & partition_columns, const DB::Block & input_header) |
| { |
| DB::ASTPtr partition_by = make_partition_expression(partition_columns, input_header); |
| return DB::PartitionStrategyFactory::get( |
| DB::PartitionStrategyFactory::StrategyType::WILDCARD, |
| partition_by, |
| input_header.getNamesAndTypesList(), |
| context, |
| "", // format_name => no need |
| false, // globbed_path => no need |
| true, |
| true); |
| } |
| DB::SinkPtr createSinkForPartition(const String & partition_id) override |
| { |
| if (bucketed_write_) |
| { |
| std::string bucket_val = partition_id.substr(partition_id.length() - 5, 5); |
| std::string real_partition_id = partition_id.substr(0, partition_id.length() - 5); |
| return createSinkForPartition(real_partition_id, bucket_val); |
| } |
| return createSinkForPartition(partition_id, ""); |
| } |
| |
| virtual DB::SinkPtr createSinkForPartition(const String & partition_id, const String & bucket) = 0; |
| |
| protected: |
| DB::ContextPtr context_; |
| std::shared_ptr<WriteStatsBase> stats_; |
| DeltaStats empty_delta_stats_; |
| bool bucketed_write_; |
| |
| public: |
| SparkPartitionedBaseSink( |
| const DB::ContextPtr & context, |
| const DB::Names & partition_by, |
| const DB::SharedHeader & input_header, |
| const std::shared_ptr<WriteStatsBase> & stats) |
| : PartitionedSink(make_partition_strategy(context, partition_by, *input_header), context, input_header) |
| , context_(context) |
| , stats_(stats) |
| , empty_delta_stats_(DeltaStats::create(*input_header, partition_by)) |
| , bucketed_write_(isBucketedWrite(*input_header)) |
| { |
| } |
| }; |
| |
| class SubstraitPartitionedFileSink final : public SparkPartitionedBaseSink |
| { |
| const std::string base_path_; |
| const FileNameGenerator generator_; |
| const DB::SharedHeader input_header_; |
| const DB::SharedHeader sample_block_; |
| const std::string format_hint_; |
| |
| public: |
| SubstraitPartitionedFileSink( |
| const DB::ContextPtr & context, |
| const DB::Names & partition_by, |
| const DB::SharedHeader & input_header, |
| const DB::SharedHeader & sample_block, |
| const std::string & base_path, |
| const FileNameGenerator & generator, |
| const std::string & format_hint, |
| const std::shared_ptr<WriteStatsBase> & stats) |
| : SparkPartitionedBaseSink(context, partition_by, input_header, stats) |
| , base_path_(base_path) |
| , generator_(generator) |
| , sample_block_(sample_block) |
| , input_header_(input_header) |
| , format_hint_(format_hint) |
| { |
| } |
| |
| DB::SinkPtr createSinkForPartition(const String & partition_id, const String & bucket) override |
| { |
| assert(stats_); |
| bool bucketed_write = !bucket.empty(); |
| std::string filename = bucketed_write ? generator_.generate(bucket) : generator_.generate(); |
| const auto partition_path = fmt::format("{}/{}", partition_id, filename); |
| validatePartitionKey(partition_path, true); |
| return std::make_shared<SubstraitFileSink>( |
| context_, base_path_, partition_id, bucketed_write, filename, format_hint_, sample_block_, stats_, empty_delta_stats_); |
| } |
| String getName() const override { return "SubstraitPartitionedFileSink"; } |
| }; |
| } |