diff --git a/nasdaqdatalink/connection.py b/nasdaqdatalink/connection.py index 2338a5f..718db1e 100644 --- a/nasdaqdatalink/connection.py +++ b/nasdaqdatalink/connection.py @@ -13,103 +13,109 @@ AuthenticationError, ForbiddenError, InvalidRequestError, NotFoundError, ServiceUnavailableError) +# global session +session = None -class Connection: - @classmethod - def request(cls, http_verb, url, **options): - if 'headers' in options: - headers = options['headers'] + +def request(http_verb, url, **options): + if 'headers' in options: + headers = options['headers'] + else: + headers = {} + + accept_value = 'application/json' + if ApiConfig.api_version: + accept_value += ", application/vnd.data.nasdaq+json;version=%s" % ApiConfig.api_version + + headers = Util.merge_to_dicts({ + 'accept': accept_value, + 'request-source': 'python', + 'request-source-version': VERSION + }, headers) + if ApiConfig.api_key: + headers = Util.merge_to_dicts({'x-api-token': ApiConfig.api_key}, headers) + + options['headers'] = headers + + abs_url = '%s/%s' % (ApiConfig.api_base, url) + + return execute_request(http_verb, abs_url, **options) + + +def execute_request(http_verb, url, **options): + session = get_session() + + try: + response = session.request( + method=http_verb, + url=url, + verify=ApiConfig.verify_ssl, + **options + ) + if response.status_code < 200 or response.status_code >= 300: + handle_api_error(response) else: - headers = {} - - accept_value = 'application/json' - if ApiConfig.api_version: - accept_value += ", application/vnd.data.nasdaq+json;version=%s" % ApiConfig.api_version - - headers = Util.merge_to_dicts({'accept': accept_value, - 'request-source': 'python', - 'request-source-version': VERSION}, headers) - if ApiConfig.api_key: - headers = Util.merge_to_dicts({'x-api-token': ApiConfig.api_key}, headers) - - options['headers'] = headers - - abs_url = '%s/%s' % (ApiConfig.api_base, url) - - return cls.execute_request(http_verb, abs_url, **options) - - @classmethod - def execute_request(cls, http_verb, url, **options): - session = cls.get_session() - - try: - response = session.request(method=http_verb, - url=url, - verify=ApiConfig.verify_ssl, - **options) - if response.status_code < 200 or response.status_code >= 300: - cls.handle_api_error(response) - else: - return response - except requests.exceptions.RequestException as e: - if e.response: - cls.handle_api_error(e.response) - raise e - - @classmethod - def get_session(cls): + return response + except requests.exceptions.RequestException as e: + if e.response: + handle_api_error(e.response) + raise e + + +def get_retries(): + if not ApiConfig.use_retries: + return Retry(total=0) + + Retry.BACKOFF_MAX = ApiConfig.max_wait_between_retries + retries = Retry(total=ApiConfig.number_of_retries, + connect=ApiConfig.number_of_retries, + read=ApiConfig.number_of_retries, + status_forcelist=ApiConfig.retry_status_codes, + backoff_factor=ApiConfig.retry_backoff_factor, + raise_on_status=False) + + return retries + + +def get_session(): + global session + if session is None: session = requests.Session() - adapter = HTTPAdapter(max_retries=cls.get_retries()) + adapter = HTTPAdapter(max_retries=get_retries()) session.mount(ApiConfig.api_protocol, adapter) - - return session - - @classmethod - def get_retries(cls): - if not ApiConfig.use_retries: - return Retry(total=0) - - Retry.BACKOFF_MAX = ApiConfig.max_wait_between_retries - retries = Retry(total=ApiConfig.number_of_retries, - connect=ApiConfig.number_of_retries, - read=ApiConfig.number_of_retries, - status_forcelist=ApiConfig.retry_status_codes, - backoff_factor=ApiConfig.retry_backoff_factor, - raise_on_status=False) - - return retries - - @classmethod - def parse(cls, response): - try: - return response.json() - except ValueError: - raise DataLinkError(http_status=response.status_code, http_body=response.text) - - @classmethod - def handle_api_error(cls, resp): - error_body = cls.parse(resp) - - # if our app does not form a proper data_link_error response - # throw generic error - if 'error' not in error_body: - raise DataLinkError(http_status=resp.status_code, http_body=resp.text) - - code = error_body['error']['code'] - message = error_body['error']['message'] - prog = re.compile('^QE([a-zA-Z])x') - if prog.match(code): - code_letter = prog.match(code).group(1) - - d_klass = { - 'L': LimitExceededError, - 'M': InternalServerError, - 'A': AuthenticationError, - 'P': ForbiddenError, - 'S': InvalidRequestError, - 'C': NotFoundError, - 'X': ServiceUnavailableError - } - klass = d_klass.get(code_letter, DataLinkError) - - raise klass(message, resp.status_code, resp.text, resp.headers, code) + return session + + +def parse(response): + try: + return response.json() + except ValueError: + raise DataLinkError(http_status=response.status_code, http_body=response.text) + + +def handle_api_error(resp): + error_body = parse(resp) + + # if our app does not form a proper data_link_error response + # throw generic error + if 'error' not in error_body: + raise DataLinkError(http_status=resp.status_code, http_body=resp.text) + + code = error_body['error']['code'] + message = error_body['error']['message'] + prog = re.compile('^QE([a-zA-Z])x') + if prog.match(code): + code_letter = prog.match(code).group(1) + + d_klass = { + 'L': LimitExceededError, + 'M': InternalServerError, + 'A': AuthenticationError, + 'P': ForbiddenError, + 'S': InvalidRequestError, + 'C': NotFoundError, + 'X': ServiceUnavailableError + } + klass = d_klass.get(code_letter, DataLinkError) + + raise klass(message, resp.status_code, resp.text, resp.headers, code) diff --git a/nasdaqdatalink/model/database.py b/nasdaqdatalink/model/database.py index 870dedc..fbf9e73 100644 --- a/nasdaqdatalink/model/database.py +++ b/nasdaqdatalink/model/database.py @@ -4,7 +4,7 @@ import nasdaqdatalink.model.dataset from nasdaqdatalink.api_config import ApiConfig -from nasdaqdatalink.connection import Connection +import nasdaqdatalink.connection as connection from nasdaqdatalink.errors.data_link_error import DataLinkError from nasdaqdatalink.message import Message from nasdaqdatalink.operations.get import GetOperation @@ -43,7 +43,7 @@ def bulk_download_to_file(self, file_or_folder_path, **options): path_url = self._bulk_download_path() options['stream'] = True - r = Connection.request('get', path_url, **options) + r = connection.request('get', path_url, **options) file_path = file_or_folder_path if os.path.isdir(file_or_folder_path): file_path = file_or_folder_path + '/' + os.path.basename(urlparse(r.url).path) diff --git a/nasdaqdatalink/model/datatable.py b/nasdaqdatalink/model/datatable.py index 2edadb8..935590e 100644 --- a/nasdaqdatalink/model/datatable.py +++ b/nasdaqdatalink/model/datatable.py @@ -3,7 +3,7 @@ from six.moves.urllib.request import urlopen -from nasdaqdatalink.connection import Connection +import nasdaqdatalink.connection as connection from nasdaqdatalink.errors.data_link_error import DataLinkError from nasdaqdatalink.message import Message from nasdaqdatalink.operations.get import GetOperation @@ -51,7 +51,7 @@ def _request_file_info(self, file_or_folder_path, **options): updated_options = Util.convert_options(request_type=request_type, **options) - r = Connection.request(request_type, url, **updated_options) + r = connection.request(request_type, url, **updated_options) response_data = r.json() diff --git a/nasdaqdatalink/operations/get.py b/nasdaqdatalink/operations/get.py index 8f93b95..efe3a26 100644 --- a/nasdaqdatalink/operations/get.py +++ b/nasdaqdatalink/operations/get.py @@ -1,7 +1,7 @@ from inflection import singularize from .operation import Operation -from nasdaqdatalink.connection import Connection +import nasdaqdatalink.connection as connection from nasdaqdatalink.util import Util @@ -21,7 +21,7 @@ def __get_raw_data__(self): path = Util.constructed_path(cls.get_path(), options['params']) - r = Connection.request('get', path, **options) + r = connection.request('get', path, **options) response_data = r.json() Util.convert_to_dates(response_data) self._raw_data = response_data[singularize(cls.lookup_key())] diff --git a/nasdaqdatalink/operations/list.py b/nasdaqdatalink/operations/list.py index 6aa020a..6e94e78 100644 --- a/nasdaqdatalink/operations/list.py +++ b/nasdaqdatalink/operations/list.py @@ -1,5 +1,5 @@ from .operation import Operation -from nasdaqdatalink.connection import Connection +import nasdaqdatalink.connection as connection from nasdaqdatalink.util import Util from nasdaqdatalink.model.paginated_list import PaginatedList from nasdaqdatalink.utils.request_type_util import RequestType @@ -12,7 +12,7 @@ def all(cls, **options): if 'params' not in options: options['params'] = {} path = Util.constructed_path(cls.list_path(), options['params']) - r = Connection.request('get', path, **options) + r = connection.request('get', path, **options) response_data = r.json() Util.convert_to_dates(response_data) resource = cls.create_list_from_response(response_data) @@ -27,7 +27,7 @@ def page(cls, datatable, **options): updated_options = Util.convert_options(request_type=request_type, **options) - r = Connection.request(request_type, path, **updated_options) + r = connection.request(request_type, path, **updated_options) response_data = r.json() Util.convert_to_dates(response_data) diff --git a/test/test_connection.py b/test/test_connection.py index 96d8380..384d1e0 100644 --- a/test/test_connection.py +++ b/test/test_connection.py @@ -1,4 +1,4 @@ -from nasdaqdatalink.connection import Connection +import nasdaqdatalink.connection as connection from nasdaqdatalink.api_config import ApiConfig from nasdaqdatalink.errors.data_link_error import ( DataLinkError, LimitExceededError, InternalServerError, @@ -42,7 +42,7 @@ def test_nasdaqdatalink_exceptions_no_retries(self, request_method): for expected_error in data_link_errors: self.assertRaises( - expected_error[2], lambda: Connection.request(request_method, 'databases')) + expected_error[2], lambda: connection.request(request_method, 'databases')) @parameterized.expand(['GET', 'POST']) def test_parse_error(self, request_method): @@ -51,7 +51,7 @@ def test_parse_error(self, request_method): "https://data.nasdaq.com/api/v3/databases", body="not json", status=500) self.assertRaises( - DataLinkError, lambda: Connection.request(request_method, 'databases')) + DataLinkError, lambda: connection.request(request_method, 'databases')) @parameterized.expand(['GET', 'POST']) def test_non_data_link_error(self, request_method): @@ -62,16 +62,16 @@ def test_non_data_link_error(self, request_method): {'foobar': {'code': 'blah', 'message': 'something went wrong'}}), status=500) self.assertRaises( - DataLinkError, lambda: Connection.request(request_method, 'databases')) + DataLinkError, lambda: connection.request(request_method, 'databases')) @parameterized.expand(['GET', 'POST']) - @patch('nasdaqdatalink.connection.Connection.execute_request') + @patch('nasdaqdatalink.connection.execute_request') def test_build_request(self, request_method, mock): ApiConfig.api_key = 'api_token' ApiConfig.api_version = '2015-04-09' params = {'per_page': 10, 'page': 2} headers = {'x-custom-header': 'header value'} - Connection.request(request_method, 'databases', headers=headers, params=params) + connection.request(request_method, 'databases', headers=headers, params=params) expected = call(request_method, 'https://data.nasdaq.com/api/v3/databases', headers={'x-custom-header': 'header value', 'x-api-token': 'api_token', @@ -81,3 +81,15 @@ def test_build_request(self, request_method, mock): 'request-source-version': VERSION}, params={'per_page': 10, 'page': 2}) self.assertEqual(mock.call_args, expected) + + def test_session_reuse(self): + session1 = connection.get_session() + session2 = connection.get_session() + areSessionsSame = session1 is session2 + + adapter1 = connection.get_session().get_adapter(ApiConfig.api_protocol) + adapter2 = connection.get_session().get_adapter(ApiConfig.api_protocol) + areAdaptersSame = adapter1 is adapter2 + + self.assertEqual(areAdaptersSame, True) + self.assertEqual(areSessionsSame, True) diff --git a/test/test_data.py b/test/test_data.py index 7852dbe..53817a1 100644 --- a/test/test_data.py +++ b/test/test_data.py @@ -77,7 +77,7 @@ def tearDownClass(cls): httpretty.disable() httpretty.reset() - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_data_calls_connection(self, mock): Data.all(params={'database_code': 'NSE', 'dataset_code': 'OIL'}) expected = call('get', 'datasets/NSE/OIL/data', params={}) diff --git a/test/test_database.py b/test/test_database.py index bbae558..7b38b98 100644 --- a/test/test_database.py +++ b/test/test_database.py @@ -7,7 +7,7 @@ from six.moves.urllib.parse import parse_qs, urlparse from nasdaqdatalink.api_config import ApiConfig -from nasdaqdatalink.connection import Connection +import nasdaqdatalink.connection as connection from nasdaqdatalink.errors.data_link_error import (InternalServerError, DataLinkError) from nasdaqdatalink.model.database import Database from test.factories.database import DatabaseFactory @@ -34,7 +34,7 @@ def tearDownClass(cls): httpretty.disable() httpretty.reset() - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_database_calls_connection(self, mock): database = Database('NSE') database.data_fields() @@ -80,7 +80,7 @@ def tearDownClass(cls): httpretty.disable() httpretty.reset() - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_databases_calls_connection(self, mock): Database.all() expected = call('get', 'databases', params={}) @@ -148,7 +148,7 @@ def test_get_bulk_download_url_without_download_type(self): def test_bulk_download_to_fileaccepts_download_type(self): m = mock_open() - with patch.object(Connection, 'request') as mock_method: + with patch.object(connection, 'request') as mock_method: mock_method.return_value.url = 'https://www.blah.com/download/db.zip' with patch('nasdaqdatalink.model.database.open', m, create=True): self.database.bulk_download_to_file( diff --git a/test/test_dataset.py b/test/test_dataset.py index c44ea65..aed9b8a 100644 --- a/test/test_dataset.py +++ b/test/test_dataset.py @@ -30,7 +30,7 @@ def tearDownClass(cls): httpretty.disable() httpretty.reset() - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_dataset_calls_connection(self, mock): d = Dataset('NSE/OIL') d.data_fields() @@ -84,7 +84,7 @@ def tearDownClass(cls): httpretty.disable() httpretty.reset() - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_datasets_calls_connection(self, mock): Dataset.all() expected = call('get', 'datasets', params={}) diff --git a/test/test_datatable.py b/test/test_datatable.py index ab80194..ff5525b 100644 --- a/test/test_datatable.py +++ b/test/test_datatable.py @@ -37,26 +37,26 @@ def tearDownClass(cls): def tearDown(self): RequestType.USE_GET_REQUEST = True - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_datatable_metadata_calls_connection(self, mock): Datatable('ZACKS/FC').data_fields() expected = call('get', 'datatables/ZACKS/FC/metadata', params={}) self.assertEqual(mock.call_args, expected) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_datatable_data_calls_connection_with_no_params_for_get_request(self, mock): Datatable('ZACKS/FC').data() expected = call('get', 'datatables/ZACKS/FC', params={}) self.assertEqual(mock.call_args, expected) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_datatable_data_calls_connection_with_no_params_for_post_request(self, mock): RequestType.USE_GET_REQUEST = False Datatable('ZACKS/FC').data() expected = call('post', 'datatables/ZACKS/FC', json={}) self.assertEqual(mock.call_args, expected) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_datatable_calls_connection_with_params_for_get_request(self, mock): params = {'ticker': ['AAPL', 'MSFT'], 'per_end_date': {'gte': '2015-01-01'}, @@ -76,7 +76,7 @@ def test_datatable_calls_connection_with_params_for_get_request(self, mock): expected = call('get', 'datatables/ZACKS/FC', params=expected_params) self.assertEqual(mock.call_args, expected) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_datatable_calls_connection_with_params_for_post_request(self, mock): RequestType.USE_GET_REQUEST = False params = {'ticker': ['AAPL', 'MSFT'], diff --git a/test/test_datatable_data.py b/test/test_datatable_data.py index 7ba53f3..b837fc9 100644 --- a/test/test_datatable_data.py +++ b/test/test_datatable_data.py @@ -83,7 +83,7 @@ def tearDownClass(cls): def tearDown(self): RequestType.USE_GET_REQUEST = True - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_data_calls_connection_get(self, mock): datatable = Datatable('ZACKS/FC') Data.page(datatable, params={'ticker': ['AAPL', 'MSFT'], @@ -95,7 +95,7 @@ def test_data_calls_connection_get(self, mock): 'qopts.columns[]': ['ticker', 'per_end_date']}) self.assertEqual(mock.call_args, expected) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_data_calls_connection_post(self, mock): RequestType.USE_GET_REQUEST = False datatable = Datatable('ZACKS/FC') diff --git a/test/test_get.py b/test/test_get.py index 950c5c5..66f2ba6 100644 --- a/test/test_get.py +++ b/test/test_get.py @@ -8,7 +8,7 @@ from nasdaqdatalink.model.merged_dataset import MergedDataset from nasdaqdatalink.get import get from nasdaqdatalink.api_config import ApiConfig -from nasdaqdatalink.connection import Connection +import nasdaqdatalink.connection as connection class GetSingleDatasetTest(unittest.TestCase): @@ -36,8 +36,8 @@ def test_returns_numpys_when_requested(self): self.assertIsInstance(result, numpy.core.records.recarray) def test_setting_api_key_config(self): - mock_connection = Mock(wraps=Connection) - with patch('nasdaqdatalink.connection.Connection.execute_request', + mock_connection = Mock(wraps=connection) + with patch('nasdaqdatalink.connection.execute_request', new=mock_connection.execute_request) as mock: ApiConfig.api_key = 'api_key_configured' get('NSE/OIL') diff --git a/test/test_get_point_in_time_data.py b/test/test_get_point_in_time_data.py index 8fa57f7..ef0b236 100644 --- a/test/test_get_point_in_time_data.py +++ b/test/test_get_point_in_time_data.py @@ -27,7 +27,7 @@ def tearDownClass(cls): def tearDown(self): RequestType.USE_GET_REQUEST = True - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_get_point_in_time_returns_data_frame_object(self, mock): with self.assertWarns(UserWarning): df = nasdaqdatalink.get_point_in_time( @@ -36,7 +36,7 @@ def test_get_point_in_time_returns_data_frame_object(self, mock): self.assertIsInstance(df, pandas.core.frame.DataFrame) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_asofdate_call_connection(self, mock): with self.assertWarns(UserWarning): nasdaqdatalink.get_point_in_time('ZACKS/FC', interval='asofdate', date='2020-01-01') @@ -44,7 +44,7 @@ def test_asofdate_call_connection(self, mock): self.assertEqual(mock.call_args, expected) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_asofdate_call_connection_with_datetimes(self, mock): with self.assertWarns(UserWarning): nasdaqdatalink.get_point_in_time( @@ -54,7 +54,7 @@ def test_asofdate_call_connection_with_datetimes(self, mock): self.assertEqual(mock.call_args, expected) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_asofdate_call_without_date(self, mock): with self.assertWarns(UserWarning): nasdaqdatalink.get_point_in_time('ZACKS/FC', interval='asofdate') @@ -62,7 +62,7 @@ def test_asofdate_call_without_date(self, mock): self.assertEqual(mock.call_args, expected) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_from_call_connection(self, mock): with self.assertWarns(UserWarning): nasdaqdatalink.get_point_in_time( @@ -75,7 +75,7 @@ def test_from_call_connection(self, mock): self.assertEqual(mock.call_args, expected) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_from_call_connection_with_datetimes(self, mock): with self.assertWarns(UserWarning): nasdaqdatalink.get_point_in_time( @@ -90,7 +90,7 @@ def test_from_call_connection_with_datetimes(self, mock): self.assertEqual(mock.call_args, expected) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_between_call_connection(self, mock): with self.assertWarns(UserWarning): nasdaqdatalink.get_point_in_time( @@ -103,7 +103,7 @@ def test_between_call_connection(self, mock): self.assertEqual(mock.call_args, expected) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_between_call_connection_with_datetimes(self, mock): with self.assertWarns(UserWarning): nasdaqdatalink.get_point_in_time( @@ -118,7 +118,7 @@ def test_between_call_connection_with_datetimes(self, mock): self.assertEqual(mock.call_args, expected) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_invalid_interval_connection(self, mock): self.assertRaises(InvalidRequestError, lambda: nasdaqdatalink.get_point_in_time('ZACKS/FC')) self.assertRaises( @@ -126,7 +126,7 @@ def test_invalid_interval_connection(self, mock): lambda: nasdaqdatalink.get_point_in_time('ZACKS/FC', interval='nasdaqdatalink') ) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_invalid_from_connection(self, mock): self.assertRaises( InvalidRequestError, @@ -145,7 +145,7 @@ def test_invalid_from_connection(self, mock): ) ) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_invalid_between_connection(self, mock): self.assertRaises( InvalidRequestError, diff --git a/test/test_get_table.py b/test/test_get_table.py index 8100b66..7f49f84 100644 --- a/test/test_get_table.py +++ b/test/test_get_table.py @@ -37,21 +37,21 @@ def tearDownClass(cls): def tearDown(self): RequestType.USE_GET_REQUEST = True - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_datatable_returns_datatable_object(self, mock): with self.assertWarns(UserWarning): df = nasdaqdatalink.get_table('ZACKS/FC', params={}) self.assertIsInstance(df, pandas.core.frame.DataFrame) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_datatable_with_code_returns_datatable_object(self, mock): with self.assertWarns(UserWarning): df = nasdaqdatalink.get_table('AR/MWCF', code="ICEP_WAC_Z2017_S") self.assertIsInstance(df, pandas.core.frame.DataFrame) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_get_table_calls_connection_with_no_params_for_get_request(self, mock): with self.assertWarns(UserWarning): nasdaqdatalink.get_table('ZACKS/FC') @@ -59,7 +59,7 @@ def test_get_table_calls_connection_with_no_params_for_get_request(self, mock): self.assertEqual(mock.call_args, expected) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_get_table_calls_connection_with_no_params_for_post_request(self, mock): with self.assertWarns(UserWarning): RequestType.USE_GET_REQUEST = False @@ -69,7 +69,7 @@ def test_get_table_calls_connection_with_no_params_for_post_request(self, mock): self.assertEqual(mock.call_args, expected) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_get_table_calls_connection_with_params_for_get_request(self, mock): with self.assertWarns(UserWarning): params = { @@ -93,7 +93,7 @@ def test_get_table_calls_connection_with_params_for_get_request(self, mock): self.assertEqual(mock.call_args, expected) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_get_table_calls_connection_with_params_for_post_request(self, mock): with self.assertWarns(UserWarning): RequestType.USE_GET_REQUEST = False diff --git a/test/test_point_in_time.py b/test/test_point_in_time.py index 07918e1..f73e33b 100644 --- a/test/test_point_in_time.py +++ b/test/test_point_in_time.py @@ -26,7 +26,7 @@ def tearDownClass(cls): def tearDown(self): RequestType.USE_GET_REQUEST = True - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_asofdate_call_connection(self, mock): PointInTime( 'ZACKS/FC', @@ -38,7 +38,7 @@ def test_asofdate_call_connection(self, mock): expected = call('get', 'pit/ZACKS/FC/asofdate/2020-01-01', params={}) self.assertEqual(mock.call_args, expected) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_from_call_connection(self, mock): PointInTime( 'ZACKS/FC', @@ -51,7 +51,7 @@ def test_from_call_connection(self, mock): expected = call('get', 'pit/ZACKS/FC/from/2020-01-01/to/2020-01-02', params={}) self.assertEqual(mock.call_args, expected) - @patch('nasdaqdatalink.connection.Connection.request') + @patch('nasdaqdatalink.connection.request') def test_between_call_connection(self, mock): PointInTime( 'ZACKS/FC', diff --git a/test/test_retries.py b/test/test_retries.py index 3028095..69c7653 100644 --- a/test/test_retries.py +++ b/test/test_retries.py @@ -1,7 +1,7 @@ import unittest import json -from nasdaqdatalink.connection import Connection +import nasdaqdatalink.connection as connection from nasdaqdatalink.api_config import ApiConfig from test.factories.datatable import DatatableFactory from test.helpers.httpretty_extension import httpretty @@ -28,6 +28,8 @@ def tearDown(self): class TestRetries(ModifyRetrySettingsTestCase): def setUp(self): + # reset session to None before every test + connection.session = None ApiConfig.use_retries = True super(TestRetries, self).setUp() @@ -47,13 +49,13 @@ def setUpClass(cls): def test_modifying_use_retries(self): ApiConfig.use_retries = False - retries = Connection.get_session().get_adapter(ApiConfig.api_protocol).max_retries + retries = connection.get_session().get_adapter(ApiConfig.api_protocol).max_retries self.assertEqual(retries.total, 0) def test_modifying_number_of_retries(self): ApiConfig.number_of_retries = 3000 - retries = Connection.get_session().get_adapter(ApiConfig.api_protocol).max_retries + retries = connection.get_session().get_adapter(ApiConfig.api_protocol).max_retries self.assertEqual(retries.total, ApiConfig.number_of_retries) self.assertEqual(retries.connect, ApiConfig.number_of_retries) @@ -62,19 +64,19 @@ def test_modifying_number_of_retries(self): def test_modifying_retry_backoff_factor(self): ApiConfig.retry_backoff_factor = 3000 - retries = Connection.get_session().get_adapter(ApiConfig.api_protocol).max_retries + retries = connection.get_session().get_adapter(ApiConfig.api_protocol).max_retries self.assertEqual(retries.backoff_factor, ApiConfig.retry_backoff_factor) def test_modifying_retry_status_codes(self): ApiConfig.retry_status_codes = [1, 2, 3] - retries = Connection.get_session().get_adapter(ApiConfig.api_protocol).max_retries + retries = connection.get_session().get_adapter(ApiConfig.api_protocol).max_retries self.assertEqual(retries.status_forcelist, ApiConfig.retry_status_codes) def test_modifying_max_wait_between_retries(self): ApiConfig.max_wait_between_retries = 3000 - retries = Connection.get_session().get_adapter(ApiConfig.api_protocol).max_retries + retries = connection.get_session().get_adapter(ApiConfig.api_protocol).max_retries self.assertEqual(retries.BACKOFF_MAX, ApiConfig.max_wait_between_retries) @httpretty.enabled @@ -87,7 +89,7 @@ def test_correct_response_returned_if_retries_succeed(self): "https://data.nasdaq.com/api/v3/databases", responses=mock_responses) - response = Connection.request('get', 'databases') + response = connection.request('get', 'databases') self.assertEqual(response.json(), self.datatable) self.assertEqual(response.status_code, self.success_response.status) @@ -100,7 +102,7 @@ def test_correct_response_exception_raised_if_retries_fail(self): "https://data.nasdaq.com/api/v3/databases", responses=mock_responses) - self.assertRaises(InternalServerError, Connection.request, 'get', 'databases') + self.assertRaises(InternalServerError, connection.request, 'get', 'databases') @httpretty.enabled def test_correct_response_exception_raised_for_errors_not_in_retry_status_codes(self): @@ -110,4 +112,4 @@ def test_correct_response_exception_raised_for_errors_not_in_retry_status_codes( "https://data.nasdaq.com/api/v3/databases", responses=mock_responses) - self.assertRaises(InternalServerError, Connection.request, 'get', 'databases') + self.assertRaises(InternalServerError, connection.request, 'get', 'databases')