diff --git a/NBitcoin.Tests/PSBTTests.cs b/NBitcoin.Tests/PSBTTests.cs index bd8f8752d..5ef238ec1 100644 --- a/NBitcoin.Tests/PSBTTests.cs +++ b/NBitcoin.Tests/PSBTTests.cs @@ -8,8 +8,6 @@ using System.Linq; using static NBitcoin.Tests.Comparer; using Xunit.Abstractions; -using System.Net.Http; -using System.Threading.Tasks; namespace NBitcoin.Tests { @@ -41,6 +39,27 @@ public static void ShouldThrowExceptionForInvalidData() } } + [Fact] + [Trait("UnitTest", "UnitTest")] + public static void CanDeriveHDKey() + { + var k = new ExtKey(); + var kwif = k.GetWif(Network.Main); + var pk = k.Neuter(); + var pkwif = pk.GetWif(Network.Main); + static void AssertEqualKey(IHDKey a, IHDKey b) + { + Assert.Equal(a.GetPublicKey(), b.GetPublicKey()); + } + AssertEqualKey(((IHDKey)k).Derive(new KeyPath("1")), ((IHDKey)pk).Derive(new KeyPath("1"))); + AssertEqualKey(((IHDKey)kwif).Derive(new KeyPath("1")), ((IHDKey)pkwif).Derive(new KeyPath("1"))); + AssertEqualKey(((IHDKey)k).Derive(new KeyPath("1")), ((IHDKey)kwif).Derive(new KeyPath("1"))); + Assert.Null(((IHDKey)pk).Derive(new KeyPath("1'"))); + Assert.Null(((IHDKey)pkwif).Derive(new KeyPath("1'"))); + Assert.NotNull(((IHDKey)k).Derive(new KeyPath("1'"))); + Assert.NotNull(((IHDKey)kwif).Derive(new KeyPath("1'"))); + } + [Theory] [InlineData(PSBTVersion.PSBTv0)] [InlineData(PSBTVersion.PSBTv2)] @@ -52,7 +71,7 @@ public static void ShouldCalculateBalanceOfHDKey(PSBTVersion version) var aliceMaster = new ExtKey(); var bobMaster = new ExtKey(); - var alice = aliceMaster.Derive(new KeyPath("1/2/3")); + var alice = aliceMaster.Derive(new KeyPath("1'/2/3")); var bob = bobMaster.Derive(new KeyPath("4/5/6")); var funding = network.CreateTransaction(); @@ -83,18 +102,21 @@ public static void ShouldCalculateBalanceOfHDKey(PSBTVersion version) builder.SendFees(Money.Coins(0.001m)); var psbt = builder.BuildPSBT(false, version); - psbt.AddKeyPath(aliceMaster, new KeyPath("1/2/3")); + psbt.AddKeyPath(aliceMaster, new KeyPath("1'/2/3")); psbt.AddKeyPath(bobMaster, new KeyPath("4/5/6")); var actualBalance = psbt.GetBalance(ScriptPubKeyType.Legacy, aliceMaster); var expectedChange = aliceCoin.Amount - (Money.Coins(0.2m) + Money.Coins(0.1m) + Money.Coins(0.123m)); var expectedBalance = -aliceCoin.Amount + expectedChange; Assert.Equal(expectedBalance, actualBalance); + // We can't derive Alice's balance from the xpub only + Assert.Equal(Money.Zero, psbt.GetBalance(ScriptPubKeyType.Legacy, aliceMaster.Neuter())); actualBalance = psbt.GetBalance(ScriptPubKeyType.Legacy, bobMaster); expectedChange = bobCoin.Amount - (Money.Coins(0.25m) + Money.Coins(0.01m) + Money.Coins(0.001m)) + Money.Coins(0.123m); expectedBalance = -bobCoin.Amount + expectedChange; Assert.Equal(expectedBalance, actualBalance); + Assert.Equal(expectedBalance, psbt.GetBalance(ScriptPubKeyType.Legacy, bobMaster.Neuter())); Assert.False(psbt.TryGetFee(out _)); Assert.False(psbt.IsReadyToSign()); diff --git a/NBitcoin/BIP174/HDKeyCache.cs b/NBitcoin/BIP174/HDKeyCache.cs index 961a4e387..9ce77ae59 100644 --- a/NBitcoin/BIP174/HDKeyCache.cs +++ b/NBitcoin/BIP174/HDKeyCache.cs @@ -1,4 +1,5 @@ -using System; +#nullable enable +using System; using System.Collections.Concurrent; using System.Collections.Generic; using System.Text; @@ -9,7 +10,7 @@ class HDKeyCache : IHDKey { private readonly IHDKey hdKey; private readonly KeyPath _PathFromRoot; - private readonly ConcurrentDictionary derivationCache; + private readonly ConcurrentDictionary derivationCache; public IHDKey Inner { get @@ -21,16 +22,16 @@ internal HDKeyCache(IHDKey masterKey) { this.hdKey = masterKey; _PathFromRoot = new KeyPath(); - derivationCache = new ConcurrentDictionary(); + derivationCache = new ConcurrentDictionary(); } - HDKeyCache(IHDKey hdKey, KeyPath childPath, ConcurrentDictionary cache) + HDKeyCache(IHDKey hdKey, KeyPath childPath, ConcurrentDictionary cache) { this.derivationCache = cache; _PathFromRoot = childPath; this.hdKey = hdKey; } - public IHDKey Derive(KeyPath keyPath) + public IHDKey? Derive(KeyPath keyPath) { if (keyPath == null) throw new ArgumentNullException(nameof(keyPath)); @@ -39,7 +40,11 @@ public IHDKey Derive(KeyPath keyPath) foreach (var index in keyPath.Indexes) { childPath = childPath.Derive(index); - key = derivationCache.GetOrAdd(childPath, _ => key.Derive(new KeyPath(index))); + if (childPath is null) + return null; + key = derivationCache.GetOrAdd(childPath, _ => key?.Derive(new KeyPath(index))); + if (key is null) + return null; } return new HDKeyCache(key, childPath, derivationCache); } @@ -50,17 +55,12 @@ public PubKey GetPublicKey() { return this.hdKey.GetPublicKey(); } - - public bool CanDeriveHardenedPath() - { - return Inner.CanDeriveHardenedPath(); - } } class HDScriptPubKeyCache : IHDScriptPubKey { private readonly IHDScriptPubKey hdKey; private readonly KeyPath _PathFromRoot; - private readonly ConcurrentDictionary derivationCache; + private readonly ConcurrentDictionary derivationCache; public IHDScriptPubKey Inner { get @@ -72,16 +72,16 @@ internal HDScriptPubKeyCache(IHDScriptPubKey masterKey) { this.hdKey = masterKey; _PathFromRoot = new KeyPath(); - derivationCache = new ConcurrentDictionary(); + derivationCache = new ConcurrentDictionary(); } - HDScriptPubKeyCache(IHDScriptPubKey hdKey, KeyPath childPath, ConcurrentDictionary cache) + HDScriptPubKeyCache(IHDScriptPubKey hdKey, KeyPath childPath, ConcurrentDictionary cache) { this.derivationCache = cache; _PathFromRoot = childPath; this.hdKey = hdKey; } - public IHDScriptPubKey Derive(KeyPath keyPath) + public IHDScriptPubKey? Derive(KeyPath keyPath) { if (keyPath == null) throw new ArgumentNullException(nameof(keyPath)); @@ -90,7 +90,9 @@ public IHDScriptPubKey Derive(KeyPath keyPath) foreach (var index in keyPath.Indexes) { childPath = childPath.Derive(index); - key = derivationCache.GetOrAdd(childPath, _ => key.Derive(new KeyPath(index))); + key = derivationCache.GetOrAdd(childPath, _ => key?.Derive(new KeyPath(index))); + if (key is null) + return null; } return new HDScriptPubKeyCache(key, childPath, derivationCache); } @@ -98,10 +100,5 @@ public IHDScriptPubKey Derive(KeyPath keyPath) internal int Cached => derivationCache.Count; public Script ScriptPubKey => Inner.ScriptPubKey; - - public bool CanDeriveHardenedPath() - { - return Inner.CanDeriveHardenedPath(); - } } } diff --git a/NBitcoin/BIP174/PSBTCoin.cs b/NBitcoin/BIP174/PSBTCoin.cs index 6a10df18e..99f69f4bc 100644 --- a/NBitcoin/BIP174/PSBTCoin.cs +++ b/NBitcoin/BIP174/PSBTCoin.cs @@ -125,10 +125,17 @@ bool Match(KeyValuePair hdKey, HDFingerprint? expectedMa if (expectedMasterFp is not null && hdKey.Value.MasterFingerprint != expectedMasterFp.Value) return false; - if (accountHDScriptPubKey is not null && - accountHDScriptPubKey.Derive(addressPath).ScriptPubKey != coinScriptPubKey) - return false; + if (accountHDScriptPubKey is not null) + { + var derivedScript = accountHDScriptPubKey.Derive(addressPath); + if (derivedScript is null) + return false; + if (derivedScript.ScriptPubKey != coinScriptPubKey) + return false; + } var derived = accountKey.Derive(addressPath); + if (derived is null) + return false; return hdKey.Key switch { PubKey pk => derived.GetPublicKey().Equals(pk), @@ -142,20 +149,14 @@ bool Match(KeyValuePair hdKey, HDFingerprint? expectedMa accountHDScriptPubKey = accountHDScriptPubKey?.AsHDKeyCache(); foreach (var hdKey in EnumerateKeyPaths()) { - bool matched = false; - var canDeriveHardenedPath = (accountKey.CanDeriveHardenedPath() && (accountHDScriptPubKey == null || accountHDScriptPubKey.CanDeriveHardenedPath())); // The case where the fingerprint of the hdkey is exactly equal to the accountKey - if (!hdKey.Value.KeyPath.IsHardenedPath || canDeriveHardenedPath) + if (Match(hdKey, accountFingerprint, accountKey, hdKey.Value.KeyPath)) { - if (Match(hdKey, accountFingerprint, accountKey, hdKey.Value.KeyPath)) - { - yield return CreateHDKeyMatch(accountKey, hdKey.Value.KeyPath, hdKey); - matched = true; - } + yield return CreateHDKeyMatch(accountKey, hdKey.Value.KeyPath, hdKey); } // The typical case where accountkey is based on an hardened derivation (eg. 49'/0'/0') - if (!matched && accountKeyPath?.MasterFingerprint is HDFingerprint mp) + else if (accountKeyPath?.MasterFingerprint is HDFingerprint mp) { var addressPath = hdKey.Value.KeyPath.GetAddressKeyPath(); // The cases where addresses are generated on a non-hardened path below it (eg. 49'/0'/0'/0/1) @@ -164,12 +165,11 @@ bool Match(KeyValuePair hdKey, HDFingerprint? expectedMa if (Match(hdKey, mp, accountKey, addressPath)) { yield return CreateHDKeyMatch(accountKey, addressPath, hdKey); - matched = true; } } // in some cases addresses are generated on a hardened path below the account key (eg. 49'/0'/0'/0'/1') in which case we // need to brute force what the address key path is - else if (canDeriveHardenedPath) // We can only do this if we can derive hardened paths + else { int addressPathSize = 0; var hdKeyIndexes = hdKey.Value.KeyPath.Indexes; @@ -181,7 +181,6 @@ bool Match(KeyValuePair hdKey, HDFingerprint? expectedMa if (Match(hdKey, null, accountKey, addressPath)) { yield return CreateHDKeyMatch(accountKey, addressPath, hdKey); - matched = true; break; } addressPathSize++; diff --git a/NBitcoin/BIP174/PSBTInput.cs b/NBitcoin/BIP174/PSBTInput.cs index 18c4721fa..0b15bae55 100644 --- a/NBitcoin/BIP174/PSBTInput.cs +++ b/NBitcoin/BIP174/PSBTInput.cs @@ -513,7 +513,10 @@ internal void TrySign(IHDScriptPubKey? accountHDScriptPubKey, IHDKey accountKey, var cache = accountKey.AsHDKeyCache(); foreach (var hdk in this.HDKeysFor(accountHDScriptPubKey, cache, accountKeyPath)) { - if (((HDKeyCache)cache.Derive(hdk.AddressKeyPath)).Inner is ISecret k) + var key = cache.Derive(hdk.AddressKeyPath); + if (key is null) + continue; + if (((HDKeyCache)key).Inner is ISecret k) Sign(k.PrivateKey, signingOptions); else throw new ArgumentException(paramName: nameof(accountKey), message: "This should be a private key"); diff --git a/NBitcoin/BIP174/PartiallySignedTransaction.cs b/NBitcoin/BIP174/PartiallySignedTransaction.cs index cbb50dbf7..8dacd35ca 100644 --- a/NBitcoin/BIP174/PartiallySignedTransaction.cs +++ b/NBitcoin/BIP174/PartiallySignedTransaction.cs @@ -1163,6 +1163,8 @@ public PSBT AddKeyPath(IHDKey masterKey, params Tuple[] paths) foreach (var path in paths) { var key = masterKey.Derive(path.Item1); + if (key is null) + continue; AddKeyPath(key.GetPublicKey(), new RootedKeyPath(masterKeyFP, path.Item1), path.Item2); } return this; diff --git a/NBitcoin/BIP32/BitcoinExtKey.cs b/NBitcoin/BIP32/BitcoinExtKey.cs index 6b0ad3b66..f18326486 100644 --- a/NBitcoin/BIP32/BitcoinExtKey.cs +++ b/NBitcoin/BIP32/BitcoinExtKey.cs @@ -1,4 +1,6 @@ -using System; +#nullable enable +using System; +using System.Diagnostics.CodeAnalysis; namespace NBitcoin { @@ -60,7 +62,7 @@ protected override bool IsValid } } - ExtKey _Key; + ExtKey? _Key; /// /// Gets the extended key, converting from the Base58 representation. @@ -113,10 +115,7 @@ public BitcoinExtKey Derive(uint index) return new BitcoinExtKey(ExtKey.Derive(index), Network); } - IHDKey IHDKey.Derive(KeyPath keyPath) - { - return Derive(keyPath); - } + IHDKey? IHDKey.Derive(KeyPath keyPath) => Derive(keyPath); public BitcoinExtKey Derive(KeyPath keyPath) { @@ -138,11 +137,6 @@ public PubKey GetPublicKey() return ExtKey.PrivateKey.PubKey; } - bool IHDKey.CanDeriveHardenedPath() - { - return true; - } - #region ISecret Members /// @@ -158,15 +152,10 @@ public Key PrivateKey #endregion - /// - /// Implicit cast from BitcoinExtKey to ExtKey. - /// - public static implicit operator ExtKey(BitcoinExtKey key) - { - if (key == null) - return null; - return key.ExtKey; - } + +#nullable disable + public static implicit operator ExtKey(BitcoinExtKey key) => key?.ExtKey; +#nullable enable } /// @@ -190,7 +179,7 @@ public BitcoinExtPubKey(ExtPubKey key, Network network) { } - ExtPubKey _PubKey; + ExtPubKey? _PubKey; /// /// Gets the extended public key, converting from the Base58 representation. @@ -246,20 +235,11 @@ public override Script ScriptPubKey } } - /// - /// Implicit cast from BitcoinExtPubKey to ExtPubKey. - /// - public static implicit operator ExtPubKey(BitcoinExtPubKey key) - { - if (key == null) - return null; - return key.ExtPubKey; - } +#nullable disable + public static implicit operator ExtPubKey(BitcoinExtPubKey key) => key?.ExtPubKey; +#nullable enable - IHDKey IHDKey.Derive(KeyPath keyPath) - { - return Derive(keyPath); - } + IHDKey? IHDKey.Derive(KeyPath keyPath) => keyPath?.IsHardenedPath is true ? null : Derive(keyPath!); public BitcoinExtPubKey Derive(uint index) { @@ -277,10 +257,5 @@ public PubKey GetPublicKey() { return ExtPubKey.pubkey; } - - bool IHDKey.CanDeriveHardenedPath() - { - return false; - } } } diff --git a/NBitcoin/BIP32/ExtKey.cs b/NBitcoin/BIP32/ExtKey.cs index 88feb2b03..c4612bcfa 100644 --- a/NBitcoin/BIP32/ExtKey.cs +++ b/NBitcoin/BIP32/ExtKey.cs @@ -33,15 +33,12 @@ public HDKeyScriptPubKey(IHDKey hdKey, ScriptPubKeyType type) Script? _ScriptPubKey; public Script ScriptPubKey => _ScriptPubKey = _ScriptPubKey ?? hdKey.GetPublicKey().GetScriptPubKey(type); - public IHDScriptPubKey Derive(KeyPath keyPath) + public IHDScriptPubKey? Derive(KeyPath keyPath) + => this.hdKey.Derive(keyPath) switch { - return new HDKeyScriptPubKey(this.hdKey.Derive(keyPath), type); - } - - public bool CanDeriveHardenedPath() - { - return this.hdKey.CanDeriveHardenedPath(); - } + { } k => new HDKeyScriptPubKey(k, type), + _ => null + }; } /// @@ -526,21 +523,13 @@ public bool Equals(ExtKey? other) StructuralComparisons.StructuralEqualityComparer.Equals(vchChainCode, other.vchChainCode); } - IHDKey IHDKey.Derive(KeyPath keyPath) - { - return this.Derive(keyPath); - } + IHDKey? IHDKey.Derive(KeyPath keyPath) => this.Derive(keyPath); public PubKey GetPublicKey() { return PrivateKey.PubKey; } - bool IHDKey.CanDeriveHardenedPath() - { - return true; - } - public override bool Equals(object? obj) { if (obj is ExtKey other) diff --git a/NBitcoin/BIP32/ExtPubKey.cs b/NBitcoin/BIP32/ExtPubKey.cs index 919ce4b39..ae8993ca4 100644 --- a/NBitcoin/BIP32/ExtPubKey.cs +++ b/NBitcoin/BIP32/ExtPubKey.cs @@ -258,21 +258,13 @@ public string ToString(Network network) return new BitcoinExtPubKey(this, network).ToString(); } - IHDKey IHDKey.Derive(KeyPath keyPath) - { - return this.Derive(keyPath); - } + IHDKey? IHDKey.Derive(KeyPath keyPath) => keyPath?.IsHardenedPath is true ? null : Derive(keyPath!); public PubKey GetPublicKey() { return this.pubkey; } - bool IHDKey.CanDeriveHardenedPath() - { - return false; - } - public bool Equals(ExtPubKey? other) { if (other is null) diff --git a/NBitcoin/BIP32/IHDKey.cs b/NBitcoin/BIP32/IHDKey.cs index f3e40847c..796dc044f 100644 --- a/NBitcoin/BIP32/IHDKey.cs +++ b/NBitcoin/BIP32/IHDKey.cs @@ -1,4 +1,5 @@ -using System; +#nullable enable +using System; using System.Collections.Generic; using System.Text; @@ -6,8 +7,7 @@ namespace NBitcoin { public interface IHDKey { - IHDKey Derive(KeyPath keyPath); + IHDKey? Derive(KeyPath keyPath); PubKey GetPublicKey(); - bool CanDeriveHardenedPath(); } } diff --git a/NBitcoin/IHDScriptPubKey.cs b/NBitcoin/IHDScriptPubKey.cs index 41bb68608..bce2f9441 100644 --- a/NBitcoin/IHDScriptPubKey.cs +++ b/NBitcoin/IHDScriptPubKey.cs @@ -1,4 +1,5 @@ -using System; +#nullable enable +using System; using System.Collections.Generic; using System.Text; @@ -9,8 +10,7 @@ namespace NBitcoin /// public interface IHDScriptPubKey { - IHDScriptPubKey Derive(KeyPath keyPath); - bool CanDeriveHardenedPath(); + IHDScriptPubKey? Derive(KeyPath keyPath); Script ScriptPubKey { get; } } } diff --git a/NBitcoin/Utils.cs b/NBitcoin/Utils.cs index a5a2d4ccd..4bbb1028e 100644 --- a/NBitcoin/Utils.cs +++ b/NBitcoin/Utils.cs @@ -91,7 +91,7 @@ public static IHDScriptPubKey AsHDScriptPubKey(this IHDKey hdKey, ScriptPubKeyTy return new HDKeyScriptPubKey(hdKey, type); } - public static IHDKey Derive(this IHDKey hdkey, uint index) + public static IHDKey? Derive(this IHDKey hdkey, uint index) { if (hdkey == null) throw new ArgumentNullException(nameof(hdkey)); @@ -104,13 +104,13 @@ public static IHDKey Derive(this IHDKey hdkey, uint index) /// The hdKey to derive /// keyPaths to derive /// An array of keyPaths.Length size with the derived keys - public static IHDKey[] Derive(this IHDKey hdkey, KeyPath[] keyPaths) + public static IHDKey?[] Derive(this IHDKey hdkey, KeyPath[] keyPaths) { if (hdkey == null) throw new ArgumentNullException(nameof(hdkey)); if (keyPaths == null) throw new ArgumentNullException(nameof(keyPaths)); - var result = new IHDKey[keyPaths.Length]; + var result = new IHDKey?[keyPaths.Length]; var cache = (HDKeyCache)hdkey.AsHDKeyCache(); #if !NOPARALLEL Parallel.For(0, keyPaths.Length, i => diff --git a/NBitcoin/WalletPolicies/DeriveParameters.cs b/NBitcoin/WalletPolicies/DeriveParameters.cs index 0c796c97a..2f494db27 100644 --- a/NBitcoin/WalletPolicies/DeriveParameters.cs +++ b/NBitcoin/WalletPolicies/DeriveParameters.cs @@ -10,7 +10,7 @@ namespace NBitcoin.WalletPolicies { - public class DerivationCache : ConcurrentDictionary<(IHDKey, int), Lazy> + public class DerivationCache : ConcurrentDictionary<(IHDKey, int), Lazy> { } public class DeriveParameters diff --git a/NBitcoin/WalletPolicies/Visitors/DeriveVisitor.cs b/NBitcoin/WalletPolicies/Visitors/DeriveVisitor.cs index 00e03a036..b30b3a0af 100644 --- a/NBitcoin/WalletPolicies/Visitors/DeriveVisitor.cs +++ b/NBitcoin/WalletPolicies/Visitors/DeriveVisitor.cs @@ -86,8 +86,8 @@ public override MiniscriptNode Visit(MiniscriptNode node) private (KeyPath KeyPath, Value Pubkey) GetPublicKey(MiniscriptNode.MultipathNode mki, IHDKey k, HDKeyNode? source = null) { var type = mki.GetTypeIndex(Intent); - k = DeriveIntent(k, type); - k = k.Derive((uint)idx); + k = DeriveIntent(k, type) ?? throw new InvalidOperationException($"Unable to derive the key for {type}"); + k = k.Derive((uint)idx) ?? throw new InvalidOperationException($"Unable to derive the key for {type}:{idx}"); var keyType = _nestedMusig ? KeyType.Classic : KeyType; return ( new KeyPath([(uint)type, (uint)idx]), @@ -101,10 +101,10 @@ public override MiniscriptNode Visit(MiniscriptNode node) public bool _nestedMusig = false; - private IHDKey DeriveIntent(IHDKey k, int typeIndex) + private IHDKey? DeriveIntent(IHDKey k, int typeIndex) { // When we derive 0/1/*, "0/1" is common to multiple derivations, so we cache it - return DerivationCache.GetOrAdd((k, typeIndex), new Lazy(() => k.Derive((uint)typeIndex))).Value; + return DerivationCache.GetOrAdd((k, typeIndex), new Lazy(() => k.Derive((uint)typeIndex))).Value; } } #endif