From 92ce301c552394c6a5771df23daf67637b7d0401 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Marko=20Milenkovi=C4=87?= <milenkovicm@users.noreply.github.com> Date: Mon, 21 Oct 2024 14:54:37 +0100 Subject: [PATCH] Replace BallistaContext with SessionContext (#1088) * Initial SessionContextExt skeleton relates to #1081 * add few more tests ... to find missing functionalities, and verify it `SessionContextExt` will not fail any of the tests for `BallistaContext` * Detect if LogicalPlan is scanning information schema ... it does, we will use `DefaultPhysicalPlanner` and execute query locally. * change extension interface, simplifying it * Change SessionContextExt interface ... ... add more tests * update rustdocs * remote methods accept `url` ... ... it would be easier to add security later. * remove config option for now ... ... would add them in next commits, once i get better idea about them. * debug failed windows test * remove `standalone` from default features in client * fix clippy in tests * fix formatting as well --- .gitignore | 1 + ballista/client/Cargo.toml | 6 + ballista/client/src/context.rs | 1 + ballista/client/src/extension.rs | 203 +++++++ ballista/client/src/lib.rs | 1 + ballista/client/src/prelude.rs | 3 +- ballista/client/tests/common/mod.rs | 147 +++++ ballista/client/tests/context_standalone.rs | 500 ++++++++++++++++++ ballista/client/tests/remote.rs | 145 +++++ ballista/client/tests/standalone.rs | 444 ++++++++++++++++ ballista/core/src/config.rs | 68 ++- ballista/core/src/error.rs | 4 +- ballista/core/src/serde/mod.rs | 89 +++- ballista/core/src/utils.rs | 196 ++++++- .../scheduler/src/scheduler_server/grpc.rs | 5 +- 15 files changed, 1779 insertions(+), 34 deletions(-) create mode 100644 ballista/client/src/extension.rs create mode 100644 ballista/client/tests/common/mod.rs create mode 100644 ballista/client/tests/context_standalone.rs create mode 100644 ballista/client/tests/remote.rs create mode 100644 ballista/client/tests/standalone.rs diff --git a/.gitignore b/.gitignore index d9e136a18..5e3a6a2bc 100644 --- a/.gitignore +++ b/.gitignore @@ -21,6 +21,7 @@ filtered_rat.txt arrow-src.tar arrow-src.tar.gz CHANGELOG.md.bak +Cargo.toml.bak # Compiled source *.a diff --git a/ballista/client/Cargo.toml b/ballista/client/Cargo.toml index f26f73eb4..a8de27362 100644 --- a/ballista/client/Cargo.toml +++ b/ballista/client/Cargo.toml @@ -28,6 +28,7 @@ edition = "2021" rust-version = "1.72" [dependencies] +async-trait = { workspace = true } ballista-core = { path = "../core", version = "0.12.0" } ballista-executor = { path = "../executor", version = "0.12.0", optional = true } ballista-scheduler = { path = "../scheduler", version = "0.12.0", optional = true } @@ -39,6 +40,11 @@ parking_lot = { workspace = true } sqlparser = { workspace = true } tempfile = { workspace = true } tokio = { workspace = true } +url = { version = "2.5" } + +[dev-dependencies] +ctor = { version = "0.2" } +env_logger = { workspace = true } [features] azure = ["ballista-core/azure"] diff --git a/ballista/client/src/context.rs b/ballista/client/src/context.rs index 269afc64d..b09e1d65b 100644 --- a/ballista/client/src/context.rs +++ b/ballista/client/src/context.rs @@ -76,6 +76,7 @@ impl BallistaContextState { } } +// #[deprecated] pub struct BallistaContext { state: Arc<Mutex<BallistaContextState>>, context: Arc<SessionContext>, diff --git a/ballista/client/src/extension.rs b/ballista/client/src/extension.rs new file mode 100644 index 000000000..ca104d3b1 --- /dev/null +++ b/ballista/client/src/extension.rs @@ -0,0 +1,203 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use ballista_core::{ + config::BallistaConfig, + serde::protobuf::{ + scheduler_grpc_client::SchedulerGrpcClient, CreateSessionParams, KeyValuePair, + }, + utils::{create_df_ctx_with_ballista_query_planner, create_grpc_client_connection}, +}; +use datafusion::{error::DataFusionError, prelude::SessionContext}; +use datafusion_proto::protobuf::LogicalPlanNode; +use url::Url; + +const DEFAULT_SCHEDULER_PORT: u16 = 50050; + +/// Module provides [SessionContextExt] which adds `standalone*` and `remote*` +/// methods to [SessionContext]. +/// +/// Provided methods set up [SessionContext] with [BallistaQueryPlanner](ballista_core::utils), which +/// handles running plans on Ballista clusters. +/// +///```no_run +/// use ballista::prelude::SessionContextExt; +/// use datafusion::prelude::SessionContext; +/// +/// # #[tokio::main] +/// # async fn main() -> datafusion::error::Result<()> { +/// let ctx: SessionContext = SessionContext::remote("df://localhost:50050").await?; +/// # Ok(()) +/// # } +///``` +/// +/// [SessionContextExt::standalone()] provides an easy way to start up +/// local cluster. It is an optional feature which should be enabled +/// with `standalone` +/// +///```no_run +/// use ballista::prelude::SessionContextExt; +/// use datafusion::prelude::SessionContext; +/// +/// # #[tokio::main] +/// # async fn main() -> datafusion::error::Result<()> { +/// let ctx: SessionContext = SessionContext::standalone().await?; +/// # Ok(()) +/// # } +///``` +/// +/// There are still few limitations on query distribution, thus not all +/// [SessionContext] functionalities are supported. +/// +#[async_trait::async_trait] +pub trait SessionContextExt { + /// Create a context for executing queries against a standalone Ballista scheduler instance + /// It wills start local ballista cluster with scheduler and executor. + #[cfg(feature = "standalone")] + async fn standalone() -> datafusion::error::Result<SessionContext>; + + /// Create a context for executing queries against a remote Ballista scheduler instance + async fn remote(url: &str) -> datafusion::error::Result<SessionContext>; +} + +#[async_trait::async_trait] +impl SessionContextExt for SessionContext { + async fn remote(url: &str) -> datafusion::error::Result<SessionContext> { + let url = + Url::parse(url).map_err(|e| DataFusionError::Configuration(e.to_string()))?; + let host = url.host().ok_or(DataFusionError::Configuration( + "hostname should be provided".to_string(), + ))?; + let port = url.port().unwrap_or(DEFAULT_SCHEDULER_PORT); + let scheduler_url = format!("http://{}:{}", &host, port); + log::info!( + "Connecting to Ballista scheduler at {}", + scheduler_url.clone() + ); + let connection = create_grpc_client_connection(scheduler_url.clone()) + .await + .map_err(|e| DataFusionError::Execution(format!("{e:?}")))?; + + let config = BallistaConfig::builder() + .build() + .map_err(|e| DataFusionError::Configuration(e.to_string()))?; + + let limit = config.default_grpc_client_max_message_size(); + let mut scheduler = SchedulerGrpcClient::new(connection) + .max_encoding_message_size(limit) + .max_decoding_message_size(limit); + + let remote_session_id = scheduler + .create_session(CreateSessionParams { + settings: config + .settings() + .iter() + .map(|(k, v)| KeyValuePair { + key: k.to_owned(), + value: v.to_owned(), + }) + .collect::<Vec<_>>(), + }) + .await + .map_err(|e| DataFusionError::Execution(format!("{e:?}")))? + .into_inner() + .session_id; + + log::info!( + "Server side SessionContext created with session id: {}", + remote_session_id + ); + + let ctx = { + create_df_ctx_with_ballista_query_planner::<LogicalPlanNode>( + scheduler_url, + remote_session_id, + &config, + ) + }; + + Ok(ctx) + } + + #[cfg(feature = "standalone")] + async fn standalone() -> datafusion::error::Result<Self> { + use ballista_core::serde::BallistaCodec; + use datafusion_proto::protobuf::PhysicalPlanNode; + + log::info!("Running in local mode. Scheduler will be run in-proc"); + + let addr = ballista_scheduler::standalone::new_standalone_scheduler() + .await + .map_err(|e| DataFusionError::Configuration(e.to_string()))?; + + let scheduler_url = format!("http://localhost:{}", addr.port()); + let mut scheduler = loop { + match SchedulerGrpcClient::connect(scheduler_url.clone()).await { + Err(_) => { + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + log::info!("Attempting to connect to in-proc scheduler..."); + } + Ok(scheduler) => break scheduler, + } + }; + let config = BallistaConfig::builder() + .build() + .map_err(|e| DataFusionError::Configuration(e.to_string()))?; + let remote_session_id = scheduler + .create_session(CreateSessionParams { + settings: config + .settings() + .iter() + .map(|(k, v)| KeyValuePair { + key: k.to_owned(), + value: v.to_owned(), + }) + .collect::<Vec<_>>(), + }) + .await + .map_err(|e| DataFusionError::Execution(format!("{e:?}")))? + .into_inner() + .session_id; + + log::info!( + "Server side SessionContext created with session id: {}", + remote_session_id + ); + + let ctx = { + create_df_ctx_with_ballista_query_planner::<LogicalPlanNode>( + scheduler_url, + remote_session_id, + &config, + ) + }; + + let default_codec: BallistaCodec<LogicalPlanNode, PhysicalPlanNode> = + BallistaCodec::default(); + + let concurrent_tasks = config.default_standalone_parallelism(); + ballista_executor::new_standalone_executor( + scheduler, + concurrent_tasks, + default_codec, + ) + .await + .map_err(|e| DataFusionError::Configuration(e.to_string()))?; + + Ok(ctx) + } +} diff --git a/ballista/client/src/lib.rs b/ballista/client/src/lib.rs index e61dfef28..76bd0c940 100644 --- a/ballista/client/src/lib.rs +++ b/ballista/client/src/lib.rs @@ -18,4 +18,5 @@ #![doc = include_str!("../README.md")] pub mod context; +pub mod extension; pub mod prelude; diff --git a/ballista/client/src/prelude.rs b/ballista/client/src/prelude.rs index acab66529..1b7988770 100644 --- a/ballista/client/src/prelude.rs +++ b/ballista/client/src/prelude.rs @@ -23,7 +23,7 @@ pub use ballista_core::{ BALLISTA_DEFAULT_SHUFFLE_PARTITIONS, BALLISTA_GRPC_CLIENT_MAX_MESSAGE_SIZE, BALLISTA_JOB_NAME, BALLISTA_PARQUET_PRUNING, BALLISTA_REPARTITION_AGGREGATIONS, BALLISTA_REPARTITION_JOINS, BALLISTA_REPARTITION_WINDOWS, - BALLISTA_WITH_INFORMATION_SCHEMA, + BALLISTA_STANDALONE_PARALLELISM, BALLISTA_WITH_INFORMATION_SCHEMA, }, error::{BallistaError, Result}, }; @@ -31,3 +31,4 @@ pub use ballista_core::{ pub use futures::StreamExt; pub use crate::context::BallistaContext; +pub use crate::extension::SessionContextExt; diff --git a/ballista/client/tests/common/mod.rs b/ballista/client/tests/common/mod.rs new file mode 100644 index 000000000..02f25d7be --- /dev/null +++ b/ballista/client/tests/common/mod.rs @@ -0,0 +1,147 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use std::env; +use std::error::Error; +use std::path::PathBuf; + +use ballista::prelude::BallistaConfig; +use ballista_core::serde::{ + protobuf::scheduler_grpc_client::SchedulerGrpcClient, BallistaCodec, +}; + +// /// Remote ballista cluster to be used for local testing. +// static BALLISTA_CLUSTER: tokio::sync::OnceCell<(String, u16)> = +// tokio::sync::OnceCell::const_new(); + +/// Returns the parquet test data directory, which is by default +/// stored in a git submodule rooted at +/// `examples/testdata`. +/// +/// The default can be overridden by the optional environment variable +/// `EXAMPLES_TEST_DATA` +/// +/// panics when the directory can not be found. +/// +/// Example: +/// ``` +/// use ballista_examples::test_util; +/// let testdata = test_util::examples_test_data(); +/// let filename = format!("{testdata}/aggregate_test_100.csv"); +/// assert!(std::path::PathBuf::from(filename).exists()); +/// ``` +#[allow(dead_code)] +pub fn example_test_data() -> String { + match get_data_dir("EXAMPLES_TEST_DATA", "testdata") { + Ok(pb) => pb.display().to_string(), + Err(err) => panic!("failed to get examples test data dir: {err}"), + } +} + +/// Returns a directory path for finding test data. +/// +/// udf_env: name of an environment variable +/// +/// submodule_dir: fallback path (relative to CARGO_MANIFEST_DIR) +/// +/// Returns either: +/// The path referred to in `udf_env` if that variable is set and refers to a directory +/// The submodule_data directory relative to CARGO_MANIFEST_PATH +#[allow(dead_code)] +fn get_data_dir(udf_env: &str, submodule_data: &str) -> Result<PathBuf, Box<dyn Error>> { + // Try user defined env. + if let Ok(dir) = env::var(udf_env) { + let trimmed = dir.trim().to_string(); + if !trimmed.is_empty() { + let pb = PathBuf::from(trimmed); + if pb.is_dir() { + return Ok(pb); + } else { + return Err(format!( + "the data dir `{}` defined by env {udf_env} not found", + pb.display() + ) + .into()); + } + } + } + + // The env is undefined or its value is trimmed to empty, let's try default dir. + + // env "CARGO_MANIFEST_DIR" is "the directory containing the manifest of your package", + // set by `cargo run` or `cargo test`, see: + // https://doc.rust-lang.org/cargo/reference/environment-variables.html + let dir = env!("CARGO_MANIFEST_DIR"); + + let pb = PathBuf::from(dir).join(submodule_data); + if pb.is_dir() { + Ok(pb) + } else { + Err(format!( + "env `{udf_env}` is undefined or has empty value, and the pre-defined data dir `{}` not found\n\ + HINT: try running `git submodule update --init`", + pb.display(), + ).into()) + } +} + +/// starts a ballista cluster for integration tests +#[allow(dead_code)] +pub async fn setup_test_cluster() -> (String, u16) { + let config = BallistaConfig::builder().build().unwrap(); + let default_codec = BallistaCodec::default(); + + let addr = ballista_scheduler::standalone::new_standalone_scheduler() + .await + .expect("scheduler to be created"); + + let host = "localhost".to_string(); + + let scheduler_url = format!("http://{}:{}", host, addr.port()); + + let scheduler = loop { + match SchedulerGrpcClient::connect(scheduler_url.clone()).await { + Err(_) => { + tokio::time::sleep(tokio::time::Duration::from_millis(100)).await; + log::info!("Attempting to connect to test scheduler..."); + } + Ok(scheduler) => break scheduler, + } + }; + + ballista_executor::new_standalone_executor( + scheduler, + config.default_standalone_parallelism(), + default_codec, + ) + .await + .expect("executor to be created"); + + log::info!("test scheduler created at: {}:{}", host, addr.port()); + + (host, addr.port()) +} + +#[ctor::ctor] +fn init() { + // Enable RUST_LOG logging configuration for test + let _ = env_logger::builder() + .filter_level(log::LevelFilter::Info) + .parse_filters("ballista=debug,ballista_scheduler-rs=debug,ballista_executor=debug,datafusion=debug") + .is_test(true) + .try_init(); +} diff --git a/ballista/client/tests/context_standalone.rs b/ballista/client/tests/context_standalone.rs new file mode 100644 index 000000000..bd83d5276 --- /dev/null +++ b/ballista/client/tests/context_standalone.rs @@ -0,0 +1,500 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod common; + +// +// The tests are extracted from context.rs where `BallistaContext` lives. +// to be checked if `SessionContextExt` has same functionality like `BallistaContext` +// +#[cfg(test)] +#[cfg(feature = "standalone")] +mod standalone_tests { + use ballista::extension::SessionContextExt; + use ballista_core::error::Result; + use datafusion::arrow; + use datafusion::arrow::util::pretty::pretty_format_batches; + use datafusion::config::TableParquetOptions; + use datafusion::dataframe::DataFrameWriteOptions; + use datafusion::prelude::ParquetReadOptions; + use datafusion::prelude::SessionContext; + use std::fs::File; + use std::io::Write; + use tempfile::TempDir; + + #[tokio::test] + async fn test_standalone_mode() { + let context = SessionContext::standalone().await.unwrap(); + let df = context.sql("SELECT 1;").await.unwrap(); + df.collect().await.unwrap(); + } + + #[tokio::test] + async fn test_write_parquet() -> Result<()> { + let context = SessionContext::standalone().await?; + let df = context.sql("SELECT 1;").await?; + let tmp_dir = TempDir::new().unwrap(); + let file_path = format!( + "{}", + tmp_dir.path().join("test_write_parquet.parquet").display() + ); + df.write_parquet( + &file_path, + DataFrameWriteOptions::default(), + Some(TableParquetOptions::default()), + ) + .await?; + Ok(()) + } + + #[tokio::test] + async fn test_write_csv() -> Result<()> { + let context = SessionContext::standalone().await?; + let df = context.sql("SELECT 1;").await?; + let tmp_dir = TempDir::new().unwrap(); + let file_path = + format!("{}", tmp_dir.path().join("test_write_csv.csv").display()); + df.write_csv(&file_path, DataFrameWriteOptions::default(), None) + .await?; + Ok(()) + } + + #[tokio::test] + async fn test_ballista_show_tables() { + let context = SessionContext::standalone().await.unwrap(); + + let data = "Jorge,2018-12-13T12:12:10.011Z\n\ + Andrew,2018-11-13T17:11:10.011Z"; + + let tmp_dir = TempDir::new().unwrap(); + let file_path = tmp_dir.path().join("timestamps.csv"); + + // scope to ensure the file is closed and written + { + File::create(&file_path) + .expect("creating temp file") + .write_all(data.as_bytes()) + .expect("writing data"); + } + + let sql = format!( + "CREATE EXTERNAL TABLE csv_with_timestamps ( + name VARCHAR, + ts TIMESTAMP + ) + STORED AS CSV + LOCATION '{}' + OPTIONS ('has_header' 'false', 'delimiter' ',') + ", + file_path.to_str().expect("path is utf8") + ); + + context.sql(sql.as_str()).await.unwrap(); + + let df = context.sql("show columns from csv_with_timestamps;").await; + + // used to fail with ballista context + // assert!(df.is_err()); + assert!(df.is_ok()); + + let result = df.unwrap().collect().await.unwrap(); + + let expected = ["+---------------+--------------+---------------------+-------------+-----------------------------+-------------+", + "| table_catalog | table_schema | table_name | column_name | data_type | is_nullable |", + "+---------------+--------------+---------------------+-------------+-----------------------------+-------------+", + "| datafusion | public | csv_with_timestamps | name | Utf8 | YES |", + "| datafusion | public | csv_with_timestamps | ts | Timestamp(Nanosecond, None) | YES |", + "+---------------+--------------+---------------------+-------------+-----------------------------+-------------+"]; + datafusion::assert_batches_eq!(expected, &result); + } + + #[tokio::test] + async fn test_show_tables_not_with_information_schema() { + let context = SessionContext::standalone().await.unwrap(); + + let data = "Jorge,2018-12-13T12:12:10.011Z\n\ + Andrew,2018-11-13T17:11:10.011Z"; + + let tmp_dir = TempDir::new().unwrap(); + let file_path = tmp_dir.path().join("timestamps.csv"); + + // scope to ensure the file is closed and written + { + File::create(&file_path) + .expect("creating temp file") + .write_all(data.as_bytes()) + .expect("writing data"); + } + + let sql = format!( + "CREATE EXTERNAL TABLE csv_with_timestamps ( + name VARCHAR, + ts TIMESTAMP + ) + STORED AS CSV + LOCATION '{}' + ", + file_path.to_str().expect("path is utf8") + ); + + context.sql(sql.as_str()).await.unwrap(); + let df = context.sql("show tables;").await; + assert!(df.is_ok()); + } + #[tokio::test] + async fn test_empty_exec_with_one_row() { + let context = SessionContext::standalone().await.unwrap(); + + let sql = "select EXTRACT(year FROM to_timestamp('2020-09-08T12:13:14+00:00'));"; + + let df = context.sql(sql).await.unwrap(); + assert!(!df.collect().await.unwrap().is_empty()); + } + + #[tokio::test] + async fn test_union_and_union_all() { + let context = SessionContext::standalone().await.unwrap(); + + let df = context + .sql("SELECT 1 as NUMBER union SELECT 1 as NUMBER;") + .await + .unwrap(); + let res1 = df.collect().await.unwrap(); + let expected1 = vec![ + "+--------+", + "| number |", + "+--------+", + "| 1 |", + "+--------+", + ]; + assert_eq!( + expected1, + pretty_format_batches(&res1) + .unwrap() + .to_string() + .trim() + .lines() + .collect::<Vec<&str>>() + ); + let expected2 = vec![ + "+--------+", + "| number |", + "+--------+", + "| 1 |", + "| 1 |", + "+--------+", + ]; + let df = context + .sql("SELECT 1 as NUMBER union all SELECT 1 as NUMBER;") + .await + .unwrap(); + let res2 = df.collect().await.unwrap(); + assert_eq!( + expected2, + pretty_format_batches(&res2) + .unwrap() + .to_string() + .trim() + .lines() + .collect::<Vec<&str>>() + ); + } + + #[tokio::test] + async fn test_aggregate_min_max() { + let context = create_test_context().await; + + let df = context.sql("select min(\"id\") from test").await.unwrap(); + let res = df.collect().await.unwrap(); + let expected = vec![ + "+--------------+", + "| min(test.id) |", + "+--------------+", + "| 0 |", + "+--------------+", + ]; + assert_result_eq(expected, &res); + + let df = context.sql("select max(\"id\") from test").await.unwrap(); + let res = df.collect().await.unwrap(); + let expected = vec![ + "+--------------+", + "| max(test.id) |", + "+--------------+", + "| 7 |", + "+--------------+", + ]; + assert_result_eq(expected, &res); + } + + #[tokio::test] + async fn test_aggregate_sum() { + let context = create_test_context().await; + + let df = context.sql("select SUM(\"id\") from test").await.unwrap(); + let res = df.collect().await.unwrap(); + let expected = vec![ + "+--------------+", + "| sum(test.id) |", + "+--------------+", + "| 28 |", + "+--------------+", + ]; + assert_result_eq(expected, &res); + } + #[tokio::test] + async fn test_aggregate_avg() { + let context = create_test_context().await; + + let df = context.sql("select AVG(\"id\") from test").await.unwrap(); + let res = df.collect().await.unwrap(); + let expected = vec![ + "+--------------+", + "| avg(test.id) |", + "+--------------+", + "| 3.5 |", + "+--------------+", + ]; + assert_result_eq(expected, &res); + } + + #[tokio::test] + async fn test_aggregate_count() { + let context = create_test_context().await; + + let df = context.sql("select COUNT(\"id\") from test").await.unwrap(); + let res = df.collect().await.unwrap(); + let expected = vec![ + "+----------------+", + "| count(test.id) |", + "+----------------+", + "| 8 |", + "+----------------+", + ]; + assert_result_eq(expected, &res); + } + #[tokio::test] + async fn test_aggregate_approx_distinct() { + let context = create_test_context().await; + + let df = context + .sql("select approx_distinct(\"id\") from test") + .await + .unwrap(); + let res = df.collect().await.unwrap(); + let expected = vec![ + "+--------------------------+", + "| approx_distinct(test.id) |", + "+--------------------------+", + "| 8 |", + "+--------------------------+", + ]; + assert_result_eq(expected, &res); + } + #[tokio::test] + async fn test_aggregate_array_agg() { + let context = create_test_context().await; + + let df = context + .sql("select ARRAY_AGG(\"id\") from test") + .await + .unwrap(); + let res = df.collect().await.unwrap(); + let expected = vec![ + "+--------------------------+", + "| array_agg(test.id) |", + "+--------------------------+", + "| [4, 5, 6, 7, 2, 3, 0, 1] |", + "+--------------------------+", + ]; + assert_result_eq(expected, &res); + } + #[tokio::test] + async fn test_aggregate_var() { + let context = create_test_context().await; + + let df = context.sql("select VAR(\"id\") from test").await.unwrap(); + let res = df.collect().await.unwrap(); + let expected = vec![ + "+-------------------+", + "| var(test.id) |", + "+-------------------+", + "| 6.000000000000001 |", + "+-------------------+", + ]; + assert_result_eq(expected, &res); + + let df = context + .sql("select VAR_POP(\"id\") from test") + .await + .unwrap(); + let res = df.collect().await.unwrap(); + let expected = vec![ + "+-------------------+", + "| var_pop(test.id) |", + "+-------------------+", + "| 5.250000000000001 |", + "+-------------------+", + ]; + assert_result_eq(expected, &res); + + let df = context + .sql("select VAR_SAMP(\"id\") from test") + .await + .unwrap(); + let res = df.collect().await.unwrap(); + let expected = vec![ + "+-------------------+", + "| var(test.id) |", + "+-------------------+", + "| 6.000000000000001 |", + "+-------------------+", + ]; + assert_result_eq(expected, &res); + } + #[tokio::test] + async fn test_aggregate_stddev() { + let context = create_test_context().await; + + let df = context + .sql("select STDDEV(\"id\") from test") + .await + .unwrap(); + let res = df.collect().await.unwrap(); + let expected = vec![ + "+--------------------+", + "| stddev(test.id) |", + "+--------------------+", + "| 2.4494897427831783 |", + "+--------------------+", + ]; + assert_result_eq(expected, &res); + + let df = context + .sql("select STDDEV_SAMP(\"id\") from test") + .await + .unwrap(); + let res = df.collect().await.unwrap(); + let expected = vec![ + "+--------------------+", + "| stddev(test.id) |", + "+--------------------+", + "| 2.4494897427831783 |", + "+--------------------+", + ]; + assert_result_eq(expected, &res); + } + #[tokio::test] + async fn test_aggregate_covar() { + let context = create_test_context().await; + + let df = context + .sql("select COVAR(id, tinyint_col) from test") + .await + .unwrap(); + let res = df.collect().await.unwrap(); + let expected = vec![ + "+--------------------------------------+", + "| covar_samp(test.id,test.tinyint_col) |", + "+--------------------------------------+", + "| 0.28571428571428586 |", + "+--------------------------------------+", + ]; + assert_result_eq(expected, &res); + } + #[tokio::test] + async fn test_aggregate_correlation() { + let context = create_test_context().await; + + let df = context + .sql("select CORR(id, tinyint_col) from test") + .await + .unwrap(); + let res = df.collect().await.unwrap(); + let expected = vec![ + "+--------------------------------+", + "| corr(test.id,test.tinyint_col) |", + "+--------------------------------+", + "| 0.21821789023599245 |", + "+--------------------------------+", + ]; + assert_result_eq(expected, &res); + } + // enable when upgrading Datafusion to > 42 + #[ignore] + #[tokio::test] + async fn test_aggregate_approx_percentile() { + let context = create_test_context().await; + + let df = context + .sql("select approx_percentile_cont_with_weight(id, 2, 0.5) from test") + .await + .unwrap(); + let res = df.collect().await.unwrap(); + let expected = vec![ + "+-------------------------------------------------------------------+", + "| approx_percentile_cont_with_weight(test.id,Int64(2),Float64(0.5)) |", + "+-------------------------------------------------------------------+", + "| 1 |", + "+-------------------------------------------------------------------+", + ]; + assert_result_eq(expected, &res); + + let df = context + .sql("select approx_percentile_cont(\"double_col\", 0.5) from test") + .await + .unwrap(); + let res = df.collect().await.unwrap(); + let expected = vec![ + "+------------------------------------------------------+", + "| approx_percentile_cont(test.double_col,Float64(0.5)) |", + "+------------------------------------------------------+", + "| 7.574999999999999 |", + "+------------------------------------------------------+", + ]; + + assert_result_eq(expected, &res); + } + + fn assert_result_eq( + expected: Vec<&str>, + results: &[arrow::record_batch::RecordBatch], + ) { + assert_eq!( + expected, + pretty_format_batches(results) + .unwrap() + .to_string() + .trim() + .lines() + .collect::<Vec<&str>>() + ); + } + async fn create_test_context() -> SessionContext { + let context = SessionContext::standalone().await.unwrap(); + + context + .register_parquet( + "test", + "testdata/alltypes_plain.parquet", + ParquetReadOptions::default(), + ) + .await + .unwrap(); + context + } +} diff --git a/ballista/client/tests/remote.rs b/ballista/client/tests/remote.rs new file mode 100644 index 000000000..619c4cd62 --- /dev/null +++ b/ballista/client/tests/remote.rs @@ -0,0 +1,145 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod common; + +#[cfg(test)] +mod remote { + use ballista::extension::SessionContextExt; + use datafusion::{assert_batches_eq, prelude::SessionContext}; + + #[tokio::test] + async fn should_execute_sql_show() -> datafusion::error::Result<()> { + let (host, port) = crate::common::setup_test_cluster().await; + let url = format!("df://{host}:{port}"); + + let test_data = crate::common::example_test_data(); + let ctx: SessionContext = SessionContext::remote(&url).await?; + + ctx.register_parquet( + "test", + &format!("{test_data}/alltypes_plain.parquet"), + Default::default(), + ) + .await?; + + let result = ctx + .sql("select string_col, timestamp_col from test where id > 4") + .await? + .collect() + .await?; + let expected = [ + "+------------+---------------------+", + "| string_col | timestamp_col |", + "+------------+---------------------+", + "| 31 | 2009-03-01T00:01:00 |", + "| 30 | 2009-04-01T00:00:00 |", + "| 31 | 2009-04-01T00:01:00 |", + "+------------+---------------------+", + ]; + + assert_batches_eq!(expected, &result); + + Ok(()) + } + + #[tokio::test] + #[cfg(not(windows))] // test is failing at windows, can't debug it + async fn should_execute_sql_write() -> datafusion::error::Result<()> { + let test_data = crate::common::example_test_data(); + let (host, port) = crate::common::setup_test_cluster().await; + let url = format!("df://{host}:{port}"); + + let ctx: SessionContext = SessionContext::remote(&url).await?; + + ctx.register_parquet( + "test", + &format!("{test_data}/alltypes_plain.parquet"), + Default::default(), + ) + .await?; + let write_dir = tempfile::tempdir().expect("temporary directory to be created"); + let write_dir_path = write_dir + .path() + .to_str() + .expect("path to be converted to str"); + + log::info!("writing to parquet .. {}", write_dir_path); + ctx.sql("select * from test") + .await? + .write_parquet(write_dir_path, Default::default(), Default::default()) + .await?; + + log::info!("registering parquet .. {}", write_dir_path); + ctx.register_parquet("written_table", write_dir_path, Default::default()) + .await?; + log::info!("reading from written parquet .."); + let result = ctx + .sql("select id, string_col, timestamp_col from written_table where id > 4") + .await? + .collect() + .await?; + let expected = [ + "+----+------------+---------------------+", + "| id | string_col | timestamp_col |", + "+----+------------+---------------------+", + "| 5 | 31 | 2009-03-01T00:01:00 |", + "| 6 | 30 | 2009-04-01T00:00:00 |", + "| 7 | 31 | 2009-04-01T00:01:00 |", + "+----+------------+---------------------+", + ]; + log::info!("reading from written parquet .. DONE"); + assert_batches_eq!(expected, &result); + Ok(()) + } + + #[tokio::test] + async fn should_execute_show_tables() -> datafusion::error::Result<()> { + let test_data = crate::common::example_test_data(); + + let (host, port) = crate::common::setup_test_cluster().await; + let url = format!("df://{host}:{port}"); + + let ctx: SessionContext = SessionContext::remote(&url).await?; + + ctx.register_parquet( + "test", + &format!("{test_data}/alltypes_plain.parquet"), + Default::default(), + ) + .await?; + + let result = ctx.sql("show tables").await?.collect().await?; + // + let expected = [ + "+---------------+--------------------+-------------+------------+", + "| table_catalog | table_schema | table_name | table_type |", + "+---------------+--------------------+-------------+------------+", + "| datafusion | public | test | BASE TABLE |", + "| datafusion | information_schema | tables | VIEW |", + "| datafusion | information_schema | views | VIEW |", + "| datafusion | information_schema | columns | VIEW |", + "| datafusion | information_schema | df_settings | VIEW |", + "| datafusion | information_schema | schemata | VIEW |", + "+---------------+--------------------+-------------+------------+", + ]; + + assert_batches_eq!(expected, &result); + + Ok(()) + } +} diff --git a/ballista/client/tests/standalone.rs b/ballista/client/tests/standalone.rs new file mode 100644 index 000000000..b483a7c21 --- /dev/null +++ b/ballista/client/tests/standalone.rs @@ -0,0 +1,444 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +mod common; + +#[cfg(test)] +#[cfg(feature = "standalone")] +mod standalone { + use ballista::{extension::SessionContextExt, prelude::*}; + use datafusion::prelude::*; + use datafusion::{assert_batches_eq, prelude::SessionContext}; + + #[tokio::test] + async fn should_execute_sql_show() -> datafusion::error::Result<()> { + let test_data = crate::common::example_test_data(); + + let ctx: SessionContext = SessionContext::standalone().await?; + ctx.register_parquet( + "test", + &format!("{test_data}/alltypes_plain.parquet"), + Default::default(), + ) + .await?; + + let result = ctx + .sql("select string_col, timestamp_col from test where id > 4") + .await? + .collect() + .await?; + let expected = [ + "+------------+---------------------+", + "| string_col | timestamp_col |", + "+------------+---------------------+", + "| 31 | 2009-03-01T00:01:00 |", + "| 30 | 2009-04-01T00:00:00 |", + "| 31 | 2009-04-01T00:01:00 |", + "+------------+---------------------+", + ]; + + assert_batches_eq!(expected, &result); + + Ok(()) + } + + #[tokio::test] + async fn should_execute_sql_show_configs() -> datafusion::error::Result<()> { + let ctx: SessionContext = SessionContext::standalone().await?; + + let result = ctx + .sql("select name from information_schema.df_settings where name like 'datafusion.%' order by name limit 5") + .await? + .collect() + .await?; + // + let expected = [ + "+------------------------------------------------------+", + "| name |", + "+------------------------------------------------------+", + "| datafusion.catalog.create_default_catalog_and_schema |", + "| datafusion.catalog.default_catalog |", + "| datafusion.catalog.default_schema |", + "| datafusion.catalog.format |", + "| datafusion.catalog.has_header |", + "+------------------------------------------------------+", + ]; + + assert_batches_eq!(expected, &result); + + Ok(()) + } + + #[tokio::test] + async fn should_execute_sql_show_configs_ballista() -> datafusion::error::Result<()> { + let ctx: SessionContext = SessionContext::standalone().await?; + let state = ctx.state(); + let ballista_config_extension = + state.config().options().extensions.get::<BallistaConfig>(); + + // ballista configuration should be registered with + // session state + assert!(ballista_config_extension.is_some()); + + let result = ctx + .sql("select name, value from information_schema.df_settings where name like 'ballista.%' order by name limit 5") + .await? + .collect() + .await?; + + let expected = [ + "+---------------------------------------------------------+----------+", + "| name | value |", + "+---------------------------------------------------------+----------+", + "| ballista.batch.size | 8192 |", + "| ballista.collect_statistics | false |", + "| ballista.grpc_client_max_message_size | 16777216 |", + "| ballista.job.name | |", + "| ballista.optimizer.hash_join_single_partition_threshold | 1048576 |", + "+---------------------------------------------------------+----------+", + ]; + + assert_batches_eq!(expected, &result); + + Ok(()) + } + + #[tokio::test] + async fn should_execute_sql_set_configs() -> datafusion::error::Result<()> { + let ctx: SessionContext = SessionContext::standalone().await?; + + ctx.sql("SET ballista.job.name = 'Super Cool Ballista App'") + .await? + .show() + .await?; + + let result = ctx + .sql("select name, value from information_schema.df_settings where name like 'ballista.job.name' order by name limit 1") + .await? + .collect() + .await?; + + let expected = [ + "+-------------------+-------------------------+", + "| name | value |", + "+-------------------+-------------------------+", + "| ballista.job.name | Super Cool Ballista App |", + "+-------------------+-------------------------+", + ]; + + assert_batches_eq!(expected, &result); + + Ok(()) + } + + // select from ballista config + // check for SET = + + #[tokio::test] + async fn should_execute_show_tables() -> datafusion::error::Result<()> { + let test_data = crate::common::example_test_data(); + + let ctx: SessionContext = SessionContext::standalone().await?; + ctx.register_parquet( + "test", + &format!("{test_data}/alltypes_plain.parquet"), + Default::default(), + ) + .await?; + + let result = ctx.sql("show tables").await?.collect().await?; + // + let expected = [ + "+---------------+--------------------+-------------+------------+", + "| table_catalog | table_schema | table_name | table_type |", + "+---------------+--------------------+-------------+------------+", + "| datafusion | public | test | BASE TABLE |", + "| datafusion | information_schema | tables | VIEW |", + "| datafusion | information_schema | views | VIEW |", + "| datafusion | information_schema | columns | VIEW |", + "| datafusion | information_schema | df_settings | VIEW |", + "| datafusion | information_schema | schemata | VIEW |", + "+---------------+--------------------+-------------+------------+", + ]; + + assert_batches_eq!(expected, &result); + + Ok(()) + } + + // + // TODO: It calls scheduler to generate the plan, but no + // but there is no ShuffleRead/Write in physical_plan + // + // ShuffleWriterExec: None, metrics=[output_rows=2, input_rows=2, write_time=1.782295ms, repart_time=1ns] + // ExplainExec, metrics=[] + // + #[tokio::test] + #[ignore = "It uses local files, will fail in CI"] + async fn should_execute_sql_explain() -> datafusion::error::Result<()> { + let test_data = crate::common::example_test_data(); + + let ctx: SessionContext = SessionContext::standalone().await?; + ctx.register_parquet( + "test", + &format!("{test_data}/alltypes_plain.parquet"), + Default::default(), + ) + .await?; + + let result = ctx + .sql("EXPLAIN select count(*), id from test where id > 4 group by id") + .await? + .collect() + .await?; + + let expected = vec![ + "+---------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+", + "| plan_type | plan |", + "+---------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+", + "| logical_plan | Projection: count(*), test.id |", + "| | Aggregate: groupBy=[[test.id]], aggr=[[count(Int64(1)) AS count(*)]] |", + "| | Filter: test.id > Int32(4) |", + "| | TableScan: test projection=[id], partial_filters=[test.id > Int32(4)] |", + "| physical_plan | ProjectionExec: expr=[count(*)@1 as count(*), id@0 as id] |", + "| | AggregateExec: mode=FinalPartitioned, gby=[id@0 as id], aggr=[count(*)] |", + "| | CoalesceBatchesExec: target_batch_size=8192 |", + "| | RepartitionExec: partitioning=Hash([id@0], 16), input_partitions=1 |", + "| | AggregateExec: mode=Partial, gby=[id@0 as id], aggr=[count(*)] |", + "| | CoalesceBatchesExec: target_batch_size=8192 |", + "| | FilterExec: id@0 > 4 |", + "| | ParquetExec: file_groups={1 group: [[Users/ballista/git/arrow-ballista/ballista/client/testdata/alltypes_plain.parquet]]}, projection=[id], predicate=id@0 > 4, pruning_predicate=CASE WHEN id_null_count@1 = id_row_count@2 THEN false ELSE id_max@0 > 4 END, required_guarantees=[] |", + "| | |", + "+---------------+--------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------+", + ]; + + assert_batches_eq!(expected, &result); + + Ok(()) + } + + #[tokio::test] + async fn should_execute_sql_create_external_table() -> datafusion::error::Result<()> { + let test_data = crate::common::example_test_data(); + + let ctx: SessionContext = SessionContext::standalone().await?; + ctx.sql(&format!("CREATE EXTERNAL TABLE tbl_test STORED AS PARQUET LOCATION '{}/alltypes_plain.parquet'", test_data, )).await?.show().await?; + + let result = ctx + .sql("select id, string_col, timestamp_col from tbl_test where id > 4") + .await? + .collect() + .await?; + let expected = [ + "+----+------------+---------------------+", + "| id | string_col | timestamp_col |", + "+----+------------+---------------------+", + "| 5 | 31 | 2009-03-01T00:01:00 |", + "| 6 | 30 | 2009-04-01T00:00:00 |", + "| 7 | 31 | 2009-04-01T00:01:00 |", + "+----+------------+---------------------+", + ]; + + assert_batches_eq!(expected, &result); + + Ok(()) + } + + #[tokio::test] + #[ignore = "Error serializing custom table - NotImplemented(LogicalExtensionCodec is not provided))"] + async fn should_execute_sql_create_table() -> datafusion::error::Result<()> { + let ctx: SessionContext = SessionContext::standalone().await?; + ctx.sql("CREATE TABLE tbl_test (id INT, value INT)") + .await? + .show() + .await?; + + // it does create table but it can't be queried + let _result = ctx + .sql("select * from tbl_test where id > 0") + .await? + .collect() + .await?; + + Ok(()) + } + + #[tokio::test] + async fn should_execute_dataframe() -> datafusion::error::Result<()> { + let test_data = crate::common::example_test_data(); + + let ctx: SessionContext = SessionContext::standalone().await?; + + let df = ctx + .read_parquet( + &format!("{test_data}/alltypes_plain.parquet"), + Default::default(), + ) + .await? + .select_columns(&["id", "bool_col", "timestamp_col"])? + .filter(col("id").gt(lit(5)))?; + + let result = df.collect().await?; + + let expected = [ + "+----+----------+---------------------+", + "| id | bool_col | timestamp_col |", + "+----+----------+---------------------+", + "| 6 | true | 2009-04-01T00:00:00 |", + "| 7 | false | 2009-04-01T00:01:00 |", + "+----+----------+---------------------+", + ]; + + assert_batches_eq!(expected, &result); + + Ok(()) + } + + #[tokio::test] + #[ignore = "Error serializing custom table - NotImplemented(LogicalExtensionCodec is not provided))"] + async fn should_execute_dataframe_cache() -> datafusion::error::Result<()> { + let test_data = crate::common::example_test_data(); + + let ctx: SessionContext = SessionContext::standalone().await?; + + let df = ctx + .read_parquet( + &format!("{test_data}/alltypes_plain.parquet"), + Default::default(), + ) + .await? + .select_columns(&["id", "bool_col", "timestamp_col"])? + .filter(col("id").gt(lit(5)))?; + + let cached_df = df.cache().await?; + let result = cached_df.collect().await?; + + let expected = [ + "+----+----------+---------------------+", + "| id | bool_col | timestamp_col |", + "+----+----------+---------------------+", + "| 6 | true | 2009-04-01T00:00:00 |", + "| 7 | false | 2009-04-01T00:01:00 |", + "+----+----------+---------------------+", + ]; + + assert_batches_eq!(expected, &result); + + Ok(()) + } + + #[tokio::test] + #[ignore = "Error: Internal(failed to serialize logical plan: Internal(LogicalPlan serde is not yet implemented for Dml))"] + async fn should_execute_sql_insert() -> datafusion::error::Result<()> { + let test_data = crate::common::example_test_data(); + + let ctx: SessionContext = SessionContext::standalone().await?; + + ctx.register_parquet( + "test", + &format!("{test_data}/alltypes_plain.parquet"), + Default::default(), + ) + .await?; + let write_dir = tempfile::tempdir().expect("temporary directory to be created"); + let write_dir_path = write_dir + .path() + .to_str() + .expect("path to be converted to str"); + + ctx.sql("select * from test") + .await? + .write_parquet(write_dir_path, Default::default(), Default::default()) + .await?; + + ctx.register_parquet("written_table", write_dir_path, Default::default()) + .await?; + + let _ = ctx + .sql("INSERT INTO written_table select * from written_table") + .await? + .collect() + .await?; + + let result = ctx + .sql("select id, string_col, timestamp_col from written_table where id > 4 order by id") + .await? + .collect() + .await?; + + let expected = [ + "+----+------------+---------------------+", + "| id | string_col | timestamp_col |", + "+----+------------+---------------------+", + "| 5 | 31 | 2009-03-01T00:01:00 |", + "| 5 | 31 | 2009-03-01T00:01:00 |", + "| 6 | 30 | 2009-04-01T00:00:00 |", + "| 6 | 30 | 2009-04-01T00:00:00 |", + "| 7 | 31 | 2009-04-01T00:01:00 |", + "| 7 | 31 | 2009-04-01T00:01:00 |", + "+----+------------+---------------------+", + ]; + + assert_batches_eq!(expected, &result); + + Ok(()) + } + + #[tokio::test] + #[cfg(not(windows))] // test is failing at windows, can't debug it + async fn should_execute_sql_write() -> datafusion::error::Result<()> { + let test_data = crate::common::example_test_data(); + + let ctx: SessionContext = SessionContext::standalone().await?; + ctx.register_parquet( + "test", + &format!("{test_data}/alltypes_plain.parquet"), + Default::default(), + ) + .await?; + let write_dir = tempfile::tempdir().expect("temporary directory to be created"); + let write_dir_path = write_dir + .path() + .to_str() + .expect("path to be converted to str"); + + ctx.sql("select * from test") + .await? + .write_parquet(write_dir_path, Default::default(), Default::default()) + .await?; + ctx.register_parquet("written_table", write_dir_path, Default::default()) + .await?; + + let result = ctx + .sql("select id, string_col, timestamp_col from written_table where id > 4") + .await? + .collect() + .await?; + let expected = [ + "+----+------------+---------------------+", + "| id | string_col | timestamp_col |", + "+----+------------+---------------------+", + "| 5 | 31 | 2009-03-01T00:01:00 |", + "| 6 | 30 | 2009-04-01T00:00:00 |", + "| 7 | 31 | 2009-04-01T00:01:00 |", + "+----+------------+---------------------+", + ]; + + assert_batches_eq!(expected, &result); + Ok(()) + } +} diff --git a/ballista/core/src/config.rs b/ballista/core/src/config.rs index db47a7e5b..88cba1d9a 100644 --- a/ballista/core/src/config.rs +++ b/ballista/core/src/config.rs @@ -25,7 +25,12 @@ use std::result; use crate::error::{BallistaError, Result}; -use datafusion::arrow::datatypes::DataType; +use datafusion::{ + arrow::datatypes::DataType, common::config_err, config::ConfigExtension, +}; + +// TODO: to be revisited, do we need all of them or +// we can reuse datafusion properties pub const BALLISTA_JOB_NAME: &str = "ballista.job.name"; pub const BALLISTA_DEFAULT_SHUFFLE_PARTITIONS: &str = "ballista.shuffle.partitions"; @@ -37,6 +42,7 @@ pub const BALLISTA_REPARTITION_AGGREGATIONS: &str = "ballista.repartition.aggreg pub const BALLISTA_REPARTITION_WINDOWS: &str = "ballista.repartition.windows"; pub const BALLISTA_PARQUET_PRUNING: &str = "ballista.parquet.pruning"; pub const BALLISTA_COLLECT_STATISTICS: &str = "ballista.collect_statistics"; +pub const BALLISTA_STANDALONE_PARALLELISM: &str = "ballista.standalone.parallelism"; pub const BALLISTA_WITH_INFORMATION_SCHEMA: &str = "ballista.with_information_schema"; @@ -198,13 +204,14 @@ impl BallistaConfig { "Sets whether enable information_schema".to_string(), DataType::Boolean, Some("false".to_string())), ConfigEntry::new(BALLISTA_HASH_JOIN_SINGLE_PARTITION_THRESHOLD.to_string(), - "Sets threshold in bytes for collecting the smaller side of the hash join in memory".to_string(), - DataType::UInt64, Some((1024 * 1024).to_string())), + "Sets threshold in bytes for collecting the smaller side of the hash join in memory".to_string(), + DataType::UInt64, Some((1024 * 1024).to_string())), ConfigEntry::new(BALLISTA_COLLECT_STATISTICS.to_string(), - "Configuration for collecting statistics during scan".to_string(), - DataType::Boolean, Some("false".to_string()) - ), - + "Configuration for collecting statistics during scan".to_string(), + DataType::Boolean, Some("false".to_string())), + ConfigEntry::new(BALLISTA_STANDALONE_PARALLELISM.to_string(), + "Standalone processing parallelism ".to_string(), + DataType::UInt16, Some(std::thread::available_parallelism().map(|v| v.get()).unwrap_or(1).to_string())), ConfigEntry::new(BALLISTA_GRPC_CLIENT_MAX_MESSAGE_SIZE.to_string(), "Configuration for max message size in gRPC clients".to_string(), DataType::UInt64, @@ -256,6 +263,10 @@ impl BallistaConfig { self.get_bool_setting(BALLISTA_COLLECT_STATISTICS) } + pub fn default_standalone_parallelism(&self) -> usize { + self.get_usize_setting(BALLISTA_STANDALONE_PARALLELISM) + } + pub fn default_with_information_schema(&self) -> bool { self.get_bool_setting(BALLISTA_WITH_INFORMATION_SCHEMA) } @@ -297,6 +308,49 @@ impl BallistaConfig { } } +impl datafusion::config::ExtensionOptions for BallistaConfig { + fn as_any(&self) -> &dyn std::any::Any { + self + } + + fn as_any_mut(&mut self) -> &mut dyn std::any::Any { + self + } + + fn cloned(&self) -> Box<dyn datafusion::config::ExtensionOptions> { + Box::new(self.clone()) + } + + fn set(&mut self, key: &str, value: &str) -> datafusion::error::Result<()> { + // TODO: this is just temporary until i figure it out + // what to do with it + let entries = Self::valid_entries(); + let k = format!("{}.{key}", BallistaConfig::PREFIX); + + if entries.contains_key(&k) { + self.settings.insert(k, value.to_string()); + Ok(()) + } else { + config_err!("configuration key `{}` does not exist", key) + } + } + + fn entries(&self) -> Vec<datafusion::config::ConfigEntry> { + Self::valid_entries() + .into_iter() + .map(|(key, value)| datafusion::config::ConfigEntry { + key: key.clone(), + value: self.settings.get(&key).cloned().or(value.default_value), + description: "", + }) + .collect() + } +} + +impl datafusion::config::ConfigExtension for BallistaConfig { + const PREFIX: &'static str = "ballista"; +} + // an enum used to configure the scheduler policy // needs to be visible to code generated by configure_me #[derive(Clone, ValueEnum, Copy, Debug, serde::Deserialize)] diff --git a/ballista/core/src/error.rs b/ballista/core/src/error.rs index 1ef795dfe..95bee2bf1 100644 --- a/ballista/core/src/error.rs +++ b/ballista/core/src/error.rs @@ -188,9 +188,9 @@ impl Display for BallistaError { BallistaError::General(ref desc) => write!(f, "General error: {desc}"), BallistaError::ArrowError(ref desc) => write!(f, "Arrow error: {desc}"), BallistaError::DataFusionError(ref desc) => { - write!(f, "DataFusion error: {desc:?}") + write!(f, "DataFusion error: {desc}") } - BallistaError::SqlError(ref desc) => write!(f, "SQL error: {desc:?}"), + BallistaError::SqlError(ref desc) => write!(f, "SQL error: {desc}"), BallistaError::IoError(ref desc) => write!(f, "IO error: {desc}"), // BallistaError::ReqwestError(ref desc) => write!(f, "Reqwest error: {}", desc), // BallistaError::HttpError(ref desc) => write!(f, "HTTP error: {}", desc), diff --git a/ballista/core/src/serde/mod.rs b/ballista/core/src/serde/mod.rs index fce4a399e..0381ba6f7 100644 --- a/ballista/core/src/serde/mod.rs +++ b/ballista/core/src/serde/mod.rs @@ -125,14 +125,22 @@ pub struct BallistaLogicalExtensionCodec { } impl BallistaLogicalExtensionCodec { + // looks for a codec which can operate on this node + // returns a position of codec in the list. + // + // position is important with encoding process + // as there is a need to remember which codec + // in the list was used to encode message, + // so we can use it for decoding as well + fn try_any<T>( &self, mut f: impl FnMut(&dyn LogicalExtensionCodec) -> Result<T>, - ) -> Result<T> { + ) -> Result<(u8, T)> { let mut last_err = None; - for codec in &self.file_format_codecs { + for (position, codec) in self.file_format_codecs.iter().enumerate() { match f(codec.as_ref()) { - Ok(node) => return Ok(node), + Ok(node) => return Ok((position as u8, node)), Err(err) => last_err = Some(err), } } @@ -202,7 +210,19 @@ impl LogicalExtensionCodec for BallistaLogicalExtensionCodec { buf: &[u8], ctx: &datafusion::prelude::SessionContext, ) -> Result<Arc<dyn datafusion::datasource::file_format::FileFormatFactory>> { - self.try_any(|codec| codec.try_decode_file_format(buf, ctx)) + if !buf.is_empty() { + // gets codec id from input buffer + let codec_number = buf[0]; + let codec = self.file_format_codecs.get(codec_number as usize).ok_or( + DataFusionError::NotImplemented("Can't find required codex".to_owned()), + )?; + + codec.try_decode_file_format(&buf[1..], ctx) + } else { + Err(DataFusionError::NotImplemented( + "File format blob should have more than 0 bytes".to_owned(), + )) + } } fn try_encode_file_format( @@ -210,7 +230,18 @@ impl LogicalExtensionCodec for BallistaLogicalExtensionCodec { buf: &mut Vec<u8>, node: Arc<dyn datafusion::datasource::file_format::FileFormatFactory>, ) -> Result<()> { - self.try_any(|codec| codec.try_encode_file_format(buf, node.clone())) + let mut encoded_format = vec![]; + let (codec_number, _) = self.try_any(|codec| { + codec.try_encode_file_format(&mut encoded_format, node.clone()) + })?; + // we need to remember which codec in the list was used to + // encode this node. + buf.push(codec_number); + + // save actual encoded node + buf.append(&mut encoded_format); + + Ok(()) } } @@ -397,3 +428,51 @@ impl PhysicalExtensionCodec for BallistaPhysicalExtensionCodec { } } } + +#[cfg(test)] +mod test { + use datafusion::{ + common::DFSchema, + datasource::file_format::{parquet::ParquetFormatFactory, DefaultFileType}, + logical_expr::{dml::CopyTo, EmptyRelation, LogicalPlan}, + prelude::SessionContext, + }; + use datafusion_proto::{logical_plan::AsLogicalPlan, protobuf::LogicalPlanNode}; + use std::sync::Arc; + + #[tokio::test] + async fn file_format_serialization_roundtrip() { + let ctx = SessionContext::new(); + let empty = EmptyRelation { + produce_one_row: false, + schema: Arc::new(DFSchema::empty()), + }; + let file_type = + Arc::new(DefaultFileType::new(Arc::new(ParquetFormatFactory::new()))); + let original_plan = LogicalPlan::Copy(CopyTo { + input: Arc::new(LogicalPlan::EmptyRelation(empty)), + output_url: "/tmp/file".to_string(), + partition_by: vec![], + file_type, + options: Default::default(), + }); + + let codec = crate::serde::BallistaLogicalExtensionCodec::default(); + let plan_message = + LogicalPlanNode::try_from_logical_plan(&original_plan, &codec).unwrap(); + + let mut buf: Vec<u8> = vec![]; + plan_message.try_encode(&mut buf).unwrap(); + println!("{}", original_plan.display_indent()); + + let decoded_message = LogicalPlanNode::try_decode(&buf).unwrap(); + let decoded_plan = decoded_message.try_into_logical_plan(&ctx, &codec).unwrap(); + + println!("{}", decoded_plan.display_indent()); + let o = original_plan.display_indent(); + let d = decoded_plan.display_indent(); + + assert_eq!(o.to_string(), d.to_string()) + //logical_plan. + } +} diff --git a/ballista/core/src/utils.rs b/ballista/core/src/utils.rs index 7e88ffaf3..eceb9d447 100644 --- a/ballista/core/src/utils.rs +++ b/ballista/core/src/utils.rs @@ -22,6 +22,7 @@ use crate::execution_plans::{ }; use crate::object_store_registry::with_object_store_registry; use crate::serde::scheduler::PartitionStats; +use crate::serde::BallistaLogicalExtensionCodec; use async_trait::async_trait; use datafusion::arrow::datatypes::Schema; @@ -29,6 +30,7 @@ use datafusion::arrow::ipc::writer::IpcWriteOptions; use datafusion::arrow::ipc::writer::StreamWriter; use datafusion::arrow::ipc::CompressionType; use datafusion::arrow::record_batch::RecordBatch; +use datafusion::common::tree_node::{TreeNode, TreeNodeVisitor}; use datafusion::datasource::physical_plan::{CsvExec, ParquetExec}; use datafusion::error::DataFusionError; use datafusion::execution::context::{ @@ -36,7 +38,7 @@ use datafusion::execution::context::{ }; use datafusion::execution::runtime_env::{RuntimeConfig, RuntimeEnv}; use datafusion::execution::session_state::SessionStateBuilder; -use datafusion::logical_expr::{DdlStatement, LogicalPlan}; +use datafusion::logical_expr::{DdlStatement, LogicalPlan, TableScan}; use datafusion::physical_plan::aggregates::AggregateExec; use datafusion::physical_plan::coalesce_batches::CoalesceBatchesExec; use datafusion::physical_plan::coalesce_partitions::CoalescePartitionsExec; @@ -47,9 +49,8 @@ use datafusion::physical_plan::metrics::MetricsSet; use datafusion::physical_plan::projection::ProjectionExec; use datafusion::physical_plan::sorts::sort::SortExec; use datafusion::physical_plan::{metrics, ExecutionPlan, RecordBatchStream}; -use datafusion_proto::logical_plan::{ - AsLogicalPlan, DefaultLogicalExtensionCodec, LogicalExtensionCodec, -}; +use datafusion::physical_planner::{DefaultPhysicalPlanner, PhysicalPlanner}; +use datafusion_proto::logical_plan::{AsLogicalPlan, LogicalExtensionCodec}; use futures::StreamExt; use log::error; use std::io::{BufWriter, Write}; @@ -248,12 +249,18 @@ pub fn create_df_ctx_with_ballista_query_planner<T: 'static + AsLogicalPlan>( session_id: String, config: &BallistaConfig, ) -> SessionContext { + // TODO: put ballista configuration as part of sessions state + // planner can get it from there. + // This would make it changeable during run time + // using SQL SET statement let planner: Arc<BallistaQueryPlanner<T>> = Arc::new(BallistaQueryPlanner::new(scheduler_url, config.clone())); let session_config = SessionConfig::new() .with_target_partitions(config.default_shuffle_partitions()) - .with_information_schema(true); + .with_information_schema(true) + .with_option_extension(config.clone()); + let session_state = SessionStateBuilder::new() .with_default_features() .with_config(session_config) @@ -272,6 +279,7 @@ pub struct BallistaQueryPlanner<T: AsLogicalPlan> { scheduler_url: String, config: BallistaConfig, extension_codec: Arc<dyn LogicalExtensionCodec>, + local_planner: DefaultPhysicalPlanner, plan_repr: PhantomData<T>, } @@ -280,7 +288,8 @@ impl<T: 'static + AsLogicalPlan> BallistaQueryPlanner<T> { Self { scheduler_url, config, - extension_codec: Arc::new(DefaultLogicalExtensionCodec {}), + extension_codec: Arc::new(BallistaLogicalExtensionCodec::default()), + local_planner: DefaultPhysicalPlanner::default(), plan_repr: PhantomData, } } @@ -294,6 +303,7 @@ impl<T: 'static + AsLogicalPlan> BallistaQueryPlanner<T> { scheduler_url, config, extension_codec, + local_planner: DefaultPhysicalPlanner::default(), plan_repr: PhantomData, } } @@ -309,6 +319,7 @@ impl<T: 'static + AsLogicalPlan> BallistaQueryPlanner<T> { config, extension_codec, plan_repr, + local_planner: DefaultPhysicalPlanner::default(), } } } @@ -320,19 +331,43 @@ impl<T: 'static + AsLogicalPlan> QueryPlanner for BallistaQueryPlanner<T> { logical_plan: &LogicalPlan, session_state: &SessionState, ) -> std::result::Result<Arc<dyn ExecutionPlan>, DataFusionError> { - match logical_plan { - LogicalPlan::Ddl(DdlStatement::CreateExternalTable(_)) => { - // table state is managed locally in the BallistaContext, not in the scheduler - Ok(Arc::new(EmptyExec::new(Arc::new(Schema::empty())))) + log::debug!("create_physical_plan - plan: {:?}", logical_plan); + // we inspect if plan scans local tables only, + // like tables located in information_schema, + // if that is the case, we run that plan + // on this same context, not on cluster + let mut local_run = LocalRun::default(); + let _ = logical_plan.visit(&mut local_run); + + if local_run.can_be_local { + log::debug!("create_physical_plan - local run"); + + self.local_planner + .create_physical_plan(logical_plan, session_state) + .await + } else { + match logical_plan { + LogicalPlan::Ddl(DdlStatement::CreateExternalTable(_t)) => { + log::debug!("create_physical_plan - handling ddl statement"); + Ok(Arc::new(EmptyExec::new(Arc::new(Schema::empty())))) + } + LogicalPlan::EmptyRelation(_) => { + log::debug!("create_physical_plan - handling empty exec"); + Ok(Arc::new(EmptyExec::new(Arc::new(Schema::empty())))) + } + _ => { + log::debug!("create_physical_plan - handling general statement"); + + Ok(Arc::new(DistributedQueryExec::with_repr( + self.scheduler_url.clone(), + self.config.clone(), + logical_plan.clone(), + self.extension_codec.clone(), + self.plan_repr, + session_state.session_id().to_string(), + ))) + } } - _ => Ok(Arc::new(DistributedQueryExec::with_repr( - self.scheduler_url.clone(), - self.config.clone(), - logical_plan.clone(), - self.extension_codec.clone(), - self.plan_repr, - session_state.session_id().to_string(), - ))), } } } @@ -389,3 +424,128 @@ pub fn get_time_before(interval_seconds: u64) -> u64 { .unwrap_or_else(|| Duration::from_secs(0)) .as_secs() } + +/// A Visitor which detect if query is using local tables, +/// such as tables located in `information_schema` and returns true +/// only if all scans are in from local tables +#[derive(Debug, Default)] +struct LocalRun { + can_be_local: bool, +} + +impl<'n> TreeNodeVisitor<'n> for LocalRun { + type Node = LogicalPlan; + + fn f_down( + &mut self, + node: &'n Self::Node, + ) -> datafusion::error::Result<datafusion::common::tree_node::TreeNodeRecursion> { + match node { + LogicalPlan::TableScan(TableScan { table_name, .. }) => match table_name { + datafusion::sql::TableReference::Partial { schema, .. } + | datafusion::sql::TableReference::Full { schema, .. } + if schema.as_ref() == "information_schema" => + { + self.can_be_local = true; + Ok(datafusion::common::tree_node::TreeNodeRecursion::Continue) + } + _ => { + self.can_be_local = false; + Ok(datafusion::common::tree_node::TreeNodeRecursion::Stop) + } + }, + _ => Ok(datafusion::common::tree_node::TreeNodeRecursion::Continue), + } + } +} + +#[cfg(test)] +mod test { + use datafusion::{ + common::tree_node::TreeNode, + error::Result, + execution::{ + runtime_env::{RuntimeConfig, RuntimeEnv}, + SessionStateBuilder, + }, + prelude::{SessionConfig, SessionContext}, + }; + + use crate::utils::LocalRun; + + fn context() -> SessionContext { + let runtime_environment = RuntimeEnv::new(RuntimeConfig::new()).unwrap(); + + let session_config = SessionConfig::new().with_information_schema(true); + + let state = SessionStateBuilder::new() + .with_config(session_config) + .with_runtime_env(runtime_environment.into()) + .with_default_features() + .build(); + + SessionContext::new_with_state(state) + } + + #[tokio::test] + async fn should_detect_show_table_as_local_plan() -> Result<()> { + let ctx = context(); + let df = ctx.sql("SHOW TABLES").await?; + let lp = df.logical_plan(); + let mut local_run = LocalRun::default(); + + lp.visit(&mut local_run).unwrap(); + + assert!(local_run.can_be_local); + + Ok(()) + } + + #[tokio::test] + async fn should_detect_select_from_information_schema_as_local_plan() -> Result<()> { + let ctx = context(); + let df = ctx.sql("SELECT * FROM information_schema.df_settings WHERE NAME LIKE 'ballista%'").await?; + let lp = df.logical_plan(); + let mut local_run = LocalRun::default(); + + lp.visit(&mut local_run).unwrap(); + + assert!(local_run.can_be_local); + + Ok(()) + } + + #[tokio::test] + async fn should_not_detect_local_table() -> Result<()> { + let ctx = context(); + ctx.sql("CREATE TABLE tt (c0 INT, c1 INT)") + .await? + .show() + .await?; + let df = ctx.sql("SELECT * FROM tt").await?; + let lp = df.logical_plan(); + let mut local_run = LocalRun::default(); + + lp.visit(&mut local_run).unwrap(); + + assert!(!local_run.can_be_local); + + Ok(()) + } + + #[tokio::test] + async fn should_not_detect_external_table() -> Result<()> { + let ctx = context(); + ctx.register_csv("tt", "tests/customer.csv", Default::default()) + .await?; + let df = ctx.sql("SELECT * FROM tt").await?; + let lp = df.logical_plan(); + let mut local_run = LocalRun::default(); + + lp.visit(&mut local_run).unwrap(); + + assert!(!local_run.can_be_local); + + Ok(()) + } +} diff --git a/ballista/scheduler/src/scheduler_server/grpc.rs b/ballista/scheduler/src/scheduler_server/grpc.rs index 0d7d5e366..653bda834 100644 --- a/ballista/scheduler/src/scheduler_server/grpc.rs +++ b/ballista/scheduler/src/scheduler_server/grpc.rs @@ -512,7 +512,10 @@ impl<T: 'static + AsLogicalPlan, U: 'static + AsExecutionPlan> SchedulerGrpc } }; - debug!("Received plan for execution: {:?}", plan); + debug!( + "Decoded logical plan for execution:\n{}", + plan.display_indent() + ); let job_id = self.state.task_manager.generate_job_id(); let job_name = query_settings