Skip to content

Commit ef25cd6

Browse files
fix p2p concurrency issues
1 parent 98eb76d commit ef25cd6

File tree

4 files changed

+191
-90
lines changed

4 files changed

+191
-90
lines changed

p2p/kademlia/conn_pool.go

Lines changed: 20 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -33,8 +33,6 @@ func NewConnPool(ctx context.Context) *ConnPool {
3333
conns: map[string]*connectionItem{},
3434
}
3535

36-
pool.StartConnEviction(ctx)
37-
3836
return pool
3937
}
4038

@@ -116,41 +114,48 @@ type connWrapper struct {
116114
mtx sync.Mutex
117115
}
118116

119-
// NewSecureClientConn do client handshake and return a secure connection
117+
// NewSecureClientConn does client handshake and returns a secure, pooled-ready connection.
120118
func NewSecureClientConn(ctx context.Context, tc credentials.TransportCredentials, remoteAddr string) (net.Conn, error) {
121-
// Extract identity if in Lumera format
119+
// Extract identity if in Lumera format (e.g., "<bech32>@ip:port")
122120
remoteIdentity, remoteAddress, err := ltc.ExtractIdentity(remoteAddr, true)
123121
if err != nil {
124122
return nil, fmt.Errorf("invalid address format: %w", err)
125123
}
126124

127-
lumeraTC, ok := tc.(*ltc.LumeraTC)
125+
base, ok := tc.(*ltc.LumeraTC)
128126
if !ok {
129127
return nil, fmt.Errorf("invalid credentials type")
130128
}
131129

132-
// Set remote identity in credentials
133-
lumeraTC.SetRemoteIdentity(remoteIdentity)
130+
// Per-connection clone; set remote identity on the clone only.
131+
cloned, ok := base.Clone().(*ltc.LumeraTC)
132+
if !ok {
133+
return nil, fmt.Errorf("failed to clone LumeraTC")
134+
}
135+
cloned.SetRemoteIdentity(remoteIdentity)
134136

135-
// dial the remote address with tcp
136-
var d net.Dialer
137+
// Dial the remote address with a short timeout.
138+
d := net.Dialer{
139+
Timeout: 3 * time.Second,
140+
KeepAlive: 30 * time.Second,
141+
}
137142
rawConn, err := d.DialContext(ctx, "tcp", remoteAddress)
138-
139143
if err != nil {
140144
return nil, errors.Errorf("dial %q: %w", remoteAddress, err)
141145
}
142146

143-
// set the deadline for read and write
144-
rawConn.SetDeadline(time.Now().UTC().Add(defaultConnDeadline))
147+
// Clear any global deadline; per-RPC deadlines are set in Network.Call.
148+
_ = rawConn.SetDeadline(time.Time{})
145149

146-
conn, _, err := tc.ClientHandshake(ctx, "", rawConn)
150+
// TLS/ALTS-ish client handshake using the per-connection cloned creds.
151+
secureConn, _, err := cloned.ClientHandshake(ctx, "", rawConn)
147152
if err != nil {
148-
rawConn.Close()
153+
_ = rawConn.Close()
149154
return nil, errors.Errorf("client secure establish %q: %w", remoteAddress, err)
150155
}
151156

152157
return &connWrapper{
153-
secureConn: conn,
158+
secureConn: secureConn,
154159
rawConn: rawConn,
155160
}, nil
156161
}
@@ -224,31 +229,3 @@ func (conn *connWrapper) SetWriteDeadline(t time.Time) error {
224229
defer conn.mtx.Unlock()
225230
return conn.secureConn.SetWriteDeadline(t)
226231
}
227-
228-
// StartConnEviction starts a goroutine that periodically evicts idle connections.
229-
func (pool *ConnPool) StartConnEviction(ctx context.Context) {
230-
go func() {
231-
ticker := time.NewTicker(time.Minute) // adjust as necessary
232-
defer ticker.Stop()
233-
234-
for {
235-
select {
236-
case <-ticker.C:
237-
pool.mtx.Lock()
238-
239-
for addr, item := range pool.conns {
240-
if time.Since(item.lastAccess) > defaultConnDeadline {
241-
_ = item.conn.Close()
242-
delete(pool.conns, addr)
243-
}
244-
}
245-
246-
pool.mtx.Unlock()
247-
248-
case <-ctx.Done():
249-
// Stop the goroutine when the context is cancelled
250-
return
251-
}
252-
}
253-
}()
254-
}

p2p/kademlia/dht.go

