[SEDONA-549] Fix memory bloat issue of RS_Union_Aggr when working with non-double band data (#1402)
Co-authored-by: Kristin Cowalcijk <bo@wherobots.com>
diff --git a/common/src/main/java/org/apache/sedona/common/raster/MapAlgebra.java b/common/src/main/java/org/apache/sedona/common/raster/MapAlgebra.java
index 564077a..b5fb469 100644
--- a/common/src/main/java/org/apache/sedona/common/raster/MapAlgebra.java
+++ b/common/src/main/java/org/apache/sedona/common/raster/MapAlgebra.java
@@ -78,12 +78,11 @@
throw new IllegalArgumentException("Band index is out of bounds. Must be between 1 and " + (numBands + 1) + ")");
}
- Double[] bandValuesClass = Arrays.stream(bandValues).boxed().toArray(Double[]::new);
if (bandIndex == numBands + 1) {
- return RasterUtils.copyRasterAndAppendBand(rasterGeom, bandValuesClass, noDataValue);
+ return RasterUtils.copyRasterAndAppendBand(rasterGeom, bandValues, noDataValue);
}
else {
- return RasterUtils.copyRasterAndReplaceBand(rasterGeom, bandIndex, bandValuesClass, noDataValue, true);
+ return RasterUtils.copyRasterAndReplaceBand(rasterGeom, bandIndex, bandValues, noDataValue, true);
}
}
@@ -94,12 +93,11 @@
throw new IllegalArgumentException("Band index is out of bounds. Must be between 1 and " + (numBands + 1) + ")");
}
- Double[] bandValuesClass = Arrays.stream(bandValues).boxed().toArray(Double[]::new);
if (bandIndex == numBands + 1) {
- return RasterUtils.copyRasterAndAppendBand(rasterGeom, bandValuesClass);
+ return RasterUtils.copyRasterAndAppendBand(rasterGeom, bandValues);
}
else {
- return RasterUtils.copyRasterAndReplaceBand(rasterGeom, bandIndex, bandValuesClass);
+ return RasterUtils.copyRasterAndReplaceBand(rasterGeom, bandIndex, bandValues);
}
}
diff --git a/common/src/main/java/org/apache/sedona/common/raster/RasterBandEditors.java b/common/src/main/java/org/apache/sedona/common/raster/RasterBandEditors.java
index 4ffac4d..41b159a 100644
--- a/common/src/main/java/org/apache/sedona/common/raster/RasterBandEditors.java
+++ b/common/src/main/java/org/apache/sedona/common/raster/RasterBandEditors.java
@@ -135,16 +135,16 @@
if (RasterUtils.isDataTypeIntegral(dataTypeCode)) {
int[] bandValues = rasterData.getSamples(0, 0, width, height, fromBand - 1, (int[]) null);
if (numBands + 1 == toRasterIndex) {
- return RasterUtils.copyRasterAndAppendBand(toRaster, Arrays.stream(bandValues).boxed().toArray(Integer[]::new), noDataValue);
+ return RasterUtils.copyRasterAndAppendBand(toRaster, bandValues, noDataValue);
} else {
- return RasterUtils.copyRasterAndReplaceBand(toRaster, fromBand, Arrays.stream(bandValues).boxed().toArray(Integer[]::new), noDataValue, false);
+ return RasterUtils.copyRasterAndReplaceBand(toRaster, fromBand, bandValues, noDataValue, false);
}
} else {
double[] bandValues = rasterData.getSamples(0, 0, width, height, fromBand - 1, (double[]) null);
if (numBands + 1 == toRasterIndex) {
- return RasterUtils.copyRasterAndAppendBand(toRaster, Arrays.stream(bandValues).boxed().toArray(Double[]::new), noDataValue);
+ return RasterUtils.copyRasterAndAppendBand(toRaster, bandValues, noDataValue);
} else {
- return RasterUtils.copyRasterAndReplaceBand(toRaster, fromBand, Arrays.stream(bandValues).boxed().toArray(Double[]::new), noDataValue, false);
+ return RasterUtils.copyRasterAndReplaceBand(toRaster, fromBand, bandValues, noDataValue, false);
}
}
}
diff --git a/common/src/main/java/org/apache/sedona/common/raster/serde/Serde.java b/common/src/main/java/org/apache/sedona/common/raster/serde/Serde.java
index e775ee4..7f67708 100644
--- a/common/src/main/java/org/apache/sedona/common/raster/serde/Serde.java
+++ b/common/src/main/java/org/apache/sedona/common/raster/serde/Serde.java
@@ -32,8 +32,6 @@
import javax.media.jai.RenderedImageAdapter;
import java.awt.image.RenderedImage;
-import java.io.ByteArrayInputStream;
-import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.Serializable;
import java.net.URI;
@@ -178,22 +176,4 @@
return state.restore();
}
}
-
- public static byte[] serializeGridSampleDimension(GridSampleDimension sampleDimension) {
- Kryo kryo = kryos.get();
- ByteArrayOutputStream baos = new ByteArrayOutputStream();
- Output output = new Output(baos);
- GridSampleDimensionSerializer serializer = new GridSampleDimensionSerializer();
- serializer.write(kryo, output, sampleDimension);
- output.close();
- return baos.toByteArray();
- }
-
- public static GridSampleDimension deserializeGridSampleDimension(byte[] data) {
- Kryo kryo = kryos.get();
- Input input = new Input(new ByteArrayInputStream(data));
- GridSampleDimensionSerializer serializer = new GridSampleDimensionSerializer();
- return serializer.read(kryo, input, GridSampleDimension.class);
- }
-
}
diff --git a/common/src/main/java/org/apache/sedona/common/utils/RasterUtils.java b/common/src/main/java/org/apache/sedona/common/utils/RasterUtils.java
index 7cc6c50..ce6c3b6 100644
--- a/common/src/main/java/org/apache/sedona/common/utils/RasterUtils.java
+++ b/common/src/main/java/org/apache/sedona/common/utils/RasterUtils.java
@@ -556,7 +556,7 @@
* @param bandValues
* @return
*/
- public static GridCoverage2D copyRasterAndAppendBand(GridCoverage2D gridCoverage2D, Number[] bandValues, Double noDataValue) {
+ public static GridCoverage2D copyRasterAndAppendBand(GridCoverage2D gridCoverage2D, Object bandValues, Double noDataValue) {
// Get the original image and its properties
RenderedImage originalImage = gridCoverage2D.getRenderedImage();
Raster raster = getRaster(originalImage);
@@ -565,17 +565,19 @@
// Copy the raster data and append the new band values
for (int i = 0; i < raster.getWidth(); i++) {
for (int j = 0; j < raster.getHeight(); j++) {
- if (bandValues instanceof Double[]) {
+ if (bandValues instanceof double[]) {
+ double[] values = (double[]) bandValues;
double[] pixels = raster.getPixel(i, j, (double[]) null);
double[] copiedPixels = new double[pixels.length + 1];
System.arraycopy(pixels, 0, copiedPixels, 0, pixels.length);
- copiedPixels[pixels.length] = (double) bandValues[j * raster.getWidth() + i];
+ copiedPixels[pixels.length] = values[j * raster.getWidth() + i];
wr.setPixel(i, j, copiedPixels);
- } else if (bandValues instanceof Integer[]) {
+ } else if (bandValues instanceof int[]) {
+ int[] values = (int[]) bandValues;
int[] pixels = raster.getPixel(i, j, (int[]) null);
int[] copiedPixels = new int[pixels.length + 1];
System.arraycopy(pixels, 0, copiedPixels, 0, pixels.length);
- copiedPixels[pixels.length] = (int) bandValues[j * raster.getWidth() + i];
+ copiedPixels[pixels.length] = values[j * raster.getWidth() + i];
wr.setPixel(i, j, copiedPixels);
}
}
@@ -594,11 +596,11 @@
return clone(wr, gridCoverage2D.getGridGeometry(), sampleDimensions, gridCoverage2D, null, true);
}
- public static GridCoverage2D copyRasterAndAppendBand(GridCoverage2D gridCoverage2D, Number[] bandValues) {
+ public static GridCoverage2D copyRasterAndAppendBand(GridCoverage2D gridCoverage2D, Object bandValues) {
return copyRasterAndAppendBand(gridCoverage2D, bandValues, null);
}
- public static GridCoverage2D copyRasterAndReplaceBand(GridCoverage2D gridCoverage2D, int bandIndex, Number[] bandValues, Double noDataValue, boolean removeNoDataIfNull) {
+ public static GridCoverage2D copyRasterAndReplaceBand(GridCoverage2D gridCoverage2D, int bandIndex, Object bandValues, Double noDataValue, boolean removeNoDataIfNull) {
// Do not allow the band index to be out of bounds
ensureBand(gridCoverage2D, bandIndex);
// Get the original image and its properties
@@ -608,13 +610,15 @@
// Copy the raster data and replace the band values
for (int i = 0; i < raster.getWidth(); i++) {
for (int j = 0; j < raster.getHeight(); j++) {
- if (bandValues instanceof Double[]) {
+ if (bandValues instanceof double[]) {
+ double[] values = (double[]) bandValues;
double[] bands = raster.getPixel(i, j, (double[]) null);
- bands[bandIndex - 1] = (double) bandValues[j * raster.getWidth() + i];
+ bands[bandIndex - 1] = values[j * raster.getWidth() + i];
wr.setPixel(i, j, bands);
- } else if (bandValues instanceof Integer[]) {
+ } else if (bandValues instanceof int[]) {
+ int[] values = (int[]) bandValues;
int[] bands = raster.getPixel(i, j, (int[]) null);
- bands[bandIndex - 1] = (int) bandValues[j * raster.getWidth() + i];
+ bands[bandIndex - 1] = values[j * raster.getWidth() + i];
wr.setPixel(i, j, bands);
}
}
@@ -629,7 +633,7 @@
return clone(wr, gridCoverage2D.getGridGeometry(), sampleDimensions, gridCoverage2D, null, true);
}
- public static GridCoverage2D copyRasterAndReplaceBand(GridCoverage2D gridCoverage2D, int bandIndex, Number[] bandValues) {
+ public static GridCoverage2D copyRasterAndReplaceBand(GridCoverage2D gridCoverage2D, int bandIndex, Object bandValues) {
return copyRasterAndReplaceBand(gridCoverage2D, bandIndex, bandValues, null, false);
}
diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/AggregateFunctions.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/AggregateFunctions.scala
index 3bf1326..d8a6f4b 100644
--- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/AggregateFunctions.scala
+++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/raster/AggregateFunctions.scala
@@ -19,13 +19,12 @@
package org.apache.spark.sql.sedona_sql.expressions.raster
-import org.apache.sedona.common.raster.serde.Serde
import org.apache.sedona.common.raster.{RasterAccessors, RasterBandAccessors}
import org.apache.sedona.common.utils.RasterUtils
+import org.apache.sedona.sql.utils.RasterSerializer
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.expressions.Aggregator
-import org.geotools.coverage.GridSampleDimension
import org.geotools.coverage.grid.GridCoverage2D
import java.awt.image.WritableRaster
@@ -33,12 +32,10 @@
import scala.collection.mutable.ArrayBuffer
case class BandData(
- var bandsData: Array[Array[Double]],
- var index: Int,
- var serializedRaster: Array[Byte],
- var serializedSampleDimensions: Array[Array[Byte]]
- )
-
+ index: Int,
+ width: Int,
+ height: Int,
+ serializedRaster: Array[Byte])
/**
* Return a raster containing bands at given indexes from all rasters in a given column
@@ -48,37 +45,32 @@
def zero: ArrayBuffer[BandData] = ArrayBuffer[BandData]()
def reduce(buffer: ArrayBuffer[BandData], input: (GridCoverage2D, Int)): ArrayBuffer[BandData] = {
- val raster = input._1
- val renderedImage = raster.getRenderedImage
- val numBands = renderedImage.getSampleModel.getNumBands
- val width = renderedImage.getWidth
- val height = renderedImage.getHeight
+ val (raster, index) = input
+ val renderedImage = raster.getRenderedImage
+ val width = renderedImage.getWidth
+ val height = renderedImage.getHeight
+ val serializedRaster = RasterSerializer.serialize(raster)
+ raster.dispose(true)
- // First check if this is the first raster to set dimensions or validate against existing dimensions
- if (buffer.nonEmpty) {
- val referenceRaster = Serde.deserialize(buffer.head.serializedRaster)
- val refWidth = RasterAccessors.getWidth(referenceRaster)
- val refHeight = RasterAccessors.getHeight(referenceRaster)
- if (width != refWidth || height != refHeight) {
- throw new IllegalArgumentException("All rasters must have the same dimensions")
- }
+ // First check if this is the first raster to set dimensions or validate against existing dimensions
+ if (buffer.nonEmpty) {
+ val refWidth = buffer.head.width
+ val refHeight = buffer.head.height
+ if (width != refWidth || height != refHeight) {
+ throw new IllegalArgumentException("All rasters must have the same dimensions")
}
-
- // Extract data for each band
- val rasterData = renderedImage.getData
- val bandsData = Array.ofDim[Double](numBands, width * height)
- val serializedSampleDimensions = new Array[Array[Byte]](numBands)
-
- for (band <- 0 until numBands) {
- bandsData(band) = rasterData.getSamples(0, 0, width, height, band, new Array[Double](width * height))
- serializedSampleDimensions(band) = Serde.serializeGridSampleDimension(raster.getSampleDimension(band))
- }
-
- buffer += BandData(bandsData, input._2, Serde.serialize(raster), serializedSampleDimensions)
- buffer
}
+ buffer += BandData(index, width, height, serializedRaster)
+ buffer
+ }
+
def merge(buffer1: ArrayBuffer[BandData], buffer2: ArrayBuffer[BandData]): ArrayBuffer[BandData] = {
+ if (buffer1.nonEmpty && buffer2.nonEmpty) {
+ if (buffer1.head.width != buffer2.head.width || buffer1.head.height != buffer2.head.height) {
+ throw new IllegalArgumentException("All rasters must have the same dimensions")
+ }
+ }
val combined = ArrayBuffer.concat(buffer1, buffer2)
if (combined.map(_.index).distinct.length != combined.length) {
throw new IllegalArgumentException("Indexes shouldn't be repeated.")
@@ -95,24 +87,37 @@
throw new IllegalArgumentException("Index should be in an arithmetic sequence.")
}
- val totalBands = sortedMerged.map(_.bandsData.length).sum
- val referenceRaster = Serde.deserialize(sortedMerged.head.serializedRaster)
- val width = RasterAccessors.getWidth(referenceRaster)
- val height = RasterAccessors.getHeight(referenceRaster)
- val dataTypeCode = RasterUtils.getRaster(referenceRaster.getRenderedImage).getDataBuffer.getDataType
- val resultRaster: WritableRaster = RasterFactory.createBandedRaster(dataTypeCode, width, height, totalBands, null)
- val gridSampleDimensions = sortedMerged.flatMap(_.serializedSampleDimensions.map(Serde.deserializeGridSampleDimension)).toArray
+ val rasters = sortedMerged.map(d => RasterSerializer.deserialize(d.serializedRaster))
+ try {
+ val gridSampleDimensions = rasters.flatMap(_.getSampleDimensions).toArray
+ val totalBands = rasters.map(_.getNumSampleDimensions).sum
+ val referenceRaster = rasters.head
+ val width = RasterAccessors.getWidth(referenceRaster)
+ val height = RasterAccessors.getHeight(referenceRaster)
+ val dataTypeCode = RasterUtils.getRaster(referenceRaster.getRenderedImage).getDataBuffer.getDataType
+ val resultRaster: WritableRaster = RasterFactory.createBandedRaster(dataTypeCode, width, height, totalBands, null)
- var currentBand = 0
- sortedMerged.foreach { bandData =>
- bandData.bandsData.foreach { band =>
- resultRaster.setSamples(0, 0, width, height, currentBand, band)
- currentBand += 1
+ var currentBand = 0
+ rasters.foreach { raster =>
+ var bandIndex = 0
+ while (bandIndex < raster.getNumSampleDimensions) {
+ if (RasterUtils.isDataTypeIntegral(dataTypeCode)) {
+ val band = RasterUtils.getRaster(raster.getRenderedImage).getSamples(0, 0, width, height, bandIndex, new Array[Int](width * height))
+ resultRaster.setSamples(0, 0, width, height, currentBand, band)
+ } else {
+ val band = RasterUtils.getRaster(raster.getRenderedImage).getSamples(0, 0, width, height, bandIndex, new Array[Double](width * height))
+ resultRaster.setSamples(0, 0, width, height, currentBand, band)
+ }
+ currentBand += 1
+ bandIndex += 1
+ }
}
- }
- val noDataValue = RasterBandAccessors.getBandNoDataValue(referenceRaster)
- RasterUtils.clone(resultRaster, referenceRaster.getGridGeometry, gridSampleDimensions, referenceRaster, noDataValue, true)
+ val noDataValue = RasterBandAccessors.getBandNoDataValue(referenceRaster)
+ RasterUtils.clone(resultRaster, referenceRaster.getGridGeometry, gridSampleDimensions, referenceRaster, noDataValue, false)
+ } finally {
+ rasters.foreach(_.dispose(true))
+ }
}
val serde = ExpressionEncoder[GridCoverage2D]