diff --git a/Cargo.lock b/Cargo.lock index 5cb3fe2d..0efd4988 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1834,6 +1834,15 @@ version = "1.9.5" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "7d71d3574edd2771538b901e6549113b4006ece66150fb69c0fb6d9a2adae946" +[[package]] +name = "humansize" +version = "2.1.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "6cb51c9a029ddc91b07a787f1d86b53ccfa49b0e86688c946ebe8d3555685dd7" +dependencies = [ + "libm", +] + [[package]] name = "icu_collections" version = "1.5.0" @@ -2874,6 +2883,26 @@ dependencies = [ "tracing", ] +[[package]] +name = "pgt_hover" +version = "0.0.0" +dependencies = [ + "humansize", + "pgt_query_ext", + "pgt_schema_cache", + "pgt_test_utils", + "pgt_text_size", + "pgt_treesitter", + "schemars", + "serde", + "serde_json", + "sqlx", + "tokio", + "tracing", + "tree-sitter", + "tree_sitter_sql", +] + [[package]] name = "pgt_lexer" version = "0.0.0" @@ -3103,6 +3132,7 @@ dependencies = [ "pgt_console", "pgt_diagnostics", "pgt_fs", + "pgt_hover", "pgt_lexer", "pgt_query_ext", "pgt_schema_cache", diff --git a/Cargo.toml b/Cargo.toml index a5195d2d..67942264 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -70,6 +70,7 @@ pgt_diagnostics_categories = { path = "./crates/pgt_diagnostics_categories", ver pgt_diagnostics_macros = { path = "./crates/pgt_diagnostics_macros", version = "0.0.0" } pgt_flags = { path = "./crates/pgt_flags", version = "0.0.0" } pgt_fs = { path = "./crates/pgt_fs", version = "0.0.0" } +pgt_hover = { path = "./crates/pgt_hover", version = "0.0.0" } pgt_lexer = { path = "./crates/pgt_lexer", version = "0.0.0" } pgt_lexer_codegen = { path = "./crates/pgt_lexer_codegen", version = "0.0.0" } pgt_lsp = { path = "./crates/pgt_lsp", version = "0.0.0" } diff --git a/crates/pgt_hover/Cargo.toml b/crates/pgt_hover/Cargo.toml new file mode 100644 index 00000000..eab3f70c --- /dev/null +++ b/crates/pgt_hover/Cargo.toml @@ -0,0 +1,36 @@ +[package] +authors.workspace = true +categories.workspace = true +description = "" +edition.workspace = true +homepage.workspace = true +keywords.workspace = true +license.workspace = true +name = "pgt_hover" +repository.workspace = true +version = "0.0.0" + + +[dependencies] +humansize = { version = "2.1.3" } +pgt_query_ext.workspace = true +pgt_schema_cache.workspace = true +pgt_text_size.workspace = true +pgt_treesitter.workspace = true +schemars = { workspace = true, optional = true } +serde = { workspace = true, features = ["derive"] } +serde_json = { workspace = true } +sqlx.workspace = true +tokio = { version = "1.41.1", features = ["full"] } +tracing = { workspace = true } +tree-sitter.workspace = true +tree_sitter_sql.workspace = true + +[dev-dependencies] +pgt_test_utils.workspace = true + +[lib] +doctest = false + +[features] +schema = ["dep:schemars"] diff --git a/crates/pgt_hover/src/hovered_node.rs b/crates/pgt_hover/src/hovered_node.rs new file mode 100644 index 00000000..2f1905e7 --- /dev/null +++ b/crates/pgt_hover/src/hovered_node.rs @@ -0,0 +1,50 @@ +use pgt_text_size::TextSize; +use pgt_treesitter::TreeSitterContextParams; + +#[derive(Debug)] +pub(crate) enum NodeIdentification { + Name(String), + SchemaAndName((String, String)), + #[allow(unused)] + SchemaAndTableAndName((String, String, String)), +} + +#[allow(unused)] +#[derive(Debug)] +pub(crate) enum HoveredNode { + Schema(NodeIdentification), + Table(NodeIdentification), + Function(NodeIdentification), + Column(NodeIdentification), + Policy(NodeIdentification), + Trigger(NodeIdentification), + Role(NodeIdentification), +} + +impl HoveredNode { + pub(crate) fn get(position: TextSize, text: &str, tree: &tree_sitter::Tree) -> Option { + let ctx = pgt_treesitter::context::TreesitterContext::new(TreeSitterContextParams { + position, + text, + tree, + }); + + let node_content = ctx.get_node_under_cursor_content()?; + + let under_node = ctx.node_under_cursor.as_ref()?; + + match under_node.kind() { + "identifier" if ctx.parent_matches_one_of_kind(&["object_reference", "relation"]) => { + if let Some(schema) = ctx.schema_or_alias_name { + Some(HoveredNode::Table(NodeIdentification::SchemaAndName(( + schema, + node_content, + )))) + } else { + Some(HoveredNode::Table(NodeIdentification::Name(node_content))) + } + } + _ => None, + } + } +} diff --git a/crates/pgt_hover/src/lib.rs b/crates/pgt_hover/src/lib.rs new file mode 100644 index 00000000..8a85d980 --- /dev/null +++ b/crates/pgt_hover/src/lib.rs @@ -0,0 +1,47 @@ +use pgt_schema_cache::SchemaCache; +use pgt_text_size::TextSize; + +use crate::{hovered_node::HoveredNode, to_markdown::ToHoverMarkdown}; + +mod hovered_node; +mod to_markdown; + +pub struct OnHoverParams<'a> { + pub position: TextSize, + pub schema_cache: &'a SchemaCache, + pub stmt_sql: &'a str, + pub ast: Option<&'a pgt_query_ext::NodeEnum>, + pub ts_tree: &'a tree_sitter::Tree, +} + +pub fn on_hover(params: OnHoverParams) -> Vec { + if let Some(hovered_node) = HoveredNode::get(params.position, params.stmt_sql, params.ts_tree) { + match hovered_node { + HoveredNode::Table(node_identification) => { + let table = match node_identification { + hovered_node::NodeIdentification::Name(n) => { + params.schema_cache.find_table(n.as_str(), None) + } + hovered_node::NodeIdentification::SchemaAndName((s, n)) => { + params.schema_cache.find_table(n.as_str(), Some(s.as_str())) + } + hovered_node::NodeIdentification::SchemaAndTableAndName(_) => None, + }; + + table + .map(|t| { + let mut markdown = String::new(); + match t.to_hover_markdown(&mut markdown) { + Ok(_) => vec![markdown], + Err(_) => vec![], + } + }) + .unwrap_or(vec![]) + } + + _ => todo!(), + } + } else { + Default::default() + } +} diff --git a/crates/pgt_hover/src/to_markdown.rs b/crates/pgt_hover/src/to_markdown.rs new file mode 100644 index 00000000..7ea66160 --- /dev/null +++ b/crates/pgt_hover/src/to_markdown.rs @@ -0,0 +1,90 @@ +use std::fmt::Write; + +use humansize::DECIMAL; + +pub(crate) trait ToHoverMarkdown { + fn to_hover_markdown(&self, writer: &mut W) -> Result<(), std::fmt::Error>; +} + +impl ToHoverMarkdown for pgt_schema_cache::Table { + fn to_hover_markdown(&self, writer: &mut W) -> Result<(), std::fmt::Error> { + HeadlineWriter::for_table(writer, self)?; + BodyWriter::for_table(writer, self)?; + FooterWriter::for_table(writer, self)?; + + Ok(()) + } +} + +struct HeadlineWriter; + +impl HeadlineWriter { + fn for_table( + writer: &mut W, + table: &pgt_schema_cache::Table, + ) -> Result<(), std::fmt::Error> { + let table_kind = match table.table_kind { + pgt_schema_cache::TableKind::View => " (View)", + pgt_schema_cache::TableKind::MaterializedView => " (M.View)", + pgt_schema_cache::TableKind::Partitioned => " (Partitioned)", + pgt_schema_cache::TableKind::Ordinary => "", + }; + + let locked_txt = if table.rls_enabled { + " - 🔒 RLS enabled" + } else { + " - 🔓 RLS disabled" + }; + + write!( + writer, + "### {}.{}{}{}", + table.schema, table.name, table_kind, locked_txt + )?; + + markdown_newline(writer)?; + + Ok(()) + } +} + +struct BodyWriter; + +impl BodyWriter { + fn for_table( + writer: &mut W, + table: &pgt_schema_cache::Table, + ) -> Result<(), std::fmt::Error> { + if let Some(c) = table.comment.as_ref() { + write!(writer, "{}", c)?; + markdown_newline(writer)?; + } + + Ok(()) + } +} + +struct FooterWriter; + +impl FooterWriter { + fn for_table( + writer: &mut W, + table: &pgt_schema_cache::Table, + ) -> Result<(), std::fmt::Error> { + write!( + writer, + "~{} rows, ~{} dead rows, {}", + table.live_rows_estimate, + table.dead_rows_estimate, + humansize::format_size(table.bytes as u64, DECIMAL) + )?; + + Ok(()) + } +} + +fn markdown_newline(writer: &mut W) -> Result<(), std::fmt::Error> { + write!(writer, " ")?; + writeln!(writer)?; + Ok(()) +} diff --git a/crates/pgt_lsp/src/capabilities.rs b/crates/pgt_lsp/src/capabilities.rs index 3b473eb7..8c8ff6d9 100644 --- a/crates/pgt_lsp/src/capabilities.rs +++ b/crates/pgt_lsp/src/capabilities.rs @@ -3,9 +3,10 @@ use crate::handlers::code_actions::command_id; use pgt_workspace::features::code_actions::CommandActionCategory; use strum::IntoEnumIterator; use tower_lsp::lsp_types::{ - ClientCapabilities, CompletionOptions, ExecuteCommandOptions, PositionEncodingKind, - SaveOptions, ServerCapabilities, TextDocumentSyncCapability, TextDocumentSyncKind, - TextDocumentSyncOptions, TextDocumentSyncSaveOptions, WorkDoneProgressOptions, + ClientCapabilities, CompletionOptions, ExecuteCommandOptions, HoverProviderCapability, + PositionEncodingKind, SaveOptions, ServerCapabilities, TextDocumentSyncCapability, + TextDocumentSyncKind, TextDocumentSyncOptions, TextDocumentSyncSaveOptions, + WorkDoneProgressOptions, }; /// The capabilities to send from server as part of [`InitializeResult`] @@ -62,6 +63,7 @@ pub(crate) fn server_capabilities(capabilities: &ClientCapabilities) -> ServerCa true, )), rename_provider: None, + hover_provider: Some(HoverProviderCapability::Simple(true)), ..Default::default() } } diff --git a/crates/pgt_lsp/src/handlers.rs b/crates/pgt_lsp/src/handlers.rs index 103bef2f..113e3fcc 100644 --- a/crates/pgt_lsp/src/handlers.rs +++ b/crates/pgt_lsp/src/handlers.rs @@ -1,3 +1,4 @@ pub(crate) mod code_actions; pub(crate) mod completions; +pub(crate) mod hover; pub(crate) mod text_document; diff --git a/crates/pgt_lsp/src/handlers/hover.rs b/crates/pgt_lsp/src/handlers/hover.rs new file mode 100644 index 00000000..4dd44ca6 --- /dev/null +++ b/crates/pgt_lsp/src/handlers/hover.rs @@ -0,0 +1,42 @@ +use pgt_workspace::{WorkspaceError, features::on_hover::OnHoverParams}; +use tower_lsp::lsp_types::{self, MarkedString, MarkupContent}; + +use crate::{adapters::get_cursor_position, diagnostics::LspError, session::Session}; + +pub(crate) fn on_hover( + session: &Session, + params: lsp_types::HoverParams, +) -> Result { + let url = params.text_document_position_params.text_document.uri; + let position = params.text_document_position_params.position; + let path = session.file_path(&url)?; + + match session.workspace.on_hover(OnHoverParams { + path, + position: get_cursor_position(session, &url, position)?, + }) { + Ok(result) => { + tracing::warn!("Got a result. {:#?}", result); + + Ok(lsp_types::HoverContents::Array( + result + .into_iter() + .map(MarkedString::from_markdown) + .collect(), + )) + } + + Err(e) => match e { + WorkspaceError::DatabaseConnectionError(_) => { + Ok(lsp_types::HoverContents::Markup(MarkupContent { + kind: lsp_types::MarkupKind::PlainText, + value: "Cannot connect to database.".into(), + })) + } + _ => { + tracing::error!("Received an error: {:#?}", e); + Err(e.into()) + } + }, + } +} diff --git a/crates/pgt_lsp/src/server.rs b/crates/pgt_lsp/src/server.rs index 6420c511..76d9bd9a 100644 --- a/crates/pgt_lsp/src/server.rs +++ b/crates/pgt_lsp/src/server.rs @@ -265,6 +265,17 @@ impl LanguageServer for LSPServer { } } + #[tracing::instrument(level = "trace", skip_all)] + async fn hover(&self, params: HoverParams) -> LspResult> { + match handlers::hover::on_hover(&self.session, params) { + Ok(result) => LspResult::Ok(Some(Hover { + contents: result, + range: None, + })), + Err(e) => LspResult::Err(into_lsp_error(e)), + } + } + #[tracing::instrument(level = "trace", skip_all)] async fn completion(&self, params: CompletionParams) -> LspResult> { match handlers::completions::get_completions(&self.session, params) { diff --git a/crates/pgt_schema_cache/src/schema_cache.rs b/crates/pgt_schema_cache/src/schema_cache.rs index 8fb9683b..b8bc78d4 100644 --- a/crates/pgt_schema_cache/src/schema_cache.rs +++ b/crates/pgt_schema_cache/src/schema_cache.rs @@ -60,13 +60,13 @@ impl SchemaCache { pub fn find_table(&self, name: &str, schema: Option<&str>) -> Option<&Table> { self.tables .iter() - .find(|t| t.name == name && schema.is_none() || Some(t.schema.as_str()) == schema) + .find(|t| t.name == name && schema.is_none_or(|s| s == t.schema.as_str())) } pub fn find_type(&self, name: &str, schema: Option<&str>) -> Option<&PostgresType> { self.types .iter() - .find(|t| t.name == name && schema.is_none() || Some(t.schema.as_str()) == schema) + .find(|t| t.name == name && schema.is_none_or(|s| s == t.schema.as_str())) } pub fn find_col(&self, name: &str, table: &str, schema: Option<&str>) -> Option<&Column> { @@ -80,7 +80,7 @@ impl SchemaCache { pub fn find_types(&self, name: &str, schema: Option<&str>) -> Vec<&PostgresType> { self.types .iter() - .filter(|t| t.name == name && schema.is_none() || Some(t.schema.as_str()) == schema) + .filter(|t| t.name == name && schema.is_none_or(|s| s == t.schema.as_str())) .collect() } } diff --git a/crates/pgt_workspace/Cargo.toml b/crates/pgt_workspace/Cargo.toml index 3ef4936b..1498fb08 100644 --- a/crates/pgt_workspace/Cargo.toml +++ b/crates/pgt_workspace/Cargo.toml @@ -25,6 +25,7 @@ pgt_configuration = { workspace = true } pgt_console = { workspace = true } pgt_diagnostics = { workspace = true } pgt_fs = { workspace = true, features = ["serde"] } +pgt_hover = { workspace = true } pgt_lexer = { workspace = true } pgt_query_ext = { workspace = true } pgt_schema_cache = { workspace = true } diff --git a/crates/pgt_workspace/src/features/completions.rs b/crates/pgt_workspace/src/features/completions.rs index a41dd06e..5944f14c 100644 --- a/crates/pgt_workspace/src/features/completions.rs +++ b/crates/pgt_workspace/src/features/completions.rs @@ -4,7 +4,7 @@ use pgt_completions::CompletionItem; use pgt_fs::PgTPath; use pgt_text_size::{TextRange, TextSize}; -use crate::workspace::{Document, GetCompletionsFilter, GetCompletionsMapper, StatementId}; +use crate::workspace::{Document, GetCompletionsFilter, StatementId, WithCSTMapper}; #[derive(Debug, serde::Serialize, serde::Deserialize)] #[cfg_attr(feature = "schema", derive(schemars::JsonSchema))] @@ -32,7 +32,7 @@ impl IntoIterator for CompletionsResult { pub(crate) fn get_statement_for_completions( doc: &Document, position: TextSize, -) -> Option<(StatementId, TextRange, String, Arc)> { +) -> Option<(StatementId, TextRange, Arc)> { let count = doc.count(); // no arms no cookies if count == 0 { @@ -40,7 +40,7 @@ pub(crate) fn get_statement_for_completions( } let mut eligible_statements = doc.iter_with_filter( - GetCompletionsMapper, + WithCSTMapper, GetCompletionsFilter { cursor_position: position, }, @@ -49,7 +49,7 @@ pub(crate) fn get_statement_for_completions( if count == 1 { eligible_statements.next() } else { - let mut prev_stmt: Option<(StatementId, TextRange, String, Arc)> = None; + let mut prev_stmt: Option<(StatementId, TextRange, Arc)> = None; for current_stmt in eligible_statements { /* @@ -112,10 +112,10 @@ mod tests { let (doc, position) = get_doc_and_pos(sql.as_str()); - let (_, _, text, _) = + let (stmt, _, _) = get_statement_for_completions(&doc, position).expect("Expected Statement"); - assert_eq!(text, "update users set email = 'myemail@com';") + assert_eq!(stmt.content(), "update users set email = 'myemail@com';") } #[test] @@ -151,10 +151,10 @@ mod tests { let (doc, position) = get_doc_and_pos(sql.as_str()); - let (_, _, text, _) = + let (stmt, _, _) = get_statement_for_completions(&doc, position).expect("Expected Statement"); - assert_eq!(text, "select * from ;") + assert_eq!(stmt.content(), "select * from ;") } #[test] @@ -163,10 +163,10 @@ mod tests { let (doc, position) = get_doc_and_pos(sql.as_str()); - let (_, _, text, _) = + let (stmt, _, _) = get_statement_for_completions(&doc, position).expect("Expected Statement"); - assert_eq!(text, "select * from") + assert_eq!(stmt.content(), "select * from") } #[test] @@ -187,10 +187,10 @@ mod tests { let (doc, position) = get_doc_and_pos(sql); - let (_, _, text, _) = + let (stmt, _, _) = get_statement_for_completions(&doc, position).expect("Expected Statement"); - assert_eq!(text.trim(), "select from cool;") + assert_eq!(stmt.content().trim(), "select from cool;") } #[test] diff --git a/crates/pgt_workspace/src/features/mod.rs b/crates/pgt_workspace/src/features/mod.rs index 31013f36..7455f0be 100644 --- a/crates/pgt_workspace/src/features/mod.rs +++ b/crates/pgt_workspace/src/features/mod.rs @@ -1,3 +1,4 @@ pub mod code_actions; pub mod completions; pub mod diagnostics; +pub mod on_hover; diff --git a/crates/pgt_workspace/src/features/on_hover.rs b/crates/pgt_workspace/src/features/on_hover.rs new file mode 100644 index 00000000..3e3fcd49 --- /dev/null +++ b/crates/pgt_workspace/src/features/on_hover.rs @@ -0,0 +1,25 @@ +use pgt_fs::PgTPath; +use pgt_text_size::TextSize; + +#[derive(Debug, serde::Serialize, serde::Deserialize)] +#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))] +pub struct OnHoverParams { + pub path: PgTPath, + pub position: TextSize, +} + +#[derive(Debug, serde::Serialize, serde::Deserialize, Default)] +#[cfg_attr(feature = "schema", derive(schemars::JsonSchema))] +pub struct OnHoverResult { + /// Can contain multiple blocks of markdown + /// if the hovered-on item is ambiguous. + pub(crate) markdown_blocks: Vec, +} + +impl IntoIterator for OnHoverResult { + type Item = String; + type IntoIter = as IntoIterator>::IntoIter; + fn into_iter(self) -> Self::IntoIter { + self.markdown_blocks.into_iter() + } +} diff --git a/crates/pgt_workspace/src/workspace.rs b/crates/pgt_workspace/src/workspace.rs index 9206b39d..0747c081 100644 --- a/crates/pgt_workspace/src/workspace.rs +++ b/crates/pgt_workspace/src/workspace.rs @@ -17,6 +17,7 @@ use crate::{ }, completions::{CompletionsResult, GetCompletionsParams}, diagnostics::{PullDiagnosticsParams, PullDiagnosticsResult}, + on_hover::{OnHoverParams, OnHoverResult}, }, }; @@ -113,6 +114,8 @@ pub trait Workspace: Send + Sync + RefUnwindSafe { params: GetCompletionsParams, ) -> Result; + fn on_hover(&self, params: OnHoverParams) -> Result; + /// Register a possible workspace project folder. Returns the key of said project. Use this key when you want to switch to different projects. fn register_project_folder( &self, diff --git a/crates/pgt_workspace/src/workspace/client.rs b/crates/pgt_workspace/src/workspace/client.rs index 2bd21513..05e964f6 100644 --- a/crates/pgt_workspace/src/workspace/client.rs +++ b/crates/pgt_workspace/src/workspace/client.rs @@ -161,4 +161,11 @@ where ) -> Result { self.request("pgt/get_completions", params) } + + fn on_hover( + &self, + params: crate::features::on_hover::OnHoverParams, + ) -> Result { + self.request("pgt/on_hover", params) + } } diff --git a/crates/pgt_workspace/src/workspace/server.rs b/crates/pgt_workspace/src/workspace/server.rs index e6456afc..a6a64e69 100644 --- a/crates/pgt_workspace/src/workspace/server.rs +++ b/crates/pgt_workspace/src/workspace/server.rs @@ -35,9 +35,10 @@ use crate::{ }, completions::{CompletionsResult, GetCompletionsParams, get_statement_for_completions}, diagnostics::{PullDiagnosticsParams, PullDiagnosticsResult}, + on_hover::{OnHoverParams, OnHoverResult}, }, settings::{WorkspaceSettings, WorkspaceSettingsHandle, WorkspaceSettingsHandleMut}, - workspace::AnalyserDiagnosticsMapper, + workspace::{AnalyserDiagnosticsMapper, WithCSTandASTMapper}, }; use super::{ @@ -634,20 +635,59 @@ impl Workspace for WorkspaceServer { tracing::debug!("No statement found."); Ok(CompletionsResult::default()) } - Some((_id, range, content, cst)) => { + Some((_id, range, cst)) => { let position = params.position - range.start(); let items = pgt_completions::complete(pgt_completions::CompletionParams { position, schema: schema_cache.as_ref(), tree: &cst, - text: content, + text: _id.content().to_string(), }); Ok(CompletionsResult { items }) } } } + + fn on_hover(&self, params: OnHoverParams) -> Result { + let documents = self.documents.read().unwrap(); + let doc = documents + .get(¶ms.path) + .ok_or(WorkspaceError::not_found())?; + + let pool = self.get_current_connection(); + if pool.is_none() { + tracing::debug!("No database connection available. Skipping completions."); + return Ok(OnHoverResult::default()); + } + let pool = pool.unwrap(); + + let schema_cache = self.schema_cache.load(pool)?; + + match doc + .iter_with_filter( + WithCSTandASTMapper, + CursorPositionFilter::new(params.position), + ) + .next() + { + Some((stmt_id, range, ts_tree, maybe_ast)) => { + let position_in_stmt = params.position + range.start(); + + let markdown_blocks = pgt_hover::on_hover(pgt_hover::OnHoverParams { + ts_tree: &ts_tree, + schema_cache: &schema_cache, + ast: maybe_ast.as_ref(), + position: position_in_stmt, + stmt_sql: stmt_id.content(), + }); + + Ok(OnHoverResult { markdown_blocks }) + } + None => Ok(OnHoverResult::default()), + } + } } /// Returns `true` if `path` is a directory or diff --git a/crates/pgt_workspace/src/workspace/server/document.rs b/crates/pgt_workspace/src/workspace/server/document.rs index c9f880ec..b2e97934 100644 --- a/crates/pgt_workspace/src/workspace/server/document.rs +++ b/crates/pgt_workspace/src/workspace/server/document.rs @@ -268,14 +268,35 @@ impl<'a> StatementMapper<'a> for AnalyserDiagnosticsMapper { ) } } +pub struct WithCSTMapper; +impl<'a> StatementMapper<'a> for WithCSTMapper { + type Output = (StatementId, TextRange, Arc); -pub struct GetCompletionsMapper; -impl<'a> StatementMapper<'a> for GetCompletionsMapper { - type Output = (StatementId, TextRange, String, Arc); + fn map(&self, parser: &'a Document, id: StatementId, range: TextRange) -> Self::Output { + let tree = parser.cst_db.get_or_cache_tree(&id); + (id.clone(), range, tree) + } +} + +pub struct WithCSTandASTMapper; +impl<'a> StatementMapper<'a> for WithCSTandASTMapper { + type Output = ( + StatementId, + TextRange, + Arc, + Option, + ); fn map(&self, parser: &'a Document, id: StatementId, range: TextRange) -> Self::Output { let tree = parser.cst_db.get_or_cache_tree(&id); - (id.clone(), range, id.content().to_string(), tree) + let ast_result = parser.ast_db.get_or_cache_ast(&id); + + let ast_option = match &*ast_result { + Ok(node) => Some(node.clone()), + Err(_) => None, + }; + + (id.clone(), range, tree, ast_option) } } @@ -555,11 +576,11 @@ $$ LANGUAGE plpgsql;"; let input = "SELECT * FROM users;"; let d = Document::new(input.to_string(), 1); - let results = d.iter(GetCompletionsMapper).collect::>(); + let results = d.iter(WithCSTMapper).collect::>(); assert_eq!(results.len(), 1); - let (_id, _range, content, tree) = &results[0]; - assert_eq!(content, "SELECT * FROM users;"); + let (id, _, tree) = &results[0]; + assert_eq!(id.content(), "SELECT * FROM users;"); assert_eq!(tree.root_node().kind(), "program"); }