parallelStream 底层 ForkJoinPool 实现

​ForkJoinPool源码解析

前言

Java 8Stream是对集合(Collection)对象功能的增强,其特性之一提供了流的并行处理 -> parallelStream。本篇来分析下项目中经常使用的parallelStream底层实现机制。

正文

以main函数为切入点分析, 采用parallelStream来处理集合数据。

public static void main(String[] args) {

    List<Integer> list = new ArrayList<>();
    list.addAll(Arrays.asList(1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 16));
    //1.1 .parallelStream()
    //1.2 .forEach(e ->{})
    list.parallelStream().forEach(e -> {
        try {
            Thread.currentThread().getThreadGroup().list();
            System.out.println("执行线程名:"+Thread.currentThread() + "  id " + e);
        } catch (InterruptedException h) {
            h.printStackTrace();
        }
    });
}

1.1 list调用 parallelStream() 返回一个并行流。

default Stream<E> parallelStream() {
    return StreamSupport.stream(spliterator(), true);
}

spliterator()方法返回一个迭代器对象。Spliterator是java 8引入的接口,用于并行遍历和划分源元素的迭代器。

//ArrayList类:
@Override
public Spliterator<E> spliterator() {
    return new ArrayListSpliterator<>(this, 0, -1, 0);
}

1.2 forEach()用于遍历集合数据,该方法用于接收一个Consumer接口函数。Consumer是java 8提供的函数式接口之一。

Stream接口:

void forEach(Consumer<? super T> action);

Java 8函数式接口

函数式接口定义:有且仅有一个抽象方法

常用函数接口:

  • Consumer 消费接口,接收T(泛型)类型参数,没有返回值。

  • Function 该接口的抽象函数式有返回值。

  • Predicate 返回值固定的boolean。

从forEach出发

//流的源阶段
static class Head<E_IN, E_OUT> extends ReferencePipeline<E_IN, E_OUT> {

    @Override
    public void forEach(Consumer<? super E_OUT> action) {
        if (!isParallel()) {
            sourceStageSpliterator().forEachRemaining(action);
        }
        else {
            //并行实现
            super.forEach(action);
        }
    }
}
//流的源阶段与中间阶段
abstract class ReferencePipeline<P_IN, P_OUT>
        extends AbstractPipeline<P_IN, P_OUT, Stream<P_OUT>>
        implements Stream<P_OUT>  {
    @Override
    public void forEach(Consumer<? super P_OUT> action) {
        //makeRef方法构造一个TerminalOp对象,该操作对流的每个元素执行操作。
        evaluate(ForEachOps.makeRef(action, false));
    }  
}
//使用终端操作评估管道以生成结果。
final <R> R evaluate(TerminalOp<E_OUT, R> terminalOp) {
    assert getOutputShape() == terminalOp.inputShape();
    if (linkedOrConsumed)
        throw new IllegalStateException(MSG_STREAM_LINKED);
    //标记管道已使用
    linkedOrConsumed = true;

    return isParallel()
           ? terminalOp.evaluateParallel(this, sourceSpliterator(terminalOp.getOpFlags()))
           : terminalOp.evaluateSequential(this, sourceSpliterator(terminalOp.getOpFlags()));
}
//静态抽象类:ForEachOp
@Override
public <S> Void evaluateParallel(PipelineHelper<T> helper,
                                 Spliterator<S> spliterator) {
    if (ordered)
        new ForEachOrderedTask<>(helper, spliterator, this).invoke();
    else
        //1.3 用于执行并行操作的任务
        new ForEachTask<>(helper, spliterator, helper.wrapSink(this)).invoke();
    return null;
}

forEach阶段简单的介绍了下,重点在下文。

来张ForkJoinPool运行流程图:

parallelStream 底层 ForkJoinPool 实现

正式分析ForkJoinPool源码之前,先看相关类的属性定义

ForkJoinPool线程池状态定义:

