blob: d8a6f4be461c9ffef442ddc7adcfab7edde818cd [file] [log] [blame]
/*
* Licensed to the Apache Software Foundation (ASF) under one
* or more contributor license agreements. See the NOTICE file
* distributed with this work for additional information
* regarding copyright ownership. The ASF licenses this file
* to you under the Apache License, Version 2.0 (the
* "License"); you may not use this file except in compliance
* with the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing,
* software distributed under the License is distributed on an
* "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
* KIND, either express or implied. See the License for the
* specific language governing permissions and limitations
* under the License.
*/
package org.apache.spark.sql.sedona_sql.expressions.raster
import org.apache.sedona.common.raster.{RasterAccessors, RasterBandAccessors}
import org.apache.sedona.common.utils.RasterUtils
import org.apache.sedona.sql.utils.RasterSerializer
import org.apache.spark.sql.Encoder
import org.apache.spark.sql.catalyst.encoders.ExpressionEncoder
import org.apache.spark.sql.expressions.Aggregator
import org.geotools.coverage.grid.GridCoverage2D
import java.awt.image.WritableRaster
import javax.media.jai.RasterFactory
import scala.collection.mutable.ArrayBuffer
case class BandData(
index: Int,
width: Int,
height: Int,
serializedRaster: Array[Byte])
/**
* Return a raster containing bands at given indexes from all rasters in a given column
*/
class RS_Union_Aggr extends Aggregator[(GridCoverage2D, Int), ArrayBuffer[BandData], GridCoverage2D] {
def zero: ArrayBuffer[BandData] = ArrayBuffer[BandData]()
def reduce(buffer: ArrayBuffer[BandData], input: (GridCoverage2D, Int)): ArrayBuffer[BandData] = {
val (raster, index) = input
val renderedImage = raster.getRenderedImage
val width = renderedImage.getWidth
val height = renderedImage.getHeight
val serializedRaster = RasterSerializer.serialize(raster)
raster.dispose(true)
// First check if this is the first raster to set dimensions or validate against existing dimensions
if (buffer.nonEmpty) {
val refWidth = buffer.head.width
val refHeight = buffer.head.height
if (width != refWidth || height != refHeight) {
throw new IllegalArgumentException("All rasters must have the same dimensions")
}
}
buffer += BandData(index, width, height, serializedRaster)
buffer
}
def merge(buffer1: ArrayBuffer[BandData], buffer2: ArrayBuffer[BandData]): ArrayBuffer[BandData] = {
if (buffer1.nonEmpty && buffer2.nonEmpty) {
if (buffer1.head.width != buffer2.head.width || buffer1.head.height != buffer2.head.height) {
throw new IllegalArgumentException("All rasters must have the same dimensions")
}
}
val combined = ArrayBuffer.concat(buffer1, buffer2)
if (combined.map(_.index).distinct.length != combined.length) {
throw new IllegalArgumentException("Indexes shouldn't be repeated.")
}
combined
}
def finish(merged: ArrayBuffer[BandData]): GridCoverage2D = {
val sortedMerged = merged.sortBy(_.index)
if (sortedMerged.zipWithIndex.exists { case (band, idx) =>
if (idx > 0) (band.index - sortedMerged(idx - 1).index) != (sortedMerged(1).index - sortedMerged(0).index)
else false
}) {
throw new IllegalArgumentException("Index should be in an arithmetic sequence.")
}
val rasters = sortedMerged.map(d => RasterSerializer.deserialize(d.serializedRaster))
try {
val gridSampleDimensions = rasters.flatMap(_.getSampleDimensions).toArray
val totalBands = rasters.map(_.getNumSampleDimensions).sum
val referenceRaster = rasters.head
val width = RasterAccessors.getWidth(referenceRaster)
val height = RasterAccessors.getHeight(referenceRaster)
val dataTypeCode = RasterUtils.getRaster(referenceRaster.getRenderedImage).getDataBuffer.getDataType
val resultRaster: WritableRaster = RasterFactory.createBandedRaster(dataTypeCode, width, height, totalBands, null)
var currentBand = 0
rasters.foreach { raster =>
var bandIndex = 0
while (bandIndex < raster.getNumSampleDimensions) {
if (RasterUtils.isDataTypeIntegral(dataTypeCode)) {
val band = RasterUtils.getRaster(raster.getRenderedImage).getSamples(0, 0, width, height, bandIndex, new Array[Int](width * height))
resultRaster.setSamples(0, 0, width, height, currentBand, band)
} else {
val band = RasterUtils.getRaster(raster.getRenderedImage).getSamples(0, 0, width, height, bandIndex, new Array[Double](width * height))
resultRaster.setSamples(0, 0, width, height, currentBand, band)
}
currentBand += 1
bandIndex += 1
}
}
val noDataValue = RasterBandAccessors.getBandNoDataValue(referenceRaster)
RasterUtils.clone(resultRaster, referenceRaster.getGridGeometry, gridSampleDimensions, referenceRaster, noDataValue, false)
} finally {
rasters.foreach(_.dispose(true))
}
}
val serde = ExpressionEncoder[GridCoverage2D]
val bufferSerde = ExpressionEncoder[ArrayBuffer[BandData]]
def outputEncoder: ExpressionEncoder[GridCoverage2D] = serde
def bufferEncoder: Encoder[ArrayBuffer[BandData]] = bufferSerde
}