[SEDONA-549] Fix memory bloat issue of RS_Union_Aggr when working with non-double band data (#1402)

Co-authored-by: Kristin Cowalcijk <bo@wherobots.com>
diff --git a/common/src/main/java/org/apache/sedona/common/raster/MapAlgebra.java b/common/src/main/java/org/apache/sedona/common/raster/MapAlgebra.java
index 564077a..b5fb469 100644
--- a/common/src/main/java/org/apache/sedona/common/raster/MapAlgebra.java
+++ b/common/src/main/java/org/apache/sedona/common/raster/MapAlgebra.java
@@ -78,12 +78,11 @@
             throw new IllegalArgumentException("Band index is out of bounds. Must be between 1 and " + (numBands + 1) + ")");
         }
 
-        Double[] bandValuesClass = Arrays.stream(bandValues).boxed().toArray(Double[]::new);
         if (bandIndex == numBands + 1) {
-            return RasterUtils.copyRasterAndAppendBand(rasterGeom, bandValuesClass, noDataValue);
+            return RasterUtils.copyRasterAndAppendBand(rasterGeom, bandValues, noDataValue);
         }
         else {
-            return RasterUtils.copyRasterAndReplaceBand(rasterGeom, bandIndex, bandValuesClass, noDataValue, true);
+            return RasterUtils.copyRasterAndReplaceBand(rasterGeom, bandIndex, bandValues, noDataValue, true);
         }
     }
 
@@ -94,12 +93,11 @@
             throw new IllegalArgumentException("Band index is out of bounds. Must be between 1 and " + (numBands + 1) + ")");
         }
 
-        Double[] bandValuesClass = Arrays.stream(bandValues).boxed().toArray(Double[]::new);
         if (bandIndex == numBands + 1) {
-            return RasterUtils.copyRasterAndAppendBand(rasterGeom, bandValuesClass);
+            return RasterUtils.copyRasterAndAppendBand(rasterGeom, bandValues);
         }
         else {
-            return RasterUtils.copyRasterAndReplaceBand(rasterGeom, bandIndex, bandValuesClass);
+            return RasterUtils.copyRasterAndReplaceBand(rasterGeom, bandIndex, bandValues);
         }
     }
 
diff --git a/common/src/main/java/org/apache/sedona/common/raster/RasterBandEditors.java b/common/src/main/java/org/apache/sedona/common/raster/RasterBandEditors.java
index 4ffac4d..41b159a 100644
--- a/common/src/main/java/org/apache/sedona/common/raster/RasterBandEditors.java
+++ b/common/src/main/java/org/apache/sedona/common/raster/RasterBandEditors.java
@@ -135,16 +135,16 @@
         if (RasterUtils.isDataTypeIntegral(dataTypeCode)) {
             int[] bandValues = rasterData.getSamples(0, 0, width, height, fromBand - 1, (int[]) null);
             if (numBands + 1 == toRasterIndex) {
-                return RasterUtils.copyRasterAndAppendBand(toRaster, Arrays.stream(bandValues).boxed().toArray(Integer[]::new), noDataValue);
+                return RasterUtils.copyRasterAndAppendBand(toRaster, bandValues, noDataValue);
             } else {
-                return RasterUtils.copyRasterAndReplaceBand(toRaster, fromBand, Arrays.stream(bandValues).boxed().toArray(Integer[]::new), noDataValue, false);
+                return RasterUtils.copyRasterAndReplaceBand(toRaster, fromBand, bandValues, noDataValue, false);
             }
         } else {
             double[] bandValues = rasterData.getSamples(0, 0, width, height, fromBand - 1, (double[]) null);
             if (numBands + 1 == toRasterIndex) {
-                return RasterUtils.copyRasterAndAppendBand(toRaster, Arrays.stream(bandValues).boxed().toArray(Double[]::new), noDataValue);
+                return RasterUtils.copyRasterAndAppendBand(toRaster, bandValues, noDataValue);
             } else {
-                return RasterUtils.copyRasterAndReplaceBand(toRaster, fromBand, Arrays.stream(bandValues).boxed().toArray(Double[]::new), noDataValue, false);
+                return RasterUtils.copyRasterAndReplaceBand(toRaster, fromBand, bandValues, noDataValue, false);
             }
         }
     }
diff --git a/common/src/main/java/org/apache/sedona/common/raster/serde/Serde.java b/common/src/main/java/org/apache/sedona/common/raster/serde/Serde.java
index e775ee4..7f67708 100644
--- a/common/src/main/java/org/apache/sedona/common/raster/serde/Serde.java
+++ b/common/src/main/java/org/apache/sedona/common/raster/serde/Serde.java
@@ -32,8 +32,6 @@
 
 import javax.media.jai.RenderedImageAdapter;
 import java.awt.image.RenderedImage;
