| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package org.apache.wayang.core.mapping; |
| |
| import java.util.Collection; |
| import java.util.LinkedList; |
| import java.util.function.Predicate; |
| import org.apache.wayang.core.plan.wayangplan.InputSlot; |
| import org.apache.wayang.core.plan.wayangplan.Operator; |
| import org.apache.wayang.core.plan.wayangplan.OperatorBase; |
| import org.apache.wayang.core.plan.wayangplan.OutputSlot; |
| import org.apache.wayang.core.plan.wayangplan.Slot; |
| import org.apache.wayang.core.plan.wayangplan.TopDownPlanVisitor; |
| import org.apache.wayang.core.types.DataSetType; |
| |
| /** |
| * An operator pattern matches to a class of operator instances. |
| */ |
| public class OperatorPattern<T extends Operator> extends OperatorBase { |
| |
| /** |
| * Identifier for this instance to identify {@link OperatorMatch}es. |
| */ |
| private final String name; |
| |
| /** |
| * {@link Operator} type matched by this instance. |
| */ |
| private final Class<?> operatorClass; |
| |
| /** |
| * Whether subclasses of {@link #operatorClass} also match. |
| */ |
| private final boolean isMatchSubclasses; |
| |
| /** |
| * Whether broadcast {@link InputSlot}s are allowed. |
| */ |
| private final boolean isAllowBroadcasts; |
| |
| /** |
| * Additional predicates to test in order to establish a match. |
| */ |
| private final Collection<Predicate<T>> additionalTests = new LinkedList<>(); |
| |
| /** |
| * Creates a new instance. |
| * |
| * @param name used to identify the new instance (e.g., in {@link SubplanMatch}es) |
| * @param exampleOperator serves as template of the {@link Operator}s to match; use {@link DataSetType#none()} to |
| * state that the {@link DataSetType} of a certain {@link Slot} are not to be matched |
| * @param isMatchSubclasses whether to match subclasses of the {@code exampleOperator} |
| */ |
| public OperatorPattern(String name, |
| T exampleOperator, |
| boolean isMatchSubclasses) { |
| |
| super(exampleOperator.getNumInputs(), exampleOperator.getNumOutputs(), |
| exampleOperator.isSupportingBroadcastInputs()); |
| |
| this.name = name; |
| InputSlot.mock(exampleOperator, this); |
| OutputSlot.mock(exampleOperator, this); |
| |
| this.operatorClass = exampleOperator.getClass(); |
| this.isAllowBroadcasts = exampleOperator.isSupportingBroadcastInputs(); |
| this.isMatchSubclasses = isMatchSubclasses; |
| } |
| |
| /** |
| * Test whether this pattern matches a given operator. |
| * |
| * @param operator the operator to match or {@code null}, which represents the absence of an operator to match |
| * @return whether the operator matches |
| */ |
| @SuppressWarnings("unchecked") |
| public OperatorMatch match(Operator operator) { |
| if (operator == null) return null; |
| |
| // Only match by the class so far. |
| if (this.matchOperatorClass(operator) && this.matchSlots(operator) && this.matchAdditionalTests((T) operator)) { |
| this.checkSanity(operator); |
| return new OperatorMatch(this, operator); |
| } |
| |
| return null; |
| } |
| |
| /** |
| * Checks whether the {@link Operator} {@link Class} that of the given {@link Operator}. |
| * |
| * @param operator that should be matched with |
| * @return whether this instance and the {@code operator} match |
| */ |
| private boolean matchOperatorClass(Operator operator) { |
| return this.isMatchSubclasses ? |
| this.operatorClass.isAssignableFrom(operator.getClass()) : |
| this.operatorClass.equals(operator.getClass()); |
| } |
| |
| /** |
| * Checks whether the {@link Operator} {@link Class} that of the given {@link Operator}. TODO |
| * |
| * @param operator that should be matched with |
| * @return whether this instance and the {@code operator} match |
| */ |
| private boolean matchSlots(Operator operator) { |
| // Check whether the InputSlots match. |
| int inputIndex; |
| for (inputIndex = 0; inputIndex < this.getNumInputs(); inputIndex++) { |
| InputSlot<?> slotPattern = this.getInput(inputIndex); |
| final InputSlot<?> testSlot = operator.getInput(slotPattern.getIndex()); |
| if (!this.matchSlot(slotPattern, testSlot)) { |
| return false; |
| } |
| } |
| // Take special care for broadcasts. |
| for (; inputIndex < operator.getNumInputs(); inputIndex++) { |
| if (operator.getInput(inputIndex).isBroadcast() && !this.isAllowBroadcasts) return false; |
| } |
| |
| // Check whether the OutputSlots match. |
| for (int outputIndex = 0; outputIndex < this.getNumOutputs(); outputIndex++) { |
| OutputSlot<?> slotPattern = this.getOutput(outputIndex); |
| final OutputSlot<?> testSlot = operator.getOutput(slotPattern.getIndex()); |
| if (!this.matchSlot(slotPattern, testSlot)) { |
| return false; |
| } |
| } |
| return true; |
| } |
| |
| /** |
| * Test whether a given test {@link Slot} matches a pattern {@link Slot}. |
| * |
| * @param slotPattern that should be matched against |
| * @param testSlot will be matched |
| * @return whether the {@code testSlot} matches the {@code slotPattern} |
| */ |
| private boolean matchSlot(Slot<?> slotPattern, Slot<?> testSlot) { |
| return slotPattern.getType().isNone() || slotPattern.getType().isSupertypeOf(testSlot.getType()); |
| } |
| |
| /** |
| * Test whether the {@link #additionalTests} are satisfied. |
| * |
| * @param operator that should be tested |
| * @return the tests are satisfied |
| */ |
| private boolean matchAdditionalTests(T operator) { |
| return this.additionalTests.stream().allMatch(test -> test.test((T) operator)); |
| } |
| |
| private void checkSanity(Operator operator) { |
| if (this.getNumRegularInputs() != operator.getNumRegularInputs()) { |
| throw new IllegalStateException(String.format("%s expected %d inputs, but matched %s with %d inputs.", |
| this, this.getNumRegularInputs(), operator, operator.getNumRegularInputs())); |
| } |
| if (this.getNumOutputs() != operator.getNumOutputs()) { |
| throw new IllegalStateException("Matched an operator with different numbers of outputs."); |
| } |
| } |
| |
| /** |
| * Add an additional {@link Predicate} that must be satisfied in order to establish matches with {@link Operator}s. |
| * |
| * @param additionalTest the {@link Predicate} |
| * @return this instance |
| */ |
| public OperatorPattern<T> withAdditionalTest(Predicate<T> additionalTest) { |
| this.additionalTests.add(additionalTest); |
| return this; |
| } |
| |
| public String getName() { |
| return this.name; |
| } |
| |
| @Override |
| public <Payload, Return> Return accept(TopDownPlanVisitor<Payload, Return> visitor, OutputSlot<?> outputSlot, Payload payload) { |
| throw new RuntimeException("Pattern does not accept visitors."); |
| } |
| |
| @Override |
| public String toString() { |
| return String.format("%s[%d->%d, %s, id=%x]", |
| this.getClass().getSimpleName(), |
| this.getNumInputs(), |
| this.getNumOutputs(), |
| this.operatorClass.getSimpleName(), |
| this.hashCode()); |
| } |
| } |