Skip to content

Commit 31e76e5

Browse files
committed
Merge branch 'rr/auth-session-support' into devel
* rr/auth-session-support: change way to remove params fix lint add test fix pop missing key fix get api_config before params assignment clean request payload fix session object not getting through add test add test for api_config fix null access api_config issue add AuthorizedSession support
2 parents 1e81355 + 24390ef commit 31e76e5

11 files changed

+288
-22
lines changed

nasdaqdatalink/__init__.py

+1
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@
1010
from .model.point_in_time import PointInTime
1111
from .model.data import Data
1212
from .model.merged_dataset import MergedDataset
13+
from .model.authorized_session import AuthorizedSession
1314
from .get import get
1415
from .bulkdownload import bulkdownload
1516
from .export_table import export_table

nasdaqdatalink/api_config.py

+23
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,18 @@ class ApiConfig:
1717
retry_status_codes = [429] + list(range(500, 512))
1818
verify_ssl = True
1919

20+
def read_key(self, filename=None):
21+
if not os.path.isfile(filename):
22+
raise_empty_file(filename)
23+
24+
with open(filename, 'r') as f:
25+
apikey = get_first_non_empty(f)
26+
27+
if not apikey:
28+
raise_empty_file(filename)
29+
30+
self.api_key = apikey
31+
2032

2133
def create_file(config_filename):
2234
# Create the file as well as the parent dir if needed.
@@ -102,3 +114,14 @@ def read_key(filename=None):
102114
read_key_from_environment_variable()
103115
elif config_file_exists(filename):
104116
read_key_from_file(filename)
117+
118+
119+
def get_config_from_kwargs(kwargs):
120+
result = ApiConfig
121+
if isinstance(kwargs, dict):
122+
params = kwargs.get('params')
123+
if isinstance(params, dict):
124+
result = params.get('api_config')
125+
if not isinstance(result, ApiConfig):
126+
result = ApiConfig
127+
return result

nasdaqdatalink/connection.py

+24-8
Original file line numberDiff line numberDiff line change
@@ -8,12 +8,17 @@
88

99
from .util import Util
1010
from .version import VERSION
11-
from .api_config import ApiConfig
11+
from .api_config import ApiConfig, get_config_from_kwargs
1212
from nasdaqdatalink.errors.data_link_error import (
1313
DataLinkError, LimitExceededError, InternalServerError,
1414
AuthenticationError, ForbiddenError, InvalidRequestError,
1515
NotFoundError, ServiceUnavailableError)
1616

17+
KW_TO_REMOVE = [
18+
'session',
19+
'api_config'
20+
]
21+
1722

1823
class Connection:
1924
@classmethod
@@ -22,31 +27,37 @@ def request(cls, http_verb, url, **options):
2227
headers = options['headers']
2328
else:
2429
headers = {}
30+
api_config = get_config_from_kwargs(options)
2531

2632
accept_value = 'application/json'
27-
if ApiConfig.api_version:
28-
accept_value += ", application/vnd.data.nasdaq+json;version=%s" % ApiConfig.api_version
33+
if api_config.api_version:
34+
accept_value += ", application/vnd.data.nasdaq+json;version=%s" % api_config.api_version
2935

3036
headers = Util.merge_to_dicts({'accept': accept_value,
3137
'request-source': 'python',
3238
'request-source-version': VERSION}, headers)
33-
if ApiConfig.api_key:
34-
headers = Util.merge_to_dicts({'x-api-token': ApiConfig.api_key}, headers)
39+
if api_config.api_key:
40+
headers = Util.merge_to_dicts({'x-api-token': api_config.api_key}, headers)
3541

3642
options['headers'] = headers
3743

38-
abs_url = '%s/%s' % (ApiConfig.api_base, url)
44+
abs_url = '%s/%s' % (api_config.api_base, url)
3945

4046
return cls.execute_request(http_verb, abs_url, **options)
4147

4248
@classmethod
4349
def execute_request(cls, http_verb, url, **options):
44-
session = cls.get_session()
50+
session = options.get('params', {}).get('session', None)
51+
if session is None:
52+
session = cls.get_session()
53+
54+
api_config = get_config_from_kwargs(options)
4555

