@@ -226,6 +226,8 @@ type Greeting struct {
226
226
227
227
// Opts is a way to configure Connection
228
228
type Opts struct {
229
+ // Auth is an authentication method.
230
+ Auth Auth
229
231
// Timeout for response to a particular request. The timeout is reset when
230
232
// push messages are received. If Timeout is zero, any request can be
231
233
// blocked infinitely.
@@ -546,19 +548,40 @@ func (conn *Connection) dial() (err error) {
546
548
547
549
// Auth.
548
550
if opts .User != "" {
549
- scr , err := scramble (conn .Greeting .auth , opts .Pass )
550
- if err != nil {
551
- err = errors .New ("auth: scrambling failure " + err .Error ())
551
+ auth := opts .Auth
552
+ if opts .Auth == AutoAuth {
553
+ if conn .serverProtocolInfo .Auth != AutoAuth {
554
+ auth = conn .serverProtocolInfo .Auth
555
+ } else {
556
+ auth = ChapSha1Auth
557
+ }
558
+ }
559
+
560
+ var req Request
561
+ if auth == ChapSha1Auth {
562
+ salt := conn .Greeting .auth
563
+ req , err = newChapSha1AuthRequest (conn .opts .User , salt , opts .Pass )
564
+ if err != nil {
565
+ return fmt .Errorf ("auth: %w" , err )
566
+ }
567
+ } else if auth == PapSha256Auth {
568
+ if opts .Transport != connTransportSsl {
569
+ return errors .New ("auth: forbidden to use " + auth .String () +
570
+ " unless SSL is enabled for the connection" )
571
+ }
572
+ req = newPapSha256AuthRequest (conn .opts .User , opts .Pass )
573
+ } else {
552
574
connection .Close ()
553
- return err
575
+ return errors . New ( "auth: " + auth . String ())
554
576
}
555
- if err = conn .writeAuthRequest (w , scr ); err != nil {
577
+
578
+ if err = conn .writeRequest (w , req ); err != nil {
556
579
connection .Close ()
557
- return err
580
+ return fmt . Errorf ( "auth: %w" , err )
558
581
}
559
- if err = conn .readAuthResponse (r ); err != nil {
582
+ if _ , err = conn .readResponse (r ); err != nil {
560
583
connection .Close ()
561
- return err
584
+ return fmt . Errorf ( "auth: %w" , err )
562
585
}
563
586
}
564
587
@@ -662,28 +685,6 @@ func (conn *Connection) writeRequest(w *bufio.Writer, req Request) error {
662
685
return err
663
686
}
664
687
665
- func (conn * Connection ) writeAuthRequest (w * bufio.Writer , scramble []byte ) error {
666
- req := newAuthRequest (conn .opts .User , string (scramble ))
667
-
668
- err := conn .writeRequest (w , req )
669
- if err != nil {
670
- return fmt .Errorf ("auth: %w" , err )
671
- }
672
-
673
- return nil
674
- }
675
-
676
- func (conn * Connection ) writeIdRequest (w * bufio.Writer , protocolInfo ProtocolInfo ) error {
677
- req := NewIdRequest (protocolInfo )
678
-
679
- err := conn .writeRequest (w , req )
680
- if err != nil {
681
- return fmt .Errorf ("identify: %w" , err )
682
- }
683
-
684
- return nil
685
- }
686
-
687
688
func (conn * Connection ) readResponse (r io.Reader ) (Response , error ) {
688
689
respBytes , err := conn .read (r )
689
690
if err != nil {
@@ -707,24 +708,6 @@ func (conn *Connection) readResponse(r io.Reader) (Response, error) {
707
708
return resp , nil
708
709
}
709
710
710
- func (conn * Connection ) readAuthResponse (r io.Reader ) error {
711
- _ , err := conn .readResponse (r )
712
- if err != nil {
713
- return fmt .Errorf ("auth: %w" , err )
714
- }
715
-
716
- return nil
717
- }
718
-
719
- func (conn * Connection ) readIdResponse (r io.Reader ) (Response , error ) {
720
- resp , err := conn .readResponse (r )
721
- if err != nil {
722
- return resp , fmt .Errorf ("identify: %w" , err )
723
- }
724
-
725
- return resp , nil
726
- }
727
-
728
711
func (conn * Connection ) createConnection (reconnect bool ) (err error ) {
729
712
var reconnects uint
730
713
for conn .c == nil && conn .state == connDisconnected {
@@ -1625,19 +1608,20 @@ func checkProtocolInfo(expected ProtocolInfo, actual ProtocolInfo) error {
1625
1608
func (conn * Connection ) identify (w * bufio.Writer , r * bufio.Reader ) error {
1626
1609
var ok bool
1627
1610
1628
- werr := conn .writeIdRequest (w , clientProtocolInfo )
1611
+ req := NewIdRequest (clientProtocolInfo )
1612
+ werr := conn .writeRequest (w , req )
1629
1613
if werr != nil {
1630
- return werr
1614
+ return fmt . Errorf ( "identify: %w" , werr )
1631
1615
}
1632
1616
1633
- resp , rerr := conn .readIdResponse (r )
1617
+ resp , rerr := conn .readResponse (r )
1634
1618
if rerr != nil {
1635
1619
if resp .Code == ErrUnknownRequestType {
1636
1620
// IPROTO_ID requests are not supported by server.
1637
1621
return nil
1638
1622
}
1639
1623
1640
- return rerr
1624
+ return fmt . Errorf ( "identify: %w" , rerr )
1641
1625
}
1642
1626
1643
1627
if len (resp .Data ) == 0 {
@@ -1664,5 +1648,7 @@ func (conn *Connection) ServerProtocolInfo() ProtocolInfo {
1664
1648
// supported by Go connection client.
1665
1649
// Since 1.10.0
1666
1650
func (conn * Connection ) ClientProtocolInfo () ProtocolInfo {
1667
- return clientProtocolInfo .Clone ()
1651
+ info := clientProtocolInfo .Clone ()
1652
+ info .Auth = conn .opts .Auth
1653
+ return info
1668
1654
}
0 commit comments