| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| package org.apache.ratis.examples.arithmetic; |
| |
| import org.apache.ratis.server.impl.MiniRaftCluster; |
| import org.apache.ratis.client.RaftClient; |
| import org.apache.ratis.examples.ParameterizedBaseTest; |
| import org.apache.ratis.examples.arithmetic.expression.DoubleValue; |
| import org.apache.ratis.examples.arithmetic.expression.Expression; |
| import org.apache.ratis.examples.arithmetic.expression.NullValue; |
| import org.apache.ratis.examples.arithmetic.expression.Variable; |
| import org.apache.ratis.protocol.RaftClientReply; |
| import org.apache.ratis.util.Slf4jUtils; |
| import org.apache.ratis.util.Preconditions; |
| import org.junit.jupiter.api.Assertions; |
| import org.junit.jupiter.params.ParameterizedTest; |
| import org.junit.jupiter.params.provider.MethodSource; |
| import org.slf4j.event.Level; |
| |
| import java.io.IOException; |
| import java.util.Collection; |
| |
| import static org.apache.ratis.examples.arithmetic.expression.BinaryExpression.Op.*; |
| import static org.apache.ratis.examples.arithmetic.expression.UnaryExpression.Op.SQRT; |
| import static org.apache.ratis.examples.arithmetic.expression.UnaryExpression.Op.SQUARE; |
| |
| public class TestArithmetic extends ParameterizedBaseTest { |
| { |
| Slf4jUtils.setLogLevel(ArithmeticStateMachine.LOG, Level.DEBUG); |
| } |
| |
| public static Collection<Object[]> data() { |
| return getMiniRaftClusters(ArithmeticStateMachine.class, 3); |
| } |
| |
| @ParameterizedTest |
| @MethodSource("data") |
| public void testPythagorean(MiniRaftCluster cluster) throws Exception { |
| setAndStart(cluster); |
| try (final RaftClient client = cluster.createClient()) { |
| runTestPythagorean(client, 3, 10); |
| } |
| } |
| |
| public static void runTestPythagorean( |
| RaftClient client, int start, int count) throws IOException { |
| Preconditions.assertTrue(count > 0, () -> "count = " + count + " <= 0"); |
| Preconditions.assertTrue(start >= 2, () -> "start = " + start + " < 2"); |
| |
| final Variable a = new Variable("a"); |
| final Variable b = new Variable("b"); |
| final Variable c = new Variable("c"); |
| final Expression pythagorean = SQRT.apply(ADD.apply(SQUARE.apply(a), SQUARE.apply(b))); |
| |
| final int end = start + 2*count; |
| for(int n = (start & 1) == 0? start + 1: start; n < end; n += 2) { |
| int n2 = n*n; |
| int half_n2 = n2/2; |
| |
| assign(client, a, n); |
| assign(client, b, half_n2); |
| assign(client, c, pythagorean, (double)half_n2 + 1); |
| |
| assignNull(client, a); |
| assignNull(client, b); |
| assignNull(client, c); |
| } |
| } |
| |
| @ParameterizedTest |
| @MethodSource("data") |
| public void testGaussLegendre(MiniRaftCluster cluster) throws Exception { |
| setAndStart(cluster); |
| try (final RaftClient client = cluster.createClient()) { |
| runGaussLegendre(client); |
| } |
| } |
| |
| void runGaussLegendre(RaftClient client) throws IOException { |
| defineVariable(client, "a0", 1); |
| defineVariable(client, "b0", DIV.apply(1, SQRT.apply(2))); |
| defineVariable(client, "t0", DIV.apply(1, 4)); |
| defineVariable(client, "p0", 1); |
| |
| double previous = 0; |
| boolean converged = false; |
| for(int i = 1; i < 8; i++) { |
| final int i_1 = i - 1; |
| final Variable a0 = new Variable("a" + i_1); |
| final Variable b0 = new Variable("b" + i_1); |
| final Variable t0 = new Variable("t" + i_1); |
| final Variable p0 = new Variable("p" + i_1); |
| final Variable a1 = defineVariable(client, "a"+i, DIV.apply(ADD.apply(a0, b0), 2)); |
| final Variable b1 = defineVariable(client, "b"+i, SQRT.apply(MULT.apply(a0, b0))); |
| final Variable t1 = defineVariable(client, "t"+i, |
| SUBTRACT.apply(t0, MULT.apply(p0, SQUARE.apply(SUBTRACT.apply(a0, a1))))); |
| final Variable p1 = defineVariable(client, "p"+i, MULT.apply(2, p0)); |
| |
| final Variable pi_i = new Variable("pi_"+i); |
| final Expression e = assign(client, pi_i, DIV.apply(SQUARE.apply(a1), t0)); |
| final double pi = e.evaluate(null); |
| |
| if (converged) { |
| Assertions.assertEquals(pi, previous); |
| } else if (pi == previous) { |
| converged = true; |
| } |
| LOG.info("{} = {}, converged? {}", pi_i, pi, converged); |
| previous = pi; |
| } |
| Assertions.assertTrue(converged); |
| } |
| |
| static Variable defineVariable(RaftClient client, String name, double value) throws IOException { |
| final Variable x = new Variable(name); |
| assign(client, x, value); |
| return x; |
| } |
| |
| static Variable defineVariable(RaftClient client, String name, Expression e) throws IOException { |
| final Variable x = new Variable(name); |
| assign(client, x, e, null); |
| return x; |
| } |
| |
| static Expression assign(RaftClient client, Variable x, double value) throws IOException { |
| return assign(client, x, new DoubleValue(value), value); |
| } |
| |
| static void assignNull(RaftClient client, Variable x) throws IOException { |
| final Expression e = assign(client, x, NullValue.getInstance()); |
| Assertions.assertEquals(NullValue.getInstance(), e); |
| } |
| |
| static Expression assign(RaftClient client, Variable x, Expression e) throws IOException { |
| return assign(client, x, e, null); |
| } |
| |
| static Expression assign(RaftClient client, Variable x, Expression e, Double expected) throws IOException { |
| final RaftClientReply r = client.io().send(x.assign(e)); |
| return assertRaftClientReply(r, expected); |
| } |
| |
| static Expression assertRaftClientReply(RaftClientReply reply, Double expected) { |
| Assertions.assertTrue(reply.isSuccess()); |
| final Expression e = Expression.Utils.bytes2Expression( |
| reply.getMessage().getContent().toByteArray(), 0); |
| if (expected != null) { |
| Assertions.assertEquals(expected, e.evaluate(null)); |
| } |
| return e; |
| } |
| } |