Update the data types and tensor operations for training
diff --git a/examples/cnn_ms/train_cnn.py b/examples/cnn_ms/train_cnn.py
index 445301e..d7f8f70 100644
--- a/examples/cnn_ms/train_cnn.py
+++ b/examples/cnn_ms/train_cnn.py
@@ -414,11 +414,11 @@
synflow_flag = True
### step 1: all one input
# Copy the patch data into input tensors
- tx.copy_from_numpy(np.ones(x.shape))
+ tx.copy_from_numpy(np.ones(x.shape, dtype=np.float32))
ty.copy_from_numpy(y)
### step 2: all weights turned to positive (done)
### step 3: new loss (done)
- pn_p_g_list, out, loss = model(tx, ty, synflow_flag, dist_option, spars)
+ pn_p_g_list, out, loss = model(tx, ty,dist_option, spars, synflow_flag)
### step 4: calculate the multiplication of weights
synflow_score = 0.0
for pn_p_g_item in pn_p_g_list:
@@ -430,13 +430,13 @@
# Copy the patch data into input tensors
tx.copy_from_numpy(x)
ty.copy_from_numpy(y)
- pn_p_g_list, out, loss = model(tx, ty, synflow_flag, dist_option, spars)
+ pn_p_g_list, out, loss = model(tx, ty, dist_option, spars, synflow_flag)
train_correct += accuracy(tensor.to_numpy(out), y)
train_loss += tensor.to_numpy(loss)[0]
# all params turned to positive
for pn_p_g_item in pn_p_g_list:
print ("absolute value parameter name: \n", pn_p_g_item[0])
- pn_p_g_item[1].data = tensor.abs(pn_p_g_item[1].data)
+ pn_p_g_item[1] = tensor.abs(pn_p_g_item[1]) # tensor variables
else: # normal train steps
# Copy the patch data into input tensors
tx.copy_from_numpy(x)
@@ -491,7 +491,7 @@
description='Training using the autograd and graph.')
parser.add_argument(
'model',
- choices=['cnn', 'resnet', 'xceptionnet', 'mlp', 'alexnet'],
+ choices=['cnn', 'resnet', 'xceptionnet', 'mlp', 'msmlp', 'alexnet'],
default='cnn')
parser.add_argument('data',
choices=['mnist', 'cifar10', 'cifar100'],