From 356261d5ac84913d4c4f897d990c285e2a7e11f6 Mon Sep 17 00:00:00 2001 From: Adam Novak Date: Wed, 26 Apr 2023 07:55:21 -0700 Subject: [PATCH] Avoid re-using old message bus path when restarting a workflow (#4445) * Add test to replicate the missing message bus crash * Fix #4438 by not trying to re-use the old message bus log when restarting * Fix #4410 by making everybody use thread-local Boto3 caching and a global initialization lock * Restore handling for custom client config --- src/toil/common.py | 26 +++++- src/toil/lib/aws/session.py | 172 ++++++++++++++++++++++++----------- src/toil/test/src/busTest.py | 64 +++++++++++++ src/toil/utils/toilStatus.py | 13 ++- 4 files changed, 213 insertions(+), 62 deletions(-) diff --git a/src/toil/common.py b/src/toil/common.py index d0eb1bc82e..d4a8c73c20 100644 --- a/src/toil/common.py +++ b/src/toil/common.py @@ -198,7 +198,7 @@ def __init__(self) -> None: self.writeLogs = None self.writeLogsGzip = None self.writeLogsFromAllJobs: bool = False - self.write_messages: str = "" + self.write_messages: Optional[str] = None # Misc self.environment: Dict[str, str] = {} @@ -222,6 +222,24 @@ def __init__(self) -> None: # CWL self.cwl: bool = False + def prepare_start(self) -> None: + """ + After options are set, prepare for initial start of workflow. + """ + self.workflowAttemptNumber = 0 + + def prepare_restart(self) -> None: + """ + Before restart options are set, prepare for a restart of a workflow. + Set up any execution-specific parameters and clear out any stale ones. + """ + self.workflowAttemptNumber += 1 + # We should clear the stored message bus path, because it may have been + # auto-generated and point to a temp directory that could no longer + # exist and that can't safely be re-made. + self.write_messages = None + + def setOptions(self, options: Namespace) -> None: """Creates a config object from the options object.""" OptionType = TypeVar("OptionType") @@ -407,6 +425,8 @@ def check_nodestoreage_overrides(overrides: List[str]) -> bool: set_option("write_messages", os.path.abspath) if not self.write_messages: + # The user hasn't specified a place for the message bus so we + # should make one. self.write_messages = gen_message_bus_path() assert not (self.writeLogs and self.writeLogsGzip), \ @@ -947,14 +967,14 @@ def __enter__(self) -> "Toil": self.options.caching = config.caching if not config.restart: - config.workflowAttemptNumber = 0 + config.prepare_start() jobStore.initialize(config) else: jobStore.resume() # Merge configuration from job store with command line options config = jobStore.config + config.prepare_restart() config.setOptions(self.options) - config.workflowAttemptNumber += 1 jobStore.write_config() self.config = config self._jobStore = jobStore diff --git a/src/toil/lib/aws/session.py b/src/toil/lib/aws/session.py index 4ef36f165b..cbc2b686b6 100644 --- a/src/toil/lib/aws/session.py +++ b/src/toil/lib/aws/session.py @@ -18,7 +18,6 @@ import re import socket import threading -from functools import lru_cache from typing import (Any, Callable, Dict, @@ -37,16 +36,33 @@ import boto.connection import botocore from boto3 import Session +from botocore.client import Config from botocore.credentials import JSONFileCache from botocore.session import get_session logger = logging.getLogger(__name__) -@lru_cache(maxsize=None) -def establish_boto3_session(region_name: Optional[str] = None) -> Session: +# A note on thread safety: +# +# Boto3 Session: Not thread safe, 1 per thread is required. +# +# Boto3 Resources: Not thread safe, one per thread is required. +# +# Boto3 Client: Thread safe after initialization, but initialization is *not* +# thread safe and only one can be being made at a time. They also are +# restricted to a single Python *process*. +# +# See: + +# We use this lock to control initialization so only one thread can be +# initializing Boto3 (or Boto2) things at a time. +_init_lock = threading.RLock() + +def _new_boto3_session(region_name: Optional[str] = None) -> Session: """ - This is the One True Place where Boto3 sessions should be established, and - prepares them with the necessary credential caching. + This is the One True Place where new Boto3 sessions should be made, and + prepares them with the necessary credential caching. Does *not* cache + sessions, because each thread needs its own caching. :param region_name: If given, the session will be associated with the given AWS region. """ @@ -55,35 +71,12 @@ def establish_boto3_session(region_name: Optional[str] = None) -> Session: # See https://github.com/boto/botocore/pull/1338/ # And https://github.com/boto/botocore/commit/2dae76f52ae63db3304b5933730ea5efaaaf2bfc - botocore_session = get_session() - botocore_session.get_component('credential_provider').get_provider( - 'assume-role').cache = JSONFileCache() - - return Session(botocore_session=botocore_session, region_name=region_name, profile_name=os.environ.get("TOIL_AWS_PROFILE", None)) - -@lru_cache(maxsize=None) -def client(service_name: str, *args: List[Any], region_name: Optional[str] = None, **kwargs: Dict[str, Any]) -> botocore.client.BaseClient: - """ - Get a Boto 3 client for a particular AWS service. - - Global alternative to AWSConnectionManager. - """ - session = establish_boto3_session(region_name=region_name) - # MyPy can't understand our argument unpacking. See - client: botocore.client.BaseClient = session.client(service_name, *args, **kwargs) # type: ignore - return client - -@lru_cache(maxsize=None) -def resource(service_name: str, *args: List[Any], region_name: Optional[str] = None, **kwargs: Dict[str, Any]) -> boto3.resources.base.ServiceResource: - """ - Get a Boto 3 resource for a particular AWS service. + with _init_lock: + botocore_session = get_session() + botocore_session.get_component('credential_provider').get_provider( + 'assume-role').cache = JSONFileCache() - Global alternative to AWSConnectionManager. - """ - session = establish_boto3_session(region_name=region_name) - # MyPy can't understand our argument unpacking. See - resource: boto3.resources.base.ServiceResource = session.resource(service_name, *args, **kwargs) # type: ignore - return resource + return Session(botocore_session=botocore_session, region_name=region_name, profile_name=os.environ.get("TOIL_AWS_PROFILE", None)) class AWSConnectionManager: """ @@ -98,6 +91,10 @@ class AWSConnectionManager: connections to multiple regions may need to be managed in the same provisioner. + We also support None for a region, in which case no region will be + passed to Boto/Boto3. The caller is responsible for implementing e.g. + TOIL_AWS_REGION support. + Since connection objects may not be thread safe (see ), one is created for each thread that calls the relevant lookup method. @@ -115,18 +112,18 @@ def __init__(self) -> None: """ # This stores Boto3 sessions in .item of a thread-local storage, by # region. - self.sessions_by_region: Dict[str, threading.local] = collections.defaultdict(threading.local) + self.sessions_by_region: Dict[Optional[str], threading.local] = collections.defaultdict(threading.local) # This stores Boto3 resources in .item of a thread-local storage, by - # (region, service name) tuples - self.resource_cache: Dict[Tuple[str, str], threading.local] = collections.defaultdict(threading.local) + # (region, service name, endpoint URL) tuples + self.resource_cache: Dict[Tuple[Optional[str], str, Optional[str]], threading.local] = collections.defaultdict(threading.local) # This stores Boto3 clients in .item of a thread-local storage, by - # (region, service name) tuples - self.client_cache: Dict[Tuple[str, str], threading.local] = collections.defaultdict(threading.local) + # (region, service name, endpoint URL) tuples + self.client_cache: Dict[Tuple[Optional[str], str, Optional[str]], threading.local] = collections.defaultdict(threading.local) # This stores Boto 2 connections in .item of a thread-local storage, by # (region, service name) tuples. - self.boto2_cache: Dict[Tuple[str, str], threading.local] = collections.defaultdict(threading.local) + self.boto2_cache: Dict[Tuple[Optional[str], str], threading.local] = collections.defaultdict(threading.local) - def session(self, region: str) -> boto3.session.Session: + def session(self, region: Optional[str]) -> boto3.session.Session: """ Get the Boto3 Session to use for the given region. """ @@ -134,35 +131,68 @@ def session(self, region: str) -> boto3.session.Session: if not hasattr(storage, 'item'): # This is the first time this thread wants to talk to this region # through this manager - storage.item = establish_boto3_session(region_name=region) + storage.item = _new_boto3_session(region_name=region) return cast(boto3.session.Session, storage.item) - def resource(self, region: str, service_name: str) -> boto3.resources.base.ServiceResource: + def resource(self, region: Optional[str], service_name: str, endpoint_url: Optional[str] = None) -> boto3.resources.base.ServiceResource: """ Get the Boto3 Resource to use with the given service (like 'ec2') in the given region. + + :param endpoint_url: AWS endpoint URL to use for the client. If not + specified, a default is used. """ - key = (region, service_name) + key = (region, service_name, endpoint_url) storage = self.resource_cache[key] if not hasattr(storage, 'item'): - # The Boto3 stubs are missing an overload for `resource` that takes - # a non-literal string. See - # - storage.item = self.session(region).resource(service_name) # type: ignore + with _init_lock: + # We lock inside the if check; we don't care if the memoization + # sometimes results in multiple different copies leaking out. + # We lock because we call .resource() + + if endpoint_url is not None: + # The Boto3 stubs are missing an overload for `resource` that takes + # a non-literal string. See + # + storage.item = self.session(region).resource(service_name, endpoint_url=endpoint_url) # type: ignore + else: + # We might not be able to pass None to Boto3 and have it be the same as no argument. + storage.item = self.session(region).resource(service_name) # type: ignore + return cast(boto3.resources.base.ServiceResource, storage.item) - def client(self, region: str, service_name: str) -> botocore.client.BaseClient: + def client(self, region: Optional[str], service_name: str, endpoint_url: Optional[str] = None, config: Optional[Config] = None) -> botocore.client.BaseClient: """ Get the Boto3 Client to use with the given service (like 'ec2') in the given region. + + :param endpoint_url: AWS endpoint URL to use for the client. If not + specified, a default is used. + :param config: Custom configuration to use for the client. """ - key = (region, service_name) + + if config is not None: + # Don't try and memoize if a custom config is used + with _init_lock: + if endpoint_url is not None: + return self.session(region).client(service_name, endpoint_url=endpoint_url, config=config) # type: ignore + else: + return self.session(region).client(service_name, config=config) # type: ignore + + key = (region, service_name, endpoint_url) storage = self.client_cache[key] if not hasattr(storage, 'item'): - # The Boto3 stubs are probably missing an overload here too. See: - # - storage.item = self.session(region).client(service_name) # type: ignore + with _init_lock: + # We lock because we call .client() + + if endpoint_url is not None: + # The Boto3 stubs are probably missing an overload here too. See: + # + storage.item = self.session(region).client(service_name, endpoint_url=endpoint_url) # type: ignore + else: + # We might not be able to pass None to Boto3 and have it be the same as no argument. + storage.item = self.session(region).client(service_name) # type: ignore return cast(botocore.client.BaseClient , storage.item) - def boto2(self, region: str, service_name: str) -> boto.connection.AWSAuthConnection: + def boto2(self, region: Optional[str], service_name: str) -> boto.connection.AWSAuthConnection: """ Get the connected boto2 connection for the given region and service. """ @@ -172,5 +202,39 @@ def boto2(self, region: str, service_name: str) -> boto.connection.AWSAuthConnec key = (region, service_name) storage = self.boto2_cache[key] if not hasattr(storage, 'item'): - storage.item = getattr(boto, service_name).connect_to_region(region, profile_name=os.environ.get("TOIL_AWS_PROFILE", None)) + with _init_lock: + storage.item = getattr(boto, service_name).connect_to_region(region, profile_name=os.environ.get("TOIL_AWS_PROFILE", None)) return cast(boto.connection.AWSAuthConnection, storage.item) + +# If you don't want your own AWSConnectionManager, we have a global one and some global functions +_global_manager = AWSConnectionManager() + +def establish_boto3_session(region_name: Optional[str] = None) -> Session: + """ + Get a Boto 3 session usable by the current thread. + + This function may not always establish a *new* session; it can be memoized. + """ + + # Just use a global version of the manager. Note that we change the argument order! + return _global_manager.session(region_name) + +def client(service_name: str, region_name: Optional[str] = None, endpoint_url: Optional[str] = None, config: Optional[Config] = None) -> botocore.client.BaseClient: + """ + Get a Boto 3 client for a particular AWS service, usable by the current thread. + + Global alternative to AWSConnectionManager. + """ + + # Just use a global version of the manager. Note that we change the argument order! + return _global_manager.client(region_name, service_name, endpoint_url=endpoint_url, config=config) + +def resource(service_name: str, region_name: Optional[str] = None, endpoint_url: Optional[str] = None) -> boto3.resources.base.ServiceResource: + """ + Get a Boto 3 resource for a particular AWS service, usable by the current thread. + + Global alternative to AWSConnectionManager. + """ + + # Just use a global version of the manager. Note that we change the argument order! + return _global_manager.resource(region_name, service_name, endpoint_url=endpoint_url) diff --git a/src/toil/test/src/busTest.py b/src/toil/test/src/busTest.py index 6be0ab4507..da6edb5421 100644 --- a/src/toil/test/src/busTest.py +++ b/src/toil/test/src/busTest.py @@ -13,12 +13,19 @@ # limitations under the License. import logging +import os from threading import Thread, current_thread +from typing import Optional from toil.batchSystems.abstractBatchSystem import BatchJobExitReason from toil.bus import JobCompletedMessage, JobIssuedMessage, MessageBus, replay_message_bus +from toil.common import Toil +from toil.job import Job +from toil.exceptions import FailedJobsException from toil.test import ToilTest, get_temp_file + + logger = logging.getLogger(__name__) class MessageBusTest(ToilTest): @@ -95,5 +102,62 @@ def send_thread_message() -> None: # And having polled for those, our handler should have run self.assertEqual(message_count, 11) + def test_restart_without_bus_path(self) -> None: + """ + Test the ability to restart a workflow when the message bus path used + by the previous attempt is gone. + """ + temp_dir = self._createTempDir(purpose='tempDir') + job_store = self._getTestJobStorePath() + + bus_holder_dir = os.path.join(temp_dir, 'bus_holder') + os.mkdir(bus_holder_dir) + + start_options = Job.Runner.getDefaultOptions(job_store) + start_options.logLevel = 'DEBUG' + start_options.retryCount = 0 + start_options.clean = "never" + start_options.write_messages = os.path.abspath(os.path.join(bus_holder_dir, 'messagebus.txt')) + + root = Job.wrapJobFn(failing_job_fn) + + try: + with Toil(start_options) as toil: + # Run once and observe a failed job + toil.start(root) + except FailedJobsException: + pass + + logger.info('First attempt successfully failed, removing message bus log') + + # Get rid of the bus + os.unlink(start_options.write_messages) + os.rmdir(bus_holder_dir) + + logger.info('Making second attempt') + + # Set up options without a specific bus path + restart_options = Job.Runner.getDefaultOptions(job_store) + restart_options.logLevel = 'DEBUG' + restart_options.retryCount = 0 + restart_options.clean = "never" + restart_options.restart = True + + try: + with Toil(restart_options) as toil: + # Run again and observe a failed job (and not a failure to start) + toil.restart() + except FailedJobsException: + pass + + logger.info('Second attempt successfully failed') + + +def failing_job_fn(job: Job) -> None: + """ + This function is guaranteed to fail. + """ + raise RuntimeError('Job attempted to run but failed') + diff --git a/src/toil/utils/toilStatus.py b/src/toil/utils/toilStatus.py index a25aa648e3..45617555ea 100644 --- a/src/toil/utils/toilStatus.py +++ b/src/toil/utils/toilStatus.py @@ -232,11 +232,14 @@ def print_bus_messages(self) -> None: """ print("\nMessage bus path: ", self.message_bus_path) - - replayed_messages = replay_message_bus(self.message_bus_path) - for key in replayed_messages: - if replayed_messages[key].exit_code != 0: - print(replayed_messages[key]) + if self.message_bus_path is not None: + if os.path.exists(self.message_bus_path): + replayed_messages = replay_message_bus(self.message_bus_path) + for key in replayed_messages: + if replayed_messages[key].exit_code != 0: + print(replayed_messages[key]) + else: + print("Message bus file is missing!") return None