parallel flatMap always sequential

There are two different aspects.

First, there is only a single pipeline which is either sequential or parallel. The choice of sequential or parallel at the inner stream is irrelevant. Note that the downstream consumer you see in the cited code snippet represents the entire subsequent stream pipeline, so in your code, ending with .collect(Collectors.toSet());, this consumer will eventually add the resulting elements to a single Set instance which is not thread safe. So processing the inner stream in parallel with that single consumer would break the entire operation.

If an outer stream gets split, that cited code might get invoked concurrently with different consumers adding to different sets. Each of these calls would process a different element of the outer stream mapping to a different inner stream instance. Since your outer stream consists of a single element only, it can’t be split.

The way, this has been implemented, is also the reason for the Why filter() after flatMap() is “not completely” lazy in Java streams? issue, as forEach is called on the inner stream which will pass all elements to the downstream consumer. As demonstrated by this answer, an alternative implementation, supporting laziness and substream splitting, is possible. But this is a fundamentally different way of implementing it. The current design of the Stream implementation mostly works by consumer composition, so in the end, the source spliterator (and those split off from it) receives a Consumer representing the entire stream pipeline in either tryAdvance or forEachRemaining. In contrast, the solution of the linked answer does spliterator composition, producing a new Spliterator delegating to source spliterators. I supposed, both approaches have advantages and I’m not sure, how much the OpenJDK implementation would lose when working the other way round.


For anyone like me, who has a dire need to parallelize flatMap and needs some practical solution, not only history and theory.

The simplest solution I came up with is to do flattening by hand, basically by replacing it with map + reduce(Stream::concat).

Here's an example to demonstrate how to do this:

@Test
void testParallelStream_NOT_WORKING() throws InterruptedException, ExecutionException {
    new ForkJoinPool(10).submit(() -> {
        Stream.iterate(0, i -> i + 1).limit(2)
                .parallel()

                // does not parallelize nested streams
                .flatMap(i -> generateRangeParallel(i, 100))

                .peek(i -> System.out.println(currentThread().getName() + " : generated value: i=" + i))
                .forEachOrdered(i -> System.out.println(currentThread().getName() + " : received value: i=" + i));
    }).get();
    System.out.println("done");
}

@Test
void testParallelStream_WORKING() throws InterruptedException, ExecutionException {
    new ForkJoinPool(10).submit(() -> {
        Stream.iterate(0, i -> i + 1).limit(2)
                .parallel()

                // concatenation of nested streams instead of flatMap, parallelizes ALL the items
                .map(i -> generateRangeParallel(i, 100))
                .reduce(Stream::concat).orElse(Stream.empty())

                .peek(i -> System.out.println(currentThread().getName() + " : generated value: i=" + i))
                .forEachOrdered(i -> System.out.println(currentThread().getName() + " : received value: i=" + i));
    }).get();
    System.out.println("done");
}

Stream<Integer> generateRangeParallel(int start, int num) {
    return Stream.iterate(start, i -> i + 1).limit(num).parallel();
}

// run this method with produced output to see how work was distributed
void countThreads(String strOut) {
    var res = Arrays.stream(strOut.split("\n"))
            .map(line -> line.split("\\s+"))
            .collect(Collectors.groupingBy(s -> s[0], Collectors.counting()));
    System.out.println(res);
    System.out.println("threads  : " + res.keySet().size());
    System.out.println("work     : " + res.values());
}

Stats from run on my machine:

NOT_WORKING case stats:
{ForkJoinPool-1-worker-23=100, ForkJoinPool-1-worker-5=300}
threads  : 2
work     : [100, 300]

WORKING case stats:
{ForkJoinPool-1-worker-9=16, ForkJoinPool-1-worker-23=20, ForkJoinPool-1-worker-21=36, ForkJoinPool-1-worker-31=17, ForkJoinPool-1-worker-27=177, ForkJoinPool-1-worker-13=17, ForkJoinPool-1-worker-5=21, ForkJoinPool-1-worker-19=8, ForkJoinPool-1-worker-17=21, ForkJoinPool-1-worker-3=67}
threads  : 10
work     : [16, 20, 36, 17, 177, 17, 21, 8, 21, 67]