diff --git a/ci/packages/build.go b/ci/packages/build.go
index 19c851d88a..d2a7c2d6b1 100644
--- a/ci/packages/build.go
+++ b/ci/packages/build.go
@@ -32,16 +32,31 @@ import (
"github.com/rs/zerolog/log"
)
+// BuildProvChecker builds myst binary with provider checker API. Like go tool, it supports cross-platform build with env vars: GOOS, GOARCH.
+func BuildProvChecker() error {
+ logconfig.Bootstrap()
+ if err := buildBinary(path.Join("cmd", "mysterium_node", "mysterium_node.go"), "myst", true); err != nil {
+ return err
+ }
+ if err := copyConfig("myst"); err != nil {
+ return err
+ }
+ if err := buildBinary(path.Join("cmd", "supervisor", "supervisor.go"), "myst_supervisor", false); err != nil {
+ return err
+ }
+ return nil
+}
+
// Build builds the project. Like go tool, it supports cross-platform build with env vars: GOOS, GOARCH.
func Build() error {
logconfig.Bootstrap()
- if err := buildBinary(path.Join("cmd", "mysterium_node", "mysterium_node.go"), "myst"); err != nil {
+ if err := buildBinary(path.Join("cmd", "mysterium_node", "mysterium_node.go"), "myst", false); err != nil {
return err
}
if err := copyConfig("myst"); err != nil {
return err
}
- if err := buildBinary(path.Join("cmd", "supervisor", "supervisor.go"), "myst_supervisor"); err != nil {
+ if err := buildBinary(path.Join("cmd", "supervisor", "supervisor.go"), "myst_supervisor", false); err != nil {
return err
}
return nil
@@ -67,7 +82,7 @@ func buildCrossBinary(os, arch string) error {
return sh.Run("bin/build_xgo", os+"/"+arch)
}
-func buildBinary(source, target string) error {
+func buildBinary(source, target string, provChecker bool) error {
targetOS, ok := os.LookupEnv("GOOS")
if !ok {
targetOS = runtime.GOOS
@@ -76,10 +91,10 @@ func buildBinary(source, target string) error {
if !ok {
targetArch = runtime.GOARCH
}
- return buildBinaryFor(source, target, targetOS, targetArch, nil, false)
+ return buildBinaryFor(source, target, targetOS, targetArch, nil, false, provChecker)
}
-func buildBinaryFor(source, target, targetOS, targetArch string, extraEnvs map[string]string, buildStatic bool) error {
+func buildBinaryFor(source, target, targetOS, targetArch string, extraEnvs map[string]string, buildStatic, provChecker bool) error {
log.Info().Msgf("Building %s -> %s %s/%s", source, target, targetOS, targetArch)
buildDir, err := filepath.Abs(path.Join("build", target))
@@ -100,6 +115,9 @@ func buildBinaryFor(source, target, targetOS, targetArch string, extraEnvs map[s
if buildStatic {
flags = append(flags, "-a", "-tags", "netgo")
}
+ if provChecker {
+ flags = append(flags, "-tags", "prov_checker")
+ }
if targetOS == "windows" {
target += ".exe"
diff --git a/ci/packages/package.go b/ci/packages/package.go
index 2e156bb124..5c165f0493 100644
--- a/ci/packages/package.go
+++ b/ci/packages/package.go
@@ -380,7 +380,7 @@ func packageStandalone(binaryPath, os, arch string, extraEnvs map[string]string)
if os == "linux" {
filename := path.Base(binaryPath)
binaryPath = path.Join("build", filename, filename)
- err = buildBinaryFor(path.Join("cmd", "mysterium_node", "mysterium_node.go"), filename, os, arch, extraEnvs, true)
+ err = buildBinaryFor(path.Join("cmd", "mysterium_node", "mysterium_node.go"), filename, os, arch, extraEnvs, true, false)
} else {
err = buildCrossBinary(os, arch)
}
@@ -388,7 +388,7 @@ func packageStandalone(binaryPath, os, arch string, extraEnvs map[string]string)
return err
}
- err = buildBinaryFor(path.Join("cmd", "supervisor", "supervisor.go"), "myst_supervisor", os, arch, extraEnvs, true)
+ err = buildBinaryFor(path.Join("cmd", "supervisor", "supervisor.go"), "myst_supervisor", os, arch, extraEnvs, true, false)
if err != nil {
return err
}
diff --git a/cmd/bootstrap.go b/cmd/bootstrap.go
index a45f464cee..5024f8fdde 100644
--- a/cmd/bootstrap.go
+++ b/cmd/bootstrap.go
@@ -45,55 +45,60 @@ func (di *Dependencies) bootstrapTequilapi(nodeOptions node.Options, listener ne
}
tequilaApiClient := tequilapi_client.NewClient(nodeOptions.TequilapiAddress, nodeOptions.TequilapiPort)
+ handlers := []func(engine *gin.Engine) error{
+ func(e *gin.Engine) error {
+ if err := tequilapi_endpoints.AddRoutesForSSE(e, di.StateKeeper, di.EventBus); err != nil {
+ return err
+ }
+ return nil
+ },
+ func(e *gin.Engine) error {
+ if config.GetBool(config.FlagPProfEnable) {
+ tequilapi_endpoints.AddRoutesForPProf(e)
+ }
+ return nil
+ },
+ func(e *gin.Engine) error {
+ e.GET("/healthcheck", tequilapi_endpoints.HealthCheckEndpointFactory(time.Now, os.Getpid).HealthCheck)
+ return nil
+ },
+ tequilapi_endpoints.AddRouteForStop(utils.SoftKiller(di.Shutdown)),
+ tequilapi_endpoints.AddRoutesForAuthentication(di.Authenticator, di.JWTAuthenticator, di.SSOMystnodes),
+ tequilapi_endpoints.AddRoutesForIdentities(di.IdentityManager, di.IdentitySelector, di.IdentityRegistry, di.ConsumerBalanceTracker, di.AddressProvider, di.HermesChannelRepository, di.BCHelper, di.Transactor, di.BeneficiaryProvider, di.IdentityMover, di.BeneficiaryAddressStorage, di.HermesMigrator),
+ tequilapi_endpoints.AddRoutesForConnection(di.MultiConnectionManager, di.StateKeeper, di.ProposalRepository, di.IdentityRegistry, di.EventBus, di.AddressProvider),
+ tequilapi_endpoints.AddRoutesForSessions(di.SessionStorage),
+ tequilapi_endpoints.AddRoutesForConnectionLocation(di.IPResolver, di.LocationResolver, di.LocationResolver),
+ tequilapi_endpoints.AddRoutesForProposals(di.ProposalRepository, di.PricingHelper, di.LocationResolver, di.FilterPresetStorage, di.NATProber),
+ tequilapi_endpoints.AddRoutesForService(di.ServicesManager, services.JSONParsersByType, di.ProposalRepository, tequilaApiClient),
+ tequilapi_endpoints.AddRoutesForAccessPolicies(di.HTTPClient, config.GetString(config.FlagAccessPolicyAddress)),
+ tequilapi_endpoints.AddRoutesForNAT(di.StateKeeper, di.NATProber),
+ tequilapi_endpoints.AddRoutesForNodeUI(versionmanager.NewVersionManager(di.UIServer, di.HTTPClient, di.uiVersionConfig)),
+ tequilapi_endpoints.AddRoutesForNode(di.NodeStatusTracker, di.NodeStatsTracker),
+ tequilapi_endpoints.AddRoutesForTransactor(di.IdentityRegistry, di.Transactor, di.Affiliator, di.HermesPromiseSettler, di.SettlementHistoryStorage, di.AddressProvider, di.BeneficiaryProvider, di.BeneficiarySaver, di.PilvytisAPI),
+ tequilapi_endpoints.AddRoutesForAffiliator(di.Affiliator),
+ tequilapi_endpoints.AddRoutesForConfig,
+ tequilapi_endpoints.AddRoutesForMMN(di.MMN, di.SSOMystnodes, di.Authenticator),
+ tequilapi_endpoints.AddRoutesForFeedback(di.Reporter),
+ tequilapi_endpoints.AddRoutesForConnectivityStatus(di.SessionConnectivityStatusStorage),
+ tequilapi_endpoints.AddRoutesForDocs,
+ tequilapi_endpoints.AddRoutesForCurrencyExchange(di.PilvytisAPI),
+ tequilapi_endpoints.AddRoutesForPilvytis(di.PilvytisAPI, di.PilvytisOrderIssuer, di.LocationResolver),
+ tequilapi_endpoints.AddRoutesForTerms,
+ tequilapi_endpoints.AddEntertainmentRoutes(entertainment.NewEstimator(
+ config.FlagPaymentPriceGiB.Value,
+ config.FlagPaymentPriceHour.Value,
+ )),
+ tequilapi_endpoints.AddRoutesForValidator,
+ }
+ if nodeOptions.ProvChecker {
+ handlers = append(handlers, tequilapi_endpoints.AddRoutesForConnectionDiag(di.MultiConnectionDiagManager, di.StateKeeper, di.ProposalRepository, di.IdentityRegistry, di.EventBus, di.EventBus, di.AddressProvider, di.IdentitySelector, nodeOptions))
+ }
+
return tequilapi.NewServer(
listener,
nodeOptions,
di.JWTAuthenticator,
- []func(engine *gin.Engine) error{
- func(e *gin.Engine) error {
- if err := tequilapi_endpoints.AddRoutesForSSE(e, di.StateKeeper, di.EventBus); err != nil {
- return err
- }
- return nil
- },
- func(e *gin.Engine) error {
- if config.GetBool(config.FlagPProfEnable) {
- tequilapi_endpoints.AddRoutesForPProf(e)
- }
- return nil
- },
- func(e *gin.Engine) error {
- e.GET("/healthcheck", tequilapi_endpoints.HealthCheckEndpointFactory(time.Now, os.Getpid).HealthCheck)
- return nil
- },
- tequilapi_endpoints.AddRouteForStop(utils.SoftKiller(di.Shutdown)),
- tequilapi_endpoints.AddRoutesForAuthentication(di.Authenticator, di.JWTAuthenticator, di.SSOMystnodes),
- tequilapi_endpoints.AddRoutesForIdentities(di.IdentityManager, di.IdentitySelector, di.IdentityRegistry, di.ConsumerBalanceTracker, di.AddressProvider, di.HermesChannelRepository, di.BCHelper, di.Transactor, di.BeneficiaryProvider, di.IdentityMover, di.BeneficiaryAddressStorage, di.HermesMigrator),
- tequilapi_endpoints.AddRoutesForConnection(di.MultiConnectionManager, di.StateKeeper, di.ProposalRepository, di.IdentityRegistry, di.EventBus, di.AddressProvider),
- tequilapi_endpoints.AddRoutesForSessions(di.SessionStorage),
- tequilapi_endpoints.AddRoutesForConnectionLocation(di.IPResolver, di.LocationResolver, di.LocationResolver),
- tequilapi_endpoints.AddRoutesForProposals(di.ProposalRepository, di.PricingHelper, di.LocationResolver, di.FilterPresetStorage, di.NATProber),
- tequilapi_endpoints.AddRoutesForService(di.ServicesManager, services.JSONParsersByType, di.ProposalRepository, tequilaApiClient),
- tequilapi_endpoints.AddRoutesForAccessPolicies(di.HTTPClient, config.GetString(config.FlagAccessPolicyAddress)),
- tequilapi_endpoints.AddRoutesForNAT(di.StateKeeper, di.NATProber),
- tequilapi_endpoints.AddRoutesForNodeUI(versionmanager.NewVersionManager(di.UIServer, di.HTTPClient, di.uiVersionConfig)),
- tequilapi_endpoints.AddRoutesForNode(di.NodeStatusTracker, di.NodeStatsTracker),
- tequilapi_endpoints.AddRoutesForTransactor(di.IdentityRegistry, di.Transactor, di.Affiliator, di.HermesPromiseSettler, di.SettlementHistoryStorage, di.AddressProvider, di.BeneficiaryProvider, di.BeneficiarySaver, di.PilvytisAPI),
- tequilapi_endpoints.AddRoutesForAffiliator(di.Affiliator),
- tequilapi_endpoints.AddRoutesForConfig,
- tequilapi_endpoints.AddRoutesForMMN(di.MMN, di.SSOMystnodes, di.Authenticator),
- tequilapi_endpoints.AddRoutesForFeedback(di.Reporter),
- tequilapi_endpoints.AddRoutesForConnectivityStatus(di.SessionConnectivityStatusStorage),
- tequilapi_endpoints.AddRoutesForDocs,
- tequilapi_endpoints.AddRoutesForCurrencyExchange(di.PilvytisAPI),
- tequilapi_endpoints.AddRoutesForPilvytis(di.PilvytisAPI, di.PilvytisOrderIssuer, di.LocationResolver),
- tequilapi_endpoints.AddRoutesForTerms,
- tequilapi_endpoints.AddEntertainmentRoutes(entertainment.NewEstimator(
- config.FlagPaymentPriceGiB.Value,
- config.FlagPaymentPriceHour.Value,
- )),
- tequilapi_endpoints.AddRoutesForValidator,
- },
+ handlers,
)
}
diff --git a/cmd/di.go b/cmd/di.go
index 1088d8ce3c..114b9f936a 100644
--- a/cmd/di.go
+++ b/cmd/di.go
@@ -150,8 +150,9 @@ type Dependencies struct {
EventBus eventbus.EventBus
- MultiConnectionManager connection.MultiManager
- ConnectionRegistry *connection.Registry
+ MultiConnectionManager connection.MultiManager
+ MultiConnectionDiagManager connection.DiagManager
+ ConnectionRegistry *connection.Registry
ServicesManager *service.Manager
ServiceRegistry *service.Registry
@@ -287,7 +288,7 @@ func (di *Dependencies) Bootstrap(nodeOptions node.Options) error {
return err
}
- if err := di.bootstrapQualityComponents(nodeOptions.Quality); err != nil {
+ if err := di.bootstrapQualityComponents(nodeOptions.Quality, nodeOptions); err != nil {
return err
}
@@ -299,6 +300,7 @@ func (di *Dependencies) Bootstrap(nodeOptions node.Options) error {
if err = di.handleConnStateChange(); err != nil {
return err
}
+
if err := di.Node.Start(); err != nil {
return err
}
@@ -581,6 +583,7 @@ func (di *Dependencies) bootstrapNodeComponents(nodeOptions node.Options, tequil
di.bootstrapBeneficiarySaver(nodeOptions)
di.ConnectionRegistry = connection.NewRegistry()
+
di.MultiConnectionManager = connection.NewMultiConnectionManager(func() connection.Manager {
return connection.NewManager(
pingpong.ExchangeFactoryFunc(
@@ -607,6 +610,31 @@ func (di *Dependencies) bootstrapNodeComponents(nodeOptions node.Options, tequil
)
})
+ if nodeOptions.ProvChecker {
+ di.MultiConnectionDiagManager = connection.NewDiagManager(
+ pingpong.ExchangeFactoryFunc(
+ di.Keystore,
+ di.SignerFactory,
+ di.ConsumerTotalsStorage,
+ di.AddressProvider,
+ di.EventBus,
+ nodeOptions.Payments.ConsumerDataLeewayMegabytes,
+ ),
+ di.ConnectionRegistry.CreateConnection,
+ di.EventBus,
+ di.IPResolver,
+ di.LocationResolver,
+ connection.DefaultConfig(),
+ config.GetDuration(config.FlagStatsReportInterval),
+ connection.NewValidator(
+ di.ConsumerBalanceTracker,
+ di.IdentityManager,
+ ),
+ di.P2PDialer,
+ di.allowTrustedDomainBypassTunnel,
+ di.disallowTrustedDomainBypassTunnel,
+ )
+ }
di.NATProber = natprobe.NewNATProber(di.MultiConnectionManager, di.EventBus)
di.LogCollector = logconfig.NewCollector(&logconfig.CurrentLogOptions)
@@ -655,7 +683,7 @@ func (di *Dependencies) bootstrapNodeComponents(nodeOptions node.Options, tequil
sleepNotifier := sleep.NewNotifier(di.MultiConnectionManager, di.EventBus)
sleepNotifier.Subscribe()
- di.Node = NewNode(di.MultiConnectionManager, tequilapiHTTPServer, di.EventBus, di.UIServer, sleepNotifier)
+ di.Node = NewNode(di.MultiConnectionManager, di.MultiConnectionDiagManager, tequilapiHTTPServer, di.EventBus, di.UIServer, sleepNotifier)
return nil
}
@@ -883,7 +911,7 @@ func (di *Dependencies) bootstrapIdentityComponents(options node.Options) error
return nil
}
-func (di *Dependencies) bootstrapQualityComponents(options node.OptionsQuality) (err error) {
+func (di *Dependencies) bootstrapQualityComponents(options node.OptionsQuality, nodeOptions node.Options) (err error) {
if err := di.AllowURLAccess(options.Address); err != nil {
return err
}
@@ -1065,7 +1093,7 @@ func (di *Dependencies) handleConnStateChange() error {
latestState := connectionstate.NotConnected
return di.EventBus.SubscribeAsync(connectionstate.AppTopicConnectionState, func(e connectionstate.AppEventConnectionState) {
- if config.GetBool(config.FlagProxyMode) || config.GetBool(config.FlagDVPNMode) {
+ if config.GetBool(config.FlagProxyMode) || config.GetBool(config.FlagDVPNMode) || config.GetBool(config.FlagProvCheckerMode) {
return // Proxy mode doesn't establish system wide tunnels, no reconnect required.
}
diff --git a/cmd/node.go b/cmd/node.go
index 870fcf635b..dfe7973436 100644
--- a/cmd/node.go
+++ b/cmd/node.go
@@ -42,23 +42,27 @@ type SleepNotifier interface {
}
// NewNode function creates new Mysterium node by given options
-func NewNode(connectionManager connection.MultiManager, tequilapiServer tequilapi.APIServer, publisher Publisher, uiServer UIServer, notifier SleepNotifier) *Node {
+func NewNode(connectionManager connection.MultiManager, connectionDiagManager connection.DiagManager, tequilapiServer tequilapi.APIServer, publisher Publisher, uiServer UIServer, notifier SleepNotifier) *Node {
return &Node{
- connectionManager: connectionManager,
- httpAPIServer: tequilapiServer,
- publisher: publisher,
- uiServer: uiServer,
- sleepNotifier: notifier,
+ connectionManager: connectionManager,
+ connectionDiagManager: connectionDiagManager,
+
+ httpAPIServer: tequilapiServer,
+ publisher: publisher,
+ uiServer: uiServer,
+ sleepNotifier: notifier,
}
}
// Node represent entrypoint for Mysterium node with top level components
type Node struct {
- connectionManager connection.MultiManager
- httpAPIServer tequilapi.APIServer
- publisher Publisher
- uiServer UIServer
- sleepNotifier SleepNotifier
+ connectionManager connection.MultiManager
+ connectionDiagManager connection.DiagManager
+
+ httpAPIServer tequilapi.APIServer
+ publisher Publisher
+ uiServer UIServer
+ sleepNotifier SleepNotifier
}
// Start starts Mysterium node (Tequilapi service, fetches location)
diff --git a/config/flags_network.go b/config/flags_network.go
index ce3255566b..b6f991d00c 100644
--- a/config/flags_network.go
+++ b/config/flags_network.go
@@ -154,6 +154,13 @@ var (
Usage: "DNS listen port for services",
Value: 11253,
}
+
+ // FlagProvCheckerDatabaseDSN sets DNS for checker API database
+ FlagProvCheckerDatabaseDSN = cli.StringFlag{
+ Name: "checker.dsn",
+ Usage: "Database DSN for provider checker",
+ Value: "",
+ }
)
// RegisterFlagsNetwork function register network flags to flag list
@@ -179,6 +186,7 @@ func RegisterFlagsNetwork(flags *[]cli.Flag) {
&FlagPortCheckServers,
&FlagStatsReportInterval,
&FlagDNSListenPort,
+ &FlagProvCheckerDatabaseDSN,
)
}
@@ -203,6 +211,7 @@ func ParseFlagsNetwork(ctx *cli.Context) {
Current.ParseStringFlag(ctx, FlagPortCheckServers)
Current.ParseDurationFlag(ctx, FlagStatsReportInterval)
Current.ParseIntFlag(ctx, FlagDNSListenPort)
+ Current.ParseStringFlag(ctx, FlagProvCheckerDatabaseDSN)
}
// BlockchainNetwork defines a blockchain network
diff --git a/config/flags_node.go b/config/flags_node.go
index acaf62e56c..8c8fd3fada 100644
--- a/config/flags_node.go
+++ b/config/flags_node.go
@@ -229,6 +229,13 @@ var (
Value: false,
}
+ // FlagProvCheckerMode allows running node under current user as a provider checker agent.
+ FlagProvCheckerMode = cli.BoolFlag{
+ Name: "provchecker",
+ Usage: "",
+ Value: false,
+ }
+
// FlagUserspace allows running a node without privileged permissions.
FlagUserspace = cli.BoolFlag{
Name: "userspace",
@@ -349,6 +356,7 @@ func RegisterFlagsNode(flags *[]cli.Flag) error {
&FlagUserMode,
&FlagDVPNMode,
&FlagProxyMode,
+ &FlagProvCheckerMode,
&FlagUserspace,
&FlagVendorID,
&FlagLauncherVersion,
@@ -411,6 +419,8 @@ func ParseFlagsNode(ctx *cli.Context) {
Current.ParseBoolFlag(ctx, FlagUserMode)
Current.ParseBoolFlag(ctx, FlagDVPNMode)
Current.ParseBoolFlag(ctx, FlagProxyMode)
+ Current.ParseBoolFlag(ctx, FlagProvCheckerMode)
+
Current.ParseBoolFlag(ctx, FlagUserspace)
Current.ParseStringFlag(ctx, FlagVendorID)
Current.ParseStringFlag(ctx, FlagLauncherVersion)
diff --git a/core/connection/interface.go b/core/connection/interface.go
index cf7f678140..bcf6a4b24c 100644
--- a/core/connection/interface.go
+++ b/core/connection/interface.go
@@ -39,6 +39,11 @@ type Connection interface {
Statistics() (connectionstate.Statistics, error)
}
+// ConnectionDiag is a specialised Connection interface for provider check
+type ConnectionDiag interface {
+ Diag() error
+}
+
// StateChannel is the channel we receive state change events on
type StateChannel chan connectionstate.State
@@ -73,3 +78,15 @@ type MultiManager interface {
// Reconnect reconnects current session
Reconnect(n int)
}
+
+// DiagManager interface provides methods to manage diagnotic connection
+type DiagManager interface {
+ // Connect creates new connection from given consumer to provider, reports error if connection already exists
+ Connect(consumerID identity.Identity, hermesID common.Address, proposal ProposalLookup, params ConnectParams) error
+ // Status queries current status of connection
+ Status() connectionstate.Status
+ // GetReadyChan returns a channel for getting a diagnostic result
+ GetReadyChan(providerID string) chan interface{}
+ // HasConnection returns true if a diagnostic connection is already established
+ HasConnection(providerID string) bool
+}
diff --git a/core/connection/manager-diag.go b/core/connection/manager-diag.go
new file mode 100644
index 0000000000..d8a061a047
--- /dev/null
+++ b/core/connection/manager-diag.go
@@ -0,0 +1,995 @@
+/*
+ * Copyright (C) 2024 The "MysteriumNetwork/node" Authors.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see .
+ */
+
+package connection
+
+import (
+ "context"
+ "encoding/json"
+ "errors"
+ "fmt"
+ "math/big"
+ "sync"
+ "time"
+
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/rs/zerolog/log"
+ "golang.org/x/time/rate"
+
+ "github.com/mysteriumnetwork/node/config"
+ "github.com/mysteriumnetwork/node/core/connection/connectionstate"
+ "github.com/mysteriumnetwork/node/core/discovery/proposal"
+ "github.com/mysteriumnetwork/node/core/ip"
+ "github.com/mysteriumnetwork/node/core/location"
+ "github.com/mysteriumnetwork/node/core/quality"
+ "github.com/mysteriumnetwork/node/eventbus"
+ "github.com/mysteriumnetwork/node/firewall"
+ "github.com/mysteriumnetwork/node/identity"
+ "github.com/mysteriumnetwork/node/market"
+ "github.com/mysteriumnetwork/node/p2p"
+ "github.com/mysteriumnetwork/node/pb"
+ "github.com/mysteriumnetwork/node/session"
+ "github.com/mysteriumnetwork/node/session/connectivity"
+ "github.com/mysteriumnetwork/node/trace"
+)
+
+type conn struct {
+ // These are populated by Connect at runtime.
+ ctx context.Context
+ ctxLock sync.RWMutex
+ status connectionstate.Status
+ statusLock sync.RWMutex
+ cleanupLock sync.Mutex
+ cleanup []func() error
+ cleanupAfterDisconnect []func() error
+ cleanupFinished chan struct{}
+ cleanupFinishedLock sync.Mutex
+ acknowledge func()
+ cancel func()
+ channel p2p.Channel
+
+ preReconnect func()
+ postReconnect func()
+
+ discoLock sync.Mutex
+ connectOptions ConnectOptions
+
+ activeConnection Connection
+ // statsTracker statsTracker
+
+ resChannel chan interface{}
+
+ uuid string
+}
+
+type diagConnectionManager struct {
+ // These are passed on creation.
+ paymentEngineFactory PaymentEngineFactory
+ newConnection Creator
+ eventBus eventbus.EventBus
+ ipResolver ip.Resolver
+ locationResolver location.OriginResolver
+ config Config
+ statsReportInterval time.Duration
+ validator validator
+ p2pDialer p2p.Dialer
+ timeGetter TimeGetter
+
+ // populated by Connect at runtime.
+ connsMu sync.Mutex
+ conns map[string]*conn
+
+ ratelimiter *rate.Limiter
+}
+
+// NewDiagManager creates connection manager with given dependencies
+func NewDiagManager(
+ paymentEngineFactory PaymentEngineFactory,
+ connectionCreator Creator,
+ eventBus eventbus.EventBus,
+ ipResolver ip.Resolver,
+ locationResolver location.OriginResolver,
+ config Config,
+ statsReportInterval time.Duration,
+ validator validator,
+ p2pDialer p2p.Dialer,
+ preReconnect, postReconnect func(),
+) *diagConnectionManager {
+
+ m := &diagConnectionManager{
+ conns: make(map[string]*conn),
+
+ newConnection: connectionCreator,
+ eventBus: eventBus,
+ paymentEngineFactory: paymentEngineFactory,
+ ipResolver: ipResolver,
+ locationResolver: locationResolver,
+ config: config,
+ statsReportInterval: statsReportInterval,
+ validator: validator,
+ p2pDialer: p2pDialer,
+ timeGetter: time.Now,
+
+ ratelimiter: rate.NewLimiter(rate.Every(1000*time.Millisecond), 1),
+ }
+
+ m.eventBus.SubscribeAsync(connectionstate.AppTopicConnectionState, m.reconnectOnHold)
+
+ return m
+}
+
+func (m *diagConnectionManager) chainID() int64 {
+ return config.GetInt64(config.FlagChainID)
+}
+
+func (m *diagConnectionManager) HasConnection(providerID string) bool {
+ m.connsMu.Lock()
+ defer m.connsMu.Unlock()
+
+ _, ok := m.conns[providerID]
+ return ok
+}
+
+func (m *diagConnectionManager) GetReadyChan(providerID string) chan interface{} {
+ m.connsMu.Lock()
+ defer m.connsMu.Unlock()
+
+ con, ok := m.conns[providerID]
+ if ok {
+ return con.resChannel
+ }
+ return nil
+}
+
+func (m *diagConnectionManager) Connect(consumerID identity.Identity, hermesID common.Address, proposalLookup ProposalLookup, params ConnectParams) (err error) {
+ var sessionID session.ID
+
+ ctx := context.Background()
+ err = m.ratelimiter.Wait(ctx) // This is a blocking call. Honors the rate limit
+ if err != nil {
+ log.Error().Msgf("ratelimiter.Wait: %s", err)
+ return err
+ }
+
+ proposal, err := proposalLookup()
+ if err != nil {
+ return fmt.Errorf("failed to lookup proposal: %w", err)
+ }
+
+ tracer := trace.NewTracer("Consumer whole Connect")
+ defer func() {
+ traceResult := tracer.Finish(m.eventBus, string(sessionID))
+ log.Debug().Msgf("Consumer connection trace: %s", traceResult)
+ }()
+
+ log.Error().Msgf("Connect > %v", proposal.ProviderID)
+ uuid := proposal.ProviderID
+
+ m.connsMu.Lock()
+ con, ok := m.conns[uuid]
+ if !ok {
+ con = new(conn)
+ con.status.State = connectionstate.NotConnected
+ con.uuid = uuid
+ m.conns[uuid] = con
+ }
+ m.connsMu.Unlock()
+
+ removeConnection := func() {
+ m.connsMu.Lock()
+ defer m.connsMu.Unlock()
+ delete(m.conns, uuid)
+ }
+
+ // make sure cache is cleared when connect terminates at any stage as part of disconnect
+ // we assume that IPResolver might be used / cache IP before connect
+ m.addCleanup(con, func() error {
+ m.clearIPCache()
+ return nil
+ })
+
+ if m.Status().State != connectionstate.NotConnected {
+ removeConnection()
+ return ErrAlreadyExists
+ }
+
+ prc := m.priceFromProposal(*proposal)
+
+ err = m.validator.Validate(m.chainID(), consumerID, prc)
+ if err != nil {
+ removeConnection()
+ return err
+ }
+
+ con.ctxLock.Lock()
+ con.ctx, con.cancel = context.WithCancel(context.Background())
+ con.ctxLock.Unlock()
+
+ m.statusConnecting(con, consumerID, hermesID, *proposal)
+ defer func() {
+ if err != nil {
+ log.Err(err).Msg("Connect failed, disconnecting")
+ m.disconnect(con)
+ }
+ }()
+
+ con.connectOptions = ConnectOptions{
+ ConsumerID: consumerID,
+ HermesID: hermesID,
+ Proposal: *proposal,
+ ProposalLookup: proposalLookup,
+ Params: params,
+ }
+
+ con.activeConnection, err = m.newConnection(proposal.ServiceType)
+ if err != nil {
+ removeConnection()
+ return err
+ }
+
+ sessionID, err = m.initSession(con, tracer, prc)
+ if err != nil {
+ removeConnection()
+ return err
+ }
+
+ originalPublicIP := m.getPublicIP()
+
+ err = m.startConnection(con, m.currentCtx(con), con.activeConnection, con.activeConnection.Start, con.connectOptions, tracer)
+ if err != nil {
+ removeConnection()
+ return m.handleStartError(con, sessionID, err)
+ }
+
+ err = m.waitForConnectedState(con, con.activeConnection.State())
+ if err != nil {
+ removeConnection()
+ return m.handleStartError(con, sessionID, err)
+ }
+
+ //m.statsTracker = newStatsTracker(m.eventBus, m.statsReportInterval)
+ //go m.statsTracker.start(m, m.activeConnection)
+ m.addCleanup(con, func() error {
+ log.Trace().Msg("Cleaning: stopping statistics publisher")
+ defer log.Trace().Msg("Cleaning: stopping statistics publisher DONE")
+ //m.statsTracker.stop()
+
+ removeConnection()
+ return nil
+ })
+
+ con.resChannel = make(chan interface{})
+ go Diag(m, con, proposal.ProviderID)
+
+ go m.consumeConnectionStates(con, con.activeConnection.State())
+ go m.checkSessionIP(con, con.channel, con.connectOptions.ConsumerID, con.connectOptions.SessionID, originalPublicIP)
+
+ return nil
+}
+
+func (m *diagConnectionManager) autoReconnect(con *conn) (err error) {
+ var sessionID session.ID
+
+ tracer := trace.NewTracer("Consumer whole autoReconnect")
+ defer func() {
+ traceResult := tracer.Finish(m.eventBus, string(sessionID))
+ log.Debug().Msgf("Consumer connection trace: %s", traceResult)
+ }()
+
+ proposal, err := con.connectOptions.ProposalLookup()
+ if err != nil {
+ return fmt.Errorf("failed to lookup proposal: %w", err)
+ }
+
+ con.connectOptions.Proposal = *proposal
+
+ sessionID, err = m.initSession(con, tracer, m.priceFromProposal(con.connectOptions.Proposal))
+ if err != nil {
+ return err
+ }
+
+ err = m.startConnection(con, m.currentCtx(con), con.activeConnection, con.activeConnection.Reconnect, con.connectOptions, tracer)
+ if err != nil {
+ return m.handleStartError(con, sessionID, err)
+ }
+
+ return nil
+}
+
+func (m *diagConnectionManager) priceFromProposal(proposal proposal.PricedServiceProposal) market.Price {
+ p := market.Price{
+ PricePerHour: proposal.Price.PricePerHour,
+ PricePerGiB: proposal.Price.PricePerGiB,
+ }
+
+ if config.GetBool(config.FlagPaymentsDuringSessionDebug) {
+ log.Info().Msg("Payments debug bas been enabled, will use absurd amounts for the proposal price")
+ amount := config.GetUInt64(config.FlagPaymentsAmountDuringSessionDebug)
+ if amount == 0 {
+ amount = 5000000000000000000
+ }
+
+ p = market.Price{
+ PricePerHour: new(big.Int).SetUint64(amount),
+ PricePerGiB: new(big.Int).SetUint64(amount),
+ }
+ }
+
+ return p
+}
+
+func (m *diagConnectionManager) initSession(con *conn, tracer *trace.Tracer, prc market.Price) (sessionID session.ID, err error) {
+
+ err = m.createP2PChannel(con, con.connectOptions, tracer)
+ if err != nil {
+ return sessionID, fmt.Errorf("could not create p2p channel during connect: %w", err)
+ }
+
+ con.connectOptions.ProviderNATConn = con.channel.ServiceConn()
+ con.connectOptions.ChannelConn = con.channel.Conn()
+
+ paymentSession, err := m.paymentLoop(con, con.connectOptions, prc)
+ if err != nil {
+ return sessionID, err
+ }
+
+ sessionDTO, err := m.createP2PSession(con, con.activeConnection, con.connectOptions, tracer, prc)
+ sessionID = session.ID(sessionDTO.GetID())
+ if err != nil {
+ m.sendSessionStatus(con, con.channel, con.connectOptions.ConsumerID, sessionID, connectivity.StatusSessionEstablishmentFailed, err)
+ return sessionID, err
+ }
+
+ traceStart := tracer.StartStage("Consumer session creation (start)")
+ go m.keepAliveLoop(con, con.channel, sessionID)
+ m.setStatus(con, func(status *connectionstate.Status) {
+ status.SessionID = sessionID
+ })
+ m.publishSessionCreate(con, sessionID)
+ paymentSession.SetSessionID(string(sessionID))
+ tracer.EndStage(traceStart)
+
+ con.connectOptions.SessionID = sessionID
+ con.connectOptions.SessionConfig = sessionDTO.GetConfig()
+
+ return sessionID, nil
+}
+
+func (m *diagConnectionManager) handleStartError(con *conn, sessionID session.ID, err error) error {
+
+ if errors.Is(err, context.Canceled) {
+ return ErrConnectionCancelled
+ }
+ m.addCleanupAfterDisconnect(con, func() error {
+ return m.sendSessionStatus(con, con.channel, con.connectOptions.ConsumerID, sessionID, connectivity.StatusConnectionFailed, err)
+ })
+ m.publishStateEvent(con, connectionstate.StateConnectionFailed)
+
+ log.Info().Err(err).Msg("Cancelling connection initiation: ")
+ m.Cancel()
+ return err
+}
+
+func (m *diagConnectionManager) clearIPCache() {
+ if config.GetBool(config.FlagProxyMode) || config.GetBool(config.FlagDVPNMode) {
+ return
+ }
+
+ if cr, ok := m.ipResolver.(*ip.CachedResolver); ok {
+ cr.ClearCache()
+ }
+}
+
+// checkSessionIP checks if IP has changed after connection was established.
+func (m *diagConnectionManager) checkSessionIP(con *conn, channel p2p.Channel, consumerID identity.Identity, sessionID session.ID, originalPublicIP string) {
+ if config.GetBool(config.FlagProxyMode) || config.GetBool(config.FlagDVPNMode) {
+ return
+ }
+
+ for i := 1; i <= m.config.IPCheck.MaxAttempts; i++ {
+ // Skip check if not connected. This may happen when context was canceled via Disconnect.
+ if m.Status().State != connectionstate.Connected {
+ return
+ }
+
+ newPublicIP := m.getPublicIP()
+ // If ip is changed notify peer that connection is successful.
+ if originalPublicIP != newPublicIP {
+ m.sendSessionStatus(con, channel, consumerID, sessionID, connectivity.StatusConnectionOk, nil)
+ return
+ }
+
+ // Notify peer and quality oracle that ip is not changed after tunnel connection was established.
+ if i == m.config.IPCheck.MaxAttempts {
+ m.sendSessionStatus(con, channel, consumerID, sessionID, connectivity.StatusSessionIPNotChanged, nil)
+ m.publishStateEvent(con, connectionstate.StateIPNotChanged)
+ return
+ }
+
+ time.Sleep(m.config.IPCheck.SleepDurationAfterCheck)
+ }
+}
+
+// sendSessionStatus sends session connectivity status to other peer.
+func (m *diagConnectionManager) sendSessionStatus(con *conn, channel p2p.ChannelSender, consumerID identity.Identity, sessionID session.ID, code connectivity.StatusCode, errDetails error) error {
+ var errDetailsMsg string
+ if errDetails != nil {
+ errDetailsMsg = errDetails.Error()
+ }
+
+ sessionStatus := &pb.SessionStatus{
+ ConsumerID: consumerID.Address,
+ SessionID: string(sessionID),
+ Code: uint32(code),
+ Message: errDetailsMsg,
+ }
+
+ log.Debug().Msgf("Sending session status P2P message to %q: %s", p2p.TopicSessionStatus, sessionStatus.String())
+
+ ctx, cancel := context.WithTimeout(m.currentCtx(con), 20*time.Second)
+ defer cancel()
+ _, err := channel.Send(ctx, p2p.TopicSessionStatus, p2p.ProtoMessage(sessionStatus))
+ if err != nil {
+ return fmt.Errorf("could not send p2p session status message: %w", err)
+ }
+
+ return nil
+}
+
+func (m *diagConnectionManager) getPublicIP() string {
+ currentPublicIP, err := m.ipResolver.GetPublicIP()
+ if err != nil {
+ log.Error().Err(err).Msg("Could not get current public IP")
+ return ""
+ }
+ return currentPublicIP
+}
+
+func (m *diagConnectionManager) paymentLoop(con *conn, opts ConnectOptions, price market.Price) (PaymentIssuer, error) {
+
+ payments, err := m.paymentEngineFactory(con.uuid, con.channel, opts.ConsumerID, identity.FromAddress(opts.Proposal.ProviderID), opts.HermesID, opts.Proposal, price)
+ if err != nil {
+ return nil, err
+ }
+ m.addCleanup(con, func() error {
+ log.Trace().Msg("Cleaning: payments")
+ defer log.Trace().Msg("Cleaning: payments DONE")
+ payments.Stop()
+ return nil
+ })
+
+ go func() {
+ err := payments.Start()
+ if err != nil {
+ log.Error().Err(err).Msg("Payment error")
+
+ if config.GetBool(config.FlagKeepConnectedOnFail) {
+ m.statusOnHold(con)
+ } else {
+ err = m.Disconnect()
+ if err != nil {
+ log.Error().Err(err).Msg("Could not disconnect gracefully")
+ }
+ }
+ }
+ }()
+ return payments, nil
+}
+
+func (m *diagConnectionManager) cleanConnection(con *conn) {
+ con.cleanupLock.Lock()
+ defer con.cleanupLock.Unlock()
+
+ for i := len(con.cleanup) - 1; i >= 0; i-- {
+ log.Trace().Msgf("Connection cleaning up: (%v/%v)", i+1, len(con.cleanup))
+ err := con.cleanup[i]()
+ if err != nil {
+ log.Warn().Err(err).Msg("Cleanup error")
+ }
+ }
+ con.cleanup = nil
+}
+
+func (m *diagConnectionManager) cleanAfterDisconnect(con *conn) {
+ con.cleanupLock.Lock()
+ defer con.cleanupLock.Unlock()
+
+ for i := len(con.cleanupAfterDisconnect) - 1; i >= 0; i-- {
+ log.Trace().Msgf("Connection cleaning up (after disconnect): (%v/%v)", i+1, len(con.cleanupAfterDisconnect))
+ err := con.cleanupAfterDisconnect[i]()
+ if err != nil {
+ log.Warn().Err(err).Msg("Cleanup error")
+ }
+ }
+ con.cleanupAfterDisconnect = nil
+}
+
+func (m *diagConnectionManager) createP2PChannel(con *conn, opts ConnectOptions, tracer *trace.Tracer) error {
+ trace := tracer.StartStage("Consumer P2P channel creation")
+ defer tracer.EndStage(trace)
+
+ contactDef, err := p2p.ParseContact(opts.Proposal.Contacts)
+ if err != nil {
+ return fmt.Errorf("provider does not support p2p communication: %w", err)
+ }
+
+ timeoutCtx, cancel := context.WithTimeout(m.currentCtx(con), p2pDialTimeout)
+ defer cancel()
+
+ // TODO register all handlers before channel read/write loops
+ channel, err := m.p2pDialer.Dial(timeoutCtx, opts.ConsumerID, identity.FromAddress(opts.Proposal.ProviderID), opts.Proposal.ServiceType, contactDef, tracer)
+ if err != nil {
+ return fmt.Errorf("p2p dialer failed: %w", err)
+ }
+ m.addCleanupAfterDisconnect(con, func() error {
+ log.Trace().Msg("Cleaning: closing P2P communication channel")
+ defer log.Trace().Msg("Cleaning: P2P communication channel DONE")
+
+ return channel.Close()
+ })
+
+ con.channel = channel
+ return nil
+}
+
+func (m *diagConnectionManager) addCleanupAfterDisconnect(con *conn, fn func() error) {
+ con.cleanupLock.Lock()
+ defer con.cleanupLock.Unlock()
+ con.cleanupAfterDisconnect = append(con.cleanupAfterDisconnect, fn)
+}
+
+func (m *diagConnectionManager) addCleanup(con *conn, fn func() error) {
+ con.cleanupLock.Lock()
+ defer con.cleanupLock.Unlock()
+ con.cleanup = append(con.cleanup, fn)
+}
+
+func (m *diagConnectionManager) createP2PSession(con *conn, c Connection, opts ConnectOptions, tracer *trace.Tracer, requestedPrice market.Price) (*pb.SessionResponse, error) {
+ trace := tracer.StartStage("Consumer session creation")
+ defer tracer.EndStage(trace)
+
+ sessionCreateConfig, err := c.GetConfig()
+ if err != nil {
+ return nil, fmt.Errorf("could not get session config: %w", err)
+ }
+
+ config, err := json.Marshal(sessionCreateConfig)
+ if err != nil {
+ return nil, fmt.Errorf("could not marshal session config: %w", err)
+ }
+
+ sessionRequest := &pb.SessionRequest{
+ Consumer: &pb.ConsumerInfo{
+ Id: opts.ConsumerID.Address,
+ HermesID: opts.HermesID.Hex(),
+ PaymentVersion: "v3",
+ Location: &pb.LocationInfo{
+ Country: m.Status().ConsumerLocation.Country,
+ },
+ Pricing: &pb.Pricing{
+ PerGib: requestedPrice.PricePerGiB.Bytes(),
+ PerHour: requestedPrice.PricePerHour.Bytes(),
+ },
+ },
+ ProposalID: opts.Proposal.ID,
+ Config: config,
+ }
+ log.Debug().Msgf("Sending P2P message to %q: %s", p2p.TopicSessionCreate, sessionRequest.String())
+ ctx, cancel := context.WithTimeout(m.currentCtx(con), 20*time.Second)
+ defer cancel()
+ res, err := con.channel.Send(ctx, p2p.TopicSessionCreate, p2p.ProtoMessage(sessionRequest))
+ if err != nil {
+ return nil, fmt.Errorf("could not send p2p session create request: %w", err)
+ }
+
+ var sessionResponse pb.SessionResponse
+ err = res.UnmarshalProto(&sessionResponse)
+ if err != nil {
+ return nil, fmt.Errorf("could not unmarshal session reply to proto: %w", err)
+ }
+
+ channel := con.channel
+ con.acknowledge = func() {
+ pc := &pb.SessionInfo{
+ ConsumerID: opts.ConsumerID.Address,
+ SessionID: sessionResponse.GetID(),
+ }
+ log.Debug().Msgf("Sending P2P message to %q: %s", p2p.TopicSessionAcknowledge, pc.String())
+ ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second)
+ defer cancel()
+ _, err := channel.Send(ctx, p2p.TopicSessionAcknowledge, p2p.ProtoMessage(pc))
+ if err != nil {
+ log.Warn().Err(err).Msg("Acknowledge failed")
+ }
+ }
+ m.addCleanupAfterDisconnect(con, func() error {
+ log.Trace().Msg("Cleaning: requesting session destroy")
+ defer log.Trace().Msg("Cleaning: requesting session destroy DONE")
+
+ sessionDestroy := &pb.SessionInfo{
+ ConsumerID: opts.ConsumerID.Address,
+ SessionID: sessionResponse.GetID(),
+ }
+
+ log.Debug().Msgf("Sending P2P message to %q: %s", p2p.TopicSessionDestroy, sessionDestroy.String())
+ ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second)
+ defer cancel()
+ _, err := con.channel.Send(ctx, p2p.TopicSessionDestroy, p2p.ProtoMessage(sessionDestroy))
+ if err != nil {
+ return fmt.Errorf("could not send session destroy request: %w", err)
+ }
+
+ return nil
+ })
+
+ return &sessionResponse, nil
+}
+
+func (m *diagConnectionManager) publishSessionCreate(con *conn, sessionID session.ID) {
+ sessionInfo := m.Status()
+ // avoid printing IP address in logs
+ sessionInfo.ConsumerLocation.IP = ""
+
+ m.eventBus.Publish(connectionstate.AppTopicConnectionSession, connectionstate.AppEventConnectionSession{
+ Status: connectionstate.SessionCreatedStatus,
+ SessionInfo: sessionInfo,
+ })
+
+ m.addCleanup(con, func() error {
+ log.Trace().Msg("Cleaning: publishing session ended status")
+ defer log.Trace().Msg("Cleaning: publishing session ended status DONE")
+
+ sessionInfo := m.Status()
+ // avoid printing IP address in logs
+ sessionInfo.ConsumerLocation.IP = ""
+
+ m.eventBus.Publish(connectionstate.AppTopicConnectionSession, connectionstate.AppEventConnectionSession{
+ Status: connectionstate.SessionEndedStatus,
+ SessionInfo: sessionInfo,
+ })
+ return nil
+ })
+}
+
+func (m *diagConnectionManager) startConnection(con *conn, ctx context.Context, conn Connection, start ConnectionStart, connectOptions ConnectOptions, tracer *trace.Tracer) (err error) {
+ trace := tracer.StartStage("Consumer start connection")
+ defer tracer.EndStage(trace)
+
+ if err = start(ctx, connectOptions); err != nil {
+ return err
+ }
+ m.addCleanup(con, func() error {
+ log.Trace().Msg("Cleaning: stopping connection")
+ defer log.Trace().Msg("Cleaning: stopping connection DONE")
+ conn.Stop()
+ return nil
+ })
+
+ err = m.setupTrafficBlock(con, connectOptions.Params.DisableKillSwitch)
+ if err != nil {
+ return err
+ }
+
+ // Clear IP cache so session IP check can report that IP has really changed.
+ m.clearIPCache()
+
+ return nil
+}
+
+func (m *diagConnectionManager) Status() connectionstate.Status {
+ log.Debug().Msg("Status() - not used")
+ return connectionstate.Status{State: connectionstate.NotConnected}
+}
+
+func (m *diagConnectionManager) UUID() string {
+ log.Debug().Msg("UUID() - not used")
+ return ""
+}
+
+func (m *diagConnectionManager) Stats() connectionstate.Statistics {
+ log.Debug().Msg("Stats() - not used")
+ return connectionstate.Statistics{}
+}
+
+func (m *diagConnectionManager) setStatus(con *conn, delta func(status *connectionstate.Status)) {
+ con.statusLock.Lock()
+ stateWas := con.status.State
+
+ delta(&con.status)
+
+ state := con.status.State
+ con.statusLock.Unlock()
+
+ if state != stateWas {
+ log.Info().Msgf("Connection state: %v -> %v", stateWas, state)
+ m.publishStateEvent(con, state)
+ }
+}
+
+func (m *diagConnectionManager) statusConnecting(con *conn, consumerID identity.Identity, accountantID common.Address, proposal proposal.PricedServiceProposal) {
+ m.setStatus(con, func(status *connectionstate.Status) {
+ *status = connectionstate.Status{
+ StartedAt: m.timeGetter(),
+ ConsumerID: consumerID,
+ ConsumerLocation: m.locationResolver.GetOrigin(),
+ HermesID: accountantID,
+ Proposal: proposal,
+ State: connectionstate.Connecting,
+ }
+ })
+}
+
+func (m *diagConnectionManager) statusConnected(con *conn) {
+ m.setStatus(con, func(status *connectionstate.Status) {
+ status.State = connectionstate.Connected
+ })
+}
+
+func (m *diagConnectionManager) statusReconnecting(con *conn) {
+ m.setStatus(con, func(status *connectionstate.Status) {
+ status.State = connectionstate.Reconnecting
+ })
+}
+
+func (m *diagConnectionManager) statusNotConnected(con *conn) {
+ m.setStatus(con, func(status *connectionstate.Status) {
+ status.State = connectionstate.NotConnected
+ })
+}
+
+func (m *diagConnectionManager) statusDisconnecting(con *conn) {
+ m.setStatus(con, func(status *connectionstate.Status) {
+ status.State = connectionstate.Disconnecting
+ })
+}
+
+func (m *diagConnectionManager) statusCanceled(con *conn) {
+ m.setStatus(con, func(status *connectionstate.Status) {
+ status.State = connectionstate.Canceled
+ })
+}
+
+func (m *diagConnectionManager) statusOnHold(con *conn) {
+ m.setStatus(con, func(status *connectionstate.Status) {
+ status.State = connectionstate.StateOnHold
+ })
+}
+
+func (m *diagConnectionManager) Cancel() {
+ log.Error().Msg("Cancel() - not used")
+}
+
+func (m *diagConnectionManager) DisconnectSingle(con *conn) error {
+ m.statusDisconnecting(con)
+ m.disconnect(con)
+ return nil
+}
+
+func (m *diagConnectionManager) Disconnect() error {
+ log.Error().Msg("Disconnect() - not used")
+ return nil
+}
+
+func (m *diagConnectionManager) CheckChannel(con *conn, ctx context.Context) error {
+ if err := m.sendKeepAlivePing(ctx, con.channel, m.Status().SessionID); err != nil {
+ return fmt.Errorf("keep alive ping failed: %w", err)
+ }
+ return nil
+}
+
+func (m *diagConnectionManager) disconnect(con *conn) {
+ con.discoLock.Lock()
+ defer con.discoLock.Unlock()
+
+ con.cleanupFinishedLock.Lock()
+ defer con.cleanupFinishedLock.Unlock()
+ con.cleanupFinished = make(chan struct{})
+ defer close(con.cleanupFinished)
+
+ con.ctxLock.Lock()
+ con.cancel()
+ con.ctxLock.Unlock()
+
+ m.cleanConnection(con)
+ m.statusNotConnected(con)
+
+ m.cleanAfterDisconnect(con)
+}
+
+func (m *diagConnectionManager) waitForConnectedState(con *conn, stateChannel <-chan connectionstate.State) error {
+ log.Debug().Msg("waiting for connected state")
+ for {
+ select {
+ case state, more := <-stateChannel:
+ if !more {
+ return ErrConnectionFailed
+ }
+
+ switch state {
+ case connectionstate.Connected:
+ log.Debug().Msg("Connected started event received")
+ if con.acknowledge != nil {
+ go con.acknowledge()
+ }
+ m.onStateChanged(con, state)
+ return nil
+ default:
+ m.onStateChanged(con, state)
+ }
+ case <-m.currentCtx(con).Done():
+ return m.currentCtx(con).Err()
+ }
+ }
+}
+
+func (m *diagConnectionManager) consumeConnectionStates(con *conn, stateChannel <-chan connectionstate.State) {
+ for state := range stateChannel {
+ m.onStateChanged(con, state)
+ }
+}
+
+func (m *diagConnectionManager) onStateChanged(con *conn, state connectionstate.State) {
+ log.Debug().Msgf("Connection state received: %s", state)
+
+ // React just to certain stains from connection. Because disconnect happens in connectionWaiter
+ switch state {
+ case connectionstate.Connected:
+ m.statusConnected(con)
+ case connectionstate.Reconnecting:
+ m.statusReconnecting(con)
+ }
+}
+
+func (m *diagConnectionManager) setupTrafficBlock(con *conn, disableKillSwitch bool) error {
+ if disableKillSwitch {
+ return nil
+ }
+
+ outboundIP, err := m.ipResolver.GetOutboundIP()
+ if err != nil {
+ return err
+ }
+
+ removeRule, err := firewall.BlockNonTunnelTraffic(firewall.Session, outboundIP)
+ if err != nil {
+ return err
+ }
+ m.addCleanup(con, func() error {
+ log.Trace().Msg("Cleaning: traffic block rule")
+ defer log.Trace().Msg("Cleaning: traffic block rule DONE")
+
+ removeRule()
+
+ return nil
+ })
+ return nil
+}
+
+func (m *diagConnectionManager) reconnectOnHold(state connectionstate.AppEventConnectionState) {
+ if state.State != connectionstate.StateOnHold || !config.GetBool(config.FlagAutoReconnect) {
+ return
+ }
+
+ con, ok := m.conns[state.UUID]
+ if !ok {
+ return
+ }
+
+ if con.channel != nil {
+ con.channel.Close()
+ }
+
+ con.preReconnect()
+ m.clearIPCache()
+
+ for err := m.autoReconnect(con); err != nil; err = m.autoReconnect(con) {
+ select {
+ case <-m.currentCtx(con).Done():
+ log.Info().Err(m.currentCtx(con).Err()).Msg("Stopping reconnect")
+ return
+ default:
+ log.Error().Err(err).Msg("Failed to reconnect active session, will try again")
+ }
+ }
+ con.postReconnect()
+}
+
+func (m *diagConnectionManager) publishStateEvent(con *conn, state connectionstate.State) {
+ sessionInfo := m.Status()
+ // avoid printing IP address in logs
+ sessionInfo.ConsumerLocation.IP = ""
+
+ m.eventBus.Publish(connectionstate.AppTopicConnectionState, connectionstate.AppEventConnectionState{
+ UUID: con.uuid,
+ State: state,
+ SessionInfo: sessionInfo,
+ })
+}
+
+func (m *diagConnectionManager) keepAliveLoop(con *conn, channel p2p.Channel, sessionID session.ID) {
+ // Register handler for handling p2p keep alive pings from provider.
+ channel.Handle(p2p.TopicKeepAlive, func(c p2p.Context) error {
+ var ping pb.P2PKeepAlivePing
+ if err := c.Request().UnmarshalProto(&ping); err != nil {
+ return err
+ }
+
+ log.Debug().Msgf("Received p2p keepalive ping with SessionID=%s from %s", ping.SessionID, c.PeerID().ToCommonAddress())
+ return c.OK()
+ })
+
+ // Send pings to provider.
+ var errCount int
+ for {
+ select {
+ case <-m.currentCtx(con).Done():
+ log.Debug().Msgf("Stopping p2p keepalive: %v", m.currentCtx(con).Err())
+ return
+ case <-time.After(m.config.KeepAlive.SendInterval):
+ ctx, cancel := context.WithTimeout(context.Background(), m.config.KeepAlive.SendTimeout)
+ if err := m.sendKeepAlivePing(ctx, channel, sessionID); err != nil {
+ log.Err(err).Msgf("Failed to send p2p keepalive ping. SessionID=%s", sessionID)
+ errCount++
+ if errCount == m.config.KeepAlive.MaxSendErrCount {
+ log.Error().Msgf("Max p2p keepalive err count reached, disconnecting. SessionID=%s", sessionID)
+ if config.GetBool(config.FlagKeepConnectedOnFail) {
+ m.statusOnHold(con)
+ } else {
+ //m.Disconnect()
+ log.Error().Msgf("Max p2p keepalive err count reached, disconnecting. SessionID=%s >>>>>>>>>", sessionID)
+ m.DisconnectSingle(con)
+ }
+ cancel()
+ return
+ }
+ } else {
+ errCount = 0
+ }
+ cancel()
+ }
+ }
+}
+
+func (m *diagConnectionManager) sendKeepAlivePing(ctx context.Context, channel p2p.Channel, sessionID session.ID) error {
+ msg := &pb.P2PKeepAlivePing{
+ SessionID: string(sessionID),
+ }
+
+ start := time.Now()
+ _, err := channel.Send(ctx, p2p.TopicKeepAlive, p2p.ProtoMessage(msg))
+ if err != nil {
+ return err
+ }
+
+ m.eventBus.Publish(quality.AppTopicConsumerPingP2P, quality.PingEvent{
+ SessionID: string(sessionID),
+ Duration: time.Since(start),
+ })
+
+ return nil
+}
+
+func (m *diagConnectionManager) currentCtx(con *conn) context.Context {
+ con.ctxLock.RLock()
+ defer con.ctxLock.RUnlock()
+
+ return con.ctx
+}
+
+func (m *diagConnectionManager) Reconnect() {
+ log.Error().Msg("Reconnect - not used")
+}
diff --git a/core/connection/pinger.go b/core/connection/pinger.go
new file mode 100644
index 0000000000..202ccf70bb
--- /dev/null
+++ b/core/connection/pinger.go
@@ -0,0 +1,41 @@
+/*
+ * Copyright (C) 2024 The "MysteriumNetwork/node" Authors.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see .
+ */
+
+package connection
+
+import (
+ "github.com/mysteriumnetwork/node/core/quality"
+ "github.com/rs/zerolog/log"
+)
+
+// Diag is used to start provider check
+func Diag(cm *diagConnectionManager, con *conn, providerID string) {
+ c, ok := con.activeConnection.(ConnectionDiag)
+ res := error(nil)
+ if ok {
+ log.Debug().Msgf("Check provider> %v", providerID)
+
+ res = c.Diag()
+ cm.DisconnectSingle(con)
+ }
+ ev := quality.DiagEvent{ProviderID: providerID, Result: res == nil}
+ if res != nil {
+ ev.Error = res
+ }
+ con.resChannel <- ev
+ close(con.resChannel)
+}
diff --git a/core/node/options.go b/core/node/options.go
index 04085ff27e..7ecb9cd22a 100644
--- a/core/node/options.go
+++ b/core/node/options.go
@@ -81,8 +81,9 @@ type Options struct {
Payments OptionsPayments
- Consumer bool
- Mobile bool
+ Consumer bool
+ Mobile bool
+ ProvChecker bool
SwarmDialerDNSHeadstart time.Duration
PilvytisAddress string
@@ -205,6 +206,7 @@ func GetOptions() *Options {
SSE: OptionsSSE{
Enabled: config.GetBool(config.FlagSSEEnable),
},
+ ProvChecker: config.GetBool(config.FlagProvCheckerMode),
}
}
diff --git a/core/quality/metrics.go b/core/quality/metrics.go
index 51d3489607..ae85bc2c47 100644
--- a/core/quality/metrics.go
+++ b/core/quality/metrics.go
@@ -102,6 +102,13 @@ type PingEvent struct {
Duration time.Duration `json:"duration"`
}
+// DiagEvent represents provider check result event
+type DiagEvent struct {
+ ProviderID string
+ Result bool
+ Error error
+}
+
const (
// AppTopicConnectionEvents represents event bus topic for the connection events.
AppTopicConnectionEvents = "connection_events"
diff --git a/go.mod b/go.mod
index 7de7350f57..eba0ee5878 100644
--- a/go.mod
+++ b/go.mod
@@ -153,11 +153,15 @@ require (
github.com/imdario/mergo v0.3.12 // indirect
github.com/ipfs/go-cid v0.4.1 // indirect
github.com/ipfs/go-log/v2 v2.5.1 // indirect
+ github.com/jackc/pgpassfile v1.0.0 // indirect
+ github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
+ github.com/jackc/pgx/v5 v5.5.5 // indirect
+ github.com/jackc/puddle/v2 v2.2.1 // indirect
github.com/jackpal/go-nat-pmp v1.0.2 // indirect
github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 // indirect
github.com/jbenet/go-temp-err-catcher v0.1.0 // indirect
github.com/jinzhu/gorm v1.9.2 // indirect
- github.com/jinzhu/inflection v0.0.0-20180308033659-04140366298a // indirect
+ github.com/jinzhu/inflection v1.0.0 // indirect
github.com/jinzhu/now v1.1.5 // indirect
github.com/jmespath/go-jmespath v0.4.0 // indirect
github.com/josharian/native v0.0.0-20200817173448-b6b71def0850 // indirect
@@ -183,6 +187,7 @@ require (
github.com/mattn/go-colorable v0.1.13 // indirect
github.com/mattn/go-isatty v0.0.20 // indirect
github.com/mattn/go-pointer v0.0.1 // indirect
+ github.com/mattn/go-sqlite3 v1.14.22 // indirect
github.com/matttproud/golang_protobuf_extensions/v2 v2.0.0 // indirect
github.com/mdlayher/genetlink v1.1.0 // indirect
github.com/mdlayher/netlink v1.4.2 // indirect
@@ -265,6 +270,9 @@ require (
gopkg.in/intercom/intercom-go.v2 v2.0.0-20210504094731-2bd1af0ce4b2 // indirect
gopkg.in/warnings.v0 v0.1.2 // indirect
gopkg.in/yaml.v3 v3.0.1 // indirect
+ gorm.io/driver/postgres v1.5.9 // indirect
+ gorm.io/driver/sqlite v1.5.6 // indirect
+ gorm.io/gorm v1.25.11 // indirect
honnef.co/go/tools v0.4.2 // indirect
lukechampine.com/blake3 v1.2.1 // indirect
rsc.io/tmplfunc v0.0.3 // indirect
diff --git a/go.sum b/go.sum
index 5ecaaf07cb..2e9595a91b 100644
--- a/go.sum
+++ b/go.sum
@@ -926,6 +926,14 @@ github.com/ipfs/go-detect-race v0.0.1 h1:qX/xay2W3E4Q1U7d9lNs1sU9nvguX0a7319XbyQ
github.com/ipfs/go-detect-race v0.0.1/go.mod h1:8BNT7shDZPo99Q74BpGMK+4D8Mn4j46UU0LZ723meps=
github.com/ipfs/go-log/v2 v2.5.1 h1:1XdUzF7048prq4aBjDQQ4SL5RxftpRGdXhNRwKSAlcY=
github.com/ipfs/go-log/v2 v2.5.1/go.mod h1:prSpmC1Gpllc9UYWxDiZDreBYw7zp4Iqp1kOLU9U5UI=
+github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM=
+github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg=
+github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk=
+github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM=
+github.com/jackc/pgx/v5 v5.5.5 h1:amBjrZVmksIdNjxGW/IiIMzxMKZFelXbUoPNb+8sjQw=
+github.com/jackc/pgx/v5 v5.5.5/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A=
+github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk=
+github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4=
github.com/jackpal/gateway v1.0.6 h1:/MJORKvJEwNVldtGVJC2p2cwCnsSoLn3hl3zxmZT7tk=
github.com/jackpal/gateway v1.0.6/go.mod h1:lTpwd4ACLXmpyiCTRtfiNyVnUmqT9RivzCDQetPfnjA=
github.com/jackpal/go-nat-pmp v1.0.2 h1:KzKSgb7qkJvOUTqYl9/Hg/me3pWgBmERKrTGD7BdWus=
@@ -943,6 +951,8 @@ github.com/jinzhu/gorm v1.9.2 h1:lCvgEaqe/HVE+tjAR2mt4HbbHAZsQOv3XAZiEZV37iw=
github.com/jinzhu/gorm v1.9.2/go.mod h1:Vla75njaFJ8clLU1W44h34PjIkijhjHIYnZxMqCdxqo=
github.com/jinzhu/inflection v0.0.0-20180308033659-04140366298a h1:eeaG9XMUvRBYXJi4pg1ZKM7nxc5AfXfojeLLW7O5J3k=
github.com/jinzhu/inflection v0.0.0-20180308033659-04140366298a/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
+github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E=
+github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc=
github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ=
github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8=
github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg=
@@ -1074,6 +1084,8 @@ github.com/mattn/go-runewidth v0.0.13 h1:lTGmDsbAYt5DmK6OnoV7EuIF1wEIFAcxld6ypU4
github.com/mattn/go-runewidth v0.0.13/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w=
github.com/mattn/go-sqlite3 v1.10.0 h1:jbhqpg7tQe4SupckyijYiy0mJJ/pRyHvXf7JdWK860o=
github.com/mattn/go-sqlite3 v1.10.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc=
+github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU=
+github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y=
github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0=
github.com/matttproud/golang_protobuf_extensions/v2 v2.0.0 h1:jWpvCLoY8Z/e3VKvlsiIGKtc+UG6U5vzxaoagmhXfyg=
github.com/matttproud/golang_protobuf_extensions/v2 v2.0.0/go.mod h1:QUyp042oQthUoa9bqDv0ER0wrtXnBruoNd7aNjkbP+k=
@@ -2247,6 +2259,14 @@ gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776/go.mod h1:K4uyk7z7BCEPqu6E+C
gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
+gorm.io/driver/postgres v1.5.9 h1:DkegyItji119OlcaLjqN11kHoUgZ/j13E0jkJZgD6A8=
+gorm.io/driver/postgres v1.5.9/go.mod h1:DX3GReXH+3FPWGrrgffdvCk3DQ1dwDPdmbenSkweRGI=
+gorm.io/driver/sqlite v1.5.6 h1:fO/X46qn5NUEEOZtnjJRWRzZMe8nqJiQ9E+0hi+hKQE=
+gorm.io/driver/sqlite v1.5.6/go.mod h1:U+J8craQU6Fzkcvu8oLeAQmi50TkwPEhHDEjQZXDah4=
+gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde h1:9DShaph9qhkIYw7QF91I/ynrr4cOO2PZra2PFD7Mfeg=
+gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8=
+gorm.io/gorm v1.25.11 h1:/Wfyg1B/je1hnDx3sMkX+gAlxrlZpn6X0BXRlwXlvHg=
+gorm.io/gorm v1.25.11/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ=
gotest.tools v2.2.0+incompatible/go.mod h1:DsYFclhRJ6vuDpmuTbkuFWG+y2sxOXAzmJt81HFBacw=
gotest.tools/v3 v3.4.0/go.mod h1:CtbdzLSsqVhDgMtKsx03ird5YTGB3ar27v0u/yKBW5g=
grpc.go4.org v0.0.0-20170609214715-11d0a25b4919/go.mod h1:77eQGdRu53HpSqPFJFmuJdjuHRquDANNeA4x7B8WQ9o=
diff --git a/services/wireguard/connection/connection.go b/services/wireguard/connection/connection.go
index 3657b8b00c..0a292591b2 100644
--- a/services/wireguard/connection/connection.go
+++ b/services/wireguard/connection/connection.go
@@ -86,6 +86,11 @@ func (c *Connection) State() <-chan connectionstate.State {
return c.stateCh
}
+// Diag is used to start provider check
+func (c *Connection) Diag() error {
+ return c.connectionEndpoint.Diag()
+}
+
// Statistics returns connection statistics channel.
func (c *Connection) Statistics() (connectionstate.Statistics, error) {
stats, err := c.connectionEndpoint.PeerStats()
diff --git a/services/wireguard/connection/connection_test.go b/services/wireguard/connection/connection_test.go
index cbc2930b6a..07af4e923e 100644
--- a/services/wireguard/connection/connection_test.go
+++ b/services/wireguard/connection/connection_test.go
@@ -158,6 +158,9 @@ func (mce *mockConnectionEndpoint) ConfigureRoutes(_ net.IP) error { retur
func (mce *mockConnectionEndpoint) PeerStats() (wgcfg.Stats, error) {
return wgcfg.Stats{LastHandshake: time.Now(), BytesSent: 10, BytesReceived: 11}, nil
}
+func (mce *mockConnectionEndpoint) Diag() error {
+ return nil
+}
type mockHandshakeWaiter struct {
err error
diff --git a/services/wireguard/endpoint.go b/services/wireguard/endpoint.go
index 8d2201b795..3085434720 100644
--- a/services/wireguard/endpoint.go
+++ b/services/wireguard/endpoint.go
@@ -34,4 +34,5 @@ type ConnectionEndpoint interface {
Config() (ServiceConfig, error)
InterfaceName() string
Stop() error
+ Diag() error
}
diff --git a/services/wireguard/endpoint/diagclient/client.go b/services/wireguard/endpoint/diagclient/client.go
new file mode 100644
index 0000000000..6e50d23759
--- /dev/null
+++ b/services/wireguard/endpoint/diagclient/client.go
@@ -0,0 +1,155 @@
+/*
+ * Copyright (C) 2024 The "MysteriumNetwork/node" Authors.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see .
+ */
+
+package diagclient
+
+import (
+ "bufio"
+ "errors"
+ "fmt"
+ "io"
+ "net/http"
+ "net/netip"
+ "strings"
+ "sync"
+ "time"
+
+ "github.com/rs/zerolog/log"
+ "golang.zx2c4.com/wireguard/conn"
+ "golang.zx2c4.com/wireguard/device"
+
+ "github.com/mysteriumnetwork/node/services/wireguard/endpoint/netstack"
+ "github.com/mysteriumnetwork/node/services/wireguard/endpoint/userspace"
+ "github.com/mysteriumnetwork/node/services/wireguard/wgcfg"
+)
+
+type client struct {
+ mu sync.Mutex
+ Device *device.Device
+ tnet *netstack.Net
+}
+
+// New create new WireGuard client testing the provider.
+func New() (*client, error) {
+ log.Error().Msg("Creating pinger wg client")
+ return &client{}, nil
+}
+
+func (c *client) ReConfigureDevice(config wgcfg.DeviceConfig) error {
+ return c.ConfigureDevice(config)
+}
+
+func (c *client) ConfigureDevice(cfg wgcfg.DeviceConfig) error {
+ localAddr, err := netip.ParseAddr(cfg.Subnet.IP.String())
+ if err != nil {
+ return fmt.Errorf("could not parse local addr: %w", err)
+ }
+ if len(cfg.DNS) == 0 {
+ return fmt.Errorf("DNS addr list is empty")
+ }
+ dnsAddr, err := netip.ParseAddr(cfg.DNS[0])
+ if err != nil {
+ return fmt.Errorf("could not parse DNS addr: %w", err)
+ }
+ tunnel, tnet, err := netstack.CreateNetTUN([]netip.Addr{localAddr}, []netip.Addr{dnsAddr}, device.DefaultMTU)
+ if err != nil {
+ return fmt.Errorf("failed to create netstack device %s: %w", cfg.IfaceName, err)
+ }
+
+ logger := device.NewLogger(device.LogLevelVerbose, fmt.Sprintf("(%s) ", cfg.IfaceName))
+ wgDevice := device.NewDevice(tunnel, conn.NewDefaultBind(), logger)
+
+ log.Info().Msg("Applying interface configuration")
+ if err := wgDevice.IpcSetOperation(bufio.NewReader(strings.NewReader(cfg.Encode()))); err != nil {
+ wgDevice.Close()
+ return fmt.Errorf("could not set device uapi config: %w", err)
+ }
+
+ log.Info().Msg("Bringing device up")
+
+ wgDevice.Up()
+
+ c.mu.Lock()
+ c.Device = wgDevice
+ c.mu.Unlock()
+ c.tnet = tnet
+
+ return nil
+}
+
+func (c *client) DestroyDevice(iface string) error {
+ return c.Close()
+}
+
+func (c *client) PeerStats(iface string) (wgcfg.Stats, error) {
+ deviceState, err := userspace.ParseUserspaceDevice(c.Device.IpcGetOperation)
+ if err != nil {
+ return wgcfg.Stats{}, fmt.Errorf("could not parse device state: %w", err)
+ }
+
+ stats, statErr := userspace.ParseDevicePeerStats(deviceState)
+ if statErr != nil {
+ err = statErr
+ log.Warn().Err(err).Msg("Failed to parse device stats, will try again")
+ } else {
+ return stats, nil
+ }
+
+ return wgcfg.Stats{}, fmt.Errorf("could not parse device state: %w", err)
+}
+
+func (c *client) Close() (err error) {
+ c.mu.Lock()
+ defer c.mu.Unlock()
+
+ log.Error().Err(err).Msg("Shutting down pinger ...")
+
+ if c.Device != nil {
+ go func() {
+ time.Sleep(5 * time.Second)
+ c.Device.Close()
+ }()
+ }
+ return nil
+}
+
+func (c *client) Diag() error {
+ client := http.Client{
+ Transport: &http.Transport{
+ DialContext: c.tnet.DialContext,
+ },
+ Timeout: 15 * time.Second,
+ }
+ resp, err := client.Get("http://107.173.23.19:8080/test")
+ if err != nil {
+ log.Error().Err(err).Msg("Get failed")
+ return err
+ }
+ defer resp.Body.Close()
+
+ body, err := io.ReadAll(resp.Body)
+ if err != nil {
+ log.Error().Err(err).Msg("Readall failed")
+ return err
+ }
+ if len(body) < 6 {
+ log.Error().Msg("Wrong length")
+ return errors.New("Wrong body length")
+ }
+
+ return nil
+}
diff --git a/services/wireguard/endpoint/endpoint.go b/services/wireguard/endpoint/endpoint.go
index 7b77489300..5dd43a0800 100644
--- a/services/wireguard/endpoint/endpoint.go
+++ b/services/wireguard/endpoint/endpoint.go
@@ -52,6 +52,14 @@ type connectionEndpoint struct {
wgClient WgClient
}
+func (ce *connectionEndpoint) Diag() error {
+ c, ok := ce.wgClient.(WgClientDiag)
+ if ok {
+ return c.Diag()
+ }
+ return nil
+}
+
// StartConsumerMode starts and configure wireguard network interface running in consumer mode.
func (ce *connectionEndpoint) StartConsumerMode(cfg wgcfg.DeviceConfig) error {
if err := ce.cleanAbandonedInterfaces(); err != nil {
diff --git a/services/wireguard/endpoint/wg_client.go b/services/wireguard/endpoint/wg_client.go
index b991bb2ab2..22dd36781e 100644
--- a/services/wireguard/endpoint/wg_client.go
+++ b/services/wireguard/endpoint/wg_client.go
@@ -24,6 +24,7 @@ import (
"github.com/rs/zerolog/log"
"github.com/mysteriumnetwork/node/config"
+ "github.com/mysteriumnetwork/node/services/wireguard/endpoint/diagclient"
"github.com/mysteriumnetwork/node/services/wireguard/endpoint/dvpnclient"
"github.com/mysteriumnetwork/node/services/wireguard/endpoint/kernelspace"
netstack_provider "github.com/mysteriumnetwork/node/services/wireguard/endpoint/netstack-provider"
@@ -43,6 +44,11 @@ type WgClient interface {
Close() error
}
+// WgClientDiag is a specialised WgClient interface for provider check
+type WgClientDiag interface {
+ Diag() error
+}
+
// WgClientFactory represents WireGuard client factory.
type WgClientFactory struct {
once sync.Once
@@ -56,6 +62,10 @@ func NewWGClientFactory() *WgClientFactory {
// NewWGClient returns a new wireguard client.
func (wcf *WgClientFactory) NewWGClient() (WgClient, error) {
+
+ if config.GetBool(config.FlagProvCheckerMode) {
+ return diagclient.New()
+ }
if config.GetBool(config.FlagDVPNMode) {
return dvpnclient.New()
}
diff --git a/services/wireguard/service/service_test.go b/services/wireguard/service/service_test.go
index 96f210be70..f378e3fada 100644
--- a/services/wireguard/service/service_test.go
+++ b/services/wireguard/service/service_test.go
@@ -153,6 +153,9 @@ func (mce *mockConnectionEndpoint) ConfigureRoutes(_ net.IP) error { retur
func (mce *mockConnectionEndpoint) PeerStats() (wgcfg.Stats, error) {
return wgcfg.Stats{LastHandshake: time.Now()}, nil
}
+func (mce *mockConnectionEndpoint) Diag() error {
+ return nil
+}
func newManagerStub(pub, out, country string) *Manager {
dnsHandler, _ := dns.ResolveViaSystem()
diff --git a/tequilapi/contract/connection.go b/tequilapi/contract/connection.go
index c20432e63e..835805c254 100644
--- a/tequilapi/contract/connection.go
+++ b/tequilapi/contract/connection.go
@@ -51,6 +51,14 @@ func NewConnectionInfoDTO(session connectionstate.Status) ConnectionInfoDTO {
return response
}
+// ConnectionDiagInfoDTO holds provider check result
+// swagger:model ConnectionDiagInfoDTO
+type ConnectionDiagInfoDTO struct {
+ ProviderID string `json:"provider_id"`
+ Error string `json:"error"`
+ DiagError string `json:"diag_err"`
+}
+
// ConnectionInfoDTO holds partial consumer connection details.
// swagger:model ConnectionInfoDTO
type ConnectionInfoDTO struct {
diff --git a/tequilapi/endpoints/connection-diag-empty.go b/tequilapi/endpoints/connection-diag-empty.go
new file mode 100644
index 0000000000..ce8d94f720
--- /dev/null
+++ b/tequilapi/endpoints/connection-diag-empty.go
@@ -0,0 +1,46 @@
+//go:build !prov_checker
+
+/*
+ * Copyright (C) 2024 The "MysteriumNetwork/node" Authors.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see .
+ */
+
+package endpoints
+
+import (
+ "github.com/gin-gonic/gin"
+
+ "github.com/mysteriumnetwork/node/core/connection"
+ "github.com/mysteriumnetwork/node/core/node"
+ "github.com/mysteriumnetwork/node/eventbus"
+ "github.com/mysteriumnetwork/node/identity/selector"
+)
+
+// AddRoutesForConnectionDiag adds proder check route to given router
+func AddRoutesForConnectionDiag(
+ manager connection.DiagManager,
+ stateProvider stateProvider,
+ proposalRepository proposalRepository,
+ identityRegistry identityRegistry,
+ publisher eventbus.Publisher,
+ subscriber eventbus.Subscriber,
+ addressProvider addressProvider,
+ identitySelector selector.Handler,
+ options node.Options,
+) func(*gin.Engine) error {
+ return func(e *gin.Engine) error {
+ return nil
+ }
+}
diff --git a/tequilapi/endpoints/connection-diag.go b/tequilapi/endpoints/connection-diag.go
new file mode 100644
index 0000000000..2de6151349
--- /dev/null
+++ b/tequilapi/endpoints/connection-diag.go
@@ -0,0 +1,434 @@
+//go:build prov_checker
+
+/*
+ * Copyright (C) 2024 The "MysteriumNetwork/node" Authors.
+ *
+ * This program is free software: you can redistribute it and/or modify
+ * it under the terms of the GNU General Public License as published by
+ * the Free Software Foundation, either version 3 of the License, or
+ * (at your option) any later version.
+ *
+ * This program is distributed in the hope that it will be useful,
+ * but WITHOUT ANY WARRANTY; without even the implied warranty of
+ * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
+ * GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License
+ * along with this program. If not, see .
+ */
+
+package endpoints
+
+import (
+ "fmt"
+ "sort"
+
+ "github.com/ethereum/go-ethereum/common"
+ "github.com/gin-gonic/gin"
+ "github.com/pkg/errors"
+ "github.com/rs/zerolog/log"
+ "gorm.io/driver/postgres"
+ "gorm.io/gorm"
+ "gvisor.dev/gvisor/pkg/sync"
+
+ "github.com/mysteriumnetwork/go-rest/apierror"
+ "github.com/mysteriumnetwork/node/config"
+ "github.com/mysteriumnetwork/node/core/connection"
+ "github.com/mysteriumnetwork/node/core/discovery/proposal"
+ "github.com/mysteriumnetwork/node/core/node"
+ "github.com/mysteriumnetwork/node/core/quality"
+ "github.com/mysteriumnetwork/node/eventbus"
+ "github.com/mysteriumnetwork/node/identity"
+ "github.com/mysteriumnetwork/node/identity/registry"
+ "github.com/mysteriumnetwork/node/identity/selector"
+ "github.com/mysteriumnetwork/node/tequilapi/contract"
+ "github.com/mysteriumnetwork/node/tequilapi/utils"
+)
+
+// ConnectionDiagEndpoint struct represents /connection resource and it's subresources
+type ConnectionDiagEndpoint struct {
+ manager connection.DiagManager
+ publisher eventbus.Publisher
+ subscriber eventbus.Subscriber
+
+ stateProvider stateProvider
+ // TODO connection should use concrete proposal from connection params and avoid going to marketplace
+ proposalRepository proposalRepository
+ identityRegistry identityRegistry
+ addressProvider addressProvider
+ identitySelector selector.Handler
+
+ consumerAddress string
+
+ db *gorm.DB
+}
+
+// NewConnectionDiagEndpoint creates and returns connection endpoint
+func NewConnectionDiagEndpoint(manager connection.DiagManager, stateProvider stateProvider, proposalRepository proposalRepository, identityRegistry identityRegistry, publisher eventbus.Publisher, subscriber eventbus.Subscriber, addressProvider addressProvider, identitySelector selector.Handler) *ConnectionDiagEndpoint {
+ ce := &ConnectionDiagEndpoint{
+ manager: manager,
+ publisher: publisher,
+ subscriber: subscriber,
+ stateProvider: stateProvider,
+ proposalRepository: proposalRepository,
+ identityRegistry: identityRegistry,
+ addressProvider: addressProvider,
+ identitySelector: identitySelector,
+ }
+
+ chainID := config.GetInt64(config.FlagChainID)
+ consumerID_, err := ce.identitySelector.UseOrCreate(config.FlagIdentity.Value, config.FlagIdentityPassphrase.Value, chainID)
+ if err != nil {
+ panic(err)
+ }
+ log.Error().Msgf("Unlocked identity: %v", consumerID_.Address)
+ ce.consumerAddress = consumerID_.Address
+
+ dsn := config.GetString(config.FlagProvCheckerDatabaseDSN)
+ if dsn != "" {
+ db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{})
+ if err != nil {
+ panic(err)
+ }
+ ce.db = db
+ }
+ return ce
+}
+
+func dedupeSortedStrings(s []string) []string {
+ if len(s) < 2 {
+ return s
+ }
+ var e = 1
+ for i := 1; i < len(s); i++ {
+ if s[i] == s[i-1] {
+ continue
+ }
+ s[e] = s[i]
+ e++
+ }
+
+ return s[:e]
+}
+
+// Diag is used to start a given provider check
+func (ce *ConnectionDiagEndpoint) Diag(c *gin.Context) {
+ hermes, err := ce.addressProvider.GetActiveHermes(config.GetInt64(config.FlagChainID))
+ if err != nil {
+ c.Error(apierror.Internal("Failed to get active hermes", contract.ErrCodeActiveHermes))
+ return
+ }
+
+ prov := c.Query("id")
+ if len(prov) == 0 {
+ c.Error(errors.New("Empty prameter: prov"))
+ return
+ }
+ cr := &contract.ConnectionCreateRequest{
+ ConsumerID: ce.consumerAddress,
+ ProviderID: prov,
+ Filter: contract.ConnectionCreateFilter{IncludeMonitoringFailed: true},
+ HermesID: hermes.Hex(),
+ ServiceType: "wireguard",
+ ConnectOptions: contract.ConnectOptions{},
+ }
+ if err := cr.Validate(); err != nil {
+ c.Error(err)
+ return
+ }
+
+ consumerID := identity.FromAddress(cr.ConsumerID)
+ status, err := ce.identityRegistry.GetRegistrationStatus(config.GetInt64(config.FlagChainID), consumerID)
+ if err != nil {
+ log.Error().Err(err).Stack().Msg("Could not check registration status")
+ c.Error(apierror.Internal("Failed to check ID registration status: "+err.Error(), contract.ErrCodeIDRegistrationCheck))
+ return
+ }
+ switch status {
+ case registry.Unregistered, registry.RegistrationError, registry.Unknown:
+ log.Error().Msgf("Identity %q is not registered, aborting...", cr.ConsumerID)
+ c.Error(apierror.Unprocessable(fmt.Sprintf("Identity %q is not registered. Please register the identity first", cr.ConsumerID), contract.ErrCodeIDNotRegistered))
+ return
+ case registry.InProgress:
+ log.Info().Msgf("identity %q registration is in progress, continuing...", cr.ConsumerID)
+ case registry.Registered:
+ log.Info().Msgf("identity %q is registered, continuing...", cr.ConsumerID)
+ default:
+ log.Error().Msgf("identity %q has unknown status, aborting...", cr.ConsumerID)
+ c.Error(apierror.Unprocessable(fmt.Sprintf("Identity %q has unknown status. Aborting", cr.ConsumerID), contract.ErrCodeIDStatusUnknown))
+ return
+ }
+
+ if len(cr.ProviderID) > 0 {
+ cr.Filter.Providers = append(cr.Filter.Providers, cr.ProviderID)
+ }
+ f := &proposal.Filter{
+ ServiceType: cr.ServiceType,
+ LocationCountry: cr.Filter.CountryCode,
+ ProviderIDs: cr.Filter.Providers,
+ IPType: cr.Filter.IPType,
+ IncludeMonitoringFailed: cr.Filter.IncludeMonitoringFailed,
+ AccessPolicy: "all",
+ }
+ proposalLookup := connection.FilteredProposals(f, cr.Filter.SortBy, ce.proposalRepository)
+
+ if ce.manager.HasConnection(cr.ProviderID) {
+ c.Error(apierror.Unprocessable("Connection already exists", contract.ErrCodeConnectionAlreadyExists))
+ return
+ }
+
+ err = ce.manager.Connect(consumerID, common.HexToAddress(cr.HermesID), proposalLookup, getConnectOptions(cr))
+ if err != nil {
+ switch err {
+ case connection.ErrAlreadyExists:
+ c.Error(apierror.Unprocessable("Connection already exists", contract.ErrCodeConnectionAlreadyExists))
+ case connection.ErrConnectionCancelled:
+ c.Error(apierror.Unprocessable("Connection cancelled", contract.ErrCodeConnectionCancelled))
+ default:
+ log.Error().Err(err).Msg("Failed to connect")
+ c.Error(apierror.Internal("Failed to connect: "+err.Error(), contract.ErrCodeConnect))
+ }
+ return
+ }
+
+ resChannel := ce.manager.GetReadyChan(cr.ProviderID)
+ res := <-resChannel
+ log.Error().Msgf("Result > %v", res)
+
+ resp := contract.ConnectionDiagInfoDTO{
+ ProviderID: prov,
+ }
+ utils.WriteAsJSON(resp, c.Writer)
+}
+
+// DiagBatch is used to start a given providers check (batch mode)
+func (ce *ConnectionDiagEndpoint) DiagBatch(c *gin.Context) {
+ hermes, err := ce.addressProvider.GetActiveHermes(config.GetInt64(config.FlagChainID))
+ if err != nil {
+ c.Error(apierror.Internal("Failed to get active hermes", contract.ErrCodeActiveHermes))
+ return
+ }
+
+ provs := make([]string, 0)
+ c.Bind(&provs)
+ sort.Strings(provs)
+ provs = dedupeSortedStrings(provs)
+
+ var (
+ wg sync.WaitGroup
+ mu sync.Mutex
+ )
+ resultMap := make(map[string]contract.ConnectionDiagInfoDTO, len(provs))
+ wg.Add(len(provs))
+
+ maxGoroutines := 15
+ guard := make(chan struct{}, maxGoroutines)
+
+ for _, pr := range provs {
+ guard <- struct{}{}
+
+ go func(providerID string) {
+ result := ce.worker(providerID, hermes)
+
+ mu.Lock()
+ resultMap[providerID] = result
+ mu.Unlock()
+
+ wg.Done()
+ <-guard
+ }(pr)
+
+ }
+ wg.Wait()
+
+ out := make([]contract.ConnectionDiagInfoDTO, 0)
+ for _, prov := range provs {
+ out = append(out, resultMap[prov])
+ }
+ utils.WriteAsJSON(out, c.Writer)
+}
+
+type proposalDB struct {
+ ID string
+ Error string
+ DiagError string `json:"diag_error"`
+ Country string
+}
+
+func (proposalDB) TableName() string {
+ return "node"
+}
+
+func (ce *ConnectionDiagEndpoint) worker(provID string, hermes common.Address) (result contract.ConnectionDiagInfoDTO) {
+ result.ProviderID = provID
+
+ cr := &contract.ConnectionCreateRequest{
+ ConsumerID: ce.consumerAddress,
+ ProviderID: provID,
+ Filter: contract.ConnectionCreateFilter{IncludeMonitoringFailed: true},
+ HermesID: hermes.Hex(),
+ ServiceType: "wireguard",
+ ConnectOptions: contract.ConnectOptions{},
+ }
+ if err := cr.Validate(); err != nil {
+ result.Error = err.Error()
+ return
+ }
+
+ consumerID := identity.FromAddress(cr.ConsumerID)
+ status, err := ce.identityRegistry.GetRegistrationStatus(config.GetInt64(config.FlagChainID), consumerID)
+ if err != nil {
+ log.Error().Err(err).Stack().Msg("Could not check registration status")
+ result.Error = (contract.ErrCodeIDRegistrationCheck)
+ return
+ }
+ switch status {
+ case registry.Unregistered, registry.RegistrationError, registry.Unknown:
+ log.Error().Msgf("Identity %q is not registered, aborting...", cr.ConsumerID)
+ result.Error = (contract.ErrCodeIDNotRegistered)
+ return
+ case registry.InProgress:
+ log.Info().Msgf("identity %q registration is in progress, continuing...", cr.ConsumerID)
+ case registry.Registered:
+ log.Info().Msgf("identity %q is registered, continuing...", cr.ConsumerID)
+ default:
+ log.Error().Msgf("identity %q has unknown status, aborting...", cr.ConsumerID)
+ result.Error = (contract.ErrCodeIDStatusUnknown)
+ return
+ }
+
+ if len(cr.ProviderID) > 0 {
+ cr.Filter.Providers = append(cr.Filter.Providers, cr.ProviderID)
+ }
+ f := &proposal.Filter{
+ ServiceType: cr.ServiceType,
+ LocationCountry: cr.Filter.CountryCode,
+ ProviderIDs: cr.Filter.Providers,
+ IPType: cr.Filter.IPType,
+ IncludeMonitoringFailed: cr.Filter.IncludeMonitoringFailed,
+ AccessPolicy: "all",
+ }
+ proposalLookup := connection.FilteredProposals(f, cr.Filter.SortBy, ce.proposalRepository)
+
+ if ce.manager.HasConnection(cr.ProviderID) {
+ result.Error = (contract.ErrCodeConnectionAlreadyExists)
+ return
+ }
+
+ err = ce.manager.Connect(consumerID, common.HexToAddress(cr.HermesID), proposalLookup, getConnectOptions(cr))
+ if err != nil {
+ switch err {
+ case connection.ErrAlreadyExists:
+ result.Error = (contract.ErrCodeConnectionAlreadyExists)
+ case connection.ErrConnectionCancelled:
+ result.Error = (contract.ErrCodeConnectionCancelled)
+ default:
+ log.Error().Err(err).Msgf("Failed to connect: %v", provID)
+ result.Error = (contract.ErrCodeConnect)
+ }
+ return
+ }
+
+ resChannel := ce.manager.GetReadyChan(cr.ProviderID)
+ res := <-resChannel
+ log.Error().Msgf("Result > %v", res)
+
+ ev := res.(quality.DiagEvent)
+ if ev.Error != nil {
+ result.DiagError = ev.Error.Error()
+ }
+
+ return
+}
+
+// DiagBatch2 is used to start a check of providers from a given country or all countries
+func (ce *ConnectionDiagEndpoint) DiagBatch2(c *gin.Context) {
+ hermes, err := ce.addressProvider.GetActiveHermes(config.GetInt64(config.FlagChainID))
+ if err != nil {
+ c.Error(apierror.Internal("Failed to get active hermes", contract.ErrCodeActiveHermes))
+ return
+ }
+
+ country := c.Query("location")
+ f := &proposal.Filter{
+ ServiceType: "wireguard",
+ LocationCountry: country,
+ ExcludeUnsupported: true,
+ IncludeMonitoringFailed: true,
+ }
+ pp, err := ce.proposalRepository.Proposals(f)
+ if err != nil {
+ log.Error().Err(err).Stack().Msg("Proposals>")
+ }
+ log.Error().Msgf("pp> %v", len(pp))
+
+ var (
+ wg sync.WaitGroup
+ mu sync.Mutex
+ )
+ resultMap := make(map[string]contract.ConnectionDiagInfoDTO, len(pp))
+ wg.Add(len(pp))
+
+ maxGoroutines := 15
+ guard := make(chan struct{}, maxGoroutines)
+
+ for _, pr := range pp {
+ guard <- struct{}{}
+
+ go func(pr proposal.PricedServiceProposal) {
+
+ result := ce.worker(pr.ProviderID, hermes)
+
+ mu.Lock()
+ resultMap[pr.ProviderID] = result
+ mu.Unlock()
+
+ if ce.db != nil {
+ // update
+ provRec := proposalDB{ID: result.ProviderID, Country: pr.Location.Country}
+ provRec.Error = ""
+ provRec.DiagError = ""
+ provRec.Error = result.Error
+ provRec.DiagError = result.DiagError
+ if ce.db.Model(&provRec).Select("Error", "DiagError", "Country").Updates(provRec).RowsAffected == 0 {
+ ce.db.Create(&provRec)
+ }
+ }
+
+ wg.Done()
+ <-guard
+ }(pr)
+
+ }
+ wg.Wait()
+
+ out := make([]contract.ConnectionDiagInfoDTO, 0)
+ for _, prov := range pp {
+ out = append(out, resultMap[prov.ProviderID])
+ }
+ utils.WriteAsJSON(out, c.Writer)
+}
+
+// AddRoutesForConnectionDiag adds proder check route to given router
+func AddRoutesForConnectionDiag(
+ manager connection.DiagManager,
+ stateProvider stateProvider,
+ proposalRepository proposalRepository,
+ identityRegistry identityRegistry,
+ publisher eventbus.Publisher,
+ subscriber eventbus.Subscriber,
+ addressProvider addressProvider,
+ identitySelector selector.Handler,
+ options node.Options,
+) func(*gin.Engine) error {
+ ConnectionDiagEndpoint := NewConnectionDiagEndpoint(manager, stateProvider, proposalRepository, identityRegistry, publisher, subscriber, addressProvider, identitySelector)
+ return func(e *gin.Engine) error {
+ connGroup := e.Group("")
+ {
+ connGroup.GET("/prov-checker", ConnectionDiagEndpoint.Diag)
+ connGroup.POST("/prov-checker-batch", ConnectionDiagEndpoint.DiagBatch)
+ connGroup.GET("/prov-checker-batch2", ConnectionDiagEndpoint.DiagBatch2)
+ }
+ return nil
+ }
+}