blob: 828eaef0c1a40f07e7a7485a3951614de8b60287 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.sysds.lops.rewrite;
import org.apache.sysds.conf.ConfigurationManager;
import org.apache.sysds.hops.rewrite.HopRewriteUtils;
import org.apache.sysds.lops.Lop;
import org.apache.sysds.parser.ForStatement;
import org.apache.sysds.parser.StatementBlock;
import org.apache.sysds.parser.WhileStatement;
import org.apache.sysds.parser.WhileStatementBlock;
import java.util.List;
public class RewriteFixIDs extends LopRewriteRule
{
@Override
public List<StatementBlock> rewriteLOPinStatementBlock(StatementBlock sb)
{
// Skip if no new Lop nodes are added
if (!ConfigurationManager.isPrefetchEnabled() && !ConfigurationManager.isBroadcastEnabled()
&& !ConfigurationManager.isCheckpointEnabled())
return List.of(sb);
if (HopRewriteUtils.isLastLevelLoopStatementBlock(sb)) {
// Some rewrites add new Lops in the last-level loop body
StatementBlock csb = sb instanceof WhileStatementBlock
? ((WhileStatement) sb.getStatement(0)).getBody().get(0)
: ((ForStatement) sb.getStatement(0)).getBody().get(0);
assignNewIDStatementBlock(csb);
}
else
assignNewIDStatementBlock(sb);
return List.of(sb);
}
@Override
public List<StatementBlock> rewriteLOPinStatementBlocks(List<StatementBlock> sbs) {
return sbs;
}
private void assignNewIDStatementBlock(StatementBlock sb) {
// Reset the IDs in a depth-first manner
if (sb.getLops() != null && !sb.getLops().isEmpty()) {
for (Lop root : sb.getLops())
assignNewIDLop(root);
sb.getLops().forEach(Lop::resetVisitStatus);
}
}
private void assignNewIDLop(Lop lop) {
if (lop.isVisited())
return;
if (lop.getInputs().isEmpty()) { //leaf node
lop.setNewID();
lop.setVisited();
return;
}
for (Lop input : lop.getInputs())
assignNewIDLop(input);
lop.setNewID();
lop.setVisited();
}
}