@@ -172,7 +172,7 @@ func (mc *mysqlConn) close() {
172172}
173173
174174// Closes the network connection and unsets internal variables. Do not call this 
175- // function after successfully  authentication, call Close instead. This function 
175+ // function after successful  authentication, call Close instead. This function 
176176// is called before auth or on auth failure because MySQL will have already 
177177// closed the network connection. 
178178func  (mc  * mysqlConn ) cleanup () {
@@ -246,100 +246,172 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
246246}
247247
248248func  (mc  * mysqlConn ) interpolateParams (query  string , args  []driver.Value ) (string , error ) {
249- 	// Number of ? should be same to len(args) 
250- 	if  strings .Count (query , "?" ) !=  len (args ) {
251- 		return  "" , driver .ErrSkip 
252- 	}
249+ 	noBackslashEscapes  :=  (mc .status  &  statusNoBackslashEscapes ) !=  0 
250+ 	const  (
251+ 		stateNormal  =  iota 
252+ 		stateString 
253+ 		stateEscape 
254+ 		stateEOLComment 
255+ 		stateSlashStarComment 
256+ 		stateBacktick 
257+ 	)
258+ 
259+ 	var  (
260+ 		QUOTE_BYTE          =  byte ('\'' )
261+ 		DBL_QUOTE_BYTE      =  byte ('"' )
262+ 		BACKSLASH_BYTE      =  byte ('\\' )
263+ 		QUESTION_MARK_BYTE  =  byte ('?' )
264+ 		SLASH_BYTE          =  byte ('/' )
265+ 		STAR_BYTE           =  byte ('*' )
266+ 		HASH_BYTE           =  byte ('#' )
267+ 		MINUS_BYTE          =  byte ('-' )
268+ 		LINE_FEED_BYTE      =  byte ('\n' )
269+ 		RADICAL_BYTE        =  byte ('`' )
270+ 	)
253271
254272	buf , err  :=  mc .buf .takeCompleteBuffer ()
255273	if  err  !=  nil  {
256- 		// can not take the buffer. Something must be wrong with the connection 
257274		mc .cleanup ()
258- 		// interpolateParams would be called before sending any query. 
259- 		// So its safe to retry. 
260275		return  "" , driver .ErrBadConn 
261276	}
262277	buf  =  buf [:0 ]
278+ 	state  :=  stateNormal 
279+ 	singleQuotes  :=  false 
280+ 	lastChar  :=  byte (0 )
263281	argPos  :=  0 
264- 
265- 	for  i  :=  0 ; i  <  len (query ); i ++  {
266- 		q  :=  strings .IndexByte (query [i :], '?' )
267- 		if  q  ==  - 1  {
268- 			buf  =  append (buf , query [i :]... )
269- 			break 
270- 		}
271- 		buf  =  append (buf , query [i :i + q ]... )
272- 		i  +=  q 
273- 
274- 		arg  :=  args [argPos ]
275- 		argPos ++ 
276- 
277- 		if  arg  ==  nil  {
278- 			buf  =  append (buf , "NULL" ... )
282+ 	lenQuery  :=  len (query )
283+ 	lastIdx  :=  0 
284+ 
285+ 	for  i  :=  0 ; i  <  lenQuery ; i ++  {
286+ 		currentChar  :=  query [i ]
287+ 		if  state  ==  stateEscape  &&  ! ((currentChar  ==  QUOTE_BYTE  &&  singleQuotes ) ||  (currentChar  ==  DBL_QUOTE_BYTE  &&  ! singleQuotes )) {
288+ 			state  =  stateString 
289+ 			lastChar  =  currentChar 
279290			continue 
280291		}
281- 
282- 		switch  v  :=  arg .(type ) {
283- 		case  int64 :
284- 			buf  =  strconv .AppendInt (buf , v , 10 )
285- 		case  uint64 :
286- 			// Handle uint64 explicitly because our custom ConvertValue emits unsigned values 
287- 			buf  =  strconv .AppendUint (buf , v , 10 )
288- 		case  float64 :
289- 			buf  =  strconv .AppendFloat (buf , v , 'g' , - 1 , 64 )
290- 		case  bool :
291- 			if  v  {
292- 				buf  =  append (buf , '1' )
293- 			} else  {
294- 				buf  =  append (buf , '0' )
292+ 		switch  currentChar  {
293+ 		case  STAR_BYTE :
294+ 			if  state  ==  stateNormal  &&  lastChar  ==  SLASH_BYTE  {
295+ 				state  =  stateSlashStarComment 
295296			}
296- 		case  time.Time :
297- 			if  v .IsZero () {
298- 				buf  =  append (buf , "'0000-00-00'" ... )
299- 			} else  {
300- 				buf  =  append (buf , '\'' )
301- 				buf , err  =  appendDateTime (buf , v .In (mc .cfg .Loc ), mc .cfg .timeTruncate )
302- 				if  err  !=  nil  {
303- 					return  "" , err 
304- 				}
305- 				buf  =  append (buf , '\'' )
297+ 		case  SLASH_BYTE :
298+ 			if  state  ==  stateSlashStarComment  &&  lastChar  ==  STAR_BYTE  {
299+ 				state  =  stateNormal 
306300			}
307- 		case  json.RawMessage :
308- 			buf  =  append (buf , '\'' )
309- 			if  mc .status & statusNoBackslashEscapes  ==  0  {
310- 				buf  =  escapeBytesBackslash (buf , v )
311- 			} else  {
312- 				buf  =  escapeBytesQuotes (buf , v )
301+ 		case  HASH_BYTE :
302+ 			if  state  ==  stateNormal  {
303+ 				state  =  stateEOLComment 
313304			}
314- 			buf  =  append (buf , '\'' )
315- 		case  []byte :
316- 			if  v  ==  nil  {
317- 				buf  =  append (buf , "NULL" ... )
318- 			} else  {
319- 				buf  =  append (buf , "_binary'" ... )
320- 				if  mc .status & statusNoBackslashEscapes  ==  0  {
321- 					buf  =  escapeBytesBackslash (buf , v )
322- 				} else  {
323- 					buf  =  escapeBytesQuotes (buf , v )
324- 				}
325- 				buf  =  append (buf , '\'' )
305+ 		case  MINUS_BYTE :
306+ 			if  state  ==  stateNormal  &&  lastChar  ==  MINUS_BYTE  {
307+ 				state  =  stateEOLComment 
326308			}
327- 		case  string :
328- 			buf  =  append (buf , '\'' )
329- 			if  mc .status & statusNoBackslashEscapes  ==  0  {
330- 				buf  =  escapeStringBackslash (buf , v )
331- 			} else  {
332- 				buf  =  escapeStringQuotes (buf , v )
309+ 		case  LINE_FEED_BYTE :
310+ 			if  state  ==  stateEOLComment  {
311+ 				state  =  stateNormal 
333312			}
334- 			buf  =  append (buf , '\'' )
335- 		default :
336- 			return  "" , driver .ErrSkip 
337- 		}
313+ 		case  DBL_QUOTE_BYTE :
314+ 			if  state  ==  stateNormal  {
315+ 				state  =  stateString 
316+ 				singleQuotes  =  false 
317+ 			} else  if  state  ==  stateString  &&  ! singleQuotes  {
318+ 				state  =  stateNormal 
319+ 			} else  if  state  ==  stateEscape  {
320+ 				state  =  stateString 
321+ 			}
322+ 		case  QUOTE_BYTE :
323+ 			if  state  ==  stateNormal  {
324+ 				state  =  stateString 
325+ 				singleQuotes  =  true 
326+ 			} else  if  state  ==  stateString  &&  singleQuotes  {
327+ 				state  =  stateNormal 
328+ 			} else  if  state  ==  stateEscape  {
329+ 				state  =  stateString 
330+ 			}
331+ 		case  BACKSLASH_BYTE :
332+ 			if  state  ==  stateString  &&  ! noBackslashEscapes  {
333+ 				state  =  stateEscape 
334+ 			}
335+ 		case  QUESTION_MARK_BYTE :
336+ 			if  state  ==  stateNormal  {
337+ 				if  argPos  >=  len (args ) {
338+ 					return  "" , driver .ErrSkip 
339+ 				}
340+ 				buf  =  append (buf , query [lastIdx :i ]... )
341+ 				arg  :=  args [argPos ]
342+ 				argPos ++ 
343+ 
344+ 				if  arg  ==  nil  {
345+ 					buf  =  append (buf , "NULL" ... )
346+ 					lastIdx  =  i  +  1 
347+ 					break 
348+ 				}
349+ 
350+ 				switch  v  :=  arg .(type ) {
351+ 				case  int64 :
352+ 					buf  =  strconv .AppendInt (buf , v , 10 )
353+ 				case  uint64 :
354+ 					buf  =  strconv .AppendUint (buf , v , 10 )
355+ 				case  float64 :
356+ 					buf  =  strconv .AppendFloat (buf , v , 'g' , - 1 , 64 )
357+ 				case  bool :
358+ 					if  v  {
359+ 						buf  =  append (buf , '1' )
360+ 					} else  {
361+ 						buf  =  append (buf , '0' )
362+ 					}
363+ 				case  time.Time :
364+ 					if  v .IsZero () {
365+ 						buf  =  append (buf , "'0000-00-00'" ... )
366+ 					} else  {
367+ 						buf  =  append (buf , '\'' )
368+ 						buf , err  =  appendDateTime (buf , v .In (mc .cfg .Loc ), mc .cfg .timeTruncate )
369+ 						if  err  !=  nil  {
370+ 							return  "" , err 
371+ 						}
372+ 						buf  =  append (buf , '\'' )
373+ 					}
374+ 				case  json.RawMessage :
375+ 					if  noBackslashEscapes  {
376+ 						buf  =  escapeBytesQuotes (buf , v , false )
377+ 					} else  {
378+ 						buf  =  escapeBytesBackslash (buf , v , false )
379+ 					}
380+ 				case  []byte :
381+ 					if  v  ==  nil  {
382+ 						buf  =  append (buf , "NULL" ... )
383+ 					} else  {
384+ 						if  noBackslashEscapes  {
385+ 							buf  =  escapeBytesQuotes (buf , v , true )
386+ 						} else  {
387+ 							buf  =  escapeBytesBackslash (buf , v , true )
388+ 						}
389+ 					}
390+ 				case  string :
391+ 					if  noBackslashEscapes  {
392+ 						buf  =  escapeStringQuotes (buf , v )
393+ 					} else  {
394+ 						buf  =  escapeStringBackslash (buf , v )
395+ 					}
396+ 				default :
397+ 					return  "" , driver .ErrSkip 
398+ 				}
338399
339- 		if  len (buf )+ 4  >  mc .maxAllowedPacket  {
340- 			return  "" , driver .ErrSkip 
400+ 				if  len (buf )+ 4  >  mc .maxAllowedPacket  {
401+ 					return  "" , driver .ErrSkip 
402+ 				}
403+ 				lastIdx  =  i  +  1 
404+ 			}
405+ 		case  RADICAL_BYTE :
406+ 			if  state  ==  stateBacktick  {
407+ 				state  =  stateNormal 
408+ 			} else  if  state  ==  stateNormal  {
409+ 				state  =  stateBacktick 
410+ 			}
341411		}
412+ 		lastChar  =  currentChar 
342413	}
414+ 	buf  =  append (buf , query [lastIdx :]... )
343415	if  argPos  !=  len (args ) {
344416		return  "" , driver .ErrSkip 
345417	}
0 commit comments