| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| package org.apache.ratis.examples.arithmetic; |
| |
| import org.apache.ratis.examples.arithmetic.expression.Expression; |
| import org.apache.ratis.io.MD5Hash; |
| import org.apache.ratis.proto.RaftProtos.LogEntryProto; |
| import org.apache.ratis.proto.RaftProtos.RaftPeerRole; |
| import org.apache.ratis.protocol.Message; |
| import org.apache.ratis.protocol.RaftGroupId; |
| import org.apache.ratis.server.RaftServer; |
| import org.apache.ratis.server.protocol.TermIndex; |
| import org.apache.ratis.server.raftlog.RaftLog; |
| import org.apache.ratis.server.storage.FileInfo; |
| import org.apache.ratis.server.storage.RaftStorage; |
| import org.apache.ratis.statemachine.StateMachineStorage; |
| import org.apache.ratis.statemachine.TransactionContext; |
| import org.apache.ratis.statemachine.impl.BaseStateMachine; |
| import org.apache.ratis.statemachine.impl.SimpleStateMachineStorage; |
| import org.apache.ratis.statemachine.impl.SingleFileSnapshotInfo; |
| import org.apache.ratis.util.AutoCloseableLock; |
| import org.apache.ratis.util.FileUtils; |
| import org.apache.ratis.util.JavaUtils; |
| import org.apache.ratis.util.MD5FileUtil; |
| |
| import java.io.BufferedInputStream; |
| import java.io.BufferedOutputStream; |
| import java.io.File; |
| import java.io.IOException; |
| import java.io.ObjectInputStream; |
| import java.io.ObjectOutputStream; |
| import java.util.HashMap; |
| import java.util.Map; |
| import java.util.concurrent.CompletableFuture; |
| import java.util.concurrent.ConcurrentHashMap; |
| import java.util.concurrent.locks.ReentrantReadWriteLock; |
| |
| public class ArithmeticStateMachine extends BaseStateMachine { |
| private final Map<String, Double> variables = new ConcurrentHashMap<>(); |
| |
| private final SimpleStateMachineStorage storage = new SimpleStateMachineStorage(); |
| |
| private final ReentrantReadWriteLock lock = new ReentrantReadWriteLock(true); |
| |
| private AutoCloseableLock readLock() { |
| return AutoCloseableLock.acquire(lock.readLock()); |
| } |
| |
| private AutoCloseableLock writeLock() { |
| return AutoCloseableLock.acquire(lock.writeLock()); |
| } |
| |
| void reset() { |
| variables.clear(); |
| setLastAppliedTermIndex(null); |
| } |
| |
| @Override |
| public void initialize(RaftServer server, RaftGroupId groupId, |
| RaftStorage raftStorage) throws IOException { |
| super.initialize(server, groupId, raftStorage); |
| this.storage.init(raftStorage); |
| loadSnapshot(storage.getLatestSnapshot()); |
| } |
| |
| @Override |
| public void reinitialize() throws IOException { |
| close(); |
| loadSnapshot(storage.getLatestSnapshot()); |
| } |
| |
| @Override |
| public long takeSnapshot() { |
| final Map<String, Double> copy; |
| final TermIndex last; |
| try(AutoCloseableLock readLock = readLock()) { |
| copy = new HashMap<>(variables); |
| last = getLastAppliedTermIndex(); |
| } |
| |
| final File snapshotFile = storage.getSnapshotFile(last.getTerm(), last.getIndex()); |
| LOG.info("Taking a snapshot to file {}", snapshotFile); |
| |
| try(ObjectOutputStream out = new ObjectOutputStream(new BufferedOutputStream( |
| FileUtils.newOutputStream(snapshotFile)))) { |
| out.writeObject(copy); |
| } catch(IOException ioe) { |
| LOG.warn("Failed to write snapshot file \"" + snapshotFile |
| + "\", last applied index=" + last); |
| } |
| |
| final MD5Hash md5 = MD5FileUtil.computeAndSaveMd5ForFile(snapshotFile); |
| final FileInfo info = new FileInfo(snapshotFile.toPath(), md5); |
| storage.updateLatestSnapshot(new SingleFileSnapshotInfo(info, last)); |
| return last.getIndex(); |
| } |
| |
| public long loadSnapshot(SingleFileSnapshotInfo snapshot) throws IOException { |
| if (snapshot == null) { |
| LOG.warn("The snapshot info is null."); |
| return RaftLog.INVALID_LOG_INDEX; |
| } |
| final File snapshotFile = snapshot.getFile().getPath().toFile(); |
| if (!snapshotFile.exists()) { |
| LOG.warn("The snapshot file {} does not exist for snapshot {}", snapshotFile, snapshot); |
| return RaftLog.INVALID_LOG_INDEX; |
| } |
| |
| // verify md5 |
| final MD5Hash md5 = snapshot.getFile().getFileDigest(); |
| if (md5 != null) { |
| MD5FileUtil.verifySavedMD5(snapshotFile, md5); |
| } |
| |
| final TermIndex last = SimpleStateMachineStorage.getTermIndexFromSnapshotFile(snapshotFile); |
| try(AutoCloseableLock writeLock = writeLock(); |
| ObjectInputStream in = new ObjectInputStream(new BufferedInputStream( |
| FileUtils.newInputStream(snapshotFile)))) { |
| reset(); |
| setLastAppliedTermIndex(last); |
| variables.putAll(JavaUtils.cast(in.readObject())); |
| } catch (ClassNotFoundException e) { |
| throw new IllegalStateException("Failed to load " + snapshot, e); |
| } |
| return last.getIndex(); |
| } |
| |
| @Override |
| public StateMachineStorage getStateMachineStorage() { |
| return storage; |
| } |
| |
| @Override |
| public CompletableFuture<Message> query(Message request) { |
| final Expression q = Expression.Utils.bytes2Expression(request.getContent().toByteArray(), 0); |
| final Double result; |
| try(AutoCloseableLock readLock = readLock()) { |
| result = q.evaluate(variables); |
| } |
| final Expression r = Expression.Utils.double2Expression(result); |
| LOG.debug("QUERY: {} = {}", q, r); |
| return CompletableFuture.completedFuture(Expression.Utils.toMessage(r)); |
| } |
| |
| @Override |
| public void close() { |
| reset(); |
| } |
| |
| @Override |
| public CompletableFuture<Message> applyTransaction(TransactionContext trx) { |
| final LogEntryProto entry = trx.getLogEntryUnsafe(); |
| final AssignmentMessage assignment = new AssignmentMessage(entry.getStateMachineLogEntry().getLogData()); |
| |
| final long index = entry.getIndex(); |
| final Double result; |
| try(AutoCloseableLock writeLock = writeLock()) { |
| result = assignment.evaluate(variables); |
| updateLastAppliedTermIndex(entry.getTerm(), index); |
| } |
| final Expression r = Expression.Utils.double2Expression(result); |
| final CompletableFuture<Message> f = CompletableFuture.completedFuture(Expression.Utils.toMessage(r)); |
| |
| final RaftPeerRole role = trx.getServerRole(); |
| if (role == RaftPeerRole.LEADER) { |
| LOG.info("{}:{}-{}: {} = {}", role, getId(), index, assignment, r); |
| } else { |
| LOG.debug("{}:{}-{}: {} = {}", role, getId(), index, assignment, r); |
| } |
| if (LOG.isTraceEnabled()) { |
| LOG.trace("{}-{}: variables={}", getId(), index, variables); |
| } |
| return f; |
| } |
| } |