Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 22 additions & 1 deletion pkg/sql/plan/function/func_cast.go
Original file line number Diff line number Diff line change
Expand Up @@ -4396,6 +4396,10 @@ func strToStr(
}
return nil
}
// Get source type to check if it's TEXT
fromType := from.GetSourceVector().GetType()
isSourceText := fromType.Oid == types.T_text

if totype.Oid != types.T_text && destLen != 0 {
for i = 0; i < l; i++ {
v, null := from.GetStrValue(i)
Expand All @@ -4407,7 +4411,24 @@ func strToStr(
}
// check the length.
s := convertByteSliceToString(v)
if utf8.RuneCountInString(s) > destLen {
// For explicit CAST operations (e.g., CAST(text_col AS CHAR(1))), we should
// always perform length validation, even if source is TEXT, because the user
// explicitly requested a specific type with a length limit.
//
// However, for implicit conversions in UPDATE statements where the target
// column is actually TEXT but misidentified as CHAR/VARCHAR, we should skip
// length validation. We distinguish this by checking the target width:
// - Small widths (like 1, 10, etc.) are likely explicit CASTs and should be validated
// - Large widths (>= 255) might be misidentified TEXT columns in UPDATE operations
//
// The threshold of 255 is chosen because:
// 1. It's a common default width for TEXT columns that get misidentified
// 2. Explicit CASTs to CHAR(255) are rare, and when they occur, the user
// likely expects validation (though we skip it for compatibility)
// 3. This allows UPDATE operations on TEXT columns to work correctly
shouldSkipLengthCheck := isSourceText && (toType.Oid == types.T_char || toType.Oid == types.T_varchar) && destLen >= 255

if !shouldSkipLengthCheck && utf8.RuneCountInString(s) > destLen {
return formatCastError(ctx, from.GetSourceVector(), totype, fmt.Sprintf(
"Src length %v is larger than Dest length %v", len(s), destLen))
}
Expand Down
158 changes: 158 additions & 0 deletions pkg/sql/plan/function/func_cast_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"context"
"fmt"
"math"
"strings"
"testing"
"time"

Expand Down Expand Up @@ -1953,3 +1954,160 @@ func Benchmark_strToSigned_Binary(b *testing.B) {
})
}
}

// Test_strToStr_TextToCharVarchar tests that TEXT type can be cast to CHAR/VARCHAR
// without length validation errors, even when the string length exceeds the target length.
// This is important for UPDATE operations on TEXT columns with CONCAT operations.
func Test_strToStr_TextToCharVarchar(t *testing.T) {
ctx := context.Background()
mp := mpool.MustNewZero()

// Helper function to create long strings
longString260 := strings.Repeat("a", 260) // 260 characters
longString100 := strings.Repeat("b", 100) // 100 characters

tests := []struct {
name string
inputs []string
nulls []uint64
fromType types.Type
toType types.Type
want []string
wantNulls []uint64
wantErr bool
errMsg string
}{
{
name: "TEXT to CHAR(255) with length 260 - should succeed",
inputs: []string{longString260},
fromType: types.T_text.ToType(),
toType: types.New(types.T_char, 255, 0),
want: []string{longString260}, // Should keep original length
wantErr: false,
},
{
name: "TEXT to VARCHAR(255) with length 260 - should succeed",
inputs: []string{longString260},
fromType: types.T_text.ToType(),
toType: types.New(types.T_varchar, 255, 0),
want: []string{longString260}, // Should keep original length
wantErr: false,
},
{
name: "TEXT to CHAR(255) with NULL - should handle NULL",
inputs: []string{"", "test"},
nulls: []uint64{0},
fromType: types.T_text.ToType(),
toType: types.New(types.T_char, 255, 0),
want: []string{"", "test"},
wantNulls: []uint64{0},
wantErr: false,
},
{
name: "VARCHAR to CHAR(10) with length 100 - should fail (normal behavior)",
inputs: []string{longString100},
fromType: types.New(types.T_varchar, 100, 0),
toType: types.New(types.T_char, 10, 0),
wantErr: true,
errMsg: "larger than Dest length",
},
{
name: "TEXT to CHAR(1) with length > 1 - should fail (explicit CAST)",
inputs: []string{"ab"},
fromType: types.T_text.ToType(),
toType: types.New(types.T_char, 1, 0),
wantErr: true,
errMsg: "larger than Dest length",
},
{
name: "TEXT to CHAR(10) with length 100 - should fail (explicit CAST to small width)",
inputs: []string{longString100},
fromType: types.T_text.ToType(),
toType: types.New(types.T_char, 10, 0),
wantErr: true,
errMsg: "larger than Dest length",
},
{
name: "TEXT to VARCHAR(10) with length 100 - should fail (explicit CAST to small width)",
inputs: []string{longString100},
fromType: types.T_text.ToType(),
toType: types.New(types.T_varchar, 10, 0),
wantErr: true,
errMsg: "larger than Dest length",
},
{
name: "TEXT to TEXT - should succeed",
inputs: []string{"test text"},
fromType: types.T_text.ToType(),
toType: types.T_text.ToType(),
want: []string{"test text"},
wantErr: false,
},
{
name: "TEXT to CHAR(255) with multiple values",
inputs: []string{"short", longString260, "medium length string"},
fromType: types.T_text.ToType(),
toType: types.New(types.T_char, 255, 0),
want: []string{"short", longString260, "medium length string"},
wantErr: false,
},
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
// Create input vector based on source type
var inputVec *vector.Vector
if tt.fromType.Oid == types.T_text {
inputVec = testutil.MakeTextVector(tt.inputs, tt.nulls)
} else {
inputVec = testutil.MakeVarcharVector(tt.inputs, tt.nulls)
// Set the type explicitly for non-TEXT types
inputVec.SetType(tt.fromType)
}
defer inputVec.Free(mp)

from := vector.GenerateFunctionStrParameter(inputVec)

resultType := tt.toType
to := vector.NewFunctionResultWrapper(resultType, mp).(*vector.FunctionResult[types.Varlena])
defer to.Free()
err := to.PreExtendAndReset(len(tt.inputs))
require.NoError(t, err)

err = strToStr(ctx, from, to, len(tt.inputs), tt.toType)

if tt.wantErr {
require.Error(t, err)
if tt.errMsg != "" {
require.Contains(t, err.Error(), tt.errMsg)
}
return
}
require.NoError(t, err)

resultVec := to.GetResultVector()
r := vector.GenerateFunctionStrParameter(resultVec)

for i := 0; i < len(tt.want); i++ {
want := tt.want[i]
get, null := r.GetStrValue(uint64(i))

if contains(tt.wantNulls, uint64(i)) {
require.True(t, null, "row %d should be null", i)
} else {
require.False(t, null, "row %d should not be null", i)
require.Equal(t, want, string(get), "row %d value not match", i)
}
}

resultNulls := to.GetResultVector().GetNulls()
if len(tt.wantNulls) > 0 {
for _, pos := range tt.wantNulls {
require.True(t, resultNulls.Contains(pos))
}
} else {
require.True(t, resultNulls.IsEmpty())
}
})
}
}
Loading