| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| #include "VeloxBatchResizer.h" |
| |
| namespace gluten { |
| namespace { |
| |
| class SliceRowVector : public ColumnarBatchIterator { |
| public: |
| SliceRowVector(int32_t maxOutputBatchSize, facebook::velox::RowVectorPtr in) |
| : maxOutputBatchSize_(maxOutputBatchSize), in_(in) {} |
| |
| std::shared_ptr<ColumnarBatch> next() override { |
| int32_t remainingLength = in_->size() - cursor_; |
| GLUTEN_CHECK(remainingLength >= 0, "Invalid state"); |
| if (remainingLength == 0) { |
| return nullptr; |
| } |
| int32_t sliceLength = std::min(maxOutputBatchSize_, remainingLength); |
| auto out = std::dynamic_pointer_cast<facebook::velox::RowVector>(in_->slice(cursor_, sliceLength)); |
| cursor_ += sliceLength; |
| GLUTEN_CHECK(out != nullptr, "Invalid state"); |
| return std::make_shared<VeloxColumnarBatch>(out); |
| } |
| |
| private: |
| int32_t maxOutputBatchSize_; |
| facebook::velox::RowVectorPtr in_; |
| int32_t cursor_ = 0; |
| }; |
| } // namespace |
| |
| gluten::VeloxBatchResizer::VeloxBatchResizer( |
| facebook::velox::memory::MemoryPool* pool, |
| int32_t minOutputBatchSize, |
| int32_t maxOutputBatchSize, |
| int64_t preferredBatchBytes, |
| std::unique_ptr<ColumnarBatchIterator> in) |
| : pool_(pool), |
| minOutputBatchSize_(minOutputBatchSize), |
| maxOutputBatchSize_(maxOutputBatchSize), |
| preferredBatchBytes_(static_cast<uint64_t>(preferredBatchBytes)), |
| in_(std::move(in)) { |
| GLUTEN_CHECK( |
| minOutputBatchSize_ > 0 && maxOutputBatchSize_ > 0, |
| "Either minOutputBatchSize or maxOutputBatchSize should be larger than 0"); |
| } |
| |
| std::shared_ptr<ColumnarBatch> VeloxBatchResizer::next() { |
| if (next_) { |
| auto next = next_->next(); |
| if (next != nullptr) { |
| return next; |
| } |
| // Cached output was drained. Continue reading data from input iterator. |
| next_ = nullptr; |
| } |
| |
| auto cb = in_->next(); |
| if (cb == nullptr) { |
| // Input iterator was drained. |
| return nullptr; |
| } |
| |
| uint64_t numBytes = cb->numBytes(); |
| if (cb->numRows() < minOutputBatchSize_ && numBytes <= preferredBatchBytes_) { |
| auto vb = VeloxColumnarBatch::from(pool_, cb); |
| auto rv = vb->getRowVector(); |
| auto buffer = facebook::velox::RowVector::createEmpty(rv->type(), pool_); |
| buffer->append(rv.get()); |
| |
| for (cb = in_->next(); cb != nullptr; cb = in_->next()) { |
| vb = VeloxColumnarBatch::from(pool_, cb); |
| rv = vb->getRowVector(); |
| uint64_t addedBytes = cb->numBytes(); |
| if (buffer->size() + rv->size() > maxOutputBatchSize_ || |
| numBytes + addedBytes > static_cast<uint64_t>(preferredBatchBytes_)) { |
| GLUTEN_CHECK(next_ == nullptr, "Invalid state"); |
| next_ = std::make_unique<SliceRowVector>(maxOutputBatchSize_, rv); |
| return std::make_shared<VeloxColumnarBatch>(buffer); |
| } |
| numBytes += addedBytes; |
| buffer->append(rv.get()); |
| if (buffer->size() >= minOutputBatchSize_) { |
| // Buffer is full. |
| break; |
| } |
| // Call reset manully to potentially release memory |
| rv.reset(); |
| vb.reset(); |
| cb.reset(); |
| } |
| return std::make_shared<VeloxColumnarBatch>(buffer); |
| } |
| |
| if (cb->numRows() > maxOutputBatchSize_) { |
| auto vb = VeloxColumnarBatch::from(pool_, cb); |
| auto rv = vb->getRowVector(); |
| GLUTEN_CHECK(next_ == nullptr, "Invalid state"); |
| next_ = std::make_unique<SliceRowVector>(maxOutputBatchSize_, rv); |
| auto next = next_->next(); |
| GLUTEN_CHECK(next != nullptr, "Invalid state"); |
| return next; |
| } |
| |
| // Fast flush path. |
| return cb; |
| } |
| |
| int64_t VeloxBatchResizer::spillFixedSize(int64_t size) { |
| return in_->spillFixedSize(size); |
| } |
| |
| } // namespace gluten |