JDK7中的实现
数据结构:
在JDK7中,ConcurrentHashMap被分为多个HashMap,每个子HashMap为一个Segment。
每个Segment都继承于ReentrantLock可重入锁,所以Segment的数量代表了并发度。
static class Segment<K,V> extends ReentrantLock implements Serializable {
private static final long serialVersionUID = 2249069246763182397L;
final float loadFactor;
Segment(float lf) { this.loadFactor = lf; }
}
final Segment<K,V>[] segments;
构造函数
static final int MAX_SEGMENTS = 1 << 16; // 上限是2的16次方
static final int MIN_SEGMENT_TABLE_CAPACITY = 2;
public ConcurrentHashMap(int initialCapacity,
float loadFactor, int concurrencyLevel) {
if (!(loadFactor > 0) || initialCapacity < 0 || concurrencyLevel <= 0)
throw new IllegalArgumentException();
if (concurrencyLevel > MAX_SEGMENTS) // 保证segment的数量不会超出上限
concurrencyLevel = MAX_SEGMENTS;
// Find power-of-two sizes best matching arguments
int sshift = 0;
int ssize = 1;
while (ssize < concurrencyLevel) { // 保证segment数量是2的整数次方
++sshift;
ssize <<= 1;
}
this.segmentShift = 32 - sshift;
this.segmentMask = ssize - 1; // 锁标记,用相应的位数作为标记
if (initialCapacity > MAXIMUM_CAPACITY)
initialCapacity = MAXIMUM_CAPACITY;
int c = initialCapacity / ssize; // 初始化每个Segment的容量
if (c * ssize < initialCapacity) // 每个Segment容量向上取整
++c;
int cap = MIN_SEGMENT_TABLE_CAPACITY;
while (cap < c) // 保证每个Segment的容量也是2的整数次方
cap <<= 1;
// create segments and segments[0]
Segment<K,V> s0 =
new Segment<K,V>(loadFactor, (int)(cap * loadFactor),
(HashEntry<K,V>[])new HashEntry[cap]); // 构造第0个Segment
Segment<K,V>[] ss = (Segment<K,V>[])new Segment[ssize]; // 设置Segment数组大小为ssize
UNSAFE.putOrderedObject(ss, SBASE, s0); // 初始化segments[0]
this.segments = ss;
}
默认的构造函数:
public ConcurrentHashMap() {
this(DEFAULT_INITIAL_CAPACITY, DEFAULT_LOAD_FACTOR, DEFAULT_CONCURRENCY_LEVEL);
}
构造函数有三个参数:int initialCapacity, float loadFactor, int concurrencyLevel。分别表示初始总容量、扩容因子、并发度(Segment数量)。
初始容量会平均到每个Segment上并且向上取整,且保证每个Segment容量是2的整数次方,并发度同样也会向上取整且保证是2的整数次方,Segment数量一旦确定就不会更改。而扩容因子表示每个Segment的扩容因子。
put 过程
public V put(K key, V value) {
Segment<K,V> s;
if (value == null)
throw new NullPointerException();
int hash = hash(key); // 计算hash值
int j = (hash >>> segmentShift) & segmentMask; // 把hash值映射到第 j 个Segment
if ((s = (Segment<K,V>)UNSAFE.getObject
(segments, (j << SSHIFT) + SBASE)) == null) // 第j个Segment为null
s = ensureSegment(j); // 调用ensureSegment(j)初始化第j个Segment
return s.put(key, hash, value, false); // 往第j个Segment put,put内部加锁进行并发控制,即分段锁
}
ensureSegment(int k)函数:
private Segment<K,V> ensureSegment(int k) {
final Segment<K,V>[] ss = this.segments;
long u = (k << SSHIFT) + SBASE; // k对应的内存偏移量
Segment<K,V> seg;
if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u)) == null) { // 当segment[k] == null 时进行初始化
Segment<K,V> proto = ss[0]; // 以segment[0]为原型
int cap = proto.table.length;
float lf = proto.loadFactor;
int threshold = (int)(cap * lf);
HashEntry<K,V>[] tab = (HashEntry<K,V>[])new HashEntry[cap];
if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u))
== null) { // 再次检查segment[k]是否为null
Segment<K,V> s = new Segment<K,V>(lf, threshold, tab);
while ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u))
== null) { // 检查segment[0]
if (UNSAFE.compareAndSwapObject(ss, u, null, seg = s)) // CAS保证segment[0]修改前是null,保证并发正确性
break;
}
}
}
return seg;
}
将元素放入的put函数:
final V put(K key, int hash, V value, boolean onlyIfAbsent) {
HashEntry<K,V> node = tryLock() ? null :
scanAndLockForPut(key, hash, value);// 尝试获取锁
V oldValue;
try {
HashEntry<K,V>[] tab = table;
int index = (tab.length - 1) & hash; // 用&运算代替对length取模,因为length是2的整数次方,所以是正确的
HashEntry<K,V> first = entryAt(tab, index); // 定位到第index个HashEntry
for (HashEntry<K,V> e = first;;) {
if (e != null) {
K k;
if ((k = e.key) == key ||
(e.hash == hash && key.equals(k))) {
oldValue = e.value;
if (!onlyIfAbsent) {
e.value = value;
++modCount; // 记录修改次数
}
break; // 当key相等或hash值相等时(即有相等的key),跳出循环
}
e = e.next; // 寻找链表下一个节点
}
else { //遍历到尾部没有发现重复的key
if (node != null)
node.setNext(first); // 把node插入头部
else
node = new HashEntry<K,V>(hash, key, value, first);
int c = count + 1; // 元素个数+1
if (c > threshold && tab.length < MAXIMUM_CAPACITY) // 判断是否超出阈值需要扩容
rehash(node);
else
setEntryAt(tab, index, node); // 把node赋值给tab[index]
++modCount;
count = c;
oldValue = null;
break;
}
}
} finally {
unlock(); // 释放锁
}
return oldValue;
}
在for循环中判断是否存在相同的key,如果有再根据onlyIfAbsent判断是否修改改节点。如果遍历完链表没发现相同key则进行插入操作,插入到尾部。
插入后将count(链表长度)+1,再判断segment内是否需要扩容。最后把table赋值给table[index]。
加锁时的scanAndLockForPut(K key, int hash, V value)函数:
private HashEntry<K,V> scanAndLockForPut(K key, int hash, V value) {
HashEntry<K,V> first = entryForHash(this, hash);
HashEntry<K,V> e = first;
HashEntry<K,V> node = null;
int retries = -1; // negative while locating node
while (!tryLock()) {
HashEntry<K,V> f; // to recheck first below
if (retries < 0) {
if (e == null) { // 遍历链表
if (node == null) // 如果节点为空创建新节点
node = new HashEntry<K,V>(hash, key, value, null);
retries = 0;
}
else if (key.equals(e.key))
retries = 0;
else
e = e.next;
}
else if (++retries > MAX_SCAN_RETRIES) { // 自旋
lock(); // 阻塞
break;
}
else if ((retries & 1) == 0 &&
(f = entryForHash(this, hash)) != first) { // 如果在此期间有线程进行了插入操作,则重新遍历链表
e = first = f;
retries = -1;
}
}
return node;
}
该函数做了两件事,一个是自旋尝试获取锁,另一个是如果在自旋时没有对应的节点(key相同)则创建节点。
扩容函数
当segment内的节点数达到阈值时,触发扩容函数:
private void rehash(HashEntry<K,V> node) {
HashEntry<K,V>[] oldTable = table;
int oldCapacity = oldTable.length;
int newCapacity = oldCapacity << 1; // 扩容为两倍
threshold = (int)(newCapacity * loadFactor); // 更新阈值
HashEntry<K,V>[] newTable =
(HashEntry<K,V>[]) new HashEntry[newCapacity]; // 创建新的table
int sizeMask = newCapacity - 1;
for (int i = 0; i < oldCapacity ; i++) {
HashEntry<K,V> e = oldTable[i];
if (e != null) {
HashEntry<K,V> next = e.next;
int idx = e.hash & sizeMask;
if (next == null) // 设置table[idx]的值
newTable[idx] = e;
else { // Reuse consecutive sequence at same slot
HashEntry<K,V> lastRun = e;
int lastIdx = idx;
for (HashEntry<K,V> last = next;
last != null;
last = last.next) {
int k = last.hash & sizeMask;
if (k != lastIdx) { // 寻找最后一个hash值不等于lastIdx的元素
lastIdx = k;
lastRun = last;
}
}
newTable[lastIdx] = lastRun; // 把在lastRun之后的元素链到新hash表的lastIdx位置
for (HashEntry<K,V> p = e; p != lastRun; p = p.next) { // 将lastRun之前的元素拷贝到新的位置
V v = p.value;
int h = p.hash;
int k = h & sizeMask;
HashEntry<K,V> n = newTable[k];
newTable[k] = new HashEntry<K,V>(h, p.key, v, n);
}
}
}
}
int nodeIndex = node.hash & sizeMask; // 将新节点加入到newTable
node.setNext(newTable[nodeIndex]);
newTable[nodeIndex] = node;
table = newTable;
}
在扩容过程中对拷贝进行了优化,先找到lastRun的位置,然后把这个位置及后面的元素直接链到新表的对应位置,再将前面的节点逐个拷贝到新的位置。
get过程
public V get(Object key) {
Segment<K,V> s;
HashEntry<K,V>[] tab;
int h = hash(key);
long u = (((h >>> segmentShift) & segmentMask) << SSHIFT) + SBASE; // 计算所在的segment的内存偏移量
if ((s = (Segment<K,V>)UNSAFE.getObjectVolatile(segments, u)) != null &&
(tab = s.table) != null) {
for (HashEntry<K,V> e = (HashEntry<K,V>) UNSAFE.getObjectVolatile
(tab, ((long)(((tab.length - 1) & h)) << TSHIFT) + TBASE); // 第二次hash查找对应的HashEntry的下标,然后遍历数组查找
e != null; e = e.next) {
K k;
if ((k = e.key) == key || (e.hash == h && key.equals(k)))
return e.value;
}
}
return null;
}
get方法比较短,主要是进行两次hash和getObjectVolatile,第一次是得到对应的segment,第二次是得到对应的HashEntry下标。
JDK8实现
JDK8中不采用Segment分段的方式加锁,锁的粒度为每个数组元素,并且JDK8引入了红黑树。
构造函数
public ConcurrentHashMap(int initialCapacity) {
if (initialCapacity < 0)
throw new IllegalArgumentException();
int cap = ((initialCapacity >= (MAXIMUM_CAPACITY >>> 1)) ?
MAXIMUM_CAPACITY :
tableSizeFor(initialCapacity + (initialCapacity >>> 1) + 1));
this.sizeCtl = cap;
}
private static final int tableSizeFor(int c) {
int n = c - 1;
n |= n >>> 1;
n |= n >>> 2;
n |= n >>> 4;
n |= n >>> 8;
n |= n >>> 16;
return (n < 0) ? 1 : (n >= MAXIMUM_CAPACITY) ? MAXIMUM_CAPACITY : n + 1;
}
cap是Node数组的长度,且为2的整数次幂。tableSizeFor() 函数用于计算一个容量,1.5倍的初始容量向上对2的整数次幂取整。sizeCtl指并发时的线程数,初始值为cap。
初始化
private final Node<K,V>[] initTable() {
Node<K,V>[] tab; int sc;
while ((tab = table) == null || tab.length == 0) {
if ((sc = sizeCtl) < 0)
Thread.yield(); // sizeCtl小于0,线程释放cpu资源,进入就绪态
else if (U.compareAndSwapInt(this, SIZECTL, sc, -1)) { // 把并发线程数设为-1
try {
if ((tab = table) == null || tab.length == 0) {
int n = (sc > 0) ? sc : DEFAULT_CAPACITY;
@SuppressWarnings("unchecked")
Node<K,V>[] nt = (Node<K,V>[])new Node<?,?>[n];
table = tab = nt;
sc = n - (n >>> 2); // 后面将sizeCtl设为0.75n
}
} finally {
sizeCtl = sc; // 设置sizeCtl的值
}
break;
}
}
return tab;
}
多线程并发控制通过CAS实现,如果有线程将sizeCtl设为-1,那么其他线程将进入等待,一直执行while循环,直到数组不为空。
put过程
public V put(K key, V value) {
return putVal(key, value, false);
}
/** Implementation for put and putIfAbsent */
final V putVal(K key, V value, boolean onlyIfAbsent) {
if (key == null || value == null) throw new NullPointerException();
int hash = spread(key.hashCode());
int binCount = 0;
for (Node<K,V>[] tab = table;;) {
Node<K,V> f; int n, i, fh;
if (tab == null || (n = tab.length) == 0) // 初始化数组
tab = initTable();
else if ((f = tabAt(tab, i = (n - 1) & hash)) == null) { // 初始化第i个元素
if (casTabAt(tab, i, null,
new Node<K,V>(hash, key, value, null)))
break; // no lock when adding to empty bin
}
else if ((fh = f.hash) == MOVED)
tab = helpTransfer(tab, f); // 协助扩容
else { // 添加元素
V oldVal = null;
synchronized (f) { // 并发上锁
if (tabAt(tab, i) == f) {
if (fh >= 0) { // 根据头结点hash值大于等于0判断是链表
binCount = 1;
for (Node<K,V> e = f;; ++binCount) {
K ek;
if (e.hash == hash &&
((ek = e.key) == key ||
(ek != null && key.equals(ek)))) {
oldVal = e.val;
if (!onlyIfAbsent)
e.val = value;
break;
}
Node<K,V> pred = e;
if ((e = e.next) == null) {
pred.next = new Node<K,V>(hash, key,
value, null);
break;
}
}
}
else if (f instanceof TreeBin) {
Node<K,V> p;
binCount = 2;
if ((p = ((TreeBin<K,V>)f).putTreeVal(hash, key,
value)) != null) {
oldVal = p.val;
if (!onlyIfAbsent)
p.val = value;
}
}
}
}
if (binCount != 0) { // 节点数达到8转换为红黑树
if (binCount >= TREEIFY_THRESHOLD)
treeifyBin(tab, i);
if (oldVal != null)
return oldVal;
break;
}
}
}
addCount(1L, binCount);
return null;
}
treeifyBin(Node<K,V>[] tab, int index)函数:
private final void treeifyBin(Node<K,V>[] tab, int index) {
Node<K,V> b; int n, sc;
if (tab != null) {
if ((n = tab.length) < MIN_TREEIFY_CAPACITY) //MIN_TREEIFY_CAPACITY为64
tryPresize(n << 1); // 如果数组长度小于64则先进行扩容重hash
else if ((b = tabAt(tab, index)) != null && b.hash >= 0) {
synchronized (b) {
if (tabAt(tab, index) == b) {
TreeNode<K,V> hd = null, tl = null;
for (Node<K,V> e = b; e != null; e = e.next) {
TreeNode<K,V> p =
new TreeNode<K,V>(e.hash, e.key, e.val,
null, null);
if ((p.prev = tl) == null) // pre和next记录链表原有的顺序结构
hd = p;
else
tl.next = p;
tl = p;
}
setTabAt(tab, index, new TreeBin<K,V>(hd));
//将数组对应位置设为构造的红黑树
}
}
}
}
}
tryPresize(int size)扩容函数:
private final void tryPresize(int size) {
int c = (size >= (MAXIMUM_CAPACITY >>> 1)) ? MAXIMUM_CAPACITY :
tableSizeFor(size + (size >>> 1) + 1); // 计算数组大小
int sc;
while ((sc = sizeCtl) >= 0) {
Node<K,V>[] tab = table; int n;
if (tab == null || (n = tab.length) == 0) { //数组初始化
n = (sc > c) ? sc : c;
if (U.compareAndSwapInt(this, SIZECTL, sc, -1)) {
try {
if (table == tab) {
@SuppressWarnings("unchecked")
Node<K,V>[] nt = (Node<K,V>[])new Node<?,?>[n];
table = nt;
sc = n - (n >>> 2); // 下次扩容的阈值
}
} finally {
sizeCtl = sc;
}
}
}
else if (c <= sc || n >= MAXIMUM_CAPACITY)
break;
else if (tab == table) {
int rs = resizeStamp(n);
if (sc < 0) { // 有线程正在并发扩容
Node<K,V>[] nt;
if ((sc >>> RESIZE_STAMP_SHIFT) != rs || sc == rs + 1 ||
sc == rs + MAX_RESIZERS || (nt = nextTable) == null ||
transferIndex <= 0) // 完成扩容
break;
if (U.compareAndSwapInt(this, SIZECTL, sc, sc + 1))
transfer(tab, nt); // 协助扩容
}
else if (U.compareAndSwapInt(this, SIZECTL, sc,
(rs << RESIZE_STAMP_SHIFT) + 2))
transfer(tab, null); // 第一次扩容
}
}
}
扩容过程中,首次扩容会将sizeCtl设置为一个负数,每个线程扩容时sizeCtl会加一,扩容完成后再将sizeCtl减一。
扩容过程中最重要的是transfer(Node<K,V>[] tab, Node<K,V>[] nextTab)函数:
private final void transfer(Node<K,V>[] tab, Node<K,V>[] nextTab) {
int n = tab.length, stride;
if ((stride = (NCPU > 1) ? (n >>> 3) / NCPU : n) < MIN_TRANSFER_STRIDE)
stride = MIN_TRANSFER_STRIDE; // 根据CPU数计算步长
if (nextTab == null) { // 初始化nextTab数组
try {
@SuppressWarnings("unchecked")
Node<K,V>[] nt = (Node<K,V>[])new Node<?,?>[n << 1]; // 将数组扩容两倍
nextTab = nt;
} catch (Throwable ex) { // try to cope with OOME
sizeCtl = Integer.MAX_VALUE;
return;
}
nextTable = nextTab; // 将新数组赋值给nextTable
transferIndex = n; // transferIndex用于标记迁移位置,是ConcurrentHashMap的元素,多个线程对同一map扩容时,通过它可以并发控制扩容位置,初始化新数组时初始化该属性。
}
int nextn = nextTab.length;
// ForwardingNode为正在被迁移的Node,内部有一个Node数组所以构造函数将nextTab传入
// 用于扩容时并发访问已迁移的数据
ForwardingNode<K,V> fwd = new ForwardingNode<K,V>(nextTab);
boolean advance = true; // 标记一个位置的迁移完成,可以进行下一个位置的迁移了
boolean finishing = false; // to ensure sweep before committing nextTab
// i为索引,bound为边界,迁移是从后往前的
for (int i = 0, bound = 0;;) {
Node<K,V> f; int fh;
// advance表示可以进行下一个位置的迁移,while循环的作用是控制i的位置,完成数组对应下标位置的迁移
while (advance) {
int nextIndex, nextBound;
if (--i >= bound || finishing)
advance = false;
else if ((nextIndex = transferIndex) <= 0) {
// 若transferIndex <= 0说明所有位置完成了迁移
i = -1;
advance = false;
}
else if (U.compareAndSwapInt
(this, TRANSFERINDEX, nextIndex,
nextBound = (nextIndex > stride ?
nextIndex - stride : 0))) {
// CAS设置transferIndex和nextBound的值
bound = nextBound;
i = nextIndex - 1;
advance = false;
}
}
// 越界,代表已经遍历完HashMap
if (i < 0 || i >= n || i + n >= nextn) {
int sc;
if (finishing) { // 所有的迁移完成
nextTable = null;
table = nextTab;
sizeCtl = (n << 1) - (n >>> 1); // 计算新的sizeCtl
return;
}
// CAS将sizeCtl -1,在tryPresize中迁移前对sizeCtl +1,这里进行 -1代表线程完成了自己的迁移任务
if (U.compareAndSwapInt(this, SIZECTL, sc = sizeCtl, sc - 1)) {
if ((sc - 2) != resizeStamp(n) << RESIZE_STAMP_SHIFT)
return; // 通过resize版本号判断所有迁移是否完成
finishing = advance = true;
i = n; // recheck before commit
}
}
else if ((f = tabAt(tab, i)) == null) // 如果对应位置为空则放入刚刚创建的ForwardingNode
advance = casTabAt(tab, i, null, fwd);
else if ((fh = f.hash) == MOVED) // 该位置已经完成迁移
advance = true; // already processed
else {
synchronized (f) {
if (tabAt(tab, i) == f) {
Node<K,V> ln, hn;
if (fh >= 0) { // 判断是链表
int runBit = fh & n;
Node<K,V> lastRun = f;
for (Node<K,V> p = f.next; p != null; p = p.next) {
int b = p.hash & n;
if (b != runBit) {
runBit = b;
lastRun = p;
}
}
if (runBit == 0) {
ln = lastRun;
hn = null;
}
else {
hn = lastRun;
ln = null;
}
for (Node<K,V> p = f; p != lastRun; p = p.next) {
int ph = p.hash; K pk = p.key; V pv = p.val;
if ((ph & n) == 0)
ln = new Node<K,V>(ph, pk, pv, ln);
else
hn = new Node<K,V>(ph, pk, pv, hn);
}
setTabAt(nextTab, i, ln);
setTabAt(nextTab, i + n, hn);
setTabAt(tab, i, fwd);
advance = true;
}
else if (f instanceof TreeBin) { // 红黑树迁移
TreeBin<K,V> t = (TreeBin<K,V>)f;
TreeNode<K,V> lo = null, loTail = null;
TreeNode<K,V> hi = null, hiTail = null;
int lc = 0, hc = 0;
for (Node<K,V> e = t.first; e != null; e = e.next) {
int h = e.hash;
TreeNode<K,V> p = new TreeNode<K,V>
(h, e.key, e.val, null, null);
if ((h & n) == 0) { // 根据hash值的最低位将节点分到两个位置
if ((p.prev = loTail) == null)
lo = p;
else
loTail.next = p;
loTail = p;
++lc;
}
else {
if ((p.prev = hiTail) == null)
hi = p;
else
hiTail.next = p;
hiTail = p;
++hc;
}
}
// 两个位置根据节点数判断用链表还是红黑树
ln = (lc <= UNTREEIFY_THRESHOLD) ? untreeify(lo) :
(hc != 0) ? new TreeBin<K,V>(lo) : t;
hn = (hc <= UNTREEIFY_THRESHOLD) ? untreeify(hi) :
(lc != 0) ? new TreeBin<K,V>(hi) : t;
setTabAt(nextTab, i, ln);
setTabAt(nextTab, i + n, hn);
setTabAt(tab, i, fwd);
advance = true;
}
}
}
}
}
}
transfer中会将旧的数组的数据迁移到2倍大小的nextTab上,当nextTab为空时会对nextTab进行初始化。由于扩容函数可能会被多个线程调用,所以采用按步长分任务的方式将不同位置分给线程去进行迁移。
扩容时并不是为每个线程直接分配指定的区间,而是通过transferIndex属性来进行并发控制,通过CAS修改transferIndex的值,从而保证同一个范围的元素不会被迁移多次。
为了解决扩容未完成时,get已经迁移元素的问题,在扩容时会创建ForwardingNode fwd,构造函数将nextTab作为fwd的nextTable,将对应已迁移的数组下标设为 fwd,在进行访问时访问的是新的ConcurrentHashMap。
即扩容过程中访问的数据若已迁移则访问新的map,否则访问旧map。
参考资料:《Java并发实现原理:JDK源码剖析》