| /* |
| * Licensed to the Apache Software Foundation (ASF) under one or more |
| * contributor license agreements. See the NOTICE file distributed with |
| * this work for additional information regarding copyright ownership. |
| * The ASF licenses this file to You under the Apache License, Version 2.0 |
| * (the "License"); you may not use this file except in compliance with |
| * the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, software |
| * distributed under the License is distributed on an "AS IS" BASIS, |
| * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| * See the License for the specific language governing permissions and |
| * limitations under the License. |
| */ |
| |
| package org.apache.commons.math4.legacy.ode; |
| |
| import java.io.Serializable; |
| |
| import org.apache.commons.math4.legacy.core.RealFieldElement; |
| import org.apache.commons.math4.legacy.exception.DimensionMismatchException; |
| import org.apache.commons.math4.legacy.exception.MathIllegalArgumentException; |
| import org.apache.commons.math4.legacy.exception.util.LocalizedFormats; |
| import org.apache.commons.math4.legacy.core.MathArrays; |
| |
| /** |
| * Class mapping the part of a complete state or derivative that pertains |
| * to a set of differential equations. |
| * <p> |
| * Instances of this class are guaranteed to be immutable. |
| * </p> |
| * @see FieldExpandableODE |
| * @param <T> the type of the field elements |
| * @since 3.6 |
| */ |
| public class FieldEquationsMapper<T extends RealFieldElement<T>> implements Serializable { |
| |
| /** Serializable UID. */ |
| private static final long serialVersionUID = 20151114L; |
| |
| /** Start indices of the components. */ |
| private final int[] start; |
| |
| /** Create a mapper by adding a new equation to another mapper. |
| * <p> |
| * The new equation will have index {@code mapper.}{@link #getNumberOfEquations()}, |
| * or 0 if {@code mapper} is null. |
| * </p> |
| * @param mapper former mapper, with one equation less (null for first equation) |
| * @param dimension dimension of the equation state vector |
| */ |
| FieldEquationsMapper(final FieldEquationsMapper<T> mapper, final int dimension) { |
| final int index = (mapper == null) ? 0 : mapper.getNumberOfEquations(); |
| this.start = new int[index + 2]; |
| if (mapper == null) { |
| start[0] = 0; |
| } else { |
| System.arraycopy(mapper.start, 0, start, 0, index + 1); |
| } |
| start[index + 1] = start[index] + dimension; |
| } |
| |
| /** Get the number of equations mapped. |
| * @return number of equations mapped |
| */ |
| public int getNumberOfEquations() { |
| return start.length - 1; |
| } |
| |
| /** Return the dimension of the complete set of equations. |
| * <p> |
| * The complete set of equations correspond to the primary set plus all secondary sets. |
| * </p> |
| * @return dimension of the complete set of equations |
| */ |
| public int getTotalDimension() { |
| return start[start.length - 1]; |
| } |
| |
| /** Map a state to a complete flat array. |
| * @param state state to map |
| * @return flat array containing the mapped state, including primary and secondary components |
| */ |
| public T[] mapState(final FieldODEState<T> state) { |
| final T[] y = MathArrays.buildArray(state.getTime().getField(), getTotalDimension()); |
| int index = 0; |
| insertEquationData(index, state.getState(), y); |
| while (++index < getNumberOfEquations()) { |
| insertEquationData(index, state.getSecondaryState(index), y); |
| } |
| return y; |
| } |
| |
| /** Map a state derivative to a complete flat array. |
| * @param state state to map |
| * @return flat array containing the mapped state derivative, including primary and secondary components |
| */ |
| public T[] mapDerivative(final FieldODEStateAndDerivative<T> state) { |
| final T[] yDot = MathArrays.buildArray(state.getTime().getField(), getTotalDimension()); |
| int index = 0; |
| insertEquationData(index, state.getDerivative(), yDot); |
| while (++index < getNumberOfEquations()) { |
| insertEquationData(index, state.getSecondaryDerivative(index), yDot); |
| } |
| return yDot; |
| } |
| |
| /** Map flat arrays to a state and derivative. |
| * @param t time |
| * @param y state array to map, including primary and secondary components |
| * @param yDot state derivative array to map, including primary and secondary components |
| * @return mapped state |
| * @exception DimensionMismatchException if an array does not match total dimension |
| */ |
| public FieldODEStateAndDerivative<T> mapStateAndDerivative(final T t, final T[] y, final T[] yDot) |
| throws DimensionMismatchException { |
| |
| if (y.length != getTotalDimension()) { |
| throw new DimensionMismatchException(y.length, getTotalDimension()); |
| } |
| |
| if (yDot.length != getTotalDimension()) { |
| throw new DimensionMismatchException(yDot.length, getTotalDimension()); |
| } |
| |
| final int n = getNumberOfEquations(); |
| int index = 0; |
| final T[] state = extractEquationData(index, y); |
| final T[] derivative = extractEquationData(index, yDot); |
| if (n < 2) { |
| return new FieldODEStateAndDerivative<>(t, state, derivative); |
| } else { |
| final T[][] secondaryState = MathArrays.buildArray(t.getField(), n - 1, -1); |
| final T[][] secondaryDerivative = MathArrays.buildArray(t.getField(), n - 1, -1); |
| while (++index < getNumberOfEquations()) { |
| secondaryState[index - 1] = extractEquationData(index, y); |
| secondaryDerivative[index - 1] = extractEquationData(index, yDot); |
| } |
| return new FieldODEStateAndDerivative<>(t, state, derivative, secondaryState, secondaryDerivative); |
| } |
| } |
| |
| /** Extract equation data from a complete state or derivative array. |
| * @param index index of the equation, must be between 0 included and |
| * {@link #getNumberOfEquations()} (excluded) |
| * @param complete complete state or derivative array from which |
| * equation data should be retrieved |
| * @return equation data |
| * @exception MathIllegalArgumentException if index is out of range |
| * @exception DimensionMismatchException if complete state has not enough elements |
| */ |
| public T[] extractEquationData(final int index, final T[] complete) |
| throws MathIllegalArgumentException, DimensionMismatchException { |
| checkIndex(index); |
| final int begin = start[index]; |
| final int end = start[index + 1]; |
| if (complete.length < end) { |
| throw new DimensionMismatchException(complete.length, end); |
| } |
| final int dimension = end - begin; |
| final T[] equationData = MathArrays.buildArray(complete[0].getField(), dimension); |
| System.arraycopy(complete, begin, equationData, 0, dimension); |
| return equationData; |
| } |
| |
| /** Insert equation data into a complete state or derivative array. |
| * @param index index of the equation, must be between 0 included and |
| * {@link #getNumberOfEquations()} (excluded) |
| * @param equationData equation data to be inserted into the complete array |
| * @param complete placeholder where to put equation data (only the |
| * part corresponding to the equation will be overwritten) |
| * @exception DimensionMismatchException if either array has not enough elements |
| */ |
| public void insertEquationData(final int index, T[] equationData, T[] complete) |
| throws DimensionMismatchException { |
| checkIndex(index); |
| final int begin = start[index]; |
| final int end = start[index + 1]; |
| final int dimension = end - begin; |
| if (complete.length < end) { |
| throw new DimensionMismatchException(complete.length, end); |
| } |
| if (equationData.length != dimension) { |
| throw new DimensionMismatchException(equationData.length, dimension); |
| } |
| System.arraycopy(equationData, 0, complete, begin, dimension); |
| } |
| |
| /** Check equation index. |
| * @param index index of the equation, must be between 0 included and |
| * {@link #getNumberOfEquations()} (excluded) |
| * @exception MathIllegalArgumentException if index is out of range |
| */ |
| private void checkIndex(final int index) throws MathIllegalArgumentException { |
| if (index < 0 || index > start.length - 2) { |
| throw new MathIllegalArgumentException(LocalizedFormats.ARGUMENT_OUTSIDE_DOMAIN, |
| index, 0, start.length - 2); |
| } |
| } |
| |
| } |