Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: support for ALTER/DROP SUSCRIPTION #253

Merged
merged 14 commits into from
Dec 7, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
72 changes: 60 additions & 12 deletions catalog/internal_tables.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,29 @@ func (it *InternalTable) QualifiedName() string {
return it.Schema + "." + it.Name
}

func (it *InternalTable) UpdateStmt(keyColumns []string, valueColumns []string) string {
var b strings.Builder
b.Grow(128)
b.WriteString("UPDATE ")
b.WriteString(it.QualifiedName())
b.WriteString(" SET " + valueColumns[0] + " = ?")

for _, valueColumn := range valueColumns[1:] {
b.WriteString(", ")
b.WriteString(valueColumn)
b.WriteString(" = ?")
}

b.WriteString(" WHERE " + keyColumns[0] + " = ?")
for _, keyColumn := range keyColumns[1:] {
b.WriteString(", ")
b.WriteString(keyColumn)
b.WriteString(" = ?")
}

return b.String()
}

func (it *InternalTable) UpsertStmt() string {
var b strings.Builder
b.Grow(128)
Expand Down Expand Up @@ -50,6 +73,16 @@ func (it *InternalTable) DeleteStmt() string {
return b.String()
}

func (it *InternalTable) DeleteAllStmt() string {
var b strings.Builder
b.Grow(128)
b.WriteString("DELETE FROM ")
b.WriteString(it.Schema)
b.WriteByte('.')
b.WriteString(it.Name)
return b.String()
}

func (it *InternalTable) SelectStmt() string {
var b strings.Builder
b.Grow(128)
Expand All @@ -74,6 +107,30 @@ func (it *InternalTable) SelectStmt() string {
return b.String()
}

func (it *InternalTable) SelectColumnsStmt(valueColumns []string) string {
var b strings.Builder
b.Grow(128)
b.WriteString("SELECT ")
b.WriteString(valueColumns[0])
for _, c := range valueColumns[1:] {
b.WriteString(", ")
b.WriteString(c)
}
b.WriteString(" FROM ")
b.WriteString(it.Schema)
b.WriteByte('.')
b.WriteString(it.Name)
b.WriteString(" WHERE ")
b.WriteString(it.KeyColumns[0])
b.WriteString(" = ?")
for _, c := range it.KeyColumns[1:] {
b.WriteString(" AND ")
b.WriteString(c)
b.WriteString(" = ?")
}
return b.String()
}

func (it *InternalTable) SelectAllStmt() string {
var b strings.Builder
b.Grow(128)
Expand All @@ -98,7 +155,6 @@ func (it *InternalTable) CountAllStmt() string {
var InternalTables = struct {
PersistentVariable InternalTable
BinlogPosition InternalTable
PgReplicationLSN InternalTable
PgSubscription InternalTable
GlobalStatus InternalTable
// TODO(sean): This is a temporary work around for clients that query the 'pg_catalog.pg_stat_replication'.
Expand All @@ -120,19 +176,12 @@ var InternalTables = struct {
ValueColumns: []string{"position"},
DDL: "channel TEXT PRIMARY KEY, position TEXT",
},
PgReplicationLSN: InternalTable{
Schema: "__sys__",
Name: "pg_replication_lsn",
KeyColumns: []string{"slot_name"},
ValueColumns: []string{"lsn"},
DDL: "slot_name TEXT PRIMARY KEY, lsn TEXT",
},
PgSubscription: InternalTable{
Schema: "__sys__",
Name: "pg_subscription",
KeyColumns: []string{"name"},
ValueColumns: []string{"connection", "publication"},
DDL: "name TEXT PRIMARY KEY, connection TEXT, publication TEXT",
KeyColumns: []string{"subname"},
ValueColumns: []string{"subconninfo", "subpublication", "subskiplsn", "subenabled"},
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The sub prefix for all column names can be removed.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sub is chosen to maintain consistency with the pg_subscription catalog in PostgreSQL.

DDL: "subname TEXT PRIMARY KEY, subconninfo TEXT, subpublication TEXT, subskiplsn TEXT, subenabled BOOLEAN",
},
GlobalStatus: InternalTable{
Schema: "performance_schema",
Expand Down Expand Up @@ -227,7 +276,6 @@ var InternalTables = struct {
var internalTables = []InternalTable{
InternalTables.PersistentVariable,
InternalTables.BinlogPosition,
InternalTables.PgReplicationLSN,
InternalTables.PgSubscription,
InternalTables.GlobalStatus,
InternalTables.PGStatReplication,
Expand Down
10 changes: 2 additions & 8 deletions main.go
Original file line number Diff line number Diff line change
Expand Up @@ -182,15 +182,9 @@ func main() {
}

// Check if there is a replication subscription and start replication if there is.
_, conn, pub, ok, err := logrepl.FindReplication(pool.DB)
err = logrepl.UpdateSubscriptions(pgServer.NewInternalCtx())
if err != nil {
logrus.WithError(err).Warnln("Failed to find replication")
} else if ok {
replicator, err := logrepl.NewLogicalReplicator(conn)
if err != nil {
logrus.WithError(err).Fatalln("Failed to create logical replicator")
}
go replicator.StartReplication(pgServer.NewInternalCtx(), pub)
logrus.WithError(err).Warnln("Failed to update subscriptions")
}

// Load the configuration for the Postgres server.
Expand Down
2 changes: 1 addition & 1 deletion pgserver/connection_handler.go
Original file line number Diff line number Diff line change
Expand Up @@ -982,7 +982,7 @@ func (h *ConnectionHandler) query(query ConvertedQuery) error {
}

if query.SubscriptionConfig != nil {
return h.executeCreateSubscriptionSQL(query.SubscriptionConfig)
return h.executeSubscriptionSQL(query.SubscriptionConfig)
}

callback := h.spoolRowsCallback(query.StatementTag, &rowsAffected, false)
Expand Down
29 changes: 5 additions & 24 deletions pgserver/logrepl/replication.go
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@ import (

"github.com/apecloud/myduckserver/adapter"
"github.com/apecloud/myduckserver/binlog"
"github.com/apecloud/myduckserver/catalog"
"github.com/apecloud/myduckserver/delta"
"github.com/apecloud/myduckserver/pgtypes"
"github.com/dolthub/go-mysql-server/sql"
Expand All @@ -46,6 +45,7 @@ type rcvMsg struct {
}

type LogicalReplicator struct {
subscription string
primaryDns string
flushInterval time.Duration

Expand All @@ -60,8 +60,9 @@ type LogicalReplicator struct {
// NewLogicalReplicator creates a new logical replicator instance which connects to the primary and replication
// databases using the connection strings provided. The connection to the replica is established immediately, and the
// connection to the primary is established when StartReplication is called.
func NewLogicalReplicator(primaryDns string) (*LogicalReplicator, error) {
func NewLogicalReplicator(subscription, primaryDns string) (*LogicalReplicator, error) {
return &LogicalReplicator{
subscription: subscription,
primaryDns: primaryDns,
flushInterval: 200 * time.Millisecond,
mu: &sync.Mutex{},
Expand Down Expand Up @@ -222,7 +223,7 @@ func (r *LogicalReplicator) StartReplication(sqlCtx *sql.Context, slotName strin
standbyMessageTimeout := 10 * time.Second
nextStandbyMessageDeadline := time.Now().Add(standbyMessageTimeout)

lastWrittenLsn, err := r.readWALPosition(sqlCtx, slotName)
lastWrittenLsn, err := SelectSubscriptionLsn(sqlCtx, r.subscription)
if err != nil {
return err
}
Expand Down Expand Up @@ -881,26 +882,6 @@ func (r *LogicalReplicator) processMessage(
return false, nil
}

// readWALPosition reads the recorded WAL position from the WAL position table
func (r *LogicalReplicator) readWALPosition(ctx *sql.Context, slotName string) (pglogrepl.LSN, error) {
var lsn string
if err := adapter.QueryRowCatalog(ctx, catalog.InternalTables.PgReplicationLSN.SelectStmt(), slotName).Scan(&lsn); err != nil {
if errors.Is(err, stdsql.ErrNoRows) {
// if the LSN doesn't exist, consider this a cold start and return 0
return pglogrepl.LSN(0), nil
}
return 0, err
}

return pglogrepl.ParseLSN(lsn)
}

// WriteWALPosition writes the recorded WAL position to the WAL position table
func (r *LogicalReplicator) WriteWALPosition(ctx *sql.Context, slotName string, lsn pglogrepl.LSN) error {
_, err := adapter.ExecCatalogInTxn(ctx, catalog.InternalTables.PgReplicationLSN.UpsertStmt(), slotName, lsn.String())
return err
}

// whereClause returns a WHERE clause string with the contents of the builder if it's non-empty, or the empty
// string otherwise
func whereClause(str strings.Builder) string {
Expand Down Expand Up @@ -1000,7 +981,7 @@ func (r *LogicalReplicator) commitOngoingTxn(state *replicationState, flushReaso
}

r.logger.Debugf("Writing LSN %s\n", state.lastCommitLSN)
if err = r.WriteWALPosition(state.replicaCtx, state.slotName, state.lastCommitLSN); err != nil {
if err = UpdateSubscriptionLsn(state.replicaCtx, state.lastCommitLSN.String(), r.subscription); err != nil {
return err
}

Expand Down
19 changes: 16 additions & 3 deletions pgserver/logrepl/replication_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ package logrepl_test
import (
"context"
"fmt"
"github.com/apecloud/myduckserver/adapter"
"github.com/jackc/pglogrepl"
"log"
"os"
"os/exec"
Expand Down Expand Up @@ -660,6 +662,7 @@ func RunReplicationScripts(t *testing.T, scripts []ReplicationTest) {
}

const slotName = "myduck_slot"
const subscriptionName = "my_sub_test"

// RunReplicationScript runs the given ReplicationTest.
func RunReplicationScript(t *testing.T, dsn string, script ReplicationTest) {
Expand All @@ -686,8 +689,18 @@ func RunReplicationScript(t *testing.T, dsn string, script ReplicationTest) {
})
}

func newReplicator(t *testing.T, primaryDns string) *logrepl.LogicalReplicator {
r, err := logrepl.NewLogicalReplicator(primaryDns)
func newReplicator(sqlCtx *sql.Context, t *testing.T, primaryDns string) *logrepl.LogicalReplicator {
err := logrepl.CreateSubscription(sqlCtx, subscriptionName, primaryDns, slotName, pglogrepl.LSN(0).String(), true)
require.NoError(t, err)

tx := adapter.TryGetTxn(sqlCtx)
if tx != nil {
err := tx.Commit()
require.NoError(t, err)
adapter.CloseTxn(sqlCtx)
}

r, err := logrepl.NewLogicalReplicator(subscriptionName, primaryDns)
require.NoError(t, err)
return r
}
Expand All @@ -701,7 +714,7 @@ func runReplicationScript(
replicaConn *pgx.Conn,
primaryDns string,
) {
r := newReplicator(t, primaryDns)
r := newReplicator(server.NewInternalCtx(), t, primaryDns)
defer r.Stop()

if script.Skip {
Expand Down
Loading
Loading