Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
add test and fix lint
  • Loading branch information
runawaycoast committed Aug 16, 2022
commit 1a990da91aca7aa89ecb72da114367cbf67ae33f
26 changes: 12 additions & 14 deletions nasdaqdatalink/model/authorized_session.py
Original file line number Diff line number Diff line change
@@ -1,16 +1,12 @@
import nasdaqdatalink
from nasdaqdatalink.api_config import ApiConfig
from nasdaqdatalink.get import get
from nasdaqdatalink.bulkdownload import bulkdownload
from nasdaqdatalink.export_table import export_table
from nasdaqdatalink.get_table import get_table
from nasdaqdatalink.get_point_in_time import get_point_in_time
from urllib3.util.retry import Retry
from requests.adapters import HTTPAdapter
import requests
import urllib


def get_retries(api_config=ApiConfig):
def get_retries(api_config=nasdaqdatalink.ApiConfig):
retries = None
if not api_config.use_retries:
return Retry(total=0)
Expand Down Expand Up @@ -41,19 +37,21 @@ def __init__(self, api_config=ApiConfig) -> None:
self._auth_session.proxies.update(proxies)

def get(self, dataset, **kwargs):
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Outside of this tickets scope, but I would like to see us passing connection details (session, api_config) as explicit arguments, and leave kwargs to params we should send to the API call. This would allow us to avoid all the kwarg popping down the line

get(dataset, session=self._auth_session, api_config=self._api_config, **kwargs)
nasdaqdatalink.get(dataset, session=self._auth_session,
api_config=self._api_config, **kwargs)

def bulkdownload(self, database, **kwargs):
bulkdownload(database, session=self._auth_session, api_config=self._api_config, **kwargs)
nasdaqdatalink.bulkdownload(database, session=self._auth_session,
api_config=self._api_config, **kwargs)

def export_table(self, datatable_code, **kwargs):
export_table(datatable_code, session=self._auth_session,
api_config=self._api_config, **kwargs)
nasdaqdatalink.export_table(datatable_code, session=self._auth_session,
api_config=self._api_config, **kwargs)

def get_table(self, datatable_code, **options):
get_table(datatable_code, session=self._auth_session,
api_config=self._api_config, **options)
nasdaqdatalink.get_table(datatable_code, session=self._auth_session,
api_config=self._api_config, **options)

def get_point_in_time(self, datatable_code, **options):
get_point_in_time(datatable_code, session=self._auth_session,
api_config=self._api_config, **options)
nasdaqdatalink.get_point_in_time(datatable_code, session=self._auth_session,
api_config=self._api_config, **options)
43 changes: 41 additions & 2 deletions test/test_authorized_session.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,12 @@
from unittest import TestCase
import unittest
from nasdaqdatalink.model.authorized_session import AuthorizedSession
from nasdaqdatalink.api_config import ApiConfig
from requests.sessions import Session
from requests.adapters import HTTPAdapter
from mock import patch


class AuthorizedSessionTest(TestCase):
class AuthorizedSessionTest(unittest.TestCase):
def test_authorized_session_assign_correct_internal_config(self):
authed_session = AuthorizedSession()
self.assertTrue(issubclass(authed_session._api_config, ApiConfig))
Expand All @@ -23,3 +24,41 @@ def test_authorized_session_pass_created_session(self):
adapter = authed_session._auth_session.get_adapter(ApiConfig.api_protocol)
self.assertTrue(isinstance(adapter, HTTPAdapter))
self.assertEqual(adapter.max_retries.connect, 130)

@patch("nasdaqdatalink.get")
def test_call_get_with_session_and_api_config(self, mock):
api_config = ApiConfig()
authed_session = AuthorizedSession(api_config)
authed_session.get('WIKI/AAPL')
mock.assert_called_with('WIKI/AAPL', api_config=api_config,
session=authed_session._auth_session)

@patch("nasdaqdatalink.bulkdownload")
def test_call_bulkdownload_with_session_and_api_config(self, mock):
api_config = ApiConfig()
authed_session = AuthorizedSession(api_config)
authed_session.bulkdownload('NSE')
mock.assert_called_with('NSE', api_config=api_config,
session=authed_session._auth_session)

