Java源码解读系列3—ConcurrentHashMap(JDK1.7 )

1 概述

普通的的curd业务工作,一般都是单线程居多,key-value操作基本是HashMap一招吃遍天下鲜。博主由于工作原因,每天工作需要使用大量多线程技术,因此本文不是定位为解释ConcurrentHashMap中的每一行代码,而是从解决并发的视角去思考,为什么ConcurrentHashMap能用于多线程环境!
涉及到并发场景,我们可以使用线程安全容器HashTable和ConcurrentHashMap替代HashMap。HashTable解决多线程访问时对每个方法都加了synchronized,这样实现虽然简单易懂,但是每次只运行一个线程访问,效率低下。所以 Doug Lea老爷子设计了高性能线程安全容器ConcurrentHashMap。

2 设计原理

JKD7的ConcurrentHashMap采用分段锁方式,提高多线程并发的同时保证了线程安全。ConcurrentHashMap的数据结构最外层是一个Segment数组(默认容量为16),Segment类继承了ReentrantLock,每次进行插入或删除等更新操作时,就会对线程访问的每个Segment元素独立上锁,这样既保证线程安全,同时最多允许16个线程对ConcurrentHashMap进行更新操作,大大提高并发性能。
每个Segment里面包含了一个HashEntry数组,HashEntry类里面含有一个next指针,用于指向下一个HashEntry。采用链地址法的设计是为了解决哈希碰撞。

ConcurrentHashMap针对线程安全部分,使用CAS+Unsafe进行解决,如果读者想深入研究,建议阅读博主的《Java源码解读系列2—Unsafe(JDK1.7 )》

Java源码解读系列3—ConcurrentHashMap(JDK1.7 )

2 构造函数

       //默认容量
       static final int DEFAULT_INITIAL_CAPACITY = 16;
   
       //负载因子
       static final float DEFAULT_LOAD_FACTOR = 0.75f;
        
        //默认并发级别
       static final int DEFAULT_CONCURRENCY_LEVEL = 16;
       
       //Segment数组最大容量
       static final int MAX_SEGMENTS = 1 << 16;
       
       //最大容量
       static final int MAXIMUM_CAPACITY = 1 << 30;
       
       //Segment类中包含的HashEntry数组的最小容量
        static final int MIN_SEGMENT_TABLE_CAPACITY = 2;
      
       //默认构造函数
       public ConcurrentHashMap() {
        this(DEFAULT_INITIAL_CAPACITY, DEFAULT_LOAD_FACTOR, DEFAULT_CONCURRENCY_LEVEL);
       }
    
      //有参构造函数
       public ConcurrentHashMap(int initialCapacity,float loadFactor, int concurrencyLevel) {
         //参数校验
        if (!(loadFactor > 0) || initialCapacity < 0 || concurrencyLevel <= 0)
            throw new IllegalArgumentException();
        if (concurrencyLevel > MAX_SEGMENTS)
            concurrencyLevel = MAX_SEGMENTS;
      
        //移动次数
        int sshift = 0;

        //segment数组的容量
        int ssize = 1;
        
        //ssize是不小于concurrencyLevel的最小2^n
        while (ssize < concurrencyLevel) {
            ++sshift;
            ssize <<= 1;
        }
        
        //用于定位Segment
        this.segmentShift = 32 - sshift
        this.segmentMask = ssize - 1;
        
        if (initialCapacity > MAXIMUM_CAPACITY)
            initialCapacity = MAXIMUM_CAPACITY;
        int c = initialCapacity / ssize;
        if (c * ssize < initialCapacity)
            ++c;
        int cap = MIN_SEGMENT_TABLE_CAPACITY;
        while (cap < c)
            cap <<= 1;
       
       //创建一个Segment元素s0,后面创建其他egment元素都是基于s0进行复制
        Segment<K,V> s0 =
            new Segment<K,V>(loadFactor, (int)(cap * loadFactor),
                             (HashEntry<K,V>[])new HashEntry[cap]);
                             
         //创建Segment数组,容量为ssize
        Segment<K,V>[] ss = (Segment<K,V>[])new Segment[ssize];
       
       //将s0添加到Segment数组中
        UNSAFE.putOrderedObject(ss, SBASE, s0); 
        this.segments = ss;
    }
    

3 新增元素(put方法)

