diff --git a/README.md b/README.md index 8f0da930..8d3c7228 100644 --- a/README.md +++ b/README.md @@ -437,6 +437,9 @@ are supported: * "github.com/golang-sql/civil".Date -> date * "github.com/golang-sql/civil".DateTime -> datetime2 * "github.com/golang-sql/civil".Time -> time +* mssql.NullDate -> date (nullable) +* mssql.NullDateTime -> datetime2 (nullable) +* mssql.NullTime -> time (nullable) * mssql.TVP -> Table Value Parameter (TDS version dependent) Using an `int` parameter will send a 4 byte value (int) from a 32bit app and an 8 byte value (bigint) from a 64bit app. diff --git a/bulkcopy_test.go b/bulkcopy_test.go index 088911e2..1ee44e25 100644 --- a/bulkcopy_test.go +++ b/bulkcopy_test.go @@ -13,6 +13,7 @@ import ( "testing" "time" + "github.com/golang-sql/civil" "github.com/stretchr/testify/assert" ) @@ -29,6 +30,9 @@ func TestBulkcopyWithInvalidNullableType(t *testing.T) { "test_nullint16", "test_nulltime", "test_nulluniqueidentifier", + "test_nulldate", + "test_nulldatetime", + "test_nullciviltime", } values := []interface{}{ sql.NullFloat64{Valid: false}, @@ -40,6 +44,9 @@ func TestBulkcopyWithInvalidNullableType(t *testing.T) { sql.NullInt16{Valid: false}, sql.NullTime{Valid: false}, NullUniqueIdentifier{Valid: false}, + NullDate{Valid: false}, + NullDateTime{Valid: false}, + NullTime{Valid: false}, } pool, logger := open(t) @@ -176,6 +183,9 @@ func testBulkcopy(t *testing.T, guidConversion bool) { {"test_nullint32", sql.NullInt32{2147483647, true}, 2147483647}, {"test_nullint16", sql.NullInt16{32767, true}, 32767}, {"test_nulltime", sql.NullTime{time.Date(2010, 11, 12, 13, 14, 15, 120000000, time.UTC), true}, time.Date(2010, 11, 12, 13, 14, 15, 120000000, time.UTC)}, + {"test_nulldate", NullDate{civil.Date{Year: 2010, Month: 11, Day: 12}, true}, time.Date(2010, 11, 12, 0, 0, 0, 0, time.UTC)}, + {"test_nulldatetime", NullDateTime{civil.DateTime{Date: civil.Date{Year: 2010, Month: 11, Day: 12}, Time: civil.Time{Hour: 13, Minute: 14, Second: 15, Nanosecond: 120000000}}, true}, time.Date(2010, 11, 12, 13, 14, 15, 120000000, time.UTC)}, + {"test_nullciviltime", NullTime{civil.Time{Hour: 13, Minute: 14, Second: 15, Nanosecond: 123000000}, true}, time.Date(1, 1, 1, 13, 14, 15, 123000000, time.UTC)}, {"test_datetimen_midnight", time.Date(2025, 1, 1, 23, 59, 59, 998_350_000, time.UTC), time.Date(2025, 1, 2, 0, 0, 0, 0, time.UTC)}, // {"test_smallmoney", 1234.56, nil}, // {"test_money", 1234.56, nil}, @@ -351,6 +361,9 @@ func setupNullableTypeTable(ctx context.Context, t *testing.T, conn *sql.Conn, t [test_nullint16] [smallint] NULL, [test_nulltime] [datetime] NULL, [test_nulluniqueidentifier] [uniqueidentifier] NULL, + [test_nulldate] [date] NULL, + [test_nulldatetime] [datetime2] NULL, + [test_nullciviltime] [time] NULL, CONSTRAINT [PK_` + tableName + `_id] PRIMARY KEY CLUSTERED ( [id] ASC @@ -438,6 +451,9 @@ func setupTable(ctx context.Context, t *testing.T, conn *sql.Conn, tableName str [test_nullint32] [int] NULL, [test_nullint16] [smallint] NULL, [test_nulltime] [datetime] NULL, + [test_nulldate] [date] NULL, + [test_nulldatetime] [datetime2] NULL, + [test_nullciviltime] [time] NULL, [test_datetimen_midnight] [datetime] NULL, CONSTRAINT [PK_` + tableName + `_id] PRIMARY KEY CLUSTERED ( diff --git a/civil_null.go b/civil_null.go new file mode 100644 index 00000000..655938b8 --- /dev/null +++ b/civil_null.go @@ -0,0 +1,214 @@ +package mssql + +import ( + "database/sql/driver" + "encoding/json" + "fmt" + "time" + + "github.com/golang-sql/civil" +) + +// NullDate represents a civil.Date that may be null. +// NullDate implements the Scanner interface so it can be used as a scan destination, +// similar to sql.NullString. +type NullDate struct { + Date civil.Date + Valid bool // Valid is true if Date is not NULL +} + +// Scan implements the Scanner interface. +func (n *NullDate) Scan(value interface{}) error { + if value == nil { + n.Date, n.Valid = civil.Date{}, false + return nil + } + n.Valid = true + switch v := value.(type) { + case time.Time: + n.Date = civil.DateOf(v) + return nil + default: + n.Valid = false + return fmt.Errorf("cannot scan %T into NullDate", value) + } +} + +// Value implements the driver Valuer interface. +func (n NullDate) Value() (driver.Value, error) { + if !n.Valid { + return nil, nil + } + return n.Date.In(time.UTC), nil +} + +// String returns the string representation of the date or "NULL". +func (n NullDate) String() string { + if !n.Valid { + return "NULL" + } + return n.Date.String() +} + +// MarshalText implements the encoding.TextMarshaler interface. +func (n NullDate) MarshalText() ([]byte, error) { + if !n.Valid { + return []byte("null"), nil + } + return n.Date.MarshalText() +} + +// UnmarshalJSON implements the json.Unmarshaler interface. +func (n *NullDate) UnmarshalJSON(b []byte) error { + if string(b) == "null" { + n.Date, n.Valid = civil.Date{}, false + return nil + } + err := json.Unmarshal(b, &n.Date) + n.Valid = err == nil + return err +} + +// MarshalJSON implements the json.Marshaler interface. +func (n NullDate) MarshalJSON() ([]byte, error) { + if !n.Valid { + return []byte("null"), nil + } + return json.Marshal(n.Date) +} + +// NullDateTime represents a civil.DateTime that may be null. +// NullDateTime implements the Scanner interface so it can be used as a scan destination, +// similar to sql.NullString. +type NullDateTime struct { + DateTime civil.DateTime + Valid bool // Valid is true if DateTime is not NULL +} + +// Scan implements the Scanner interface. +func (n *NullDateTime) Scan(value interface{}) error { + if value == nil { + n.DateTime, n.Valid = civil.DateTime{}, false + return nil + } + n.Valid = true + switch v := value.(type) { + case time.Time: + n.DateTime = civil.DateTimeOf(v) + return nil + default: + n.Valid = false + return fmt.Errorf("cannot scan %T into NullDateTime", value) + } +} + +// Value implements the driver Valuer interface. +func (n NullDateTime) Value() (driver.Value, error) { + if !n.Valid { + return nil, nil + } + return n.DateTime.In(time.UTC), nil +} + +// String returns the string representation of the datetime or "NULL". +func (n NullDateTime) String() string { + if !n.Valid { + return "NULL" + } + return n.DateTime.String() +} + +// MarshalText implements the encoding.TextMarshaler interface. +func (n NullDateTime) MarshalText() ([]byte, error) { + if !n.Valid { + return []byte("null"), nil + } + return n.DateTime.MarshalText() +} + +// UnmarshalJSON implements the json.Unmarshaler interface. +func (n *NullDateTime) UnmarshalJSON(b []byte) error { + if string(b) == "null" { + n.DateTime, n.Valid = civil.DateTime{}, false + return nil + } + err := json.Unmarshal(b, &n.DateTime) + n.Valid = err == nil + return err +} + +// MarshalJSON implements the json.Marshaler interface. +func (n NullDateTime) MarshalJSON() ([]byte, error) { + if !n.Valid { + return []byte("null"), nil + } + return json.Marshal(n.DateTime) +} + +// NullTime represents a civil.Time that may be null. +// NullTime implements the Scanner interface so it can be used as a scan destination, +// similar to sql.NullString. +type NullTime struct { + Time civil.Time + Valid bool // Valid is true if Time is not NULL +} + +// Scan implements the Scanner interface. +func (n *NullTime) Scan(value interface{}) error { + if value == nil { + n.Time, n.Valid = civil.Time{}, false + return nil + } + n.Valid = true + switch v := value.(type) { + case time.Time: + n.Time = civil.TimeOf(v) + return nil + default: + n.Valid = false + return fmt.Errorf("cannot scan %T into NullTime", value) + } +} + +// Value implements the driver Valuer interface. +func (n NullTime) Value() (driver.Value, error) { + if !n.Valid { + return nil, nil + } + return time.Date(1, 1, 1, n.Time.Hour, n.Time.Minute, n.Time.Second, n.Time.Nanosecond, time.UTC), nil +} + +// String returns the string representation of the time or "NULL". +func (n NullTime) String() string { + if !n.Valid { + return "NULL" + } + return n.Time.String() +} + +// MarshalText implements the encoding.TextMarshaler interface. +func (n NullTime) MarshalText() ([]byte, error) { + if !n.Valid { + return []byte("null"), nil + } + return n.Time.MarshalText() +} + +// UnmarshalJSON implements the json.Unmarshaler interface. +func (n *NullTime) UnmarshalJSON(b []byte) error { + if string(b) == "null" { + n.Time, n.Valid = civil.Time{}, false + return nil + } + err := json.Unmarshal(b, &n.Time) + n.Valid = err == nil + return err +} + +// MarshalJSON implements the json.Marshaler interface. +func (n NullTime) MarshalJSON() ([]byte, error) { + if !n.Valid { + return []byte("null"), nil + } + return json.Marshal(n.Time) +} diff --git a/civil_null_integration_test.go b/civil_null_integration_test.go new file mode 100644 index 00000000..bb8586a8 --- /dev/null +++ b/civil_null_integration_test.go @@ -0,0 +1,197 @@ +package mssql + +import ( + "context" + "database/sql" + "testing" + "time" + + "github.com/golang-sql/civil" +) + +// TestNullCivilTypesIntegration tests the nullable civil types with actual database operations +// This test requires a SQL Server connection +func TestNullCivilTypesIntegration(t *testing.T) { + checkConnStr(t) + + tl := testLogger{t: t} + defer tl.StopLogging() + + conn, logger := open(t) + defer conn.Close() + defer logger.StopLogging() + + ctx, cancel := context.WithCancel(context.Background()) + defer cancel() + + // Test civil null types as OUT parameters + t.Run("OUT parameters", func(t *testing.T) { + // Test NullDate OUT parameter + t.Run("NullDate", func(t *testing.T) { + var nullDate NullDate + + // Test NULL value + _, err := conn.ExecContext(ctx, "SELECT @p1 = NULL", sql.Out{Dest: &nullDate}) + if err != nil { + t.Fatalf("Failed to execute query with NULL: %v", err) + } + if nullDate.Valid { + t.Error("Expected NullDate to be invalid (NULL)") + } + + // Test valid value + _, err = conn.ExecContext(ctx, "SELECT @p1 = '2023-12-25'", sql.Out{Dest: &nullDate}) + if err != nil { + t.Fatalf("Failed to execute query with date: %v", err) + } + if !nullDate.Valid { + t.Error("Expected NullDate to be valid") + } + expectedDate := civil.Date{Year: 2023, Month: time.December, Day: 25} + if nullDate.Date != expectedDate { + t.Errorf("Expected %v, got %v", expectedDate, nullDate.Date) + } + }) + + // Test NullDateTime OUT parameter + t.Run("NullDateTime", func(t *testing.T) { + var nullDateTime NullDateTime + + // Test NULL value + _, err := conn.ExecContext(ctx, "SELECT @p1 = NULL", sql.Out{Dest: &nullDateTime}) + if err != nil { + t.Fatalf("Failed to execute query with NULL: %v", err) + } + if nullDateTime.Valid { + t.Error("Expected NullDateTime to be invalid (NULL)") + } + + // Test valid value + _, err = conn.ExecContext(ctx, "SELECT @p1 = '2023-12-25 14:30:45'", sql.Out{Dest: &nullDateTime}) + if err != nil { + t.Fatalf("Failed to execute query with datetime: %v", err) + } + if !nullDateTime.Valid { + t.Error("Expected NullDateTime to be valid") + } + // Check that the date and time components are correct + if nullDateTime.DateTime.Date.Year != 2023 || + nullDateTime.DateTime.Date.Month != time.December || + nullDateTime.DateTime.Date.Day != 25 || + nullDateTime.DateTime.Time.Hour != 14 || + nullDateTime.DateTime.Time.Minute != 30 || + nullDateTime.DateTime.Time.Second != 45 { + t.Errorf("Unexpected datetime value: %v", nullDateTime.DateTime) + } + }) + + // Test NullTime OUT parameter + t.Run("NullTime", func(t *testing.T) { + var nullTime NullTime + + // Test NULL value + _, err := conn.ExecContext(ctx, "SELECT @p1 = NULL", sql.Out{Dest: &nullTime}) + if err != nil { + t.Fatalf("Failed to execute query with NULL: %v", err) + } + if nullTime.Valid { + t.Error("Expected NullTime to be invalid (NULL)") + } + + // Test valid value + _, err = conn.ExecContext(ctx, "SELECT @p1 = '14:30:45'", sql.Out{Dest: &nullTime}) + if err != nil { + t.Fatalf("Failed to execute query with time: %v", err) + } + if !nullTime.Valid { + t.Error("Expected NullTime to be valid") + } + if nullTime.Time.Hour != 14 || nullTime.Time.Minute != 30 || nullTime.Time.Second != 45 { + t.Errorf("Expected time 14:30:45, got %02d:%02d:%02d", + nullTime.Time.Hour, nullTime.Time.Minute, nullTime.Time.Second) + } + }) + }) + + // Test civil null types as input parameters + t.Run("Input parameters", func(t *testing.T) { + // Test NullDate input parameter + t.Run("NullDate", func(t *testing.T) { + // Test NULL value + nullDate := NullDate{Valid: false} + var result *time.Time + err := conn.QueryRowContext(ctx, "SELECT @p1", nullDate).Scan(&result) + if err != nil { + t.Fatalf("Failed to query with NULL NullDate: %v", err) + } + if result != nil { + t.Error("Expected result to be nil for NULL input") + } + + // Test valid value + nullDate = NullDate{Date: civil.Date{Year: 2023, Month: time.December, Day: 25}, Valid: true} + err = conn.QueryRowContext(ctx, "SELECT @p1", nullDate).Scan(&result) + if err != nil { + t.Fatalf("Failed to query with valid NullDate: %v", err) + } + if result == nil { + t.Error("Expected result to be non-nil for valid input") + } else { + expectedTime := time.Date(2023, time.December, 25, 0, 0, 0, 0, result.Location()) + if !result.Equal(expectedTime) { + t.Errorf("Expected %v, got %v", expectedTime, *result) + } + } + }) + + // Test NullDateTime input parameter + t.Run("NullDateTime", func(t *testing.T) { + // Test NULL value + nullDateTime := NullDateTime{Valid: false} + var result *time.Time + err := conn.QueryRowContext(ctx, "SELECT @p1", nullDateTime).Scan(&result) + if err != nil { + t.Fatalf("Failed to query with NULL NullDateTime: %v", err) + } + if result != nil { + t.Error("Expected result to be nil for NULL input") + } + + // Test valid value + testTime := time.Date(2023, time.December, 25, 14, 30, 45, 0, time.UTC) + nullDateTime = NullDateTime{DateTime: civil.DateTimeOf(testTime), Valid: true} + err = conn.QueryRowContext(ctx, "SELECT @p1", nullDateTime).Scan(&result) + if err != nil { + t.Fatalf("Failed to query with valid NullDateTime: %v", err) + } + if result == nil { + t.Error("Expected result to be non-nil for valid input") + } + }) + + // Test NullTime input parameter + t.Run("NullTime", func(t *testing.T) { + // Test NULL value + nullTime := NullTime{Valid: false} + var result *time.Time + err := conn.QueryRowContext(ctx, "SELECT @p1", nullTime).Scan(&result) + if err != nil { + t.Fatalf("Failed to query with NULL NullTime: %v", err) + } + if result != nil { + t.Error("Expected result to be nil for NULL input") + } + + // Test valid value + testTime := time.Date(2023, time.December, 25, 14, 30, 45, 0, time.UTC) + nullTime = NullTime{Time: civil.TimeOf(testTime), Valid: true} + err = conn.QueryRowContext(ctx, "SELECT @p1", nullTime).Scan(&result) + if err != nil { + t.Fatalf("Failed to query with valid NullTime: %v", err) + } + if result == nil { + t.Error("Expected result to be non-nil for valid input") + } + }) + }) +} diff --git a/civil_null_test.go b/civil_null_test.go new file mode 100644 index 00000000..810002e2 --- /dev/null +++ b/civil_null_test.go @@ -0,0 +1,489 @@ +package mssql + +import ( + "database/sql/driver" + "encoding/json" + "testing" + "time" + + "github.com/golang-sql/civil" +) + +func TestNullDate(t *testing.T) { + // Test Value() method + t.Run("Value", func(t *testing.T) { + // Valid case + date := civil.Date{Year: 2023, Month: time.December, Day: 25} + nullDate := NullDate{Date: date, Valid: true} + val, err := nullDate.Value() + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + expectedTime := date.In(time.UTC) + if val != expectedTime { + t.Errorf("Expected %v, got %v", expectedTime, val) + } + + // Invalid case + nullDate = NullDate{Valid: false} + val, err = nullDate.Value() + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if val != nil { + t.Errorf("Expected nil, got %v", val) + } + }) + + // Test Scan() method + t.Run("Scan", func(t *testing.T) { + var nullDate NullDate + + // Scan nil value + err := nullDate.Scan(nil) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if nullDate.Valid { + t.Error("Expected Valid to be false") + } + + // Scan time.Time value + testTime := time.Date(2023, time.December, 25, 14, 30, 45, 0, time.UTC) + err = nullDate.Scan(testTime) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if !nullDate.Valid { + t.Error("Expected Valid to be true") + } + expectedDate := civil.DateOf(testTime) + if nullDate.Date != expectedDate { + t.Errorf("Expected %v, got %v", expectedDate, nullDate.Date) + } + + // Scan invalid type + err = nullDate.Scan("invalid") + if err == nil { + t.Error("Expected error for invalid type") + } + if nullDate.Valid { + t.Error("Expected Valid to be false after error") + } + }) + + // Test String() method + t.Run("String", func(t *testing.T) { + // Valid case + date := civil.Date{Year: 2023, Month: time.December, Day: 25} + nullDate := NullDate{Date: date, Valid: true} + str := nullDate.String() + if str != date.String() { + t.Errorf("Expected %s, got %s", date.String(), str) + } + + // Invalid case + nullDate = NullDate{Valid: false} + str = nullDate.String() + if str != "NULL" { + t.Errorf("Expected 'NULL', got %s", str) + } + }) + + // Test JSON marshaling/unmarshaling + t.Run("JSON", func(t *testing.T) { + // Valid case + date := civil.Date{Year: 2023, Month: time.December, Day: 25} + nullDate := NullDate{Date: date, Valid: true} + data, err := json.Marshal(nullDate) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + + var unmarshaled NullDate + err = json.Unmarshal(data, &unmarshaled) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if !unmarshaled.Valid || unmarshaled.Date != date { + t.Errorf("Expected %v (valid), got %v (valid: %t)", date, unmarshaled.Date, unmarshaled.Valid) + } + + // Invalid case + nullDate = NullDate{Valid: false} + data, err = json.Marshal(nullDate) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if string(data) != "null" { + t.Errorf("Expected 'null', got %s", string(data)) + } + + err = json.Unmarshal([]byte("null"), &unmarshaled) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if unmarshaled.Valid { + t.Error("Expected Valid to be false") + } + }) +} + +func TestNullDateTime(t *testing.T) { + // Test Value() method + t.Run("Value", func(t *testing.T) { + // Valid case + testTime := time.Date(2023, time.December, 25, 14, 30, 45, 0, time.UTC) + dateTime := civil.DateTimeOf(testTime) + nullDateTime := NullDateTime{DateTime: dateTime, Valid: true} + val, err := nullDateTime.Value() + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + expectedTime := dateTime.In(time.UTC) + if val != expectedTime { + t.Errorf("Expected %v, got %v", expectedTime, val) + } + + // Invalid case + nullDateTime = NullDateTime{Valid: false} + val, err = nullDateTime.Value() + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if val != nil { + t.Errorf("Expected nil, got %v", val) + } + }) + + // Test Scan() method + t.Run("Scan", func(t *testing.T) { + var nullDateTime NullDateTime + + // Scan nil value + err := nullDateTime.Scan(nil) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if nullDateTime.Valid { + t.Error("Expected Valid to be false") + } + + // Scan time.Time value + testTime := time.Date(2023, time.December, 25, 14, 30, 45, 0, time.UTC) + err = nullDateTime.Scan(testTime) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if !nullDateTime.Valid { + t.Error("Expected Valid to be true") + } + expectedDateTime := civil.DateTimeOf(testTime) + if nullDateTime.DateTime != expectedDateTime { + t.Errorf("Expected %v, got %v", expectedDateTime, nullDateTime.DateTime) + } + + // Scan invalid type + err = nullDateTime.Scan("invalid") + if err == nil { + t.Error("Expected error for invalid type") + } + if nullDateTime.Valid { + t.Error("Expected Valid to be false after error") + } + }) +} + +func TestNullTime(t *testing.T) { + // Test Value() method + t.Run("Value", func(t *testing.T) { + // Valid case + testTime := time.Date(2023, time.December, 25, 14, 30, 45, 0, time.UTC) + civilTime := civil.TimeOf(testTime) + nullTime := NullTime{Time: civilTime, Valid: true} + val, err := nullTime.Value() + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + expectedTime := time.Date(1, 1, 1, civilTime.Hour, civilTime.Minute, civilTime.Second, civilTime.Nanosecond, time.UTC) + if val != expectedTime { + t.Errorf("Expected %v, got %v", expectedTime, val) + } + + // Invalid case + nullTime = NullTime{Valid: false} + val, err = nullTime.Value() + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if val != nil { + t.Errorf("Expected nil, got %v", val) + } + }) + + // Test Scan() method + t.Run("Scan", func(t *testing.T) { + var nullTime NullTime + + // Scan nil value + err := nullTime.Scan(nil) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if nullTime.Valid { + t.Error("Expected Valid to be false") + } + + // Scan time.Time value + testTime := time.Date(2023, time.December, 25, 14, 30, 45, 0, time.UTC) + err = nullTime.Scan(testTime) + if err != nil { + t.Errorf("Unexpected error: %v", err) + } + if !nullTime.Valid { + t.Error("Expected Valid to be true") + } + expectedTime := civil.TimeOf(testTime) + if nullTime.Time != expectedTime { + t.Errorf("Expected %v, got %v", expectedTime, nullTime.Time) + } + + // Scan invalid type + err = nullTime.Scan("invalid") + if err == nil { + t.Error("Expected error for invalid type") + } + if nullTime.Valid { + t.Error("Expected Valid to be false after error") + } + }) +} + +// Test that the types implement the required interfaces +func TestNullCivilTypesImplementInterfaces(t *testing.T) { + var ( + _ driver.Valuer = NullDate{} + _ driver.Valuer = NullDateTime{} + _ driver.Valuer = NullTime{} + ) + // Note: Scanner interface is verified by successful compilation of Scan methods +} + +// TestNullCivilTypesParameterEncoding tests that nullable civil types are properly encoded +// as typed NULL parameters rather than untyped NULLs, which is important for OUT parameters +func TestNullCivilTypesParameterEncoding(t *testing.T) { + // Create a mock connection and statement for testing + c := &Conn{} + c.sess = &tdsSession{} + c.sess.loginAck.TDSVersion = verTDS74 // Use modern TDS version + s := &Stmt{c: c} + + t.Run("NullDate parameter encoding", func(t *testing.T) { + // Test valid NullDate + validDate := NullDate{Date: civil.Date{Year: 2023, Month: time.December, Day: 25}, Valid: true} + param, err := s.makeParam(validDate) + if err != nil { + t.Errorf("Unexpected error for valid NullDate: %v", err) + } + if param.ti.TypeId != typeDateN { + t.Errorf("Expected TypeId %v for valid NullDate, got %v", typeDateN, param.ti.TypeId) + } + if len(param.buffer) == 0 { + t.Error("Expected non-empty buffer for valid NullDate") + } + + // Test invalid NullDate (NULL) + nullDate := NullDate{Valid: false} + param, err = s.makeParam(nullDate) + if err != nil { + t.Errorf("Unexpected error for NULL NullDate: %v", err) + } + if param.ti.TypeId != typeDateN { + t.Errorf("Expected TypeId %v for NULL NullDate, got %v", typeDateN, param.ti.TypeId) + } + if param.ti.TypeId == typeNull { + t.Error("NULL NullDate should not use untyped NULL (typeNull)") + } + if len(param.buffer) != 0 { + t.Error("Expected empty buffer for NULL NullDate") + } + if param.ti.Size != 3 { + t.Errorf("Expected Size 3 for NULL NullDate, got %v", param.ti.Size) + } + }) + + t.Run("NullDateTime parameter encoding", func(t *testing.T) { + // Test valid NullDateTime + testTime := time.Date(2023, time.December, 25, 14, 30, 45, 0, time.UTC) + validDateTime := NullDateTime{DateTime: civil.DateTimeOf(testTime), Valid: true} + param, err := s.makeParam(validDateTime) + if err != nil { + t.Errorf("Unexpected error for valid NullDateTime: %v", err) + } + if param.ti.TypeId != typeDateTime2N { + t.Errorf("Expected TypeId %v for valid NullDateTime, got %v", typeDateTime2N, param.ti.TypeId) + } + if len(param.buffer) == 0 { + t.Error("Expected non-empty buffer for valid NullDateTime") + } + + // Test invalid NullDateTime (NULL) + nullDateTime := NullDateTime{Valid: false} + param, err = s.makeParam(nullDateTime) + if err != nil { + t.Errorf("Unexpected error for NULL NullDateTime: %v", err) + } + if param.ti.TypeId != typeDateTime2N { + t.Errorf("Expected TypeId %v for NULL NullDateTime, got %v", typeDateTime2N, param.ti.TypeId) + } + if param.ti.TypeId == typeNull { + t.Error("NULL NullDateTime should not use untyped NULL (typeNull)") + } + if len(param.buffer) != 0 { + t.Error("Expected empty buffer for NULL NullDateTime") + } + if param.ti.Scale != 7 { + t.Errorf("Expected Scale 7 for NULL NullDateTime, got %v", param.ti.Scale) + } + }) + + t.Run("NullTime parameter encoding", func(t *testing.T) { + // Test valid NullTime + testTime := time.Date(2023, time.December, 25, 14, 30, 45, 0, time.UTC) + validTime := NullTime{Time: civil.TimeOf(testTime), Valid: true} + param, err := s.makeParam(validTime) + if err != nil { + t.Errorf("Unexpected error for valid NullTime: %v", err) + } + if param.ti.TypeId != typeTimeN { + t.Errorf("Expected TypeId %v for valid NullTime, got %v", typeTimeN, param.ti.TypeId) + } + if len(param.buffer) == 0 { + t.Error("Expected non-empty buffer for valid NullTime") + } + + // Test invalid NullTime (NULL) + nullTime := NullTime{Valid: false} + param, err = s.makeParam(nullTime) + if err != nil { + t.Errorf("Unexpected error for NULL NullTime: %v", err) + } + if param.ti.TypeId != typeTimeN { + t.Errorf("Expected TypeId %v for NULL NullTime, got %v", typeTimeN, param.ti.TypeId) + } + if param.ti.TypeId == typeNull { + t.Error("NULL NullTime should not use untyped NULL (typeNull)") + } + if len(param.buffer) != 0 { + t.Error("Expected empty buffer for NULL NullTime") + } + if param.ti.Scale != 7 { + t.Errorf("Expected Scale 7 for NULL NullTime, got %v", param.ti.Scale) + } + }) + + // Test pointer types (as used in OUT parameters) + t.Run("Pointer NullDate parameter encoding", func(t *testing.T) { + // Test valid *NullDate + validDate := &NullDate{Date: civil.Date{Year: 2023, Month: time.December, Day: 25}, Valid: true} + param, err := s.makeParam(validDate) + if err != nil { + t.Errorf("Unexpected error for valid *NullDate: %v", err) + } + if param.ti.TypeId != typeDateN { + t.Errorf("Expected TypeId %v for valid *NullDate, got %v", typeDateN, param.ti.TypeId) + } + if len(param.buffer) == 0 { + t.Error("Expected non-empty buffer for valid *NullDate") + } + + // Test invalid *NullDate (NULL) + nullDate := &NullDate{Valid: false} + param, err = s.makeParam(nullDate) + if err != nil { + t.Errorf("Unexpected error for NULL *NullDate: %v", err) + } + if param.ti.TypeId != typeDateN { + t.Errorf("Expected TypeId %v for NULL *NullDate, got %v", typeDateN, param.ti.TypeId) + } + if param.ti.TypeId == typeNull { + t.Error("NULL *NullDate should not use untyped NULL (typeNull)") + } + if len(param.buffer) != 0 { + t.Error("Expected empty buffer for NULL *NullDate") + } + if param.ti.Size != 3 { + t.Errorf("Expected Size 3 for NULL *NullDate, got %v", param.ti.Size) + } + }) + + t.Run("Pointer NullDateTime parameter encoding", func(t *testing.T) { + // Test valid *NullDateTime + testTime := time.Date(2023, time.December, 25, 14, 30, 45, 0, time.UTC) + validDateTime := &NullDateTime{DateTime: civil.DateTimeOf(testTime), Valid: true} + param, err := s.makeParam(validDateTime) + if err != nil { + t.Errorf("Unexpected error for valid *NullDateTime: %v", err) + } + if param.ti.TypeId != typeDateTime2N { + t.Errorf("Expected TypeId %v for valid *NullDateTime, got %v", typeDateTime2N, param.ti.TypeId) + } + if len(param.buffer) == 0 { + t.Error("Expected non-empty buffer for valid *NullDateTime") + } + + // Test invalid *NullDateTime (NULL) + nullDateTime := &NullDateTime{Valid: false} + param, err = s.makeParam(nullDateTime) + if err != nil { + t.Errorf("Unexpected error for NULL *NullDateTime: %v", err) + } + if param.ti.TypeId != typeDateTime2N { + t.Errorf("Expected TypeId %v for NULL *NullDateTime, got %v", typeDateTime2N, param.ti.TypeId) + } + if param.ti.TypeId == typeNull { + t.Error("NULL *NullDateTime should not use untyped NULL (typeNull)") + } + if len(param.buffer) != 0 { + t.Error("Expected empty buffer for NULL *NullDateTime") + } + if param.ti.Scale != 7 { + t.Errorf("Expected Scale 7 for NULL *NullDateTime, got %v", param.ti.Scale) + } + }) + + t.Run("Pointer NullTime parameter encoding", func(t *testing.T) { + // Test valid *NullTime + testTime := time.Date(2023, time.December, 25, 14, 30, 45, 0, time.UTC) + validTime := &NullTime{Time: civil.TimeOf(testTime), Valid: true} + param, err := s.makeParam(validTime) + if err != nil { + t.Errorf("Unexpected error for valid *NullTime: %v", err) + } + if param.ti.TypeId != typeTimeN { + t.Errorf("Expected TypeId %v for valid *NullTime, got %v", typeTimeN, param.ti.TypeId) + } + if len(param.buffer) == 0 { + t.Error("Expected non-empty buffer for valid *NullTime") + } + + // Test invalid *NullTime (NULL) + nullTime := &NullTime{Valid: false} + param, err = s.makeParam(nullTime) + if err != nil { + t.Errorf("Unexpected error for NULL *NullTime: %v", err) + } + if param.ti.TypeId != typeTimeN { + t.Errorf("Expected TypeId %v for NULL *NullTime, got %v", typeTimeN, param.ti.TypeId) + } + if param.ti.TypeId == typeNull { + t.Error("NULL *NullTime should not use untyped NULL (typeNull)") + } + if len(param.buffer) != 0 { + t.Error("Expected empty buffer for NULL *NullTime") + } + if param.ti.Scale != 7 { + t.Errorf("Expected Scale 7 for NULL *NullTime, got %v", param.ti.Scale) + } + }) +} diff --git a/doc/how-to-handle-date-and-time-types.md b/doc/how-to-handle-date-and-time-types.md index 5ffacd82..616b2507 100644 --- a/doc/how-to-handle-date-and-time-types.md +++ b/doc/how-to-handle-date-and-time-types.md @@ -5,13 +5,16 @@ SQL Server has six date and time datatypes: date, time, smalldatetime, datetime, ## Inserting Date and Time Data The following is a list of datatypes that can be used to insert data into a SQL Server date and/or time type column: -- string -- time.Time -- mssql.DateTime1 -- mssql.DateTimeOffset -- "github.com/golang-sql/civil".Date -- "github.com/golang-sql/civil".Time -- "github.com/golang-sql/civil".DateTime +- string +- time.Time +- mssql.DateTime1 +- mssql.DateTimeOffset +- "github.com/golang-sql/civil".Date +- "github.com/golang-sql/civil".Time +- "github.com/golang-sql/civil".DateTime +- mssql.NullDate (nullable civil.Date) +- mssql.NullTime (nullable civil.Time) +- mssql.NullDateTime (nullable civil.DateTime) `time.Time` and `mssql.DateTimeOffset` contain the most information (time zone and over 7 digits precision). Designed to match the SQL Server `datetime` type, `mssql.DateTime1` does not have time zone information, only has up to 3 digits precision and they are rouded to increments of .000, .003 or .007 seconds when the data is passed to SQL Server. If you use `mssql.DateTime1` to hold time zone information or very precised time data (more than 3 decimal digits), you will see data lost when inserting into columns with types that can hold more information. For example: @@ -30,16 +33,31 @@ _, err = stmt.Exec(param, param, param) // precisions are lost in all columns. Also, time zone information is lost in datetimeoffsetCol ``` - `"github.com/golang-sql/civil".DateTime` does not have time zone information. `"github.com/golang-sql/civil".Date` only has the date information, and `"github.com/golang-sql/civil".Time` only has the time information. `string` can also be used to insert data into date and time types columns, but you have to make sure the format is accepted by SQL Server. + `"github.com/golang-sql/civil".DateTime` does not have time zone information. `"github.com/golang-sql/civil".Date` only has the date information, and `"github.com/golang-sql/civil".Time` only has the time information. `string` can also be used to insert data into date and time types columns, but you have to make sure the format is accepted by SQL Server. + +The nullable civil types (`mssql.NullDate`, `mssql.NullDateTime`, `mssql.NullTime`) can be used when you need to handle NULL values, particularly useful for OUT parameters: + +```go +var nullDate mssql.NullDate +_, err := conn.ExecContext(ctx, "SELECT @p1 = NULL", sql.Out{Dest: &nullDate}) +// nullDate.Valid will be false + +var nullDateTime mssql.NullDateTime +_, err = conn.ExecContext(ctx, "SELECT @p1 = '2023-12-25 14:30:45'", sql.Out{Dest: &nullDateTime}) +// nullDateTime.Valid will be true, nullDateTime.DateTime contains the value +``` ## Retrieving Date and Time Data -The following is a list of datatypes that can be used to retrieved data from a SQL Server date and/or time type column: -- string -- sql.RawBytes -- time.Time -- mssql.DateTime1 -- mssql.DateTiimeOffset +The following is a list of datatypes that can be used to retrieved data from a SQL Server date and/or time type column: +- string +- sql.RawBytes +- time.Time +- mssql.DateTime1 +- mssql.DateTiimeOffset +- mssql.NullDate (for nullable date columns) +- mssql.NullTime (for nullable time columns) +- mssql.NullDateTime (for nullable datetime2 columns) When using these data types to retrieve information from a date and/or time type column, you may end up with some extra unexpected information. For example, if you use Go type `time.Time` to retrieve information from a SQL Server `date` column: diff --git a/msdsn/conn_str.go b/msdsn/conn_str.go index 467e279d..30e93140 100644 --- a/msdsn/conn_str.go +++ b/msdsn/conn_str.go @@ -711,11 +711,11 @@ func splitAdoConnectionStringParts(dsn string) []string { var parts []string var current strings.Builder inQuotes := false - + runes := []rune(dsn) for i := 0; i < len(runes); i++ { char := runes[i] - + if char == '"' { if inQuotes && i+1 < len(runes) && runes[i+1] == '"' { // Double quote escape sequence - add both quotes to current part @@ -735,12 +735,12 @@ func splitAdoConnectionStringParts(dsn string) []string { current.WriteRune(char) } } - + // Add the last part if it's not empty if current.Len() > 0 { parts = append(parts, current.String()) } - + return parts } diff --git a/msdsn/conn_str_test.go b/msdsn/conn_str_test.go index 1889b02f..1fb49122 100644 --- a/msdsn/conn_str_test.go +++ b/msdsn/conn_str_test.go @@ -111,12 +111,12 @@ func TestValidConnectionString(t *testing.T) { {"MultiSubnetFailover=false", func(p Config) bool { return !p.MultiSubnetFailover }}, {"timezone=Asia/Shanghai", func(p Config) bool { return p.Encoding.Timezone.String() == "Asia/Shanghai" }}, {"Pwd=placeholder", func(p Config) bool { return p.Password == "placeholder" }}, - + // ADO connection string tests with double-quoted values containing semicolons {"server=test;password=\"pass;word\"", func(p Config) bool { return p.Host == "test" && p.Password == "pass;word" }}, {"password=\"[2+R2B6O:fF/[;]cJsr\"", func(p Config) bool { return p.Password == "[2+R2B6O:fF/[;]cJsr" }}, - {"server=host;user id=user;password=\"complex;pass=word\"", func(p Config) bool { - return p.Host == "host" && p.User == "user" && p.Password == "complex;pass=word" + {"server=host;user id=user;password=\"complex;pass=word\"", func(p Config) bool { + return p.Host == "host" && p.User == "user" && p.Password == "complex;pass=word" }}, {"password=\"value with \"\"quotes\"\" inside\"", func(p Config) bool { return p.Password == "value with \"quotes\" inside" }}, {"server=test;password=\"simple\"", func(p Config) bool { return p.Host == "test" && p.Password == "simple" }}, @@ -125,19 +125,19 @@ func TestValidConnectionString(t *testing.T) { return p.Host == "sql.database.windows.net" && p.Database == "MyDatabase" && p.User == "testadmin@sql.database.windows.net" && p.Password == "[2+R2B6O:fF/[;]cJsr" }}, // Additional edge cases for double-quoted values - {"password=\"\"", func(p Config) bool { return p.Password == "" }}, // Empty quoted password - {"password=\";\"", func(p Config) bool { return p.Password == ";" }}, // Just a semicolon - {"password=\";;\"", func(p Config) bool { return p.Password == ";;" }}, // Multiple semicolons + {"password=\"\"", func(p Config) bool { return p.Password == "" }}, // Empty quoted password + {"password=\";\"", func(p Config) bool { return p.Password == ";" }}, // Just a semicolon + {"password=\";;\"", func(p Config) bool { return p.Password == ";;" }}, // Multiple semicolons {"server=\"host;name\";password=\"pass;word\"", func(p Config) bool { return p.Host == "host;name" && p.Password == "pass;word" }}, // Multiple quoted values - + // Test cases with multibyte UTF-8 characters - {"password=\"пароль;test\"", func(p Config) bool { return p.Password == "пароль;test" }}, // Cyrillic characters with semicolon + {"password=\"пароль;test\"", func(p Config) bool { return p.Password == "пароль;test" }}, // Cyrillic characters with semicolon {"server=\"服务器;name\";password=\"密码;word\"", func(p Config) bool { return p.Host == "服务器;name" && p.Password == "密码;word" }}, // Chinese characters - {"password=\"🔐;secret;🗝️\"", func(p Config) bool { return p.Password == "🔐;secret;🗝️" }}, // Emoji characters with semicolons - {"user id=\"用户名\";password=\"пароль\"", func(p Config) bool { return p.User == "用户名" && p.Password == "пароль" }}, // Mixed multibyte chars - {"password=\"测试\"\"密码\"\"\"", func(p Config) bool { return p.Password == "测试\"密码\"" }}, // Chinese chars with escaped quotes - {"password=\"café;naïve;résumé\"", func(p Config) bool { return p.Password == "café;naïve;résumé" }}, // Accented characters - + {"password=\"🔐;secret;🗝️\"", func(p Config) bool { return p.Password == "🔐;secret;🗝️" }}, // Emoji characters with semicolons + {"user id=\"用户名\";password=\"пароль\"", func(p Config) bool { return p.User == "用户名" && p.Password == "пароль" }}, // Mixed multibyte chars + {"password=\"测试\"\"密码\"\"\"", func(p Config) bool { return p.Password == "测试\"密码\"" }}, // Chinese chars with escaped quotes + {"password=\"café;naïve;résumé\"", func(p Config) bool { return p.Password == "café;naïve;résumé" }}, // Accented characters + // those are supported currently, but maybe should not be {"someparam", func(p Config) bool { return true }}, {";;=;", func(p Config) bool { return true }}, diff --git a/mssql.go b/mssql.go index eae193eb..18570e30 100644 --- a/mssql.go +++ b/mssql.go @@ -996,10 +996,38 @@ func (s *Stmt) makeParam(val driver.Value) (res param, err error) { if valuer.Valid { return s.makeParam(valuer.Int32) } + case NullDate: + if valuer.Valid { + return s.makeParamExtra(valuer.Date) + } + return s.makeParamExtra(valuer) + case NullDateTime: + if valuer.Valid { + return s.makeParamExtra(valuer.DateTime) + } + return s.makeParamExtra(valuer) + case NullTime: + if valuer.Valid { + return s.makeParamExtra(valuer.Time) + } + return s.makeParamExtra(valuer) + case *NullDate: + if valuer.Valid { + return s.makeParamExtra(valuer.Date) + } + return s.makeParamExtra(*valuer) + case *NullDateTime: + if valuer.Valid { + return s.makeParamExtra(valuer.DateTime) + } + return s.makeParamExtra(*valuer) + case *NullTime: + if valuer.Valid { + return s.makeParamExtra(valuer.Time) + } + return s.makeParamExtra(*valuer) case UniqueIdentifier: case NullUniqueIdentifier: - default: - break case driver.Valuer: // If the value has a non-nil value, call MakeParam on its Value val, e := driver.DefaultParameterConverter.ConvertValue(valuer) @@ -1143,6 +1171,34 @@ func (s *Stmt) makeParam(val driver.Value) (res param, err error) { } else { res.ti.TypeId = typeDateTimeN } + case NullDate: // only null values reach here + res.ti.TypeId = typeDateN + res.ti.Size = 3 + res.buffer = []byte{} + case NullDateTime: // only null values reach here + res.ti.TypeId = typeDateTime2N + res.ti.Scale = 7 + res.ti.Size = calcTimeSize(int(res.ti.Scale)) + 3 + res.buffer = []byte{} + case NullTime: // only null values reach here + res.ti.TypeId = typeTimeN + res.ti.Scale = 7 + res.ti.Size = calcTimeSize(int(res.ti.Scale)) + res.buffer = []byte{} + case *NullDate: // only null values reach here + res.ti.TypeId = typeDateN + res.ti.Size = 3 + res.buffer = []byte{} + case *NullDateTime: // only null values reach here + res.ti.TypeId = typeDateTime2N + res.ti.Scale = 7 + res.ti.Size = calcTimeSize(int(res.ti.Scale)) + 3 + res.buffer = []byte{} + case *NullTime: // only null values reach here + res.ti.TypeId = typeTimeN + res.ti.Scale = 7 + res.ti.Size = calcTimeSize(int(res.ti.Scale)) + res.buffer = []byte{} case driver.Valuer: // We have a custom Valuer implementation with a nil value return s.makeParam(nil) diff --git a/mssql_go19.go b/mssql_go19.go index 6df8c366..ac332c4b 100644 --- a/mssql_go19.go +++ b/mssql_go19.go @@ -71,6 +71,12 @@ func convertInputParameter(val interface{}) (interface{}, error) { return val, nil case civil.Time: return val, nil + case NullDate: + return val, nil + case NullDateTime: + return val, nil + case NullTime: + return val, nil // case *apd.Decimal: // return nil case float32: @@ -188,6 +194,34 @@ func (s *Stmt) makeParamExtra(val driver.Value) (res param, err error) { res.ti.Scale = 7 res.buffer = encodeTime(val.Hour, val.Minute, val.Second, val.Nanosecond, int(res.ti.Scale)) res.ti.Size = len(res.buffer) + case NullDate: + res.ti.TypeId = typeDateN + res.ti.Size = 3 + if val.Valid { + res.buffer = encodeDate(val.Date.In(loc)) + } else { + res.buffer = []byte{} + } + case NullDateTime: + res.ti.TypeId = typeDateTime2N + res.ti.Scale = 7 + if val.Valid { + res.buffer = encodeDateTime2(val.DateTime.In(loc), int(res.ti.Scale)) + res.ti.Size = len(res.buffer) + } else { + res.buffer = []byte{} + res.ti.Size = calcTimeSize(int(res.ti.Scale)) + 3 + } + case NullTime: + res.ti.TypeId = typeTimeN + res.ti.Scale = 7 + if val.Valid { + res.buffer = encodeTime(val.Time.Hour, val.Time.Minute, val.Time.Second, val.Time.Nanosecond, int(res.ti.Scale)) + res.ti.Size = len(res.buffer) + } else { + res.buffer = []byte{} + res.ti.Size = calcTimeSize(int(res.ti.Scale)) + } case sql.Out: res, err = s.makeParam(val.Dest) res.Flags = fByRevValue diff --git a/tvp_go19.go b/tvp_go19.go index 9a71b7de..576f57e5 100644 --- a/tvp_go19.go +++ b/tvp_go19.go @@ -13,6 +13,7 @@ import ( "strings" "time" + "github.com/golang-sql/civil" "github.com/microsoft/go-mssqldb/msdsn" ) @@ -108,6 +109,23 @@ func (tvp TVP) encode(schema, name string, columnStr []columnStruct, tvpFieldInd if tvp.verifyStandardTypeOnNull(buf, tvpVal) { continue } + + // Extract inner value from nullable civil types when they are valid + switch v := tvpVal.(type) { + case NullDate: + if v.Valid { + tvpVal = v.Date + } + case NullDateTime: + if v.Valid { + tvpVal = v.DateTime + } + case NullTime: + if v.Valid { + tvpVal = v.Time + } + } + valOf := reflect.ValueOf(tvpVal) elemKind := field.Kind() if elemKind == reflect.Ptr && valOf.IsNil() { @@ -279,6 +297,12 @@ func (tvp TVP) createZeroType(fieldVal interface{}) interface{} { return defaultInt64 case sql.NullString: return defaultString + case NullDate: + return civil.Date{} + case NullDateTime: + return civil.DateTime{} + case NullTime: + return civil.Time{} } return fieldVal } @@ -310,6 +334,21 @@ func (tvp TVP) verifyStandardTypeOnNull(buf *bytes.Buffer, tvpVal interface{}) b binary.Write(buf, binary.LittleEndian, uint64(_PLP_NULL)) return true } + case NullDate: + if !val.Valid { + binary.Write(buf, binary.LittleEndian, defaultNull) + return true + } + case NullDateTime: + if !val.Valid { + binary.Write(buf, binary.LittleEndian, defaultNull) + return true + } + case NullTime: + if !val.Valid { + binary.Write(buf, binary.LittleEndian, defaultNull) + return true + } } return false } diff --git a/tvp_go19_test.go b/tvp_go19_test.go index 6b8927fb..794ecc71 100644 --- a/tvp_go19_test.go +++ b/tvp_go19_test.go @@ -8,6 +8,7 @@ import ( "testing" "time" + "github.com/golang-sql/civil" "github.com/microsoft/go-mssqldb/msdsn" ) @@ -587,3 +588,166 @@ func TestTVP_encode_WithGuidConversion(t *testing.T) { func TestTVP_encode(t *testing.T) { testTVP_encode(t, false /*guidConversion*/) } + +// TestTVPWithNullCivilTypes tests that nullable civil types work correctly in TVP operations +func TestTVPWithNullCivilTypes(t *testing.T) { + type tvpDataRowNullDateTime struct { + T NullDateTime + } + + type tvpDataRowNullDate struct { + D NullDate + } + + type tvpDataRowNullTime struct { + T NullTime + } + + type tvpDataRowMixed struct { + Date NullDate `tvp:"date_col"` + DateTime NullDateTime `tvp:"datetime_col"` + Time NullTime `tvp:"time_col"` + } + + tests := []struct { + name string + tvpData interface{} + wantErr bool + }{ + { + name: "NullDateTime with Valid=false", + tvpData: []tvpDataRowNullDateTime{ + {T: NullDateTime{Valid: false}}, + }, + wantErr: false, + }, + { + name: "NullDateTime with Valid=true", + tvpData: []tvpDataRowNullDateTime{ + {T: NullDateTime{DateTime: civil.DateTime{Date: civil.Date{Year: 2025, Month: 10, Day: 2}, Time: civil.Time{Hour: 16, Minute: 10, Second: 55}}, Valid: true}}, + }, + wantErr: false, + }, + { + name: "NullDate with Valid=false", + tvpData: []tvpDataRowNullDate{ + {D: NullDate{Valid: false}}, + }, + wantErr: false, + }, + { + name: "NullDate with Valid=true", + tvpData: []tvpDataRowNullDate{ + {D: NullDate{Date: civil.Date{Year: 2025, Month: 10, Day: 2}, Valid: true}}, + }, + wantErr: false, + }, + { + name: "NullTime with Valid=false", + tvpData: []tvpDataRowNullTime{ + {T: NullTime{Valid: false}}, + }, + wantErr: false, + }, + { + name: "NullTime with Valid=true", + tvpData: []tvpDataRowNullTime{ + {T: NullTime{Time: civil.Time{Hour: 16, Minute: 10, Second: 55}, Valid: true}}, + }, + wantErr: false, + }, + { + name: "Mixed nullable civil types with some null, some valid", + tvpData: []tvpDataRowMixed{ + { + Date: NullDate{Valid: false}, + DateTime: NullDateTime{DateTime: civil.DateTime{Date: civil.Date{Year: 2025, Month: 10, Day: 2}, Time: civil.Time{Hour: 16, Minute: 10, Second: 55}}, Valid: true}, + Time: NullTime{Valid: false}, + }, + { + Date: NullDate{Date: civil.Date{Year: 2025, Month: 12, Day: 25}, Valid: true}, + DateTime: NullDateTime{Valid: false}, + Time: NullTime{Time: civil.Time{Hour: 9, Minute: 30, Second: 0}, Valid: true}, + }, + }, + wantErr: false, + }, + { + name: "User example 1: Empty NullDateTime", + tvpData: []tvpDataRowNullDateTime{ + {T: NullDateTime{}}, // Valid defaults to false + }, + wantErr: false, + }, + { + name: "User example 2: Valid NullDateTime", + tvpData: func() []tvpDataRowNullDateTime { + t1, _ := civil.ParseDateTime("2025-10-02T16:10:55") + return []tvpDataRowNullDateTime{ + {T: NullDateTime{DateTime: t1, Valid: true}}, + } + }(), + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tvp := TVP{ + TypeName: "dbo.TestType", + Value: tt.tvpData, + } + + // Test columnTypes + columnStr, tvpFieldIndexes, err := tvp.columnTypes() + if (err != nil) != tt.wantErr { + t.Errorf("TVP.columnTypes() error = %v, wantErr %v", err, tt.wantErr) + return + } + + if err != nil { + return // Skip encoding test if columnTypes failed + } + + // Test encode + _, err = tvp.encode("dbo", "TestType", columnStr, tvpFieldIndexes, msdsn.EncodeParameters{}) + if (err != nil) != tt.wantErr { + t.Errorf("TVP.encode() error = %v, wantErr %v", err, tt.wantErr) + return + } + }) + } +} + +// TestTVPNullCivilTypesCreateZeroType tests that nullable civil types are handled correctly +// in the createZeroType method when building column type information +func TestTVPNullCivilTypesCreateZeroType(t *testing.T) { + type tvpDataRowMixed struct { + Date NullDate `tvp:"date_col"` + DateTime NullDateTime `tvp:"datetime_col"` + Time NullTime `tvp:"time_col"` + } + + tvp := TVP{ + TypeName: "dbo.TestType", + Value: []tvpDataRowMixed{ + {}, // Empty struct to trigger createZeroType for all fields + }, + } + + // Test that we can get column types without error + columnStr, tvpFieldIndexes, err := tvp.columnTypes() + if err != nil { + t.Errorf("TVP.columnTypes() with empty struct failed: %v", err) + return + } + + // Should have 3 columns for the 3 fields + if len(columnStr) != 3 { + t.Errorf("Expected 3 columns, got %d", len(columnStr)) + } + + if len(tvpFieldIndexes) != 3 { + t.Errorf("Expected 3 field indexes, got %d", len(tvpFieldIndexes)) + } +}