Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added A par_chunk_by method #925

Merged
merged 12 commits into from
Mar 24, 2024
244 changes: 244 additions & 0 deletions src/slice/chunk_by.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,244 @@
use crate::iter::plumbing::*;
use crate::iter::*;
use std::marker::PhantomData;
use std::{fmt, mem};

trait ChunkBySlice<T>: AsRef<[T]> + Default + Send {
fn split(self, index: usize) -> (Self, Self);

fn find(&self, pred: &impl Fn(&T, &T) -> bool, start: usize, end: usize) -> Option<usize> {
self.as_ref()[start..end]
.windows(2)
.position(move |w| !pred(&w[0], &w[1]))
.map(|i| i + 1)
}

fn rfind(&self, pred: &impl Fn(&T, &T) -> bool, end: usize) -> Option<usize> {
self.as_ref()[..end]
.windows(2)
.rposition(move |w| !pred(&w[0], &w[1]))
.map(|i| i + 1)
}
}

impl<T: Sync> ChunkBySlice<T> for &[T] {
fn split(self, index: usize) -> (Self, Self) {
self.split_at(index)
}
}

impl<T: Send> ChunkBySlice<T> for &mut [T] {
fn split(self, index: usize) -> (Self, Self) {
self.split_at_mut(index)
}
}

struct ChunkByProducer<'p, T, Slice, Pred> {
slice: Slice,
pred: &'p Pred,
tail: usize,
marker: PhantomData<fn(&T)>,
}

// Note: this implementation is very similar to `SplitProducer`.
impl<T, Slice, Pred> UnindexedProducer for ChunkByProducer<'_, T, Slice, Pred>
where
Slice: ChunkBySlice<T>,
Pred: Fn(&T, &T) -> bool + Send + Sync,
{
type Item = Slice;

fn split(self) -> (Self, Option<Self>) {
if self.tail < 2 {
return (Self { tail: 0, ..self }, None);
}

// Look forward for the separator, and failing that look backward.
let mid = self.tail / 2;
let index = match self.slice.find(self.pred, mid, self.tail) {
Some(i) => Some(mid + i),
None => self.slice.rfind(self.pred, mid + 1),
};

if let Some(index) = index {
let (left, right) = self.slice.split(index);

let (left_tail, right_tail) = if index <= mid {
// If we scanned backwards to find the separator, everything in
// the right side is exhausted, with no separators left to find.
(index, 0)
} else {
(mid + 1, self.tail - index)
};

// Create the left split before the separator.
let left = Self {
slice: left,
tail: left_tail,
..self
};

// Create the right split following the separator.
let right = Self {
slice: right,
tail: right_tail,
..self
};

(left, Some(right))
} else {
// The search is exhausted, no more separators...
(Self { tail: 0, ..self }, None)
}
}

fn fold_with<F>(self, mut folder: F) -> F
where
F: Folder<Self::Item>,
{
let Self {
slice, pred, tail, ..
} = self;

let (slice, tail) = if tail == slice.as_ref().len() {
// No tail section, so just let `consume_iter` do it all.
(Some(slice), None)
} else if let Some(index) = slice.rfind(pred, tail) {
// We found the last separator to complete the tail, so
// end with that slice after `consume_iter` finds the rest.
let (left, right) = slice.split(index);
(Some(left), Some(right))
} else {
// We know there are no separators at all, so it's all "tail".
(None, Some(slice))
};

if let Some(mut slice) = slice {
// TODO (MSRV 1.77) use either:
// folder.consume_iter(slice.chunk_by(pred))
// folder.consume_iter(slice.chunk_by_mut(pred))

folder = folder.consume_iter(std::iter::from_fn(move || {
let len = slice.as_ref().len();
if len > 0 {
let i = slice.find(pred, 0, len).unwrap_or(len);
let (head, tail) = mem::take(&mut slice).split(i);
slice = tail;
Some(head)
} else {
None
}
}));
}

if let Some(tail) = tail {
folder = folder.consume(tail);
}

folder
}
}

/// Parallel iterator over slice in (non-overlapping) chunks separated by a predicate.
///
/// This struct is created by the [`par_chunk_by`] method on `&[T]`.
///
/// [`par_chunk_by`]: trait.ParallelSlice.html#method.par_chunk_by
pub struct ChunkBy<'data, T, P> {
pred: P,
slice: &'data [T],
}

