blob: 4c47d597e92a55ca19e71856ad4a0bf909af6c48 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysds.runtime.functionobjects;
import org.apache.sysds.runtime.DMLRuntimeException;
import org.apache.sysds.runtime.instructions.cp.CM_COV_Object;
import org.apache.sysds.runtime.instructions.cp.Data;
import org.apache.sysds.runtime.instructions.cp.KahanObject;
import org.apache.sysds.runtime.matrix.operators.CMOperator.AggregateOperationTypes;
/**
* GENERAL NOTE:
* * 05/28/2014: We decided to do handle weights consistently to SPSS in an operation-specific manner,
* i.e., we (1) round instead of casting where required (e.g. count), and (2) consistently use
* fractional weight values elsewhere. In case a count-base interpretation of weights is needed, just
* ensure rounding before calling CM/COV/KahanPlus.
*
*/
public class CM extends ValueFunction
{
private static final long serialVersionUID = 9177194651533064123L;
private AggregateOperationTypes _type = null;
//helper function objects for specific types
private KahanPlus _plus = null;
private KahanObject _buff2 = null;
private KahanObject _buff3 = null;
private CM( AggregateOperationTypes type )
{
_type = type;
switch( _type ) //helper obj on demand
{
case COUNT:
break;
case CM4:
case CM3:
_buff3 = new KahanObject(0, 0);
case CM2:
_buff2 = new KahanObject(0, 0);
case VARIANCE:
case MEAN:
_plus = KahanPlus.getKahanPlusFnObject();
break;
default:
//do nothing
}
}
public static CM getCMFnObject( AggregateOperationTypes type ) {
//return new obj, required for correctness in multi-threaded
//execution due to state in cm object (buff2, buff3)
return new CM( type );
}
public AggregateOperationTypes getAggOpType() {
return _type;
}
/**
* Special case for weights w2==1
*/
@Override
public Data execute(Data in1, double in2) {
CM_COV_Object cm1=(CM_COV_Object) in1;
if(cm1.isCMAllZeros()) {
cm1.w=1;
cm1.mean.set(in2, 0);
cm1.m2.set(0,0);
cm1.m3.set(0,0);
cm1.m4.set(0,0);
return cm1;
}
switch( _type )
{
case COUNT:
{
cm1.w = cm1.w + 1;
break;
}
case MEAN:
{
double w= cm1.w + 1;
double d=in2-cm1.mean._sum;
cm1.mean=(KahanObject) _plus.execute(cm1.mean, d/w);
cm1.w=w;
break;
}
case CM2:
{
double w= cm1.w + 1;
double d=in2-cm1.mean._sum;
cm1.mean=(KahanObject) _plus.execute(cm1.mean, d/w);
double t1=cm1.w/w*d;
double lt1=t1*d;
_buff2.set(cm1.m2);
_buff2=(KahanObject) _plus.execute(_buff2, lt1);
cm1.m2.set(_buff2);
cm1.w=w;
break;
}
case CM3:
{
double w = cm1.w + 1;
double d=in2-cm1.mean._sum;
cm1.mean=(KahanObject) _plus.execute(cm1.mean, d/w);
double t1=cm1.w/w*d;
double t2=-1/cm1.w;
double lt1=t1*d;
double lt2=Math.pow(t1, 3)*(1.0-Math.pow(t2, 2));
double f2=1.0/w;
_buff2.set(cm1.m2);
_buff2=(KahanObject) _plus.execute(_buff2, lt1);
_buff3.set(cm1.m3);
_buff3=(KahanObject) _plus.execute(_buff3, lt2-3*cm1.m2._sum*f2*d);
cm1.m2.set(_buff2);
cm1.m3.set(_buff3);
cm1.w=w;
break;
}
case CM4:
{
double w=cm1.w+1;
double d=in2-cm1.mean._sum;
cm1.mean=(KahanObject) _plus.execute(cm1.mean, d/w);
double t1=cm1.w/w*d;
double t2=-1/cm1.w;
double lt1=t1*d;
double lt2=Math.pow(t1, 3)*(1.0-Math.pow(t2, 2));
double lt3=Math.pow(t1, 4)*(1.0-Math.pow(t2, 3));
double f2=1.0/w;
_buff2.set(cm1.m2);
_buff2=(KahanObject) _plus.execute(_buff2, lt1);
_buff3.set(cm1.m3);
_buff3=(KahanObject) _plus.execute(_buff3, lt2-3*cm1.m2._sum*f2*d);
cm1.m4=(KahanObject) _plus.execute(cm1.m4, 6*cm1.m2._sum*Math.pow(-f2*d, 2) + lt3-4*cm1.m3._sum*f2*d);
cm1.m2.set(_buff2);
cm1.m3.set(_buff3);
cm1.w=w;
break;
}
case VARIANCE:
{
double w=cm1.w+1;
double d=in2-cm1.mean._sum;
cm1.mean=(KahanObject) _plus.execute(cm1.mean, d/w);
double t1=cm1.w/w*d;
double lt1=t1*d;
cm1.m2=(KahanObject) _plus.execute(cm1.m2, lt1);
cm1.w=w;
break;
}
default:
throw new DMLRuntimeException("Unsupported operation type: "+_type);
}
return cm1;
}
/**
* General case for arbitrary weights w2
*/
@Override
public Data execute(Data in1, double in2, double w2) {
CM_COV_Object cm1=(CM_COV_Object) in1;
if(cm1.isCMAllZeros())
{
cm1.w=w2;
cm1.mean.set(in2, 0);
cm1.m2.set(0,0);
cm1.m3.set(0,0);
cm1.m4.set(0,0);
return cm1;
}
switch( _type )
{
case COUNT:
{
cm1.w = Math.round(cm1.w + w2);
break;
}
case MEAN:
{
double w = cm1.w + w2;
double d=in2-cm1.mean._sum;
cm1.mean=(KahanObject) _plus.execute(cm1.mean, w2*d/w);
cm1.w=w;
break;
}
case CM2:
{
double w = cm1.w + w2;
double d=in2-cm1.mean._sum;
cm1.mean=(KahanObject) _plus.execute(cm1.mean, w2*d/w);
double t1=cm1.w*w2/w*d;
double lt1=t1*d;
_buff2.set(cm1.m2);
_buff2=(KahanObject) _plus.execute(_buff2, lt1);
cm1.m2.set(_buff2);
cm1.w=w;
break;
}
case CM3:
{
double w = cm1.w + w2;
double d=in2-cm1.mean._sum;
cm1.mean=(KahanObject) _plus.execute(cm1.mean, w2*d/w);
double t1=cm1.w*w2/w*d;
double t2=-1/cm1.w;
double lt1=t1*d;
double lt2=Math.pow(t1, 3)*(1/Math.pow(w2, 2)-Math.pow(t2, 2));
double f2=w2/w;
_buff2.set(cm1.m2);
_buff2=(KahanObject) _plus.execute(_buff2, lt1);
_buff3.set(cm1.m3);
_buff3=(KahanObject) _plus.execute(_buff3, lt2-3*cm1.m2._sum*f2*d);
cm1.m2.set(_buff2);
cm1.m3.set(_buff3);
cm1.w=w;
break;
}
case CM4:
{
double w = cm1.w + w2;
double d=in2-cm1.mean._sum;
cm1.mean=(KahanObject) _plus.execute(cm1.mean, w2*d/w);
double t1=cm1.w*w2/w*d;
double t2=-1/cm1.w;
double lt1=t1*d;
double lt2=Math.pow(t1, 3)*(1/Math.pow(w2, 2)-Math.pow(t2, 2));
double lt3=Math.pow(t1, 4)*(1/Math.pow(w2, 3)-Math.pow(t2, 3));
double f2=w2/w;
_buff2.set(cm1.m2);
_buff2=(KahanObject) _plus.execute(_buff2, lt1);
_buff3.set(cm1.m3);
_buff3=(KahanObject) _plus.execute(_buff3, lt2-3*cm1.m2._sum*f2*d);
cm1.m4=(KahanObject) _plus.execute(cm1.m4, 6*cm1.m2._sum*Math.pow(-f2*d, 2) + lt3-4*cm1.m3._sum*f2*d);
cm1.m2.set(_buff2);
cm1.m3.set(_buff3);
cm1.w=w;
break;
}
case VARIANCE:
{
double w = cm1.w + w2;
double d=in2-cm1.mean._sum;
cm1.mean=(KahanObject) _plus.execute(cm1.mean, w2*d/w);
double t1=cm1.w*w2/w*d;
double lt1=t1*d;
cm1.m2=(KahanObject) _plus.execute(cm1.m2, lt1);
cm1.w=w;
break;
}
default:
throw new DMLRuntimeException("Unsupported operation type: "+_type);
}
return cm1;
}
/**
* Combining stats from two partitions of the data.
*/
@Override
public Data execute(Data in1, Data in2)
{
CM_COV_Object cm1=(CM_COV_Object) in1;
CM_COV_Object cm2=(CM_COV_Object) in2;
if(cm1.isCMAllZeros())
{
cm1.w=cm2.w;
cm1.mean.set(cm2.mean);
cm1.m2.set(cm2.m2);
cm1.m3.set(cm2.m3);
cm1.m4.set(cm2.m4);
return cm1;
}
if(cm2.isCMAllZeros())
return cm1;
switch( _type )
{
case COUNT:
{
cm1.w = Math.round(cm1.w + cm2.w);
break;
}
case MEAN:
{
double w = cm1.w + cm2.w;
double d=cm2.mean._sum-cm1.mean._sum;
cm1.mean=(KahanObject) _plus.execute(cm1.mean, cm2.w*d/w);
cm1.w=w;
break;
}
case CM2:
{
double w = cm1.w + cm2.w;
double d=cm2.mean._sum-cm1.mean._sum;
cm1.mean=(KahanObject) _plus.execute(cm1.mean, cm2.w*d/w);
double t1=cm1.w*cm2.w/w*d;
double lt1=t1*d;
_buff2.set(cm1.m2);
_buff2=(KahanObject) _plus.execute(_buff2, cm2.m2._sum, cm2.m2._correction);
_buff2=(KahanObject) _plus.execute(_buff2, lt1);
cm1.m2.set(_buff2);
cm1.w=w;
break;
}
case CM3:
{
double w = cm1.w + cm2.w;
double d=cm2.mean._sum-cm1.mean._sum;
cm1.mean=(KahanObject) _plus.execute(cm1.mean, cm2.w*d/w);
double t1=cm1.w*cm2.w/w*d;
double t2=-1/cm1.w;
double lt1=t1*d;
double lt2=Math.pow(t1, 3)*(1/Math.pow(cm2.w, 2)-Math.pow(t2, 2));
double f1=cm1.w/w;
double f2=cm2.w/w;
_buff2.set(cm1.m2);
_buff2=(KahanObject) _plus.execute(_buff2, cm2.m2._sum, cm2.m2._correction);
_buff2=(KahanObject) _plus.execute(_buff2, lt1);
_buff3.set(cm1.m3);
_buff3=(KahanObject) _plus.execute(_buff3, cm2.m3._sum, cm2.m3._correction);
_buff3=(KahanObject) _plus.execute(_buff3, 3*(-f2*cm1.m2._sum+f1*cm2.m2._sum)*d + lt2);
cm1.m2.set(_buff2);
cm1.m3.set(_buff3);
cm1.w=w;
break;
}
case CM4:
{
double w = cm1.w + cm2.w;
double d=cm2.mean._sum-cm1.mean._sum;
cm1.mean=(KahanObject) _plus.execute(cm1.mean, cm2.w*d/w);
double t1=cm1.w*cm2.w/w*d;
double t2=-1/cm1.w;
double lt1=t1*d;
double lt2=Math.pow(t1, 3)*(1/Math.pow(cm2.w, 2)-Math.pow(t2, 2));
double lt3=Math.pow(t1, 4)*(1/Math.pow(cm2.w, 3)-Math.pow(t2, 3));
double f1=cm1.w/w;
double f2=cm2.w/w;
_buff2.set(cm1.m2);
_buff2=(KahanObject) _plus.execute(_buff2, cm2.m2._sum, cm2.m2._correction);
_buff2=(KahanObject) _plus.execute(_buff2, lt1);
_buff3.set(cm1.m3);
_buff3=(KahanObject) _plus.execute(_buff3, cm2.m3._sum, cm2.m3._correction);
_buff3=(KahanObject) _plus.execute(_buff3, 3*(-f2*cm1.m2._sum+f1*cm2.m2._sum)*d + lt2);
cm1.m4=(KahanObject) _plus.execute(cm1.m4, cm2.m4._sum, cm2.m4._correction);
cm1.m4=(KahanObject) _plus.execute(cm1.m4, 4*(-f2*cm1.m3._sum+f1*cm2.m3._sum)*d
+ 6*(Math.pow(-f2, 2)*cm1.m2._sum+Math.pow(f1, 2)*cm2.m2._sum)*Math.pow(d, 2) + lt3);
cm1.m2.set(_buff2);
cm1.m3.set(_buff3);
cm1.w=w;
break;
}
case VARIANCE:
{
double w = cm1.w + cm2.w;
double d=cm2.mean._sum-cm1.mean._sum;
cm1.mean=(KahanObject) _plus.execute(cm1.mean, cm2.w*d/w);
double t1=cm1.w*cm2.w/w*d;
double lt1=t1*d;
cm1.m2=(KahanObject) _plus.execute(cm1.m2, cm2.m2._sum, cm2.m2._correction);
cm1.m2=(KahanObject) _plus.execute(cm1.m2, lt1);
cm1.w=w;
break;
}
default:
throw new DMLRuntimeException("Unsupported operation type: "+_type);
}
return cm1;
}
}