Skip to content

Commit 4e82010

Browse files
authored
snapshot/revert views (#4669)
1 parent 061bc1b commit 4e82010

File tree

7 files changed

+83
-7
lines changed

7 files changed

+83
-7
lines changed

action/protocol/protocol.go

Lines changed: 36 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,16 +120,50 @@ type (
120120

121121
// Views stores the view for all protocols
122122
Views struct {
123-
vm map[string]View
123+
snapshotID int
124+
snapshots map[int]map[string]int
125+
vm map[string]View
124126
}
125127
)
126128

127129
func NewViews() *Views {
128130
return &Views{
129-
vm: make(map[string]View),
131+
snapshotID: 0,
132+
snapshots: make(map[int]map[string]int),
133+
vm: make(map[string]View),
130134
}
131135
}
132136

137+
func (views *Views) Snapshot() int {
138+
views.snapshotID++
139+
views.snapshots[views.snapshotID] = make(map[string]int)
140+
keys := make([]string, 0, len(views.vm))
141+
for key := range views.vm {
142+
keys = append(keys, key)
143+
}
144+
for _, key := range keys {
145+
views.snapshots[views.snapshotID][key] = views.vm[key].Snapshot()
146+
}
147+
return views.snapshotID
148+
}
149+
150+
func (views *Views) Revert(id int) error {
151+
if id > views.snapshotID || id < 0 {
152+
return errors.Errorf("invalid snapshot id %d, max id is %d", id, views.snapshotID)
153+
}
154+
for k, v := range views.snapshots[id] {
155+
if err := views.vm[k].Revert(v); err != nil {
156+
return err
157+
}
158+
}
159+
views.snapshotID = id
160+
// clean up snapshots that are not needed anymore
161+
for i := id + 1; i <= views.snapshotID; i++ {
162+
delete(views.snapshots, i)
163+
}
164+
return nil
165+
}
166+
133167
func (views *Views) Fork() *Views {
134168
fork := NewViews()
135169
for key, view := range views.vm {

action/protocol/staking/viewdata.go

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,8 @@ type (
2222
Wrap() ContractStakeView
2323
// Fork forks the contract stake view, commit will not affect the original view
2424
Fork() ContractStakeView
25+
// IsDirty checks if the contract stake view is dirty
26+
IsDirty() bool
2527
// Commit commits the contract stake view
2628
Commit(context.Context, protocol.StateManager) error
2729
// CreatePreStates creates pre states for the contract stake view
@@ -91,7 +93,7 @@ func (v *viewData) Commit(ctx context.Context, sm protocol.StateManager) error {
9193
}
9294

9395
func (v *viewData) IsDirty() bool {
94-
return v.candCenter.IsDirty() || v.bucketPool.IsDirty()
96+
return v.candCenter.IsDirty() || v.bucketPool.IsDirty() || (v.contractsStake != nil && v.contractsStake.IsDirty())
9597
}
9698

9799
func (v *viewData) Snapshot() int {
@@ -198,6 +200,19 @@ func (csv *contractStakeView) CreatePreStates(ctx context.Context) error {
198200
return nil
199201
}
200202

203+
func (csv *contractStakeView) IsDirty() bool {
204+
if csv.v1 != nil && csv.v1.IsDirty() {
205+
return true
206+
}
207+
if csv.v2 != nil && csv.v2.IsDirty() {
208+
return true
209+
}
210+
if csv.v3 != nil && csv.v3.IsDirty() {
211+
return true
212+
}
213+
return false
214+
}
215+
201216
func (csv *contractStakeView) Commit(ctx context.Context, sm protocol.StateManager) error {
202217
featureCtx, ok := protocol.GetFeatureCtx(ctx)
203218
if !ok || featureCtx.LoadContractStakingFromIndexer {

blockindex/contractstaking/stakeview.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -53,6 +53,10 @@ func (s *stakeView) assembleBuckets(ids []uint64, types []*BucketType, infos []*
5353
return vbs
5454
}
5555

56+
func (s *stakeView) IsDirty() bool {
57+
return s.cache.IsDirty()
58+
}
59+
5660
func (s *stakeView) WriteBuckets(sm protocol.StateManager) error {
5761
ids, types, infos := s.cache.Buckets()
5862
cssm := contractstaking.NewContractStakingStateManager(sm)

blockindex/contractstaking/wrappedcache.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -287,6 +287,11 @@ func (wc *wrappedCache) Commit(ctx context.Context, ca address.Address, sm proto
287287
wc.base.PutBucketInfo(id, bi)
288288
}
289289
}
290+
wc.updatedBucketInfos = make(map[uint64]*bucketInfo)
291+
wc.updatedBucketTypes = make(map[uint64]*BucketType)
292+
wc.updatedCandidates = make(map[string]map[uint64]bool)
293+
wc.propertyBucketTypeMap = make(map[uint64]map[uint64]uint64)
294+
290295
return wc.base.Commit(ctx, ca, sm)
291296
}
292297

e2etest/expect.go

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -97,9 +97,7 @@ func (ce *candidateExpect) expect(test *e2etest, act *action.SealedEnvelope, rec
9797
cs := test.svr.ChainService(test.cfg.Chain.ID)
9898
sr := cs.StateFactory()
9999
bc := cs.Blockchain()
100-
prtcl, ok := cs.Registry().Find("staking")
101-
require.True(ok)
102-
stkPrtcl := prtcl.(*staking.Protocol)
100+
stkPrtcl := staking.FindProtocol(cs.Registry())
103101
reqBytes, err := proto.Marshal(r)
104102
require.NoError(err)
105103
ctx := protocol.WithRegistry(context.Background(), cs.Registry())

state/factory/workingset.go

Lines changed: 17 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,7 @@ type (
7171
workingSetStoreFactory WorkingSetStoreFactory
7272
height uint64
7373
views *protocol.Views
74+
viewsSnapshots map[int]int
7475
store workingSetStore
7576
finalized bool
7677
txValidator *protocol.GenericValidator
@@ -82,6 +83,7 @@ func newWorkingSet(height uint64, views *protocol.Views, store workingSetStore,
8283
ws := &workingSet{
8384
height: height,
8485
views: views,
86+
viewsSnapshots: make(map[int]int),
8587
store: store,
8688
workingSetStoreFactory: storeFactory,
8789
}
@@ -281,14 +283,28 @@ func (ws *workingSet) finalizeTx(ctx context.Context) {
281283
}
282284

283285
func (ws *workingSet) Snapshot() int {
284-
return ws.store.Snapshot()
286+
id := ws.store.Snapshot()
287+
vid := ws.views.Snapshot()
288+
ws.viewsSnapshots[id] = vid
289+
290+
return id
285291
}
286292

287293
func (ws *workingSet) Revert(snapshot int) error {
294+
vid, ok := ws.viewsSnapshots[snapshot]
295+
if !ok {
296+
return errors.Errorf("snapshot %d not found", snapshot)
297+
}
298+
if err := ws.views.Revert(vid); err != nil {
299+
return errors.Wrapf(err, "failed to revert views to snapshot %d", vid)
300+
}
288301
return ws.store.RevertSnapshot(snapshot)
289302
}
290303

291304
func (ws *workingSet) ResetSnapshots() {
305+
if len(ws.viewsSnapshots) > 0 {
306+
ws.viewsSnapshots = make(map[int]int)
307+
}
292308
ws.store.ResetSnapshots()
293309
}
294310

systemcontractindex/stakingindex/stakeview.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -50,6 +50,10 @@ func (s *stakeView) Fork() staking.ContractStakeView {
5050
}
5151
}
5252

53+
func (s *stakeView) IsDirty() bool {
54+
return s.cache.IsDirty()
55+
}
56+
5357
func (s *stakeView) WriteBuckets(sm protocol.StateManager) error {
5458
ids := s.cache.BucketIdxs()
5559
slices.Sort(ids)

0 commit comments

Comments
 (0)