Skip to content

Commit

Permalink
Fix data race on hcr.channels
Browse files Browse the repository at this point in the history
Signed-off-by: Anton Litvinov <[email protected]>
  • Loading branch information
Zensey committed Mar 4, 2024
1 parent d6e4124 commit 69a6454
Show file tree
Hide file tree
Showing 3 changed files with 16 additions and 6 deletions.
7 changes: 5 additions & 2 deletions services/wireguard/service/service.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,11 @@ package service
import (
"encoding/json"
"fmt"
"maps"
"net"
"sync"
"time"

"github.com/jinzhu/copier"
"github.com/pkg/errors"
"github.com/rs/zerolog/log"

Expand Down Expand Up @@ -277,7 +277,10 @@ func (m *Manager) Stop() error {

// prevent concurrent iteration and write
sessionCleanupCopy := make(map[string]func())
maps.Copy(sessionCleanupCopy, m.sessionCleanup)
if err := copier.Copy(&sessionCleanupCopy, m.sessionCleanup); err != nil {
panic(err)
}

for k, v := range sessionCleanupCopy {
cleanupWg.Add(1)
go func(sessionID string, cleanup func()) {
Expand Down
11 changes: 9 additions & 2 deletions session/pingpong/hermes_channel_repository.go
Original file line number Diff line number Diff line change
Expand Up @@ -261,7 +261,8 @@ func (hcr *HermesChannelRepository) handleHermesPromiseReceived(payload pingEven
return
}

err = hcr.updateChannelWithLatestPromise(payload.Promise.ChainID, promise.ChannelID, payload.ProviderID, payload.HermesID, promise)
// use parameter "protectChannels" to protect channels on update
err = hcr.updateChannelWithLatestPromise(payload.Promise.ChainID, promise.ChannelID, payload.ProviderID, payload.HermesID, promise, true)
if err != nil {
log.Err(err).Msg("could not update channel state with latest hermes promise")
}
Expand Down Expand Up @@ -438,7 +439,7 @@ func (hcr *HermesChannelRepository) fetchChannel(chainID int64, channelID string
return hermesChannel, nil
}

func (hcr *HermesChannelRepository) updateChannelWithLatestPromise(chainID int64, channelID string, id identity.Identity, hermesID common.Address, promise HermesPromise) error {
func (hcr *HermesChannelRepository) updateChannelWithLatestPromise(chainID int64, channelID string, id identity.Identity, hermesID common.Address, promise HermesPromise, protectChannels bool) error {
gotten, ok := hcr.Get(chainID, id, hermesID)
if !ok {
// this actually performs the update, so no need to do anything
Expand All @@ -447,6 +448,12 @@ func (hcr *HermesChannelRepository) updateChannelWithLatestPromise(chainID int64
}

hermesChannel := NewHermesChannel(channelID, id, hermesID, gotten.Channel, promise, gotten.Beneficiary)

// protect hcr.channels: handleHermesPromiseReceived -> updateChannelWithLatestPromise -> updateChannel
if protectChannels {
hcr.lock.Lock()
defer hcr.lock.Unlock()
}
hcr.updateChannel(chainID, hermesChannel)
return nil
}
Expand Down
4 changes: 2 additions & 2 deletions session/pingpong/hermes_channel_repository_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ func TestHermesChannelRepository_BeneficiaryReset(t *testing.T) {

// when
promise := HermesPromise{ChannelID: channelID.Hex(), Identity: id, HermesID: hermesID}
err := repo.updateChannelWithLatestPromise(1, channelID.Hex(), id, hermesID, promise)
err := repo.updateChannelWithLatestPromise(1, channelID.Hex(), id, hermesID, promise, false)
assert.NoError(t, err)
hermesChannel, exists := repo.Get(1, id, hermesID)

Expand All @@ -279,7 +279,7 @@ func TestHermesChannelRepository_BeneficiaryReset(t *testing.T) {
assert.Equal(t, beneficiary, hermesChannel.Beneficiary)

// when
err = repo.updateChannelWithLatestPromise(1, channelID.Hex(), id, hermesID, promise)
err = repo.updateChannelWithLatestPromise(1, channelID.Hex(), id, hermesID, promise, false)
assert.NoError(t, err)
hermesChannel, exists = repo.Get(1, id, hermesID)

Expand Down

0 comments on commit 69a6454

Please sign in to comment.