Skip to content

Commit

Permalink
Add SSH command for executions (#314)
Browse files Browse the repository at this point in the history
  • Loading branch information
memona008 authored Sep 12, 2024
1 parent 1ae4a55 commit fb79331
Show file tree
Hide file tree
Showing 10 changed files with 338 additions and 9 deletions.
80 changes: 80 additions & 0 deletions tests/commands/execution/test_ssh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,80 @@
import shutil
import subprocess
import time

from tests.commands.execution.utils import get_execution_data_mock, no_sleep
from tests.fixture_data import EXECUTION_DATA, STATUS_EVENT_RESPONSE_DATA
from valohai_cli.commands.execution.ssh import ssh


def test_ssh_in_completed_execution(runner, logged_in_and_linked, monkeypatch):
with get_execution_data_mock():
counter = EXECUTION_DATA["counter"]
result = runner.invoke(ssh, [str(counter)], catch_exceptions=False)
assert f"Error: Execution #{counter} is complete. Cannot SSH into it.\n" in result.output
assert result.exit_code == 1


def test_ssh_in_queued_execution(runner, logged_in_and_linked, monkeypatch):
counter = EXECUTION_DATA["counter"]
monkeypatch.setitem(EXECUTION_DATA, "status", "queued")
monkeypatch.setattr(time, "sleep", no_sleep)
with get_execution_data_mock():
result = runner.invoke(ssh, [str(counter)], catch_exceptions=False)
assert f"Execution #{counter} is queued. Waiting for it to start...\n" in result.output
assert result.exit_code == 1


def test_ssh_with_no_ssh_details_present(runner, logged_in_and_linked, monkeypatch):
counter = EXECUTION_DATA["counter"]
monkeypatch.setitem(EXECUTION_DATA, "status", "started")
monkeypatch.setattr(time, "sleep", lambda x: None)
with get_execution_data_mock() as m:
m.get(
f"https://app.valohai.com/api/v0/executions/{EXECUTION_DATA['id']}/status-events/",
json={"status_events": []},
)
result = runner.invoke(ssh, [str(counter)], catch_exceptions=False)
output = result.output
assert "1/5 Retrying: No SSH details found...\n" in output
assert "2/5 Retrying: No SSH details found...\n" in output
assert "3/5 Retrying: No SSH details found...\n" in output
assert "4/5 Retrying: No SSH details found...\n" in output
assert "5/5 Retrying: No SSH details found...\n" in output

assert result.exit_code == 1


def test_ssh(runner, logged_in_and_linked, monkeypatch, tmp_path):
counter = EXECUTION_DATA["counter"]
monkeypatch.setitem(EXECUTION_DATA, "status", "started")

def mock_prompt():
return tmp_path

monkeypatch.setattr(
"valohai_cli.commands.execution.ssh.select_private_key_from_possible_directories",
mock_prompt,
)

with get_execution_data_mock() as m:
m.get(
f"https://app.valohai.com/api/v0/executions/{EXECUTION_DATA['id']}/status-events/",
json=STATUS_EVENT_RESPONSE_DATA,
)
result = runner.invoke(ssh, [str(counter)], catch_exceptions=False)
output = result.output
assert "SSH address is 127.0.0.1:2222" in output

def mock_subprocess_run(*args, **kwargs):
print(args[0])
return subprocess.CompletedProcess(args=args, returncode=0)

monkeypatch.setattr(subprocess, "run", mock_subprocess_run)

result = result.runner.invoke(ssh, [str(counter)], input="1", catch_exceptions=False)
assert (
f"['{shutil.which('ssh')}', '-i', PosixPath('{tmp_path}'), '[email protected]', '-p', '2222', '-t', '/bin/bash']"
in result.output
)
assert result.exit_code == 0
6 changes: 1 addition & 5 deletions tests/commands/execution/test_watch.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,10 @@
import time

from tests.commands.execution.utils import get_execution_data_mock
from tests.commands.execution.utils import get_execution_data_mock, no_sleep
from tests.fixture_data import EXECUTION_DATA, PROJECT_DATA
from valohai_cli.commands.execution.watch import watch


def no_sleep(t):
raise KeyboardInterrupt("no... sleep... til... Brooklyn!")


def test_execution_watch(runner, logged_in_and_linked, monkeypatch):
monkeypatch.setattr(time, "sleep", no_sleep)
with get_execution_data_mock():
Expand Down
4 changes: 4 additions & 0 deletions tests/commands/execution/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,3 +36,7 @@ def get_execution_data_mock():
m.delete(execution_by_counter_url, json={"ok": True})
m.post(re.compile("^https://app.valohai.com/api/v0/data/(.+?)/purge/$"), json={"ok": True})
return m


def no_sleep(t):
raise KeyboardInterrupt("no... sleep... til... Brooklyn!")
17 changes: 16 additions & 1 deletion tests/fixture_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,6 @@
path: model.h5
"""


PIPELINE_YAML = (
CONFIG_YAML
+ """
Expand Down Expand Up @@ -665,3 +664,19 @@ def main(config) -> Pipeline:
description: Model output file from TensorFlow
path: model.h5
"""

STATUS_EVENT_RESPONSE_DATA = {
"total": 2,
"status_events": [
{
"stream": "status",
"message": '::ssh::{"port": 2222, "address": "127.0.0.1"}',
"time": "2024-09-04T12:16:20.722000",
},
{
"stream": "status",
"message": " $ ssh -i <path-to-private-key> 127.0.0.1 -p 2222 -t /bin/bash",
"time": "2024-09-04T12:16:20.723000",
},
],
}
10 changes: 10 additions & 0 deletions valohai_cli/commands/execution/run/dynamic_run_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@
from valohai_cli.utils.file_input import read_data_file
from valohai_cli.utils.friendly_option_parser import FriendlyOptionParser

from ..ssh import ssh
from .excs import ExecutionCreationAPIError


Expand Down Expand Up @@ -59,6 +60,7 @@ def __init__(
tags: Optional[Sequence[str]] = None,
runtime_config: Optional[dict] = None,
runtime_config_preset: Optional[str] = None,
ssh: bool = False,
) -> None:
"""
Initialize the dynamic run command.
Expand All @@ -75,6 +77,7 @@ def __init__(
:param download_directory: Where to (if somewhere) to download execution outputs (sync mode)
:param runtime_config: Runtime config dict
:param runtime_config_preset: Runtime config preset identifier (UUID)
:param ssh: Whether to chain to `exec ssh` afterward
"""
assert isinstance(step, Step)
self.project = project
Expand All @@ -90,6 +93,7 @@ def __init__(
self.tags = list(tags or [])
self.runtime_config = dict(runtime_config or {})
self.runtime_config_preset = runtime_config_preset
self.ssh = ssh
super().__init__(
name=sanitize_option_name(step.name.lower()),
callback=self.execute,
Expand Down Expand Up @@ -201,6 +205,12 @@ def execute(self, **kwargs: Any) -> None:

webbrowser.open(resp["urls"]["display"])

if self.ssh:
try:
ctx.invoke(ssh, counter=resp["counter"])
except Exception as e:
warn(f"Failed to open SSH connection: {e}")

if self.watch:
from valohai_cli.commands.execution.watch import watch

Expand Down
12 changes: 10 additions & 2 deletions valohai_cli/commands/execution/run/frontend_command.py
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,11 @@
)
@click.option("--debug-port", type=int)
@click.option("--debug-key-file", type=click.Path(file_okay=True, readable=True, writable=False))
@click.option(
"--ssh",
is_flag=True,
help="Start ssh remote connection for debugging the execution after it starts.",
)
@click.option(
"--autorestart/--no-autorestart",
help="Enable Automatic Restart on Spot Instance Interruption",
Expand Down Expand Up @@ -184,6 +189,7 @@ def run(
k8s_devices: List[str],
k8s_device_none: bool,
k8s_preset: Optional[str],
ssh: bool = False,
) -> Any:
"""
Start an execution of a step.
Expand All @@ -197,8 +203,8 @@ def run(
project = get_project(require=True)
project.refresh_details()

if download_directory and watch:
raise click.UsageError("Combining --sync and --watch not supported yet.")
if sum([watch, download_directory is not None, ssh]) > 1:
raise click.UsageError("Only one of --watch, --sync or --ssh can be set.")

if not commit and project.is_remote:
# For remote projects, we need to resolve early.
Expand Down Expand Up @@ -239,6 +245,7 @@ def run(
"debug_port": debug_port,
"debug_key": key,
}

if autorestart:
runtime_config["autorestart"] = autorestart

Expand Down Expand Up @@ -291,6 +298,7 @@ def run(
tags=tags,
runtime_config=runtime_config,
runtime_config_preset=k8s_preset,
ssh=ssh,
)
with rc.make_context(rc.name, list(args), parent=ctx) as child_ctx:
return rc.invoke(child_ctx)
Expand Down
46 changes: 46 additions & 0 deletions valohai_cli/commands/execution/ssh.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
import click

from valohai_cli.exceptions import CLIException
from valohai_cli.utils.cli_utils import counter_argument
from valohai_cli.utils.ssh import (
get_ssh_details_with_retry,
make_ssh_connection,
select_private_key_from_possible_directories,
)


@click.command()
@counter_argument
@click.option(
"--private-key-file",
default=None,
type=click.Path(file_okay=True, exists=True),
help="Private SSH key to use for the connection.",
)
@click.option(
"--address",
default=None,
help='Address of the container in "ip:port" format. If not provided, '
"the address from the execution will be used.",
)
def ssh(counter: int, private_key_file: str, address: str) -> None:
"""
Make SSH Connection to the execution container.
"""
if address:
try:
ip_address, _, port_str = address.partition(":")
if not ip_address or not port_str:
raise CLIException("Address must be in 'ip:port' format.")
port = int(port_str)
if port <= 1023:
raise CLIException("Port must be above 1023")
except ValueError as e:
raise CLIException(f"Invalid address format: {e}")
else:
ip_address, port = get_ssh_details_with_retry(counter)

click.echo(f"SSH address is {ip_address}:{port}")
if not private_key_file:
private_key_file = select_private_key_from_possible_directories()
make_ssh_connection(ip_address, port, private_key_file)
4 changes: 4 additions & 0 deletions valohai_cli/settings/paths.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,10 @@ def get_settings_root_path() -> str: # pragma: no cover


def get_settings_file_name(name: str) -> str:
"""
Get the path to a settings file in the user's configuration directory.
name can be empty string to get the directory itself.
"""
path = os.environ.get("VALOHAI_CONFIG_DIR")
if path:
if not os.path.isdir(path):
Expand Down
24 changes: 23 additions & 1 deletion valohai_cli/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,9 +3,10 @@
import random
import re
import string
import time
import unicodedata
import webbrowser
from typing import Any, Callable, Dict, Iterable, Iterator, Tuple, Union
from typing import Any, Callable, Dict, Iterable, Iterator, Tuple, Type, TypeVar, Union

import click

Expand Down Expand Up @@ -154,3 +155,24 @@ def parse_environment_variable_strings(

def compact_dict(dct: dict) -> dict:
return {key: value for (key, value) in dct.items() if key and value}


T = TypeVar("T")


def call_with_retry(
func: Callable[[], T],
retries: int = 3,
delay_range: Tuple[int, int] = (1, 5),
retry_on_exception_classes: Tuple[Type[Exception], ...] = (Exception,),
) -> T:
for attempt in range(retries):
try:
return func()
except retry_on_exception_classes as e:
click.echo(f"{attempt + 1}/{retries} Retrying: {e}...")
if attempt + 1 == retries:
raise RuntimeError(f"Failed after {retries} attempts") from e
time.sleep(random.uniform(*delay_range))

raise RuntimeError("Shouldn't get here!")
Loading

0 comments on commit fb79331

Please sign in to comment.