From 917f40d75c0f07c1db973c1217f5597d52d0f0c3 Mon Sep 17 00:00:00 2001 From: Anton Litvinov Date: Mon, 13 May 2024 02:24:46 +0400 Subject: [PATCH 1/6] Provider checker mode Signed-off-by: Anton Litvinov --- cmd/bootstrap.go | 1 + cmd/di.go | 15 +- config/flags_node.go | 10 + core/connection/interface.go | 5 + core/connection/manager.go | 8 + core/connection/manager_test.go | 4 + core/connection/pinger.go | 50 ++++ core/node/options.go | 6 +- core/quality/metrics.go | 9 + services/wireguard/connection/connection.go | 7 + .../wireguard/connection/connection_test.go | 3 + services/wireguard/endpoint.go | 1 + .../wireguard/endpoint/diagclient/client.go | 150 +++++++++++ services/wireguard/endpoint/endpoint.go | 10 + services/wireguard/endpoint/wg_client.go | 10 + services/wireguard/service/service_test.go | 3 + tequilapi/contract/connection.go | 8 + tequilapi/endpoints/connection-diag.go | 244 ++++++++++++++++++ 18 files changed, 539 insertions(+), 5 deletions(-) create mode 100644 core/connection/pinger.go create mode 100644 services/wireguard/endpoint/diagclient/client.go create mode 100644 tequilapi/endpoints/connection-diag.go diff --git a/cmd/bootstrap.go b/cmd/bootstrap.go index a45f464cee..e38b9cb968 100644 --- a/cmd/bootstrap.go +++ b/cmd/bootstrap.go @@ -70,6 +70,7 @@ func (di *Dependencies) bootstrapTequilapi(nodeOptions node.Options, listener ne tequilapi_endpoints.AddRoutesForAuthentication(di.Authenticator, di.JWTAuthenticator, di.SSOMystnodes), tequilapi_endpoints.AddRoutesForIdentities(di.IdentityManager, di.IdentitySelector, di.IdentityRegistry, di.ConsumerBalanceTracker, di.AddressProvider, di.HermesChannelRepository, di.BCHelper, di.Transactor, di.BeneficiaryProvider, di.IdentityMover, di.BeneficiaryAddressStorage, di.HermesMigrator), tequilapi_endpoints.AddRoutesForConnection(di.MultiConnectionManager, di.StateKeeper, di.ProposalRepository, di.IdentityRegistry, di.EventBus, di.AddressProvider), + tequilapi_endpoints.AddRoutesForConnectionDiag(di.MultiConnectionManager, di.StateKeeper, di.ProposalRepository, di.IdentityRegistry, di.EventBus, di.EventBus, di.AddressProvider, di.IdentitySelector, nodeOptions), tequilapi_endpoints.AddRoutesForSessions(di.SessionStorage), tequilapi_endpoints.AddRoutesForConnectionLocation(di.IPResolver, di.LocationResolver, di.LocationResolver), tequilapi_endpoints.AddRoutesForProposals(di.ProposalRepository, di.PricingHelper, di.LocationResolver, di.FilterPresetStorage, di.NATProber), diff --git a/cmd/di.go b/cmd/di.go index 1088d8ce3c..d1a161e728 100644 --- a/cmd/di.go +++ b/cmd/di.go @@ -210,6 +210,8 @@ type Dependencies struct { NodeStatusTracker *monitoring.StatusTracker NodeStatsTracker *node.StatsTracker uiVersionConfig versionmanager.NodeUIVersionConfig + + provPinger *connection.ProviderChecker } // Bootstrap initiates all container dependencies @@ -287,7 +289,7 @@ func (di *Dependencies) Bootstrap(nodeOptions node.Options) error { return err } - if err := di.bootstrapQualityComponents(nodeOptions.Quality); err != nil { + if err := di.bootstrapQualityComponents(nodeOptions.Quality, nodeOptions); err != nil { return err } @@ -299,6 +301,7 @@ func (di *Dependencies) Bootstrap(nodeOptions node.Options) error { if err = di.handleConnStateChange(); err != nil { return err } + if err := di.Node.Start(); err != nil { return err } @@ -581,6 +584,7 @@ func (di *Dependencies) bootstrapNodeComponents(nodeOptions node.Options, tequil di.bootstrapBeneficiarySaver(nodeOptions) di.ConnectionRegistry = connection.NewRegistry() + di.MultiConnectionManager = connection.NewMultiConnectionManager(func() connection.Manager { return connection.NewManager( pingpong.ExchangeFactoryFunc( @@ -604,6 +608,7 @@ func (di *Dependencies) bootstrapNodeComponents(nodeOptions node.Options, tequil di.P2PDialer, di.allowTrustedDomainBypassTunnel, di.disallowTrustedDomainBypassTunnel, + di.provPinger, ) }) @@ -883,7 +888,7 @@ func (di *Dependencies) bootstrapIdentityComponents(options node.Options) error return nil } -func (di *Dependencies) bootstrapQualityComponents(options node.OptionsQuality) (err error) { +func (di *Dependencies) bootstrapQualityComponents(options node.OptionsQuality, nodeOptions node.Options) (err error) { if err := di.AllowURLAccess(options.Address); err != nil { return err } @@ -924,6 +929,10 @@ func (di *Dependencies) bootstrapQualityComponents(options node.OptionsQuality) return err } + if nodeOptions.ProvChecker { + di.provPinger = connection.NewProviderChecker(di.EventBus) + } + return nil } @@ -1065,7 +1074,7 @@ func (di *Dependencies) handleConnStateChange() error { latestState := connectionstate.NotConnected return di.EventBus.SubscribeAsync(connectionstate.AppTopicConnectionState, func(e connectionstate.AppEventConnectionState) { - if config.GetBool(config.FlagProxyMode) || config.GetBool(config.FlagDVPNMode) { + if config.GetBool(config.FlagProxyMode) || config.GetBool(config.FlagDVPNMode) || config.GetBool(config.FlagProvCheckerMode) { return // Proxy mode doesn't establish system wide tunnels, no reconnect required. } diff --git a/config/flags_node.go b/config/flags_node.go index acaf62e56c..8c8fd3fada 100644 --- a/config/flags_node.go +++ b/config/flags_node.go @@ -229,6 +229,13 @@ var ( Value: false, } + // FlagProvCheckerMode allows running node under current user as a provider checker agent. + FlagProvCheckerMode = cli.BoolFlag{ + Name: "provchecker", + Usage: "", + Value: false, + } + // FlagUserspace allows running a node without privileged permissions. FlagUserspace = cli.BoolFlag{ Name: "userspace", @@ -349,6 +356,7 @@ func RegisterFlagsNode(flags *[]cli.Flag) error { &FlagUserMode, &FlagDVPNMode, &FlagProxyMode, + &FlagProvCheckerMode, &FlagUserspace, &FlagVendorID, &FlagLauncherVersion, @@ -411,6 +419,8 @@ func ParseFlagsNode(ctx *cli.Context) { Current.ParseBoolFlag(ctx, FlagUserMode) Current.ParseBoolFlag(ctx, FlagDVPNMode) Current.ParseBoolFlag(ctx, FlagProxyMode) + Current.ParseBoolFlag(ctx, FlagProvCheckerMode) + Current.ParseBoolFlag(ctx, FlagUserspace) Current.ParseStringFlag(ctx, FlagVendorID) Current.ParseStringFlag(ctx, FlagLauncherVersion) diff --git a/core/connection/interface.go b/core/connection/interface.go index cf7f678140..531cca8f0f 100644 --- a/core/connection/interface.go +++ b/core/connection/interface.go @@ -39,6 +39,11 @@ type Connection interface { Statistics() (connectionstate.Statistics, error) } +// ConnectionDiag is a specialised Connection interface for provider check +type ConnectionDiag interface { + Diag() bool +} + // StateChannel is the channel we receive state change events on type StateChannel chan connectionstate.State diff --git a/core/connection/manager.go b/core/connection/manager.go index e87d5a32ad..15ce1bca25 100644 --- a/core/connection/manager.go +++ b/core/connection/manager.go @@ -170,6 +170,8 @@ type connectionManager struct { statsTracker statsTracker uuid string + + provChecker *ProviderChecker } // NewManager creates connection manager with given dependencies @@ -184,6 +186,7 @@ func NewManager( validator validator, p2pDialer p2p.Dialer, preReconnect, postReconnect func(), + provChecker *ProviderChecker, ) *connectionManager { uuid, err := uuid.NewV4() if err != nil { @@ -207,6 +210,7 @@ func NewManager( preReconnect: preReconnect, postReconnect: postReconnect, uuid: uuid.String(), + provChecker: provChecker, } m.eventBus.SubscribeAsync(connectionstate.AppTopicConnectionState, m.reconnectOnHold) @@ -301,6 +305,10 @@ func (m *connectionManager) Connect(consumerID identity.Identity, hermesID commo return nil }) + if m.provChecker != nil { + go m.provChecker.Diag(m, proposal.ProviderID) + } + go m.consumeConnectionStates(m.activeConnection.State()) go m.checkSessionIP(m.channel, m.connectOptions.ConsumerID, m.connectOptions.SessionID, originalPublicIP) diff --git a/core/connection/manager_test.go b/core/connection/manager_test.go index 4e7ec1da1d..92ba87f454 100644 --- a/core/connection/manager_test.go +++ b/core/connection/manager_test.go @@ -61,6 +61,8 @@ type testContext struct { statsReportInterval time.Duration mockP2P *mockP2PDialer mockTime time.Time + provChecker *ProviderChecker + sync.RWMutex } @@ -140,6 +142,7 @@ func (tc *testContext) SetupTest() { tc.mockP2P = &mockP2PDialer{&mockP2PChannel{}} tc.mockTime = time.Date(2000, time.January, 0, 10, 12, 3, 0, time.UTC) + tc.provChecker = NewProviderChecker(tc.stubPublisher) tc.connManager = NewManager( func(senderUUID string, channel p2p.Channel, @@ -159,6 +162,7 @@ func (tc *testContext) SetupTest() { &mockValidator{}, tc.mockP2P, func() {}, func() {}, + tc.provChecker, ) tc.connManager.timeGetter = func() time.Time { return tc.mockTime diff --git a/core/connection/pinger.go b/core/connection/pinger.go new file mode 100644 index 0000000000..9ed462507d --- /dev/null +++ b/core/connection/pinger.go @@ -0,0 +1,50 @@ +/* + * Copyright (C) 2024 The "MysteriumNetwork/node" Authors. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package connection + +import ( + "github.com/mysteriumnetwork/node/core/quality" + "github.com/mysteriumnetwork/node/eventbus" + "github.com/rs/zerolog/log" +) + +// ProviderChecker is a service for provider testing +type ProviderChecker struct { + bus eventbus.Publisher +} + +// NewProviderChecker is a ProviderChecker constructor +func NewProviderChecker(bus eventbus.Publisher) *ProviderChecker { + return &ProviderChecker{ + bus: bus, + } +} + +// Diag is used to start provider check +func (p *ProviderChecker) Diag(cm *connectionManager, providerID string) { + c, ok := cm.activeConnection.(ConnectionDiag) + res := false + if ok { + log.Debug().Msgf("Check provider> %v", providerID) + + res = c.Diag() + cm.Disconnect() + } + ev := quality.DiagEvent{ProviderID: providerID, Result: res} + p.bus.Publish(quality.AppTopicConnectionDiagRes, ev) +} diff --git a/core/node/options.go b/core/node/options.go index 04085ff27e..7ecb9cd22a 100644 --- a/core/node/options.go +++ b/core/node/options.go @@ -81,8 +81,9 @@ type Options struct { Payments OptionsPayments - Consumer bool - Mobile bool + Consumer bool + Mobile bool + ProvChecker bool SwarmDialerDNSHeadstart time.Duration PilvytisAddress string @@ -205,6 +206,7 @@ func GetOptions() *Options { SSE: OptionsSSE{ Enabled: config.GetBool(config.FlagSSEEnable), }, + ProvChecker: config.GetBool(config.FlagProvCheckerMode), } } diff --git a/core/quality/metrics.go b/core/quality/metrics.go index 51d3489607..9beedfd6bb 100644 --- a/core/quality/metrics.go +++ b/core/quality/metrics.go @@ -102,6 +102,12 @@ type PingEvent struct { Duration time.Duration `json:"duration"` } +// DiagEvent represents provider check result event +type DiagEvent struct { + ProviderID string + Result bool +} + const ( // AppTopicConnectionEvents represents event bus topic for the connection events. AppTopicConnectionEvents = "connection_events" @@ -111,4 +117,7 @@ const ( // AppTopicProviderPingP2P represents event bus topic for provider p2p pings to consumer. AppTopicProviderPingP2P = "provider_ping_p2p" + + // AppTopicConnectionDiagRes represents event bus topic for provider check result. + AppTopicConnectionDiagRes = "connection_diag" ) diff --git a/services/wireguard/connection/connection.go b/services/wireguard/connection/connection.go index 3657b8b00c..7909dfd815 100644 --- a/services/wireguard/connection/connection.go +++ b/services/wireguard/connection/connection.go @@ -86,6 +86,11 @@ func (c *Connection) State() <-chan connectionstate.State { return c.stateCh } +// Diag is used to start provider check +func (c *Connection) Diag() bool { + return c.connectionEndpoint.Diag() +} + // Statistics returns connection statistics channel. func (c *Connection) Statistics() (connectionstate.Statistics, error) { stats, err := c.connectionEndpoint.PeerStats() @@ -110,6 +115,8 @@ func (c *Connection) Reconnect(ctx context.Context, options connection.ConnectOp } func (c *Connection) start(ctx context.Context, start startConn, options connection.ConnectOptions) (err error) { + log.Info().Msg("+++++++++++++++++++++++++++++++++++++++++++++++++++++ *Connection) start") + var config wg.ServiceConfig if err = json.Unmarshal(options.SessionConfig, &config); err != nil { return errors.Wrap(err, "failed to unmarshal connection config") diff --git a/services/wireguard/connection/connection_test.go b/services/wireguard/connection/connection_test.go index cbc2930b6a..ad82dbc3a8 100644 --- a/services/wireguard/connection/connection_test.go +++ b/services/wireguard/connection/connection_test.go @@ -158,6 +158,9 @@ func (mce *mockConnectionEndpoint) ConfigureRoutes(_ net.IP) error { retur func (mce *mockConnectionEndpoint) PeerStats() (wgcfg.Stats, error) { return wgcfg.Stats{LastHandshake: time.Now(), BytesSent: 10, BytesReceived: 11}, nil } +func (mce *mockConnectionEndpoint) Diag() bool { + return true +} type mockHandshakeWaiter struct { err error diff --git a/services/wireguard/endpoint.go b/services/wireguard/endpoint.go index 8d2201b795..e6df362a67 100644 --- a/services/wireguard/endpoint.go +++ b/services/wireguard/endpoint.go @@ -34,4 +34,5 @@ type ConnectionEndpoint interface { Config() (ServiceConfig, error) InterfaceName() string Stop() error + Diag() bool } diff --git a/services/wireguard/endpoint/diagclient/client.go b/services/wireguard/endpoint/diagclient/client.go new file mode 100644 index 0000000000..646ff7f14e --- /dev/null +++ b/services/wireguard/endpoint/diagclient/client.go @@ -0,0 +1,150 @@ +/* + * Copyright (C) 2024 The "MysteriumNetwork/node" Authors. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package diagclient + +import ( + "bufio" + "fmt" + "io" + "net/http" + "net/netip" + "strings" + "sync" + "time" + + "github.com/rs/zerolog/log" + "golang.zx2c4.com/wireguard/conn" + "golang.zx2c4.com/wireguard/device" + + "github.com/mysteriumnetwork/node/services/wireguard/endpoint/netstack" + "github.com/mysteriumnetwork/node/services/wireguard/endpoint/userspace" + "github.com/mysteriumnetwork/node/services/wireguard/wgcfg" +) + +type client struct { + mu sync.Mutex + Device *device.Device + tnet *netstack.Net +} + +// New create new WireGuard client testing the provider. +func New() (*client, error) { + log.Error().Msg("Creating pinger wg client") + return &client{}, nil +} + +func (c *client) ReConfigureDevice(config wgcfg.DeviceConfig) error { + return c.ConfigureDevice(config) +} + +func (c *client) ConfigureDevice(cfg wgcfg.DeviceConfig) error { + localAddr, err := netip.ParseAddr(cfg.Subnet.IP.String()) + if err != nil { + return fmt.Errorf("could not parse local addr: %w", err) + } + if len(cfg.DNS) == 0 { + return fmt.Errorf("DNS addr list is empty") + } + dnsAddr, err := netip.ParseAddr(cfg.DNS[0]) + if err != nil { + return fmt.Errorf("could not parse DNS addr: %w", err) + } + tunnel, tnet, err := netstack.CreateNetTUN([]netip.Addr{localAddr}, []netip.Addr{dnsAddr}, device.DefaultMTU) + if err != nil { + return fmt.Errorf("failed to create netstack device %s: %w", cfg.IfaceName, err) + } + + logger := device.NewLogger(device.LogLevelVerbose, fmt.Sprintf("(%s) ", cfg.IfaceName)) + wgDevice := device.NewDevice(tunnel, conn.NewDefaultBind(), logger) + + log.Info().Msg("Applying interface configuration") + if err := wgDevice.IpcSetOperation(bufio.NewReader(strings.NewReader(cfg.Encode()))); err != nil { + wgDevice.Close() + return fmt.Errorf("could not set device uapi config: %w", err) + } + + log.Info().Msg("Bringing device up") + + wgDevice.Up() + + c.mu.Lock() + c.Device = wgDevice + c.mu.Unlock() + c.tnet = tnet + + return nil +} + +func (c *client) DestroyDevice(iface string) error { + return c.Close() +} + +func (c *client) PeerStats(iface string) (wgcfg.Stats, error) { + deviceState, err := userspace.ParseUserspaceDevice(c.Device.IpcGetOperation) + if err != nil { + return wgcfg.Stats{}, fmt.Errorf("could not parse device state: %w", err) + } + + stats, statErr := userspace.ParseDevicePeerStats(deviceState) + if statErr != nil { + err = statErr + log.Warn().Err(err).Msg("Failed to parse device stats, will try again") + } else { + return stats, nil + } + + return wgcfg.Stats{}, fmt.Errorf("could not parse device state: %w", err) +} + +func (c *client) Close() (err error) { + c.mu.Lock() + defer c.mu.Unlock() + + log.Error().Err(err).Msg("Shutting down pinger ...") + + if c.Device != nil { + go func() { + time.Sleep(5 * time.Second) + c.Device.Close() + }() + } + return nil +} + +func (c *client) Diag() bool { + client := http.Client{ + Transport: &http.Transport{ + DialContext: c.tnet.DialContext, + }, + } + resp, err := client.Get("http://1.1.1.1/") + if err != nil { + log.Error().Err(err).Msg("Get failed") + return false + } + defer resp.Body.Close() + + body, err := io.ReadAll(resp.Body) + if err != nil { + log.Error().Err(err).Msg("Readall failed") + return false + } + _ = body + + return true +} diff --git a/services/wireguard/endpoint/endpoint.go b/services/wireguard/endpoint/endpoint.go index 7b77489300..c9b8c7bc73 100644 --- a/services/wireguard/endpoint/endpoint.go +++ b/services/wireguard/endpoint/endpoint.go @@ -52,6 +52,14 @@ type connectionEndpoint struct { wgClient WgClient } +func (ce *connectionEndpoint) Diag() bool { + c, ok := ce.wgClient.(WgClientDiag) + if ok { + return c.Diag() + } + return false +} + // StartConsumerMode starts and configure wireguard network interface running in consumer mode. func (ce *connectionEndpoint) StartConsumerMode(cfg wgcfg.DeviceConfig) error { if err := ce.cleanAbandonedInterfaces(); err != nil { @@ -80,6 +88,8 @@ func (ce *connectionEndpoint) StartConsumerMode(cfg wgcfg.DeviceConfig) error { } return errors.Wrap(err, "could not configure device") } + + // ce.wgClient.Diag() return nil } diff --git a/services/wireguard/endpoint/wg_client.go b/services/wireguard/endpoint/wg_client.go index b991bb2ab2..74cd0f6111 100644 --- a/services/wireguard/endpoint/wg_client.go +++ b/services/wireguard/endpoint/wg_client.go @@ -24,6 +24,7 @@ import ( "github.com/rs/zerolog/log" "github.com/mysteriumnetwork/node/config" + "github.com/mysteriumnetwork/node/services/wireguard/endpoint/diagclient" "github.com/mysteriumnetwork/node/services/wireguard/endpoint/dvpnclient" "github.com/mysteriumnetwork/node/services/wireguard/endpoint/kernelspace" netstack_provider "github.com/mysteriumnetwork/node/services/wireguard/endpoint/netstack-provider" @@ -43,6 +44,11 @@ type WgClient interface { Close() error } +// WgClientDiag is a specialised WgClient interface for provider check +type WgClientDiag interface { + Diag() bool +} + // WgClientFactory represents WireGuard client factory. type WgClientFactory struct { once sync.Once @@ -56,6 +62,10 @@ func NewWGClientFactory() *WgClientFactory { // NewWGClient returns a new wireguard client. func (wcf *WgClientFactory) NewWGClient() (WgClient, error) { + + if config.GetBool(config.FlagProvCheckerMode) { + return diagclient.New() + } if config.GetBool(config.FlagDVPNMode) { return dvpnclient.New() } diff --git a/services/wireguard/service/service_test.go b/services/wireguard/service/service_test.go index 96f210be70..aa4742a17e 100644 --- a/services/wireguard/service/service_test.go +++ b/services/wireguard/service/service_test.go @@ -153,6 +153,9 @@ func (mce *mockConnectionEndpoint) ConfigureRoutes(_ net.IP) error { retur func (mce *mockConnectionEndpoint) PeerStats() (wgcfg.Stats, error) { return wgcfg.Stats{LastHandshake: time.Now()}, nil } +func (mce *mockConnectionEndpoint) Diag() bool { + return true +} func newManagerStub(pub, out, country string) *Manager { dnsHandler, _ := dns.ResolveViaSystem() diff --git a/tequilapi/contract/connection.go b/tequilapi/contract/connection.go index c20432e63e..226e9fdb37 100644 --- a/tequilapi/contract/connection.go +++ b/tequilapi/contract/connection.go @@ -51,6 +51,14 @@ func NewConnectionInfoDTO(session connectionstate.Status) ConnectionInfoDTO { return response } +// ConnectionDiagInfoDTO holds provider check result +// swagger:model ConnectionDiagInfoDTO +type ConnectionDiagInfoDTO struct { + Status bool `json:"status"` + Error interface{} `json:"error"` + ProviderID string `json:"provider_id"` +} + // ConnectionInfoDTO holds partial consumer connection details. // swagger:model ConnectionInfoDTO type ConnectionInfoDTO struct { diff --git a/tequilapi/endpoints/connection-diag.go b/tequilapi/endpoints/connection-diag.go new file mode 100644 index 0000000000..3435fa0178 --- /dev/null +++ b/tequilapi/endpoints/connection-diag.go @@ -0,0 +1,244 @@ +/* + * Copyright (C) 2024 The "MysteriumNetwork/node" Authors. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package endpoints + +import ( + "fmt" + "strconv" + + "github.com/ethereum/go-ethereum/common" + "github.com/gin-gonic/gin" + "github.com/gofrs/uuid" + "github.com/pkg/errors" + "github.com/rs/zerolog/log" + + "github.com/mysteriumnetwork/go-rest/apierror" + "github.com/mysteriumnetwork/node/config" + "github.com/mysteriumnetwork/node/core/connection" + "github.com/mysteriumnetwork/node/core/discovery/proposal" + "github.com/mysteriumnetwork/node/core/node" + "github.com/mysteriumnetwork/node/core/quality" + "github.com/mysteriumnetwork/node/eventbus" + "github.com/mysteriumnetwork/node/identity" + "github.com/mysteriumnetwork/node/identity/registry" + "github.com/mysteriumnetwork/node/identity/selector" + "github.com/mysteriumnetwork/node/tequilapi/contract" + "github.com/mysteriumnetwork/node/tequilapi/utils" +) + +// ConnectionDiagEndpoint struct represents /connection resource and it's subresources +type ConnectionDiagEndpoint struct { + manager connection.MultiManager + publisher eventbus.Publisher + subscriber eventbus.Subscriber + + stateProvider stateProvider + // TODO connection should use concrete proposal from connection params and avoid going to marketplace + proposalRepository proposalRepository + identityRegistry identityRegistry + addressProvider addressProvider + identitySelector selector.Handler +} + +// NewConnectionDiagEndpoint creates and returns connection endpoint +func NewConnectionDiagEndpoint(manager connection.MultiManager, stateProvider stateProvider, proposalRepository proposalRepository, identityRegistry identityRegistry, publisher eventbus.Publisher, subscriber eventbus.Subscriber, addressProvider addressProvider, identitySelector selector.Handler) *ConnectionDiagEndpoint { + return &ConnectionDiagEndpoint{ + manager: manager, + publisher: publisher, + subscriber: subscriber, + stateProvider: stateProvider, + proposalRepository: proposalRepository, + identityRegistry: identityRegistry, + addressProvider: addressProvider, + identitySelector: identitySelector, + } +} + +// Status returns result of provider check +// swagger:operation GET /prov-checker ConnectionDiagInfoDTO +// +// --- +// summary: Returns connection status +// description: Returns status of current connection +// responses: +// 200: +// description: Status +// schema: +// "$ref": "#/definitions/ConnectionInfoDTO" +// 400: +// description: Failed to parse or request validation failed +// schema: +// "$ref": "#/definitions/APIError" +// 500: +// description: Internal server error +// schema: +// "$ref": "#/definitions/APIError" +func (ce *ConnectionDiagEndpoint) Status(c *gin.Context) { + n := 0 + id := c.Query("id") + if len(id) > 0 { + var err error + n, err = strconv.Atoi(id) + if err != nil { + c.Error(apierror.ParseFailed()) + return + } + } + status := ce.manager.Status(n) + statusResponse := contract.NewConnectionInfoDTO(status) + utils.WriteAsJSON(statusResponse, c.Writer) +} + +// Diag is used to start provider check +func (ce *ConnectionDiagEndpoint) Diag(c *gin.Context) { + log.Error().Msgf("Diag >>>") + + chainID := config.GetInt64(config.FlagChainID) + consumerID_, err := ce.identitySelector.UseOrCreate(config.FlagIdentity.Value, config.FlagIdentityPassphrase.Value, chainID) + if err != nil { + c.Error(apierror.Internal("Failed to unlock identity", err.Error())) + return + } + log.Error().Msgf("Unlocked identity: %v", consumerID_) + + hermes, err := ce.addressProvider.GetActiveHermes(config.GetInt64(config.FlagChainID)) + if err != nil { + c.Error(apierror.Internal("Failed to get active hermes", contract.ErrCodeActiveHermes)) + return + } + + prov := c.Query("id") + if len(prov) == 0 { + c.Error(errors.New("Empty prameter: prov")) + return + } + cr := &contract.ConnectionCreateRequest{ + ConsumerID: consumerID_.Address, + ProviderID: prov, + Filter: contract.ConnectionCreateFilter{IncludeMonitoringFailed: true}, + HermesID: hermes.Hex(), + ServiceType: "wireguard", + ConnectOptions: contract.ConnectOptions{}, + } + + if err := cr.Validate(); err != nil { + c.Error(err) + return + } + + consumerID := identity.FromAddress(cr.ConsumerID) + status, err := ce.identityRegistry.GetRegistrationStatus(config.GetInt64(config.FlagChainID), consumerID) + if err != nil { + log.Error().Err(err).Stack().Msg("Could not check registration status") + c.Error(apierror.Internal("Failed to check ID registration status: "+err.Error(), contract.ErrCodeIDRegistrationCheck)) + return + } + + switch status { + case registry.Unregistered, registry.RegistrationError, registry.Unknown: + log.Error().Msgf("Identity %q is not registered, aborting...", cr.ConsumerID) + c.Error(apierror.Unprocessable(fmt.Sprintf("Identity %q is not registered. Please register the identity first", cr.ConsumerID), contract.ErrCodeIDNotRegistered)) + return + case registry.InProgress: + log.Info().Msgf("identity %q registration is in progress, continuing...", cr.ConsumerID) + case registry.Registered: + log.Info().Msgf("identity %q is registered, continuing...", cr.ConsumerID) + default: + log.Error().Msgf("identity %q has unknown status, aborting...", cr.ConsumerID) + c.Error(apierror.Unprocessable(fmt.Sprintf("Identity %q has unknown status. Aborting", cr.ConsumerID), contract.ErrCodeIDStatusUnknown)) + return + } + + if len(cr.ProviderID) > 0 { + cr.Filter.Providers = append(cr.Filter.Providers, cr.ProviderID) + } + + f := &proposal.Filter{ + ServiceType: cr.ServiceType, + LocationCountry: cr.Filter.CountryCode, + ProviderIDs: cr.Filter.Providers, + IPType: cr.Filter.IPType, + IncludeMonitoringFailed: cr.Filter.IncludeMonitoringFailed, + AccessPolicy: "all", + } + proposalLookup := connection.FilteredProposals(f, cr.Filter.SortBy, ce.proposalRepository) + + res := make(chan bool) + cb := func(r quality.DiagEvent) { + if r.ProviderID == prov { + res <- r.Result + } + } + + uid, err := uuid.NewV4() + if err != nil { + log.Error().Msgf("Error > %v", err) + c.Error(err) + return + } + + ce.subscriber.SubscribeWithUID(quality.AppTopicConnectionDiagRes, uid.String(), cb) + defer ce.subscriber.UnsubscribeWithUID(quality.AppTopicConnectionDiagRes, uid.String(), cb) + + err = ce.manager.Connect(consumerID, common.HexToAddress(cr.HermesID), proposalLookup, getConnectOptions(cr)) + if err != nil { + switch err { + case connection.ErrAlreadyExists: + c.Error(apierror.Unprocessable("Connection already exists", contract.ErrCodeConnectionAlreadyExists)) + case connection.ErrConnectionCancelled: + c.Error(apierror.Unprocessable("Connection cancelled", contract.ErrCodeConnectionCancelled)) + default: + log.Error().Err(err).Msg("Failed to connect") + c.Error(apierror.Internal("Failed to connect: "+err.Error(), contract.ErrCodeConnect)) + } + + return + } + + r := <-res + log.Debug().Msgf("Result > %v", r) + resp := contract.ConnectionDiagInfoDTO{ + ProviderID: prov, + Status: r, + } + utils.WriteAsJSON(resp, c.Writer) +} + +// AddRoutesForConnectionDiag adds proder check route to given router +func AddRoutesForConnectionDiag( + manager connection.MultiManager, + stateProvider stateProvider, + proposalRepository proposalRepository, + identityRegistry identityRegistry, + publisher eventbus.Publisher, + publisher2 eventbus.Subscriber, + addressProvider addressProvider, + identitySelector selector.Handler, + options node.Options, +) func(*gin.Engine) error { + ConnectionDiagEndpoint := NewConnectionDiagEndpoint(manager, stateProvider, proposalRepository, identityRegistry, publisher, publisher2, addressProvider, identitySelector) + return func(e *gin.Engine) error { + connGroup := e.Group("") + { + if options.ProvChecker { + connGroup.GET("/prov-checker", ConnectionDiagEndpoint.Diag) + } + } + return nil + } +} From 7b7c5ab99d09fb70d2b4facb2f53ffd57ed71ab2 Mon Sep 17 00:00:00 2001 From: Anton Litvinov Date: Mon, 20 May 2024 15:57:25 +0400 Subject: [PATCH 2/6] Add concurrency for diag endpoint Signed-off-by: Anton Litvinov --- cmd/bootstrap.go | 2 +- cmd/di.go | 37 +- cmd/node.go | 26 +- core/connection/interface.go | 12 + core/connection/manager-diag.go | 978 +++++++++++++++++++++++++ core/connection/manager.go | 8 - core/connection/manager_test.go | 3 - core/connection/pinger.go | 22 +- tequilapi/endpoints/connection-diag.go | 71 +- 9 files changed, 1053 insertions(+), 106 deletions(-) create mode 100644 core/connection/manager-diag.go diff --git a/cmd/bootstrap.go b/cmd/bootstrap.go index e38b9cb968..b0bae35d4e 100644 --- a/cmd/bootstrap.go +++ b/cmd/bootstrap.go @@ -70,7 +70,7 @@ func (di *Dependencies) bootstrapTequilapi(nodeOptions node.Options, listener ne tequilapi_endpoints.AddRoutesForAuthentication(di.Authenticator, di.JWTAuthenticator, di.SSOMystnodes), tequilapi_endpoints.AddRoutesForIdentities(di.IdentityManager, di.IdentitySelector, di.IdentityRegistry, di.ConsumerBalanceTracker, di.AddressProvider, di.HermesChannelRepository, di.BCHelper, di.Transactor, di.BeneficiaryProvider, di.IdentityMover, di.BeneficiaryAddressStorage, di.HermesMigrator), tequilapi_endpoints.AddRoutesForConnection(di.MultiConnectionManager, di.StateKeeper, di.ProposalRepository, di.IdentityRegistry, di.EventBus, di.AddressProvider), - tequilapi_endpoints.AddRoutesForConnectionDiag(di.MultiConnectionManager, di.StateKeeper, di.ProposalRepository, di.IdentityRegistry, di.EventBus, di.EventBus, di.AddressProvider, di.IdentitySelector, nodeOptions), + tequilapi_endpoints.AddRoutesForConnectionDiag(di.MultiConnectionDiagManager, di.StateKeeper, di.ProposalRepository, di.IdentityRegistry, di.EventBus, di.EventBus, di.AddressProvider, di.IdentitySelector, nodeOptions), tequilapi_endpoints.AddRoutesForSessions(di.SessionStorage), tequilapi_endpoints.AddRoutesForConnectionLocation(di.IPResolver, di.LocationResolver, di.LocationResolver), tequilapi_endpoints.AddRoutesForProposals(di.ProposalRepository, di.PricingHelper, di.LocationResolver, di.FilterPresetStorage, di.NATProber), diff --git a/cmd/di.go b/cmd/di.go index d1a161e728..726fbfa90c 100644 --- a/cmd/di.go +++ b/cmd/di.go @@ -150,8 +150,9 @@ type Dependencies struct { EventBus eventbus.EventBus - MultiConnectionManager connection.MultiManager - ConnectionRegistry *connection.Registry + MultiConnectionManager connection.MultiManager + MultiConnectionDiagManager connection.DiagManager + ConnectionRegistry *connection.Registry ServicesManager *service.Manager ServiceRegistry *service.Registry @@ -210,8 +211,6 @@ type Dependencies struct { NodeStatusTracker *monitoring.StatusTracker NodeStatsTracker *node.StatsTracker uiVersionConfig versionmanager.NodeUIVersionConfig - - provPinger *connection.ProviderChecker } // Bootstrap initiates all container dependencies @@ -608,10 +607,34 @@ func (di *Dependencies) bootstrapNodeComponents(nodeOptions node.Options, tequil di.P2PDialer, di.allowTrustedDomainBypassTunnel, di.disallowTrustedDomainBypassTunnel, - di.provPinger, ) }) + if nodeOptions.ProvChecker { + di.MultiConnectionDiagManager = connection.NewDiagManager( + pingpong.ExchangeFactoryFunc( + di.Keystore, + di.SignerFactory, + di.ConsumerTotalsStorage, + di.AddressProvider, + di.EventBus, + nodeOptions.Payments.ConsumerDataLeewayMegabytes, + ), + di.ConnectionRegistry.CreateConnection, + di.EventBus, + di.IPResolver, + di.LocationResolver, + connection.DefaultConfig(), + config.GetDuration(config.FlagStatsReportInterval), + connection.NewValidator( + di.ConsumerBalanceTracker, + di.IdentityManager, + ), + di.P2PDialer, + di.allowTrustedDomainBypassTunnel, + di.disallowTrustedDomainBypassTunnel, + ) + } di.NATProber = natprobe.NewNATProber(di.MultiConnectionManager, di.EventBus) di.LogCollector = logconfig.NewCollector(&logconfig.CurrentLogOptions) @@ -660,7 +683,7 @@ func (di *Dependencies) bootstrapNodeComponents(nodeOptions node.Options, tequil sleepNotifier := sleep.NewNotifier(di.MultiConnectionManager, di.EventBus) sleepNotifier.Subscribe() - di.Node = NewNode(di.MultiConnectionManager, tequilapiHTTPServer, di.EventBus, di.UIServer, sleepNotifier) + di.Node = NewNode(di.MultiConnectionManager, di.MultiConnectionDiagManager, tequilapiHTTPServer, di.EventBus, di.UIServer, sleepNotifier) return nil } @@ -930,7 +953,7 @@ func (di *Dependencies) bootstrapQualityComponents(options node.OptionsQuality, } if nodeOptions.ProvChecker { - di.provPinger = connection.NewProviderChecker(di.EventBus) + // di.provPinger = connection.NewProviderChecker(di.EventBus) } return nil diff --git a/cmd/node.go b/cmd/node.go index 870fcf635b..dfe7973436 100644 --- a/cmd/node.go +++ b/cmd/node.go @@ -42,23 +42,27 @@ type SleepNotifier interface { } // NewNode function creates new Mysterium node by given options -func NewNode(connectionManager connection.MultiManager, tequilapiServer tequilapi.APIServer, publisher Publisher, uiServer UIServer, notifier SleepNotifier) *Node { +func NewNode(connectionManager connection.MultiManager, connectionDiagManager connection.DiagManager, tequilapiServer tequilapi.APIServer, publisher Publisher, uiServer UIServer, notifier SleepNotifier) *Node { return &Node{ - connectionManager: connectionManager, - httpAPIServer: tequilapiServer, - publisher: publisher, - uiServer: uiServer, - sleepNotifier: notifier, + connectionManager: connectionManager, + connectionDiagManager: connectionDiagManager, + + httpAPIServer: tequilapiServer, + publisher: publisher, + uiServer: uiServer, + sleepNotifier: notifier, } } // Node represent entrypoint for Mysterium node with top level components type Node struct { - connectionManager connection.MultiManager - httpAPIServer tequilapi.APIServer - publisher Publisher - uiServer UIServer - sleepNotifier SleepNotifier + connectionManager connection.MultiManager + connectionDiagManager connection.DiagManager + + httpAPIServer tequilapi.APIServer + publisher Publisher + uiServer UIServer + sleepNotifier SleepNotifier } // Start starts Mysterium node (Tequilapi service, fetches location) diff --git a/core/connection/interface.go b/core/connection/interface.go index 531cca8f0f..dbbb8b2245 100644 --- a/core/connection/interface.go +++ b/core/connection/interface.go @@ -78,3 +78,15 @@ type MultiManager interface { // Reconnect reconnects current session Reconnect(n int) } + +// DiagManager interface provides methods to manage diagnotic connection +type DiagManager interface { + // Connect creates new connection from given consumer to provider, reports error if connection already exists + Connect(consumerID identity.Identity, hermesID common.Address, proposal ProposalLookup, params ConnectParams) error + // Status queries current status of connection + Status() connectionstate.Status + // GetReadyChan returns a channel for getting a diagnostic result + GetReadyChan(providerID string) chan interface{} + // HasConnection returns true if a diagnostic connection is already established + HasConnection(providerID string) bool +} diff --git a/core/connection/manager-diag.go b/core/connection/manager-diag.go new file mode 100644 index 0000000000..1a984877db --- /dev/null +++ b/core/connection/manager-diag.go @@ -0,0 +1,978 @@ +/* + * Copyright (C) 2024 The "MysteriumNetwork/node" Authors. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package connection + +import ( + "context" + "encoding/json" + "errors" + "fmt" + "math/big" + "sync" + "time" + + "github.com/ethereum/go-ethereum/common" + "github.com/rs/zerolog/log" + + "github.com/mysteriumnetwork/node/config" + "github.com/mysteriumnetwork/node/core/connection/connectionstate" + "github.com/mysteriumnetwork/node/core/discovery/proposal" + "github.com/mysteriumnetwork/node/core/ip" + "github.com/mysteriumnetwork/node/core/location" + "github.com/mysteriumnetwork/node/core/quality" + "github.com/mysteriumnetwork/node/eventbus" + "github.com/mysteriumnetwork/node/firewall" + "github.com/mysteriumnetwork/node/identity" + "github.com/mysteriumnetwork/node/market" + "github.com/mysteriumnetwork/node/p2p" + "github.com/mysteriumnetwork/node/pb" + "github.com/mysteriumnetwork/node/session" + "github.com/mysteriumnetwork/node/session/connectivity" + "github.com/mysteriumnetwork/node/trace" +) + +type conn struct { + // These are populated by Connect at runtime. + ctx context.Context + ctxLock sync.RWMutex + status connectionstate.Status + statusLock sync.RWMutex + cleanupLock sync.Mutex + cleanup []func() error + cleanupAfterDisconnect []func() error + cleanupFinished chan struct{} + cleanupFinishedLock sync.Mutex + acknowledge func() + cancel func() + channel p2p.Channel + + preReconnect func() + postReconnect func() + + discoLock sync.Mutex + connectOptions ConnectOptions + + activeConnection Connection + // statsTracker statsTracker + + resChannel chan interface{} + + uuid string +} + +type diagConnectionManager struct { + // These are passed on creation. + paymentEngineFactory PaymentEngineFactory + newConnection Creator + eventBus eventbus.EventBus + ipResolver ip.Resolver + locationResolver location.OriginResolver + config Config + statsReportInterval time.Duration + validator validator + p2pDialer p2p.Dialer + timeGetter TimeGetter + + // populated by Connect at runtime. + connsMu sync.Mutex + conns map[string]*conn +} + +// NewDiagManager creates connection manager with given dependencies +func NewDiagManager( + paymentEngineFactory PaymentEngineFactory, + connectionCreator Creator, + eventBus eventbus.EventBus, + ipResolver ip.Resolver, + locationResolver location.OriginResolver, + config Config, + statsReportInterval time.Duration, + validator validator, + p2pDialer p2p.Dialer, + preReconnect, postReconnect func(), +) *diagConnectionManager { + + m := &diagConnectionManager{ + conns: make(map[string]*conn), + + newConnection: connectionCreator, + eventBus: eventBus, + paymentEngineFactory: paymentEngineFactory, + ipResolver: ipResolver, + locationResolver: locationResolver, + config: config, + statsReportInterval: statsReportInterval, + validator: validator, + p2pDialer: p2pDialer, + timeGetter: time.Now, + } + + m.eventBus.SubscribeAsync(connectionstate.AppTopicConnectionState, m.reconnectOnHold) + + return m +} + +func (m *diagConnectionManager) chainID() int64 { + return config.GetInt64(config.FlagChainID) +} + +func (m *diagConnectionManager) HasConnection(providerID string) bool { + m.connsMu.Lock() + defer m.connsMu.Unlock() + + _, ok := m.conns[providerID] + return ok +} + +func (m *diagConnectionManager) GetReadyChan(providerID string) chan interface{} { + m.connsMu.Lock() + defer m.connsMu.Unlock() + + con, ok := m.conns[providerID] + if ok { + return con.resChannel + } + return nil +} + +func (m *diagConnectionManager) Connect(consumerID identity.Identity, hermesID common.Address, proposalLookup ProposalLookup, params ConnectParams) (err error) { + var sessionID session.ID + + proposal, err := proposalLookup() + if err != nil { + return fmt.Errorf("failed to lookup proposal: %w", err) + } + + tracer := trace.NewTracer("Consumer whole Connect") + defer func() { + traceResult := tracer.Finish(m.eventBus, string(sessionID)) + log.Debug().Msgf("Consumer connection trace: %s", traceResult) + }() + + fmt.Println("Connect>", proposal.ProviderID) + uuid := proposal.ProviderID + con, ok := m.conns[uuid] + if !ok { + con = new(conn) + con.status.State = connectionstate.NotConnected + con.uuid = uuid + m.conns[uuid] = con + } + removeConnection := func() { + m.connsMu.Lock() + defer m.connsMu.Unlock() + delete(m.conns, uuid) + } + + // make sure cache is cleared when connect terminates at any stage as part of disconnect + // we assume that IPResolver might be used / cache IP before connect + m.addCleanup(con, func() error { + m.clearIPCache() + return nil + }) + + if m.Status().State != connectionstate.NotConnected { + removeConnection() + return ErrAlreadyExists + } + + prc := m.priceFromProposal(*proposal) + + err = m.validator.Validate(m.chainID(), consumerID, prc) + if err != nil { + removeConnection() + return err + } + + con.ctxLock.Lock() + con.ctx, con.cancel = context.WithCancel(context.Background()) + con.ctxLock.Unlock() + + m.statusConnecting(con, consumerID, hermesID, *proposal) + defer func() { + if err != nil { + log.Err(err).Msg("Connect failed, disconnecting") + m.disconnect(con) + } + }() + + con.connectOptions = ConnectOptions{ + ConsumerID: consumerID, + HermesID: hermesID, + Proposal: *proposal, + ProposalLookup: proposalLookup, + Params: params, + } + + con.activeConnection, err = m.newConnection(proposal.ServiceType) + if err != nil { + removeConnection() + return err + } + + sessionID, err = m.initSession(con, tracer, prc) + if err != nil { + removeConnection() + return err + } + + originalPublicIP := m.getPublicIP() + + err = m.startConnection(con, m.currentCtx(con), con.activeConnection, con.activeConnection.Start, con.connectOptions, tracer) + if err != nil { + removeConnection() + return m.handleStartError(con, sessionID, err) + } + + err = m.waitForConnectedState(con, con.activeConnection.State()) + if err != nil { + removeConnection() + return m.handleStartError(con, sessionID, err) + } + + //m.statsTracker = newStatsTracker(m.eventBus, m.statsReportInterval) + //go m.statsTracker.start(m, m.activeConnection) + m.addCleanup(con, func() error { + log.Trace().Msg("Cleaning: stopping statistics publisher") + defer log.Trace().Msg("Cleaning: stopping statistics publisher DONE") + //m.statsTracker.stop() + + removeConnection() + return nil + }) + + con.resChannel = make(chan interface{}) + go Diag(m, con, proposal.ProviderID) + + go m.consumeConnectionStates(con, con.activeConnection.State()) + go m.checkSessionIP(con, con.channel, con.connectOptions.ConsumerID, con.connectOptions.SessionID, originalPublicIP) + + return nil +} + +func (m *diagConnectionManager) autoReconnect(con *conn) (err error) { + var sessionID session.ID + + tracer := trace.NewTracer("Consumer whole autoReconnect") + defer func() { + traceResult := tracer.Finish(m.eventBus, string(sessionID)) + log.Debug().Msgf("Consumer connection trace: %s", traceResult) + }() + + proposal, err := con.connectOptions.ProposalLookup() + if err != nil { + return fmt.Errorf("failed to lookup proposal: %w", err) + } + + con.connectOptions.Proposal = *proposal + + sessionID, err = m.initSession(con, tracer, m.priceFromProposal(con.connectOptions.Proposal)) + if err != nil { + return err + } + + err = m.startConnection(con, m.currentCtx(con), con.activeConnection, con.activeConnection.Reconnect, con.connectOptions, tracer) + if err != nil { + return m.handleStartError(con, sessionID, err) + } + + return nil +} + +func (m *diagConnectionManager) priceFromProposal(proposal proposal.PricedServiceProposal) market.Price { + p := market.Price{ + PricePerHour: proposal.Price.PricePerHour, + PricePerGiB: proposal.Price.PricePerGiB, + } + + if config.GetBool(config.FlagPaymentsDuringSessionDebug) { + log.Info().Msg("Payments debug bas been enabled, will use absurd amounts for the proposal price") + amount := config.GetUInt64(config.FlagPaymentsAmountDuringSessionDebug) + if amount == 0 { + amount = 5000000000000000000 + } + + p = market.Price{ + PricePerHour: new(big.Int).SetUint64(amount), + PricePerGiB: new(big.Int).SetUint64(amount), + } + } + + return p +} + +func (m *diagConnectionManager) initSession(con *conn, tracer *trace.Tracer, prc market.Price) (sessionID session.ID, err error) { + + err = m.createP2PChannel(con, con.connectOptions, tracer) + if err != nil { + return sessionID, fmt.Errorf("could not create p2p channel during connect: %w", err) + } + + con.connectOptions.ProviderNATConn = con.channel.ServiceConn() + con.connectOptions.ChannelConn = con.channel.Conn() + + paymentSession, err := m.paymentLoop(con, con.connectOptions, prc) + if err != nil { + return sessionID, err + } + + sessionDTO, err := m.createP2PSession(con, con.activeConnection, con.connectOptions, tracer, prc) + sessionID = session.ID(sessionDTO.GetID()) + if err != nil { + m.sendSessionStatus(con, con.channel, con.connectOptions.ConsumerID, sessionID, connectivity.StatusSessionEstablishmentFailed, err) + return sessionID, err + } + + traceStart := tracer.StartStage("Consumer session creation (start)") + go m.keepAliveLoop(con, con.channel, sessionID) + m.setStatus(con, func(status *connectionstate.Status) { + status.SessionID = sessionID + }) + m.publishSessionCreate(con, sessionID) + paymentSession.SetSessionID(string(sessionID)) + tracer.EndStage(traceStart) + + con.connectOptions.SessionID = sessionID + con.connectOptions.SessionConfig = sessionDTO.GetConfig() + + return sessionID, nil +} + +func (m *diagConnectionManager) handleStartError(con *conn, sessionID session.ID, err error) error { + + if errors.Is(err, context.Canceled) { + return ErrConnectionCancelled + } + m.addCleanupAfterDisconnect(con, func() error { + return m.sendSessionStatus(con, con.channel, con.connectOptions.ConsumerID, sessionID, connectivity.StatusConnectionFailed, err) + }) + m.publishStateEvent(con, connectionstate.StateConnectionFailed) + + log.Info().Err(err).Msg("Cancelling connection initiation: ") + m.Cancel() + return err +} + +func (m *diagConnectionManager) clearIPCache() { + if config.GetBool(config.FlagProxyMode) || config.GetBool(config.FlagDVPNMode) { + return + } + + if cr, ok := m.ipResolver.(*ip.CachedResolver); ok { + cr.ClearCache() + } +} + +// checkSessionIP checks if IP has changed after connection was established. +func (m *diagConnectionManager) checkSessionIP(con *conn, channel p2p.Channel, consumerID identity.Identity, sessionID session.ID, originalPublicIP string) { + if config.GetBool(config.FlagProxyMode) || config.GetBool(config.FlagDVPNMode) { + return + } + + for i := 1; i <= m.config.IPCheck.MaxAttempts; i++ { + // Skip check if not connected. This may happen when context was canceled via Disconnect. + if m.Status().State != connectionstate.Connected { + return + } + + newPublicIP := m.getPublicIP() + // If ip is changed notify peer that connection is successful. + if originalPublicIP != newPublicIP { + m.sendSessionStatus(con, channel, consumerID, sessionID, connectivity.StatusConnectionOk, nil) + return + } + + // Notify peer and quality oracle that ip is not changed after tunnel connection was established. + if i == m.config.IPCheck.MaxAttempts { + m.sendSessionStatus(con, channel, consumerID, sessionID, connectivity.StatusSessionIPNotChanged, nil) + m.publishStateEvent(con, connectionstate.StateIPNotChanged) + return + } + + time.Sleep(m.config.IPCheck.SleepDurationAfterCheck) + } +} + +// sendSessionStatus sends session connectivity status to other peer. +func (m *diagConnectionManager) sendSessionStatus(con *conn, channel p2p.ChannelSender, consumerID identity.Identity, sessionID session.ID, code connectivity.StatusCode, errDetails error) error { + var errDetailsMsg string + if errDetails != nil { + errDetailsMsg = errDetails.Error() + } + + sessionStatus := &pb.SessionStatus{ + ConsumerID: consumerID.Address, + SessionID: string(sessionID), + Code: uint32(code), + Message: errDetailsMsg, + } + + log.Debug().Msgf("Sending session status P2P message to %q: %s", p2p.TopicSessionStatus, sessionStatus.String()) + + ctx, cancel := context.WithTimeout(m.currentCtx(con), 20*time.Second) + defer cancel() + _, err := channel.Send(ctx, p2p.TopicSessionStatus, p2p.ProtoMessage(sessionStatus)) + if err != nil { + return fmt.Errorf("could not send p2p session status message: %w", err) + } + + return nil +} + +func (m *diagConnectionManager) getPublicIP() string { + currentPublicIP, err := m.ipResolver.GetPublicIP() + if err != nil { + log.Error().Err(err).Msg("Could not get current public IP") + return "" + } + return currentPublicIP +} + +func (m *diagConnectionManager) paymentLoop(con *conn, opts ConnectOptions, price market.Price) (PaymentIssuer, error) { + + payments, err := m.paymentEngineFactory(con.uuid, con.channel, opts.ConsumerID, identity.FromAddress(opts.Proposal.ProviderID), opts.HermesID, opts.Proposal, price) + if err != nil { + return nil, err + } + m.addCleanup(con, func() error { + log.Trace().Msg("Cleaning: payments") + defer log.Trace().Msg("Cleaning: payments DONE") + payments.Stop() + return nil + }) + + go func() { + err := payments.Start() + if err != nil { + log.Error().Err(err).Msg("Payment error") + + if config.GetBool(config.FlagKeepConnectedOnFail) { + m.statusOnHold(con) + } else { + err = m.Disconnect() + if err != nil { + log.Error().Err(err).Msg("Could not disconnect gracefully") + } + } + } + }() + return payments, nil +} + +func (m *diagConnectionManager) cleanConnection(con *conn) { + con.cleanupLock.Lock() + defer con.cleanupLock.Unlock() + + for i := len(con.cleanup) - 1; i >= 0; i-- { + log.Trace().Msgf("Connection cleaning up: (%v/%v)", i+1, len(con.cleanup)) + err := con.cleanup[i]() + if err != nil { + log.Warn().Err(err).Msg("Cleanup error") + } + } + con.cleanup = nil +} + +func (m *diagConnectionManager) cleanAfterDisconnect(con *conn) { + con.cleanupLock.Lock() + defer con.cleanupLock.Unlock() + + for i := len(con.cleanupAfterDisconnect) - 1; i >= 0; i-- { + log.Trace().Msgf("Connection cleaning up (after disconnect): (%v/%v)", i+1, len(con.cleanupAfterDisconnect)) + err := con.cleanupAfterDisconnect[i]() + if err != nil { + log.Warn().Err(err).Msg("Cleanup error") + } + } + con.cleanupAfterDisconnect = nil +} + +func (m *diagConnectionManager) createP2PChannel(con *conn, opts ConnectOptions, tracer *trace.Tracer) error { + trace := tracer.StartStage("Consumer P2P channel creation") + defer tracer.EndStage(trace) + + contactDef, err := p2p.ParseContact(opts.Proposal.Contacts) + if err != nil { + return fmt.Errorf("provider does not support p2p communication: %w", err) + } + + timeoutCtx, cancel := context.WithTimeout(m.currentCtx(con), p2pDialTimeout) + defer cancel() + + // TODO register all handlers before channel read/write loops + channel, err := m.p2pDialer.Dial(timeoutCtx, opts.ConsumerID, identity.FromAddress(opts.Proposal.ProviderID), opts.Proposal.ServiceType, contactDef, tracer) + if err != nil { + return fmt.Errorf("p2p dialer failed: %w", err) + } + m.addCleanupAfterDisconnect(con, func() error { + log.Trace().Msg("Cleaning: closing P2P communication channel") + defer log.Trace().Msg("Cleaning: P2P communication channel DONE") + + return channel.Close() + }) + + con.channel = channel + return nil +} + +func (m *diagConnectionManager) addCleanupAfterDisconnect(con *conn, fn func() error) { + con.cleanupLock.Lock() + defer con.cleanupLock.Unlock() + con.cleanupAfterDisconnect = append(con.cleanupAfterDisconnect, fn) +} + +func (m *diagConnectionManager) addCleanup(con *conn, fn func() error) { + con.cleanupLock.Lock() + defer con.cleanupLock.Unlock() + con.cleanup = append(con.cleanup, fn) +} + +func (m *diagConnectionManager) createP2PSession(con *conn, c Connection, opts ConnectOptions, tracer *trace.Tracer, requestedPrice market.Price) (*pb.SessionResponse, error) { + trace := tracer.StartStage("Consumer session creation") + defer tracer.EndStage(trace) + + sessionCreateConfig, err := c.GetConfig() + if err != nil { + return nil, fmt.Errorf("could not get session config: %w", err) + } + + config, err := json.Marshal(sessionCreateConfig) + if err != nil { + return nil, fmt.Errorf("could not marshal session config: %w", err) + } + + sessionRequest := &pb.SessionRequest{ + Consumer: &pb.ConsumerInfo{ + Id: opts.ConsumerID.Address, + HermesID: opts.HermesID.Hex(), + PaymentVersion: "v3", + Location: &pb.LocationInfo{ + Country: m.Status().ConsumerLocation.Country, + }, + Pricing: &pb.Pricing{ + PerGib: requestedPrice.PricePerGiB.Bytes(), + PerHour: requestedPrice.PricePerHour.Bytes(), + }, + }, + ProposalID: opts.Proposal.ID, + Config: config, + } + log.Debug().Msgf("Sending P2P message to %q: %s", p2p.TopicSessionCreate, sessionRequest.String()) + ctx, cancel := context.WithTimeout(m.currentCtx(con), 20*time.Second) + defer cancel() + res, err := con.channel.Send(ctx, p2p.TopicSessionCreate, p2p.ProtoMessage(sessionRequest)) + if err != nil { + return nil, fmt.Errorf("could not send p2p session create request: %w", err) + } + + var sessionResponse pb.SessionResponse + err = res.UnmarshalProto(&sessionResponse) + if err != nil { + return nil, fmt.Errorf("could not unmarshal session reply to proto: %w", err) + } + + channel := con.channel + con.acknowledge = func() { + pc := &pb.SessionInfo{ + ConsumerID: opts.ConsumerID.Address, + SessionID: sessionResponse.GetID(), + } + log.Debug().Msgf("Sending P2P message to %q: %s", p2p.TopicSessionAcknowledge, pc.String()) + ctx, cancel := context.WithTimeout(context.Background(), 20*time.Second) + defer cancel() + _, err := channel.Send(ctx, p2p.TopicSessionAcknowledge, p2p.ProtoMessage(pc)) + if err != nil { + log.Warn().Err(err).Msg("Acknowledge failed") + } + } + m.addCleanupAfterDisconnect(con, func() error { + log.Trace().Msg("Cleaning: requesting session destroy") + defer log.Trace().Msg("Cleaning: requesting session destroy DONE") + + sessionDestroy := &pb.SessionInfo{ + ConsumerID: opts.ConsumerID.Address, + SessionID: sessionResponse.GetID(), + } + + log.Debug().Msgf("Sending P2P message to %q: %s", p2p.TopicSessionDestroy, sessionDestroy.String()) + ctx, cancel := context.WithTimeout(context.Background(), 1*time.Second) + defer cancel() + _, err := con.channel.Send(ctx, p2p.TopicSessionDestroy, p2p.ProtoMessage(sessionDestroy)) + if err != nil { + return fmt.Errorf("could not send session destroy request: %w", err) + } + + return nil + }) + + return &sessionResponse, nil +} + +func (m *diagConnectionManager) publishSessionCreate(con *conn, sessionID session.ID) { + sessionInfo := m.Status() + // avoid printing IP address in logs + sessionInfo.ConsumerLocation.IP = "" + + m.eventBus.Publish(connectionstate.AppTopicConnectionSession, connectionstate.AppEventConnectionSession{ + Status: connectionstate.SessionCreatedStatus, + SessionInfo: sessionInfo, + }) + + m.addCleanup(con, func() error { + log.Trace().Msg("Cleaning: publishing session ended status") + defer log.Trace().Msg("Cleaning: publishing session ended status DONE") + + sessionInfo := m.Status() + // avoid printing IP address in logs + sessionInfo.ConsumerLocation.IP = "" + + m.eventBus.Publish(connectionstate.AppTopicConnectionSession, connectionstate.AppEventConnectionSession{ + Status: connectionstate.SessionEndedStatus, + SessionInfo: sessionInfo, + }) + return nil + }) +} + +func (m *diagConnectionManager) startConnection(con *conn, ctx context.Context, conn Connection, start ConnectionStart, connectOptions ConnectOptions, tracer *trace.Tracer) (err error) { + trace := tracer.StartStage("Consumer start connection") + defer tracer.EndStage(trace) + + if err = start(ctx, connectOptions); err != nil { + return err + } + m.addCleanup(con, func() error { + log.Trace().Msg("Cleaning: stopping connection") + defer log.Trace().Msg("Cleaning: stopping connection DONE") + conn.Stop() + return nil + }) + + err = m.setupTrafficBlock(con, connectOptions.Params.DisableKillSwitch) + if err != nil { + return err + } + + // Clear IP cache so session IP check can report that IP has really changed. + m.clearIPCache() + + return nil +} + +func (m *diagConnectionManager) Status() connectionstate.Status { + log.Debug().Msg("Status() - not used") + return connectionstate.Status{State: connectionstate.NotConnected} +} + +func (m *diagConnectionManager) UUID() string { + log.Debug().Msg("UUID() - not used") + return "" +} + +func (m *diagConnectionManager) Stats() connectionstate.Statistics { + log.Debug().Msg("Stats() - not used") + return connectionstate.Statistics{} +} + +func (m *diagConnectionManager) setStatus(con *conn, delta func(status *connectionstate.Status)) { + con.statusLock.Lock() + stateWas := con.status.State + + delta(&con.status) + + state := con.status.State + con.statusLock.Unlock() + + if state != stateWas { + log.Info().Msgf("Connection state: %v -> %v", stateWas, state) + m.publishStateEvent(con, state) + } +} + +func (m *diagConnectionManager) statusConnecting(con *conn, consumerID identity.Identity, accountantID common.Address, proposal proposal.PricedServiceProposal) { + m.setStatus(con, func(status *connectionstate.Status) { + *status = connectionstate.Status{ + StartedAt: m.timeGetter(), + ConsumerID: consumerID, + ConsumerLocation: m.locationResolver.GetOrigin(), + HermesID: accountantID, + Proposal: proposal, + State: connectionstate.Connecting, + } + }) +} + +func (m *diagConnectionManager) statusConnected(con *conn) { + m.setStatus(con, func(status *connectionstate.Status) { + status.State = connectionstate.Connected + }) +} + +func (m *diagConnectionManager) statusReconnecting(con *conn) { + m.setStatus(con, func(status *connectionstate.Status) { + status.State = connectionstate.Reconnecting + }) +} + +func (m *diagConnectionManager) statusNotConnected(con *conn) { + m.setStatus(con, func(status *connectionstate.Status) { + status.State = connectionstate.NotConnected + }) +} + +func (m *diagConnectionManager) statusDisconnecting(con *conn) { + m.setStatus(con, func(status *connectionstate.Status) { + status.State = connectionstate.Disconnecting + }) +} + +func (m *diagConnectionManager) statusCanceled(con *conn) { + m.setStatus(con, func(status *connectionstate.Status) { + status.State = connectionstate.Canceled + }) +} + +func (m *diagConnectionManager) statusOnHold(con *conn) { + m.setStatus(con, func(status *connectionstate.Status) { + status.State = connectionstate.StateOnHold + }) +} + +func (m *diagConnectionManager) Cancel() { + log.Error().Msg("Cancel() - not used") +} + +func (m *diagConnectionManager) DisconnectSingle(con *conn) error { + m.statusDisconnecting(con) + m.disconnect(con) + return nil +} + +func (m *diagConnectionManager) Disconnect() error { + log.Error().Msg("Disconnect() - not used") + return nil +} + +func (m *diagConnectionManager) CheckChannel(con *conn, ctx context.Context) error { + if err := m.sendKeepAlivePing(ctx, con.channel, m.Status().SessionID); err != nil { + return fmt.Errorf("keep alive ping failed: %w", err) + } + return nil +} + +func (m *diagConnectionManager) disconnect(con *conn) { + con.discoLock.Lock() + defer con.discoLock.Unlock() + + con.cleanupFinishedLock.Lock() + defer con.cleanupFinishedLock.Unlock() + con.cleanupFinished = make(chan struct{}) + defer close(con.cleanupFinished) + + con.ctxLock.Lock() + con.cancel() + con.ctxLock.Unlock() + + m.cleanConnection(con) + m.statusNotConnected(con) + + m.cleanAfterDisconnect(con) +} + +func (m *diagConnectionManager) waitForConnectedState(con *conn, stateChannel <-chan connectionstate.State) error { + log.Debug().Msg("waiting for connected state") + for { + select { + case state, more := <-stateChannel: + if !more { + return ErrConnectionFailed + } + + switch state { + case connectionstate.Connected: + log.Debug().Msg("Connected started event received") + if con.acknowledge != nil { + go con.acknowledge() + } + m.onStateChanged(con, state) + return nil + default: + m.onStateChanged(con, state) + } + case <-m.currentCtx(con).Done(): + return m.currentCtx(con).Err() + } + } +} + +func (m *diagConnectionManager) consumeConnectionStates(con *conn, stateChannel <-chan connectionstate.State) { + for state := range stateChannel { + m.onStateChanged(con, state) + } +} + +func (m *diagConnectionManager) onStateChanged(con *conn, state connectionstate.State) { + log.Debug().Msgf("Connection state received: %s", state) + + // React just to certain stains from connection. Because disconnect happens in connectionWaiter + switch state { + case connectionstate.Connected: + m.statusConnected(con) + case connectionstate.Reconnecting: + m.statusReconnecting(con) + } +} + +func (m *diagConnectionManager) setupTrafficBlock(con *conn, disableKillSwitch bool) error { + if disableKillSwitch { + return nil + } + + outboundIP, err := m.ipResolver.GetOutboundIP() + if err != nil { + return err + } + + removeRule, err := firewall.BlockNonTunnelTraffic(firewall.Session, outboundIP) + if err != nil { + return err + } + m.addCleanup(con, func() error { + log.Trace().Msg("Cleaning: traffic block rule") + defer log.Trace().Msg("Cleaning: traffic block rule DONE") + + removeRule() + + return nil + }) + return nil +} + +func (m *diagConnectionManager) reconnectOnHold(state connectionstate.AppEventConnectionState) { + if state.State != connectionstate.StateOnHold || !config.GetBool(config.FlagAutoReconnect) { + return + } + + con, ok := m.conns[state.UUID] + if !ok { + return + } + + if con.channel != nil { + con.channel.Close() + } + + con.preReconnect() + m.clearIPCache() + + for err := m.autoReconnect(con); err != nil; err = m.autoReconnect(con) { + select { + case <-m.currentCtx(con).Done(): + log.Info().Err(m.currentCtx(con).Err()).Msg("Stopping reconnect") + return + default: + log.Error().Err(err).Msg("Failed to reconnect active session, will try again") + } + } + con.postReconnect() +} + +func (m *diagConnectionManager) publishStateEvent(con *conn, state connectionstate.State) { + sessionInfo := m.Status() + // avoid printing IP address in logs + sessionInfo.ConsumerLocation.IP = "" + + m.eventBus.Publish(connectionstate.AppTopicConnectionState, connectionstate.AppEventConnectionState{ + UUID: con.uuid, + State: state, + SessionInfo: sessionInfo, + }) +} + +func (m *diagConnectionManager) keepAliveLoop(con *conn, channel p2p.Channel, sessionID session.ID) { + // Register handler for handling p2p keep alive pings from provider. + channel.Handle(p2p.TopicKeepAlive, func(c p2p.Context) error { + var ping pb.P2PKeepAlivePing + if err := c.Request().UnmarshalProto(&ping); err != nil { + return err + } + + log.Debug().Msgf("Received p2p keepalive ping with SessionID=%s from %s", ping.SessionID, c.PeerID().ToCommonAddress()) + return c.OK() + }) + + // Send pings to provider. + var errCount int + for { + select { + case <-m.currentCtx(con).Done(): + log.Debug().Msgf("Stopping p2p keepalive: %v", m.currentCtx(con).Err()) + return + case <-time.After(m.config.KeepAlive.SendInterval): + ctx, cancel := context.WithTimeout(context.Background(), m.config.KeepAlive.SendTimeout) + if err := m.sendKeepAlivePing(ctx, channel, sessionID); err != nil { + log.Err(err).Msgf("Failed to send p2p keepalive ping. SessionID=%s", sessionID) + errCount++ + if errCount == m.config.KeepAlive.MaxSendErrCount { + log.Error().Msgf("Max p2p keepalive err count reached, disconnecting. SessionID=%s", sessionID) + if config.GetBool(config.FlagKeepConnectedOnFail) { + m.statusOnHold(con) + } else { + m.Disconnect() + } + cancel() + return + } + } else { + errCount = 0 + } + cancel() + } + } +} + +func (m *diagConnectionManager) sendKeepAlivePing(ctx context.Context, channel p2p.Channel, sessionID session.ID) error { + msg := &pb.P2PKeepAlivePing{ + SessionID: string(sessionID), + } + + start := time.Now() + _, err := channel.Send(ctx, p2p.TopicKeepAlive, p2p.ProtoMessage(msg)) + if err != nil { + return err + } + + _ = start + m.eventBus.Publish(quality.AppTopicConsumerPingP2P, quality.PingEvent{ + SessionID: string(sessionID), + Duration: time.Since(start), + }) + + return nil +} + +func (m *diagConnectionManager) currentCtx(con *conn) context.Context { + con.ctxLock.RLock() + defer con.ctxLock.RUnlock() + + return con.ctx +} + +func (m *diagConnectionManager) Reconnect() { + log.Error().Msg("Reconnect - not used") +} diff --git a/core/connection/manager.go b/core/connection/manager.go index 15ce1bca25..e87d5a32ad 100644 --- a/core/connection/manager.go +++ b/core/connection/manager.go @@ -170,8 +170,6 @@ type connectionManager struct { statsTracker statsTracker uuid string - - provChecker *ProviderChecker } // NewManager creates connection manager with given dependencies @@ -186,7 +184,6 @@ func NewManager( validator validator, p2pDialer p2p.Dialer, preReconnect, postReconnect func(), - provChecker *ProviderChecker, ) *connectionManager { uuid, err := uuid.NewV4() if err != nil { @@ -210,7 +207,6 @@ func NewManager( preReconnect: preReconnect, postReconnect: postReconnect, uuid: uuid.String(), - provChecker: provChecker, } m.eventBus.SubscribeAsync(connectionstate.AppTopicConnectionState, m.reconnectOnHold) @@ -305,10 +301,6 @@ func (m *connectionManager) Connect(consumerID identity.Identity, hermesID commo return nil }) - if m.provChecker != nil { - go m.provChecker.Diag(m, proposal.ProviderID) - } - go m.consumeConnectionStates(m.activeConnection.State()) go m.checkSessionIP(m.channel, m.connectOptions.ConsumerID, m.connectOptions.SessionID, originalPublicIP) diff --git a/core/connection/manager_test.go b/core/connection/manager_test.go index 92ba87f454..61f7a5e5b4 100644 --- a/core/connection/manager_test.go +++ b/core/connection/manager_test.go @@ -61,7 +61,6 @@ type testContext struct { statsReportInterval time.Duration mockP2P *mockP2PDialer mockTime time.Time - provChecker *ProviderChecker sync.RWMutex } @@ -142,7 +141,6 @@ func (tc *testContext) SetupTest() { tc.mockP2P = &mockP2PDialer{&mockP2PChannel{}} tc.mockTime = time.Date(2000, time.January, 0, 10, 12, 3, 0, time.UTC) - tc.provChecker = NewProviderChecker(tc.stubPublisher) tc.connManager = NewManager( func(senderUUID string, channel p2p.Channel, @@ -162,7 +160,6 @@ func (tc *testContext) SetupTest() { &mockValidator{}, tc.mockP2P, func() {}, func() {}, - tc.provChecker, ) tc.connManager.timeGetter = func() time.Time { return tc.mockTime diff --git a/core/connection/pinger.go b/core/connection/pinger.go index 9ed462507d..6ececc5437 100644 --- a/core/connection/pinger.go +++ b/core/connection/pinger.go @@ -19,32 +19,20 @@ package connection import ( "github.com/mysteriumnetwork/node/core/quality" - "github.com/mysteriumnetwork/node/eventbus" "github.com/rs/zerolog/log" ) -// ProviderChecker is a service for provider testing -type ProviderChecker struct { - bus eventbus.Publisher -} - -// NewProviderChecker is a ProviderChecker constructor -func NewProviderChecker(bus eventbus.Publisher) *ProviderChecker { - return &ProviderChecker{ - bus: bus, - } -} - // Diag is used to start provider check -func (p *ProviderChecker) Diag(cm *connectionManager, providerID string) { - c, ok := cm.activeConnection.(ConnectionDiag) +func Diag(cm *diagConnectionManager, con *conn, providerID string) { + c, ok := con.activeConnection.(ConnectionDiag) res := false if ok { log.Debug().Msgf("Check provider> %v", providerID) res = c.Diag() - cm.Disconnect() + cm.DisconnectSingle(con) } ev := quality.DiagEvent{ProviderID: providerID, Result: res} - p.bus.Publish(quality.AppTopicConnectionDiagRes, ev) + con.resChannel <- ev + close(con.resChannel) } diff --git a/tequilapi/endpoints/connection-diag.go b/tequilapi/endpoints/connection-diag.go index 3435fa0178..3ff08bc4f8 100644 --- a/tequilapi/endpoints/connection-diag.go +++ b/tequilapi/endpoints/connection-diag.go @@ -19,11 +19,9 @@ package endpoints import ( "fmt" - "strconv" "github.com/ethereum/go-ethereum/common" "github.com/gin-gonic/gin" - "github.com/gofrs/uuid" "github.com/pkg/errors" "github.com/rs/zerolog/log" @@ -43,7 +41,7 @@ import ( // ConnectionDiagEndpoint struct represents /connection resource and it's subresources type ConnectionDiagEndpoint struct { - manager connection.MultiManager + manager connection.DiagManager publisher eventbus.Publisher subscriber eventbus.Subscriber @@ -56,7 +54,7 @@ type ConnectionDiagEndpoint struct { } // NewConnectionDiagEndpoint creates and returns connection endpoint -func NewConnectionDiagEndpoint(manager connection.MultiManager, stateProvider stateProvider, proposalRepository proposalRepository, identityRegistry identityRegistry, publisher eventbus.Publisher, subscriber eventbus.Subscriber, addressProvider addressProvider, identitySelector selector.Handler) *ConnectionDiagEndpoint { +func NewConnectionDiagEndpoint(manager connection.DiagManager, stateProvider stateProvider, proposalRepository proposalRepository, identityRegistry identityRegistry, publisher eventbus.Publisher, subscriber eventbus.Subscriber, addressProvider addressProvider, identitySelector selector.Handler) *ConnectionDiagEndpoint { return &ConnectionDiagEndpoint{ manager: manager, publisher: publisher, @@ -69,44 +67,9 @@ func NewConnectionDiagEndpoint(manager connection.MultiManager, stateProvider st } } -// Status returns result of provider check -// swagger:operation GET /prov-checker ConnectionDiagInfoDTO -// -// --- -// summary: Returns connection status -// description: Returns status of current connection -// responses: -// 200: -// description: Status -// schema: -// "$ref": "#/definitions/ConnectionInfoDTO" -// 400: -// description: Failed to parse or request validation failed -// schema: -// "$ref": "#/definitions/APIError" -// 500: -// description: Internal server error -// schema: -// "$ref": "#/definitions/APIError" -func (ce *ConnectionDiagEndpoint) Status(c *gin.Context) { - n := 0 - id := c.Query("id") - if len(id) > 0 { - var err error - n, err = strconv.Atoi(id) - if err != nil { - c.Error(apierror.ParseFailed()) - return - } - } - status := ce.manager.Status(n) - statusResponse := contract.NewConnectionInfoDTO(status) - utils.WriteAsJSON(statusResponse, c.Writer) -} - // Diag is used to start provider check func (ce *ConnectionDiagEndpoint) Diag(c *gin.Context) { - log.Error().Msgf("Diag >>>") + log.Debug().Msgf("Diag >>>") chainID := config.GetInt64(config.FlagChainID) consumerID_, err := ce.identitySelector.UseOrCreate(config.FlagIdentity.Value, config.FlagIdentityPassphrase.Value, chainID) @@ -178,23 +141,12 @@ func (ce *ConnectionDiagEndpoint) Diag(c *gin.Context) { } proposalLookup := connection.FilteredProposals(f, cr.Filter.SortBy, ce.proposalRepository) - res := make(chan bool) - cb := func(r quality.DiagEvent) { - if r.ProviderID == prov { - res <- r.Result - } - } - - uid, err := uuid.NewV4() - if err != nil { - log.Error().Msgf("Error > %v", err) - c.Error(err) + hasConnection := ce.manager.HasConnection(cr.ProviderID) + if hasConnection { + c.Error(apierror.Unprocessable("Connection already exists", contract.ErrCodeConnectionAlreadyExists)) return } - ce.subscriber.SubscribeWithUID(quality.AppTopicConnectionDiagRes, uid.String(), cb) - defer ce.subscriber.UnsubscribeWithUID(quality.AppTopicConnectionDiagRes, uid.String(), cb) - err = ce.manager.Connect(consumerID, common.HexToAddress(cr.HermesID), proposalLookup, getConnectOptions(cr)) if err != nil { switch err { @@ -206,22 +158,23 @@ func (ce *ConnectionDiagEndpoint) Diag(c *gin.Context) { log.Error().Err(err).Msg("Failed to connect") c.Error(apierror.Internal("Failed to connect: "+err.Error(), contract.ErrCodeConnect)) } - return } - r := <-res - log.Debug().Msgf("Result > %v", r) + resChannel := ce.manager.GetReadyChan(cr.ProviderID) + res := <-resChannel + log.Error().Msgf("Result > %v", res) + resp := contract.ConnectionDiagInfoDTO{ ProviderID: prov, - Status: r, + Status: res.(quality.DiagEvent).Result, } utils.WriteAsJSON(resp, c.Writer) } // AddRoutesForConnectionDiag adds proder check route to given router func AddRoutesForConnectionDiag( - manager connection.MultiManager, + manager connection.DiagManager, stateProvider stateProvider, proposalRepository proposalRepository, identityRegistry identityRegistry, From 838c712c57c974cf881c897dfbf828435de5c62b Mon Sep 17 00:00:00 2001 From: Anton Litvinov Date: Thu, 23 May 2024 02:15:43 +0400 Subject: [PATCH 3/6] Add batch mode for provider diagnostic endpoint Signed-off-by: Anton Litvinov --- cmd/bootstrap.go | 96 +++++------ cmd/di.go | 4 - core/connection/manager-diag.go | 1 - core/connection/manager_test.go | 1 - core/quality/metrics.go | 3 - services/wireguard/connection/connection.go | 2 - services/wireguard/endpoint/endpoint.go | 2 - tequilapi/endpoints/connection-diag.go | 171 +++++++++++++++++--- 8 files changed, 202 insertions(+), 78 deletions(-) diff --git a/cmd/bootstrap.go b/cmd/bootstrap.go index b0bae35d4e..5024f8fdde 100644 --- a/cmd/bootstrap.go +++ b/cmd/bootstrap.go @@ -45,56 +45,60 @@ func (di *Dependencies) bootstrapTequilapi(nodeOptions node.Options, listener ne } tequilaApiClient := tequilapi_client.NewClient(nodeOptions.TequilapiAddress, nodeOptions.TequilapiPort) + handlers := []func(engine *gin.Engine) error{ + func(e *gin.Engine) error { + if err := tequilapi_endpoints.AddRoutesForSSE(e, di.StateKeeper, di.EventBus); err != nil { + return err + } + return nil + }, + func(e *gin.Engine) error { + if config.GetBool(config.FlagPProfEnable) { + tequilapi_endpoints.AddRoutesForPProf(e) + } + return nil + }, + func(e *gin.Engine) error { + e.GET("/healthcheck", tequilapi_endpoints.HealthCheckEndpointFactory(time.Now, os.Getpid).HealthCheck) + return nil + }, + tequilapi_endpoints.AddRouteForStop(utils.SoftKiller(di.Shutdown)), + tequilapi_endpoints.AddRoutesForAuthentication(di.Authenticator, di.JWTAuthenticator, di.SSOMystnodes), + tequilapi_endpoints.AddRoutesForIdentities(di.IdentityManager, di.IdentitySelector, di.IdentityRegistry, di.ConsumerBalanceTracker, di.AddressProvider, di.HermesChannelRepository, di.BCHelper, di.Transactor, di.BeneficiaryProvider, di.IdentityMover, di.BeneficiaryAddressStorage, di.HermesMigrator), + tequilapi_endpoints.AddRoutesForConnection(di.MultiConnectionManager, di.StateKeeper, di.ProposalRepository, di.IdentityRegistry, di.EventBus, di.AddressProvider), + tequilapi_endpoints.AddRoutesForSessions(di.SessionStorage), + tequilapi_endpoints.AddRoutesForConnectionLocation(di.IPResolver, di.LocationResolver, di.LocationResolver), + tequilapi_endpoints.AddRoutesForProposals(di.ProposalRepository, di.PricingHelper, di.LocationResolver, di.FilterPresetStorage, di.NATProber), + tequilapi_endpoints.AddRoutesForService(di.ServicesManager, services.JSONParsersByType, di.ProposalRepository, tequilaApiClient), + tequilapi_endpoints.AddRoutesForAccessPolicies(di.HTTPClient, config.GetString(config.FlagAccessPolicyAddress)), + tequilapi_endpoints.AddRoutesForNAT(di.StateKeeper, di.NATProber), + tequilapi_endpoints.AddRoutesForNodeUI(versionmanager.NewVersionManager(di.UIServer, di.HTTPClient, di.uiVersionConfig)), + tequilapi_endpoints.AddRoutesForNode(di.NodeStatusTracker, di.NodeStatsTracker), + tequilapi_endpoints.AddRoutesForTransactor(di.IdentityRegistry, di.Transactor, di.Affiliator, di.HermesPromiseSettler, di.SettlementHistoryStorage, di.AddressProvider, di.BeneficiaryProvider, di.BeneficiarySaver, di.PilvytisAPI), + tequilapi_endpoints.AddRoutesForAffiliator(di.Affiliator), + tequilapi_endpoints.AddRoutesForConfig, + tequilapi_endpoints.AddRoutesForMMN(di.MMN, di.SSOMystnodes, di.Authenticator), + tequilapi_endpoints.AddRoutesForFeedback(di.Reporter), + tequilapi_endpoints.AddRoutesForConnectivityStatus(di.SessionConnectivityStatusStorage), + tequilapi_endpoints.AddRoutesForDocs, + tequilapi_endpoints.AddRoutesForCurrencyExchange(di.PilvytisAPI), + tequilapi_endpoints.AddRoutesForPilvytis(di.PilvytisAPI, di.PilvytisOrderIssuer, di.LocationResolver), + tequilapi_endpoints.AddRoutesForTerms, + tequilapi_endpoints.AddEntertainmentRoutes(entertainment.NewEstimator( + config.FlagPaymentPriceGiB.Value, + config.FlagPaymentPriceHour.Value, + )), + tequilapi_endpoints.AddRoutesForValidator, + } + if nodeOptions.ProvChecker { + handlers = append(handlers, tequilapi_endpoints.AddRoutesForConnectionDiag(di.MultiConnectionDiagManager, di.StateKeeper, di.ProposalRepository, di.IdentityRegistry, di.EventBus, di.EventBus, di.AddressProvider, di.IdentitySelector, nodeOptions)) + } + return tequilapi.NewServer( listener, nodeOptions, di.JWTAuthenticator, - []func(engine *gin.Engine) error{ - func(e *gin.Engine) error { - if err := tequilapi_endpoints.AddRoutesForSSE(e, di.StateKeeper, di.EventBus); err != nil { - return err - } - return nil - }, - func(e *gin.Engine) error { - if config.GetBool(config.FlagPProfEnable) { - tequilapi_endpoints.AddRoutesForPProf(e) - } - return nil - }, - func(e *gin.Engine) error { - e.GET("/healthcheck", tequilapi_endpoints.HealthCheckEndpointFactory(time.Now, os.Getpid).HealthCheck) - return nil - }, - tequilapi_endpoints.AddRouteForStop(utils.SoftKiller(di.Shutdown)), - tequilapi_endpoints.AddRoutesForAuthentication(di.Authenticator, di.JWTAuthenticator, di.SSOMystnodes), - tequilapi_endpoints.AddRoutesForIdentities(di.IdentityManager, di.IdentitySelector, di.IdentityRegistry, di.ConsumerBalanceTracker, di.AddressProvider, di.HermesChannelRepository, di.BCHelper, di.Transactor, di.BeneficiaryProvider, di.IdentityMover, di.BeneficiaryAddressStorage, di.HermesMigrator), - tequilapi_endpoints.AddRoutesForConnection(di.MultiConnectionManager, di.StateKeeper, di.ProposalRepository, di.IdentityRegistry, di.EventBus, di.AddressProvider), - tequilapi_endpoints.AddRoutesForConnectionDiag(di.MultiConnectionDiagManager, di.StateKeeper, di.ProposalRepository, di.IdentityRegistry, di.EventBus, di.EventBus, di.AddressProvider, di.IdentitySelector, nodeOptions), - tequilapi_endpoints.AddRoutesForSessions(di.SessionStorage), - tequilapi_endpoints.AddRoutesForConnectionLocation(di.IPResolver, di.LocationResolver, di.LocationResolver), - tequilapi_endpoints.AddRoutesForProposals(di.ProposalRepository, di.PricingHelper, di.LocationResolver, di.FilterPresetStorage, di.NATProber), - tequilapi_endpoints.AddRoutesForService(di.ServicesManager, services.JSONParsersByType, di.ProposalRepository, tequilaApiClient), - tequilapi_endpoints.AddRoutesForAccessPolicies(di.HTTPClient, config.GetString(config.FlagAccessPolicyAddress)), - tequilapi_endpoints.AddRoutesForNAT(di.StateKeeper, di.NATProber), - tequilapi_endpoints.AddRoutesForNodeUI(versionmanager.NewVersionManager(di.UIServer, di.HTTPClient, di.uiVersionConfig)), - tequilapi_endpoints.AddRoutesForNode(di.NodeStatusTracker, di.NodeStatsTracker), - tequilapi_endpoints.AddRoutesForTransactor(di.IdentityRegistry, di.Transactor, di.Affiliator, di.HermesPromiseSettler, di.SettlementHistoryStorage, di.AddressProvider, di.BeneficiaryProvider, di.BeneficiarySaver, di.PilvytisAPI), - tequilapi_endpoints.AddRoutesForAffiliator(di.Affiliator), - tequilapi_endpoints.AddRoutesForConfig, - tequilapi_endpoints.AddRoutesForMMN(di.MMN, di.SSOMystnodes, di.Authenticator), - tequilapi_endpoints.AddRoutesForFeedback(di.Reporter), - tequilapi_endpoints.AddRoutesForConnectivityStatus(di.SessionConnectivityStatusStorage), - tequilapi_endpoints.AddRoutesForDocs, - tequilapi_endpoints.AddRoutesForCurrencyExchange(di.PilvytisAPI), - tequilapi_endpoints.AddRoutesForPilvytis(di.PilvytisAPI, di.PilvytisOrderIssuer, di.LocationResolver), - tequilapi_endpoints.AddRoutesForTerms, - tequilapi_endpoints.AddEntertainmentRoutes(entertainment.NewEstimator( - config.FlagPaymentPriceGiB.Value, - config.FlagPaymentPriceHour.Value, - )), - tequilapi_endpoints.AddRoutesForValidator, - }, + handlers, ) } diff --git a/cmd/di.go b/cmd/di.go index 726fbfa90c..114b9f936a 100644 --- a/cmd/di.go +++ b/cmd/di.go @@ -952,10 +952,6 @@ func (di *Dependencies) bootstrapQualityComponents(options node.OptionsQuality, return err } - if nodeOptions.ProvChecker { - // di.provPinger = connection.NewProviderChecker(di.EventBus) - } - return nil } diff --git a/core/connection/manager-diag.go b/core/connection/manager-diag.go index 1a984877db..c1f7a2abef 100644 --- a/core/connection/manager-diag.go +++ b/core/connection/manager-diag.go @@ -957,7 +957,6 @@ func (m *diagConnectionManager) sendKeepAlivePing(ctx context.Context, channel p return err } - _ = start m.eventBus.Publish(quality.AppTopicConsumerPingP2P, quality.PingEvent{ SessionID: string(sessionID), Duration: time.Since(start), diff --git a/core/connection/manager_test.go b/core/connection/manager_test.go index 61f7a5e5b4..4e7ec1da1d 100644 --- a/core/connection/manager_test.go +++ b/core/connection/manager_test.go @@ -61,7 +61,6 @@ type testContext struct { statsReportInterval time.Duration mockP2P *mockP2PDialer mockTime time.Time - sync.RWMutex } diff --git a/core/quality/metrics.go b/core/quality/metrics.go index 9beedfd6bb..75ac4268e0 100644 --- a/core/quality/metrics.go +++ b/core/quality/metrics.go @@ -117,7 +117,4 @@ const ( // AppTopicProviderPingP2P represents event bus topic for provider p2p pings to consumer. AppTopicProviderPingP2P = "provider_ping_p2p" - - // AppTopicConnectionDiagRes represents event bus topic for provider check result. - AppTopicConnectionDiagRes = "connection_diag" ) diff --git a/services/wireguard/connection/connection.go b/services/wireguard/connection/connection.go index 7909dfd815..690226df04 100644 --- a/services/wireguard/connection/connection.go +++ b/services/wireguard/connection/connection.go @@ -115,8 +115,6 @@ func (c *Connection) Reconnect(ctx context.Context, options connection.ConnectOp } func (c *Connection) start(ctx context.Context, start startConn, options connection.ConnectOptions) (err error) { - log.Info().Msg("+++++++++++++++++++++++++++++++++++++++++++++++++++++ *Connection) start") - var config wg.ServiceConfig if err = json.Unmarshal(options.SessionConfig, &config); err != nil { return errors.Wrap(err, "failed to unmarshal connection config") diff --git a/services/wireguard/endpoint/endpoint.go b/services/wireguard/endpoint/endpoint.go index c9b8c7bc73..ad1d1ddd28 100644 --- a/services/wireguard/endpoint/endpoint.go +++ b/services/wireguard/endpoint/endpoint.go @@ -88,8 +88,6 @@ func (ce *connectionEndpoint) StartConsumerMode(cfg wgcfg.DeviceConfig) error { } return errors.Wrap(err, "could not configure device") } - - // ce.wgClient.Diag() return nil } diff --git a/tequilapi/endpoints/connection-diag.go b/tequilapi/endpoints/connection-diag.go index 3ff08bc4f8..5c144a2000 100644 --- a/tequilapi/endpoints/connection-diag.go +++ b/tequilapi/endpoints/connection-diag.go @@ -19,11 +19,13 @@ package endpoints import ( "fmt" + "sort" "github.com/ethereum/go-ethereum/common" "github.com/gin-gonic/gin" "github.com/pkg/errors" "github.com/rs/zerolog/log" + "gvisor.dev/gvisor/pkg/sync" "github.com/mysteriumnetwork/go-rest/apierror" "github.com/mysteriumnetwork/node/config" @@ -51,11 +53,13 @@ type ConnectionDiagEndpoint struct { identityRegistry identityRegistry addressProvider addressProvider identitySelector selector.Handler + + consumerAddress string } // NewConnectionDiagEndpoint creates and returns connection endpoint func NewConnectionDiagEndpoint(manager connection.DiagManager, stateProvider stateProvider, proposalRepository proposalRepository, identityRegistry identityRegistry, publisher eventbus.Publisher, subscriber eventbus.Subscriber, addressProvider addressProvider, identitySelector selector.Handler) *ConnectionDiagEndpoint { - return &ConnectionDiagEndpoint{ + ce := &ConnectionDiagEndpoint{ manager: manager, publisher: publisher, subscriber: subscriber, @@ -65,19 +69,153 @@ func NewConnectionDiagEndpoint(manager connection.DiagManager, stateProvider sta addressProvider: addressProvider, identitySelector: identitySelector, } -} - -// Diag is used to start provider check -func (ce *ConnectionDiagEndpoint) Diag(c *gin.Context) { - log.Debug().Msgf("Diag >>>") chainID := config.GetInt64(config.FlagChainID) consumerID_, err := ce.identitySelector.UseOrCreate(config.FlagIdentity.Value, config.FlagIdentityPassphrase.Value, chainID) if err != nil { - c.Error(apierror.Internal("Failed to unlock identity", err.Error())) + panic(err) + } + log.Error().Msgf("Unlocked identity: %v", consumerID_.Address) + ce.consumerAddress = consumerID_.Address + + return ce +} + +func dedupeSortedStrings(s []string) []string { + if len(s) < 2 { + return s + } + var e = 1 + for i := 1; i < len(s); i++ { + if s[i] == s[i-1] { + continue + } + s[e] = s[i] + e++ + } + + return s[:e] +} + +// DiagBatch is used to start a given providers check (batch mode) +func (ce *ConnectionDiagEndpoint) DiagBatch(c *gin.Context) { + hermes, err := ce.addressProvider.GetActiveHermes(config.GetInt64(config.FlagChainID)) + if err != nil { + c.Error(apierror.Internal("Failed to get active hermes", contract.ErrCodeActiveHermes)) return } - log.Error().Msgf("Unlocked identity: %v", consumerID_) + + provs := make([]string, 0) + c.Bind(&provs) + sort.Strings(provs) + provs = dedupeSortedStrings(provs) + + var ( + wg sync.WaitGroup + mu sync.Mutex + ) + resultMap := make(map[string]contract.ConnectionDiagInfoDTO, len(provs)) + wg.Add(len(provs)) + + for _, prov := range provs { + go func(prov string) { + result := contract.ConnectionDiagInfoDTO{ + ProviderID: prov, + } + defer func() { + mu.Lock() + resultMap[prov] = result + mu.Unlock() + + wg.Done() + }() + + cr := &contract.ConnectionCreateRequest{ + ConsumerID: ce.consumerAddress, + ProviderID: prov, + Filter: contract.ConnectionCreateFilter{IncludeMonitoringFailed: true}, + HermesID: hermes.Hex(), + ServiceType: "wireguard", + ConnectOptions: contract.ConnectOptions{}, + } + if err := cr.Validate(); err != nil { + result.Error = err + return + } + + consumerID := identity.FromAddress(cr.ConsumerID) + status, err := ce.identityRegistry.GetRegistrationStatus(config.GetInt64(config.FlagChainID), consumerID) + if err != nil { + log.Error().Err(err).Stack().Msg("Could not check registration status") + result.Error = contract.ErrCodeIDRegistrationCheck + return + } + switch status { + case registry.Unregistered, registry.RegistrationError, registry.Unknown: + log.Error().Msgf("Identity %q is not registered, aborting...", cr.ConsumerID) + result.Error = contract.ErrCodeIDNotRegistered + return + case registry.InProgress: + log.Info().Msgf("identity %q registration is in progress, continuing...", cr.ConsumerID) + case registry.Registered: + log.Info().Msgf("identity %q is registered, continuing...", cr.ConsumerID) + default: + log.Error().Msgf("identity %q has unknown status, aborting...", cr.ConsumerID) + result.Error = contract.ErrCodeIDStatusUnknown + return + } + + if len(cr.ProviderID) > 0 { + cr.Filter.Providers = append(cr.Filter.Providers, cr.ProviderID) + } + f := &proposal.Filter{ + ServiceType: cr.ServiceType, + LocationCountry: cr.Filter.CountryCode, + ProviderIDs: cr.Filter.Providers, + IPType: cr.Filter.IPType, + IncludeMonitoringFailed: cr.Filter.IncludeMonitoringFailed, + AccessPolicy: "all", + } + proposalLookup := connection.FilteredProposals(f, cr.Filter.SortBy, ce.proposalRepository) + + if ce.manager.HasConnection(cr.ProviderID) { + result.Error = contract.ErrCodeConnectionAlreadyExists + return + } + + err = ce.manager.Connect(consumerID, common.HexToAddress(cr.HermesID), proposalLookup, getConnectOptions(cr)) + if err != nil { + switch err { + case connection.ErrAlreadyExists: + result.Error = contract.ErrCodeConnectionAlreadyExists + case connection.ErrConnectionCancelled: + result.Error = contract.ErrCodeConnectionCancelled + default: + log.Error().Err(err).Msgf("Failed to connect: %v", prov) + result.Error = contract.ErrCodeConnect + } + return + } + + resChannel := ce.manager.GetReadyChan(cr.ProviderID) + res := <-resChannel + log.Error().Msgf("Result > %v", res) + result.Status = res.(quality.DiagEvent).Result + + }(prov) + } + wg.Wait() + + out := make([]contract.ConnectionDiagInfoDTO, 0) + for _, prov := range provs { + out = append(out, resultMap[prov]) + } + utils.WriteAsJSON(out, c.Writer) +} + +// Diag is used to start a given provider check +func (ce *ConnectionDiagEndpoint) Diag(c *gin.Context) { + log.Debug().Msgf("Diag >>>") hermes, err := ce.addressProvider.GetActiveHermes(config.GetInt64(config.FlagChainID)) if err != nil { @@ -91,14 +229,13 @@ func (ce *ConnectionDiagEndpoint) Diag(c *gin.Context) { return } cr := &contract.ConnectionCreateRequest{ - ConsumerID: consumerID_.Address, + ConsumerID: ce.consumerAddress, ProviderID: prov, Filter: contract.ConnectionCreateFilter{IncludeMonitoringFailed: true}, HermesID: hermes.Hex(), ServiceType: "wireguard", ConnectOptions: contract.ConnectOptions{}, } - if err := cr.Validate(); err != nil { c.Error(err) return @@ -111,7 +248,6 @@ func (ce *ConnectionDiagEndpoint) Diag(c *gin.Context) { c.Error(apierror.Internal("Failed to check ID registration status: "+err.Error(), contract.ErrCodeIDRegistrationCheck)) return } - switch status { case registry.Unregistered, registry.RegistrationError, registry.Unknown: log.Error().Msgf("Identity %q is not registered, aborting...", cr.ConsumerID) @@ -130,7 +266,6 @@ func (ce *ConnectionDiagEndpoint) Diag(c *gin.Context) { if len(cr.ProviderID) > 0 { cr.Filter.Providers = append(cr.Filter.Providers, cr.ProviderID) } - f := &proposal.Filter{ ServiceType: cr.ServiceType, LocationCountry: cr.Filter.CountryCode, @@ -141,8 +276,7 @@ func (ce *ConnectionDiagEndpoint) Diag(c *gin.Context) { } proposalLookup := connection.FilteredProposals(f, cr.Filter.SortBy, ce.proposalRepository) - hasConnection := ce.manager.HasConnection(cr.ProviderID) - if hasConnection { + if ce.manager.HasConnection(cr.ProviderID) { c.Error(apierror.Unprocessable("Connection already exists", contract.ErrCodeConnectionAlreadyExists)) return } @@ -179,18 +313,17 @@ func AddRoutesForConnectionDiag( proposalRepository proposalRepository, identityRegistry identityRegistry, publisher eventbus.Publisher, - publisher2 eventbus.Subscriber, + subscriber eventbus.Subscriber, addressProvider addressProvider, identitySelector selector.Handler, options node.Options, ) func(*gin.Engine) error { - ConnectionDiagEndpoint := NewConnectionDiagEndpoint(manager, stateProvider, proposalRepository, identityRegistry, publisher, publisher2, addressProvider, identitySelector) + ConnectionDiagEndpoint := NewConnectionDiagEndpoint(manager, stateProvider, proposalRepository, identityRegistry, publisher, subscriber, addressProvider, identitySelector) return func(e *gin.Engine) error { connGroup := e.Group("") { - if options.ProvChecker { - connGroup.GET("/prov-checker", ConnectionDiagEndpoint.Diag) - } + connGroup.GET("/prov-checker", ConnectionDiagEndpoint.Diag) + connGroup.POST("/prov-checker-batch", ConnectionDiagEndpoint.DiagBatch) } return nil } From e5781361a9379a04e7ce960f2043d285470fe2b0 Mon Sep 17 00:00:00 2001 From: Anton Litvinov Date: Mon, 15 Jul 2024 12:36:56 +0400 Subject: [PATCH 4/6] Make batch mode usable for global scan of nodes Signed-off-by: Anton Litvinov --- core/connection/interface.go | 2 +- core/connection/manager-diag.go | 22 +- core/connection/pinger.go | 7 +- core/quality/metrics.go | 1 + go.mod | 10 +- go.sum | 20 ++ services/wireguard/connection/connection.go | 2 +- .../wireguard/connection/connection_test.go | 4 +- services/wireguard/endpoint.go | 3 +- .../wireguard/endpoint/diagclient/client.go | 17 +- services/wireguard/endpoint/endpoint.go | 4 +- services/wireguard/endpoint/wg_client.go | 2 +- tequilapi/contract/connection.go | 6 +- tequilapi/endpoints/connection-diag.go | 199 +++++++++++++++++- 14 files changed, 267 insertions(+), 32 deletions(-) diff --git a/core/connection/interface.go b/core/connection/interface.go index dbbb8b2245..bcf6a4b24c 100644 --- a/core/connection/interface.go +++ b/core/connection/interface.go @@ -41,7 +41,7 @@ type Connection interface { // ConnectionDiag is a specialised Connection interface for provider check type ConnectionDiag interface { - Diag() bool + Diag() error } // StateChannel is the channel we receive state change events on diff --git a/core/connection/manager-diag.go b/core/connection/manager-diag.go index c1f7a2abef..d8a061a047 100644 --- a/core/connection/manager-diag.go +++ b/core/connection/manager-diag.go @@ -28,6 +28,7 @@ import ( "github.com/ethereum/go-ethereum/common" "github.com/rs/zerolog/log" + "golang.org/x/time/rate" "github.com/mysteriumnetwork/node/config" "github.com/mysteriumnetwork/node/core/connection/connectionstate" @@ -91,6 +92,8 @@ type diagConnectionManager struct { // populated by Connect at runtime. connsMu sync.Mutex conns map[string]*conn + + ratelimiter *rate.Limiter } // NewDiagManager creates connection manager with given dependencies @@ -120,6 +123,8 @@ func NewDiagManager( validator: validator, p2pDialer: p2pDialer, timeGetter: time.Now, + + ratelimiter: rate.NewLimiter(rate.Every(1000*time.Millisecond), 1), } m.eventBus.SubscribeAsync(connectionstate.AppTopicConnectionState, m.reconnectOnHold) @@ -153,6 +158,13 @@ func (m *diagConnectionManager) GetReadyChan(providerID string) chan interface{} func (m *diagConnectionManager) Connect(consumerID identity.Identity, hermesID common.Address, proposalLookup ProposalLookup, params ConnectParams) (err error) { var sessionID session.ID + ctx := context.Background() + err = m.ratelimiter.Wait(ctx) // This is a blocking call. Honors the rate limit + if err != nil { + log.Error().Msgf("ratelimiter.Wait: %s", err) + return err + } + proposal, err := proposalLookup() if err != nil { return fmt.Errorf("failed to lookup proposal: %w", err) @@ -164,8 +176,10 @@ func (m *diagConnectionManager) Connect(consumerID identity.Identity, hermesID c log.Debug().Msgf("Consumer connection trace: %s", traceResult) }() - fmt.Println("Connect>", proposal.ProviderID) + log.Error().Msgf("Connect > %v", proposal.ProviderID) uuid := proposal.ProviderID + + m.connsMu.Lock() con, ok := m.conns[uuid] if !ok { con = new(conn) @@ -173,6 +187,8 @@ func (m *diagConnectionManager) Connect(consumerID identity.Identity, hermesID c con.uuid = uuid m.conns[uuid] = con } + m.connsMu.Unlock() + removeConnection := func() { m.connsMu.Lock() defer m.connsMu.Unlock() @@ -933,7 +949,9 @@ func (m *diagConnectionManager) keepAliveLoop(con *conn, channel p2p.Channel, se if config.GetBool(config.FlagKeepConnectedOnFail) { m.statusOnHold(con) } else { - m.Disconnect() + //m.Disconnect() + log.Error().Msgf("Max p2p keepalive err count reached, disconnecting. SessionID=%s >>>>>>>>>", sessionID) + m.DisconnectSingle(con) } cancel() return diff --git a/core/connection/pinger.go b/core/connection/pinger.go index 6ececc5437..202ccf70bb 100644 --- a/core/connection/pinger.go +++ b/core/connection/pinger.go @@ -25,14 +25,17 @@ import ( // Diag is used to start provider check func Diag(cm *diagConnectionManager, con *conn, providerID string) { c, ok := con.activeConnection.(ConnectionDiag) - res := false + res := error(nil) if ok { log.Debug().Msgf("Check provider> %v", providerID) res = c.Diag() cm.DisconnectSingle(con) } - ev := quality.DiagEvent{ProviderID: providerID, Result: res} + ev := quality.DiagEvent{ProviderID: providerID, Result: res == nil} + if res != nil { + ev.Error = res + } con.resChannel <- ev close(con.resChannel) } diff --git a/core/quality/metrics.go b/core/quality/metrics.go index 75ac4268e0..ae85bc2c47 100644 --- a/core/quality/metrics.go +++ b/core/quality/metrics.go @@ -106,6 +106,7 @@ type PingEvent struct { type DiagEvent struct { ProviderID string Result bool + Error error } const ( diff --git a/go.mod b/go.mod index 7de7350f57..eba0ee5878 100644 --- a/go.mod +++ b/go.mod @@ -153,11 +153,15 @@ require ( github.com/imdario/mergo v0.3.12 // indirect github.com/ipfs/go-cid v0.4.1 // indirect github.com/ipfs/go-log/v2 v2.5.1 // indirect + github.com/jackc/pgpassfile v1.0.0 // indirect + github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect + github.com/jackc/pgx/v5 v5.5.5 // indirect + github.com/jackc/puddle/v2 v2.2.1 // indirect github.com/jackpal/go-nat-pmp v1.0.2 // indirect github.com/jbenet/go-context v0.0.0-20150711004518-d14ea06fba99 // indirect github.com/jbenet/go-temp-err-catcher v0.1.0 // indirect github.com/jinzhu/gorm v1.9.2 // indirect - github.com/jinzhu/inflection v0.0.0-20180308033659-04140366298a // indirect + github.com/jinzhu/inflection v1.0.0 // indirect github.com/jinzhu/now v1.1.5 // indirect github.com/jmespath/go-jmespath v0.4.0 // indirect github.com/josharian/native v0.0.0-20200817173448-b6b71def0850 // indirect @@ -183,6 +187,7 @@ require ( github.com/mattn/go-colorable v0.1.13 // indirect github.com/mattn/go-isatty v0.0.20 // indirect github.com/mattn/go-pointer v0.0.1 // indirect + github.com/mattn/go-sqlite3 v1.14.22 // indirect github.com/matttproud/golang_protobuf_extensions/v2 v2.0.0 // indirect github.com/mdlayher/genetlink v1.1.0 // indirect github.com/mdlayher/netlink v1.4.2 // indirect @@ -265,6 +270,9 @@ require ( gopkg.in/intercom/intercom-go.v2 v2.0.0-20210504094731-2bd1af0ce4b2 // indirect gopkg.in/warnings.v0 v0.1.2 // indirect gopkg.in/yaml.v3 v3.0.1 // indirect + gorm.io/driver/postgres v1.5.9 // indirect + gorm.io/driver/sqlite v1.5.6 // indirect + gorm.io/gorm v1.25.11 // indirect honnef.co/go/tools v0.4.2 // indirect lukechampine.com/blake3 v1.2.1 // indirect rsc.io/tmplfunc v0.0.3 // indirect diff --git a/go.sum b/go.sum index 5ecaaf07cb..2e9595a91b 100644 --- a/go.sum +++ b/go.sum @@ -926,6 +926,14 @@ github.com/ipfs/go-detect-race v0.0.1 h1:qX/xay2W3E4Q1U7d9lNs1sU9nvguX0a7319XbyQ github.com/ipfs/go-detect-race v0.0.1/go.mod h1:8BNT7shDZPo99Q74BpGMK+4D8Mn4j46UU0LZ723meps= github.com/ipfs/go-log/v2 v2.5.1 h1:1XdUzF7048prq4aBjDQQ4SL5RxftpRGdXhNRwKSAlcY= github.com/ipfs/go-log/v2 v2.5.1/go.mod h1:prSpmC1Gpllc9UYWxDiZDreBYw7zp4Iqp1kOLU9U5UI= +github.com/jackc/pgpassfile v1.0.0 h1:/6Hmqy13Ss2zCq62VdNG8tM1wchn8zjSGOBJ6icpsIM= +github.com/jackc/pgpassfile v1.0.0/go.mod h1:CEx0iS5ambNFdcRtxPj5JhEz+xB6uRky5eyVu/W2HEg= +github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a h1:bbPeKD0xmW/Y25WS6cokEszi5g+S0QxI/d45PkRi7Nk= +github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a/go.mod h1:5TJZWKEWniPve33vlWYSoGYefn3gLQRzjfDlhSJ9ZKM= +github.com/jackc/pgx/v5 v5.5.5 h1:amBjrZVmksIdNjxGW/IiIMzxMKZFelXbUoPNb+8sjQw= +github.com/jackc/pgx/v5 v5.5.5/go.mod h1:ez9gk+OAat140fv9ErkZDYFWmXLfV+++K0uAOiwgm1A= +github.com/jackc/puddle/v2 v2.2.1 h1:RhxXJtFG022u4ibrCSMSiu5aOq1i77R3OHKNJj77OAk= +github.com/jackc/puddle/v2 v2.2.1/go.mod h1:vriiEXHvEE654aYKXXjOvZM39qJ0q+azkZFrfEOc3H4= github.com/jackpal/gateway v1.0.6 h1:/MJORKvJEwNVldtGVJC2p2cwCnsSoLn3hl3zxmZT7tk= github.com/jackpal/gateway v1.0.6/go.mod h1:lTpwd4ACLXmpyiCTRtfiNyVnUmqT9RivzCDQetPfnjA= github.com/jackpal/go-nat-pmp v1.0.2 h1:KzKSgb7qkJvOUTqYl9/Hg/me3pWgBmERKrTGD7BdWus= @@ -943,6 +951,8 @@ github.com/jinzhu/gorm v1.9.2 h1:lCvgEaqe/HVE+tjAR2mt4HbbHAZsQOv3XAZiEZV37iw= github.com/jinzhu/gorm v1.9.2/go.mod h1:Vla75njaFJ8clLU1W44h34PjIkijhjHIYnZxMqCdxqo= github.com/jinzhu/inflection v0.0.0-20180308033659-04140366298a h1:eeaG9XMUvRBYXJi4pg1ZKM7nxc5AfXfojeLLW7O5J3k= github.com/jinzhu/inflection v0.0.0-20180308033659-04140366298a/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= +github.com/jinzhu/inflection v1.0.0 h1:K317FqzuhWc8YvSVlFMCCUb36O/S9MCKRDI7QkRKD/E= +github.com/jinzhu/inflection v1.0.0/go.mod h1:h+uFLlag+Qp1Va5pdKtLDYj+kHp5pxUVkryuEj+Srlc= github.com/jinzhu/now v1.1.5 h1:/o9tlHleP7gOFmsnYNz3RGnqzefHA47wQpKrrdTIwXQ= github.com/jinzhu/now v1.1.5/go.mod h1:d3SSVoowX0Lcu0IBviAWJpolVfI5UJVZZ7cO71lE/z8= github.com/jmespath/go-jmespath v0.4.0 h1:BEgLn5cpjn8UN1mAw4NjwDrS35OdebyEtFe+9YPoQUg= @@ -1074,6 +1084,8 @@ github.com/mattn/go-runewidth v0.0.13 h1:lTGmDsbAYt5DmK6OnoV7EuIF1wEIFAcxld6ypU4 github.com/mattn/go-runewidth v0.0.13/go.mod h1:Jdepj2loyihRzMpdS35Xk/zdY8IAYHsh153qUoGf23w= github.com/mattn/go-sqlite3 v1.10.0 h1:jbhqpg7tQe4SupckyijYiy0mJJ/pRyHvXf7JdWK860o= github.com/mattn/go-sqlite3 v1.10.0/go.mod h1:FPy6KqzDD04eiIsT53CuJW3U88zkxoIYsOqkbpncsNc= +github.com/mattn/go-sqlite3 v1.14.22 h1:2gZY6PC6kBnID23Tichd1K+Z0oS6nE/XwU+Vz/5o4kU= +github.com/mattn/go-sqlite3 v1.14.22/go.mod h1:Uh1q+B4BYcTPb+yiD3kU8Ct7aC0hY9fxUwlHK0RXw+Y= github.com/matttproud/golang_protobuf_extensions v1.0.1/go.mod h1:D8He9yQNgCq6Z5Ld7szi9bcBfOoFv/3dc6xSMkL2PC0= github.com/matttproud/golang_protobuf_extensions/v2 v2.0.0 h1:jWpvCLoY8Z/e3VKvlsiIGKtc+UG6U5vzxaoagmhXfyg= github.com/matttproud/golang_protobuf_extensions/v2 v2.0.0/go.mod h1:QUyp042oQthUoa9bqDv0ER0wrtXnBruoNd7aNjkbP+k= @@ -2247,6 +2259,14 @@ gopkg.in/yaml.v3 v3.0.0-20200615113413-eeeca48fe776/go.mod h1:K4uyk7z7BCEPqu6E+C gopkg.in/yaml.v3 v3.0.0-20210107192922-496545a6307b/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= +gorm.io/driver/postgres v1.5.9 h1:DkegyItji119OlcaLjqN11kHoUgZ/j13E0jkJZgD6A8= +gorm.io/driver/postgres v1.5.9/go.mod h1:DX3GReXH+3FPWGrrgffdvCk3DQ1dwDPdmbenSkweRGI= +gorm.io/driver/sqlite v1.5.6 h1:fO/X46qn5NUEEOZtnjJRWRzZMe8nqJiQ9E+0hi+hKQE= +gorm.io/driver/sqlite v1.5.6/go.mod h1:U+J8craQU6Fzkcvu8oLeAQmi50TkwPEhHDEjQZXDah4= +gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde h1:9DShaph9qhkIYw7QF91I/ynrr4cOO2PZra2PFD7Mfeg= +gorm.io/gorm v1.25.7-0.20240204074919-46816ad31dde/go.mod h1:hbnx/Oo0ChWMn1BIhpy1oYozzpM15i4YPuHDmfYtwg8= +gorm.io/gorm v1.25.11 h1:/Wfyg1B/je1hnDx3sMkX+gAlxrlZpn6X0BXRlwXlvHg= +gorm.io/gorm v1.25.11/go.mod h1:xh7N7RHfYlNc5EmcI/El95gXusucDrQnHXe0+CgWcLQ= gotest.tools v2.2.0+incompatible/go.mod h1:DsYFclhRJ6vuDpmuTbkuFWG+y2sxOXAzmJt81HFBacw= gotest.tools/v3 v3.4.0/go.mod h1:CtbdzLSsqVhDgMtKsx03ird5YTGB3ar27v0u/yKBW5g= grpc.go4.org v0.0.0-20170609214715-11d0a25b4919/go.mod h1:77eQGdRu53HpSqPFJFmuJdjuHRquDANNeA4x7B8WQ9o= diff --git a/services/wireguard/connection/connection.go b/services/wireguard/connection/connection.go index 690226df04..0a292591b2 100644 --- a/services/wireguard/connection/connection.go +++ b/services/wireguard/connection/connection.go @@ -87,7 +87,7 @@ func (c *Connection) State() <-chan connectionstate.State { } // Diag is used to start provider check -func (c *Connection) Diag() bool { +func (c *Connection) Diag() error { return c.connectionEndpoint.Diag() } diff --git a/services/wireguard/connection/connection_test.go b/services/wireguard/connection/connection_test.go index ad82dbc3a8..07af4e923e 100644 --- a/services/wireguard/connection/connection_test.go +++ b/services/wireguard/connection/connection_test.go @@ -158,8 +158,8 @@ func (mce *mockConnectionEndpoint) ConfigureRoutes(_ net.IP) error { retur func (mce *mockConnectionEndpoint) PeerStats() (wgcfg.Stats, error) { return wgcfg.Stats{LastHandshake: time.Now(), BytesSent: 10, BytesReceived: 11}, nil } -func (mce *mockConnectionEndpoint) Diag() bool { - return true +func (mce *mockConnectionEndpoint) Diag() error { + return nil } type mockHandshakeWaiter struct { diff --git a/services/wireguard/endpoint.go b/services/wireguard/endpoint.go index e6df362a67..ea3e3d9960 100644 --- a/services/wireguard/endpoint.go +++ b/services/wireguard/endpoint.go @@ -34,5 +34,6 @@ type ConnectionEndpoint interface { Config() (ServiceConfig, error) InterfaceName() string Stop() error - Diag() bool + + Diag() error } diff --git a/services/wireguard/endpoint/diagclient/client.go b/services/wireguard/endpoint/diagclient/client.go index 646ff7f14e..6e50d23759 100644 --- a/services/wireguard/endpoint/diagclient/client.go +++ b/services/wireguard/endpoint/diagclient/client.go @@ -19,6 +19,7 @@ package diagclient import ( "bufio" + "errors" "fmt" "io" "net/http" @@ -126,25 +127,29 @@ func (c *client) Close() (err error) { return nil } -func (c *client) Diag() bool { +func (c *client) Diag() error { client := http.Client{ Transport: &http.Transport{ DialContext: c.tnet.DialContext, }, + Timeout: 15 * time.Second, } - resp, err := client.Get("http://1.1.1.1/") + resp, err := client.Get("http://107.173.23.19:8080/test") if err != nil { log.Error().Err(err).Msg("Get failed") - return false + return err } defer resp.Body.Close() body, err := io.ReadAll(resp.Body) if err != nil { log.Error().Err(err).Msg("Readall failed") - return false + return err + } + if len(body) < 6 { + log.Error().Msg("Wrong length") + return errors.New("Wrong body length") } - _ = body - return true + return nil } diff --git a/services/wireguard/endpoint/endpoint.go b/services/wireguard/endpoint/endpoint.go index ad1d1ddd28..5dd43a0800 100644 --- a/services/wireguard/endpoint/endpoint.go +++ b/services/wireguard/endpoint/endpoint.go @@ -52,12 +52,12 @@ type connectionEndpoint struct { wgClient WgClient } -func (ce *connectionEndpoint) Diag() bool { +func (ce *connectionEndpoint) Diag() error { c, ok := ce.wgClient.(WgClientDiag) if ok { return c.Diag() } - return false + return nil } // StartConsumerMode starts and configure wireguard network interface running in consumer mode. diff --git a/services/wireguard/endpoint/wg_client.go b/services/wireguard/endpoint/wg_client.go index 74cd0f6111..22dd36781e 100644 --- a/services/wireguard/endpoint/wg_client.go +++ b/services/wireguard/endpoint/wg_client.go @@ -46,7 +46,7 @@ type WgClient interface { // WgClientDiag is a specialised WgClient interface for provider check type WgClientDiag interface { - Diag() bool + Diag() error } // WgClientFactory represents WireGuard client factory. diff --git a/tequilapi/contract/connection.go b/tequilapi/contract/connection.go index 226e9fdb37..835805c254 100644 --- a/tequilapi/contract/connection.go +++ b/tequilapi/contract/connection.go @@ -54,9 +54,9 @@ func NewConnectionInfoDTO(session connectionstate.Status) ConnectionInfoDTO { // ConnectionDiagInfoDTO holds provider check result // swagger:model ConnectionDiagInfoDTO type ConnectionDiagInfoDTO struct { - Status bool `json:"status"` - Error interface{} `json:"error"` - ProviderID string `json:"provider_id"` + ProviderID string `json:"provider_id"` + Error string `json:"error"` + DiagError string `json:"diag_err"` } // ConnectionInfoDTO holds partial consumer connection details. diff --git a/tequilapi/endpoints/connection-diag.go b/tequilapi/endpoints/connection-diag.go index 5c144a2000..e719aeab88 100644 --- a/tequilapi/endpoints/connection-diag.go +++ b/tequilapi/endpoints/connection-diag.go @@ -25,6 +25,9 @@ import ( "github.com/gin-gonic/gin" "github.com/pkg/errors" "github.com/rs/zerolog/log" + + "gorm.io/driver/postgres" + "gorm.io/gorm" "gvisor.dev/gvisor/pkg/sync" "github.com/mysteriumnetwork/go-rest/apierror" @@ -55,6 +58,8 @@ type ConnectionDiagEndpoint struct { identitySelector selector.Handler consumerAddress string + + db *gorm.DB } // NewConnectionDiagEndpoint creates and returns connection endpoint @@ -78,6 +83,17 @@ func NewConnectionDiagEndpoint(manager connection.DiagManager, stateProvider sta log.Error().Msgf("Unlocked identity: %v", consumerID_.Address) ce.consumerAddress = consumerID_.Address + dsn := "host=____ user=mypguser password=___ dbname=myst_nodes port=5432 sslmode=disable" + db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{}) + if err != nil { + panic(err) + } + + ce.db = db + if err != nil { + panic(err) + } + return ce } @@ -139,7 +155,7 @@ func (ce *ConnectionDiagEndpoint) DiagBatch(c *gin.Context) { ConnectOptions: contract.ConnectOptions{}, } if err := cr.Validate(); err != nil { - result.Error = err + result.Error = err.Error() return } @@ -147,13 +163,13 @@ func (ce *ConnectionDiagEndpoint) DiagBatch(c *gin.Context) { status, err := ce.identityRegistry.GetRegistrationStatus(config.GetInt64(config.FlagChainID), consumerID) if err != nil { log.Error().Err(err).Stack().Msg("Could not check registration status") - result.Error = contract.ErrCodeIDRegistrationCheck + result.Error = (contract.ErrCodeIDRegistrationCheck) return } switch status { case registry.Unregistered, registry.RegistrationError, registry.Unknown: log.Error().Msgf("Identity %q is not registered, aborting...", cr.ConsumerID) - result.Error = contract.ErrCodeIDNotRegistered + result.Error = (contract.ErrCodeIDNotRegistered) return case registry.InProgress: log.Info().Msgf("identity %q registration is in progress, continuing...", cr.ConsumerID) @@ -161,7 +177,7 @@ func (ce *ConnectionDiagEndpoint) DiagBatch(c *gin.Context) { log.Info().Msgf("identity %q is registered, continuing...", cr.ConsumerID) default: log.Error().Msgf("identity %q has unknown status, aborting...", cr.ConsumerID) - result.Error = contract.ErrCodeIDStatusUnknown + result.Error = (contract.ErrCodeIDStatusUnknown) return } @@ -179,7 +195,7 @@ func (ce *ConnectionDiagEndpoint) DiagBatch(c *gin.Context) { proposalLookup := connection.FilteredProposals(f, cr.Filter.SortBy, ce.proposalRepository) if ce.manager.HasConnection(cr.ProviderID) { - result.Error = contract.ErrCodeConnectionAlreadyExists + result.Error = (contract.ErrCodeConnectionAlreadyExists) return } @@ -187,12 +203,12 @@ func (ce *ConnectionDiagEndpoint) DiagBatch(c *gin.Context) { if err != nil { switch err { case connection.ErrAlreadyExists: - result.Error = contract.ErrCodeConnectionAlreadyExists + result.Error = (contract.ErrCodeConnectionAlreadyExists) case connection.ErrConnectionCancelled: - result.Error = contract.ErrCodeConnectionCancelled + result.Error = (contract.ErrCodeConnectionCancelled) default: log.Error().Err(err).Msgf("Failed to connect: %v", prov) - result.Error = contract.ErrCodeConnect + result.Error = (contract.ErrCodeConnect) } return } @@ -200,7 +216,6 @@ func (ce *ConnectionDiagEndpoint) DiagBatch(c *gin.Context) { resChannel := ce.manager.GetReadyChan(cr.ProviderID) res := <-resChannel log.Error().Msgf("Result > %v", res) - result.Status = res.(quality.DiagEvent).Result }(prov) } @@ -301,11 +316,174 @@ func (ce *ConnectionDiagEndpoint) Diag(c *gin.Context) { resp := contract.ConnectionDiagInfoDTO{ ProviderID: prov, - Status: res.(quality.DiagEvent).Result, } utils.WriteAsJSON(resp, c.Writer) } +type proposalDB struct { + ID string + Error string + DiagError string `json:"diag_error"` + Country string +} + +func (proposalDB) TableName() string { + return "node" +} + +// DiagBatch is used to start a given providers check (batch mode) +func (ce *ConnectionDiagEndpoint) DiagBatch2(c *gin.Context) { + + hermes, err := ce.addressProvider.GetActiveHermes(config.GetInt64(config.FlagChainID)) + if err != nil { + c.Error(apierror.Internal("Failed to get active hermes", contract.ErrCodeActiveHermes)) + return + } + + country := c.Query("location") + f := &proposal.Filter{ + ServiceType: "wireguard", + LocationCountry: country, + ExcludeUnsupported: true, + IncludeMonitoringFailed: true, + } + pp, err := ce.proposalRepository.Proposals(f) + if err != nil { + log.Error().Err(err).Stack().Msg("Proposals>") + } + log.Error().Msgf("pp> %v", len(pp)) + + var ( + wg sync.WaitGroup + mu sync.Mutex + ) + resultMap := make(map[string]contract.ConnectionDiagInfoDTO, len(pp)) + wg.Add(len(pp)) + + maxGoroutines := 15 + guard := make(chan struct{}, maxGoroutines) + + for _, pr := range pp { + guard <- struct{}{} + + worker := func(provID string) (result contract.ConnectionDiagInfoDTO) { + result.ProviderID = provID + + cr := &contract.ConnectionCreateRequest{ + ConsumerID: ce.consumerAddress, + ProviderID: provID, + Filter: contract.ConnectionCreateFilter{IncludeMonitoringFailed: true}, + HermesID: hermes.Hex(), + ServiceType: "wireguard", + ConnectOptions: contract.ConnectOptions{}, + } + if err := cr.Validate(); err != nil { + result.Error = err.Error() + return + } + + consumerID := identity.FromAddress(cr.ConsumerID) + status, err := ce.identityRegistry.GetRegistrationStatus(config.GetInt64(config.FlagChainID), consumerID) + if err != nil { + log.Error().Err(err).Stack().Msg("Could not check registration status") + result.Error = (contract.ErrCodeIDRegistrationCheck) + return + } + switch status { + case registry.Unregistered, registry.RegistrationError, registry.Unknown: + log.Error().Msgf("Identity %q is not registered, aborting...", cr.ConsumerID) + result.Error = (contract.ErrCodeIDNotRegistered) + return + case registry.InProgress: + log.Info().Msgf("identity %q registration is in progress, continuing...", cr.ConsumerID) + case registry.Registered: + log.Info().Msgf("identity %q is registered, continuing...", cr.ConsumerID) + default: + log.Error().Msgf("identity %q has unknown status, aborting...", cr.ConsumerID) + result.Error = (contract.ErrCodeIDStatusUnknown) + return + } + + if len(cr.ProviderID) > 0 { + cr.Filter.Providers = append(cr.Filter.Providers, cr.ProviderID) + } + f := &proposal.Filter{ + ServiceType: cr.ServiceType, + LocationCountry: cr.Filter.CountryCode, + ProviderIDs: cr.Filter.Providers, + IPType: cr.Filter.IPType, + IncludeMonitoringFailed: cr.Filter.IncludeMonitoringFailed, + AccessPolicy: "all", + } + proposalLookup := connection.FilteredProposals(f, cr.Filter.SortBy, ce.proposalRepository) + + if ce.manager.HasConnection(cr.ProviderID) { + result.Error = (contract.ErrCodeConnectionAlreadyExists) + return + } + + err = ce.manager.Connect(consumerID, common.HexToAddress(cr.HermesID), proposalLookup, getConnectOptions(cr)) + if err != nil { + switch err { + case connection.ErrAlreadyExists: + result.Error = (contract.ErrCodeConnectionAlreadyExists) + case connection.ErrConnectionCancelled: + result.Error = (contract.ErrCodeConnectionCancelled) + default: + log.Error().Err(err).Msgf("Failed to connect: %v", provID) + result.Error = (contract.ErrCodeConnect) + } + return + } + + resChannel := ce.manager.GetReadyChan(cr.ProviderID) + res := <-resChannel + log.Error().Msgf("Result > %v", res) + + ev := res.(quality.DiagEvent) + // result.Status = ev.Result + if ev.Error != nil { + result.DiagError = ev.Error.Error() + } + + return + } + go func(pr proposal.PricedServiceProposal) { + + result := worker(pr.ProviderID) + + mu.Lock() + resultMap[pr.ProviderID] = result + mu.Unlock() + + // update + provRec := proposalDB{ID: result.ProviderID, Country: pr.Location.Country} + provRec.Error = "" + provRec.DiagError = "" + provRec.Error = result.Error + provRec.DiagError = result.DiagError + if ce.db.Model(&provRec).Select("Error", "DiagError", "Country").Updates(provRec).RowsAffected == 0 { + ce.db.Create(&provRec) + } + + wg.Done() + <-guard + }(pr) + + } + wg.Wait() + + out := make([]contract.ConnectionDiagInfoDTO, 0) + for _, prov := range pp { + res := resultMap[prov.ProviderID] + if res.Error != "" || res.DiagError != "" { + out = append(out, resultMap[prov.ProviderID]) + } + + } + utils.WriteAsJSON(out, c.Writer) +} + // AddRoutesForConnectionDiag adds proder check route to given router func AddRoutesForConnectionDiag( manager connection.DiagManager, @@ -324,6 +502,7 @@ func AddRoutesForConnectionDiag( { connGroup.GET("/prov-checker", ConnectionDiagEndpoint.Diag) connGroup.POST("/prov-checker-batch", ConnectionDiagEndpoint.DiagBatch) + connGroup.GET("/prov-checker-batch2", ConnectionDiagEndpoint.DiagBatch2) } return nil } From d89c3064224be35ddaf128c64ecd1d4bc3a5b8ab Mon Sep 17 00:00:00 2001 From: Anton Litvinov Date: Wed, 17 Jul 2024 16:28:07 +0400 Subject: [PATCH 5/6] Fix linter warnings Signed-off-by: Anton Litvinov --- services/wireguard/endpoint.go | 1 - services/wireguard/service/service_test.go | 4 ++-- tequilapi/endpoints/connection-diag.go | 3 +-- 3 files changed, 3 insertions(+), 5 deletions(-) diff --git a/services/wireguard/endpoint.go b/services/wireguard/endpoint.go index ea3e3d9960..3085434720 100644 --- a/services/wireguard/endpoint.go +++ b/services/wireguard/endpoint.go @@ -34,6 +34,5 @@ type ConnectionEndpoint interface { Config() (ServiceConfig, error) InterfaceName() string Stop() error - Diag() error } diff --git a/services/wireguard/service/service_test.go b/services/wireguard/service/service_test.go index aa4742a17e..f378e3fada 100644 --- a/services/wireguard/service/service_test.go +++ b/services/wireguard/service/service_test.go @@ -153,8 +153,8 @@ func (mce *mockConnectionEndpoint) ConfigureRoutes(_ net.IP) error { retur func (mce *mockConnectionEndpoint) PeerStats() (wgcfg.Stats, error) { return wgcfg.Stats{LastHandshake: time.Now()}, nil } -func (mce *mockConnectionEndpoint) Diag() bool { - return true +func (mce *mockConnectionEndpoint) Diag() error { + return nil } func newManagerStub(pub, out, country string) *Manager { diff --git a/tequilapi/endpoints/connection-diag.go b/tequilapi/endpoints/connection-diag.go index e719aeab88..6e5c20152e 100644 --- a/tequilapi/endpoints/connection-diag.go +++ b/tequilapi/endpoints/connection-diag.go @@ -25,7 +25,6 @@ import ( "github.com/gin-gonic/gin" "github.com/pkg/errors" "github.com/rs/zerolog/log" - "gorm.io/driver/postgres" "gorm.io/gorm" "gvisor.dev/gvisor/pkg/sync" @@ -331,7 +330,7 @@ func (proposalDB) TableName() string { return "node" } -// DiagBatch is used to start a given providers check (batch mode) +// DiagBatch2 is used to start a check of providers from a given country or all countries func (ce *ConnectionDiagEndpoint) DiagBatch2(c *gin.Context) { hermes, err := ce.addressProvider.GetActiveHermes(config.GetInt64(config.FlagChainID)) From cdc46441449adcd9bd7a028cab8022e8773e010d Mon Sep 17 00:00:00 2001 From: Anton Litvinov Date: Thu, 18 Jul 2024 18:09:43 +0400 Subject: [PATCH 6/6] Add new build target; Add cli flag --checker.dsn for database storage of found nodes Signed-off-by: Anton Litvinov --- ci/packages/build.go | 28 +- ci/packages/package.go | 4 +- config/flags_network.go | 9 + tequilapi/endpoints/connection-diag-empty.go | 46 +++ tequilapi/endpoints/connection-diag.go | 372 ++++++++----------- 5 files changed, 229 insertions(+), 230 deletions(-) create mode 100644 tequilapi/endpoints/connection-diag-empty.go diff --git a/ci/packages/build.go b/ci/packages/build.go index 19c851d88a..d2a7c2d6b1 100644 --- a/ci/packages/build.go +++ b/ci/packages/build.go @@ -32,16 +32,31 @@ import ( "github.com/rs/zerolog/log" ) +// BuildProvChecker builds myst binary with provider checker API. Like go tool, it supports cross-platform build with env vars: GOOS, GOARCH. +func BuildProvChecker() error { + logconfig.Bootstrap() + if err := buildBinary(path.Join("cmd", "mysterium_node", "mysterium_node.go"), "myst", true); err != nil { + return err + } + if err := copyConfig("myst"); err != nil { + return err + } + if err := buildBinary(path.Join("cmd", "supervisor", "supervisor.go"), "myst_supervisor", false); err != nil { + return err + } + return nil +} + // Build builds the project. Like go tool, it supports cross-platform build with env vars: GOOS, GOARCH. func Build() error { logconfig.Bootstrap() - if err := buildBinary(path.Join("cmd", "mysterium_node", "mysterium_node.go"), "myst"); err != nil { + if err := buildBinary(path.Join("cmd", "mysterium_node", "mysterium_node.go"), "myst", false); err != nil { return err } if err := copyConfig("myst"); err != nil { return err } - if err := buildBinary(path.Join("cmd", "supervisor", "supervisor.go"), "myst_supervisor"); err != nil { + if err := buildBinary(path.Join("cmd", "supervisor", "supervisor.go"), "myst_supervisor", false); err != nil { return err } return nil @@ -67,7 +82,7 @@ func buildCrossBinary(os, arch string) error { return sh.Run("bin/build_xgo", os+"/"+arch) } -func buildBinary(source, target string) error { +func buildBinary(source, target string, provChecker bool) error { targetOS, ok := os.LookupEnv("GOOS") if !ok { targetOS = runtime.GOOS @@ -76,10 +91,10 @@ func buildBinary(source, target string) error { if !ok { targetArch = runtime.GOARCH } - return buildBinaryFor(source, target, targetOS, targetArch, nil, false) + return buildBinaryFor(source, target, targetOS, targetArch, nil, false, provChecker) } -func buildBinaryFor(source, target, targetOS, targetArch string, extraEnvs map[string]string, buildStatic bool) error { +func buildBinaryFor(source, target, targetOS, targetArch string, extraEnvs map[string]string, buildStatic, provChecker bool) error { log.Info().Msgf("Building %s -> %s %s/%s", source, target, targetOS, targetArch) buildDir, err := filepath.Abs(path.Join("build", target)) @@ -100,6 +115,9 @@ func buildBinaryFor(source, target, targetOS, targetArch string, extraEnvs map[s if buildStatic { flags = append(flags, "-a", "-tags", "netgo") } + if provChecker { + flags = append(flags, "-tags", "prov_checker") + } if targetOS == "windows" { target += ".exe" diff --git a/ci/packages/package.go b/ci/packages/package.go index 2e156bb124..5c165f0493 100644 --- a/ci/packages/package.go +++ b/ci/packages/package.go @@ -380,7 +380,7 @@ func packageStandalone(binaryPath, os, arch string, extraEnvs map[string]string) if os == "linux" { filename := path.Base(binaryPath) binaryPath = path.Join("build", filename, filename) - err = buildBinaryFor(path.Join("cmd", "mysterium_node", "mysterium_node.go"), filename, os, arch, extraEnvs, true) + err = buildBinaryFor(path.Join("cmd", "mysterium_node", "mysterium_node.go"), filename, os, arch, extraEnvs, true, false) } else { err = buildCrossBinary(os, arch) } @@ -388,7 +388,7 @@ func packageStandalone(binaryPath, os, arch string, extraEnvs map[string]string) return err } - err = buildBinaryFor(path.Join("cmd", "supervisor", "supervisor.go"), "myst_supervisor", os, arch, extraEnvs, true) + err = buildBinaryFor(path.Join("cmd", "supervisor", "supervisor.go"), "myst_supervisor", os, arch, extraEnvs, true, false) if err != nil { return err } diff --git a/config/flags_network.go b/config/flags_network.go index ce3255566b..b6f991d00c 100644 --- a/config/flags_network.go +++ b/config/flags_network.go @@ -154,6 +154,13 @@ var ( Usage: "DNS listen port for services", Value: 11253, } + + // FlagProvCheckerDatabaseDSN sets DNS for checker API database + FlagProvCheckerDatabaseDSN = cli.StringFlag{ + Name: "checker.dsn", + Usage: "Database DSN for provider checker", + Value: "", + } ) // RegisterFlagsNetwork function register network flags to flag list @@ -179,6 +186,7 @@ func RegisterFlagsNetwork(flags *[]cli.Flag) { &FlagPortCheckServers, &FlagStatsReportInterval, &FlagDNSListenPort, + &FlagProvCheckerDatabaseDSN, ) } @@ -203,6 +211,7 @@ func ParseFlagsNetwork(ctx *cli.Context) { Current.ParseStringFlag(ctx, FlagPortCheckServers) Current.ParseDurationFlag(ctx, FlagStatsReportInterval) Current.ParseIntFlag(ctx, FlagDNSListenPort) + Current.ParseStringFlag(ctx, FlagProvCheckerDatabaseDSN) } // BlockchainNetwork defines a blockchain network diff --git a/tequilapi/endpoints/connection-diag-empty.go b/tequilapi/endpoints/connection-diag-empty.go new file mode 100644 index 0000000000..ce8d94f720 --- /dev/null +++ b/tequilapi/endpoints/connection-diag-empty.go @@ -0,0 +1,46 @@ +//go:build !prov_checker + +/* + * Copyright (C) 2024 The "MysteriumNetwork/node" Authors. + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package endpoints + +import ( + "github.com/gin-gonic/gin" + + "github.com/mysteriumnetwork/node/core/connection" + "github.com/mysteriumnetwork/node/core/node" + "github.com/mysteriumnetwork/node/eventbus" + "github.com/mysteriumnetwork/node/identity/selector" +) + +// AddRoutesForConnectionDiag adds proder check route to given router +func AddRoutesForConnectionDiag( + manager connection.DiagManager, + stateProvider stateProvider, + proposalRepository proposalRepository, + identityRegistry identityRegistry, + publisher eventbus.Publisher, + subscriber eventbus.Subscriber, + addressProvider addressProvider, + identitySelector selector.Handler, + options node.Options, +) func(*gin.Engine) error { + return func(e *gin.Engine) error { + return nil + } +} diff --git a/tequilapi/endpoints/connection-diag.go b/tequilapi/endpoints/connection-diag.go index 6e5c20152e..2de6151349 100644 --- a/tequilapi/endpoints/connection-diag.go +++ b/tequilapi/endpoints/connection-diag.go @@ -1,3 +1,5 @@ +//go:build prov_checker + /* * Copyright (C) 2024 The "MysteriumNetwork/node" Authors. * @@ -82,17 +84,14 @@ func NewConnectionDiagEndpoint(manager connection.DiagManager, stateProvider sta log.Error().Msgf("Unlocked identity: %v", consumerID_.Address) ce.consumerAddress = consumerID_.Address - dsn := "host=____ user=mypguser password=___ dbname=myst_nodes port=5432 sslmode=disable" - db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{}) - if err != nil { - panic(err) - } - - ce.db = db - if err != nil { - panic(err) + dsn := config.GetString(config.FlagProvCheckerDatabaseDSN) + if dsn != "" { + db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{}) + if err != nil { + panic(err) + } + ce.db = db } - return ce } @@ -112,6 +111,96 @@ func dedupeSortedStrings(s []string) []string { return s[:e] } +// Diag is used to start a given provider check +func (ce *ConnectionDiagEndpoint) Diag(c *gin.Context) { + hermes, err := ce.addressProvider.GetActiveHermes(config.GetInt64(config.FlagChainID)) + if err != nil { + c.Error(apierror.Internal("Failed to get active hermes", contract.ErrCodeActiveHermes)) + return + } + + prov := c.Query("id") + if len(prov) == 0 { + c.Error(errors.New("Empty prameter: prov")) + return + } + cr := &contract.ConnectionCreateRequest{ + ConsumerID: ce.consumerAddress, + ProviderID: prov, + Filter: contract.ConnectionCreateFilter{IncludeMonitoringFailed: true}, + HermesID: hermes.Hex(), + ServiceType: "wireguard", + ConnectOptions: contract.ConnectOptions{}, + } + if err := cr.Validate(); err != nil { + c.Error(err) + return + } + + consumerID := identity.FromAddress(cr.ConsumerID) + status, err := ce.identityRegistry.GetRegistrationStatus(config.GetInt64(config.FlagChainID), consumerID) + if err != nil { + log.Error().Err(err).Stack().Msg("Could not check registration status") + c.Error(apierror.Internal("Failed to check ID registration status: "+err.Error(), contract.ErrCodeIDRegistrationCheck)) + return + } + switch status { + case registry.Unregistered, registry.RegistrationError, registry.Unknown: + log.Error().Msgf("Identity %q is not registered, aborting...", cr.ConsumerID) + c.Error(apierror.Unprocessable(fmt.Sprintf("Identity %q is not registered. Please register the identity first", cr.ConsumerID), contract.ErrCodeIDNotRegistered)) + return + case registry.InProgress: + log.Info().Msgf("identity %q registration is in progress, continuing...", cr.ConsumerID) + case registry.Registered: + log.Info().Msgf("identity %q is registered, continuing...", cr.ConsumerID) + default: + log.Error().Msgf("identity %q has unknown status, aborting...", cr.ConsumerID) + c.Error(apierror.Unprocessable(fmt.Sprintf("Identity %q has unknown status. Aborting", cr.ConsumerID), contract.ErrCodeIDStatusUnknown)) + return + } + + if len(cr.ProviderID) > 0 { + cr.Filter.Providers = append(cr.Filter.Providers, cr.ProviderID) + } + f := &proposal.Filter{ + ServiceType: cr.ServiceType, + LocationCountry: cr.Filter.CountryCode, + ProviderIDs: cr.Filter.Providers, + IPType: cr.Filter.IPType, + IncludeMonitoringFailed: cr.Filter.IncludeMonitoringFailed, + AccessPolicy: "all", + } + proposalLookup := connection.FilteredProposals(f, cr.Filter.SortBy, ce.proposalRepository) + + if ce.manager.HasConnection(cr.ProviderID) { + c.Error(apierror.Unprocessable("Connection already exists", contract.ErrCodeConnectionAlreadyExists)) + return + } + + err = ce.manager.Connect(consumerID, common.HexToAddress(cr.HermesID), proposalLookup, getConnectOptions(cr)) + if err != nil { + switch err { + case connection.ErrAlreadyExists: + c.Error(apierror.Unprocessable("Connection already exists", contract.ErrCodeConnectionAlreadyExists)) + case connection.ErrConnectionCancelled: + c.Error(apierror.Unprocessable("Connection cancelled", contract.ErrCodeConnectionCancelled)) + default: + log.Error().Err(err).Msg("Failed to connect") + c.Error(apierror.Internal("Failed to connect: "+err.Error(), contract.ErrCodeConnect)) + } + return + } + + resChannel := ce.manager.GetReadyChan(cr.ProviderID) + res := <-resChannel + log.Error().Msgf("Result > %v", res) + + resp := contract.ConnectionDiagInfoDTO{ + ProviderID: prov, + } + utils.WriteAsJSON(resp, c.Writer) +} + // DiagBatch is used to start a given providers check (batch mode) func (ce *ConnectionDiagEndpoint) DiagBatch(c *gin.Context) { hermes, err := ce.addressProvider.GetActiveHermes(config.GetInt64(config.FlagChainID)) @@ -132,91 +221,23 @@ func (ce *ConnectionDiagEndpoint) DiagBatch(c *gin.Context) { resultMap := make(map[string]contract.ConnectionDiagInfoDTO, len(provs)) wg.Add(len(provs)) - for _, prov := range provs { - go func(prov string) { - result := contract.ConnectionDiagInfoDTO{ - ProviderID: prov, - } - defer func() { - mu.Lock() - resultMap[prov] = result - mu.Unlock() - - wg.Done() - }() - - cr := &contract.ConnectionCreateRequest{ - ConsumerID: ce.consumerAddress, - ProviderID: prov, - Filter: contract.ConnectionCreateFilter{IncludeMonitoringFailed: true}, - HermesID: hermes.Hex(), - ServiceType: "wireguard", - ConnectOptions: contract.ConnectOptions{}, - } - if err := cr.Validate(); err != nil { - result.Error = err.Error() - return - } - - consumerID := identity.FromAddress(cr.ConsumerID) - status, err := ce.identityRegistry.GetRegistrationStatus(config.GetInt64(config.FlagChainID), consumerID) - if err != nil { - log.Error().Err(err).Stack().Msg("Could not check registration status") - result.Error = (contract.ErrCodeIDRegistrationCheck) - return - } - switch status { - case registry.Unregistered, registry.RegistrationError, registry.Unknown: - log.Error().Msgf("Identity %q is not registered, aborting...", cr.ConsumerID) - result.Error = (contract.ErrCodeIDNotRegistered) - return - case registry.InProgress: - log.Info().Msgf("identity %q registration is in progress, continuing...", cr.ConsumerID) - case registry.Registered: - log.Info().Msgf("identity %q is registered, continuing...", cr.ConsumerID) - default: - log.Error().Msgf("identity %q has unknown status, aborting...", cr.ConsumerID) - result.Error = (contract.ErrCodeIDStatusUnknown) - return - } + maxGoroutines := 15 + guard := make(chan struct{}, maxGoroutines) - if len(cr.ProviderID) > 0 { - cr.Filter.Providers = append(cr.Filter.Providers, cr.ProviderID) - } - f := &proposal.Filter{ - ServiceType: cr.ServiceType, - LocationCountry: cr.Filter.CountryCode, - ProviderIDs: cr.Filter.Providers, - IPType: cr.Filter.IPType, - IncludeMonitoringFailed: cr.Filter.IncludeMonitoringFailed, - AccessPolicy: "all", - } - proposalLookup := connection.FilteredProposals(f, cr.Filter.SortBy, ce.proposalRepository) + for _, pr := range provs { + guard <- struct{}{} - if ce.manager.HasConnection(cr.ProviderID) { - result.Error = (contract.ErrCodeConnectionAlreadyExists) - return - } + go func(providerID string) { + result := ce.worker(providerID, hermes) - err = ce.manager.Connect(consumerID, common.HexToAddress(cr.HermesID), proposalLookup, getConnectOptions(cr)) - if err != nil { - switch err { - case connection.ErrAlreadyExists: - result.Error = (contract.ErrCodeConnectionAlreadyExists) - case connection.ErrConnectionCancelled: - result.Error = (contract.ErrCodeConnectionCancelled) - default: - log.Error().Err(err).Msgf("Failed to connect: %v", prov) - result.Error = (contract.ErrCodeConnect) - } - return - } + mu.Lock() + resultMap[providerID] = result + mu.Unlock() - resChannel := ce.manager.GetReadyChan(cr.ProviderID) - res := <-resChannel - log.Error().Msgf("Result > %v", res) + wg.Done() + <-guard + }(pr) - }(prov) } wg.Wait() @@ -227,31 +248,30 @@ func (ce *ConnectionDiagEndpoint) DiagBatch(c *gin.Context) { utils.WriteAsJSON(out, c.Writer) } -// Diag is used to start a given provider check -func (ce *ConnectionDiagEndpoint) Diag(c *gin.Context) { - log.Debug().Msgf("Diag >>>") +type proposalDB struct { + ID string + Error string + DiagError string `json:"diag_error"` + Country string +} - hermes, err := ce.addressProvider.GetActiveHermes(config.GetInt64(config.FlagChainID)) - if err != nil { - c.Error(apierror.Internal("Failed to get active hermes", contract.ErrCodeActiveHermes)) - return - } +func (proposalDB) TableName() string { + return "node" +} + +func (ce *ConnectionDiagEndpoint) worker(provID string, hermes common.Address) (result contract.ConnectionDiagInfoDTO) { + result.ProviderID = provID - prov := c.Query("id") - if len(prov) == 0 { - c.Error(errors.New("Empty prameter: prov")) - return - } cr := &contract.ConnectionCreateRequest{ ConsumerID: ce.consumerAddress, - ProviderID: prov, + ProviderID: provID, Filter: contract.ConnectionCreateFilter{IncludeMonitoringFailed: true}, HermesID: hermes.Hex(), ServiceType: "wireguard", ConnectOptions: contract.ConnectOptions{}, } if err := cr.Validate(); err != nil { - c.Error(err) + result.Error = err.Error() return } @@ -259,13 +279,13 @@ func (ce *ConnectionDiagEndpoint) Diag(c *gin.Context) { status, err := ce.identityRegistry.GetRegistrationStatus(config.GetInt64(config.FlagChainID), consumerID) if err != nil { log.Error().Err(err).Stack().Msg("Could not check registration status") - c.Error(apierror.Internal("Failed to check ID registration status: "+err.Error(), contract.ErrCodeIDRegistrationCheck)) + result.Error = (contract.ErrCodeIDRegistrationCheck) return } switch status { case registry.Unregistered, registry.RegistrationError, registry.Unknown: log.Error().Msgf("Identity %q is not registered, aborting...", cr.ConsumerID) - c.Error(apierror.Unprocessable(fmt.Sprintf("Identity %q is not registered. Please register the identity first", cr.ConsumerID), contract.ErrCodeIDNotRegistered)) + result.Error = (contract.ErrCodeIDNotRegistered) return case registry.InProgress: log.Info().Msgf("identity %q registration is in progress, continuing...", cr.ConsumerID) @@ -273,7 +293,7 @@ func (ce *ConnectionDiagEndpoint) Diag(c *gin.Context) { log.Info().Msgf("identity %q is registered, continuing...", cr.ConsumerID) default: log.Error().Msgf("identity %q has unknown status, aborting...", cr.ConsumerID) - c.Error(apierror.Unprocessable(fmt.Sprintf("Identity %q has unknown status. Aborting", cr.ConsumerID), contract.ErrCodeIDStatusUnknown)) + result.Error = (contract.ErrCodeIDStatusUnknown) return } @@ -291,7 +311,7 @@ func (ce *ConnectionDiagEndpoint) Diag(c *gin.Context) { proposalLookup := connection.FilteredProposals(f, cr.Filter.SortBy, ce.proposalRepository) if ce.manager.HasConnection(cr.ProviderID) { - c.Error(apierror.Unprocessable("Connection already exists", contract.ErrCodeConnectionAlreadyExists)) + result.Error = (contract.ErrCodeConnectionAlreadyExists) return } @@ -299,12 +319,12 @@ func (ce *ConnectionDiagEndpoint) Diag(c *gin.Context) { if err != nil { switch err { case connection.ErrAlreadyExists: - c.Error(apierror.Unprocessable("Connection already exists", contract.ErrCodeConnectionAlreadyExists)) + result.Error = (contract.ErrCodeConnectionAlreadyExists) case connection.ErrConnectionCancelled: - c.Error(apierror.Unprocessable("Connection cancelled", contract.ErrCodeConnectionCancelled)) + result.Error = (contract.ErrCodeConnectionCancelled) default: - log.Error().Err(err).Msg("Failed to connect") - c.Error(apierror.Internal("Failed to connect: "+err.Error(), contract.ErrCodeConnect)) + log.Error().Err(err).Msgf("Failed to connect: %v", provID) + result.Error = (contract.ErrCodeConnect) } return } @@ -313,26 +333,16 @@ func (ce *ConnectionDiagEndpoint) Diag(c *gin.Context) { res := <-resChannel log.Error().Msgf("Result > %v", res) - resp := contract.ConnectionDiagInfoDTO{ - ProviderID: prov, + ev := res.(quality.DiagEvent) + if ev.Error != nil { + result.DiagError = ev.Error.Error() } - utils.WriteAsJSON(resp, c.Writer) -} -type proposalDB struct { - ID string - Error string - DiagError string `json:"diag_error"` - Country string -} - -func (proposalDB) TableName() string { - return "node" + return } // DiagBatch2 is used to start a check of providers from a given country or all countries func (ce *ConnectionDiagEndpoint) DiagBatch2(c *gin.Context) { - hermes, err := ce.addressProvider.GetActiveHermes(config.GetInt64(config.FlagChainID)) if err != nil { c.Error(apierror.Internal("Failed to get active hermes", contract.ErrCodeActiveHermes)) @@ -365,104 +375,24 @@ func (ce *ConnectionDiagEndpoint) DiagBatch2(c *gin.Context) { for _, pr := range pp { guard <- struct{}{} - worker := func(provID string) (result contract.ConnectionDiagInfoDTO) { - result.ProviderID = provID - - cr := &contract.ConnectionCreateRequest{ - ConsumerID: ce.consumerAddress, - ProviderID: provID, - Filter: contract.ConnectionCreateFilter{IncludeMonitoringFailed: true}, - HermesID: hermes.Hex(), - ServiceType: "wireguard", - ConnectOptions: contract.ConnectOptions{}, - } - if err := cr.Validate(); err != nil { - result.Error = err.Error() - return - } - - consumerID := identity.FromAddress(cr.ConsumerID) - status, err := ce.identityRegistry.GetRegistrationStatus(config.GetInt64(config.FlagChainID), consumerID) - if err != nil { - log.Error().Err(err).Stack().Msg("Could not check registration status") - result.Error = (contract.ErrCodeIDRegistrationCheck) - return - } - switch status { - case registry.Unregistered, registry.RegistrationError, registry.Unknown: - log.Error().Msgf("Identity %q is not registered, aborting...", cr.ConsumerID) - result.Error = (contract.ErrCodeIDNotRegistered) - return - case registry.InProgress: - log.Info().Msgf("identity %q registration is in progress, continuing...", cr.ConsumerID) - case registry.Registered: - log.Info().Msgf("identity %q is registered, continuing...", cr.ConsumerID) - default: - log.Error().Msgf("identity %q has unknown status, aborting...", cr.ConsumerID) - result.Error = (contract.ErrCodeIDStatusUnknown) - return - } - - if len(cr.ProviderID) > 0 { - cr.Filter.Providers = append(cr.Filter.Providers, cr.ProviderID) - } - f := &proposal.Filter{ - ServiceType: cr.ServiceType, - LocationCountry: cr.Filter.CountryCode, - ProviderIDs: cr.Filter.Providers, - IPType: cr.Filter.IPType, - IncludeMonitoringFailed: cr.Filter.IncludeMonitoringFailed, - AccessPolicy: "all", - } - proposalLookup := connection.FilteredProposals(f, cr.Filter.SortBy, ce.proposalRepository) - - if ce.manager.HasConnection(cr.ProviderID) { - result.Error = (contract.ErrCodeConnectionAlreadyExists) - return - } - - err = ce.manager.Connect(consumerID, common.HexToAddress(cr.HermesID), proposalLookup, getConnectOptions(cr)) - if err != nil { - switch err { - case connection.ErrAlreadyExists: - result.Error = (contract.ErrCodeConnectionAlreadyExists) - case connection.ErrConnectionCancelled: - result.Error = (contract.ErrCodeConnectionCancelled) - default: - log.Error().Err(err).Msgf("Failed to connect: %v", provID) - result.Error = (contract.ErrCodeConnect) - } - return - } - - resChannel := ce.manager.GetReadyChan(cr.ProviderID) - res := <-resChannel - log.Error().Msgf("Result > %v", res) - - ev := res.(quality.DiagEvent) - // result.Status = ev.Result - if ev.Error != nil { - result.DiagError = ev.Error.Error() - } - - return - } go func(pr proposal.PricedServiceProposal) { - result := worker(pr.ProviderID) + result := ce.worker(pr.ProviderID, hermes) mu.Lock() resultMap[pr.ProviderID] = result mu.Unlock() - // update - provRec := proposalDB{ID: result.ProviderID, Country: pr.Location.Country} - provRec.Error = "" - provRec.DiagError = "" - provRec.Error = result.Error - provRec.DiagError = result.DiagError - if ce.db.Model(&provRec).Select("Error", "DiagError", "Country").Updates(provRec).RowsAffected == 0 { - ce.db.Create(&provRec) + if ce.db != nil { + // update + provRec := proposalDB{ID: result.ProviderID, Country: pr.Location.Country} + provRec.Error = "" + provRec.DiagError = "" + provRec.Error = result.Error + provRec.DiagError = result.DiagError + if ce.db.Model(&provRec).Select("Error", "DiagError", "Country").Updates(provRec).RowsAffected == 0 { + ce.db.Create(&provRec) + } } wg.Done() @@ -474,11 +404,7 @@ func (ce *ConnectionDiagEndpoint) DiagBatch2(c *gin.Context) { out := make([]contract.ConnectionDiagInfoDTO, 0) for _, prov := range pp { - res := resultMap[prov.ProviderID] - if res.Error != "" || res.DiagError != "" { - out = append(out, resultMap[prov.ProviderID]) - } - + out = append(out, resultMap[prov.ProviderID]) } utils.WriteAsJSON(out, c.Writer) }