[Fix] Update ShapeView use in nccl.cc (#18352)
This PR fixes the use of ShapeView in nccl.cc, which was using
`Shape()->Product()`. This has been changed to `Shape().Product()`
with the introduction of ShapeView.
diff --git a/src/runtime/disco/nccl/nccl.cc b/src/runtime/disco/nccl/nccl.cc
index 2eb0c33..fd4ad06 100644
--- a/src/runtime/disco/nccl/nccl.cc
+++ b/src/runtime/disco/nccl/nccl.cc
@@ -150,13 +150,13 @@
const void* send_data = [&]() -> const void* {
if (is_sender) {
CHECK(send.defined());
- CHECK(send.value().Shape()->Product() == recv.Shape()->Product());
+ CHECK(send.value().Shape().Product() == recv.Shape().Product());
return send.value()->data;
} else {
return nullptr;
}
}();
- int64_t numel = recv.Shape()->Product();
+ int64_t numel = recv.Shape().Product();
deviceStream_t stream = ctx->GetDefaultStream();
NCCL_CALL(ncclBroadcast(send_data, recv->data, numel,
@@ -176,7 +176,7 @@
if (is_sender) {
CHECK(send.defined()) << "ValueError: buffer `send` must be provided when worker_id == 0.";
Tensor buffer = send.value();
- int64_t numel = buffer.Shape()->Product();
+ int64_t numel = buffer.Shape().Product();
CHECK_EQ(numel % num_receiver, 0) << "ValueError: Scattering evenly requires that the number "
"of elements in the buffer to be "
"divisible by the number of workers, but got numel = "
@@ -184,11 +184,11 @@
DataType dtype(buffer->dtype);
int64_t numel_per_shard = numel / num_receiver;
int64_t bytes_per_shard = numel_per_shard * dtype.bytes();
- CHECK_EQ(numel_per_shard, recv.Shape()->Product())
+ CHECK_EQ(numel_per_shard, recv.Shape().Product())
<< "ValueError: The number of elements in buffer `recv` must be the same as each shard "
"of "
"buffer `send`. `send.size` is "
- << numel << ", but `recv.size` is " << recv.Shape()->Product() << ".";
+ << numel << ", but `recv.size` is " << recv.Shape().Product() << ".";
NCCL_CALL(ncclGroupStart());
uint8_t* data = static_cast<uint8_t*>(buffer->data);
for (int i = 0; i < num_receiver; ++i) {
@@ -204,7 +204,7 @@
}
NCCL_CALL(ncclGroupStart());
}
- int64_t numel = recv.Shape()->Product();
+ int64_t numel = recv.Shape().Product();
DataType dtype(recv->dtype);
NCCL_CALL(ncclRecv(recv->data, numel, AsNCCLDataType(dtype), 0,
in_group ? ctx->group_comm : ctx->global_comm, stream));
@@ -223,7 +223,7 @@
if (is_sender) {
CHECK(recv.defined()) << "ValueError: buffer `recv` must be provided when worker_id == 0.";
Tensor buffer = recv.value();
- int64_t numel = buffer.Shape()->Product();
+ int64_t numel = buffer.Shape().Product();
CHECK_EQ(numel % num_receiver, 0) << "ValueError: Gathering evenly requires that the number "
"of elements in the buffer to be "
"divisible by the number of workers, but got numel = "
@@ -231,11 +231,11 @@
DataType dtype(buffer->dtype);
int64_t numel_per_shard = numel / num_receiver;
int64_t bytes_per_shard = numel_per_shard * dtype.bytes();
- CHECK_EQ(numel_per_shard, send.Shape()->Product())
+ CHECK_EQ(numel_per_shard, send.Shape().Product())
<< "ValueError: The number of elements in buffer `send` must be the same as each shard "
"of "
"buffer `recv`. `recv.size` is "
- << numel << ", but `send.size` is " << send.Shape()->Product() << ".";
+ << numel << ", but `send.size` is " << send.Shape().Product() << ".";
NCCL_CALL(ncclGroupStart());
uint8_t* data = static_cast<uint8_t*>(buffer->data);
for (int i = 0; i < num_receiver; ++i) {
@@ -251,7 +251,7 @@
}
NCCL_CALL(ncclGroupStart());
}
- int64_t numel = send.Shape()->Product();
+ int64_t numel = send.Shape().Product();
DataType dtype(send->dtype);
NCCL_CALL(ncclSend(send->data, numel, AsNCCLDataType(dtype), 0,
in_group ? ctx->group_comm : ctx->global_comm, stream));
@@ -264,7 +264,7 @@
CHECK_NE(ctx->worker->worker_id, 0)
<< "ValueError: Worker 0 is not allowed to call RecvFromWorker0.";
NCCL_CALL(ncclGroupStart());
- NCCL_CALL(ncclRecv(buffer->data, buffer.Shape()->Product(), AsNCCLDataType(buffer.DataType()), 0,
+ NCCL_CALL(ncclRecv(buffer->data, buffer.Shape().Product(), AsNCCLDataType(buffer.DataType()), 0,
ctx->global_comm, stream));
NCCL_CALL(ncclGroupEnd());
}
@@ -278,7 +278,7 @@
CHECK_LT(receiver_id, ctx->worker->num_workers)
<< "The current group is already the last group and there is no such a next group.";
NCCL_CALL(ncclGroupStart());
- NCCL_CALL(ncclSend(buffer->data, buffer.Shape()->Product(), AsNCCLDataType(buffer.DataType()),
+ NCCL_CALL(ncclSend(buffer->data, buffer.Shape().Product(), AsNCCLDataType(buffer.DataType()),
receiver_id, ctx->global_comm, stream));
NCCL_CALL(ncclGroupEnd());
}
@@ -292,7 +292,7 @@
CHECK_GE(sender_id, 0)
<< "The current group is already the first group and there is no such a previous group.";
NCCL_CALL(ncclGroupStart());
- NCCL_CALL(ncclRecv(buffer->data, buffer.Shape()->Product(), AsNCCLDataType(buffer.DataType()),
+ NCCL_CALL(ncclRecv(buffer->data, buffer.Shape().Product(), AsNCCLDataType(buffer.DataType()),
sender_id, ctx->global_comm, stream));
NCCL_CALL(ncclGroupEnd());
}
@@ -305,7 +305,7 @@
<< "Invalid receiver id " << receiver_id << ". The world size is "
<< ctx->worker->num_workers;
CHECK_NE(worker_id, receiver_id) << "Cannot send to worker itself.";
- NCCL_CALL(ncclSend(buffer->data, buffer.Shape()->Product(), AsNCCLDataType(buffer.DataType()),
+ NCCL_CALL(ncclSend(buffer->data, buffer.Shape().Product(), AsNCCLDataType(buffer.DataType()),
receiver_id, ctx->global_comm, stream));
}
@@ -316,7 +316,7 @@
CHECK(sender_id >= 0 && sender_id < ctx->worker->num_workers)
<< "Invalid sender id " << sender_id << ". The world size is " << ctx->worker->num_workers;
CHECK_NE(worker_id, sender_id) << "Cannot receive from the worker itself.";
- NCCL_CALL(ncclRecv(buffer->data, buffer.Shape()->Product(), AsNCCLDataType(buffer.DataType()),
+ NCCL_CALL(ncclRecv(buffer->data, buffer.Shape().Product(), AsNCCLDataType(buffer.DataType()),
sender_id, ctx->global_comm, stream));
}