Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions config.example.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ postgresql:
test_only: false # 仅测试连接,不执行转换
max_conns: 50 # 连接池配置的最大连接数,从20提升到50
pg_connection_params: search_path=public connect_timeout=300 statement_timeout=0 # PostgreSQL连接参数
password_encryption: auto # 密码加密方式:md5, scram-sha-256, auto(默认)

# 转换配置
conversion:
Expand Down
1 change: 1 addition & 0 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,7 @@ type PostgreSQLConfig struct {
TestOnly bool `mapstructure:"test_only"`
MaxConns int `mapstructure:"max_conns"` // 最大连接数
PgConnectionParams string `mapstructure:"pg_connection_params"` // PostgreSQL连接参数
PasswordEncryption string `mapstructure:"password_encryption"` // 密码加密方式:md5, scram-sha-256, auto
}

// ConversionConfig 转换配置
Expand Down
97 changes: 56 additions & 41 deletions internal/postgres/connection.go
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,7 @@ func (p *PostgreSQLVersionInfo) IsVersionGreaterOrEqual(major, minor int) bool {
return false
}

// Connection PostgreSQL连接管理器
// Connection PostgreSQL 连接管理器
type Connection struct {
pool *pgxpool.Pool
config *config.PostgreSQLConfig
Expand Down Expand Up @@ -199,22 +199,37 @@ func getTypedValue(dest *typedDest) interface{} {
}
}

// NewConnection 创建新的PostgreSQL连接
// NewConnection 创建新的 PostgreSQL 连接
func NewConnection(config *config.PostgreSQLConfig) (*Connection, error) {
ctx := context.Background()

// 使用无压缩连接
// 构建基础连接字符串
connStr := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable",
config.Host, config.Port, config.Username, config.Password, config.Database)

// 根据密码加密方式添加参数
// PostgreSQL 支持两种密码加密方式:
// - md5: 传统加密方式,兼容性最好
// - scram-sha-256: 更安全的加密方式(PostgreSQL 10+ 默认)
// - auto: 自动选择(默认,由 PostgreSQL 服务器决定)
if config.PasswordEncryption != "" {
switch strings.ToLower(config.PasswordEncryption) {
case "md5":
connStr += " password_encryption=md5"
case "scram-sha-256":
connStr += " password_encryption=scram-sha-256"
// "auto" 或不设置时使用 PostgreSQL 默认行为
}
}

// 添加连接参数
if config.PgConnectionParams != "" {
connStr += " " + config.PgConnectionParams
}

poolConfig, err := pgxpool.ParseConfig(connStr)
if err != nil {
return nil, fmt.Errorf("解析PostgreSQL连接配置失败: %w", err)
return nil, fmt.Errorf("解析 PostgreSQL 连接配置失败:%w", err)
}

