介绍
“分而治之“是理清思路和解决问题的一个重要的方法。大到系统架构对功能模块的拆分,小到归并排序的实现,无一不在散发着分而治之的思想。在实现分而治之的算法的时候,我们通常使用递归的方法。递归相当于把大的任务拆成多个小的任务,然后大任务等待多个小的子任务执行完成后,合并子任务的结果。一般来说,父任务依赖与子任务的执行结果,子任务与子任务之间没有依赖关系。因此子任务之间可以并发执行来提升性能。于是ForkJoinPool
提供了一个并发处理“分而治之”的框架,让我们能以类似于递归的编程方式获得并发执行的能力。
使用
分而治之代码典型的形式如下:
Result solve(Problem problem) {
if (problem is small) {
directly solve problem
} else {
split problem into independent parts
fork new subtasks to solve each part
join all subtasks
compose result from subresults
}
}
计算斐波那契数:
Class Fibonacci extends RecursiveTask<Integer> {
final int n;
Fibonacci(int n) { this.n = n; }
Integer compute() {
if (n <= 1)
return n;
Fibonacci f1 = new Fibonacci(n - 1);
f1.fork();
Fibonacci f2 = new Fibonacci(n - 2);
return f2.compute() + f1.join();
}
}
原理
ForkJoinPool
的核心在于其轻量级的调度机制,采用了Cilk的work-stealing的基本调度策略:
每个工作线程维持一个任务队列
任务队列以双端队列的形式维护,不仅支持先进后出的
push
和pop
操作,还支持先进先出的take操作由父任务
fork
出来的子任务被push
到运行该父任务的工作线程对应的任务队列中工作线程以先进后出的方式处理
pop
自己任务队列中的任务(优先处理最年轻的任务)当任务队列中没有任务时,工作线程尝试随机从其他任务队列中窃取任务
当工作线程没有任务可以执行,且窃取不到任务时,它会“退出”(yiled、sleep、优先级调整),经过一段时间后再次尝试。除非其他所有的线程也都没有任务可以执行,这种情况下它们会一直阻塞直到有新的任务从上层添加进来
一个简单的实现:
public class NaiveForkJoinPool {
private final TaskQueue[] submissionQueues;
private final TaskQueue[] workerQueues;
private final WorkerThread[] workers;
private final AtomicInteger aliveCount;
private final ReentrantLock lock = new ReentrantLock();
private final Condition taskEmpty = lock.newCondition();
private final int parallelism;
public NaiveForkJoinPool(int parallelism) {
this.parallelism = parallelism;
submissionQueues = new TaskQueue[parallelism];
workerQueues = new TaskQueue[parallelism];
workers = new WorkerThread[parallelism];
aliveCount = new AtomicInteger(parallelism);
for (int i = 0; i < parallelism; i++) {
submissionQueues[i] = new TaskQueue();
workerQueues[i] = new TaskQueue();
workers[i] = new WorkerThread(this, workerQueues[i]);
}
for (int i = 0; i < parallelism; i++) {
workers[i].start();
}
}
public <T> T invoke(Task<T> task) {
TaskQueue sd = submissionQueues[(submissionQueues.length
- 1) & ThreadLocalRandom.current().nextInt()];
sd.push(task);
tryCompensate();
return task.join();
}
public <T> List<T> invokeAll(Task<T>... tasks) {
List<T> res = new LinkedList<>();
for (Task<T> task : tasks) {
TaskQueue sd = submissionQueues[(submissionQueues.length
- 1) & ThreadLocalRandom.current().nextInt()];
sd.push(task);
tryCompensate();
res.add(task.join());
}
return res;
}
void tryCompensate() {
if (aliveCount.get() < parallelism) {
lock.lock();
if (aliveCount.get() < parallelism) {
taskEmpty.signal();
}
lock.unlock();
}
}
void runWorker() {
int len = submissionQueues.length;
int startIndex = (ThreadLocalRandom.current().nextInt()) & (len -
1);
for (Task task = null; ; ) {
if (task != null || (task = scan(startIndex)) != null) {
task.runTask();
task = null;
} else {
task = awaitForWork(startIndex);
}
}
}
Task scan(int startIndex) {
Task task;
if ((task = scan(startIndex, submissionQueues)) != null) {
return task;
}
if ((task = scan(startIndex, workerQueues)) != null) {
return task;
}
return null;
}
Task scan(int startIndex, TaskQueue[] queues) {
for (int i = startIndex, len = queues.length; i <
startIndex + len; i++) {
TaskQueue td = queues[i & (len - 1)];
Task task = td.take();
if (task != null) {
return task;
}
}
return null;
}
Task awaitForWork(int startIndex) {
lock.lock();
try {
Task task = scan(startIndex);
if (task != null) {
return task;
}
aliveCount.decrementAndGet();
try {
taskEmpty.await();
} catch (InterruptedException e) {
e.printStackTrace();
}
aliveCount.incrementAndGet();
return null;
} finally {
lock.unlock();
}
}
class WorkerThread extends Thread {
NaiveForkJoinPool pool;
TaskQueue workQueue;
public WorkerThread(NaiveForkJoinPool pool, TaskQueue workQueue) {
this.pool = pool;
this.workQueue = workQueue;
}
@Override
public void run() {
runWorker();
}
}
static abstract class Task<T> {
static final int NORMAL = 1;
final AtomicInteger status = new AtomicInteger();
final CountDownLatch isDone = new CountDownLatch(1);
private T result;
public abstract T compute();
public void runTask() {
result = compute();
status.set(NORMAL);
isDone.countDown();
}
public Task<T> fork() {
WorkerThread t = (WorkerThread) Thread.currentThread();
t.workQueue.push(this);
t.pool.tryCompensate();
return this;
}
public T join() {
Thread currentThread = Thread.currentThread();
if (currentThread instanceof WorkerThread) {
WorkerThread t = (WorkerThread) Thread.currentThread();
TaskQueue wk = t.workQueue;
for (Task task = wk.pop(); task != null; task = wk.pop()) {
task.runTask();
if (task == this) {
return result;
}
}
waitForComplete();
} else {
waitForComplete();
}
return result;
}
void waitForComplete() {
try {
isDone.await();
} catch (InterruptedException e) {
}
}
}
static class TaskQueue {
private final Deque<Task> deque = new ConcurrentLinkedDeque<>();
public void push(Task task) {
deque.push(task);
}
public Task pop() {
return deque.pollFirst();
}
public Task take() {
return deque.pollLast();
}
}
}