blob: 6270f27055d95b7030892770f031f6ffd22ca415 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysds.runtime.matrix.operators;
import org.apache.sysds.lops.WeightedCrossEntropy.WCeMMType;
import org.apache.sysds.lops.WeightedDivMM.WDivMMType;
import org.apache.sysds.lops.WeightedSigmoid.WSigmoidType;
import org.apache.sysds.lops.WeightedSquaredLoss.WeightsType;
import org.apache.sysds.lops.WeightedUnaryMM.WUMMType;
import org.apache.sysds.runtime.functionobjects.Builtin;
import org.apache.sysds.runtime.functionobjects.Multiply2;
import org.apache.sysds.runtime.functionobjects.Power2;
import org.apache.sysds.runtime.functionobjects.ValueFunction;
public class QuaternaryOperator extends Operator
{
private static final long serialVersionUID = -1642908613016116069L;
public final WeightsType wtype1;
public final WSigmoidType wtype2;
public final WDivMMType wtype3;
public final WCeMMType wtype4;
public final WUMMType wtype5;
public final ValueFunction fn;
private final double eps;
private QuaternaryOperator( WeightsType wt1, WSigmoidType wt2, WDivMMType wt3, WCeMMType wt4, WUMMType wt5, ValueFunction fn, double eps ) {
wtype1 = wt1;
wtype2 = wt2;
wtype3 = wt3;
wtype4 = wt4;
wtype5 = wt5;
this.fn = fn;
this.eps = eps;
}
/**
* wsloss
*
* @param wt Weights type
*/
public QuaternaryOperator( WeightsType wt ) {
this(wt, null, null, null, null, null, 0);
}
/**
* wsigmoid
*
* @param wt WSigmoid type
*/
public QuaternaryOperator( WSigmoidType wt ) {
this(null, wt, null, null, null, Builtin.getBuiltinFnObject("sigmoid"), 0);
}
/**
* wdivmm
*
* @param wt WDivMM type
*/
public QuaternaryOperator( WDivMMType wt ) {
this(null, null, wt, null, null, null, 0);
}
/**
* wdivmm w/epsilon
*
* @param wt WDivMM type
* @param epsilon the epsilon value
*/
public QuaternaryOperator( WDivMMType wt, double epsilon) {
this(null, null, wt, null, null, null, epsilon);
}
/**
* wcemm
*
* @param wt WCeMM type
*/
public QuaternaryOperator( WCeMMType wt ) {
this(null, null, null, wt, null, null, 0);
}
/**
* wcemm w/epsilon
*
* @param wt WCeMM type
* @param epsilon the epsilon value
*/
public QuaternaryOperator( WCeMMType wt, double epsilon) {
this(null, null, null, wt, null, null, epsilon);
}
/**
* wumm
*
* @param wt WUMM type
* @param op operator type
*/
public QuaternaryOperator( WUMMType wt, String op ) {
this(null, null, null, null, wt,
op.equals("^2") ? Power2.getPower2FnObject() :
op.equals("*2") ? Multiply2.getMultiply2FnObject() :
Builtin.getBuiltinFnObject(op), 0);
}
public boolean hasFourInputs() {
return (wtype1 != null && wtype1.hasFourInputs())
|| (wtype3 != null && wtype3.hasFourInputs())
|| (wtype4 != null && wtype4.hasFourInputs());
}
/**
* Obtain epsilon value
*
* @return epsilon
*/
public double getScalar() {
return eps;
}
}