Skip to content

Commit

Permalink
feat: support for ALTER/DROP SUSCRIPTION (#253)
Browse files Browse the repository at this point in the history
  • Loading branch information
TianyuZhang1214 authored Dec 7, 2024
1 parent 19b1af5 commit 3c612a8
Show file tree
Hide file tree
Showing 8 changed files with 405 additions and 155 deletions.
72 changes: 60 additions & 12 deletions catalog/internal_tables.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,29 @@ func (it *InternalTable) QualifiedName() string {
return it.Schema + "." + it.Name
}

func (it *InternalTable) UpdateStmt(keyColumns []string, valueColumns []string) string {
var b strings.Builder
b.Grow(128)
b.WriteString("UPDATE ")
b.WriteString(it.QualifiedName())
b.WriteString(" SET " + valueColumns[0] + " = ?")

for _, valueColumn := range valueColumns[1:] {
b.WriteString(", ")
b.WriteString(valueColumn)
b.WriteString(" = ?")
}

b.WriteString(" WHERE " + keyColumns[0] + " = ?")
for _, keyColumn := range keyColumns[1:] {
b.WriteString(", ")
b.WriteString(keyColumn)
b.WriteString(" = ?")
}

return b.String()
}

func (it *InternalTable) UpsertStmt() string {
var b strings.Builder
b.Grow(128)
Expand Down Expand Up @@ -50,6 +73,16 @@ func (it *InternalTable) DeleteStmt() string {
return b.String()
}

func (it *InternalTable) DeleteAllStmt() string {
var b strings.Builder
b.Grow(128)
b.WriteString("DELETE FROM ")
b.WriteString(it.Schema)
b.WriteByte('.')
b.WriteString(it.Name)
return b.String()
}

func (it *InternalTable) SelectStmt() string {
var b strings.Builder
b.Grow(128)
Expand All @@ -74,6 +107,30 @@ func (it *InternalTable) SelectStmt() string {
return b.String()
}

func (it *InternalTable) SelectColumnsStmt(valueColumns []string) string {
var b strings.Builder
b.Grow(128)
b.WriteString("SELECT ")
b.WriteString(valueColumns[0])
for _, c := range valueColumns[1:] {
b.WriteString(", ")
b.WriteString(c)
}
b.WriteString(" FROM ")
b.WriteString(it.Schema)
b.WriteByte('.')
b.WriteString(it.Name)
b.WriteString(" WHERE ")
b.WriteString(it.KeyColumns[0])
b.WriteString(" = ?")
for _, c := range it.KeyColumns[1:] {
b.WriteString(" AND ")
b.WriteString(c)
b.WriteString(" = ?")
}
return b.String()
}

func (it *InternalTable) SelectAllStmt() string {
var b strings.Builder
b.Grow(128)
Expand All @@ -98,7 +155,6 @@ func (it *InternalTable) CountAllStmt() string {
var InternalTables = struct {
PersistentVariable InternalTable
BinlogPosition InternalTable
PgReplicationLSN InternalTable
PgSubscription InternalTable
GlobalStatus InternalTable
// TODO(sean): This is a temporary work around for clients that query the 'pg_catalog.pg_stat_replication'.
Expand All @@ -120,19 +176,12 @@ var InternalTables = struct {
ValueColumns: []string{"position"},
DDL: "channel TEXT PRIMARY KEY, position TEXT",
},
PgReplicationLSN: InternalTable{
Schema: "__sys__",
Name: "pg_replication_lsn",
KeyColumns: []string{"slot_name"},
ValueColumns: []string{"lsn"},
DDL: "slot_name TEXT PRIMARY KEY, lsn TEXT",
},
PgSubscription: InternalTable{
Schema: "__sys__",
Name: "pg_subscription",
KeyColumns: []string{"name"},
ValueColumns: []string{"connection", "publication"},
DDL: "name TEXT PRIMARY KEY, connection TEXT, publication TEXT",
KeyColumns: []string{"subname"},
ValueColumns: []string{"subconninfo", "subpublication", "subskiplsn", "subenabled"},
DDL: "subname TEXT PRIMARY KEY, subconninfo TEXT, subpublication TEXT, subskiplsn TEXT, subenabled BOOLEAN",
},
GlobalStatus: InternalTable{
Schema: "performance_schema",
Expand Down Expand Up @@ -227,7 +276,6 @@ var InternalTables = struct {
var internalTables = []InternalTable{
InternalTables.PersistentVariable,
InternalTables.BinlogPosition,
InternalTables.PgReplicationLSN,
InternalTables.PgSubscription,
InternalTables.GlobalStatus,
InternalTables.PGStatReplication,
Expand Down
10 changes: 2 additions & 8 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,15 +182,9 @@ func main() {
}

// Check if there is a replication subscription and start replication if there is.
_, conn, pub, ok, err := logrepl.FindReplication(pool.DB)
err = logrepl.UpdateSubscriptions(pgServer.NewInternalCtx())
if err != nil {
logrus.WithError(err).Warnln("Failed to find replication")
} else if ok {
replicator, err := logrepl.NewLogicalReplicator(conn)
if err != nil {
logrus.WithError(err).Fatalln("Failed to create logical replicator")
}
go replicator.StartReplication(pgServer.NewInternalCtx(), pub)
logrus.WithError(err).Warnln("Failed to update subscriptions")
}

// Load the configuration for the Postgres server.
Expand Down
2 changes: 1 addition & 1 deletion pgserver/connection_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -982,7 +982,7 @@ func (h *ConnectionHandler) query(query ConvertedQuery) error {
}

if query.SubscriptionConfig != nil {
return h.executeCreateSubscriptionSQL(query.SubscriptionConfig)
return h.executeSubscriptionSQL(query.SubscriptionConfig)
}

callback := h.spoolRowsCallback(query.StatementTag, &rowsAffected, false)
Expand Down
29 changes: 5 additions & 24 deletions pgserver/logrepl/replication.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ import (

"github.com/apecloud/myduckserver/adapter"
"github.com/apecloud/myduckserver/binlog"
"github.com/apecloud/myduckserver/catalog"
"github.com/apecloud/myduckserver/delta"
"github.com/apecloud/myduckserver/pgtypes"
"github.com/dolthub/go-mysql-server/sql"
Expand All @@ -46,6 +45,7 @@ type rcvMsg struct {
}

type LogicalReplicator struct {
subscription string
primaryDns string
flushInterval time.Duration

Expand All @@ -60,8 +60,9 @@ type LogicalReplicator struct {
// NewLogicalReplicator creates a new logical replicator instance which connects to the primary and replication
// databases using the connection strings provided. The connection to the replica is established immediately, and the
// connection to the primary is established when StartReplication is called.
func NewLogicalReplicator(primaryDns string) (*LogicalReplicator, error) {
func NewLogicalReplicator(subscription, primaryDns string) (*LogicalReplicator, error) {
return &LogicalReplicator{
subscription: subscription,
primaryDns: primaryDns,
flushInterval: 200 * time.Millisecond,
mu: &sync.Mutex{},
Expand Down Expand Up @@ -222,7 +223,7 @@ func (r *LogicalReplicator) StartReplication(sqlCtx *sql.Context, slotName strin
standbyMessageTimeout := 10 * time.Second
nextStandbyMessageDeadline := time.Now().Add(standbyMessageTimeout)

lastWrittenLsn, err := r.readWALPosition(sqlCtx, slotName)
lastWrittenLsn, err := SelectSubscriptionLsn(sqlCtx, r.subscription)
if err != nil {
return err
}
Expand Down Expand Up @@ -881,26 +882,6 @@ func (r *LogicalReplicator) processMessage(
return false, nil
}

// readWALPosition reads the recorded WAL position from the WAL position table
func (r *LogicalReplicator) readWALPosition(ctx *sql.Context, slotName string) (pglogrepl.LSN, error) {
var lsn string
if err := adapter.QueryRowCatalog(ctx, catalog.InternalTables.PgReplicationLSN.SelectStmt(), slotName).Scan(&lsn); err != nil {
if errors.Is(err, stdsql.ErrNoRows) {
// if the LSN doesn't exist, consider this a cold start and return 0
return pglogrepl.LSN(0), nil
}
return 0, err
}

return pglogrepl.ParseLSN(lsn)
}

// WriteWALPosition writes the recorded WAL position to the WAL position table
func (r *LogicalReplicator) WriteWALPosition(ctx *sql.Context, slotName string, lsn pglogrepl.LSN) error {
_, err := adapter.ExecCatalogInTxn(ctx, catalog.InternalTables.PgReplicationLSN.UpsertStmt(), slotName, lsn.String())
return err
}

// whereClause returns a WHERE clause string with the contents of the builder if it's non-empty, or the empty
// string otherwise
func whereClause(str strings.Builder) string {
Expand Down Expand Up @@ -1000,7 +981,7 @@ func (r *LogicalReplicator) commitOngoingTxn(state *replicationState, flushReaso
}

r.logger.Debugf("Writing LSN %s\n", state.lastCommitLSN)
if err = r.WriteWALPosition(state.replicaCtx, state.slotName, state.lastCommitLSN); err != nil {
if err = UpdateSubscriptionLsn(state.replicaCtx, state.lastCommitLSN.String(), r.subscription); err != nil {
return err
}

Expand Down
19 changes: 16 additions & 3 deletions pgserver/logrepl/replication_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ package logrepl_test
import (
"context"
"fmt"
"github.com/apecloud/myduckserver/adapter"
"github.com/jackc/pglogrepl"
"log"
"os"
"os/exec"
Expand Down Expand Up @@ -660,6 +662,7 @@ func RunReplicationScripts(t *testing.T, scripts []ReplicationTest) {
}

const slotName = "myduck_slot"
const subscriptionName = "my_sub_test"

// RunReplicationScript runs the given ReplicationTest.
func RunReplicationScript(t *testing.T, dsn string, script ReplicationTest) {
Expand All @@ -686,8 +689,18 @@ func RunReplicationScript(t *testing.T, dsn string, script ReplicationTest) {
})
}

func newReplicator(t *testing.T, primaryDns string) *logrepl.LogicalReplicator {
r, err := logrepl.NewLogicalReplicator(primaryDns)
func newReplicator(sqlCtx *sql.Context, t *testing.T, primaryDns string) *logrepl.LogicalReplicator {
err := logrepl.CreateSubscription(sqlCtx, subscriptionName, primaryDns, slotName, pglogrepl.LSN(0).String(), true)
require.NoError(t, err)

tx := adapter.TryGetTxn(sqlCtx)
if tx != nil {
err := tx.Commit()
require.NoError(t, err)
adapter.CloseTxn(sqlCtx)
}

r, err := logrepl.NewLogicalReplicator(subscriptionName, primaryDns)
require.NoError(t, err)
return r
}
Expand All @@ -701,7 +714,7 @@ func runReplicationScript(
replicaConn *pgx.Conn,
primaryDns string,
) {
r := newReplicator(t, primaryDns)
r := newReplicator(server.NewInternalCtx(), t, primaryDns)
defer r.Stop()

if script.Skip {
Expand Down
Loading

0 comments on commit 3c612a8

Please sign in to comment.