public class ForkJoinPool extends AbstractExecutorService {
    private static final int  RSLOCK     = 1; //锁定态
    private static final int  RSIGNAL    = 1 << 1; //唤醒态
    private static final int  STARTED    = 1 << 2; //开始态
    private static final int  STOP       = 1 << 29; //停止态
    private static final int  TERMINATED = 1 << 30; //终止态
    private static final int  SHUTDOWN   = 1 << 31; //关闭态
    //线程池状态
    volatile int runState;
    
    //线程池的控制变量字段
    /**
     * long类型ctl字段为64位,包含4个16位子字段
     * AC: 活动线程数 49-64位
     * TC: 总线程数 33-48位
     * SS: 顶部等待线程的版本计数和状态(第32位为1表示inative, 其余15位表示版本计数) 17-32位
     * ID: WorkQueue在WorkQueue[]中索引位置 1-16位
     * SP: 低32位。
     *
     * AC为负:没有足够活跃的线程
     * TC为负:总线程数不足
     * SP非0:有等待的线程。
     **/
    volatile long ctl;
    
}

parallelStream 底层 ForkJoinPool 实现

ctl图示

WorkQueue属性定义:

static final class WorkQueue {
    ......
    //扫描状态,<0:不活跃;奇数:扫描任务;偶数:执行任务
    volatile int scanState; 
    //线程探针值,可用于计算workqueue队列在数组中槽位
    int hint; 
    //高16位记录队列模式(FIFO、LIFO),低16位记录工作队列数组槽位(workqueue在数组中索引下标)
    int config;
    //队列状态,1:锁定;0:未锁定;<0:终止
    volatile int qlock;
    //下一个poll操作的索引(栈底/队列头),线程从base端窃取任务
    volatile int base;         // index of next slot for poll
    //下一个push操作的索引(栈顶/队列尾)
    int top;                   // index of next slot for push
    //任务数组
    ForkJoinTask<?>[] array;
    //队列所属线程池(可能为空)
    final ForkJoinPool pool;
    //当前WorkQueue所属工作线程,如果是外部提交的任务,则owner为空
    final ForkJoinWorkerThread owner; // owning thread or null if shared
    //当前正在进行join等待的其它任务
    volatile ForkJoinTask<?> currentJoin;  // task being joined in awaitJoin
    //当前正在偷取的任务
    volatile ForkJoinTask<?> currentSteal; // mainly used by helpStealer
    ......
}

config字段高16位队列模式图:

parallelStream 底层 ForkJoinPool 实现

ForkJoinTask任务状态定义:

public abstract class ForkJoinTask<V> implements Future<V>, Serializable {
    //任务运行状态,初始为0
    volatile int status; 
    static final int DONE_MASK   = 0xf0000000;  //任务状态掩码
    static final int NORMAL      = 0xf0000000;  //完成状态
    static final int CANCELLED   = 0xc0000000;  //取消状态
    static final int EXCEPTIONAL = 0x80000000;  //异常状态
    static final int SIGNAL      = 0x00010000;  //信号状态,用于唤醒阻塞线程
    static final int SMASK       = 0x0000ffff;  //低位掩码
}

ForkJoinTask初探

1.3 跟踪 invoke方法

//ForkJoinTask抽象类
public final V invoke() {
    int s;
    //获取任务执行状态
    if ((s = doInvoke() & DONE_MASK) != NORMAL)
        reportException(s);
    //计算结果
    return getRawResult();
}
    
private int doInvoke() {
    int s; Thread t; ForkJoinWorkerThread wt;
    return (s = doExec()) < 0 ? s :
        ((t = Thread.currentThread()) instanceof ForkJoinWorkerThread) ?
        (wt = (ForkJoinWorkerThread)t).pool.
        //1.4 帮助执行或阻塞线程
        awaitJoin(wt.workQueue, this, 0L) :
        //阻塞一个非工作线程直到完成
        externalAwaitDone();
}

externalAwaitDone()方法这里就不贴代码了,简单说下逻辑。后面扫描任务的scan方法有详解,流程差不多。