-import java.io.ByteArrayInputStream;
-import java.io.ByteArrayOutputStream;
 import java.io.IOException;
 import java.io.Serializable;
 import java.net.URI;
@@ -178,22 +176,4 @@
             return state.restore();
         }
     }
-
-    public static byte[] serializeGridSampleDimension(GridSampleDimension sampleDimension) {
-        Kryo kryo = kryos.get();
-        ByteArrayOutputStream baos = new ByteArrayOutputStream();
-        Output output = new Output(baos);
-        GridSampleDimensionSerializer serializer = new GridSampleDimensionSerializer();
-        serializer.write(kryo, output, sampleDimension);
-        output.close();
-        return baos.toByteArray();
-    }
-
-    public static GridSampleDimension deserializeGridSampleDimension(byte[] data) {
-        Kryo kryo = kryos.get();
-        Input input = new Input(new ByteArrayInputStream(data));
-        GridSampleDimensionSerializer serializer = new GridSampleDimensionSerializer();
-        return serializer.read(kryo, input, GridSampleDimension.class);
-    }
-
 }
diff --git a/common/src/main/java/org/apache/sedona/common/utils/RasterUtils.java b/common/src/main/java/org/apache/sedona/common/utils/RasterUtils.java
index 7cc6c50..ce6c3b6 100644
--- a/common/src/main/java/org/apache/sedona/common/utils/RasterUtils.java
+++ b/common/src/main/java/org/apache/sedona/common/utils/RasterUtils.java
@@ -556,7 +556,7 @@
      * @param bandValues
      * @return
      */
