| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| package org.apache.sysds.runtime.transform.encode; |
| |
| import java.io.IOException; |
| import java.util.BitSet; |
| import java.util.HashMap; |
| import java.util.List; |
| import java.util.Map.Entry; |
| |
| import org.apache.wink.json4j.JSONArray; |
| import org.apache.wink.json4j.JSONException; |
| import org.apache.wink.json4j.JSONObject; |
| import org.apache.sysds.runtime.functionobjects.CM; |
| import org.apache.sysds.runtime.functionobjects.Mean; |
| import org.apache.sysds.runtime.instructions.cp.CM_COV_Object; |
| import org.apache.sysds.runtime.instructions.cp.KahanObject; |
| import org.apache.sysds.runtime.matrix.data.FrameBlock; |
| import org.apache.sysds.runtime.matrix.data.MatrixBlock; |
| import org.apache.sysds.runtime.matrix.operators.CMOperator.AggregateOperationTypes; |
| import org.apache.sysds.runtime.transform.TfUtils; |
| import org.apache.sysds.runtime.transform.TfUtils.TfMethod; |
| import org.apache.sysds.runtime.transform.meta.TfMetaUtils; |
| import org.apache.sysds.runtime.util.UtilFunctions; |
| |
| public class EncoderMVImpute extends Encoder |
| { |
| private static final long serialVersionUID = 9057868620144662194L; |
| |
| public enum MVMethod { INVALID, GLOBAL_MEAN, GLOBAL_MODE, CONSTANT } |
| |
| private MVMethod[] _mvMethodList = null; |
| private MVMethod[] _mvscMethodList = null; // scaling methods for attributes that are imputed and also scaled |
| |
| private BitSet _isMVScaled = null; |
| private CM _varFn = CM.getCMFnObject(AggregateOperationTypes.VARIANCE); // function object that understands variance computation |
| |
| // objects required to compute mean and variance of all non-missing entries |
| private Mean _meanFn = Mean.getMeanFnObject(); // function object that understands mean computation |
| private KahanObject[] _meanList = null; // column-level means, computed so far |
| private long[] _countList = null; // #of non-missing values |
| |
| private CM_COV_Object[] _varList = null; // column-level variances, computed so far (for scaling) |
| |
| private int[] _scnomvList = null; // List of attributes that are scaled but not imputed |
| private MVMethod[] _scnomvMethodList = null; // scaling methods: 0 for invalid; 1 for mean-subtraction; 2 for z-scoring |
| private KahanObject[] _scnomvMeanList = null; // column-level means, for attributes scaled but not imputed |
| private long[] _scnomvCountList = null; // #of non-missing values, for attributes scaled but not imputed |
| private CM_COV_Object[] _scnomvVarList = null; // column-level variances, computed so far |
| |
| private String[] _replacementList = null; // replacements: for global_mean, mean; and for global_mode, recode id of mode category |
| private String[] _NAstrings = null; |
| private List<Integer> _rcList = null; |
| private HashMap<Integer,HashMap<String,Long>> _hist = null; |
| |
| public String[] getReplacements() { return _replacementList; } |
| public KahanObject[] getMeans() { return _meanList; } |
| public CM_COV_Object[] getVars() { return _varList; } |
| public KahanObject[] getMeans_scnomv() { return _scnomvMeanList; } |
| public CM_COV_Object[] getVars_scnomv() { return _scnomvVarList; } |
| |
| public EncoderMVImpute(JSONObject parsedSpec, String[] colnames, int clen) |
| throws JSONException |
| { |
| super(null, clen); |
| |
| //handle column list |
| int[] collist = TfMetaUtils.parseJsonObjectIDList(parsedSpec, colnames, TfMethod.IMPUTE.toString()); |
| initColList(collist); |
| |
| //handle method list |
| parseMethodsAndReplacments(parsedSpec); |
| |
| //create reuse histograms |
| _hist = new HashMap<>(); |
| } |
| |
| public EncoderMVImpute(JSONObject parsedSpec, String[] colnames, String[] NAstrings, int clen) |
| throws JSONException |
| { |
| super(null, clen); |
| boolean isMV = parsedSpec.containsKey(TfMethod.IMPUTE.toString()); |
| boolean isSC = parsedSpec.containsKey(TfMethod.SCALE.toString()); |
| _NAstrings = NAstrings; |
| |
| if(!isMV) { |
| // MV Impute is not applicable |
| _colList = null; |
| _mvMethodList = null; |
| _meanList = null; |
| _countList = null; |
| _replacementList = null; |
| } |
| else { |
| JSONObject mvobj = (JSONObject) parsedSpec.get(TfMethod.IMPUTE.toString()); |
| JSONArray mvattrs = (JSONArray) mvobj.get(TfUtils.JSON_ATTRS); |
| JSONArray mvmthds = (JSONArray) mvobj.get(TfUtils.JSON_MTHD); |
| int mvLength = mvattrs.size(); |
| |
| _colList = new int[mvLength]; |
| _mvMethodList = new MVMethod[mvLength]; |
| |
| _meanList = new KahanObject[mvLength]; |
| _countList = new long[mvLength]; |
| _varList = new CM_COV_Object[mvLength]; |
| |
| _isMVScaled = new BitSet(_colList.length); |
| _isMVScaled.clear(); |
| |
| for(int i=0; i < _colList.length; i++) { |
| _colList[i] = UtilFunctions.toInt(mvattrs.get(i)); |
| _mvMethodList[i] = MVMethod.values()[UtilFunctions.toInt(mvmthds.get(i))]; |
| _meanList[i] = new KahanObject(0, 0); |
| } |
| |
| _replacementList = new String[mvLength]; // contains replacements for all columns (scale and categorical) |
| |
| JSONArray constants = (JSONArray)mvobj.get(TfUtils.JSON_CONSTS); |
| for(int i=0; i < constants.size(); i++) { |
| if ( constants.get(i) == null ) |
| _replacementList[i] = "NaN"; |
| else |
| _replacementList[i] = constants.get(i).toString(); |
| } |
| } |
| |
| // Handle scaled attributes |
| if ( !isSC ) |
| { |
| // scaling is not applicable |
| _scnomvCountList = null; |
| _scnomvMeanList = null; |
| _scnomvVarList = null; |
| } |
| else |
| { |
| if ( _colList != null ) |
| _mvscMethodList = new MVMethod[_colList.length]; |
| |
| JSONObject scobj = (JSONObject) parsedSpec.get(TfMethod.SCALE.toString()); |
| JSONArray scattrs = (JSONArray) scobj.get(TfUtils.JSON_ATTRS); |
| JSONArray scmthds = (JSONArray) scobj.get(TfUtils.JSON_MTHD); |
| int scLength = scattrs.size(); |
| |
| int[] _allscaled = new int[scLength]; |
| int scnomv = 0, colID; |
| byte mthd; |
| for(int i=0; i < scLength; i++) |
| { |
| colID = UtilFunctions.toInt(scattrs.get(i)); |
| mthd = (byte) UtilFunctions.toInt(scmthds.get(i)); |
| |
| _allscaled[i] = colID; |
| |
| // check if the attribute is also MV imputed |
| int mvidx = isApplicable(colID); |
| if(mvidx != -1) |
| { |
| _isMVScaled.set(mvidx); |
| _mvscMethodList[mvidx] = MVMethod.values()[mthd]; |
| _varList[mvidx] = new CM_COV_Object(); |
| } |
| else |
| scnomv++; // count of scaled but not imputed |
| } |
| |
| if(scnomv > 0) |
| { |
| _scnomvList = new int[scnomv]; |
| _scnomvMethodList = new MVMethod[scnomv]; |
| |
| _scnomvMeanList = new KahanObject[scnomv]; |
| _scnomvCountList = new long[scnomv]; |
| _scnomvVarList = new CM_COV_Object[scnomv]; |
| |
| for(int i=0, idx=0; i < scLength; i++) |
| { |
| colID = UtilFunctions.toInt(scattrs.get(i)); |
| mthd = (byte)UtilFunctions.toInt(scmthds.get(i)); |
| |
| if(isApplicable(colID) == -1) |
| { // scaled but not imputed |
| _scnomvList[idx] = colID; |
| _scnomvMethodList[idx] = MVMethod.values()[mthd]; |
| _scnomvMeanList[idx] = new KahanObject(0, 0); |
| _scnomvVarList[idx] = new CM_COV_Object(); |
| idx++; |
| } |
| } |
| } |
| } |
| } |
| |
| private void parseMethodsAndReplacments(JSONObject parsedSpec) throws JSONException { |
| JSONArray mvspec = (JSONArray) parsedSpec.get(TfMethod.IMPUTE.toString()); |
| _mvMethodList = new MVMethod[mvspec.size()]; |
| _replacementList = new String[mvspec.size()]; |
| _meanList = new KahanObject[mvspec.size()]; |
| _countList = new long[mvspec.size()]; |
| for(int i=0; i < mvspec.size(); i++) { |
| JSONObject mvobj = (JSONObject)mvspec.get(i); |
| _mvMethodList[i] = MVMethod.valueOf(mvobj.get("method").toString().toUpperCase()); |
| if( _mvMethodList[i] == MVMethod.CONSTANT ) { |
| _replacementList[i] = mvobj.getString("value").toString(); |
| } |
| _meanList[i] = new KahanObject(0, 0); |
| } |
| } |
| |
| public void prepare(String[] words) throws IOException { |
| |
| try { |
| String w = null; |
| if(_colList != null) |
| for(int i=0; i <_colList.length; i++) { |
| int colID = _colList[i]; |
| w = UtilFunctions.unquote(words[colID-1].trim()); |
| |
| try { |
| if(!TfUtils.isNA(_NAstrings, w)) { |
| _countList[i]++; |
| |
| boolean computeMean = (_mvMethodList[i] == MVMethod.GLOBAL_MEAN || _isMVScaled.get(i) ); |
| if(computeMean) { |
| // global_mean |
| double d = UtilFunctions.parseToDouble(w, UtilFunctions.defaultNaString); |
| _meanFn.execute2(_meanList[i], d, _countList[i]); |
| |
| if (_isMVScaled.get(i) && _mvscMethodList[i] == MVMethod.GLOBAL_MODE) |
| _varFn.execute(_varList[i], d); |
| } |
| else { |
| // global_mode or constant |
| // Nothing to do here. Mode is computed using recode maps. |
| } |
| } |
| } catch (NumberFormatException e) |
| { |
| throw new RuntimeException("Encountered \"" + w + "\" in column ID \"" + colID + "\", when expecting a numeric value. Consider adding \"" + w + "\" to na.strings, along with an appropriate imputation method."); |
| } |
| } |
| |
| // Compute mean and variance for attributes that are scaled but not imputed |
| if(_scnomvList != null) |
| for(int i=0; i < _scnomvList.length; i++) |
| { |
| int colID = _scnomvList[i]; |
| w = UtilFunctions.unquote(words[colID-1].trim()); |
| double d = UtilFunctions.parseToDouble(w, UtilFunctions.defaultNaString); |
| _scnomvCountList[i]++; // not required, this is always equal to total #records processed |
| _meanFn.execute2(_scnomvMeanList[i], d, _scnomvCountList[i]); |
| if(_scnomvMethodList[i] == MVMethod.GLOBAL_MODE) |
| _varFn.execute(_scnomvVarList[i], d); |
| } |
| } catch(Exception e) { |
| throw new IOException(e); |
| } |
| } |
| |
| public MVMethod getMethod(int colID) { |
| int idx = isApplicable(colID); |
| if(idx == -1) |
| return MVMethod.INVALID; |
| else |
| return _mvMethodList[idx]; |
| } |
| |
| public long getNonMVCount(int colID) { |
| int idx = isApplicable(colID); |
| return (idx == -1) ? 0 : _countList[idx]; |
| } |
| |
| public String getReplacement(int colID) { |
| int idx = isApplicable(colID); |
| return (idx == -1) ? null : _replacementList[idx]; |
| } |
| |
| @Override |
| public MatrixBlock encode(FrameBlock in, MatrixBlock out) { |
| build(in); |
| return apply(in, out); |
| } |
| |
| @Override |
| public void build(FrameBlock in) { |
| try { |
| for( int j=0; j<_colList.length; j++ ) { |
| int colID = _colList[j]; |
| if( _mvMethodList[j] == MVMethod.GLOBAL_MEAN ) { |
| //compute global column mean (scale) |
| long off = _countList[j]; |
| for( int i=0; i<in.getNumRows(); i++ ) |
| _meanFn.execute2(_meanList[j], UtilFunctions.objectToDouble( |
| in.getSchema()[colID-1], in.get(i, colID-1)), off+i+1); |
| _replacementList[j] = String.valueOf(_meanList[j]._sum); |
| _countList[j] += in.getNumRows(); |
| } |
| else if( _mvMethodList[j] == MVMethod.GLOBAL_MODE ) { |
| //compute global column mode (categorical), i.e., most frequent category |
| HashMap<String,Long> hist = _hist.containsKey(colID) ? |
| _hist.get(colID) : new HashMap<>(); |
| for( int i=0; i<in.getNumRows(); i++ ) { |
| String key = String.valueOf(in.get(i, colID-1)); |
| if( key != null && !key.isEmpty() ) { |
| Long val = hist.get(key); |
| hist.put(key, (val!=null) ? val+1 : 1); |
| } |
| } |
| _hist.put(colID, hist); |
| long max = Long.MIN_VALUE; |
| for( Entry<String, Long> e : hist.entrySet() ) |
| if( e.getValue() > max ) { |
| _replacementList[j] = e.getKey(); |
| max = e.getValue(); |
| } |
| } |
| } |
| } |
| catch(Exception ex) { |
| throw new RuntimeException(ex); |
| } |
| } |
| |
| @Override |
| public MatrixBlock apply(FrameBlock in, MatrixBlock out) { |
| for(int i=0; i<in.getNumRows(); i++) { |
| for(int j=0; j<_colList.length; j++) { |
| int colID = _colList[j]; |
| if( Double.isNaN(out.quickGetValue(i, colID-1)) ) |
| out.quickSetValue(i, colID-1, Double.parseDouble(_replacementList[j])); |
| } |
| } |
| return out; |
| } |
| |
| @Override |
| public FrameBlock getMetaData(FrameBlock out) { |
| for( int j=0; j<_colList.length; j++ ) { |
| out.getColumnMetadata(_colList[j]-1) |
| .setMvValue(_replacementList[j]); |
| } |
| return out; |
| } |
| |
| @Override |
| public void initMetaData(FrameBlock meta) { |
| //init replacement lists, replace recoded values to |
| //apply mv imputation potentially after recoding |
| for( int j=0; j<_colList.length; j++ ) { |
| int colID = _colList[j]; |
| String mvVal = UtilFunctions.unquote(meta.getColumnMetadata(colID-1).getMvValue()); |
| if( _rcList.contains(colID) ) { |
| Long mvVal2 = meta.getRecodeMap(colID-1).get(mvVal); |
| if( mvVal2 == null) |
| throw new RuntimeException("Missing recode value for impute value '"+mvVal+"' (colID="+colID+")."); |
| _replacementList[j] = mvVal2.toString(); |
| } |
| else { |
| _replacementList[j] = mvVal; |
| } |
| } |
| } |
| |
| public void initRecodeIDList(List<Integer> rcList) { |
| _rcList = rcList; |
| } |
| |
| /** |
| * Exposes the internal histogram after build. |
| * |
| * @param colID column ID |
| * @return histogram (map of string keys and long values) |
| */ |
| public HashMap<String,Long> getHistogram( int colID ) { |
| return _hist.get(colID); |
| } |
| } |