|
| 1 | +import itertools |
1 | 2 | import socket
|
2 | 3 | import types
|
| 4 | +from unittest import TestCase |
3 | 5 | from unittest import mock
|
4 |
| -from unittest.mock import patch |
5 |
| - |
| 6 | +from unittest.mock import patch, MagicMock |
6 | 7 | import pytest
|
7 | 8 | import redis
|
8 | 9 | from redis import ConnectionPool, Redis
|
|
13 | 14 | SSLConnection,
|
14 | 15 | UnixDomainSocketConnection,
|
15 | 16 | parse_url,
|
| 17 | + UsernamePasswordCredentialProvider, |
| 18 | + AuthenticationError |
16 | 19 | )
|
17 | 20 | from redis.exceptions import ConnectionError, InvalidResponse, TimeoutError
|
18 | 21 | from redis.retry import Retry
|
@@ -55,7 +58,7 @@ def inner():
|
55 | 58 | # assert mod.get('fookey') == d
|
56 | 59 |
|
57 | 60 |
|
58 |
| -class TestConnection: |
| 61 | +class TestConnection(TestCase): |
59 | 62 | def test_disconnect(self):
|
60 | 63 | conn = Connection()
|
61 | 64 | mock_sock = mock.Mock()
|
@@ -131,6 +134,50 @@ def test_connect_timeout_error_without_retry(self):
|
131 | 134 | assert str(e.value) == "Timeout connecting to server"
|
132 | 135 | self.clear(conn)
|
133 | 136 |
|
| 137 | + @patch.object(Connection, 'send_command') |
| 138 | + @patch.object(Connection, 'read_response') |
| 139 | + def test_on_connect(self, mock_read_response, mock_send_command): |
| 140 | + """Test that the on_connect function sends the correct commands""" |
| 141 | + conn = Connection() |
| 142 | + |
| 143 | + conn._parser = MagicMock() |
| 144 | + conn._parser.on_connect.return_value = None |
| 145 | + conn.credential_provider = None |
| 146 | + conn.username = "myuser" |
| 147 | + conn.password = "password" |
| 148 | + conn.protocol = 3 |
| 149 | + conn.client_name = "test-client" |
| 150 | + conn.lib_name = "test" |
| 151 | + conn.lib_version = "1234" |
| 152 | + conn.db = 0 |
| 153 | + conn.client_cache = True |
| 154 | + |
| 155 | + # command response |
| 156 | + mock_read_response.side_effect = itertools.cycle([ |
| 157 | + b'QUEUED', # MULTI |
| 158 | + b'QUEUED', # HELLO |
| 159 | + b'QUEUED', # AUTH |
| 160 | + b'QUEUED', # CLIENT SETNAME |
| 161 | + b'QUEUED', # CLIENT SETINFO LIB-NAME |
| 162 | + b'QUEUED', # CLIENT SETINFO LIB-VER |
| 163 | + b'QUEUED', # SELECT |
| 164 | + b'QUEUED', # CLIENT TRACKING ON |
| 165 | + [ # EXEC response list |
| 166 | + {"proto": 3, "version": "6"}, |
| 167 | + b'OK', |
| 168 | + b'OK', |
| 169 | + b'OK', |
| 170 | + b'OK', |
| 171 | + b'OK', |
| 172 | + b'OK', |
| 173 | + b'OK' |
| 174 | + ] |
| 175 | + ]) |
| 176 | + |
| 177 | + conn.on_connect() |
| 178 | + |
| 179 | + mock_read_response.side_effect = itertools.repeat("OK") |
| 180 | + |
134 | 181 |
|
135 | 182 | @pytest.mark.onlynoncluster
|
136 | 183 | @pytest.mark.parametrize(
|
|
0 commit comments