rusticl: Use the from_raw_parts wrappers

Deduplicates some safety checks and ensures we didn't forget one.

Reviewed-by: Karol Herbst <kherbst@redhat.com>
Part-of: <https://gitlab.freedesktop.org/mesa/mesa/-/merge_requests/26157>
This commit is contained in:
LingMan 2023-11-13 05:29:47 +01:00 committed by Marge Bot
parent 471d89c4fd
commit 76996e2a94
2 changed files with 14 additions and 53 deletions

View file

@ -2466,11 +2466,6 @@ fn enqueue_svm_memcpy_impl(
return Err(CL_INVALID_OPERATION);
}
// CL_INVALID_VALUE if dst_ptr or src_ptr is NULL.
if dst_ptr.is_null() || src_ptr.is_null() {
return Err(CL_INVALID_VALUE);
}
// CL_MEM_COPY_OVERLAP if the values specified for dst_ptr, src_ptr and size result in an
// overlapping copy.
let dst_ptr_addr = dst_ptr as usize;
@ -2481,14 +2476,6 @@ fn enqueue_svm_memcpy_impl(
return Err(CL_MEM_COPY_OVERLAP);
}
// Not technically guaranteed by the OpenCL spec, but required by `from_raw_parts` below.
if isize::try_from(size).is_err()
|| src_ptr_addr.checked_add(size).is_none()
|| dst_ptr_addr.checked_add(size).is_none()
{
return Err(CL_INVALID_VALUE);
}
// CAST: We have no idea about the type or initialization status of these bytes.
// MaybeUninit<u8> is the safe bet.
let src_ptr = src_ptr.cast::<MaybeUninit<u8>>();
@ -2497,16 +2484,14 @@ fn enqueue_svm_memcpy_impl(
// MaybeUninit<u8> is the safe bet.
let dst_ptr = dst_ptr.cast::<MaybeUninit<u8>>();
// SAFETY: We've checked above that the pointer is not NULL, that the size isn't excessive, and
// that addr + size doesn't overflow. It is up to the application to ensure the memory is valid
// to read for `size` bytes and that it doesn't modify it until the command has completed.
let src = unsafe { slice::from_raw_parts(src_ptr, size) };
// SAFETY: It is up to the application to ensure the memory is valid to read for `size` bytes
// and that it doesn't modify it until the command has completed.
let src = unsafe { cl_slice::from_raw_parts(src_ptr, size)? };
// SAFETY: We've checked above that the pointer is not NULL, that the size isn't excessive, and
// that addr + size doesn't overflow. We've also ensured there's no aliasing between src and
// dst. It is up to the application to ensure the memory is valid to read and write for `size`
// bytes and that it doesn't modify or read from it until the command has completed.
let dst = unsafe { slice::from_raw_parts_mut(dst_ptr, size) };
// SAFETY: We've ensured there's no aliasing between src and dst. It is up to the application
// to ensure the memory is valid to read and write for `size` bytes and that it doesn't modify
// or read from it until the command has completed.
let dst = unsafe { cl_slice::from_raw_parts_mut(dst_ptr, size)? };
create_and_queue(
q,
@ -2582,23 +2567,12 @@ fn enqueue_svm_mem_fill_impl(
) -> CLResult<()> {
let q = command_queue.get_arc()?;
let evs = event_list_from_cl(&q, num_events_in_wait_list, event_wait_list)?;
let svm_ptr_addr = svm_ptr as usize;
// CL_INVALID_OPERATION if the device associated with command queue does not support SVM.
if !q.device.svm_supported() {
return Err(CL_INVALID_OPERATION);
}
// CL_INVALID_VALUE if svm_ptr is NULL.
if svm_ptr.is_null() {
return Err(CL_INVALID_VALUE);
}
// CL_INVALID_VALUE if svm_ptr is not aligned to pattern_size bytes.
if svm_ptr_addr & (pattern_size - 1) != 0 {
return Err(CL_INVALID_VALUE);
}
// CL_INVALID_VALUE if pattern is NULL [...]
if pattern.is_null() {
return Err(CL_INVALID_VALUE);
@ -2609,11 +2583,6 @@ fn enqueue_svm_mem_fill_impl(
return Err(CL_INVALID_VALUE);
}
// Not technically guaranteed by the OpenCL spec, but required by `from_raw_parts_mut` below.
if isize::try_from(size).is_err() || svm_ptr_addr.checked_add(size).is_none() {
return Err(CL_INVALID_VALUE);
}
// The provided `$bytesize` must equal `pattern_size`.
macro_rules! generate_fill_closure {
($bytesize:literal) => {{
@ -2672,16 +2641,12 @@ fn enqueue_svm_mem_fill_impl(
// the same layout as `Pattern`.
let svm_ptr = svm_ptr.cast::<MaybeUninit<Pattern>>();
// SAFETY: We've checked that `svm_ptr` is not NULL above. It is otherwise the calling
// application's responsibility to ensure that it is valid for reads and writes up to
// `size` bytes.
// SAFETY: It is the calling application's responsibility to ensure that `svm_ptr` is
// valid for reads and writes up to `size` bytes.
// Since `pattern_size == mem::size_of::<Pattern>()` and `MaybeUninit<Pattern>` has the
// same layout as `Pattern`, we know that
// `size / pattern_size * mem::size_of<MaybeUninit<Pattern>>` equals `size`.
//
// We've also checked that `svm_ptr` has an alignment of `pattern_size` which fulfills
// `Pattern`'s requirement.
//
// Since we're creating a `&[MaybeUninit<Pattern>]` the initialization status does not
// matter.
//
@ -2689,10 +2654,7 @@ fn enqueue_svm_mem_fill_impl(
// particular, since we've made a copy of `pattern`, it doesn't matter if the memory
// region referenced by `pattern` aliases the one referenced by this slice. It is up to
// the application not to access it at all until this command has been completed.
//
// We've checked that `size` does not exceed `isize::MAX` and that `svm_ptr + size`
// does not overflow above.
let svm_slice = unsafe { slice::from_raw_parts_mut(svm_ptr, size / pattern_size) };
let svm_slice = unsafe { cl_slice::from_raw_parts_mut(svm_ptr, size / pattern_size)? };
Box::new(move |_, _| {
for x in svm_slice {
@ -2935,20 +2897,20 @@ fn enqueue_svm_migrate_mem(
return Err(CL_INVALID_OPERATION);
}
// CL_INVALID_VALUE if num_svm_pointers is zero or svm_pointers is NULL.
if num_svm_pointers == 0 || svm_pointers.is_null() {
// CL_INVALID_VALUE if num_svm_pointers is zero
if num_svm_pointers == 0 {
return Err(CL_INVALID_VALUE);
}
let num_svm_pointers = num_svm_pointers as usize;
// SAFETY: Just hoping the application is alright.
let mut svm_pointers =
unsafe { slice::from_raw_parts(svm_pointers, num_svm_pointers) }.to_owned();
unsafe { cl_slice::from_raw_parts(svm_pointers, num_svm_pointers)? }.to_owned();
// if sizes is NULL, every allocation containing the pointers need to be migrated
let mut sizes = if sizes.is_null() {
vec![0; num_svm_pointers]
} else {
unsafe { slice::from_raw_parts(sizes, num_svm_pointers) }.to_owned()
unsafe { cl_slice::from_raw_parts(sizes, num_svm_pointers)? }.to_owned()
};
// CL_INVALID_VALUE if sizes[i] is non-zero range [svm_pointers[i], svm_pointers[i]+sizes[i]) is

View file

@ -366,7 +366,6 @@ pub fn check_copy_overlap(
true
}
#[allow(dead_code)]
pub mod cl_slice {
use crate::api::util::CLResult;
use mesa_rust_util::ptr::addr;