知道了Fork/Join框架的原理后,对阅读源码应该困难小了很多。Fork/Join框架核心的类有3个,ForkJoinTask、ForkJoinWorkerThread和ForkJoinPool,另外还有两上ForkJoinTask的子类,RecursiveTask和RecursiveAction。
一、类图结构
从类图可以看出,ForkJoinPool是一个ExecutorService服务,ExecutorService的方法都适用于ForkJoinPool,而ForkJoinTask也是一个Future实现类,Future定义的方法在ForkJoinTask中都有实现。ForkJoinWorkerThread则是Thread的子类,是一个特殊的线程实现。RecursiveTask和RecursiveAction则分别继承了ForkJoinTask,RecursiveTask表示有结果的任务,而RecursiveAction则表示没有返回结果的任务,一般在使用时都会继承这两个子类来实现它的compute方法。
二、源码分析 ForkJoinPool、ForkJoinWorkerThread和ForkJoinTask分别代表调度者、执行者和执行单元,三者的关系是密不可分的,单独实现哪个都没有意义。ForkJoinWorkerThread是连接ForkJoinPool和ForkJoinTask的纽带,下面先来看它的实现;
1、ForkJoinWorkerThread ForkJoinWorkerThread是Thread的子类,代表执行ForkJoinTask的工作者线程。它需要维护一个双端任务队列。因此ForkJoinWorkerThread的大部分功能都与这个队列有关。双端任务队列是通过数据来实现的,它用一个queueTop int变量代表队列的头端,queueBase int变量代表尾端。
1 2 3 4 5 6 7 ForkJoinTask<?>[] queue; int queueTop;volatile int queueBase;private static final int INITIAL_QUEUE_CAPACITY = 1 << 13 ;private static final int MAXIMUM_QUEUE_CAPACITY = 1 << 24 ;
因为queueTop只有当前的worker线程能访问,因为不需要并发控制;queueBase是多线程访问的变量,使用volatile来实现它的内存可见性。任务队列长度初始化为 1 << 13,最大不能超过1 << 24,在扩展长度时,都是以2的n次方来进进行的,主要是为了进行位算的方便性; 任务队列是个对象数组,当进行入队出队操作时就要对它进行并发更新,主要是通过Unsafe类提供的CAS操作来实现的。Unsafe的CAS操作需要计算出内存的实际偏移地址,所以在入队时可以看到一些位运算来定义内存地址:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 final void pushTask (ForkJoinTask<?> t) { ForkJoinTask<?>[] q; int s, m; if ((q = queue) != null ) { long u = (((s = queueTop) & (m = q.length - 1 )) << ASHIFT) + ABASE; UNSAFE.putOrderedObject(q, u, t); queueTop = s + 1 ; if ((s -= queueBase) <= 2 ) pool.signalWork(); else if (s == m) growQueue(); } }
入队的过程实际上就是把任务添加到数据中,首先会要计算出任务要添加数组的内存地址,然后通过UNSAFE.putOrderedObject方法来插入任务;最后通知线程池signalWork来调度worker线程执行任务;如果队列已经满,则通过growQueue的方法来进行扩容。 对于出队操作,从ForkJoinPool的实现原理我们知道,它有个work-stealing机制,所以这里支持从队首出队或从队尾出队,队首出队使用队列表现得像一个栈,即LIFO; 而从队尾出现表现像队列,即FIFO;本地worker线程可以选择从队首或队尾出队,这取决于ForkJoinPool的初始化参数配置;而非本地线程只能从队尾出队。 从队首出队的方法:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 private ForkJoinTask<?> popTask() { int m; ForkJoinTask<?>[] q = queue; if (q != null && (m = q.length - 1 ) >= 0 ) { for (int s; (s = queueTop) != queueBase;) { int i = m & --s; long u = (i << ASHIFT) + ABASE; ForkJoinTask<?> t = q[i]; if (t == null ) break ; if (UNSAFE.compareAndSwapObject(q, u, t, null )) { queueTop = s; return t; } } } return null ; }
可以看到,它操作的是queueTop的最新值,即代表最近入队的任务,计算出任务所在队列的内存偏移地址后,再通过CAS操作把任务引用置null,这里要注意的是任务可以被别的worker线程偷走的情况 ,所以需要不停遍历队列直至找到任务,如果最后找不到任务,返回null。 从队尾出队的方法:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 final ForkJoinTask<?> deqTask() { ForkJoinTask<?> t; ForkJoinTask<?>[] q; int b, i; if (queueTop != (b = queueBase) && (q = queue) != null && (i = (q.length - 1 ) & b) >= 0 && (t = q[i]) != null && queueBase == b && UNSAFE.compareAndSwapObject(q, (i << ASHIFT) + ABASE, t, null )) { queueBase = b + 1 ; return t; } return null ; }
可以看出,主要操作的是queueBase变量,同样需要计算出任务所以数组的内存地址,然后再进行一个re-check操作,检查queueBase变量是否发生变化,因为queueBase是volatile的,有变化都会更新;最后进行CAS设置任务引用为null,queueBase变量+1; 另外还有类似的一个locallyDeqTask()方法,只有本地线程能调用:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 final ForkJoinTask<?> locallyDeqTask() { ForkJoinTask<?> t; int m, b, i; ForkJoinTask<?>[] q = queue; if (q != null && (m = q.length - 1 ) >= 0 ) { while (queueTop != (b = queueBase)) { if ((t = q[i = m & b]) != null && queueBase == b && UNSAFE.compareAndSwapObject(q, (i << ASHIFT) + ABASE, t, null )) { queueBase = b + 1 ; return t; } } } return null ; }
提供给ForkJoinTask的出队方法:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 final ForkJoinTask<?> pollTask() { ForkJoinWorkerThread[] ws; ForkJoinTask<?> t = pollLocalTask(); if (t != null || (ws = pool.workers) == null ) return t; int n = ws.length; int steps = n << 1 ; int r = nextSeed(); int i = 0 ; while (i < steps) { ForkJoinWorkerThread w = ws[(i++ + r) & (n - 1 )]; if (w != null && w.queueBase != w.queueTop && w.queue != null ) { if ((t = w.deqTask()) != null ) return t; i = 0 ; } } return null ; }
线程初始化和执行: 构造方法初始化:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 protected ForkJoinWorkerThread (ForkJoinPool pool) { super (pool.nextWorkerName()); this .pool = pool; int k = pool.registerWorker(this ); poolIndex = k; eventCount = ~k & SMASK; locallyFifo = pool.locallyFifo; Thread.UncaughtExceptionHandler ueh = pool.ueh; if (ueh != null ) setUncaughtExceptionHandler(ueh); setDaemon(true ); }
onStart时初始化:
1 2 3 4 5 6 7 protected void onStart () { queue = new ForkJoinTask<?>[INITIAL_QUEUE_CAPACITY]; int r = pool.workerSeedGenerator.nextInt(); seed = (r == 0 ) ? 1 : r; }
种子的作用主要是在线程池扫描worker线程时,能够均匀地扫描。 执行线程任务:
1 2 3 4 5 6 7 8 9 10 11 12 13 public void run () { Throwable exception = null ; try { onStart(); pool.work(this ); } catch (Throwable ex) { exception = ex; } finally { onTermination(exception); } }
给ForkJoinTask提供的joinTask方法,用于合并任务:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 final int joinTask (ForkJoinTask<?> joinMe) { ForkJoinTask<?> prevJoin = currentJoin; currentJoin = joinMe; for (int s, retries = MAX_HELP;;) { if ((s = joinMe.status) < 0 ) { currentJoin = prevJoin; return s; } if (retries > 0 ) { if (queueTop != queueBase) { if (!localHelpJoinTask(joinMe)) retries = 0 ; } else if (retries == MAX_HELP >>> 1 ) { --retries; if (tryDeqAndExec(joinMe) >= 0 ) Thread.yield(); } else retries = helpJoinTask(joinMe) ? MAX_HELP : retries - 1 ; } else { retries = MAX_HELP; pool.tryAwaitJoin(joinMe); } } }
以上代码主要执行以下逻辑: 如果任务已经完成,则返回任务状态; 如果任务在自己任务队列的首部,则执行它; 如果任务在任务的尾部,则执行 ; 以上两种情况都不满足,扫描别的worker线程的任务队列,发现的话就执行它; 最后如果都不满足则调用线程池的tryAwaitJoin方法陷入阻塞等待;
给线程池扫描时提供的execTask方法,用于执行任务:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 final void execTask (ForkJoinTask<?> t) { currentSteal = t; for (;;) { if (t != null ) t.doExec(); if (queueTop == queueBase) break ; t = locallyFifo ? locallyDeqTask() : popTask(); } ++stealCount; currentSteal = null ; }
2、ForkJoinTask ForkJoinTask代表一个可分割和合并的任务,它是Future的实现类,所以它可以代表一个异步计算任务,具有get/cancel/isDone/isCancelled一些方法。ForkJoinTask通过一个volatile的status整型变量来代表任务状态,它定义的任务状态有以下4种:
1 2 3 4 5 volatile int status; private static final int NORMAL = -1 ;private static final int CANCELLED = -2 ;private static final int EXCEPTIONAL = -3 ;private static final int SIGNAL = 1 ;
小于0的状态都代表任务完成态,分别代表正常完成、被取消和异常状态,SIGNAL代表任务处于阻塞等待通知状态,另外还有0代表初始状态。 ForkJoinTask最主要的两个方法就是fork和join了,对任务进行分割和合并: fork方法:
1 2 3 4 5 6 public final ForkJoinTask<V> fork () { ((ForkJoinWorkerThread) Thread.currentThread()) .pushTask(this ); return this ; }
fork方法比较简单,只是把任务推到当前worker线程的任务队列中。
join方法:
1 2 3 4 5 6 7 8 public final V join () { if (doJoin() != NORMAL) return reportResult(); else return getRawResult(); }
调用核心的doJoin方法后,如果返回非正常结果,比如任务被取消,执行过程中发生异常等,则调用reportResult方法处理,否则调用getRawResult方法处理。 核心的doJoin方法:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 private int doJoin () { Thread t; ForkJoinWorkerThread w; int s; boolean completed; if ((t = Thread.currentThread()) instanceof ForkJoinWorkerThread) { if ((s = status) < 0 ) return s; if ((w = (ForkJoinWorkerThread)t).unpushTask(this )) { try { completed = exec(); } catch (Throwable rex) { return setExceptionalCompletion(rex); } if (completed) return setCompletion(NORMAL); } return w.joinTask(this ); } else return externalAwaitDone(); }
doJoin方法首先判断执行的线程是不是worker线程,是的话尝试从它的任务队列首尾出队一个任务,成功则执行该任务,否则调用worker线程的joinTask方法;如果不是worker线程进行join任务,则调用externalAwaitDone进行等待; 另外ForkJoinTask类还提供invoke/invokeAll等方法用于执行任务;
三、ForkJoinPool ForkJoinPool的主要工作是对worker线程进行调度,是3个类中实现最复杂的一个,其中为了维护池的状态和worker线程组并发更新,用了大量的位运算。ForkJoinPool也是执行work-stealing算法的主要场所。ForkJoinPool通过一个volatile的64位的long变量ctl来表示池的状态,把它按位切割成了几个部分,每个部分代表了不同的含义:
1 2 3 4 5 6 7 8 9 10 11 12 volatile long ctl;
为了控制worker数组的并发更新,定义了32位的scanGuard,同样按拉把它切成了3个部分:
1 2 3 4 5 volatile int scanGuard;private static final int SG_UNIT = 1 << 16 ;
ForkJoinPool本身也维护着一个任务队列,与worker线程类似,主要是用来放置外部线程提交的任务,通过lock来实现并发控制:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 private ForkJoinTask<?>[] submissionQueue;private final ReentrantLock submissionLock;private final Condition termination;
下面分析ForkJoinPool的主要方法: 任务提交:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 public <T> ForkJoinTask<T> submit (ForkJoinTask<T> task) { if (task == null ) throw new NullPointerException(); forkOrSubmit(task); return task; } private <T> void forkOrSubmit (ForkJoinTask<T> task) { ForkJoinWorkerThread w; Thread t = Thread.currentThread(); if (shutdown) throw new RejectedExecutionException(); if ((t instanceof ForkJoinWorkerThread) && (w = (ForkJoinWorkerThread)t).pool == this ) w.pushTask(task); else addSubmission(task); }
首先会判断任务是否为worker线程提交,如果是则直接把任务推到它的任务队列中;否则提交到线程池的任务队列。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 private void addSubmission (ForkJoinTask<?> t) { final ReentrantLock lock = this .submissionLock; lock.lock(); try { ForkJoinTask<?>[] q; int s, m; if ((q = submissionQueue) != null ) { long u = (((s = queueTop) & (m = q.length-1 )) << ASHIFT)+ABASE; UNSAFE.putOrderedObject(q, u, t); queueTop = s + 1 ; if (s - queueBase == m) growSubmissionQueue(); } } finally { lock.unlock(); } signalWork(); }
addSubmission方法通过ReentrantLock来控制添加到线程池的任务队列,因为外部线程可能存在并发提交的情况。最后通过signalWork方法通知线程唤醒worker线程或者添加新的worker线程来工作;
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 final void signalWork () { long c; int e, u; while ((((e = (int )(c = ctl)) | (u = (int )(c >>> 32 ))) & (INT_SIGN|SHORT_SIGN)) == (INT_SIGN|SHORT_SIGN) && e >= 0 ) { if (e > 0 ) { int i; ForkJoinWorkerThread w; ForkJoinWorkerThread[] ws; if ((ws = workers) == null || (i = ~e & SMASK) >= ws.length || (w = ws[i]) == null ) break ; long nc = (((long )(w.nextWait & E_MASK)) | ((long )(u + UAC_UNIT) << 32 )); if (w.eventCount == e && UNSAFE.compareAndSwapLong(this , ctlOffset, c, nc)) { w.eventCount = (e + EC_UNIT) & E_MASK; if (w.parked) UNSAFE.unpark(w); break ; } } else if (UNSAFE.compareAndSwapLong (this , ctlOffset, c, (long )(((u + UTC_UNIT) & UTC_MASK) | ((u + UAC_UNIT) & UAC_MASK)) << 32 )) { addWorker(); break ; } } }
signalWork方法先通过一系列位运算来获取池的状态进行判断,如果池中有阻塞的空闲线程则唤醒它;否则添加一个新的worker线程;signalWork在worker线程把任务推到自己的任务队列(pushTask())时也会进行调用。
addWorker方法添加了新的worker线程后会调用t.start方法启动线程,即worker的run方法会被调用,执行pool.work方法:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 final void work (ForkJoinWorkerThread w) { boolean swept = false ; long c; while (!w.terminate && (int )(c = ctl) >= 0 ) { int a; if (!swept && (a = (int )(c >> AC_SHIFT)) <= 0 ) swept = scan(w, a); else if (tryAwaitWork(w, c)) swept = false ; } }
work方法包含了worker线程最外层的while循环,它首先通过位运算得到池的状态和线程相关信息,然后调用scan方法扫描整个线程池的worker数组的任务队列和线程池的任务队列来执行任务。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 private boolean scan (ForkJoinWorkerThread w, int a) { int g = scanGuard; int m = (parallelism == 1 - a && blockedCount == 0 ) ? 0 : g & SMASK; ForkJoinWorkerThread[] ws = workers; if (ws == null || ws.length <= m) return false ; for (int r = w.seed, k = r, j = -(m + m); j <= m + m; ++j) { ForkJoinTask<?> t; ForkJoinTask<?>[] q; int b, i; ForkJoinWorkerThread v = ws[k & m]; if (v != null && (b = v.queueBase) != v.queueTop && (q = v.queue) != null && (i = (q.length - 1 ) & b) >= 0 ) { long u = (i << ASHIFT) + ABASE; if ((t = q[i]) != null && v.queueBase == b && UNSAFE.compareAndSwapObject(q, u, t, null )) { int d = (v.queueBase = b + 1 ) - v.queueTop; v.stealHint = w.poolIndex; if (d != 0 ) signalWork(); w.execTask(t); } r ^= r << 13 ; r ^= r >>> 17 ; w.seed = r ^ (r << 5 ); return false ; } else if (j < 0 ) { r ^= r << 13 ; r ^= r >>> 17 ; k = r ^= r << 5 ; } else ++k; } if (scanGuard != g) return false ; else { ForkJoinTask<?> t; ForkJoinTask<?>[] q; int b, i; if ((b = queueBase) != queueTop && (q = submissionQueue) != null && (i = (q.length - 1 ) & b) >= 0 ) { long u = (i << ASHIFT) + ABASE; if ((t = q[i]) != null && queueBase == b && UNSAFE.compareAndSwapObject(q, u, t, null )) { queueBase = b + 1 ; w.execTask(t); } return false ; } return true ; } }
scan方法实现非常复杂,但基本的逻辑是先随机获取一个worker线程,然后窃取它的任务来执行,如果失败则扫描线程池的任务队列。如果能找到任务执行都返回false,否则返回true。返回false会导致work方法继续进行while循环扫描,直到返回true或活动线程数<=0;此时,会调用tryAwaitWork把当前worker线程park起等待唤醒。
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 private boolean tryAwaitWork (ForkJoinWorkerThread w, long c) { int v = w.eventCount; w.nextWait = (int )c; long nc = (long )(v & E_MASK) | ((c - AC_UNIT) & (AC_MASK|TC_MASK)); if (ctl != c || !UNSAFE.compareAndSwapLong(this , ctlOffset, c, nc)) { long d = ctl; return (int )d != (int )c && ((d - c) & AC_MASK) >= 0L ; } for (int sc = w.stealCount; sc != 0 ;) { long s = stealCount; if (UNSAFE.compareAndSwapLong(this , stealCountOffset, s, s + sc)) sc = w.stealCount = 0 ; else if (w.eventCount != v) return true ; } if ((!shutdown || !tryTerminate(false )) && (int )c != 0 && parallelism + (int )(nc >> AC_SHIFT) == 0 && blockedCount == 0 && quiescerCount == 0 ) idleAwaitWork(w, nc, c, v); for (boolean rescanned = false ;;) { if (w.eventCount != v) return true ; if (!rescanned) { int g = scanGuard, m = g & SMASK; ForkJoinWorkerThread[] ws = workers; if (ws != null && m < ws.length) { rescanned = true ; for (int i = 0 ; i <= m; ++i) { ForkJoinWorkerThread u = ws[i]; if (u != null ) { if (u.queueBase != u.queueTop && !tryReleaseWaiter()) rescanned = false ; if (w.eventCount != v) return true ; } } } if (scanGuard != g || (queueBase != queueTop && !tryReleaseWaiter())) rescanned = false ; if (!rescanned) Thread.yield(); else Thread.interrupted(); } else { w.parked = true ; if (w.eventCount != v) { w.parked = false ; return true ; } LockSupport.park(this ); rescanned = w.parked = false ; } } }
tryAwaitWork方法会进行多次检查,目的是最后挣扎一下,不想进入park(阻塞)状态。在worker线程进行joinTask方法中,如果合并失败,会进入tryAwaitJoin方法,也是线程池提供的:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 final void tryAwaitJoin (ForkJoinTask<?> joinMe) { int s; Thread.interrupted(); if (joinMe.status >= 0 ) { if (tryPreBlock()) { joinMe.tryAwaitDone(0L ); postBlock(); } else if ((ctl & STOP_BIT) != 0L ) joinMe.cancelIgnoringExceptions(); } }
最后看看ForkJoinPool的初始化过程:
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 public ForkJoinPool (int parallelism, ForkJoinWorkerThreadFactory factory, Thread.UncaughtExceptionHandler handler, boolean asyncMode) { checkPermission(); if (factory == null ) throw new NullPointerException(); if (parallelism <= 0 || parallelism > MAX_ID) throw new IllegalArgumentException(); this .parallelism = parallelism; this .factory = factory; this .ueh = handler; this .locallyFifo = asyncMode; long np = (long )(-parallelism); this .ctl = ((np << AC_SHIFT) & AC_MASK) | ((np << TC_SHIFT) & TC_MASK); this .submissionQueue = new ForkJoinTask<?>[INITIAL_QUEUE_CAPACITY]; int n = parallelism << 1 ; if (n >= MAX_ID) n = MAX_ID; else { n |= n >>> 1 ; n |= n >>> 2 ; n |= n >>> 4 ; n |= n >>> 8 ; } workers = new ForkJoinWorkerThread[n + 1 ]; this .submissionLock = new ReentrantLock(); this .termination = submissionLock.newCondition(); StringBuilder sb = new StringBuilder("ForkJoinPool-" ); sb.append(poolNumberGenerator.incrementAndGet()); sb.append("-worker-" ); this .workerNamePrefix = sb.toString(); }
ForkJoinPool还提供了ExecutorService的一些通用方法,这里不再分析。
四、总结 从ForkJoinPool、ForkJoinWorkerThread和ForkJoinTask源码实现可以看到,三者之前存在互相依赖的关系,且实现比ThreadPoolExecutor更加复杂;在并发控制时较少使用lock相关的API。