blob: d69770e047579464ef2ec6c56b6e7e56fac3c7cf [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package opennlp.tools.ml;
import java.util.List;
import opennlp.tools.ml.model.Context;
/**
* Utility class for simple vector arithmetic.
*/
public class ArrayMath {
public static double innerProduct(double[] vecA, double[] vecB) {
if (vecA == null || vecB == null || vecA.length != vecB.length)
return Double.NaN;
double product = 0.0;
for (int i = 0; i < vecA.length; i++) {
product += vecA[i] * vecB[i];
}
return product;
}
/**
* L1-norm
*/
public static double l1norm(double[] v) {
double norm = 0;
for (int i = 0; i < v.length; i++)
norm += StrictMath.abs(v[i]);
return norm;
}
/**
* L2-norm
*/
public static double l2norm(double[] v) {
return StrictMath.sqrt(innerProduct(v, v));
}
/**
* Inverse L2-norm
*/
public static double invL2norm(double[] v) {
return 1 / l2norm(v);
}
/**
* Computes \log(\sum_{i=1}^n e^{x_i}) using a maximum-element trick
* to avoid arithmetic overflow.
*
* @param x input vector
* @return log-sum of exponentials of vector elements
*/
public static double logSumOfExps(double[] x) {
double max = max(x);
double sum = 0.0;
for (int i = 0; i < x.length; i++) {
if (x[i] != Double.NEGATIVE_INFINITY)
sum += StrictMath.exp(x[i] - max);
}
return max + StrictMath.log(sum);
}
public static double max(double[] x) {
int maxIdx = argmax(x);
return x[maxIdx];
}
/**
* Find index of maximum element in the vector x
* @param x input vector
* @return index of the maximum element. Index of the first
* maximum element is returned if multiple maximums are found.
*/
public static int argmax(double[] x) {
if (x == null || x.length == 0) {
throw new IllegalArgumentException("Vector x is null or empty");
}
int maxIdx = 0;
for (int i = 1; i < x.length; i++) {
if (x[maxIdx] < x[i])
maxIdx = i;
}
return maxIdx;
}
public static void sumFeatures(Context[] context, float[] values, double[] prior) {
for (int ci = 0; ci < context.length; ci++) {
if (context[ci] != null) {
Context predParams = context[ci];
int[] activeOutcomes = predParams.getOutcomes();
double[] activeParameters = predParams.getParameters();
double value = 1;
if (values != null) {
value = values[ci];
}
for (int ai = 0; ai < activeOutcomes.length; ai++) {
int oid = activeOutcomes[ai];
prior[oid] += activeParameters[ai] * value;
}
}
}
}
// === Not really related to math ===
/**
* Convert a list of Double objects into an array of primitive doubles
*/
public static double[] toDoubleArray(List<Double> list) {
double[] arr = new double[list.size()];
for (int i = 0; i < arr.length; i++) {
arr[i] = list.get(i);
}
return arr;
}
/**
* Convert a list of Integer objects into an array of primitive integers
*/
public static int[] toIntArray(List<Integer> list) {
int[] arr = new int[list.size()];
for (int i = 0; i < arr.length; i++) {
arr[i] = list.get(i);
}
return arr;
}
}