Merge pull request #1170 from dcslin/feature/ms_model_mlp
update msmlp
diff --git a/examples/ms_model_mlp/model.py b/examples/ms_model_mlp/model.py
index b3fe116..454b382 100644
--- a/examples/ms_model_mlp/model.py
+++ b/examples/ms_model_mlp/model.py
@@ -83,21 +83,30 @@
class MSMLP(model.Model):
- def __init__(self, data_size=10, perceptron_size=100, num_classes=10):
+ def __init__(self, data_size=10, perceptron_size=100, num_classes=10, layer_hidden_list=[10,10,10,10]):
super(MSMLP, self).__init__()
self.num_classes = num_classes
self.dimension = 2
self.relu = layer.ReLU()
- self.linear1 = layer.Linear(perceptron_size)
- self.linear2 = layer.Linear(num_classes)
+ self.linear1 = layer.Linear(layer_hidden_list[0])
+ self.linear2 = layer.Linear(layer_hidden_list[1])
+ self.linear3 = layer.Linear(layer_hidden_list[2])
+ self.linear4 = layer.Linear(layer_hidden_list[3])
+ self.linear5 = layer.Linear(num_classes)
self.softmax_cross_entropy = layer.SoftMaxCrossEntropy()
self.sum_error = SumErrorLayer()
-
+
def forward(self, inputs):
y = self.linear1(inputs)
y = self.relu(y)
y = self.linear2(y)
+ y = self.relu(y)
+ y = self.linear3(y)
+ y = self.relu(y)
+ y = self.linear4(y)
+ y = self.relu(y)
+ y = self.linear5(y)
return y
def train_one_batch(self, x, y, dist_option, spars, synflow_flag):
@@ -144,6 +153,7 @@
def create_model(pretrained=False, **kwargs):
"""Constructs a CNN model.
+
Args:
pretrained (bool): If True, returns a pre-trained model.