关于ThreadLocal你要知道的一切

ThreadLocal是线程本地变量,可以应用在分布式系统追踪,事务管理方面,这里先提出几个较难的问题

  • ThreadLocal的内部大致实现原理?
  • ThreadLocal的Entry的key为何设计成弱引用?
  • ThreadLocal的hash碰撞是如何处理的?
  • ThreadLocal如何处理主线程传值到子线程?
  • 如何让子线程跟随主线程tl值变化而变化?
  • 线程池中使用threadLocal如何保持不变?
  • 阿里TransmittableThreadLocal?

ThreadLocal的内部大致实现原理?

ThreadLocal内部类有一个内部类ThreadLocalMap, ThreadLocalMap有一个内部类Entry<tl,value>,Entry的key是弱引用。
关于ThreadLocal你要知道的一切

Thread类有个属性threadLocals,类型是ThreadLocalMap。
关于ThreadLocal你要知道的一切

当你执行threadLocal.get的时候,他会先获得当前线程,从线程里拿到属性threadLocals,然后从map中获取值
关于ThreadLocal你要知道的一切

ThreadLocal的Entry的key为何设计成弱引用?

实线为强引用,虚线为弱引用
关于ThreadLocal你要知道的一切设想如果是强引用会发生什么?
ThreadLocal tl = new ThreadLocal();
当你tl=null时,1号线断开,但是由于ThreadLocal还被Entry的key强引用着(2号线),不能被回收,就造成了内存泄漏。

key使用弱引用就没有内存泄漏了吗?
当然不是,ThreadLocal确实可以被回收了,但是Entry的value还被threadLocals强引用着(3号线),value还是内存泄露了,
也正是因为这样,ThreadLocal的get,set,remove方法全部存在一个判断如下:

if(null==key){
	//手动释放value
	value = null;
}

ThreadLocal的hash碰撞是如何处理的?

我们知道hashmap处理hash碰撞的方式是拉链法,或者链地址法,具体做法就是数组+链表/红黑树。

那么ThreadLocal的hash碰撞其实是使用闭散列表,或者开放定址法,具体做法就是如果hash后mod数组长度如果发现下标已经有元素了,那么取数组的下一个下标,直到取到为止。
关于ThreadLocal你要知道的一切

ThreadLocal如何处理主线程传值到子线程?

使用InheritableThreadLocal

		 //主线程设置了值
        InheritableThreadLocal<Integer> tl = new InheritableThreadLocal<>();
        tl.set(1);
        
        //子线程可以拿得到
        new Thread(()->{
            System.out.println("子线程:"+tl.get());
        }).start();

实现原理就是在new Thread里的init会把主线程的inheritableThreadLocals属性深拷贝到子线程的inheritableThreadLocals属性
关于ThreadLocal你要知道的一切

如何让子线程跟随主线程tl值变化而变化?