Lines changed: 50 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -8,6 +8,7 @@ import (
88
"math"
99
"net"
1010
"net/url"
11+
"strings"
1112
"sync"
1213
"sync/atomic"
1314
"time"
@@ -399,6 +400,12 @@ func (s *DHT) Stats(ctx context.Context) (map[string]interface{}, error) {
399400
func (s *DHT) newMessage(messageType int, receiver *Node, data interface{}) *Message {
400401
supernodeAddr := s.getCachedSupernodeAddress()
401402
hostIP := parseSupernodeAddress(supernodeAddr)
403+
404+
// If fallback produced an invalid address (e.g., 0.0.0.0), do not advertise it.
405+
if ip := net.ParseIP(hostIP); ip == nil || ip.IsUnspecified() || ip.IsLoopback() || ip.IsPrivate() {
406+
hostIP = ""
407+
}
408+
402409
sender := &Node{
403410
IP: hostIP,
404411
ID: s.ht.self.ID,
@@ -1270,6 +1277,13 @@ func (s *DHT) sendStoreData(ctx context.Context, n *Node, request *StoreDataRequ
12701277

12711278
// add a node into the appropriate k bucket, return the removed node if it's full
12721279
func (s *DHT) addNode(ctx context.Context, node *Node) *Node {
1280+
if node.IP == "" || node.IP == "0.0.0.0" || node.IP == "127.0.0.1" {
1281+
logtrace.Info(ctx, "Trying to add invalid node", logtrace.Fields{
1282+
logtrace.FieldModule: "p2p",
1283+
})
1284+
return nil
1285+
}
1286+
12731287
// ensure this is not itself address
12741288
if bytes.Equal(node.ID, s.ht.self.ID) {
12751289
logtrace.Info(ctx, "Trying to add itself", logtrace.Fields{
@@ -1524,6 +1538,26 @@ func (s *DHT) addKnownNodes(ctx context.Context, nodes []*Node, knownNodes map[s
15241538
if _, ok := knownNodes[string(node.ID)]; ok {
15251539
continue
15261540
}
1541+
1542+
// Reject bind/local/link-local/private/bogus addresses early
1543+
if ip := net.ParseIP(node.IP); ip != nil {
1544+
if ip.IsUnspecified() || ip.IsLoopback() || ip.IsLinkLocalUnicast() || ip.IsLinkLocalMulticast() {
1545+
s.ignorelist.IncrementCount(node)
1546+
continue
1547+
}
1548+
// If this overlay is public, also reject RFC1918/CGNAT:
1549+
if ip.IsPrivate() {
1550+
s.ignorelist.IncrementCount(node)
1551+
continue
1552+
}
1553+
} else {
1554+
// Hostname: basic sanity (must look like a FQDN)
1555+
if !strings.Contains(node.IP, ".") {
1556+
s.ignorelist.IncrementCount(node)
1557+
continue
1558+
}
1559+
}
1560+
15271561
node.SetHashedID()
15281562
knownNodes[string(node.ID)] = node
15291563

@@ -1717,7 +1751,14 @@ func (s *DHT) IterateBatchStore(ctx context.Context, values [][]byte, typ int, i
17171751

17181752
func (s *DHT) batchStoreNetwork(ctx context.Context, values [][]byte, nodes map[string]*Node, storageMap map[string][]int, typ int) chan *MessageWithError {
17191753
responses := make(chan *MessageWithError, len(nodes))
1720-
semaphore := make(chan struct{}, 3) // Semaphore to limit concurrency to 3
1754+
maxStore := 16
1755+
if ln := len(nodes); ln < maxStore {
1756+
maxStore = ln
1757+
}
1758+
if maxStore < 1 {
1759+
maxStore = 1
1760+
}
1761+
semaphore := make(chan struct{}, maxStore)
17211762

17221763
var wg sync.WaitGroup
17231764

@@ -1794,7 +1835,14 @@ func (s *DHT) batchFindNode(ctx context.Context, payload [][]byte, nodes map[str
17941835
responses := make(chan *MessageWithError, len(nodes))
17951836
atleastOneContacted := false
17961837
var wg sync.WaitGroup
1797-
semaphore := make(chan struct{}, 20)
1838+
maxInFlight := 64
1839+
if ln := len(nodes); ln < maxInFlight {
1840+
maxInFlight = ln
1841+
}
1842+
if maxInFlight < 1 {
1843+
maxInFlight = 1
1844+
}
1845+
semaphore := make(chan struct{}, maxInFlight)
17981846

17991847
for _, node := range nodes {
18001848
if _, ok := contacted[string(node.ID)]; ok {

0 commit comments

Comments
 (0)