diff --git a/config.example.yml b/config.example.yml index 8e36640..ff0252a 100644 --- a/config.example.yml +++ b/config.example.yml @@ -21,6 +21,7 @@ postgresql: test_only: false # 仅测试连接,不执行转换 max_conns: 50 # 连接池配置的最大连接数,从20提升到50 pg_connection_params: search_path=public connect_timeout=300 statement_timeout=0 # PostgreSQL连接参数 + password_encryption: auto # 密码加密方式:md5, scram-sha-256, auto(默认) # 转换配置 conversion: diff --git a/internal/config/config.go b/internal/config/config.go index 029d193..7aa080f 100644 --- a/internal/config/config.go +++ b/internal/config/config.go @@ -44,6 +44,7 @@ type PostgreSQLConfig struct { TestOnly bool `mapstructure:"test_only"` MaxConns int `mapstructure:"max_conns"` // 最大连接数 PgConnectionParams string `mapstructure:"pg_connection_params"` // PostgreSQL连接参数 + PasswordEncryption string `mapstructure:"password_encryption"` // 密码加密方式:md5, scram-sha-256, auto } // ConversionConfig 转换配置 diff --git a/internal/postgres/connection.go b/internal/postgres/connection.go index d0a07c3..84dfcc7 100644 --- a/internal/postgres/connection.go +++ b/internal/postgres/connection.go @@ -39,7 +39,7 @@ func (p *PostgreSQLVersionInfo) IsVersionGreaterOrEqual(major, minor int) bool { return false } -// Connection PostgreSQL连接管理器 +// Connection PostgreSQL 连接管理器 type Connection struct { pool *pgxpool.Pool config *config.PostgreSQLConfig @@ -199,14 +199,29 @@ func getTypedValue(dest *typedDest) interface{} { } } -// NewConnection 创建新的PostgreSQL连接 +// NewConnection 创建新的 PostgreSQL 连接 func NewConnection(config *config.PostgreSQLConfig) (*Connection, error) { ctx := context.Background() - // 使用无压缩连接 + // 构建基础连接字符串 connStr := fmt.Sprintf("host=%s port=%d user=%s password=%s dbname=%s sslmode=disable", config.Host, config.Port, config.Username, config.Password, config.Database) + // 根据密码加密方式添加参数 + // PostgreSQL 支持两种密码加密方式: + // - md5: 传统加密方式,兼容性最好 + // - scram-sha-256: 更安全的加密方式(PostgreSQL 10+ 默认) + // - auto: 自动选择(默认,由 PostgreSQL 服务器决定) + if config.PasswordEncryption != "" { + switch strings.ToLower(config.PasswordEncryption) { + case "md5": + connStr += " password_encryption=md5" + case "scram-sha-256": + connStr += " password_encryption=scram-sha-256" + // "auto" 或不设置时使用 PostgreSQL 默认行为 + } + } + // 添加连接参数 if config.PgConnectionParams != "" { connStr += " " + config.PgConnectionParams @@ -214,7 +229,7 @@ func NewConnection(config *config.PostgreSQLConfig) (*Connection, error) { poolConfig, err := pgxpool.ParseConfig(connStr) if err != nil { - return nil, fmt.Errorf("解析PostgreSQL连接配置失败: %w", err) + return nil, fmt.Errorf("解析 PostgreSQL 连接配置失败:%w", err) } // 设置连接池大小 @@ -223,12 +238,12 @@ func NewConnection(config *config.PostgreSQLConfig) (*Connection, error) { // 创建连接池 pool, err := pgxpool.NewWithConfig(ctx, poolConfig) if err != nil { - return nil, fmt.Errorf("创建PostgreSQL连接池失败: %w", err) + return nil, fmt.Errorf("创建 PostgreSQL 连接池失败:%w", err) } // 测试连接 if err := pool.Ping(ctx); err != nil { - return nil, fmt.Errorf("PostgreSQL连接测试失败: %w", err) + return nil, fmt.Errorf("PostgreSQL 连接测试失败:%w", err) } return &Connection{ @@ -253,18 +268,18 @@ func (c *Connection) BeginTransaction(ctx context.Context) (pgx.Tx, error) { return c.pool.Begin(ctx) } -// ExecuteDDL 执行DDL语句 +// ExecuteDDL 执行 DDL 语句 func (c *Connection) ExecuteDDL(ddl string) error { ctx := context.Background() execDDL := sanitizeDDLForExecution(ddl) _, err := c.pool.Exec(ctx, execDDL) if err != nil { - return fmt.Errorf("执行DDL失败: %w, PostgreSQL SQL: %s", err, execDDL) + return fmt.Errorf("执行 DDL 失败:%w, PostgreSQL SQL: %s", err, execDDL) } return err } -// ExecuteDDLWithTransaction 在事务中执行DDL语句 +// ExecuteDDLWithTransaction 在事务中执行 DDL 语句 func (c *Connection) ExecuteDDLWithTransaction(tx pgx.Tx, ddl string) error { execDDL := sanitizeDDLForExecution(ddl) _, err := tx.Exec(context.Background(), execDDL) @@ -308,13 +323,13 @@ func (c *Connection) InsertData(tableName string, columns []string, rows *sql.Ro // 扫描行数据 if err := rows.Scan(valuePtrs...); err != nil { - return fmt.Errorf("扫描行数据失败: %w", err) + return fmt.Errorf("扫描行数据失败:%w", err) } // 执行插入 _, err := c.pool.Exec(ctx, query, values...) if err != nil { - return fmt.Errorf("执行插入失败: %w", err) + return fmt.Errorf("执行插入失败:%w", err) } } @@ -354,13 +369,13 @@ func (c *Connection) InsertDataWithTransaction(tx pgx.Tx, tableName string, colu // 扫描行数据 if err := rows.Scan(valuePtrs...); err != nil { - return fmt.Errorf("扫描行数据失败: %w", err) + return fmt.Errorf("扫描行数据失败:%w", err) } // 执行插入 _, err := tx.Exec(ctx, query, values...) if err != nil { - return fmt.Errorf("执行插入失败: %w", err) + return fmt.Errorf("执行插入失败:%w", err) } } @@ -382,7 +397,7 @@ func (c *Connection) BatchInsertDataWithTransaction(tx pgx.Tx, tableName string, var batchValues []interface{} var rowCount int - // 严格使用传入的batchSize参数,不使用硬编码默认值 + // 严格使用传入的 batchSize 参数,不使用硬编码默认值 effectiveBatchSize := batchSize if effectiveBatchSize <= 0 { effectiveBatchSize = 10000 // 确保至少有一个合理的默认值 @@ -391,7 +406,7 @@ func (c *Connection) BatchInsertDataWithTransaction(tx pgx.Tx, tableName string, // 预分配切片容量,减少内存分配 batchValues = make([]interface{}, 0, effectiveBatchSize*len(columns)) - // 重用values和valuePtrs切片,减少内存分配 + // 重用 values 和 valuePtrs 切片,减少内存分配 values := make([]interface{}, len(columns)) valuePtrs := make([]interface{}, len(columns)) for i := range values { @@ -402,7 +417,7 @@ func (c *Connection) BatchInsertDataWithTransaction(tx pgx.Tx, tableName string, for rows.Next() { // 扫描行数据 if err := rows.Scan(valuePtrs...); err != nil { - return fmt.Errorf("扫描行数据失败: %w", err) + return fmt.Errorf("扫描行数据失败:%w", err) } // 添加到批量值中 @@ -435,9 +450,9 @@ func (c *Connection) BatchInsertDataWithTransaction(tx pgx.Tx, tableName string, // executeBatchInsert 执行批量插入操作 func (c *Connection) executeBatchInsert(tx pgx.Tx, ctx context.Context, tableName, columnsStr string, columns []string, values []interface{}) error { - // 计算批次大小,确保总参数数量不超过PostgreSQL的限制(65535) + // 计算批次大小,确保总参数数量不超过 PostgreSQL 的限制 (65535) columnCount := len(columns) - // 计算每个批次的最大行数,确保总参数数量不超过65535 + // 计算每个批次的最大行数,确保总参数数量不超过 65535 maxRowsPerBatch := 65535 / columnCount if maxRowsPerBatch == 0 { maxRowsPerBatch = 1 // 确保至少有一行 @@ -460,7 +475,7 @@ func (c *Connection) executeBatchInsert(tx pgx.Tx, ctx context.Context, tableNam // 获取当前批次的值 batchValues := values[startIdx:endIdx] - // 构建VALUES部分 + // 构建 VALUES 部分 var valuesParts strings.Builder // 预分配更大的内存 valuesParts.Grow((end - i) * (columnCount*4 + 5)) // 增加预分配空间 @@ -481,7 +496,7 @@ func (c *Connection) executeBatchInsert(tx pgx.Tx, ctx context.Context, tableNam valuesParts.WriteString(")") } - // 构建完整的SQL语句 + // 构建完整的 SQL 语句 query := fmt.Sprintf("INSERT INTO \"%s\" (%s) VALUES %s", tableName, columnsStr, valuesParts.String()) // 执行批量插入 @@ -496,20 +511,20 @@ func (c *Connection) executeBatchInsert(tx pgx.Tx, ctx context.Context, tableNam for j := 0; j < sampleSize; j++ { samples = append(samples, fmt.Sprintf("%v", batchValues[j])) } - return fmt.Errorf("批量插入失败: %w, 数据样本: %v", err, samples) + return fmt.Errorf("批量插入失败:%w, 数据样本:%v", err, samples) } } return nil } -// GetVersion 获取PostgreSQL版本信息 +// GetVersion 获取 PostgreSQL 版本信息 func (c *Connection) GetVersion() (string, error) { ctx := context.Background() var version string err := c.pool.QueryRow(ctx, "SELECT version()").Scan(&version) if err != nil { - return "", fmt.Errorf("获取PostgreSQL版本失败: %w", err) + return "", fmt.Errorf("获取 PostgreSQL 版本失败:%w", err) } return version, nil } @@ -552,12 +567,12 @@ func ParsePostgreSQLVersion(version string) *PostgreSQLVersionInfo { return info } -// TestConnection 测试PostgreSQL连接 +// TestConnection 测试 PostgreSQL 连接 func TestConnection(config *config.PostgreSQLConfig) error { // 测试连接时不使用压缩 conn, err := NewConnection(config) if err != nil { - return fmt.Errorf("PostgreSQL连接测试失败: %w", err) + return fmt.Errorf("PostgreSQL 连接测试失败:%w", err) } defer conn.Close() @@ -578,7 +593,7 @@ func (c *Connection) TableExists(tableName string) (bool, error) { var exists bool err := c.pool.QueryRow(ctx, query, tableName).Scan(&exists) if err != nil { - return false, fmt.Errorf("检查表是否存在失败: %w", err) + return false, fmt.Errorf("检查表是否存在失败:%w", err) } return exists, nil } @@ -595,7 +610,7 @@ func (c *Connection) GrantTablePrivileges(user, tableName string, privileges []s _, err := c.pool.Exec(ctx, query) if err != nil { - return fmt.Errorf("授予表权限失败: %w", err) + return fmt.Errorf("授予表权限失败:%w", err) } return nil @@ -619,7 +634,7 @@ func (c *Connection) GetTablePrivileges(tableName string) ([]map[string]string, rows, err := c.pool.Query(ctx, query, tableName) if err != nil { - return nil, fmt.Errorf("获取表权限失败: %w", err) + return nil, fmt.Errorf("获取表权限失败:%w", err) } defer rows.Close() @@ -627,7 +642,7 @@ func (c *Connection) GetTablePrivileges(tableName string) ([]map[string]string, for rows.Next() { var user, privilege, isGrantable string if err := rows.Scan(&user, &privilege, &isGrantable); err != nil { - return nil, fmt.Errorf("扫描表权限信息失败: %w", err) + return nil, fmt.Errorf("扫描表权限信息失败:%w", err) } privileges = append(privileges, map[string]string{ @@ -648,7 +663,7 @@ func (c *Connection) GetTableRowCount(tableName string) (int64, error) { var count int64 err := c.pool.QueryRow(ctx, query).Scan(&count) if err != nil { - return 0, fmt.Errorf("获取表 %s 行数失败: %w", tableName, err) + return 0, fmt.Errorf("获取表 %s 行数失败:%w", tableName, err) } return count, nil @@ -753,7 +768,7 @@ func (c *Connection) BatchInsertDataWithTransactionAndGetLastValue(tx pgx.Tx, ta var rowCount int var totalRows int - // 严格使用传入的batchSize参数,不使用硬编码默认值 + // 严格使用传入的 batchSize 参数,不使用硬编码默认值 effectiveBatchSize := batchSize if effectiveBatchSize <= 0 { effectiveBatchSize = 10000 // 确保至少有一个合理的默认值 @@ -787,7 +802,7 @@ func (c *Connection) BatchInsertDataWithTransactionAndGetLastValue(tx pgx.Tx, ta for rows.Next() { // 扫描行数据 — 使用类型化指针,避免每次 Scan 分配堆对象 if err := rows.Scan(scanPtrs...); err != nil { - return 0, nil, fmt.Errorf("扫描行数据失败: %w", err) + return 0, nil, fmt.Errorf("扫描行数据失败:%w", err) } // 跟踪最后一个主键值 @@ -806,12 +821,12 @@ func (c *Connection) BatchInsertDataWithTransactionAndGetLastValue(tx pgx.Tx, ta rowCount++ totalRows++ - // 当达到批量大小时执行CopyFrom + // 当达到批量大小时执行 CopyFrom if rowCount == effectiveBatchSize { - // 执行CopyFrom,使用转换后的小写列名 + // 执行 CopyFrom,使用转换后的小写列名 _, err := tx.CopyFrom(ctx, pgx.Identifier{tableName}, copyColumns, pgx.CopyFromRows(copyRows)) if err != nil { - return 0, nil, fmt.Errorf("CopyFrom执行失败: %w", err) + return 0, nil, fmt.Errorf("CopyFrom 执行失败:%w", err) } // 将 rowValues 切片返回 pool 复用 @@ -827,10 +842,10 @@ func (c *Connection) BatchInsertDataWithTransactionAndGetLastValue(tx pgx.Tx, ta // 执行剩余的数据 if len(copyRows) > 0 { - // 执行CopyFrom,使用转换后的小写列名 + // 执行 CopyFrom,使用转换后的小写列名 _, err := tx.CopyFrom(ctx, pgx.Identifier{tableName}, copyColumns, pgx.CopyFromRows(copyRows)) if err != nil { - return 0, nil, fmt.Errorf("CopyFrom执行失败: %w", err) + return 0, nil, fmt.Errorf("CopyFrom 执行失败:%w", err) } // 将 rowValues 切片返回 pool 复用 for _, rv := range copyRows { @@ -842,19 +857,19 @@ func (c *Connection) BatchInsertDataWithTransactionAndGetLastValue(tx pgx.Tx, ta return 0, nil, err } - // 只有在没有找到主键值的情况下,才执行MAX查询(作为后备方案) + // 只有在没有找到主键值的情况下,才执行 MAX 查询(作为后备方案) if resolvedPrimaryKey != "" && lastValue == nil { query := fmt.Sprintf("SELECT MAX(\"%s\") FROM \"%s\"", resolvedPrimaryKey, tableName) err := tx.QueryRow(ctx, query).Scan(&lastValue) if err != nil && err != pgx.ErrNoRows { - return 0, nil, fmt.Errorf("获取最后一个主键值失败: %w", err) + return 0, nil, fmt.Errorf("获取最后一个主键值失败:%w", err) } } return totalRows, lastValue, nil } -// parseMySQLPoint 解析MySQL的WKB格式Point数据 +// parseMySQLPoint 解析 MySQL 的 WKB 格式 Point 数据 func parseMySQLPoint(data []byte) (string, error) { // MySQL Geometry Header (4 bytes SRID) + WKB (1 byte order + 4 bytes type + 16 bytes coords) // SRID (4) + Order (1) + Type (4) + X (8) + Y (8) = 25 bytes @@ -890,7 +905,7 @@ func parseMySQLPoint(data []byte) (string, error) { y = math.Float64frombits(yBits) } - // 格式化为PostgreSQL Point格式 (x,y) + // 格式化为 PostgreSQL Point 格式 (x,y) return fmt.Sprintf("(%v,%v)", x, y), nil }