Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions snowflake-api/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,15 +22,16 @@ serde_json = "1"
serde = { version = "1", features = ["derive"] }
url = "2"
uuid = { version = "1.4", features = ["v4"] }
arrow = "42"
arrow = "45"
base64 = "0.21"
regex = "1"
object_store = { version = "0.6", features = ["aws"] }
object_store = { version = "0.7", features = ["aws"] }
async-trait = "0.1"

[dev-dependencies]
anyhow = "1"
pretty_env_logger = "0.5.0"
clap = { version = "4", features = ["derive"] }
arrow = { version = "42", features = ["prettyprint"] }
tokio = { version = "1", features=["macros", "rt-multi-thread"] }
parquet = { version = "45", features = ["arrow", "snap"] }
arrow = { version = "45", features = ["prettyprint"] }
36 changes: 36 additions & 0 deletions snowflake-api/examples/filetransfer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,8 @@ use arrow::util::pretty::pretty_format_batches;
use clap::Parser;
use snowflake_api::{QueryResult, SnowflakeApi};
use std::fs;
use std::fs::File;
use parquet::arrow::arrow_reader::ParquetRecordBatchReaderBuilder;

extern crate snowflake_api;

Expand Down Expand Up @@ -43,6 +45,9 @@ struct Args {

#[arg(long)]
csv_path: String,

#[arg(long)]
output_path: String,
}

#[tokio::main]
Expand Down Expand Up @@ -111,7 +116,38 @@ async fn main() -> Result<()> {
}
}

log::info!("Copy table contents into a stage");
api.exec(
"COPY INTO @%OSCAR_AGE_MALE/output/ FROM OSCAR_AGE_MALE FILE_FORMAT = (TYPE = parquet COMPRESSION = NONE) HEADER = TRUE OVERWRITE = TRUE SINGLE = TRUE;"
).await?;

log::info!("Downloading Parquet files");
api.exec(&format!(
"GET @%OSCAR_AGE_MALE/output/ file://{}",
&args.output_path
))
.await?;

log::info!("Closing Snowflake session");
api.close_session().await?;

log::info!("Reading downloaded files");
let parquet_dir = format!("{}output", &args.output_path);
let paths = fs::read_dir(&parquet_dir).unwrap();

for path in paths {
let path = path?.path();
log::info!("Reading {:?}", path);
let file = File::open(path)?;

let builder = ParquetRecordBatchReaderBuilder::try_new(file)?;
let reader = builder.build()?;
let mut batches = Vec::default();
for batch in reader {
batches.push(batch?);
}
println!("{}", pretty_format_batches(batches.as_slice()).unwrap());
}

Ok(())
}
95 changes: 91 additions & 4 deletions snowflake-api/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@
issue_tracker_base_url = "https://github.com/mycelial/snowflake-rs/issues",
test(no_crate_inject)
)]
#![doc = include_str ! ("../README.md")]
#![doc = include_str!("../README.md")]

