summaryrefslogtreecommitdiff
path: root/src/lib.rs
diff options
context:
space:
mode:
Diffstat (limited to 'src/lib.rs')
-rw-r--r--src/lib.rs470
1 files changed, 470 insertions, 0 deletions
diff --git a/src/lib.rs b/src/lib.rs
new file mode 100644
index 0000000..3442f58
--- /dev/null
+++ b/src/lib.rs
@@ -0,0 +1,470 @@
+//! Parallel sequential iterator.
+//!
+//! This crate implements an extension trait [`ParallelIterator`] adding parallel sequential
+//! mapping to the standard [`Iterator`] trait.
+//!
+//! # Example
+//!
+//! ```
+//! use std::time::Duration;
+//! use parseq::ParallelIterator;
+//!
+//! let mut iter = (0..3)
+//! .into_iter()
+//! .map_parallel(|i| {
+//! // Insert heavy computation here ...
+//! std::thread::sleep(Duration::from_millis((i % 3) * 10));
+//! i
+//! });
+//!
+//! assert_eq!(iter.next(), Some(0));
+//! assert_eq!(iter.next(), Some(1));
+//! assert_eq!(iter.next(), Some(2));
+//! assert_eq!(iter.next(), None);
+//! ```
+//!
+//! # Rationale
+//!
+//! This library was created to process a large number of files returned by
+//! [walkdir](https://crates.io/crates/walkdir) in parallel, in order, and in constant space. It's
+//! API and dependencies were kept to a minimum to ease maintenance.
+//!
+//! If you don't care about the order of the returned iterator you'll probably want to use
+//! [rayon](https://crates.io/crates/rayon) instead. If you do care about the order, take a look at
+//! [pariter](https://crates.io/crates/pariter). The latter provides more functionality than this
+//! crate and predates it.
+
+#![forbid(unsafe_code)]
+#![warn(missing_docs)]
+
+use std::{collections::HashMap, num::NonZeroUsize};
+
+use crossbeam_channel::{Receiver, Select, Sender, TryRecvError};
+
+/// An extension trait adding parallel sequential mapping to the standard [`Iterator`] trait.
+pub trait ParallelIterator {
+ /// Creates an iterator which applies a given closure to each element in parallel.
+ ///
+ /// This function is a multi-threaded equivalent of [`Iterator::map`]. It uses up to
+ /// [`available_parallelism`](std::thread::available_parallelism) threads and buffers a finite
+ /// number of items. Use [`map_parallel_limit`](ParallelIterator) if you want to set
+ /// parallelism and space limits.
+ ///
+ /// The returned iterator
+ ///
+ /// * preserves the order of the original iterator
+ /// * is lazy in the sense that it doesn't consume from the original iterator before [`next`](`Iterator::next`) is called for the first time
+ /// * doesn't fuse the original iterator
+ /// * uses constant space: linear in `threads` and `buffer_size`, not in the length of the possibly infinite original iterator
+ /// * propagates panics from the given closure
+ ///
+ /// # Example
+ ///
+ /// ```
+ /// use std::time::Duration;
+ /// use parseq::ParallelIterator;
+ ///
+ /// let mut iter = (0..3)
+ /// .into_iter()
+ /// .map_parallel(|i| {
+ /// std::thread::sleep(Duration::from_millis((i % 3) * 10));
+ /// i
+ /// });
+ ///
+ /// assert_eq!(iter.next(), Some(0));
+ /// assert_eq!(iter.next(), Some(1));
+ /// assert_eq!(iter.next(), Some(2));
+ /// assert_eq!(iter.next(), None);
+ /// ```
+ fn map_parallel<B, F>(self, f: F) -> ParallelMap<Self, B>
+ where
+ Self: Iterator + Sized,
+ Self::Item: Send + 'static,
+ F: FnMut(Self::Item) -> B + Send + Clone + 'static,
+ B: Send + 'static,
+ {
+ let threads = std::thread::available_parallelism()
+ .map(NonZeroUsize::get)
+ .unwrap_or(1);
+ let buffer_size = threads.saturating_mul(16);
+ self.map_parallel_limit(threads, buffer_size, f)
+ }
+
+ /// Creates an iterator which applies a given closure to each element in parallel.
+ ///
+ /// This function is a multi-threaded equivalent of [`Iterator::map`]. It uses up to the given
+ /// number of `threads` and buffers up to `buffer_size` items. If `threads` is zero, up to
+ /// [`available_parallelism`](std::thread::available_parallelism) threads are used instead. The
+ /// `buffer_size` should be greater than the number of threads. A `buffer_size < 2` effectively
+ /// results in single-threaded processing.
+ ///
+ /// The returned iterator
+ ///
+ /// * preserves the order of the original iterator
+ /// * is lazy in the sense that it doesn't consume from the original iterator before [`next`](`Iterator::next`) is called for the first time
+ /// * doesn't fuse the original iterator
+ /// * uses constant space: linear in `threads` and `buffer_size`, not in the length of the possibly infinite original iterator
+ /// * propagates panics from the given closure
+ ///
+ /// # Example
+ ///
+ /// ```
+ /// use std::time::Duration;
+ /// use parseq::ParallelIterator;
+ ///
+ /// let mut iter = (0..3)
+ /// .into_iter()
+ /// .map_parallel_limit(2, 16, |i| {
+ /// std::thread::sleep(Duration::from_millis((i % 3) * 10));
+ /// i
+ /// });
+ ///
+ /// assert_eq!(iter.next(), Some(0));
+ /// assert_eq!(iter.next(), Some(1));
+ /// assert_eq!(iter.next(), Some(2));
+ /// assert_eq!(iter.next(), None);
+ /// ```
+ fn map_parallel_limit<B, F>(
+ self,
+ threads: usize,
+ buffer_size: usize,
+ f: F,
+ ) -> ParallelMap<Self, B>
+ where
+ Self: Iterator + Sized,
+ Self::Item: Send + 'static,
+ F: FnMut(Self::Item) -> B + Send + Clone + 'static,
+ B: Send + 'static,
+ {
+ ParallelMap::new(self, threads, buffer_size, f)
+ }
+}
+
+impl<I> ParallelIterator for I where I: Iterator {}
+
+/// An iterator that maps the elements of another iterator in parallel.
+///
+/// This struct is created by the [`map_parallel`](ParallelIterator::map_parallel) method on the
+/// [`ParallelIterator`] trait.
+pub struct ParallelMap<I, B>
+where
+ I: Iterator,
+{
+ /// Wrapped iterator.
+ iter: I,
+
+ /// Maximum number of in-flight input and output items.
+ buffer_size: usize,
+
+ /// Input sender.
+ in_tx: Sender<(usize, I::Item)>,
+
+ /// Index of the next input.
+ in_i: usize,
+
+ /// Output receiver.
+ out_rx: Receiver<(usize, B)>,
+
+ /// Index of the next output.
+ out_i: usize,
+
+ /// Worker thread panic receiver.
+ panic_rx: Receiver<()>,
+
+ /// Output buffer.
+ buf: HashMap<usize, Option<B>>,
+}
+
+impl<I, B> ParallelMap<I, B>
+where
+ I: Iterator,
+{
+ /// Returns the number of in-flight items: queued input items and buffered output items.
+ fn inflight(&self) -> usize {
+ if self.in_i >= self.out_i {
+ self.in_i - self.out_i
+ } else {
+ (self.in_i + 1) + (usize::MAX - self.out_i)
+ }
+ }
+}
+
+impl<I, B> ParallelMap<I, B>
+where
+ I: Iterator,
+ I::Item: Send + 'static,
+ B: Send + 'static,
+{
+ fn new<F>(iter: I, threads: usize, buffer_size: usize, f: F) -> Self
+ where
+ F: FnMut(I::Item) -> B + Send + Clone + 'static,
+ {
+ let threads = if threads > 0 {
+ threads
+ } else {
+ std::thread::available_parallelism()
+ .map(NonZeroUsize::get)
+ .unwrap_or(1)
+ };
+ let buffer_size = if buffer_size > 0 { buffer_size } else { 1 };
+
+ let (in_tx, in_rx) = crossbeam_channel::bounded(buffer_size);
+ let (out_tx, out_rx) = crossbeam_channel::bounded(buffer_size);
+ let (panic_tx, panic_rx) = crossbeam_channel::bounded(threads);
+
+ for _ in 0..threads {
+ let in_rx = in_rx.clone();
+ let out_tx = out_tx.clone();
+ let panic_tx = panic_tx.clone();
+ let mut f = f.clone();
+
+ std::thread::spawn(move || {
+ let _foo = Canary::new(|| {
+ panic_tx.send(()).ok(); // avoid nested panic
+ });
+ for (i, item) in in_rx.into_iter() {
+ out_tx.send((i, (f)(item))).unwrap();
+ }
+ });
+ }
+
+ ParallelMap {
+ iter,
+ buffer_size,
+ in_tx,
+ in_i: 0,
+ out_rx,
+ out_i: 0,
+ panic_rx,
+ buf: HashMap::new(),
+ }
+ }
+}
+
+impl<I, B> Iterator for ParallelMap<I, B>
+where
+ I: Iterator,
+{
+ type Item = B;
+
+ fn next(&mut self) -> Option<Self::Item> {
+ loop {
+ // Send input to workers.
+ while self.inflight() < self.buffer_size {
+ if let Some(item) = self.iter.next() {
+ self.in_tx.send((self.in_i, item)).unwrap();
+ } else {
+ self.buf.insert(self.in_i, None);
+ }
+ self.in_i = self.in_i.wrapping_add(1);
+ }
+
+ // Return requested item from buffer, if available.
+ if let Some(item) = self.buf.remove(&self.out_i) {
+ self.out_i = self.out_i.wrapping_add(1);
+ return item;
+ }
+
+ // Wait for new output from workers.
+ let mut sel = Select::new();
+ sel.recv(&self.out_rx);
+ let panic_received = sel.recv(&self.panic_rx);
+
+ if sel.ready() == panic_received && self.panic_rx.try_recv().is_ok() {
+ panic!("worker thread panicked");
+ }
+
+ // Receive output from workers.
+ loop {
+ match self.out_rx.try_recv() {
+ Ok((i, item)) => {
+ self.buf.insert(i, Some(item));
+ }
+ Err(TryRecvError::Empty) => break,
+ Err(TryRecvError::Disconnected) => break,
+ }
+ }
+ }
+ }
+}
+
+/// Calls a given closure when the thread unwinds due to a panic.
+struct Canary<F: FnMut()> {
+ f: F,
+}
+
+impl<F: FnMut()> Canary<F> {
+ /// Creates a canary with the given closure.
+ ///
+ /// The closure shouldn't panic. Otherwise the process will be aborted.
+ fn new(f: F) -> Self {
+ Canary { f }
+ }
+}
+
+impl<F: FnMut()> Drop for Canary<F> {
+ fn drop(&mut self) {
+ if std::thread::panicking() {
+ (self.f)();
+ }
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use std::time::Duration;
+
+ use super::*;
+
+ #[test]
+ fn empty_iterator() {
+ assert!(std::iter::empty()
+ .map_parallel_limit(5, 7, |i: i32| 2 * i)
+ .eq(std::iter::empty()));
+ }
+
+ #[test]
+ fn unit_iterator() {
+ assert!(std::iter::once(1)
+ .map_parallel_limit(5, 7, |i| 2 * i)
+ .eq(std::iter::once(2)));
+ }
+
+ #[test]
+ fn preserves_order_with_multiple_threads() {
+ assert!((0..100)
+ .into_iter()
+ .map_parallel_limit(5, 7, |i| {
+ std::thread::sleep(Duration::from_millis((i % 3) * 10));
+ 2 * i
+ })
+ .eq((0..100).into_iter().map(|i| 2 * i)));
+ }
+
+ #[test]
+ fn preserves_order_with_single_thread() {
+ assert!((0..100)
+ .into_iter()
+ .map_parallel_limit(1, 7, |i| {
+ std::thread::sleep(Duration::from_millis((i % 3) * 10));
+ 2 * i
+ })
+ .eq((0..100).into_iter().map(|i| 2 * i)));
+ }
+
+ #[test]
+ fn preserves_order_with_zero_threads() {
+ assert!((0..100)
+ .into_iter()
+ .map_parallel_limit(0, 7, |i| {
+ std::thread::sleep(Duration::from_millis((i % 3) * 10));
+ 2 * i
+ })
+ .eq((0..100).into_iter().map(|i| 2 * i)));
+ }
+
+ #[test]
+ fn preserves_order_with_zero_buffer_size() {
+ assert!((0..100)
+ .into_iter()
+ .map_parallel_limit(5, 0, |i| {
+ std::thread::sleep(Duration::from_millis((i % 3) * 10));
+ 2 * i
+ })
+ .eq((0..100).into_iter().map(|i| 2 * i)));
+ }
+
+ #[test]
+ fn does_not_fuse() {
+ let mut i = 0;
+ let mut iter = std::iter::from_fn(move || {
+ i += 1;
+ if i == 2 {
+ None
+ } else {
+ Some(i)
+ }
+ })
+ .take(3)
+ .map_parallel_limit(5, 7, |i| i);
+
+ assert_eq!(iter.next(), Some(1));
+ assert_eq!(iter.next(), None);
+ assert_eq!(iter.next(), Some(3));
+ assert_eq!(iter.next(), None);
+ }
+
+ #[test]
+ fn is_lazy() {
+ let _iter = (0..10)
+ .into_iter()
+ .map_parallel_limit(5, 7, |_| panic!("eager evaluation"));
+ }
+
+ #[test]
+ #[should_panic]
+ #[ntest::timeout(1000)]
+ fn propagates_panic() {
+ let _ = (0..100)
+ .into_iter()
+ .map_parallel_limit(5, 7, |i| {
+ if i == 13 {
+ panic!("boom");
+ } else {
+ i
+ }
+ })
+ .collect::<Vec<_>>();
+ }
+
+ #[test]
+ fn canary_positive() {
+ let (tx, rx) = crossbeam_channel::bounded(1);
+ std::thread::spawn(move || {
+ let _canary = Canary::new(|| tx.send(()).unwrap());
+ panic!("boom");
+ });
+ assert_eq!(rx.recv_timeout(Duration::from_secs(1)), Ok(()));
+ }
+
+ #[test]
+ fn canary_negative() {
+ let mut panicked = false;
+ let canary = Canary::new(|| {
+ panicked = true;
+ });
+ drop(canary);
+ assert!(!panicked);
+ }
+
+ #[test]
+ fn indices_wrap() {
+ let mut map = ParallelMap::new(0..100, 5, 7, |i| {
+ std::thread::sleep(Duration::from_millis((i % 3) * 10));
+ 2 * i
+ });
+
+ // Fast forward
+ map.in_i = usize::MAX - 13;
+ map.out_i = usize::MAX - 13;
+
+ assert!(map.eq((0..100).into_iter().map(|i| 2 * i)));
+ }
+
+ #[test]
+ fn inflight() {
+ let inflight = |i, j| {
+ let mut map = ParallelMap::new(std::iter::empty(), 5, 7, |x: i32| x);
+ map.in_i = i;
+ map.out_i = j;
+ map.inflight()
+ };
+
+ assert_eq!(inflight(0, 0), 0);
+ assert_eq!(inflight(usize::MAX, 0), usize::MAX);
+ assert_eq!(inflight(usize::MAX, usize::MAX), 0);
+ assert_eq!(inflight(0, usize::MAX), 1);
+ assert_eq!(inflight(17, 13), 4);
+ assert_eq!(inflight(13, usize::MAX - 17), 31);
+ }
+}
Generated by cgit. See skreutz.com for my tech blog and contact information.