-    public static GridCoverage2D copyRasterAndAppendBand(GridCoverage2D gridCoverage2D, Number[] bandValues, Double noDataValue) {
+    public static GridCoverage2D copyRasterAndAppendBand(GridCoverage2D gridCoverage2D, Object bandValues, Double noDataValue) {
         // Get the original image and its properties
         RenderedImage originalImage = gridCoverage2D.getRenderedImage();
         Raster raster = getRaster(originalImage);
@@ -565,17 +565,19 @@
         // Copy the raster data and append the new band values
         for (int i = 0; i < raster.getWidth(); i++) {
             for (int j = 0; j < raster.getHeight(); j++) {
-                if (bandValues instanceof Double[]) {
+                if (bandValues instanceof double[]) {
+                    double[] values = (double[]) bandValues;
                     double[] pixels = raster.getPixel(i, j, (double[]) null);
                     double[] copiedPixels = new double[pixels.length + 1];
                     System.arraycopy(pixels, 0, copiedPixels, 0, pixels.length);
-                    copiedPixels[pixels.length] = (double) bandValues[j * raster.getWidth() + i];
+                    copiedPixels[pixels.length] = values[j * raster.getWidth() + i];
                     wr.setPixel(i, j, copiedPixels);
-                } else if (bandValues instanceof Integer[]) {
+                } else if (bandValues instanceof int[]) {
+                    int[] values = (int[]) bandValues;
                     int[] pixels = raster.getPixel(i, j, (int[]) null);
                     int[] copiedPixels = new int[pixels.length + 1];
                     System.arraycopy(pixels, 0, copiedPixels, 0, pixels.length);
-                    copiedPixels[pixels.length] = (int) bandValues[j * raster.getWidth() + i];
+                    copiedPixels[pixels.length] = values[j * raster.getWidth() + i];
                     wr.setPixel(i, j, copiedPixels);
                 }
             }
@@ -594,11 +596,11 @@
         return clone(wr, gridCoverage2D.getGridGeometry(), sampleDimensions, gridCoverage2D, null, true);
     }
 
-    public static GridCoverage2D copyRasterAndAppendBand(GridCoverage2D gridCoverage2D, Number[] bandValues) {
+    public static GridCoverage2D copyRasterAndAppendBand(GridCoverage2D gridCoverage2D, Object bandValues) {
         return copyRasterAndAppendBand(gridCoverage2D, bandValues, null);
     }
 
-    public static GridCoverage2D copyRasterAndReplaceBand(GridCoverage2D gridCoverage2D, int bandIndex, Number[] bandValues, Double noDataValue, boolean removeNoDataIfNull) {
+    public static GridCoverage2D copyRasterAndReplaceBand(GridCoverage2D gridCoverage2D, int bandIndex, Object bandValues, Double noDataValue, boolean removeNoDataIfNull) {
         // Do not allow the band index to be out of bounds
         ensureBand(gridCoverage2D, bandIndex);
         // Get the original image and its properties
@@ -608,13 +610,15 @@
         // Copy the raster data and replace the band values
         for (int i = 0; i < raster.getWidth(); i++) {
             for (int j = 0; j < raster.getHeight(); j++) {
-                if (bandValues instanceof Double[]) {
+                if (bandValues instanceof double[]) {
+                    double[] values = (double[]) bandValues;
                     double[] bands = raster.getPixel(i, j, (double[]) null);
-                    bands[bandIndex - 1] = (double) bandValues[j * raster.getWidth() + i];
+                    bands[bandIndex - 1] = values[j * raster.getWidth() + i];
                     wr.setPixel(i, j, bands);
-                } else if (bandValues instanceof Integer[]) {
+                } else if (bandValues instanceof int[]) {
+                    int[] values = (int[]) bandValues;
                     int[] bands = raster.getPixel(i, j, (int[]) null);
-                    bands[bandIndex - 1] = (int) bandValues[j * raster.getWidth() + i];
+                    bands[bandIndex - 1] = values[j * raster.getWidth() + i];
                     wr.setPixel(i, j, bands);
                 }
             }
@@ -629,7 +633,7 @@
         return clone(wr, gridCoverage2D.getGridGeometry(), sampleDimensions, gridCoverage2D, null, true);
     }
 
-    public static GridCoverage2D copyRasterAndReplaceBand(GridCoverage2D gridCoverage2D, int bandIndex, Number[] bandValues) {
+    public static GridCoverage2D copyRasterAndReplaceBand(GridCoverage2D gridCoverage2D, int bandIndex, Object bandValues) {
         return copyRasterAndReplaceBand(gridCoverage2D, bandIndex, bandValues, null, false);
     }
 
diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/AggregateFunctions.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/AggregateFunctions.scala
index 3bf1326..d8a6f4b 100644
--- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/AggregateFunctions.scala
+++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/AggregateFunctions.scala
@@ -19,13 +19,12 @@
 
 package org.apache.spark.sql.sedona_sql.expressions.raster
 
-import org.apache.sedona.common.raster.serde.Serde
 import org.apache.sedona.common.raster.{RasterAccessors, RasterBandAccessors}
 import org.apache.sedona.common.utils.RasterUtils
+import org.apache.sedona.sql.utils.RasterSerializer
 import org.apache.spark.sql.Encoder
 import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
 import org.apache.spark.sql.expressions.Aggregator
-import org.geotools.coverage.GridSampleDimension
 import org.geotools.coverage.grid.GridCoverage2D
 
 import java.awt.image.WritableRaster
@@ -33,12 +32,10 @@
 import scala.collection.mutable.ArrayBuffer
 
 case class BandData(
-                     var bandsData: Array[Array[Double]],
-                     var index: Int,
-                     var serializedRaster: Array[Byte],
-                     var serializedSampleDimensions: Array[Array[Byte]]
-                   )
-
+  index: Int,
+  width: Int,
+  height: Int,
+  serializedRaster: Array[Byte])
 
 /**
  * Return a raster containing bands at given indexes from all rasters in a given column
@@ -48,37 +45,32 @@
   def zero: ArrayBuffer[BandData] = ArrayBuffer[BandData]()
 
   def reduce(buffer: ArrayBuffer[BandData], input: (GridCoverage2D, Int)): ArrayBuffer[BandData] = {
-      val raster = input._1
-      val renderedImage = raster.getRenderedImage
-      val numBands = renderedImage.getSampleModel.getNumBands
-      val width = renderedImage.getWidth
-      val height = renderedImage.getHeight
+    val (raster, index) = input
+    val renderedImage = raster.getRenderedImage
+    val width = renderedImage.getWidth
+    val height = renderedImage.getHeight
+    val serializedRaster = RasterSerializer.serialize(raster)
+    raster.dispose(true)
 
-      // First check if this is the first raster to set dimensions or validate against existing dimensions
-      if (buffer.nonEmpty) {
-        val referenceRaster = Serde.deserialize(buffer.head.serializedRaster)
-        val refWidth = RasterAccessors.getWidth(referenceRaster)
-        val refHeight = RasterAccessors.getHeight(referenceRaster)
-        if (width != refWidth || height != refHeight) {
-          throw new IllegalArgumentException("All rasters must have the same dimensions")
-        }
+    // First check if this is the first raster to set dimensions or validate against existing dimensions
+    if (buffer.nonEmpty) {
+      val refWidth = buffer.head.width
+      val refHeight = buffer.head.height
+      if (width != refWidth || height != refHeight) {
+        throw new IllegalArgumentException("All rasters must have the same dimensions")
       }
-
-      // Extract data for each band
-      val rasterData = renderedImage.getData
-      val bandsData = Array.ofDim[Double](numBands, width * height)
-      val serializedSampleDimensions = new Array[Array[Byte]](numBands)
-
-      for (band <- 0 until numBands) {
-        bandsData(band) = rasterData.getSamples(0, 0, width, height, band, new Array[Double](width * height))
-        serializedSampleDimensions(band) = Serde.serializeGridSampleDimension(raster.getSampleDimension(band))
-      }
-
-      buffer += BandData(bandsData, input._2, Serde.serialize(raster), serializedSampleDimensions)
-      buffer
     }
 
+    buffer += BandData(index, width, height, serializedRaster)
+    buffer
+  }
+
   def merge(buffer1: ArrayBuffer[BandData], buffer2: ArrayBuffer[BandData]): ArrayBuffer[BandData] = {
+    if (buffer1.nonEmpty && buffer2.nonEmpty) {
+      if (buffer1.head.width != buffer2.head.width || buffer1.head.height != buffer2.head.height) {
+        throw new IllegalArgumentException("All rasters must have the same dimensions")
+      }
+    }
     val combined = ArrayBuffer.concat(buffer1, buffer2)
     if (combined.map(_.index).distinct.length != combined.length) {
       throw new IllegalArgumentException("Indexes shouldn't be repeated.")
@@ -95,24 +87,37 @@
       throw new IllegalArgumentException("Index should be in an arithmetic sequence.")
     }
 
-    val totalBands = sortedMerged.map(_.bandsData.length).sum
-    val referenceRaster = Serde.deserialize(sortedMerged.head.serializedRaster)
-    val width = RasterAccessors.getWidth(referenceRaster)
-    val height = RasterAccessors.getHeight(referenceRaster)
-    val dataTypeCode = RasterUtils.getRaster(referenceRaster.getRenderedImage).getDataBuffer.getDataType
-    val resultRaster: WritableRaster = RasterFactory.createBandedRaster(dataTypeCode, width, height, totalBands, null)
-    val gridSampleDimensions = sortedMerged.flatMap(_.serializedSampleDimensions.map(Serde.deserializeGridSampleDimension)).toArray
+    val rasters = sortedMerged.map(d => RasterSerializer.deserialize(d.serializedRaster))
+    try {
+      val gridSampleDimensions = rasters.flatMap(_.getSampleDimensions).toArray
+      val totalBands = rasters.map(_.getNumSampleDimensions).sum
+      val referenceRaster = rasters.head
+      val width = RasterAccessors.getWidth(referenceRaster)
+      val height = RasterAccessors.getHeight(referenceRaster)
+      val dataTypeCode = RasterUtils.getRaster(referenceRaster.getRenderedImage).getDataBuffer.getDataType
+      val resultRaster: WritableRaster = RasterFactory.createBandedRaster(dataTypeCode, width, height, totalBands, null)
 
-    var currentBand = 0
-    sortedMerged.foreach { bandData =>
-      bandData.bandsData.foreach { band =>
-        resultRaster.setSamples(0, 0, width, height, currentBand, band)
-        currentBand += 1
+      var currentBand = 0
+      rasters.foreach { raster =>
+        var bandIndex = 0
+        while (bandIndex < raster.getNumSampleDimensions) {
+          if (RasterUtils.isDataTypeIntegral(dataTypeCode)) {
+            val band = RasterUtils.getRaster(raster.getRenderedImage).getSamples(0, 0, width, height, bandIndex, new Array[Int](width * height))
+            resultRaster.setSamples(0, 0, width, height, currentBand, band)
+          } else {
+            val band = RasterUtils.getRaster(raster.getRenderedImage).getSamples(0, 0, width, height, bandIndex, new Array[Double](width * height))
+            resultRaster.setSamples(0, 0, width, height, currentBand, band)
+          }
+          currentBand += 1
+          bandIndex += 1
+        }
       }
-    }
 
-    val noDataValue = RasterBandAccessors.getBandNoDataValue(referenceRaster)
-    RasterUtils.clone(resultRaster, referenceRaster.getGridGeometry, gridSampleDimensions, referenceRaster, noDataValue, true)
+      val noDataValue = RasterBandAccessors.getBandNoDataValue(referenceRaster)
+      RasterUtils.clone(resultRaster, referenceRaster.getGridGeometry, gridSampleDimensions, referenceRaster, noDataValue, false)
+    } finally {
+      rasters.foreach(_.dispose(true))
+    }
   }
 
   val serde = ExpressionEncoder[GridCoverage2D]