HORN-9: Add activations package and ReLU function
Rename activations to funcs
diff --git a/src/main/java/org/apache/horn/funcs/CrossEntropy.java b/src/main/java/org/apache/horn/funcs/CrossEntropy.java
new file mode 100644
index 0000000..567db29
--- /dev/null
+++ b/src/main/java/org/apache/horn/funcs/CrossEntropy.java
@@ -0,0 +1,60 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.horn.funcs;
+
+import org.apache.hama.commons.math.DoubleDoubleFunction;
+
+/**
+ * The cross entropy cost function.
+ *
+ * <pre>
+ * cost(t, y) = - t * log(y) - (1 - t) * log(1 - y),
+ * where t denotes the target value, y denotes the estimated value.
+ * </pre>
+ */
+public class CrossEntropy extends DoubleDoubleFunction {
+
+ @Override
+ public double apply(double target, double actual) {
+ double adjustedTarget = (target == 0 ? 0.000001 : target);
+ adjustedTarget = (target == 1.0 ? 0.999999 : target);
+ double adjustedActual = (actual == 0 ? 0.000001 : actual);
+ adjustedActual = (actual == 1 ? 0.999999 : actual);
+ return -adjustedTarget * Math.log(adjustedActual) - (1 - adjustedTarget)
+ * Math.log(1 - adjustedActual);
+ }
+
+ @Override
+ public double applyDerivative(double target, double actual) {
+ double adjustedTarget = target;
+ double adjustedActual = actual;
+ if (adjustedActual == 1) {
+ adjustedActual = 0.999;
+ } else if (actual == 0) {
+ adjustedActual = 0.001;
+ }
+ if (adjustedTarget == 1) {
+ adjustedTarget = 0.999;
+ } else if (adjustedTarget == 0) {
+ adjustedTarget = 0.001;
+ }
+ return -adjustedTarget / adjustedActual + (1 - adjustedTarget)
+ / (1 - adjustedActual);
+ }
+
+}
diff --git a/src/main/java/org/apache/horn/funcs/Identity.java b/src/main/java/org/apache/horn/funcs/Identity.java
new file mode 100644
index 0000000..d8c8380
--- /dev/null
+++ b/src/main/java/org/apache/horn/funcs/Identity.java
@@ -0,0 +1,38 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.horn.funcs;
+
+import org.apache.hama.commons.math.DoubleFunction;
+
+/**
+ * The identity function f(x) = x.
+ *
+ */
+public class Identity extends DoubleFunction {
+
+ @Override
+ public double apply(double value) {
+ return value;
+ }
+
+ @Override
+ public double applyDerivative(double value) {
+ return 1;
+ }
+
+}
diff --git a/src/main/java/org/apache/horn/funcs/ReLU.java b/src/main/java/org/apache/horn/funcs/ReLU.java
new file mode 100644
index 0000000..425137f
--- /dev/null
+++ b/src/main/java/org/apache/horn/funcs/ReLU.java
@@ -0,0 +1,41 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.horn.funcs;
+
+import org.apache.hama.commons.math.DoubleFunction;
+
+/**
+ * The rectifier function
+ *
+ * <pre>
+ * f(x) = max(0, x)
+ * </pre>
+ */
+public class ReLU extends DoubleFunction {
+
+ @Override
+ public double apply(double value) {
+ return Math.max(0, value);
+ }
+
+ @Override
+ public double applyDerivative(double value) {
+ return (value > Double.MIN_VALUE) ? 1 : 0;
+ }
+
+}
diff --git a/src/main/java/org/apache/horn/funcs/Sigmoid.java b/src/main/java/org/apache/horn/funcs/Sigmoid.java
new file mode 100644
index 0000000..4472b8a
--- /dev/null
+++ b/src/main/java/org/apache/horn/funcs/Sigmoid.java
@@ -0,0 +1,41 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.horn.funcs;
+
+import org.apache.hama.commons.math.DoubleFunction;
+
+/**
+ * The Sigmoid function
+ *
+ * <pre>
+ * f(x) = 1 / (1 + e^{-x})
+ * </pre>
+ */
+public class Sigmoid extends DoubleFunction {
+
+ @Override
+ public double apply(double value) {
+ return 1.0 / (1 + Math.exp(-value));
+ }
+
+ @Override
+ public double applyDerivative(double value) {
+ return value * (1 - value);
+ }
+
+}
diff --git a/src/main/java/org/apache/horn/funcs/SquaredError.java b/src/main/java/org/apache/horn/funcs/SquaredError.java
new file mode 100644
index 0000000..081c53d
--- /dev/null
+++ b/src/main/java/org/apache/horn/funcs/SquaredError.java
@@ -0,0 +1,48 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.horn.funcs;
+
+import org.apache.hama.commons.math.DoubleDoubleFunction;
+
+/**
+ * Square error cost function.
+ *
+ * <pre>
+ * cost(t, y) = 0.5 * (t - y) ˆ 2
+ * </pre>
+ */
+public class SquaredError extends DoubleDoubleFunction {
+
+ @Override
+ /**
+ * {@inheritDoc}
+ */
+ public double apply(double target, double actual) {
+ double diff = target - actual;
+ return 0.5 * diff * diff;
+ }
+
+ @Override
+ /**
+ * {@inheritDoc}
+ */
+ public double applyDerivative(double target, double actual) {
+ return actual - target;
+ }
+
+}
diff --git a/src/main/java/org/apache/horn/funcs/Tanh.java b/src/main/java/org/apache/horn/funcs/Tanh.java
new file mode 100644
index 0000000..c7ced33
--- /dev/null
+++ b/src/main/java/org/apache/horn/funcs/Tanh.java
@@ -0,0 +1,38 @@
+/**
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+package org.apache.horn.funcs;
+
+import org.apache.hama.commons.math.DoubleFunction;
+
+/**
+ * Tanh function.
+ *
+ */
+public class Tanh extends DoubleFunction {
+
+ @Override
+ public double apply(double value) {
+ return Math.tanh(value);
+ }
+
+ @Override
+ public double applyDerivative(double value) {
+ return 1 - Math.pow(Math.tanh(value), 2);
+ }
+
+}