diff --git a/epoch/flow_operations.go b/epoch/flow_operations.go index 6711918..f3c7b95 100644 --- a/epoch/flow_operations.go +++ b/epoch/flow_operations.go @@ -2,6 +2,7 @@ package epoch import ( "fmt" + "reflect" "github.com/bytedance/sonic/ast" ) @@ -261,7 +262,7 @@ func (op *ResponseRemoveFieldIfDefault) ApplyToResponse(node *ast.Node) error { } // Compare with default - only remove if they match - if fieldValue == op.Default { + if reflect.DeepEqual(fieldValue, op.Default) { return DeleteNodeField(node, op.Name) } diff --git a/epoch/middleware.go b/epoch/middleware.go index dde9537..39f8f4a 100644 --- a/epoch/middleware.go +++ b/epoch/middleware.go @@ -401,8 +401,14 @@ func (vah *VersionAwareHandler) handleWithMigration(c *gin.Context, requestedVer body: make([]byte, 0), statusCode: 200, } + originalWriter := c.Writer c.Writer = responseCapture + // Ensure the original writer is always restored, even if the handler panics. + // Without this, Gin's recovery middleware would operate on the capture writer, + // causing broken error responses. + defer func() { c.Writer = originalWriter }() + // 3. Call the handler (which expects head version data) vah.handler(c) @@ -510,7 +516,9 @@ func (vah *VersionAwareHandler) migrateRequest( return fmt.Errorf("failed to get raw JSON from migrated request: %w", err) } - c.Request.Body = io.NopCloser(bytes.NewReader([]byte(migratedJSON))) + migratedBytes := []byte(migratedJSON) + c.Request.Body = io.NopCloser(bytes.NewReader(migratedBytes)) + c.Request.ContentLength = int64(len(migratedBytes)) return nil } diff --git a/epoch/version_change.go b/epoch/version_change.go index 819fcc2..a5fa3d0 100644 --- a/epoch/version_change.go +++ b/epoch/version_change.go @@ -4,6 +4,7 @@ import ( "context" "fmt" "reflect" + "sort" "strings" "github.com/bytedance/sonic/ast" @@ -337,8 +338,22 @@ type MigrationChain struct { // NewMigrationChain creates a new migration chain with cycle detection func NewMigrationChain(changes []*VersionChange) (*MigrationChain, error) { + // Sort changes by FromVersion (oldest first) so forward migration can iterate sequentially. + // This is required because MigrateRequest relies on sequential iteration through sorted changes. + sorted := make([]*VersionChange, len(changes)) + copy(sorted, changes) + sort.Slice(sorted, func(i, j int) bool { + // Primary sort: by FromVersion (older first) + cmp := sorted[i].FromVersion().Compare(sorted[j].FromVersion()) + if cmp != 0 { + return cmp < 0 + } + // Secondary sort: by ToVersion (older first) for changes with the same FromVersion + return sorted[i].ToVersion().Compare(sorted[j].ToVersion()) < 0 + }) + mc := &MigrationChain{ - changes: changes, + changes: sorted, } // Detect cycles in the version graph @@ -385,36 +400,25 @@ func (mc *MigrationChain) MigrateRequest(ctx context.Context, requestInfo *Reque return nil } - // Find the starting point in the version chain - start := -1 - for i, change := range mc.changes { - if change.FromVersion().Equal(from) || change.FromVersion().IsOlderThan(from) { - start = i - break + // Apply changes in sequence (changes are sorted by FromVersion in NewMigrationChain). + // A change is included if: + // 1. Its FromVersion >= from (don't apply changes from before our starting version) + // 2. Its ToVersion <= targetVersion (don't apply changes past our target) + for _, change := range mc.changes { + // Skip changes that start before our starting version + if change.FromVersion().IsOlderThan(from) { + continue } - } - if start == -1 { - return fmt.Errorf("no migration path found from version %s to %s (available changes: %d)", - from.String(), to.String(), len(mc.changes)) - } - - // Apply changes in sequence until we reach the target version - for i := start; i < len(mc.changes); i++ { - change := mc.changes[i] - - // Stop if this change would take us past the target + // Stop if this change goes past the target version if change.ToVersion().IsNewerThan(targetVersion) { break } - // Apply this change if it's part of the migration path - if (change.ToVersion().Equal(targetVersion) || change.ToVersion().IsOlderThan(targetVersion)) && - (change.FromVersion().IsOlderThan(targetVersion) || change.FromVersion().Equal(targetVersion)) { - if err := change.MigrateRequest(ctx, requestInfo); err != nil { - return fmt.Errorf("migration failed at %s->%s: %w", - change.FromVersion().String(), change.ToVersion().String(), err) - } + // Apply this change - it's in the migration path [from, targetVersion] + if err := change.MigrateRequest(ctx, requestInfo); err != nil { + return fmt.Errorf("migration failed at %s->%s: %w", + change.FromVersion().String(), change.ToVersion().String(), err) } } @@ -456,7 +460,9 @@ func (mc *MigrationChain) MigrateResponse(ctx context.Context, responseInfo *Res // We need to apply changes at each step in reverse iterationCount := 0 - maxIterations := 10 // Safety limit + // Safety limit derived from the number of changes + 1. + // Cycles are already detected at construction time, so this is a defensive safeguard. + maxIterations := len(mc.changes) + 1 for !currentVersion.Equal(to) { iterationCount++ @@ -509,9 +515,34 @@ func (mc *MigrationChain) MigrateResponse(ctx context.Context, responseInfo *Res return nil } -// AddChange adds a new version change to the chain -func (mc *MigrationChain) AddChange(change *VersionChange) { +// AddChange adds a new version change to the chain. +// The change is inserted in sorted order (by FromVersion, then ToVersion) and +// cycle detection is re-run to ensure the chain remains valid. +func (mc *MigrationChain) AddChange(change *VersionChange) error { mc.changes = append(mc.changes, change) + + // Re-sort to maintain the ordering invariant required by MigrateRequest + sort.Slice(mc.changes, func(i, j int) bool { + cmp := mc.changes[i].FromVersion().Compare(mc.changes[j].FromVersion()) + if cmp != 0 { + return cmp < 0 + } + return mc.changes[i].ToVersion().Compare(mc.changes[j].ToVersion()) < 0 + }) + + // Re-run cycle detection to ensure the new change doesn't introduce a cycle + if err := mc.detectCycles(); err != nil { + // Rollback: remove the newly added change + for i, c := range mc.changes { + if c == change { + mc.changes = append(mc.changes[:i], mc.changes[i+1:]...) + break + } + } + return err + } + + return nil } // detectCycles uses depth-first search to find cycles in the version graph