diff --git a/management/server/activity/store/sql_store.go b/management/server/activity/store/sql_store.go index 80b1659385d..28cb0f615b6 100644 --- a/management/server/activity/store/sql_store.go +++ b/management/server/activity/store/sql_store.go @@ -10,6 +10,7 @@ import ( log "github.com/sirupsen/logrus" "gorm.io/driver/postgres" + "gorm.io/driver/mysql" "gorm.io/driver/sqlite" "gorm.io/gorm" "gorm.io/gorm/clause" @@ -30,6 +31,7 @@ const ( storeEngineEnv = "NB_ACTIVITY_EVENT_STORE_ENGINE" postgresDsnEnv = "NB_ACTIVITY_EVENT_POSTGRES_DSN" + mysqlDsnEnv = "NB_ACTIVITY_EVENT_MYSQL_DSN" sqlMaxOpenConnsEnv = "NB_SQL_MAX_OPEN_CONNS" ) @@ -254,6 +256,12 @@ func initDatabase(ctx context.Context, dataDir string) (*gorm.DB, error) { return nil, fmt.Errorf("%s environment variable not set", postgresDsnEnv) } dialector = postgres.Open(dsn) + case types.MysqlStoreEngine: + dsn, ok := os.LookupEnv(mysqlDsnEnv) + if !ok { + return nil, fmt.Errorf("%s environment variable not set", mysqlDsnEnv) + } + dialector = mysql.Open(dsn+"?charset=utf8&parseTime=True&loc=Local") default: return nil, fmt.Errorf("unsupported store engine: %s", storeEngine) } diff --git a/management/server/activity/store/sql_store_test.go b/management/server/activity/store/sql_store_test.go index 8c0d159df7b..3f207e2ab5f 100644 --- a/management/server/activity/store/sql_store_test.go +++ b/management/server/activity/store/sql_store_test.go @@ -9,49 +9,87 @@ import ( "github.com/stretchr/testify/assert" "github.com/netbirdio/netbird/management/server/activity" + "github.com/netbirdio/netbird/management/server/testutil" + "github.com/netbirdio/netbird/management/server/types" ) -func TestNewSqlStore(t *testing.T) { +var enginesToTest = []types.Engine{types.SqliteStoreEngine,types.PostgresStoreEngine,types.MysqlStoreEngine} + +func runTestForAllEngines(t *testing.T, test func( t *testing.T, store *Store )) { dataDir := t.TempDir() key, _ := GenerateKey() - store, err := NewSqlStore(context.Background(), dataDir, key) - if err != nil { - t.Fatal(err) - return - } - defer store.Close(context.Background()) //nolint - - accountID := "account_1" + + for _,engine := range enginesToTest { + t.Setenv("NB_ACTIVITY_EVENT_STORE_ENGINE",string(engine)) + switch engine { + case types.PostgresStoreEngine : + cleanup, dsn, err := testutil.CreatePostgresTestContainer() + if err != nil { + t.Fatalf("could not start Postgres container %s",err) + } + t.Cleanup(cleanup) + t.Setenv("NB_ACTIVITY_EVENT_POSTGRES_DSN",dsn) + case types.MysqlStoreEngine : + cleanup, dsn, err := testutil.CreateMysqlTestContainer() + if err != nil { + t.Fatalf("could not start MySQL container %s",err) + } + t.Cleanup(cleanup) + t.Setenv("NB_ACTIVITY_EVENT_MYSQL_DSN",dsn) + default: + t.Setenv("NB_ACTIVITY_EVENT_STORE_ENGINE",string(types.SqliteStoreEngine)) - for i := 0; i < 10; i++ { - _, err = store.Save(context.Background(), &activity.Event{ - Timestamp: time.Now().UTC(), - Activity: activity.PeerAddedByUser, - InitiatorID: "user_" + fmt.Sprint(i), - TargetID: "peer_" + fmt.Sprint(i), - AccountID: accountID, - }) + } + store, err := NewSqlStore(context.Background(), dataDir, key) if err != nil { t.Fatal(err) return } + assert.NoError(t,err) + t.Run(string(engine), func(t *testing.T) { + test(t, store) + }) } - result, err := store.Get(context.Background(), accountID, 0, 10, false) - if err != nil { - t.Fatal(err) - return - } +} + +func TestNewSqlStore(t *testing.T) { - assert.Len(t, result, 10) - assert.True(t, result[0].Timestamp.Before(result[len(result)-1].Timestamp)) + runTestForAllEngines(t, func(t *testing.T, store *Store){ + defer store.Close(context.Background()) //nolint - result, err = store.Get(context.Background(), accountID, 0, 5, true) - if err != nil { - t.Fatal(err) - return - } + accountID := "account_1" + + for i := range 10 { + _, err := store.Save(context.Background(), &activity.Event{ + Timestamp: time.Now().UTC(), + Activity: activity.PeerAddedByUser, + InitiatorID: "user_" + fmt.Sprint(i), + TargetID: "peer_" + fmt.Sprint(i), + AccountID: accountID, + }) + if err != nil { + t.Fatal(err) + return + } + } + + result, err := store.Get(context.Background(), accountID, 0, 10, false) + if err != nil { + t.Fatal(err) + return + } + + assert.Len(t, result, 10) + assert.True(t, result[0].Timestamp.Before(result[len(result)-1].Timestamp)) + + result, err = store.Get(context.Background(), accountID, 0, 5, true) + if err != nil { + t.Fatal(err) + return + } - assert.Len(t, result, 5) - assert.True(t, result[0].Timestamp.After(result[len(result)-1].Timestamp)) + assert.Len(t, result, 5) + assert.True(t, result[0].Timestamp.After(result[len(result)-1].Timestamp)) + }) }