外部线程,也就是前面例子中的main线程执行完doExec()完成任务拆分逻辑后,此时任务状态还不是完成态,则由main线程帮助执行子任务。

main线程先执行自己初始化创建队列中的任务,后扫描其它工作队列Workqueue中的任务。

上面流程执行完后,判断任务状态是否为完成态,如果不是完成态,说明有工作线程在执行任务,cas操作设置状态为唤醒态,然后main线程调用wait方法阻塞自己,等其它线程唤醒。工作线程任务执行完后会调用notifyAll唤醒主线程。最终任务全部执行完毕返回。

1.4  java.util.concurrent.ForkJoinPool#awaitJoin:

//帮助或阻塞,直到给定的任务完成或超时
final int awaitJoin(WorkQueue w, ForkJoinTask<?> task, long deadline) {
    int s = 0;
    if (task != null && w != null) {
        //记录前一个等待的任务
        ForkJoinTask<?> prevJoin = w.currentJoin;
        //设置当前任务
        U.putOrderedObject(w, QCURRENTJOIN, task);
        CountedCompleter<?> cc = (task instanceof CountedCompleter) ?
            (CountedCompleter<?>)task : null;
        for (;;) {
            //任务执行完
            if ((s = task.status) < 0)
                break;
            if (cc != null)
                //帮助完成任务,先执行自己队列任务,再扫描其它队列任务
                helpComplete(w, cc, 0);
            //如果队列为空,说明任务被偷
            else if (w.base == w.top || w.tryRemoveAndExec(task))
                //遍历队列找到当前偷取此任务的队列,执行偷取者队列的任务
                helpStealer(w, task);
            if ((s = task.status) < 0)
                break;
            long ms, ns;
            if (deadline == 0L)
                ms = 0L;
            else if ((ns = deadline - System.nanoTime()) <= 0L)
                break;
            else if ((ms = TimeUnit.NANOSECONDS.toMillis(ns)) <= 0L)
                ms = 1L;
            //补偿措施,唤醒线程或创建线程执行任务
            if (tryCompensate(w)) {
                task.internalWait(ms);
                U.getAndAddLong(this, CTL, AC_UNIT);
            }
        }
        U.putOrderedObject(w, QCURRENTJOIN, prevJoin);
    }
    return s;
}

跟踪doExec方法,执行路径:ForkJoinTask#doExec -> CountedCompleter#exec -> ForEachTask#compute。

来看看compute()方法做了哪些事情。

//ForEachTask静态类
public void compute() {
    //ArrayList.ArrayListSpliterator迭代器
    Spliterator<S> rightSplit = spliterator, leftSplit;
    //预估值,初始化时为集合总数,后面为任务分割后子任务内元素和
    long sizeEstimate = rightSplit.estimateSize(), sizeThreshold;
    if ((sizeThreshold = targetSize) == 0L)
        //目标值 = 预估值 / (总线程数 * 4)
        targetSize = sizeThreshold = AbstractTask.suggestTargetSize(sizeEstimate);
    boolean isShortCircuit = StreamOpFlag.SHORT_CIRCUIT.isKnown(helper.getStreamAndOpFlags());
    boolean forkRight = false;
    Sink<S> taskSink = sink;
    ForEachTask<S, T> task = this;
    //任务切分逻辑
    while (!isShortCircuit || !taskSink.cancellationRequested()) {
        // 切分直至子任务的大小小于阈值
        if (sizeEstimate <= sizeThreshold ||
            //trySplit()将rightSplit平均分,返回平分后左边的任务
            //待切分任务小于等于1,停止切分
            (leftSplit = rightSplit.trySplit()) == null) {
            //执行函数式接口
            task.helper.copyInto(taskSink, rightSplit);
            break;
        }
        ForEachTask<S, T> leftTask = new ForEachTask<>(task, leftSplit);
        //原子更新挂起任务数量
        task.addToPendingCount(1);
        ForEachTask<S, T> taskToFork;
        if (forkRight) {
            forkRight = false;
            rightSplit = leftSplit;
            taskToFork = task;
            task = leftTask;
        }
        else {
            forkRight = true;
            taskToFork = leftTask;
        }
        //将子任务提交到线程池ForkJoinPool
        taskToFork.fork();
        sizeEstimate = rightSplit.estimateSize();
    }
    task.spliterator = null;
    //如果挂起计数不为0,则递减。最终任务处理完后设置状态为完成态,并唤醒阻塞的线程
    task.propagateCompletion();
}

