diff --git a/core/cli/explorer.go b/core/cli/explorer.go index f3e3618de18..67d25304165 100644 --- a/core/cli/explorer.go +++ b/core/cli/explorer.go @@ -14,6 +14,9 @@ type ExplorerCMD struct { PoolDatabase string `env:"LOCALAI_POOL_DATABASE,POOL_DATABASE" default:"explorer.json" help:"Path to the pool database" group:"api"` ConnectionTimeout string `env:"LOCALAI_CONNECTION_TIMEOUT,CONNECTION_TIMEOUT" default:"2m" help:"Connection timeout for the explorer" group:"api"` ConnectionErrorThreshold int `env:"LOCALAI_CONNECTION_ERROR_THRESHOLD,CONNECTION_ERROR_THRESHOLD" default:"3" help:"Connection failure threshold for the explorer" group:"api"` + + WithSync bool `env:"LOCALAI_WITH_SYNC,WITH_SYNC" default:"false" help:"Enable sync with the network" group:"api"` + OnlySync bool `env:"LOCALAI_ONLY_SYNC,ONLY_SYNC" default:"false" help:"Only sync with the network" group:"api"` } func (e *ExplorerCMD) Run(ctx *cliContext.Context) error { @@ -27,10 +30,20 @@ func (e *ExplorerCMD) Run(ctx *cliContext.Context) error { if err != nil { return err } - ds := explorer.NewDiscoveryServer(db, dur, e.ConnectionErrorThreshold) - go ds.Start(context.Background()) - appHTTP := http.Explorer(db, ds) + if e.WithSync { + ds := explorer.NewDiscoveryServer(db, dur, e.ConnectionErrorThreshold) + go ds.Start(context.Background(), true) + } + + if e.OnlySync { + ds := explorer.NewDiscoveryServer(db, dur, e.ConnectionErrorThreshold) + ctx := context.Background() + + return ds.Start(ctx, false) + } + + appHTTP := http.Explorer(db) return appHTTP.Listen(e.Address) } diff --git a/core/explorer/database.go b/core/explorer/database.go index 8535140c909..e24de0aad26 100644 --- a/core/explorer/database.go +++ b/core/explorer/database.go @@ -7,58 +7,83 @@ import ( "os" "sort" "sync" + + "github.com/gofrs/flock" ) // Database is a simple JSON database for storing and retrieving p2p network tokens and a name and description. type Database struct { - sync.RWMutex - path string - data map[string]TokenData + path string + data map[string]TokenData + flock *flock.Flock + sync.Mutex } // TokenData is a p2p network token with a name and description. type TokenData struct { Name string `json:"name"` Description string `json:"description"` + Clusters []ClusterData + Failures int +} + +type ClusterData struct { + Workers []string + Type string + NetworkID string } // NewDatabase creates a new Database with the given path. func NewDatabase(path string) (*Database, error) { + fileLock := flock.New(path + ".lock") db := &Database{ - data: make(map[string]TokenData), - path: path, + data: make(map[string]TokenData), + path: path, + flock: fileLock, } return db, db.load() } // Get retrieves a Token from the Database by its token. func (db *Database) Get(token string) (TokenData, bool) { - db.RLock() - defer db.RUnlock() + db.flock.Lock() // we are making sure that the file is not being written to + defer db.flock.Unlock() + db.Lock() // we are making sure that is safe if called by another instance in the same process + defer db.Unlock() + db.load() t, ok := db.data[token] return t, ok } // Set stores a Token in the Database by its token. func (db *Database) Set(token string, t TokenData) error { + db.flock.Lock() + defer db.flock.Unlock() db.Lock() + defer db.Unlock() + db.load() db.data[token] = t - db.Unlock() - return db.Save() + return db.save() } // Delete removes a Token from the Database by its token. func (db *Database) Delete(token string) error { + db.flock.Lock() + defer db.flock.Unlock() db.Lock() + defer db.Unlock() + db.load() delete(db.data, token) - db.Unlock() - return db.Save() + return db.save() } func (db *Database) TokenList() []string { - db.RLock() - defer db.RUnlock() + db.flock.Lock() + defer db.flock.Unlock() + db.Lock() + defer db.Unlock() + db.load() tokens := []string{} for k := range db.data { tokens = append(tokens, k) @@ -74,9 +99,6 @@ func (db *Database) TokenList() []string { // load reads the Database from disk. func (db *Database) load() error { - db.Lock() - defer db.Unlock() - if _, err := os.Stat(db.path); os.IsNotExist(err) { return nil } @@ -91,10 +113,7 @@ func (db *Database) load() error { } // Save writes the Database to disk. -func (db *Database) Save() error { - db.RLock() - defer db.RUnlock() - +func (db *Database) save() error { // Marshal db.data into JSON // Write the JSON to the file f, err := os.Create(db.path) diff --git a/core/explorer/discovery.go b/core/explorer/discovery.go index 6a29442fc30..fe6470cb825 100644 --- a/core/explorer/discovery.go +++ b/core/explorer/discovery.go @@ -16,22 +16,10 @@ import ( type DiscoveryServer struct { sync.Mutex database *Database - networkState *NetworkState connectionTime time.Duration - failures map[string]int errorThreshold int } -type NetworkState struct { - Networks map[string]Network -} - -func (s *DiscoveryServer) NetworkState() *NetworkState { - s.Lock() - defer s.Unlock() - return s.networkState -} - // NewDiscoveryServer creates a new DiscoveryServer with the given Database. // it keeps the db state in sync with the network state func NewDiscoveryServer(db *Database, dur time.Duration, failureThreshold int) *DiscoveryServer { @@ -44,11 +32,7 @@ func NewDiscoveryServer(db *Database, dur time.Duration, failureThreshold int) * return &DiscoveryServer{ database: db, connectionTime: dur, - networkState: &NetworkState{ - Networks: map[string]Network{}, - }, errorThreshold: failureThreshold, - failures: make(map[string]int), } } @@ -116,10 +100,10 @@ func (s *DiscoveryServer) runBackground() { if hasWorkers { s.Lock() - s.networkState.Networks[token] = Network{ - Clusters: ledgerK, - } - delete(s.failures, token) + data, _ := s.database.Get(token) + (&data).Clusters = ledgerK + (&data).Failures = 0 + s.database.Set(token, data) s.Unlock() } else { s.failedToken(token) @@ -132,27 +116,23 @@ func (s *DiscoveryServer) runBackground() { func (s *DiscoveryServer) failedToken(token string) { s.Lock() defer s.Unlock() - s.failures[token]++ + data, _ := s.database.Get(token) + (&data).Failures++ + s.database.Set(token, data) } func (s *DiscoveryServer) deleteFailedConnections() { s.Lock() defer s.Unlock() - for k, v := range s.failures { - if v > s.errorThreshold { - log.Info().Any("network", k).Msg("Network has been removed from the database") - s.database.Delete(k) - delete(s.failures, k) + for _, t := range s.database.TokenList() { + data, _ := s.database.Get(t) + if data.Failures > s.errorThreshold { + log.Info().Any("token", t).Msg("Token has been removed from the database") + s.database.Delete(t) } } } -type ClusterData struct { - Workers []string - Type string - NetworkID string -} - func (s *DiscoveryServer) retrieveNetworkData(c context.Context, ledger *blockchain.Ledger, networkData chan ClusterData) { clusters := map[string]ClusterData{} @@ -217,7 +197,7 @@ func (s *DiscoveryServer) retrieveNetworkData(c context.Context, ledger *blockch } // Start the discovery server. This is meant to be run in to a goroutine. -func (s *DiscoveryServer) Start(ctx context.Context) error { +func (s *DiscoveryServer) Start(ctx context.Context, keepRunning bool) error { for { select { case <-ctx.Done(): @@ -225,6 +205,9 @@ func (s *DiscoveryServer) Start(ctx context.Context) error { default: // Collect data s.runBackground() + if !keepRunning { + return nil + } } } } diff --git a/core/http/endpoints/explorer/dashboard.go b/core/http/endpoints/explorer/dashboard.go index 7cd9f3c9842..9c731d9a4f7 100644 --- a/core/http/endpoints/explorer/dashboard.go +++ b/core/http/endpoints/explorer/dashboard.go @@ -11,7 +11,6 @@ import ( func Dashboard() func(*fiber.Ctx) error { return func(c *fiber.Ctx) error { - summary := fiber.Map{ "Title": "LocalAI API - " + internal.PrintableVersion(), "Version": internal.PrintableVersion(), @@ -34,26 +33,24 @@ type AddNetworkRequest struct { } type Network struct { - explorer.Network explorer.TokenData Token string `json:"token"` } -func ShowNetworks(db *explorer.Database, ds *explorer.DiscoveryServer) func(*fiber.Ctx) error { +func ShowNetworks(db *explorer.Database) func(*fiber.Ctx) error { return func(c *fiber.Ctx) error { - networkState := ds.NetworkState() results := []Network{} - for token, network := range networkState.Networks { + for _, token := range db.TokenList() { networkData, exists := db.Get(token) // get the token data hasWorkers := false - for _, cluster := range network.Clusters { + for _, cluster := range networkData.Clusters { if len(cluster.Workers) > 0 { hasWorkers = true break } } if exists && hasWorkers { - results = append(results, Network{Network: network, TokenData: networkData, Token: token}) + results = append(results, Network{TokenData: networkData, Token: token}) } } diff --git a/core/http/explorer.go b/core/http/explorer.go index 608ecdb51b8..bdcb93b16d5 100644 --- a/core/http/explorer.go +++ b/core/http/explorer.go @@ -10,7 +10,7 @@ import ( "github.com/mudler/LocalAI/core/http/routes" ) -func Explorer(db *explorer.Database, discoveryServer *explorer.DiscoveryServer) *fiber.App { +func Explorer(db *explorer.Database) *fiber.App { fiberCfg := fiber.Config{ Views: renderEngine(), @@ -22,7 +22,7 @@ func Explorer(db *explorer.Database, discoveryServer *explorer.DiscoveryServer) app := fiber.New(fiberCfg) - routes.RegisterExplorerRoutes(app, db, discoveryServer) + routes.RegisterExplorerRoutes(app, db) httpFS := http.FS(embedDirStatic) diff --git a/core/http/routes/explorer.go b/core/http/routes/explorer.go index b3c0d40b995..960b476b8ff 100644 --- a/core/http/routes/explorer.go +++ b/core/http/routes/explorer.go @@ -6,8 +6,8 @@ import ( "github.com/mudler/LocalAI/core/http/endpoints/explorer" ) -func RegisterExplorerRoutes(app *fiber.App, db *coreExplorer.Database, ds *coreExplorer.DiscoveryServer) { +func RegisterExplorerRoutes(app *fiber.App, db *coreExplorer.Database) { app.Get("/", explorer.Dashboard()) app.Post("/network/add", explorer.AddNetwork(db)) - app.Get("/networks", explorer.ShowNetworks(db, ds)) + app.Get("/networks", explorer.ShowNetworks(db)) } diff --git a/core/p2p/p2p.go b/core/p2p/p2p.go index 37b892d9564..bfa12287baf 100644 --- a/core/p2p/p2p.go +++ b/core/p2p/p2p.go @@ -236,6 +236,7 @@ func ensureService(ctx context.Context, n *node.Node, nd *NodeData, sserv string if ndService, found := service[nd.Name]; !found { if !nd.IsOnline() { // if node is offline and not present, do nothing + zlog.Debug().Msgf("Node %s is offline", nd.ID) return } newCtxm, cancel := context.WithCancel(ctx) diff --git a/go.mod b/go.mod index b35db1b1120..dcece45cef1 100644 --- a/go.mod +++ b/go.mod @@ -67,6 +67,7 @@ require ( github.com/cpuguy83/go-md2man/v2 v2.0.4 // indirect github.com/go-task/slim-sprig/v3 v3.0.0 // indirect github.com/go-viper/mapstructure/v2 v2.0.0 // indirect + github.com/gofrs/flock v0.12.1 // indirect github.com/labstack/echo/v4 v4.12.0 // indirect github.com/labstack/gommon v0.4.2 // indirect github.com/moby/docker-image-spec v1.3.1 // indirect diff --git a/go.sum b/go.sum index 47fd4c0625b..db47c36bfec 100644 --- a/go.sum +++ b/go.sum @@ -204,6 +204,8 @@ github.com/gofiber/template/html/v2 v2.1.2 h1:wkK/mYJ3nIhongTkG3t0QgV4ADdgOYJYVS github.com/gofiber/template/html/v2 v2.1.2/go.mod h1:E98Z/FzvpaSib06aWEgYk6GXNf3ctoyaJH8yW5ay5ak= github.com/gofiber/utils v1.1.0 h1:vdEBpn7AzIUJRhe+CiTOJdUcTg4Q9RK+pEa0KPbLdrM= github.com/gofiber/utils v1.1.0/go.mod h1:poZpsnhBykfnY1Mc0KeEa6mSHrS3dV0+oBWyeQmb2e0= +github.com/gofrs/flock v0.12.1 h1:MTLVXXHf8ekldpJk3AKicLij9MdwOWkZ+a/jHHZby9E= +github.com/gofrs/flock v0.12.1/go.mod h1:9zxTsyu5xtJ9DK+1tFZyibEV7y3uwDxPPfbxeeHCoD0= github.com/gogo/protobuf v1.1.1/go.mod h1:r8qH/GZQm5c6nD/R0oafs1akxWv10x8SbQlK7atdtwQ= github.com/gogo/protobuf v1.3.1/go.mod h1:SlYgWuQ5SjCEi6WLHjHCa1yvBfUnHcTbrrZtXPKa29o= github.com/gogo/protobuf v1.3.2 h1:Ov1cvc58UF3b5XjBnZv7+opcTcQFZebYjWzi34vdm4Q=