put方法将键值对存储到ConcurrentHashMap中,需要加锁操作。

   public V put(K key, V value) {
        Segment<K,V> s;
        //非空检查
        if (value == null)
            throw new NullPointerException();
          
          //通过哈希值hash找到待插入元素在segment数组中的下标
        int hash = hash(key);
        int j = (hash >>> segmentShift) & segmentMask;
        
        //获取Segment数组下标为j 的元素
        if ((s = (Segment<K,V>)UNSAFE.getObject     
             (segments, (j << SSHIFT) + SBASE)) == null)   
             //获取的segments元素为空,需要对Segment数组下标为j的位置进行初始化
            s = ensureSegment(j);
            
            //存储新的键值对
        return s.put(key, hash, value, false);
    }

ensureSegment方法采用的是乐观锁方式对Segment数组待插入下标的元素进行更新。

 private Segment<K,V> ensureSegment(int k) {
        final Segment<K,V>[] ss = this.segments;
        
        //获取待插入元素在segment数组中偏移量的位置
        long u = (k << SSHIFT) + SBASE; // raw offset
        Segment<K,V> seg;
        if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u)) == null) {
        
           //使用segment数组第0位置的元素为模板,复制一个新的元素
            Segment<K,V> proto = ss[0]; 
            int cap = proto.table.length;
            float lf = proto.loadFactor;
            int threshold = (int)(cap * lf);
            HashEntry<K,V>[] tab = (HashEntry<K,V>[])new HashEntry[cap];
            
            //再次检查是否被别的线程初始化了
            if ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u))
                == null) {
                Segment<K,V> s = new Segment<K,V>(lf, threshold, tab);
                
                //通过CAS操作进行原子更新,对Segment数组偏移量为u的元素进行初始化
                while ((seg = (Segment<K,V>)UNSAFE.getObjectVolatile(ss, u))
                       == null) {
                    if (UNSAFE.compareAndSwapObject(ss, u, null, seg = s))
                        break;
                }
            }
        }
        return seg;
    }

操作数据时,先将key所在segment元素加锁,不是整个segment数组加锁,再执行添加操作。执行成功后会执行modCoun数值加1操作,用于记录发生的修改次数。这里有个问题,大家都是公共变量,为什么modCoun不需要用volatile修饰保证内存可见性,而table需要呢?这是因为线程只有拿到锁后才能执行更新操作,更新成功后只从公共堆栈拿取一次modCoun的值。因为有锁操作,每次只允许一个线程执行更新操作,当线程A把锁换回去时候,它已经结束生命周期,公共堆栈中的modCount的值已经发生改变,线程B此时再去读取已经能读到最新的值。
相反当线程A拿到锁后在操作table,线程B虽然没拿到锁,但是线程B可能在自旋操作中,也是会访问table,因此需要马上读取到最新的值,所以需要使用volatile修饰公共变量。

        //Segment内部维护的HashEntry数组
        transient volatile HashEntry<K,V>[] table;

         //记录更新操作次数
        transient int modCount;

        final V put(K key, int hash, V value, boolean onlyIfAbsent) {
             //判断是否获取锁
            HashEntry<K,V> node = tryLock() ? null :
                scanAndLockForPut(key, hash, value);
            V oldValue;
            try {
              //通过哈希值hash找到待插入元素在HashEntry数组中的位置
                HashEntry<K,V>[] tab = table;
                int index = (tab.length - 1) & hash;
                //获取HashEntry链表的头节点
                HashEntry<K,V> first = entryAt(tab, index);
                
                //遍历HashEntry链表
                for (HashEntry<K,V> e = first;;) {
                    if (e != null) {
                        K k;
                        //找到key相等的节点
                        if ((k = e.key) == key ||
                            (e.hash == hash && key.equals(k))) {
                            oldValue = e.value;
                          
                            if (!onlyIfAbsent) {
                               //将新的value值添加到目标节点中
                                e.value = value;
                                //更新操作次数加1                                               
                                ++modCount;
                            }
                            //跳出循环
                            break;
                        }
                        //没有找到目标节点,继续寻找下一个节点
                        e = e.next;
                    }
                    else {
                    //node不为空,使用头插法将新节点插入到链表头节点位置
                        if (node != null)
                            node.setNext(first);
                        else
                        //node为空,创建一个新的HashEntry元素,并且指向first节点
                            node = new HashEntry<K,V>(hash, key, value, first);
                        int c = count + 1;
                        //超过扩容阀值且小于默认最大容量2^30
                        if (c > threshold && tab.length < MAXIMUM_CAPACITY)
                           //扩容HashEntry数组
                            rehash(node);
                        else
                        //将新创建的HashEntry元素更新到链表
                            setEntryAt(tab, index, node);
                        ++modCount;
                        count = c;
                        oldValue = null;
                        break;
                    }
                }
            } finally {
            //释放锁
                unlock();
            }
            //返回旧值
            return oldValue;
        }

