diff --git a/catalog/internal_tables.go b/catalog/internal_tables.go index 1dcbf325..ce505716 100644 --- a/catalog/internal_tables.go +++ b/catalog/internal_tables.go @@ -15,6 +15,29 @@ func (it *InternalTable) QualifiedName() string { return it.Schema + "." + it.Name } +func (it *InternalTable) UpdateStmt(keyColumns []string, valueColumns []string) string { + var b strings.Builder + b.Grow(128) + b.WriteString("UPDATE ") + b.WriteString(it.QualifiedName()) + b.WriteString(" SET " + valueColumns[0] + " = ?") + + for _, valueColumn := range valueColumns[1:] { + b.WriteString(", ") + b.WriteString(valueColumn) + b.WriteString(" = ?") + } + + b.WriteString(" WHERE " + keyColumns[0] + " = ?") + for _, keyColumn := range keyColumns[1:] { + b.WriteString(", ") + b.WriteString(keyColumn) + b.WriteString(" = ?") + } + + return b.String() +} + func (it *InternalTable) UpsertStmt() string { var b strings.Builder b.Grow(128) @@ -50,6 +73,16 @@ func (it *InternalTable) DeleteStmt() string { return b.String() } +func (it *InternalTable) DeleteAllStmt() string { + var b strings.Builder + b.Grow(128) + b.WriteString("DELETE FROM ") + b.WriteString(it.Schema) + b.WriteByte('.') + b.WriteString(it.Name) + return b.String() +} + func (it *InternalTable) SelectStmt() string { var b strings.Builder b.Grow(128) @@ -74,6 +107,30 @@ func (it *InternalTable) SelectStmt() string { return b.String() } +func (it *InternalTable) SelectColumnsStmt(valueColumns []string) string { + var b strings.Builder + b.Grow(128) + b.WriteString("SELECT ") + b.WriteString(valueColumns[0]) + for _, c := range valueColumns[1:] { + b.WriteString(", ") + b.WriteString(c) + } + b.WriteString(" FROM ") + b.WriteString(it.Schema) + b.WriteByte('.') + b.WriteString(it.Name) + b.WriteString(" WHERE ") + b.WriteString(it.KeyColumns[0]) + b.WriteString(" = ?") + for _, c := range it.KeyColumns[1:] { + b.WriteString(" AND ") + b.WriteString(c) + b.WriteString(" = ?") + } + return b.String() +} + func (it *InternalTable) SelectAllStmt() string { var b strings.Builder b.Grow(128) @@ -98,7 +155,6 @@ func (it *InternalTable) CountAllStmt() string { var InternalTables = struct { PersistentVariable InternalTable BinlogPosition InternalTable - PgReplicationLSN InternalTable PgSubscription InternalTable GlobalStatus InternalTable // TODO(sean): This is a temporary work around for clients that query the 'pg_catalog.pg_stat_replication'. @@ -120,19 +176,12 @@ var InternalTables = struct { ValueColumns: []string{"position"}, DDL: "channel TEXT PRIMARY KEY, position TEXT", }, - PgReplicationLSN: InternalTable{ - Schema: "__sys__", - Name: "pg_replication_lsn", - KeyColumns: []string{"slot_name"}, - ValueColumns: []string{"lsn"}, - DDL: "slot_name TEXT PRIMARY KEY, lsn TEXT", - }, PgSubscription: InternalTable{ Schema: "__sys__", Name: "pg_subscription", - KeyColumns: []string{"name"}, - ValueColumns: []string{"connection", "publication"}, - DDL: "name TEXT PRIMARY KEY, connection TEXT, publication TEXT", + KeyColumns: []string{"subname"}, + ValueColumns: []string{"subconninfo", "subpublication", "subskiplsn", "subenabled"}, + DDL: "subname TEXT PRIMARY KEY, subconninfo TEXT, subpublication TEXT, subskiplsn TEXT, subenabled BOOLEAN", }, GlobalStatus: InternalTable{ Schema: "performance_schema", @@ -227,7 +276,6 @@ var InternalTables = struct { var internalTables = []InternalTable{ InternalTables.PersistentVariable, InternalTables.BinlogPosition, - InternalTables.PgReplicationLSN, InternalTables.PgSubscription, InternalTables.GlobalStatus, InternalTables.PGStatReplication, diff --git a/main.go b/main.go index 0e55a7c1..17a52295 100644 --- a/main.go +++ b/main.go @@ -182,15 +182,9 @@ func main() { } // Check if there is a replication subscription and start replication if there is. - _, conn, pub, ok, err := logrepl.FindReplication(pool.DB) + err = logrepl.UpdateSubscriptions(pgServer.NewInternalCtx()) if err != nil { - logrus.WithError(err).Warnln("Failed to find replication") - } else if ok { - replicator, err := logrepl.NewLogicalReplicator(conn) - if err != nil { - logrus.WithError(err).Fatalln("Failed to create logical replicator") - } - go replicator.StartReplication(pgServer.NewInternalCtx(), pub) + logrus.WithError(err).Warnln("Failed to update subscriptions") } // Load the configuration for the Postgres server. diff --git a/pgserver/connection_handler.go b/pgserver/connection_handler.go index 2791a254..60b949c6 100644 --- a/pgserver/connection_handler.go +++ b/pgserver/connection_handler.go @@ -982,7 +982,7 @@ func (h *ConnectionHandler) query(query ConvertedQuery) error { } if query.SubscriptionConfig != nil { - return h.executeCreateSubscriptionSQL(query.SubscriptionConfig) + return h.executeSubscriptionSQL(query.SubscriptionConfig) } callback := h.spoolRowsCallback(query.StatementTag, &rowsAffected, false) diff --git a/pgserver/logrepl/replication.go b/pgserver/logrepl/replication.go index afc1d1bd..8a3283ec 100644 --- a/pgserver/logrepl/replication.go +++ b/pgserver/logrepl/replication.go @@ -26,7 +26,6 @@ import ( "github.com/apecloud/myduckserver/adapter" "github.com/apecloud/myduckserver/binlog" - "github.com/apecloud/myduckserver/catalog" "github.com/apecloud/myduckserver/delta" "github.com/apecloud/myduckserver/pgtypes" "github.com/dolthub/go-mysql-server/sql" @@ -46,6 +45,7 @@ type rcvMsg struct { } type LogicalReplicator struct { + subscription string primaryDns string flushInterval time.Duration @@ -60,8 +60,9 @@ type LogicalReplicator struct { // NewLogicalReplicator creates a new logical replicator instance which connects to the primary and replication // databases using the connection strings provided. The connection to the replica is established immediately, and the // connection to the primary is established when StartReplication is called. -func NewLogicalReplicator(primaryDns string) (*LogicalReplicator, error) { +func NewLogicalReplicator(subscription, primaryDns string) (*LogicalReplicator, error) { return &LogicalReplicator{ + subscription: subscription, primaryDns: primaryDns, flushInterval: 200 * time.Millisecond, mu: &sync.Mutex{}, @@ -222,7 +223,7 @@ func (r *LogicalReplicator) StartReplication(sqlCtx *sql.Context, slotName strin standbyMessageTimeout := 10 * time.Second nextStandbyMessageDeadline := time.Now().Add(standbyMessageTimeout) - lastWrittenLsn, err := r.readWALPosition(sqlCtx, slotName) + lastWrittenLsn, err := SelectSubscriptionLsn(sqlCtx, r.subscription) if err != nil { return err } @@ -881,26 +882,6 @@ func (r *LogicalReplicator) processMessage( return false, nil } -// readWALPosition reads the recorded WAL position from the WAL position table -func (r *LogicalReplicator) readWALPosition(ctx *sql.Context, slotName string) (pglogrepl.LSN, error) { - var lsn string - if err := adapter.QueryRowCatalog(ctx, catalog.InternalTables.PgReplicationLSN.SelectStmt(), slotName).Scan(&lsn); err != nil { - if errors.Is(err, stdsql.ErrNoRows) { - // if the LSN doesn't exist, consider this a cold start and return 0 - return pglogrepl.LSN(0), nil - } - return 0, err - } - - return pglogrepl.ParseLSN(lsn) -} - -// WriteWALPosition writes the recorded WAL position to the WAL position table -func (r *LogicalReplicator) WriteWALPosition(ctx *sql.Context, slotName string, lsn pglogrepl.LSN) error { - _, err := adapter.ExecCatalogInTxn(ctx, catalog.InternalTables.PgReplicationLSN.UpsertStmt(), slotName, lsn.String()) - return err -} - // whereClause returns a WHERE clause string with the contents of the builder if it's non-empty, or the empty // string otherwise func whereClause(str strings.Builder) string { @@ -1000,7 +981,7 @@ func (r *LogicalReplicator) commitOngoingTxn(state *replicationState, flushReaso } r.logger.Debugf("Writing LSN %s\n", state.lastCommitLSN) - if err = r.WriteWALPosition(state.replicaCtx, state.slotName, state.lastCommitLSN); err != nil { + if err = UpdateSubscriptionLsn(state.replicaCtx, state.lastCommitLSN.String(), r.subscription); err != nil { return err } diff --git a/pgserver/logrepl/replication_test.go b/pgserver/logrepl/replication_test.go index d12f0619..fa2ae542 100644 --- a/pgserver/logrepl/replication_test.go +++ b/pgserver/logrepl/replication_test.go @@ -17,6 +17,8 @@ package logrepl_test import ( "context" "fmt" + "github.com/apecloud/myduckserver/adapter" + "github.com/jackc/pglogrepl" "log" "os" "os/exec" @@ -660,6 +662,7 @@ func RunReplicationScripts(t *testing.T, scripts []ReplicationTest) { } const slotName = "myduck_slot" +const subscriptionName = "my_sub_test" // RunReplicationScript runs the given ReplicationTest. func RunReplicationScript(t *testing.T, dsn string, script ReplicationTest) { @@ -686,8 +689,18 @@ func RunReplicationScript(t *testing.T, dsn string, script ReplicationTest) { }) } -func newReplicator(t *testing.T, primaryDns string) *logrepl.LogicalReplicator { - r, err := logrepl.NewLogicalReplicator(primaryDns) +func newReplicator(sqlCtx *sql.Context, t *testing.T, primaryDns string) *logrepl.LogicalReplicator { + err := logrepl.CreateSubscription(sqlCtx, subscriptionName, primaryDns, slotName, pglogrepl.LSN(0).String(), true) + require.NoError(t, err) + + tx := adapter.TryGetTxn(sqlCtx) + if tx != nil { + err := tx.Commit() + require.NoError(t, err) + adapter.CloseTxn(sqlCtx) + } + + r, err := logrepl.NewLogicalReplicator(subscriptionName, primaryDns) require.NoError(t, err) return r } @@ -701,7 +714,7 @@ func runReplicationScript( replicaConn *pgx.Conn, primaryDns string, ) { - r := newReplicator(t, primaryDns) + r := newReplicator(server.NewInternalCtx(), t, primaryDns) defer r.Stop() if script.Skip { diff --git a/pgserver/logrepl/subscription.go b/pgserver/logrepl/subscription.go index 34b1f876..b4f9c920 100644 --- a/pgserver/logrepl/subscription.go +++ b/pgserver/logrepl/subscription.go @@ -1,35 +1,137 @@ package logrepl import ( - "context" stdsql "database/sql" - + "errors" + "fmt" "github.com/apecloud/myduckserver/adapter" "github.com/apecloud/myduckserver/catalog" "github.com/dolthub/go-mysql-server/sql" + "github.com/jackc/pglogrepl" + "sync" ) -func WriteSubscription(ctx *sql.Context, name, conn, pub string) error { - _, err := adapter.ExecCatalogInTxn(ctx, catalog.InternalTables.PgSubscription.UpsertStmt(), name, conn, pub) - return err +type Subscription struct { + Subscription string + Conn string + Publication string + LsnStr string + Enabled bool + Replicator *LogicalReplicator } -func FindReplication(db *stdsql.DB) (name, conn, pub string, ok bool, err error) { - var rows *stdsql.Rows - rows, err = db.QueryContext(context.Background(), catalog.InternalTables.PgSubscription.SelectAllStmt()) +var keyColumns = []string{"subname"} +var statusValueColumns = []string{"subenabled"} +var lsnValueColumns = []string{"subskiplsn"} + +var subscriptionMap = sync.Map{} + +func UpdateSubscriptions(ctx *sql.Context) error { + rows, err := adapter.QueryCatalog(ctx, catalog.InternalTables.PgSubscription.SelectAllStmt()) if err != nil { - return + return err } defer rows.Close() - if !rows.Next() { - return + var subMap = make(map[string]*Subscription) + for rows.Next() { + var name, conn, pub, lsn string + var enabled bool + if err := rows.Scan(&name, &conn, &pub, &lsn, &enabled); err != nil { + return err + } + subMap[name] = &Subscription{ + Subscription: name, + Conn: conn, + Publication: pub, + LsnStr: lsn, + Enabled: enabled, + Replicator: nil, + } } - if err = rows.Scan(&name, &conn, &pub); err != nil { - return + if err = rows.Err(); err != nil { + return err + } + + for tempName, tempSub := range subMap { + if _, loaded := subscriptionMap.LoadOrStore(tempName, tempSub); !loaded { + replicator, err := NewLogicalReplicator(tempName, tempSub.Conn) + if err != nil { + return fmt.Errorf("failed to create logical replicator: %v", err) + } + + if sub, ok := subscriptionMap.Load(tempName); ok { + if subscription, ok := sub.(*Subscription); ok { + subscription.Replicator = replicator + } + } + + err = replicator.CreateReplicationSlotIfNotExists(tempSub.Publication) + if err != nil { + return fmt.Errorf("failed to create replication slot: %v", err) + } + if tempSub.Enabled { + go replicator.StartReplication(ctx, tempSub.Publication) + } + } else { + if sub, ok := subscriptionMap.Load(tempName); ok { + if subscription, ok := sub.(*Subscription); ok { + if tempSub.Enabled != subscription.Enabled { + subscription.Enabled = tempSub.Enabled + if subscription.Enabled { + go subscription.Replicator.StartReplication(ctx, subscription.Publication) + } else { + subscription.Replicator.Stop() + } + } + } + } + } + } + + subscriptionMap.Range(func(key, value interface{}) bool { + name, _ := key.(string) + subscription, _ := value.(*Subscription) + if _, ok := subMap[name]; !ok { + subscription.Replicator.Stop() + subscriptionMap.Delete(name) + } + return true + }) + + return nil +} + +func CreateSubscription(ctx *sql.Context, name, conn, pub, lsn string, enabled bool) error { + _, err := adapter.ExecCatalogInTxn(ctx, catalog.InternalTables.PgSubscription.UpsertStmt(), name, conn, pub, lsn, enabled) + return err +} + +func UpdateSubscriptionStatus(ctx *sql.Context, enabled bool, name string) error { + _, err := adapter.ExecCatalogInTxn(ctx, catalog.InternalTables.PgSubscription.UpdateStmt(keyColumns, statusValueColumns), enabled, name) + return err +} + +func DeleteSubscription(ctx *sql.Context, name string) error { + _, err := adapter.ExecCatalogInTxn(ctx, catalog.InternalTables.PgSubscription.DeleteStmt(), name) + return err +} + +func UpdateSubscriptionLsn(ctx *sql.Context, lsn, name string) error { + _, err := adapter.ExecCatalogInTxn(ctx, catalog.InternalTables.PgSubscription.UpdateStmt(keyColumns, lsnValueColumns), lsn, name) + return err +} + +func SelectSubscriptionLsn(ctx *sql.Context, subscription string) (pglogrepl.LSN, error) { + var lsn string + if err := adapter.QueryRowCatalog(ctx, catalog.InternalTables.PgSubscription.SelectColumnsStmt(lsnValueColumns), subscription).Scan(&lsn); err != nil { + if errors.Is(err, stdsql.ErrNoRows) { + // if the LSN doesn't exist, consider this a cold start and return 0 + return pglogrepl.LSN(0), nil + } + return 0, err } - ok = true - return + return pglogrepl.ParseLSN(lsn) } diff --git a/pgserver/pg_catalog_handler.go b/pgserver/pg_catalog_handler.go index 71cdd83c..f3d0dcb6 100644 --- a/pgserver/pg_catalog_handler.go +++ b/pgserver/pg_catalog_handler.go @@ -11,7 +11,7 @@ import ( "github.com/apecloud/myduckserver/adapter" "github.com/apecloud/myduckserver/catalog" duckConfig "github.com/apecloud/myduckserver/configuration" - tree "github.com/cockroachdb/cockroachdb-parser/pkg/sql/sem/tree" + "github.com/cockroachdb/cockroachdb-parser/pkg/sql/sem/tree" "github.com/dolthub/go-mysql-server/sql" "github.com/jackc/pgx/v5/pgproto3" ) @@ -36,7 +36,7 @@ func (h *ConnectionHandler) isInRecovery() (string, error) { return "f", err } var count int - if err := adapter.QueryRow(ctx, catalog.InternalTables.PgReplicationLSN.CountAllStmt()).Scan(&count); err != nil { + if err := adapter.QueryRow(ctx, catalog.InternalTables.PgSubscription.CountAllStmt()).Scan(&count); err != nil { return "f", err } @@ -54,9 +54,12 @@ func (h *ConnectionHandler) readOneWALPositionStr() (string, error) { if err != nil { return "0/0", err } - var slotName string - var lsn string - if err := adapter.QueryRow(ctx, catalog.InternalTables.PgReplicationLSN.SelectAllStmt()).Scan(&slotName, &lsn); err != nil { + + // TODO(neo.zty): needs to be fixed + var subscription, conn, publication, lsn string + var enabled bool + + if err := adapter.QueryRow(ctx, catalog.InternalTables.PgSubscription.SelectAllStmt()).Scan(&subscription, &conn, &publication, &lsn, &enabled); err != nil { if errors.Is(err, stdsql.ErrNoRows) { // if no lsn is stored, return 0 return "0/0", nil diff --git a/pgserver/subscription_handler.go b/pgserver/subscription_handler.go index 4b44f13f..f36c1eee 100644 --- a/pgserver/subscription_handler.go +++ b/pgserver/subscription_handler.go @@ -3,102 +3,221 @@ package pgserver import ( "context" "fmt" - "regexp" - "strings" - "github.com/apecloud/myduckserver/adapter" "github.com/apecloud/myduckserver/catalog" "github.com/apecloud/myduckserver/pgserver/logrepl" "github.com/dolthub/go-mysql-server/sql" "github.com/jackc/pglogrepl" + "regexp" + "strings" ) -// This file implements the logic for handling CREATE SUBSCRIPTION SQL statements. -// Example usage of CREATE SUBSCRIPTION SQL: +// This file handles SQL statements for managing PostgreSQL subscriptions. It supports: +// +// 1. Creating a subscription: +// CREATE SUBSCRIPTION mysub +// CONNECTION 'dbname= host=127.0.0.1 port=5432 user=postgres password=root' +// PUBLICATION mypub; +// This statement sets up a new subscription named 'mysub' that connects to a specified PostgreSQL +// database and listens for changes published under the 'mypub' publication. // -// CREATE SUBSCRIPTION mysub -// CONNECTION 'dbname= host=127.0.0.1 port=5432 user=postgres password=root' -// PUBLICATION mypub; +// 2. Altering a subscription (enable/disable): +// ALTER SUBSCRIPTION mysub enable; +// ALTER SUBSCRIPTION mysub disable; // -// The statement creates a subscription named 'mysub' that connects to a PostgreSQL -// database and subscribes to changes published under the 'mypub' publication. +// 3. Dropping a subscription: +// DROP SUBSCRIPTION mysub; +// This statement removes the specified subscription. + +// Action represents the type of SQL action. +type Action string + +const ( + Create Action = "CREATE" + Drop Action = "DROP" + AlterDisable Action = "DISABLE" + AlterEnable Action = "ENABLE" +) + +// ConnectionDetails holds parsed connection string components. +type ConnectionDetails struct { + DBName string + Host string + Port string + User string + Password string +} +// SubscriptionConfig represents the configuration of a subscription. type SubscriptionConfig struct { SubscriptionName string PublicationName string - DBName string - Host string - Port string - User string - Password string + Connection *ConnectionDetails // Embedded pointer to ConnectionDetails + Action Action } -var subscriptionRegex = regexp.MustCompile(`(?i)CREATE SUBSCRIPTION\s+(\w+)\s+CONNECTION\s+'([^']+)'\s+PUBLICATION\s+(\w+);`) +// createRegex matches and extracts components from a CREATE SUBSCRIPTION SQL statement. Example matched command: +var createRegex = regexp.MustCompile(`(?i)^CREATE\s+SUBSCRIPTION\s+([\w-]+)\s+CONNECTION\s+'([^']+)'(?:\s+PUBLICATION\s+([\w-]+))?;?$`) + +// alterRegex matches ALTER SUBSCRIPTION SQL commands and captures the subscription name and the action to be taken. +var alterRegex = regexp.MustCompile(`(?i)^ALTER\s+SUBSCRIPTION\s+([\w-]+)\s+(disable|enable);?$`) + +// dropRegex matches DROP SUBSCRIPTION SQL commands and captures the subscription name. +var dropRegex = regexp.MustCompile(`(?i)^DROP\s+SUBSCRIPTION\s+([\w-]+);?$`) + +// connectionRegex matches and captures key-value pairs within a connection string. var connectionRegex = regexp.MustCompile(`(\b\w+)=([\w\.\d]*)`) +// ParseSubscriptionSQL parses the given SQL statement and returns a SubscriptionConfig. +func parseSubscriptionSQL(sql string) (*SubscriptionConfig, error) { + var config SubscriptionConfig + switch { + case createRegex.MatchString(sql): + matches := createRegex.FindStringSubmatch(sql) + config.Action = Create + config.SubscriptionName = matches[1] + if len(matches) > 3 { + config.PublicationName = matches[3] + } + conn, err := parseConnectionString(matches[2]) + if err != nil { + return nil, err + } + config.Connection = conn + + case alterRegex.MatchString(sql): + matches := alterRegex.FindStringSubmatch(sql) + config.SubscriptionName = matches[1] + switch strings.ToUpper(matches[2]) { + case string(AlterDisable): + config.Action = AlterDisable + case string(AlterEnable): + config.Action = AlterEnable + default: + return nil, fmt.Errorf("invalid ALTER SUBSCRIPTION action: %s", matches[2]) + } + + case dropRegex.MatchString(sql): + matches := dropRegex.FindStringSubmatch(sql) + config.Action = Drop + config.SubscriptionName = matches[1] + + default: + return nil, nil + } + + return &config, nil +} + +// parseConnectionString parses the given connection string and returns a ConnectionDetails. +func parseConnectionString(connStr string) (*ConnectionDetails, error) { + details := &ConnectionDetails{} + pairs := connectionRegex.FindAllStringSubmatch(connStr, -1) + + if pairs == nil { + return nil, fmt.Errorf("no valid key-value pairs found in connection string") + } + + for _, pair := range pairs { + key := pair[1] + value := pair[2] + switch key { + case "dbname": + details.DBName = value + case "host": + details.Host = value + case "port": + details.Port = value + case "user": + details.User = value + case "password": + details.Password = value + } + } + + // Handle default values + if details.DBName == "" { + details.DBName = "postgres" + } + if details.Port == "" { + details.Port = "5432" + } + + return details, nil +} + // ToConnectionInfo Format SubscriptionConfig into a ConnectionInfo func (config *SubscriptionConfig) ToConnectionInfo() string { return fmt.Sprintf("dbname=%s user=%s password=%s host=%s port=%s", - config.DBName, config.User, config.Password, config.Host, config.Port) + config.Connection.DBName, config.Connection.User, config.Connection.Password, + config.Connection.Host, config.Connection.Port) } // ToDNS Format SubscriptionConfig into a DNS func (config *SubscriptionConfig) ToDNS() string { return fmt.Sprintf("postgres://%s:%s@%s:%s/%s", - config.User, config.Password, config.Host, config.Port, config.DBName) + config.Connection.User, config.Connection.Password, config.Connection.Host, + config.Connection.Port, config.Connection.DBName) } -func parseSubscriptionSQL(sql string) (*SubscriptionConfig, error) { - subscriptionMatch := subscriptionRegex.FindStringSubmatch(sql) - if len(subscriptionMatch) < 4 { - return nil, fmt.Errorf("invalid CREATE SUBSCRIPTION SQL format") +func (h *ConnectionHandler) executeSubscriptionSQL(subscriptionConfig *SubscriptionConfig) error { + switch subscriptionConfig.Action { + case Create: + return h.executeCreate(subscriptionConfig) + case Drop: + return h.executeDrop(subscriptionConfig) + case AlterEnable: + return h.executeEnableSubscription(subscriptionConfig) + case AlterDisable: + return h.executeDisableSubscription(subscriptionConfig) + default: + return fmt.Errorf("unsupported action: %s", subscriptionConfig.Action) } +} - subscriptionName := subscriptionMatch[1] - connectionString := subscriptionMatch[2] - publicationName := subscriptionMatch[3] +func (h *ConnectionHandler) executeEnableSubscription(subscriptionConfig *SubscriptionConfig) error { + sqlCtx, err := h.duckHandler.sm.NewContextWithQuery(context.Background(), h.mysqlConn, "") + if err != nil { + return fmt.Errorf("failed to create context for query: %w", err) + } - // Parse the connection string into key-value pairs - matches := connectionRegex.FindAllStringSubmatch(connectionString, -1) - if matches == nil { - return nil, fmt.Errorf("no valid key-value pairs found in connection string") + err = logrepl.UpdateSubscriptionStatus(sqlCtx, true, subscriptionConfig.SubscriptionName) + if err != nil { + return fmt.Errorf("failed to delete subscription: %w", err) } - // Initialize SubscriptionConfig struct - config := &SubscriptionConfig{ - SubscriptionName: subscriptionName, - PublicationName: publicationName, + return commitAndUpdate(sqlCtx) +} + +func (h *ConnectionHandler) executeDisableSubscription(subscriptionConfig *SubscriptionConfig) error { + sqlCtx, err := h.duckHandler.sm.NewContextWithQuery(context.Background(), h.mysqlConn, "") + if err != nil { + return fmt.Errorf("failed to create context for query: %w", err) } - // Map the matches to struct fields - for _, match := range matches { - key := strings.ToLower(match[1]) - switch key { - case "dbname": - config.DBName = match[2] - case "host": - config.Host = match[2] - case "port": - config.Port = match[2] - case "user": - config.User = match[2] - case "password": - config.Password = match[2] - } + err = logrepl.UpdateSubscriptionStatus(sqlCtx, false, subscriptionConfig.SubscriptionName) + if err != nil { + return fmt.Errorf("failed to delete subscription: %w", err) } - // Handle default values - if config.DBName == "" { - config.DBName = "postgres" + return commitAndUpdate(sqlCtx) +} + +func (h *ConnectionHandler) executeDrop(subscriptionConfig *SubscriptionConfig) error { + sqlCtx, err := h.duckHandler.sm.NewContextWithQuery(context.Background(), h.mysqlConn, "") + if err != nil { + return fmt.Errorf("failed to create context for query: %w", err) } - if config.Port == "" { - config.Port = "5432" + + err = logrepl.DeleteSubscription(sqlCtx, subscriptionConfig.SubscriptionName) + if err != nil { + return fmt.Errorf("failed to delete subscription: %w", err) } - return config, nil + return commitAndUpdate(sqlCtx) } -func (h *ConnectionHandler) executeCreateSubscriptionSQL(subscriptionConfig *SubscriptionConfig) error { +func (h *ConnectionHandler) executeCreate(subscriptionConfig *SubscriptionConfig) error { sqlCtx, err := h.duckHandler.sm.NewContextWithQuery(context.Background(), h.mysqlConn, "") if err != nil { return fmt.Errorf("failed to create context for query: %w", err) @@ -109,21 +228,11 @@ func (h *ConnectionHandler) executeCreateSubscriptionSQL(subscriptionConfig *Sub return fmt.Errorf("failed to create snapshot for CREATE SUBSCRIPTION: %w", err) } - // Do a checkpoint here to merge the WAL logs - // if _, err := adapter.ExecCatalog(sqlCtx, "FORCE CHECKPOINT"); err != nil { - // return fmt.Errorf("failed to execute FORCE CHECKPOINT: %w", err) - // } - // if _, err := adapter.ExecCatalog(sqlCtx, "PRAGMA force_checkpoint;"); err != nil { - // return fmt.Errorf("failed to execute PRAGMA force_checkpoint: %w", err) - // } - - replicator, err := h.doCreateSubscription(sqlCtx, subscriptionConfig, lsn) + err = h.doCreateSubscription(sqlCtx, subscriptionConfig, lsn) if err != nil { return fmt.Errorf("failed to execute CREATE SUBSCRIPTION: %w", err) } - go replicator.StartReplication(h.server.NewInternalCtx(), subscriptionConfig.PublicationName) - return nil } @@ -224,40 +333,40 @@ func (h *ConnectionHandler) doSnapshot(sqlCtx *sql.Context, subscriptionConfig * return lsn, txn.Commit() } -func (h *ConnectionHandler) doCreateSubscription(sqlCtx *sql.Context, subscriptionConfig *SubscriptionConfig, lsn pglogrepl.LSN) (*logrepl.LogicalReplicator, error) { - replicator, err := logrepl.NewLogicalReplicator(subscriptionConfig.ToDNS()) +func (h *ConnectionHandler) doCreateSubscription(sqlCtx *sql.Context, subscriptionConfig *SubscriptionConfig, lsn pglogrepl.LSN) error { + err := logrepl.CreatePublicationIfNotExists(subscriptionConfig.ToDNS(), subscriptionConfig.PublicationName) if err != nil { - return nil, fmt.Errorf("failed to create logical replicator: %w", err) + return fmt.Errorf("failed to create publication: %w", err) } - err = logrepl.CreatePublicationIfNotExists(subscriptionConfig.ToDNS(), subscriptionConfig.PublicationName) - if err != nil { - return nil, fmt.Errorf("failed to create publication: %w", err) - } - - err = replicator.CreateReplicationSlotIfNotExists(subscriptionConfig.PublicationName) - if err != nil { - return nil, fmt.Errorf("failed to create replication slot: %w", err) - } - - // `WriteWALPosition` and `WriteSubscription` execute in a transaction internally, - // so we start a transaction here and commit it after writing the WAL position. tx, err := adapter.GetCatalogTxn(sqlCtx, nil) if err != nil { - return nil, fmt.Errorf("failed to get transaction: %w", err) + return fmt.Errorf("failed to get transaction: %w", err) } defer tx.Rollback() defer adapter.CloseTxn(sqlCtx) - err = replicator.WriteWALPosition(sqlCtx, subscriptionConfig.PublicationName, lsn) + err = logrepl.CreateSubscription(sqlCtx, subscriptionConfig.SubscriptionName, subscriptionConfig.ToDNS(), subscriptionConfig.PublicationName, lsn.String(), true) if err != nil { - return nil, fmt.Errorf("failed to write WAL position: %w", err) + return fmt.Errorf("failed to write subscription: %w", err) + } + + return commitAndUpdate(sqlCtx) +} + +func commitAndUpdate(sqlCtx *sql.Context) error { + tx := adapter.TryGetTxn(sqlCtx) + if tx != nil { + if err := tx.Commit(); err != nil { + return fmt.Errorf("failed to commit transaction: %w", err) + } + adapter.CloseTxn(sqlCtx) } - err = logrepl.WriteSubscription(sqlCtx, subscriptionConfig.SubscriptionName, subscriptionConfig.ToDNS(), subscriptionConfig.PublicationName) + err := logrepl.UpdateSubscriptions(sqlCtx) if err != nil { - return nil, fmt.Errorf("failed to write subscription: %w", err) + return fmt.Errorf("failed to update subscriptions: %w", err) } - return replicator, tx.Commit() + return nil }