compute方法主要逻辑就是达到拆分条件就平均拆分任务,待切分任务小于1时,停止切分,去执行任务。对应例子中的consumer函数式接口。

拆分后的任务会提交到队列中,由线程获取任务后继续调用 compute() 来判断拆分还是执行。

继续看fork逻辑:

//ForkJoinTask抽象类
public final ForkJoinTask<V> fork() {
    Thread t;
    //如果当前线程为ForkJoin工作线程,则任务放入自身工作线程队列
    if ((t = Thread.currentThread()) instanceof ForkJoinWorkerThread)
        ((ForkJoinWorkerThread)t).workQueue.push(this);
    else
        //如果是外部线程池调用,则任务push到公共的线程池执行,由于本文是main函数开头,所以会执行此处
        ForkJoinPool.common.externalPush(this);
    return this;
}

ForkJoinPool类解析

来看下公共线程池的创建:
ForkJoinPool.common = makeCommonPool()
ForkJoinPool类初始化时会执行静态代码块,静态代码块中会执行makeCommonPool方法,返回ForkJoinPool对象实例

//ForkJoinPool类

//创建并返回一个公共线程池
private static ForkJoinPool makeCommonPool() {
    int parallelism = -1;
    ForkJoinWorkerThreadFactory factory = null;
    ......
    if (fp != null)
        factory = ((ForkJoinWorkerThreadFactory)ClassLoader.
                   getSystemClassLoader().loadClass(fp).newInstance());

    if (parallelism < 0 && // default 1 less than #cores
        //线程池数量,默认cpu核数-1, 用户也可在环境变量中指定数量
        (parallelism = Runtime.getRuntime().availableProcessors() - 1) <= 0)
        parallelism = 1;
    if (parallelism > MAX_CAP)
        parallelism = MAX_CAP;
    return new ForkJoinPool(parallelism, factory, handler, LIFO_QUEUE,
                            "ForkJoinPool.commonPool-worker-");
}

回到externalPush方法:

//外部任务提交
final void externalPush(ForkJoinTask<?> task) {
    WorkQueue[] ws; WorkQueue q; int m;
    //探针值,用于计算WorkQueue槽位索引。
    //ThreadLocalRandom原理:获取随机数时,每个线程获取自己初始化种子,避免多线程使用同一个原子种子变量,从而导致原子变量的竞争。
    int r = ThreadLocalRandom.getProbe();
    int rs = runState;
    //外部线程第一次进入,此时workQueues还未初始化,if代码不会执行
    if ((ws = workQueues) != null && (m = (ws.length - 1)) >= 0 &&
        //这里会获取偶数槽位,外部提交的任务会放在WorkQueue数组的偶数槽位。后面会配图说明
        (q = ws[m & r & SQMASK]) != null && r != 0 && rs > 0 &&
        //加锁
        U.compareAndSwapInt(q, QLOCK, 0, 1)) {
        ForkJoinTask<?>[] a; int am, n, s;
        if ((a = q.array) != null &&
            (am = a.length - 1) > (n = (s = q.top) - q.base)) {
            //计算任务在内存中偏移量
            int j = ((am & s) << ASHIFT) + ABASE;
            //向队列放入任务
            U.putOrderedObject(a, j, task);
            //top值 +1
            U.putOrderedInt(q, QTOP, s + 1);
            //解锁
            U.putIntVolatile(q, QLOCK, 0);
            //任务数小于等于1,尝试创建或激活工作线程
            if (n <= 1)
                signalWork(ws, q);
            return;
        }
        U.compareAndSwapInt(q, QLOCK, 1, 0);
    }
    externalSubmit(task);
  }