use std::io;
use std::path::Path;
Expand Down Expand Up @@ -169,24 +169,109 @@ impl SnowflakeApi {
/// Execute a single query against API.
/// If statement is PUT, then file will be uploaded to the Snowflake-managed storage
pub async fn exec(&mut self, sql: &str) -> Result<QueryResult, SnowflakeApiError> {
// fixme: can go without regex? but needs different accept-mime for those still
let put_re = Regex::new(r"(?i)^(?:/\*.*\*/\s*)*put\s+").unwrap();
let get_re = Regex::new(r"(?i)^(?:/\*.*\*/\s*)*get\s+").unwrap();

// put commands go through a different flow and result is side-effect
// put/get commands go through a different flow and result is side-effect
if put_re.is_match(sql) {
log::info!("Detected PUT query");

self.exec_put(sql).await.map(|_| QueryResult::Empty)
} else if get_re.is_match(sql) {
log::info!("Detected GET query");

self.exec_get(sql).await.map(|_| QueryResult::Empty)
} else {
self.exec_arrow(sql).await
}
}

async fn exec_get(&mut self, sql: &str) -> Result<(), SnowflakeApiError> {
let resp = self
.run_sql::<ExecResponse>(sql, QueryType::JsonQuery)
.await?;
log::debug!("Got GET response: {:?}", resp);

match resp {
ExecResponse::Query(_) => Err(SnowflakeApiError::UnexpectedResponse),
ExecResponse::PutGet(pg) => self.get(pg).await,
ExecResponse::Error(e) => Err(SnowflakeApiError::ApiError(
e.data.error_code,
e.message.unwrap_or_default(),
)),
}
}

async fn get(&self, resp: PutGetExecResponse) -> Result<(), SnowflakeApiError> {
match resp.data.stage_info {
PutGetStageInfo::Aws(info) => {
self.get_from_s3(
resp.data
.local_location
.ok_or(SnowflakeApiError::BrokenResponse)?,
&resp.data.src_locations,
info,
)
.await
}
PutGetStageInfo::Azure(_) => Err(SnowflakeApiError::Unimplemented(
"GET local file requests for Azure".to_string(),
)),
PutGetStageInfo::Gcs(_) => Err(SnowflakeApiError::Unimplemented(
"GET local file requests for GCS".to_string(),
)),
}
}

// fixme: refactor s3 put/get into a single function?
async fn get_from_s3(
&self,
local_location: String,
src_locations: &[String],
info: AwsPutGetStageInfo,
) -> Result<(), SnowflakeApiError> {
// todo: use path parser?
let (bucket_name, bucket_path) = info
.location
.split_once('/')
.ok_or(SnowflakeApiError::InvalidBucketPath(info.location.clone()))?;

let s3 = AmazonS3Builder::new()
.with_region(info.region)
.with_bucket_name(bucket_name)
.with_access_key_id(info.creds.aws_key_id)
.with_secret_access_key(info.creds.aws_secret_key)
.with_token(info.creds.aws_token)
.build()?;

// todo: implement parallelism for small files
// todo: security vulnerability, external system tells you which local files to upload
for src_path in src_locations {
let dest_path = format!("{}{}", local_location, src_path);
let dest_path = object_store::path::Path::parse(dest_path)?;

let src_path = format!("{}{}", bucket_path, src_path);
let src_path = object_store::path::Path::parse(src_path)?;

// fixme: can we stream the thing or multipart?
let bytes = s3.get(&src_path).await?;
LocalFileSystem::new()
.put(&dest_path, bytes.bytes().await?)
.await?;
}

Ok(())
}

async fn exec_put(&mut self, sql: &str) -> Result<(), SnowflakeApiError> {
let resp = self
.run_sql::<ExecResponse>(sql, QueryType::JsonQuery)
.await?;
// fixme: don't log secrets maybe?
log::debug!("Got PUT response: {:?}", resp);

// fixme: support PUT for external stage
match resp {
ExecResponse::Query(_) => Err(SnowflakeApiError::UnexpectedResponse),
ExecResponse::PutGet(pg) => self.put(pg).await,
Expand Down Expand Up @@ -227,21 +312,23 @@ impl SnowflakeApi {
.with_token(info.creds.aws_token)
.build()?;

// todo: implement parallelism for small files
// todo: security vulnerability, external system tells you which local files to upload
for src_path in src_locations {
let path = Path::new(src_path);
let filename = path
.file_name()
.ok_or(SnowflakeApiError::InvalidLocalPath(src_path.clone()))?;

// fixme: nicer way to join paths?
// fixme: unwrap
let dest_path = format!("{}{}", bucket_path, filename.to_str().unwrap());
let dest_path = object_store::path::Path::parse(dest_path)?;

let src_path = object_store::path::Path::parse(src_path)?;

// fixme: can we stream the thing or multipart?
let fs = LocalFileSystem::new().get(&src_path).await?;

s3.put(&dest_path, fs.bytes().await?).await?;
}

Expand Down Expand Up @@ -276,7 +363,7 @@ impl SnowflakeApi {
return Err(SnowflakeApiError::ApiError(
e.data.error_code,
e.message.unwrap_or_default(),
))
));
}
};

Expand Down
9 changes: 5 additions & 4 deletions snowflake-api/src/responses.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,18 +198,19 @@ pub struct PutGetResponseData {
// file upload parallelism
pub parallel: i32,
// file size threshold, small ones are should be uploaded with given parallelism
pub threshold: i64,
pub threshold: Option<i64>,
// doesn't need compression if source is already compressed
pub auto_compress: bool,
pub auto_compress: Option<bool>,
pub overwrite: bool,
// maps to one of the predefined compression algos
// todo: support different compression formats?
pub source_compression: String,
pub source_compression: Option<String>,
pub stage_info: PutGetStageInfo,
pub encryption_material: EncryptionMaterialVariant,
// GCS specific. If you request multiple files?
// might return a [ null ] for AWS responses
#[serde(default)]
pub presigned_urls: Vec<String>,
pub presigned_urls: Vec<Option<String>>,
#[serde(default)]
pub parameters: Vec<NameValueParameter>,
pub statement_type_id: Option<i64>,
Expand Down