| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| package org.apache.sysds.runtime.instructions.spark; |
| |
| import org.apache.spark.api.java.JavaPairRDD; |
| import org.apache.spark.api.java.JavaRDD; |
| import org.apache.spark.api.java.JavaSparkContext; |
| import org.apache.spark.api.java.function.FlatMapFunction; |
| import org.apache.spark.api.java.function.PairFlatMapFunction; |
| import org.apache.spark.broadcast.Broadcast; |
| import org.apache.spark.util.AccumulatorV2; |
| import org.apache.sysds.common.Types.DataType; |
| import org.apache.sysds.common.Types.FileFormat; |
| import org.apache.sysds.common.Types.ValueType; |
| import org.apache.sysds.runtime.DMLRuntimeException; |
| import org.apache.sysds.runtime.controlprogram.caching.FrameObject; |
| import org.apache.sysds.runtime.controlprogram.context.ExecutionContext; |
| import org.apache.sysds.runtime.controlprogram.context.SparkExecutionContext; |
| import org.apache.sysds.runtime.functionobjects.KahanPlus; |
| import org.apache.sysds.runtime.instructions.InstructionUtils; |
| import org.apache.sysds.runtime.instructions.cp.CPOperand; |
| import org.apache.sysds.runtime.instructions.cp.KahanObject; |
| import org.apache.sysds.runtime.instructions.spark.ParameterizedBuiltinSPInstruction.RDDTransformApplyFunction; |
| import org.apache.sysds.runtime.instructions.spark.ParameterizedBuiltinSPInstruction.RDDTransformApplyOffsetFunction; |
| import org.apache.sysds.runtime.instructions.spark.utils.FrameRDDConverterUtils; |
| import org.apache.sysds.runtime.instructions.spark.utils.SparkUtils; |
| import org.apache.sysds.runtime.io.FrameReader; |
| import org.apache.sysds.runtime.io.FrameReaderFactory; |
| import org.apache.sysds.runtime.matrix.data.FrameBlock; |
| import org.apache.sysds.runtime.matrix.data.FrameBlock.ColumnMetadata; |
| import org.apache.sysds.runtime.matrix.data.MatrixBlock; |
| import org.apache.sysds.runtime.matrix.data.MatrixIndexes; |
| import org.apache.sysds.runtime.matrix.operators.Operator; |
| import org.apache.sysds.runtime.meta.DataCharacteristics; |
| import org.apache.sysds.runtime.transform.encode.Encoder; |
| import org.apache.sysds.runtime.transform.encode.EncoderComposite; |
| import org.apache.sysds.runtime.transform.encode.EncoderFactory; |
| import org.apache.sysds.runtime.transform.encode.EncoderMVImpute; |
| import org.apache.sysds.runtime.transform.encode.EncoderMVImpute.MVMethod; |
| import org.apache.sysds.runtime.transform.encode.EncoderRecode; |
| import org.apache.sysds.runtime.transform.meta.TfMetaUtils; |
| import org.apache.sysds.runtime.transform.meta.TfOffsetMap; |
| import scala.Tuple2; |
| |
| import java.io.IOException; |
| import java.util.ArrayList; |
| import java.util.HashMap; |
| import java.util.HashSet; |
| import java.util.Iterator; |
| import java.util.Map.Entry; |
| |
| public class MultiReturnParameterizedBuiltinSPInstruction extends ComputationSPInstruction { |
| protected ArrayList<CPOperand> _outputs; |
| |
| private MultiReturnParameterizedBuiltinSPInstruction(Operator op, CPOperand input1, CPOperand input2, |
| ArrayList<CPOperand> outputs, String opcode, String istr) { |
| super(SPType.MultiReturnBuiltin, op, input1, input2, outputs.get(0), opcode, istr); |
| _outputs = outputs; |
| } |
| |
| public static MultiReturnParameterizedBuiltinSPInstruction parseInstruction( String str ) { |
| String[] parts = InstructionUtils.getInstructionPartsWithValueType(str); |
| ArrayList<CPOperand> outputs = new ArrayList<>(); |
| String opcode = parts[0]; |
| |
| if ( opcode.equalsIgnoreCase("transformencode") ) { |
| // one input and two outputs |
| CPOperand in1 = new CPOperand(parts[1]); |
| CPOperand in2 = new CPOperand(parts[2]); |
| outputs.add ( new CPOperand(parts[3], ValueType.FP64, DataType.MATRIX) ); |
| outputs.add ( new CPOperand(parts[4], ValueType.STRING, DataType.FRAME) ); |
| return new MultiReturnParameterizedBuiltinSPInstruction(null, in1, in2, outputs, opcode, str); |
| } |
| else { |
| throw new DMLRuntimeException("Invalid opcode in MultiReturnBuiltin instruction: " + opcode); |
| } |
| |
| } |
| |
| @Override |
| @SuppressWarnings("unchecked") |
| public void processInstruction(ExecutionContext ec) { |
| SparkExecutionContext sec = (SparkExecutionContext) ec; |
| |
| try |
| { |
| //get input RDD and meta data |
| FrameObject fo = sec.getFrameObject(input1.getName()); |
| FrameObject fometa = sec.getFrameObject(_outputs.get(1).getName()); |
| JavaPairRDD<Long,FrameBlock> in = (JavaPairRDD<Long,FrameBlock>) |
| sec.getRDDHandleForFrameObject(fo, FileFormat.BINARY); |
| String spec = ec.getScalarInput(input2).getStringValue(); |
| DataCharacteristics mcIn = sec.getDataCharacteristics(input1.getName()); |
| DataCharacteristics mcOut = sec.getDataCharacteristics(output.getName()); |
| String[] colnames = !TfMetaUtils.isIDSpec(spec) ? |
| in.lookup(1L).get(0).getColumnNames() : null; |
| |
| //step 1: build transform meta data |
| Encoder encoderBuild = EncoderFactory.createEncoder(spec, colnames, |
| fo.getSchema(), (int)fo.getNumColumns(), null); |
| |
| MaxLongAccumulator accMax = registerMaxLongAccumulator(sec.getSparkContext()); |
| JavaRDD<String> rcMaps = in |
| .mapPartitionsToPair(new TransformEncodeBuildFunction(encoderBuild)) |
| .distinct().groupByKey() |
| .flatMap(new TransformEncodeGroupFunction(accMax)); |
| if( containsMVImputeEncoder(encoderBuild) ) { |
| EncoderMVImpute mva = getMVImputeEncoder(encoderBuild); |
| rcMaps = rcMaps.union( |
| in.mapPartitionsToPair(new TransformEncodeBuild2Function(mva)) |
| .groupByKey().flatMap(new TransformEncodeGroup2Function(mva)) ); |
| } |
| rcMaps.saveAsTextFile(fometa.getFileName()); //trigger eval |
| |
| //consolidate meta data frame (reuse multi-threaded reader, special handling missing values) |
| FrameReader reader = FrameReaderFactory.createFrameReader(FileFormat.TEXT); |
| FrameBlock meta = reader.readFrameFromHDFS(fometa.getFileName(), accMax.value(), fo.getNumColumns()); |
| meta.recomputeColumnCardinality(); //recompute num distinct items per column |
| meta.setColumnNames((colnames!=null)?colnames:meta.getColumnNames()); |
| |
| //step 2: transform apply (similar to spark transformapply) |
| //compute omit offset map for block shifts |
| TfOffsetMap omap = null; |
| if( TfMetaUtils.containsOmitSpec(spec, colnames) ) { |
| omap = new TfOffsetMap(SparkUtils.toIndexedLong(in.mapToPair( |
| new RDDTransformApplyOffsetFunction(spec, colnames)).collect())); |
| } |
| |
| //create encoder broadcast (avoiding replication per task) |
| Encoder encoder = EncoderFactory.createEncoder(spec, colnames, |
| fo.getSchema(), (int)fo.getNumColumns(), meta); |
| mcOut.setDimension(mcIn.getRows()-((omap!=null)?omap.getNumRmRows():0), encoder.getNumCols()); |
| Broadcast<Encoder> bmeta = sec.getSparkContext().broadcast(encoder); |
| Broadcast<TfOffsetMap> bomap = (omap!=null) ? sec.getSparkContext().broadcast(omap) : null; |
| |
| //execute transform apply |
| JavaPairRDD<Long,FrameBlock> tmp = in |
| .mapToPair(new RDDTransformApplyFunction(bmeta, bomap)); |
| JavaPairRDD<MatrixIndexes,MatrixBlock> out = FrameRDDConverterUtils |
| .binaryBlockToMatrixBlock(tmp, mcOut, mcOut); |
| |
| //set output and maintain lineage/output characteristics |
| sec.setRDDHandleForVariable(_outputs.get(0).getName(), out); |
| sec.addLineageRDD(_outputs.get(0).getName(), input1.getName()); |
| sec.setFrameOutput(_outputs.get(1).getName(), meta); |
| } |
| catch(IOException ex) { |
| throw new RuntimeException(ex); |
| } |
| } |
| |
| private static boolean containsMVImputeEncoder(Encoder encoder) { |
| if( encoder instanceof EncoderComposite ) |
| for( Encoder cencoder : ((EncoderComposite)encoder).getEncoders() ) |
| if( cencoder instanceof EncoderMVImpute ) |
| return true; |
| return false; |
| } |
| |
| private static EncoderMVImpute getMVImputeEncoder(Encoder encoder) { |
| if( encoder instanceof EncoderComposite ) |
| for( Encoder cencoder : ((EncoderComposite)encoder).getEncoders() ) |
| if( cencoder instanceof EncoderMVImpute ) |
| return (EncoderMVImpute) cencoder; |
| return null; |
| } |
| |
| private static MaxLongAccumulator registerMaxLongAccumulator(JavaSparkContext sc) { |
| MaxLongAccumulator acc = new MaxLongAccumulator(Long.MIN_VALUE); |
| sc.sc().register(acc, "max"); |
| return acc; |
| } |
| |
| |
| private static class MaxLongAccumulator extends AccumulatorV2<Long,Long> |
| { |
| private static final long serialVersionUID = -3739727823287550826L; |
| |
| private long _value = Long.MIN_VALUE; |
| |
| public MaxLongAccumulator(long value) { |
| _value = value; |
| } |
| |
| @Override |
| public void add(Long arg0) { |
| _value = Math.max(_value, arg0); |
| } |
| |
| @Override |
| public AccumulatorV2<Long, Long> copy() { |
| return new MaxLongAccumulator(_value); |
| } |
| |
| @Override |
| public boolean isZero() { |
| return _value == Long.MIN_VALUE; |
| } |
| |
| @Override |
| public void merge(AccumulatorV2<Long, Long> arg0) { |
| _value = Math.max(_value, arg0.value()); |
| } |
| |
| @Override |
| public void reset() { |
| _value = Long.MIN_VALUE; |
| } |
| |
| @Override |
| public Long value() { |
| return _value; |
| } |
| } |
| |
| /** |
| * This function pre-aggregates distinct values of recoded columns per partition |
| * (part of distributed recode map construction, used for recoding, binning and |
| * dummy coding). We operate directly over schema-specific objects to avoid |
| * unnecessary string conversion, as well as reduce memory overhead and shuffle. |
| */ |
| public static class TransformEncodeBuildFunction |
| implements PairFlatMapFunction<Iterator<Tuple2<Long, FrameBlock>>, Integer, Object> |
| { |
| private static final long serialVersionUID = 6336375833412029279L; |
| |
| private EncoderRecode _raEncoder = null; |
| |
| public TransformEncodeBuildFunction(Encoder encoder) { |
| for( Encoder cEncoder : ((EncoderComposite)encoder).getEncoders() ) |
| if( cEncoder instanceof EncoderRecode ) |
| _raEncoder = (EncoderRecode)cEncoder; |
| } |
| |
| @Override |
| public Iterator<Tuple2<Integer, Object>> call(Iterator<Tuple2<Long, FrameBlock>> iter) |
| throws Exception |
| { |
| //build meta data (e.g., recode maps) |
| if( _raEncoder != null ) |
| while( iter.hasNext() ) |
| _raEncoder.buildPartial(iter.next()._2()); |
| |
| //output recode maps as columnID - token pairs |
| ArrayList<Tuple2<Integer,Object>> ret = new ArrayList<>(); |
| HashMap<Integer,HashSet<Object>> tmp = _raEncoder.getCPRecodeMapsPartial(); |
| for( Entry<Integer,HashSet<Object>> e1 : tmp.entrySet() ) |
| for( Object token : e1.getValue() ) |
| ret.add(new Tuple2<>(e1.getKey(), token)); |
| if( _raEncoder != null ) |
| _raEncoder.getCPRecodeMapsPartial().clear(); |
| |
| return ret.iterator(); |
| } |
| } |
| |
| /** |
| * This function assigns codes to globally distinct values of recoded columns |
| * and writes the resulting column map in textcell (IJV) format to the output. |
| * (part of distributed recode map construction, used for recoding, binning and |
| * dummy coding). We operate directly over schema-specific objects to avoid |
| * unnecessary string conversion, as well as reduce memory overhead and shuffle. |
| */ |
| public static class TransformEncodeGroupFunction |
| implements FlatMapFunction<Tuple2<Integer, Iterable<Object>>, String> |
| { |
| private static final long serialVersionUID = -1034187226023517119L; |
| |
| private MaxLongAccumulator _accMax = null; |
| |
| public TransformEncodeGroupFunction( MaxLongAccumulator accMax ) { |
| _accMax = accMax; |
| } |
| |
| @Override |
| public Iterator<String> call(Tuple2<Integer, Iterable<Object>> arg0) |
| throws Exception |
| { |
| String colID = String.valueOf(arg0._1()); |
| Iterator<Object> iter = arg0._2().iterator(); |
| |
| ArrayList<String> ret = new ArrayList<>(); |
| StringBuilder sb = new StringBuilder(); |
| long rowID = 1; |
| while( iter.hasNext() ) { |
| sb.append(rowID); |
| sb.append(' '); |
| sb.append(colID); |
| sb.append(' '); |
| sb.append(EncoderRecode.constructRecodeMapEntry( |
| iter.next().toString(), rowID)); |
| ret.add(sb.toString()); |
| sb.setLength(0); |
| rowID++; |
| } |
| _accMax.add(rowID-1); |
| |
| return ret.iterator(); |
| } |
| } |
| |
| public static class TransformEncodeBuild2Function implements PairFlatMapFunction<Iterator<Tuple2<Long, FrameBlock>>, Integer, ColumnMetadata> |
| { |
| private static final long serialVersionUID = 6336375833412029279L; |
| |
| private EncoderMVImpute _encoder = null; |
| |
| public TransformEncodeBuild2Function(EncoderMVImpute encoder) { |
| _encoder = encoder; |
| } |
| |
| @Override |
| public Iterator<Tuple2<Integer, ColumnMetadata>> call(Iterator<Tuple2<Long, FrameBlock>> iter) |
| throws Exception |
| { |
| //build meta data (e.g., histograms and means) |
| while( iter.hasNext() ) { |
| FrameBlock block = iter.next()._2(); |
| _encoder.build(block); |
| } |
| |
| //extract meta data |
| ArrayList<Tuple2<Integer,ColumnMetadata>> ret = new ArrayList<>(); |
| int[] collist = _encoder.getColList(); |
| for( int j=0; j<collist.length; j++ ) { |
| if( _encoder.getMethod(collist[j]) == MVMethod.GLOBAL_MODE ) { |
| HashMap<String,Long> hist = _encoder.getHistogram(collist[j]); |
| for( Entry<String,Long> e : hist.entrySet() ) |
| ret.add(new Tuple2<>(collist[j], |
| new ColumnMetadata(e.getValue(), e.getKey()))); |
| } |
| else if( _encoder.getMethod(collist[j]) == MVMethod.GLOBAL_MEAN ) { |
| ret.add(new Tuple2<>(collist[j], |
| new ColumnMetadata(_encoder.getNonMVCount(collist[j]), String.valueOf(_encoder.getMeans()[j]._sum)))); |
| } |
| else if( _encoder.getMethod(collist[j]) == MVMethod.CONSTANT ) { |
| ret.add(new Tuple2<>(collist[j], |
| new ColumnMetadata(0, _encoder.getReplacement(collist[j])))); |
| } |
| } |
| |
| return ret.iterator(); |
| } |
| } |
| |
| public static class TransformEncodeGroup2Function implements FlatMapFunction<Tuple2<Integer, Iterable<ColumnMetadata>>, String> |
| { |
| private static final long serialVersionUID = 702100641492347459L; |
| |
| private EncoderMVImpute _encoder = null; |
| |
| public TransformEncodeGroup2Function(EncoderMVImpute encoder) { |
| _encoder = encoder; |
| } |
| |
| @Override |
| public Iterator<String> call(Tuple2<Integer, Iterable<ColumnMetadata>> arg0) |
| throws Exception |
| { |
| int colix = arg0._1(); |
| Iterator<ColumnMetadata> iter = arg0._2().iterator(); |
| ArrayList<String> ret = new ArrayList<>(); |
| |
| //compute global mode of categorical feature, i.e., value with highest frequency |
| if( _encoder.getMethod(colix) == MVMethod.GLOBAL_MODE ) { |
| HashMap<String, Long> hist = new HashMap<>(); |
| while( iter.hasNext() ) { |
| ColumnMetadata cmeta = iter.next(); |
| Long tmp = hist.get(cmeta.getMvValue()); |
| hist.put(cmeta.getMvValue(), cmeta.getNumDistinct() + ((tmp!=null)?tmp:0)); |
| } |
| long max = Long.MIN_VALUE; String mode = null; |
| for( Entry<String, Long> e : hist.entrySet() ) |
| if( e.getValue() > max ) { |
| mode = e.getKey(); |
| max = e.getValue(); |
| } |
| ret.add("-2 " + colix + " " + mode); |
| } |
| //compute global mean of categorical feature |
| else if( _encoder.getMethod(colix) == MVMethod.GLOBAL_MEAN ) { |
| KahanObject kbuff = new KahanObject(0, 0); |
| KahanPlus kplus = KahanPlus.getKahanPlusFnObject(); |
| int count = 0; |
| while( iter.hasNext() ) { |
| ColumnMetadata cmeta = iter.next(); |
| kplus.execute2(kbuff, Double.parseDouble(cmeta.getMvValue())); |
| count += cmeta.getNumDistinct(); |
| } |
| if( count > 0 ) |
| ret.add("-2 " + colix + " " + String.valueOf(kbuff._sum/count)); |
| } |
| //pass-through constant label |
| else if( _encoder.getMethod(colix) == MVMethod.CONSTANT ) { |
| if( iter.hasNext() ) |
| ret.add("-2 " + colix + " " + iter.next().getMvValue()); |
| } |
| |
| return ret.iterator(); |
| } |
| } |
| } |