Skip to content

Commit

Permalink
Avoid re-using old message bus path when restarting a workflow (DataB…
Browse files Browse the repository at this point in the history
…iosphere#4445)

* Add test to replicate the missing message bus crash

* Fix DataBiosphere#4438 by not trying to re-use the old message bus log when restarting

* Fix DataBiosphere#4410 by making everybody use thread-local Boto3 caching and a global initialization lock

* Restore handling for custom client config
  • Loading branch information
adamnovak authored Apr 26, 2023
1 parent 2804edf commit 356261d
Show file tree
Hide file tree
Showing 4 changed files with 213 additions and 62 deletions.
26 changes: 23 additions & 3 deletions src/toil/common.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ def __init__(self) -> None:
self.writeLogs = None
self.writeLogsGzip = None
self.writeLogsFromAllJobs: bool = False
self.write_messages: str = ""
self.write_messages: Optional[str] = None

# Misc
self.environment: Dict[str, str] = {}
Expand All @@ -222,6 +222,24 @@ def __init__(self) -> None:
# CWL
self.cwl: bool = False

def prepare_start(self) -> None:
"""
After options are set, prepare for initial start of workflow.
"""
self.workflowAttemptNumber = 0

def prepare_restart(self) -> None:
"""
Before restart options are set, prepare for a restart of a workflow.
Set up any execution-specific parameters and clear out any stale ones.
"""
self.workflowAttemptNumber += 1
# We should clear the stored message bus path, because it may have been
# auto-generated and point to a temp directory that could no longer
# exist and that can't safely be re-made.
self.write_messages = None


def setOptions(self, options: Namespace) -> None:
"""Creates a config object from the options object."""
OptionType = TypeVar("OptionType")
Expand Down Expand Up @@ -407,6 +425,8 @@ def check_nodestoreage_overrides(overrides: List[str]) -> bool:
set_option("write_messages", os.path.abspath)

if not self.write_messages:
# The user hasn't specified a place for the message bus so we
# should make one.
self.write_messages = gen_message_bus_path()

assert not (self.writeLogs and self.writeLogsGzip), \
Expand Down Expand Up @@ -947,14 +967,14 @@ def __enter__(self) -> "Toil":
self.options.caching = config.caching

if not config.restart:
config.workflowAttemptNumber = 0
config.prepare_start()
jobStore.initialize(config)
else:
jobStore.resume()
# Merge configuration from job store with command line options
config = jobStore.config
config.prepare_restart()
config.setOptions(self.options)
config.workflowAttemptNumber += 1
jobStore.write_config()
self.config = config
self._jobStore = jobStore
Expand Down
172 changes: 118 additions & 54 deletions src/toil/lib/aws/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import re
import socket
import threading
from functools import lru_cache
from typing import (Any,
Callable,
Dict,
Expand All @@ -37,16 +36,33 @@
import boto.connection
import botocore
from boto3 import Session
from botocore.client import Config
from botocore.credentials import JSONFileCache
from botocore.session import get_session

logger = logging.getLogger(__name__)

@lru_cache(maxsize=None)
def establish_boto3_session(region_name: Optional[str] = None) -> Session:
# A note on thread safety:
#
# Boto3 Session: Not thread safe, 1 per thread is required.
#
# Boto3 Resources: Not thread safe, one per thread is required.
#
# Boto3 Client: Thread safe after initialization, but initialization is *not*
# thread safe and only one can be being made at a time. They also are
# restricted to a single Python *process*.
#
# See: <https://stackoverflow.com/questions/52820971/is-boto3-client-thread-safe>

# We use this lock to control initialization so only one thread can be
# initializing Boto3 (or Boto2) things at a time.
_init_lock = threading.RLock()

def _new_boto3_session(region_name: Optional[str] = None) -> Session:
"""
This is the One True Place where Boto3 sessions should be established, and
prepares them with the necessary credential caching.
This is the One True Place where new Boto3 sessions should be made, and
prepares them with the necessary credential caching. Does *not* cache
sessions, because each thread needs its own caching.
:param region_name: If given, the session will be associated with the given AWS region.
"""
Expand All @@ -55,35 +71,12 @@ def establish_boto3_session(region_name: Optional[str] = None) -> Session:
# See https://github.com/boto/botocore/pull/1338/
# And https://github.com/boto/botocore/commit/2dae76f52ae63db3304b5933730ea5efaaaf2bfc

botocore_session = get_session()
botocore_session.get_component('credential_provider').get_provider(
'assume-role').cache = JSONFileCache()

return Session(botocore_session=botocore_session, region_name=region_name, profile_name=os.environ.get("TOIL_AWS_PROFILE", None))

@lru_cache(maxsize=None)
def client(service_name: str, *args: List[Any], region_name: Optional[str] = None, **kwargs: Dict[str, Any]) -> botocore.client.BaseClient:
"""
Get a Boto 3 client for a particular AWS service.
Global alternative to AWSConnectionManager.
"""
session = establish_boto3_session(region_name=region_name)
# MyPy can't understand our argument unpacking. See <https://github.com/vemel/mypy_boto3_builder/issues/121>
client: botocore.client.BaseClient = session.client(service_name, *args, **kwargs) # type: ignore
return client

@lru_cache(maxsize=None)
def resource(service_name: str, *args: List[Any], region_name: Optional[str] = None, **kwargs: Dict[str, Any]) -> boto3.resources.base.ServiceResource:
"""
Get a Boto 3 resource for a particular AWS service.
with _init_lock:
botocore_session = get_session()
botocore_session.get_component('credential_provider').get_provider(
'assume-role').cache = JSONFileCache()

Global alternative to AWSConnectionManager.
"""
session = establish_boto3_session(region_name=region_name)
# MyPy can't understand our argument unpacking. See <https://github.com/vemel/mypy_boto3_builder/issues/121>
resource: boto3.resources.base.ServiceResource = session.resource(service_name, *args, **kwargs) # type: ignore
return resource
return Session(botocore_session=botocore_session, region_name=region_name, profile_name=os.environ.get("TOIL_AWS_PROFILE", None))

class AWSConnectionManager:
"""
Expand All @@ -98,6 +91,10 @@ class AWSConnectionManager:
connections to multiple regions may need to be managed in the same
provisioner.
We also support None for a region, in which case no region will be
passed to Boto/Boto3. The caller is responsible for implementing e.g.
TOIL_AWS_REGION support.
Since connection objects may not be thread safe (see
<https://boto3.amazonaws.com/v1/documentation/api/1.14.31/guide/session.html#multithreading-or-multiprocessing-with-sessions>),
one is created for each thread that calls the relevant lookup method.
Expand All @@ -115,54 +112,87 @@ def __init__(self) -> None:
"""
# This stores Boto3 sessions in .item of a thread-local storage, by
# region.
self.sessions_by_region: Dict[str, threading.local] = collections.defaultdict(threading.local)
self.sessions_by_region: Dict[Optional[str], threading.local] = collections.defaultdict(threading.local)
# This stores Boto3 resources in .item of a thread-local storage, by
# (region, service name) tuples
self.resource_cache: Dict[Tuple[str, str], threading.local] = collections.defaultdict(threading.local)
# (region, service name, endpoint URL) tuples
self.resource_cache: Dict[Tuple[Optional[str], str, Optional[str]], threading.local] = collections.defaultdict(threading.local)
# This stores Boto3 clients in .item of a thread-local storage, by
# (region, service name) tuples
self.client_cache: Dict[Tuple[str, str], threading.local] = collections.defaultdict(threading.local)
# (region, service name, endpoint URL) tuples
self.client_cache: Dict[Tuple[Optional[str], str, Optional[str]], threading.local] = collections.defaultdict(threading.local)
# This stores Boto 2 connections in .item of a thread-local storage, by
# (region, service name) tuples.
self.boto2_cache: Dict[Tuple[str, str], threading.local] = collections.defaultdict(threading.local)
self.boto2_cache: Dict[Tuple[Optional[str], str], threading.local] = collections.defaultdict(threading.local)

def session(self, region: str) -> boto3.session.Session:
def session(self, region: Optional[str]) -> boto3.session.Session:
"""
Get the Boto3 Session to use for the given region.
"""
storage = self.sessions_by_region[region]
if not hasattr(storage, 'item'):
# This is the first time this thread wants to talk to this region
# through this manager
storage.item = establish_boto3_session(region_name=region)
storage.item = _new_boto3_session(region_name=region)
return cast(boto3.session.Session, storage.item)

def resource(self, region: str, service_name: str) -> boto3.resources.base.ServiceResource:
def resource(self, region: Optional[str], service_name: str, endpoint_url: Optional[str] = None) -> boto3.resources.base.ServiceResource:
"""
Get the Boto3 Resource to use with the given service (like 'ec2') in the given region.
:param endpoint_url: AWS endpoint URL to use for the client. If not
specified, a default is used.
"""
key = (region, service_name)
key = (region, service_name, endpoint_url)
storage = self.resource_cache[key]
if not hasattr(storage, 'item'):
# The Boto3 stubs are missing an overload for `resource` that takes
# a non-literal string. See
# <https://github.com/vemel/mypy_boto3_builder/issues/121#issuecomment-1011322636>
storage.item = self.session(region).resource(service_name) # type: ignore
with _init_lock:
# We lock inside the if check; we don't care if the memoization
# sometimes results in multiple different copies leaking out.
# We lock because we call .resource()

if endpoint_url is not None:
# The Boto3 stubs are missing an overload for `resource` that takes
# a non-literal string. See
# <https://github.com/vemel/mypy_boto3_builder/issues/121#issuecomment-1011322636>
storage.item = self.session(region).resource(service_name, endpoint_url=endpoint_url) # type: ignore
else:
# We might not be able to pass None to Boto3 and have it be the same as no argument.
storage.item = self.session(region).resource(service_name) # type: ignore

return cast(boto3.resources.base.ServiceResource, storage.item)

def client(self, region: str, service_name: str) -> botocore.client.BaseClient:
def client(self, region: Optional[str], service_name: str, endpoint_url: Optional[str] = None, config: Optional[Config] = None) -> botocore.client.BaseClient:
"""
Get the Boto3 Client to use with the given service (like 'ec2') in the given region.
:param endpoint_url: AWS endpoint URL to use for the client. If not
specified, a default is used.
:param config: Custom configuration to use for the client.
"""
key = (region, service_name)

if config is not None:
# Don't try and memoize if a custom config is used
with _init_lock:
if endpoint_url is not None:
return self.session(region).client(service_name, endpoint_url=endpoint_url, config=config) # type: ignore
else:
return self.session(region).client(service_name, config=config) # type: ignore

key = (region, service_name, endpoint_url)
storage = self.client_cache[key]
if not hasattr(storage, 'item'):
# The Boto3 stubs are probably missing an overload here too. See:
# <https://github.com/vemel/mypy_boto3_builder/issues/121#issuecomment-1011322636>
storage.item = self.session(region).client(service_name) # type: ignore
with _init_lock:
# We lock because we call .client()

if endpoint_url is not None:
# The Boto3 stubs are probably missing an overload here too. See:
# <https://github.com/vemel/mypy_boto3_builder/issues/121#issuecomment-1011322636>
storage.item = self.session(region).client(service_name, endpoint_url=endpoint_url) # type: ignore
else:
# We might not be able to pass None to Boto3 and have it be the same as no argument.
storage.item = self.session(region).client(service_name) # type: ignore
return cast(botocore.client.BaseClient , storage.item)

def boto2(self, region: str, service_name: str) -> boto.connection.AWSAuthConnection:
def boto2(self, region: Optional[str], service_name: str) -> boto.connection.AWSAuthConnection:
"""
Get the connected boto2 connection for the given region and service.
"""
Expand All @@ -172,5 +202,39 @@ def boto2(self, region: str, service_name: str) -> boto.connection.AWSAuthConnec
key = (region, service_name)
storage = self.boto2_cache[key]
if not hasattr(storage, 'item'):
storage.item = getattr(boto, service_name).connect_to_region(region, profile_name=os.environ.get("TOIL_AWS_PROFILE", None))
with _init_lock:
storage.item = getattr(boto, service_name).connect_to_region(region, profile_name=os.environ.get("TOIL_AWS_PROFILE", None))
return cast(boto.connection.AWSAuthConnection, storage.item)

# If you don't want your own AWSConnectionManager, we have a global one and some global functions
_global_manager = AWSConnectionManager()

def establish_boto3_session(region_name: Optional[str] = None) -> Session:
"""
Get a Boto 3 session usable by the current thread.
This function may not always establish a *new* session; it can be memoized.
"""

# Just use a global version of the manager. Note that we change the argument order!
return _global_manager.session(region_name)

def client(service_name: str, region_name: Optional[str] = None, endpoint_url: Optional[str] = None, config: Optional[Config] = None) -> botocore.client.BaseClient:
"""
Get a Boto 3 client for a particular AWS service, usable by the current thread.
Global alternative to AWSConnectionManager.
"""

# Just use a global version of the manager. Note that we change the argument order!
return _global_manager.client(region_name, service_name, endpoint_url=endpoint_url, config=config)

def resource(service_name: str, region_name: Optional[str] = None, endpoint_url: Optional[str] = None) -> boto3.resources.base.ServiceResource:
"""
Get a Boto 3 resource for a particular AWS service, usable by the current thread.
Global alternative to AWSConnectionManager.
"""

# Just use a global version of the manager. Note that we change the argument order!
return _global_manager.resource(region_name, service_name, endpoint_url=endpoint_url)
64 changes: 64 additions & 0 deletions src/toil/test/src/busTest.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,12 +13,19 @@
# limitations under the License.

import logging
import os
from threading import Thread, current_thread
from typing import Optional

from toil.batchSystems.abstractBatchSystem import BatchJobExitReason
from toil.bus import JobCompletedMessage, JobIssuedMessage, MessageBus, replay_message_bus
from toil.common import Toil
from toil.job import Job
from toil.exceptions import FailedJobsException
from toil.test import ToilTest, get_temp_file



logger = logging.getLogger(__name__)

class MessageBusTest(ToilTest):
Expand Down Expand Up @@ -95,5 +102,62 @@ def send_thread_message() -> None:
# And having polled for those, our handler should have run
self.assertEqual(message_count, 11)

def test_restart_without_bus_path(self) -> None:
"""
Test the ability to restart a workflow when the message bus path used
by the previous attempt is gone.
"""
temp_dir = self._createTempDir(purpose='tempDir')
job_store = self._getTestJobStorePath()

bus_holder_dir = os.path.join(temp_dir, 'bus_holder')
os.mkdir(bus_holder_dir)

start_options = Job.Runner.getDefaultOptions(job_store)
start_options.logLevel = 'DEBUG'
start_options.retryCount = 0
start_options.clean = "never"
start_options.write_messages = os.path.abspath(os.path.join(bus_holder_dir, 'messagebus.txt'))

root = Job.wrapJobFn(failing_job_fn)

try:
with Toil(start_options) as toil:
# Run once and observe a failed job
toil.start(root)
except FailedJobsException:
pass

logger.info('First attempt successfully failed, removing message bus log')

# Get rid of the bus
os.unlink(start_options.write_messages)
os.rmdir(bus_holder_dir)

logger.info('Making second attempt')

# Set up options without a specific bus path
restart_options = Job.Runner.getDefaultOptions(job_store)
restart_options.logLevel = 'DEBUG'
restart_options.retryCount = 0
restart_options.clean = "never"
restart_options.restart = True

try:
with Toil(restart_options) as toil:
# Run again and observe a failed job (and not a failure to start)
toil.restart()
except FailedJobsException:
pass

logger.info('Second attempt successfully failed')


def failing_job_fn(job: Job) -> None:
"""
This function is guaranteed to fail.
"""
raise RuntimeError('Job attempted to run but failed')



Loading

0 comments on commit 356261d

Please sign in to comment.