blob: e105299bbd5da8269e1d05fdc292a9c9d235093f [file]
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.
//! Session-wide memory accounting wrapper around any `MemoryPool`.
use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, Mutex, OnceLock};
use datafusion::common::Result;
use datafusion::execution::memory_pool::{MemoryLimit, MemoryPool, MemoryReservation};
use jni::sys::jlong;
/// Wraps an inner `MemoryPool` with two counters: total bytes currently held
/// and the peak observed since the wrapper was constructed. Bookkeeping is two
/// atomic loads/stores per `grow`/`shrink` and adds no contention because the
/// inner pool already serializes its own state.
#[derive(Debug)]
pub struct TrackingMemoryPool {
inner: Arc<dyn MemoryPool>,
current_bytes: AtomicU64,
peak_bytes: AtomicU64,
}
impl TrackingMemoryPool {
pub fn new(inner: Arc<dyn MemoryPool>) -> Self {
Self {
inner,
current_bytes: AtomicU64::new(0),
peak_bytes: AtomicU64::new(0),
}
}
// Note: there are intentionally no separate `current_bytes()` /
// `peak_bytes()` accessors. Reading them independently allows a torn
// observation where `current > peak` while a concurrent `record_grow` is
// mid-flight. Callers must use `snapshot()` instead, which clamps.
/// Returns `(current_bytes, peak_bytes)` as a consistent snapshot. Two
/// invariants are maintained:
///
/// 1. **Within a snapshot:** `peak >= current`. This matters when a
/// concurrent `record_grow` is in-flight between its `fetch_add` of
/// `current_bytes` and its `fetch_max` of `peak_bytes` -- a naive
/// snapshot could observe the higher current but the stale peak.
/// 2. **Across snapshots:** `peak` is monotonically non-decreasing for
/// every observer. If reader R1 ever sees peak = X, no later reader R2
/// (on any thread) sees peak < X.
///
/// The read-side clamp `max(stored_peak, observed_current)` alone gives
/// (1) but not (2): it's possible for R1 to *synthesize* a peak of X
/// from a transiently-high current, and then for a `record_shrink` plus
/// the slow path of `record_grow`'s missing `fetch_max` to leave the
/// stored peak briefly below X while R2 polls. R2 would observe a
/// smaller current and a smaller stored peak -- non-monotonic.
///
/// The fix is to persist the clamped value back into `peak_bytes` via
/// `fetch_max` *before* returning. Doing so guarantees every snapshot's
/// returned peak is also reflected in the global atomic, so any
/// subsequent observer's `peak_bytes.load()` is at least as high as the
/// largest peak any earlier observer returned.
///
/// Both atomics are `Relaxed`-ordered; we tolerate slightly stale values
/// in exchange for cheaper writes.
pub fn snapshot(&self) -> (u64, u64) {
// Order matters on the load side too: read `peak` first, then
// `current`. If a `record_grow` is in flight, this order ensures
// that whenever we observe a fresh `current` we also observe at
// least the corresponding (or older) `peak` -- so the eventual
// clamp `max(peak, current)` produces the right value.
let peak_observed = self.peak_bytes.load(Ordering::Relaxed);
let current = self.current_bytes.load(Ordering::Relaxed);
let clamped_peak = peak_observed.max(current);
// Persist the clamped peak so future observers can never see a
// smaller peak than the one we just returned. fetch_max is
// monotonic, so this is a no-op when another thread (or this very
// thread on a previous call) already pushed peak past clamped_peak.
let stored_peak = self.peak_bytes.fetch_max(clamped_peak, Ordering::Relaxed);
(current, stored_peak.max(clamped_peak))
}
fn record_grow(&self, additional: usize) {
// We bump `current` first to avoid double-counting in `peak` if a
// concurrent reader sees a stale current and a fresh peak. The
// snapshot path's read-side `max(peak, current)` clamp closes the
// window where `current` is up-to-date but `peak` is not.
let now = self
.current_bytes
.fetch_add(additional as u64, Ordering::Relaxed)
.saturating_add(additional as u64);
// fetch_max is monotonic so concurrent grows still produce the right peak.
self.peak_bytes.fetch_max(now, Ordering::Relaxed);
}
fn record_shrink(&self, shrink: usize) {
// Saturating: if the inner pool ever calls shrink with a value larger
// than the running total we'd rather report 0 than wrap around.
let prev = self
.current_bytes
.fetch_sub(shrink as u64, Ordering::Relaxed);
if prev < shrink as u64 {
// Restore to zero. Concurrent grows that interleaved are fine --
// they'll just see a transient under-count.
self.current_bytes.store(0, Ordering::Relaxed);
}
}
}
impl MemoryPool for TrackingMemoryPool {
fn register(&self, consumer: &datafusion::execution::memory_pool::MemoryConsumer) {
self.inner.register(consumer);
}
fn unregister(&self, consumer: &datafusion::execution::memory_pool::MemoryConsumer) {
self.inner.unregister(consumer);
}
fn grow(&self, reservation: &MemoryReservation, additional: usize) {
self.inner.grow(reservation, additional);
self.record_grow(additional);
}
fn shrink(&self, reservation: &MemoryReservation, shrink: usize) {
self.inner.shrink(reservation, shrink);
self.record_shrink(shrink);
}
fn try_grow(&self, reservation: &MemoryReservation, additional: usize) -> Result<()> {
self.inner.try_grow(reservation, additional)?;
self.record_grow(additional);
Ok(())
}
fn reserved(&self) -> usize {
self.inner.reserved()
}
fn memory_limit(&self) -> MemoryLimit {
self.inner.memory_limit()
}
}
/// Process-wide registry mapping a `SessionContext`'s native handle to its
/// `TrackingMemoryPool`. Populated at session creation, drained at session
/// close. We can't downcast `Arc<dyn MemoryPool>` back to the concrete type
/// without an `Any` bound (which the trait does not require), so instead we
/// keep a parallel `Arc` indexed by the JNI handle Java already threads
/// through every call.
fn registry() -> &'static Mutex<HashMap<jlong, Arc<TrackingMemoryPool>>> {
static REGISTRY: OnceLock<Mutex<HashMap<jlong, Arc<TrackingMemoryPool>>>> = OnceLock::new();
REGISTRY.get_or_init(|| Mutex::new(HashMap::new()))
}
pub fn register(handle: jlong, pool: Arc<TrackingMemoryPool>) {
registry()
.lock()
.expect("memory registry lock poisoned")
.insert(handle, pool);
}
pub fn lookup(handle: jlong) -> Option<Arc<TrackingMemoryPool>> {
registry()
.lock()
.expect("memory registry lock poisoned")
.get(&handle)
.cloned()
}
pub fn unregister(handle: jlong) {
registry()
.lock()
.expect("memory registry lock poisoned")
.remove(&handle);
}