Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 8 additions & 0 deletions management/server/activity/store/sql_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (

log "github.com/sirupsen/logrus"
"gorm.io/driver/postgres"
"gorm.io/driver/mysql"
"gorm.io/driver/sqlite"
"gorm.io/gorm"
"gorm.io/gorm/clause"
Expand All @@ -30,6 +31,7 @@ const (

storeEngineEnv = "NB_ACTIVITY_EVENT_STORE_ENGINE"
postgresDsnEnv = "NB_ACTIVITY_EVENT_POSTGRES_DSN"
mysqlDsnEnv = "NB_ACTIVITY_EVENT_MYSQL_DSN"
sqlMaxOpenConnsEnv = "NB_SQL_MAX_OPEN_CONNS"
)

Expand Down Expand Up @@ -254,6 +256,12 @@ func initDatabase(ctx context.Context, dataDir string) (*gorm.DB, error) {
return nil, fmt.Errorf("%s environment variable not set", postgresDsnEnv)
}
dialector = postgres.Open(dsn)
case types.MysqlStoreEngine:
dsn, ok := os.LookupEnv(mysqlDsnEnv)
if !ok {
return nil, fmt.Errorf("%s environment variable not set", mysqlDsnEnv)
}
dialector = mysql.Open(dsn+"?charset=utf8&parseTime=True&loc=Local")
default:
return nil, fmt.Errorf("unsupported store engine: %s", storeEngine)
}
Expand Down
100 changes: 69 additions & 31 deletions management/server/activity/store/sql_store_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -9,49 +9,87 @@ import (
"github.com/stretchr/testify/assert"

"github.com/netbirdio/netbird/management/server/activity"
"github.com/netbirdio/netbird/management/server/testutil"
"github.com/netbirdio/netbird/management/server/types"
)

func TestNewSqlStore(t *testing.T) {
var enginesToTest = []types.Engine{types.SqliteStoreEngine,types.PostgresStoreEngine,types.MysqlStoreEngine}

func runTestForAllEngines(t *testing.T, test func( t *testing.T, store *Store )) {
dataDir := t.TempDir()
key, _ := GenerateKey()
store, err := NewSqlStore(context.Background(), dataDir, key)
if err != nil {
t.Fatal(err)
return
}
defer store.Close(context.Background()) //nolint

accountID := "account_1"

for _,engine := range enginesToTest {
t.Setenv("NB_ACTIVITY_EVENT_STORE_ENGINE",string(engine))
switch engine {
case types.PostgresStoreEngine :
cleanup, dsn, err := testutil.CreatePostgresTestContainer()
if err != nil {
t.Fatalf("could not start Postgres container %s",err)
}
t.Cleanup(cleanup)
t.Setenv("NB_ACTIVITY_EVENT_POSTGRES_DSN",dsn)
case types.MysqlStoreEngine :
cleanup, dsn, err := testutil.CreateMysqlTestContainer()
if err != nil {
t.Fatalf("could not start MySQL container %s",err)
}
t.Cleanup(cleanup)
t.Setenv("NB_ACTIVITY_EVENT_MYSQL_DSN",dsn)
default:
t.Setenv("NB_ACTIVITY_EVENT_STORE_ENGINE",string(types.SqliteStoreEngine))

for i := 0; i < 10; i++ {
_, err = store.Save(context.Background(), &activity.Event{
Timestamp: time.Now().UTC(),
Activity: activity.PeerAddedByUser,
InitiatorID: "user_" + fmt.Sprint(i),
TargetID: "peer_" + fmt.Sprint(i),
AccountID: accountID,
})
}
store, err := NewSqlStore(context.Background(), dataDir, key)
if err != nil {
t.Fatal(err)
return
}
assert.NoError(t,err)
t.Run(string(engine), func(t *testing.T) {
test(t, store)
})
}

result, err := store.Get(context.Background(), accountID, 0, 10, false)
if err != nil {
t.Fatal(err)
return
}
}

func TestNewSqlStore(t *testing.T) {

assert.Len(t, result, 10)
assert.True(t, result[0].Timestamp.Before(result[len(result)-1].Timestamp))
runTestForAllEngines(t, func(t *testing.T, store *Store){
defer store.Close(context.Background()) //nolint

result, err = store.Get(context.Background(), accountID, 0, 5, true)
if err != nil {
t.Fatal(err)
return
}
accountID := "account_1"

for i := range 10 {
_, err := store.Save(context.Background(), &activity.Event{
Timestamp: time.Now().UTC(),
Activity: activity.PeerAddedByUser,
InitiatorID: "user_" + fmt.Sprint(i),
TargetID: "peer_" + fmt.Sprint(i),
AccountID: accountID,
})
if err != nil {
t.Fatal(err)
return
}
}

result, err := store.Get(context.Background(), accountID, 0, 10, false)
if err != nil {
t.Fatal(err)
return
}

assert.Len(t, result, 10)
assert.True(t, result[0].Timestamp.Before(result[len(result)-1].Timestamp))

result, err = store.Get(context.Background(), accountID, 0, 5, true)
if err != nil {
t.Fatal(err)
return
}

assert.Len(t, result, 5)
assert.True(t, result[0].Timestamp.After(result[len(result)-1].Timestamp))
assert.Len(t, result, 5)
assert.True(t, result[0].Timestamp.After(result[len(result)-1].Timestamp))
})
}
Loading