blob: 785c5f1d07046b6a5a1c111a0910ddc991519ee9 [file] [log] [blame]
/**
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.aurora.common.util;
import java.util.Arrays;
import java.util.List;
import java.util.Set;
import java.util.concurrent.locks.Lock;
import java.util.concurrent.locks.ReadWriteLock;
import java.util.concurrent.locks.ReentrantReadWriteLock;
import java.util.function.Consumer;
import com.google.common.base.Function;
import com.google.common.base.Preconditions;
import com.google.common.base.Predicate;
import com.google.common.base.Predicates;
import com.google.common.collect.HashMultimap;
import com.google.common.collect.ImmutableSet;
import com.google.common.collect.Iterables;
import com.google.common.collect.Lists;
import com.google.common.collect.Multimap;
import org.apache.aurora.common.base.Consumers;
import org.apache.commons.lang.builder.HashCodeBuilder;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
import static com.google.common.base.Preconditions.checkArgument;
import static com.google.common.base.Preconditions.checkNotNull;
import static org.apache.aurora.common.base.MorePreconditions.checkNotBlank;
/**
* Represents a state machine that is not necessarily a Finite State Machine.
* The caller may configure the state machine to permit only known state transitions, or to only
* disallow known state transitions (and permit unknown transitions).
*
* @param <T> THe type of objects that the caller uses to represent states.
*
* TODO(William Farner): Consider merging the stats-tracking ala PipelineStats into this.
*/
public class StateMachine<T> {
private static final Logger LOG = LoggerFactory.getLogger(StateMachine.class);
private final String name;
// Stores mapping from states to the states that the machine is allowed to transition into.
private final Multimap<T, T> stateTransitions;
private final Consumer<Transition<T>> transitionCallback;
private final boolean throwOnBadTransition;
private volatile T currentState;
private final Lock readLock;
private final Lock writeLock;
private StateMachine(String name,
T initialState,
Multimap<T, T> stateTransitions,
Consumer<Transition<T>> transitionCallback,
boolean throwOnBadTransition) {
this.name = name;
this.currentState = initialState;
this.stateTransitions = stateTransitions;
this.transitionCallback = transitionCallback;
this.throwOnBadTransition = throwOnBadTransition;
ReadWriteLock stateLock = new ReentrantReadWriteLock(true /* fair */);
readLock = stateLock.readLock();
writeLock = stateLock.writeLock();
}
/**
* Gets the name of this state machine.
*
* @return The state machine name.
*/
public String getName() {
return name;
}
/**
* Fetches the state that the machine is currently in.
*
* @return Current state.
*/
public T getState() {
return currentState;
}
/**
* Checks that the current state is the {@code expectedState} and throws if it is not.
*
* @param expectedState The expected state
* @throws IllegalStateException if the current state is not the {@code expectedState}.
*/
public void checkState(T expectedState) {
checkState(ImmutableSet.of(expectedState));
}
/**
* Checks that the current state is one of the {@code allowedStates} and throws if it is not.
*
* @param allowedStates The allowed states.
* @throws IllegalStateException if the current state is not the {@code expectedState}.
*/
public void checkState(Set<T> allowedStates) {
checkNotNull(allowedStates);
checkArgument(!allowedStates.isEmpty(), "At least one possible state must be provided.");
readLock.lock();
try {
if (!allowedStates.contains(currentState)) {
throw new IllegalStateException(
String.format("In state %s, expected to be in %s.", currentState, allowedStates));
}
} finally {
readLock.unlock();
}
}
/**
* Transitions the machine into state {@code nextState}.
*
* @param nextState The state to move into.
* @throws IllegalStateTransitionException If the state transition is not allowed.
* @return {@code true} if the transition was allowed, {@code false} otherwise.
*/
public boolean transition(T nextState) throws IllegalStateTransitionException {
boolean transitionAllowed = false;
T currentCopy = currentState;
writeLock.lock();
try {
if (stateTransitions.containsEntry(currentState, nextState)) {
currentState = nextState;
transitionAllowed = true;
} else if (throwOnBadTransition) {
throw new IllegalStateTransitionException(
String.format("State transition from %s to %s is not allowed.", currentState,
nextState));
}
} finally {
writeLock.unlock();
}
transitionCallback.accept(new Transition<T>(currentCopy, nextState, transitionAllowed));
return transitionAllowed;
}
public static class IllegalStateTransitionException extends IllegalStateException {
public IllegalStateTransitionException(String msg) {
super(msg);
}
}
/**
* Convenience method to create a builder object.
*
* @param <T> Type of builder to create.
* @param name Name of the state machine to create a builder for.
* @return New builder.
*/
public static <T> Builder<T> builder(String name) {
return new Builder<T>(name);
}
/**
* A state and its allowed transitions (if any) and (optional) callback.
*
* @param <T> State type.
*/
public static class Rule<T> {
private final T from;
private final Set<T> to;
private final Consumer<Transition<T>> callback;
private Rule(T from) {
this(from, ImmutableSet.<T>of());
}
private Rule(T from, Set<T> to) {
this(from, to, Consumers.<Transition<T>>noop());
}
private Rule(T from, Set<T> to, Consumer<Transition<T>> callback) {
this.from = checkNotNull(from);
this.to = checkNotNull(to);
this.callback = checkNotNull(callback);
}
/**
* Associates a callback to be triggered after any attempt to transition from this state is
* made.
*
* @param callback Callback to signal.
* @return A new rule that is identical to this rule, but with the provided
* callback
*/
public Rule<T> withCallback(Consumer<Transition<T>> callback) {
return new Rule<T>(from, to, callback);
}
/**
* A helper class when building a transition rule, to define the allowed transitions.
*
* @param <T> State type.
*/
public static class AllowedTransition<T> {
private final Rule<T> rule;
private AllowedTransition(Rule<T> rule) {
this.rule = rule;
}
/**
* Associates a single allowed transition with this state.
*
* @param state Allowed transition state.
* @return A new rule that identical to the original, but only allowing a transition to the
* provided state.
*/
public Rule<T> to(T state) {
return new Rule<T>(rule.from, ImmutableSet.<T>of(state), rule.callback);
}
/**
* Associates multiple transitions with this state.
*
* @param state An allowed transition state.
* @param additionalStates Additional states that may be transitioned to.
* @return A new rule that identical to the original, but only allowing a transition to the
* provided states.
*/
public Rule<T> to(T state, T... additionalStates) {
return new Rule<T>(rule.from, ImmutableSet.copyOf(Lists.asList(state, additionalStates)));
}
/**
* Allows no transitions to be performed from this state.
*
* @return The original rule.
*/
public Rule<T> noTransitions() {
return rule;
}
}
/**
* Creates a new transition rule.
*
* @param state State to create and associate transitions with.
* @param <T> State type.
* @return A new transition rule builder.
*/
public static <T> AllowedTransition<T> from(T state) {
return new AllowedTransition<T>(new Rule<T>(state));
}
}
/**
* Builder to create a state machine.
*
* @param <T>
*/
public static class Builder<T> {
private final String name;
private T initialState;
private final Multimap<T, T> stateTransitions = HashMultimap.create();
private final List<Consumer<Transition<T>>> transitionCallbacks = Lists.newArrayList();
private boolean throwOnBadTransition = true;
public Builder(String name) {
this.name = checkNotBlank(name);
}
/**
* Sets the initial state for the state machine.
*
* @param state Initial state.
* @return A reference to the builder.
*/
public Builder<T> initialState(T state) {
checkNotNull(state);
initialState = state;
return this;
}
/**
* Adds a state and its allowed transitions.
*
* @param rule The state and transition rule to add.
* @return A reference to the builder.
*/
public Builder<T> addState(Rule<T> rule) {
return addState(rule.callback, rule.from, rule.to);
}
/**
* Adds a state and its allowed transitions.
* At least one transition state must be added, it is not necessary to explicitly add states
* that have no allowed transitions (terminal states).
*
* @param callback Callback to notify of any transition attempted from the state.
* @param state State to add.
* @param transitionStates Allowed transitions from {@code state}.
* @return A reference to the builder.
*/
public Builder<T> addState(Consumer<Transition<T>> callback, T state,
Set<T> transitionStates) {
checkNotNull(callback);
checkNotNull(state);
Preconditions.checkArgument(Iterables.all(transitionStates, Predicates.notNull()));
stateTransitions.putAll(state, transitionStates);
@SuppressWarnings("unchecked")
Predicate<Transition<T>> filter = Transition.from(state);
onTransition(filter, callback);
return this;
}
/**
* Varargs version of {@link #addState(Consumer, Object, java.util.Set)}.
*
* @param callback Callback to notify of any transition attempted from the state.
* @param state State to add.
* @param transitionStates Allowed transitions from {@code state}.
* @return A reference to the builder.
*/
public Builder<T> addState(Consumer<Transition<T>> callback, T state,
T... transitionStates) {
Set<T> states = ImmutableSet.copyOf(transitionStates);
Preconditions.checkArgument(Iterables.all(states, Predicates.notNull()));
return addState(callback, state, states);
}
/**
* Adds a state and its allowed transitions.
* At least one transition state must be added, it is not necessary to explicitly add states
* that have no allowed transitions (terminal states).
*
* @param state State to add.
* @param transitionStates Allowed transitions from {@code state}.
* @return A reference to the builder.
*/
public Builder<T> addState(T state, T... transitionStates) {
return addState(Consumers.<Transition<T>>noop(), state, transitionStates);
}
private void onTransition(Predicate<Transition<T>> transitionFilter,
Consumer<Transition<T>> handler) {
onAnyTransition(Consumers.filter(transitionFilter, handler));
}
/**
* Adds a callback to be executed for every state transition, including invalid transitions
* that are attempted.
*
* @param handler Callback to notify of transition attempts.
* @return A reference to the builder.
*/
public Builder<T> onAnyTransition(Consumer<Transition<T>> handler) {
transitionCallbacks.add(handler);
return this;
}
/**
* Adds a log message for every state transition that is attempted.
*
* @return A reference to the builder.
*/
public Builder<T> logTransitions() {
return onAnyTransition(transition -> LOG.info(name + " state machine transition " + transition));
}
/**
* Allows the caller to specify whether {@link IllegalStateTransitionException} should be thrown
* when a bad state transition is attempted (the default behavior).
*
* @param throwOnBadTransition Whether an exception should be thrown when a bad state transition
* is attempted.
* @return A reference to the builder.
*/
public Builder<T> throwOnBadTransition(boolean throwOnBadTransition) {
this.throwOnBadTransition = throwOnBadTransition;
return this;
}
/**
* Builds the state machine.
*
* @return A reference to the prepared state machine.
*/
public StateMachine<T> build() {
Preconditions.checkState(initialState != null, "Initial state must be specified.");
checkArgument(!stateTransitions.isEmpty(), "No state transitions were specified.");
return new StateMachine<T>(name,
initialState,
stateTransitions,
Consumers.combine(transitionCallbacks),
throwOnBadTransition);
}
}
/**
* Representation of a state transition.
*
* @param <T> State type.
*/
public static class Transition<T> {
private final T from;
private final T to;
private final boolean allowed;
public Transition(T from, T to, boolean allowed) {
this.from = checkNotNull(from);
this.to = checkNotNull(to);
this.allowed = allowed;
}
private static <T> Function<Transition<T>, T> from() {
return transition -> transition.from;
}
private static <T> Function<Transition<T>, T> to() {
return transition -> transition.to;
}
private static <T> Predicate<Transition<T>> oneSideFilter(
Function<Transition<T>, T> extractor, final T... states) {
checkArgument(Iterables.all(Arrays.asList(states), Predicates.notNull()));
return Predicates.compose(Predicates.in(ImmutableSet.copyOf(states)), extractor);
}
/**
* Creates a predicate that returns {@code true} for transitions from the given states.
*
* @param states States to filter on.
* @param <T> State type.
* @return A from-state filter.
*/
public static <T> Predicate<Transition<T>> from(final T... states) {
return oneSideFilter(Transition.<T>from(), states);
}
/**
* Creates a predicate that returns {@code true} for transitions to the given states.
*
* @param states States to filter on.
* @param <T> State type.
* @return A to-state filter.
*/
public static <T> Predicate<Transition<T>> to(final T... states) {
return oneSideFilter(Transition.<T>to(), states);
}
/**
* Creates a predicate that returns {@code true} for a specific state transition.
*
* @param from From state.
* @param to To state.
* @param <T> State type.
* @return A state transition filter.
*/
public static <T> Predicate<Transition<T>> transition(final T from, final T to) {
@SuppressWarnings("unchecked")
Predicate<Transition<T>> fromFilter = from(from);
@SuppressWarnings("unchecked")
Predicate<Transition<T>> toFilter = to(to);
return Predicates.and(fromFilter, toFilter);
}
public T getFrom() {
return from;
}
public T getTo() {
return to;
}
public boolean isAllowed() {
return allowed;
}
/**
* Checks whether this transition represents a state change, which means that the 'to' state is
* not equal to the 'from' state, and the transition is allowed.
*
* @return {@code true} if the state was changed, {@code false} otherwise.
*/
public boolean isValidStateChange() {
return isAllowed() && !from.equals(to);
}
@Override
public boolean equals(Object o) {
if (!(o instanceof Transition)) {
return false;
}
if (o == this) {
return true;
}
Transition<?> other = (Transition) o;
return from.equals(other.from) && to.equals(other.to);
}
@Override
public int hashCode() {
return new HashCodeBuilder()
.append(from)
.append(to)
.toHashCode();
}
@Override
public String toString() {
String str = from.toString() + " -> " + to.toString();
if (!isAllowed()) {
str += " (not allowed)";
}
return str;
}
}
}