Skip to content

offer: make the merkle tree signature public #3892

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
193 changes: 188 additions & 5 deletions lightning/src/offers/merkle.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,59 @@ impl TaggedHash {
Self::from_tlv_stream(tag, tlv_stream)
}

/// Creates a tagged hash with the given parameters, validating the TLV stream.
///
/// This is a low-level function exposed for specific use cases like command-line tools
/// and testing. For production use, prefer higher-level methods like
/// [`Bolt12Invoice::try_from`] which handle validation automatically.
///
/// Returns an error if `bytes` is not a well-formed TLV stream containing at least one TLV record.
///
/// [`Bolt12Invoice::try_from`]: crate::offers::invoice::Bolt12Invoice::try_from
pub fn from_tlv_stream_bytes(tag: &'static str, bytes: &[u8]) -> Result<Self, TlvStreamError> {
// Validate the TLV stream first
if bytes.is_empty() {
return Err(TlvStreamError::EmptyStream);
}

// Try to parse the TLV stream to check validity
let mut cursor = io::Cursor::new(bytes);
let mut has_records = false;

while cursor.position() < bytes.len() as u64 {
// Try to read type
let type_result = <BigSize as Readable>::read(&mut cursor);
if type_result.is_err() {
return Err(TlvStreamError::InvalidRecord);
}

// Try to read length
let length_result = <BigSize as Readable>::read(&mut cursor);
if length_result.is_err() {
return Err(TlvStreamError::InvalidRecord);
}

let length = length_result.unwrap().0;
let end_position = cursor.position() + length;

// Check if the record extends beyond the buffer
if end_position > bytes.len() as u64 {
return Err(TlvStreamError::InvalidRecord);
}

// Skip the value
cursor.set_position(end_position);
has_records = true;
}

if !has_records {
return Err(TlvStreamError::EmptyStream);
}

// If validation passes, create the tagged hash
Ok(Self::from_valid_tlv_stream_bytes(tag, bytes))
Comment on lines +58 to +99
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The TLV stream validation implementation is missing a critical requirement from the Lightning Network specification: TLV records must appear in ascending order by type.

To ensure full compliance, consider tracking the previous type value during iteration and verifying that each new type is greater than the previous one:

let mut previous_type: Option<u64> = None;

while cursor.position() < bytes.len() as u64 {
    // Try to read type
    let type_result = <BigSize as Readable>::read(&mut cursor);
    if type_result.is_err() {
        return Err(TlvStreamError::InvalidRecord);
    }
    
    let current_type = type_result.unwrap().0;
    
    // Check ascending order
    if let Some(prev) = previous_type {
        if current_type <= prev {
            return Err(TlvStreamError::InvalidOrder);
        }
    }
    
    previous_type = Some(current_type);
    
    // Rest of the validation...
}

This would require adding an InvalidOrder variant to the TlvStreamError enum.

Suggested change
pub fn from_tlv_stream_bytes(tag: &'static str, bytes: &[u8]) -> Result<Self, TlvStreamError> {
// Validate the TLV stream first
if bytes.is_empty() {
return Err(TlvStreamError::EmptyStream);
}
// Try to parse the TLV stream to check validity
let mut cursor = io::Cursor::new(bytes);
let mut has_records = false;
while cursor.position() < bytes.len() as u64 {
// Try to read type
let type_result = <BigSize as Readable>::read(&mut cursor);
if type_result.is_err() {
return Err(TlvStreamError::InvalidRecord);
}
// Try to read length
let length_result = <BigSize as Readable>::read(&mut cursor);
if length_result.is_err() {
return Err(TlvStreamError::InvalidRecord);
}
let length = length_result.unwrap().0;
let end_position = cursor.position() + length;
// Check if the record extends beyond the buffer
if end_position > bytes.len() as u64 {
return Err(TlvStreamError::InvalidRecord);
}
// Skip the value
cursor.set_position(end_position);
has_records = true;
}
if !has_records {
return Err(TlvStreamError::EmptyStream);
}
// If validation passes, create the tagged hash
Ok(Self::from_valid_tlv_stream_bytes(tag, bytes))
pub fn from_tlv_stream_bytes(tag: &'static str, bytes: &[u8]) -> Result<Self, TlvStreamError> {
// Validate the TLV stream first
if bytes.is_empty() {
return Err(TlvStreamError::EmptyStream);
}
// Try to parse the TLV stream to check validity
let mut cursor = io::Cursor::new(bytes);
let mut has_records = false;
let mut previous_type: Option<u64> = None;
while cursor.position() < bytes.len() as u64 {
// Try to read type
let type_result = <BigSize as Readable>::read(&mut cursor);
if type_result.is_err() {
return Err(TlvStreamError::InvalidRecord);
}
let current_type = type_result.unwrap().0;
// Check ascending order
if let Some(prev) = previous_type {
if current_type <= prev {
return Err(TlvStreamError::InvalidOrder);
}
}
previous_type = Some(current_type);
// Try to read length
let length_result = <BigSize as Readable>::read(&mut cursor);
if length_result.is_err() {
return Err(TlvStreamError::InvalidRecord);
}
let length = length_result.unwrap().0;
let end_position = cursor.position() + length;
// Check if the record extends beyond the buffer
if end_position > bytes.len() as u64 {
return Err(TlvStreamError::InvalidRecord);
}
// Skip the value
cursor.set_position(end_position);
has_records = true;
}
if !has_records {
return Err(TlvStreamError::EmptyStream);
}
// If validation passes, create the tagged hash
Ok(Self::from_valid_tlv_stream_bytes(tag, bytes))

Spotted by Diamond

Is this helpful? React 👍 or 👎 to let us know.

}

/// Creates a tagged hash with the given parameters.
///
/// Panics if `tlv_stream` is not a well-formed TLV stream containing at least one TLV record.
Expand Down Expand Up @@ -93,8 +146,17 @@ pub enum SignError {
Verification(secp256k1::Error),
}

/// Error when parsing TLV streams.
#[derive(Debug, PartialEq)]
pub enum TlvStreamError {
/// The TLV stream is empty (contains no records).
EmptyStream,
/// The TLV stream contains an invalid record.
InvalidRecord,
}

/// A function for signing a [`TaggedHash`].
pub(super) trait SignFn<T: AsRef<TaggedHash>> {
pub trait SignFn<T: AsRef<TaggedHash>> {
/// Signs a [`TaggedHash`] computed over the merkle root of `message`'s TLV stream.
fn sign(&self, message: &T) -> Result<Signature, ()>;
}
Expand All @@ -111,15 +173,17 @@ where
/// Signs a [`TaggedHash`] computed over the merkle root of `message`'s TLV stream, checking if it
/// can be verified with the supplied `pubkey`.
///
/// This is a low-level function exposed for specific use cases like command-line tools
/// and testing. For production use, prefer higher-level methods on invoice types that handle
/// signing automatically.
///
/// Since `message` is any type that implements [`AsRef<TaggedHash>`], `sign` may be a closure that
/// takes a message such as [`Bolt12Invoice`] or [`InvoiceRequest`]. This allows further message
/// verification before signing its [`TaggedHash`].
///
/// [`Bolt12Invoice`]: crate::offers::invoice::Bolt12Invoice
/// [`InvoiceRequest`]: crate::offers::invoice_request::InvoiceRequest
pub(super) fn sign_message<F, T>(
f: F, message: &T, pubkey: PublicKey,
) -> Result<Signature, SignError>
pub fn sign_message<F, T>(f: F, message: &T, pubkey: PublicKey) -> Result<Signature, SignError>
where
F: SignFn<T>,
T: AsRef<TaggedHash>,
Expand All @@ -136,7 +200,13 @@ where

/// Verifies the signature with a pubkey over the given message using a tagged hash as the message
/// digest.
pub(super) fn verify_signature(
///
/// This is a low-level function exposed for specific use cases like command-line tools
/// and testing. For production use, prefer higher-level methods like
/// [`Bolt12Invoice::try_from`] which handle signature verification automatically.
///
/// [`Bolt12Invoice::try_from`]: crate::offers::invoice::Bolt12Invoice::try_from
pub fn verify_signature(
signature: &Signature, message: &TaggedHash, pubkey: PublicKey,
) -> Result<(), secp256k1::Error> {
let digest = message.as_digest();
Expand Down Expand Up @@ -481,6 +551,119 @@ mod tests {
assert_eq!(tlv_stream, invoice_request.bytes);
}

#[test]
fn validates_tlv_stream_bytes() {
// Test with valid TLV stream
const VALID_HEX: &'static str = "010203e8";
let valid_bytes = <Vec<u8>>::from_hex(VALID_HEX).unwrap();
let result = super::TaggedHash::from_tlv_stream_bytes("test", &valid_bytes);
assert!(result.is_ok());

// Test with empty stream
let empty_bytes = Vec::new();
let result = super::TaggedHash::from_tlv_stream_bytes("test", &empty_bytes);
assert_eq!(result, Err(super::TlvStreamError::EmptyStream));

// Test with invalid TLV stream (truncated)
let invalid_bytes = vec![0x01, 0x02]; // Type and length but no value
let result = super::TaggedHash::from_tlv_stream_bytes("test", &invalid_bytes);
assert_eq!(result, Err(super::TlvStreamError::InvalidRecord));
}

#[test]
fn consistent_results_between_validating_and_non_validating_functions() {
// Test vectors from BOLT 12
let test_vectors = vec![
"010203e8",
"010203e802080000010000020003",
"010203e802080000010000020003", // Using same as above for simplicity
];

for hex_data in test_vectors {
let bytes = <Vec<u8>>::from_hex(hex_data).unwrap();
let tag = "test_tag";

// Create tagged hash using the validating function
let validating_result = super::TaggedHash::from_tlv_stream_bytes(tag, &bytes);
assert!(
validating_result.is_ok(),
"Validating function should succeed for valid TLV stream"
);
let validating_hash = validating_result.unwrap();

// Create tagged hash using the non-validating function
let non_validating_hash = super::TaggedHash::from_valid_tlv_stream_bytes(tag, &bytes);

// Both should produce identical results
assert_eq!(
validating_hash.tag(),
non_validating_hash.tag(),
"Tags should be identical"
);
assert_eq!(
validating_hash.merkle_root(),
non_validating_hash.merkle_root(),
"Merkle roots should be identical"
);
assert_eq!(
validating_hash.as_digest(),
non_validating_hash.as_digest(),
"Digests should be identical"
);
assert_eq!(validating_hash, non_validating_hash, "Tagged hashes should be identical");
}
}

#[test]
fn regression_test_with_invoice_request_data() {
// Use real invoice request data to ensure no regression
let expanded_key = ExpandedKey::new([42; 32]);
let nonce = Nonce([0u8; 16]);
let secp_ctx = Secp256k1::new();
let payment_id = PaymentId([1; 32]);

let recipient_pubkey = {
let secret_key = SecretKey::from_slice(&[41; 32]).unwrap();
Keypair::from_secret_key(&secp_ctx, &secret_key).public_key()
};

let invoice_request = OfferBuilder::new(recipient_pubkey)
.amount_msats(100)
.build_unchecked()
.request_invoice(&expanded_key, nonce, &secp_ctx, payment_id)
.unwrap()
.build_and_sign()
.unwrap();

// Extract bytes without signature for testing
let mut bytes_without_signature = Vec::new();
let tlv_stream_without_signatures = TlvStream::new(&invoice_request.bytes)
.filter(|record| !SIGNATURE_TYPES.contains(&record.r#type));
for record in tlv_stream_without_signatures {
record.write(&mut bytes_without_signature).unwrap();
}

let tag = "invoice_request";

// Test both functions produce the same result
let validating_result =
super::TaggedHash::from_tlv_stream_bytes(tag, &bytes_without_signature);
assert!(
validating_result.is_ok(),
"Should successfully validate real invoice request data"
);
let validating_hash = validating_result.unwrap();

let non_validating_hash =
super::TaggedHash::from_valid_tlv_stream_bytes(tag, &bytes_without_signature);

// Verify they produce identical results
assert_eq!(
validating_hash, non_validating_hash,
"Both functions should produce identical results for real data"
);
}

impl AsRef<[u8]> for InvoiceRequest {
fn as_ref(&self) -> &[u8] {
&self.bytes
Expand Down
Loading