diff --git a/pkg/gcc/rate_controller.go b/pkg/gcc/rate_controller.go index c19a617b..28cc0e1e 100644 --- a/pkg/gcc/rate_controller.go +++ b/pkg/gcc/rate_controller.go @@ -90,7 +90,8 @@ func (c *rateController) onDelayStats(ds DelayStats) { return } c.delayStats = ds - c.delayStats.State = c.delayStats.State.transition(ds.Usage) + c.delayStats.State = c.lastState.transition(ds.Usage) + c.lastState = c.delayStats.State if c.delayStats.State == stateHold { return diff --git a/pkg/gcc/rate_controller_test.go b/pkg/gcc/rate_controller_test.go index 46b3238c..242aefb7 100644 --- a/pkg/gcc/rate_controller_test.go +++ b/pkg/gcc/rate_controller_test.go @@ -76,3 +76,57 @@ func TestRateControllerRun(t *testing.T) { }) } } + +func TestRateController_StateTransition(t *testing.T) { + tcs := []struct { + name string + delayStats []DelayStats + wantStates []state + }{ + { + name: "overuse-normal", + delayStats: []DelayStats{{Usage: usageOver}, {Usage: usageNormal}}, + wantStates: []state{stateDecrease, stateHold}, + }, + { + name: "overuse-underuse", + delayStats: []DelayStats{{Usage: usageOver}, {Usage: usageUnder}}, + wantStates: []state{stateDecrease, stateHold}, + }, + { + name: "normal", + delayStats: []DelayStats{{Usage: usageNormal}}, + wantStates: []state{stateIncrease}, + }, + { + name: "under-over", + delayStats: []DelayStats{{Usage: usageUnder}, {Usage: usageOver}}, + wantStates: []state{stateHold, stateDecrease}, + }, + { + name: "under-normal", + delayStats: []DelayStats{{Usage: usageUnder}, {Usage: usageNormal}}, + wantStates: []state{stateHold, stateIncrease}, + }, + { + name: "under-under", + delayStats: []DelayStats{{Usage: usageUnder}, {Usage: usageUnder}}, + wantStates: []state{stateHold, stateHold}, + }, + } + + for _, tc := range tcs { + t.Run(tc.name, func(t *testing.T) { + rc := newRateController(time.Now, 500_000, 100_000, 1_000_000, func(DelayStats) {}) + // Call it once to initialize the rate controller + rc.onDelayStats(DelayStats{}) + + for i, ds := range tc.delayStats { + rc.onDelayStats(ds) + if rc.lastState != tc.wantStates[i] { + t.Errorf("expected lastState to be %v but got %v", tc.wantStates[i], rc.lastState) + } + } + }) + } +}