Skip to content
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 4 additions & 5 deletions client/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -92,7 +92,7 @@ func (c *Conn) readInitialHandshake() error {
pos += 2

// The upper 2 bytes of the Capabilities Flags
c.capability = uint32(binary.LittleEndian.Uint16(data[pos:pos+2]))<<16 | c.capability
c.capability |= uint32(binary.LittleEndian.Uint16(data[pos:pos+2])) << 16
pos += 2

// length of the combined auth_plugin_data (scramble), if auth_plugin_data_len is > 0
Expand Down Expand Up @@ -209,10 +209,8 @@ func (c *Conn) writeAuthHandshake() error {

// Set default client capabilities that reflect the abilities of this library
capability := mysql.CLIENT_PROTOCOL_41 | mysql.CLIENT_SECURE_CONNECTION |
mysql.CLIENT_LONG_PASSWORD | mysql.CLIENT_TRANSACTIONS | mysql.CLIENT_PLUGIN_AUTH
// Adjust client capability flags based on server support
capability |= c.capability & mysql.CLIENT_LONG_FLAG
capability |= c.capability & mysql.CLIENT_QUERY_ATTRIBUTES
mysql.CLIENT_LONG_PASSWORD | mysql.CLIENT_TRANSACTIONS | mysql.CLIENT_PLUGIN_AUTH |
mysql.CLIENT_LONG_FLAG | mysql.CLIENT_QUERY_ATTRIBUTES | mysql.CLIENT_DEPRECATE_EOF
// Adjust client capability flags on specific client requests
// Only flags that would make any sense setting and aren't handled elsewhere
// in the library are supported here
Expand Down Expand Up @@ -275,6 +273,7 @@ func (c *Conn) writeAuthHandshake() error {
data := make([]byte, length+4)

// capability [32 bit]
c.capability &= capability
data[4] = byte(capability)
data[5] = byte(capability >> 8)
data[6] = byte(capability >> 16)
Expand Down
13 changes: 13 additions & 0 deletions client/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -101,6 +101,18 @@ func (s *clientTestSuite) TestConn_Compress() {
require.NoError(s.T(), err)
}

func (s *clientTestSuite) TestConn_NoDeprecateEOF() {
addr := fmt.Sprintf("%s:%s", *test_util.MysqlHost, s.port)
conn, err := Connect(addr, *testUser, *testPassword, "", func(conn *Conn) error {
conn.UnsetCapability(mysql.CLIENT_DEPRECATE_EOF)
return nil
})
require.NoError(s.T(), err)

_, err = conn.Execute("SELECT VERSION()")
require.NoError(s.T(), err)
}

func (s *clientTestSuite) TestConn_SetCapability() {
caps := []uint32{
mysql.CLIENT_LONG_PASSWORD,
Expand All @@ -125,6 +137,7 @@ func (s *clientTestSuite) TestConn_SetCapability() {
mysql.CLIENT_PLUGIN_AUTH,
mysql.CLIENT_CONNECT_ATTRS,
mysql.CLIENT_PLUGIN_AUTH_LENENC_CLIENT_DATA,
mysql.CLIENT_DEPRECATE_EOF,
}

for _, capI := range caps {
Expand Down
4 changes: 2 additions & 2 deletions client/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -252,7 +252,7 @@ func (c *Conn) UnsetCapability(cap uint32) {

// HasCapability returns true if the connection has the specific capability
func (c *Conn) HasCapability(cap uint32) bool {
return c.ccaps&cap > 0
return c.ccaps&cap != 0
}

// UseSSL: use default SSL
Expand Down Expand Up @@ -466,7 +466,7 @@ func (c *Conn) FieldList(table string, wildcard string) ([]*mysql.Field, error)
}

// EOF Packet
if c.isEOFPacket(data) {
if data[0] == mysql.EOF_HEADER && len(data) <= 0xffffff {
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can you explain why we change it to len(data) <= 0xffffff?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Will add comment later, bit unavailable this week

https://dev.mysql.com/worklog/task/?id=7766

In case of huge data packet with length greater than 16777216L client will treat it as a data packet and process accordingly.

return fs, nil
}

Expand Down
87 changes: 45 additions & 42 deletions client/resp.go
Original file line number Diff line number Diff line change
Expand Up @@ -14,22 +14,7 @@ import (
"github.com/go-mysql-org/go-mysql/utils"
)

func (c *Conn) readUntilEOF() (err error) {
var data []byte

for {
data, err = c.ReadPacket()
if err != nil {
return err
}

// EOF Packet
if c.isEOFPacket(data) {
return err
}
}
}

// this should only be called when CLIENT_DEPRECATE_EOF not enabled
func (c *Conn) isEOFPacket(data []byte) bool {
return data[0] == mysql.EOF_HEADER && len(data) <= 5
}
Expand Down Expand Up @@ -336,33 +321,16 @@ func (c *Conn) readResultsetStreaming(data []byte, binary bool, result *mysql.Re
}

func (c *Conn) readResultColumns(result *mysql.Result) (err error) {
i := 0
var data []byte

for {
for i := range len(result.Fields) {
rawPkgLen := len(result.RawPkg)
result.RawPkg, err = c.ReadPacketReuseMem(result.RawPkg)
if err != nil {
return err
}
data = result.RawPkg[rawPkgLen:]

// EOF Packet
if c.isEOFPacket(data) {
if c.capability&mysql.CLIENT_PROTOCOL_41 > 0 {
result.Warnings = binary.LittleEndian.Uint16(data[1:])
// todo add strict_mode, warning will be treat as error
result.Status = binary.LittleEndian.Uint16(data[3:])
c.status = result.Status
}

if i != len(result.Fields) {
err = mysql.ErrMalformPacket
}

return err
}

if result.Fields[i] == nil {
result.Fields[i] = &mysql.Field{}
}
Expand All @@ -372,8 +340,30 @@ func (c *Conn) readResultColumns(result *mysql.Result) (err error) {
}

result.FieldNames[utils.ByteSliceToString(result.Fields[i].Name)] = i
}

i++
if c.capability&mysql.CLIENT_DEPRECATE_EOF == 0 {
// EOF Packet
rawPkgLen := len(result.RawPkg)
result.RawPkg, err = c.ReadPacketReuseMem(result.RawPkg)
if err != nil {
return err
}
data = result.RawPkg[rawPkgLen:]

if c.isEOFPacket(data) {
if c.capability&mysql.CLIENT_PROTOCOL_41 > 0 {
result.Warnings = binary.LittleEndian.Uint16(data[1:])
// todo add strict_mode, warning will be treat as error
result.Status = binary.LittleEndian.Uint16(data[3:])
c.status = result.Status
}
return nil
} else {
return mysql.ErrMalformPacket
}
} else {
return nil
}
}

Expand All @@ -388,15 +378,21 @@ func (c *Conn) readResultRows(result *mysql.Result, isBinary bool) (err error) {
}
data = result.RawPkg[rawPkgLen:]

// EOF Packet
if c.isEOFPacket(data) {
if c.capability&mysql.CLIENT_PROTOCOL_41 > 0 {
if data[0] == mysql.EOF_HEADER && len(data) <= 0xffffff {
if c.capability&mysql.CLIENT_DEPRECATE_EOF != 0 {
// Treat like OK
affectedRows, _, n := mysql.LengthEncodedInt(data[1:])
insertId, _, m := mysql.LengthEncodedInt(data[1+n:])
result.Status = binary.LittleEndian.Uint16(data[1+n+m:])
result.AffectedRows = affectedRows
result.InsertId = insertId
c.status = result.Status
} else if c.capability&mysql.CLIENT_PROTOCOL_41 > 0 {
result.Warnings = binary.LittleEndian.Uint16(data[1:])
// todo add strict_mode, warning will be treat as error
result.Status = binary.LittleEndian.Uint16(data[3:])
c.status = result.Status
}

break
}

Expand Down Expand Up @@ -435,9 +431,16 @@ func (c *Conn) readResultRowsStreaming(result *mysql.Result, isBinary bool, perR
return err
}

// EOF Packet
if c.isEOFPacket(data) {
if c.capability&mysql.CLIENT_PROTOCOL_41 > 0 {
if data[0] == mysql.EOF_HEADER && len(data) <= 0xffffff {
if c.capability&mysql.CLIENT_DEPRECATE_EOF != 0 {
// Treat like OK
affectedRows, _, n := mysql.LengthEncodedInt(data[1:])
insertId, _, m := mysql.LengthEncodedInt(data[1+n:])
result.Status = binary.LittleEndian.Uint16(data[1+n+m:])
result.AffectedRows = affectedRows
result.InsertId = insertId
c.status = result.Status
} else if c.capability&mysql.CLIENT_PROTOCOL_41 > 0 {
result.Warnings = binary.LittleEndian.Uint16(data[1:])
// todo add strict_mode, warning will be treat as error
result.Status = binary.LittleEndian.Uint16(data[3:])
Expand Down
38 changes: 31 additions & 7 deletions client/stmt.go
Original file line number Diff line number Diff line change
Expand Up @@ -265,19 +265,43 @@ func (c *Conn) Prepare(query string) (*Stmt, error) {
s.params = int(binary.LittleEndian.Uint16(data[pos:]))
pos += 2

// warnings
s.warnings = int(binary.LittleEndian.Uint16(data[pos:]))
// pos += 2
// reserved
pos += 1

if len(data) >= 12 {
// warnings
s.warnings = int(binary.LittleEndian.Uint16(data[pos:]))
// pos += 2
}

if s.params > 0 {
if err := s.conn.readUntilEOF(); err != nil {
return nil, errors.Trace(err)
for range s.params {
if _, err := s.conn.ReadPacket(); err != nil {
return nil, errors.Trace(err)
}
}
if s.conn.capability&mysql.CLIENT_DEPRECATE_EOF == 0 {
if packet, err := s.conn.ReadPacket(); err != nil {
return nil, errors.Trace(err)
} else if !c.isEOFPacket(packet) {
return nil, mysql.ErrMalformPacket
}
}
}

if s.columns > 0 {
if err := s.conn.readUntilEOF(); err != nil {
return nil, errors.Trace(err)
// TODO process when CLIENT_CACHE_METADATA enabled
for range s.columns {
if _, err := s.conn.ReadPacket(); err != nil {
return nil, errors.Trace(err)
}
}
if s.conn.capability&mysql.CLIENT_DEPRECATE_EOF == 0 {
if packet, err := s.conn.ReadPacket(); err != nil {
return nil, errors.Trace(err)
} else if !c.isEOFPacket(packet) {
return nil, mysql.ErrMalformPacket
}
}
}

Expand Down
Loading