Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Implement OAuthenticator.refresh_user #579

Merged
merged 16 commits into from
Nov 27, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
179 changes: 139 additions & 40 deletions oauthenticator/oauth2.py
Original file line number Diff line number Diff line change
Expand Up @@ -270,7 +270,7 @@ class OAuthenticator(Authenticator):
- Override the constant `user_auth_state_key`
- Override various config's default values, such as
`authorize_url`, `token_url`, `userdata_url`, and `login_service`.
- Override various methods called by the `authenticate` method, which
- Override various methods called by :meth:`authenticate`, which
subclasses should not override.
- Override handler classes such as `login_handler`, `callback_handler`, and
`logout_handler`.
Expand Down Expand Up @@ -919,7 +919,8 @@ def get_handlers(self, app):
def build_userdata_request_headers(self, access_token, token_type):
"""
Builds and returns the headers to be used in the userdata request.
Called by the :meth:`oauthenticator.OAuthenticator.token_to_user`

Called by :meth:`.token_to_user`.
"""

# token_type is case-insensitive, but the headers are case-sensitive
Expand All @@ -937,7 +938,8 @@ def build_userdata_request_headers(self, access_token, token_type):
def build_token_info_request_headers(self):
"""
Builds and returns the headers to be used in the access token request.
Called by the :meth:`oauthenticator.OAuthenticator.get_token_info`.

Called by :meth:`.get_token_info`.

The Content-Type header is specified by the OAuth 2.0 RFC in
https://www.rfc-editor.org/rfc/rfc6749#section-4.1.3. utf-8 is also
Expand Down Expand Up @@ -971,7 +973,7 @@ def user_info_to_username(self, user_info):
Returns:
user_info["self.username_claim"] or raises an error if such value isn't found.

Called by the :meth:`oauthenticator.OAuthenticator.authenticate`
Called by :meth:`.authenticate` and :meth:`.refresh_user`.
"""

if callable(self.username_claim):
Expand All @@ -987,25 +989,12 @@ def user_info_to_username(self, user_info):

return username

# Originally a GoogleOAuthenticator only feature
async def get_prev_refresh_token(self, handler, username):
"""
Retrieves the `refresh_token` from previous encrypted auth state.
Called by the :meth:`oauthenticator.OAuthenticator.authenticate`
"""
user = handler.find_user(username)
if not user:
return None
auth_state = await user.get_auth_state()
if not auth_state:
return None
return auth_state.get("refresh_token", None)

def build_access_tokens_request_params(self, handler, data=None):
"""
Builds the parameters that should be passed to the URL request
that exchanges the OAuth code for the Access Token.
Called by the :meth:`oauthenticator.OAuthenticator.authenticate`.

Called by :meth:`.authenticate`.
"""
code = handler.get_argument("code")
if not code:
Expand Down Expand Up @@ -1042,14 +1031,36 @@ def build_access_tokens_request_params(self, handler, data=None):

return params

def build_refresh_token_request_params(self, refresh_token):
"""
Builds the parameters that should be passed to the URL request
to renew the Access Token based on the Refresh Token

Called by :meth:`.refresh_user`.
"""
params = {
"grant_type": "refresh_token",
"refresh_token": refresh_token,
minrk marked this conversation as resolved.
Show resolved Hide resolved
}

# the client_id and client_secret should not be included in the access token request params
# when basic authentication is used
# ref: https://www.rfc-editor.org/rfc/rfc6749#section-2.3.1
if not self.basic_auth:
params["client_id"] = self.client_id
params["client_secret"] = self.client_secret

return params

async def get_token_info(self, handler, params):
"""
Makes a "POST" request to `self.token_url`, with the parameters received as argument.

Returns:
the JSON response to the `token_url` the request.
the JSON response to the `token_url` the request as described in
https://www.rfc-editor.org/rfc/rfc6749#section-5.1

Called by the :meth:`oauthenticator.OAuthenticator.authenticate`
Called by :meth:`.authenticate` and :meth:`.refresh_user`.
"""

