Skip to content

Commit

Permalink
Better handle situations where the parent directory doesn't exist dur…
Browse files Browse the repository at this point in the history
…ing sftp (#175)

This change adds some additional logic to not only handle when parent directories
don't exist, but also constructs the destination path when only a directory is given

Fixes #169
  • Loading branch information
JacobCallahan authored Dec 21, 2022
1 parent fca5b61 commit 44f3a22
Show file tree
Hide file tree
Showing 3 changed files with 27 additions and 10 deletions.
30 changes: 22 additions & 8 deletions broker/session.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,6 +81,8 @@ def sftp_read(self, source, destination=None):
"""read a remote file into a local destination"""
if not destination:
destination = source
elif destination.endswith("/"):
destination = destination + Path(source).name
# create the destination path if it doesn't exist
destination = Path(destination)
destination.parent.mkdir(parents=True, exist_ok=True)
Expand All @@ -94,30 +96,38 @@ def sftp_read(self, source, destination=None):
for size, data in remote:
local.write(data)

def sftp_write(self, source, destination=None):
def sftp_write(self, source, destination=None, ensure_dir=True):
"""sftp write a local file to a remote destination"""
if not destination:
destination = source
elif destination.endswith("/"):
destination = destination + Path(source).name
data = Path(source).read_bytes()
if ensure_dir:
self.run(f"mkdir -p {Path(destination).absolute().parent}")
sftp = self.session.sftp_init()
with sftp.open(destination, FILE_FLAGS, SFTP_MODE) as remote:
remote.write(data)

def remote_copy(self, source, dest_host):
def remote_copy(self, source, dest_host, ensure_dir=True):
"""Copy a file from this host to another"""
sftp_down = self.session.sftp_init()
sftp_up = dest_host.session.session.sftp_init()
if ensure_dir:
dest_host.run(f"mkdir -p {Path(source).absolute().parent}")
with sftp_down.open(
source, ssh2_sftp.LIBSSH2_FXF_READ, ssh2_sftp.LIBSSH2_SFTP_S_IRUSR
) as download:
with sftp_up.open(source, FILE_FLAGS, SFTP_MODE) as upload:
for size, data in download:
upload.write(data)

def scp_write(self, source, destination=None):
def scp_write(self, source, destination=None, ensure_dir=True):
"""scp write a local file to a remote destination"""
if not destination:
destination = source
elif destination.endswith("/"):
destination = destination + Path(source).name
fileinfo = os.stat(source)
chan = self.session.scp_send64(
destination,
Expand All @@ -126,6 +136,8 @@ def scp_write(self, source, destination=None):
fileinfo.st_mtime,
fileinfo.st_atime,
)
if ensure_dir:
self.run(f"mkdir -p {Path(destination).absolute().parent}")
with open(source, "rb") as local:
for data in local:
chan.write(data)
Expand Down Expand Up @@ -221,7 +233,7 @@ def disconnect(self):
"""Needed for simple compatability with Session"""
pass

def sftp_write(self, source, destination=None):
def sftp_write(self, source, destination=None, ensure_dir=True):
"""Add one of more files to the container"""
# ensure source is a list of Path objects
if not isinstance(source, list):
Expand All @@ -232,15 +244,17 @@ def sftp_write(self, source, destination=None):
for src in source:
if not Path(src).exists():
raise FileNotFoundError(src)
destination = Path(destination) or source[0].parent
destination = destination or f"{source[0].parent}/"
# Files need to be added to a tarfile
with helpers.temporary_tar(source) as tar:
logger.debug(
f"{self._cont_inst.hostname} adding file(s) {source} to {destination}"
)
# if the destination is a file, create the parent path
if destination.is_file():
self.execute(f"mkdir -p {destination.parent}")
if ensure_dir:
if destination.endswith("/"):
self.run(f"mkdir -m 666 -p {destination}")
else:
self.run(f"mkdir -m 666 -p {Path(destination).parent}")
self._cont_inst._cont_inst.put_archive(str(destination), tar.read_bytes())

def sftp_read(self, source, destination=None):
Expand Down
4 changes: 2 additions & 2 deletions tests/functional/test_containers.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,6 @@ def test_container_e2e_mp():
res = c_host.execute("hostname")
assert res.stdout.strip() == c_host.hostname
# Test that a file can be uploaded to the container
c_host.session.sftp_write("broker_settings.yaml", "/root")
res = c_host.execute("ls")
c_host.session.sftp_write("broker_settings.yaml", "/tmp/fake/")
res = c_host.execute("ls /tmp/fake")
assert "broker_settings.yaml" in res.stdout
3 changes: 3 additions & 0 deletions tests/functional/test_satlab.py
Original file line number Diff line number Diff line change
Expand Up @@ -68,3 +68,6 @@ def test_tower_host():
with Broker(workflow="deploy-base-rhel") as r_host:
res = r_host.execute("hostname")
assert res.stdout.strip() == r_host.hostname
r_host.session.sftp_write("broker_settings.yaml", "/tmp/fake/")
res = r_host.execute("ls /tmp/fake")
assert "broker_settings.yaml" in res.stdout

0 comments on commit 44f3a22

Please sign in to comment.