simplifies PinnedVec implementation (#16382)

This commit is contained in:
behzad nouri
2021-04-08 10:40:30 +00:00
committed by GitHub
parent 3645092a52
commit 0e262aab3d

View File

@ -11,58 +11,52 @@ use rand::seq::SliceRandom;
use rand::Rng; use rand::Rng;
use rayon::prelude::*; use rayon::prelude::*;
use std::ops::{Index, IndexMut}; use std::ops::{Index, IndexMut};
use std::slice::SliceIndex; use std::slice::{Iter, IterMut, SliceIndex};
use std::sync::{Arc, Weak}; use std::sync::Weak;
use std::os::raw::c_int; use std::os::raw::c_int;
const CUDA_SUCCESS: c_int = 0; const CUDA_SUCCESS: c_int = 0;
pub fn pin<T>(_mem: &mut Vec<T>) { fn pin<T>(_mem: &mut Vec<T>) {
if let Some(api) = perf_libs::api() { if let Some(api) = perf_libs::api() {
unsafe { use std::ffi::c_void;
use core::ffi::c_void;
use std::mem::size_of; use std::mem::size_of;
let err = (api.cuda_host_register)( let ptr = _mem.as_mut_ptr();
_mem.as_mut_ptr() as *mut c_void, let size = _mem.capacity().saturating_mul(size_of::<T>());
_mem.capacity().saturating_mul(size_of::<T>()), let err = unsafe {
0, (api.cuda_host_register)(ptr as *mut c_void, size, /*flags=*/ 0)
); };
if err != CUDA_SUCCESS { if err != CUDA_SUCCESS {
panic!( panic!(
"cudaHostRegister error: {} ptr: {:?} bytes: {}", "cudaHostRegister error: {} ptr: {:?} bytes: {}",
err, err, ptr, size,
_mem.as_ptr(),
_mem.capacity().saturating_mul(size_of::<T>()),
); );
} }
} }
} }
}
pub fn unpin<T>(_mem: *mut T) { fn unpin<T>(_mem: *mut T) {
if let Some(api) = perf_libs::api() { if let Some(api) = perf_libs::api() {
unsafe { use std::ffi::c_void;
use core::ffi::c_void;
let err = (api.cuda_host_unregister)(_mem as *mut c_void); let err = unsafe { (api.cuda_host_unregister)(_mem as *mut c_void) };
if err != CUDA_SUCCESS { if err != CUDA_SUCCESS {
panic!("cudaHostUnregister returned: {} ptr: {:?}", err, _mem); panic!("cudaHostUnregister returned: {} ptr: {:?}", err, _mem);
} }
} }
} }
}
// A vector wrapper where the underlying memory can be // A vector wrapper where the underlying memory can be
// page-pinned. Controlled by flags in case user only wants // page-pinned. Controlled by flags in case user only wants
// to pin in certain circumstances. // to pin in certain circumstances.
#[derive(Debug)] #[derive(Debug, Default)]
pub struct PinnedVec<T: Default + Clone + Sized> { pub struct PinnedVec<T: Default + Clone + Sized> {
x: Vec<T>, x: Vec<T>,
pinned: bool, pinned: bool,
pinnable: bool, pinnable: bool,
recycler: Option<Weak<RecyclerX<PinnedVec<T>>>>, recycler: Weak<RecyclerX<PinnedVec<T>>>,
} }
impl<T: Default + Clone + Sized> Reset for PinnedVec<T> { impl<T: Default + Clone + Sized> Reset for PinnedVec<T> {
@ -74,21 +68,10 @@ impl<T: Default + Clone + Sized> Reset for PinnedVec<T> {
self.resize(size_hint, T::default()); self.resize(size_hint, T::default());
} }
fn set_recycler(&mut self, recycler: Weak<RecyclerX<Self>>) { fn set_recycler(&mut self, recycler: Weak<RecyclerX<Self>>) {
self.recycler = Some(recycler); self.recycler = recycler;
} }
fn unset_recycler(&mut self) { fn unset_recycler(&mut self) {
self.recycler = None; self.recycler = Weak::default();
}
}
impl<T: Clone + Default + Sized> Default for PinnedVec<T> {
fn default() -> Self {
Self {
x: Vec::new(),
pinned: false,
pinnable: false,
recycler: None,
}
} }
} }
@ -99,29 +82,10 @@ impl<T: Clone + Default + Sized> From<PinnedVec<T>> for Vec<T> {
pinned_vec.pinned = false; pinned_vec.pinned = false;
} }
pinned_vec.pinnable = false; pinned_vec.pinnable = false;
pinned_vec.recycler = None; pinned_vec.recycler = Weak::default();
std::mem::take(&mut pinned_vec.x) std::mem::take(&mut pinned_vec.x)
} }
} }
pub struct PinnedIter<'a, T>(std::slice::Iter<'a, T>);
pub struct PinnedIterMut<'a, T>(std::slice::IterMut<'a, T>);
impl<'a, T: Clone + Default + Sized> Iterator for PinnedIter<'a, T> {
type Item = &'a T;
fn next(&mut self) -> Option<Self::Item> {
self.0.next()
}
}
impl<'a, T: Clone + Default + Sized> Iterator for PinnedIterMut<'a, T> {
type Item = &'a mut T;
fn next(&mut self) -> Option<Self::Item> {
self.0.next()
}
}
impl<T: Clone + Default + Sized> IntoIterator for PinnedVec<T> { impl<T: Clone + Default + Sized> IntoIterator for PinnedVec<T> {
type Item = T; type Item = T;
@ -132,21 +96,12 @@ impl<T: Clone + Default + Sized> IntoIterator for PinnedVec<T> {
} }
} }
impl<'a, T: Clone + Default + Sized> IntoIterator for &'a mut PinnedVec<T> {
type Item = &'a T;
type IntoIter = PinnedIter<'a, T>;
fn into_iter(self) -> Self::IntoIter {
PinnedIter(self.x.iter())
}
}
impl<'a, T: Clone + Default + Sized> IntoIterator for &'a PinnedVec<T> { impl<'a, T: Clone + Default + Sized> IntoIterator for &'a PinnedVec<T> {
type Item = &'a T; type Item = &'a T;
type IntoIter = PinnedIter<'a, T>; type IntoIter = Iter<'a, T>;
fn into_iter(self) -> Self::IntoIter { fn into_iter(self) -> Self::IntoIter {
PinnedIter(self.x.iter()) self.x.iter()
} }
} }
@ -167,12 +122,12 @@ impl<T: Clone + Default + Sized, I: SliceIndex<[T]>> IndexMut<I> for PinnedVec<T
} }
impl<T: Clone + Default + Sized> PinnedVec<T> { impl<T: Clone + Default + Sized> PinnedVec<T> {
pub fn iter(&self) -> PinnedIter<T> { pub fn iter(&self) -> Iter<'_, T> {
PinnedIter(self.x.iter()) self.x.iter()
} }
pub fn iter_mut(&mut self) -> PinnedIterMut<T> { pub fn iter_mut(&mut self) -> IterMut<'_, T> {
PinnedIterMut(self.x.iter_mut()) self.x.iter_mut()
} }
pub fn capacity(&self) -> usize { pub fn capacity(&self) -> usize {
@ -237,18 +192,12 @@ impl<T: Clone + Default + Sized> PinnedVec<T> {
x: source, x: source,
pinned: false, pinned: false,
pinnable: false, pinnable: false,
recycler: None, recycler: Weak::default(),
} }
} }
pub fn with_capacity(capacity: usize) -> Self { pub fn with_capacity(capacity: usize) -> Self {
let x = Vec::with_capacity(capacity); Self::from_vec(Vec::with_capacity(capacity))
Self {
x,
pinned: false,
pinnable: false,
recycler: None,
}
} }
pub fn is_empty(&self) -> bool { pub fn is_empty(&self) -> bool {
@ -332,10 +281,6 @@ impl<T: Clone + Default + Sized> PinnedVec<T> {
self.pinned = true; self.pinned = true;
} }
} }
fn recycler_ref(&self) -> Option<Arc<RecyclerX<Self>>> {
let r = self.recycler.as_ref()?;
r.upgrade()
}
} }
impl<T: Clone + Default + Sized> Clone for PinnedVec<T> { impl<T: Clone + Default + Sized> Clone for PinnedVec<T> {
@ -364,12 +309,9 @@ impl<T: Clone + Default + Sized> Clone for PinnedVec<T> {
impl<T: Sized + Default + Clone> Drop for PinnedVec<T> { impl<T: Sized + Default + Clone> Drop for PinnedVec<T> {
fn drop(&mut self) { fn drop(&mut self) {
if let Some(strong) = self.recycler_ref() { if let Some(recycler) = self.recycler.upgrade() {
let mut vec = PinnedVec::default(); recycler.recycle(std::mem::take(self));
std::mem::swap(&mut vec, self); } else if self.pinned {
strong.recycle(vec);
}
if self.pinned {
unpin(self.x.as_mut_ptr()); unpin(self.x.as_mut_ptr());
} }
} }