impl<'data, T, P: Clone> Clone for ChunkBy<'data, T, P> {
fn clone(&self) -> Self {
ChunkBy {
pred: self.pred.clone(),
slice: self.slice,
}
}
}

impl<'data, T: fmt::Debug, P> fmt::Debug for ChunkBy<'data, T, P> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ChunkBy")
.field("slice", &self.slice)
.finish()
}
}

impl<'data, T, P> ChunkBy<'data, T, P> {
pub(super) fn new(slice: &'data [T], pred: P) -> Self {
Self { pred, slice }
}
}

impl<'data, T, P> ParallelIterator for ChunkBy<'data, T, P>
where
T: Sync,
P: Fn(&T, &T) -> bool + Send + Sync,
{
type Item = &'data [T];

fn drive_unindexed<C>(self, consumer: C) -> C::Result
where
C: UnindexedConsumer<Self::Item>,
{
bridge_unindexed(
ChunkByProducer {
tail: self.slice.len(),
slice: self.slice,
pred: &self.pred,
marker: PhantomData,
},
consumer,
)
}
}

/// Parallel iterator over slice in (non-overlapping) mutable chunks
/// separated by a predicate.
///
/// This struct is created by the [`par_chunk_by_mut`] method on `&mut [T]`.
///
/// [`par_chunk_by_mut`]: trait.ParallelSliceMut.html#method.par_chunk_by_mut
pub struct ChunkByMut<'data, T, P> {
pred: P,
slice: &'data mut [T],
}

impl<'data, T: fmt::Debug, P> fmt::Debug for ChunkByMut<'data, T, P> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("ChunkByMut")
.field("slice", &self.slice)
.finish()
}
}

impl<'data, T, P> ChunkByMut<'data, T, P> {
pub(super) fn new(slice: &'data mut [T], pred: P) -> Self {
Self { pred, slice }
}
}

impl<'data, T, P> ParallelIterator for ChunkByMut<'data, T, P>
where
T: Send,
P: Fn(&T, &T) -> bool + Send + Sync,
{
type Item = &'data mut [T];

fn drive_unindexed<C>(self, consumer: C) -> C::Result
where
C: UnindexedConsumer<Self::Item>,
{
bridge_unindexed(
ChunkByProducer {
tail: self.slice.len(),
slice: self.slice,
pred: &self.pred,
marker: PhantomData,
},
consumer,
)
}
}
49 changes: 49 additions & 0 deletions src/slice/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
//!
//! [std::slice]: https://doc.rust-lang.org/stable/std/slice/

mod chunk_by;
mod chunks;
mod mergesort;
mod quicksort;
Expand All @@ -22,6 +23,7 @@ use std::cmp::Ordering;
use std::fmt::{self, Debug};
use std::mem;

pub use self::chunk_by::{ChunkBy, ChunkByMut};
pub use self::chunks::{Chunks, ChunksExact, ChunksExactMut, ChunksMut};
pub use self::rchunks::{RChunks, RChunksExact, RChunksExactMut, RChunksMut};

Expand Down Expand Up @@ -173,6 +175,29 @@ pub trait ParallelSlice<T: Sync> {
assert!(chunk_size != 0, "chunk_size must not be zero");
RChunksExact::new(chunk_size, self.as_parallel_slice())
}

/// Returns a parallel iterator over the slice producing non-overlapping runs
/// of elements using the predicate to separate them.
///
/// The predicate is called on two elements following themselves,
/// it means the predicate is called on `slice[0]` and `slice[1]`
/// then on `slice[1]` and `slice[2]` and so on.
///
/// # Examples
///
/// ```
/// use rayon::prelude::*;
/// let chunks: Vec<_> = [1, 2, 2, 3, 3, 3].par_chunk_by(|&x, &y| x == y).collect();
/// assert_eq!(chunks[0], &[1]);
/// assert_eq!(chunks[1], &[2, 2]);
/// assert_eq!(chunks[2], &[3, 3, 3]);
/// ```
fn par_chunk_by<F>(&self, pred: F) -> ChunkBy<'_, T, F>
where
F: Fn(&T, &T) -> bool + Send + Sync,
{
ChunkBy::new(self.as_parallel_slice(), pred)
}
}