@patch("nasdaqdatalink.export_table")
def test_call_export_table_with_session_and_api_config(self, mock):
authed_session = AuthorizedSession()
authed_session.export_table('WIKI/AAPL')
mock.assert_called_with('WIKI/AAPL', api_config=ApiConfig,
session=authed_session._auth_session)

@patch("nasdaqdatalink.get_table")
def test_call_get_table_with_session_and_api_config(self, mock):
authed_session = AuthorizedSession()
authed_session.get_table('WIKI/AAPL')
mock.assert_called_with('WIKI/AAPL', api_config=ApiConfig,
session=authed_session._auth_session)

@patch("nasdaqdatalink.get_point_in_time")
def test_call_get_point_in_time_with_session_and_api_config(self, mock):
authed_session = AuthorizedSession()
authed_session.get_point_in_time('DATABASE/CODE', interval='asofdate', date='2020-01-01')
mock.assert_called_with('DATABASE/CODE', interval='asofdate',
date='2020-01-01', api_config=ApiConfig,
session=authed_session._auth_session)
50 changes: 48 additions & 2 deletions test/test_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@
NotFoundError, ServiceUnavailableError)
from test.test_retries import ModifyRetrySettingsTestCase
from test.helpers.httpretty_extension import httpretty
import requests
import json
from mock import patch, call
from nasdaqdatalink.version import VERSION
Expand Down Expand Up @@ -59,8 +60,8 @@ def test_non_data_link_error(self, request_method):
httpretty.register_uri(getattr(httpretty, request_method),
"https://data.nasdaq.com/api/v3/databases",
body=json.dumps(
{'foobar':
{'code': 'blah', 'message': 'something went wrong'}}), status=500)
{'foobar':
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm surprised we weren't hit with a lint warning about indentation / formatting.

I'll need to review the lint settings in our CICD.

{'code': 'blah', 'message': 'something went wrong'}}), status=500)
self.assertRaises(
DataLinkError, lambda: Connection.request(request_method, 'databases'))

Expand All @@ -81,3 +82,48 @@ def test_build_request(self, request_method, mock):
'request-source-version': VERSION},
params={'per_page': 10, 'page': 2})
self.assertEqual(mock.call_args, expected)

@parameterized.expand(['GET', 'POST'])
@patch('nasdaqdatalink.connection.Connection.execute_request')
def test_build_request_with_custom_api_config(self, request_method, mock):
ApiConfig.api_key = 'api_token'
ApiConfig.api_version = '2015-04-09'
api_config = ApiConfig()
api_config.api_key = 'custom_api_token'
api_config.api_version = '2022-06-09'
session = requests.session()
params = {'per_page': 10, 'page': 2, 'api_config': api_config, 'session': session}
headers = {'x-custom-header': 'header value'}
Connection.request(request_method, 'databases', headers=headers, params=params)
expected = call(request_method, 'https://data.nasdaq.com/api/v3/databases',
headers={'x-custom-header': 'header value',
'x-api-token': 'custom_api_token',
'accept': ('application/json, '
'application/vnd.data.nasdaq+json;version=2022-06-09'),
'request-source': 'python',
'request-source-version': VERSION},
params={'per_page': 10, 'page': 2,
'session': session, 'api_config': api_config})
self.assertEqual(mock.call_args, expected)

def test_remove_session_and_api_config_param(self):
ApiConfig.api_key = 'api_token'
ApiConfig.api_version = '2015-04-09'
ApiConfig.verify_ssl = True
api_config = ApiConfig()
api_config.api_key = 'custom_api_token'
api_config.api_version = '2022-06-09'
api_config.verify_ssl = False
session = requests.Session()
params = {'per_page': 10, 'page': 2, 'api_config': api_config, 'session': session}
headers = {'x-custom-header': 'header value'}
dummy_response = requests.Response()
dummy_response.status_code = 200
with patch.object(session, 'request', return_value=dummy_response) as mock:
Connection.execute_request(
'GET', 'https://data.nasdaq.com/api/v3/databases', headers=headers, params=params)
mock.assert_called_once_with(method='GET',
url='https://data.nasdaq.com/api/v3/databases',
verify=False,
headers=headers,
params={'per_page': 10, 'page': 2})