Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
5 changes: 5 additions & 0 deletions native/core/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,11 @@ hdfs = ["datafusion-comet-objectstore-hdfs"]
hdfs-opendal = ["opendal", "object_store_opendal", "hdfs-sys"]
jemalloc = ["tikv-jemallocator", "tikv-jemalloc-ctl"]

# Allocator-level OOM circuit breaker. When enabled, the global allocator is
# wrapped to track real allocated bytes and panic an over-budget query-worker
# thread (caught at the task boundary). Off by default; zero overhead when off.
oom-guard = []

# exclude optional packages from cargo machete verifications
[package.metadata.cargo-machete]
ignored = ["hdfs-sys", "paste"]
Expand Down
199 changes: 140 additions & 59 deletions native/core/src/execution/jni_api.rs
Original file line number Diff line number Diff line change
Expand Up @@ -112,9 +112,13 @@ use crate::execution::spark_config::{
SparkConfig, COMET_DEBUG_ENABLED, COMET_DEBUG_MEMORY, COMET_EXPLAIN_NATIVE_ENABLED,
COMET_MAX_TEMP_DIRECTORY_SIZE, COMET_TRACING_ENABLED, SPARK_EXECUTOR_CORES,
};
#[cfg(feature = "oom-guard")]
use crate::execution::spark_config::{COMET_MEMORY_GUARD_ENABLED, COMET_MEMORY_GUARD_SIZE};
use crate::parquet::encryption_support::{CometEncryptionFactory, ENCRYPTION_FACTORY_ID};
use datafusion_comet_proto::spark_operator::operator::OpStruct;
use log::info;
#[cfg(feature = "oom-guard")]
use log::warn;
use std::sync::OnceLock;
#[cfg(feature = "jemalloc")]
use tikv_jemalloc_ctl::{epoch, stats};
Expand Down Expand Up @@ -192,6 +196,8 @@ fn parse_usize_env_var(name: &str) -> Option<usize> {

fn build_runtime(default_worker_threads: Option<usize>) -> Runtime {
let mut builder = tokio::runtime::Builder::new_multi_thread();
#[cfg(feature = "oom-guard")]
builder.on_thread_start(|| crate::execution::memory_pools::oom_guard::stamp_current_thread());
if let Some(n) = parse_usize_env_var("COMET_WORKER_THREADS") {
info!("Comet tokio runtime: using COMET_WORKER_THREADS={n}");
builder.worker_threads(n);
Expand Down Expand Up @@ -369,6 +375,24 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_createPlan(
spark_config.get_u64(COMET_MAX_TEMP_DIRECTORY_SIZE, 100 * 1024 * 1024 * 1024);
let logging_memory_pool = spark_config.get_bool(COMET_DEBUG_MEMORY);

#[cfg(feature = "oom-guard")]
{
if spark_config.get_bool(COMET_MEMORY_GUARD_ENABLED) {
// Default to the executor off-heap memory limit (`memory_limit`);
// allow an explicit override.
let default_limit = memory_limit.max(0) as u64;
let limit = spark_config.get_u64(COMET_MEMORY_GUARD_SIZE, default_limit);
if limit == 0 {
warn!(
"spark.comet.exec.memoryGuard.enabled is true but the effective limit \
is 0 (memory_limit={memory_limit}); the guard will not trip. Set \
spark.comet.exec.memoryGuard.size explicitly."
);
}
crate::execution::memory_pools::oom_guard::arm(limit as usize);
}
}

with_trace("createPlan", tracing_enabled, || {
// Init JVM classes
JVMClasses::init(env);
Expand Down Expand Up @@ -715,6 +739,8 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan(
schema_addrs: JLongArray,
) -> jlong {
try_unwrap_or_throw(&e, |env| {
#[cfg(feature = "oom-guard")]
crate::execution::memory_pools::oom_guard::stamp_current_thread();
// Retrieve the query
let exec_context = get_execution_context(exec_context);

Expand Down Expand Up @@ -786,6 +812,17 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan(
.await;

if let Err(panic) = result {
#[cfg(feature = "oom-guard")]
if let Some(e) =
crate::execution::memory_pools::oom_guard::map_panic_to_error(
panic.as_ref(),
)
{
// Runs on the tokio worker thread that panicked, so this clears
// that worker's UNWINDING flag (not the blocked JNI caller thread's).
let _ = tx.send(Err(e)).await;
return;
}
let msg = match panic.downcast_ref::<&str>() {
Some(s) => s.to_string(),
None => match panic.downcast_ref::<String>() {
Expand All @@ -810,76 +847,120 @@ pub unsafe extern "system" fn Java_org_apache_comet_Native_executePlan(
pull_input_batches(exec_context)?;
}

if let Some(rx) = &mut exec_context.batch_receiver {
match rx.blocking_recv() {
Some(Ok(batch)) => {
update_metrics(env, exec_context)?;
return prepare_output(
env,
array_addrs,
schema_addrs,
batch,
exec_context.debug_native,
);
}
Some(Err(e)) => {
return Err(e.into());
}
None => {
log_plan_metrics(exec_context, stage_id, partition);
return Ok(-1);
if exec_context.batch_receiver.is_some() {
let recv_result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(
|| -> CometResult<jlong> {
// Scope the rx borrow to just the blocking_recv call so that
// exec_context is free for update_metrics / prepare_output below.
let recv = exec_context
.batch_receiver
.as_mut()
.unwrap()
.blocking_recv();
match recv {
Some(Ok(batch)) => {
update_metrics(env, exec_context)?;
prepare_output(
env,
array_addrs,
schema_addrs,
batch,
exec_context.debug_native,
)
}
Some(Err(e)) => Err(e.into()),
None => {
log_plan_metrics(exec_context, stage_id, partition);
Ok(-1)
}
}
},
));

match recv_result {
Ok(r) => return r,
Err(_panic) => {
#[cfg(feature = "oom-guard")]
if let Some(e) =
crate::execution::memory_pools::oom_guard::map_panic_to_error(
_panic.as_ref(),
)
{
// Drop the receiver so any re-entry re-initializes.
exec_context.batch_receiver = None;
return Err(e.into());
}
std::panic::resume_unwind(_panic);
}
}
}

// ScanExec path: busy-poll to interleave JVM batch pulls with stream polling
get_runtime().block_on(async {
loop {
let next_item = exec_context.stream.as_mut().unwrap().next();
let poll_output = poll!(next_item);

// Only check time/tracing every 100 polls to reduce overhead
exec_context.poll_count_since_metrics_check += 1;
if exec_context.poll_count_since_metrics_check >= 100 {
exec_context.poll_count_since_metrics_check = 0;
if let Some(interval) = exec_context.metrics_update_interval {
let now = Instant::now();
if now - exec_context.metrics_last_update_time >= interval {
update_metrics(env, exec_context)?;
exec_context.metrics_last_update_time = now;
let poll_result = std::panic::catch_unwind(std::panic::AssertUnwindSafe(|| {
get_runtime().block_on(async {
loop {
let next_item = exec_context.stream.as_mut().unwrap().next();
let poll_output = poll!(next_item);

// Only check time/tracing every 100 polls to reduce overhead
exec_context.poll_count_since_metrics_check += 1;
if exec_context.poll_count_since_metrics_check >= 100 {
exec_context.poll_count_since_metrics_check = 0;
if let Some(interval) = exec_context.metrics_update_interval {
let now = Instant::now();
if now - exec_context.metrics_last_update_time >= interval {
update_metrics(env, exec_context)?;
exec_context.metrics_last_update_time = now;
}
}
if exec_context.tracing_enabled {
log_memory_usage(
&exec_context.tracing_memory_metric_name,
total_reserved_for_thread(exec_context.rust_thread_id) as u64,
);
}
}
if exec_context.tracing_enabled {
log_memory_usage(
&exec_context.tracing_memory_metric_name,
total_reserved_for_thread(exec_context.rust_thread_id) as u64,
);
}
}

match poll_output {
Poll::Ready(Some(output)) => {
return prepare_output(
env,
array_addrs,
schema_addrs,
output?,
exec_context.debug_native,
);
}
Poll::Ready(None) => {
log_plan_metrics(exec_context, stage_id, partition);
return Ok(-1);
}
Poll::Pending => {
// JNI call to pull batches from JVM into ScanExec operators.
// block_in_place lets tokio move other tasks off this worker
// while we wait for JVM data.
tokio::task::block_in_place(|| pull_input_batches(exec_context))?;
match poll_output {
Poll::Ready(Some(output)) => {
return prepare_output(
env,
array_addrs,
schema_addrs,
output?,
exec_context.debug_native,
);
}
Poll::Ready(None) => {
log_plan_metrics(exec_context, stage_id, partition);
return Ok(-1);
}
Poll::Pending => {
// JNI call to pull batches from JVM into ScanExec operators.
// block_in_place lets tokio move other tasks off this worker
// while we wait for JVM data.
tokio::task::block_in_place(|| pull_input_batches(exec_context))?;
}
}
}
})
}));

match poll_result {
Ok(r) => r,
Err(_panic) => {
#[cfg(feature = "oom-guard")]
if let Some(e) = crate::execution::memory_pools::oom_guard::map_panic_to_error(
_panic.as_ref(),
) {
// The block_on future was dropped mid-poll; null the stream so any
// inadvertent re-entry re-initializes rather than polling a half-consumed one.
exec_context.stream = None;
return Err(e.into());
}
std::panic::resume_unwind(_panic);
}
})
}
});

if exec_context.tracing_enabled {
Expand Down
2 changes: 2 additions & 0 deletions native/core/src/execution/memory_pools/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
mod config;
mod fair_pool;
pub mod logging_pool;
#[cfg(feature = "oom-guard")]
pub mod oom_guard;
mod task_shared;
mod unified_pool;

Expand Down
Loading