//完整版externalPush,用于将任务提交到池中,如果是第一次提交,则初始化工作队列
private void externalSubmit(ForkJoinTask<?> task) {
    int r; //线程探针值
    if ((r = ThreadLocalRandom.getProbe()) == 0) {
        //初始化线程字段,生成原子种子变量,获取随机数
        ThreadLocalRandom.localInit();
        r = ThreadLocalRandom.getProbe();
    }
    for (;;) {
        WorkQueue[] ws; WorkQueue q; int rs, m, k;
        boolean move = false;
        //如果线程池为终止状态,则抛出拒绝执行异常
        if ((rs = runState) < 0) {
            tryTerminate(false, false);     // help terminate
            throw new RejectedExecutionException();
        }
        //线程池状态不是开始状态,则状态加锁,初始化WorkQueue 数组
        else if ((rs & STARTED) == 0 ||     // initialize
                 ((ws = workQueues) == null || (m = ws.length - 1) < 0)) {
            int ns = 0;
            //加锁
            rs = lockRunState();
            try {
                if ((rs & STARTED) == 0) {
                    //通过原子操作,完成窃取任务总数这个计数器的初始化
                    U.compareAndSwapObject(this, STEALCOUNTER, null,
                                           new AtomicLong());
                    // create workQueues array with size a power of two
                    //获取并行度config的值,前面说过config值为cpuh核数-1,比如我的系统是8核,那这里p=8-1=7。
                    int p = config & SMASK; // ensure at least 2 slots
                    int n = (p > 1) ? p - 1 : 1;
                    n |= n >>> 1; n |= n >>> 2;  n |= n >>> 4;
                    n |= n >>> 8; n |= n >>> 16; n = (n + 1) << 1;
                    //数组大小这里可以找到规律。找到大于2倍p的值最近的2^m。由于最后 (n+1)左移了一位,所以n最小值为4。
                    //举例:p = 7, 那么 2p = 14, 大于2p的2^m = 16。所以这里m=4, n=16。
                    //假如p=3, 2p=6,2^m=8 > 6。 则m=3, n=8。 
                    workQueues = new WorkQueue[n]; //由于我这里p=7, 所以初始化队列数组长度为16。
                    ns = STARTED;
                }
            } finally {
                //解锁,状态修改为开始态
                unlockRunState(rs, (rs & ~RSLOCK) | ns);
            }
        }
        //判断偶数槽位是否为空
        else if ((q = ws[k = r & m & SQMASK]) != null) {
            //qlock属性定义,初始值0,1:锁定,< 0:终止
            //队列加锁
            if (q.qlock == 0 && U.compareAndSwapInt(q, QLOCK, 0, 1)) {
                //任务数组
                ForkJoinTask<?>[] a = q.array;
                int s = q.top;
                boolean submitted = false;
                try {
                    if ((a != null && a.length > s + 1 - q.base) ||
                        //growArray()逻辑:任务数组为空,则创建数组,初始大小为8192,不为空,则2倍扩容。任务移到新数组。
                        (a = q.growArray()) != null) {
                        //计算top在内存中偏移量
                        int j = (((a.length - 1) & s) << ASHIFT) + ABASE; 
                        //任务存储到任务数组
                        U.putOrderedObject(a, j, task);
                        //更新top值。top定义:push操作的索引(栈顶)
                        U.putOrderedInt(q, QTOP, s + 1);
                        submitted = true;
                    }
                } finally {
                    //队列解锁
                    U.compareAndSwapInt(q, QLOCK, 1, 0);
                }
                if (submitted) {
                    //1.5 创建或激活工作线程
                    signalWork(ws, q);
                    return;
                }
            }
            move = true;                   // move on failure
        }
        //队列数组槽位为空,则创建一个WorkQueue放入数组
        else if (((rs = runState) & RSLOCK) == 0) {
            //外部线程所属队列是共享队列,参数owner传入null
            q = new WorkQueue(this, null);
            //记录这个探针值
            q.hint = r;
            //记录队列数组槽位和队列模式
            q.config = k | SHARED_QUEUE;
            //共享队列初始化为inactive状态
            q.scanState = INACTIVE;
            rs = lockRunState();           // publish index
            if (rs > 0 &&  (ws = workQueues) != null &&
                k < ws.length && ws[k] == null)
                ws[k] = q;
            unlockRunState(rs, rs & ~RSLOCK);
        }
        else
            move = true;
        if (move)
            r = ThreadLocalRandom.advanceProbe(r);
    }
  }

