From 0d579060b9df0b3ef296f1f66c566b23eaa9b379 Mon Sep 17 00:00:00 2001 From: sunayanag Date: Sat, 30 Aug 2025 10:47:22 +0530 Subject: [PATCH 1/2] Add automatic expires_at calculation from expires_in This change adds automatic calculation of expires_at from expires_in in the OAuth2 token response. This improves the automatic token refresh functionality by ensuring that expires_at is always set correctly without requiring manual intervention. - Added automatic calculation of expires_at when expires_in is present - Uses float precision for accurate expiration tracking - Maintains backward compatibility with existing code - All tests pass on Python 3.9 Fixes #561 --- requests_oauthlib/oauth2_session.py | 21 ++++++++++++++++++++- tests/test_compliance_fixes.py | 4 ++-- tests/test_oauth2_session.py | 23 +++++++++++++++++++++++ 3 files changed, 45 insertions(+), 3 deletions(-) diff --git a/requests_oauthlib/oauth2_session.py b/requests_oauthlib/oauth2_session.py index 93cc4d7..203d77d 100644 --- a/requests_oauthlib/oauth2_session.py +++ b/requests_oauthlib/oauth2_session.py @@ -1,4 +1,5 @@ import logging +import time from oauthlib.common import generate_token, urldecode from oauthlib.oauth2 import WebApplicationClient, InsecureTransportError @@ -171,6 +172,23 @@ def authorized(self): """ return bool(self.access_token) + def _add_expires_at(self, token): + """Add expires_at to token if expires_in is present and expires_at is not. + + OAuth2 responses often include expires_in (seconds until token expires) but + oauthlib expects expires_at (timestamp when token expires) for expiration checks. + + :param token: OAuth2 token dict + :return: Token dict with expires_at if expires_in was present + """ + if token and 'expires_in' in token: + # Keep full precision of time.time() for accurate token expiration + # RFC 6749 requires expires_in to be 1*DIGIT, but some providers send it as string + expires_in = int(token['expires_in']) + # Keep float precision for expires_at since it's an internal field + token['expires_at'] = time.time() + expires_in + return token + def authorization_url(self, url, state=None, **kwargs): """Form an authorization URL. @@ -404,7 +422,7 @@ def fetch_token( r = hook(r) self._client.parse_request_body_response(r.text, scope=self.scope) - self.token = self._client.token + self.token = self._add_expires_at(self._client.token) log.debug("Obtained token %s.", self.token) return self.token @@ -494,6 +512,7 @@ def refresh_token( r = hook(r) self.token = self._client.parse_request_body_response(r.text, scope=self.scope) + self.token = self._add_expires_at(self.token) if "refresh_token" not in self.token: log.debug("No new refresh token given. Re-using old.") self.token["refresh_token"] = refresh_token diff --git a/tests/test_compliance_fixes.py b/tests/test_compliance_fixes.py index 9ad6d09..c5166bd 100644 --- a/tests/test_compliance_fixes.py +++ b/tests/test_compliance_fixes.py @@ -115,7 +115,7 @@ def test_fetch_access_token(self): authorization_response="https://i.b/?code=hello", ) # Times should be close - approx_expires_at = round(time.time()) + 3600 + approx_expires_at = time.time() + 3600 actual_expires_at = token.pop("expires_at") self.assertAlmostEqual(actual_expires_at, approx_expires_at, places=2) @@ -289,7 +289,7 @@ def test_fetch_access_token(self): authorization_response="https://i.b/?code=hello", ) - approx_expires_at = round(time.time()) + 86400 + approx_expires_at = time.time() + 86400 actual_expires_at = token.pop("expires_at") self.assertAlmostEqual(actual_expires_at, approx_expires_at, places=2) diff --git a/tests/test_oauth2_session.py b/tests/test_oauth2_session.py index 7e3e63c..efdf940 100644 --- a/tests/test_oauth2_session.py +++ b/tests/test_oauth2_session.py @@ -488,6 +488,29 @@ def test_token_proxy(self): with self.assertRaises(AttributeError): del sess.token + @mock.patch("time.time", new=lambda: fake_time) + def test_add_expires_at_from_expires_in(self): + """Test that expires_at is correctly calculated from expires_in""" + sess = OAuth2Session("someclientid") + now = fake_time + + # Test with expires_in as string (some providers send it this way) + token = {"access_token": "foo", "expires_in": "3600"} + updated_token = sess._add_expires_at(token) + self.assertIn('expires_at', updated_token) + self.assertAlmostEqual(updated_token['expires_at'], now + 3600, places=2) + + # Test with expires_in as integer (spec-compliant format) + token = {"access_token": "foo", "expires_in": 3600} + updated_token = sess._add_expires_at(token) + self.assertIn('expires_at', updated_token) + self.assertAlmostEqual(updated_token['expires_at'], now + 3600, places=2) + + # Test with missing expires_in (should not modify token) + token = {"access_token": "foo"} + updated_token = sess._add_expires_at(token) + self.assertNotIn('expires_at', updated_token) + def test_authorized_false(self): sess = OAuth2Session("someclientid") self.assertFalse(sess.authorized) From 45fbcfebbf9827891f049b2ac027770ff0c63ad4 Mon Sep 17 00:00:00 2001 From: sginji Date: Sun, 31 Aug 2025 08:36:07 +0530 Subject: [PATCH 2/2] Use int() for expires_at calculation - Changed expires_at calculation to use int() for consistent truncation - Added Date header parsing with fallback to time.time() - Updated tests --- requests_oauthlib/oauth2_session.py | 32 ++++++++++++++++++------ tests/test_compliance_fixes.py | 8 +++--- tests/test_oauth2_session.py | 38 +++++++++++++++++------------ 3 files changed, 51 insertions(+), 27 deletions(-) diff --git a/requests_oauthlib/oauth2_session.py b/requests_oauthlib/oauth2_session.py index 203d77d..3df4a9f 100644 --- a/requests_oauthlib/oauth2_session.py +++ b/requests_oauthlib/oauth2_session.py @@ -1,5 +1,7 @@ import logging import time +import calendar +from datetime import datetime from oauthlib.common import generate_token, urldecode from oauthlib.oauth2 import WebApplicationClient, InsecureTransportError @@ -172,21 +174,37 @@ def authorized(self): """ return bool(self.access_token) - def _add_expires_at(self, token): - """Add expires_at to token if expires_in is present and expires_at is not. + def _add_expires_at(self, token, response_date=None): + """Add expires_at to token if expires_in is present. OAuth2 responses often include expires_in (seconds until token expires) but oauthlib expects expires_at (timestamp when token expires) for expiration checks. + Uses response Date header if provided. Falls back to current time if + Date header is not provided or malformed. + :param token: OAuth2 token dict + :param response_date: Optional Date header value from response (e.g. "Thu, 14 Mar 2024 08:30:00 GMT") :return: Token dict with expires_at if expires_in was present """ if token and 'expires_in' in token: - # Keep full precision of time.time() for accurate token expiration # RFC 6749 requires expires_in to be 1*DIGIT, but some providers send it as string expires_in = int(token['expires_in']) - # Keep float precision for expires_at since it's an internal field - token['expires_at'] = time.time() + expires_in + + # Try to use response Date header if provided + if response_date: + try: + # Parse HTTP date format (RFC 7231) + dt = datetime.strptime(response_date, "%a, %d %b %Y %H:%M:%S GMT") + # Convert UTC time tuple to Unix timestamp (returns integer) + token['expires_at'] = calendar.timegm(dt.utctimetuple()) + expires_in + return token + except (TypeError, ValueError): + # Skip if Date header is malformed or uses non-standard format + log.debug("Failed to parse Date header: %s", response_date) + + # Fall back to current time (truncate to second for conservative expiry) + token['expires_at'] = int(time.time()) + expires_in return token def authorization_url(self, url, state=None, **kwargs): @@ -422,7 +440,7 @@ def fetch_token( r = hook(r) self._client.parse_request_body_response(r.text, scope=self.scope) - self.token = self._add_expires_at(self._client.token) + self.token = self._add_expires_at(self._client.token, response_date=r.headers.get('Date')) log.debug("Obtained token %s.", self.token) return self.token @@ -512,7 +530,7 @@ def refresh_token( r = hook(r) self.token = self._client.parse_request_body_response(r.text, scope=self.scope) - self.token = self._add_expires_at(self.token) + self.token = self._add_expires_at(self.token, response_date=r.headers.get('Date')) if "refresh_token" not in self.token: log.debug("No new refresh token given. Re-using old.") self.token["refresh_token"] = refresh_token diff --git a/tests/test_compliance_fixes.py b/tests/test_compliance_fixes.py index c5166bd..aaec230 100644 --- a/tests/test_compliance_fixes.py +++ b/tests/test_compliance_fixes.py @@ -115,9 +115,9 @@ def test_fetch_access_token(self): authorization_response="https://i.b/?code=hello", ) # Times should be close - approx_expires_at = time.time() + 3600 + approx_expires_at = int(time.time()) + 3600 actual_expires_at = token.pop("expires_at") - self.assertAlmostEqual(actual_expires_at, approx_expires_at, places=2) + self.assertEqual(actual_expires_at, approx_expires_at) # Other token values exact self.assertEqual(token, {"access_token": "mailchimp", "expires_in": 3600}) @@ -289,9 +289,9 @@ def test_fetch_access_token(self): authorization_response="https://i.b/?code=hello", ) - approx_expires_at = time.time() + 86400 + approx_expires_at = int(time.time()) + 86400 actual_expires_at = token.pop("expires_at") - self.assertAlmostEqual(actual_expires_at, approx_expires_at, places=2) + self.assertEqual(actual_expires_at, approx_expires_at) self.assertEqual( token, diff --git a/tests/test_oauth2_session.py b/tests/test_oauth2_session.py index efdf940..6f80b4b 100644 --- a/tests/test_oauth2_session.py +++ b/tests/test_oauth2_session.py @@ -40,7 +40,7 @@ def setUp(self): "access_token": "asdfoiw37850234lkjsdfsdf", "refresh_token": "sldvafkjw34509s8dfsdf", "expires_in": 3600, - "expires_at": fake_time + 3600, + "expires_at": int(fake_time) + 3600, } # use someclientid:someclientsecret to easily differentiate between client and user credentials # these are the values used in oauthlib tests @@ -401,10 +401,10 @@ def test_cleans_previous_token_before_fetching_new_one(self): """ new_token = deepcopy(self.token) - past = time.time() - 7200 now = time.time() - self.token["expires_at"] = past - new_token["expires_at"] = now + 3600 + past = now - 7200 + self.token["expires_at"] = int(past) + new_token["expires_at"] = int(now) + 3600 url = "https://example.com/token" with mock.patch("time.time", lambda: now): @@ -492,24 +492,30 @@ def test_token_proxy(self): def test_add_expires_at_from_expires_in(self): """Test that expires_at is correctly calculated from expires_in""" sess = OAuth2Session("someclientid") - now = fake_time + now = int(fake_time) - # Test with expires_in as string (some providers send it this way) - token = {"access_token": "foo", "expires_in": "3600"} + # Test with missing expires_in (should not modify token) + token = {"access_token": "foo"} updated_token = sess._add_expires_at(token) - self.assertIn('expires_at', updated_token) - self.assertAlmostEqual(updated_token['expires_at'], now + 3600, places=2) + self.assertNotIn('expires_at', updated_token) - # Test with expires_in as integer (spec-compliant format) + # Test with Date header + date_str = "Thu, 14 Mar 2024 08:30:00 GMT" token = {"access_token": "foo", "expires_in": 3600} - updated_token = sess._add_expires_at(token) + updated_token = sess._add_expires_at(token, response_date=date_str) self.assertIn('expires_at', updated_token) - self.assertAlmostEqual(updated_token['expires_at'], now + 3600, places=2) + expected_timestamp = 1710405000 + 3600 # 2024-03-14 08:30:00 UTC + 1 hour + self.assertEqual(updated_token['expires_at'], expected_timestamp) - # Test with missing expires_in (should not modify token) - token = {"access_token": "foo"} - updated_token = sess._add_expires_at(token) - self.assertNotIn('expires_at', updated_token) + # Test with malformed Date header + updated_token = sess._add_expires_at(token, response_date="invalid date format") + self.assertIn('expires_at', updated_token) + self.assertEqual(updated_token['expires_at'], now + 3600) + + # Test with missing Date header + updated_token = sess._add_expires_at(token, response_date=None) + self.assertIn('expires_at', updated_token) + self.assertEqual(updated_token['expires_at'], now + 3600) def test_authorized_false(self): sess = OAuth2Session("someclientid")