diff --git a/AGENTS.md b/AGENTS.md index 3f6a48b..21dd626 100644 --- a/AGENTS.md +++ b/AGENTS.md @@ -110,6 +110,11 @@ Be cautious about: ## Verification Notes +For formatting, use `rustfmt` from the Rust toolchain managed by `rustup`, not a Homebrew-installed formatter. CI installs `rustfmt` on the stable toolchain and runs `cargo fmt --all --check`, so local verification should use the same toolchain, for example: + +- `rustup component add rustfmt` +- `rustup run stable cargo fmt --all` + Before closing a session that changes behavior: - run `cargo check` diff --git a/README.md b/README.md index 4e1d584..f7105da 100644 --- a/README.md +++ b/README.md @@ -68,6 +68,25 @@ Inspect the stack at any time: dig tree ``` +Create or adopt a GitHub pull request for the current tracked branch: + +```bash +dig pr --title "feat: auth" --body "Implements authentication." --draft +``` + +Open the current branch's pull request in the browser: + +```bash +dig pr --view +``` + +List tracked open pull requests in stack order: + +```bash +dig pr list +dig pr list --view +``` + ### Common commands ```bash @@ -77,6 +96,11 @@ dig branch -p # create a tracked branch under a specific paren dig tree # show the full tracked branch tree dig tree --branch # show one branch and its descendants dig commit -m "message" # commit and restack tracked descendants if needed +dig pr # create or adopt a GitHub PR for the current tracked branch +dig pr --title "title" --body "body" --draft +dig pr --view # open the current branch PR in the browser +dig pr list # list open GitHub PRs that dig is tracking +dig pr list --view # list tracked PRs, then open them in the browser dig sync # reconcile local dig state, restack stale stacks, then offer cleanup dig sync --continue # continue a paused restack after resolving conflicts dig merge # merge a tracked branch into its tracked parent @@ -105,6 +129,22 @@ If cleanup finds merged branches, `dig sync` reuses the same delete prompt as `d Remote sync is intentionally out of scope for now. Future GitHub and `gh` integration can extend `dig sync`, but the current command only reconciles local branches and local dig state. +### Track GitHub pull requests + +`dig pr` uses the GitHub CLI (`gh`) to create a pull request for the current tracked branch, or to adopt the existing open pull request for that branch if one already exists on GitHub. + +By default, dig targets the branch's tracked parent as the PR base. Root branches target trunk, child branches target their tracked parent branch, and the tracked PR number is stored locally in `.git/dig/state.json`. + +If the branch is not pushed to a resolvable remote yet, `dig pr` prompts before running `git push -u ` and then continues with PR creation if you confirm. + +When dig creates a pull request, it prints both the creation summary and the GitHub link. + +`dig tree` annotates tracked branches that have a PR with `(#123)`. + +`dig pr --view` opens the current branch's pull request in the browser. If you combine `--view` with a mutating PR command, dig opens the browser after the command completes. + +`dig pr list` shows only open pull requests that are both open on GitHub and currently tracked by dig, rendered in dig's stack order. Each line includes `#: ` and the GitHub URL. + ### Resolve paused commands Some commands, including `dig commit`, `dig adopt`, `dig reparent`, `dig merge`, `dig clean`, `dig orphan`, and `dig sync`, may pause if `dig` hits a rebase conflict while restacking tracked descendants. diff --git a/src/cli/mod.rs b/src/cli/mod.rs index 78c6920..815370e 100644 --- a/src/cli/mod.rs +++ b/src/cli/mod.rs @@ -7,6 +7,7 @@ mod init; mod merge; mod operation; mod orphan; +mod pr; mod reparent; mod sync; mod tree; @@ -48,6 +49,9 @@ enum Commands { /// Stop tracking a branch in dig while keeping the local branch Orphan(orphan::OrphanArgs), + /// Create or adopt a GitHub pull request for the current tracked branch + Pr(pr::PrArgs), + /// Change a tracked branch's parent and restack it onto the new base Reparent(reparent::ReparentArgs), @@ -74,6 +78,7 @@ pub fn run() -> ExitCode { Commands::Commit(args) => commit::execute(args), Commands::Merge(args) => merge::execute(args), Commands::Orphan(args) => orphan::execute(args), + Commands::Pr(args) => pr::execute(args), Commands::Reparent(args) => reparent::execute(args), Commands::Sync(args) => sync::execute(args), Commands::Tree(args) => tree::execute(args), diff --git a/src/cli/pr/mod.rs b/src/cli/pr/mod.rs new file mode 100644 index 0000000..4816bc1 --- /dev/null +++ b/src/cli/pr/mod.rs @@ -0,0 +1,227 @@ +use std::io; + +use clap::{Args, Subcommand}; + +use crate::core::git; +use crate::core::pr::{ + self, PrOptions, PrOutcomeKind, TrackedPullRequestListNode, TrackedPullRequestListView, +}; + +use super::CommandOutcome; +use super::common; + +#[derive(Args, Debug, Clone)] +pub struct PrArgs { + #[command(subcommand)] + pub command: Option<PrCommand>, + + /// Title for the pull request + #[arg(long = "title", value_name = "TITLE")] + pub title: Option<String>, + + /// Body for the pull request + #[arg(long = "body", value_name = "BODY")] + pub body: Option<String>, + + /// Mark the pull request as a draft + #[arg(long = "draft")] + pub draft: bool, + + /// Open the pull request in the browser + #[arg(long = "view")] + pub view: bool, +} + +#[derive(Subcommand, Debug, Clone)] +pub enum PrCommand { + /// List open pull requests that are tracked by dig + List(PrListArgs), +} + +#[derive(Args, Debug, Clone, Default)] +pub struct PrListArgs { + /// Open each listed pull request in the browser + #[arg(long = "view")] + pub view: bool, +} + +pub fn execute(args: PrArgs) -> io::Result<CommandOutcome> { + match args.command.clone() { + Some(PrCommand::List(list_args)) => execute_list(list_args), + None => execute_current(args), + } +} + +fn execute_current(args: PrArgs) -> io::Result<CommandOutcome> { + let create_requested = args.title.is_some() || args.body.is_some() || args.draft; + if args.view && !create_requested { + pr::open_current_pull_request_in_browser()?; + return Ok(CommandOutcome { + status: git::success_status()?, + }); + } + + let mut options: PrOptions = args.clone().into(); + if let Some(push_target) = pr::current_branch_push_target_for_create()? { + let confirmed = common::confirm_yes_no(&format!( + "Branch '{}' is not pushed to '{}'. Push it and create the pull request? [y/N] ", + push_target.branch_name, push_target.remote_name + ))?; + + if !confirmed { + println!( + "Did not create pull request because '{}' is not pushed to '{}'.", + push_target.branch_name, push_target.remote_name + ); + return Ok(CommandOutcome { + status: git::success_status()?, + }); + } + + options.push_if_needed = true; + } + + let outcome = pr::run(&options)?; + match outcome.kind { + PrOutcomeKind::AlreadyTracked => { + println!( + "Branch '{}' already tracks pull request #{}.", + outcome.branch_name, outcome.pull_request.number + ); + } + PrOutcomeKind::Created => { + println!( + "Created pull request #{} for '{}' into '{}'.", + outcome.pull_request.number, outcome.branch_name, outcome.base_branch_name + ); + } + PrOutcomeKind::Adopted => { + println!( + "Tracking existing pull request #{} for '{}' into '{}'.", + outcome.pull_request.number, outcome.branch_name, outcome.base_branch_name + ); + } + } + + if args.view { + pr::open_pull_request_in_browser(outcome.pull_request.number)?; + } + + Ok(CommandOutcome { + status: outcome.status, + }) +} + +fn execute_list(args: PrListArgs) -> io::Result<CommandOutcome> { + let outcome = pr::list_open_tracked_pull_requests()?; + + if outcome.pull_requests.is_empty() { + println!("No open tracked pull requests."); + } else { + println!("{}", render_pull_request_list(&outcome.view)); + } + + if args.view { + pr::open_pull_requests_in_browser(&outcome.pull_requests)?; + } + + Ok(CommandOutcome { + status: outcome.status, + }) +} + +fn render_pull_request_list(view: &TrackedPullRequestListView) -> String { + common::render_tree( + view.root_label.clone(), + &view.roots, + &format_pull_request_label, + &|node| node.children.as_slice(), + ) +} + +fn format_pull_request_label(node: &TrackedPullRequestListNode) -> String { + format!( + "#{}: {} - {}", + node.pull_request.number, node.pull_request.title, node.pull_request.url + ) +} + +impl From<PrArgs> for PrOptions { + fn from(args: PrArgs) -> Self { + Self { + title: args.title, + body: args.body, + draft: args.draft, + push_if_needed: false, + } + } +} + +#[cfg(test)] +mod tests { + use super::{PrArgs, PrCommand, PrListArgs, render_pull_request_list}; + use crate::core::pr::PrOptions; + use crate::core::pr::{TrackedPullRequestListNode, TrackedPullRequestListView}; + + #[test] + fn converts_cli_args_into_core_pr_options() { + let options = PrOptions::from(PrArgs { + command: None, + title: Some("feat: auth".into()), + body: Some("Implements auth.".into()), + draft: true, + view: true, + }); + + assert_eq!(options.title.as_deref(), Some("feat: auth")); + assert_eq!(options.body.as_deref(), Some("Implements auth.")); + assert!(options.draft); + } + + #[test] + fn preserves_pr_list_subcommand_args() { + match (PrArgs { + command: Some(PrCommand::List(PrListArgs { view: true })), + title: None, + body: None, + draft: false, + view: false, + }) + .command + .unwrap() + { + PrCommand::List(args) => assert!(args.view), + } + } + + #[test] + fn renders_pull_request_list_as_tree() { + let rendered = render_pull_request_list(&TrackedPullRequestListView { + root_label: Some("main".into()), + roots: vec![TrackedPullRequestListNode { + pull_request: crate::core::gh::PullRequestDetails { + number: 123, + title: "Auth".into(), + url: "https://github.com/acme/dig/pull/123".into(), + }, + children: vec![TrackedPullRequestListNode { + pull_request: crate::core::gh::PullRequestDetails { + number: 124, + title: "Auth UI".into(), + url: "https://github.com/acme/dig/pull/124".into(), + }, + children: vec![], + }], + }], + }); + + assert_eq!( + rendered, + concat!( + "main\n", + "└── #123: Auth - https://github.com/acme/dig/pull/123\n", + " └── #124: Auth UI - https://github.com/acme/dig/pull/124" + ) + ); + } +} diff --git a/src/cli/tree/render.rs b/src/cli/tree/render.rs index a432495..9b058d3 100644 --- a/src/cli/tree/render.rs +++ b/src/cli/tree/render.rs @@ -1,13 +1,14 @@ use crate::cli::common; +use crate::core::graph::BranchLineageNode; use crate::core::tree::{TreeLabel, TreeView}; use crate::ui::markers; use crate::ui::palette::Accent; -pub fn render_branch_lineage(lineage: &[String]) -> String { +pub fn render_branch_lineage(lineage: &[BranchLineageNode]) -> String { let mut lines = Vec::new(); - for (index, branch_name) in lineage.iter().enumerate() { - lines.push(format_lineage_branch(branch_name, index == 0)); + for (index, branch) in lineage.iter().enumerate() { + lines.push(format_lineage_branch(branch, index == 0)); if index + 1 < lineage.len() { lines.push(format!("{} ", markers::LINEAGE_PIPE)); @@ -21,50 +22,79 @@ pub fn render_stack_tree(view: &TreeView) -> String { common::render_tree( view.root_label.as_ref().map(format_tree_label), &view.roots, - &|node| format_branch_label(&node.branch_name, node.is_current), + &|node| format_branch_label(&node.branch_name, node.is_current, node.pull_request_number), &|node| node.children.as_slice(), ) } fn format_tree_label(root_label: &TreeLabel) -> String { - format_branch_label(&root_label.branch_name, root_label.is_current) + format_branch_label( + &root_label.branch_name, + root_label.is_current, + root_label.pull_request_number, + ) +} + +fn format_branch_text(branch_name: &str, pull_request_number: Option<u64>) -> String { + match pull_request_number { + Some(number) => format!("{branch_name} (#{number})"), + None => branch_name.to_string(), + } } -fn format_branch_label(branch_name: &str, is_current: bool) -> String { +fn format_branch_label( + branch_name: &str, + is_current: bool, + pull_request_number: Option<u64>, +) -> String { + let label = format_branch_text(branch_name, pull_request_number); + if is_current { format!( "{} {}", Accent::BranchRef.paint_ansi(markers::CURRENT_BRANCH), - Accent::BranchRef.paint_ansi(branch_name) + Accent::BranchRef.paint_ansi(&label) ) } else { - branch_name.to_string() + label } } -fn format_lineage_branch(branch_name: &str, is_current: bool) -> String { +fn format_lineage_branch(branch: &BranchLineageNode, is_current: bool) -> String { + let label = format_branch_text(&branch.branch_name, branch.pull_request_number); + if is_current { format!( "{} {}", Accent::BranchRef.paint_ansi(markers::CURRENT_BRANCH), - Accent::BranchRef.paint_ansi(branch_name) + Accent::BranchRef.paint_ansi(&label) ) } else { - format!("{} {}", markers::NON_CURRENT_BRANCH, branch_name) + format!("{} {}", markers::NON_CURRENT_BRANCH, label) } } #[cfg(test)] mod tests { use super::{render_branch_lineage, render_stack_tree}; + use crate::core::graph::BranchLineageNode; use crate::core::tree::{TreeLabel, TreeNode, TreeView}; #[test] fn renders_linear_branch_lineage_as_vertical_path() { let tree = render_branch_lineage(&[ - "feature/api-followup".into(), - "feature/api".into(), - "main".into(), + BranchLineageNode { + branch_name: "feature/api-followup".into(), + pull_request_number: None, + }, + BranchLineageNode { + branch_name: "feature/api".into(), + pull_request_number: None, + }, + BranchLineageNode { + branch_name: "main".into(), + pull_request_number: None, + }, ]); assert_eq!( @@ -81,35 +111,72 @@ mod tests { #[test] fn renders_single_branch_lineage_without_connectors() { - let tree = render_branch_lineage(&["main".into()]); + let tree = render_branch_lineage(&[BranchLineageNode { + branch_name: "main".into(), + pull_request_number: None, + }]); assert_eq!(tree, "\u{1b}[32m✓\u{1b}[0m \u{1b}[32mmain\u{1b}[0m"); } + #[test] + fn renders_pull_request_numbers_in_lineage_view() { + let tree = render_branch_lineage(&[ + BranchLineageNode { + branch_name: "feature/api-followup".into(), + pull_request_number: Some(43), + }, + BranchLineageNode { + branch_name: "feature/api".into(), + pull_request_number: Some(42), + }, + BranchLineageNode { + branch_name: "main".into(), + pull_request_number: None, + }, + ]); + + assert_eq!( + tree, + concat!( + "\u{1b}[32m✓\u{1b}[0m \u{1b}[32mfeature/api-followup (#43)\u{1b}[0m\n", + "│ \n", + "* feature/api (#42)\n", + "│ \n", + "* main" + ) + ); + } + #[test] fn renders_shared_root_stack_tree() { let rendered = render_stack_tree(&TreeView { root_label: Some(TreeLabel { branch_name: "main".into(), is_current: false, + pull_request_number: None, }), roots: vec![ TreeNode { branch_name: "feat/auth".into(), is_current: false, + pull_request_number: None, children: vec![ TreeNode { branch_name: "feat/auth-api".into(), is_current: false, + pull_request_number: None, children: vec![TreeNode { branch_name: "feat/auth-api-tests".into(), is_current: false, + pull_request_number: None, children: vec![], }], }, TreeNode { branch_name: "feat/auth-ui".into(), is_current: true, + pull_request_number: None, children: vec![], }, ], @@ -117,15 +184,18 @@ mod tests { TreeNode { branch_name: "feat/billing".into(), is_current: false, + pull_request_number: None, children: vec![TreeNode { branch_name: "feat/billing-retry".into(), is_current: false, + pull_request_number: None, children: vec![], }], }, TreeNode { branch_name: "docs/readme".into(), is_current: false, + pull_request_number: None, children: vec![], }, ], @@ -152,20 +222,24 @@ mod tests { root_label: Some(TreeLabel { branch_name: "feat/auth".into(), is_current: false, + pull_request_number: None, }), roots: vec![ TreeNode { branch_name: "feat/auth-api".into(), is_current: false, + pull_request_number: None, children: vec![TreeNode { branch_name: "feat/auth-api-tests".into(), is_current: false, + pull_request_number: None, children: vec![], }], }, TreeNode { branch_name: "feat/auth-ui".into(), is_current: true, + pull_request_number: None, children: vec![], }, ], @@ -181,4 +255,38 @@ mod tests { ) ); } + + #[test] + fn renders_pull_request_numbers_for_normal_current_and_filtered_root_labels() { + let rendered = render_stack_tree(&TreeView { + root_label: Some(TreeLabel { + branch_name: "feat/auth".into(), + is_current: false, + pull_request_number: Some(42), + }), + roots: vec![ + TreeNode { + branch_name: "feat/auth-api".into(), + is_current: false, + pull_request_number: Some(43), + children: vec![], + }, + TreeNode { + branch_name: "feat/auth-ui".into(), + is_current: true, + pull_request_number: Some(44), + children: vec![], + }, + ], + }); + + assert_eq!( + rendered, + concat!( + "feat/auth (#42)\n", + "├── feat/auth-api (#43)\n", + "└── \u{1b}[32m✓\u{1b}[0m \u{1b}[32mfeat/auth-ui (#44)\u{1b}[0m" + ) + ); + } } diff --git a/src/core/adopt.rs b/src/core/adopt.rs index 87dcec7..846ad15 100644 --- a/src/core/adopt.rs +++ b/src/core/adopt.rs @@ -181,6 +181,7 @@ pub fn apply(plan: &AdoptPlan) -> io::Result<AdoptOutcome> { fork_point_oid: parent_head_oid, head_oid_at_creation: branch_head_oid, created_at_unix_secs: now_unix_timestamp_secs(), + pull_request: None, archived: false, }; @@ -263,6 +264,7 @@ pub(crate) fn resume_after_sync( fork_point_oid: parent_head_oid, head_oid_at_creation: branch_head_oid, created_at_unix_secs: now_unix_timestamp_secs(), + pull_request: None, archived: false, }; diff --git a/src/core/branch.rs b/src/core/branch.rs index 422e3db..68825c4 100644 --- a/src/core/branch.rs +++ b/src/core/branch.rs @@ -5,6 +5,7 @@ use uuid::Uuid; use crate::core::git; use crate::core::graph::BranchGraph; +use crate::core::graph::BranchLineageNode; use crate::core::store::types::DigState; use crate::core::store::{ BranchNode, DigConfig, ParentRef, now_unix_timestamp_secs, open_or_initialize, @@ -21,7 +22,7 @@ pub struct BranchOptions { pub struct BranchOutcome { pub status: ExitStatus, pub created_node: Option<BranchNode>, - pub lineage: Vec<String>, + pub lineage: Vec<BranchLineageNode>, } pub fn run(options: &BranchOptions) -> io::Result<BranchOutcome> { @@ -78,6 +79,7 @@ pub fn run(options: &BranchOptions) -> io::Result<BranchOutcome> { fork_point_oid: parent_head_oid.clone(), head_oid_at_creation: parent_head_oid, created_at_unix_secs: now_unix_timestamp_secs(), + pull_request: None, archived: false, }; @@ -87,7 +89,10 @@ pub fn run(options: &BranchOptions) -> io::Result<BranchOutcome> { return Ok(BranchOutcome { status, created_node: None, - lineage: vec![branch_name.to_string()], + lineage: vec![BranchLineageNode { + branch_name: branch_name.to_string(), + pull_request_number: None, + }], }); } @@ -181,6 +186,7 @@ mod tests { fork_point_oid: "abc123".into(), head_oid_at_creation: "abc123".into(), created_at_unix_secs: 1, + pull_request: None, archived: false, }], }; diff --git a/src/core/gh.rs b/src/core/gh.rs new file mode 100644 index 0000000..a648de4 --- /dev/null +++ b/src/core/gh.rs @@ -0,0 +1,416 @@ +use std::io; +use std::io::{Read, Write}; +use std::process::{Command, ExitStatus, Output, Stdio}; +use std::thread; + +use serde::Deserialize; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct PullRequestSummary { + pub number: u64, + pub base_ref_name: String, + pub url: String, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct PullRequestDetails { + pub number: u64, + pub title: String, + pub url: String, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct CreatePullRequestOptions { + pub base_branch_name: String, + pub title: Option<String>, + pub body: Option<String>, + pub draft: bool, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct CreatedPullRequest { + pub number: u64, + pub url: String, +} + +#[derive(Debug)] +struct GhCommandOutput { + status: ExitStatus, + stdout: String, + stderr: String, +} + +impl GhCommandOutput { + fn combined_output(&self) -> String { + let stdout = self.stdout.trim(); + let stderr = self.stderr.trim(); + + match (stdout.is_empty(), stderr.is_empty()) { + (true, true) => String::new(), + (false, true) => stdout.to_string(), + (true, false) => stderr.to_string(), + (false, false) => format!("{stdout}\n{stderr}"), + } + } +} + +#[derive(Debug, Deserialize)] +struct PullRequestSummaryRecord { + number: u64, + #[serde(rename = "baseRefName")] + base_ref_name: String, + url: String, +} + +#[derive(Debug, Deserialize)] +struct PullRequestViewRecord { + number: u64, + url: String, +} + +#[derive(Debug, Deserialize)] +struct PullRequestDetailsRecord { + number: u64, + title: String, + url: String, +} + +pub fn list_open_pull_requests_for_head(branch_name: &str) -> io::Result<Vec<PullRequestSummary>> { + let output = run_gh_capture_output(&[ + "pr".to_string(), + "list".to_string(), + "--head".to_string(), + branch_name.to_string(), + "--state".to_string(), + "open".to_string(), + "--json".to_string(), + "number,baseRefName,url".to_string(), + ])?; + + parse_open_pull_requests(&output.stdout) +} + +pub fn create_pull_request(options: &CreatePullRequestOptions) -> io::Result<CreatedPullRequest> { + let mut args = vec![ + "pr".to_string(), + "create".to_string(), + "--base".to_string(), + options.base_branch_name.clone(), + ]; + + if let Some(title) = &options.title { + args.push("--title".to_string()); + args.push(title.clone()); + } + + if let Some(body) = &options.body { + args.push("--body".to_string()); + args.push(body.clone()); + } + + if options.draft { + args.push("--draft".to_string()); + } + + let output = run_gh_with_live_output(&args)?; + if !output.status.success() { + return Err(gh_command_failed( + "gh pr create", + &output.stdout, + &output.stderr, + )); + } + + let url = find_pull_request_url(&output.combined_output()).ok_or_else(|| { + io::Error::other("gh pr create succeeded but did not report a pull request URL") + })?; + + if let Some(number) = pull_request_number_from_url(&url) { + return Ok(CreatedPullRequest { number, url }); + } + + view_pull_request_by_url(&url) +} + +pub fn list_open_pull_requests() -> io::Result<Vec<PullRequestDetails>> { + let output = run_gh_capture_output(&[ + "pr".to_string(), + "list".to_string(), + "--state".to_string(), + "open".to_string(), + "--json".to_string(), + "number,title,url".to_string(), + ])?; + + parse_open_pull_request_details(&output.stdout) +} + +pub fn open_current_pull_request_in_browser() -> io::Result<()> { + run_gh_command( + "gh pr view --web", + &["pr".to_string(), "view".to_string(), "--web".to_string()], + ) +} + +pub fn open_pull_request_in_browser(number: u64) -> io::Result<()> { + run_gh_command( + "gh pr view --web", + &[ + "pr".to_string(), + "view".to_string(), + number.to_string(), + "--web".to_string(), + ], + ) +} + +fn view_pull_request_by_url(url: &str) -> io::Result<CreatedPullRequest> { + let output = run_gh_capture_output(&[ + "pr".to_string(), + "view".to_string(), + url.to_string(), + "--json".to_string(), + "number,url".to_string(), + ])?; + let record: PullRequestViewRecord = serde_json::from_str(&output.stdout) + .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?; + + Ok(CreatedPullRequest { + number: record.number, + url: record.url, + }) +} + +fn parse_open_pull_requests(stdout: &str) -> io::Result<Vec<PullRequestSummary>> { + let records: Vec<PullRequestSummaryRecord> = serde_json::from_str(stdout) + .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?; + + Ok(records + .into_iter() + .map(|record| PullRequestSummary { + number: record.number, + base_ref_name: record.base_ref_name, + url: record.url, + }) + .collect()) +} + +fn parse_open_pull_request_details(stdout: &str) -> io::Result<Vec<PullRequestDetails>> { + let records: Vec<PullRequestDetailsRecord> = serde_json::from_str(stdout) + .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?; + + Ok(records + .into_iter() + .map(|record| PullRequestDetails { + number: record.number, + title: record.title, + url: record.url, + }) + .collect()) +} + +fn find_pull_request_url(output: &str) -> Option<String> { + output + .split_whitespace() + .find(|token| token.contains("/pull/")) + .map(|token| { + token.trim_matches(|ch: char| matches!(ch, '"' | '\'' | '(' | ')' | '[' | ']')) + }) + .map(str::to_string) +} + +fn pull_request_number_from_url(url: &str) -> Option<u64> { + let (_, suffix) = url.rsplit_once("/pull/")?; + let digits = suffix + .chars() + .take_while(|ch| ch.is_ascii_digit()) + .collect::<String>(); + + (!digits.is_empty()).then(|| digits.parse().ok()).flatten() +} + +fn run_gh_capture_output(args: &[String]) -> io::Result<GhCommandOutput> { + let output = Command::new("gh") + .args(args) + .output() + .map_err(normalize_gh_spawn_error)?; + + output_to_gh_command_output(output) +} + +fn run_gh_command(command_name: &str, args: &[String]) -> io::Result<()> { + let output = run_gh_capture_output(args)?; + if output.status.success() { + Ok(()) + } else { + Err(gh_command_failed( + command_name, + &output.stdout, + &output.stderr, + )) + } +} + +fn run_gh_with_live_output(args: &[String]) -> io::Result<GhCommandOutput> { + let mut child = Command::new("gh") + .args(args) + .stdin(Stdio::inherit()) + .stdout(Stdio::piped()) + .stderr(Stdio::piped()) + .spawn() + .map_err(normalize_gh_spawn_error)?; + + let stdout = child + .stdout + .take() + .ok_or_else(|| io::Error::other("failed to capture gh stdout"))?; + let stderr = child + .stderr + .take() + .ok_or_else(|| io::Error::other("failed to capture gh stderr"))?; + + let stdout_handle = thread::spawn(move || stream_and_capture(stdout, false)); + let stderr_handle = thread::spawn(move || stream_and_capture(stderr, true)); + + let status = child.wait()?; + let stdout = join_capture_thread(stdout_handle, "stdout")?; + let stderr = join_capture_thread(stderr_handle, "stderr")?; + + Ok(GhCommandOutput { + status, + stdout, + stderr, + }) +} + +fn stream_and_capture<R>(mut reader: R, use_stderr: bool) -> io::Result<String> +where + R: Read, +{ + let mut buffer = [0_u8; 4096]; + let mut captured = Vec::new(); + + loop { + let read = reader.read(&mut buffer)?; + if read == 0 { + break; + } + + if use_stderr { + let mut stderr = io::stderr(); + stderr.write_all(&buffer[..read])?; + stderr.flush()?; + } else { + let mut stdout = io::stdout(); + stdout.write_all(&buffer[..read])?; + stdout.flush()?; + } + + captured.extend_from_slice(&buffer[..read]); + } + + String::from_utf8(captured).map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err)) +} + +fn join_capture_thread( + handle: thread::JoinHandle<io::Result<String>>, + stream_name: &str, +) -> io::Result<String> { + handle + .join() + .map_err(|_| io::Error::other(format!("failed to join gh {stream_name} capture thread")))? +} + +fn output_to_gh_command_output(output: Output) -> io::Result<GhCommandOutput> { + Ok(GhCommandOutput { + status: output.status, + stdout: String::from_utf8(output.stdout) + .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?, + stderr: String::from_utf8(output.stderr) + .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?, + }) +} + +fn normalize_gh_spawn_error(err: io::Error) -> io::Error { + if err.kind() == io::ErrorKind::NotFound { + io::Error::other("gh CLI is not installed or not found on PATH") + } else { + err + } +} + +fn gh_command_failed(command_name: &str, stdout: &str, stderr: &str) -> io::Error { + let combined = match (stdout.trim().is_empty(), stderr.trim().is_empty()) { + (true, true) => String::new(), + (false, true) => stdout.trim().to_string(), + (true, false) => stderr.trim().to_string(), + (false, false) => format!("{}\n{}", stdout.trim(), stderr.trim()), + }; + + if looks_like_auth_error(&combined) { + return io::Error::other("gh authentication failed; run 'gh auth login'"); + } + + if combined.is_empty() { + io::Error::other(format!("{command_name} failed")) + } else { + io::Error::other(format!("{command_name} failed: {combined}")) + } +} + +fn looks_like_auth_error(message: &str) -> bool { + let normalized = message.to_ascii_lowercase(); + + normalized.contains("gh auth login") + || normalized.contains("not logged into any github hosts") + || normalized.contains("authentication") +} + +#[cfg(test)] +mod tests { + use super::{ + find_pull_request_url, parse_open_pull_request_details, parse_open_pull_requests, + pull_request_number_from_url, + }; + + #[test] + fn parses_open_pull_request_list_output() { + let pull_requests = parse_open_pull_requests( + r#"[{"number":123,"baseRefName":"main","url":"https://github.com/acme/dig/pull/123"}]"#, + ) + .unwrap(); + + assert_eq!(pull_requests.len(), 1); + assert_eq!(pull_requests[0].number, 123); + assert_eq!(pull_requests[0].base_ref_name, "main"); + } + + #[test] + fn parses_open_pull_request_details_output() { + let pull_requests = parse_open_pull_request_details( + r#"[{"number":123,"title":"Auth PR","url":"https://github.com/acme/dig/pull/123"}]"#, + ) + .unwrap(); + + assert_eq!(pull_requests[0].title, "Auth PR"); + } + + #[test] + fn extracts_pull_request_url_and_number_from_create_output() { + let url = find_pull_request_url( + "Creating pull request for feat/auth into main in acme/dig.\nhttps://github.com/acme/dig/pull/456\n", + ) + .unwrap(); + + assert_eq!(url, "https://github.com/acme/dig/pull/456"); + assert_eq!(pull_request_number_from_url(&url), Some(456)); + } + + #[test] + fn ignores_non_pull_request_urls_when_extracting_number() { + assert_eq!( + pull_request_number_from_url("https://github.com/acme/dig/issues/456"), + None + ); + } +} diff --git a/src/core/git.rs b/src/core/git.rs index 2c73230..5075100 100644 --- a/src/core/git.rs +++ b/src/core/git.rs @@ -45,6 +45,12 @@ pub struct CommitMetadata { pub body: String, } +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct BranchPushTarget { + pub remote_name: String, + pub branch_name: String, +} + #[derive(Debug, Clone)] pub struct RepoContext { pub git_dir: PathBuf, @@ -333,6 +339,110 @@ pub fn commit_metadata_in_range(range_spec: &str) -> io::Result<Vec<CommitMetada Ok(parse_commit_metadata_records(&stdout)) } +pub fn branch_push_target_if_needed(branch_name: &str) -> io::Result<Option<BranchPushTarget>> { + let Some(remote_name) = resolve_push_remote_name(branch_name)? else { + return Ok(None); + }; + + if branch_head_is_pushed_to_remote(branch_name, &remote_name)? { + return Ok(None); + } + + Ok(Some(BranchPushTarget { + remote_name, + branch_name: branch_name.to_string(), + })) +} + +pub fn push_branch_to_remote(target: &BranchPushTarget) -> io::Result<GitCommandOutput> { + let output = Command::new("git") + .args(["push", "-u", &target.remote_name, &target.branch_name]) + .output()?; + + output_to_git_command_output(output) +} + +fn resolve_push_remote_name(branch_name: &str) -> io::Result<Option<String>> { + if let Some(remote_name) = configured_branch_remote_name(branch_name)? { + return Ok(Some(remote_name)); + } + + let remote_names = git_remote_names()?; + match remote_names.as_slice() { + [] => Ok(None), + [remote_name] => Ok(Some(remote_name.clone())), + _ if remote_names + .iter() + .any(|remote_name| remote_name == "origin") => + { + Ok(Some("origin".to_string())) + } + _ => Ok(None), + } +} + +fn configured_branch_remote_name(branch_name: &str) -> io::Result<Option<String>> { + let key = format!("branch.{branch_name}.remote"); + let output = Command::new("git") + .args(["config", "--get", &key]) + .output()?; + + if !output.status.success() { + return Ok(None); + } + + let stdout = String::from_utf8(output.stdout) + .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?; + let remote_name = stdout.trim(); + + if remote_name.is_empty() { + Ok(None) + } else { + Ok(Some(remote_name.to_string())) + } +} + +fn git_remote_names() -> io::Result<Vec<String>> { + let output = Command::new("git").arg("remote").output()?; + + if !output.status.success() { + return Err(git_command_failed(&output)); + } + + let stdout = String::from_utf8(output.stdout) + .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?; + + Ok(stdout + .lines() + .map(str::trim) + .filter(|line| !line.is_empty()) + .map(str::to_string) + .collect()) +} + +fn branch_head_is_pushed_to_remote(branch_name: &str, remote_name: &str) -> io::Result<bool> { + let output = Command::new("git") + .args([ + "ls-remote", + "--heads", + remote_name, + &format!("refs/heads/{branch_name}"), + ]) + .output()?; + + if !output.status.success() { + return Err(git_command_failed(&output)); + } + + let stdout = String::from_utf8(output.stdout) + .map_err(|err| io::Error::new(io::ErrorKind::InvalidData, err))?; + let Some(remote_oid) = stdout.split_whitespace().next() else { + return Ok(false); + }; + + Ok(remote_oid == ref_oid(branch_name)?) +} + fn run_git_capture_output<const N: usize>(args: [&str; N]) -> io::Result<GitCommandOutput> { let output = Command::new("git").args(args).output()?; diff --git a/src/core/graph.rs b/src/core/graph.rs index 4f95aac..76fba01 100644 --- a/src/core/graph.rs +++ b/src/core/graph.rs @@ -12,6 +12,12 @@ pub struct BranchTreeNode { pub children: Vec<BranchTreeNode>, } +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct BranchLineageNode { + pub branch_name: String, + pub pull_request_number: Option<u64>, +} + pub struct BranchGraph<'a> { state: &'a DigState, } @@ -21,18 +27,27 @@ impl<'a> BranchGraph<'a> { Self { state } } - pub fn lineage(&self, branch_name: &str, trunk_branch: &str) -> Vec<String> { + pub fn lineage(&self, branch_name: &str, trunk_branch: &str) -> Vec<BranchLineageNode> { let Some(mut current_node) = self.state.find_branch_by_name(branch_name) else { - return vec![branch_name.to_string()]; + return vec![BranchLineageNode { + branch_name: branch_name.to_string(), + pull_request_number: None, + }]; }; - let mut lineage = vec![current_node.branch_name.clone()]; + let mut lineage = vec![BranchLineageNode { + branch_name: current_node.branch_name.clone(), + pull_request_number: current_node.pull_request.as_ref().map(|pr| pr.number), + }]; loop { match ¤t_node.parent { ParentRef::Trunk => { if current_node.branch_name != trunk_branch { - lineage.push(trunk_branch.to_string()); + lineage.push(BranchLineageNode { + branch_name: trunk_branch.to_string(), + pull_request_number: None, + }); } break; } @@ -41,7 +56,10 @@ impl<'a> BranchGraph<'a> { break; }; - lineage.push(parent_node.branch_name.clone()); + lineage.push(BranchLineageNode { + branch_name: parent_node.branch_name.clone(), + pull_request_number: parent_node.pull_request.as_ref().map(|pr| pr.number), + }); current_node = parent_node; } } @@ -141,9 +159,9 @@ impl<'a> BranchGraph<'a> { #[cfg(test)] mod tests { - use super::{BranchGraph, BranchTreeNode}; + use super::{BranchGraph, BranchLineageNode, BranchTreeNode}; use crate::core::store::types::{DIG_STATE_VERSION, DigState}; - use crate::core::store::{BranchNode, ParentRef}; + use crate::core::store::{BranchNode, ParentRef, TrackedPullRequest}; use uuid::Uuid; fn fixture_state() -> (DigState, Uuid, Uuid, Uuid) { @@ -163,6 +181,7 @@ mod tests { fork_point_oid: "abc123".into(), head_oid_at_creation: "abc123".into(), created_at_unix_secs: 1, + pull_request: Some(TrackedPullRequest { number: 42 }), archived: false, }, BranchNode { @@ -173,6 +192,7 @@ mod tests { fork_point_oid: "def456".into(), head_oid_at_creation: "def456".into(), created_at_unix_secs: 2, + pull_request: None, archived: false, }, BranchNode { @@ -183,6 +203,7 @@ mod tests { fork_point_oid: "fedcba".into(), head_oid_at_creation: "fedcba".into(), created_at_unix_secs: 3, + pull_request: None, archived: false, }, ], @@ -201,9 +222,18 @@ mod tests { assert_eq!( graph.lineage("feature/api-followup", "main"), vec![ - "feature/api-followup".to_string(), - "feature/api".to_string(), - "main".to_string() + BranchLineageNode { + branch_name: "feature/api-followup".to_string(), + pull_request_number: None, + }, + BranchLineageNode { + branch_name: "feature/api".to_string(), + pull_request_number: Some(42), + }, + BranchLineageNode { + branch_name: "main".to_string(), + pull_request_number: None, + } ] ); } diff --git a/src/core/init.rs b/src/core/init.rs index a43eac6..4842f51 100644 --- a/src/core/init.rs +++ b/src/core/init.rs @@ -3,6 +3,7 @@ use std::process::ExitStatus; use crate::core::git; use crate::core::graph::BranchGraph; +use crate::core::graph::BranchLineageNode; use crate::core::store::{StoreInitialization, open_or_initialize}; #[derive(Debug, Clone, Default, PartialEq, Eq)] @@ -12,7 +13,7 @@ pub struct InitOptions {} pub struct InitOutcome { pub status: ExitStatus, pub created_git_repo: bool, - pub lineage: Vec<String>, + pub lineage: Vec<BranchLineageNode>, pub store_initialization: StoreInitialization, } @@ -43,7 +44,7 @@ pub fn run(_: &InitOptions) -> io::Result<InitOutcome> { #[cfg(test)] mod tests { - use crate::core::graph::BranchGraph; + use crate::core::graph::{BranchGraph, BranchLineageNode}; use crate::core::store::StoreInitialization; use crate::core::store::types::DigState; use std::process::{Command, Stdio}; @@ -63,6 +64,12 @@ mod tests { }; assert!(outcome.created_git_repo); - assert_eq!(outcome.lineage, vec!["main".to_string()]); + assert_eq!( + outcome.lineage, + vec![BranchLineageNode { + branch_name: "main".to_string(), + pull_request_number: None, + }] + ); } } diff --git a/src/core/mod.rs b/src/core/mod.rs index 2451488..cc78ce3 100644 --- a/src/core/mod.rs +++ b/src/core/mod.rs @@ -3,11 +3,13 @@ pub(crate) mod branch; pub(crate) mod clean; pub(crate) mod commit; pub(crate) mod deleted_local; +pub(crate) mod gh; pub(crate) mod git; pub(crate) mod graph; pub(crate) mod init; pub(crate) mod merge; pub(crate) mod orphan; +pub(crate) mod pr; pub(crate) mod reparent; pub(crate) mod restack; pub(crate) mod store; diff --git a/src/core/pr.rs b/src/core/pr.rs new file mode 100644 index 0000000..5b5b85e --- /dev/null +++ b/src/core/pr.rs @@ -0,0 +1,469 @@ +use std::collections::HashMap; +use std::io; +use std::process::ExitStatus; + +use crate::core::gh::{self, CreatePullRequestOptions, PullRequestDetails, PullRequestSummary}; +use crate::core::git; +use crate::core::graph::BranchGraph; +use crate::core::store::{ + BranchPullRequestTrackedSource, TrackedPullRequest, open_initialized, + record_branch_pull_request_tracked, +}; +use crate::core::tree::{self, TreeOptions}; +use crate::core::workflow; + +#[derive(Debug, Clone, Default, PartialEq, Eq)] +pub struct PrOptions { + pub title: Option<String>, + pub body: Option<String>, + pub draft: bool, + pub push_if_needed: bool, +} + +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum PrOutcomeKind { + AlreadyTracked, + Created, + Adopted, +} + +#[derive(Debug)] +pub struct PrOutcome { + pub status: ExitStatus, + pub kind: PrOutcomeKind, + pub branch_name: String, + pub base_branch_name: String, + pub pull_request: TrackedPullRequest, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct TrackedPullRequestListNode { + pub pull_request: PullRequestDetails, + pub children: Vec<TrackedPullRequestListNode>, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct TrackedPullRequestListView { + pub root_label: Option<String>, + pub roots: Vec<TrackedPullRequestListNode>, +} + +#[derive(Debug)] +pub struct PrListOutcome { + pub status: ExitStatus, + pub view: TrackedPullRequestListView, + pub pull_requests: Vec<PullRequestDetails>, +} + +#[derive(Debug, Clone, PartialEq, Eq)] +enum PrTrackingAction { + Create, + Adopt(PullRequestSummary), +} + +pub fn run(options: &PrOptions) -> io::Result<PrOutcome> { + let mut session = open_initialized("dig is not initialized; run 'dig init' first")?; + workflow::ensure_no_pending_operation(&session.paths, "pr")?; + git::ensure_no_in_progress_operations(&session.repo, "pr")?; + + let branch_name = git::current_branch_name_if_any()?.ok_or_else(|| { + io::Error::other("dig pr requires a named branch; detached HEAD is not supported") + })?; + let node = session + .state + .find_branch_by_name(&branch_name) + .cloned() + .ok_or_else(|| { + io::Error::other(format!("branch '{}' is not tracked by dig", branch_name)) + })?; + + let base_branch_name = BranchGraph::new(&session.state) + .parent_branch_name(&node, &session.config.trunk_branch) + .ok_or_else(|| { + io::Error::other(format!( + "tracked parent branch for '{}' was not found", + branch_name + )) + })?; + + if let Some(pull_request) = node.pull_request.clone() { + return Ok(PrOutcome { + status: git::success_status()?, + kind: PrOutcomeKind::AlreadyTracked, + branch_name, + base_branch_name, + pull_request, + }); + } + + let open_pull_requests = gh::list_open_pull_requests_for_head(&branch_name)?; + match resolve_tracking_action(&branch_name, &base_branch_name, &open_pull_requests)? { + PrTrackingAction::Create => { + if let Some(push_target) = git::branch_push_target_if_needed(&branch_name)? { + if !options.push_if_needed { + return Err(io::Error::other(format!( + "branch '{}' is not pushed to '{}'", + push_target.branch_name, push_target.remote_name + ))); + } + + let push_output = git::push_branch_to_remote(&push_target)?; + if !push_output.status.success() { + let combined_output = push_output.combined_output(); + return Err(io::Error::other(if combined_output.is_empty() { + format!( + "git push to '{}' failed for branch '{}'", + push_target.remote_name, push_target.branch_name + ) + } else { + format!( + "git push to '{}' failed for branch '{}': {}", + push_target.remote_name, push_target.branch_name, combined_output + ) + })); + } + } + + let created_pull_request = gh::create_pull_request(&CreatePullRequestOptions { + base_branch_name: base_branch_name.clone(), + title: options.title.clone(), + body: options.body.clone(), + draft: options.draft, + })?; + let pull_request = TrackedPullRequest { + number: created_pull_request.number, + }; + + record_branch_pull_request_tracked( + &mut session, + node.id, + node.branch_name.clone(), + pull_request.clone(), + BranchPullRequestTrackedSource::Created, + )?; + + Ok(PrOutcome { + status: git::success_status()?, + kind: PrOutcomeKind::Created, + branch_name, + base_branch_name, + pull_request, + }) + } + PrTrackingAction::Adopt(existing_pull_request) => { + let pull_request = TrackedPullRequest { + number: existing_pull_request.number, + }; + + record_branch_pull_request_tracked( + &mut session, + node.id, + node.branch_name.clone(), + pull_request.clone(), + BranchPullRequestTrackedSource::Adopted, + )?; + + Ok(PrOutcome { + status: git::success_status()?, + kind: PrOutcomeKind::Adopted, + branch_name, + base_branch_name, + pull_request, + }) + } + } +} + +pub fn current_branch_push_target_for_create() -> io::Result<Option<git::BranchPushTarget>> { + let session = open_initialized("dig is not initialized; run 'dig init' first")?; + workflow::ensure_no_pending_operation(&session.paths, "pr")?; + git::ensure_no_in_progress_operations(&session.repo, "pr")?; + + let branch_name = git::current_branch_name_if_any()?.ok_or_else(|| { + io::Error::other("dig pr requires a named branch; detached HEAD is not supported") + })?; + let node = session + .state + .find_branch_by_name(&branch_name) + .ok_or_else(|| { + io::Error::other(format!("branch '{}' is not tracked by dig", branch_name)) + })?; + + if node.pull_request.is_some() { + return Ok(None); + } + + let base_branch_name = BranchGraph::new(&session.state) + .parent_branch_name(node, &session.config.trunk_branch) + .ok_or_else(|| { + io::Error::other(format!( + "tracked parent branch for '{}' was not found", + branch_name + )) + })?; + + let open_pull_requests = gh::list_open_pull_requests_for_head(&branch_name)?; + match resolve_tracking_action(&branch_name, &base_branch_name, &open_pull_requests)? { + PrTrackingAction::Create => git::branch_push_target_if_needed(&branch_name), + PrTrackingAction::Adopt(_) => Ok(None), + } +} + +pub fn list_open_tracked_pull_requests() -> io::Result<PrListOutcome> { + open_initialized("dig is not initialized; run 'dig init' first")?; + let open_pull_requests = gh::list_open_pull_requests()?; + let pull_request_lookup = open_pull_requests + .into_iter() + .map(|pull_request| (pull_request.number, pull_request)) + .collect::<HashMap<_, _>>(); + let tree_outcome = tree::run(&TreeOptions::default())?; + let roots = tree_outcome + .view + .roots + .iter() + .flat_map(|node| build_pull_request_list_nodes(node, &pull_request_lookup)) + .collect::<Vec<_>>(); + let mut ordered_pull_requests = Vec::new(); + collect_pull_requests_in_order(&roots, &mut ordered_pull_requests); + + Ok(PrListOutcome { + status: tree_outcome.status, + view: TrackedPullRequestListView { + root_label: tree_outcome.view.root_label.map(|label| label.branch_name), + roots, + }, + pull_requests: ordered_pull_requests, + }) +} + +pub fn open_current_pull_request_in_browser() -> io::Result<()> { + let session = open_initialized("dig is not initialized; run 'dig init' first")?; + let branch_name = git::current_branch_name_if_any()?.ok_or_else(|| { + io::Error::other("dig pr requires a named branch; detached HEAD is not supported") + })?; + + if let Some(pull_request) = session + .state + .find_branch_by_name(&branch_name) + .and_then(|node| node.pull_request.as_ref()) + { + return gh::open_pull_request_in_browser(pull_request.number); + } + + gh::open_current_pull_request_in_browser() +} + +pub fn open_pull_request_in_browser(number: u64) -> io::Result<()> { + gh::open_pull_request_in_browser(number) +} + +pub fn open_pull_requests_in_browser(pull_requests: &[PullRequestDetails]) -> io::Result<()> { + for pull_request in pull_requests { + gh::open_pull_request_in_browser(pull_request.number)?; + } + + Ok(()) +} + +fn resolve_tracking_action( + branch_name: &str, + base_branch_name: &str, + open_pull_requests: &[PullRequestSummary], +) -> io::Result<PrTrackingAction> { + match open_pull_requests { + [] => Ok(PrTrackingAction::Create), + [pull_request] if pull_request.base_ref_name == base_branch_name => { + Ok(PrTrackingAction::Adopt(pull_request.clone())) + } + [pull_request] => Err(io::Error::other(format!( + "branch '{}' already has open pull request #{} into '{}', but dig expects base '{}'", + branch_name, pull_request.number, pull_request.base_ref_name, base_branch_name + ))), + _ => Err(io::Error::other(format!( + "branch '{}' has multiple open pull requests on GitHub; dig pr cannot choose automatically", + branch_name + ))), + } +} + +fn build_pull_request_list_nodes( + node: &crate::core::tree::TreeNode, + pull_request_lookup: &HashMap<u64, PullRequestDetails>, +) -> Vec<TrackedPullRequestListNode> { + let children = node + .children + .iter() + .flat_map(|child| build_pull_request_list_nodes(child, pull_request_lookup)) + .collect::<Vec<_>>(); + + let Some(number) = node.pull_request_number else { + return children; + }; + let Some(pull_request) = pull_request_lookup.get(&number) else { + return children; + }; + + vec![TrackedPullRequestListNode { + pull_request: pull_request.clone(), + children, + }] +} + +fn collect_pull_requests_in_order( + nodes: &[TrackedPullRequestListNode], + ordered_pull_requests: &mut Vec<PullRequestDetails>, +) { + for node in nodes { + ordered_pull_requests.push(node.pull_request.clone()); + collect_pull_requests_in_order(&node.children, ordered_pull_requests); + } +} + +#[cfg(test)] +mod tests { + use super::{ + PrTrackingAction, TrackedPullRequestListNode, build_pull_request_list_nodes, + collect_pull_requests_in_order, resolve_tracking_action, + }; + use crate::core::gh::PullRequestDetails; + use crate::core::gh::PullRequestSummary; + use crate::core::tree::TreeNode; + use std::collections::HashMap; + + #[test] + fn resolves_create_when_no_open_pull_requests_exist() { + let action = resolve_tracking_action("feat/auth", "main", &[]).unwrap(); + + assert_eq!(action, PrTrackingAction::Create); + } + + #[test] + fn resolves_adopt_for_matching_open_pull_request() { + let action = resolve_tracking_action( + "feat/auth", + "main", + &[PullRequestSummary { + number: 123, + base_ref_name: "main".into(), + url: "https://github.com/acme/dig/pull/123".into(), + }], + ) + .unwrap(); + + assert_eq!( + action, + PrTrackingAction::Adopt(PullRequestSummary { + number: 123, + base_ref_name: "main".into(), + url: "https://github.com/acme/dig/pull/123".into(), + }) + ); + } + + #[test] + fn rejects_open_pull_request_with_mismatched_base() { + let error = resolve_tracking_action( + "feat/auth", + "main", + &[PullRequestSummary { + number: 123, + base_ref_name: "develop".into(), + url: "https://github.com/acme/dig/pull/123".into(), + }], + ) + .unwrap_err(); + + assert!(error.to_string().contains("expects base 'main'")); + } + + #[test] + fn rejects_multiple_open_pull_requests() { + let error = resolve_tracking_action( + "feat/auth", + "main", + &[ + PullRequestSummary { + number: 123, + base_ref_name: "main".into(), + url: "https://github.com/acme/dig/pull/123".into(), + }, + PullRequestSummary { + number: 124, + base_ref_name: "main".into(), + url: "https://github.com/acme/dig/pull/124".into(), + }, + ], + ) + .unwrap_err(); + + assert!(error.to_string().contains("multiple open pull requests")); + } + + #[test] + fn builds_pull_request_list_nodes_by_collapsing_non_pr_branches() { + let nodes = build_pull_request_list_nodes( + &TreeNode { + branch_name: "feat/auth".into(), + is_current: false, + pull_request_number: None, + children: vec![TreeNode { + branch_name: "feat/auth-ui".into(), + is_current: false, + pull_request_number: Some(123), + children: vec![], + }], + }, + &HashMap::from([( + 123, + PullRequestDetails { + number: 123, + title: "Auth UI".into(), + url: "https://github.com/acme/dig/pull/123".into(), + }, + )]), + ); + + assert_eq!( + nodes, + vec![TrackedPullRequestListNode { + pull_request: PullRequestDetails { + number: 123, + title: "Auth UI".into(), + url: "https://github.com/acme/dig/pull/123".into(), + }, + children: vec![], + }] + ); + } + + #[test] + fn collects_pull_requests_in_tree_order() { + let mut ordered_pull_requests = Vec::new(); + collect_pull_requests_in_order( + &[TrackedPullRequestListNode { + pull_request: PullRequestDetails { + number: 123, + title: "Auth".into(), + url: "https://github.com/acme/dig/pull/123".into(), + }, + children: vec![TrackedPullRequestListNode { + pull_request: PullRequestDetails { + number: 124, + title: "Auth UI".into(), + url: "https://github.com/acme/dig/pull/124".into(), + }, + children: vec![], + }], + }], + &mut ordered_pull_requests, + ); + + assert_eq!( + ordered_pull_requests + .iter() + .map(|pr| pr.number) + .collect::<Vec<_>>(), + vec![123, 124] + ); + } +} diff --git a/src/core/restack.rs b/src/core/restack.rs index baf6deb..fc72f6b 100644 --- a/src/core/restack.rs +++ b/src/core/restack.rs @@ -421,6 +421,7 @@ mod tests { fork_point_oid: "root".into(), head_oid_at_creation: "root".into(), created_at_unix_secs: 1, + pull_request: None, archived: false, }, BranchNode { @@ -431,6 +432,7 @@ mod tests { fork_point_oid: "auth".into(), head_oid_at_creation: "auth".into(), created_at_unix_secs: 2, + pull_request: None, archived: false, }, BranchNode { @@ -441,6 +443,7 @@ mod tests { fork_point_oid: "api".into(), head_oid_at_creation: "api".into(), created_at_unix_secs: 3, + pull_request: None, archived: false, }, ], @@ -504,6 +507,7 @@ mod tests { fork_point_oid: "root".into(), head_oid_at_creation: "root".into(), created_at_unix_secs: 1, + pull_request: None, archived: false, }, BranchNode { @@ -514,6 +518,7 @@ mod tests { fork_point_oid: "auth".into(), head_oid_at_creation: "auth".into(), created_at_unix_secs: 2, + pull_request: None, archived: false, }, BranchNode { @@ -524,6 +529,7 @@ mod tests { fork_point_oid: "api".into(), head_oid_at_creation: "api".into(), created_at_unix_secs: 3, + pull_request: None, archived: false, }, BranchNode { @@ -534,6 +540,7 @@ mod tests { fork_point_oid: "platform-root".into(), head_oid_at_creation: "platform-root".into(), created_at_unix_secs: 4, + pull_request: None, archived: false, }, ], diff --git a/src/core/store/mod.rs b/src/core/store/mod.rs index 334199a..3a4b7b4 100644 --- a/src/core/store/mod.rs +++ b/src/core/store/mod.rs @@ -13,16 +13,18 @@ pub(crate) use config::load_config; pub(crate) use events::append_event; pub(crate) use fs::dig_paths; pub(crate) use mutations::{ - record_branch_adopted, record_branch_archived, record_branch_created, record_branch_reparented, + record_branch_adopted, record_branch_archived, record_branch_created, + record_branch_pull_request_tracked, record_branch_reparented, }; pub(crate) use operation::{clear_operation, load_operation, save_operation}; pub(crate) use session::{StoreSession, open_initialized, open_or_initialize}; pub(crate) use state::{load_state, save_state}; pub(crate) use types::{ BranchAdoptedEvent, BranchArchiveReason, BranchArchivedEvent, BranchCreatedEvent, BranchNode, - BranchReparentedEvent, DigConfig, DigEvent, ParentRef, PendingAdoptOperation, - PendingCleanCandidate, PendingCleanCandidateKind, PendingCleanOperation, PendingCommitEntry, - PendingCommitOperation, PendingMergeOperation, PendingOperationKind, PendingOperationState, - PendingOrphanOperation, PendingReparentOperation, PendingSyncOperation, PendingSyncPhase, + BranchPullRequestTrackedEvent, BranchPullRequestTrackedSource, BranchReparentedEvent, + DigConfig, DigEvent, ParentRef, PendingAdoptOperation, PendingCleanCandidate, + PendingCleanCandidateKind, PendingCleanOperation, PendingCommitEntry, PendingCommitOperation, + PendingMergeOperation, PendingOperationKind, PendingOperationState, PendingOrphanOperation, + PendingReparentOperation, PendingSyncOperation, PendingSyncPhase, TrackedPullRequest, now_unix_timestamp_secs, }; diff --git a/src/core/store/mutations.rs b/src/core/store/mutations.rs index 64bdc7c..f32be46 100644 --- a/src/core/store/mutations.rs +++ b/src/core/store/mutations.rs @@ -4,7 +4,8 @@ use uuid::Uuid; use super::{ BranchAdoptedEvent, BranchArchiveReason, BranchArchivedEvent, BranchCreatedEvent, BranchNode, - BranchReparentedEvent, DigEvent, ParentRef, now_unix_timestamp_secs, save_state, + BranchPullRequestTrackedEvent, BranchPullRequestTrackedSource, BranchReparentedEvent, DigEvent, + ParentRef, TrackedPullRequest, now_unix_timestamp_secs, save_state, }; use crate::core::store::append_event; use crate::core::store::session::StoreSession; @@ -75,3 +76,26 @@ pub fn record_branch_archived( }), ) } + +pub fn record_branch_pull_request_tracked( + session: &mut StoreSession, + branch_id: Uuid, + branch_name: String, + pull_request: TrackedPullRequest, + source: BranchPullRequestTrackedSource, +) -> io::Result<()> { + session + .state + .track_pull_request(branch_id, pull_request.clone())?; + save_state(&session.paths, &session.state)?; + append_event( + &session.paths, + &DigEvent::BranchPullRequestTracked(BranchPullRequestTrackedEvent { + occurred_at_unix_secs: now_unix_timestamp_secs(), + branch_id, + branch_name, + pull_request, + source, + }), + ) +} diff --git a/src/core/store/types.rs b/src/core/store/types.rs index 56fdeaa..34fcfa6 100644 --- a/src/core/store/types.rs +++ b/src/core/store/types.rs @@ -99,6 +99,20 @@ impl DigState { Ok((old_parent, old_base_ref)) } + + pub fn track_pull_request( + &mut self, + node_id: Uuid, + pull_request: TrackedPullRequest, + ) -> io::Result<()> { + let node = self.find_branch_by_id_mut(node_id).ok_or_else(|| { + io::Error::new(io::ErrorKind::NotFound, "tracked branch was not found") + })?; + + node.pull_request = Some(pull_request); + + Ok(()) + } } #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] @@ -280,9 +294,16 @@ pub struct BranchNode { pub fork_point_oid: String, pub head_oid_at_creation: String, pub created_at_unix_secs: u64, + #[serde(default)] + pub pull_request: Option<TrackedPullRequest>, pub archived: bool, } +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct TrackedPullRequest { + pub number: u64, +} + #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] #[serde(tag = "kind", rename_all = "snake_case")] pub enum ParentRef { @@ -297,6 +318,7 @@ pub enum DigEvent { BranchAdopted(BranchAdoptedEvent), BranchArchived(BranchArchivedEvent), BranchReparented(BranchReparentedEvent), + BranchPullRequestTracked(BranchPullRequestTrackedEvent), } #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] @@ -338,6 +360,22 @@ pub struct BranchReparentedEvent { pub new_base_ref: String, } +#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)] +#[serde(rename_all = "snake_case")] +pub enum BranchPullRequestTrackedSource { + Created, + Adopted, +} + +#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)] +pub struct BranchPullRequestTrackedEvent { + pub occurred_at_unix_secs: u64, + pub branch_id: Uuid, + pub branch_name: String, + pub pull_request: TrackedPullRequest, + pub source: BranchPullRequestTrackedSource, +} + pub fn now_unix_timestamp_secs() -> u64 { SystemTime::now() .duration_since(UNIX_EPOCH) @@ -348,10 +386,11 @@ pub fn now_unix_timestamp_secs() -> u64 { #[cfg(test)] mod tests { use super::{ - BranchAdoptedEvent, BranchArchiveReason, BranchArchivedEvent, BranchNode, DigConfig, - DigEvent, DigState, ParentRef, PendingCommitOperation, PendingOperationKind, - PendingOperationState, PendingOrphanOperation, PendingReparentOperation, - PendingSyncOperation, PendingSyncPhase, + BranchAdoptedEvent, BranchArchiveReason, BranchArchivedEvent, BranchNode, + BranchPullRequestTrackedEvent, BranchPullRequestTrackedSource, DigConfig, DigEvent, + DigState, ParentRef, PendingCommitOperation, PendingOperationKind, PendingOperationState, + PendingOrphanOperation, PendingReparentOperation, PendingSyncOperation, PendingSyncPhase, + TrackedPullRequest, }; use crate::core::restack::RestackAction; use uuid::Uuid; @@ -366,6 +405,7 @@ mod tests { fork_point_oid: "abc123".into(), head_oid_at_creation: "abc123".into(), created_at_unix_secs: 1, + pull_request: None, archived: false, }; @@ -424,6 +464,7 @@ mod tests { fork_point_oid: "abc123".into(), head_oid_at_creation: "def456".into(), created_at_unix_secs: 1, + pull_request: None, archived: false, }, }); @@ -434,6 +475,42 @@ mod tests { assert!(serialized.contains("\"branch_name\":\"feature/api\"")); } + #[test] + fn serializes_branch_pull_request_tracked_event() { + let event = DigEvent::BranchPullRequestTracked(BranchPullRequestTrackedEvent { + occurred_at_unix_secs: 1, + branch_id: Uuid::nil(), + branch_name: "feature/api".into(), + pull_request: TrackedPullRequest { number: 123 }, + source: BranchPullRequestTrackedSource::Created, + }); + + let serialized = serde_json::to_string(&event).unwrap(); + + assert!(serialized.contains("\"type\":\"branch_pull_request_tracked\"")); + assert!(serialized.contains("\"source\":\"created\"")); + assert!(serialized.contains("\"number\":123")); + } + + #[test] + fn deserializes_legacy_branch_node_without_pull_request() { + let node = serde_json::from_str::<BranchNode>( + r#"{ + "id":"00000000-0000-0000-0000-000000000000", + "branch_name":"feature/api", + "parent":{"kind":"trunk"}, + "base_ref":"main", + "fork_point_oid":"abc123", + "head_oid_at_creation":"abc123", + "created_at_unix_secs":1, + "archived":false + }"#, + ) + .unwrap(); + + assert_eq!(node.pull_request, None); + } + #[test] fn advances_pending_operation_queue_after_success() { let first_action = RestackAction { diff --git a/src/core/tree.rs b/src/core/tree.rs index 03108c6..f5cc0ec 100644 --- a/src/core/tree.rs +++ b/src/core/tree.rs @@ -17,12 +17,14 @@ pub struct TreeOptions { pub struct TreeLabel { pub branch_name: String, pub is_current: bool, + pub pull_request_number: Option<u64>, } #[derive(Debug, Clone, PartialEq, Eq)] pub struct TreeNode { pub branch_name: String, pub is_current: bool, + pub pull_request_number: Option<u64>, pub children: Vec<TreeNode>, } @@ -97,6 +99,7 @@ fn build_tree_view(state: &DigState, trunk_branch: &str, current_branch: Option< root_label: Some(TreeLabel { branch_name: trunk_branch.to_string(), is_current: current_branch == Some(trunk_branch), + pull_request_number: None, }), roots: root_nodes .into_iter() @@ -142,6 +145,7 @@ fn filter_tree_view(view: TreeView, requested_branch: Option<&str>) -> io::Resul root_label: Some(TreeLabel { branch_name: selected_node.branch_name.clone(), is_current: selected_node.is_current, + pull_request_number: selected_node.pull_request_number, }), roots: selected_node.children.clone(), }) @@ -205,6 +209,7 @@ fn build_tree_node( TreeNode { branch_name: node.branch_name.clone(), is_current: current_branch == Some(node.branch_name.as_str()), + pull_request_number: node.pull_request.as_ref().map(|pr| pr.number), children, } } @@ -228,6 +233,7 @@ fn prune_to_branch_path(node: &TreeNode, branch_name: &str) -> Option<TreeNode> prune_to_branch_path(child, branch_name).map(|pruned_child| TreeNode { branch_name: node.branch_name.clone(), is_current: node.is_current, + pull_request_number: node.pull_request_number, children: vec![pruned_child], }) }) @@ -290,6 +296,7 @@ mod tests { fork_point_oid: "1".into(), head_oid_at_creation: "1".into(), created_at_unix_secs: 1, + pull_request: Some(crate::core::store::TrackedPullRequest { number: 101 }), archived: false, }, BranchNode { @@ -300,6 +307,7 @@ mod tests { fork_point_oid: "2".into(), head_oid_at_creation: "2".into(), created_at_unix_secs: 2, + pull_request: Some(crate::core::store::TrackedPullRequest { number: 102 }), archived: false, }, BranchNode { @@ -310,6 +318,7 @@ mod tests { fork_point_oid: "3".into(), head_oid_at_creation: "3".into(), created_at_unix_secs: 3, + pull_request: None, archived: false, }, ], @@ -321,20 +330,24 @@ mod tests { root_label: Some(TreeLabel { branch_name: "main".into(), is_current: false, + pull_request_number: None, }), roots: vec![ TreeNode { branch_name: "feat/auth".into(), is_current: false, + pull_request_number: Some(101), children: vec![TreeNode { branch_name: "feat/auth-api".into(), is_current: true, + pull_request_number: Some(102), children: vec![], }], }, TreeNode { branch_name: "feat/billing".into(), is_current: false, + pull_request_number: None, children: vec![], }, ], @@ -348,23 +361,28 @@ mod tests { root_label: Some(TreeLabel { branch_name: "main".into(), is_current: false, + pull_request_number: None, }), roots: vec![TreeNode { branch_name: "feat/auth".into(), is_current: false, + pull_request_number: Some(101), children: vec![ TreeNode { branch_name: "feat/auth-api".into(), is_current: false, + pull_request_number: Some(102), children: vec![TreeNode { branch_name: "feat/auth-api-tests".into(), is_current: false, + pull_request_number: Some(103), children: vec![], }], }, TreeNode { branch_name: "feat/auth-ui".into(), is_current: true, + pull_request_number: None, children: vec![], }, ], @@ -377,20 +395,24 @@ mod tests { root_label: Some(TreeLabel { branch_name: "feat/auth".into(), is_current: false, + pull_request_number: Some(101), }), roots: vec![ TreeNode { branch_name: "feat/auth-api".into(), is_current: false, + pull_request_number: Some(102), children: vec![TreeNode { branch_name: "feat/auth-api-tests".into(), is_current: false, + pull_request_number: Some(103), children: vec![], }], }, TreeNode { branch_name: "feat/auth-ui".into(), is_current: true, + pull_request_number: None, children: vec![], }, ], @@ -404,24 +426,29 @@ mod tests { root_label: Some(TreeLabel { branch_name: "main".into(), is_current: false, + pull_request_number: None, }), roots: vec![ TreeNode { branch_name: "feat/auth".into(), is_current: false, + pull_request_number: Some(101), children: vec![ TreeNode { branch_name: "feat/auth-api".into(), is_current: false, + pull_request_number: Some(102), children: vec![TreeNode { branch_name: "feat/auth-api-tests".into(), is_current: false, + pull_request_number: Some(103), children: vec![], }], }, TreeNode { branch_name: "feat/auth-ui".into(), is_current: false, + pull_request_number: None, children: vec![], }, ], @@ -429,6 +456,7 @@ mod tests { TreeNode { branch_name: "feat/billing".into(), is_current: false, + pull_request_number: None, children: vec![], }, ], @@ -440,16 +468,20 @@ mod tests { root_label: Some(TreeLabel { branch_name: "main".into(), is_current: false, + pull_request_number: None, }), roots: vec![TreeNode { branch_name: "feat/auth".into(), is_current: false, + pull_request_number: Some(101), children: vec![TreeNode { branch_name: "feat/auth-api".into(), is_current: true, + pull_request_number: Some(102), children: vec![TreeNode { branch_name: "feat/auth-api-tests".into(), is_current: false, + pull_request_number: Some(103), children: vec![], }], }], diff --git a/tests/branch.rs b/tests/branch.rs index 5142ca3..54b2586 100644 --- a/tests/branch.rs +++ b/tests/branch.rs @@ -1,9 +1,21 @@ mod support; +use std::path::{Path, PathBuf}; + use support::{ - dig_ok, find_node, initialize_main_repo, load_state_json, strip_ansi, with_temp_repo, + dig_ok, dig_ok_with_env, find_node, initialize_main_repo, install_fake_executable, + load_state_json, path_with_prepend, strip_ansi, with_temp_repo, }; +fn install_fake_gh(repo: &Path, script: &str) -> (PathBuf, String) { + let bin_dir = repo.join("fake-bin"); + install_fake_executable(&bin_dir, "gh", script); + + let path = path_with_prepend(&bin_dir); + + (bin_dir, path) +} + #[test] fn branch_command_renders_marked_lineage_and_tracks_parent() { with_temp_repo("dig-branch-cli", |repo| { @@ -38,3 +50,36 @@ fn init_reuses_marked_lineage_output_for_current_branch() { assert!(stdout.contains("✓ feat/auth\n│ \n* main")); }); } + +#[test] +fn init_lineage_shows_tracked_pull_request_numbers() { + with_temp_repo("dig-branch-cli", |repo| { + initialize_main_repo(repo); + dig_ok(repo, &["init"]); + dig_ok(repo, &["branch", "feat/auth"]); + + let (_, path) = install_fake_gh( + repo, + r#"#!/bin/sh +set -eu +if [ "$1" = "pr" ] && [ "$2" = "list" ]; then + printf '[]\n' + exit 0 +fi +if [ "$1" = "pr" ] && [ "$2" = "create" ]; then + printf 'https://github.com/acme/dig/pull/123\n' + exit 0 +fi +echo "unexpected gh args: $*" >&2 +exit 1 +"#, + ); + + dig_ok_with_env(repo, &["pr"], &[("PATH", path.as_str())]); + + let output = dig_ok(repo, &["init"]); + let stdout = strip_ansi(&String::from_utf8(output.stdout).unwrap()); + + assert!(stdout.contains("✓ feat/auth (#123)\n│ \n* main")); + }); +} diff --git a/tests/pr.rs b/tests/pr.rs new file mode 100644 index 0000000..12e732c --- /dev/null +++ b/tests/pr.rs @@ -0,0 +1,762 @@ +mod support; + +use std::fs; +use std::path::{Path, PathBuf}; + +use support::{ + dig_ok, dig_ok_with_env, dig_with_input_and_env, find_node, git_binary_path, git_ok, + git_stdout, initialize_main_repo, install_fake_executable, load_events_json, load_state_json, + path_with_prepend, strip_ansi, with_temp_repo, +}; + +fn install_fake_gh(repo: &Path, script: &str) -> (PathBuf, String, String) { + let bin_dir = repo.join("fake-bin"); + install_fake_executable(&bin_dir, "gh", script); + + let path = path_with_prepend(&bin_dir); + let log_path = repo.join("gh.log").display().to_string(); + + (bin_dir, path, log_path) +} + +fn clear_log(path: &str) { + fs::write(path, "").unwrap(); +} + +fn initialize_origin_remote(repo: &Path) { + git_ok(repo, &["init", "--bare", "origin.git"]); + git_ok(repo, &["remote", "add", "origin", "origin.git"]); +} + +#[test] +fn pr_creates_root_pull_request_tracks_number_and_updates_tree() { + with_temp_repo("dig-pr-cli", |repo| { + initialize_main_repo(repo); + dig_ok(repo, &["init"]); + dig_ok(repo, &["branch", "feat/auth"]); + + let (_, path, log_path) = install_fake_gh( + repo, + r#"#!/bin/sh +set -eu +printf '%s\n' "$*" >> "$DIG_TEST_GH_LOG" +if [ "$1" = "pr" ] && [ "$2" = "list" ]; then + printf '[]\n' + exit 0 +fi +if [ "$1" = "pr" ] && [ "$2" = "create" ]; then + printf 'https://github.com/acme/dig/pull/123\n' + exit 0 +fi +echo "unexpected gh args: $*" >&2 +exit 1 +"#, + ); + + let output = dig_ok_with_env( + repo, + &[ + "pr", + "--title", + "feat-auth", + "--body", + "body-text", + "--draft", + ], + &[ + ("PATH", path.as_str()), + ("DIG_TEST_GH_LOG", log_path.as_str()), + ], + ); + let stdout = strip_ansi(&String::from_utf8(output.stdout).unwrap()); + + assert!(stdout.contains("Created pull request #123 for 'feat/auth' into 'main'.")); + assert_eq!( + stdout + .matches("https://github.com/acme/dig/pull/123") + .count(), + 1 + ); + + let state = load_state_json(repo); + let node = find_node(&state, "feat/auth").unwrap(); + assert_eq!(node["pull_request"]["number"], 123); + + let tree_output = dig_ok(repo, &["tree"]); + let tree_stdout = strip_ansi(&String::from_utf8(tree_output.stdout).unwrap()); + assert!(tree_stdout.contains("feat/auth (#123)")); + + let events = load_events_json(repo); + assert!(events.iter().any(|event| { + event["type"].as_str() == Some("branch_pull_request_tracked") + && event["branch_name"].as_str() == Some("feat/auth") + && event["pull_request"]["number"].as_u64() == Some(123) + && event["source"].as_str() == Some("created") + })); + + let gh_log = fs::read_to_string(log_path).unwrap(); + assert!( + gh_log.contains("pr list --head feat/auth --state open --json number,baseRefName,url") + ); + assert!( + gh_log.contains("pr create --base main --title feat-auth --body body-text --draft") + ); + }); +} + +#[test] +fn pr_creates_child_pull_request_against_tracked_parent() { + with_temp_repo("dig-pr-cli", |repo| { + initialize_main_repo(repo); + dig_ok(repo, &["init"]); + dig_ok(repo, &["branch", "feat/auth"]); + dig_ok(repo, &["branch", "feat/auth-api"]); + + let (_, path, log_path) = install_fake_gh( + repo, + r#"#!/bin/sh +set -eu +printf '%s\n' "$*" >> "$DIG_TEST_GH_LOG" +if [ "$1" = "pr" ] && [ "$2" = "list" ]; then + printf '[]\n' + exit 0 +fi +if [ "$1" = "pr" ] && [ "$2" = "create" ]; then + printf 'https://github.com/acme/dig/pull/234\n' + exit 0 +fi +exit 1 +"#, + ); + + let output = dig_ok_with_env( + repo, + &["pr"], + &[ + ("PATH", path.as_str()), + ("DIG_TEST_GH_LOG", log_path.as_str()), + ], + ); + let stdout = strip_ansi(&String::from_utf8(output.stdout).unwrap()); + + assert!(stdout.contains("Created pull request #234 for 'feat/auth-api' into 'feat/auth'.")); + assert_eq!( + stdout + .matches("https://github.com/acme/dig/pull/234") + .count(), + 1 + ); + + let gh_log = fs::read_to_string(log_path).unwrap(); + assert!(gh_log.contains("pr create --base feat/auth")); + }); +} + +#[test] +fn pr_adopts_matching_open_pull_request_without_creating_another() { + with_temp_repo("dig-pr-cli", |repo| { + initialize_main_repo(repo); + dig_ok(repo, &["init"]); + dig_ok(repo, &["branch", "feat/auth"]); + + let (_, path, log_path) = install_fake_gh( + repo, + r#"#!/bin/sh +set -eu +printf '%s\n' "$*" >> "$DIG_TEST_GH_LOG" +if [ "$1" = "pr" ] && [ "$2" = "list" ]; then + printf '[{"number":345,"baseRefName":"main","url":"https://github.com/acme/dig/pull/345"}]\n' + exit 0 +fi +echo "unexpected gh args: $*" >&2 +exit 1 +"#, + ); + + let output = dig_ok_with_env( + repo, + &["pr"], + &[ + ("PATH", path.as_str()), + ("DIG_TEST_GH_LOG", log_path.as_str()), + ], + ); + let stdout = strip_ansi(&String::from_utf8(output.stdout).unwrap()); + + assert!( + stdout.contains("Tracking existing pull request #345 for 'feat/auth' into 'main'.") + ); + + let state = load_state_json(repo); + let node = find_node(&state, "feat/auth").unwrap(); + assert_eq!(node["pull_request"]["number"], 345); + + let events = load_events_json(repo); + assert!(events.iter().any(|event| { + event["type"].as_str() == Some("branch_pull_request_tracked") + && event["source"].as_str() == Some("adopted") + })); + + let gh_log = fs::read_to_string(log_path).unwrap(); + assert!(!gh_log.contains("pr create")); + }); +} + +#[test] +fn pr_is_idempotent_when_branch_already_tracks_pull_request() { + with_temp_repo("dig-pr-cli", |repo| { + initialize_main_repo(repo); + dig_ok(repo, &["init"]); + dig_ok(repo, &["branch", "feat/auth"]); + + let (bin_dir, path, log_path) = install_fake_gh( + repo, + r#"#!/bin/sh +set -eu +printf '%s\n' "$*" >> "$DIG_TEST_GH_LOG" +if [ "$1" = "pr" ] && [ "$2" = "list" ]; then + printf '[]\n' + exit 0 +fi +if [ "$1" = "pr" ] && [ "$2" = "create" ]; then + printf 'https://github.com/acme/dig/pull/456\n' + exit 0 +fi +exit 1 +"#, + ); + + dig_ok_with_env( + repo, + &["pr"], + &[ + ("PATH", path.as_str()), + ("DIG_TEST_GH_LOG", log_path.as_str()), + ], + ); + + install_fake_executable( + &bin_dir, + "gh", + r#"#!/bin/sh +set -eu +printf '%s\n' "$*" >> "$DIG_TEST_GH_LOG" +echo "gh should not have been called" >&2 +exit 99 +"#, + ); + + let output = dig_ok_with_env( + repo, + &["pr"], + &[ + ("PATH", path.as_str()), + ("DIG_TEST_GH_LOG", log_path.as_str()), + ], + ); + let stdout = strip_ansi(&String::from_utf8(output.stdout).unwrap()); + + assert!(stdout.contains("Branch 'feat/auth' already tracks pull request #456.")); + + let gh_log = fs::read_to_string(log_path).unwrap(); + assert_eq!(gh_log.lines().count(), 3); + }); +} + +#[test] +fn pr_with_view_only_opens_tracked_pull_request_in_browser() { + with_temp_repo("dig-pr-cli", |repo| { + initialize_main_repo(repo); + dig_ok(repo, &["init"]); + dig_ok(repo, &["branch", "feat/auth"]); + + let (bin_dir, path, log_path) = install_fake_gh( + repo, + r#"#!/bin/sh +set -eu +printf '%s\n' "$*" >> "$DIG_TEST_GH_LOG" +if [ "$1" = "pr" ] && [ "$2" = "list" ]; then + printf '[]\n' + exit 0 +fi +if [ "$1" = "pr" ] && [ "$2" = "create" ]; then + printf 'https://github.com/acme/dig/pull/456\n' + exit 0 +fi +exit 1 +"#, + ); + + dig_ok_with_env( + repo, + &["pr"], + &[ + ("PATH", path.as_str()), + ("DIG_TEST_GH_LOG", log_path.as_str()), + ], + ); + + clear_log(&log_path); + install_fake_executable( + &bin_dir, + "gh", + r#"#!/bin/sh +set -eu +printf '%s\n' "$*" >> "$DIG_TEST_GH_LOG" +if [ "$1" = "pr" ] && [ "$2" = "view" ] && [ "$3" = "456" ] && [ "$4" = "--web" ]; then + exit 0 +fi +echo "unexpected gh args: $*" >&2 +exit 1 +"#, + ); + + let output = dig_ok_with_env( + repo, + &["pr", "--view"], + &[ + ("PATH", path.as_str()), + ("DIG_TEST_GH_LOG", log_path.as_str()), + ], + ); + + assert!(String::from_utf8(output.stdout).unwrap().trim().is_empty()); + let gh_log = fs::read_to_string(log_path).unwrap(); + assert_eq!(gh_log.trim(), "pr view 456 --web"); + }); +} + +#[test] +fn pr_with_create_and_view_opens_browser_after_tracking() { + with_temp_repo("dig-pr-cli", |repo| { + initialize_main_repo(repo); + dig_ok(repo, &["init"]); + dig_ok(repo, &["branch", "feat/auth"]); + + let (_, path, log_path) = install_fake_gh( + repo, + r#"#!/bin/sh +set -eu +printf '%s\n' "$*" >> "$DIG_TEST_GH_LOG" +if [ "$1" = "pr" ] && [ "$2" = "list" ]; then + printf '[]\n' + exit 0 +fi +if [ "$1" = "pr" ] && [ "$2" = "create" ]; then + printf 'https://github.com/acme/dig/pull/123\n' + exit 0 +fi +if [ "$1" = "pr" ] && [ "$2" = "view" ] && [ "$3" = "123" ] && [ "$4" = "--web" ]; then + exit 0 +fi +echo "unexpected gh args: $*" >&2 +exit 1 +"#, + ); + + let output = dig_ok_with_env( + repo, + &["pr", "--title", "feat-auth", "--view"], + &[ + ("PATH", path.as_str()), + ("DIG_TEST_GH_LOG", log_path.as_str()), + ], + ); + let stdout = strip_ansi(&String::from_utf8(output.stdout).unwrap()); + + assert!(stdout.contains("Created pull request #123 for 'feat/auth' into 'main'.")); + assert_eq!( + stdout + .matches("https://github.com/acme/dig/pull/123") + .count(), + 1 + ); + + let gh_log = fs::read_to_string(log_path).unwrap(); + assert_eq!( + gh_log.lines().collect::<Vec<_>>(), + vec![ + "pr list --head feat/auth --state open --json number,baseRefName,url", + "pr list --head feat/auth --state open --json number,baseRefName,url", + "pr create --base main --title feat-auth", + "pr view 123 --web", + ] + ); + }); +} + +#[test] +fn pr_prompts_to_push_branch_before_creating_pull_request() { + with_temp_repo("dig-pr-cli", |repo| { + initialize_main_repo(repo); + initialize_origin_remote(repo); + dig_ok(repo, &["init"]); + dig_ok(repo, &["branch", "feat/auth"]); + + let (_, path, log_path) = install_fake_gh( + repo, + r#"#!/bin/sh +set -eu +printf '%s\n' "$*" >> "$DIG_TEST_GH_LOG" +if [ "$1" = "pr" ] && [ "$2" = "list" ]; then + printf '[]\n' + exit 0 +fi +if [ "$1" = "pr" ] && [ "$2" = "create" ]; then + printf 'https://github.com/acme/dig/pull/777\n' + exit 0 +fi +echo "unexpected gh args: $*" >&2 +exit 1 +"#, + ); + + let output = dig_with_input_and_env( + repo, + &["pr"], + "y\n", + &[ + ("PATH", path.as_str()), + ("DIG_TEST_GH_LOG", log_path.as_str()), + ], + ); + assert!(output.status.success()); + + let stdout = strip_ansi(&String::from_utf8(output.stdout).unwrap()); + assert!(stdout.contains("Push it and create the pull request? [y/N]")); + assert!(stdout.contains("Created pull request #777 for 'feat/auth' into 'main'.")); + assert_eq!( + stdout + .matches("https://github.com/acme/dig/pull/777") + .count(), + 1 + ); + + let remote_ref = git_stdout( + repo, + &["ls-remote", "--heads", "origin", "refs/heads/feat/auth"], + ); + assert!(remote_ref.contains("refs/heads/feat/auth")); + + let gh_log = fs::read_to_string(log_path).unwrap(); + assert_eq!( + gh_log.lines().collect::<Vec<_>>(), + vec![ + "pr list --head feat/auth --state open --json number,baseRefName,url", + "pr list --head feat/auth --state open --json number,baseRefName,url", + "pr create --base main", + ] + ); + }); +} + +#[test] +fn pr_declining_push_skips_pull_request_creation() { + with_temp_repo("dig-pr-cli", |repo| { + initialize_main_repo(repo); + initialize_origin_remote(repo); + dig_ok(repo, &["init"]); + dig_ok(repo, &["branch", "feat/auth"]); + + let (_, path, log_path) = install_fake_gh( + repo, + r#"#!/bin/sh +set -eu +printf '%s\n' "$*" >> "$DIG_TEST_GH_LOG" +if [ "$1" = "pr" ] && [ "$2" = "list" ]; then + printf '[]\n' + exit 0 +fi +echo "unexpected gh args: $*" >&2 +exit 1 +"#, + ); + + let output = dig_with_input_and_env( + repo, + &["pr"], + "n\n", + &[ + ("PATH", path.as_str()), + ("DIG_TEST_GH_LOG", log_path.as_str()), + ], + ); + assert!(output.status.success()); + + let stdout = strip_ansi(&String::from_utf8(output.stdout).unwrap()); + assert!(stdout.contains( + "Did not create pull request because 'feat/auth' is not pushed to 'origin'." + )); + + assert!( + git_stdout( + repo, + &["ls-remote", "--heads", "origin", "refs/heads/feat/auth"] + ) + .is_empty() + ); + assert_eq!( + fs::read_to_string(log_path) + .unwrap() + .lines() + .collect::<Vec<_>>(), + vec!["pr list --head feat/auth --state open --json number,baseRefName,url"] + ); + }); +} + +#[test] +fn pr_list_renders_open_tracked_pull_requests_in_lineage_order() { + with_temp_repo("dig-pr-cli", |repo| { + initialize_main_repo(repo); + dig_ok(repo, &["init"]); + + let (bin_dir, path, log_path) = install_fake_gh( + repo, + r#"#!/bin/sh +set -eu +printf '%s\n' "$*" >> "$DIG_TEST_GH_LOG" +if [ "$1" = "pr" ] && [ "$2" = "list" ]; then + printf '[]\n' + exit 0 +fi +if [ "$1" = "pr" ] && [ "$2" = "create" ]; then + current_branch="$(git branch --show-current)" + if [ "$current_branch" = "feat/auth" ]; then + printf 'https://github.com/acme/dig/pull/101\n' + exit 0 + fi + if [ "$current_branch" = "feat/auth-ui" ]; then + printf 'https://github.com/acme/dig/pull/102\n' + exit 0 + fi +fi +echo "unexpected gh args: $*" >&2 +exit 1 +"#, + ); + + dig_ok(repo, &["branch", "feat/auth"]); + dig_ok_with_env( + repo, + &["pr"], + &[ + ("PATH", path.as_str()), + ("DIG_TEST_GH_LOG", log_path.as_str()), + ], + ); + dig_ok(repo, &["branch", "feat/auth-ui"]); + dig_ok_with_env( + repo, + &["pr"], + &[ + ("PATH", path.as_str()), + ("DIG_TEST_GH_LOG", log_path.as_str()), + ], + ); + + clear_log(&log_path); + install_fake_executable( + &bin_dir, + "gh", + r#"#!/bin/sh +set -eu +printf '%s\n' "$*" >> "$DIG_TEST_GH_LOG" +if [ "$1" = "pr" ] && [ "$2" = "list" ] && [ "$3" = "--state" ] && [ "$4" = "open" ]; then + printf '[{"number":101,"title":"Auth PR","url":"https://github.com/acme/dig/pull/101"},{"number":102,"title":"Auth UI PR","url":"https://github.com/acme/dig/pull/102"},{"number":999,"title":"External PR","url":"https://github.com/acme/dig/pull/999"}]\n' + exit 0 +fi +echo "unexpected gh args: $*" >&2 +exit 1 +"#, + ); + + let output = dig_ok_with_env( + repo, + &["pr", "list"], + &[ + ("PATH", path.as_str()), + ("DIG_TEST_GH_LOG", log_path.as_str()), + ], + ); + let stdout = strip_ansi(&String::from_utf8(output.stdout).unwrap()); + + assert!(stdout.contains("main")); + assert!(stdout.contains("#101: Auth PR - https://github.com/acme/dig/pull/101")); + assert!(stdout.contains("#102: Auth UI PR - https://github.com/acme/dig/pull/102")); + assert!(!stdout.contains("#999: External PR")); + }); +} + +#[test] +fn pr_list_with_view_opens_each_listed_pull_request() { + with_temp_repo("dig-pr-cli", |repo| { + initialize_main_repo(repo); + dig_ok(repo, &["init"]); + dig_ok(repo, &["branch", "feat/auth"]); + + let (bin_dir, path, log_path) = install_fake_gh( + repo, + r#"#!/bin/sh +set -eu +printf '%s\n' "$*" >> "$DIG_TEST_GH_LOG" +if [ "$1" = "pr" ] && [ "$2" = "list" ]; then + printf '[]\n' + exit 0 +fi +if [ "$1" = "pr" ] && [ "$2" = "create" ]; then + printf 'https://github.com/acme/dig/pull/301\n' + exit 0 +fi +exit 1 +"#, + ); + + dig_ok_with_env( + repo, + &["pr"], + &[ + ("PATH", path.as_str()), + ("DIG_TEST_GH_LOG", log_path.as_str()), + ], + ); + dig_ok(repo, &["branch", "feat/auth-ui"]); + install_fake_executable( + &bin_dir, + "gh", + r#"#!/bin/sh +set -eu +printf '%s\n' "$*" >> "$DIG_TEST_GH_LOG" +if [ "$1" = "pr" ] && [ "$2" = "list" ]; then + current_branch="$(git branch --show-current)" + if [ "$current_branch" = "feat/auth-ui" ] && [ "$3" = "--head" ]; then + printf '[]\n' + exit 0 + fi + printf '[{"number":301,"title":"Auth PR","url":"https://github.com/acme/dig/pull/301"},{"number":302,"title":"Auth UI PR","url":"https://github.com/acme/dig/pull/302"}]\n' + exit 0 +fi +if [ "$1" = "pr" ] && [ "$2" = "create" ]; then + printf 'https://github.com/acme/dig/pull/302\n' + exit 0 +fi +if [ "$1" = "pr" ] && [ "$2" = "view" ] && [ "$4" = "--web" ]; then + exit 0 +fi +echo "unexpected gh args: $*" >&2 +exit 1 +"#, + ); + dig_ok_with_env( + repo, + &["pr"], + &[ + ("PATH", path.as_str()), + ("DIG_TEST_GH_LOG", log_path.as_str()), + ], + ); + + clear_log(&log_path); + install_fake_executable( + &bin_dir, + "gh", + r#"#!/bin/sh +set -eu +printf '%s\n' "$*" >> "$DIG_TEST_GH_LOG" +if [ "$1" = "pr" ] && [ "$2" = "list" ] && [ "$3" = "--state" ] && [ "$4" = "open" ]; then + printf '[{"number":301,"title":"Auth PR","url":"https://github.com/acme/dig/pull/301"},{"number":302,"title":"Auth UI PR","url":"https://github.com/acme/dig/pull/302"}]\n' + exit 0 +fi +if [ "$1" = "pr" ] && [ "$2" = "view" ] && [ "$4" = "--web" ]; then + exit 0 +fi +echo "unexpected gh args: $*" >&2 +exit 1 +"#, + ); + + dig_ok_with_env( + repo, + &["pr", "list", "--view"], + &[ + ("PATH", path.as_str()), + ("DIG_TEST_GH_LOG", log_path.as_str()), + ], + ); + + assert_eq!( + fs::read_to_string(log_path) + .unwrap() + .lines() + .collect::<Vec<_>>(), + vec![ + "pr list --state open --json number,title,url", + "pr view 301 --web", + "pr view 302 --web", + ] + ); + }); +} + +#[test] +fn pr_rejects_existing_open_pull_request_with_wrong_base() { + with_temp_repo("dig-pr-cli", |repo| { + initialize_main_repo(repo); + dig_ok(repo, &["init"]); + dig_ok(repo, &["branch", "feat/auth"]); + + let (_, path, log_path) = install_fake_gh( + repo, + r#"#!/bin/sh +set -eu +printf '%s\n' "$*" >> "$DIG_TEST_GH_LOG" +if [ "$1" = "pr" ] && [ "$2" = "list" ]; then + printf '[{"number":567,"baseRefName":"develop","url":"https://github.com/acme/dig/pull/567"}]\n' + exit 0 +fi +exit 1 +"#, + ); + + let output = support::dig_with_env( + repo, + &["pr"], + &[ + ("PATH", path.as_str()), + ("DIG_TEST_GH_LOG", log_path.as_str()), + ], + ); + + assert!(!output.status.success()); + let stderr = String::from_utf8(output.stderr).unwrap(); + assert!(stderr.contains("expects base 'main'")); + + let state = load_state_json(repo); + let node = find_node(&state, "feat/auth").unwrap(); + assert!(node["pull_request"].is_null()); + }); +} + +#[test] +fn pr_reports_missing_gh_cli() { + with_temp_repo("dig-pr-cli", |repo| { + initialize_main_repo(repo); + dig_ok(repo, &["init"]); + dig_ok(repo, &["branch", "feat/auth"]); + + let bin_dir = repo.join("fake-bin"); + let git_path = git_binary_path(); + install_fake_executable( + &bin_dir, + "git", + &format!("#!/bin/sh\nset -eu\nexec \"{}\" \"$@\"\n", git_path), + ); + let path = bin_dir.display().to_string(); + + let output = support::dig_with_env(repo, &["pr"], &[("PATH", path.as_str())]); + + assert!(!output.status.success()); + let stderr = String::from_utf8(output.stderr).unwrap(); + assert!(stderr.contains("gh CLI is not installed or not found on PATH")); + }); +} diff --git a/tests/support/mod.rs b/tests/support/mod.rs index 4e2eaca..84da0b5 100644 --- a/tests/support/mod.rs +++ b/tests/support/mod.rs @@ -5,6 +5,9 @@ use std::io::Write; use std::path::Path; use std::process::{Command, Output, Stdio}; +#[cfg(unix)] +use std::os::unix::fs::PermissionsExt; + use serde_json::Value; use uuid::Uuid; @@ -76,17 +79,32 @@ pub fn write_file(repo: &Path, file_name: &str, contents: &str) { } pub fn dig(repo: &Path, args: &[&str]) -> Output { + dig_with_env(repo, args, &[]) +} + +pub fn dig_with_env(repo: &Path, args: &[&str], envs: &[(&str, &str)]) -> Output { Command::new(env!("CARGO_BIN_EXE_dig")) .current_dir(repo) .args(args) + .envs(envs.iter().copied()) .output() .unwrap() } pub fn dig_with_input(repo: &Path, args: &[&str], input: &str) -> Output { + dig_with_input_and_env(repo, args, input, &[]) +} + +pub fn dig_with_input_and_env( + repo: &Path, + args: &[&str], + input: &str, + envs: &[(&str, &str)], +) -> Output { let mut child = Command::new(env!("CARGO_BIN_EXE_dig")) .current_dir(repo) .args(args) + .envs(envs.iter().copied()) .stdin(Stdio::piped()) .stdout(Stdio::piped()) .stderr(Stdio::piped()) @@ -104,7 +122,11 @@ pub fn dig_with_input(repo: &Path, args: &[&str], input: &str) -> Output { } pub fn dig_ok(repo: &Path, args: &[&str]) -> Output { - let output = dig(repo, args); + dig_ok_with_env(repo, args, &[]) +} + +pub fn dig_ok_with_env(repo: &Path, args: &[&str], envs: &[(&str, &str)]) -> Output { + let output = dig_with_env(repo, args, envs); assert!( output.status.success(), "dig {:?} failed\nstdout:\n{}\nstderr:\n{}", @@ -137,6 +159,34 @@ pub fn git_stdout(repo: &Path, args: &[&str]) -> String { String::from_utf8(output.stdout).unwrap().trim().to_string() } +pub fn install_fake_executable(bin_dir: &Path, name: &str, script: &str) { + fs::create_dir_all(bin_dir).unwrap(); + let path = bin_dir.join(name); + fs::write(&path, script).unwrap(); + #[cfg(unix)] + { + let mut permissions = fs::metadata(&path).unwrap().permissions(); + permissions.set_mode(0o755); + fs::set_permissions(path, permissions).unwrap(); + } +} + +pub fn path_with_prepend(dir: &Path) -> String { + let existing_path = std::env::var("PATH").unwrap_or_default(); + if existing_path.is_empty() { + dir.display().to_string() + } else { + format!("{}:{existing_path}", dir.display()) + } +} + +pub fn git_binary_path() -> String { + let output = Command::new("which").arg("git").output().unwrap(); + assert!(output.status.success(), "which git failed"); + + String::from_utf8(output.stdout).unwrap().trim().to_string() +} + pub fn load_state_json(repo: &Path) -> Value { serde_json::from_str(&fs::read_to_string(repo.join(".git/dig/state.json")).unwrap()).unwrap() }