Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 28 additions & 5 deletions internal/delivery/mcp/timescale_tool.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,36 @@ import (
"context"
"encoding/json"
"fmt"
"regexp"
"strconv"
"strings"

"github.com/FreePeak/cortex/pkg/server"
cortextools "github.com/FreePeak/cortex/pkg/tools"
)

// validIdentifier matches valid SQL identifiers
var validIdentifier = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`)

// sanitizeIdentifier validates and escapes a SQL identifier
func sanitizeIdentifier(identifier string) string {
if identifier == "" {
return ""
}
if strings.Contains(identifier, "\x00") {
return ""
}
if validIdentifier.MatchString(identifier) {
return identifier
}
return "\"" + strings.ReplaceAll(identifier, "\"", "\"\"") + "\""
}

// sanitizeStringLiteral escapes a string literal by replacing ' with ”
func sanitizeStringLiteral(s string) string {
return strings.ReplaceAll(s, "'", "''")
}

// TimescaleDBTool implements a tool for TimescaleDB operations
type TimescaleDBTool struct {
name string
Expand Down Expand Up @@ -1743,19 +1766,19 @@ func getBoolParam(params map[string]interface{}, key string) bool {
func buildCreateHypertableSQL(table, timeColumn, chunkTimeInterval, partitioningColumn string, ifNotExists bool) string {
var args []string

// Add required arguments: table name and time column
args = append(args, fmt.Sprintf("'%s'", table))
args = append(args, fmt.Sprintf("'%s'", timeColumn))
// Add required arguments: table name and time column (identifiers, not string literals)
args = append(args, sanitizeIdentifier(table))
args = append(args, sanitizeIdentifier(timeColumn))

// Build optional parameters
var options []string

if chunkTimeInterval != "" {
options = append(options, fmt.Sprintf("chunk_time_interval => interval '%s'", chunkTimeInterval))
options = append(options, fmt.Sprintf("chunk_time_interval => interval '%s'", sanitizeStringLiteral(chunkTimeInterval)))
}

if partitioningColumn != "" {
options = append(options, fmt.Sprintf("partitioning_column => '%s'", partitioningColumn))
options = append(options, fmt.Sprintf("partitioning_column => %s", sanitizeIdentifier(partitioningColumn)))
}

options = append(options, fmt.Sprintf("if_not_exists => %t", ifNotExists))
Expand Down
2 changes: 1 addition & 1 deletion pkg/db/timescale/hypertable.go
Original file line number Diff line number Diff line change
Expand Up @@ -30,7 +30,7 @@ func sanitizeIdentifier(identifier string) string {
return "\"" + strings.ReplaceAll(identifier, "\"", "\"\"") + "\""
}

// sanitizeStringLiteral escapes a string literal by replacing ' with
// sanitizeStringLiteral escapes a string literal by replacing ' with ''
func sanitizeStringLiteral(s string) string {
return strings.ReplaceAll(s, "'", "''")
}
Expand Down
16 changes: 8 additions & 8 deletions pkg/db/timescale/metadata.go
Original file line number Diff line number Diff line change
Expand Up @@ -497,16 +497,16 @@ func (t *DB) GenerateHypertableSchema(ctx context.Context, tableName string) (st
// Generate CREATE HYPERTABLE statement
var createHypertableStmt strings.Builder
createHypertableStmt.WriteString(fmt.Sprintf("SELECT create_hypertable('%s', '%s'",
tableName, metadata.TimeDimension))
sanitizeStringLiteral(tableName), sanitizeStringLiteral(metadata.TimeDimension)))

if metadata.ChunkTimeInterval != "" {
createHypertableStmt.WriteString(fmt.Sprintf(", chunk_time_interval => INTERVAL '%s'",
metadata.ChunkTimeInterval))
sanitizeStringLiteral(metadata.ChunkTimeInterval)))
}

if len(metadata.SpaceDimensions) > 0 {
createHypertableStmt.WriteString(fmt.Sprintf(", partitioning_column => '%s'",
metadata.SpaceDimensions[0]))
sanitizeStringLiteral(metadata.SpaceDimensions[0])))
}

createHypertableStmt.WriteString(");")
Expand All @@ -518,20 +518,20 @@ func (t *DB) GenerateHypertableSchema(ctx context.Context, tableName string) (st
if metadata.Compression {
compressionSettings, err := t.GetCompressionSettings(ctx, tableName)
if err == nil && compressionSettings.CompressionEnabled {
compressionStmt := fmt.Sprintf("ALTER TABLE %s SET (timescaledb.compress = true);", tableName)
compressionStmt := fmt.Sprintf("ALTER TABLE %s SET (timescaledb.compress = true);", sanitizeIdentifier(tableName))
result += "\n\n" + compressionStmt

// Add compression policy if exists
if compressionSettings.CompressionInterval != "" {
policyStmt := fmt.Sprintf("SELECT add_compression_policy('%s', INTERVAL '%s'",
tableName, compressionSettings.CompressionInterval)
sanitizeStringLiteral(tableName), sanitizeStringLiteral(compressionSettings.CompressionInterval))

if compressionSettings.SegmentBy != "" {
policyStmt += fmt.Sprintf(", segmentby => '%s'", compressionSettings.SegmentBy)
policyStmt += fmt.Sprintf(", segmentby => '%s'", sanitizeStringLiteral(compressionSettings.SegmentBy))
}

if compressionSettings.OrderBy != "" {
policyStmt += fmt.Sprintf(", orderby => '%s'", compressionSettings.OrderBy)
policyStmt += fmt.Sprintf(", orderby => '%s'", sanitizeStringLiteral(compressionSettings.OrderBy))
}

policyStmt += ");"
Expand All @@ -545,7 +545,7 @@ func (t *DB) GenerateHypertableSchema(ctx context.Context, tableName string) (st
retentionSettings, err := t.GetRetentionSettings(ctx, tableName)
if err == nil && retentionSettings.RetentionEnabled && retentionSettings.RetentionInterval != "" {
retentionStmt := fmt.Sprintf("SELECT add_retention_policy('%s', INTERVAL '%s');",
tableName, retentionSettings.RetentionInterval)
sanitizeStringLiteral(tableName), sanitizeStringLiteral(retentionSettings.RetentionInterval))
result += "\n\n" + retentionStmt
}
}
Expand Down
36 changes: 18 additions & 18 deletions pkg/db/timescale/policy.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@ func (t *DB) EnableCompression(ctx context.Context, tableName string, afterInter
return fmt.Errorf("TimescaleDB extension not available")
}

query := fmt.Sprintf("ALTER TABLE %s SET (timescaledb.compress = true)", tableName)
query := fmt.Sprintf("ALTER TABLE %s SET (timescaledb.compress = true)", sanitizeIdentifier(tableName))
_, err := t.ExecuteSQLWithoutParams(ctx, query)
if err != nil {
return fmt.Errorf("failed to enable compression: %w", err)
Expand Down Expand Up @@ -59,7 +59,7 @@ func (t *DB) DisableCompression(ctx context.Context, tableName string) error {
}

// Then disable compression
query := fmt.Sprintf("ALTER TABLE %s SET (timescaledb.compress = false)", tableName)
query := fmt.Sprintf("ALTER TABLE %s SET (timescaledb.compress = false)", sanitizeIdentifier(tableName))
_, err = t.ExecuteSQLWithoutParams(ctx, query)
if err != nil {
return fmt.Errorf("failed to disable compression: %w", err)
Expand All @@ -75,7 +75,7 @@ func (t *DB) AddCompressionPolicy(ctx context.Context, tableName, interval, segm
}

// First, check if the table has compression enabled
query := fmt.Sprintf("SELECT compress FROM timescaledb_information.hypertables WHERE hypertable_name = '%s'", tableName)
query := fmt.Sprintf("SELECT compress FROM timescaledb_information.hypertables WHERE hypertable_name = '%s'", sanitizeStringLiteral(tableName))
result, err := t.ExecuteSQLWithoutParams(ctx, query)
if err != nil {
return fmt.Errorf("failed to check compression status: %w", err)
Expand All @@ -89,7 +89,7 @@ func (t *DB) AddCompressionPolicy(ctx context.Context, tableName, interval, segm
isCompressed := rows[0]["compress"]
if isCompressed == nil || fmt.Sprintf("%v", isCompressed) == "false" {
// Enable compression
enableQuery := fmt.Sprintf("ALTER TABLE %s SET (timescaledb.compress = true)", tableName)
enableQuery := fmt.Sprintf("ALTER TABLE %s SET (timescaledb.compress = true)", sanitizeIdentifier(tableName))
_, err := t.ExecuteSQLWithoutParams(ctx, enableQuery)
if err != nil {
return fmt.Errorf("failed to enable compression: %w", err)
Expand All @@ -98,14 +98,14 @@ func (t *DB) AddCompressionPolicy(ctx context.Context, tableName, interval, segm

// Build the compression policy query
var policyQuery strings.Builder
policyQuery.WriteString(fmt.Sprintf("SELECT add_compression_policy('%s', INTERVAL '%s'", tableName, interval))
policyQuery.WriteString(fmt.Sprintf("SELECT add_compression_policy('%s', INTERVAL '%s'", sanitizeStringLiteral(tableName), sanitizeStringLiteral(interval)))

if segmentBy != "" {
policyQuery.WriteString(fmt.Sprintf(", segmentby => '%s'", segmentBy))
policyQuery.WriteString(fmt.Sprintf(", segmentby => '%s'", sanitizeStringLiteral(segmentBy)))
}

if orderBy != "" {
policyQuery.WriteString(fmt.Sprintf(", orderby => '%s'", orderBy))
policyQuery.WriteString(fmt.Sprintf(", orderby => '%s'", sanitizeStringLiteral(orderBy)))
}

policyQuery.WriteString(")")
Expand All @@ -128,7 +128,7 @@ func (t *DB) RemoveCompressionPolicy(ctx context.Context, tableName string) erro
// Find the policy ID
query := fmt.Sprintf(
"SELECT job_id FROM timescaledb_information.jobs WHERE hypertable_name = '%s' AND proc_name = 'policy_compression'",
tableName,
sanitizeStringLiteral(tableName),
)

result, err := t.ExecuteSQLWithoutParams(ctx, query)
Expand Down Expand Up @@ -167,7 +167,7 @@ func (t *DB) GetCompressionSettings(ctx context.Context, tableName string) (*Com
// Check if the table has compression enabled
query := fmt.Sprintf(
"SELECT compress FROM timescaledb_information.hypertables WHERE hypertable_name = '%s'",
tableName,
sanitizeStringLiteral(tableName),
)

result, err := t.ExecuteSQLWithoutParams(ctx, query)
Expand All @@ -191,7 +191,7 @@ func (t *DB) GetCompressionSettings(ctx context.Context, tableName string) (*Com
// Get compression-specific settings
settingsQuery := fmt.Sprintf(
"SELECT segmentby, orderby FROM timescaledb_information.compression_settings WHERE hypertable_name = '%s'",
tableName,
sanitizeStringLiteral(tableName),
)

settingsResult, err := t.ExecuteSQLWithoutParams(ctx, settingsQuery)
Expand Down Expand Up @@ -243,7 +243,7 @@ func (t *DB) AddRetentionPolicy(ctx context.Context, tableName, interval string)
return fmt.Errorf("TimescaleDB extension not available")
}

query := fmt.Sprintf("SELECT add_retention_policy('%s', INTERVAL '%s')", tableName, interval)
query := fmt.Sprintf("SELECT add_retention_policy('%s', INTERVAL '%s')", sanitizeStringLiteral(tableName), sanitizeStringLiteral(interval))
_, err := t.ExecuteSQLWithoutParams(ctx, query)
if err != nil {
return fmt.Errorf("failed to add retention policy: %w", err)
Expand All @@ -261,7 +261,7 @@ func (t *DB) RemoveRetentionPolicy(ctx context.Context, tableName string) error
// Find the policy ID
query := fmt.Sprintf(
"SELECT job_id FROM timescaledb_information.jobs WHERE hypertable_name = '%s' AND proc_name = 'policy_retention'",
tableName,
sanitizeStringLiteral(tableName),
)

result, err := t.ExecuteSQLWithoutParams(ctx, query)
Expand Down Expand Up @@ -306,7 +306,7 @@ func (t *DB) GetRetentionSettings(ctx context.Context, tableName string) (*Reten
"SELECT s.schedule_interval FROM timescaledb_information.jobs j "+
"JOIN timescaledb_information.job_stats s ON j.job_id = s.job_id "+
"WHERE j.hypertable_name = '%s' AND j.proc_name = 'policy_retention'",
tableName,
sanitizeStringLiteral(tableName),
)

result, err := t.ExecuteSQLWithoutParams(ctx, query)
Expand Down Expand Up @@ -334,11 +334,11 @@ func (t *DB) CompressChunks(ctx context.Context, tableName, olderThan string) er
var query string
if olderThan == "" {
// Compress all chunks
query = fmt.Sprintf("SELECT compress_chunks(hypertable => '%s')", tableName)
query = fmt.Sprintf("SELECT compress_chunks(hypertable => '%s')", sanitizeStringLiteral(tableName))
} else {
// Compress chunks older than the specified interval
query = fmt.Sprintf("SELECT compress_chunks(hypertable => '%s', older_than => INTERVAL '%s')",
tableName, olderThan)
sanitizeStringLiteral(tableName), sanitizeStringLiteral(olderThan))
}

_, err := t.ExecuteSQLWithoutParams(ctx, query)
Expand All @@ -358,11 +358,11 @@ func (t *DB) DecompressChunks(ctx context.Context, tableName, newerThan string)
var query string
if newerThan == "" {
// Decompress all chunks
query = fmt.Sprintf("SELECT decompress_chunks(hypertable => '%s')", tableName)
query = fmt.Sprintf("SELECT decompress_chunks(hypertable => '%s')", sanitizeStringLiteral(tableName))
} else {
// Decompress chunks newer than the specified interval
query = fmt.Sprintf("SELECT decompress_chunks(hypertable => '%s', newer_than => INTERVAL '%s')",
tableName, newerThan)
sanitizeStringLiteral(tableName), sanitizeStringLiteral(newerThan))
}

_, err := t.ExecuteSQLWithoutParams(ctx, query)
Expand Down Expand Up @@ -394,7 +394,7 @@ func (t *DB) GetChunkCompressionStats(ctx context.Context, tableName string) (in
FROM timescaledb_information.chunks
WHERE hypertable_name = '%s'
ORDER BY range_end DESC
`, tableName)
`, sanitizeStringLiteral(tableName))

result, err := t.ExecuteSQLWithoutParams(ctx, query)
if err != nil {
Expand Down
2 changes: 1 addition & 1 deletion pkg/db/timescale/timeseries.go
Original file line number Diff line number Diff line change
Expand Up @@ -170,7 +170,7 @@ func addWindowFunctions(query string, functions []WindowFunction) string {

// Add alias if specified
if fn.Alias != "" {
windowPart.WriteString(fmt.Sprintf(" AS %s", fn.Alias))
windowPart.WriteString(fmt.Sprintf(" AS %s", sanitizeIdentifier(fn.Alias)))
}

// Add comma if not last function
Expand Down
22 changes: 21 additions & 1 deletion pkg/dbtools/schema.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,13 +4,33 @@ import (
"context"
"database/sql"
"fmt"
"regexp"
"strings"
"time"

"github.com/FreePeak/db-mcp-server/pkg/db"
"github.com/FreePeak/db-mcp-server/pkg/logger"
"github.com/FreePeak/db-mcp-server/pkg/tools"
)

// validIdentifier matches valid SQL identifiers (alphanumeric and underscores)
var validIdentifier = regexp.MustCompile(`^[a-zA-Z_][a-zA-Z0-9_]*$`)

// sanitizeIdentifier validates and escapes a SQL identifier for SQLite
func sanitizeIdentifier(identifier string) string {
if identifier == "" {
return ""
}
if strings.Contains(identifier, "\x00") {
return ""
}
if validIdentifier.MatchString(identifier) {
return identifier
}
// Quote the identifier to handle special characters
return "\"" + strings.ReplaceAll(identifier, "\"", "\"\"") + "\""
}

// DatabaseStrategy defines the interface for database-specific query strategies
type DatabaseStrategy interface {
GetTablesQueries() []QueryWithArgs
Expand Down Expand Up @@ -268,7 +288,7 @@ func (s *SQLiteStrategy) GetColumnsQueries(table string) []QueryWithArgs {
return []QueryWithArgs{
// Primary: PRAGMA table_info approach
{
Query: "PRAGMA table_info(" + table + ")",
Query: "PRAGMA table_info(" + sanitizeIdentifier(table) + ")",
Args: []interface{}{},
},
// Secondary: sqlite_master approach for column info
Expand Down
Loading