Skip to content

Commit

Permalink
finish
Browse files Browse the repository at this point in the history
  • Loading branch information
ugoa committed Feb 7, 2025
1 parent 7db4019 commit 6b80aa7
Showing 1 changed file with 338 additions and 13 deletions.
351 changes: 338 additions & 13 deletions docs/source/library-user-guide/custom-table-providers.md
Original file line number Diff line number Diff line change
Expand Up @@ -37,9 +37,9 @@ The `ExecutionPlan` trait at its core is a way to get a stream of batches. The a

There are many different types of `SendableRecordBatchStream` implemented in DataFusion -- you can use a pre existing one, such as `MemoryStream` (if your `RecordBatch`es are all in memory) or implement your own custom logic, depending on your usecase.

Looking at the [example in this repo][ex], the execute method:
Looking at the full example below:

```fixed
```rust
use std::any::Any;
use std::sync::{Arc, Mutex};
use std::collections::{BTreeMap, HashMap};
Expand All @@ -50,7 +50,7 @@ use datafusion::physical_plan::{
ExecutionPlan, SendableRecordBatchStream, DisplayAs, DisplayFormatType,
Statistics, PlanProperties
};
use datafusion::execution::context::{SessionState, TaskContext};
use datafusion::execution::context::TaskContext;
use datafusion::arrow::array::{UInt64Builder, UInt8Builder};
use datafusion::physical_plan::memory::MemoryStream;
use datafusion::arrow::record_batch::RecordBatch;
Expand Down Expand Up @@ -148,7 +148,7 @@ impl ExecutionPlan for CustomExec {
}
```

This:
This `execute` method:

1. Gets the users from the database
2. Constructs the individual output arrays (columns)
Expand All @@ -162,7 +162,135 @@ With the `ExecutionPlan` implemented, we can now implement the `scan` method of

The `scan` method of the `TableProvider` returns a `Result<Arc<dyn ExecutionPlan>>`. We can use the `Arc` to return a reference-counted pointer to the `ExecutionPlan` we implemented. In the example, this is done by:

```tofix
```rust

# use std::any::Any;
# use std::sync::{Arc, Mutex};
# use std::collections::{BTreeMap, HashMap};
# use datafusion::common::Result;
# use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef};
# use datafusion::physical_plan::expressions::PhysicalSortExpr;
# use datafusion::physical_plan::{
# ExecutionPlan, SendableRecordBatchStream, DisplayAs, DisplayFormatType,
# Statistics, PlanProperties
# };
# use datafusion::execution::context::TaskContext;
# use datafusion::arrow::array::{UInt64Builder, UInt8Builder};
# use datafusion::physical_plan::memory::MemoryStream;
# use datafusion::arrow::record_batch::RecordBatch;
#
# /// A User, with an id and a bank account
# #[derive(Clone, Debug)]
# struct User {
# id: u8,
# bank_account: u64,
# }
#
# /// A custom datasource, used to represent a datastore with a single index
# #[derive(Clone, Debug)]
# pub struct CustomDataSource {
# inner: Arc<Mutex<CustomDataSourceInner>>,
# }
#
# #[derive(Debug)]
# struct CustomDataSourceInner {
# data: HashMap<u8, User>,
# bank_account_index: BTreeMap<u64, u8>,
# }
#
# #[derive(Debug)]
# struct CustomExec {
# db: CustomDataSource,
# projected_schema: SchemaRef,
# }
#
# impl DisplayAs for CustomExec {
# fn fmt_as(&self, _t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result {
# write!(f, "CustomExec")
# }
# }
#
# impl ExecutionPlan for CustomExec {
# fn name(&self) -> &str {
# "CustomExec"
# }
#
# fn as_any(&self) -> &dyn Any {
# self
# }
#
# fn schema(&self) -> SchemaRef {
# self.projected_schema.clone()
# }
#
#
# fn properties(&self) -> &PlanProperties {
# unreachable!()
# }
#
# fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
# Vec::new()
# }
#
# fn with_new_children(
# self: Arc<Self>,
# _: Vec<Arc<dyn ExecutionPlan>>,
# ) -> Result<Arc<dyn ExecutionPlan>> {
# Ok(self)
# }
#
# fn execute(
# &self,
# _partition: usize,
# _context: Arc<TaskContext>,
# ) -> Result<SendableRecordBatchStream> {
# let users: Vec<User> = {
# let db = self.db.inner.lock().unwrap();
# db.data.values().cloned().collect()
# };
#
# let mut id_array = UInt8Builder::with_capacity(users.len());
# let mut account_array = UInt64Builder::with_capacity(users.len());
#
# for user in users {
# id_array.append_value(user.id);
# account_array.append_value(user.bank_account);
# }
#
# Ok(Box::pin(MemoryStream::try_new(
# vec![RecordBatch::try_new(
# self.projected_schema.clone(),
# vec![
# Arc::new(id_array.finish()),
# Arc::new(account_array.finish()),
# ],
# )?],
# self.schema(),
# None,
# )?))
# }
# }

use async_trait::async_trait;
use datafusion::logical_expr::expr::Expr;
use datafusion::datasource::{TableProvider, TableType};
use datafusion::physical_plan::project_schema;
use datafusion::catalog::Session;

impl CustomExec {
fn new(
projections: Option<&Vec<usize>>,
schema: SchemaRef,
db: CustomDataSource,
) -> Self {
let projected_schema = project_schema(&schema, projections).unwrap();
Self {
db,
projected_schema,
}
}
}

impl CustomDataSource {
pub(crate) async fn create_physical_plan(
&self,
Expand All @@ -175,6 +303,21 @@ impl CustomDataSource {

#[async_trait]
impl TableProvider for CustomDataSource {
fn as_any(&self) -> &dyn Any {
self
}

fn schema(&self) -> SchemaRef {
SchemaRef::new(Schema::new(vec![
Field::new("id", DataType::UInt8, false),
Field::new("bank_account", DataType::UInt64, true),
]))
}

fn table_type(&self) -> TableType {
TableType::Base
}

async fn scan(
&self,
_state: &dyn Session,
Expand Down Expand Up @@ -210,17 +353,199 @@ For filters that can be pushed down, they'll be passed to the `scan` method as t

In order to use the custom table provider, we need to register it with DataFusion. This is done by creating a `TableProvider` and registering it with the `SessionContext`.

```tofix
let ctx = SessionContext::new();
This will allow you to use the custom table provider in DataFusion. For example, you could use it in a SQL query to get a `DataFrame`.

