Skip to content

Commit

Permalink
Fix data race on hcr.channels
Browse files Browse the repository at this point in the history
Signed-off-by: Anton Litvinov <[email protected]>
  • Loading branch information
Zensey committed Mar 5, 2024
1 parent d6e4124 commit 833702d
Show file tree
Hide file tree
Showing 4 changed files with 42 additions and 11 deletions.
30 changes: 25 additions & 5 deletions core/state/state.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,7 +120,13 @@ func NewKeeper(deps KeeperDeps, debounceDuration time.Duration) *Keeper {
deps: deps,
}
k.state.Identities = k.fetchIdentities()
k.state.ProviderChannels = k.deps.EarningsProvider.List(deps.ChainID)

channels := k.deps.EarningsProvider.List(deps.ChainID)
channelsCopy := make([]pingpong.HermesChannel, 0)
if err := copier.CopyWithOption(&channelsCopy, channels, copier.Option{DeepCopy: true}); err != nil {
panic(err)
}
k.state.ProviderChannels = channelsCopy

// provider
k.consumeServiceStateEvent = debounce(k.updateServiceState, debounceDuration)
Expand All @@ -136,6 +142,9 @@ func NewKeeper(deps KeeperDeps, debounceDuration time.Duration) *Keeper {
return k
}

// func (k *Keeper) copyHermesChannels([]pingpong.HermesChannel) []pingpong.HermesChannel {
// }

func (k *Keeper) fetchIdentities() []stateEvent.Identity {
ids := k.deps.IdentityProvider.GetIdentities()
identities := make([]stateEvent.Identity, len(ids))
Expand All @@ -157,15 +166,26 @@ func (k *Keeper) fetchIdentities() []stateEvent.Identity {
}

earnings := k.deps.EarningsProvider.GetEarningsDetailed(k.deps.ChainID, id)
earningsCopy := &pingpongEvent.EarningsDetailed{}
if err := copier.CopyWithOption(earningsCopy, *earnings, copier.Option{DeepCopy: true}); err != nil {
panic(err)
}

balanceCopy := new(big.Int)
err = copier.Copy(balanceCopy, *k.deps.BalanceProvider.GetBalance(k.deps.ChainID, id))
if err != nil {
panic(err)
}

stateIdentity := stateEvent.Identity{
Address: id.Address,
RegistrationStatus: status,
ChannelAddress: channelAddress,
Balance: k.deps.BalanceProvider.GetBalance(k.deps.ChainID, id),
Earnings: earnings.Total.UnsettledBalance,
EarningsTotal: earnings.Total.LifetimeBalance,
Balance: balanceCopy,
Earnings: earningsCopy.Total.UnsettledBalance,
EarningsTotal: earningsCopy.Total.LifetimeBalance,
HermesID: hermesID,
EarningsPerHermes: earnings.PerHermes,
EarningsPerHermes: earningsCopy.PerHermes,
}
identities[idx] = stateIdentity
}
Expand Down
7 changes: 5 additions & 2 deletions services/wireguard/service/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ package service
import (
"encoding/json"
"fmt"
"maps"
"net"
"sync"
"time"

"github.com/jinzhu/copier"
"github.com/pkg/errors"
"github.com/rs/zerolog/log"

Expand Down Expand Up @@ -277,7 +277,10 @@ func (m *Manager) Stop() error {

// prevent concurrent iteration and write
sessionCleanupCopy := make(map[string]func())
maps.Copy(sessionCleanupCopy, m.sessionCleanup)
if err := copier.Copy(&sessionCleanupCopy, m.sessionCleanup); err != nil {
panic(err)
}

for k, v := range sessionCleanupCopy {
cleanupWg.Add(1)
go func(sessionID string, cleanup func()) {
Expand Down
12 changes: 10 additions & 2 deletions session/pingpong/hermes_channel_repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,8 @@ func (hcr *HermesChannelRepository) handleHermesPromiseReceived(payload pingEven
return
}

err = hcr.updateChannelWithLatestPromise(payload.Promise.ChainID, promise.ChannelID, payload.ProviderID, payload.HermesID, promise)
// use parameter "protectChannels" to protect channels on update
err = hcr.updateChannelWithLatestPromise(payload.Promise.ChainID, promise.ChannelID, payload.ProviderID, payload.HermesID, promise, true)
if err != nil {
log.Err(err).Msg("could not update channel state with latest hermes promise")
}
Expand Down Expand Up @@ -438,7 +439,7 @@ func (hcr *HermesChannelRepository) fetchChannel(chainID int64, channelID string
return hermesChannel, nil
}

func (hcr *HermesChannelRepository) updateChannelWithLatestPromise(chainID int64, channelID string, id identity.Identity, hermesID common.Address, promise HermesPromise) error {
func (hcr *HermesChannelRepository) updateChannelWithLatestPromise(chainID int64, channelID string, id identity.Identity, hermesID common.Address, promise HermesPromise, protectChannels bool) error {
gotten, ok := hcr.Get(chainID, id, hermesID)
if !ok {
// this actually performs the update, so no need to do anything
Expand All @@ -447,7 +448,14 @@ func (hcr *HermesChannelRepository) updateChannelWithLatestPromise(chainID int64
}

hermesChannel := NewHermesChannel(channelID, id, hermesID, gotten.Channel, promise, gotten.Beneficiary)

// protect hcr.channels: handleHermesPromiseReceived -> updateChannelWithLatestPromise -> updateChannel
if protectChannels {
hcr.lock.Lock()
defer hcr.lock.Unlock()
}
hcr.updateChannel(chainID, hermesChannel)

return nil
}

Expand Down
4 changes: 2 additions & 2 deletions session/pingpong/hermes_channel_repository_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ func TestHermesChannelRepository_BeneficiaryReset(t *testing.T) {

// when
promise := HermesPromise{ChannelID: channelID.Hex(), Identity: id, HermesID: hermesID}
err := repo.updateChannelWithLatestPromise(1, channelID.Hex(), id, hermesID, promise)
err := repo.updateChannelWithLatestPromise(1, channelID.Hex(), id, hermesID, promise, false)
assert.NoError(t, err)
hermesChannel, exists := repo.Get(1, id, hermesID)

Expand All @@ -279,7 +279,7 @@ func TestHermesChannelRepository_BeneficiaryReset(t *testing.T) {
assert.Equal(t, beneficiary, hermesChannel.Beneficiary)

// when
err = repo.updateChannelWithLatestPromise(1, channelID.Hex(), id, hermesID, promise)
err = repo.updateChannelWithLatestPromise(1, channelID.Hex(), id, hermesID, promise, false)
assert.NoError(t, err)
hermesChannel, exists = repo.Get(1, id, hermesID)

Expand Down

0 comments on commit 833702d

Please sign in to comment.