token_info = await self.httpfetch(
Expand All @@ -1073,9 +1084,9 @@ async def get_token_info(self, handler, params):
async def token_to_user(self, token_info):
"""
Determines who the logged-in user by sending a "GET" request to
:data:`oauthenticator.OAuthenticator.userdata_url` using the `access_token`.
:attr:`.userdata_url` using the `access_token`.

If :data:`oauthenticator.OAuthenticator.userdata_from_id_token` is set then
If :attr:`.userdata_from_id_token` is set then
extracts the corresponding info from an `id_token` instead.

Args:
Expand All @@ -1084,7 +1095,7 @@ async def token_to_user(self, token_info):
Returns:
the JSON response to the `userdata_url` request.

Called by the :meth:`oauthenticator.OAuthenticator.authenticate`
Called by :meth:`.authenticate` and :meth:`.refresh_user`.
"""
if self.userdata_from_id_token:
# Use id token instead of exchanging access token with userinfo endpoint.
Expand Down Expand Up @@ -1134,7 +1145,7 @@ async def token_to_user(self, token_info):

def build_auth_state_dict(self, token_info, user_info):
"""
Builds the `auth_state` dict that will be returned by a succesfull `authenticate` method call.
Builds the `auth_state` dict that will be returned by a successful `authenticate` method call.
May be async (requires oauthenticator >= 17.0).

Args:
Expand All @@ -1150,13 +1161,13 @@ def build_auth_state_dict(self, token_info, user_info):
- "token_response": the full token_info response
- self.user_auth_state_key: the full user_info response

Called by the :meth:`oauthenticator.OAuthenticator.authenticate`
Called by :meth:`.authenticate` and :meth:`.refresh_user`.

.. versionchanged:: 17.0
This method may be async.
"""

# We know for sure the `access_token` key exists, oterwise we would have errored out already
# We know for sure the `access_token` key exists, otherwise we would have errored out already
access_token = token_info["access_token"]

refresh_token = token_info.get("refresh_token", None)
Expand Down Expand Up @@ -1221,9 +1232,9 @@ async def update_auth_model(self, auth_model):
- `admin`: the admin status (True/False/None), where None means it
should be unchanged.
- `auth_state`: the auth state dictionary,
returned by :meth:`oauthenticator.OAuthenticator.build_auth_state_dict`
returned by :meth:`.build_auth_state_dict`

Called by the :meth:`oauthenticator.OAuthenticator.authenticate`
Called by :meth:`.authenticate` and :meth:`.refresh_user`.
"""
# NOTE: this base implementation should _not_ be updated to do anything
# subclasses should have full control without calling super()
Expand Down Expand Up @@ -1276,24 +1287,112 @@ async def authenticate(self, handler, data=None, **kwargs):
"""
# build the parameters to be used in the request exchanging the oauth code for the access token
access_token_params = self.build_access_tokens_request_params(handler, data)
# exchange the oauth code for an access token and get the JSON with info about it
token_info = await self.get_token_info(handler, access_token_params)
# call the oauth endpoints
return await self._token_to_auth_model(token_info)

async def refresh_user(self, user, handler=None, **kwargs):
"""
Refresh user authentication

If auth_state is enabled, constructs a fresh user model
(the same as `authenticate`)
using the access_token in auth_state.
If requests with the access token fail
(e.g. because the token has expired)
and a refresh token is found, attempts to exchange
the refresh token for a new access token to store in auth_state.
If the access token still fails after refresh,
return False to require the user to login via oauth again.

Set `Authenticator.auth_refresh_age = 0` to disable.

Returns
-------

True:
If auth info is up-to-date and needs no changes
(always if `enable_auth_state` is False)
False:
If the user needs to login again
(e.g. tokens in `auth_state` unavailable or expired)
auth_model: dict
The same dict as `authenticate`, updating any fields that should change.
Can include things like group membership,
but in OAuthenticator this mainly refreshes
the token fields in `auth_state`.
"""
if not self.enable_auth_state:
# auth state not enabled, can't refresh
return True
minrk marked this conversation as resolved.
Show resolved Hide resolved
auth_state = await user.get_auth_state()
if not auth_state:
self.log.info(
f"No auth_state found for user {user.name} refresh, need full authentication",
)
return False

token_info = auth_state.get("token_response")
auth_model = None
try:
auth_model = await self._token_to_auth_model(token_info)
except HTTPClientError as e:
# assume any client error means an expired token
# most likely 401 or 403 for well-behaved providers
if 400 <= e.code < 500:
self.log.info(
f"Error refreshing auth with current access_token for {user.name}: {e}. Will try to refresh, if possible."
)
else:
raise
refresh_token = auth_state.get("refresh_token", None)
if refresh_token and not auth_model:
self.log.info(f"Refreshing oauth access token for {user.name}")
# access_token expired, try refreshing with refresh_token
refresh_token_params = self.build_refresh_token_request_params(
refresh_token
)
try:
token_info = await self.get_token_info(handler, refresh_token_params)
except Exception as e:
self.log.info(
f"Error using refresh_token for {user.name}: {e}. Requiring fresh login."
)
return False
else:
self.log.debug(
f"Received fresh access_token for {user.name} via refresh_token"
)
# refresh_token may not be returned when refreshing a token
# in which case, keep the current one
if not token_info.get("refresh_token"):
token_info["refresh_token"] = refresh_token
try:
auth_model = await self._token_to_auth_model(token_info)
except Exception as e:
# this means we were issued a fresh access token,
# but it didn't work! Fail harder?
self.log.error(
f"Error refreshing auth with fresh access_token for {user.name}: {e}. Requiring fresh login."
)
return False

# return False if auth_model is None for "needs new login"
return auth_model or False

async def _token_to_auth_model(self, token_info):
"""
Turn a token into the user's `auth_model` to be returned by :meth:`.authenticate`.

Common logic shared by :meth:`.authenticate` and :meth:`.refresh_user`.
"""

# use the access_token to get userdata info
user_info = await self.token_to_user(token_info)
# extract the username out of the user_info dict and normalize it
username = self.user_info_to_username(user_info)
username = self.normalize_username(username)

# check if there any refresh_token in the token_info dict
refresh_token = token_info.get("refresh_token", None)
if self.enable_auth_state and not refresh_token:
self.log.debug(
"Refresh token was empty, will try to pull refresh_token from previous auth_state"
)
refresh_token = await self.get_prev_refresh_token(handler, username)
minrk marked this conversation as resolved.
Show resolved Hide resolved
if refresh_token:
token_info["refresh_token"] = refresh_token

auth_state = self.build_auth_state_dict(token_info, user_info)
if isawaitable(auth_state):
auth_state = await auth_state
Expand Down
52 changes: 41 additions & 11 deletions oauthenticator/tests/mocks.py
Original file line number Diff line number Diff line change
Expand Up @@ -107,6 +107,7 @@ def setup_oauth_mock(
user_path=None,
token_type='Bearer',
token_request_style='post',
enable_refresh_tokens=False,
scope="",
):
"""setup the mock client for OAuth
Expand Down Expand Up @@ -134,6 +135,8 @@ def setup_oauth_mock(

client.oauth_codes = oauth_codes = {}
client.access_tokens = access_tokens = {}
client.refresh_tokens = refresh_tokens = {}
client.enable_refresh_tokens = enable_refresh_tokens

def access_token(request):
"""Handler for access token endpoint
Expand All @@ -146,26 +149,53 @@ def access_token(request):
if not query:
query = request.body.decode('utf8')
query = parse_qs(query)
if 'code' not in query:
grant_type = query.get("grant_type", [""])[0]
if grant_type == 'authorization_code':
if 'code' not in query:
return HTTPResponse(
request=request,
code=400,
reason=f"No code in access token request: url={request.url}, body={request.body}",
)
code = query['code'][0]
if code not in oauth_codes:
return HTTPResponse(
request=request, code=403, reason=f"No such code: {code}"
)
user = oauth_codes.pop(code)
elif grant_type == 'refresh_token':
if 'refresh_token' not in query:
return HTTPResponse(
request=request,
code=400,
reason=f"No refresh_token in access token request: url={request.url}, body={request.body}",
)
refresh_token = query['refresh_token'][0]
if refresh_token not in refresh_token:
return HTTPResponse(
request=request,
code=403,
reason=f"No such refresh_toekn: {refresh_token}",
)
user = refresh_tokens[refresh_token]
else:
return HTTPResponse(
request=request,
code=400,
reason=f"No code in access token request: url={request.url}, body={request.body}",
)
code = query['code'][0]
if code not in oauth_codes:
return HTTPResponse(
request=request, code=403, reason=f"No such code: {code}"
reason=f"Invalid grant_type={grant_type}: url={request.url}, body={request.body}",
)

# consume code, allocate token
token = uuid.uuid4().hex
user = oauth_codes.pop(code)
access_tokens[token] = user
access_token = uuid.uuid4().hex
access_tokens[access_token] = user
model = {
'access_token': token,
'access_token': access_token,
'token_type': token_type,
}
if client.enable_refresh_tokens:
refresh_token = uuid.uuid4().hex
refresh_tokens[refresh_token] = user
model['refresh_token'] = refresh_token
if scope:
model['scope'] = scope
if 'id_token' in user:
Expand Down
Loading