blob: 130f1bb01047dde56beb61943fe7fcba4a6ae1ce [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
#include "tests/VeloxShuffleWriterTestBase.h"
using namespace facebook::velox;
using namespace facebook::velox::test;
namespace gluten {
namespace {
class FakeBufferRssClient : public RssClient {
public:
FakeBufferRssClient() = default;
int32_t pushPartitionData(int32_t partitionId, const char* bytes, int64_t size) override {
receiveTimes_++;
return size;
}
void stop() override {}
const uint32_t getReceiveTimes() const {
return receiveTimes_;
}
private:
uint32_t receiveTimes_{0};
};
} // namespace
class VeloxSortShuffleWriterTest : public VeloxShuffleWriterTestBase, public testing::Test {
protected:
static void SetUpTestSuite() {
setUpVeloxBackend();
}
static void TearDownTestSuite() {
tearDownVeloxBackend();
}
std::shared_ptr<VeloxShuffleWriter> createShuffleWriter(
uint32_t numPartitions,
std::shared_ptr<SortShuffleWriterOptions> writeOptions,
std::shared_ptr<RssClient> rssClient) {
auto options = std::make_shared<RssPartitionWriterOptions>();
auto partitionWriter = std::make_shared<RssPartitionWriter>(
numPartitions, nullptr, getDefaultMemoryManager(), options, std::move(rssClient));
GLUTEN_ASSIGN_OR_THROW(
auto shuffleWriter,
VeloxSortShuffleWriter::create(
numPartitions, std::move(partitionWriter), std::move(writeOptions), getDefaultMemoryManager()));
return shuffleWriter;
}
};
TEST_F(VeloxSortShuffleWriterTest, pushCompleteRows) {
auto rssClient = std::make_shared<FakeBufferRssClient>();
auto writeOptions = std::make_shared<SortShuffleWriterOptions>();
// Make buffer size smallest to ensure each push only contains one row.
writeOptions->diskWriteBufferSize = 1;
auto rowVector = makeRowVector({
makeFlatVector<StringView>(
{"alice0",
"bob1",
"alice2",
"bob3",
"Alice4",
"Bob5123456789098766notinline",
"AlicE6",
"boB7",
"ALICE8",
"BOB9"}),
});
std::shared_ptr<ColumnarBatch> cb = std::make_shared<VeloxColumnarBatch>(rowVector);
auto shuffleWriter = createShuffleWriter(1, writeOptions, rssClient);
auto status = shuffleWriter->write(cb, ShuffleWriter::kMinMemLimit);
ASSERT_TRUE(shuffleWriter->stop().ok());
// numRows should equal to push data times in rss client.
EXPECT_EQ(10, rssClient->getReceiveTimes());
}
} // namespace gluten