// 设置连接池大小
Expand All @@ -223,12 +238,12 @@ func NewConnection(config *config.PostgreSQLConfig) (*Connection, error) {
// 创建连接池
pool, err := pgxpool.NewWithConfig(ctx, poolConfig)
if err != nil {
return nil, fmt.Errorf("创建PostgreSQL连接池失败: %w", err)
return nil, fmt.Errorf("创建 PostgreSQL 连接池失败:%w", err)
}

// 测试连接
if err := pool.Ping(ctx); err != nil {
return nil, fmt.Errorf("PostgreSQL连接测试失败: %w", err)
return nil, fmt.Errorf("PostgreSQL 连接测试失败:%w", err)
}

return &Connection{
Expand All @@ -253,18 +268,18 @@ func (c *Connection) BeginTransaction(ctx context.Context) (pgx.Tx, error) {
return c.pool.Begin(ctx)
}

// ExecuteDDL 执行DDL语句
// ExecuteDDL 执行 DDL 语句
func (c *Connection) ExecuteDDL(ddl string) error {
ctx := context.Background()
execDDL := sanitizeDDLForExecution(ddl)
_, err := c.pool.Exec(ctx, execDDL)
if err != nil {
return fmt.Errorf("执行DDL失败: %w, PostgreSQL SQL: %s", err, execDDL)
return fmt.Errorf("执行 DDL 失败:%w, PostgreSQL SQL: %s", err, execDDL)
}
return err
}

// ExecuteDDLWithTransaction 在事务中执行DDL语句
// ExecuteDDLWithTransaction 在事务中执行 DDL 语句
func (c *Connection) ExecuteDDLWithTransaction(tx pgx.Tx, ddl string) error {
execDDL := sanitizeDDLForExecution(ddl)
_, err := tx.Exec(context.Background(), execDDL)
Expand Down Expand Up @@ -308,13 +323,13 @@ func (c *Connection) InsertData(tableName string, columns []string, rows *sql.Ro

// 扫描行数据
if err := rows.Scan(valuePtrs...); err != nil {
return fmt.Errorf("扫描行数据失败: %w", err)
return fmt.Errorf("扫描行数据失败%w", err)
}

// 执行插入
_, err := c.pool.Exec(ctx, query, values...)
if err != nil {
return fmt.Errorf("执行插入失败: %w", err)
return fmt.Errorf("执行插入失败%w", err)
}
}

Expand Down Expand Up @@ -354,13 +369,13 @@ func (c *Connection) InsertDataWithTransaction(tx pgx.Tx, tableName string, colu

// 扫描行数据
if err := rows.Scan(valuePtrs...); err != nil {
return fmt.Errorf("扫描行数据失败: %w", err)
return fmt.Errorf("扫描行数据失败%w", err)
}

// 执行插入
_, err := tx.Exec(ctx, query, values...)
if err != nil {
return fmt.Errorf("执行插入失败: %w", err)
return fmt.Errorf("执行插入失败%w", err)
}
}

Expand All @@ -382,7 +397,7 @@ func (c *Connection) BatchInsertDataWithTransaction(tx pgx.Tx, tableName string,
var batchValues []interface{}
var rowCount int

// 严格使用传入的batchSize参数,不使用硬编码默认值
// 严格使用传入的 batchSize 参数,不使用硬编码默认值
effectiveBatchSize := batchSize
if effectiveBatchSize <= 0 {
effectiveBatchSize = 10000 // 确保至少有一个合理的默认值
Expand All @@ -391,7 +406,7 @@ func (c *Connection) BatchInsertDataWithTransaction(tx pgx.Tx, tableName string,
// 预分配切片容量,减少内存分配
batchValues = make([]interface{}, 0, effectiveBatchSize*len(columns))

// 重用values和valuePtrs切片,减少内存分配
// 重用 values 和 valuePtrs 切片,减少内存分配
values := make([]interface{}, len(columns))
valuePtrs := make([]interface{}, len(columns))
for i := range values {
Expand All @@ -402,7 +417,7 @@ func (c *Connection) BatchInsertDataWithTransaction(tx pgx.Tx, tableName string,
for rows.Next() {
// 扫描行数据
if err := rows.Scan(valuePtrs...); err != nil {
return fmt.Errorf("扫描行数据失败: %w", err)
return fmt.Errorf("扫描行数据失败%w", err)
}

// 添加到批量值中
Expand Down Expand Up @@ -435,9 +450,9 @@ func (c *Connection) BatchInsertDataWithTransaction(tx pgx.Tx, tableName string,

// executeBatchInsert 执行批量插入操作
func (c *Connection) executeBatchInsert(tx pgx.Tx, ctx context.Context, tableName, columnsStr string, columns []string, values []interface{}) error {
// 计算批次大小,确保总参数数量不超过PostgreSQL的限制(65535)
// 计算批次大小,确保总参数数量不超过 PostgreSQL 的限制 (65535)
columnCount := len(columns)
// 计算每个批次的最大行数,确保总参数数量不超过65535
// 计算每个批次的最大行数,确保总参数数量不超过 65535
maxRowsPerBatch := 65535 / columnCount
if maxRowsPerBatch == 0 {
maxRowsPerBatch = 1 // 确保至少有一行
Expand All @@ -460,7 +475,7 @@ func (c *Connection) executeBatchInsert(tx pgx.Tx, ctx context.Context, tableNam
// 获取当前批次的值
batchValues := values[startIdx:endIdx]

// 构建VALUES部分
// 构建 VALUES 部分
var valuesParts strings.Builder
// 预分配更大的内存
valuesParts.Grow((end - i) * (columnCount*4 + 5)) // 增加预分配空间
Expand All @@ -481,7 +496,7 @@ func (c *Connection) executeBatchInsert(tx pgx.Tx, ctx context.Context, tableNam
valuesParts.WriteString(")")
}

// 构建完整的SQL语句
// 构建完整的 SQL 语句
query := fmt.Sprintf("INSERT INTO \"%s\" (%s) VALUES %s", tableName, columnsStr, valuesParts.String())

// 执行批量插入
Expand All @@ -496,20 +511,20 @@ func (c *Connection) executeBatchInsert(tx pgx.Tx, ctx context.Context, tableNam
for j := 0; j < sampleSize; j++ {
samples = append(samples, fmt.Sprintf("%v", batchValues[j]))
}
return fmt.Errorf("批量插入失败: %w, 数据样本: %v", err, samples)
return fmt.Errorf("批量插入失败%w, 数据样本%v", err, samples)
}
}

return nil
}

// GetVersion 获取PostgreSQL版本信息
// GetVersion 获取 PostgreSQL 版本信息
func (c *Connection) GetVersion() (string, error) {
ctx := context.Background()
var version string
err := c.pool.QueryRow(ctx, "SELECT version()").Scan(&version)
if err != nil {
return "", fmt.Errorf("获取PostgreSQL版本失败: %w", err)
return "", fmt.Errorf("获取 PostgreSQL 版本失败:%w", err)
}
return version, nil
}
Expand Down Expand Up @@ -552,12 +567,12 @@ func ParsePostgreSQLVersion(version string) *PostgreSQLVersionInfo {
return info
}

// TestConnection 测试PostgreSQL连接
// TestConnection 测试 PostgreSQL 连接
func TestConnection(config *config.PostgreSQLConfig) error {
// 测试连接时不使用压缩
conn, err := NewConnection(config)
if err != nil {
return fmt.Errorf("PostgreSQL连接测试失败: %w", err)
return fmt.Errorf("PostgreSQL 连接测试失败:%w", err)
}
defer conn.Close()

Expand All @@ -578,7 +593,7 @@ func (c *Connection) TableExists(tableName string) (bool, error) {
var exists bool
err := c.pool.QueryRow(ctx, query, tableName).Scan(&exists)
if err != nil {
return false, fmt.Errorf("检查表是否存在失败: %w", err)
return false, fmt.Errorf("检查表是否存在失败%w", err)
}
return exists, nil
}
Expand All @@ -595,7 +610,7 @@ func (c *Connection) GrantTablePrivileges(user, tableName string, privileges []s

_, err := c.pool.Exec(ctx, query)
if err != nil {
return fmt.Errorf("授予表权限失败: %w", err)
return fmt.Errorf("授予表权限失败%w", err)
}

return nil
Expand All @@ -619,15 +634,15 @@ func (c *Connection) GetTablePrivileges(tableName string) ([]map[string]string,

rows, err := c.pool.Query(ctx, query, tableName)
if err != nil {
return nil, fmt.Errorf("获取表权限失败: %w", err)
return nil, fmt.Errorf("获取表权限失败%w", err)
}
defer rows.Close()

var privileges []map[string]string
for rows.Next() {
var user, privilege, isGrantable string
if err := rows.Scan(&user, &privilege, &isGrantable); err != nil {
return nil, fmt.Errorf("扫描表权限信息失败: %w", err)
return nil, fmt.Errorf("扫描表权限信息失败%w", err)
}

privileges = append(privileges, map[string]string{
Expand All @@ -648,7 +663,7 @@ func (c *Connection) GetTableRowCount(tableName string) (int64, error) {
var count int64
err := c.pool.QueryRow(ctx, query).Scan(&count)
if err != nil {
return 0, fmt.Errorf("获取表 %s 行数失败: %w", tableName, err)
return 0, fmt.Errorf("获取表 %s 行数失败%w", tableName, err)
}

return count, nil
Expand Down Expand Up @@ -753,7 +768,7 @@ func (c *Connection) BatchInsertDataWithTransactionAndGetLastValue(tx pgx.Tx, ta
var rowCount int
var totalRows int

// 严格使用传入的batchSize参数,不使用硬编码默认值
// 严格使用传入的 batchSize 参数,不使用硬编码默认值
effectiveBatchSize := batchSize
if effectiveBatchSize <= 0 {
effectiveBatchSize = 10000 // 确保至少有一个合理的默认值
Expand Down Expand Up @@ -787,7 +802,7 @@ func (c *Connection) BatchInsertDataWithTransactionAndGetLastValue(tx pgx.Tx, ta
for rows.Next() {
// 扫描行数据 — 使用类型化指针,避免每次 Scan 分配堆对象
if err := rows.Scan(scanPtrs...); err != nil {
return 0, nil, fmt.Errorf("扫描行数据失败: %w", err)
return 0, nil, fmt.Errorf("扫描行数据失败%w", err)
}

// 跟踪最后一个主键值
Expand All @@ -806,12 +821,12 @@ func (c *Connection) BatchInsertDataWithTransactionAndGetLastValue(tx pgx.Tx, ta
rowCount++
totalRows++

// 当达到批量大小时执行CopyFrom
// 当达到批量大小时执行 CopyFrom
if rowCount == effectiveBatchSize {
// 执行CopyFrom,使用转换后的小写列名
// 执行 CopyFrom,使用转换后的小写列名
_, err := tx.CopyFrom(ctx, pgx.Identifier{tableName}, copyColumns, pgx.CopyFromRows(copyRows))
if err != nil {
return 0, nil, fmt.Errorf("CopyFrom执行失败: %w", err)
return 0, nil, fmt.Errorf("CopyFrom 执行失败:%w", err)
}

// 将 rowValues 切片返回 pool 复用
Expand All @@ -827,10 +842,10 @@ func (c *Connection) BatchInsertDataWithTransactionAndGetLastValue(tx pgx.Tx, ta

// 执行剩余的数据
if len(copyRows) > 0 {
// 执行CopyFrom,使用转换后的小写列名
// 执行 CopyFrom,使用转换后的小写列名
_, err := tx.CopyFrom(ctx, pgx.Identifier{tableName}, copyColumns, pgx.CopyFromRows(copyRows))
if err != nil {
return 0, nil, fmt.Errorf("CopyFrom执行失败: %w", err)
return 0, nil, fmt.Errorf("CopyFrom 执行失败:%w", err)
}
// 将 rowValues 切片返回 pool 复用
for _, rv := range copyRows {
Expand All @@ -842,19 +857,19 @@ func (c *Connection) BatchInsertDataWithTransactionAndGetLastValue(tx pgx.Tx, ta
return 0, nil, err
}

// 只有在没有找到主键值的情况下,才执行MAX查询(作为后备方案)
// 只有在没有找到主键值的情况下,才执行 MAX 查询(作为后备方案)
if resolvedPrimaryKey != "" && lastValue == nil {
query := fmt.Sprintf("SELECT MAX(\"%s\") FROM \"%s\"", resolvedPrimaryKey, tableName)
err := tx.QueryRow(ctx, query).Scan(&lastValue)
if err != nil && err != pgx.ErrNoRows {
return 0, nil, fmt.Errorf("获取最后一个主键值失败: %w", err)
return 0, nil, fmt.Errorf("获取最后一个主键值失败%w", err)
}
}

return totalRows, lastValue, nil
}

// parseMySQLPoint 解析MySQL的WKB格式Point数据
// parseMySQLPoint 解析 MySQL 的 WKB 格式 Point 数据
func parseMySQLPoint(data []byte) (string, error) {
// MySQL Geometry Header (4 bytes SRID) + WKB (1 byte order + 4 bytes type + 16 bytes coords)
// SRID (4) + Order (1) + Type (4) + X (8) + Y (8) = 25 bytes
Expand Down Expand Up @@ -890,7 +905,7 @@ func parseMySQLPoint(data []byte) (string, error) {
y = math.Float64frombits(yBits)
}

// 格式化为PostgreSQL Point格式 (x,y)
// 格式化为 PostgreSQL Point 格式 (x,y)
return fmt.Sprintf("(%v,%v)", x, y), nil
}

Expand Down
Loading