| /* |
| * Licensed to the Apache Software Foundation (ASF) under one |
| * or more contributor license agreements. See the NOTICE file |
| * distributed with this work for additional information |
| * regarding copyright ownership. The ASF licenses this file |
| * to you under the Apache License, Version 2.0 (the |
| * "License"); you may not use this file except in compliance |
| * with the License. You may obtain a copy of the License at |
| * |
| * http://www.apache.org/licenses/LICENSE-2.0 |
| * |
| * Unless required by applicable law or agreed to in writing, |
| * software distributed under the License is distributed on an |
| * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY |
| * KIND, either express or implied. See the License for the |
| * specific language governing permissions and limitations |
| * under the License. |
| */ |
| |
| package org.apache.datasketches.vector.regression; |
| |
| import static org.apache.datasketches.memory.UnsafeUtil.unsafe; |
| |
| import org.apache.datasketches.memory.Memory; |
| import org.apache.datasketches.memory.WritableMemory; |
| import org.apache.datasketches.vector.MatrixFamily; |
| import org.checkerframework.checker.nullness.qual.NonNull; |
| |
| /** |
| * Computes mean and variance for each of d dimensions of an input vector using Welford's online algorithm, |
| * as described in https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance |
| * <p> |
| * For serialized images, multi-byte integers (<tt>int</tt> and <tt>long</tt>) are stored in native byte |
| * order. All <tt>byte</tt> values are treated as unsigned.</p> |
| * |
| * <p>An empty object requires 8 bytes. A non-empty sketch requires 16 bytes |
| * of preamble.</p> |
| * |
| * <pre> |
| * Long || Start Byte Adr: |
| * Adr: |
| * || 0 | 1 | 2 | 3 | 4 | 5 | 6 | 7 | |
| * 0 || Preamble_Longs | SerVer | FamID | Flags |---------Vector Dim. (d)---------| |
| * |
| * || 8 | 9 | 10 | 11 | 12 | 13 | 14 | 15 | |
| * 1 ||-------------------------Num. Vectors Processed (n)--------------------------| |
| * |
| * || 16 | 17 | 18 | 19 | 20 | 21 | 22 | 23 | |
| * 2 ||---------------------------Intercept (target mean)---------------------------| |
| * |
| * || 24 | 25 | 26 | 27 | 28 | 29 | 30 | 31 | |
| * 3 ||-----------------------------start of mean array-----------------------------| |
| * </pre> |
| * |
| * @author Jon Malkin |
| |
| */ |
| public class VectorNormalizer { |
| private final int d_; |
| private final double[] mean_; |
| private final double[] M2_; |
| private double intercept_; |
| private long n_; |
| |
| // Preamble byte Addresses |
| static final int PREAMBLE_LONGS_BYTE = 0; |
| static final int SER_VER_BYTE = 1; |
| static final int FAMILY_BYTE = 2; |
| static final int FLAGS_BYTE = 3; |
| static final int D_INT = 4; |
| static final int N_LONG = 8; |
| |
| // flag bit masks |
| static final int EMPTY_FLAG_MASK = 4; |
| |
| // Other constants |
| static final int SER_VER = 1; |
| |
| |
| /** |
| * Creates a new, empty VectorNormalizer |
| * @param d The number of dimensions the VectorNormalizer holds |
| */ |
| public VectorNormalizer(final int d) { |
| if (d < 1) |
| throw new IllegalArgumentException("d cannot be < 1. Found: " + d); |
| |
| d_ = d; |
| n_ = 0; |
| intercept_ = 0.0; |
| mean_ = new double[d_]; |
| M2_ = new double[d_]; |
| } |
| |
| /** |
| * Copy constructor |
| * @param other The VectorNormalizer to copy |
| */ |
| public VectorNormalizer(final VectorNormalizer other) { |
| d_ = other.d_; |
| n_ = other.n_; |
| intercept_ = other.intercept_; |
| mean_ = other.mean_.clone(); |
| M2_ = other.M2_.clone(); |
| } |
| |
| private VectorNormalizer(final int d, final long n, final double[] mean, final double[] M2, final double intercept) { |
| d_ = d; |
| n_ = n; |
| intercept_ = intercept; |
| mean_ = mean; |
| M2_ = M2; |
| } |
| |
| /** |
| * Instantiates a VectorNormalizer object from a serialized image |
| * @param srcMem Memory containing the serialized image of a VectorNormalizer object |
| * @return A VectorNormalizer, or null if srcMem is null |
| */ |
| static VectorNormalizer heapify(final Memory srcMem) { |
| if (srcMem == null) { return null; } |
| |
| final int preLongs = getAndCheckPreLongs(srcMem); |
| if (preLongs < MatrixFamily.VECTORNORMALIZER.getMinPreLongs() |
| || preLongs > MatrixFamily.VECTORNORMALIZER.getMaxPreLongs()) { |
| throw new IllegalArgumentException("Possible corruption: Invalid number of preamble longs: " + preLongs); |
| } |
| |
| final int serVer = extractSerVer(srcMem); |
| if (serVer != SER_VER) { |
| throw new IllegalArgumentException("Invalid serialization version: " + serVer); |
| } |
| |
| final int family = extractFamilyID(srcMem); |
| if (family != MatrixFamily.VECTORNORMALIZER.getID()) { |
| throw new IllegalArgumentException("Possible corruption: Family id (" + family + ") " |
| + "is not a VectorNormalization image"); |
| } |
| |
| final boolean empty = (extractFlags(srcMem) & EMPTY_FLAG_MASK) > 0; |
| final int d = extractD(srcMem); |
| if (d < 1) |
| throw new IllegalArgumentException("Possible corruption: d cannot be < 1. Found: " + d); |
| |
| if (empty) { |
| if (preLongs != MatrixFamily.VECTORNORMALIZER.getMinPreLongs()) { |
| throw new IllegalArgumentException("Possible corruption: Empty flag set but header indicates image has data."); |
| } |
| return new VectorNormalizer(d); |
| } |
| |
| if (preLongs == MatrixFamily.VECTORNORMALIZER.getMinPreLongs()) { |
| throw new IllegalArgumentException("Possible corruption: Non-empty image too small to contain serialized data"); |
| } |
| |
| final long n = extractN(srcMem); |
| if (n <= 0) |
| throw new IllegalArgumentException("Possible corruption: n must be positive for a non-empty sketch. Found: " + n); |
| |
| long offsetBytes = (long) preLongs * Long.BYTES; |
| |
| // check capacity for the rest |
| final long bytesNeeded = offsetBytes + (((2L * d) + 1) * Double.BYTES); |
| if (srcMem.getCapacity() < bytesNeeded) { |
| throw new IllegalArgumentException( |
| "Possible Corruption: Size of Memory not large enough: Size: " + srcMem.getCapacity() |
| + ", Required: " + bytesNeeded); |
| } |
| |
| final double intercept = srcMem.getDouble(offsetBytes); |
| offsetBytes += Double.BYTES; |
| |
| final double[] mean = new double[d]; |
| srcMem.getDoubleArray(offsetBytes, mean, 0, d); |
| offsetBytes += (long) d * Double.BYTES; |
| |
| final double[] M2 = new double[d]; |
| srcMem.getDoubleArray(offsetBytes, M2, 0, d); |
| |
| return new VectorNormalizer(d, n, mean, M2, intercept); |
| } |
| |
| /** |
| * Returns an array of bytes with a serialized image of this object. |
| * @return A <tt>byte[]</tt> containing the serialized image of this object. |
| */ |
| public byte[] toByteArray() { |
| final boolean empty = isEmpty(); |
| final int familyId = MatrixFamily.VECTORNORMALIZER.getID(); |
| |
| final int preLongs = empty |
| ? MatrixFamily.VECTORNORMALIZER.getMinPreLongs() |
| : MatrixFamily.VECTORNORMALIZER.getMaxPreLongs(); |
| |
| final int outBytes = (preLongs * Long.BYTES) + (empty ? 0 : (1 + 2 * d_) * Double.BYTES); |
| final byte[] outArr = new byte[outBytes]; |
| final WritableMemory memOut = WritableMemory.wrap(outArr); |
| final Object memObj = memOut.getArray(); |
| final long memAddr = memOut.getCumulativeOffset(0L); |
| |
| insertPreLongs(memObj, memAddr, preLongs); |
| insertSerVer(memObj, memAddr, SER_VER); |
| insertFamilyID(memObj, memAddr, familyId); |
| insertFlags(memObj, memAddr, (empty ? EMPTY_FLAG_MASK : 0)); |
| insertD(memObj, memAddr, d_); |
| |
| if (!empty) { |
| insertN(memObj, memAddr, n_); |
| long offset = (long) preLongs * Long.BYTES; |
| memOut.putDouble(offset, intercept_); |
| offset += Double.BYTES; |
| memOut.putDoubleArray(offset, mean_, 0, d_); |
| offset += (long) d_ * Double.BYTES; |
| memOut.putDoubleArray(offset, M2_, 0, d_); |
| } |
| |
| return outArr; |
| } |
| |
| /** |
| * Returns true if the object has no data, otherwise false |
| * @return True if the object has no data, otherwise false. |
| */ |
| public boolean isEmpty() { |
| return n_ == 0; |
| } |
| |
| /** |
| * Returns the number of dimensions configured for this object |
| * @return The number of dimensions |
| */ |
| public long getD() { |
| return d_; |
| } |
| |
| /** |
| * Returns the number of input vectors processed by this object |
| * @return The number of input vectors processed |
| */ |
| public long getN() { |
| return n_; |
| } |
| |
| /** |
| * Returns the array of means held by this object |
| * @return The array of means |
| */ |
| public double[] getMean() { |
| if (n_ == 0) { |
| final double[] result = new double[d_]; |
| for (int i = 0; i < d_; ++i) { |
| result[i] = Double.NaN; |
| } |
| return result; |
| } else { |
| return mean_.clone(); |
| } |
| } |
| |
| /** |
| * Returns the mean of the target value, aka the intercept |
| * @return Mean of the target value |
| */ |
| public double getIntercept() { |
| if (n_ == 0) |
| return Double.NaN; |
| else |
| return intercept_; |
| } |
| |
| /** |
| * Returns the sample variance array represented in this object. Returns an array of NaN if N = 0 and an |
| * array of zeros if N = 1. |
| * @return The sample variance array represented in this object |
| */ |
| public double[] getSampleVariance() { |
| if (n_ == 0) { |
| final double[] result = new double[d_]; |
| for (int i = 0; i < d_; ++i) { |
| result[i] = Double.NaN; |
| } |
| return result; |
| } else if (n_ == 1) { |
| return new double[d_]; // array of zeros |
| } else { // n_ > 1 |
| double[] result = M2_.clone(); |
| for (int i = 0; i < d_; ++i) { |
| result[i] = M2_[i] / n_; |
| } |
| return result; |
| } |
| } |
| |
| /** |
| * Returns the population variance array represented in this object. Returns an array of NaN if N = 0 and an |
| * array of zeros if N = 1. |
| * @return The population variance array represented in this object |
| */ |
| public double[] getPopulationVariance() { |
| if (n_ == 0) { |
| final double[] result = new double[d_]; |
| for (int i = 0; i < d_; ++i) { |
| result[i] = Double.NaN; |
| } |
| return result; |
| } else if (n_ == 1) { |
| return new double[d_]; // array of zeros |
| } else { // n_ > 1 |
| double[] result = M2_.clone(); |
| for (int i = 0; i < d_; ++i) { |
| result[i] = M2_[i] / (n_ - 1); |
| } |
| return result; |
| } |
| } |
| |
| public void update(final double[] x, final double target) { |
| if (x == null) |
| return; |
| |
| if (x.length != d_) { |
| throw new IllegalArgumentException("Input vector length must be " + d_ + ". Found: " + x.length ); |
| } |
| |
| ++n_; |
| for (int i = 0; i < d_; ++i) { |
| double d1 = x[i] - mean_[i]; // x_i - oldMean_i |
| mean_[i] += d1 / n_; |
| double d2 = x[i] - mean_[i]; // x_i - newMean_i |
| M2_[i] += d1 * d2; |
| } |
| |
| double delta = target - intercept_; |
| intercept_ += delta / n_; |
| } |
| |
| public void merge(@NonNull final VectorNormalizer other) { |
| if (other.d_ != d_) |
| throw new IllegalArgumentException("Input VectorNormalizer must have d= " + d_ + ". Found: " + other.d_); |
| |
| long combinedN = n_ + other.n_; |
| double varCountScalar = (n_ * other.n_) / (double) combinedN; // n_A * n_B / (n_A + n_B) |
| intercept_ = ((n_ * intercept_) + (other.n_ * other.intercept_)) / combinedN; |
| for (int i = 0; i < d_; ++i) { |
| double meanDiff = other.mean_[i] - mean_[i]; |
| mean_[i] = ((n_ * mean_[i]) + (other.n_ * other.mean_[i])) / combinedN; |
| M2_[i] += other.M2_[i] + meanDiff * meanDiff * varCountScalar; |
| } |
| n_ += other.n_; |
| } |
| |
| public int getSerializedSizeBytes() { |
| if (n_ == 0) { |
| return MatrixFamily.VECTORNORMALIZER.getMinPreLongs() * Long.BYTES; |
| } else { |
| return (MatrixFamily.VECTORNORMALIZER.getMaxPreLongs()) * Long.BYTES + ((1 + 2 * d_) * Double.BYTES); |
| } |
| } |
| |
| // Extraction methods |
| static int extractPreLongs(final Memory mem) { |
| return mem.getInt(PREAMBLE_LONGS_BYTE) & 0xFF; |
| } |
| |
| static int extractSerVer(final Memory mem) { |
| return mem.getInt(SER_VER_BYTE) & 0xFF; |
| } |
| |
| static int extractFamilyID(final Memory mem) { |
| return mem.getByte(FAMILY_BYTE) & 0xFF; |
| } |
| |
| static int extractFlags(final Memory mem) { |
| return mem.getByte(FLAGS_BYTE) & 0xFF; |
| } |
| |
| static int extractD(final Memory mem) { |
| return mem.getInt(D_INT); |
| } |
| |
| static long extractN(final Memory mem) { |
| return mem.getLong(N_LONG); |
| } |
| |
| |
| // Insertion methods |
| private void insertPreLongs(final Object memObj, final long memAddr, final int preLongs) { |
| unsafe.putByte(memObj, memAddr + PREAMBLE_LONGS_BYTE, (byte) preLongs); |
| } |
| |
| private void insertSerVer(final Object memObj, final long memAddr, final int serVer) { |
| unsafe.putByte(memObj, memAddr + SER_VER_BYTE, (byte) serVer); |
| } |
| |
| private void insertFamilyID(final Object memObj, final long memAddr, final int matrixFamId) { |
| unsafe.putByte(memObj, memAddr + FAMILY_BYTE, (byte) matrixFamId); |
| } |
| |
| private void insertFlags(final Object memObj, final long memAddr, final int flags) { |
| unsafe.putByte(memObj, memAddr + FLAGS_BYTE, (byte) flags); |
| } |
| |
| private void insertD(final Object memObj, final long memAddr, final int d) { |
| unsafe.putInt(memObj, memAddr + D_INT, d); |
| } |
| |
| private void insertN(final Object memObj, final long memAddr, final long n) { |
| unsafe.putLong(memObj, memAddr + N_LONG, n); |
| } |
| |
| /** |
| * Checks Memory for capacity to hold the preamble and returns the extracted preLongs. |
| * @param mem the given Memory |
| * @return the extracted prelongs value. |
| */ |
| private static int getAndCheckPreLongs(final Memory mem) { |
| final long cap = mem.getCapacity(); |
| if (cap < Long.BYTES) { throwNotBigEnough(cap, Long.BYTES); } |
| final int preLongs = extractPreLongs(mem); |
| final int required = Math.max(preLongs << 2, Long.BYTES); |
| if (cap < required) { throwNotBigEnough(cap, required); } |
| return preLongs; |
| } |
| |
| private static void throwNotBigEnough(final long cap, final int required) { |
| throw new IllegalArgumentException( |
| "Possible Corruption: Size of byte array or Memory not large enough: Size: " + cap |
| + ", Required: " + required); |
| } |
| } |