diff --git a/firewalldb/sql_migration.go b/firewalldb/sql_migration.go index 1e114c12c..422b4c084 100644 --- a/firewalldb/sql_migration.go +++ b/firewalldb/sql_migration.go @@ -67,6 +67,11 @@ func (e *kvEntry) namespacedKey() string { return ns } +// privacyPairs is a type alias for a map that holds the privacy pairs, where +// the outer key is the group ID, and the value is a map of real to pseudo +// values. +type privacyPairs = map[int64]map[string]string + // MigrateFirewallDBToSQL runs the migration of the firwalldb stores from the // bbolt database to a SQL database. The migration is done in a single // transaction to ensure that all rows in the stores are migrated or none at @@ -87,10 +92,14 @@ func MigrateFirewallDBToSQL(ctx context.Context, kvStore *bbolt.DB, return err } + err = migratePrivacyMapperDBToSQL(ctx, kvStore, sqlTx) + if err != nil { + return err + } + log.Infof("The rules DB has been migrated from KV to SQL.") - // TODO(viktor): Add migration for the privacy mapper and the action - // stores. + // TODO(viktor): Add migration for the action stores. return nil } @@ -490,3 +499,299 @@ func verifyBktKeys(bkt *bbolt.Bucket, errorOnKeyValues bool, return fmt.Errorf("unexpected key found: %s", key) }) } + +func migratePrivacyMapperDBToSQL(ctx context.Context, kvStore *bbolt.DB, + sqlTx SQLQueries) error { + + log.Infof("Starting migration of the privacy mapper store to SQL") + + // 1) Collect all privacy pairs from the KV store. + privPairs, err := collectPrivacyPairs(ctx, kvStore, sqlTx) + if err != nil { + return fmt.Errorf("error migrating privacy mapper store: %w", + err) + } + + // 2) Insert all collected privacy pairs into the SQL database. + err = insertPrivacyPairs(ctx, sqlTx, privPairs) + if err != nil { + return fmt.Errorf("insertion of privacy pairs failed: %w", err) + } + + // 3) Validate that all inserted privacy pairs match the original values + // in the KV store. Note that this is done after all values have been + // inserted, to ensure that the migration doesn't overwrite any values + // after they were inserted. + err = validatePrivacyPairsMigration(ctx, sqlTx, privPairs) + if err != nil { + return fmt.Errorf("migration validation of privacy pairs "+ + "failed: %w", err) + } + + log.Infof("Migration of the privacy mapper stores to SQL completed. "+ + "Total number of rows migrated: %d", len(privPairs)) + return nil +} + +// collectPrivacyPairs collects all privacy pairs from the KV store and +// returns them as the privacyPairs type alias. +func collectPrivacyPairs(ctx context.Context, kvStore *bbolt.DB, + sqlTx SQLQueries) (privacyPairs, error) { + + groupPairs := make(privacyPairs) + + return groupPairs, kvStore.View(func(kvTx *bbolt.Tx) error { + bkt := kvTx.Bucket(privacyBucketKey) + if bkt == nil { + // If we haven't generated any privacy bucket yet, + // we can skip the migration, as there are no privacy + // pairs to migrate. + return nil + } + + return bkt.ForEach(func(groupId, v []byte) error { + if v != nil { + return fmt.Errorf("expected only buckets "+ + "under %s bkt, but found value %s", + privacyBucketKey, v) + } + + gBkt := bkt.Bucket(groupId) + if gBkt == nil { + return fmt.Errorf("group bkt for group id "+ + "%s not found", groupId) + } + + groupSqlId, err := sqlTx.GetSessionIDByAlias( + ctx, groupId, + ) + if errors.Is(err, sql.ErrNoRows) { + return fmt.Errorf("session with group id %x "+ + "not found in sql db", groupId) + } else if err != nil { + return err + } + + groupRealToPseudoPairs, err := collectGroupPairs(gBkt) + if err != nil { + return fmt.Errorf("processing group bkt "+ + "for group id %s (sqlID %d) failed: %w", + groupId, groupSqlId, err) + } + + groupPairs[groupSqlId] = groupRealToPseudoPairs + + return nil + }) + }) +} + +// collectGroupPairs collects all privacy pairs for a specific session group, +// i.e. the group buckets under the privacy mapper bucket in the KV store. +// The function returns them as a map, where the key is the real value, and +// the value for the key is the pseudo values. +// It also checks that the pairs are consistent, i.e. that for each real value +// there is a corresponding pseudo value, and vice versa. If the pairs are +// inconsistent, it returns an error indicating the mismatch. +func collectGroupPairs(bkt *bbolt.Bucket) (map[string]string, error) { + var ( + realToPseudoRes map[string]string + pseudoToRealRes map[string]string + err error + missMatchErr = errors.New("privacy mapper pairs mismatch") + ) + + if realBkt := bkt.Bucket(realToPseudoKey); realBkt != nil { + realToPseudoRes, err = collectPairs(realBkt) + if err != nil { + return nil, fmt.Errorf("fetching real to pseudo pairs "+ + "failed: %w", err) + } + } else { + return nil, fmt.Errorf("%s bucket not found", realToPseudoKey) + } + + if pseudoBkt := bkt.Bucket(pseudoToRealKey); pseudoBkt != nil { + pseudoToRealRes, err = collectPairs(pseudoBkt) + if err != nil { + return nil, fmt.Errorf("fetching pseudo to real pairs "+ + "failed: %w", err) + } + } else { + return nil, fmt.Errorf("%s bucket not found", pseudoToRealKey) + } + + if len(realToPseudoRes) != len(pseudoToRealRes) { + return nil, missMatchErr + } + + for realVal, pseudoVal := range realToPseudoRes { + if rv, ok := pseudoToRealRes[pseudoVal]; !ok || rv != realVal { + return nil, missMatchErr + } + } + + return realToPseudoRes, nil +} + +// collectPairs collects all privacy pairs from a specific realToPseudoKey or +// pseudoToRealKey bucket in the KV store. It returns a map where the key is +// the real value or pseudo value, and the value is the corresponding pseudo +// value or real value, respectively (depending on if the realToPseudo or +// pseudoToReal bucket is passed to the function). +func collectPairs(pairsBucket *bbolt.Bucket) (map[string]string, error) { + pairsRes := make(map[string]string) + + return pairsRes, pairsBucket.ForEach(func(k, v []byte) error { + if v == nil { + return fmt.Errorf("expected only key-values under "+ + "pairs bucket, but found bucket %s", k) + } + + if len(v) == 0 { + return fmt.Errorf("empty value stored for privacy "+ + "pairs key %s", k) + } + + pairsRes[string(k)] = string(v) + + return nil + }) +} + +// insertPrivacyPairs inserts the collected privacy pairs into the SQL database. +func insertPrivacyPairs(ctx context.Context, sqlTx SQLQueries, + pairs privacyPairs) error { + + for groupId, groupPairs := range pairs { + err := insertGroupPairs(ctx, sqlTx, groupPairs, groupId) + if err != nil { + return fmt.Errorf("inserting group pairs for group "+ + "id %d failed: %w", groupId, err) + } + } + + return nil +} + +// insertGroupPairs inserts the privacy pairs for a specific group into +// the SQL database. It checks for duplicates before inserting, and returns +// an error if a duplicate pair is found. The function takes a map of real +// to pseudo values, where the key is the real value and the value is the +// corresponding pseudo value. +func insertGroupPairs(ctx context.Context, sqlTx SQLQueries, + pairs map[string]string, groupID int64) error { + + for realVal, pseudoVal := range pairs { + _, err := sqlTx.GetPseudoForReal( + ctx, sqlc.GetPseudoForRealParams{ + GroupID: groupID, + RealVal: realVal, + }, + ) + if err == nil { + return fmt.Errorf("duplicate privacy pair %s:%s: %w", + realVal, pseudoVal, ErrDuplicatePseudoValue) + } else if !errors.Is(err, sql.ErrNoRows) { + return err + } + + _, err = sqlTx.GetRealForPseudo( + ctx, sqlc.GetRealForPseudoParams{ + GroupID: groupID, + PseudoVal: pseudoVal, + }, + ) + if err == nil { + return fmt.Errorf("duplicate privacy pair %s:%s: %w", + realVal, pseudoVal, ErrDuplicatePseudoValue) + } else if !errors.Is(err, sql.ErrNoRows) { + return err + } + + err = sqlTx.InsertPrivacyPair( + ctx, sqlc.InsertPrivacyPairParams{ + GroupID: groupID, + RealVal: realVal, + PseudoVal: pseudoVal, + }, + ) + if err != nil { + return fmt.Errorf("inserting privacy pair %s:%s "+ + "failed: %w", realVal, pseudoVal, err) + } + } + + return nil +} + +// validatePrivacyPairsMigration validates that the migrated privacy pairs +// match the original values in the KV store. +func validatePrivacyPairsMigration(ctx context.Context, sqlTx SQLQueries, + pairs privacyPairs) error { + + for groupId, groupPairs := range pairs { + err := validateGroupPairsMigration( + ctx, sqlTx, groupPairs, groupId, + ) + if err != nil { + return fmt.Errorf("migration validation of privacy "+ + "pairs for group %d failed: %w", groupId, err) + } + } + + return nil +} + +// validateGroupPairsMigration validates that the migrated privacy pairs for +// a specific group match the original values in the KV store. It checks that +// for each real value, the pseudo value in the SQL database matches the +// original pseudo value, and vice versa. If any mismatch is found, it returns +// an error indicating the mismatch. +func validateGroupPairsMigration(ctx context.Context, sqlTx SQLQueries, + pairs map[string]string, groupID int64) error { + + for realVal, pseudoVal := range pairs { + resPseudoVal, err := sqlTx.GetPseudoForReal( + ctx, sqlc.GetPseudoForRealParams{ + GroupID: groupID, + RealVal: realVal, + }, + ) + if errors.Is(err, sql.ErrNoRows) { + return fmt.Errorf("migrated privacy pair %s:%s not "+ + "found for real value", realVal, pseudoVal) + } + if err != nil { + return err + } + + if resPseudoVal != pseudoVal { + return fmt.Errorf("pseudo value in db %s, does not "+ + "match original value %s, for real value %s", + resPseudoVal, pseudoVal, realVal) + } + + resRealVal, err := sqlTx.GetRealForPseudo( + ctx, sqlc.GetRealForPseudoParams{ + GroupID: groupID, + PseudoVal: pseudoVal, + }, + ) + if errors.Is(err, sql.ErrNoRows) { + return fmt.Errorf("migrated privacy pair %s:%s not "+ + "found for pseudo value", realVal, pseudoVal) + } + if err != nil { + return err + } + + if resRealVal != realVal { + return fmt.Errorf("real value in db %s, does not "+ + "match original value %s, for pseudo value %s", + resRealVal, realVal, pseudoVal) + } + } + + return nil +} diff --git a/firewalldb/sql_migration_test.go b/firewalldb/sql_migration_test.go index 1298e3e53..cdc7b655d 100644 --- a/firewalldb/sql_migration_test.go +++ b/firewalldb/sql_migration_test.go @@ -34,6 +34,12 @@ var ( testEntryValue = []byte{1, 2, 3} ) +// expectedResult represents the expected result of a migration test. +type expectedResult struct { + kvEntries fn.Option[[]*kvEntry] + privPairs fn.Option[privacyPairs] +} + // TestFirewallDBMigration tests the migration of firewalldb from a bolt // backend to a SQL database. Note that this test does not attempt to be a // complete migration test. @@ -72,10 +78,10 @@ func TestFirewallDBMigration(t *testing.T) { return store, genericExecutor } - // The assertMigrationResults function will currently assert that + // The assertKvStoreMigrationResults function will currently assert that // the migrated kv stores entries in the SQLDB match the original kv // stores entries in the BoltDB. - assertMigrationResults := func(t *testing.T, sqlStore *SQLDB, + assertKvStoreMigrationResults := func(t *testing.T, sqlStore *SQLDB, kvEntries []*kvEntry) { var ( @@ -213,6 +219,50 @@ func TestFirewallDBMigration(t *testing.T) { } } + assertPrivacyMapperMigrationResults := func(t *testing.T, + sqlStore *SQLDB, privPairs privacyPairs) { + + for groupID, groupPairs := range privPairs { + storePairs, err := sqlStore.GetAllPrivacyPairs( + ctx, groupID, + ) + require.NoError(t, err) + require.Len(t, storePairs, len(groupPairs)) + + for _, storePair := range storePairs { + // Assert that the store pair is in the + // original pairs. + pseudo, ok := groupPairs[storePair.RealVal] + require.True(t, ok) + + // Assert that the pseudo value matches + // the one in the store. + require.Equal(t, pseudo, storePair.PseudoVal) + } + } + } + + // The assertMigrationResults function will currently assert that + // the migrated kv stores records and privacy pairs in the SQLDB match + // the original entries in the BoltDB. + assertMigrationResults := func(t *testing.T, sqlStore *SQLDB, + expRes *expectedResult) { + + // If the expected result contains kv records, then we + // assert that the kv store migration results match + // the expected results. + expRes.kvEntries.WhenSome(func(kvEntries []*kvEntry) { + assertKvStoreMigrationResults(t, sqlStore, kvEntries) + }) + + // If the expected result contains privacy pairs, then we + // assert that the privacy mapper migration results match + // the expected results. + expRes.privPairs.WhenSome(func(pairs privacyPairs) { + assertPrivacyMapperMigrationResults(t, sqlStore, pairs) + }) + } + // The tests slice contains all the tests that we will run for the // migration of the firewalldb from a BoltDB to a SQLDB. // Note that the tests currently only test the migration of the KV @@ -221,38 +271,63 @@ func TestFirewallDBMigration(t *testing.T) { tests := []struct { name string populateDB func(t *testing.T, ctx context.Context, - boltDB *BoltDB, sessionStore session.Store) []*kvEntry + boltDB *BoltDB, sessionStore session.Store) *expectedResult }{ { name: "empty", populateDB: func(t *testing.T, ctx context.Context, boltDB *BoltDB, - sessionStore session.Store) []*kvEntry { + sessionStore session.Store) *expectedResult { + + // Don't populate the DB, and return empty kv + // records and privacy pairs. - // Don't populate the DB. - return make([]*kvEntry, 0) + pairsRes := make(privacyPairs) + + return &expectedResult{ + kvEntries: fn.Some( + []*kvEntry{}, + ), + privPairs: fn.Some(pairsRes), + } }, }, { - name: "global entries", + name: "global kv entries", populateDB: globalEntries, }, { - name: "session specific entries", + name: "session specific kv entries", populateDB: sessionSpecificEntries, }, { - name: "feature specific entries", + name: "feature specific kv entries", populateDB: featureSpecificEntries, }, { - name: "all entry combinations", + name: "all kv entry combinations", populateDB: allEntryCombinations, }, { - name: "random entries", + name: "random kv entries", populateDB: randomKVEntries, }, + { + name: "one session and privacy pair", + populateDB: oneSessionAndPrivPair, + }, + { + name: "multiple sessions with one privacy pair", + populateDB: multiSessionsOnePrivPairs, + }, + { + name: "multiple privacy pairs", + populateDB: multipleSessionsAndPrivacyPairs, + }, + { + name: "random privacy pairs", + populateDB: randomPrivacyPairs, + }, } for _, test := range tests { @@ -315,7 +390,7 @@ func TestFirewallDBMigration(t *testing.T) { // globalEntries populates the kv store with one global entry for the temp // store, and one for the perm store. func globalEntries(t *testing.T, ctx context.Context, boltDB *BoltDB, - _ session.Store) []*kvEntry { + _ session.Store) *expectedResult { return insertTempAndPermEntry( t, ctx, boltDB, testRuleName, fn.None[[]byte](), @@ -327,7 +402,7 @@ func globalEntries(t *testing.T, ctx context.Context, boltDB *BoltDB, // entry for the local temp store, and one session specific entry for the perm // local store. func sessionSpecificEntries(t *testing.T, ctx context.Context, boltDB *BoltDB, - sessionStore session.Store) []*kvEntry { + sessionStore session.Store) *expectedResult { groupAlias := getNewSessionAlias(t, ctx, sessionStore) @@ -341,7 +416,7 @@ func sessionSpecificEntries(t *testing.T, ctx context.Context, boltDB *BoltDB, // entry for the local temp store, and one feature specific entry for the perm // local store. func featureSpecificEntries(t *testing.T, ctx context.Context, boltDB *BoltDB, - sessionStore session.Store) []*kvEntry { + sessionStore session.Store) *expectedResult { groupAlias := getNewSessionAlias(t, ctx, sessionStore) @@ -359,11 +434,11 @@ func featureSpecificEntries(t *testing.T, ctx context.Context, boltDB *BoltDB, // any entries when the entry set is more complex than just a single entry at // each level. func allEntryCombinations(t *testing.T, ctx context.Context, boltDB *BoltDB, - sessionStore session.Store) []*kvEntry { + sessionStore session.Store) *expectedResult { var result []*kvEntry - add := func(entry []*kvEntry) { - result = append(result, entry...) + add := func(entry *expectedResult) { + result = append(result, entry.kvEntries.UnwrapOrFail(t)...) } // First lets create standard entries at all levels, which represents @@ -444,7 +519,10 @@ func allEntryCombinations(t *testing.T, ctx context.Context, boltDB *BoltDB, fn.Some(testFeatureName), testEntryKey4, emptyValue, )) - return result + return &expectedResult{ + kvEntries: fn.Some(result), + privPairs: fn.None[privacyPairs](), + } } func getNewSessionAlias(t *testing.T, ctx context.Context, @@ -465,7 +543,7 @@ func getNewSessionAlias(t *testing.T, ctx context.Context, func insertTempAndPermEntry(t *testing.T, ctx context.Context, boltDB *BoltDB, ruleName string, groupAlias fn.Option[[]byte], featureNameOpt fn.Option[string], entryKey string, - entryValue []byte) []*kvEntry { + entryValue []byte) *expectedResult { tempKvEntry := &kvEntry{ ruleName: ruleName, @@ -489,7 +567,11 @@ func insertTempAndPermEntry(t *testing.T, ctx context.Context, insertKvEntry(t, ctx, boltDB, permKvEntry) - return []*kvEntry{tempKvEntry, permKvEntry} + return &expectedResult{ + kvEntries: fn.Some([]*kvEntry{tempKvEntry, permKvEntry}), + // No privacy pairs are inserted in this test. + privPairs: fn.None[privacyPairs](), + } } // insertKvEntry populates the kv store with passed entry, and asserts that the @@ -538,7 +620,7 @@ func insertKvEntry(t *testing.T, ctx context.Context, // across all possible combinations of different levels of entries in the kv // store. All values and different bucket names are randomly generated. func randomKVEntries(t *testing.T, ctx context.Context, - boltDB *BoltDB, sessionStore session.Store) []*kvEntry { + boltDB *BoltDB, sessionStore session.Store) *expectedResult { var ( // We set the number of entries to insert to 1000, as that @@ -649,7 +731,139 @@ func randomKVEntries(t *testing.T, ctx context.Context, insertedEntries = append(insertedEntries, entry) } - return insertedEntries + return &expectedResult{ + kvEntries: fn.Some(insertedEntries), + // No privacy pairs are inserted in this test. + privPairs: fn.None[privacyPairs](), + } +} + +// multiSessionsOnePrivPairs inserts 1 session with 1 privacy pair into the +// boltDB. +func oneSessionAndPrivPair(t *testing.T, ctx context.Context, + boltDB *BoltDB, sessionStore session.Store) *expectedResult { + + return createPrivacyPairs(t, ctx, boltDB, sessionStore, 1, 1) +} + +// multiSessionsOnePrivPairs inserts 1 session with 10 privacy pairs into the +// boltDB. +func multiSessionsOnePrivPairs(t *testing.T, ctx context.Context, + boltDB *BoltDB, sessionStore session.Store) *expectedResult { + + return createPrivacyPairs(t, ctx, boltDB, sessionStore, 1, 10) +} + +// multipleSessionsAndPrivacyPairs inserts 5 sessions with 10 privacy pairs +// per session into the boltDB. +func multipleSessionsAndPrivacyPairs(t *testing.T, ctx context.Context, + boltDB *BoltDB, sessionStore session.Store) *expectedResult { + + return createPrivacyPairs(t, ctx, boltDB, sessionStore, 5, 10) +} + +// createPrivacyPairs is a helper function that creates a number of sessions +// with a number of privacy pairs per session. It returns an expectedResult +// struct that contains the expected privacy pairs and no kv records. +func createPrivacyPairs(t *testing.T, ctx context.Context, + boltDB *BoltDB, sessionStore session.Store, numSessions int, + numPairsPerSession int) *expectedResult { + + pairs := make(privacyPairs) + + sessSQLStore, ok := sessionStore.(*session.SQLStore) + require.True(t, ok) + + for i := 0; i < numSessions; i++ { + sess, err := sessionStore.NewSession( + ctx, fmt.Sprintf("session-%d", i), + session.Type(uint8(rand.Intn(5))), + time.Unix(1000, 0), randomString(rand.Intn(10)+1), + ) + require.NoError(t, err) + + groupID := sess.GroupID + sqlGroupID, err := sessSQLStore.GetSessionIDByAlias( + ctx, groupID[:], + ) + require.NoError(t, err) + + groupPairs := make(map[string]string) + + for j := 0; j < numPairsPerSession; j++ { + realKey := fmt.Sprintf("real-%d-%d", i, j) + pseudoKey := fmt.Sprintf("pseudo-%d-%d", i, j) + + f := func(ctx context.Context, tx PrivacyMapTx) error { + return tx.NewPair(ctx, realKey, pseudoKey) + } + + err := boltDB.PrivacyDB(groupID).Update(ctx, f) + require.NoError(t, err) + + groupPairs[realKey] = pseudoKey + } + + pairs[sqlGroupID] = groupPairs + } + + return &expectedResult{ + kvEntries: fn.None[[]*kvEntry](), + privPairs: fn.Some(pairs), + } +} + +// randomPrivacyPairs creates a random number of privacy pairs to 10 sessions. +func randomPrivacyPairs(t *testing.T, ctx context.Context, + boltDB *BoltDB, sessionStore session.Store) *expectedResult { + + numSessions := 10 + maxPairsPerSession := 20 + pairs := make(privacyPairs) + + sessSQLStore, ok := sessionStore.(*session.SQLStore) + require.True(t, ok) + + for i := 0; i < numSessions; i++ { + sess, err := sessionStore.NewSession( + ctx, fmt.Sprintf("rand-session-%d", i), + session.Type(uint8(rand.Intn(5))), + time.Unix(1000, 0), randomString(rand.Intn(10)+1), + ) + require.NoError(t, err) + + groupID := sess.GroupID + sqlGroupID, err := sessSQLStore.GetSessionIDByAlias( + ctx, groupID[:], + ) + require.NoError(t, err) + + numPairs := rand.Intn(maxPairsPerSession) + 1 + groupPairs := make(map[string]string) + + for j := 0; j < numPairs; j++ { + realKey := fmt.Sprintf("real-%s", + randomString(rand.Intn(10)+5)) + pseudoKey := fmt.Sprintf("pseudo-%s", + randomString(rand.Intn(10)+5)) + + f := func(ctx context.Context, tx PrivacyMapTx) error { + return tx.NewPair(ctx, realKey, pseudoKey) + } + + err := boltDB.PrivacyDB(groupID).Update(ctx, f) + require.NoError(t, err) + + groupPairs[realKey] = pseudoKey + } + + pairs[sqlGroupID] = groupPairs + } + + return &expectedResult{ + kvEntries: fn.None[[]*kvEntry](), + privPairs: fn.Some(pairs), + } } // randomString generates a random string of the passed length n.