diff options
author | Stefan Kreutz <mail@skreutz.com> | 2022-12-18 21:30:46 +0100 |
---|---|---|
committer | Stefan Kreutz <mail@skreutz.com> | 2022-12-18 21:30:46 +0100 |
commit | f2181e1c8d55d4da0e298685f7805ae0c17cf6ae (patch) | |
tree | 5f5d61e83397d435f30a572404e21f8d400ac481 /src | |
download | parseq-0.1.0.tar |
Add initial implementationparseq-0.1.0
Diffstat (limited to 'src')
-rw-r--r-- | src/lib.rs | 470 |
1 files changed, 470 insertions, 0 deletions
diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..3442f58 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,470 @@ +//! Parallel sequential iterator. +//! +//! This crate implements an extension trait [`ParallelIterator`] adding parallel sequential +//! mapping to the standard [`Iterator`] trait. +//! +//! # Example +//! +//! ``` +//! use std::time::Duration; +//! use parseq::ParallelIterator; +//! +//! let mut iter = (0..3) +//! .into_iter() +//! .map_parallel(|i| { +//! // Insert heavy computation here ... +//! std::thread::sleep(Duration::from_millis((i % 3) * 10)); +//! i +//! }); +//! +//! assert_eq!(iter.next(), Some(0)); +//! assert_eq!(iter.next(), Some(1)); +//! assert_eq!(iter.next(), Some(2)); +//! assert_eq!(iter.next(), None); +//! ``` +//! +//! # Rationale +//! +//! This library was created to process a large number of files returned by +//! [walkdir](https://crates.io/crates/walkdir) in parallel, in order, and in constant space. It's +//! API and dependencies were kept to a minimum to ease maintenance. +//! +//! If you don't care about the order of the returned iterator you'll probably want to use +//! [rayon](https://crates.io/crates/rayon) instead. If you do care about the order, take a look at +//! [pariter](https://crates.io/crates/pariter). The latter provides more functionality than this +//! crate and predates it. + +#![forbid(unsafe_code)] +#![warn(missing_docs)] + +use std::{collections::HashMap, num::NonZeroUsize}; + +use crossbeam_channel::{Receiver, Select, Sender, TryRecvError}; + +/// An extension trait adding parallel sequential mapping to the standard [`Iterator`] trait. +pub trait ParallelIterator { + /// Creates an iterator which applies a given closure to each element in parallel. + /// + /// This function is a multi-threaded equivalent of [`Iterator::map`]. It uses up to + /// [`available_parallelism`](std::thread::available_parallelism) threads and buffers a finite + /// number of items. Use [`map_parallel_limit`](ParallelIterator) if you want to set + /// parallelism and space limits. + /// + /// The returned iterator + /// + /// * preserves the order of the original iterator + /// * is lazy in the sense that it doesn't consume from the original iterator before [`next`](`Iterator::next`) is called for the first time + /// * doesn't fuse the original iterator + /// * uses constant space: linear in `threads` and `buffer_size`, not in the length of the possibly infinite original iterator + /// * propagates panics from the given closure + /// + /// # Example + /// + /// ``` + /// use std::time::Duration; + /// use parseq::ParallelIterator; + /// + /// let mut iter = (0..3) + /// .into_iter() + /// .map_parallel(|i| { + /// std::thread::sleep(Duration::from_millis((i % 3) * 10)); + /// i + /// }); + /// + /// assert_eq!(iter.next(), Some(0)); + /// assert_eq!(iter.next(), Some(1)); + /// assert_eq!(iter.next(), Some(2)); + /// assert_eq!(iter.next(), None); + /// ``` + fn map_parallel<B, F>(self, f: F) -> ParallelMap<Self, B> + where + Self: Iterator + Sized, + Self::Item: Send + 'static, + F: FnMut(Self::Item) -> B + Send + Clone + 'static, + B: Send + 'static, + { + let threads = std::thread::available_parallelism() + .map(NonZeroUsize::get) + .unwrap_or(1); + let buffer_size = threads.saturating_mul(16); + self.map_parallel_limit(threads, buffer_size, f) + } + + /// Creates an iterator which applies a given closure to each element in parallel. + /// + /// This function is a multi-threaded equivalent of [`Iterator::map`]. It uses up to the given + /// number of `threads` and buffers up to `buffer_size` items. If `threads` is zero, up to + /// [`available_parallelism`](std::thread::available_parallelism) threads are used instead. The + /// `buffer_size` should be greater than the number of threads. A `buffer_size < 2` effectively + /// results in single-threaded processing. + /// + /// The returned iterator + /// + /// * preserves the order of the original iterator + /// * is lazy in the sense that it doesn't consume from the original iterator before [`next`](`Iterator::next`) is called for the first time + /// * doesn't fuse the original iterator + /// * uses constant space: linear in `threads` and `buffer_size`, not in the length of the possibly infinite original iterator + /// * propagates panics from the given closure + /// + /// # Example + /// + /// ``` + /// use std::time::Duration; + /// use parseq::ParallelIterator; + /// + /// let mut iter = (0..3) + /// .into_iter() + /// .map_parallel_limit(2, 16, |i| { + /// std::thread::sleep(Duration::from_millis((i % 3) * 10)); + /// i + /// }); + /// + /// assert_eq!(iter.next(), Some(0)); + /// assert_eq!(iter.next(), Some(1)); + /// assert_eq!(iter.next(), Some(2)); + /// assert_eq!(iter.next(), None); + /// ``` + fn map_parallel_limit<B, F>( + self, + threads: usize, + buffer_size: usize, + f: F, + ) -> ParallelMap<Self, B> + where + Self: Iterator + Sized, + Self::Item: Send + 'static, + F: FnMut(Self::Item) -> B + Send + Clone + 'static, + B: Send + 'static, + { + ParallelMap::new(self, threads, buffer_size, f) + } +} + +impl<I> ParallelIterator for I where I: Iterator {} + +/// An iterator that maps the elements of another iterator in parallel. +/// +/// This struct is created by the [`map_parallel`](ParallelIterator::map_parallel) method on the +/// [`ParallelIterator`] trait. +pub struct ParallelMap<I, B> +where + I: Iterator, +{ + /// Wrapped iterator. + iter: I, + + /// Maximum number of in-flight input and output items. + buffer_size: usize, + + /// Input sender. + in_tx: Sender<(usize, I::Item)>, + + /// Index of the next input. + in_i: usize, + + /// Output receiver. + out_rx: Receiver<(usize, B)>, + + /// Index of the next output. + out_i: usize, + + /// Worker thread panic receiver. + panic_rx: Receiver<()>, + + /// Output buffer. + buf: HashMap<usize, Option<B>>, +} + +impl<I, B> ParallelMap<I, B> +where + I: Iterator, +{ + /// Returns the number of in-flight items: queued input items and buffered output items. + fn inflight(&self) -> usize { + if self.in_i >= self.out_i { + self.in_i - self.out_i + } else { + (self.in_i + 1) + (usize::MAX - self.out_i) + } + } +} + +impl<I, B> ParallelMap<I, B> +where + I: Iterator, + I::Item: Send + 'static, + B: Send + 'static, +{ + fn new<F>(iter: I, threads: usize, buffer_size: usize, f: F) -> Self + where + F: FnMut(I::Item) -> B + Send + Clone + 'static, + { + let threads = if threads > 0 { + threads + } else { + std::thread::available_parallelism() + .map(NonZeroUsize::get) + .unwrap_or(1) + }; + let buffer_size = if buffer_size > 0 { buffer_size } else { 1 }; + + let (in_tx, in_rx) = crossbeam_channel::bounded(buffer_size); + let (out_tx, out_rx) = crossbeam_channel::bounded(buffer_size); + let (panic_tx, panic_rx) = crossbeam_channel::bounded(threads); + + for _ in 0..threads { + let in_rx = in_rx.clone(); + let out_tx = out_tx.clone(); + let panic_tx = panic_tx.clone(); + let mut f = f.clone(); + + std::thread::spawn(move || { + let _foo = Canary::new(|| { + panic_tx.send(()).ok(); // avoid nested panic + }); + for (i, item) in in_rx.into_iter() { + out_tx.send((i, (f)(item))).unwrap(); + } + }); + } + + ParallelMap { + iter, + buffer_size, + in_tx, + in_i: 0, + out_rx, + out_i: 0, + panic_rx, + buf: HashMap::new(), + } + } +} + +impl<I, B> Iterator for ParallelMap<I, B> +where + I: Iterator, +{ + type Item = B; + + fn next(&mut self) -> Option<Self::Item> { + loop { + // Send input to workers. + while self.inflight() < self.buffer_size { + if let Some(item) = self.iter.next() { + self.in_tx.send((self.in_i, item)).unwrap(); + } else { + self.buf.insert(self.in_i, None); + } + self.in_i = self.in_i.wrapping_add(1); + } + + // Return requested item from buffer, if available. + if let Some(item) = self.buf.remove(&self.out_i) { + self.out_i = self.out_i.wrapping_add(1); + return item; + } + + // Wait for new output from workers. + let mut sel = Select::new(); + sel.recv(&self.out_rx); + let panic_received = sel.recv(&self.panic_rx); + + if sel.ready() == panic_received && self.panic_rx.try_recv().is_ok() { + panic!("worker thread panicked"); + } + + // Receive output from workers. + loop { + match self.out_rx.try_recv() { + Ok((i, item)) => { + self.buf.insert(i, Some(item)); + } + Err(TryRecvError::Empty) => break, + Err(TryRecvError::Disconnected) => break, + } + } + } + } +} + +/// Calls a given closure when the thread unwinds due to a panic. +struct Canary<F: FnMut()> { + f: F, +} + +impl<F: FnMut()> Canary<F> { + /// Creates a canary with the given closure. + /// + /// The closure shouldn't panic. Otherwise the process will be aborted. + fn new(f: F) -> Self { + Canary { f } + } +} + +impl<F: FnMut()> Drop for Canary<F> { + fn drop(&mut self) { + if std::thread::panicking() { + (self.f)(); + } + } +} + +#[cfg(test)] +mod tests { + use std::time::Duration; + + use super::*; + + #[test] + fn empty_iterator() { + assert!(std::iter::empty() + .map_parallel_limit(5, 7, |i: i32| 2 * i) + .eq(std::iter::empty())); + } + + #[test] + fn unit_iterator() { + assert!(std::iter::once(1) + .map_parallel_limit(5, 7, |i| 2 * i) + .eq(std::iter::once(2))); + } + + #[test] + fn preserves_order_with_multiple_threads() { + assert!((0..100) + .into_iter() + .map_parallel_limit(5, 7, |i| { + std::thread::sleep(Duration::from_millis((i % 3) * 10)); + 2 * i + }) + .eq((0..100).into_iter().map(|i| 2 * i))); + } + + #[test] + fn preserves_order_with_single_thread() { + assert!((0..100) + .into_iter() + .map_parallel_limit(1, 7, |i| { + std::thread::sleep(Duration::from_millis((i % 3) * 10)); + 2 * i + }) + .eq((0..100).into_iter().map(|i| 2 * i))); + } + + #[test] + fn preserves_order_with_zero_threads() { + assert!((0..100) + .into_iter() + .map_parallel_limit(0, 7, |i| { + std::thread::sleep(Duration::from_millis((i % 3) * 10)); + 2 * i + }) + .eq((0..100).into_iter().map(|i| 2 * i))); + } + + #[test] + fn preserves_order_with_zero_buffer_size() { + assert!((0..100) + .into_iter() + .map_parallel_limit(5, 0, |i| { + std::thread::sleep(Duration::from_millis((i % 3) * 10)); + 2 * i + }) + .eq((0..100).into_iter().map(|i| 2 * i))); + } + + #[test] + fn does_not_fuse() { + let mut i = 0; + let mut iter = std::iter::from_fn(move || { + i += 1; + if i == 2 { + None + } else { + Some(i) + } + }) + .take(3) + .map_parallel_limit(5, 7, |i| i); + + assert_eq!(iter.next(), Some(1)); + assert_eq!(iter.next(), None); + assert_eq!(iter.next(), Some(3)); + assert_eq!(iter.next(), None); + } + + #[test] + fn is_lazy() { + let _iter = (0..10) + .into_iter() + .map_parallel_limit(5, 7, |_| panic!("eager evaluation")); + } + + #[test] + #[should_panic] + #[ntest::timeout(1000)] + fn propagates_panic() { + let _ = (0..100) + .into_iter() + .map_parallel_limit(5, 7, |i| { + if i == 13 { + panic!("boom"); + } else { + i + } + }) + .collect::<Vec<_>>(); + } + + #[test] + fn canary_positive() { + let (tx, rx) = crossbeam_channel::bounded(1); + std::thread::spawn(move || { + let _canary = Canary::new(|| tx.send(()).unwrap()); + panic!("boom"); + }); + assert_eq!(rx.recv_timeout(Duration::from_secs(1)), Ok(())); + } + + #[test] + fn canary_negative() { + let mut panicked = false; + let canary = Canary::new(|| { + panicked = true; + }); + drop(canary); + assert!(!panicked); + } + + #[test] + fn indices_wrap() { + let mut map = ParallelMap::new(0..100, 5, 7, |i| { + std::thread::sleep(Duration::from_millis((i % 3) * 10)); + 2 * i + }); + + // Fast forward + map.in_i = usize::MAX - 13; + map.out_i = usize::MAX - 13; + + assert!(map.eq((0..100).into_iter().map(|i| 2 * i))); + } + + #[test] + fn inflight() { + let inflight = |i, j| { + let mut map = ParallelMap::new(std::iter::empty(), 5, 7, |x: i32| x); + map.in_i = i; + map.out_i = j; + map.inflight() + }; + + assert_eq!(inflight(0, 0), 0); + assert_eq!(inflight(usize::MAX, 0), usize::MAX); + assert_eq!(inflight(usize::MAX, usize::MAX), 0); + assert_eq!(inflight(0, usize::MAX), 1); + assert_eq!(inflight(17, 13), 4); + assert_eq!(inflight(13, usize::MAX - 17), 31); + } +} |