externalSubmit方法三部曲:

  • 初始化工作队列数组 WorkQueue[],

  • 创建工作队列 WorkQueue,

  • 任务放入任务数组 ForkJoinTask[], 更新top值, 创建线程执行任务。

画个图帮助理解任务提交,看到这的点个赞把☺(* ̄︶ ̄)

parallelStream 底层 ForkJoinPool 实现

m & r & SQMASK必为偶数,所以通过externalSubmit方法提交的任务都添加到了偶数索引的任务队列中(没有绑定的工作线程)。

1.5  java.util.concurrent.ForkJoinPool#signalWork:

工作线程数不足:创建一个工作线程;

工作线程数足够:唤醒一个空闲(阻塞)的工作线程。

//创建或激活工作线程
final void signalWork(WorkQueue[] ws, WorkQueue q) {
    long c; int sp, i; WorkQueue v; Thread p;
    //ctl小于0时,活跃的线程数不够
    while ((c = ctl) < 0L) {
        //取ctl的低32位,如果为0,说明没有等待的线程
        if ((sp = (int)c) == 0) { 
            //取tc的高位,不为0说明总线程数不够,创建线程
            if ((c & ADD_WORKER) != 0L)
                tryAddWorker(c);
            break;
        }
        if (ws == null) 
            break;
        if (ws.length <= (i = sp & SMASK))
            break;
        if ((v = ws[i]) == null)    
            break;
        int vs = (sp + SS_SEQ) & ~INACTIVE; 
        int d = sp - v.scanState;
        long nc = (UC_MASK & (c + AC_UNIT)) | (SP_MASK & v.stackPred);
        if (d == 0 && U.compareAndSwapLong(this, CTL, c, nc)) {
            v.scanState = vs;
            if ((p = v.parker) != null)
                //唤醒线程
                U.unpark(p);
            break;
        }
        if (q != null && q.base == q.top)
            break;
    }
}

private void tryAddWorker(long c) {
    boolean add = false;
    do {
        //活跃线程和总线程数都 +1
        long nc = ((AC_MASK & (c + AC_UNIT)) |
                   (TC_MASK & (c + TC_UNIT)));
        if (ctl == c) {
            int rs, stop; 
            //加锁,并判断线程池状态是否为停止态
            if ((stop = (rs = lockRunState()) & STOP) == 0)
                //cas更新ctl的值
                add = U.compareAndSwapLong(this, CTL, c, nc);
            //解锁
            unlockRunState(rs, rs & ~RSLOCK);
            if (stop != 0)
                break;
            if (add) {
                //1.6 创建新的工作线程
                createWorker();
                break;
            }
        }
      //ADD_WORKER的第48位是1,其余位都为0,与ctl位与判断TC总线程是否已满,如果线程满了, 那么ctl的第48位为0,0 & ADD_WORKER = 0
      //(int)c == 0 表示截取ctl的低32位(int是4字节,1字节是8位)
    } while (((c = ctl) & ADD_WORKER) != 0L && (int)c == 0);
}

通过ctl的高32位来判断是否需要创建线程或激活线程。

上文有讲,公共线程池数量为cpu核数-1,我电脑是8核,所以总线程数为7。也就是并行度为7。高32位 +7最终等于0。

