diff --git a/tests/framework/microvm.py b/tests/framework/microvm.py index eb755ffc5fa..22cead8725b 100644 --- a/tests/framework/microvm.py +++ b/tests/framework/microvm.py @@ -247,6 +247,8 @@ def __init__( self.mem_size_bytes = None self.cpu_template_name = None + self._ssh_connections = [] + self._pre_cmd = [] if numa_node: node_str = str(numa_node) @@ -282,6 +284,10 @@ def kill(self): for monitor in self.monitors: monitor.stop() + # Cleanup all SSH connections + for conn in self._ssh_connections: + conn.close() + # We start with vhost-user backends, # because if we stop Firecracker first, the backend will want # to exit as well and this will cause a race condition. @@ -1007,13 +1013,15 @@ def ssh_iface(self, iface_idx=0): """Return a cached SSH connection on a given interface id.""" guest_ip = list(self.iface.values())[iface_idx]["iface"].guest_ip self.ssh_key = Path(self.ssh_key) - return net_tools.SSHConnection( - netns=self.netns.id, + connection = net_tools.SSHConnection( + netns_=self.netns.id, ssh_key=self.ssh_key, user="root", host=guest_ip, on_error=self._dump_debug_information, ) + self._ssh_connections.append(connection) + return connection @property def ssh(self): diff --git a/tests/host_tools/network.py b/tests/host_tools/network.py index 7877b914d28..bfbb7e0fadb 100644 --- a/tests/host_tools/network.py +++ b/tests/host_tools/network.py @@ -5,13 +5,16 @@ import ipaddress import random import string -import subprocess from dataclasses import dataclass, field +from io import BytesIO from pathlib import Path -from tenacity import retry, retry_if_exception_type, stop_after_attempt, wait_fixed +import netns +from fabric import Connection +from tenacity import retry, stop_after_attempt, wait_fixed from framework import utils +from framework.utils import CommandReturn class SSHConnection: @@ -22,13 +25,14 @@ class SSHConnection: the hostname obtained from the MAC address, the username for logging into the image and the path of the ssh key. - This translates into an SSH connection as follows: - ssh -i ssh_key_path username@hostname + Uses the fabric library to establish a single connection once, and then + keep it alive for the lifetime of the microvm, to avoid spurious failures + due to reestablishing SSH connections for every single command sent. """ - def __init__(self, netns, ssh_key: Path, host, user, *, on_error=None): + def __init__(self, netns_, ssh_key: Path, host, user, *, on_error=None): """Instantiate a SSH client and connect to a microVM.""" - self.netns = netns + self.netns = netns_ self.ssh_key = ssh_key # check that the key exists and the permissions are 0o400 # This saves a lot of debugging time. @@ -40,26 +44,23 @@ def __init__(self, netns, ssh_key: Path, host, user, *, on_error=None): self._on_error = None - self.options = [ - "-o", - "LogLevel=ERROR", - "-o", - "ConnectTimeout=1", - "-o", - "StrictHostKeyChecking=no", - "-o", - "UserKnownHostsFile=/dev/null", - "-o", - "PreferredAuthentications=publickey", - "-i", - str(self.ssh_key), - ] + self._connection = Connection( + host, + user, + connect_timeout=1, + connect_kwargs={ + "key_filename": str(self.ssh_key), + "banner_timeout": 1, + "auth_timeout": 1, + }, + ) # _init_connection loops until it can connect to the guest # dumping debug state on every iteration is not useful or wanted, so # only dump it once if _all_ iterations fail. try: - self._init_connection() + with netns.NetNS(netns_): + self._init_connection() except Exception as exc: if on_error: on_error(exc) @@ -68,35 +69,15 @@ def __init__(self, netns, ssh_key: Path, host, user, *, on_error=None): self._on_error = on_error - @property - def user_host(self): - """remote address for in SSH format @""" - return f"{self.user}@{self.host}" - - def remote_path(self, path): - """Convert a path to remote""" - return f"{self.user_host}:{path}" - - def _scp(self, path1, path2, options): - """Copy files to/from the VM using scp.""" - self._exec(["scp", *options, path1, path2], check=True) - - def scp_put(self, local_path, remote_path, recursive=False): + def scp_put(self, local_path, remote_path): """Copy files to the VM using scp.""" - opts = self.options.copy() - if recursive: - opts.append("-r") - self._scp(local_path, self.remote_path(remote_path), opts) + self._connection.put(local_path, remote_path) - def scp_get(self, remote_path, local_path, recursive=False): + def scp_get(self, remote_path, local_path): """Copy files from the VM using scp.""" - opts = self.options.copy() - if recursive: - opts.append("-r") - self._scp(self.remote_path(remote_path), local_path, opts) + self._connection.get(remote_path, local_path) @retry( - retry=retry_if_exception_type(ChildProcessError), wait=wait_fixed(0.5), stop=stop_after_attempt(20), reraise=True, @@ -106,61 +87,43 @@ def _init_connection(self): Since we're connecting to a microVM we just started, we'll probably have to wait for it to boot up and start the SSH server. - We'll keep trying to execute a remote command that can't fail - (`/bin/true`), until we get a successful (0) exit code. + We'll keep trying to open the connection in a loop for 20 attempts with 0.5s + delay. Each connection attempt has a timeout of 1s. """ - self.check_output("true", timeout=100, debug=True) + self._connection.open() - def run(self, cmd_string, timeout=None, *, check=False, debug=False): + def run(self, cmd_string, timeout=None, *, check=False): """ Execute the command passed as a string in the ssh context. - - If `debug` is set, pass `-vvv` to `ssh`. Note that this will clobber stderr. """ - command = ["ssh", *self.options, self.user_host, cmd_string] - - if debug: - command.insert(1, "-vvv") - - return self._exec(command, timeout, check=check) - - def check_output(self, cmd_string, timeout=None, *, debug=False): - """Same as `run`, but raises an exception on non-zero return code of remote command""" - return self.run(cmd_string, timeout, check=True, debug=debug) - - def _exec(self, cmd, timeout=None, check=False): - """Private function that handles the ssh client invocation.""" - if self.netns is not None: - cmd = ["ip", "netns", "exec", self.netns] + cmd - try: - return utils.run_cmd(cmd, check=check, timeout=timeout) + # - warn=True means "do not raise exception on non-zero exit code, instead just log", e.g. + # it's the inverse of our "check" argument. + # - hide=True means "do not always log stdout/stderr" + # - in_stream=BytesIO(b"") is needed to immediately close stdin of the remote command + # without this, command that only exit after their stdin is closed would hang forever + # and this hang would bypass the pytest timeout. + result = self._connection.run( + cmd_string, + timeout=timeout, + warn=not check, + hide=True, + in_stream=BytesIO(b""), + ) except Exception as exc: if self._on_error: self._on_error(exc) raise + return CommandReturn(result.exited, result.stdout, result.stderr) - # pylint:disable=invalid-name - def Popen( - self, - cmd: str, - stdin=subprocess.DEVNULL, - stdout=subprocess.PIPE, - stderr=subprocess.PIPE, - **kwargs, - ) -> subprocess.Popen: - """Execute the command in the guest and return a Popen object. - - pop = uvm.ssh.Popen("while true; do echo $(date -Is) $RANDOM; sleep 1; done") - pop.stdout.read(16) - """ - cmd = ["ssh", *self.options, self.user_host, cmd] - if self.netns is not None: - cmd = ["ip", "netns", "exec", self.netns] + cmd - return subprocess.Popen( - cmd, stdin=stdin, stdout=stdout, stderr=stderr, **kwargs - ) + def check_output(self, cmd_string, timeout=None): + """Same as `run`, but raises an exception on non-zero return code of remote command""" + return self.run(cmd_string, timeout, check=True) + + def close(self): + """Closes this SSHConnection""" + self._connection.close() def mac_from_ip(ip_address): diff --git a/tests/integration_tests/functional/test_balloon.py b/tests/integration_tests/functional/test_balloon.py index 2c59fd2e814..d44f8b029df 100644 --- a/tests/integration_tests/functional/test_balloon.py +++ b/tests/integration_tests/functional/test_balloon.py @@ -4,9 +4,9 @@ import logging import time -from subprocess import TimeoutExpired import pytest +from invoke import CommandTimedOut from tenacity import retry, stop_after_attempt, wait_fixed from framework.utils import check_output, get_free_mem_ssh @@ -74,7 +74,7 @@ def make_guest_dirty_memory(ssh_connection, amount_mib=32): logger.error("while running: %s", cmd) logger.error("stdout: %s", stdout) logger.error("stderr: %s", stderr) - except TimeoutExpired: + except CommandTimedOut: # It's ok if this expires. Sometimes the SSH connection # gets killed by the OOM killer *after* the fillmem program # started. As a result, we can ignore timeouts here. diff --git a/tests/integration_tests/functional/test_pause_resume.py b/tests/integration_tests/functional/test_pause_resume.py index 3d0ac124c11..de22b1cb9d8 100644 --- a/tests/integration_tests/functional/test_pause_resume.py +++ b/tests/integration_tests/functional/test_pause_resume.py @@ -51,8 +51,9 @@ def test_pause_resume(uvm_nano): # Flush and reset metrics as they contain pre-pause data. microvm.flush_metrics() - # Verify guest is no longer active. - with pytest.raises(ChildProcessError): + # Verify guest is no longer active (by observing a failure to reconnect) + with pytest.raises(TimeoutError): + microvm.ssh.close() microvm.ssh.check_output("true") # Verify emulation was indeed paused and no events from either @@ -60,7 +61,8 @@ def test_pause_resume(uvm_nano): verify_net_emulation_paused(microvm.flush_metrics()) # Verify guest is no longer active. - with pytest.raises(ChildProcessError): + with pytest.raises(TimeoutError): + microvm.ssh.close() microvm.ssh.check_output("true") # Pausing the microVM when it is already `Paused` is allowed