56+
cls.options_kw_strip(options)
4657
try:
4758
response = session.request(method=http_verb,
4859
url=url,
49-
verify=ApiConfig.verify_ssl,
60+
verify=api_config.verify_ssl,
5061
**options)
5162
if response.status_code < 200 or response.status_code >= 300:
5263
cls.handle_api_error(response)
@@ -118,3 +129,8 @@ def handle_api_error(cls, resp):
118129
klass = d_klass.get(code_letter, DataLinkError)
119130

120131
raise klass(message, resp.status_code, resp.text, resp.headers, code)
132+
133+
@classmethod
134+
def options_kw_strip(self, options):
135+
for kw in KW_TO_REMOVE:
136+
options.get('params', {}).pop(kw, None)

nasdaqdatalink/get_point_in_time.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@ def get_point_in_time(datatable_code, **options):
2323

2424
data = None
2525
page_count = 0
26+
api_config = options.get('api_config', ApiConfig)
2627
while True:
2728
next_options = copy.deepcopy(options)
2829
next_data = PointInTime(datatable_code, pit=pit_options).data(params=next_options)
@@ -32,10 +33,10 @@ def get_point_in_time(datatable_code, **options):
3233
else:
3334
data.extend(next_data)
3435

35-
if page_count >= ApiConfig.page_limit:
36+
if page_count >= api_config.page_limit:
3637
raise LimitExceededError(
3738
Message.WARN_DATA_LIMIT_EXCEEDED % (datatable_code,
38-
ApiConfig.api_key
39+
api_config.api_key
3940
)
4041
)
4142

nasdaqdatalink/get_table.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@ def get_table(datatable_code, **options):
1414

1515
data = None
1616
page_count = 0
17+
api_config = options.get('api_config', ApiConfig)
1718
while True:
1819
next_options = copy.deepcopy(options)
1920
next_data = Datatable(datatable_code).data(params=next_options)
@@ -23,10 +24,10 @@ def get_table(datatable_code, **options):
2324
else:
2425
data.extend(next_data)
2526

26-
if page_count >= ApiConfig.page_limit:
27+
if page_count >= api_config.page_limit:
2728
raise LimitExceededError(
2829
Message.WARN_DATA_LIMIT_EXCEEDED % (datatable_code,
29-
ApiConfig.api_key
30+
api_config.api_key
3031
)
3132
)
3233

+57
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,57 @@
1+
import nasdaqdatalink
2+
from nasdaqdatalink.api_config import ApiConfig
3+
from urllib3.util.retry import Retry
4+
from requests.adapters import HTTPAdapter
5+
import requests
6+
import urllib
7+
8+
9+
def get_retries(api_config=nasdaqdatalink.ApiConfig):
10+
retries = None
11+
if not api_config.use_retries:
12+
return Retry(total=0)
13+
14+
Retry.BACKOFF_MAX = api_config.max_wait_between_retries
15+
retries = Retry(total=api_config.number_of_retries,
16+
connect=api_config.number_of_retries,
17+
read=api_config.number_of_retries,
18+
status_forcelist=api_config.retry_status_codes,
19+
backoff_factor=api_config.retry_backoff_factor,
20+
raise_on_status=False)
21+
return retries
22+
23+
24+
class AuthorizedSession:
25+
def __init__(self, api_config=ApiConfig) -> None:
26+
super(AuthorizedSession, self).__init__()
27+
if not isinstance(api_config, ApiConfig):
28+
api_config = ApiConfig
29+
self._api_config = api_config
30+
self._auth_session = requests.Session()
31+
retries = get_retries(self._api_config)
32+
adapter = HTTPAdapter(max_retries=retries)
33+
self._auth_session.mount(api_config.api_protocol, adapter)
34+
35+
proxies = urllib.request.getproxies()
36+
if proxies is not None:
37+
self._auth_session.proxies.update(proxies)
38+
39+
def get(self, dataset, **kwargs):
40+
nasdaqdatalink.get(dataset, session=self._auth_session,
41+
api_config=self._api_config, **kwargs)
42+
43+
def bulkdownload(self, database, **kwargs):
44+
nasdaqdatalink.bulkdownload(database, session=self._auth_session,
45+
api_config=self._api_config, **kwargs)
46+
47+
def export_table(self, datatable_code, **kwargs):
48+
nasdaqdatalink.export_table(datatable_code, session=self._auth_session,
49+
api_config=self._api_config, **kwargs)
50+
51+
def get_table(self, datatable_code, **options):
52+
nasdaqdatalink.get_table(datatable_code, session=self._auth_session,
53+
api_config=self._api_config, **options)
54+
55+
def get_point_in_time(self, datatable_code, **options):
56+
nasdaqdatalink.get_point_in_time(datatable_code, session=self._auth_session,
57+
api_config=self._api_config, **options)

