diff --git a/oauthenticator/google.py b/oauthenticator/google.py index 71c87c82..ddfb026c 100644 --- a/oauthenticator/google.py +++ b/oauthenticator/google.py @@ -14,6 +14,7 @@ class GoogleOAuthenticator(OAuthenticator, GoogleOAuth2Mixin): user_auth_state_key = "google_user" + _service_credentials = {} @default("login_service") def _login_service_default(self): @@ -251,7 +252,7 @@ async def update_auth_model(self, auth_model): user_groups = set() if self.allowed_google_groups or self.admin_google_groups: - user_groups = self._fetch_member_groups(user_email, user_domain) + user_groups = await self._fetch_member_groups(user_email, user_domain) # sets are not JSONable, cast to list for auth_state user_info["google_groups"] = list(user_groups) @@ -322,6 +323,36 @@ async def check_allowed(self, username, auth_model): # users should be explicitly allowed via config, otherwise they aren't return False + def _get_service_credentials(self, user_email_domain): + """ + Returns the stored credentials or fetches and stores new ones. + + Checks if the credentials are valid before returning them. Refreshes + if necessary and stores the refreshed credentials. + """ + if ( + user_email_domain not in self._service_credentials + or not self._is_token_valid(user_email_domain) + ): + self._service_credentials[user_email_domain] = ( + self._setup_service_credentials(user_email_domain) + ) + + return self._service_credentials + + def _is_token_valid(self, user_email_domain): + """ + Checks if the stored token is valid. + """ + if not self._service_credentials[user_email_domain]: + return False + if not self._service_credentials[user_email_domain].token: + return False + if self._service_credentials[user_email_domain].expired: + return False + + return True + def _service_client_credentials(self, scopes, user_email_domain): """ Return a configured service client credentials for the API. @@ -346,71 +377,57 @@ def _service_client_credentials(self, scopes, user_email_domain): return credentials - def _service_client(self, service_name, service_version, credentials, http=None): + def _setup_service_credentials(self, user_email_domain): """ - Return a configured service client for the API. + Set up the oauth credentials for Google API. """ + credentials = self._service_client_credentials( + scopes=[f"{self.google_api_url}/auth/admin.directory.group.readonly"], + user_email_domain=user_email_domain, + ) + try: - from googleapiclient.discovery import build + from google.auth.transport import requests except: raise ImportError( - "Could not import googleapiclient.discovery's build," + "Could not import google.auth.transport's requests," "you may need to run 'pip install oauthenticator[googlegroups]' or not declare google groups" ) - self.log.debug( - f"service_name is {service_name}, service_version is {service_version}" - ) - - return build( - serviceName=service_name, - version=service_version, - credentials=credentials, - cache_discovery=False, - http=http, - ) - - def _setup_service(self, user_email_domain, http=None): - """ - Set up the service client for Google API. - """ - credentials = self._service_client_credentials( - scopes=[f"{self.google_api_url}/auth/admin.directory.group.readonly"], - user_email_domain=user_email_domain, - ) - service = self._service_client( - service_name='admin', - service_version='directory_v1', - credentials=credentials, - http=http, - ) - return service + request = requests.Request() + credentials.refresh(request) + self.log.debug(f"Credentials refreshed for {user_email_domain}") + return credentials - def _fetch_member_groups( + async def _fetch_member_groups( self, member_email, user_email_domain, http=None, checked_groups=None, processed_groups=None, + credentials=None, ): """ Return a set with the google groups a given user/group is a member of, including nested groups if allowed. """ - # FIXME: When this function is used and waiting for web request - # responses, JupyterHub gets blocked from doing other things. - # Ideally the web requests should be made using an async client - # that can be awaited while JupyterHub handles other things. - # - if not hasattr(self, 'service'): - self.service = self._setup_service(user_email_domain, http) - + # WARNING: There's a race condition here if multiple users login at the same time. + # This is currently ignored. + credentials = credentials or self._get_service_credentials(user_email_domain) + token = credentials[user_email_domain].token checked_groups = checked_groups or set() processed_groups = processed_groups or set() - resp = self.service.groups().list(userKey=member_email).execute() + headers = {'Authorization': f'Bearer {token}'} + url = f'https://www.googleapis.com/admin/directory/v1/groups?userKey={member_email}' + group_data = await self.httpfetch( + url, headers=headers, label="fetching google groups" + ) + member_groups = { - g['email'].split('@')[0] for g in resp.get('groups', []) if g.get('email') + g['email'].split('@')[0] + for g in group_data.get('groups', []) + if g.get('email') } self.log.debug(f"Fetched groups for {member_email}: {member_groups}") @@ -422,7 +439,7 @@ def _fetch_member_groups( if group in processed_groups: continue processed_groups.add(group) - nested_groups = self._fetch_member_groups( + nested_groups = await self._fetch_member_groups( f"{group}@{user_email_domain}", user_email_domain, http, diff --git a/oauthenticator/tests/test_google.py b/oauthenticator/tests/test_google.py index 4c3a9e0c..070de014 100644 --- a/oauthenticator/tests/test_google.py +++ b/oauthenticator/tests/test_google.py @@ -3,6 +3,7 @@ import logging import re from unittest import mock +from unittest.mock import AsyncMock from pytest import fixture, mark, raises from traitlets.config import Config @@ -211,7 +212,7 @@ async def test_google( handled_user_model = user_model("user1@example.com", "user1") handler = google_client.handler_for_user(handled_user_model) with mock.patch.object( - authenticator, "_fetch_member_groups", lambda *args: {"group1"} + authenticator, "_fetch_member_groups", AsyncMock(return_value={"group1"}) ): auth_model = await authenticator.get_authenticated_user(handler, None) diff --git a/setup.py b/setup.py index 480e04d3..615695be 100644 --- a/setup.py +++ b/setup.py @@ -91,7 +91,6 @@ def run(self): # googlegroups is required for use of GoogleOAuthenticator configured with # either admin_google_groups and/or allowed_google_groups. 'googlegroups': [ - 'google-api-python-client', 'google-auth-oauthlib', ], # mediawiki is required for use of MWOAuthenticator @@ -105,7 +104,6 @@ def run(self): 'pytest-cov', 'requests-mock', # dependencies from googlegroups: - 'google-api-python-client', 'google-auth-oauthlib', # dependencies from mediawiki: 'mwoauth>=0.3.8',