| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| #include "shuffle/RoundRobinPartitioner.h" |
| #include <glog/logging.h> |
| #include <gtest/gtest.h> |
| #include <numeric> |
| |
| namespace gluten { |
| class RoundRobinPartitionerTest : public ::testing::Test { |
| protected: |
| void prepareData(int numPart, int seed) { |
| partitioner_ = std::make_shared<RoundRobinPartitioner>(numPart, seed); |
| row2Partition_.clear(); |
| } |
| |
| void checkResult(const std::vector<uint32_t>& expectRow2Part) const { |
| ASSERT_EQ(row2Partition_, expectRow2Part); |
| } |
| |
| void traceCheckResult(const std::vector<uint32_t>& expectRow2Part) const { |
| toString(expectRow2Part, "expectRow2Part"); |
| toString(row2Partition_, "row2Partition_"); |
| ASSERT_EQ(row2Partition_, expectRow2Part); |
| } |
| |
| template <typename T> |
| void toString(const std::vector<T>& vec, const std::string& name) const { |
| std::stringstream ss; |
| ss << name << " = ["; |
| std::copy(vec.cbegin(), vec.cend(), std::ostream_iterator<T>(ss, ",")); |
| ss << " ]"; |
| LOG(INFO) << ss.str(); |
| } |
| |
| int32_t getPidSelection() const { |
| return partitioner_->pidSelection_; |
| } |
| |
| std::vector<uint32_t> row2Partition_; |
| std::shared_ptr<RoundRobinPartitioner> partitioner_; |
| }; |
| |
| TEST_F(RoundRobinPartitionerTest, TestInit) { |
| int numPart = 2; |
| prepareData(numPart, 3); |
| ASSERT_NE(partitioner_, nullptr); |
| int32_t pidSelection = getPidSelection(); |
| ASSERT_EQ(pidSelection, 1); |
| } |
| |
| TEST_F(RoundRobinPartitionerTest, TestComoputeNormal) { |
| // numRows equal numPart |
| { |
| int numPart = 10; |
| prepareData(numPart, 0); |
| int numRows = 10; |
| ASSERT_TRUE(partitioner_->compute(nullptr, numRows, row2Partition_).ok()); |
| ASSERT_EQ(getPidSelection(), 0); |
| std::vector<uint32_t> row2Part(numRows); |
| std::iota(row2Part.begin(), row2Part.end(), 0); |
| checkResult(row2Part); |
| } |
| |
| // numRows less than numPart |
| { |
| int numPart = 10; |
| prepareData(numPart, 0); |
| int numRows = 8; |
| ASSERT_TRUE(partitioner_->compute(nullptr, numRows, row2Partition_).ok()); |
| ASSERT_EQ(getPidSelection(), 8); |
| std::vector<uint32_t> row2Part(numRows); |
| std::iota(row2Part.begin(), row2Part.end(), 0); |
| checkResult(row2Part); |
| } |
| |
| // numRows greater than numPart |
| { |
| int numPart = 10; |
| prepareData(numPart, 0); |
| int numRows = 12; |
| ASSERT_TRUE(partitioner_->compute(nullptr, numRows, row2Partition_).ok()); |
| ASSERT_EQ(getPidSelection(), 2); |
| std::vector<uint32_t> row2Part(numRows); |
| std::generate_n(row2Part.begin(), numRows, [n = 0, numPart]() mutable { return (n++) % numPart; }); |
| checkResult(row2Part); |
| } |
| |
| // numRows greater than 2*numPart |
| { |
| int numPart = 10; |
| prepareData(numPart, 0); |
| int numRows = 22; |
| ASSERT_TRUE(partitioner_->compute(nullptr, numRows, row2Partition_).ok()); |
| ASSERT_EQ(getPidSelection(), 2); |
| std::vector<uint32_t> row2Part(numRows); |
| std::generate_n(row2Part.begin(), numRows, [n = 0, numPart]() mutable { return (n++) % numPart; }); |
| checkResult(row2Part); |
| } |
| } |
| |
| TEST_F(RoundRobinPartitionerTest, TestComoputeContinuous) { |
| int numPart = 10; |
| prepareData(numPart, 0); |
| |
| { |
| int numRows = 8; |
| ASSERT_TRUE(partitioner_->compute(nullptr, numRows, row2Partition_).ok()); |
| ASSERT_EQ(getPidSelection(), 8); |
| std::vector<uint32_t> row2Part(numRows); |
| std::generate_n(row2Part.begin(), numRows, [n = 0, numPart]() mutable { return (n++) % numPart; }); |
| checkResult(row2Part); |
| } |
| |
| { |
| int numRows = 10; |
| ASSERT_TRUE(partitioner_->compute(nullptr, numRows, row2Partition_).ok()); |
| ASSERT_EQ(getPidSelection(), 8); |
| std::vector<uint32_t> row2Part(numRows); |
| std::generate_n(row2Part.begin(), numRows, [n = 8, numPart]() mutable { return (n++) % numPart; }); |
| checkResult(row2Part); |
| } |
| |
| { |
| int numRows = 12; |
| ASSERT_TRUE(partitioner_->compute(nullptr, numRows, row2Partition_).ok()); |
| ASSERT_EQ(getPidSelection(), 0); |
| std::vector<uint32_t> row2Part(numRows); |
| std::generate_n(row2Part.begin(), numRows, [n = 8, numPart]() mutable { return (n++) % numPart; }); |
| checkResult(row2Part); |
| } |
| |
| { |
| int numRows = 22; |
| ASSERT_TRUE(partitioner_->compute(nullptr, numRows, row2Partition_).ok()); |
| ASSERT_EQ(getPidSelection(), 2); |
| std::vector<uint32_t> row2Part(numRows); |
| std::generate_n(row2Part.begin(), numRows, [n = 0, numPart]() mutable { return (n++) % numPart; }); |
| checkResult(row2Part); |
| } |
| } |
| |
| } // namespace gluten |