阅读如下代码:

 @Test
    public void testInheritableThreadLocal1() throws InterruptedException {
        int i=1;
        //主线程设置了值
        InheritableThreadLocal<Integer> tl = new InheritableThreadLocal<>();
        tl.set(i);


        new Thread(()->{
            while(true){
                //但子线程始终是初始值,因为只有init才会深拷贝父线程的inheritableThreadLocals属性
                System.out.println("子线程:"+tl.get());//始终拿到的是1
                try {
                    Thread.sleep(1000);
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
            }

        }).start();

        //主线程不停修改
        while(true){
            Thread.sleep(1000);
            i++;
            tl.set(i);
            System.out.println("主线程:"+tl.get());
        }
    }

通过如上代码可知:子线程只会在new Thread的init方法深拷贝inheritableThreadLocals数据,所以之后不管主线程如何修改tl的值,子线程都不会改变。

解决方案:使用线程同步,当父线程值发生改变时通知子线程修改tl的值。
代码如下:
方案一:

    int i = 1;
    @Test
    public void testInheritableThreadLocal2() throws InterruptedException {

        Object obj = new Object();

        //主线程设置了值
        InheritableThreadLocal<Integer> tl = new InheritableThreadLocal<>();
        tl.set(i);

        //new Thread里的init会把主线程的inheritableThreadLocals属性深拷贝到子线程的inheritableThreadLocals属性
        //但子线程可以拿得到
        new Thread(()->{
            synchronized (obj){
                while(true){
                    tl.set(i);
                    System.out.println("子线程:"+tl.get());
                    obj.notify();

                    try {
                        obj.wait();
                    } catch (InterruptedException e) {
                        e.printStackTrace();
                    }
                }
            }

        }).start();

        //主线程不停修改
        synchronized (obj){
            while(true){
                i++;
                obj.notify();//通知子线程修改线程变量
                obj.wait();//释放锁
                tl.set(i);//主线程修改线程变量
                System.out.println("主线程:"+tl.get());

                Thread.sleep(1000);
            }
        }
    }

方案二:(失败,但值得借鉴)使用回调函数callback,当主线程的tl值发生改变,调用子线程的callback接口把tl值传过去
失败原因在注释里有详细说明

    @Test
    public void testInheritableThreadLocal3() throws InterruptedException {

        
        //主线程设置了值
        InheritableThreadLocal<Integer> tl = new InheritableThreadLocal<>();
        tl.set(i);

        final Callback[] callback = new Callback[1];

        new Thread(()->{
            //这样还是不行,子线程始终还是最初的值
            //因为这行代码是主线程调用的,所以修改的还是主线程的tl值
             callback[0] = i -> { tl.set(i); };

            while(true){
                System.out.println("子线程:"+tl.get());

                try {
                    Thread.sleep(1000);
                } catch (InterruptedException e) {
                    e.printStackTrace();
                }
            }
        }).start();

        Thread.sleep(1000);
        //主线程不停修改
        while(true){
            i++;
            //调用子线程回调函数callback
            callback[0].update(i);
            tl.set(i);//主线程修改线程变量
            System.out.println("主线程:"+tl.get());

            Thread.sleep(1000);
        }
    }

    interface Callback{
        void update(int i);
    }

线程池中使用threadLocal如何保持不变?

先看错误例子:

public class ThreadLocalVariableHolder2 {
    private static ThreadLocal<Integer> variableHolder = new InheritableThreadLocal<Integer>(){
        @Override
        protected Integer initialValue() {
            return 0;
        }
    };

    public static void setValue(int val) {
        variableHolder.set(val);
    }

    public static int getValue() {
        return variableHolder.get();
    }

    public static void remove() {
        variableHolder.remove();
    }

    public static void increment() {
        variableHolder.set(variableHolder.get() + 1);
    }


    public static void main(String[] args) {
        ExecutorService executor = Executors.newFixedThreadPool(2);
        //线程池会复用线程,使用完线程后一定要清楚threadLocal,不然去执行新任务的时候会使用上一次的threadLocal产生错误结果
        for (int i = 0; i < 5; i++) {
            executor.execute(() -> {
                    long threadId = Thread.currentThread().getId();
                    int before = getValue();
                    increment();
                    int after = getValue();
                    System.out.println("threadId: " + threadId + ", before: " + before + ", after: " + after);
            });
        }
        executor.shutdown();
    }
}

输出结果如下:

threadId: 13, before: 0, after: 1
threadId: 12, before: 0, after: 1
threadId: 13, before: 1, after: 2
threadId: 13, before: 2, after: 3
threadId: 12, before: 1, after: 2

让我们关注threadId=12的输出
threadId: 12, before: 0, after: 1
threadId: 12, before: 1, after: 2
这就不符合我们想要的结果了,我们想要的结果是
threadId: 12, before: 0, after: 1
threadId: 12, before: 0, after: 1
即使是同一个线程,但是执行的任务却是不同的,前面的任务肯定不能影响后面的任务啊.

解决方案
在每个任务执行结束后执行remove,把当前线程的tl删掉就不会影响到下一个相同线程了,如下:

				try {
                    long threadId = Thread.currentThread().getId();
                    int before = getValue();
                    increment();
                    int after = getValue();
                    System.out.println("threadId: " + threadId + ", before: " + before + ", after: " + after);
  **加粗样式**              }finally {
                    remove();
                }

上面的例子有initialValue,如果我们没有复写该方法呢?
如果还是remove的话,下一次相同线程进来的时候就会空指针,getValue产生的空指针。
关于ThreadLocal你要知道的一切
解决方案:
关于ThreadLocal你要知道的一切
之所以这样是因为getValue设计成如果拿到的为null,就会从initValue方法拿。

阿里TransmittableThreadLocal

 <dependency>
      <groupId>com.alibaba</groupId>
      <artifactId>transmittable-thread-local</artifactId>
      <version>2.12.1</version>
    </dependency>

public class ThreadLocalVariableHolder1 {
    private static ThreadLocal<Integer> variableHolder = new TransmittableThreadLocal<Integer>();

    public static void setValue(int val) {
        variableHolder.set(val);
    }

    public static int getValue() {
        return variableHolder.get();
    }

    public static void remove() {
        variableHolder.remove();
    }

    public static void increment() {
        variableHolder.set(variableHolder.get() + 1);
    }


    public static void main(String[] args) {
        ExecutorService executor = Executors.newFixedThreadPool(2);
        variableHolder.set(1);
        //线程池会复用线程,使用完线程后一定要清楚threadLocal,不然去执行新任务的时候会使用上一次的threadLocal产生错误结果
        for (int i = 0; i < 5; i++) {
            executor.execute(TtlRunnable.get(() -> {
                long threadId = Thread.currentThread().getId();
                int before = getValue();
                increment();
                int after = getValue();
                System.out.println("threadId: " + threadId + ", before: " + before + ", after: " + after);
            }));
        }
        executor.shutdown();
    }
}

输出结果:

threadId: 14, before: 1, after: 2
threadId: 13, before: 1, after: 2
threadId: 14, before: 1, after: 2
threadId: 13, before: 1, after: 2
threadId: 14, before: 1, after: 2

结果正确,我们连finally都没写,他是如何做到的?
TtlRunnable.get(task), 这行代码很关键,最终代码如下:

关于ThreadLocal你要知道的一切
基于这个思想,我改造了一下我的代码:

public class ThreadLocalVariableHolder3 {
    private static ThreadLocal<Integer> variableHolder = new InheritableThreadLocal<Integer>();

    public static void setValue(int val) {
        variableHolder.set(val);
    }

    public static int getValue() {
        return variableHolder.get();
    }

    public static void remove() {
        variableHolder.remove();
    }

    public static void increment() {
        variableHolder.set(variableHolder.get() + 1);
    }


    public static void main(String[] args) {
        ExecutorService executor = Executors.newFixedThreadPool(2);
        variableHolder.set(1);

        //线程池会复用线程,使用完线程后一定要清楚threadLocal,不然去执行新任务的时候会使用上一次的threadLocal产生错误结果
        for (int i = 0; i < 5; i++) {
            executor.execute(
          
                    TtlRunnable.get(
                       () -> {
                        int before = getValue();
                        long threadId = Thread.currentThread().getId();
                        increment();
                        int after = getValue();
                        System.out.println("threadId: " + threadId + ", before: " + before + ", after: " + after);
                    },variableHolder)
				//我也不需要在这里还原回去
            );
        }
        
        executor.shutdown();
    }

}

class TtlRunnable implements Runnable{

    private Runnable task;
    private ThreadLocal<Integer> tl;

    private TtlRunnable(Runnable task,ThreadLocal<Integer> tl) {
        this.task = task;
        this.tl = tl;
    }

    public static Runnable get(Runnable task,ThreadLocal<Integer> tl){
        return new TtlRunnable(task,tl);
    }

    @Override
    public void run() {
        //备份线程执行之前threadLocal的值
        int back = tl.get();
        try {
            this.task.run();
        } finally {
            //线程执行完后给恢复回去
            tl.set(back);
        }
    }
}
上一篇:[css] ::before和:after中单冒号和双冒号的区别是什么,这两个伪元素有什么作用?


下一篇:【JAVA并发第四篇】线程安全