//! Parallel sequential iterator. //! //! This crate implements an extension trait [`ParallelIterator`] adding parallel sequential //! mapping to the standard [`Iterator`] trait. //! //! # Example //! //! ``` //! use std::time::Duration; //! use parseq::ParallelIterator; //! //! let mut iter = [3,2,1] //! .into_iter() //! .map_parallel(|i| { //! // Insert heavy computation here ... //! std::thread::sleep(Duration::from_millis(100*i)); //! 2*i //! }); //! //! assert_eq!(iter.next(), Some(6)); //! assert_eq!(iter.next(), Some(4)); //! assert_eq!(iter.next(), Some(2)); //! assert_eq!(iter.next(), None); //! ``` //! //! See the `examples` directory for a real world example. //! //! # Features //! //! * Parseq utilizes a configurable number of worker threads //! * Parseq preserves the order of the original iterator //! * Parseq is lazy in the sense that it doesn't consume from the original iterator before [`next`](`Iterator::next`) is called for the first time //! * Parseq doesn't [`fuse`](`Iterator::fuse`) the original iterator //! * Parseq uses constant space: linear in the number of threads and the size of the buffer, not in the length of the possibly infinite original iterator //! * Parseq propagates panics from the given closure //! //! # Alternatives //! //! If you don't care about the order of the returned iterator you'll probably want to use //! [Rayon](https://crates.io/crates/rayon) instead. If you do care about the order, take a look at //! [Pariter](https://crates.io/crates/pariter). The latter provides more functionality than this //! crate and predates it. #![forbid(unsafe_code)] #![warn(missing_docs)] use std::{collections::HashMap, iter::FusedIterator, num::NonZeroUsize}; use crossbeam_channel::{Receiver, Select, Sender, TryRecvError}; /// An extension trait adding parallel sequential mapping to the standard [`Iterator`] trait. #[must_use = "iterators are lazy and do nothing unless consumed"] pub trait ParallelIterator { /// Creates an iterator which applies a given closure to each element in parallel. /// /// This function is a multi-threaded equivalent of [`Iterator::map`]. It uses up to /// [`available_parallelism`](std::thread::available_parallelism) threads and buffers a finite /// number of items. Use [`map_parallel_limit`](ParallelIterator) if you want to set /// parallelism and space limits. /// /// The returned iterator /// /// * preserves the order of the original iterator /// * is lazy in the sense that it doesn't consume from the original iterator before [`next`](`Iterator::next`) is called for the first time /// * doesn't [`fuse`](`Iterator::fuse`) the original iterator /// * uses constant space: linear in `threads` and `buffer_size`, not in the length of the possibly infinite original iterator /// * propagates panics from the given closure /// /// # Example /// /// ``` /// use std::time::Duration; /// use parseq::ParallelIterator; /// /// let mut iter = [3,2,1] /// .into_iter() /// .map_parallel(|i| { /// // Insert heavy computation here ... /// std::thread::sleep(Duration::from_millis(100*i)); /// 2*i /// }); /// /// assert_eq!(iter.next(), Some(6)); /// assert_eq!(iter.next(), Some(4)); /// assert_eq!(iter.next(), Some(2)); /// assert_eq!(iter.next(), None); /// ``` fn map_parallel(self, f: F) -> ParallelMap where Self: Iterator + Sized, Self::Item: Send + 'static, F: FnMut(Self::Item) -> B + Send + Clone + 'static, B: Send + 'static, { let threads = std::thread::available_parallelism() .map(NonZeroUsize::get) .unwrap_or(1); let buffer_size = threads.saturating_mul(16); self.map_parallel_limit(threads, buffer_size, f) } /// Creates an iterator which applies a given closure to each element in parallel. /// /// This function is a multi-threaded equivalent of [`Iterator::map`]. It uses up to the given /// number of `threads` and buffers up to `buffer_size` items. If `threads` is zero, up to /// [`available_parallelism`](std::thread::available_parallelism) threads are used instead. The /// `buffer_size` should be greater than the number of threads. A `buffer_size < 2` effectively /// results in single-threaded processing. /// /// The returned iterator /// /// * preserves the order of the original iterator /// * is lazy in the sense that it doesn't consume from the original iterator before [`next`](`Iterator::next`) is called for the first time /// * doesn't [`fuse`](`Iterator::fuse`) the original iterator /// * uses constant space: linear in `threads` and `buffer_size`, not in the length of the possibly infinite original iterator /// * propagates panics from the given closure /// /// # Example /// /// ``` /// use std::time::Duration; /// use parseq::ParallelIterator; /// /// let mut iter = [3,2,1] /// .into_iter() /// .map_parallel_limit(2, 16, |i| { /// std::thread::sleep(Duration::from_millis(100*i)); /// 2*i /// }); /// /// assert_eq!(iter.next(), Some(6)); /// assert_eq!(iter.next(), Some(4)); /// assert_eq!(iter.next(), Some(2)); /// assert_eq!(iter.next(), None); /// ``` fn map_parallel_limit( self, threads: usize, buffer_size: usize, f: F, ) -> ParallelMap where Self: Iterator + Sized, Self::Item: Send + 'static, F: FnMut(Self::Item) -> B + Send + Clone + 'static, B: Send + 'static, { ParallelMap::new(self, threads, buffer_size, f) } } impl ParallelIterator for I where I: Iterator {} /// An iterator that maps the elements of another iterator in parallel. /// /// This struct is created by the [`map_parallel`](ParallelIterator::map_parallel) method on the /// [`ParallelIterator`] trait. #[must_use = "iterators are lazy and do nothing unless consumed"] pub struct ParallelMap where I: Iterator, { /// Wrapped iterator. iter: I, /// Maximum number of in-flight input and output items. buffer_size: usize, /// Input sender. in_tx: Sender<(usize, I::Item)>, /// Index of the next input. in_i: usize, /// Output receiver. out_rx: Receiver<(usize, B)>, /// Index of the next output. out_i: usize, /// Worker thread panic receiver. panic_rx: Receiver<()>, /// Output buffer. buf: HashMap>, } impl ParallelMap where I: Iterator, { /// Returns the number of in-flight items: queued input items and buffered output items. fn inflight(&self) -> usize { if self.in_i >= self.out_i { self.in_i - self.out_i } else { (self.in_i + 1) + (usize::MAX - self.out_i) } } } impl ParallelMap where I: Iterator, I::Item: Send + 'static, B: Send + 'static, { fn new(iter: I, threads: usize, buffer_size: usize, f: F) -> Self where F: FnMut(I::Item) -> B + Send + Clone + 'static, { let threads = if threads > 0 { threads } else { std::thread::available_parallelism() .map(NonZeroUsize::get) .unwrap_or(1) }; let buffer_size = if buffer_size > 0 { buffer_size } else { 1 }; let (in_tx, in_rx) = crossbeam_channel::bounded(buffer_size); let (out_tx, out_rx) = crossbeam_channel::bounded(buffer_size); let (panic_tx, panic_rx) = crossbeam_channel::bounded(threads); for _ in 0..threads { let in_rx = in_rx.clone(); let out_tx = out_tx.clone(); let panic_tx = panic_tx.clone(); let mut f = f.clone(); std::thread::spawn(move || { let _foo = Canary::new(|| { panic_tx.send(()).ok(); // avoid nested panic }); for (i, item) in in_rx.into_iter() { out_tx.send((i, (f)(item))).ok(); // fails iff. ParallelMap was dropped } }); } ParallelMap { iter, buffer_size, in_tx, in_i: 0, out_rx, out_i: 0, panic_rx, buf: HashMap::new(), } } } impl Iterator for ParallelMap where I: Iterator, { type Item = B; fn next(&mut self) -> Option { loop { // Send input to workers. while self.inflight() < self.buffer_size { if let Some(item) = self.iter.next() { self.in_tx.send((self.in_i, item)).unwrap(); } else { self.buf.insert(self.in_i, None); } self.in_i = self.in_i.wrapping_add(1); } // Return requested item from buffer, if available. if let Some(item) = self.buf.remove(&self.out_i) { self.out_i = self.out_i.wrapping_add(1); return item; } // Wait for new output from workers. let mut sel = Select::new(); sel.recv(&self.out_rx); let panic_received = sel.recv(&self.panic_rx); if sel.ready() == panic_received && self.panic_rx.try_recv().is_ok() { panic!("worker thread panicked"); } // Receive output from workers. loop { match self.out_rx.try_recv() { Ok((i, item)) => { self.buf.insert(i, Some(item)); } Err(TryRecvError::Empty) => break, Err(TryRecvError::Disconnected) => break, } } } } fn size_hint(&self) -> (usize, Option) { let (lower, upper) = self.iter.size_hint(); let inflight = self.inflight(); ( lower.saturating_add(inflight), upper.and_then(|i| i.checked_add(inflight)), ) } } impl FusedIterator for ParallelMap where I: FusedIterator {} impl ExactSizeIterator for ParallelMap where I: ExactSizeIterator, { fn len(&self) -> usize { self.iter.len() + self.inflight() } } /// Calls a given closure when the thread unwinds due to a panic. struct Canary { f: F, } impl Canary { /// Creates a canary with the given closure. /// /// The closure shouldn't panic. Otherwise the process will be aborted. fn new(f: F) -> Self { Canary { f } } } impl Drop for Canary { fn drop(&mut self) { if std::thread::panicking() { (self.f)(); } } } #[cfg(test)] mod tests { use std::{ sync::{Arc, Mutex}, time::Duration, }; use super::*; #[test] fn map_empty_iterator() { assert!(std::iter::empty() .map_parallel_limit(5, 7, |i: i32| 2 * i) .eq(std::iter::empty())); } #[test] fn map_unit_iterator() { assert!(std::iter::once(1) .map_parallel_limit(5, 7, |i| 2 * i) .eq(std::iter::once(2))); } #[test] fn map_with_multiple_threads() { assert!((0..100) .map_parallel_limit(5, 7, |i| { std::thread::sleep(Duration::from_millis((i % 3) * 10)); 2 * i }) .eq((0..100).map(|i| 2 * i))); } #[test] fn map_with_single_thread() { assert!((0..100) .map_parallel_limit(1, 7, |i| { std::thread::sleep(Duration::from_millis((i % 3) * 10)); 2 * i }) .eq((0..100).map(|i| 2 * i))); } #[test] fn map_with_zero_threads() { assert!((0..100) .map_parallel_limit(0, 7, |i| { std::thread::sleep(Duration::from_millis((i % 3) * 10)); 2 * i }) .eq((0..100).map(|i| 2 * i))); } #[test] fn map_with_zero_buffer_size() { assert!((0..100) .map_parallel_limit(5, 0, |i| { std::thread::sleep(Duration::from_millis((i % 3) * 10)); 2 * i }) .eq((0..100).map(|i| 2 * i))); } #[test] fn map_does_not_fuse() { let mut i = 0; let mut iter = std::iter::from_fn(move || { i += 1; if i == 2 { None } else { Some(i) } }) .take(3) .map_parallel_limit(5, 7, |i| i); assert_eq!(iter.next(), Some(1)); assert_eq!(iter.next(), None); assert_eq!(iter.next(), Some(3)); assert_eq!(iter.next(), None); } #[test] fn map_is_lazy() { let _iter = (0..10).map_parallel_limit(5, 7, |_| panic!("eager evaluation")); } #[test] #[should_panic] #[ntest::timeout(1000)] fn map_propagates_panic() { let _ = (0..100) .map_parallel_limit(5, 7, |i| { if i == 13 { panic!("boom"); } else { i } }) .collect::>(); } #[test] fn canary_positive() { let (tx, rx) = crossbeam_channel::bounded(1); std::thread::spawn(move || { let _canary = Canary::new(|| tx.send(()).unwrap()); panic!("boom"); }); assert_eq!(rx.recv_timeout(Duration::from_secs(1)), Ok(())); } #[test] fn canary_negative() { let mut panicked = false; let canary = Canary::new(|| { panicked = true; }); drop(canary); assert!(!panicked); } #[test] fn map_wraps_item_indices() { let mut map = ParallelMap::new(0..100, 5, 7, |i| { std::thread::sleep(Duration::from_millis((i % 3) * 10)); 2 * i }); // Fast forward map.in_i = usize::MAX - 13; map.out_i = usize::MAX - 13; assert!(map.eq((0..100).map(|i| 2 * i))); } #[test] fn inflight() { let inflight = |i, j| { let mut map = ParallelMap::new(std::iter::empty(), 5, 7, |x: i32| x); map.in_i = i; map.out_i = j; map.inflight() }; assert_eq!(inflight(0, 0), 0); assert_eq!(inflight(usize::MAX, 0), usize::MAX); assert_eq!(inflight(usize::MAX, usize::MAX), 0); assert_eq!(inflight(0, usize::MAX), 1); assert_eq!(inflight(17, 13), 4); assert_eq!(inflight(13, usize::MAX - 17), 31); } /// ParallelMap must stop feeding workers when dropped. #[test] fn drop_parallel_map() { let threads = 5; let buffer_size = 20; let consume = 7; let counter = Arc::new(Mutex::new(0)); let count = counter.clone(); let mut iter = (0..).map_parallel_limit(threads, buffer_size, move |i| { if i < consume { std::thread::sleep(Duration::from_millis(100)); } let mut counter = counter.lock().unwrap(); *counter += 1; 2 * i }); for _ in 0..consume { iter.next(); } drop(iter); assert!(*count.lock().unwrap() < consume + buffer_size); } }