nasdaqdatalink/model/database.py

+7-6
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@
33
from six.moves.urllib.parse import urlencode, urlparse
44

55
import nasdaqdatalink.model.dataset
6-
from nasdaqdatalink.api_config import ApiConfig
6+
from nasdaqdatalink.api_config import get_config_from_kwargs
77
from nasdaqdatalink.connection import Connection
88
from nasdaqdatalink.errors.data_link_error import DataLinkError
99
from nasdaqdatalink.message import Message
@@ -21,15 +21,16 @@ def get_code_from_meta(cls, metadata):
2121
return metadata['database_code']
2222

2323
def bulk_download_url(self, **options):
24+
api_config = get_config_from_kwargs(options)
2425
url = self._bulk_download_path()
25-
url = ApiConfig.api_base + '/' + url
26+
url = api_config.api_base + '/' + url
2627

2728
if 'params' not in options:
2829
options['params'] = {}
29-
if ApiConfig.api_key:
30-
options['params']['api_key'] = ApiConfig.api_key
31-
if ApiConfig.api_version:
32-
options['params']['api_version'] = ApiConfig.api_version
30+
if api_config.api_key:
31+
options['params']['api_key'] = api_config.api_key
32+
if api_config.api_version:
33+
options['params']['api_version'] = api_config.api_version
3334

3435
if list(options.keys()):
3536
url += '?' + urlencode(options['params'])

nasdaqdatalink/utils/request_type_util.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
from urllib.parse import urlencode
2-
from nasdaqdatalink.api_config import ApiConfig
2+
from nasdaqdatalink.api_config import get_config_from_kwargs
33

44

55
class RequestType(object):
@@ -13,7 +13,8 @@ class RequestType(object):
1313
@classmethod
1414
def get_request_type(cls, url, **params):
1515
query_string = urlencode(params['params'])
16-
request_url = '%s/%s/%s' % (ApiConfig.api_base, url, query_string)
16+
api_config = get_config_from_kwargs(params)
17+
request_url = '%s/%s/%s' % (api_config.api_base, url, query_string)
1718
if RequestType.USE_GET_REQUEST and (len(request_url) < cls.MAX_URL_LENGTH_FOR_GET):
1819
return 'get'
1920
else:

test/test_api_config.py

+55
Original file line numberDiff line numberDiff line change
@@ -132,3 +132,58 @@ def test_read_key_from_file_with_tab(self):
132132
def test_read_key_from_file_with_multi_newline(self):
133133
given = "keyfordefaultfile\n\nanotherkey\n"
134134
self._read_key_from_file_helper(given, TEST_DEFAULT_FILE_CONTENTS)
135+
136+
def test_default_instance_will_have_share_values_with_singleton(self):
137+
os.environ['NASDAQ_DATA_LINK_API_KEY'] = 'setinenv'
138+
ApiConfig.api_key = None
139+
read_key()
140+
api_config = ApiConfig()
141+
self.assertEqual(api_config.api_key, "setinenv")
142+
# make sure change in instance will not affect the singleton
143+
api_config.api_key = None
144+
self.assertEqual(ApiConfig.api_key, "setinenv")
145+
146+
def test_get_config_from_kwargs_return_api_config_if_present(self):
147+
api_config = get_config_from_kwargs({
148+
'params': {
149+
'api_config': ApiConfig()
150+
}
151+
})
152+
self.assertTrue(isinstance(api_config, ApiConfig))
153+
154+
def test_get_config_from_kwargs_return_singleton_if_not_present_or_wrong_type(self):
155+
api_config = get_config_from_kwargs(None)
156+
self.assertTrue(issubclass(api_config, ApiConfig))
157+
self.assertFalse(isinstance(api_config, ApiConfig))
158+
api_config = get_config_from_kwargs(1)
159+
self.assertTrue(issubclass(api_config, ApiConfig))
160+
self.assertFalse(isinstance(api_config, ApiConfig))
161+
api_config = get_config_from_kwargs({
162+
'params': None
163+
})
164+
self.assertTrue(issubclass(api_config, ApiConfig))
165+
self.assertFalse(isinstance(api_config, ApiConfig))
166+
167+
def test_instance_read_key_should_raise_error(self):
168+
api_config = ApiConfig()
169+
with self.assertRaises(TypeError):
170+
api_config.read_key(None)
171+
with self.assertRaises(ValueError):
172+
api_config.read_key('')
173+
174+
def test_instance_read_key_should_raise_error_when_empty(self):
175+
save_key("", TEST_KEY_FILE)
176+
api_config = ApiConfig()
177+
with self.assertRaises(ValueError):
178+
# read empty file
179+
api_config.read_key(TEST_KEY_FILE)
180+
181+
def test_instance_read_the_right_key(self):
182+
expected_key = 'ilovepython'
183+
save_key(expected_key, TEST_KEY_FILE)
184+
api_config = ApiConfig()
185+
api_config.api_key = ''
186+
api_config.read_key(TEST_KEY_FILE)
187+
self.assertEqual(ApiConfig.api_key, expected_key)
188+
189+

