From 5bc2fa6d5ccc62aadfa5264ae65812e1a3e7e00f Mon Sep 17 00:00:00 2001 From: Jamie Cansdale Date: Tue, 8 Jan 2019 09:35:44 +0000 Subject: [PATCH] Detect the PR associated with the current branch This implementation uses ListReferences rather than a GitHub API. --- src/GitHub.App/Services/GitClient.cs | 18 ++++++ src/GitHub.App/Services/PullRequestService.cs | 57 ++++++++++++++++++- .../Services/IGitClient.cs | 3 + 3 files changed, 76 insertions(+), 2 deletions(-) diff --git a/src/GitHub.App/Services/GitClient.cs b/src/GitHub.App/Services/GitClient.cs index 6114d82169..499eefb51a 100644 --- a/src/GitHub.App/Services/GitClient.cs +++ b/src/GitHub.App/Services/GitClient.cs @@ -20,6 +20,7 @@ public class GitClient : IGitClient const string defaultOriginName = "origin"; static readonly ILogger log = LogManager.ForContext(); readonly IGitService gitService; + readonly IGitHubCredentialProvider credentialProvider; readonly PullOptions pullOptions; readonly PushOptions pushOptions; readonly FetchOptions fetchOptions; @@ -31,6 +32,7 @@ public GitClient(IGitHubCredentialProvider credentialProvider, IGitService gitSe Guard.ArgumentNotNull(gitService, nameof(gitService)); this.gitService = gitService; + this.credentialProvider = credentialProvider; pushOptions = new PushOptions { CredentialsProvider = credentialProvider.HandleCredentials }; fetchOptions = new FetchOptions { CredentialsProvider = credentialProvider.HandleCredentials }; @@ -168,6 +170,22 @@ public Task Fetch(IRepository repository, string remoteName, params string[] ref }); } + public Task> ListReferences(IRepository repo, string remoteName) + { + return Task.Run>(() => + { + var dictionary = new Dictionary(); + var remote = repo.Network.Remotes[remoteName]; + var refs = repo.Network.ListReferences(remote, credentialProvider.HandleCredentials); + foreach (var reference in refs) + { + dictionary[reference.CanonicalName] = reference.TargetIdentifier; + } + + return dictionary; + }); + } + public Task Checkout(IRepository repository, string branchName) { Guard.ArgumentNotNull(repository, nameof(repository)); diff --git a/src/GitHub.App/Services/PullRequestService.cs b/src/GitHub.App/Services/PullRequestService.cs index fd2f9bc006..130ab7eb6e 100644 --- a/src/GitHub.App/Services/PullRequestService.cs +++ b/src/GitHub.App/Services/PullRequestService.cs @@ -62,7 +62,7 @@ public class PullRequestService : IPullRequestService, IStaticReviewFileMap readonly IUsageTracker usageTracker; readonly IDictionary tempFileMappings; - + [ImportingConstructor] public PullRequestService( IGitClient gitClient, @@ -738,11 +738,64 @@ public IObservable SwitchToBranch(LocalRepositoryModel repository, PullReq repo.Head.FriendlyName, SettingGHfVSPullRequest); var value = await gitClient.GetConfig(repo, configKey); - return Observable.Return(ParseGHfVSConfigKeyValue(value)); + var pr = ParseGHfVSConfigKeyValue(value); + if (pr != default((string, int))) + { + return Observable.Return(pr); + } + + pr = await FindPullRequestForBranchAsync(repo, repo.Head, "origin"); + return Observable.Return(pr); } }); } + async Task<(string owner, int number)> FindPullRequestForBranchAsync( + IRepository repo, Branch branch, string upstreamRemoteName = "origin") + { + if (!branch.IsTracking) + { + return default((string, int)); + } + + var remoteReferences = await gitClient.ListReferences(repo, branch.RemoteName); + if (!remoteReferences.TryGetValue(branch.UpstreamBranchCanonicalName, out var sha)) + { + return default((string, int)); + } + + if (branch.RemoteName != upstreamRemoteName) + { + remoteReferences = await gitClient.ListReferences(repo, upstreamRemoteName); + } + + var prs = remoteReferences + .Where(kv => kv.Value == sha) + .Select(kv => FindPullRequestForCanonicalName(kv.Key)) + .Where(p => p != -1) + .ToList(); + if (prs.Count == 0) + { + return default((string, int)); + } + + var owner = gitService.GetRemoteUri(repo, upstreamRemoteName).Owner; + var number = prs[0]; + + return (owner, number); + } + + static int FindPullRequestForCanonicalName(string canonicalName) + { + var match = Regex.Match(canonicalName, "^refs/pull/([0-9]+)/head$"); + if (match.Success && int.TryParse(match.Groups[1].Value, out var number)) + { + return number; + } + + return -1; + } + public async Task ExtractToTempFile( LocalRepositoryModel repository, PullRequestDetailModel pullRequest, diff --git a/src/GitHub.Exports.Reactive/Services/IGitClient.cs b/src/GitHub.Exports.Reactive/Services/IGitClient.cs index 81fbbafebe..3ab58c9adb 100644 --- a/src/GitHub.Exports.Reactive/Services/IGitClient.cs +++ b/src/GitHub.Exports.Reactive/Services/IGitClient.cs @@ -56,6 +56,9 @@ public interface IGitClient /// Task Fetch(IRepository repository, UriString remoteUri, params string[] refspecs); + // blar! + Task> ListReferences(IRepository repo, string remoteName); + /// /// Checks out a branch. ///