blob: 6bea475d456d233b10718f0d5b392e4c0b0bf905 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysds.runtime.functionobjects;
import java.util.HashMap;
import org.apache.commons.math3.distribution.AbstractRealDistribution;
import org.apache.commons.math3.distribution.ChiSquaredDistribution;
import org.apache.commons.math3.distribution.ExponentialDistribution;
import org.apache.commons.math3.distribution.FDistribution;
import org.apache.commons.math3.distribution.NormalDistribution;
import org.apache.commons.math3.distribution.TDistribution;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.util.UtilFunctions;
/**
* Function object for builtin function that takes a list of name=value parameters.
* This class can not be instantiated elsewhere.
*/
public class ParameterizedBuiltin extends ValueFunction
{
private static final long serialVersionUID = -5966242955816522697L;
public enum ParameterizedBuiltinCode {
CDF, INVCDF, RMEMPTY, REPLACE, REXPAND, LOWER_TRI, UPPER_TRI,
TRANSFORMAPPLY, TRANSFORMDECODE, PARAMSERV }
public enum ProbabilityDistributionCode {
INVALID, NORMAL, EXP, CHISQ, F, T }
public ParameterizedBuiltinCode bFunc;
public ProbabilityDistributionCode distFunc;
static public HashMap<String, ParameterizedBuiltinCode> String2ParameterizedBuiltinCode;
static {
String2ParameterizedBuiltinCode = new HashMap<>();
String2ParameterizedBuiltinCode.put( "cdf", ParameterizedBuiltinCode.CDF);
String2ParameterizedBuiltinCode.put( "invcdf", ParameterizedBuiltinCode.INVCDF);
String2ParameterizedBuiltinCode.put( "rmempty", ParameterizedBuiltinCode.RMEMPTY);
String2ParameterizedBuiltinCode.put( "replace", ParameterizedBuiltinCode.REPLACE);
String2ParameterizedBuiltinCode.put( "lowertri", ParameterizedBuiltinCode.LOWER_TRI);
String2ParameterizedBuiltinCode.put( "uppertri", ParameterizedBuiltinCode.UPPER_TRI);
String2ParameterizedBuiltinCode.put( "rexpand", ParameterizedBuiltinCode.REXPAND);
String2ParameterizedBuiltinCode.put( "transformapply", ParameterizedBuiltinCode.TRANSFORMAPPLY);
String2ParameterizedBuiltinCode.put( "transformdecode", ParameterizedBuiltinCode.TRANSFORMDECODE);
String2ParameterizedBuiltinCode.put( "paramserv", ParameterizedBuiltinCode.PARAMSERV);
}
static public HashMap<String, ProbabilityDistributionCode> String2DistCode;
static {
String2DistCode = new HashMap<>();
String2DistCode.put("normal" , ProbabilityDistributionCode.NORMAL);
String2DistCode.put("exp" , ProbabilityDistributionCode.EXP);
String2DistCode.put("chisq" , ProbabilityDistributionCode.CHISQ);
String2DistCode.put("f" , ProbabilityDistributionCode.F);
String2DistCode.put("t" , ProbabilityDistributionCode.T);
}
// We should create one object for every builtin function that we support
private static ParameterizedBuiltin normalObj = null, expObj = null, chisqObj = null, fObj = null, tObj = null;
private static ParameterizedBuiltin inormalObj = null, iexpObj = null, ichisqObj = null, ifObj = null, itObj = null;
private ParameterizedBuiltin(ParameterizedBuiltinCode bf) {
bFunc = bf;
distFunc = ProbabilityDistributionCode.INVALID;
}
private ParameterizedBuiltin(ParameterizedBuiltinCode bf, ProbabilityDistributionCode dist) {
bFunc = bf;
distFunc = dist;
}
public static ParameterizedBuiltin getParameterizedBuiltinFnObject (String str) {
return getParameterizedBuiltinFnObject (str, null);
}
public static ParameterizedBuiltin getParameterizedBuiltinFnObject (String str, String str2) {
ParameterizedBuiltinCode code = String2ParameterizedBuiltinCode.get(str);
switch ( code )
{
case CDF:
// str2 will point the appropriate distribution
ProbabilityDistributionCode dcode = String2DistCode.get(str2.toLowerCase());
switch(dcode) {
case NORMAL:
if ( normalObj == null )
normalObj = new ParameterizedBuiltin(ParameterizedBuiltinCode.CDF, dcode);
return normalObj;
case EXP:
if ( expObj == null )
expObj = new ParameterizedBuiltin(ParameterizedBuiltinCode.CDF, dcode);
return expObj;
case CHISQ:
if ( chisqObj == null )
chisqObj = new ParameterizedBuiltin(ParameterizedBuiltinCode.CDF, dcode);
return chisqObj;
case F:
if ( fObj == null )
fObj = new ParameterizedBuiltin(ParameterizedBuiltinCode.CDF, dcode);
return fObj;
case T:
if ( tObj == null )
tObj = new ParameterizedBuiltin(ParameterizedBuiltinCode.CDF, dcode);
return tObj;
default:
throw new DMLRuntimeException("Invalid distribution code: " + dcode);
}
case INVCDF:
// str2 will point the appropriate distribution
ProbabilityDistributionCode distcode = String2DistCode.get(str2.toLowerCase());
switch(distcode) {
case NORMAL:
if ( inormalObj == null )
inormalObj = new ParameterizedBuiltin(ParameterizedBuiltinCode.INVCDF, distcode);
return inormalObj;
case EXP:
if ( iexpObj == null )
iexpObj = new ParameterizedBuiltin(ParameterizedBuiltinCode.INVCDF, distcode);
return iexpObj;
case CHISQ:
if ( ichisqObj == null )
ichisqObj = new ParameterizedBuiltin(ParameterizedBuiltinCode.INVCDF, distcode);
return ichisqObj;
case F:
if ( ifObj == null )
ifObj = new ParameterizedBuiltin(ParameterizedBuiltinCode.INVCDF, distcode);
return ifObj;
case T:
if ( itObj == null )
itObj = new ParameterizedBuiltin(ParameterizedBuiltinCode.INVCDF, distcode);
return itObj;
default:
throw new DMLRuntimeException("Invalid distribution code: " + distcode);
}
case RMEMPTY:
return new ParameterizedBuiltin(ParameterizedBuiltinCode.RMEMPTY);
case REPLACE:
return new ParameterizedBuiltin(ParameterizedBuiltinCode.REPLACE);
case LOWER_TRI:
return new ParameterizedBuiltin(ParameterizedBuiltinCode.LOWER_TRI);
case UPPER_TRI:
return new ParameterizedBuiltin(ParameterizedBuiltinCode.UPPER_TRI);
case REXPAND:
return new ParameterizedBuiltin(ParameterizedBuiltinCode.REXPAND);
case TRANSFORMAPPLY:
return new ParameterizedBuiltin(ParameterizedBuiltinCode.TRANSFORMAPPLY);
case TRANSFORMDECODE:
return new ParameterizedBuiltin(ParameterizedBuiltinCode.TRANSFORMDECODE);
case PARAMSERV:
return new ParameterizedBuiltin(ParameterizedBuiltinCode.PARAMSERV);
default:
throw new DMLRuntimeException("Invalid parameterized builtin code: " + code);
}
}
@Override
public double execute(HashMap<String,String> params) {
switch(bFunc) {
case CDF:
case INVCDF:
switch(distFunc) {
case NORMAL:
case EXP:
case CHISQ:
case F:
case T:
return computeFromDistribution(distFunc, params, (bFunc==ParameterizedBuiltinCode.INVCDF));
default:
throw new DMLRuntimeException("Unsupported distribution (" + distFunc + ").");
}
default:
throw new DMLRuntimeException("ParameterizedBuiltin.execute(): Unknown operation: " + bFunc);
}
}
/**
* Helper function to compute distribution-specific cdf (both lowertail and uppertail) and inverse cdf.
*
* @param dcode probablility distribution code
* @param params map of parameters
* @param inverse true if inverse
* @return cdf or inverse cdf
*/
private static double computeFromDistribution (ProbabilityDistributionCode dcode, HashMap<String,String> params, boolean inverse ) {
// given value is "quantile" when inverse=false, and it is "probability" when inverse=true
double val = Double.parseDouble(params.get("target"));
boolean lowertail = true;
if(params.get("lower.tail") != null) {
lowertail = Boolean.parseBoolean(params.get("lower.tail"));
}
AbstractRealDistribution distFunction = null;
switch(dcode) {
case NORMAL:
double mean = 0.0, sd = 1.0; // default values for mean and sd
String mean_s = params.get("mean"), sd_s = params.get("sd");
if(mean_s != null) mean = Double.parseDouble(mean_s);
if(sd_s != null) sd = Double.parseDouble(sd_s);
if ( sd <= 0 )
throw new DMLRuntimeException("Standard deviation for Normal distribution must be positive (" + sd + ")");
distFunction = new NormalDistribution(mean, sd);
break;
case EXP:
double exp_rate = 1.0; // default value for 1/mean or rate
if(params.get("rate") != null) exp_rate = Double.parseDouble(params.get("rate"));
if ( exp_rate <= 0 ) {
throw new DMLRuntimeException("Rate for Exponential distribution must be positive (" + exp_rate + ")");
}
// For exponential distribution: mean = 1/rate
distFunction = new ExponentialDistribution(1.0/exp_rate);
break;
case CHISQ:
if ( params.get("df") == null ) {
throw new DMLRuntimeException("" +
"Degrees of freedom must be specified for chi-squared distribution " +
"(e.g., q=qchisq(0.5, df=20); p=pchisq(target=q, df=1.2))");
}
int df = UtilFunctions.parseToInt(params.get("df"));
if ( df <= 0 ) {
throw new DMLRuntimeException("Degrees of Freedom for chi-squared distribution must be positive (" + df + ")");
}
distFunction = new ChiSquaredDistribution(df);
break;
case F:
if ( params.get("df1") == null || params.get("df2") == null ) {
throw new DMLRuntimeException("" +
"Degrees of freedom must be specified for F distribution " +
"(e.g., q = qf(target=0.5, df1=20, df2=30); p=pf(target=q, df1=20, df2=30))");
}
int df1 = UtilFunctions.parseToInt(params.get("df1"));
int df2 = UtilFunctions.parseToInt(params.get("df2"));
if ( df1 <= 0 || df2 <= 0) {
throw new DMLRuntimeException("Degrees of Freedom for F distribution must be positive (" + df1 + "," + df2 + ")");
}
distFunction = new FDistribution(df1, df2);
break;
case T:
if ( params.get("df") == null ) {
throw new DMLRuntimeException("" +
"Degrees of freedom is needed to compute probabilities from t distribution " +
"(e.g., q = qt(target=0.5, df=10); p = pt(target=q, df=10))");
}
int t_df = UtilFunctions.parseToInt(params.get("df"));
if ( t_df <= 0 ) {
throw new DMLRuntimeException("Degrees of Freedom for t distribution must be positive (" + t_df + ")");
}
distFunction = new TDistribution(t_df);
break;
default:
throw new DMLRuntimeException("Invalid distribution code: " + dcode);
}
double ret = Double.NaN;
if(inverse) {
// inverse cdf
ret = distFunction.inverseCumulativeProbability(val);
}
else if(lowertail) {
// cdf (lowertail)
ret = distFunction.cumulativeProbability(val);
}
else {
// cdf (upper tail)
// TODO: more accurate distribution-specific computation of upper tail probabilities
ret = 1.0 - distFunction.cumulativeProbability(val);
}
return ret;
}
}