parallelStream 底层 ForkJoinPool 实现

1.6  java.util.concurrent.ForkJoinPool#createWorker:

private boolean createWorker() {
    ForkJoinWorkerThreadFactory fac = factory;
    Throwable ex = null;
    ForkJoinWorkerThread wt = null;
    try {
        //使用线程工厂创建工作线程
        if (fac != null && (wt = fac.newThread(this)) != null) {
            //启动线程
            wt.start();
            return true;
        }
    } catch (Throwable rex) {
        ex = rex;
    }
    deregisterWorker(wt, ex);
    return false;
}

线程工厂创建线程。

外部线程提交任务后创建线程,线程扫描到任务后会继续创建线程,直到达到总线程数。

本地测试中某一次线程创建流程图如下,可以看到main线程创建了worker-1和worker-3工作线程,worker-1又创建了worker-2和worker-4工作线程。都属于同一个main线程组。

parallelStream 底层 ForkJoinPool 实现

ForkJoinWorkerThread

public void run() {
//线程初次运行任务数组为空
if (workQueue.array == null) { // only run once
    Throwable exception = null;
    try {
        onStart();
        pool.runWorker(workQueue);
    } catch (Throwable ex) {
        exception = ex;
    } finally {
        try {
            onTermination(exception);
        } catch (Throwable ex) {
            if (exception == null)
                exception = ex;
        } finally {
            pool.deregisterWorker(this, exception);
        }
    }
}
}
java.util.concurrent.ForkJoinPool#runWorkerfinal void runWorker(WorkQueue w) {
    //初始化或两倍扩容
    w.growArray();     
    //获取创建工作队列的线程探针值
    int seed = w.hint; 
    int r = (seed == 0) ? 1 : seed;  
    for (ForkJoinTask<?> t;;) {
        //扫描任务
        if ((t = scan(w, r)) != null)
            //1.7 执行获取到的任务
            w.runTask(t);
        else if (!awaitWork(w, r))
            break;
        r ^= r << 13; r ^= r >>> 17; r ^= r << 5; // xorshift
    }
}

private ForkJoinTask<?> scan(WorkQueue w, int r) {
    WorkQueue[] ws; int m;
    //判断任务队列数组是否为空
    if ((ws = workQueues) != null && (m = ws.length - 1) > 0 && w != null) {
        //获取扫描状态
        int ss = w.scanState; 
        for (int origin = r & m, k = origin, oldSum = 0, checkSum = 0;;) {
            WorkQueue q; ForkJoinTask<?>[] a; ForkJoinTask<?> t;
            int b, n; long c;
            //如果k槽位不为空
            if ((q = ws[k]) != null) {
                if ((n = (b = q.base) - q.top) < 0 &&
                    (a = q.array) != null) {
                    //获取base的偏移量
                    long i = (((a.length - 1) & b) << ASHIFT) + ABASE;
                    //FIFO模式,从base端获取任务
                    if ((t = ((ForkJoinTask<?>)
                              U.getObjectVolatile(a, i))) != null &&
                        q.base == b) {
                        //判断线程是否活跃状态
                        if (ss >= 0) {
                            //任务被偷了。更新任务数组a在内存中偏移量为base位置的值为空
                            if (U.compareAndSwapObject(a, i, t, null)) {  
                                //更新base值
                                q.base = b + 1;
                                if (n < -1)
                                    //数组还有任务,创建或激活线程执行任务
                                    signalWork(ws, q);
                                //返回扫描到的任务
                                return t;
                            }
                        }
                        else if (oldSum == 0 && 
                                 w.scanState < 0)
                            tryRelease(c = ctl, ws[m & (int)c], AC_UNIT);
                    }
                    if (ss < 0) 
                        ss = w.scanState;
                    r ^= r << 1; r ^= r >>> 3; r ^= r << 10;
                    origin = k = r & m; 
                    oldSum = checkSum = 0;
                    continue;
                }
                checkSum += b;
            }
            //m二进制位都为1,(k+1)&m 会遍历工作队列数组所有槽位
            if ((k = (k + 1) & m) == origin) {
                if ((ss >= 0 || (ss == (ss = w.scanState))) &&
                    oldSum == (oldSum = checkSum)) {
                    if (ss < 0 || w.qlock < 0)
                        break;
                    int ns = ss | INACTIVE;
                    //活跃线程-1
                    long nc = ((SP_MASK & ns) |
                               (UC_MASK & ((c = ctl) - AC_UNIT)));
                    w.stackPred = (int)c;
                    U.putInt(w, QSCANSTATE, ns);
                    if (U.compareAndSwapLong(this, CTL, c, nc))
                        ss = ns;
                    else
                        w.scanState = ss;         // back out
                }
                checkSum = 0;
            }
        }
    }
    return null;
}

