Check dataset shape correctness
diff --git a/examples/largedataset_cnn/train_largedata.py b/examples/largedataset_cnn/train_largedata.py
index adba9ce..69097a4 100755
--- a/examples/largedataset_cnn/train_largedata.py
+++ b/examples/largedataset_cnn/train_largedata.py
@@ -143,7 +143,7 @@
train_x, train_y, val_x,
val_y)
'''
- # check dataset shape correctness
+ # Check dataset shape correctness
if global_rank == 0:
print("Check the shape of dataset:")
print(train_x.shape)
@@ -183,7 +183,7 @@
if global_rank == 0:
print('Starting Epoch %d:' % (epoch))
- # Training Phase
+ # Training phase
train_correct = np.zeros(shape=[1], dtype=np.float32)
test_correct = np.zeros(shape=[1], dtype=np.float32)
train_loss = np.zeros(shape=[1], dtype=np.float32)
@@ -218,7 +218,7 @@
(num_train_batch * batch_size * world_size)),
flush=True)
- # Evaluation Phase
+ # Evaluation phase
model.eval()
for b in range(num_val_batch):
x = val_x[b * batch_size:(b + 1) * batch_size]