15、Java JUC源码分析 - 集合-ConcurrentHashMap
好几天没看juc了,之前看了HashMap,还有个差不多的HashTable,二者的结构大致相同,小小的比较下2者的不同:
1、 HashMap是非线程安全的,HashTable通过synchronized加锁实现线程安全如果我们的代码里存在{get();...;put()}这种操作的话就保证不了;
2、 HashMap可以存储key或value为null的值,HashTable不行;
3、 初始大小HashTable是11,HashMap是16,扩容的话,HashTable是2*old+1,HashMap是2*old;
可能还有其他的不同,先不管了。
这次学习下ConcurrentHashMap,看看为什么说ConcurrentHashMap是线程安全的。HashTable的锁是加在整个table上,这样你put的时候就不同get,get的时候就不能put,而ConcurrentHashMap通过将整个table分段,将一个大的table分成几份,每次只对你要处理的那部分加锁,这样就减少了锁等待,看下ConcurrentHashMap的结构,画个图看看:
画的太丑。ConcurrentHashMap将整个table分成多个segment,每个segment相当于一个table,segment各自维护自己的锁,大概就是这个意思。
看下一些字段:
<span style="font-size:18px;">//默认初始大小
static final int DEFAULT_INITIAL_CAPACITY = 16;
//负载因子
static final float DEFAULT_LOAD_FACTOR = 0.75f;
//每个segment一个锁,并发数,可以看出segment的个数
static final int DEFAULT_CONCURRENCY_LEVEL = 16;
//最大容量
static final int MAXIMUM_CAPACITY = 1 << 30;
//segment中table最小的容量
static final int MIN_SEGMENT_TABLE_CAPACITY = 2;
//segment最大个数,65536
static final int MAX_SEGMENTS = 1 << 16; // slightly conservative
//加锁前重试次数,取size时会用到
static final int RETRIES_BEFORE_LOCK = 2;
//mask,跟segmentshift搭配使用,用来获取存储位置的segment的时候会用,下面讲
final int segmentMask;
//偏移量
final int segmentShift;
//segments
final Segment<K,V>[] segments;
</span>
基本还能将就看懂,多了几个字段,主要用来搜索具体segment的时候使用,跟着构造函数看看可能会更清楚怎么用的:
<span style="font-size:18px;">public ConcurrentHashMap(int initialCapacity,
float loadFactor, int concurrencyLevel) {
//入参判断
if (!(loadFactor > 0) || initialCapacity < 0 || concurrencyLevel <= 0)
throw new IllegalArgumentException();
//segment大小判断
if (concurrencyLevel > MAX_SEGMENTS)
concurrencyLevel = MAX_SEGMENTS;
// Find power-of-two sizes best matching arguments
int sshift = 0;
int ssize = 1;
//这里处理的就是保证segment的大小为不小于入参并发量的2的倍数,有点绕口
//举个栗子:并发数为9-16时,则ssize为16,跟hashmap那个意思差不多
while (ssize < concurrencyLevel) {
++sshift;
ssize <<= 1;
}
this.segmentShift = 32 - sshift; //segment的偏移量,就是每次hash后右偏移多少位,就是保留hash后值的高位
this.segmentMask = ssize - 1; //hash右偏移多少位后与这个值做&操作获取值存储的具体segment位置
if (initialCapacity > MAXIMUM_CAPACITY)
initialCapacity = MAXIMUM_CAPACITY;
int c = initialCapacity / ssize; //初步计算每个segment的大小
if (c * ssize < initialCapacity) //如果总数小于入参的初始大小就累加下
++c;
int cap = MIN_SEGMENT_TABLE_CAPACITY; //2
while (cap < c) //这里保证每个segment的大小为2的倍数
cap <<= 1;
//初始化s0,有的版本这里的代码是把所有segment都初始化一遍
Segment<K,V> s0 =
new Segment<K,V>(loadFactor, (int)(cap * loadFactor),
(HashEntry<K,V>[])new HashEntry[cap]);
Segment<K,V>[] ss = (Segment<K,V>[])new Segment[ssize];//用ssize初始化segments数组
UNSAFE.putOrderedObject(ss, SBASE, s0); // concurrentHashMap大量使用unsafe中的方法,unsafe太强大了,不清楚unsafe的可以百度
this.segments = ss;
}</span>
总结下构造做了什么:
1、 判断初始入参;2.计算segment数量和每个segment的大小,数值都是2的倍数,并且初始化了s0,其中有2个参数segmentShift和segmentMask很重要2个搭配用来计算key的具体segment存储位置;
看下put方法:
<span style="font-size:18px;">public V put(K key, V value) {
Segment<K,V> s;
//注意concurrentHashMap是不能存储key/value为null数据,跟hashmap不一样
if (value == null)
throw new NullPointerException();
int hash = hash(key); //取key的hashcode再来一次hash,2次hash打撒分布,避免冲突
int j = (hash >>> segmentShift) & segmentMask; //nb的处理,获取hash后的key的存储位置,右偏移保留高位再&取具体的值
if ((s = (Segment<K,V>)UNSAFE.getObject // nonvolatile; recheck
(segments, (j << SSHIFT) + SBASE)) == null) // in ensureSegment
s = ensureSegment(j); //只初始化了s0,这里确保segment存在
return s.put(key, hash, value, false); //调用segment的put
}
//因为构造初始化的时候只初始化了s0,所以如果segment存储位置不为s0的时候,要确保位置不为空才行
private Segment<K,V> ensureSegment(int k) {
final Segment<K,V>[] ss = this.segments;
long u = (k << SSHIFT) + SBASE; // 偏移量的计算
Segment<K,V> seg;
if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u)) == null) {
Segment<K,V> proto = ss[0]; // s0不为空null,所以一些参数直接从s0获取
int cap = proto.table.length;
float lf = proto.loadFactor;
int threshold = (int)(cap * lf);
HashEntry<K,V>[] tab = (HashEntry<K,V>[])new HashEntry[cap]; //构造segment里面的table
if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u))
== null) { // recheck
Segment<K,V> s = new Segment<K,V>(lf, threshold, tab);
while ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u))
== null) { //cas操作保证存储位置一定设置成功
if (UNSAFE.compareAndSwapObject(ss, u, null, seg = s))
break;
}
}
}
return seg;
}</span>
最重要的是int j = (hash >>> segmentShift) & segmentMask;这一个查找segment存储位置的,看构造函数这2个变量时怎么来的,再体会下二进制操作,想不佩服都不行,处理的真nb。其他没什么,就是查找后确认segment不会null,为null需要通过s0初始化一个,然后cas设置,最后调用segment的put操作。segment的代码最后看。
看下get操作:
<span style="font-size:18px;">public V get(Object key) {
Segment<K,V> s; // manually integrate access methods to reduce overhead
HashEntry<K,V>[] tab;
int h = hash(key);
long u = (((h >>> segmentShift) & segmentMask) << SSHIFT) + SBASE;
if ((s = (Segment<K,V>)UNSAFE.getObjectVolatile(segments, u)) != null &&
(tab = s.table) != null) {
for (HashEntry<K,V> e = (HashEntry<K,V>) UNSAFE.getObjectVolatile
(tab, ((long)(((tab.length - 1) & h)) << TSHIFT) + TBASE);
e != null; e = e.next) {
K k;
if ((k = e.key) == key || (e.hash == h && key.equals(k)))
return e.value;
}
}
return null;
}</span>
get操作通过unsafe.getObjectVolatile操作来获取具体的值,也实现了volatile语义,避免并发操作时获取不到最新值。获取存储位置segment的语句long u = (((h >>> segmentShift) & segmentMask) << SSHIFT) + SBASE;这句是通过segment数组的首地址+偏移量来计算获得,获取segment位置后,再获取具体table的值,然后就是一些判断。关于首地址+偏移量那里的操作不明白的可以看看原子数组变量和unsafe的代码。
看下size():
<span style="font-size:18px;">public int size() {
// Try a few times to get accurate count. On failure due to
// continuous async changes in table, resort to locking.
final Segment<K,V>[] segments = this.segments;
int size;
boolean overflow; // true if size overflows 32 bits
long sum; // sum of modCounts
long last = 0L; // previous sum
int retries = -1; // first iteration isn't retry
try {
for (;;) {
//这里是for循环3次后,如果没break,那就分别对segment加锁,然后再统计,如果之前segment有为null的,这里强制初始化
if (retries++ == RETRIES_BEFORE_LOCK) {
for (int j = 0; j < segments.length; ++j)
ensureSegment(j).lock(); // force creation
}
sum = 0L;
size = 0;
overflow = false;
for (int j = 0; j < segments.length; ++j) {
Segment<K,V> seg = segmentAt(segments, j);
if (seg != null) {
sum += seg.modCount; //统计各个segment的结构变化次数
int c = seg.count; //统计各个segment的table元素数量
if (c < 0 || (size += c) < 0) //防止溢出
overflow = true;
}
}
if (sum == last) //如果和上次统计结果一样就退出
break;
last = sum;
}
} finally {
//segment分别解锁
if (retries > RETRIES_BEFORE_LOCK) {
for (int j = 0; j < segments.length; ++j)
segmentAt(segments, j).unlock();
}
}
return overflow ? Integer.MAX_VALUE : size;
}
//getObjectVolatile获取segment的值
static final <K,V> Segment<K,V> segmentAt(Segment<K,V>[] ss, int j) {
long u = (j << SSHIFT) + SBASE;
return ss == null ? null :
(Segment<K,V>) UNSAFE.getObjectVolatile(ss, u);
}</span>
大致流程是:3次for循环,如果有连续2次统计的segment的modCount(segment的table结构修改次数)sum结果相同,那就说明在此期间,concurrentHashMap没有变化,那就返回此时统计的size,如果第3次统计的结果跟第2次不一样,那么下一个循环就依次对各个segment加锁,如果segment为null那就创建,统计完再依次解锁。
最后看下segment的代码:
<span style="font-size:18px;">//继承ReetrantLock实现每个segment一把锁
static final class Segment<K,V> extends ReentrantLock implements Serializable {
private static final long serialVersionUID = 2249069246763182397L;
/**
* The maximum number of times to tryLock in a prescan before
* possibly blocking on acquire in preparation for a locked
* segment operation. On multiprocessors, using a bounded
* number of retries maintains cache acquired while locating
* nodes.
*/
static final int MAX_SCAN_RETRIES =
Runtime.getRuntime().availableProcessors() > 1 ? 64 : 1;
//segment中table
transient volatile HashEntry<K,V>[] table;
//元素数量
transient int count;
//结构修改次数
transient int modCount;
//极限值
transient int threshold;
//负载因子
final float loadFactor;
/**
之前ConcurrentHashMap初始化构造的创建s0,
Segment<K,V> s0 =
new Segment<K,V>(loadFactor, (int)(cap * loadFactor),
(HashEntry<K,V>[])new HashEntry[cap]);
*/
Segment(float lf, int threshold, HashEntry<K,V>[] tab) {
this.loadFactor = lf;
this.threshold = threshold;
this.table = tab;
}
//之前的put操作找到segment具体位置后调用segment的put操作s.put(key, hash, value, false);
final V put(K key, int hash, V value, boolean onlyIfAbsent) {
//首先尝试加锁,加锁失败则调用scanAndLockForPut自旋加锁
HashEntry<K,V> node = tryLock() ? null :
scanAndLockForPut(key, hash, value);
V oldValue;
try {
HashEntry<K,V>[] tab = table;
int index = (tab.length - 1) & hash;//在table中查找key对应的位置
HashEntry<K,V> first = entryAt(tab, index); //unsafe调用获取table指定位置链表的值第一个值
for (HashEntry<K,V> e = first;;) {
if (e != null) { //链表存在就搜索链表看是否存在相同的,跟hashmap都一样
K k;
if ((k = e.key) == key ||
(e.hash == hash && key.equals(k))) {
oldValue = e.value;
if (!onlyIfAbsent) {
e.value = value;
++modCount;
}
break;
}
e = e.next;
}
else {//不存在就新建一个,设置next
if (node != null)
node.setNext(first);
else
node = new HashEntry<K,V>(hash, key, value, first);
int c = count + 1;
if (c > threshold && tab.length < MAXIMUM_CAPACITY)
rehash(node); //超过极限值就rehash
else
setEntryAt(tab, index, node); //unsafe设置回去数组的对应位置的链表
++modCount;
count = c;
oldValue = null;
break;
}
}
} finally {
unlock();
}
return oldValue;
}
/**
* 把table长度*2,原节点和新节点都加入到新创建的table
*/
@SuppressWarnings("unchecked")
private void rehash(HashEntry<K,V> node) {
HashEntry<K,V>[] oldTable = table;
int oldCapacity = oldTable.length;
int newCapacity = oldCapacity << 1; //新table大小
threshold = (int)(newCapacity * loadFactor); //新的极限值
HashEntry<K,V>[] newTable =
(HashEntry<K,V>[]) new HashEntry[newCapacity]; //创建新的table数组
int sizeMask = newCapacity - 1; //计算具体位置时用,跟hashmap计算方式一样
for (int i = 0; i < oldCapacity ; i++) { //循环oldtable
HashEntry<K,V> e = oldTable[i];
if (e != null) {
HashEntry<K,V> next = e.next;
int idx = e.hash & sizeMask;
if (next == null) // 只有一个节点,直接移过去
newTable[idx] = e;
else { // 节点重用
HashEntry<K,V> lastRun = e;
int lastIdx = idx;
//下面2个for循环的逻辑是lastRun,last从next节点往后移,最后lastRun指向最后一个转移到新table的index不变的节点
//比较乱,画图走几遍,意思就是说假如原来的table[1]有10个节点,然后不停计算节点在newtable的位置,很可能从第四个节点的时候开始,
//后面的所有节点在newtable中的存储位置都一样了,那么我newtable只要把第4个节点直接放过去就行,然后从链表头开始处理其他节点,
//就不用把所有节点都新建一遍了
for (HashEntry<K,V> last = next;
last != null;
last = last.next) {
int k = last.hash & sizeMask;
if (k != lastIdx) {
lastIdx = k;
lastRun = last;
}
}
newTable[lastIdx] = lastRun; //直接lastRun设置到newtable
// 复制其他节点
for (HashEntry<K,V> p = e; p != lastRun; p = p.next) {
V v = p.value;
int h = p.hash;
int k = h & sizeMask;
HashEntry<K,V> n = newTable[k];
newTable[k] = new HashEntry<K,V>(h, p.key, v, n);
}
}
}
}
int nodeIndex = node.hash & sizeMask; // 把新节点加入到newtable
node.setNext(newTable[nodeIndex]);
newTable[nodeIndex] = node;
table = newTable;
}
/**
* 自旋尝试加锁,不成功扫描对应位置的链表,如果链表中key不存在就创建一个node,达到最大次数后就阻塞加锁,如果key存在返回的null
* 处理过程中其他线程改变了链表结构,那就重头再来
*/
private HashEntry<K,V> scanAndLockForPut(K key, int hash, V value) {
HashEntry<K,V> first = entryForHash(this, hash);
HashEntry<K,V> e = first;
HashEntry<K,V> node = null;
int retries = -1; // negative while locating node
while (!tryLock()) {
HashEntry<K,V> f; // to recheck first below
if (retries < 0) {
if (e == null) { //基本就是查找key不存在就创建一个,存在就trylock一直到次数限制,再不行就阻塞加锁
if (node == null)
node = new HashEntry<K,V>(hash, key, value, null);
retries = 0;
}
else if (key.equals(e.key))
retries = 0;
else
e = e.next;
}
else if (++retries > MAX_SCAN_RETRIES) { //超过最大尝试次数,那么就lock阻塞,单核1,多核64
lock();
break;
}
else if ((retries & 1) == 0 &&
(f = entryForHash(this, hash)) != first) { //隔一次检查一遍尝试的时候发现链表的首节点变化了,也就是有别的线程操作了,那就重来
e = first = f; // re-traverse if entry changed
retries = -1;
}
}
return node;
}
/**
跟这个差不多scanAndLockForPut,没有返回,要买trylock成功,要买阻塞lock
*/
private void scanAndLock(Object key, int hash) {
// similar to but simpler than scanAndLockForPut
HashEntry<K,V> first = entryForHash(this, hash);
HashEntry<K,V> e = first;
int retries = -1;
while (!tryLock()) {
HashEntry<K,V> f;
if (retries < 0) {
if (e == null || key.equals(e.key))
retries = 0;
else
e = e.next;
}
else if (++retries > MAX_SCAN_RETRIES) {
lock();
break;
}
else if ((retries & 1) == 0 &&
(f = entryForHash(this, hash)) != first) {
e = first = f;
retries = -1;
}
}
}
/**
* Remove; match on key only if value null, else match both.
*/
final V remove(Object key, int hash, Object value) {
if (!tryLock())
scanAndLock(key, hash);
V oldValue = null;
try {
HashEntry<K,V>[] tab = table;
int index = (tab.length - 1) & hash;
HashEntry<K,V> e = entryAt(tab, index);
HashEntry<K,V> pred = null;
while (e != null) {
K k;
HashEntry<K,V> next = e.next;
if ((k = e.key) == key ||
(e.hash == hash && key.equals(k))) {
V v = e.value;
if (value == null || value == v || value.equals(v)) {
if (pred == null)
setEntryAt(tab, index, next);
else
pred.setNext(next);
++modCount;
--count;
oldValue = v;
}
break;
}
pred = e;
e = next;
}
} finally {
unlock();
}
return oldValue;
}
final boolean replace(K key, int hash, V oldValue, V newValue) {
if (!tryLock())
scanAndLock(key, hash);
boolean replaced = false;
try {
HashEntry<K,V> e;
for (e = entryForHash(this, hash); e != null; e = e.next) {
K k;
if ((k = e.key) == key ||
(e.hash == hash && key.equals(k))) {
if (oldValue.equals(e.value)) {
e.value = newValue;
++modCount;
replaced = true;
}
break;
}
}
} finally {
unlock();
}
return replaced;
}
final V replace(K key, int hash, V value) {
if (!tryLock())
scanAndLock(key, hash);
V oldValue = null;
try {
HashEntry<K,V> e;
for (e = entryForHash(this, hash); e != null; e = e.next) {
K k;
if ((k = e.key) == key ||
(e.hash == hash && key.equals(k))) {
oldValue = e.value;
e.value = value;
++modCount;
break;
}
}
} finally {
unlock();
}
return oldValue;
}
final void clear() {
lock();
try {
HashEntry<K,V>[] tab = table;
for (int i = 0; i < tab.length ; i++)
setEntryAt(tab, i, null);
++modCount;
count = 0;
} finally {
unlock();
}
}
}
@SuppressWarnings("unchecked")
static final <K,V> HashEntry<K,V> entryAt(HashEntry<K,V>[] tab, int i) {
return (tab == null) ? null :
(HashEntry<K,V>) UNSAFE.getObjectVolatile
(tab, ((long)i << TSHIFT) + TBASE);
}
/**
* Sets the ith element of given table, with volatile write
* semantics. (See above about use of putOrderedObject.)
*/
static final <K,V> void setEntryAt(HashEntry<K,V>[] tab, int i,
HashEntry<K,V> e) {
UNSAFE.putOrderedObject(tab, ((long)i << TSHIFT) + TBASE, e);
}</span>
总结:
1、 大量使用了unsafe中方法,这个需要去了解unsafe,很重要;
2、 segment使用了Reentrantlock实现分段锁来保证put的线程安全,get使用unsafe.getobjectvolatile来保证可见性;
3、 不容许key/value为null;
4、 ConcurrentHashMap的get,clear,iterator(entrySet、keySet、values方法)可能存在弱一致性问题,关于这个,还要学习;