test/test_authorized_session.py

+64
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
import unittest
2+
from nasdaqdatalink.model.authorized_session import AuthorizedSession
3+
from nasdaqdatalink.api_config import ApiConfig
4+
from requests.sessions import Session
5+
from requests.adapters import HTTPAdapter
6+
from mock import patch
7+
8+
9+
class AuthorizedSessionTest(unittest.TestCase):
10+
def test_authorized_session_assign_correct_internal_config(self):
11+
authed_session = AuthorizedSession()
12+
self.assertTrue(issubclass(authed_session._api_config, ApiConfig))
13+
authed_session = AuthorizedSession(None)
14+
self.assertTrue(issubclass(authed_session._api_config, ApiConfig))
15+
api_config = ApiConfig()
16+
authed_session = AuthorizedSession(api_config)
17+
self.assertTrue(isinstance(authed_session._api_config, ApiConfig))
18+
19+
def test_authorized_session_pass_created_session(self):
20+
ApiConfig.use_retries = True
21+
ApiConfig.number_of_retries = 130
22+
authed_session = AuthorizedSession()
23+
self.assertTrue(isinstance(authed_session._auth_session, Session))
24+
adapter = authed_session._auth_session.get_adapter(ApiConfig.api_protocol)
25+
self.assertTrue(isinstance(adapter, HTTPAdapter))
26+
self.assertEqual(adapter.max_retries.connect, 130)
27+
28+
@patch("nasdaqdatalink.get")
29+
def test_call_get_with_session_and_api_config(self, mock):
30+
api_config = ApiConfig()
31+
authed_session = AuthorizedSession(api_config)
32+
authed_session.get('WIKI/AAPL')
33+
mock.assert_called_with('WIKI/AAPL', api_config=api_config,
34+
session=authed_session._auth_session)
35+
36+
@patch("nasdaqdatalink.bulkdownload")
37+
def test_call_bulkdownload_with_session_and_api_config(self, mock):
38+
api_config = ApiConfig()
39+
authed_session = AuthorizedSession(api_config)
40+
authed_session.bulkdownload('NSE')
41+
mock.assert_called_with('NSE', api_config=api_config,
42+
session=authed_session._auth_session)
43+
44+
@patch("nasdaqdatalink.export_table")
45+
def test_call_export_table_with_session_and_api_config(self, mock):
46+
authed_session = AuthorizedSession()
47+
authed_session.export_table('WIKI/AAPL')
48+
mock.assert_called_with('WIKI/AAPL', api_config=ApiConfig,
49+
session=authed_session._auth_session)
50+
51+
@patch("nasdaqdatalink.get_table")
52+
def test_call_get_table_with_session_and_api_config(self, mock):
53+
authed_session = AuthorizedSession()
54+
authed_session.get_table('WIKI/AAPL')
55+
mock.assert_called_with('WIKI/AAPL', api_config=ApiConfig,
56+
session=authed_session._auth_session)
57+
58+
@patch("nasdaqdatalink.get_point_in_time")
59+
def test_call_get_point_in_time_with_session_and_api_config(self, mock):
60+
authed_session = AuthorizedSession()
61+
authed_session.get_point_in_time('DATABASE/CODE', interval='asofdate', date='2020-01-01')
62+
mock.assert_called_with('DATABASE/CODE', interval='asofdate',
63+
date='2020-01-01', api_config=ApiConfig,
64+
session=authed_session._auth_session)

0 commit comments

Comments
 (0)