Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -533,6 +533,56 @@ func setupInterruptionTestWithMLServer(t *testing.T, mlBehavior *mockMLNodeBehav
inferenceUpCmd := broker.NewInferenceUpAllCommand()
err = suite.nodeBroker.QueueMessage(inferenceUpCmd)
require.NoError(t, err)
<-inferenceUpCmd.Response

waitForStableInferenceNode := func() bool {
nodes, nodesErr := suite.nodeBroker.GetNodes()
require.NoError(t, nodesErr)
for _, n := range nodes {
if n.Node.Id == nodeConfig.Id &&
n.State.IntendedStatus == types.HardwareNodeStatus_INFERENCE &&
n.State.CurrentStatus == types.HardwareNodeStatus_INFERENCE &&
n.State.ReconcileInfo == nil {
return true
}
}
return false
}

nodeReady := false
deadline := time.Now().Add(2 * time.Second)
for time.Now().Before(deadline) {
if waitForStableInferenceNode() {
nodeReady = true
break
}
time.Sleep(10 * time.Millisecond)
}
if !nodeReady {
setStatusCommand := broker.NewSetNodesActualStatusCommand(
[]broker.StatusUpdate{
{
NodeId: nodeConfig.Id,
PrevStatus: types.HardwareNodeStatus_UNKNOWN,
NewStatus: types.HardwareNodeStatus_INFERENCE,
Timestamp: time.Now(),
},
},
)
err = suite.nodeBroker.QueueMessage(setStatusCommand)
require.NoError(t, err)
<-setStatusCommand.Response

deadline = time.Now().Add(2 * time.Second)
for time.Now().Before(deadline) {
if waitForStableInferenceNode() {
nodeReady = true
break
}
time.Sleep(10 * time.Millisecond)
}
}
require.True(t, nodeReady, "node did not reach stable INFERENCE status before test start")

// 7. Create the public server
payloadStorage := newMockPayloadStorage()
Expand Down
146 changes: 108 additions & 38 deletions decentralized-api/internal/validation/inference_validation.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ import (
"sort"
"strconv"
"sync"
"sync/atomic"
"time"

"github.com/cosmos/cosmos-sdk/types/query"
Expand All @@ -37,11 +38,27 @@ import (
// and the inference is post-upgrade (no on-chain fallback available).
var ErrPayloadUnavailable = errors.New("payload unavailable after all retries")

// maxConcurrentValidations caps validation replay work across every production
// entrypoint in this validator process. Each validation can hold payloads,
// retry state, an HTTP connection, and a broker model lock for minutes.
const maxConcurrentValidations = 10

// validationHTTPClient is used for replay calls into the validator's ML node.
// http.Post uses http.DefaultClient with no timeout, which can pin a validation
// goroutine and its broker lock indefinitely when the ML node stalls.
var validationHTTPClient = &http.Client{Timeout: 5 * time.Minute}

type InferenceValidator struct {
recorder cosmosclient.CosmosMessageClient
nodeBroker *broker.Broker
configManager *apiconfig.ConfigManager
phaseTracker *chainphase.ChainPhaseTracker
recorder cosmosclient.CosmosMessageClient
nodeBroker *broker.Broker
configManager *apiconfig.ConfigManager
phaseTracker *chainphase.ChainPhaseTracker
validationSlotsOnce sync.Once
validationSlots chan struct{}

// recoveryRunning prevents the new-block dispatcher, startup recovery, and
// admin recovery handler from running recovery validation execution together.
recoveryRunning atomic.Bool
}

func NewInferenceValidator(
Expand All @@ -50,13 +67,43 @@ func NewInferenceValidator(
recorder cosmosclient.CosmosMessageClient,
phaseTracker *chainphase.ChainPhaseTracker) *InferenceValidator {
return &InferenceValidator{
nodeBroker: nodeBroker,
configManager: configManager,
recorder: recorder,
phaseTracker: phaseTracker,
nodeBroker: nodeBroker,
configManager: configManager,
recorder: recorder,
phaseTracker: phaseTracker,
validationSlots: make(chan struct{}, maxConcurrentValidations),
}
}

func (s *InferenceValidator) validationLimiter() chan struct{} {
s.validationSlotsOnce.Do(func() {
if s.validationSlots == nil {
s.validationSlots = make(chan struct{}, maxConcurrentValidations)
}
})
return s.validationSlots
}

func (s *InferenceValidator) acquireValidationSlot() func() {
ch := s.validationLimiter()
ch <- struct{}{}
return func() {
<-ch
}
}

// startValidationWithSlot runs validation in its own goroutine and acquires a
// shared validation slot inside that goroutine. Slot acquisition is intentionally
// done in the background so callers (event-handler workers) are never blocked
// waiting for an in-flight validation to release a slot.
func (s *InferenceValidator) startValidationWithSlot(inf types.Inference, recorder cosmosclient.InferenceCosmosClient, revalidation bool) {
go func() {
release := s.acquireValidationSlot()
defer release()
s.validateInferenceAndSendValMessage(inf, recorder, revalidation)
}()
}

func (s *InferenceValidator) VerifyInvalidation(events map[string][]string, recorder cosmosclient.InferenceCosmosClient) {
inferenceIds, ok := events["inference_validation.inference_id"]
if !ok || len(inferenceIds) == 0 {
Expand All @@ -77,10 +124,7 @@ func (s *InferenceValidator) VerifyInvalidation(events map[string][]string, reco
}

logInferencesToValidate([]string{inferenceId})
go func() {
s.validateInferenceAndSendValMessage(r.Inference, recorder, true)
}()

s.startValidationWithSlot(r.Inference, recorder, true)
}

// shouldValidateInference determines if the current participant should validate a specific inference
Expand Down Expand Up @@ -392,9 +436,15 @@ func (s *InferenceValidator) DetectMissedValidations(epochIndex uint64, seed int
}

// ExecuteRecoveryValidations executes validation for a list of missed inferences
// This function uses the inference data already obtained and executes validations in parallel goroutines
// It waits for all validations to complete before returning
// while sharing the process-wide validation replay cap with live sampled
// validation and revalidation work.
func (s *InferenceValidator) ExecuteRecoveryValidations(missedInferences []types.Inference) (int, error) {
if !s.recoveryRunning.CompareAndSwap(false, true) {
logging.Warn("Skipping recovery: another recovery execution is already running", types.ValidationRecovery)
return 0, nil
}
defer s.recoveryRunning.Store(false)

// TODO: allow to send validation for previous epoch and then rollback changes
// Chain requires validator to be active in CURRENT epoch
if !s.isActiveInCurrentEpoch() {
Expand Down Expand Up @@ -426,33 +476,37 @@ func (s *InferenceValidator) ExecuteRecoveryValidations(missedInferences []types
return 0, nil
}

logging.Info("Starting recovery validation execution", types.ValidationRecovery, "missedValidations", len(missedInferencesToValidate))
concreteRecorder, ok := s.recorder.(*cosmosclient.InferenceCosmosClient)
if !ok {
return 0, fmt.Errorf("recovery validation requires *InferenceCosmosClient recorder, got %T", s.recorder)
}

logging.Info("Starting recovery validation execution", types.ValidationRecovery,
"missedValidations", len(missedInferencesToValidate),
"maxConcurrentValidations", maxConcurrentValidations)

var wg sync.WaitGroup

// Execute recovery validations in parallel goroutines with WaitGroup synchronization
for _, inf := range missedInferencesToValidate {
release := s.acquireValidationSlot()
wg.Add(1)
go func(inference types.Inference) {
go func(inference types.Inference, release func()) {
defer wg.Done()
defer release()

logging.Info("Executing recovery validation", types.ValidationRecovery, "inferenceId", inference.InferenceId)

// Use existing validation infrastructure
// The validateInferenceAndSendValMessage function handles all validation logic, node locking, and message sending
// Cast the interface back to concrete type (safe since it's always *InferenceCosmosClient)
concreteRecorder := s.recorder.(*cosmosclient.InferenceCosmosClient)
s.validateInferenceAndSendValMessage(inference, *concreteRecorder, false)

logging.Info("Recovery validation completed", types.ValidationRecovery, "inferenceId", inference.InferenceId)
}(inf)
}(inf, release)
}

// Wait for all recovery validations to complete
logging.Info("Waiting for all recovery validations to complete", types.ValidationRecovery, "count", len(missedInferences))
logging.Info("Waiting for all recovery validations to complete", types.ValidationRecovery, "count", len(missedInferencesToValidate))
wg.Wait()

logging.Info("All recovery validations completed", types.ValidationRecovery, "count", len(missedInferences))
logging.Info("All recovery validations completed", types.ValidationRecovery, "count", len(missedInferencesToValidate))
return len(missedInferencesToValidate), nil
}

Expand Down Expand Up @@ -532,16 +586,29 @@ func (s *InferenceValidator) SampleInferenceToValidate(ids []string, transaction
}

logInferencesToValidate(toValidateIds)
for _, inf := range toValidateIds {
go func() {
response, err := queryClient.Inference(transactionRecorder.GetContext(), &types.QueryGetInferenceRequest{Index: inf})
if err != nil {
logging.Error("Failed to get inference by id", types.Validation, "id", response, "error", err)
return
}
s.validateInferenceAndSendValMessage(response.Inference, transactionRecorder, false)
}()
if len(toValidateIds) == 0 {
return
}

// SampleInferenceToValidate is called from the event-handler worker pool
// and must remain fire-and-forget so chain event processing is never
// blocked. The dispatcher goroutine below is what may block waiting for
// validation slots; the caller returns immediately.
go func() {
for _, infID := range toValidateIds {
release := s.acquireValidationSlot()
go func(inferenceID string, release func()) {
defer release()

response, err := queryClient.Inference(transactionRecorder.GetContext(), &types.QueryGetInferenceRequest{Index: inferenceID})
if err != nil {
logging.Error("Failed to get inference by id", types.Validation, "id", inferenceID, "error", err)
return
}
s.validateInferenceAndSendValMessage(response.Inference, transactionRecorder, false)
}(infID, release)
}
}()
}

func logInferencesToSample(inferences []*types.InferenceValidationDetails) {
Expand Down Expand Up @@ -923,11 +990,14 @@ func (s *InferenceValidator) validateWithPayloads(inference types.Inference, inf
return nil, err
}

resp, err := http.Post(
completionsUrl,
"application/json",
bytes.NewReader(requestBody),
)
req, err := http.NewRequestWithContext(s.recorder.GetContext(), http.MethodPost, completionsUrl, bytes.NewReader(requestBody))
if err != nil {
logging.Error("Failed to create ML node validation request", types.Validation, "url", completionsUrl, "error", err)
return nil, err
}
req.Header.Set("Content-Type", "application/json")

resp, err := validationHTTPClient.Do(req)
if err != nil {
return nil, err
}
Expand Down
Loading