Skip to content

Commit 0ed2a88

Browse files
imporve download functionality
1 parent 1e7c8e3 commit 0ed2a88

File tree

3 files changed

+102
-77
lines changed

3 files changed

+102
-77
lines changed

p2p/kademlia/dht.go

Lines changed: 59 additions & 30 deletions
Original file line numberDiff line numberDiff line change
@@ -40,9 +40,9 @@ const (
4040
defaultDeleteDataInterval = 11 * time.Hour
4141
delKeysCountThreshold = 10
4242
lowSpaceThreshold = 50 // GB
43-
batchStoreSize = 2500
43+
batchRetrieveSize = 1000
4444
storeSameSymbolsBatchConcurrency = 3
45-
storeSymbolsBatchConcurrency = 3.0
45+
fetchSymbolsBatchConcurrency = 6
4646
minimumDataStoreSuccessRate = 75.0
4747

4848
maxIterations = 4
@@ -734,10 +734,10 @@ func (s *DHT) BatchRetrieve(ctx context.Context, keys []string, required int32,
734734
return result, nil
735735
}
736736

737-
batchSize := batchStoreSize
737+
batchSize := batchRetrieveSize
738738
var networkFound int32
739739
totalBatches := int(math.Ceil(float64(required) / float64(batchSize)))
740-
parallelBatches := int(math.Min(float64(totalBatches), storeSymbolsBatchConcurrency))
740+
parallelBatches := int(math.Min(float64(totalBatches), fetchSymbolsBatchConcurrency))
741741

742742
semaphore := make(chan struct{}, parallelBatches)
743743
var wg sync.WaitGroup
@@ -775,7 +775,13 @@ func (s *DHT) BatchRetrieve(ctx context.Context, keys []string, required int32,
775775
wg.Wait()
776776

777777
netFound := int(atomic.LoadInt32(&networkFound))
778-
s.metrics.RecordBatchRetrieve(len(keys), int(required), int(foundLocalCount), netFound, time.Duration(time.Since(start).Milliseconds())) // NEW
778+
totalFound := int(foundLocalCount) + netFound
779+
780+
s.metrics.RecordBatchRetrieve(len(keys), int(required), int(foundLocalCount), netFound, time.Since(start))
781+
782+
if totalFound < int(required) {
783+
return result, errors.Errorf("insufficient symbols: required=%d, found=%d", required, totalFound)
784+
}
779785

780786
return result, nil
781787
}
@@ -800,7 +806,7 @@ func (s *DHT) processBatch(
800806
defer wg.Done()
801807
defer func() { <-semaphore }()
802808

803-
for i := 0; i < maxIterations; i++ {
809+
for i := 0; i < 1; i++ {
804810
select {
805811
case <-ctx.Done():
806812
return
@@ -822,9 +828,18 @@ func (s *DHT) processBatch(
822828
}
823829
}
824830

831+
knownMu.Lock()
832+
nodesSnap := make(map[string]*Node, len(knownNodes))
833+
for id, n := range knownNodes {
834+
nodesSnap[id] = n
835+
}
836+
knownMu.Unlock()
837+
825838
foundCount, newClosestContacts, batchErr := s.iterateBatchGetValues(
826-
ctx, knownNodes, batchKeys, batchHexKeys, fetchMap, resMap, required, foundLocalCount+atomic.LoadInt32(networkFound),
839+
ctx, nodesSnap, batchKeys, batchHexKeys, fetchMap, resMap,
840+
required, foundLocalCount+atomic.LoadInt32(networkFound),
827841
)
842+
828843
if batchErr != nil {
829844
logtrace.Error(ctx, "Iterate batch get values failed", logtrace.Fields{
830845
logtrace.FieldModule: "dht", "txid": txID, logtrace.FieldError: batchErr.Error(),
@@ -872,19 +887,36 @@ func (s *DHT) processBatch(
872887
}
873888
}
874889

875-
func (s *DHT) iterateBatchGetValues(ctx context.Context, nodes map[string]*Node, keys []string, hexKeys []string, fetchMap map[string][]int,
876-
resMap *sync.Map, req, alreadyFound int32) (int, map[string]*NodeList, error) {
877-
semaphore := make(chan struct{}, storeSameSymbolsBatchConcurrency) // Limit concurrency to 1
890+
func (s *DHT) iterateBatchGetValues(
891+
ctx context.Context,
892+
nodes map[string]*Node,
893+
keys []string,
894+
hexKeys []string,
895+
fetchMap map[string][]int,
896+
resMap *sync.Map,
897+
req, alreadyFound int32,
898+
) (int, map[string]*NodeList, error) {
899+
900+
semaphore := make(chan struct{}, storeSameSymbolsBatchConcurrency)
878901
closestContacts := make(map[string]*NodeList)
879902
var wg sync.WaitGroup
880903
contactsMap := make(map[string]map[string][]*Node)
881904
var firstErr error
882-
var mu sync.Mutex // To protect the firstErr
905+
var mu sync.Mutex
883906
foundCount := int32(0)
884907

885-
gctx, cancel := context.WithCancel(ctx) // Create a cancellable context
908+
gctx, cancel := context.WithCancel(ctx)
886909
defer cancel()
887-
for nodeID, node := range nodes {
910+
911+
// ✅ Iterate ONLY nodes that actually have work according to fetchMap
912+
for nodeID, idxs := range fetchMap {
913+
if len(idxs) == 0 {
914+
continue
915+
}
916+
node, ok := nodes[nodeID]
917+
if !ok {
918+
continue
919+
}
888920
if s.ignorelist.Banned(node) {
889921
logtrace.Info(ctx, "Ignore banned node in iterate batch get values", logtrace.Fields{
890922
logtrace.FieldModule: "dht",
@@ -894,8 +926,9 @@ func (s *DHT) iterateBatchGetValues(ctx context.Context, nodes map[string]*Node,
894926
}
895927

896928
contactsMap[nodeID] = make(map[string][]*Node)
929+
897930
wg.Add(1)
898-
go func(node *Node, nodeID string) {
931+
go func(node *Node, nodeID string, indices []int) {
899932
defer wg.Done()
900933

901934
select {
@@ -907,17 +940,15 @@ func (s *DHT) iterateBatchGetValues(ctx context.Context, nodes map[string]*Node,
907940
defer func() { <-semaphore }()
908941
}
909942

910-
indices := fetchMap[nodeID]
911-
requestKeys := make(map[string]KeyValWithClosest)
943+
// Build requestKeys from the provided indices only
944+
requestKeys := make(map[string]KeyValWithClosest, len(indices))
912945
for _, idx := range indices {
913-
if idx < len(hexKeys) {
914-
_, loaded := resMap.Load(hexKeys[idx]) // check if key is already there in resMap
915-
if !loaded {
946+
if idx >= 0 && idx < len(hexKeys) {
947+
if _, loaded := resMap.Load(hexKeys[idx]); !loaded {
916948
requestKeys[hexKeys[idx]] = KeyValWithClosest{}
917949
}
918950
}
919951
}
920-
921952
if len(requestKeys) == 0 {
922953
return
923954
}
@@ -932,21 +963,20 @@ func (s *DHT) iterateBatchGetValues(ctx context.Context, nodes map[string]*Node,
932963
return
933964
}
934965

966+
// Merge values or closest contacts
935967
for k, v := range decompressedData {
936968
if len(v.Value) > 0 {
937-
_, loaded := resMap.LoadOrStore(k, v.Value)
938-
if !loaded {
939-
atomic.AddInt32(&foundCount, 1)
940-
if atomic.LoadInt32(&foundCount) >= int32(req-alreadyFound) {
941-
cancel() // Cancel context to stop other goroutines
969+
if _, loaded := resMap.LoadOrStore(k, v.Value); !loaded {
970+
if atomic.AddInt32(&foundCount, 1) >= int32(req-alreadyFound) {
971+
cancel()
942972
return
943973
}
944974
}
945975
} else {
946976
contactsMap[nodeID][k] = v.Closest
947977
}
948978
}
949-
}(node, nodeID)
979+
}(node, nodeID, idxs)
950980
}
951981

952982
wg.Wait()
@@ -964,27 +994,26 @@ func (s *DHT) iterateBatchGetValues(ctx context.Context, nodes map[string]*Node,
964994
})
965995
}
966996

997+
// Build closestContacts from contactsMap (same as before)
967998
for _, closestNodes := range contactsMap {
968999
for key, nodes := range closestNodes {
9691000
comparator, err := hex.DecodeString(key)
9701001
if err != nil {
971-
logtrace.Error(ctx, "Failed to decode hex key in closestNodes.Range", logtrace.Fields{
1002+
logtrace.Error(ctx, "Failed to decode hex key in closestNodes", logtrace.Fields{
9721003
logtrace.FieldModule: "dht",
9731004
"key": key,
9741005
logtrace.FieldError: err.Error(),
9751006
})
976-
return 0, nil, err
1007+
return int(foundCount), nil, err
9771008
}
9781009
bkey := base58.Encode(comparator)
979-
9801010
if _, ok := closestContacts[bkey]; !ok {
9811011
closestContacts[bkey] = &NodeList{Nodes: nodes, Comparator: comparator}
9821012
} else {
9831013
closestContacts[bkey].AddNodes(nodes)
9841014
}
9851015
}
9861016
}
987-
9881017
for key, nodes := range closestContacts {
9891018
nodes.Sort()
9901019
nodes.TopN(Alpha)

supernode/services/cascade/download.go

Lines changed: 27 additions & 27 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@ import (
1717
)
1818

1919
const (
20-
requiredSymbolPercent = 9
20+
requiredSymbolPercent = 17
2121
)
2222

2323
type DownloadRequest struct {
@@ -36,8 +36,8 @@ func (task *CascadeRegistrationTask) Download(
3636
req *DownloadRequest,
3737
send func(resp *DownloadResponse) error,
3838
) (err error) {
39-
fields := logtrace.Fields{logtrace.FieldMethod: "Download", logtrace.FieldRequest: req}
40-
logtrace.Info(ctx, "Cascade download request received", fields)
39+
fields := logtrace.Fields{logtrace.FieldMethod: "Download", logtrace.FieldRequest: req}
40+
logtrace.Info(ctx, "Cascade download request received", fields)
4141

4242
// Ensure task status is finalized regardless of outcome
4343
defer func() {
@@ -54,36 +54,36 @@ func (task *CascadeRegistrationTask) Download(
5454
fields[logtrace.FieldError] = err
5555
return task.wrapErr(ctx, "failed to get action", err, fields)
5656
}
57-
logtrace.Info(ctx, "Action retrieved", fields)
58-
task.streamDownloadEvent(SupernodeEventTypeActionRetrieved, "Action retrieved", "", "", send)
57+
logtrace.Info(ctx, "Action retrieved", fields)
58+
task.streamDownloadEvent(SupernodeEventTypeActionRetrieved, "Action retrieved", "", "", send)
5959

6060
if actionDetails.GetAction().State != actiontypes.ActionStateDone {
6161
err = errors.New("action is not in a valid state")
6262
fields[logtrace.FieldError] = "action state is not done yet"
6363
fields[logtrace.FieldActionState] = actionDetails.GetAction().State
6464
return task.wrapErr(ctx, "action not found", err, fields)
6565
}
66-
logtrace.Info(ctx, "Action state validated", fields)
66+
logtrace.Info(ctx, "Action state validated", fields)
6767

6868
metadata, err := task.decodeCascadeMetadata(ctx, actionDetails.GetAction().Metadata, fields)
6969
if err != nil {
7070
fields[logtrace.FieldError] = err.Error()
7171
return task.wrapErr(ctx, "error decoding cascade metadata", err, fields)
7272
}
73-
logtrace.Info(ctx, "Cascade metadata decoded", fields)
74-
task.streamDownloadEvent(SupernodeEventTypeMetadataDecoded, "Cascade metadata decoded", "", "", send)
73+
logtrace.Info(ctx, "Cascade metadata decoded", fields)
74+
task.streamDownloadEvent(SupernodeEventTypeMetadataDecoded, "Cascade metadata decoded", "", "", send)
7575

76-
// Notify: network retrieval phase begins
77-
task.streamDownloadEvent(SupernodeEventTypeNetworkRetrieveStarted, "Network retrieval started", "", "", send)
76+
// Notify: network retrieval phase begins
77+
task.streamDownloadEvent(SupernodeEventTypeNetworkRetrieveStarted, "Network retrieval started", "", "", send)
7878

79-
filePath, tmpDir, err := task.downloadArtifacts(ctx, actionDetails.GetAction().ActionID, metadata, fields)
80-
if err != nil {
81-
fields[logtrace.FieldError] = err.Error()
82-
return task.wrapErr(ctx, "failed to download artifacts", err, fields)
83-
}
84-
logtrace.Info(ctx, "File reconstructed and hash verified", fields)
85-
// Notify: decode completed, file ready on disk
86-
task.streamDownloadEvent(SupernodeEventTypeDecodeCompleted, "Decode completed", filePath, tmpDir, send)
79+
filePath, tmpDir, err := task.downloadArtifacts(ctx, actionDetails.GetAction().ActionID, metadata, fields)
80+
if err != nil {
81+
fields[logtrace.FieldError] = err.Error()
82+
return task.wrapErr(ctx, "failed to download artifacts", err, fields)
83+
}
84+
logtrace.Info(ctx, "File reconstructed and hash verified", fields)
85+
// Notify: decode completed, file ready on disk
86+
task.streamDownloadEvent(SupernodeEventTypeDecodeCompleted, "Decode completed", filePath, tmpDir, send)
8787

8888
return nil
8989
}
@@ -147,15 +147,15 @@ func (task *CascadeRegistrationTask) restoreFileFromLayout(
147147

148148
fields["totalSymbols"] = totalSymbols
149149
fields["requiredSymbols"] = requiredSymbols
150-
logtrace.Info(ctx, "Symbols to be retrieved", fields)
150+
logtrace.Info(ctx, "Symbols to be retrieved", fields)
151151

152-
// Progressive retrieval moved to helper for readability/testing
153-
decodeInfo, err := task.retrieveAndDecodeProgressively(ctx, layout, actionID, fields)
154-
if err != nil {
155-
fields[logtrace.FieldError] = err.Error()
156-
logtrace.Error(ctx, "failed to decode symbols progressively", fields)
157-
return "", "", fmt.Errorf("decode symbols using RaptorQ: %w", err)
158-
}
152+
// Progressive retrieval moved to helper for readability/testing
153+
decodeInfo, err := task.retrieveAndDecodeProgressively(ctx, layout, actionID, fields)
154+
if err != nil {
155+
fields[logtrace.FieldError] = err.Error()
156+
logtrace.Error(ctx, "failed to decode symbols progressively", fields)
157+
return "", "", fmt.Errorf("decode symbols using RaptorQ: %w", err)
158+
}
159159

160160
fileHash, err := crypto.HashFileIncrementally(decodeInfo.FilePath, 0)
161161
if err != nil {
@@ -175,7 +175,7 @@ func (task *CascadeRegistrationTask) restoreFileFromLayout(
175175
fields[logtrace.FieldError] = err.Error()
176176
return "", decodeInfo.DecodeTmpDir, err
177177
}
178-
logtrace.Info(ctx, "File successfully restored and hash verified", fields)
178+
logtrace.Info(ctx, "File successfully restored and hash verified", fields)
179179

180180
return decodeInfo.FilePath, decodeInfo.DecodeTmpDir, nil
181181
}

supernode/services/cascade/progressive_decode.go

Lines changed: 16 additions & 20 deletions
Original file line numberDiff line numberDiff line change
@@ -10,16 +10,10 @@ import (
1010
)
1111

1212
// retrieveAndDecodeProgressively performs a minimal two-step retrieval for a single-block layout:
13-
// 1) Fetch approximately requiredSymbolPercent of symbols and try decoding.
14-
// 2) If that fails, fetch all available symbols from the block and try again.
15-
// This replaces earlier multi-block balancing and multi-threshold escalation.
16-
func (task *CascadeRegistrationTask) retrieveAndDecodeProgressively(
17-
ctx context.Context,
18-
layout codec.Layout,
19-
actionID string,
20-
fields logtrace.Fields,
21-
) (adaptors.DecodeResponse, error) {
22-
// Ensure base context fields are present for logs
13+
// 1) Send ALL keys with a minimum required count (requiredSymbolPercent).
14+
// 2) If decode fails, escalate by asking for ALL symbols (required = total).
15+
func (task *CascadeRegistrationTask) retrieveAndDecodeProgressively(ctx context.Context, layout codec.Layout, actionID string,
16+
fields logtrace.Fields) (adaptors.DecodeResponse, error) {
2317
if fields == nil {
2418
fields = logtrace.Fields{}
2519
}
@@ -29,28 +23,27 @@ func (task *CascadeRegistrationTask) retrieveAndDecodeProgressively(
2923
return adaptors.DecodeResponse{}, fmt.Errorf("empty layout: no blocks")
3024
}
3125

32-
// Single-block fast path
26+
// Single-block path
3327
if len(layout.Blocks) == 1 {
3428
blk := layout.Blocks[0]
3529
total := len(blk.Symbols)
3630
if total == 0 {
3731
return adaptors.DecodeResponse{}, fmt.Errorf("empty layout: no symbols")
3832
}
3933

40-
// Step 1: try with requiredSymbolPercent of symbols
34+
// Step 1: send ALL keys, require only reqCount
4135
reqCount := (total*requiredSymbolPercent + 99) / 100
4236
if reqCount < 1 {
4337
reqCount = 1
44-
}
45-
if reqCount > total {
38+
} else if reqCount > total {
4639
reqCount = total
4740
}
4841
fields["targetPercent"] = requiredSymbolPercent
4942
fields["targetCount"] = reqCount
43+
fields["total"] = total
5044
logtrace.Info(ctx, "retrieving initial symbols (single block)", fields)
5145

52-
keys := blk.Symbols[:reqCount]
53-
symbols, err := task.P2PClient.BatchRetrieve(ctx, keys, reqCount, actionID)
46+
symbols, err := task.P2PClient.BatchRetrieve(ctx, blk.Symbols, reqCount, actionID)
5447
if err != nil {
5548
fields[logtrace.FieldError] = err.Error()
5649
logtrace.Error(ctx, "failed to retrieve symbols", fields)
@@ -66,21 +59,24 @@ func (task *CascadeRegistrationTask) retrieveAndDecodeProgressively(
6659
return decodeInfo, nil
6760
}
6861

69-
// Step 2: escalate to all symbols
70-
logtrace.Info(ctx, "initial decode failed; retrieving all symbols (single block)", nil)
71-
symbols, err = task.P2PClient.BatchRetrieve(ctx, blk.Symbols, total, actionID)
62+
// Step 2: escalate to require ALL symbols
63+
fields["escalating"] = true
64+
fields["requiredCount"] = total
65+
logtrace.Info(ctx, "initial decode failed; retrieving all symbols (single block)", fields)
66+
67+
symbols, err = task.P2PClient.BatchRetrieve(ctx, blk.Symbols, reqCount*2, actionID)
7268
if err != nil {
7369
fields[logtrace.FieldError] = err.Error()
7470
logtrace.Error(ctx, "failed to retrieve all symbols", fields)
7571
return adaptors.DecodeResponse{}, fmt.Errorf("failed to retrieve symbols: %w", err)
7672
}
73+
7774
return task.RQ.Decode(ctx, adaptors.DecodeRequest{
7875
ActionID: actionID,
7976
Symbols: symbols,
8077
Layout: layout,
8178
})
8279
}
8380

84-
// Multi-block layouts are not supported by current policy
8581
return adaptors.DecodeResponse{}, fmt.Errorf("unsupported layout: expected 1 block, found %d", len(layout.Blocks))
8682
}

0 commit comments

Comments
 (0)