diff --git a/tests/api2/test_smb_ioctl.py b/tests/api2/test_smb_ioctl.py new file mode 100644 index 000000000000..0dc4f0ad0ba8 --- /dev/null +++ b/tests/api2/test_smb_ioctl.py @@ -0,0 +1,141 @@ +import pytest +import random + +from dataclasses import asdict +from middlewared.test.integration.assets.account import user, group +from middlewared.test.integration.assets.pool import dataset + +from protocols import smb_connection +from protocols.SMB import ( + FsctlQueryFileRegionsRequest, + FileRegionInfo, + FileUsage, +) + +from samba import ntstatus +from samba import NTSTATUSError + +SHARE_NAME = 'ioctl_share' + + +@pytest.fixture(scope='module') +def setup_smb_tests(request): + with dataset('smbclient-testing', data={'share_type': 'SMB'}) as ds: + with user({ + 'username': 'smbuser', + 'full_name': 'smbuser', + 'group_create': True, + 'password': 'Abcd1234' + }) as u: + with smb_share(os.path.join('/mnt', ds), SHARE_NAME) as s: + try: + call('service.start', 'cifs') + yield {'dataset': ds, 'share': s, 'user': u} + finally: + call('service.stop', 'cifs') + + +def test__query_file_regions_normal(setup_smb_tests): + ds, share, smb_user = setup_smb_tests + with smb_connection( + share=SHARE_NAME, + username=smbuser['username'], + password='Abcd1234', + smb1=False + ) as c: + fd = c.create_file("file_regions_normal", "w") + buf = random.randbytes(1024) + + for offset in range(0, 128): + c.write(fd, offset=offset * 1024, data=buf) + + # First get with region omitted. This should return entire file + fsctl_request_null_region = FsctlQueryFileRegionsRequest(region=None) + fsctl_resp = c.fsctl(fd, fsctl_request_null_region) + + assert fsctl_resp.flags == 0 + assert fsctl_resp.total_region_entry_count == 1 + assert fsctl_resp.region_entry_count == 1 + assert fsctl_resp.reserved == 1 + assert fsctl_resp.region is not None + + assert fsctl_resp.region.offset == 0 + assert fsctl_resp.region.length == 128 * 1024 + assert fsctl_resp.region.desired_usage == FileUsage.VALID_CACHED_DATA + assert fsctl_resp.region.reserved == 0 + + # Take same region we retrieved from server and use with new request + fsctl_request_with_region = FsctlQueryFileRegionsRequest(region=fsctl_resp.region) + fsctl_resp2 = c.fsctl(fd, fsctl_request_with_region) + + assert asdict(fsctl_resp) == asdict(fsctl_resp2) + + +def test__query_file_regions_with_holes(setup_smb_tests): + ds, share, smb_user = setup_smb_tests + with smb_connection( + share=SHARE_NAME, + username=smbuser['username'], + password='Abcd1234', + smb1=False + ) as c: + fd = c.create_file("file_regions_normal", "w") + buf = random.randbytes(4096) + + # insert some holes in file + for offset in range(0, 130): + if offset % 2 == 0: + c.write(fd, offset=offset * 4096, data=buf) + + fsctl_request_null_region = FsctlQueryFileRegionsRequest(region=None) + fsctl_resp = c.fsctl(fd, fsctl_request_null_region) + + assert fsctl_resp.flags == 0 + assert fsctl_resp.total_region_entry_count == 1 + assert fsctl_resp.region_entry_count == 1 + assert fsctl_resp.reserved == 1 + assert fsctl_resp.region is not None + + assert fsctl_resp.region.offset == 0 + assert fsctl_resp.region.length == 128 * 4096 + assert fsctl_resp.region.desired_usage == FileUsage.VALID_CACHED_DATA + assert fsctl_resp.region.reserved == 0 + + # Take same region we retrieved from server and use with new request + fsctl_request_with_region = FsctlQueryFileRegionsRequest(region=fsctl_resp.region) + fsctl_resp2 = c.fsctl(fd, fsctl_request_with_region) + + assert asdict(fsctl_resp) == asdict(fsctl_resp2) + + +def test__query_file_regions_trailing_zeroes(setup_smb_tests): + """ + FileRegionInfo should contain Valid Data Length which is length in bytes of data + that has been written to the file in the specified region, from the beginning of + the region untile the last byte that has not been zeroed or uninitialized + """ + ds, share, smb_user = setup_smb_tests + with smb_connection( + share=SHARE_NAME, + username=smbuser['username'], + password='Abcd1234', + smb1=False + ) as c: + fd = c.create_file("file_regions_normal", "w") + buf = random.randbytes(4096) + + # insert a hole in file + c.write(fd, offset=0, data=buf) + c.write(fd, offset=8192, data=buf) + + # requesting entire file should give full length + fsctl_request_null_region = FsctlQueryFileRegionsRequest(region=None) + fsctl_resp = c.fsctl(fd, fsctl_request_null_region) + assert fsctl_resp.region.length == 12288 + + # requesting region that has hole at end of it should only give data length + limited_region = FileRegionInfo(offset=0, length=8192) + fsctl_request_limited_region = FsctlQueryFileRegionsRequest(region=limited_region) + fsctl_resp = c.fsctl(fd, fsctl_request_limited_region) + assert fsctl_resp.region.offset == 0 + assert fsctl_resp.region.length == 4096 diff --git a/tests/protocols/smb_proto.py b/tests/protocols/smb_proto.py index 7fe6fdd1566c..83192cd07189 100644 --- a/tests/protocols/smb_proto.py +++ b/tests/protocols/smb_proto.py @@ -1,6 +1,8 @@ import sys import enum +import struct import subprocess +from dataclasses import dataclass from functions import SRVTarget, get_host_ip from platform import system @@ -32,6 +34,60 @@ libsmb_has_rename = 'rename' in dir(libsmb.Conn) +class Fsctl(enum.IntEnum): + QUERY_FILE_REGIONS = 0x00090284 + + +class FileUsage(enum.IntEnum): + VALID_CACHED_DATA = 0x00000001 # NTFS + VALID_NONCACHED_DATA = 0x00000002 # REFS + + +@dataclass(frozen=True) +class FileRegionInfo: + """ MS-FSCC 2.3.56.1 """ + offset: int + length: int + desired_usage: FileUsage = FileUsage.VALID_CACHED_DATA + reserved: int = 0 # by protocol must be zero + + +@dataclass(frozen=True) +class FsctlQueryFileRegionsReply: + """ MS-FSCC 2.3.56 """ + flags: int # by protocol must be zero + total_region_entry_count: int + region_entry_count: int + reserved: int # by protocol must be zero + region: FileRegionInfo + + +@dataclass(frozen=True) +class FsctlQueryFileRegionsRequest: + """ MS-FSCC 2.3.55 """ + region_info: FileRegionInfo | None + + def __post_init__(self): + self.fsctl = Fsctl.QUERY_FILE_REGIONS + + def pack(self): + if self.region_info is None: + return b'' + + return struct.pack( + '