blob: a5e15f7a5e2821bf983ab637582843e7fb7c6e2e [file] [log] [blame]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
#include <chrono> // NOLINT(build/c++11)
#include <iostream>
#include <memory>
#include <thread> // NOLINT(build/c++11)
#include <utility>
#include <vector>
#include "gflags/gflags.h"
#include "tmb/address.h"
#include "tmb/id_typedefs.h"
#include "tmb/message_bus.h"
#include "tmb/message_style.h"
#include "tmb/priority.h"
#include "tmb/tagged_message.h"
#include "tmbbench/bus_setup.h"
#include "tmbbench/messages.h"
// Command-line flags.
DEFINE_uint64(test_duration, 20,
"Test duration in seconds. May run for longer if receiver "
"threads have trouble keeping up with send throughput.");
DEFINE_uint64(processes, 0,
"Number of processes running oneway_throughput_distributed");
DEFINE_uint64(sender_threads, 1,
"Number of sender threads per process to use. Must be at least "
"1.");
DEFINE_uint64(receiver_threads, 1,
"Number of receiver threads per process to use. Must be at "
"least 1.");
DEFINE_uint64(message_bytes, 8,
"Size of messages exchanged via the TMB in bytes. Must be at "
"least 8.");
DEFINE_bool(delete_messages_immediately, false,
"Whether to delete messages immediately as they are received, or "
"to delete separately afterwards.");
DEFINE_bool(verbose, false,
"Whether to enable verbose logging of experiments.");
int main(int argc, char *argv[]) {
gflags::ParseCommandLineFlags(&argc, &argv, true);
if (argc != 1) {
std::cerr << "Unrecognized command-line arguments.\n";
return 1;
}
if (FLAGS_message_bytes < 8) {
std::cerr << "message_bytes must be at least 8.\n";
return 1;
}
if (FLAGS_processes == 0) {
std::cerr << "processes must be at least 1.\n";
return 1;
}
std::unique_ptr<tmb::MessageBus> message_bus(
tmbbench::SetupBusAllInOneDistributed());
if (!message_bus) {
return 1;
}
tmb::client_id coordinator_id = message_bus->Connect();
message_bus->RegisterClientAsSender(coordinator_id, 1);
message_bus->RegisterClientAsSender(coordinator_id, 2);
message_bus->RegisterClientAsSender(coordinator_id, 3);
message_bus->RegisterClientAsReceiver(coordinator_id, 3);
message_bus->RegisterClientAsReceiver(coordinator_id, 4);
if (FLAGS_verbose) {
std::cout << "Coordinator connected and registered.\n";
}
tmb::Address address_all;
address_all.All(true);
tmb::MessageStyle broadcast_style;
broadcast_style.Broadcast(true);
tmbbench::RunDescription run_desc;
run_desc.num_senders = FLAGS_sender_threads;
run_desc.num_receivers = FLAGS_receiver_threads;
run_desc.message_bytes = FLAGS_message_bytes;
run_desc.delete_immediately = FLAGS_delete_messages_immediately;
tmb::TaggedMessage run_desc_msg(&run_desc, sizeof(run_desc), 2);
if (FLAGS_verbose) {
std::cout << "Sending setup message to other processes... ";
}
tmb::MessageBus::SendStatus status = message_bus->Send(
coordinator_id,
address_all,
broadcast_style,
std::move(run_desc_msg));
if (status == tmb::MessageBus::SendStatus::kOK) {
if (FLAGS_verbose) {
std::cout << "OK\n";
}
} else {
std::cerr << "ERROR sending setup message\n";
return 1;
}
std::vector<tmb::AnnotatedMessage> responses;
while (responses.size() < FLAGS_processes) {
message_bus->ReceiveBatch(coordinator_id, &responses, 0, 0, true);
}
if (FLAGS_verbose) {
std::cout << "Received replies from other processes.\n";
}
tmb::Address controllers_address;
std::vector<tmb::client_id> global_receiver_list;
for (const tmb::AnnotatedMessage &response : responses) {
if (response.tagged_message.message_type() != 3) {
std::cerr << "ERROR: unexpected response type\n";
return 1;
}
controllers_address.AddRecipient(response.sender);
global_receiver_list.insert(
global_receiver_list.end(),
static_cast<const tmb::client_id*>(
response.tagged_message.message()),
reinterpret_cast<const tmb::client_id*>(
static_cast<const char*>(response.tagged_message.message())
+ response.tagged_message.message_bytes()));
}
if (FLAGS_verbose) {
std::cout << "Constructed list of " << global_receiver_list.size()
<< " total receivers\n";
}
tmb::TaggedMessage receiver_list_msg(
global_receiver_list.data(),
global_receiver_list.size() * sizeof(tmb::client_id),
3);
if (FLAGS_verbose) {
std::cout << "Sending receiver list to other processes... ";
}
status = message_bus->Send(coordinator_id,
controllers_address,
broadcast_style,
std::move(receiver_list_msg));
if (status == tmb::MessageBus::SendStatus::kOK) {
if (FLAGS_verbose) {
std::cout << "OK\n";
}
} else {
std::cerr << "ERROR sending receiver list\n";
return 1;
}
if (FLAGS_verbose) {
std::cout << "Sleeping for " << FLAGS_test_duration << " seconds\n";
}
std::this_thread::sleep_for(std::chrono::seconds(FLAGS_test_duration));
tmb::TaggedMessage poison(&tmbbench::kPoisonMessage,
sizeof(tmbbench::kPoisonMessage),
1);
status = message_bus->Send(coordinator_id,
address_all,
broadcast_style,
std::move(poison),
tmb::kDefaultPriority + 1);
if (FLAGS_verbose) {
std::cout << "Sent poison message to all test threads.\n"
<< "Awaiting responses from other processes...\n";
}
responses.clear();
while (responses.size() < FLAGS_processes) {
message_bus->ReceiveBatch(coordinator_id, &responses, 0, 0, true);
}
message_bus->Disconnect(coordinator_id);
double total_send_throughput = 0.0;
double total_receive_throughput = 0.0;
for (const tmb::AnnotatedMessage &response : responses) {
if (response.tagged_message.message_type() != 4) {
std::cerr << "ERROR: unexpected response type\n";
return 1;
}
const tmbbench::ThroughputResult *result
= static_cast<const tmbbench::ThroughputResult*>(
response.tagged_message.message());
total_send_throughput += result->send_throughput;
total_receive_throughput += result->receive_throughput;
}
std::cout << "Send throughput: " << total_send_throughput
<< " messages/s\n";
std::cout << "Receive throughput: " << total_receive_throughput
<< " messages/s\n";
return 0;
}