diff --git a/src/vmm/src/devices/virtio/iovec.rs b/src/vmm/src/devices/virtio/iovec.rs index 0fe08a91659..1267412b3e2 100644 --- a/src/vmm/src/devices/virtio/iovec.rs +++ b/src/vmm/src/devices/virtio/iovec.rs @@ -1,8 +1,6 @@ // Copyright 2020 Amazon.com, Inc. or its affiliates. All Rights Reserved. // SPDX-License-Identifier: Apache-2.0 -use std::marker::PhantomData; - use libc::{c_void, iovec, size_t}; use utils::vm_memory::{Bitmap, GuestMemory, GuestMemoryMmap}; @@ -21,93 +19,6 @@ pub enum IoVecError { GuestMemory(#[from] utils::vm_memory::GuestMemoryError), } -// Describes a sub-region of a buffer described as a slice of `iovec` structs. -#[derive(Debug)] -struct IovVecSubregion<'a> { - // An iterator of the iovec items we are iterating - iovecs: Vec, - // Lifetime of the origin buffer - phantom: PhantomData<&'a iovec>, -} - -impl<'a> IovVecSubregion<'a> { - // Create a new `IovVecSubregion` - // - // Given an initial buffer (described as a collecetion of `iovec` structs) and a sub-region - // inside it, in the form of [offset; size] create a "sub-region" inside it, if the sub-region - // does not fall outside the original buffer, i.e. `offset` is not after the end of the original - // buffer. - // - // # Arguments - // - // * `iovecs` - A slice of `iovec` structures describing the buffer. - // * `len` - The total length of the buffer, i.e. the sum of the lengths of all `iovec` - // structs. - // * `offset` - The offset inside the buffer at which the sub-region starts. - // * `size` - The size of the sub-region - // - // # Returns - // - // If the sub-region is within the range of the buffer, i.e. the offset is not past the end of - // the buffer, it will return an `IovVecSubregion`. - fn new(iovecs: &'a [iovec], len: usize, mut offset: usize, mut size: usize) -> Option { - // Out-of-bounds sub-region - if offset >= len { - return None; - } - - // Empty sub-region - if size == 0 { - return None; - } - - let sub_regions = iovecs - .iter() - .filter_map(|iov| { - // If offset is bigger than the length of the current `iovec`, this `iovec` is not - // part of the sub-range - if offset >= iov.iov_len { - offset -= iov.iov_len; - return None; - } - - // No more `iovec` structs needed - if size == 0 { - return None; - } - - // SAFETY: This is safe because we chacked that `offset < iov.iov_len`. - let iov_base = unsafe { iov.iov_base.add(offset) }; - let iov_len = std::cmp::min(iov.iov_len - offset, size); - offset = 0; - size -= iov_len; - - Some(iovec { iov_base, iov_len }) - }) - .collect(); - - Some(Self { - iovecs: sub_regions, - phantom: PhantomData, - }) - } - - #[cfg(test)] - fn len(&self) -> usize { - self.iovecs.iter().fold(0, |acc, iov| acc + iov.iov_len) - } -} - -impl<'a> IntoIterator for IovVecSubregion<'a> { - type Item = iovec; - - type IntoIter = std::vec::IntoIter; - - fn into_iter(self) -> Self::IntoIter { - self.iovecs.into_iter() - } -} - /// This is essentially a wrapper of a `Vec` which can be passed to `libc::writev`. /// /// It describes a buffer passed to us by the guest that is scattered across multiple @@ -171,9 +82,51 @@ impl IoVecBuffer { self.vecs.len() } - /// Get a sub-region of the buffer - fn sub_region(&self, offset: usize, size: usize) -> Option { - IovVecSubregion::new(&self.vecs, self.len, offset, size) + // Read data from a subregion of the IoVecBuffer. + // + // This will read data into `buf` from a subregion that starts at `offset` and it is + // `buf.len()` long in the `buf` slice. Here we assume that [`offset`, `offset` + `buf.len()`] + // is within range, so it is the responsibility of the caller function to perform the necessary + // checks. + fn read_subregion(&self, buf: &mut [u8], mut offset: usize) -> usize { + debug_assert!(offset + buf.len() <= self.len()); + let mut bytes = 0; + let mut buf_ptr = buf.as_mut_ptr(); + for iov in self.vecs.iter() { + // We filled up all of `buf`, we 're done. + if bytes == buf.len() { + break; + } + + // While `offset` is past the end of an `iov`, this `iov` is not + // part of the subregion. + if offset >= iov.iov_len { + offset -= iov.iov_len; + continue; + } + + // SAFETY: This is safe because we checked that `offset < iov.iov_len`. + let src = unsafe { iov.iov_base.add(offset).cast::() }; + let len = std::cmp::min(iov.iov_len - offset, buf.len() - bytes); + offset = 0; + + // SAFETY: + // The call to `copy_nonoverlapping` is safe because: + // 1. `iov` describes a valid range in guest memory. The constructor of `IoVecBuffer` + // has checked that. + // 2. `buf_ptr` is a pointer inside the `buf` slice. We only get this pointer using + // safe methods. + // 3. Both pointers point to `u8` elements, so they're always aligned. + // 4. The memory regions these pointers point to are not overlapping. `src` points to + // guest physical memory and `buf_ptr` to Firecracker-owned memory. + unsafe { + std::ptr::copy_nonoverlapping(src, buf_ptr, len); + } + buf_ptr = buf[len..].as_mut_ptr(); + bytes += len; + } + + bytes } /// Reads a number of bytes from the `IoVecBuffer` starting at a given offset. @@ -185,32 +138,14 @@ impl IoVecBuffer { /// /// The number of bytes read (if any) pub fn read_at(&self, buf: &mut [u8], offset: usize) -> Option { - self.sub_region(offset, buf.len()).map(|sub_region| { - let mut bytes = 0; - let mut buf_ptr = buf.as_mut_ptr(); - - sub_region.into_iter().for_each(|iov| { - let src = iov.iov_base.cast::(); - // SAFETY: - // The call to `copy_nonoverlapping` is safe because: - // 1. `iov` is a an iovec describing a segment inside `Self`. `IoVecSubregion` has - // performed all necessary bound checks. - // 2. `buf_ptr` is a pointer inside the memory of `buf` - // 3. Both pointers point to `u8` elements, so they're always aligned. - // 4. The memory regions these pointers point to are not overlapping. `src` points - // to guest physical memory and `buf_ptr` to Firecracker-owned memory. - // - // `buf_ptr.add()` is safe because `IoVecSubregion` gives us `iovec` structs that - // their size adds up to `buf.len()`. - unsafe { - std::ptr::copy_nonoverlapping(src, buf_ptr, iov.iov_len); - buf_ptr = buf_ptr.add(iov.iov_len); - } - bytes += iov.iov_len; - }); - - bytes - }) + if offset < self.len() { + // Make sure we only read up to the end of the `IoVecBuffer`. + let size = buf.len().min(self.len() - offset); + Some(self.read_subregion(&mut buf[..size], offset)) + } else { + // If `offset` is past size, there's nothing to read. + None + } } } @@ -267,9 +202,51 @@ impl IoVecBufferMut { self.len } - /// Get a sub-region of the buffer - fn sub_region(&self, offset: usize, size: usize) -> Option { - IovVecSubregion::new(&self.vecs, self.len, offset, size) + // Write data into a subregion of the IoVecBuffer. + // + // This will write data from `buf` into a subregion that starts at `offset` and it is + // `buf.len()` long in the `buf` slice. Here we assume that [`offset`, `offset` + `buf.len()`] + // is within range, so it is the responsibility of the caller function to perform the necessary + // checks. + fn write_subregion(&self, buf: &[u8], mut offset: usize) -> usize { + debug_assert!(offset + buf.len() <= self.len()); + let mut bytes = 0; + let mut buf_ptr = buf.as_ptr(); + for iov in self.vecs.iter() { + // We read all of `buf`, we 're done. + if bytes == buf.len() { + break; + } + + // While `offset` is past the end of an `iov`, this `iov` is not + // part of the subregion. + if offset >= iov.iov_len { + offset -= iov.iov_len; + continue; + } + + // SAFETY: This is safe because we checked that `offset < iov.iov_len`. + let dst = unsafe { iov.iov_base.add(offset).cast::() }; + let len = std::cmp::min(iov.iov_len - offset, buf.len() - bytes); + offset = 0; + + // SAFETY: + // The call to `copy_nonoverlapping` is safe because: + // 1. `iov` describes a valid range in guest memory. The constructor of + // `IoVecBufferMut` has checked that. + // 2. `buf_ptr` is a pointer inside the `buf` slice. We only get this pointer using + // safe methods. + // 3. Both pointers point to `u8` elements, so they're always aligned. + // 4. The memory regions these pointers point to are not overlapping. `dst` points to + // guest physical memory and `buf_ptr` to Firecracker-owned memory. + unsafe { + std::ptr::copy_nonoverlapping(buf_ptr, dst, len); + } + buf_ptr = buf[len..].as_ptr(); + bytes += len; + } + + bytes } /// Writes a number of bytes into the `IoVecBufferMut` starting at a given offset. @@ -282,32 +259,14 @@ impl IoVecBufferMut { /// /// The number of bytes written (if any) pub fn write_at(&mut self, buf: &[u8], offset: usize) -> Option { - self.sub_region(offset, buf.len()).map(|sub_region| { - let mut bytes = 0; - let mut buf_ptr = buf.as_ptr(); - - sub_region.into_iter().for_each(|iov| { - let dst = iov.iov_base.cast::(); - // SAFETY: - // The call to `copy_nonoverlapping` is safe because: - // 1. `iov` is a an iovec describing a segment inside `Self`. `IoVecSubregion` has - // performed all necessary bound checks. - // 2. `buf_ptr` is a pointer inside the memory of `buf` - // 3. Both pointers point to `u8` elements, so they're always aligned. - // 4. The memory regions these pointers point to are not overlapping. `src` points - // to guest physical memory and `buf_ptr` to Firecracker-owned memory. - // - // `buf_ptr.add()` is safe because `IoVecSubregion` gives us `iovec` structs that - // their size adds up to `buf.len()`. - unsafe { - std::ptr::copy_nonoverlapping(buf_ptr, dst, iov.iov_len); - buf_ptr = buf_ptr.add(iov.iov_len); - } - bytes += iov.iov_len; - }); - - bytes - }) + if offset < self.len() { + // Make sure we only write up to the end of the `IoVecBufferMut`. + let size = buf.len().min(self.len() - offset); + Some(self.write_subregion(&buf[..size], offset)) + } else { + // We cannot write past the end of the `IoVecBufferMut`. + None + } } } @@ -564,53 +523,6 @@ mod tests { vq.dtable[2].check_data(&test_vec3); vq.dtable[3].check_data(&test_vec4); } - - #[test] - fn test_sub_range() { - let mem = default_mem(); - let (mut q, _) = read_only_chain(&mem); - let head = q.pop(&mem).unwrap(); - - // This is a descriptor chain with 4 elements 64 bytes long each, - // so 256 bytes long. - let iovec = IoVecBuffer::from_descriptor_chain(&mem, head).unwrap(); - - // Sub-ranges past the end of the buffer are invalid - assert!(iovec.sub_region(iovec.len(), 256).is_none()); - - // Getting an empty sub-range is invalid - assert!(iovec.sub_region(0, 0).is_none()); - - // Let's take the whole region - let sub = iovec.sub_region(0, iovec.len()).unwrap(); - assert_eq!(iovec.len(), sub.len()); - - // Let's take a valid sub-region that ends past the the end of the buffer - let sub = iovec.sub_region(128, 256).unwrap(); - assert_eq!(128, sub.len()); - - // Getting a sub-region that falls in a single iovec of the buffer - for i in 0..4 { - let sub = iovec.sub_region(10 + i * 64, 50).unwrap(); - assert_eq!(50, sub.len()); - assert_eq!(1, sub.iovecs.len()); - // SAFETY: All `iovecs` are 64 bytes long - assert_eq!(sub.iovecs[0].iov_base, unsafe { - iovec.vecs[i].iov_base.add(10) - }); - } - - // Get a sub-region that traverses more than one iovec of the buffer - let sub = iovec.sub_region(10, 100).unwrap(); - assert_eq!(100, sub.len()); - assert_eq!(2, sub.iovecs.len()); - // SAFETY: all `iovecs` are 64 bytes long - assert_eq!(sub.iovecs[0].iov_base, unsafe { - iovec.vecs[0].iov_base.add(10) - }); - - assert_eq!(sub.iovecs[1].iov_base, iovec.vecs[1].iov_base); - } } #[cfg(kani)]