From 803dc7114ac1e9414c98f54e4f3078f8ed5a45dc Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Julien=20LE=20M=C3=89NER?= Date: Tue, 28 Mar 2023 11:26:12 +0200 Subject: [PATCH 01/14] feat: Implement generic refresh_user() method, not tested --- oauthenticator/oauth2.py | 107 ++++++++++++++++++++++++++++++++++++++- 1 file changed, 105 insertions(+), 2 deletions(-) diff --git a/oauthenticator/oauth2.py b/oauthenticator/oauth2.py index 5daa3574..9129bedd 100644 --- a/oauthenticator/oauth2.py +++ b/oauthenticator/oauth2.py @@ -7,6 +7,7 @@ import base64 import json import os +import time import uuid from functools import reduce from inspect import isawaitable @@ -695,6 +696,15 @@ def _client_secret_default(self): return client_secret return os.getenv("OAUTH_CLIENT_SECRET", "") + access_token_expiration_env = "OAUTH_ACCESS_TOKEN_EXPIRATION" + access_token_expiration = Unicode( + config=True, + help="""Default expiration, in seconds, of the access token.""" + ) + + def _access_token_expiration_default(self): + return os.getenv(self.access_token_expiration_env, "3600") + validate_server_cert_env = "OAUTH_TLS_VERIFY" validate_server_cert = Bool( config=True, @@ -992,6 +1002,27 @@ def build_access_tokens_request_params(self, handler, data=None): return params + def build_refresh_token_request_params(self, refresh_token): + """ + Builds the parameters that should be passed to the URL request + that renew Access Token from Refresh Token. + Called by the :meth:`oauthenticator.OAuthenticator.refresh_user`. + """ + params = { + "grant_type": "refresh_token", + "refresh_token": refresh_token, + } + + # the client_id and client_secret should not be included in the access token request params + # when basic authentication is used + # ref: https://www.rfc-editor.org/rfc/rfc6749#section-2.3.1 + if self.basic_auth: + params.update( + [("client_id", self.client_id), ("client_secret", self.client_secret)] + ) + + return params + async def get_token_info(self, handler, params): """ Makes a "POST" request to `self.token_url`, with the parameters received as argument. @@ -1082,6 +1113,38 @@ async def token_to_user(self, token_info): validate_cert=self.validate_server_cert, ) + def get_access_token_creation_date(self, token_info): + """ + Returns the access token creation date, in seconds (Unix epoch time). + + Example: 1679994631 + + Args: + token_info: the dictionary returned by the token request (exchanging the OAuth code for an Access Token) + + Returns: + creation_date: a number representing the access token creation date, in seconds (Unix epoch time) + + Called by the :meth:`oauthenticator.OAuthenticator.build_auth_state_dict` + """ + return token_info.get("created_at", time.time()) + + def get_access_token_lifetime(self, token_info): + """ + Returns the access token lifetime, in seconds. + + Example: 7200 + + Args: + token_info: the dictionary returned by the token request (exchanging the OAuth code for an Access Token) + + Returns: + lifetime: a number representing the access token lifetime, in seconds + + Called by the :meth:`oauthenticator.OAuthenticator.build_auth_state_dict` + """ + return token_info.get("expires_in", self.access_token_expiration) + def build_auth_state_dict(self, token_info, user_info): """ Builds the `auth_state` dict that will be returned by a succesfull `authenticate` method call. @@ -1094,6 +1157,8 @@ def build_auth_state_dict(self, token_info, user_info): Returns: auth_state: a dictionary of auth state that should be persisted with the following keys: - "access_token": the access_token + - "created_at": creation date, in seconds, of the access_token + - "expires_in": expiration date, in seconds, of the access_token - "refresh_token": the refresh_token, if available - "id_token": the id_token, if available - "scope": the scopes, if available @@ -1106,8 +1171,10 @@ def build_auth_state_dict(self, token_info, user_info): This method be async. """ - # We know for sure the `access_token` key exists, oterwise we would have errored out already + # We know for sure the `access_token` key exists, otherwise we would have errored out already access_token = token_info["access_token"] + created_at = self.get_access_token_creation_date(token_info) + expires_in = self.get_access_token_lifetime(token_info) refresh_token = token_info.get("refresh_token", None) id_token = token_info.get("id_token", None) @@ -1118,6 +1185,8 @@ def build_auth_state_dict(self, token_info, user_info): return { "access_token": access_token, + "created_at": created_at, + "expires_in": expires_in, "refresh_token": refresh_token, "id_token": id_token, "scope": scope, @@ -1212,8 +1281,42 @@ async def authenticate(self, handler, data=None, **kwargs): """ # build the parameters to be used in the request exchanging the oauth code for the access token access_token_params = self.build_access_tokens_request_params(handler, data) + # call the oauth endpoints + return await self._oauth_call(handler, access_token_params) + + async def refresh_user(self, user, handler=None, **kwargs): + ''' + Renew the Access Token with a valid Refresh Token + ''' + + auth_state = await user.get_auth_state() + if not auth_state: + self.log.info( + "No auth_state found for user %s refresh, need full authentication", + user, + ) + return False + + created_at = auth_state.get('created_at', 0) + expires_in = auth_state.get('expires_in', 0) + is_expired = created_at + expires_in - time.time() < 0 + if not is_expired: + self.log.info( + "access_token still valid for user %s, skip refresh", + user, + ) + return True + + refresh_token_params = self.build_refresh_token_request_params(auth_state['refresh_token']) + return await self._oauth_call(handler, refresh_token_params) + + async def _oauth_call(self, handler, params, data=None): + """ + Common logic shared by authenticate() and refresh_user() + """ + # exchange the oauth code for an access token and get the JSON with info about it - token_info = await self.get_token_info(handler, access_token_params) + token_info = await self.get_token_info(handler, params) # use the access_token to get userdata info user_info = await self.token_to_user(token_info) # extract the username out of the user_info dict and normalize it From 31f8cbb17aa4ff6ef082ecfcb467b270cf363a3f Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 28 Mar 2023 09:40:40 +0000 Subject: [PATCH 02/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- oauthenticator/oauth2.py | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/oauthenticator/oauth2.py b/oauthenticator/oauth2.py index 9129bedd..d6a9e43b 100644 --- a/oauthenticator/oauth2.py +++ b/oauthenticator/oauth2.py @@ -698,8 +698,7 @@ def _client_secret_default(self): access_token_expiration_env = "OAUTH_ACCESS_TOKEN_EXPIRATION" access_token_expiration = Unicode( - config=True, - help="""Default expiration, in seconds, of the access token.""" + config=True, help="""Default expiration, in seconds, of the access token.""" ) def _access_token_expiration_default(self): @@ -1307,7 +1306,9 @@ async def refresh_user(self, user, handler=None, **kwargs): ) return True - refresh_token_params = self.build_refresh_token_request_params(auth_state['refresh_token']) + refresh_token_params = self.build_refresh_token_request_params( + auth_state['refresh_token'] + ) return await self._oauth_call(handler, refresh_token_params) async def _oauth_call(self, handler, params, data=None): From 2c5b30d1a1d26649ff4315ce54ce26eee1d1c185 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Julien=20LE=20M=C3=89NER?= Date: Tue, 28 Mar 2023 12:07:49 +0200 Subject: [PATCH 03/14] Add missing kwargs for test, fix format --- oauthenticator/oauth2.py | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/oauthenticator/oauth2.py b/oauthenticator/oauth2.py index d6a9e43b..5cf5b8a9 100644 --- a/oauthenticator/oauth2.py +++ b/oauthenticator/oauth2.py @@ -1281,7 +1281,7 @@ async def authenticate(self, handler, data=None, **kwargs): # build the parameters to be used in the request exchanging the oauth code for the access token access_token_params = self.build_access_tokens_request_params(handler, data) # call the oauth endpoints - return await self._oauth_call(handler, access_token_params) + return await self._oauth_call(handler, access_token_params, **kwargs) async def refresh_user(self, user, handler=None, **kwargs): ''' @@ -1309,9 +1309,9 @@ async def refresh_user(self, user, handler=None, **kwargs): refresh_token_params = self.build_refresh_token_request_params( auth_state['refresh_token'] ) - return await self._oauth_call(handler, refresh_token_params) + return await self._oauth_call(handler, refresh_token_params, **kwargs) - async def _oauth_call(self, handler, params, data=None): + async def _oauth_call(self, handler, params, data=None, **kwargs): """ Common logic shared by authenticate() and refresh_user() """ From 4e7943bbe94ab3a91d1b213c7149b1826a77f266 Mon Sep 17 00:00:00 2001 From: Georgiana Dolocan Date: Mon, 2 Sep 2024 20:35:30 +0300 Subject: [PATCH 04/14] Refactor and rm the env var support --- oauthenticator/oauth2.py | 65 +++++++++++----------------------------- 1 file changed, 17 insertions(+), 48 deletions(-) diff --git a/oauthenticator/oauth2.py b/oauthenticator/oauth2.py index 5cf5b8a9..00c2aab7 100644 --- a/oauthenticator/oauth2.py +++ b/oauthenticator/oauth2.py @@ -696,14 +696,16 @@ def _client_secret_default(self): return client_secret return os.getenv("OAUTH_CLIENT_SECRET", "") - access_token_expiration_env = "OAUTH_ACCESS_TOKEN_EXPIRATION" access_token_expiration = Unicode( - config=True, help="""Default expiration, in seconds, of the access token.""" + config=True, + default="3600", + help=""" + If the `expires_in` field is omitted in the OAuth 2.0 token response + then this value will be the default expiration, in seconds, of the + access token. + """ ) - def _access_token_expiration_default(self): - return os.getenv(self.access_token_expiration_env, "3600") - validate_server_cert_env = "OAUTH_TLS_VERIFY" validate_server_cert = Bool( config=True, @@ -1004,7 +1006,8 @@ def build_access_tokens_request_params(self, handler, data=None): def build_refresh_token_request_params(self, refresh_token): """ Builds the parameters that should be passed to the URL request - that renew Access Token from Refresh Token. + to renew the Access Token based on the Refresh Token + Called by the :meth:`oauthenticator.OAuthenticator.refresh_user`. """ params = { @@ -1112,41 +1115,9 @@ async def token_to_user(self, token_info): validate_cert=self.validate_server_cert, ) - def get_access_token_creation_date(self, token_info): - """ - Returns the access token creation date, in seconds (Unix epoch time). - - Example: 1679994631 - - Args: - token_info: the dictionary returned by the token request (exchanging the OAuth code for an Access Token) - - Returns: - creation_date: a number representing the access token creation date, in seconds (Unix epoch time) - - Called by the :meth:`oauthenticator.OAuthenticator.build_auth_state_dict` - """ - return token_info.get("created_at", time.time()) - - def get_access_token_lifetime(self, token_info): - """ - Returns the access token lifetime, in seconds. - - Example: 7200 - - Args: - token_info: the dictionary returned by the token request (exchanging the OAuth code for an Access Token) - - Returns: - lifetime: a number representing the access token lifetime, in seconds - - Called by the :meth:`oauthenticator.OAuthenticator.build_auth_state_dict` - """ - return token_info.get("expires_in", self.access_token_expiration) - def build_auth_state_dict(self, token_info, user_info): """ - Builds the `auth_state` dict that will be returned by a succesfull `authenticate` method call. + Builds the `auth_state` dict that will be returned by a successful `authenticate` method call. May be async (requires oauthenticator >= 17.0). Args: @@ -1172,8 +1143,8 @@ def build_auth_state_dict(self, token_info, user_info): # We know for sure the `access_token` key exists, otherwise we would have errored out already access_token = token_info["access_token"] - created_at = self.get_access_token_creation_date(token_info) - expires_in = self.get_access_token_lifetime(token_info) + created_at = token_info.get("created_at", time.time()) + expires_in = token_info.get("expires_in", self.access_token_expiration) refresh_token = token_info.get("refresh_token", None) id_token = token_info.get("id_token", None) @@ -1284,15 +1255,14 @@ async def authenticate(self, handler, data=None, **kwargs): return await self._oauth_call(handler, access_token_params, **kwargs) async def refresh_user(self, user, handler=None, **kwargs): - ''' + """ Renew the Access Token with a valid Refresh Token - ''' + """ auth_state = await user.get_auth_state() if not auth_state: self.log.info( - "No auth_state found for user %s refresh, need full authentication", - user, + f"No auth_state found for user {user} refresh, need full authentication", ) return False @@ -1301,8 +1271,7 @@ async def refresh_user(self, user, handler=None, **kwargs): is_expired = created_at + expires_in - time.time() < 0 if not is_expired: self.log.info( - "access_token still valid for user %s, skip refresh", - user, + f"The access_token is still valid for user {user}, skipping refresh", ) return True @@ -1311,7 +1280,7 @@ async def refresh_user(self, user, handler=None, **kwargs): ) return await self._oauth_call(handler, refresh_token_params, **kwargs) - async def _oauth_call(self, handler, params, data=None, **kwargs): + async def _oauth_call(self, handler, params, **kwargs): """ Common logic shared by authenticate() and refresh_user() """ From c0fbad443e2cc3318c191bcd2a59407c1be08481 Mon Sep 17 00:00:00 2001 From: Georgiana Dolocan Date: Mon, 2 Sep 2024 20:35:52 +0300 Subject: [PATCH 05/14] Get the conditional right --- oauthenticator/oauth2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/oauthenticator/oauth2.py b/oauthenticator/oauth2.py index 00c2aab7..9ac5416c 100644 --- a/oauthenticator/oauth2.py +++ b/oauthenticator/oauth2.py @@ -1018,7 +1018,7 @@ def build_refresh_token_request_params(self, refresh_token): # the client_id and client_secret should not be included in the access token request params # when basic authentication is used # ref: https://www.rfc-editor.org/rfc/rfc6749#section-2.3.1 - if self.basic_auth: + if not self.basic_auth: params.update( [("client_id", self.client_id), ("client_secret", self.client_secret)] ) From 61f1b03eeb31ccd91c38985cd592a68461cb2ae1 Mon Sep 17 00:00:00 2001 From: Georgiana Dolocan Date: Mon, 2 Sep 2024 21:27:54 +0300 Subject: [PATCH 06/14] Add comment link to rfc --- oauthenticator/oauth2.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/oauthenticator/oauth2.py b/oauthenticator/oauth2.py index 9ac5416c..70ebb72d 100644 --- a/oauthenticator/oauth2.py +++ b/oauthenticator/oauth2.py @@ -1030,7 +1030,8 @@ async def get_token_info(self, handler, params): Makes a "POST" request to `self.token_url`, with the parameters received as argument. Returns: - the JSON response to the `token_url` the request. + the JSON response to the `token_url` the request as described in + https://www.rfc-editor.org/rfc/rfc6749#section-5.1 Called by the :meth:`oauthenticator.OAuthenticator.authenticate` """ From 3ef16dfcfd2d7d4e5bb3528fca24adf1e7c8dd1b Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Mon, 2 Sep 2024 18:51:02 +0000 Subject: [PATCH 07/14] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- oauthenticator/oauth2.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/oauthenticator/oauth2.py b/oauthenticator/oauth2.py index 70ebb72d..0bc01a23 100644 --- a/oauthenticator/oauth2.py +++ b/oauthenticator/oauth2.py @@ -703,7 +703,7 @@ def _client_secret_default(self): If the `expires_in` field is omitted in the OAuth 2.0 token response then this value will be the default expiration, in seconds, of the access token. - """ + """, ) validate_server_cert_env = "OAUTH_TLS_VERIFY" From 6156c936f209ff1a0bb32e34a6aa7a2f80e838a0 Mon Sep 17 00:00:00 2001 From: Min RK Date: Thu, 17 Oct 2024 11:06:23 +0200 Subject: [PATCH 08/14] remove duplicate auth expiration config refresh_user should always refresh auth, JupyterHub config already exists to determine expiration --- oauthenticator/oauth2.py | 30 ++---------------------------- 1 file changed, 2 insertions(+), 28 deletions(-) diff --git a/oauthenticator/oauth2.py b/oauthenticator/oauth2.py index f2e6481e..5c37f0db 100644 --- a/oauthenticator/oauth2.py +++ b/oauthenticator/oauth2.py @@ -746,16 +746,6 @@ def _client_secret_default(self): return client_secret return os.getenv("OAUTH_CLIENT_SECRET", "") - access_token_expiration = Unicode( - config=True, - default="3600", - help=""" - If the `expires_in` field is omitted in the OAuth 2.0 token response - then this value will be the default expiration, in seconds, of the - access token. - """, - ) - validate_server_cert_env = "OAUTH_TLS_VERIFY" validate_server_cert = Bool( config=True, @@ -1082,7 +1072,7 @@ def build_refresh_token_request_params(self, refresh_token): # ref: https://www.rfc-editor.org/rfc/rfc6749#section-2.3.1 if not self.basic_auth: params.update( - [("client_id", self.client_id), ("client_secret", self.client_secret)] + {"client_id": self.client_id, "client_secret": self.client_secret,} ) return params @@ -1190,8 +1180,6 @@ def build_auth_state_dict(self, token_info, user_info): Returns: auth_state: a dictionary of auth state that should be persisted with the following keys: - "access_token": the access_token - - "created_at": creation date, in seconds, of the access_token - - "expires_in": expiration date, in seconds, of the access_token - "refresh_token": the refresh_token, if available - "id_token": the id_token, if available - "scope": the scopes, if available @@ -1206,8 +1194,6 @@ def build_auth_state_dict(self, token_info, user_info): # We know for sure the `access_token` key exists, otherwise we would have errored out already access_token = token_info["access_token"] - created_at = token_info.get("created_at", time.time()) - expires_in = token_info.get("expires_in", self.access_token_expiration) refresh_token = token_info.get("refresh_token", None) id_token = token_info.get("id_token", None) @@ -1218,8 +1204,6 @@ def build_auth_state_dict(self, token_info, user_info): return { "access_token": access_token, - "created_at": created_at, - "expires_in": expires_in, "refresh_token": refresh_token, "id_token": id_token, "scope": scope, @@ -1335,23 +1319,13 @@ async def refresh_user(self, user, handler=None, **kwargs): """ Renew the Access Token with a valid Refresh Token """ - auth_state = await user.get_auth_state() if not auth_state: self.log.info( - f"No auth_state found for user {user} refresh, need full authentication", + f"No auth_state found for user {user.name} refresh, need full authentication", ) return False - created_at = auth_state.get('created_at', 0) - expires_in = auth_state.get('expires_in', 0) - is_expired = created_at + expires_in - time.time() < 0 - if not is_expired: - self.log.info( - f"The access_token is still valid for user {user}, skipping refresh", - ) - return True - refresh_token_params = self.build_refresh_token_request_params( auth_state['refresh_token'] ) From 3d37ff89885b319cd457fd71870f034c60320e3d Mon Sep 17 00:00:00 2001 From: Min RK Date: Thu, 17 Oct 2024 13:17:59 +0200 Subject: [PATCH 09/14] more cases for refresh_user - do not refresh if auth_state is disabled (would force re-login every 5 minutes in default config) - always refresh if refresh_token is defined - if refresh_token not available, only check validity of access_token and refresh associated user info --- oauthenticator/oauth2.py | 64 ++++++++++++++++---------- oauthenticator/tests/mocks.py | 52 ++++++++++++++++----- oauthenticator/tests/test_generic.py | 68 ++++++++++++++++++++++++++++ oauthenticator/tests/test_github.py | 2 +- 4 files changed, 151 insertions(+), 35 deletions(-) diff --git a/oauthenticator/oauth2.py b/oauthenticator/oauth2.py index 5c37f0db..4cceaba6 100644 --- a/oauthenticator/oauth2.py +++ b/oauthenticator/oauth2.py @@ -9,7 +9,6 @@ import json import os import secrets -import time import uuid from functools import reduce from inspect import isawaitable @@ -1071,9 +1070,8 @@ def build_refresh_token_request_params(self, refresh_token): # when basic authentication is used # ref: https://www.rfc-editor.org/rfc/rfc6749#section-2.3.1 if not self.basic_auth: - params.update( - {"client_id": self.client_id, "client_secret": self.client_secret,} - ) + params["client_id"] = self.client_id + params["client_secret"] = self.client_secret return params @@ -1312,48 +1310,68 @@ async def authenticate(self, handler, data=None, **kwargs): """ # build the parameters to be used in the request exchanging the oauth code for the access token access_token_params = self.build_access_tokens_request_params(handler, data) + token_info = await self.get_token_info(handler, access_token_params) # call the oauth endpoints - return await self._oauth_call(handler, access_token_params, **kwargs) + return await self._token_to_auth_model(token_info) async def refresh_user(self, user, handler=None, **kwargs): """ Renew the Access Token with a valid Refresh Token """ + if not self.enable_auth_state: + # auth state not enabled, can't refresh + return True auth_state = await user.get_auth_state() if not auth_state: self.log.info( f"No auth_state found for user {user.name} refresh, need full authentication", ) return False + refresh_token = auth_state.get("refresh_token", None) + if refresh_token: + refresh_token_params = self.build_refresh_token_request_params( + refresh_token + ) + try: + token_info = await self.get_token_info(handler, refresh_token_params) + except Exception as e: + self.log.info( + f"Error using refresh_token for {user.name}: {e}. Treating auth info as expired." + ) + return False + # refresh_token may not be returned when refreshing a token + if not token_info.get("refresh_token"): + token_info["refresh_token"] = refresh_token + else: + # no refresh token, check access token validity + self.log.debug( + f"No refresh token for user {user.name}, checking access_token validity" + ) + token_info = auth_state.get("token_response") + try: + auth_model = await self._token_to_auth_model(token_info) + except Exception as e: + # handle more specific errors? + # e.g. expired token! + self.log.info( + f"Error refreshing auth with access_token for {user.name}: {e}. Treating auth info as expired." + ) + return False + else: + # return False if auth_model is None for no-longer-authorized + return auth_model or False - refresh_token_params = self.build_refresh_token_request_params( - auth_state['refresh_token'] - ) - return await self._oauth_call(handler, refresh_token_params, **kwargs) - - async def _oauth_call(self, handler, params, **kwargs): + async def _token_to_auth_model(self, token_info): """ Common logic shared by authenticate() and refresh_user() """ - # exchange the oauth code for an access token and get the JSON with info about it - token_info = await self.get_token_info(handler, params) # use the access_token to get userdata info user_info = await self.token_to_user(token_info) # extract the username out of the user_info dict and normalize it username = self.user_info_to_username(user_info) username = self.normalize_username(username) - # check if there any refresh_token in the token_info dict - refresh_token = token_info.get("refresh_token", None) - if self.enable_auth_state and not refresh_token: - self.log.debug( - "Refresh token was empty, will try to pull refresh_token from previous auth_state" - ) - refresh_token = await self.get_prev_refresh_token(handler, username) - if refresh_token: - token_info["refresh_token"] = refresh_token - auth_state = self.build_auth_state_dict(token_info, user_info) if isawaitable(auth_state): auth_state = await auth_state diff --git a/oauthenticator/tests/mocks.py b/oauthenticator/tests/mocks.py index d9174604..5186f5ef 100644 --- a/oauthenticator/tests/mocks.py +++ b/oauthenticator/tests/mocks.py @@ -107,6 +107,7 @@ def setup_oauth_mock( user_path=None, token_type='Bearer', token_request_style='post', + enable_refresh_tokens=False, scope="", ): """setup the mock client for OAuth @@ -134,6 +135,8 @@ def setup_oauth_mock( client.oauth_codes = oauth_codes = {} client.access_tokens = access_tokens = {} + client.refresh_tokens = refresh_tokens = {} + client.enable_refresh_tokens = enable_refresh_tokens def access_token(request): """Handler for access token endpoint @@ -146,26 +149,53 @@ def access_token(request): if not query: query = request.body.decode('utf8') query = parse_qs(query) - if 'code' not in query: + grant_type = query.get("grant_type", [""])[0] + if grant_type == 'authorization_code': + if 'code' not in query: + return HTTPResponse( + request=request, + code=400, + reason=f"No code in access token request: url={request.url}, body={request.body}", + ) + code = query['code'][0] + if code not in oauth_codes: + return HTTPResponse( + request=request, code=403, reason=f"No such code: {code}" + ) + user = oauth_codes.pop(code) + elif grant_type == 'refresh_token': + if 'refresh_token' not in query: + return HTTPResponse( + request=request, + code=400, + reason=f"No refresh_token in access token request: url={request.url}, body={request.body}", + ) + refresh_token = query['refresh_token'][0] + if refresh_token not in refresh_token: + return HTTPResponse( + request=request, + code=403, + reason=f"No such refresh_toekn: {refresh_token}", + ) + user = refresh_tokens[refresh_token] + else: return HTTPResponse( request=request, code=400, - reason=f"No code in access token request: url={request.url}, body={request.body}", - ) - code = query['code'][0] - if code not in oauth_codes: - return HTTPResponse( - request=request, code=403, reason=f"No such code: {code}" + reason=f"Invalid grant_type={grant_type}: url={request.url}, body={request.body}", ) # consume code, allocate token - token = uuid.uuid4().hex - user = oauth_codes.pop(code) - access_tokens[token] = user + access_token = uuid.uuid4().hex + access_tokens[access_token] = user model = { - 'access_token': token, + 'access_token': access_token, 'token_type': token_type, } + if client.enable_refresh_tokens: + refresh_token = uuid.uuid4().hex + refresh_tokens[refresh_token] = user + model['refresh_token'] = refresh_token if scope: model['scope'] = scope if 'id_token' in user: diff --git a/oauthenticator/tests/test_generic.py b/oauthenticator/tests/test_generic.py index e3225039..bb14f72f 100644 --- a/oauthenticator/tests/test_generic.py +++ b/oauthenticator/tests/test_generic.py @@ -505,6 +505,74 @@ async def test_check_allowed_no_auth_state(get_authenticator, name, allowed): assert await authenticator.check_allowed(name, None) +class MockUser: + """Mock subset of JupyterHub User API from the `auth_model` dict""" + + name: str + + def __init__(self, auth_model): + self._auth_model = auth_model + self.name = auth_model["name"] + + async def get_auth_state(self): + return self._auth_model["auth_state"] + + +@mark.parametrize("enable_refresh_tokens", [True, False]) +async def test_refresh_user(get_authenticator, generic_client, enable_refresh_tokens): + generic_client.enable_refresh_tokens = enable_refresh_tokens + authenticator = get_authenticator(allowed_users={"user1"}) + handled_user_model = user_model("user1", permissions={"groups": ["super_user"]}) + handler = generic_client.handler_for_user(handled_user_model) + auth_model = await authenticator.get_authenticated_user(handler, None) + auth_state = auth_model["auth_state"] + if enable_refresh_tokens: + assert "refresh_token" in auth_state + assert "refresh_token" in auth_state["token_response"] + assert ( + auth_state["refresh_token"] == auth_state["token_response"]["refresh_token"] + ) + else: + assert "refresh_token" not in auth_state["token_response"] + assert auth_state.get("refresh_token") is None + user = MockUser(auth_model) + # case: auth_state not enabled, nothing to refresh + refreshed = await authenticator.refresh_user(user, handler) + assert refreshed is True + + # from here on, enable auth state required for refresh to do anything + authenticator.enable_auth_state = True + + # case: no auth state, but auth state enabled needs refresh + auth_without_state = auth_model.copy() + auth_without_state["auth_state"] = None + user_without_state = MockUser(auth_without_state) + refreshed = await authenticator.refresh_user(user_without_state, handler) + assert refreshed is False + + # case: actually refresh + refreshed = await authenticator.refresh_user(user, handler) + assert isinstance(refreshed, dict) + assert refreshed["name"] == auth_model["name"] + refreshed_state = refreshed["auth_state"] + assert "access_token" in refreshed_state + if enable_refresh_tokens: + # refresh_token refreshed the access token + assert refreshed_state["access_token"] != auth_state["access_token"] + assert refreshed_state["refresh_token"] + else: + # refresh with access token succeeds, keeps access token unchanged + assert refreshed_state["access_token"] == auth_state["access_token"] + + # case: token used for refresh is no longer valid + user = MockUser(refreshed) + generic_client.access_tokens.pop(refreshed_state["access_token"]) + if enable_refresh_tokens: + generic_client.refresh_tokens.pop(refreshed_state["refresh_token"]) + refreshed = await authenticator.refresh_user(user, handler) + assert refreshed is False + + @mark.parametrize( "test_variation_id,class_config,expect_config,expect_loglevel,expect_message", [ diff --git a/oauthenticator/tests/test_github.py b/oauthenticator/tests/test_github.py index e49fe064..1ca5f590 100644 --- a/oauthenticator/tests/test_github.py +++ b/oauthenticator/tests/test_github.py @@ -141,7 +141,7 @@ async def test_github( assert user_info == handled_user_model assert auth_model["name"] == user_info[authenticator.username_claim] else: - assert auth_model == None + assert auth_model is None def make_link_header(urlinfo, page): From a8fe50012eba40ec7dd5bf546ab719158d5314df Mon Sep 17 00:00:00 2001 From: Min RK Date: Thu, 17 Oct 2024 14:09:25 +0200 Subject: [PATCH 10/14] only refresh expired access tokens --- oauthenticator/oauth2.py | 48 ++++++++++++++++------------ oauthenticator/tests/test_generic.py | 47 ++++++++++++++++++--------- 2 files changed, 60 insertions(+), 35 deletions(-) diff --git a/oauthenticator/oauth2.py b/oauthenticator/oauth2.py index 4cceaba6..123557cb 100644 --- a/oauthenticator/oauth2.py +++ b/oauthenticator/oauth2.py @@ -1327,8 +1327,21 @@ async def refresh_user(self, user, handler=None, **kwargs): f"No auth_state found for user {user.name} refresh, need full authentication", ) return False + + token_info = auth_state.get("token_response") + auth_model = None + try: + auth_model = await self._token_to_auth_model(token_info) + except Exception as e: + # usually this means the access token has expired + # handle more specific errors? + self.log.info( + f"Error refreshing auth with current access_token for {user.name}: {e}. Will try to refresh, if possible." + ) refresh_token = auth_state.get("refresh_token", None) - if refresh_token: + if refresh_token and not auth_model: + self.log.info(f"Refreshing oauth access token for {user.name}") + # access_token expired, try refreshing with refresh_token refresh_token_params = self.build_refresh_token_request_params( refresh_token ) @@ -1336,30 +1349,25 @@ async def refresh_user(self, user, handler=None, **kwargs): token_info = await self.get_token_info(handler, refresh_token_params) except Exception as e: self.log.info( - f"Error using refresh_token for {user.name}: {e}. Treating auth info as expired." + f"Error using refresh_token for {user.name}: {e}. Requiring fresh login." ) return False # refresh_token may not be returned when refreshing a token + # in which case, keep the current one if not token_info.get("refresh_token"): token_info["refresh_token"] = refresh_token - else: - # no refresh token, check access token validity - self.log.debug( - f"No refresh token for user {user.name}, checking access_token validity" - ) - token_info = auth_state.get("token_response") - try: - auth_model = await self._token_to_auth_model(token_info) - except Exception as e: - # handle more specific errors? - # e.g. expired token! - self.log.info( - f"Error refreshing auth with access_token for {user.name}: {e}. Treating auth info as expired." - ) - return False - else: - # return False if auth_model is None for no-longer-authorized - return auth_model or False + try: + auth_model = await self._token_to_auth_model(token_info) + except Exception as e: + # this means we were issued a fresh access token, + # but it didn't work! Fail harder? + self.log.error( + f"Error refreshing auth with fresh access_token for {user.name}: {e}. Requiring fresh login." + ) + return False + + # return False if auth_model is None for "needs new login" + return auth_model or False async def _token_to_auth_model(self, token_info): """ diff --git a/oauthenticator/tests/test_generic.py b/oauthenticator/tests/test_generic.py index bb14f72f..e5c99602 100644 --- a/oauthenticator/tests/test_generic.py +++ b/oauthenticator/tests/test_generic.py @@ -15,14 +15,15 @@ def user_model(username, **kwargs): """Return a user model""" - return { + model = { "username": username, "aud": client_id, "sub": "oauth2|cilogon|http://cilogon.org/servera/users/43431", "scope": "basic", "groups": ["group1"], - **kwargs, } + model.update(kwargs) + return model @fixture(params=["id_token", "userdata_url"]) @@ -522,10 +523,13 @@ async def get_auth_state(self): async def test_refresh_user(get_authenticator, generic_client, enable_refresh_tokens): generic_client.enable_refresh_tokens = enable_refresh_tokens authenticator = get_authenticator(allowed_users={"user1"}) - handled_user_model = user_model("user1", permissions={"groups": ["super_user"]}) - handler = generic_client.handler_for_user(handled_user_model) + authenticator.manage_groups = True + authenticator.auth_state_groups_key = "oauth_user.groups" + oauth_userinfo = user_model("user1", groups=["round1"]) + handler = generic_client.handler_for_user(oauth_userinfo) auth_model = await authenticator.get_authenticated_user(handler, None) auth_state = auth_model["auth_state"] + assert auth_model["groups"] == ["round1"] if enable_refresh_tokens: assert "refresh_token" in auth_state assert "refresh_token" in auth_state["token_response"] @@ -551,26 +555,39 @@ async def test_refresh_user(get_authenticator, generic_client, enable_refresh_to assert refreshed is False # case: actually refresh + oauth_userinfo["groups"] = ["refreshed"] refreshed = await authenticator.refresh_user(user, handler) - assert isinstance(refreshed, dict) + assert refreshed assert refreshed["name"] == auth_model["name"] + assert refreshed["groups"] == ["refreshed"] refreshed_state = refreshed["auth_state"] assert "access_token" in refreshed_state + # refresh with access token succeeds, keeps tokens unchanged + assert refreshed_state.get("refresh_token") == auth_state.get("refresh_token") + assert refreshed_state["access_token"] == auth_state["access_token"] + + # case: access token is no longer valid, triggers refresh + oauth_userinfo["groups"] = ["token_refreshed"] + generic_client.access_tokens.pop(refreshed_state["access_token"]) + refreshed = await authenticator.refresh_user(user, handler) if enable_refresh_tokens: - # refresh_token refreshed the access token - assert refreshed_state["access_token"] != auth_state["access_token"] - assert refreshed_state["refresh_token"] + # access_token refreshed + assert refreshed + refreshed_state = refreshed["auth_state"] + assert ( + refreshed_state["access_token"] != auth_model["auth_state"]["access_token"] + ) + assert refreshed["groups"] == ["token_refreshed"] else: - # refresh with access token succeeds, keeps access token unchanged - assert refreshed_state["access_token"] == auth_state["access_token"] + assert refreshed is False - # case: token used for refresh is no longer valid - user = MockUser(refreshed) - generic_client.access_tokens.pop(refreshed_state["access_token"]) if enable_refresh_tokens: + # case: token used for refresh is no longer valid + user = MockUser(refreshed) + generic_client.access_tokens.pop(refreshed_state["access_token"]) generic_client.refresh_tokens.pop(refreshed_state["refresh_token"]) - refreshed = await authenticator.refresh_user(user, handler) - assert refreshed is False + refreshed = await authenticator.refresh_user(user, handler) + assert refreshed is False @mark.parametrize( From 19001676de1665af420b1bb880359d9cea4b9ba7 Mon Sep 17 00:00:00 2001 From: Min RK Date: Thu, 7 Nov 2024 14:47:57 +0100 Subject: [PATCH 11/14] more specific error handling for failure to refresh auth only consider HTTP 4XX expired auth --- oauthenticator/oauth2.py | 20 ++++++++++++++------ 1 file changed, 14 insertions(+), 6 deletions(-) diff --git a/oauthenticator/oauth2.py b/oauthenticator/oauth2.py index 123557cb..626ec092 100644 --- a/oauthenticator/oauth2.py +++ b/oauthenticator/oauth2.py @@ -1320,6 +1320,7 @@ async def refresh_user(self, user, handler=None, **kwargs): """ if not self.enable_auth_state: # auth state not enabled, can't refresh + self.log.debug("auth_state disabled, no auth state to refresh") return True auth_state = await user.get_auth_state() if not auth_state: @@ -1332,12 +1333,15 @@ async def refresh_user(self, user, handler=None, **kwargs): auth_model = None try: auth_model = await self._token_to_auth_model(token_info) - except Exception as e: - # usually this means the access token has expired - # handle more specific errors? - self.log.info( - f"Error refreshing auth with current access_token for {user.name}: {e}. Will try to refresh, if possible." - ) + except HTTPClientError as e: + # assume any client error means an expired token + # most likely 401 or 403 for well-behaved providers + if 400 <= e.code < 500: + self.log.info( + f"Error refreshing auth with current access_token for {user.name}: {e}. Will try to refresh, if possible." + ) + else: + raise refresh_token = auth_state.get("refresh_token", None) if refresh_token and not auth_model: self.log.info(f"Refreshing oauth access token for {user.name}") @@ -1352,6 +1356,10 @@ async def refresh_user(self, user, handler=None, **kwargs): f"Error using refresh_token for {user.name}: {e}. Requiring fresh login." ) return False + else: + self.log.debug( + f"Received fresh access_token for {user.name} via refresh_token" + ) # refresh_token may not be returned when refreshing a token # in which case, keep the current one if not token_info.get("refresh_token"): From 1ddea4e58bfbb61cd1a713d511b8f83019466e94 Mon Sep 17 00:00:00 2001 From: Min RK Date: Mon, 11 Nov 2024 08:30:41 +0100 Subject: [PATCH 12/14] remove unused get_prev_refresh_token --- oauthenticator/oauth2.py | 26 -------------------------- 1 file changed, 26 deletions(-) diff --git a/oauthenticator/oauth2.py b/oauthenticator/oauth2.py index 626ec092..e4e728f3 100644 --- a/oauthenticator/oauth2.py +++ b/oauthenticator/oauth2.py @@ -16,7 +16,6 @@ import jwt from jupyterhub.auth import Authenticator -from jupyterhub.crypto import EncryptionUnavailable, InvalidToken, decrypt from jupyterhub.handlers import BaseHandler, LogoutHandler from jupyterhub.utils import url_path_join from tornado import web @@ -988,31 +987,6 @@ def user_info_to_username(self, user_info): return username - # Originally a GoogleOAuthenticator only feature - async def get_prev_refresh_token(self, handler, username): - """ - Retrieves the `refresh_token` from previous encrypted auth state. - Called by the :meth:`oauthenticator.OAuthenticator.authenticate` - """ - user = handler.find_user(username) - if not user or not user.encrypted_auth_state: - return - - self.log.debug( - "Encrypted_auth_state was found, will try to decrypt and pull refresh_token from it..." - ) - - try: - encrypted = user.encrypted_auth_state - auth_state = await decrypt(encrypted) - - return auth_state.get("refresh_token") - except (ValueError, InvalidToken, EncryptionUnavailable) as e: - self.log.warning( - f"Failed to retrieve encrypted auth_state for {username}. Error was {e}.", - ) - return - def build_access_tokens_request_params(self, handler, data=None): """ Builds the parameters that should be passed to the URL request From 12720978c8dc3352eefe75e8f8a985ea73693e8e Mon Sep 17 00:00:00 2001 From: Min RK Date: Mon, 11 Nov 2024 09:16:18 +0100 Subject: [PATCH 13/14] expand docstring for refresh_user --- oauthenticator/oauth2.py | 30 ++++++++++++++++++++++++++++-- 1 file changed, 28 insertions(+), 2 deletions(-) diff --git a/oauthenticator/oauth2.py b/oauthenticator/oauth2.py index e4e728f3..204074fb 100644 --- a/oauthenticator/oauth2.py +++ b/oauthenticator/oauth2.py @@ -1290,11 +1290,37 @@ async def authenticate(self, handler, data=None, **kwargs): async def refresh_user(self, user, handler=None, **kwargs): """ - Renew the Access Token with a valid Refresh Token + Refresh user authentication + + If auth_state is enabled, constructs a fresh user model + (the same as `authenticate`) + using the access_token in auth_state. + If requests with the access token fail + (e.g. because the token has expired) + and a refresh token is found, attempts to exchange + the refresh token for a new access token to store in auth_state. + If the access token still fails after refresh, + return False to require the user to login via oauth again. + + Set `Authenticator.auth_refresh_age = 0` to disable. + + Returns + ------- + + True: + If auth info is up-to-date and needs no changes + (always if `enable_auth_state` is False) + False: + If the user needs to login again + (e.g. tokens in `auth_state` unavailable or expired) + auth_model: dict + The same dict as `authenticate`, updating any fields that should change. + Can include things like group membership, + but in OAuthenticator this mainly refreshes + the token fields in `auth_state`. """ if not self.enable_auth_state: # auth state not enabled, can't refresh - self.log.debug("auth_state disabled, no auth state to refresh") return True auth_state = await user.get_auth_state() if not auth_state: From e3a01a59960f581de75b8956718cec265a638ffa Mon Sep 17 00:00:00 2001 From: Min RK Date: Mon, 11 Nov 2024 09:17:26 +0100 Subject: [PATCH 14/14] Fix "Called by" cross-links in docstrings - remove incorrect 'the' - add `refresh_user` where appropriate - fix some links to the current class --- oauthenticator/oauth2.py | 33 +++++++++++++++++++-------------- 1 file changed, 19 insertions(+), 14 deletions(-) diff --git a/oauthenticator/oauth2.py b/oauthenticator/oauth2.py index 204074fb..e4865948 100644 --- a/oauthenticator/oauth2.py +++ b/oauthenticator/oauth2.py @@ -270,7 +270,7 @@ class OAuthenticator(Authenticator): - Override the constant `user_auth_state_key` - Override various config's default values, such as `authorize_url`, `token_url`, `userdata_url`, and `login_service`. - - Override various methods called by the `authenticate` method, which + - Override various methods called by :meth:`authenticate`, which subclasses should not override. - Override handler classes such as `login_handler`, `callback_handler`, and `logout_handler`. @@ -919,7 +919,8 @@ def get_handlers(self, app): def build_userdata_request_headers(self, access_token, token_type): """ Builds and returns the headers to be used in the userdata request. - Called by the :meth:`oauthenticator.OAuthenticator.token_to_user` + + Called by :meth:`.token_to_user`. """ # token_type is case-insensitive, but the headers are case-sensitive @@ -937,7 +938,8 @@ def build_userdata_request_headers(self, access_token, token_type): def build_token_info_request_headers(self): """ Builds and returns the headers to be used in the access token request. - Called by the :meth:`oauthenticator.OAuthenticator.get_token_info`. + + Called by :meth:`.get_token_info`. The Content-Type header is specified by the OAuth 2.0 RFC in https://www.rfc-editor.org/rfc/rfc6749#section-4.1.3. utf-8 is also @@ -971,7 +973,7 @@ def user_info_to_username(self, user_info): Returns: user_info["self.username_claim"] or raises an error if such value isn't found. - Called by the :meth:`oauthenticator.OAuthenticator.authenticate` + Called by :meth:`.authenticate` and :meth:`.refresh_user`. """ if callable(self.username_claim): @@ -991,7 +993,8 @@ def build_access_tokens_request_params(self, handler, data=None): """ Builds the parameters that should be passed to the URL request that exchanges the OAuth code for the Access Token. - Called by the :meth:`oauthenticator.OAuthenticator.authenticate`. + + Called by :meth:`.authenticate`. """ code = handler.get_argument("code") if not code: @@ -1033,7 +1036,7 @@ def build_refresh_token_request_params(self, refresh_token): Builds the parameters that should be passed to the URL request to renew the Access Token based on the Refresh Token - Called by the :meth:`oauthenticator.OAuthenticator.refresh_user`. + Called by :meth:`.refresh_user`. """ params = { "grant_type": "refresh_token", @@ -1057,7 +1060,7 @@ async def get_token_info(self, handler, params): the JSON response to the `token_url` the request as described in https://www.rfc-editor.org/rfc/rfc6749#section-5.1 - Called by the :meth:`oauthenticator.OAuthenticator.authenticate` + Called by :meth:`.authenticate` and :meth:`.refresh_user`. """ token_info = await self.httpfetch( @@ -1081,9 +1084,9 @@ async def get_token_info(self, handler, params): async def token_to_user(self, token_info): """ Determines who the logged-in user by sending a "GET" request to - :data:`oauthenticator.OAuthenticator.userdata_url` using the `access_token`. + :attr:`.userdata_url` using the `access_token`. - If :data:`oauthenticator.OAuthenticator.userdata_from_id_token` is set then + If :attr:`.userdata_from_id_token` is set then extracts the corresponding info from an `id_token` instead. Args: @@ -1092,7 +1095,7 @@ async def token_to_user(self, token_info): Returns: the JSON response to the `userdata_url` request. - Called by the :meth:`oauthenticator.OAuthenticator.authenticate` + Called by :meth:`.authenticate` and :meth:`.refresh_user`. """ if self.userdata_from_id_token: # Use id token instead of exchanging access token with userinfo endpoint. @@ -1158,7 +1161,7 @@ def build_auth_state_dict(self, token_info, user_info): - "token_response": the full token_info response - self.user_auth_state_key: the full user_info response - Called by the :meth:`oauthenticator.OAuthenticator.authenticate` + Called by :meth:`.authenticate` and :meth:`.refresh_user`. .. versionchanged:: 17.0 This method may be async. @@ -1229,9 +1232,9 @@ async def update_auth_model(self, auth_model): - `admin`: the admin status (True/False/None), where None means it should be unchanged. - `auth_state`: the auth state dictionary, - returned by :meth:`oauthenticator.OAuthenticator.build_auth_state_dict` + returned by :meth:`.build_auth_state_dict` - Called by the :meth:`oauthenticator.OAuthenticator.authenticate` + Called by :meth:`.authenticate` and :meth:`.refresh_user`. """ # NOTE: this base implementation should _not_ be updated to do anything # subclasses should have full control without calling super() @@ -1379,7 +1382,9 @@ async def refresh_user(self, user, handler=None, **kwargs): async def _token_to_auth_model(self, token_info): """ - Common logic shared by authenticate() and refresh_user() + Turn a token into the user's `auth_model` to be returned by :meth:`.authenticate`. + + Common logic shared by :meth:`.authenticate` and :meth:`.refresh_user`. """ # use the access_token to get userdata info