Skip to content

Commit 26e8bcd

Browse files
committed
paymentsdb: implement RegisterAttempt for sql backend
1 parent 073d4dc commit 26e8bcd

File tree

1 file changed

+246
-0
lines changed

1 file changed

+246
-0
lines changed

payments/db/sql_store.go

Lines changed: 246 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,10 +6,12 @@ import (
66
"errors"
77
"fmt"
88
"math"
9+
"strconv"
910
"time"
1011

1112
"github.com/lightningnetwork/lnd/lntypes"
1213
"github.com/lightningnetwork/lnd/lnwire"
14+
"github.com/lightningnetwork/lnd/routing/route"
1315
"github.com/lightningnetwork/lnd/sqldb"
1416
"github.com/lightningnetwork/lnd/sqldb/sqlc"
1517
)
@@ -1044,3 +1046,247 @@ func (s *SQLStore) InitPayment(paymentHash lntypes.Hash,
10441046

10451047
return nil
10461048
}
1049+
1050+
// insertRouteHops inserts all route hop data for a given set of hops.
1051+
func (s *SQLStore) insertRouteHops(ctx context.Context, db SQLQueries,
1052+
hops []*route.Hop, attemptID uint64) error {
1053+
1054+
for i, hop := range hops {
1055+
// Insert the basic route hop data and get the generated ID.
1056+
hopID, err := db.InsertRouteHop(ctx, sqlc.InsertRouteHopParams{
1057+
HtlcAttemptIndex: int64(attemptID),
1058+
HopIndex: int32(i),
1059+
PubKey: hop.PubKeyBytes[:],
1060+
Scid: strconv.FormatUint(
1061+
hop.ChannelID, 10,
1062+
),
1063+
OutgoingTimeLock: int32(hop.OutgoingTimeLock),
1064+
AmtToForward: int64(hop.AmtToForward),
1065+
MetaData: hop.Metadata,
1066+
})
1067+
if err != nil {
1068+
return fmt.Errorf("failed to insert route hop: %w", err)
1069+
}
1070+
1071+
// Insert the per-hop custom records.
1072+
if len(hop.CustomRecords) > 0 {
1073+
for key, value := range hop.CustomRecords {
1074+
err = db.InsertPaymentHopCustomRecord(
1075+
ctx,
1076+
sqlc.InsertPaymentHopCustomRecordParams{
1077+
HopID: hopID,
1078+
Key: int64(key),
1079+
Value: value,
1080+
})
1081+
if err != nil {
1082+
return fmt.Errorf("failed to insert "+
1083+
"payment hop custom record: %w",
1084+
err)
1085+
}
1086+
}
1087+
}
1088+
1089+
// Insert MPP data if present.
1090+
if hop.MPP != nil {
1091+
paymentAddr := hop.MPP.PaymentAddr()
1092+
err = db.InsertRouteHopMpp(
1093+
ctx, sqlc.InsertRouteHopMppParams{
1094+
HopID: hopID,
1095+
PaymentAddr: paymentAddr[:],
1096+
TotalMsat: int64(hop.MPP.TotalMsat()),
1097+
})
1098+
if err != nil {
1099+
return fmt.Errorf("failed to insert "+
1100+
"route hop MPP: %w", err)
1101+
}
1102+
}
1103+
1104+
// Insert AMP data if present.
1105+
if hop.AMP != nil {
1106+
rootShare := hop.AMP.RootShare()
1107+
setID := hop.AMP.SetID()
1108+
err = db.InsertRouteHopAmp(
1109+
ctx, sqlc.InsertRouteHopAmpParams{
1110+
HopID: hopID,
1111+
RootShare: rootShare[:],
1112+
SetID: setID[:],
1113+
ChildIndex: int32(hop.AMP.ChildIndex()),
1114+
})
1115+
if err != nil {
1116+
return fmt.Errorf("failed to insert "+
1117+
"route hop AMP: %w", err)
1118+
}
1119+
}
1120+
1121+
// Insert blinded route data if present. Every hop in the
1122+
// blinded path must have an encrypted data record. If the
1123+
// encrypted data is not present, we skip the insertion.
1124+
if hop.EncryptedData == nil {
1125+
continue
1126+
}
1127+
1128+
// The introduction point has a blinding point set.
1129+
var blindingPointBytes []byte
1130+
if hop.BlindingPoint != nil {
1131+
blindingPointBytes = hop.BlindingPoint.
1132+
SerializeCompressed()
1133+
}
1134+
1135+
// The total amount is only set for the final hop in a
1136+
// blinded path.
1137+
totalAmtMsat := sql.NullInt64{}
1138+
if i == len(hops)-1 {
1139+
totalAmtMsat = sql.NullInt64{
1140+
Int64: int64(hop.TotalAmtMsat),
1141+
Valid: true,
1142+
}
1143+
}
1144+
1145+
err = db.InsertRouteHopBlinded(ctx,
1146+
sqlc.InsertRouteHopBlindedParams{
1147+
HopID: hopID,
1148+
EncryptedData: hop.EncryptedData,
1149+
BlindingPoint: blindingPointBytes,
1150+
BlindedPathTotalAmt: totalAmtMsat,
1151+
},
1152+
)
1153+
if err != nil {
1154+
return fmt.Errorf("failed to insert "+
1155+
"route hop blinded: %w", err)
1156+
}
1157+
}
1158+
1159+
return nil
1160+
}
1161+
1162+
// RegisterAttempt atomically records a new HTLC attempt for the specified
1163+
// payment. The attempt includes the attempt ID, session key, route information
1164+
// (hops, timelocks, amounts), and optional data such as MPP/AMP parameters,
1165+
// blinded route data, and custom records.
1166+
//
1167+
// Returns the updated MPPayment with the new attempt appended to the HTLCs
1168+
// slice, and the payment state recalculated. Returns an error if the payment
1169+
// doesn't exist or validation fails.
1170+
//
1171+
// This method is part of the PaymentControl interface, which is embedded in
1172+
// the PaymentWriter interface and ultimately the DB interface. It represents
1173+
// step 2 in the payment lifecycle control flow, called after InitPayment and
1174+
// potentially multiple times for multi-path payments.
1175+
func (s *SQLStore) RegisterAttempt(paymentHash lntypes.Hash,
1176+
attempt *HTLCAttemptInfo) (*MPPayment, error) {
1177+
1178+
ctx := context.TODO()
1179+
1180+
var mpPayment *MPPayment
1181+
1182+
err := s.db.ExecTx(ctx, sqldb.WriteTxOpt(), func(db SQLQueries) error {
1183+
// First Fetch the payment and check if it is registrable.
1184+
existingPayment, err := db.FetchPayment(ctx, paymentHash[:])
1185+
if err != nil {
1186+
return fmt.Errorf("failed to fetch payment: %w", err)
1187+
}
1188+
1189+
// We fetch the complete payment to determine if the payment is
1190+
// registrable.
1191+
//
1192+
// TODO(ziggie): We could improve the query here since only
1193+
// the last hop data is needed here not the complete payment
1194+
// data.
1195+
mpPayment, err = s.fetchPaymentWithCompleteData(
1196+
ctx, db, existingPayment,
1197+
)
1198+
if err != nil {
1199+
return fmt.Errorf("failed to fetch payment with "+
1200+
"complete data: %w", err)
1201+
}
1202+
1203+
if err := mpPayment.Registrable(); err != nil {
1204+
return fmt.Errorf("htlc attempt not registrable: %w",
1205+
err)
1206+
}
1207+
1208+
// Verify the attempt is compatible with the existing payment.
1209+
if err := verifyAttempt(mpPayment, attempt); err != nil {
1210+
return fmt.Errorf("failed to verify attempt: %w", err)
1211+
}
1212+
1213+
// Register the plain HTLC attempt next.
1214+
sessionKey := attempt.SessionKey()
1215+
sessionKeyBytes := sessionKey.Serialize()
1216+
1217+
_, err = db.InsertHtlcAttempt(ctx, sqlc.InsertHtlcAttemptParams{
1218+
PaymentID: existingPayment.Payment.ID,
1219+
AttemptIndex: int64(attempt.AttemptID),
1220+
SessionKey: sessionKeyBytes,
1221+
AttemptTime: attempt.AttemptTime,
1222+
PaymentHash: paymentHash[:],
1223+
FirstHopAmountMsat: int64(
1224+
attempt.Route.FirstHopAmount.Val.Int(),
1225+
),
1226+
RouteTotalTimeLock: int32(attempt.Route.TotalTimeLock),
1227+
RouteTotalAmount: int64(attempt.Route.TotalAmount),
1228+
RouteSourceKey: attempt.Route.SourcePubKey[:],
1229+
})
1230+
if err != nil {
1231+
return fmt.Errorf("failed to insert HTLC "+
1232+
"attempt: %w", err)
1233+
}
1234+
1235+
// Insert the route level first hop custom records.
1236+
attemptFirstHopCustomRecords := attempt.Route.
1237+
FirstHopWireCustomRecords
1238+
1239+
for key, value := range attemptFirstHopCustomRecords {
1240+
//nolint:ll
1241+
err = db.InsertPaymentAttemptFirstHopCustomRecord(
1242+
ctx,
1243+
sqlc.InsertPaymentAttemptFirstHopCustomRecordParams{
1244+
HtlcAttemptIndex: int64(attempt.AttemptID),
1245+
Key: int64(key),
1246+
Value: value,
1247+
},
1248+
)
1249+
if err != nil {
1250+
return fmt.Errorf("failed to insert "+
1251+
"payment attempt first hop custom "+
1252+
"record: %w", err)
1253+
}
1254+
}
1255+
1256+
// Insert the route hops.
1257+
err = s.insertRouteHops(
1258+
ctx, db, attempt.Route.Hops, attempt.AttemptID,
1259+
)
1260+
if err != nil {
1261+
return fmt.Errorf("failed to insert route hops: %w",
1262+
err)
1263+
}
1264+
1265+
// We fetch the HTLC attempts again to recalculate the payment
1266+
// state after the attempt is registered. This also makes sure
1267+
// we have the right data in case multiple attempts are
1268+
// registered concurrently.
1269+
//
1270+
// NOTE: While the caller is responsible for serializing calls
1271+
// to RegisterAttempt per payment hash (see PaymentControl
1272+
// interface), we still refetch here to guarantee we return
1273+
// consistent, up-to-date data that reflects all changes made
1274+
// within this transaction.
1275+
mpPayment, err = s.fetchPaymentWithCompleteData(
1276+
ctx, db, existingPayment,
1277+
)
1278+
if err != nil {
1279+
return fmt.Errorf("failed to fetch payment with "+
1280+
"complete data: %w", err)
1281+
}
1282+
1283+
return nil
1284+
}, func() {
1285+
mpPayment = nil
1286+
})
1287+
if err != nil {
1288+
return nil, fmt.Errorf("failed to register attempt: %w", err)
1289+
}
1290+
1291+
return mpPayment, nil
1292+
}

0 commit comments

Comments
 (0)