Merge branch 'HORN-7' of https://github.com/edwardyoon/incubator-horn This closes #11
diff --git a/README.md b/README.md
index cf96ca5..fd3ca9b 100644
--- a/README.md
+++ b/README.md
@@ -23,3 +23,4 @@
## Getting Involved
Horn is an open source volunteer project under the Apache Software Foundation. We encourage you to learn about the project and contribute your expertise.
+
diff --git a/src/main/java/org/apache/horn/bsp/ParameterMerger.java b/src/main/java/org/apache/horn/bsp/ParameterMerger.java
new file mode 100644
index 0000000..709331b
--- /dev/null
+++ b/src/main/java/org/apache/horn/bsp/ParameterMerger.java
@@ -0,0 +1,10 @@
+package org.apache.horn.bsp;
+
+import org.apache.hama.commons.math.DoubleMatrix;
+import org.apache.hama.ipc.VersionedProtocol;
+
+public interface ParameterMerger extends VersionedProtocol {
+ long versionID = 1L;
+
+ SmallLayeredNeuralNetworkMessage merge(double trainingError, DoubleMatrix[] weightUpdates, DoubleMatrix[] prevWeightUpdates);
+}
diff --git a/src/main/java/org/apache/horn/bsp/ParameterMergerServer.java b/src/main/java/org/apache/horn/bsp/ParameterMergerServer.java
new file mode 100644
index 0000000..54caf2b
--- /dev/null
+++ b/src/main/java/org/apache/horn/bsp/ParameterMergerServer.java
@@ -0,0 +1,97 @@
+package org.apache.horn.bsp;
+
+import com.google.common.base.Preconditions;
+
+import org.apache.hama.commons.math.DoubleMatrix;
+import org.mortbay.log.Log;
+
+import java.io.IOException;
+import java.util.concurrent.atomic.AtomicBoolean;
+
+public class ParameterMergerServer implements ParameterMerger {
+ /* The parameter merge base. */
+ protected SmallLayeredNeuralNetwork inMemoryModel;
+
+ /* To terminate or not to terminate. */
+ protected AtomicBoolean isConverge;
+
+ /* The number of slave works that request commits. */
+ protected int SlaveCount;
+
+ /* After mergeLimit, terminate whether the result is converging or not. */
+ protected int mergeLimit;
+
+ /* last n training errors. converging is decided based on the average value of these errors. */
+ protected double[] trainingErrors;
+
+ /* If the average of last n training errors is smaller than this value, it is converging. */
+ protected double prevAvgTrainingError = Double.MAX_VALUE;
+
+ /* current index for trainingErrors. */
+ protected int curTrainingError = 0;
+
+ /* how many merges have been conducted? */
+ protected int mergeCount = 0;
+
+ public ParameterMergerServer(SmallLayeredNeuralNetwork inMemoryModel, AtomicBoolean isConverge,
+ int slaveCount, int mergeLimit, int convergenceCheckInterval) {
+ this.inMemoryModel = inMemoryModel;
+ this.isConverge = isConverge;
+ this.SlaveCount = slaveCount;
+ this.mergeLimit = mergeLimit;
+ this.trainingErrors = new double[convergenceCheckInterval];
+ }
+
+ @Override
+ public long getProtocolVersion(String s, long l) throws IOException {
+ return ParameterMerger.versionID;
+ }
+
+ @Override
+ public SmallLayeredNeuralNetworkMessage merge(double trainingError, DoubleMatrix[] weightUpdates,
+ DoubleMatrix[] prevWeightUpdates) {
+ Preconditions.checkArgument(weightUpdates.length == prevWeightUpdates.length);
+
+ Log.info(String.format("Start merging: %d.\n", this.mergeCount));
+
+ if (!this.isConverge.get()) {
+ for (int i = 0; i < weightUpdates.length; ++i) {
+ weightUpdates[i] = weightUpdates[i].divide(this.SlaveCount);
+ prevWeightUpdates[i] = prevWeightUpdates[i].divide(this.SlaveCount);
+ }
+
+ synchronized (inMemoryModel) {
+ this.inMemoryModel.updateWeightMatrices(weightUpdates);
+ this.inMemoryModel.setPrevWeightMatrices(prevWeightUpdates);
+
+ // add trainingError to trainingErrors
+ this.trainingErrors[this.curTrainingError++] = trainingError;
+
+ // check convergence
+ if (this.trainingErrors.length == this.curTrainingError) {
+ double curAvgTrainingError = 0.0;
+ for (int i = 0; i < this.curTrainingError; ++i) {
+ curAvgTrainingError += this.trainingErrors[i];
+ }
+ curAvgTrainingError /= this.trainingErrors.length;
+
+ if (prevAvgTrainingError < curAvgTrainingError) {
+ this.isConverge.set(true);
+ } else {
+ // update
+ prevAvgTrainingError = curAvgTrainingError;
+ this.curTrainingError = 0;
+ }
+ }
+
+ if (++this.mergeCount == this.mergeLimit) {
+ this.isConverge.set(true);
+ }
+ }
+ }
+
+ return new SmallLayeredNeuralNetworkMessage(
+ 0, this.isConverge.get(), this.inMemoryModel.getWeightMatrices(),
+ this.inMemoryModel.getPrevMatricesUpdates());
+ }
+}
diff --git a/src/main/java/org/apache/horn/bsp/SmallLayeredNeuralNetworkTrainer.java b/src/main/java/org/apache/horn/bsp/SmallLayeredNeuralNetworkTrainer.java
index 132ec8c..9e3d02f 100644
--- a/src/main/java/org/apache/horn/bsp/SmallLayeredNeuralNetworkTrainer.java
+++ b/src/main/java/org/apache/horn/bsp/SmallLayeredNeuralNetworkTrainer.java
@@ -17,7 +17,7 @@
*/
package org.apache.horn.bsp;
-import java.io.IOException;
+import com.google.common.base.Preconditions;
import org.apache.hadoop.conf.Configuration;
import org.apache.hadoop.io.LongWritable;
@@ -29,8 +29,13 @@
import org.apache.hama.commons.math.DenseDoubleMatrix;
import org.apache.hama.commons.math.DoubleMatrix;
import org.apache.hama.commons.math.DoubleVector;
+import org.apache.hama.ipc.RPC;
import org.mortbay.log.Log;
+import java.io.IOException;
+import java.net.InetSocketAddress;
+import java.util.concurrent.atomic.AtomicBoolean;
+
/**
* The trainer that train the {@link SmallLayeredNeuralNetwork} based on BSP
* framework.
@@ -39,21 +44,36 @@
public final class SmallLayeredNeuralNetworkTrainer
extends
BSP<LongWritable, VectorWritable, NullWritable, NullWritable, SmallLayeredNeuralNetworkMessage> {
-
+ /* When given peer is master worker: base of parameter merge */
+ /* When given peer is slave worker: neural network for training */
private SmallLayeredNeuralNetwork inMemoryModel;
+
+ /* Job configuration */
private Configuration conf;
+
/* Default batch size */
private int batchSize;
- /* check the interval between intervals */
- private double prevAvgTrainingError;
- private double curAvgTrainingError;
- private long convergenceCheckInterval;
- private long iterations;
- private long maxIterations;
- private boolean isConverge;
+ /* whether it is converging or not */
+ private AtomicBoolean isConverge;
- private String modelPath;
+ /* When given peer is master worker: Asynchronous parameter merger */
+ /* When given peer is slave worker: null */
+ private RPC.Server merger;
+
+ /* When given peer is master worker: null */
+ /* When given peer is slave worker: proxy to Asynchronous parameter merger */
+ private ParameterMerger proxy;
+
+ /**
+ * Returns true if this worker is master worker.
+ *
+ * @param peer
+ * */
+ private boolean isMaster(
+ BSPPeer<LongWritable, VectorWritable, NullWritable, NullWritable, SmallLayeredNeuralNetworkMessage> peer) {
+ return peer.getPeerIndex() == 0;
+ }
@Override
/**
@@ -61,20 +81,40 @@
*/
public void setup(
BSPPeer<LongWritable, VectorWritable, NullWritable, NullWritable, SmallLayeredNeuralNetworkMessage> peer) {
- if (peer.getPeerIndex() == 0) {
- Log.info("Begin to train");
- }
- this.isConverge = false;
- this.conf = peer.getConfiguration();
- this.iterations = 0;
- this.modelPath = conf.get("modelPath");
- this.maxIterations = conf.getLong("training.max.iterations", 100000);
- this.convergenceCheckInterval = conf.getLong("convergence.check.interval",
- 2000);
- this.modelPath = conf.get("modelPath");
+ // At least one master & slave worker exist.
+ Preconditions.checkArgument(peer.getNumPeers() >= 2);
+
+ String modelPath = conf.get("modelPath");
this.inMemoryModel = new SmallLayeredNeuralNetwork(modelPath);
- this.prevAvgTrainingError = Integer.MAX_VALUE;
+ this.conf = peer.getConfiguration();
this.batchSize = conf.getInt("training.batch.size", 50);
+ this.isConverge = new AtomicBoolean(false);
+
+ int slaveCount = peer.getNumPeers() - 1;
+ int mergeLimit = conf.getInt("training.max.iterations", 100000);
+ int convergenceCheckInterval = peer.getNumPeers() * conf.getInt("convergence.check.interval",
+ 2000);
+ String master = peer.getPeerName();
+ String masterAddr = master.substring(0, master.indexOf(':'));
+ int port = conf.getInt("sync.server.port", 40042);
+
+ if (isMaster(peer)) {
+ try {
+ this.merger = RPC.getServer(new ParameterMergerServer(inMemoryModel, isConverge, slaveCount,
+ mergeLimit, convergenceCheckInterval), masterAddr, port, conf);
+ merger.start();
+ } catch (IOException e) {
+ e.printStackTrace();
+ }
+ Log.info("Begin to train");
+ } else {
+ InetSocketAddress addr = new InetSocketAddress(masterAddr, port);
+ try {
+ this.proxy = (ParameterMerger) RPC.getProxy(ParameterMerger.class, ParameterMerger.versionID, addr, conf);
+ } catch (IOException e) {
+ e.printStackTrace();
+ }
+ }
}
@Override
@@ -84,10 +124,8 @@
public void cleanup(
BSPPeer<LongWritable, VectorWritable, NullWritable, NullWritable, SmallLayeredNeuralNetworkMessage> peer) {
// write model to modelPath
- if (peer.getPeerIndex() == 0) {
+ if (isMaster(peer)) {
try {
- Log.info(String.format("End of training, number of iterations: %d.\n",
- this.iterations));
Log.info(String.format("Write model back to %s\n",
inMemoryModel.getModelPath()));
this.inMemoryModel.writeModelToFile();
@@ -101,18 +139,11 @@
public void bsp(
BSPPeer<LongWritable, VectorWritable, NullWritable, NullWritable, SmallLayeredNeuralNetworkMessage> peer)
throws IOException, SyncException, InterruptedException {
- while (this.iterations++ < maxIterations) {
- // each groom calculate the matrices updates according to local data
- calculateUpdates(peer);
- peer.sync();
-
- // master merge the updates model
- if (peer.getPeerIndex() == 0) {
- mergeUpdates(peer);
- }
- peer.sync();
- if (this.isConverge) {
- break;
+ if (!isMaster(peer)) {
+ while (!this.isConverge.get()) {
+ // each slave-worker calculate the matrices updates according to local data
+ // and merge them with master
+ calculateUpdates(peer);
}
}
}
@@ -126,20 +157,6 @@
private void calculateUpdates(
BSPPeer<LongWritable, VectorWritable, NullWritable, NullWritable, SmallLayeredNeuralNetworkMessage> peer)
throws IOException {
- // receive update information from master
- if (peer.getNumCurrentMessages() != 0) {
- SmallLayeredNeuralNetworkMessage inMessage = peer.getCurrentMessage();
- DoubleMatrix[] newWeights = inMessage.getCurMatrices();
- DoubleMatrix[] preWeightUpdates = inMessage.getPrevMatrices();
- this.inMemoryModel.setWeightMatrices(newWeights);
- this.inMemoryModel.setPrevWeightMatrices(preWeightUpdates);
- this.isConverge = inMessage.isConverge();
- // check converge
- if (isConverge) {
- return;
- }
- }
-
DoubleMatrix[] weightUpdates = new DoubleMatrix[this.inMemoryModel.weightMatrixList
.size()];
for (int i = 0; i < weightUpdates.length; ++i) {
@@ -169,76 +186,14 @@
weightUpdates[i] = weightUpdates[i].divide(batchSize);
}
- DoubleMatrix[] prevWeightUpdates = this.inMemoryModel
- .getPrevMatricesUpdates();
- SmallLayeredNeuralNetworkMessage outMessage = new SmallLayeredNeuralNetworkMessage(
- avgTrainingError, false, weightUpdates, prevWeightUpdates);
- peer.send(peer.getPeerName(0), outMessage);
- }
-
- /**
- * Merge the updates according to the updates of the grooms.
- *
- * @param peer
- * @throws IOException
- */
- private void mergeUpdates(
- BSPPeer<LongWritable, VectorWritable, NullWritable, NullWritable, SmallLayeredNeuralNetworkMessage> peer)
- throws IOException {
- int numMessages = peer.getNumCurrentMessages();
- boolean isConverge = false;
- if (numMessages == 0) { // converges
- isConverge = true;
- return;
- }
-
- double avgTrainingError = 0;
- DoubleMatrix[] matricesUpdates = null;
- DoubleMatrix[] prevMatricesUpdates = null;
-
- while (peer.getNumCurrentMessages() > 0) {
- SmallLayeredNeuralNetworkMessage message = peer.getCurrentMessage();
- if (matricesUpdates == null) {
- matricesUpdates = message.getCurMatrices();
- prevMatricesUpdates = message.getPrevMatrices();
- } else {
- SmallLayeredNeuralNetwork.matricesAdd(matricesUpdates,
- message.getCurMatrices());
- SmallLayeredNeuralNetwork.matricesAdd(prevMatricesUpdates,
- message.getPrevMatrices());
- }
- avgTrainingError += message.getTrainingError();
- }
-
- if (numMessages != 1) {
- avgTrainingError /= numMessages;
- for (int i = 0; i < matricesUpdates.length; ++i) {
- matricesUpdates[i] = matricesUpdates[i].divide(numMessages);
- prevMatricesUpdates[i] = prevMatricesUpdates[i].divide(numMessages);
- }
- }
- this.inMemoryModel.updateWeightMatrices(matricesUpdates);
- this.inMemoryModel.setPrevWeightMatrices(prevMatricesUpdates);
-
- // check convergence
- if (iterations % convergenceCheckInterval == 0) {
- if (prevAvgTrainingError < curAvgTrainingError) {
- // error cannot decrease any more
- isConverge = true;
- }
- // update
- prevAvgTrainingError = curAvgTrainingError;
- curAvgTrainingError = 0;
- }
- curAvgTrainingError += avgTrainingError / convergenceCheckInterval;
-
- // broadcast updated weight matrices
- for (String peerName : peer.getAllPeerNames()) {
- SmallLayeredNeuralNetworkMessage msg = new SmallLayeredNeuralNetworkMessage(
- 0, isConverge, this.inMemoryModel.getWeightMatrices(),
- this.inMemoryModel.getPrevMatricesUpdates());
- peer.send(peerName, msg);
- }
+ // exchange parameter update with master
+ SmallLayeredNeuralNetworkMessage inMessage = proxy.merge(avgTrainingError, weightUpdates,
+ this.inMemoryModel.getWeightMatrices());
+ DoubleMatrix[] newWeights = inMessage.getCurMatrices();
+ DoubleMatrix[] preWeightUpdates = inMessage.getPrevMatrices();
+ this.inMemoryModel.setWeightMatrices(newWeights);
+ this.inMemoryModel.setPrevWeightMatrices(preWeightUpdates);
+ this.isConverge.set(inMessage.isConverge());
}
}
diff --git a/src/main/java/org/apache/horn/funcs/CrossEntropy.java b/src/main/java/org/apache/horn/funcs/CrossEntropy.java
new file mode 100644
index 0000000..567db29
--- /dev/null
+++ b/src/main/java/org/apache/horn/funcs/CrossEntropy.java
@@ -0,0 +1,60 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.horn.funcs;
+
+import org.apache.hama.commons.math.DoubleDoubleFunction;
+
+/**
+ * The cross entropy cost function.
+ *
+ * <pre>
+ * cost(t, y) = - t * log(y) - (1 - t) * log(1 - y),
+ * where t denotes the target value, y denotes the estimated value.
+ * </pre>
+ */
+public class CrossEntropy extends DoubleDoubleFunction {
+
+ @Override
+ public double apply(double target, double actual) {
+ double adjustedTarget = (target == 0 ? 0.000001 : target);
+ adjustedTarget = (target == 1.0 ? 0.999999 : target);
+ double adjustedActual = (actual == 0 ? 0.000001 : actual);
+ adjustedActual = (actual == 1 ? 0.999999 : actual);
+ return -adjustedTarget * Math.log(adjustedActual) - (1 - adjustedTarget)
+ * Math.log(1 - adjustedActual);
+ }
+
+ @Override
+ public double applyDerivative(double target, double actual) {
+ double adjustedTarget = target;
+ double adjustedActual = actual;
+ if (adjustedActual == 1) {
+ adjustedActual = 0.999;
+ } else if (actual == 0) {
+ adjustedActual = 0.001;
+ }
+ if (adjustedTarget == 1) {
+ adjustedTarget = 0.999;
+ } else if (adjustedTarget == 0) {
+ adjustedTarget = 0.001;
+ }
+ return -adjustedTarget / adjustedActual + (1 - adjustedTarget)
+ / (1 - adjustedActual);
+ }
+
+}
diff --git a/src/main/java/org/apache/horn/funcs/Identity.java b/src/main/java/org/apache/horn/funcs/Identity.java
new file mode 100644
index 0000000..d8c8380
--- /dev/null
+++ b/src/main/java/org/apache/horn/funcs/Identity.java
@@ -0,0 +1,38 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.horn.funcs;
+
+import org.apache.hama.commons.math.DoubleFunction;
+
+/**
+ * The identity function f(x) = x.
+ *
+ */
+public class Identity extends DoubleFunction {
+
+ @Override
+ public double apply(double value) {
+ return value;
+ }
+
+ @Override
+ public double applyDerivative(double value) {
+ return 1;
+ }
+
+}
diff --git a/src/main/java/org/apache/horn/funcs/ReLU.java b/src/main/java/org/apache/horn/funcs/ReLU.java
new file mode 100644
index 0000000..425137f
--- /dev/null
+++ b/src/main/java/org/apache/horn/funcs/ReLU.java
@@ -0,0 +1,41 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.horn.funcs;
+
+import org.apache.hama.commons.math.DoubleFunction;
+
+/**
+ * The rectifier function
+ *
+ * <pre>
+ * f(x) = max(0, x)
+ * </pre>
+ */
+public class ReLU extends DoubleFunction {
+
+ @Override
+ public double apply(double value) {
+ return Math.max(0, value);
+ }
+
+ @Override
+ public double applyDerivative(double value) {
+ return (value > Double.MIN_VALUE) ? 1 : 0;
+ }
+
+}
diff --git a/src/main/java/org/apache/horn/funcs/Sigmoid.java b/src/main/java/org/apache/horn/funcs/Sigmoid.java
new file mode 100644
index 0000000..cc393e3
--- /dev/null
+++ b/src/main/java/org/apache/horn/funcs/Sigmoid.java
@@ -0,0 +1,41 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.horn.funcs;
+
+import org.apache.hama.commons.math.DoubleFunction;
+
+/**
+ * The Sigmoid function
+ *
+ * <pre>
+ * f(x) = 1 / (1 + e^{-x})
+ * </pre>
+ */
+public class Sigmoid extends DoubleFunction {
+
+ @Override
+ public double apply(double value) {
+ return 1.0 / (1 + Math.exp(-value));
+ }
+
+ @Override
+ public double applyDerivative(double value) {
+ return apply(value) * (1 - apply(value));
+ }
+
+}
diff --git a/src/main/java/org/apache/horn/funcs/SquaredError.java b/src/main/java/org/apache/horn/funcs/SquaredError.java
new file mode 100644
index 0000000..081c53d
--- /dev/null
+++ b/src/main/java/org/apache/horn/funcs/SquaredError.java
@@ -0,0 +1,48 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.horn.funcs;
+
+import org.apache.hama.commons.math.DoubleDoubleFunction;
+
+/**
+ * Square error cost function.
+ *
+ * <pre>
+ * cost(t, y) = 0.5 * (t - y) ˆ 2
+ * </pre>
+ */
+public class SquaredError extends DoubleDoubleFunction {
+
+ @Override
+ /**
+ * {@inheritDoc}
+ */
+ public double apply(double target, double actual) {
+ double diff = target - actual;
+ return 0.5 * diff * diff;
+ }
+
+ @Override
+ /**
+ * {@inheritDoc}
+ */
+ public double applyDerivative(double target, double actual) {
+ return actual - target;
+ }
+
+}
diff --git a/src/main/java/org/apache/horn/funcs/Tanh.java b/src/main/java/org/apache/horn/funcs/Tanh.java
new file mode 100644
index 0000000..c7ced33
--- /dev/null
+++ b/src/main/java/org/apache/horn/funcs/Tanh.java
@@ -0,0 +1,38 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.horn.funcs;
+
+import org.apache.hama.commons.math.DoubleFunction;
+
+/**
+ * Tanh function.
+ *
+ */
+public class Tanh extends DoubleFunction {
+
+ @Override
+ public double apply(double value) {
+ return Math.tanh(value);
+ }
+
+ @Override
+ public double applyDerivative(double value) {
+ return 1 - Math.pow(Math.tanh(value), 2);
+ }
+
+}