| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| package org.apache.sysds.hops.rewrite; |
| |
| import java.util.ArrayList; |
| |
| import org.apache.sysds.hops.Hop; |
| import org.apache.sysds.hops.UnaryOp; |
| import org.apache.sysds.common.Types.OpOp1; |
| import org.apache.sysds.common.Types.ValueType; |
| |
| /** |
| * Rule: RemoveUnnecessaryCasts. For all value type casts check |
| * if they are really necessary. If both cast input and output |
| * type are the same, the cast itself is redundant. |
| * |
| * There are two use case where this can arise: (1) automatically |
| * inserted casts on function inlining (in case of unknown value |
| * types), and (2) explicit script-level value type casts, that |
| * might be redundant according to the read input data. |
| * |
| * The benefit of this rewrite is negligible for scalars. However, |
| * when we support matrices with different value types, those casts |
| * might refer to matrices and with that incur large costs. |
| * |
| */ |
| public class RewriteRemoveUnnecessaryCasts extends HopRewriteRule |
| { |
| @Override |
| public ArrayList<Hop> rewriteHopDAGs(ArrayList<Hop> roots, ProgramRewriteStatus state) { |
| if( roots == null ) |
| return null; |
| for( Hop h : roots ) |
| rule_RemoveUnnecessaryCasts( h ); |
| return roots; |
| } |
| |
| @Override |
| public Hop rewriteHopDAG(Hop root, ProgramRewriteStatus state) { |
| if( root == null ) |
| return root; |
| rule_RemoveUnnecessaryCasts( root ); |
| return root; |
| } |
| |
| @SuppressWarnings("unchecked") |
| private void rule_RemoveUnnecessaryCasts( Hop hop ) |
| { |
| //check mark processed |
| if( hop.isVisited() ) |
| return; |
| |
| //recursively process childs |
| ArrayList<Hop> inputs = hop.getInput(); |
| for( int i=0; i<inputs.size(); i++ ) |
| rule_RemoveUnnecessaryCasts( inputs.get(i) ); |
| |
| //remove unnecessary value type cast |
| if( hop instanceof UnaryOp && HopRewriteUtils.isValueTypeCast(((UnaryOp)hop).getOp()) ) |
| { |
| Hop in = hop.getInput().get(0); |
| ValueType vtIn = in.getValueType(); //type cast input |
| ValueType vtOut = hop.getValueType(); //type cast output |
| |
| //if input/output types match, no need to cast |
| if( vtIn == vtOut && vtIn != ValueType.UNKNOWN ) |
| { |
| ArrayList<Hop> parents = hop.getParent(); |
| for( int i=0; i<parents.size(); i++ ) //for all parents |
| { |
| Hop p = parents.get(i); |
| ArrayList<Hop> pin = p.getInput(); |
| for( int j=0; j<pin.size(); j++ ) //for all parent childs |
| { |
| Hop pinj = pin.get(j); |
| if( pinj == hop ) //found parent ref |
| { |
| //rehang cast input as child of cast consumer |
| pin.remove( j ); //remove cast ref |
| pin.add(j, in); //add ref to cast input |
| in.getParent().remove(hop); //remove cast from cast input parents |
| in.getParent().add( p ); //add parent to cast input parents |
| } |
| } |
| } |
| parents.clear(); |
| } |
| } |
| |
| //remove unnecessary data type casts |
| if( hop instanceof UnaryOp && hop.getInput().get(0) instanceof UnaryOp ) { |
| UnaryOp uop1 = (UnaryOp) hop; |
| UnaryOp uop2 = (UnaryOp) hop.getInput().get(0); |
| if( (uop1.getOp()==OpOp1.CAST_AS_MATRIX && uop2.getOp()==OpOp1.CAST_AS_SCALAR) |
| || (uop1.getOp()==OpOp1.CAST_AS_SCALAR && uop2.getOp()==OpOp1.CAST_AS_MATRIX) ) { |
| Hop input = uop2.getInput().get(0); |
| //rewire parents |
| ArrayList<Hop> parents = (ArrayList<Hop>) hop.getParent().clone(); |
| for( Hop p : parents ) |
| HopRewriteUtils.replaceChildReference(p, hop, input); |
| } |
| } |
| |
| //mark processed |
| hop.setVisited(); |
| } |
| } |