Segment类继承ReentrantLock,每次并发操作时可以单独对Segment数组中的每一个元素加锁,提高并发度(并发度为Segment数组的长度,默认值为16),scanAndLockForPut操作可以保证一定获取锁。scanAndLockForPut(K key, int hash, V value)方法中while (!tryLock()) {…}部分代码是一个非常优秀的设计。如果使用lock()方法,获取不到锁线程就会进入阻塞状态,存在线程状态切换的代价。Doug Lea老爷子使用tryLock()+MAX_SCAN_RETRIES的设计方法。tryLock()返回值是Boolean类型,代表尝试获取锁的结果,获取不到锁时线程不会进入阻塞状态。在获取不到锁时,最大自旋次数为MAX_SCAN_RETRIES(多核CPU为64次),当自选次数达到 MAX_SCAN_RETRIES时,线程进入阻塞状态,从而避免该线程在没获取锁的时候还一直占用一个CPU用于while循环。


 static final class Segment<K,V> extends ReentrantLock implements Serializable {
 
 //最大的自旋次数,多核CPU最多为64次,单核CPU为1次
 static final int MAX_SCAN_RETRIES =
            Runtime.getRuntime().availableProcessors() > 1 ? 64 : 1;
            
            ......
            
   
      /**
      *  通过调用此方法,保证一定能获取锁
      * /
 private HashEntry<K,V> scanAndLockForPut(K key, int hash, V value) {
            HashEntry<K,V> first = entryForHash(this, hash);
            HashEntry<K,V> e = first;
            HashEntry<K,V> node = null;
            int retries = -1;
            
            //线程获取锁时才能跳出循环
            while (!tryLock()) {
                HashEntry<K,V> f;
                
                if (retries < 0) {
                    //e和node为空时,创建一个新的节点
                   //满足此条件只要3种情况:1)Segment为空2 )Segment中的HashEntry数组为空时 3)遍历完HashEntry链表,没有发现key相同的节点
                    if (e == null) {
                        if (node == null) 
                            node = new HashEntry<K,V>(hash, key, value, null);
                        retries = 0;
                    }
                    //遍历HashEntry链表过程中,找到与待插入元素的key相同的节点
                    else if (key.equals(e.key))
                        retries = 0;
                        
                      //遍历HashEntry链表
                    else
                        e = e.next;
                }
                
                //自旋次数大于MAX_SCAN_RETRIES,线程进入阻塞状态
                else if (++retries > MAX_SCAN_RETRIES) {
                    lock();
                    break;
                }
                
                //retries为偶数次且HashEntry链表的头节点被其他线程改变了
                else if ((retries & 1) == 0 &&
                         (f = entryForHash(this, hash)) != first) {
                    e = first = f; // re-traverse if entry changed
                    retries = -1;
                }
            }
            return node;
        }



    /**
      *  通过哈希值h获取segment中HashEntry数组索引位置为((tab.length - 1) &       *   h)的HashEntry链表中的头元素
      * /
 static final <K,V> HashEntry<K,V> entryForHash(Segment<K,V> seg, int h) {
        HashEntry<K,V>[] tab;
        
        //当seg为空或者seg中的HashEntry数组为空时,返回值为null
        return (seg == null || (tab = seg.table) == null) ? null :
            (HashEntry<K,V>) UNSAFE.getObjectVolatile
            
            //获取tab下标为(tab.length - 1) & h)的元素
            (tab, ((long)(((tab.length - 1) & h)) << TSHIFT) + TBASE);
    }
    
    ......
    }

4 扩容HashEntry数组(rehash方法)

ConcurrentHashMap类里面的扩容操作,是指每个Segment类里面的HashEntry数组扩容,而不是扩容Segment数组。

HashEntry数组扩容后容量为旧的HashEntry数组的两倍。sizeMask代表数组的索引,即每个元素 哈希后能落到的范围是【0,sizeMask】。扩容前sizeMask为15,其二进制值为 0000 1111,即每个元素得哈希值只能落到[0, 15】。扩容后sizeMask为31,二进制值为 0001 1111,每个元素得哈希能落到[0, 31】比扩容前大了16,因此扩容前落到index位置的元素,在扩容后只可能到i或者index+16,这里的描述是第一次扩容后场景,后续扩容的原理类似不再累述。

  private void rehash(HashEntry<K,V> node) {
            HashEntry<K,V>[] oldTable = table;
            int oldCapacity = oldTable.length;
            int newCapacity = oldCapacity << 1;
            threshold = (int)(newCapacity * loadFactor);
            HashEntry<K,V>[] newTable =
                (HashEntry<K,V>[]) new HashEntry[newCapacity];
            int sizeMask = newCapacity - 1;
            
            
            //下面的for循环主要通过计算元素的哈希值 & sizeMask,获得新的索引,把链表每个旧的元素移动到正确的位置index或者index+16,具体实现比较绕。
            
            for (int i = 0; i < oldCapacity ; i++) {
                HashEntry<K,V> e = oldTable[i];
                if (e != null) {
                    HashEntry<K,V> next = e.next;
                    int idx = e.hash & sizeMask;
                    
                   //元素e不存在下一个节点,不需要移动后续元素
                    if (next == null) 
                        newTable[idx] = e;
                    else { 
                        HashEntry<K,V> lastRun = e;
                        int lastIdx = idx;
                        
                        //遍历整个链表,找到最后与上一个元素的哈希值 & sizeMask 运算结果不同的元素lastRun
                        for (HashEntry<K,V> last = next;
                             last != null;
                             last = last.next) {
                            int k = last.hash & sizeMask;
                            if (k != lastIdx) {
                                lastIdx = k;
                                lastRun = last;
                            }
                        }
                        
                        //将lastRun置于lastIdx位置的首个元素,因为lastRun的哈希值 & sizeMask运算值与后面的节点都相同,因此lastRun置于首位,后面的元素就不需要移动了
                        newTable[lastIdx] = lastRun;
                       
                       
                       // 从链表头开始计算,将lastRun前的元素移动到正确位置
                        for (HashEntry<K,V> p = e; p != lastRun; p = p.next) {
                            V v = p.value;
                            int h = p.hash;
                            int k = h & sizeMask;
                            HashEntry<K,V> n = newTable[k];
                            newTable[k] = new HashEntry<K,V>(h, p.key, v, n);
                        }
                    }
                }
            }
            
            //采用头插法,将新增的元素置于新数组nodeIndex位置的首位
            int nodeIndex = node.hash & sizeMask; 
            node.setNext(newTable[nodeIndex]);
            newTable[nodeIndex] = node;
            table = newTable;
        }

5 查询元素(get方法)

查询方法不涉及锁,整个方法主要做了3件事,第一是在Segment数组中寻找
符合要求的索引,即Segment元素;第二在Segment元素中的HashEntry数组寻找符合要求的索引,即HashEntry元素,最后遍历HashEntry链表,查看是否包含key的节点,有就返回,没有就返回null。

 public V get(Object key) {
        Segment<K,V> s; 
        HashEntry<K,V>[] tab;
        int h = hash(key);
        
        //查找Segment数组中对应索引
        long u = (((h >>> segmentShift) & segmentMask) << SSHIFT) + SBASE;
        if ((s = (Segment<K,V>)UNSAFE.getObjectVolatile(segments, u)) != null &&
            (tab = s.table) != null) {
            
            //(HashEntry<K,V>) UNSAFE.getObjectVolatile (tab, ((long)(((tab.length - 1) & h)) << TSHIFT) + TBASE)获取HashEntry数组符合要求的索引,这里使用getObjectVolatile方法保证获取的公共堆栈中的变量值,而不是从线程的私有堆栈中获取变量值
                     
            for (HashEntry<K,V> e = (HashEntry<K,V>) UNSAFE.getObjectVolatile
                     (tab, ((long)(((tab.length - 1) & h)) << TSHIFT) + TBASE);
                 e != null; e = e.next) {   //遍历HashEntry链表
                K k;
                //找到HashEntry链表中key相等的节点
                if ((k = e.key) == key || (e.hash == h && key.equals(k)))
                    return e.value;
            }
        }
        return null;
    }

6 删除元素 (remove方法)

remove操作是根据key删除HashEntry的节点

 public V remove(Object key) {
        int hash = hash(key);
        
        //通过哈希值hash找到segment数组中对应下标
        Segment<K,V> s = segmentForHash(hash);
        
        //如果s为空则直接返回,否则执行remove(Object key, int hash, Object value)方法
        return s == null ? null : s.remove(key, hash, null);
    }
    

remove(Object key, int hash, Object value) 方法第一步是获取锁,第二步通过哈希值找到HashEntry数组对应下标,第三步遍历整个HashEntry链表找到key相等的节点,执行删除操作,最后释放锁。这些套路跟前面的方法都类似。

final V remove(Object key, int hash, Object value) {
            //尝试获取锁
            if (!tryLock())
            //调用scanAndLock(Object key, int hash) 函数,保证一定能获得锁
                scanAndLock(key, hash);
            V oldValue = null;
            try {
                HashEntry<K,V>[] tab = table;
                
                //通过哈希值找到HashEntry数组对应下标
                int index = (tab.length - 1) & hash;
                HashEntry<K,V> e = entryAt(tab, index);
                
                HashEntry<K,V> pred = null;
                
                //遍历链表
                while (e != null) {
                    K k;
                    HashEntry<K,V> next = e.next;    
                    //找到key相等的节点
                    if ((k = e.key) == key ||
                        (e.hash == hash && key.equals(k))) {
                        V v = e.value;
                        if (value == null || value == v || value.equals(v)) {
                        //如果e的上一个节点为空,将e的下一个节点next移动到e的位置 上       
                            if (pred == null)
                                setEntryAt(tab, index, next);
                                
                            else
                            //如果e的上一个节点为非空,则将上一个节点pred的指针指向e的下一个节点next
                                pred.setNext(next);
                                
                                //修改次数加1
                            ++modCount;
                            
                            //ConcurrentHashMap中存储的元素数量减1
                            --count;
                            oldValue = v;
                        }
                        //跳出循环
                        break;
                    }
                    
                    //找不到合适节点,继续遍历HashEntry链表
                    pred = e;
                    e = next;
                }
            } finally {
                //释放锁
                unlock();
            }
            //返回被删除的元素
            return oldValue;
        }

涉及节点更新操作,需要加锁。 调用scanAndLock(Object key, int hash) 方法保证线程一定能获取锁,跟新增操作的scanAndLockForPut(K key, int hash, V value)方法类似。

private void scanAndLock(Object key, int hash) {
            // similar to but simpler than scanAndLockForPut
            HashEntry<K,V> first = entryForHash(this, hash);
            HashEntry<K,V> e = first;
            int retries = -1;
            while (!tryLock()) {
                HashEntry<K,V> f;
                if (retries < 0) {
                //遍历完HashEntry链表,没找到key相等的节点或者 找到key相等的节点
                    if (e == null || key.equals(e.key))
                        retries = 0;
                    else
                        e = e.next;
                }
                else if (++retries > MAX_SCAN_RETRIES) {
                    lock();
                    break;
                }
                else if ((retries & 1) == 0 &&
                         (f = entryForHash(this, hash)) != first) {
                    e = first = f;
                    retries = -1;
                }
            }
        }

调用entryAt函数,通过索引i获取HashEntry数组i位置的元素


static final <K,V> HashEntry<K,V> entryAt(HashEntry<K,V>[] tab, int i) {
        return (tab == null) ? null :
           //通过getObjectVolatile保证拿到公共堆栈里面的值
            (HashEntry<K,V>) UNSAFE.getObjectVolatile
            (tab, ((long)i << TSHIFT) + TBASE);
    }

调用setEntryAt函数,将元素e写入到HashEntry数组i位置上

static final <K,V> void setEntryAt(HashEntry<K,V>[] tab, int i,
                                       HashEntry<K,V> e) {
                                       
         //putOrderedObject是有序延、迟版本的putObjectVolatile,不保证
         修改的值的被其他线程立即看到,优点是性能高于putObjectVolatile          
        UNSAFE.putOrderedObject(tab, ((long)i << TSHIFT) + TBASE, e);
    }

7 参考文献

  1. JDK7在线文档
    https://tool.oschina.net/apidocs/apidoc?api=jdk_7u4
  2. Bruce Eckel,Java编程思想 第4版. 2007, 机械工业出版社ConcurrentHashMap
上一篇:2021-07-01


下一篇:【操作】Win7怎么同时连接内外网?(无线网络和本地连接无法同时使用)