Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions snowflake-api/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ polars = ["dep:polars-core", "dep:polars-io"]
[dependencies]
arrow = "53"
async-trait = "0.1"
async-stream = "0.3.5"
base64 = "0.22"
bytes = "1"
futures = "0.3"
Expand Down
119 changes: 112 additions & 7 deletions snowflake-api/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,8 +15,11 @@ clippy::missing_panics_doc

use std::fmt::{Display, Formatter};
use std::io;
use std::pin::Pin;
use std::sync::Arc;

use async_stream::stream;

use arrow::error::ArrowError;
use arrow::ipc::reader::StreamReader;
use arrow::record_batch::RecordBatch;
Expand All @@ -27,7 +30,7 @@ use regex::Regex;
use reqwest_middleware::ClientWithMiddleware;
use thiserror::Error;

use responses::ExecResponse;
use responses::{ExecResponse, QueryExecResponseData};
use session::{AuthError, Session};

use crate::connection::QueryType;
Expand All @@ -36,6 +39,8 @@ use crate::requests::ExecRequest;
use crate::responses::{ExecResponseRowType, SnowflakeType};
use crate::session::AuthError::MissingEnvArgument;

use futures::{future, Stream, StreamExt};

pub mod connection;
#[cfg(feature = "polars")]
mod polars;
Expand Down Expand Up @@ -98,6 +103,8 @@ pub enum SnowflakeApiError {
GlobError(#[from] glob::GlobError),
}

const MAX_CHUNK_DOWNLOAD_WORKERS: usize = 10;

/// Even if Arrow is specified as a return type non-select queries
/// will return Json array of arrays: `[[42, "answer"], [43, "non-answer"]]`.
pub struct JsonResult {
Expand Down Expand Up @@ -144,24 +151,38 @@ pub enum QueryResult {
Empty,
}

pub type BytesStream = Pin<Box<dyn Stream<Item = Result<bytes::Bytes, SnowflakeApiError>> + Send>>;
pub type RecordBatchStream = Pin<Box<dyn Stream<Item = Result<RecordBatch, ArrowError>> + Send>>;

/// Raw query result
/// Can be transformed into [`QueryResult`]
pub enum RawQueryResult {
/// Arrow IPC chunks
/// see: <https://arrow.apache.org/docs/format/Columnar.html#serialization-and-interprocess-communication-ipc>
Bytes(Vec<Bytes>),
Stream(BytesStream),
/// Json payload is deserialized,
/// as it's already a part of REST response
Json(JsonResult),
Empty,
}

impl RawQueryResult {
pub fn deserialize_arrow(self) -> Result<QueryResult, ArrowError> {
pub async fn deserialize_arrow(self) -> Result<QueryResult, ArrowError> {
match self {
RawQueryResult::Bytes(bytes) => {
Self::flat_bytes_to_batches(bytes).map(QueryResult::Arrow)
}
RawQueryResult::Stream(bytes_stream) => {
let arrow_records_stream = Self::to_record_batches_stream(bytes_stream);
let arrow_records = arrow_records_stream
.collect::<Vec<Result<RecordBatch, ArrowError>>>()
.await;

Ok(QueryResult::Arrow(
arrow_records.into_iter().map(Result::unwrap).collect(),
))
}
RawQueryResult::Json(j) => Ok(QueryResult::Json(j)),
RawQueryResult::Empty => Ok(QueryResult::Empty),
}
Expand All @@ -176,6 +197,23 @@ impl RawQueryResult {
Ok(res)
}

fn to_record_batches_stream(bytes_stream: BytesStream) -> RecordBatchStream {
let batch_stream = bytes_stream.flat_map(|bytes_result| match bytes_result {
Ok(bytes) => match Self::bytes_to_batches(bytes) {
Ok(batches) => futures::stream::iter(batches.into_iter().map(Ok)).boxed(),
Err(e) => futures::stream::once(async move { Err(e) }).boxed(),
},
Err(e) => futures::stream::once(async move {
Err(ArrowError::ParseError(format!(
"Unable to parse RecordBatch due to error in bytes stream: {e}"
)))
})
.boxed(),
});

Box::pin(batch_stream)
}

fn bytes_to_batches(bytes: Bytes) -> Result<Vec<RecordBatch>, ArrowError> {
let record_batches = StreamReader::try_new(bytes.reader(), None)?;
record_batches.into_iter().collect()
Expand Down Expand Up @@ -380,10 +418,23 @@ impl SnowflakeApi {
/// If statement is PUT, then file will be uploaded to the Snowflake-managed storage
pub async fn exec(&self, sql: &str) -> Result<QueryResult, SnowflakeApiError> {
let raw = self.exec_raw(sql).await?;
let res = raw.deserialize_arrow()?;
let res = raw.deserialize_arrow().await?;
Ok(res)
}

// Executes a single query against API and returns a stream of RecordBatches
pub async fn exec_streamed(&self, sql: &str) -> Result<RecordBatchStream, SnowflakeApiError> {
let raw = self.exec_arrow_raw(sql, true).await?;
match raw {
RawQueryResult::Empty => Ok(Box::pin(futures::stream::empty())),
RawQueryResult::Stream(bytes_stream) => {
let arrow_stream = RawQueryResult::to_record_batches_stream(bytes_stream);
Ok(arrow_stream)
}
_ => Err(SnowflakeApiError::UnexpectedResponse),
}
}

/// Executes a single query against API.
/// If statement is PUT, then file will be uploaded to the Snowflake-managed storage
/// Returns raw bytes in the Arrow response
Expand All @@ -395,7 +446,7 @@ impl SnowflakeApi {
log::info!("Detected PUT query");
self.exec_put(sql).await.map(|()| RawQueryResult::Empty)
} else {
self.exec_arrow_raw(sql).await
self.exec_arrow_raw(sql, false).await
}
}

Expand Down Expand Up @@ -429,7 +480,11 @@ impl SnowflakeApi {
.await
}

async fn exec_arrow_raw(&self, sql: &str) -> Result<RawQueryResult, SnowflakeApiError> {
async fn exec_arrow_raw(
&self,
sql: &str,
enable_streaming: bool,
) -> Result<RawQueryResult, SnowflakeApiError> {
let resp = self
.run_sql::<ExecResponse>(sql, QueryType::ArrowQuery)
.await?;
Expand Down Expand Up @@ -459,14 +514,19 @@ impl SnowflakeApi {
value,
schema: resp.data.rowtype.into_iter().map(Into::into).collect(),
}))
} else if let Some(base64) = resp.data.rowset_base64 {
// fixme: is it possible to give streaming interface?
} else if resp.data.rowset_base64.is_some() {
if enable_streaming {
return Ok(self.chunks_to_bytes_stream(&resp.data));
}

let mut chunks = try_join_all(resp.data.chunks.iter().map(|chunk| {
self.connection
.get_chunk(&chunk.url, &resp.data.chunk_headers)
}))
.await?;

let base64 = resp.data.rowset_base64.unwrap_or_default();

// fixme: should base64 chunk go first?
// fixme: if response is chunked is it both base64 + chunks or just chunks?
if !base64.is_empty() {
Expand Down Expand Up @@ -510,4 +570,49 @@ impl SnowflakeApi {

Ok(resp)
}

fn chunks_to_bytes_stream(&self, data: &QueryExecResponseData) -> RawQueryResult {
let chunk_urls = data
.chunks
.iter()
.map(|chunk| chunk.url.clone())
.collect::<Vec<String>>();
let chunk_headers = data.chunk_headers.clone();
let connection = self.connection.clone();
let base64 = data.rowset_base64.clone().unwrap_or_default();

let stream = stream! {

let chunks_iter = chunk_urls.chunks(MAX_CHUNK_DOWNLOAD_WORKERS);

for chunk in chunks_iter {
let futures_batch = chunk.iter().map(|chunk_url| {
let headers = chunk_headers.clone();
let connection_clone = connection.clone();
async move {
connection_clone.get_chunk(chunk_url, &headers).await.map_err(SnowflakeApiError::from)
}
}).collect::<Vec<_>>();

let results = future::join_all(futures_batch).await;
for result in results {
yield result;
}
}

if !base64.is_empty() {
log::debug!("Got base64 encoded response");
match base64::engine::general_purpose::STANDARD.decode(&base64) {
Ok(bytes) => {
yield Ok(Bytes::from(bytes));
}
Err(e) => {
yield Err(SnowflakeApiError::from(e));
}
}
}
};

RawQueryResult::Stream(Box::pin(stream))
}
}
1 change: 1 addition & 0 deletions snowflake-api/src/polars.rs
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ impl RawQueryResult {
RawQueryResult::Bytes(bytes) => dataframe_from_bytes(bytes),
RawQueryResult::Json(json) => dataframe_from_json(&json),
RawQueryResult::Empty => Ok(DataFrame::empty()),
RawQueryResult::Stream(_) => todo!(),
}
}
}
Expand Down