impl<T: Sync> ParallelSlice<T> for [T] {
Expand Down Expand Up @@ -704,6 +729,30 @@ pub trait ParallelSliceMut<T: Send> {
{
par_quicksort(self.as_parallel_slice_mut(), |a, b| f(a).lt(&f(b)));
}

/// Returns a parallel iterator over the slice producing non-overlapping mutable
/// runs of elements using the predicate to separate them.
///
/// The predicate is called on two elements following themselves,
/// it means the predicate is called on `slice[0]` and `slice[1]`
/// then on `slice[1]` and `slice[2]` and so on.
///
/// # Examples
///
/// ```
/// use rayon::prelude::*;
/// let mut xs = [1, 2, 2, 3, 3, 3];
/// let chunks: Vec<_> = xs.par_chunk_by_mut(|&x, &y| x == y).collect();
/// assert_eq!(chunks[0], &mut [1]);
/// assert_eq!(chunks[1], &mut [2, 2]);
/// assert_eq!(chunks[2], &mut [3, 3, 3]);
/// ```
fn par_chunk_by_mut<F>(&mut self, pred: F) -> ChunkByMut<'_, T, F>
where
F: Fn(&T, &T) -> bool + Send + Sync,
{
ChunkByMut::new(self.as_parallel_slice_mut(), pred)
}
}

impl<T: Send> ParallelSliceMut<T> for [T] {
Expand Down
46 changes: 46 additions & 0 deletions src/slice/test.rs
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ use rand::distributions::Uniform;
use rand::seq::SliceRandom;
use rand::{thread_rng, Rng};
use std::cmp::Ordering::{Equal, Greater, Less};
use std::sync::atomic::{AtomicUsize, Ordering::Relaxed};

macro_rules! sort {
($f:ident, $name:ident) => {
Expand Down Expand Up @@ -168,3 +169,48 @@ fn test_par_rchunks_exact_mut_remainder() {
assert_eq!(c.take_remainder(), &[]);
assert_eq!(c.len(), 2);
}

#[test]
fn slice_chunk_by() {
let v: Vec<_> = (0..1000).collect();
assert_eq!(v[..0].par_chunk_by(|_, _| todo!()).count(), 0);
assert_eq!(v[..1].par_chunk_by(|_, _| todo!()).count(), 1);
assert_eq!(v[..2].par_chunk_by(|_, _| true).count(), 1);
assert_eq!(v[..2].par_chunk_by(|_, _| false).count(), 2);

let count = AtomicUsize::new(0);
let par: Vec<_> = v
.par_chunk_by(|x, y| {
count.fetch_add(1, Relaxed);
(x % 10 < 3) == (y % 10 < 3)
})
.collect();
assert_eq!(count.into_inner(), v.len() - 1);

let seq: Vec<_> = v.chunk_by(|x, y| (x % 10 < 3) == (y % 10 < 3)).collect();
assert_eq!(par, seq);
}

#[test]
fn slice_chunk_by_mut() {
let mut v: Vec<_> = (0..1000).collect();
assert_eq!(v[..0].par_chunk_by_mut(|_, _| todo!()).count(), 0);
assert_eq!(v[..1].par_chunk_by_mut(|_, _| todo!()).count(), 1);
assert_eq!(v[..2].par_chunk_by_mut(|_, _| true).count(), 1);
assert_eq!(v[..2].par_chunk_by_mut(|_, _| false).count(), 2);

let mut v2 = v.clone();
let count = AtomicUsize::new(0);
let par: Vec<_> = v
.par_chunk_by_mut(|x, y| {
count.fetch_add(1, Relaxed);
(x % 10 < 3) == (y % 10 < 3)
})
.collect();
assert_eq!(count.into_inner(), v2.len() - 1);

let seq: Vec<_> = v2
.chunk_by_mut(|x, y| (x % 10 < 3) == (y % 10 < 3))
.collect();
assert_eq!(par, seq);
}
1 change: 1 addition & 0 deletions tests/clones.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,6 +109,7 @@ fn clone_str() {
fn clone_vec() {
let v: Vec<_> = (0..1000).collect();
check(v.par_iter());
check(v.par_chunk_by(i32::eq));
check(v.par_chunks(42));
check(v.par_chunks_exact(42));
check(v.par_rchunks(42));
Expand Down
2 changes: 2 additions & 0 deletions tests/debug.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,6 +121,8 @@ fn debug_vec() {
let mut v: Vec<_> = (0..10).collect();
check(v.par_iter());
check(v.par_iter_mut());
check(v.par_chunk_by(i32::eq));
check(v.par_chunk_by_mut(i32::eq));
check(v.par_chunks(42));
check(v.par_chunks_exact(42));
check(v.par_chunks_mut(42));
Expand Down