From 69a6454b0eb20b46805d8ba743c4495df29815fb Mon Sep 17 00:00:00 2001 From: Anton Litvinov Date: Mon, 4 Mar 2024 19:57:45 +0400 Subject: [PATCH] Fix data race on hcr.channels Signed-off-by: Anton Litvinov --- services/wireguard/service/service.go | 7 +++++-- session/pingpong/hermes_channel_repository.go | 11 +++++++++-- session/pingpong/hermes_channel_repository_test.go | 4 ++-- 3 files changed, 16 insertions(+), 6 deletions(-) diff --git a/services/wireguard/service/service.go b/services/wireguard/service/service.go index ccc9058455..a0bc4acd3d 100644 --- a/services/wireguard/service/service.go +++ b/services/wireguard/service/service.go @@ -20,11 +20,11 @@ package service import ( "encoding/json" "fmt" - "maps" "net" "sync" "time" + "github.com/jinzhu/copier" "github.com/pkg/errors" "github.com/rs/zerolog/log" @@ -277,7 +277,10 @@ func (m *Manager) Stop() error { // prevent concurrent iteration and write sessionCleanupCopy := make(map[string]func()) - maps.Copy(sessionCleanupCopy, m.sessionCleanup) + if err := copier.Copy(&sessionCleanupCopy, m.sessionCleanup); err != nil { + panic(err) + } + for k, v := range sessionCleanupCopy { cleanupWg.Add(1) go func(sessionID string, cleanup func()) { diff --git a/session/pingpong/hermes_channel_repository.go b/session/pingpong/hermes_channel_repository.go index 0d58dfbba9..958848774b 100644 --- a/session/pingpong/hermes_channel_repository.go +++ b/session/pingpong/hermes_channel_repository.go @@ -261,7 +261,8 @@ func (hcr *HermesChannelRepository) handleHermesPromiseReceived(payload pingEven return } - err = hcr.updateChannelWithLatestPromise(payload.Promise.ChainID, promise.ChannelID, payload.ProviderID, payload.HermesID, promise) + // use parameter "protectChannels" to protect channels on update + err = hcr.updateChannelWithLatestPromise(payload.Promise.ChainID, promise.ChannelID, payload.ProviderID, payload.HermesID, promise, true) if err != nil { log.Err(err).Msg("could not update channel state with latest hermes promise") } @@ -438,7 +439,7 @@ func (hcr *HermesChannelRepository) fetchChannel(chainID int64, channelID string return hermesChannel, nil } -func (hcr *HermesChannelRepository) updateChannelWithLatestPromise(chainID int64, channelID string, id identity.Identity, hermesID common.Address, promise HermesPromise) error { +func (hcr *HermesChannelRepository) updateChannelWithLatestPromise(chainID int64, channelID string, id identity.Identity, hermesID common.Address, promise HermesPromise, protectChannels bool) error { gotten, ok := hcr.Get(chainID, id, hermesID) if !ok { // this actually performs the update, so no need to do anything @@ -447,6 +448,12 @@ func (hcr *HermesChannelRepository) updateChannelWithLatestPromise(chainID int64 } hermesChannel := NewHermesChannel(channelID, id, hermesID, gotten.Channel, promise, gotten.Beneficiary) + + // protect hcr.channels: handleHermesPromiseReceived -> updateChannelWithLatestPromise -> updateChannel + if protectChannels { + hcr.lock.Lock() + defer hcr.lock.Unlock() + } hcr.updateChannel(chainID, hermesChannel) return nil } diff --git a/session/pingpong/hermes_channel_repository_test.go b/session/pingpong/hermes_channel_repository_test.go index 7a87c0abc8..04da4bb78c 100644 --- a/session/pingpong/hermes_channel_repository_test.go +++ b/session/pingpong/hermes_channel_repository_test.go @@ -270,7 +270,7 @@ func TestHermesChannelRepository_BeneficiaryReset(t *testing.T) { // when promise := HermesPromise{ChannelID: channelID.Hex(), Identity: id, HermesID: hermesID} - err := repo.updateChannelWithLatestPromise(1, channelID.Hex(), id, hermesID, promise) + err := repo.updateChannelWithLatestPromise(1, channelID.Hex(), id, hermesID, promise, false) assert.NoError(t, err) hermesChannel, exists := repo.Get(1, id, hermesID) @@ -279,7 +279,7 @@ func TestHermesChannelRepository_BeneficiaryReset(t *testing.T) { assert.Equal(t, beneficiary, hermesChannel.Beneficiary) // when - err = repo.updateChannelWithLatestPromise(1, channelID.Hex(), id, hermesID, promise) + err = repo.updateChannelWithLatestPromise(1, channelID.Hex(), id, hermesID, promise, false) assert.NoError(t, err) hermesChannel, exists = repo.Get(1, id, hermesID)