线程池中线程都执行完毕再执行后续操作

问题

在AsteriaGraph上查询MayKnow节点时,为了提高查询性能,我们决定使用线程池启动多个查询线程来提高查询效率。我们使用固定大小的线程池,每个线程在查询过程中会将查询结果存放到一个公共的结构中,核心代码如下:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
/*
*使用固定大小的线程池启动多个查询线程计算查询结果
*/
private static void createResultMap(int poolSize, Set<String> midNodeSet,
ConcurrentHashMap<String, Bag> concurrentMap, AsteriaGraphClient graphClient, String startNodeValue,
RelationCondition.RelationConditionBuilder builder)
throws AsteriaGraphException {
//固定数量的线程池
ExecutorService pool = Executors.newFixedThreadPool(poolSize);
RelationshipIterator rels = null;
for (String midNode : midNodeSet) {//遍历每一个中间节点,对每一个中间节点启动一个查询线程
//获取到每个中间节点的所有符合builder配置的直连边的遍历器
rels =
Utils.getDirectedRelsIterator(graphClient, Utils.getNode(graphClient, Constant.UID, midNode), builder);
//启动一个查询线程
pool.execute(new TraversalThread(rels, midNode, startNodeValue, concurrentMap, midNodeSet));
}
//关闭线程池
pool.shutdownNow();
}

在实际环境经过多次验证,我们发现有个现象:每次返回的结果数量可能不一致!再定位到代码,我们觉得问题的原因应该是这样的:在所有查询线程还未全部执行完毕时就输出了查询结果。

解决方法

解决问题的方法就是需要保证在代码pool.shutdownNow()执行之后保证所有线程池内的所有线程都执行完毕才执行后续的代码。

查阅资料后发现有两种解决思路:

①闭锁(CountDownLatch);

②线程池自身的方法(awaitTermination)来判断所有线程是否执行完毕。

因为①的效率不是很好,我们这里着重讲解方法②。

awaitTermination方法:
awaitTermination(long timeOut,TimeUnit unit)方法有两个参数,一个是timeout(超时时间),另一个是unit(时间单位)。调用该方法后,当前线程阻塞,直到
1、等所有已提交的任务(包括正在跑的和队列中等待的)执行完毕
2、或者到达超时时间
3、或者线程被中断,抛出InterruptedException

然后返回true(shutdown请求后所有任务执行完毕)或者false(已超时)。

修改后的代码:

1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
private static void createResultMap(int poolSize, Set<String> midNodeSet,
ConcurrentHashMap<String, Bag> concurrentMap, AsteriaGraphClient graphClient, String startNodeValue,
RelationCondition.RelationConditionBuilder builder)
throws AsteriaGraphException {
//固定数量的线程池
ExecutorService pool = Executors.newFixedThreadPool(poolSize);
RelationshipIterator rels = null;
for (String midNode : midNodeSet) {//遍历每一个中间节点,对每一个中间节点启动一个查询线程
//获取到每个中间节点的所有符合builder配置的直连边的遍历器
rels =
Utils.getDirectedRelsIterator(graphClient, Utils.getNode(graphClient, Constant.UID, midNode), builder);
//启动一个查询线程
pool.execute(new TraversalThread(rels, midNode, startNodeValue, concurrentMap, midNodeSet));
}
//关闭线程池
pool.shutdownNow();
try {
if(!pool.awaitTermination(1, TimeUnit.MINUTES)){//因为我们对于查询时间的要求是秒级的,所以这里的超时时间设置为1分钟
logger.error("查询超时!");
}
} catch (InterruptedException e) {
logger.error("Exception Happened during finding mayknow nodes!",e);
}
}