diff --git a/client/column_decoder.go b/client/column_decoder.go index e943541..0898c4f 100644 --- a/client/column_decoder.go +++ b/client/column_decoder.go @@ -95,6 +95,18 @@ func (decoder *Int32ArrayColumnDecoder) ReadColumn(reader *bytes.Reader, dataTyp // +---------------+-----------------+-------------+ // | byte | list[byte] | list[int32] | // +---------------+-----------------+-------------+ + + if positionCount == 0 { + switch dataType { + case INT32, DATE: + return NewIntColumn(0, 0, nil, []int32{}) + case FLOAT: + return NewFloatColumn(0, 0, nil, []float32{}) + default: + return nil, fmt.Errorf("invalid data type: %v", dataType) + } + } + nullIndicators, err := deserializeNullIndicators(reader, positionCount) if err != nil { return nil, err @@ -139,6 +151,18 @@ func (decoder *Int64ArrayColumnDecoder) ReadColumn(reader *bytes.Reader, dataTyp // +---------------+-----------------+-------------+ // | byte | list[byte] | list[int64] | // +---------------+-----------------+-------------+ + + if positionCount == 0 { + switch dataType { + case INT64, TIMESTAMP: + return NewLongColumn(0, 0, nil, []int64{}) + case DOUBLE: + return NewDoubleColumn(0, 0, nil, []float64{}) + default: + return nil, fmt.Errorf("invalid data type: %v", dataType) + } + } + nullIndicators, err := deserializeNullIndicators(reader, positionCount) if err != nil { return nil, err @@ -185,6 +209,11 @@ func (decoder *ByteArrayColumnDecoder) ReadColumn(reader *bytes.Reader, dataType if dataType != BOOLEAN { return nil, fmt.Errorf("invalid data type: %v", dataType) } + + if positionCount == 0 { + return NewBooleanColumn(0, 0, nil, []bool{}) + } + nullIndicators, err := deserializeNullIndicators(reader, positionCount) if err != nil { return nil, err @@ -218,6 +247,11 @@ func (decoder *BinaryArrayColumnDecoder) ReadColumn(reader *bytes.Reader, dataTy if TEXT != dataType { return nil, fmt.Errorf("invalid data type: %v", dataType) } + + if positionCount == 0 { + return NewBinaryColumn(0, 0, nil, []*Binary{}) + } + nullIndicators, err := deserializeNullIndicators(reader, positionCount) if err != nil { return nil, err @@ -232,12 +266,17 @@ func (decoder *BinaryArrayColumnDecoder) ReadColumn(reader *bytes.Reader, dataTy if err != nil { return nil, err } - value := make([]byte, length) - _, err = reader.Read(value) - if err != nil { - return nil, err + + if length == 0 { + values[i] = NewBinary([]byte{}) + } else { + value := make([]byte, length) + _, err = reader.Read(value) + if err != nil { + return nil, err + } + values[i] = NewBinary(value) } - values[i] = NewBinary(value) } return NewBinaryColumn(0, positionCount, nullIndicators, values) } diff --git a/client/column_decoder_test.go b/client/column_decoder_test.go new file mode 100644 index 0000000..3c0ba7d --- /dev/null +++ b/client/column_decoder_test.go @@ -0,0 +1,181 @@ +package client + +import ( + "bytes" + "encoding/binary" + "testing" +) + +func buildNullIndicatorBytes(nulls []bool) []byte { + var buf bytes.Buffer + hasNull := false + for _, n := range nulls { + if n { + hasNull = true + break + } + } + if !hasNull { + buf.WriteByte(0) + return buf.Bytes() + } + buf.WriteByte(1) + packed := make([]byte, (len(nulls)+7)/8) + for i, n := range nulls { + if n { + packed[i/8] |= 0b10000000 >> (uint(i) % 8) + } + } + buf.Write(packed) + return buf.Bytes() +} + +func TestBinaryArrayColumnDecoder_EmptyString(t *testing.T) { + var buf bytes.Buffer + buf.Write(buildNullIndicatorBytes([]bool{false})) + _ = binary.Write(&buf, binary.BigEndian, int32(0)) + + col, err := (&BinaryArrayColumnDecoder{}).ReadColumn(bytes.NewReader(buf.Bytes()), TEXT, 1) + if err != nil { + t.Fatalf("ReadColumn failed: %v", err) + } + if col.GetPositionCount() != 1 { + t.Fatalf("expected positionCount=1, got %d", col.GetPositionCount()) + } + if col.IsNull(0) { + t.Fatal("row 0 should not be null") + } + val, err := col.GetBinary(0) + if err != nil { + t.Fatalf("GetBinary(0) failed: %v", err) + } + if len(val.values) != 0 { + t.Fatalf("expected empty string, got %q", string(val.values)) + } +} + +func TestBinaryArrayColumnDecoder_NullThenEmptyString(t *testing.T) { + var buf bytes.Buffer + buf.Write(buildNullIndicatorBytes([]bool{true, false})) + _ = binary.Write(&buf, binary.BigEndian, int32(0)) + + col, err := (&BinaryArrayColumnDecoder{}).ReadColumn(bytes.NewReader(buf.Bytes()), TEXT, 2) + if err != nil { + t.Fatalf("ReadColumn failed: %v", err) + } + if !col.IsNull(0) { + t.Error("row 0 should be null") + } + if col.IsNull(1) { + t.Error("row 1 should not be null") + } + val, err := col.GetBinary(1) + if err != nil { + t.Fatalf("GetBinary(1) failed: %v", err) + } + if len(val.values) != 0 { + t.Fatalf("expected empty string, got %q", string(val.values)) + } +} + +func TestBinaryArrayColumnDecoder_WithNull(t *testing.T) { + var buf bytes.Buffer + buf.Write(buildNullIndicatorBytes([]bool{false, true, false})) + writeText := func(s string) { + _ = binary.Write(&buf, binary.BigEndian, int32(len(s))) + buf.WriteString(s) + } + writeText("hello") + writeText("world") + + col, err := (&BinaryArrayColumnDecoder{}).ReadColumn(bytes.NewReader(buf.Bytes()), TEXT, 3) + if err != nil { + t.Fatalf("ReadColumn failed: %v", err) + } + if col.IsNull(0) { + t.Error("row 0 should not be null") + } + if v, _ := col.GetBinary(0); string(v.values) != "hello" { + t.Errorf("row 0: expected \"hello\", got %q", string(v.values)) + } + if !col.IsNull(1) { + t.Error("row 1 should be null") + } + if col.IsNull(2) { + t.Error("row 2 should not be null") + } + if v, _ := col.GetBinary(2); string(v.values) != "world" { + t.Errorf("row 2: expected \"world\", got %q", string(v.values)) + } +} + +func TestInt64ArrayColumnDecoder_WithNull(t *testing.T) { + var buf bytes.Buffer + buf.Write(buildNullIndicatorBytes([]bool{false, true, false})) + _ = binary.Write(&buf, binary.BigEndian, int64(100)) + _ = binary.Write(&buf, binary.BigEndian, int64(200)) + + col, err := (&Int64ArrayColumnDecoder{}).ReadColumn(bytes.NewReader(buf.Bytes()), INT64, 3) + if err != nil { + t.Fatalf("ReadColumn failed: %v", err) + } + if col.IsNull(0) { + t.Error("row 0 should not be null") + } + if v, _ := col.GetLong(0); v != 100 { + t.Errorf("row 0: expected 100, got %d", v) + } + if !col.IsNull(1) { + t.Error("row 1 should be null") + } + if col.IsNull(2) { + t.Error("row 2 should not be null") + } + if v, _ := col.GetLong(2); v != 200 { + t.Errorf("row 2: expected 200, got %d", v) + } +} + +func TestColumnDecoder_ZeroPositionCount(t *testing.T) { + empty := func() *bytes.Reader { return bytes.NewReader([]byte{}) } + + t.Run("Int32ArrayColumnDecoder", func(t *testing.T) { + col, err := (&Int32ArrayColumnDecoder{}).ReadColumn(empty(), INT32, 0) + if err != nil { + t.Fatalf("ReadColumn failed: %v", err) + } + if col.GetPositionCount() != 0 { + t.Errorf("expected positionCount=0, got %d", col.GetPositionCount()) + } + }) + + t.Run("Int64ArrayColumnDecoder", func(t *testing.T) { + col, err := (&Int64ArrayColumnDecoder{}).ReadColumn(empty(), INT64, 0) + if err != nil { + t.Fatalf("ReadColumn failed: %v", err) + } + if col.GetPositionCount() != 0 { + t.Errorf("expected positionCount=0, got %d", col.GetPositionCount()) + } + }) + + t.Run("ByteArrayColumnDecoder", func(t *testing.T) { + col, err := (&ByteArrayColumnDecoder{}).ReadColumn(empty(), BOOLEAN, 0) + if err != nil { + t.Fatalf("ReadColumn failed: %v", err) + } + if col.GetPositionCount() != 0 { + t.Errorf("expected positionCount=0, got %d", col.GetPositionCount()) + } + }) + + t.Run("BinaryArrayColumnDecoder", func(t *testing.T) { + col, err := (&BinaryArrayColumnDecoder{}).ReadColumn(empty(), TEXT, 0) + if err != nil { + t.Fatalf("ReadColumn failed: %v", err) + } + if col.GetPositionCount() != 0 { + t.Errorf("expected positionCount=0, got %d", col.GetPositionCount()) + } + }) +} diff --git a/client/session.go b/client/session.go index 2cd1e82..0d4084d 100644 --- a/client/session.go +++ b/client/session.go @@ -569,10 +569,12 @@ func (s *Session) ExecuteQueryStatement(sql string, timeoutMs *int64) (*SessionD request.SessionId = s.sessionId request.StatementId = s.requestStatementId resp, err = s.client.ExecuteQueryStatementV2(context.Background(), &request) - if statusErr := VerifySuccess(resp.Status); statusErr == nil { - return NewSessionDataSet(sql, resp.Columns, resp.DataTypeList, resp.ColumnNameIndexMap, *resp.QueryId, s.requestStatementId, s.client, s.sessionId, resp.QueryResult_, resp.IgnoreTimeStamp != nil && *resp.IgnoreTimeStamp, timeoutMs, *resp.MoreData, s.config.FetchSize, s.config.TimeZone, s.timeFactor, resp.GetColumnIndex2TsBlockColumnIndexList()) - } else { - return nil, statusErr + if err == nil && resp != nil { + if statusErr := VerifySuccess(resp.Status); statusErr == nil { + return NewSessionDataSet(sql, resp.Columns, resp.DataTypeList, resp.ColumnNameIndexMap, *resp.QueryId, s.requestStatementId, s.client, s.sessionId, resp.QueryResult_, resp.IgnoreTimeStamp != nil && *resp.IgnoreTimeStamp, timeoutMs, *resp.MoreData, s.config.FetchSize, s.config.TimeZone, s.timeFactor, resp.GetColumnIndex2TsBlockColumnIndexList()) + } else { + return nil, statusErr + } } } return nil, err @@ -597,10 +599,12 @@ func (s *Session) ExecuteAggregationQuery(paths []string, aggregations []common. if s.reconnect() { request.SessionId = s.sessionId resp, err = s.client.ExecuteAggregationQueryV2(context.Background(), &request) - if statusErr := VerifySuccess(resp.Status); statusErr == nil { - return NewSessionDataSet("", resp.Columns, resp.DataTypeList, resp.ColumnNameIndexMap, *resp.QueryId, s.requestStatementId, s.client, s.sessionId, resp.QueryResult_, resp.IgnoreTimeStamp != nil && *resp.IgnoreTimeStamp, timeoutMs, *resp.MoreData, s.config.FetchSize, s.config.TimeZone, s.timeFactor, resp.GetColumnIndex2TsBlockColumnIndexList()) - } else { - return nil, statusErr + if err == nil && resp != nil { + if statusErr := VerifySuccess(resp.Status); statusErr == nil { + return NewSessionDataSet("", resp.Columns, resp.DataTypeList, resp.ColumnNameIndexMap, *resp.QueryId, s.requestStatementId, s.client, s.sessionId, resp.QueryResult_, resp.IgnoreTimeStamp != nil && *resp.IgnoreTimeStamp, timeoutMs, *resp.MoreData, s.config.FetchSize, s.config.TimeZone, s.timeFactor, resp.GetColumnIndex2TsBlockColumnIndexList()) + } else { + return nil, statusErr + } } } return nil, err @@ -626,10 +630,12 @@ func (s *Session) ExecuteAggregationQueryWithLegalNodes(paths []string, aggregat if s.reconnect() { request.SessionId = s.sessionId resp, err = s.client.ExecuteAggregationQueryV2(context.Background(), &request) - if statusErr := VerifySuccess(resp.Status); statusErr == nil { - return NewSessionDataSet("", resp.Columns, resp.DataTypeList, resp.ColumnNameIndexMap, *resp.QueryId, s.requestStatementId, s.client, s.sessionId, resp.QueryResult_, resp.IgnoreTimeStamp != nil && *resp.IgnoreTimeStamp, timeoutMs, *resp.MoreData, s.config.FetchSize, s.config.TimeZone, s.timeFactor, resp.GetColumnIndex2TsBlockColumnIndexList()) - } else { - return nil, statusErr + if err == nil && resp != nil { + if statusErr := VerifySuccess(resp.Status); statusErr == nil { + return NewSessionDataSet("", resp.Columns, resp.DataTypeList, resp.ColumnNameIndexMap, *resp.QueryId, s.requestStatementId, s.client, s.sessionId, resp.QueryResult_, resp.IgnoreTimeStamp != nil && *resp.IgnoreTimeStamp, timeoutMs, *resp.MoreData, s.config.FetchSize, s.config.TimeZone, s.timeFactor, resp.GetColumnIndex2TsBlockColumnIndexList()) + } else { + return nil, statusErr + } } } return nil, err @@ -653,10 +659,12 @@ func (s *Session) ExecuteFastLastDataQueryForOnePrefixPath(prefixes []string, ti if s.reconnect() { request.SessionId = s.sessionId resp, err = s.client.ExecuteFastLastDataQueryForOnePrefixPath(context.Background(), &request) - if statusErr := VerifySuccess(resp.Status); statusErr == nil { - return NewSessionDataSet("", resp.Columns, resp.DataTypeList, resp.ColumnNameIndexMap, *resp.QueryId, s.requestStatementId, s.client, s.sessionId, resp.QueryResult_, resp.IgnoreTimeStamp != nil && *resp.IgnoreTimeStamp, timeoutMs, *resp.MoreData, s.config.FetchSize, s.config.TimeZone, s.timeFactor, resp.GetColumnIndex2TsBlockColumnIndexList()) - } else { - return nil, statusErr + if err == nil && resp != nil { + if statusErr := VerifySuccess(resp.Status); statusErr == nil { + return NewSessionDataSet("", resp.Columns, resp.DataTypeList, resp.ColumnNameIndexMap, *resp.QueryId, s.requestStatementId, s.client, s.sessionId, resp.QueryResult_, resp.IgnoreTimeStamp != nil && *resp.IgnoreTimeStamp, timeoutMs, *resp.MoreData, s.config.FetchSize, s.config.TimeZone, s.timeFactor, resp.GetColumnIndex2TsBlockColumnIndexList()) + } else { + return nil, statusErr + } } } return nil, err diff --git a/client/sessiondataset.go b/client/sessiondataset.go index ef3faba..9a5b414 100644 --- a/client/sessiondataset.go +++ b/client/sessiondataset.go @@ -125,3 +125,7 @@ func (s *SessionDataSet) GetColumnNames() []string { func (s *SessionDataSet) GetColumnTypes() []string { return s.ioTDBRpcDataSet.columnTypeList } + +func (s *SessionDataSet) GetCurrentRowTime() int64 { + return s.ioTDBRpcDataSet.GetCurrentRowTime() +}