scan扫描任务,从其它队列的base端偷取。

base属性加了volatile关键字,保证了共享变量的内存可见性。获取任务后通过cas乐观锁操作将被窃取的队列中任务置位空。

此窃取机制减少了top和base端的竞争(队列线程和窃取线程分别从top和base端操作),用cas操作也提高了效率。

比如外部线程分割了3个任务,那么top+3=4099,三个任务放在内存偏移量为4096、4097、4098的位置。启动的工作线程从base端偷取任务。如下图:

parallelStream 底层 ForkJoinPool 实现

1.7  java.util.concurrent.ForkJoinPool.WorkQueue#runTask:

final void runTask(ForkJoinTask<?> task) {
    if (task != null) {
        //此时scanState变成了偶数,表示正在执行任务
        scanState &= ~SCANNING; // mark as busy
        //执行窃取到的任务,doExec上面有讲,会在compute方法中执行任务或分割任务,分割的任务放入自己队列的任务数组里。
        (currentSteal = task).doExec();
        //将窃取到的任务字段置空
        U.putOrderedObject(this, QCURRENTSTEAL, null); // release for GC
        //执行自个任务数组里任务。任务来源就是上面分割后的子任务。
        execLocalTasks();
        ForkJoinWorkerThread thread = owner;
        //窃取任务数+1
        if (++nsteals < 0)      // collect on overflow
            //任务数累加到pool的stealCounter字段中
            transferStealCount(pool);
        //恢复扫描状态
        scanState |= SCANNING;
        if (thread != null)
            //钩子方法
            thread.afterTopLevelExec();
    }
}

final void execLocalTasks() {
    int b = base, m, s;
    ForkJoinTask<?>[] a = array;
    //任务数组中有任务
    if (b - (s = top - 1) <= 0 && a != null &&
        (m = a.length - 1) >= 0) {
        //队列模式,判断是否为后进先出模式 LIFO。
        if ((config & FIFO_QUEUE) == 0) {
            for (ForkJoinTask<?> t;;) {
                //从top端取任务
                if ((t = (ForkJoinTask<?>)U.getAndSetObject
                     (a, ((m & s) << ASHIFT) + ABASE, null)) == null)
                    break;
                //修改内存中top值
                U.putOrderedInt(this, QTOP, s);
                //执行任务
                t.doExec();
                if (base - (s = top - 1) > 0)
                    break;
            }
        }
        else
            //先进先出模式 FIFO, 轮询从base端取任务运行,直到为空
            pollAndExecAll();
    }
}


runTask方法主要逻辑就是将扫描到的任务执行,执行过程中可能会继续分割。执行完窃取的任务后,就执行自个队列里的任务。

至此,ForkJoinPool就分析到这。

总结:

大概可以用一句话总结ForkJoinPool原理:一个可以并行执行任务的线程池,可以处理一个可递归划分的任务并获取结果。

参考博文:

[1]  https://blog.csdn.net/qq_27785239/article/details/103395079
[2]  https://blog.csdn.net/yusimiao/article/details/114009972

上一篇:ThreadPoolExecutor 中为什么WorkQueue会在corePoolSize满了之后入队


下一篇:forkJoin源码解读