let custom_table_provider = CustomDataSource::new();
ctx.register_table("custom_table", Arc::new(custom_table_provider));
```
```rust
# use std::any::Any;
# use std::sync::{Arc, Mutex};
# use std::collections::{BTreeMap, HashMap};
# use datafusion::common::Result;
# use datafusion::arrow::datatypes::{DataType, Field, Schema, SchemaRef};
# use datafusion::physical_plan::expressions::PhysicalSortExpr;
# use datafusion::physical_plan::{
# ExecutionPlan, SendableRecordBatchStream, DisplayAs, DisplayFormatType,
# Statistics, PlanProperties
# };
# use datafusion::execution::context::TaskContext;
# use datafusion::arrow::array::{UInt64Builder, UInt8Builder};
# use datafusion::physical_plan::memory::MemoryStream;
# use datafusion::arrow::record_batch::RecordBatch;
#
# /// A User, with an id and a bank account
# #[derive(Clone, Debug)]
# struct User {
# id: u8,
# bank_account: u64,
# }
#
# /// A custom datasource, used to represent a datastore with a single index
# #[derive(Clone, Debug)]
# pub struct CustomDataSource {
# inner: Arc<Mutex<CustomDataSourceInner>>,
# }
#
# #[derive(Debug)]
# struct CustomDataSourceInner {
# data: HashMap<u8, User>,
# bank_account_index: BTreeMap<u64, u8>,
# }
#
# #[derive(Debug)]
# struct CustomExec {
# db: CustomDataSource,
# projected_schema: SchemaRef,
# }
#
# impl DisplayAs for CustomExec {
# fn fmt_as(&self, _t: DisplayFormatType, f: &mut std::fmt::Formatter) -> std::fmt::Result {
# write!(f, "CustomExec")
# }
# }
#
# impl ExecutionPlan for CustomExec {
# fn name(&self) -> &str {
# "CustomExec"
# }
#
# fn as_any(&self) -> &dyn Any {
# self
# }
#
# fn schema(&self) -> SchemaRef {
# self.projected_schema.clone()
# }
#
#
# fn properties(&self) -> &PlanProperties {
# unreachable!()
# }
#
# fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
# Vec::new()
# }
#
# fn with_new_children(
# self: Arc<Self>,
# _: Vec<Arc<dyn ExecutionPlan>>,
# ) -> Result<Arc<dyn ExecutionPlan>> {
# Ok(self)
# }
#
# fn execute(
# &self,
# _partition: usize,
# _context: Arc<TaskContext>,
# ) -> Result<SendableRecordBatchStream> {
# let users: Vec<User> = {
# let db = self.db.inner.lock().unwrap();
# db.data.values().cloned().collect()
# };
#
# let mut id_array = UInt8Builder::with_capacity(users.len());
# let mut account_array = UInt64Builder::with_capacity(users.len());
#
# for user in users {
# id_array.append_value(user.id);
# account_array.append_value(user.bank_account);
# }
#
# Ok(Box::pin(MemoryStream::try_new(
# vec![RecordBatch::try_new(
# self.projected_schema.clone(),
# vec![
# Arc::new(id_array.finish()),
# Arc::new(account_array.finish()),
# ],
# )?],
# self.schema(),
# None,
# )?))
# }
# }

