Skip to content
Closed
16 changes: 8 additions & 8 deletions electrum/commands.py
Original file line number Diff line number Diff line change
Expand Up @@ -1405,18 +1405,18 @@ async def add_hold_invoice(
assert inbound_capacity > satoshis(amount or 0), \
f"Not enough inbound capacity [{inbound_capacity} sat] to receive this payment"

wallet.lnworker.add_payment_info_for_hold_invoice(
bfh(payment_hash),
lightning_amount_sat=satoshis(amount) if amount else None,
min_final_cltv_delta=min_final_cltv_expiry_delta,
exp_delay=expiry,
)
info = wallet.lnworker.get_payment_info(bfh(payment_hash))
lnaddr, invoice = wallet.lnworker.get_bolt11_invoice(
payment_hash=bfh(payment_hash),
amount_msat=satoshis(amount) * 1000 if amount else None,
payment_info=info,
message=memo,
expiry=expiry,
min_final_cltv_expiry_delta=min_final_cltv_expiry_delta,
fallback_address=None
)
wallet.lnworker.add_payment_info_for_hold_invoice(
bfh(payment_hash),
satoshis(amount) if amount else None,
)
wallet.lnworker.dont_settle_htlcs[payment_hash] = None
wallet.set_label(payment_hash, memo)
result = {
Expand Down
36 changes: 27 additions & 9 deletions electrum/lnpeer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2498,16 +2498,15 @@ def maybe_fulfill_htlc(
Decide what to do with an HTLC: return preimage if it can be fulfilled, forwarding callback if it can be forwarded.
Return (preimage, (payment_key, callback)) with at most a single element not None.
"""
htlc_key = serialize_htlc_key(chan.get_scid_or_local_alias(), htlc.htlc_id)
if not processed_onion.are_we_final:
if not self.lnworker.enable_htlc_forwarding:
return None, None
# use the htlc key if we are forwarding
payment_key = serialize_htlc_key(chan.get_scid_or_local_alias(), htlc.htlc_id)
callback = lambda: self.maybe_forward_htlc(
incoming_chan=chan,
htlc=htlc,
processed_onion=processed_onion)
return None, (payment_key, callback)
return None, (htlc_key, callback) # use the htlc key if we are forwarding

def log_fail_reason(reason: str):
self.logger.info(
Expand Down Expand Up @@ -2544,10 +2543,10 @@ def log_fail_reason(reason: str):
):
return None, None

# TODO check against actual min_final_cltv_expiry_delta from invoice (and give 2-3 blocks of leeway?)
blocks_to_expiry = htlc.cltv_abs - local_height
# note: payment_bundles might get split here, e.g. one payment is "already forwarded" and the other is not.
# In practice, for the swap prepayment use case, this does not matter.
if local_height + MIN_FINAL_CLTV_DELTA_ACCEPTED > htlc.cltv_abs and not already_forwarded:
if blocks_to_expiry < MIN_FINAL_CLTV_DELTA_ACCEPTED and not already_forwarded:
log_fail_reason(f"htlc.cltv_abs is unreasonably close")
raise exc_incorrect_or_unknown_pd

Expand Down Expand Up @@ -2581,10 +2580,6 @@ def log_fail_reason(reason: str):
fw_payment_key=payment_key)
return None, (payment_key, callback)

# TODO don't accept payments twice for same invoice
# note: we don't check invoice expiry (bolt11 'x' field) on the receiver-side.
# - semantics are weird: would make sense for simple-payment-receives, but not
# if htlc is expected to be pending for a while, e.g. for a hold-invoice.
info = self.lnworker.get_payment_info(payment_hash)
if info is None:
log_fail_reason(f"no payment_info found for RHASH {htlc.payment_hash.hex()}")
Expand All @@ -2605,6 +2600,27 @@ def log_fail_reason(reason: str):
log_fail_reason(f"total_msat={total_msat} too different from invoice_msat={invoice_msat}")
raise exc_incorrect_or_unknown_pd

if htlc_key not in self.lnworker.verified_pending_htlcs:
# these checks against the PaymentInfo have to be done only once after
# receiving the htlc
valid_expiry = htlc.timestamp < info.expiration_ts
if not valid_expiry and not already_forwarded:
log_fail_reason(f"invoice already expired: {info.expiration_ts=}")
raise exc_incorrect_or_unknown_pd

valid_cltv = blocks_to_expiry >= info.min_final_cltv_delta
will_settle = preimage is not None and payment_hash.hex() not in self.lnworker.dont_settle_htlcs
if not valid_cltv and not will_settle and not already_forwarded:
# this check only really matters for htlcs which don't get settled right away
log_fail_reason(f"remaining locktime smaller than requested {blocks_to_expiry=} < {info.min_final_cltv_delta=}")
raise exc_incorrect_or_unknown_pd

if info.status == PR_PAID:
log_fail_reason(f"invoice has already been paid")
raise exc_incorrect_or_unknown_pd

self.lnworker.verified_pending_htlcs[htlc_key] = None

hold_invoice_callback = self.lnworker.hold_invoice_callbacks.get(payment_hash)
if hold_invoice_callback and not preimage:
callback = lambda: hold_invoice_callback(payment_hash)
Expand Down Expand Up @@ -3099,6 +3115,8 @@ async def htlc_switch(self):
self.lnworker.maybe_cleanup_mpp(chan.get_scid_or_local_alias(), htlc)
if forwarding_key:
self.lnworker.maybe_cleanup_forwarding(forwarding_key)
htlc_key = serialize_htlc_key(chan.short_channel_id, htlc_id)
self.lnworker.verified_pending_htlcs.pop(htlc_key, None)
done.add(htlc_id)
continue
if onion_packet_hex is None:
Expand Down
8 changes: 5 additions & 3 deletions electrum/lnutil.py
Original file line number Diff line number Diff line change
Expand Up @@ -503,9 +503,11 @@ class LNProtocolWarning(Exception):
# the minimum cltv_expiry accepted for newly received HTLCs
# note: when changing, consider Blockchain.is_tip_stale()
MIN_FINAL_CLTV_DELTA_ACCEPTED = 144
# set it a tiny bit higher for invoices as blocks could get mined
# during forward path of payment
MIN_FINAL_CLTV_DELTA_FOR_INVOICE = MIN_FINAL_CLTV_DELTA_ACCEPTED + 3
MIN_FINAL_CLTV_DELTA_FOR_INVOICE = MIN_FINAL_CLTV_DELTA_ACCEPTED
# Buffer added to the min final cltv delta of all created bolt11 invoices so that the received htlcs
# locktime is still above the limit requested by the creator of the invoice even if some blocks got
# mined during forwarding
MIN_FINAL_CLTV_DELTA_BUFFER_INVOICE = 3

# the deadline for offered HTLCs:
# the deadline after which the channel has to be failed and timed out on-chain
Expand Down
119 changes: 81 additions & 38 deletions electrum/lnworker.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,8 @@
from concurrent import futures
import urllib.parse
import itertools
import dataclasses
from dataclasses import dataclass

import aiohttp
import dns.asyncresolver
Expand Down Expand Up @@ -67,7 +69,8 @@
LnKeyFamily, LOCAL, REMOTE, MIN_FINAL_CLTV_DELTA_FOR_INVOICE, SENT, RECEIVED, HTLCOwner, UpdateAddHtlc, LnFeatures,
ShortChannelID, HtlcLog, NoPathFound, InvalidGossipMsg, FeeBudgetExceeded, ImportedChannelBackupStorage,
OnchainChannelBackupStorage, ln_compare_features, IncompatibleLightningFeatures, PaymentFeeBudget,
NBLOCK_CLTV_DELTA_TOO_FAR_INTO_FUTURE, GossipForwardingMessage, MIN_FUNDING_SAT
NBLOCK_CLTV_DELTA_TOO_FAR_INTO_FUTURE, GossipForwardingMessage, MIN_FUNDING_SAT,
MIN_FINAL_CLTV_DELTA_BUFFER_INVOICE,
)
from .lnonion import decode_onion_error, OnionFailureCode, OnionRoutingFailure, OnionPacket
from .lnmsg import decode_msg
Expand Down Expand Up @@ -106,12 +109,44 @@ class PaymentDirection(IntEnum):
FORWARDING = 3


class PaymentInfo(NamedTuple):
payment_hash: bytes
@stored_in('lightning_payments')
@dataclass(frozen=True)
class PaymentInfo:
"""Information required to handle incoming htlcs for a payment request"""
rhash: str
amount_msat: Optional[int]
# direction is being used with PaymentDirection and lnutil.Direction?
direction: int
status: int
min_final_cltv_delta: int
# expiration can be used to clean-up PaymentInfo and fail htlcs coming in too late
expiry_delay: int
creation_ts: int = dataclasses.field(default_factory=lambda: int(time.time()))

@property
def payment_hash(self) -> bytes:
return bytes.fromhex(self.rhash)

@property
def expiration_ts(self):
return self.creation_ts + self.expiry_delay

def validate(self):
assert isinstance(self.rhash, str), type(self.rhash)
assert self.amount_msat is None or isinstance(self.amount_msat, int)
assert isinstance(self.direction, int)
assert isinstance(self.status, int)
assert isinstance(self.min_final_cltv_delta, int)
assert isinstance(self.expiry_delay, int) and self.expiry_delay > 0
assert isinstance(self.creation_ts, int)

def __post_init__(self):
self.validate()

def to_json(self):
# required because PaymentInfo doesn't inherit StoredObject so it can be declared frozen
self.validate()
return dataclasses.asdict(self)

# Note: these states are persisted in the wallet file.
# Do not modify them without performing a wallet db upgrade
Expand Down Expand Up @@ -869,7 +904,7 @@ def __init__(self, wallet: 'Abstract_Wallet', xprv):
LNWorker.__init__(self, self.node_keypair, features, config=self.config)
self.lnwatcher = LNWatcher(self)
self.lnrater: LNRater = None
self.payment_info = self.db.get_dict('lightning_payments') # RHASH -> amount, direction, is_paid
self.payment_info = self.db.get_dict('lightning_payments') # type: dict[str, PaymentInfo]
self._preimages = self.db.get_dict('lightning_preimages') # RHASH -> preimage
self._bolt11_cache = {}
# note: this sweep_address is only used as fallback; as it might result in address-reuse
Expand All @@ -896,6 +931,7 @@ def __init__(self, wallet: 'Abstract_Wallet', xprv):
self._paysessions = dict() # type: Dict[bytes, PaySession]
self.sent_htlcs_info = dict() # type: Dict[SentHtlcKey, SentHtlcInfo]
self.received_mpp_htlcs = self.db.get_dict('received_mpp_htlcs') # type: Dict[str, ReceivedMPPStatus] # payment_key -> ReceivedMPPStatus
self.verified_pending_htlcs = self.db.get_dict('verified_pending_htlcs') # type: Dict[str, None] # htlc_key, to keep track of checks that have to be done only once when receiving the htlc

# detect inflight payments
self.inflight_payments = set() # (not persisted) keys of invoices that are in PR_INFLIGHT state
Expand Down Expand Up @@ -1567,7 +1603,7 @@ async def pay_invoice(
raise PaymentFailure(_("A payment was already initiated for this invoice"))
if payment_hash in self.get_payments(status='inflight'):
raise PaymentFailure(_("A previous attempt to pay this invoice did not clear"))
info = PaymentInfo(payment_hash, amount_to_pay, SENT, PR_UNPAID)
info = PaymentInfo(key, amount_to_pay, SENT, PR_UNPAID, min_final_cltv_delta, LN_EXPIRY_NEVER)
self.save_payment_info(info)
self.wallet.set_label(key, lnaddr.get_description())
self.set_invoice_status(key, PR_INFLIGHT)
Expand Down Expand Up @@ -2238,17 +2274,13 @@ def clear_invoices_cache(self):

def get_bolt11_invoice(
self, *,
payment_hash: bytes,
amount_msat: Optional[int],
payment_info: PaymentInfo,
message: str,
expiry: int, # expiration of invoice (in seconds, relative)
fallback_address: Optional[str],
channels: Optional[Sequence[Channel]] = None,
min_final_cltv_expiry_delta: Optional[int] = None,
) -> Tuple[LnAddr, str]:
assert isinstance(payment_hash, bytes), f"expected bytes, but got {type(payment_hash)}"

pair = self._bolt11_cache.get(payment_hash)
amount_msat = payment_info.amount_msat
pair = self._bolt11_cache.get(payment_info.payment_hash)
if pair:
lnaddr, invoice = pair
assert lnaddr.get_amount_msat() == amount_msat
Expand All @@ -2265,27 +2297,24 @@ def get_bolt11_invoice(
if needs_jit:
# jit only works with single htlcs, mpp will cause LSP to open channels for each htlc
invoice_features &= ~ LnFeatures.BASIC_MPP_OPT & ~ LnFeatures.BASIC_MPP_REQ
payment_secret = self.get_payment_secret(payment_hash)
payment_secret = self.get_payment_secret(payment_info.payment_hash)
amount_btc = amount_msat/Decimal(COIN*1000) if amount_msat else None
if expiry == 0:
expiry = LN_EXPIRY_NEVER
if min_final_cltv_expiry_delta is None:
min_final_cltv_expiry_delta = MIN_FINAL_CLTV_DELTA_FOR_INVOICE
min_final_cltv_delta_requested = payment_info.min_final_cltv_delta + MIN_FINAL_CLTV_DELTA_BUFFER_INVOICE
lnaddr = LnAddr(
paymenthash=payment_hash,
paymenthash=payment_info.payment_hash,
amount=amount_btc,
tags=[
('d', message),
('c', min_final_cltv_expiry_delta),
('x', expiry),
('c', min_final_cltv_delta_requested),
('x', payment_info.expiry_delay),
('9', invoice_features),
('f', fallback_address),
] + routing_hints,
date=timestamp,
payment_secret=payment_secret)
invoice = lnencode(lnaddr, self.node_keypair.privkey)
pair = lnaddr, invoice
self._bolt11_cache[payment_hash] = pair
self._bolt11_cache[payment_info.payment_hash] = pair
return pair

def get_payment_secret(self, payment_hash):
Expand All @@ -2299,10 +2328,17 @@ def _get_payment_key(self, payment_hash: bytes) -> bytes:
payment_secret = self.get_payment_secret(payment_hash)
return payment_hash + payment_secret

def create_payment_info(self, *, amount_msat: Optional[int], write_to_disk=True) -> bytes:
def create_payment_info(
self, *,
amount_msat: Optional[int],
min_final_cltv_delta: Optional[int] = None,
exp_delay: int = LN_EXPIRY_NEVER,
write_to_disk=True
) -> bytes:
payment_preimage = os.urandom(32)
payment_hash = sha256(payment_preimage)
info = PaymentInfo(payment_hash, amount_msat, RECEIVED, PR_UNPAID)
min_final_cltv_delta = min_final_cltv_delta if min_final_cltv_delta else MIN_FINAL_CLTV_DELTA_FOR_INVOICE
info = PaymentInfo(payment_hash.hex(), amount_msat, RECEIVED, PR_UNPAID, min_final_cltv_delta, exp_delay)
self.save_preimage(payment_hash, payment_preimage, write_to_disk=False)
self.save_payment_info(info, write_to_disk=False)
if write_to_disk:
Expand Down Expand Up @@ -2374,14 +2410,17 @@ def get_payment_info(self, payment_hash: bytes) -> Optional[PaymentInfo]:
"""returns None if payment_hash is a payment we are forwarding"""
key = payment_hash.hex()
with self.lock:
if key in self.payment_info:
amount_msat, direction, status = self.payment_info[key]
return PaymentInfo(payment_hash, amount_msat, direction, status)
return None
return self.payment_info.get(key, None)

def add_payment_info_for_hold_invoice(self, payment_hash: bytes, lightning_amount_sat: Optional[int]):
def add_payment_info_for_hold_invoice(
self,
payment_hash: bytes, *,
lightning_amount_sat: Optional[int],
min_final_cltv_delta: int,
exp_delay: int,
):
amount = lightning_amount_sat * 1000 if lightning_amount_sat else None
info = PaymentInfo(payment_hash, amount, RECEIVED, PR_UNPAID)
info = PaymentInfo(payment_hash.hex(), amount, RECEIVED, PR_UNPAID, min_final_cltv_delta, exp_delay)
self.save_payment_info(info, write_to_disk=False)

def register_hold_invoice(self, payment_hash: bytes, cb: Callable[[bytes], Awaitable[None]]):
Expand All @@ -2396,11 +2435,13 @@ def save_payment_info(self, info: PaymentInfo, *, write_to_disk: bool = True) ->
if old_info := self.get_payment_info(payment_hash=info.payment_hash):
if info == old_info:
return # already saved
if info != old_info._replace(status=info.status):
if info.direction == SENT:
# allow saving of newer PaymentInfo if it is a sending attempt
old_info = dataclasses.replace(old_info, creation_ts=info.creation_ts)
if info != dataclasses.replace(old_info, status=info.status):
# differs more than in status. let's fail
raise Exception("payment_hash already in use")
key = info.payment_hash.hex()
self.payment_info[key] = info.amount_msat, info.direction, info.status
raise Exception(f"payment_hash already in use: {info=} != {old_info=}")
self.payment_info[info.rhash] = info
if write_to_disk:
self.wallet.save_db()

Expand Down Expand Up @@ -2577,7 +2618,7 @@ def set_payment_status(self, payment_hash: bytes, status: int) -> None:
if info is None:
# if we are forwarding
return
info = info._replace(status=status)
info = dataclasses.replace(info, status=status)
self.save_payment_info(info)

def is_forwarded_htlc(self, htlc_key) -> Optional[str]:
Expand Down Expand Up @@ -3016,12 +3057,14 @@ async def rebalance_channels(self, chan1: Channel, chan2: Channel, *, amount_msa
raise Exception('Rebalance requires two different channels')
if self.uses_trampoline() and chan1.node_id == chan2.node_id:
raise Exception('Rebalance requires channels from different trampolines')
payment_hash = self.create_payment_info(amount_msat=amount_msat)
lnaddr, invoice = self.get_bolt11_invoice(
payment_hash=payment_hash,
payment_hash = self.create_payment_info(
amount_msat=amount_msat,
exp_delay=3600,
)
info = self.get_payment_info(payment_hash)
lnaddr, invoice = self.get_bolt11_invoice(
payment_info=info,
message='rebalance',
expiry=3600,
fallback_address=None,
channels=[chan2],
)
Expand Down
Loading