blob: d5a027910f97a3f707a6d420cc1c355275dded32 [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.apache.ratis.examples.counter.server;
import org.apache.ratis.examples.counter.CounterCommand;
import org.apache.ratis.proto.RaftProtos.LogEntryProto;
import org.apache.ratis.proto.RaftProtos.RaftPeerRole;
import org.apache.ratis.protocol.Message;
import org.apache.ratis.protocol.RaftGroupId;
import org.apache.ratis.server.RaftServer;
import org.apache.ratis.server.protocol.TermIndex;
import org.apache.ratis.server.raftlog.RaftLog;
import org.apache.ratis.server.storage.RaftStorage;
import org.apache.ratis.statemachine.TransactionContext;
import org.apache.ratis.statemachine.impl.BaseStateMachine;
import org.apache.ratis.statemachine.impl.SimpleStateMachineStorage;
import org.apache.ratis.statemachine.impl.SingleFileSnapshotInfo;
import org.apache.ratis.util.JavaUtils;
import java.io.BufferedInputStream;
import java.io.BufferedOutputStream;
import java.io.File;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;
import java.nio.file.Files;
import java.nio.file.Path;
import java.util.concurrent.CompletableFuture;
import java.util.concurrent.atomic.AtomicInteger;
/**
* A {@link org.apache.ratis.statemachine.StateMachine} implementation for the {@link CounterServer}.
* This class maintain a {@link AtomicInteger} object as a state and accept two commands:
*
* - {@link CounterCommand#GET} is a readonly command
* which is handled by the {@link #query(Message)} method.
*
* - {@link CounterCommand#INCREMENT} is a transactional command
* which is handled by the {@link #applyTransaction(TransactionContext)} method.
*/
public class CounterStateMachine extends BaseStateMachine {
/** The state of the {@link CounterStateMachine}. */
static class CounterState {
private final TermIndex applied;
private final int counter;
CounterState(TermIndex applied, int counter) {
this.applied = applied;
this.counter = counter;
}
TermIndex getApplied() {
return applied;
}
int getCounter() {
return counter;
}
}
private final SimpleStateMachineStorage storage = new SimpleStateMachineStorage();
private final AtomicInteger counter = new AtomicInteger(0);
/** @return the current state. */
private synchronized CounterState getState() {
return new CounterState(getLastAppliedTermIndex(), counter.get());
}
private synchronized void updateState(TermIndex applied, int counterValue) {
updateLastAppliedTermIndex(applied);
counter.set(counterValue);
}
private synchronized int incrementCounter(TermIndex termIndex) {
updateLastAppliedTermIndex(termIndex);
return counter.incrementAndGet();
}
/**
* Initialize the state machine storage and then load the state.
*
* @param server the server running this state machine
* @param groupId the id of the {@link org.apache.ratis.protocol.RaftGroup}
* @param raftStorage the storage of the server
* @throws IOException if it fails to load the state.
*/
@Override
public void initialize(RaftServer server, RaftGroupId groupId, RaftStorage raftStorage) throws IOException {
super.initialize(server, groupId, raftStorage);
storage.init(raftStorage);
reinitialize();
}
/**
* Simply load the latest snapshot.
*
* @throws IOException if it fails to load the state.
*/
@Override
public void reinitialize() throws IOException {
load(storage.getLatestSnapshot());
}
/**
* Store the current state as a snapshot file in the {@link #storage}.
*
* @return the index of the snapshot
*/
@Override
public long takeSnapshot() {
//get the current state
final CounterState state = getState();
final long index = state.getApplied().getIndex();
//create a file with a proper name to store the snapshot
final File snapshotFile = storage.getSnapshotFile(state.getApplied().getTerm(), index);
//write the counter value into the snapshot file
try (ObjectOutputStream out = new ObjectOutputStream(new BufferedOutputStream(
Files.newOutputStream(snapshotFile.toPath())))) {
out.writeInt(state.getCounter());
} catch (IOException ioe) {
LOG.warn("Failed to write snapshot file \"" + snapshotFile
+ "\", last applied index=" + state.getApplied());
}
//return the index of the stored snapshot (which is the last applied one)
return index;
}
/**
* Load the state of the state machine from the {@link #storage}.
*
* @param snapshot the information of the snapshot being loaded
* @return the index of the snapshot or -1 if snapshot is invalid
* @throws IOException if it failed to read from storage
*/
private long load(SingleFileSnapshotInfo snapshot) throws IOException {
//check null
if (snapshot == null) {
LOG.warn("The snapshot info is null.");
return RaftLog.INVALID_LOG_INDEX;
}
//check if the snapshot file exists.
final Path snapshotPath = snapshot.getFile().getPath();
if (!Files.exists(snapshotPath)) {
LOG.warn("The snapshot file {} does not exist for snapshot {}", snapshotPath, snapshot);
return RaftLog.INVALID_LOG_INDEX;
}
//read the TermIndex from the snapshot file name
final TermIndex last = SimpleStateMachineStorage.getTermIndexFromSnapshotFile(snapshotPath.toFile());
//read the counter value from the snapshot file
final int counterValue;
try (ObjectInputStream in = new ObjectInputStream(new BufferedInputStream(Files.newInputStream(snapshotPath)))) {
counterValue = in.readInt();
}
//update state
updateState(last, counterValue);
return last.getIndex();
}
/**
* Process {@link CounterCommand#GET}, which gets the counter value.
*
* @param request the GET request
* @return a {@link Message} containing the current counter value as a {@link String}.
*/
@Override
public CompletableFuture<Message> query(Message request) {
final String command = request.getContent().toStringUtf8();
if (!CounterCommand.GET.matches(command)) {
return JavaUtils.completeExceptionally(new IllegalArgumentException("Invalid Command: " + command));
}
return CompletableFuture.completedFuture(Message.valueOf(counter.toString()));
}
/**
* Apply the {@link CounterCommand#INCREMENT} by incrementing the counter object.
*
* @param trx the transaction context
* @return the message containing the updated counter value
*/
@Override
public CompletableFuture<Message> applyTransaction(TransactionContext trx) {
final LogEntryProto entry = trx.getLogEntry();
//check if the command is valid
final String command = entry.getStateMachineLogEntry().getLogData().toStringUtf8();
if (!CounterCommand.INCREMENT.matches(command)) {
return JavaUtils.completeExceptionally(new IllegalArgumentException("Invalid Command: " + command));
}
//increment the counter and update term-index
final TermIndex termIndex = TermIndex.valueOf(entry);
final long incremented = incrementCounter(termIndex);
//if leader, log the incremented value and the term-index
if (trx.getServerRole() == RaftPeerRole.LEADER) {
LOG.info("{}: Increment to {}", termIndex, incremented);
}
//return the new value of the counter to the client
return CompletableFuture.completedFuture(Message.valueOf(String.valueOf(incremented)));
}
}