# use async_trait::async_trait;
# use datafusion::logical_expr::expr::Expr;
# use datafusion::datasource::{TableProvider, TableType};
# use datafusion::physical_plan::project_schema;
# use datafusion::catalog::Session;
#
# impl CustomExec {
# fn new(
# projections: Option<&Vec<usize>>,
# schema: SchemaRef,
# db: CustomDataSource,
# ) -> Self {
# let projected_schema = project_schema(&schema, projections).unwrap();
# Self {
# db,
# projected_schema,
# }
# }
# }
#
# impl CustomDataSource {
# pub(crate) async fn create_physical_plan(
# &self,
# projections: Option<&Vec<usize>>,
# schema: SchemaRef,
# ) -> Result<Arc<dyn ExecutionPlan>> {
# Ok(Arc::new(CustomExec::new(projections, schema, self.clone())))
# }
# }
#
# #[async_trait]
# impl TableProvider for CustomDataSource {
# fn as_any(&self) -> &dyn Any {
# self
# }
#
# fn schema(&self) -> SchemaRef {
# SchemaRef::new(Schema::new(vec![
# Field::new("id", DataType::UInt8, false),
# Field::new("bank_account", DataType::UInt64, true),
# ]))
# }
#
# fn table_type(&self) -> TableType {
# TableType::Base
# }
#
# async fn scan(
# &self,
# _state: &dyn Session,
# projection: Option<&Vec<usize>>,
# // filters and limit can be used here to inject some push-down operations if needed
# _filters: &[Expr],
# _limit: Option<usize>,
# ) -> Result<Arc<dyn ExecutionPlan>> {
# return self.create_physical_plan(projection, self.schema()).await;
# }
# }

use datafusion::execution::context::SessionContext;

impl Default for CustomDataSource {
fn default() -> Self {
CustomDataSource {
inner: Arc::new(Mutex::new(CustomDataSourceInner {
data: Default::default(),
bank_account_index: Default::default(),
})),
}
}
}

This will allow you to use the custom table provider in DataFusion. For example, you could use it in a SQL query to get a `DataFrame`.
#[tokio::main]
async fn main() -> Result<()> {
let ctx = SessionContext::new();

let custom_table_provider = CustomDataSource::default();
ctx.register_table("customers", Arc::new(custom_table_provider));
let df = ctx.sql("SELECT id, bank_account FROM customers").await?;

Ok(())
}

```tofix
let df = ctx.sql("SELECT id, bank_account FROM custom_table")?;
```

## Recap
Expand Down

0 comments on commit 6b80aa7

Please sign in to comment.