Merge pull request #894 from NLGithubWP/update-distributed-train-cnn-dev
Update data augmentation implementation for cifar_distributed_cnn exa…
diff --git a/examples/cifar_distributed_cnn/train_cnn.py b/examples/cifar_distributed_cnn/train_cnn.py
old mode 100644
new mode 100755
index 26e0403..d102623
--- a/examples/cifar_distributed_cnn/train_cnn.py
+++ b/examples/cifar_distributed_cnn/train_cnn.py
@@ -170,9 +170,6 @@
train_x, train_y, val_x,
val_y)
-
-
-
if model.dimension == 4:
tx = tensor.Tensor(
(batch_size, num_channels, model.input_size, model.input_size), dev,
@@ -192,6 +189,7 @@
dev.SetVerbosity(verbosity)
model.train()
+ # Augmentation is done only once before training
b = 0
x = train_x[idx[b * batch_size:(b + 1) * batch_size]]
if model.dimension == 4: