通过transmittable-thread-local源码理解线程池线程本地变量传递的原理

前提

最近一两个月花了很大的功夫做UCloud服务和中间件迁移到阿里云的工作,没什么空闲时间撸文。想起很早之前写过ThreadLocal的源码分析相关文章,里面提到了ThreadLocal存在一个不能向预先创建的线程中进行变量传递的局限性,刚好有一位HSBC的技术大牛前同事提到了团队引入了transmittable-thread-local解决了此问题。借着这个契机,顺便clonetransmittable-thread-local源码进行分析,这篇文章会把ThreadLocalInheritableThreadLocal的局限性分析完毕,并且从一些基本原理以及设计模式的运用分析transmittable-thread-local(下文简称为TTL)整套框架的实现。

如果对线程池和ThreadLocal不熟悉的话,可以先参看一下前置文章:

这篇文章前后花了两周时间编写,行文比价干硬,文字比较多(接近5W字),希望带着耐心阅读。

父子线程的变量传递

Java中没有明确给出一个API可以基于子线程实例获取其父线程实例,有一个相对可行的方案就是在创建子线程Thread实例的时候获取当前线程的实例,用到的APIThread#currentThread()

public class Thread implements Runnable {

    // 省略其他代码

    @HotSpotIntrinsicCandidate
    public static native Thread currentThread();

    // 省略其他代码
}

Thread#currentThread()方法是一个静态本地方法,它是由JVM实现,这是在JDK中唯一可以获取父线程实例的API。一般而言,如果想在子线程实例中得到它的父线程实例,那么需要像如下这样操作:

public class InheritableThread {

    public static void main(String[] args) throws Exception{
        // 父线程就是main线程
        Thread parentThread = Thread.currentThread();
        Thread childThread = new Thread(()-> {
            System.out.println("Parent thread is:" + parentThread.getName());
        },"childThread");
        childThread.start();
        TimeUnit.SECONDS.sleep(Long.MAX_VALUE);
    }
}
// 输出结果:
Parent thread is:main

类似地,如果我们想把一个父子线程共享的变量实例传递,也可以这样做:

public class InheritableVars {

    public static void main(String[] args) throws Exception {
        // 父线程就是main线程
        Thread parentThread = Thread.currentThread();
        final Var var = new Var();
        var.setValue1("var1");
        var.setValue2("var2");
        Thread childThread = new Thread(() -> {
            System.out.println("Parent thread is:" + parentThread.getName());
            methodFrame1(var);
        }, "childThread");
        childThread.start();
        TimeUnit.SECONDS.sleep(Long.MAX_VALUE);
    }

    private static void methodFrame1(Var var) {
        methodFrame2(var);
    }

    private static void methodFrame2(Var var) {

    }

    @Data
    private static class Var {

        private Object value1;
        private Object value2;
    }
}

这种做法其实是可行的,子线程调用的方法栈中的所有方法都必须显示传入需要从父线程传递过来的参数引用Var实例,这样就会产生硬编码问题,既不灵活也导致方法不能复用,所以才衍生出线程本地变量Thread Local,具体的实现有ThreadLocalInheritableThreadLocal。它们两者的基本原理是类似的,实际上所有的变量实例是缓存在线程实例的变量ThreadLocal.ThreadLocalMap中,线程本地变量实例都只是线程实例获取ThreadLocal.ThreadLocalMap的一道桥梁:

public class Thread implements Runnable {

    // 省略其他代码

    // KEY为ThreadLocal实例,VALUE为具体的值
    ThreadLocal.ThreadLocalMap threadLocals = null;
    
    // KEY为InheritableThreadLocal实例,VALUE为具体的值
    ThreadLocal.ThreadLocalMap inheritableThreadLocals = null;

    // 省略其他代码
}

ThreadLocalInheritableThreadLocal之间的区别可以结合源码分析一下(见下一小节)。前面的分析听起来如果觉得抽象的话,可以自己写几个类推敲一下,假如线程其实叫ThrowableThread,而线程本地变量叫ThrowableThreadLocal,那么它们之间的关系如下:

public class Actor {

    static ThrowableThreadLocal THREAD_LOCAL = new ThrowableThreadLocal();

    public static void main(String[] args) throws Exception {
        ThrowableThread throwableThread = new ThrowableThread() {

            @Override
            public void run() {
                methodFrame1();
            }
        };
        throwableThread.start();
    }

    private static void methodFrame1() {
        THREAD_LOCAL.set("throwable");
        methodFrame2();
    }

    private static void methodFrame2() {
        System.out.println(THREAD_LOCAL.get());
    }

    /**
     * 这个类暂且认为是java.lang.Thread
     */
    private static class ThrowableThread implements Runnable {

        ThrowableThreadLocal.ThrowableThreadLocalMap threadLocalMap;

        @Override
        public void run() {

        }

        // 这里模拟VM的实现,返回ThrowableThread自身,大家先认为不是返回NULL
        public static ThrowableThread getCurrentThread() {
//            return new ThrowableThread();
            return null;   // <--- 假设这里在VM的实现里面返回的不是NULL而是当前的ThrowableThread
        }

        public void start() {
            run();
        }
    }

    private static class ThrowableThreadLocal {

        public ThrowableThreadLocal() {

        }

        public void set(Object value) {
            ThrowableThread currentThread = ThrowableThread.getCurrentThread();
            assert null != currentThread;
            ThrowableThreadLocalMap threadLocalMap = currentThread.threadLocalMap;
            if (null == threadLocalMap) {
                threadLocalMap = currentThread.threadLocalMap = new ThrowableThreadLocalMap();
            }
            threadLocalMap.put(this, value);
        }

        public Object get() {
            ThrowableThread currentThread = ThrowableThread.getCurrentThread();
            assert null != currentThread;
            ThrowableThreadLocalMap threadLocalMap = currentThread.threadLocalMap;
            if (null == threadLocalMap) {
                return null;
            }
            return threadLocalMap.get(this);
        }
        
        // 这里其实在ThreadLocal中用的是WeakHashMap
        public static class ThrowableThreadLocalMap extends HashMap<ThrowableThreadLocal, Object> {

        }
    }
}

上面的代码不能运行,只是通过一个自定义的实现说明一下其中的原理和关系。

ThreadLocal和InheritableThreadLocal的局限性

InheritableThreadLocalThreadLocal的子类,它们之间的联系是:两者都是线程Thread实例获取ThreadLocal.ThreadLocalMap的一个中间变量。区别是:两者控制ThreadLocal.ThreadLocalMap创建的时机和通过Thread实例获取ThreadLocal.ThreadLocalMapThread实例中对应的属性并不一样,导致两者的功能有一点差别。通俗来说两者的功能联系和区别是:

  • ThreadLocal:单个线程生命周期强绑定,只能在某个线程的生命周期内对ThreadLocal进行存取,不能跨线程存取。
public class ThreadLocalMain {

    private static ThreadLocal<String> TL = new ThreadLocal<>();

    public static void main(String[] args) throws Exception {
        new Thread(() -> {
            methodFrame1();
        }, "childThread").start();
        TimeUnit.SECONDS.sleep(Long.MAX_VALUE);
    }

    private static void methodFrame1() {
        TL.set("throwable");
        methodFrame2();
    }

    private static void methodFrame2() {
        System.out.println(TL.get());
    }
}
// 输出结果:
throwable
  • InheritableThreadLocal:(1)可以无感知替代ThreadLocal的功能,当成ThreadLocal使用。(2)明确父-子线程关系的前提下,继承(拷贝)父线程的线程本地变量缓存过的变量,而这个拷贝的时机是子线程Thread实例化时候进行的,也就是子线程实例化完毕后已经完成了InheritableThreadLocal变量的拷贝,这是一个变量传递的过程。
public class InheritableThreadLocalMain {
    
    // 此处可以尝试替换为ThreadLocal,最后会输出null
    static InheritableThreadLocal<String> ITL = new InheritableThreadLocal<>();

    public static void main(String[] args) throws Exception {
        new Thread(() -> {
            // 在父线程中设置变量
            ITL.set("throwable");
            new Thread(() -> {
                methodFrame1();
            }, "childThread").start();
        }, "parentThread").start();
        TimeUnit.SECONDS.sleep(Long.MAX_VALUE);
    }

    private static void methodFrame1() {
        methodFrame2();
    }

    private static void methodFrame2() {
        System.out.println(ITL.get());
    }
}
// 输出结果:
throwable

上面提到的两点可以具体参看ThreadLocalInheritableThreadLocalThread三个类的源码,这里笔者把一些必要的注释和源码段贴出:

// --> java.lang.Thread类的源码片段
public class Thread implements Runnable {

    // 省略其他代码 

    // 这是Thread最基本的构造函数
    private Thread(ThreadGroup g, Runnable target, String name,
                   long stackSize, AccessControlContext acc,
                   boolean inheritThreadLocals) {

        // 省略其他代码

        Thread parent = currentThread();
        this.group = g;
        this.daemon = parent.isDaemon();
        this.priority = parent.getPriority();
        if (security == null || isCCLOverridden(parent.getClass()))
            this.contextClassLoader = parent.getContextClassLoader();
        else
            this.contextClassLoader = parent.contextClassLoader;
        this.inheritedAccessControlContext =
                acc != null ? acc : AccessController.getContext();
        this.target = target;
        setPriority(priority);
        // inheritThreadLocals一般情况下为true
        // 当前子线程实例拷贝父线程的inheritableThreadLocals属性,创建一个新的ThreadLocal.ThreadLocalMap实例赋值到自身的inheritableThreadLocals属性
        if (inheritThreadLocals && parent.inheritableThreadLocals != null)
            this.inheritableThreadLocals = ThreadLocal.createInheritedMap(parent.inheritableThreadLocals);
        this.stackSize = stackSize;
        this.tid = nextThreadID();
    }

    // 省略其他代码
}

// --> java.lang.ThreadLocal源码片段
public class ThreadLocal<T> {

    // 省略其他代码 

    public void set(T value) {
        Thread t = Thread.currentThread();
        // 通过当前线程获取线程实例中的threadLocals
        ThreadLocalMap map = getMap(t);
        // 线程实例中的threadLocals为NULL,实例则创建一个ThreadLocal.ThreadLocalMap实例添加当前ThreadLocal->VALUE到ThreadLocalMap中,如果已经存在ThreadLocalMap则进行覆盖对应的Entry
        if (map != null) {
            map.set(this, value);
        } else {
            createMap(t, value);
        }
    }

    // 通过线程实例获取该线程的threadLocals实例,其实是ThreadLocal.ThreadLocalMap类型的属性
    ThreadLocalMap getMap(Thread t) {
        return t.threadLocals;
    }


    public T get() {
        Thread t = Thread.currentThread();
        // 通过当前线程获取线程实例中的threadLocals,再获取ThreadLocal.ThreadLocalMap中匹配上KEY为当前ThreadLocal实例的Entry对应的VALUE
        ThreadLocalMap map = getMap(t);
        if (map != null) {
            ThreadLocalMap.Entry e = map.getEntry(this);
            if (e != null) {
                @SuppressWarnings("unchecked")
                T result = (T)e.value;
                return result;
            }
        }
        // 找不到则尝试初始化ThreadLocal.ThreadLocalMap
        return setInitialValue();
    }
    
    // 如果不存在ThreadLocal.ThreadLocalMap,则通过初始化initialValue()方法的返回值,构造一个ThreadLocal.ThreadLocalMap
    private T setInitialValue() {
        T value = initialValue();
        Thread t = Thread.currentThread();
        ThreadLocalMap map = getMap(t);
        if (map != null)
            map.set(this, value);
        else
            createMap(t, value);
        return value;
    }

    // 省略其他代码 
}

// --> java.lang.InheritableThreadLocal源码 - 太简单,全量贴出
public class InheritableThreadLocal<T> extends ThreadLocal<T> {
    
    // 这个方法使用在线程Thread的构造函数里面ThreadLocal.createInheritedMap(),基于父线程InheritableThreadLocal的属性创建子线程的InheritableThreadLocal属性,它的返回值决定了拷贝父线程的属性时候传入子线程的值
    protected T childValue(T parentValue) {
        return parentValue;
    }
    
    // 覆盖获取线程实例中的绑定的ThreadLocalMap为Thread#inheritableThreadLocals,这个方法其实是覆盖了ThreadLocal中对应的方法,应该加@Override注解
    ThreadLocalMap getMap(Thread t) {
       return t.inheritableThreadLocals;
    }
    
    // 覆盖创建ThreadLocalMap的逻辑,赋值到线程实例中的inheritableThreadLocals,而不是threadLocals,这个方法其实是覆盖了ThreadLocal中对应的方法,应该加@Override注解
    void createMap(Thread t, T firstValue) {
        t.inheritableThreadLocals = new ThreadLocalMap(this, firstValue);
    }
}

一定要注意,这里的setInitialValue()方法很重要,一个新的线程Thread实例在初始化(对于InheritableThreadLocal而言继承父线程的线程本地变量)或者是首次调用ThreadLocal#set(),会通过此setInitialValue()方法去构造一个全新的ThreadLocal.ThreadLocalMap,会直接使用createMap()方法。

以前面提到的两个例子,贴一个图加深理解:

Example-1

通过transmittable-thread-local源码理解线程池线程本地变量传递的原理

Example-2

通过transmittable-thread-local源码理解线程池线程本地变量传递的原理

ThreadLocalInheritableThreadLocal的最大局限性就是:无法为预先创建好(未投入使用)的线程实例传递变量(准确来说是首次传递某些场景是可行的,而后面由于线程池中的线程是复用的,无法进行更新或者修改变量的传递值),泛线程池Executor体系、TimerTaskForkJoinPool等一般会预先创建(核心)线程,也就它们都是无法在线程池中由预创建的子线程执行的Runnable任务实例中使用。例如下面的方式会导致参数传递失败:

public class InheritableThreadForExecutor {

    static final InheritableThreadLocal<String> ITL = new InheritableThreadLocal<>();
    static final Executor EXECUTOR = Executors.newFixedThreadPool(1);

    public static void main(String[] args) throws Exception {
        ITL.set("throwable");
        EXECUTOR.execute(() -> {
            System.out.println(ITL.get());
        });
        ITL.set("doge");
        EXECUTOR.execute(() -> {
            System.out.println(ITL.get());
        });
        TimeUnit.SECONDS.sleep(Long.MAX_VALUE);
    }
}
// 输出结果:
throwable
throwable   # <--- 可见此处参数传递出现异常

首次变量传递成功是因为线程池中的所有子线程都是派生自main线程。

TTL的简单使用

TTL的使用方式在它的项目README.md或者项目中的单元测试有十分详细的介绍,先引入依赖com.alibaba:transmittable-thread-local:2.11.4,这里演示一个例子:

// 父-子线程
public class TtlSample1 {

    static TransmittableThreadLocal<String> TTL = new TransmittableThreadLocal<>();

    public static void main(String[] args) throws Exception {
        new Thread(() -> {
            // 在父线程中设置变量
            TTL.set("throwable");
            new Thread(TtlRunnable.get(() -> {
                methodFrame1();
            }), "childThread").start();
        }, "parentThread").start();
        TimeUnit.SECONDS.sleep(Long.MAX_VALUE);
    }

    private static void methodFrame1() {
        methodFrame2();
    }

    private static void methodFrame2() {
        System.out.println(TTL.get());
    }
}
// 输出:
throwable

// 线程池
public class TtlSample2 {

    static TransmittableThreadLocal<String> TTL = new TransmittableThreadLocal<>();
    static final Executor EXECUTOR = Executors.newFixedThreadPool(1);

    public static void main(String[] args) throws Exception {
        TTL.set("throwable");
        EXECUTOR.execute(TtlRunnable.get(() -> {
            System.out.println(TTL.get());
        }));
        TTL.set("doge");
        EXECUTOR.execute(TtlRunnable.get(() -> {
            System.out.println(TTL.get());
        }));
        TimeUnit.SECONDS.sleep(Long.MAX_VALUE);
    }
}
// 输出:
throwable
doge

TTL实现的基本原理

TTL设计上使用了大量的委托(Delegate),委托是C#里面的说法,对标Java的设计模式就是代理模式。举个简单的例子:

@Slf4j
public class StaticDelegate {

    public static void main(String[] args) throws Exception {
        new RunnableDelegate(() -> log.info("Hello World!")).run();
    }

    @Slf4j
    @RequiredArgsConstructor
    private static final class RunnableDelegate implements Runnable {

        private final Runnable runnable;

        @Override
        public void run() {
            try {
                log.info("Before run...");
                runnable.run();
                log.info("After run...");
            } finally {
                log.info("Finally run...");
            }
        }
    }
}
// 输出结果:
23:45:27.763 [main] INFO club.throwable.juc.StaticDelegate$RunnableDelegate - Before run...
23:45:27.766 [main] INFO club.throwable.juc.StaticDelegate - Hello World!
23:45:27.766 [main] INFO club.throwable.juc.StaticDelegate$RunnableDelegate - After run...
23:45:27.766 [main] INFO club.throwable.juc.StaticDelegate$RunnableDelegate - Finally run...

委托如果使用纯熟的话,可以做出很多十分有用的功能,例如可以基于Micrometer去统计任务的执行时间,上报到Prometheus,然后用Grafana做监控和展示:

// 需要引入io.micrometer:micrometer-core:${version}
@Slf4j
public class MeterDelegate {

    public static void main(String[] args) throws Exception {
        Executor executor = Executors.newFixedThreadPool(1);
        Runnable task = () -> {
            try {
                // 模拟耗时
                Thread.sleep(1000);
            } catch (Exception ignore) {

            }
        };
        Map<String, String> tags = new HashMap<>(8);
        tags.put("_class", "MeterDelegate");
        executor.execute(new MicrometerDelegate(task, "test-task", tags));
        TimeUnit.SECONDS.sleep(Long.MAX_VALUE);
    }

    @Slf4j
    @RequiredArgsConstructor
    private static final class MicrometerDelegate implements Runnable {

        private final Runnable runnable;
        private final String taskType;
        private final Map<String, String> tags;

        @Override
        public void run() {
            long start = System.currentTimeMillis();
            try {
                runnable.run();
            } finally {
                long end = System.currentTimeMillis();
                List<Tag> tagsList = Lists.newArrayList();
                Optional.ofNullable(tags).ifPresent(x -> x.forEach((k, v) -> {
                    tagsList.add(Tag.of(k, v));
                }));
                Metrics.summary(taskType, tagsList).record(end - start);
            }
        }
    }
}

委托理论上只要不线程栈溢出,可以无限层级地包装,有点像洋葱的结构,原始的目标方法会被包裹在最里面并且最后执行:

    public static void main(String[] args) throws Exception {
        Runnable target = () -> log.info("target");
        Delegate level1 = new Delegate(target);
        Delegate level2 = new Delegate(level1);
        Delegate level3 = new Delegate(level2);
        // ......
    }
    
    @RequiredArgsConstructor
    static class Delegate implements Runnable{
        
        private final Runnable runnable;

        @Override
        public void run() {
            runnable.run();
        }
    }

当然,委托的层级越多,代码结构就会越复杂,不利于理解和维护。多层级委托这个洋葱结构,再配合Java反射API剥离对具体方法调用的依赖,就是Java中切面编程的普遍原理,spring-aop就是这样实现的。委托如果再结合Agent和字节码增强(使用ASMJavassist等),可以实现类加载时期替换对应的RunnableCallable或者一般接口的实现,这样就能无感知完成了增强功能。此外,TTL中还使用了模板方法模式,如:

@Slf4j
public class TemplateMethod {

    public static void main(String[] args) throws Exception {
        Runnable runnable = () -> log.info("Hello World!");
        Template template = new Template(runnable) {
            @Override
            protected void beforeExecute() {
                log.info("BeforeExecute...");
            }

            @Override
            protected void afterExecute() {
                log.info("AfterExecute...");
            }
        };
        template.run();
    }

    @RequiredArgsConstructor
    static abstract class Template implements Runnable {

        private final Runnable runnable;

        protected void beforeExecute() {

        }

        @Override
        public void run() {
            beforeExecute();
            runnable.run();
            afterExecute();
        }

        protected void afterExecute() {

        }
    }
}
// 输出结果:
00:25:32.862 [main] INFO club.throwable.juc.TemplateMethod - BeforeExecute...
00:25:32.865 [main] INFO club.throwable.juc.TemplateMethod - Hello World!
00:25:32.865 [main] INFO club.throwable.juc.TemplateMethod - AfterExecute...

分析了两种设计模式,下面简单理解一下TTL实现的伪代码:

# TTL extends InheritableThreadLocal
# Holder of TTL -> InheritableThreadLocal<WeakHashMap<TransmittableThreadLocal<Object>, ?>> [? => NULL]
(1)创建一个全局的Holder,用于保存父线程(或者明确了父线程的子线程)的TTL对象,这里注意,是TTL对象,Holder是当作Set使用
(2)(父)线程A中使用了TTL,则所有设置的变量会被TTL捕获
(3)(子)线程B使用了TtlRunnable(Runnable的TTL实现,使用了前面提到的委托,像Callable的实现是TtlCallable),会重放所有存储在TTL中的,来自于线程A的存储变量
(4)线程B重放完毕后,清理线程B独立产生的ThreadLocal变量,归还变TTL的变量

主要就是这几步,里面的话术有点抽象,后面一节分析源码的时候会详细讲解。

TTL的源码分析

主要分析:

  • 框架的骨架。
  • 核心类TransmittableThreadLocal
  • 发射器Transmitter
  • 捕获、重放和复原。
  • Agent模块。

TTL框架骨架

TTL是一个十分精悍的框架,它依赖少量的类实现了比较强大的功能,除了提供给用户使用的API,还提供了基于Agent和字节码增强实现了无感知增强泛线程池对应类的功能,这一点是比较惊艳的。这里先分析编程式的API,再简单分析Agent部分的实现。笔者阅读TTL框架的时间是2020年五一劳动节前后,当前的最新发行版本为2.11.4TTL的项目结构很简单:

- transmittable-thread-local
  - com.alibaba.ttl
   - spi   SPI接口和一些实现
   - threadpool   线程池增强,包括ThreadFactory和线程池的Wrapper等
     - agent   线程池的Agent实现相关
   最外层的包有一些Wrapper的实现和TTL

先看spi包:

- spi
  TtlAttachments
  TtlAttachmentsDelegate
  TtlEnhanced
  TtlWrapper

通过transmittable-thread-local源码理解线程池线程本地变量传递的原理

TtlEnhancedTTL的标识接口(空接口),标识具体的组件被TTL增强:

public interface TtlEnhanced {

}

通过instanceof关键字就可以判断具体的实现是否TTL增强过的组件。TtlWrapper接口继承自接口TtlEnhanced,用于标记实现类可以解包装获得原始实例:

public interface TtlWrapper<T> extends TtlEnhanced {
    
    // 返回解包装实例,实际是就是原始实例
    @NonNull
    T unwrap();
}

TtlAttachments接口也是继承自接口TtlEnhanced,用于为TTL添加K-V结构的附件,TtlAttachmentsDelegate是其实现类,K-V的存储实际上是委托给ConcurrentHashMap

public interface TtlAttachments extends TtlEnhanced {
    
    // 添加K-V附件
    void setTtlAttachment(@NonNull String key, Object value);
    
    // 通过KEY获取值
    <T> T getTtlAttachment(@NonNull String key);
    
    // 标识自动包装的KEY,Agent模式会使用自动包装,这个时候会传入一个附件的K-V,其中KEY就是KEY_IS_AUTO_WRAPPER
    String KEY_IS_AUTO_WRAPPER = "ttl.is.auto.wrapper";
}

// TtlAttachmentsDelegate
public class TtlAttachmentsDelegate implements TtlAttachments {

    private final ConcurrentMap<String, Object> attachments = new ConcurrentHashMap<String, Object>();

    @Override
    public void setTtlAttachment(@NonNull String key, Object value) {
        attachments.put(key, value);
    }

    @Override
    @SuppressWarnings("unchecked")
    public <T> T getTtlAttachment(@NonNull String key) {
        return (T) attachments.get(key);
    }
}

因为TTL的实现覆盖了泛线程池ExecutorExecutorServiceScheduledExecutorServiceForkJoinPoolTimerTask(在TTL中组件已经标记为过期,推荐使用ScheduledExecutorService),范围比较广,短篇幅无法分析所有的源码,而且它们的实现思路是基本一致的,笔者下文只会挑选Executor的实现路线进行分析。

通过transmittable-thread-local源码理解线程池线程本地变量传递的原理

核心类TransmittableThreadLocal

TransmittableThreadLocalTTL的核心类,TTL框架就是用这个类来命名的。先看它的构造函数和关键属性:

// 函数式接口,TTL拷贝器
@FunctionalInterface
public interface TtlCopier<T> {
   
    // 拷贝父属性
    T copy(T parentValue);
}

public class TransmittableThreadLocal<T> extends InheritableThreadLocal<T> implements TtlCopier<T> {

    // 日志句柄,使用的不是SLF4J的接口,而是java.util.logging的实现
    private static final Logger logger = Logger.getLogger(TransmittableThreadLocal.class.getName());
    
    // 是否禁用忽略NULL值的语义
    private final boolean disableIgnoreNullValueSemantics;
    
    // 默认是false,也就是不禁用忽略NULL值的语义,也就是忽略NULL值,也就是默认的话,NULL值传入不会覆盖原来已经存在的值
    public TransmittableThreadLocal() {
        this(false);
    }
    
    // 可以通过手动设置,去覆盖IgnoreNullValue的语义,如果设置为true,则是支持NULL值的设置,设置为true的时候,与ThreadLocal的语义一致
    public TransmittableThreadLocal(boolean disableIgnoreNullValueSemantics) {
        this.disableIgnoreNullValueSemantics = disableIgnoreNullValueSemantics;
    }
    
    // 先忽略其他代码
}

disableIgnoreNullValueSemantics属性相关可以查看Issue157,下文分析方法的时候也会说明具体的场景。TransmittableThreadLocal继承自InheritableThreadLocal,本质就是ThreadLocal,那它到底怎么样保证变量可以在线程池中的线程传递?接着分析其他所有方法:

public class TransmittableThreadLocal<T> extends InheritableThreadLocal<T> implements TtlCopier<T> {
    
    // 拷贝器的拷贝方法实现
    public T copy(T parentValue) {
        return parentValue;
    }

    // 模板方法,留给子类实现,在TtlRunnable或者TtlCallable执行前回调
    protected void beforeExecute() {
    }

    // 模板方法,留给子类实现,在TtlRunnable或者TtlCallable执行后回调
    protected void afterExecute() {
    }

    // 获取值,直接从InheritableThreadLocal#get()获取
    @Override
    public final T get() {
        T value = super.get();
        // 如果值不为NULL 或者 禁用了忽略空值的语义(也就是和ThreadLocal语义一致),则重新添加TTL实例自身到存储器
        if (disableIgnoreNullValueSemantics || null != value) addThisToHolder();
        return value;
    }
    
    @Override
    public final void set(T value) {
        // 如果不禁用忽略空值的语义,也就是需要忽略空值,并且设置的入参值为空,则做一次彻底的移除,包括从存储器移除TTL自身实例,TTL(ThrealLocalMap)中也移除对应的值
        if (!disableIgnoreNullValueSemantics && null == value) {
            // may set null to remove value
            remove();
        } else {
            // TTL(ThrealLocalMap)中设置对应的值
            super.set(value);
            // 添加TTL实例自身到存储器
            addThisToHolder();
        }
    }
   
    // 从存储器移除TTL自身实例,从TTL(ThrealLocalMap)中移除对应的值
    @Override
    public final void remove() {
        removeThisFromHolder();
        super.remove();
    }
    
    // 从TTL(ThrealLocalMap)中移除对应的值
    private void superRemove() {
        super.remove();
    }
    
    // 拷贝值,主要是拷贝get()的返回值
    private T copyValue() {
        return copy(get());
    }
     
    // 存储器,本身就是一个InheritableThreadLocal(ThreadLocal)
    // 它的存放对象是WeakHashMap<TransmittableThreadLocal<Object>, ?>类型,而WeakHashMap的VALUE总是为NULL,这里当做Set容器使用,WeakHashMap支持NULL值
    private static InheritableThreadLocal<WeakHashMap<TransmittableThreadLocal<Object>, ?>> holder =
            new InheritableThreadLocal<WeakHashMap<TransmittableThreadLocal<Object>, ?>>() {
                @Override
                protected WeakHashMap<TransmittableThreadLocal<Object>, ?> initialValue() {
                    return new WeakHashMap<TransmittableThreadLocal<Object>, Object>();
                }

                @Override
                protected WeakHashMap<TransmittableThreadLocal<Object>, ?> childValue(WeakHashMap<TransmittableThreadLocal<Object>, ?> parentValue) {
                    // 注意这里的WeakHashMap总是拷贝父线程的值
                    return new WeakHashMap<TransmittableThreadLocal<Object>, Object>(parentValue);
                }
            };
    
    // 添加TTL自身实例到存储器,不存在则添加策略
    @SuppressWarnings("unchecked")
    private void addThisToHolder() {
        if (!holder.get().containsKey(this)) {
            holder.get().put((TransmittableThreadLocal<Object>) this, null); // WeakHashMap supports null value.
        }
    }
    
    // 从存储器移除TTL自身的实例
    private void removeThisFromHolder() {
        holder.get().remove(this);
    }
    
    // 执行目标方法,isBefore决定回调beforeExecute还是afterExecute,注意此回调方法会吞掉所有的异常只打印日志
    private static void doExecuteCallback(boolean isBefore) {
        for (TransmittableThreadLocal<Object> threadLocal : holder.get().keySet()) {
            try {
                if (isBefore) threadLocal.beforeExecute();
                else threadLocal.afterExecute();
            } catch (Throwable t) {
                if (logger.isLoggable(Level.WARNING)) {
                    logger.log(Level.WARNING, "TTL exception when " + (isBefore ? "beforeExecute" : "afterExecute") + ", cause: " + t.toString(), t);
                }
            }
        }
    }
    
    // DEBUG模式下打印TTL里面的所有值
    static void dump(@Nullable String title) {
        if (title != null && title.length() > 0) {
            System.out.printf("Start TransmittableThreadLocal[%s] Dump...%n", title);
        } else {
            System.out.println("Start TransmittableThreadLocal Dump...");
        }

        for (TransmittableThreadLocal<Object> threadLocal : holder.get().keySet()) {
            System.out.println(threadLocal.get());
        }
        System.out.println("TransmittableThreadLocal Dump end!");
    }
    
    // DEBUG模式下打印TTL里面的所有值
    static void dump() {
        dump(null);
    }

    // 省略静态类Transmitter的实现代码
}

这里一定要记住holder是全局静态的,并且它自身也是一个InheritableThreadLocalget()方法也是线程隔离的),它实际上就是父线程管理所有TransmittableThreadLocal的桥梁。这里可以考虑一个单线程的例子来说明TransmittableThreadLocal的存储架构:

public class TtlSample3 {

    static TransmittableThreadLocal<String> TTL1 = new TransmittableThreadLocal<>();
    static TransmittableThreadLocal<String> TTL2 = new TransmittableThreadLocal<>();
    static TransmittableThreadLocal<String> TTL3 = new TransmittableThreadLocal<>();

    public static void main(String[] args) throws Exception {
        TTL1.set("VALUE-1");
        TTL2.set("VALUE-2");
        TTL3.set("VALUE-3");
    }
}

这里简化了例子,只演示了单线程的场景,图中的一些对象的哈希码有可能每次启动JVM实例都不一样,这里只是做示例:

通过transmittable-thread-local源码理解线程池线程本地变量传递的原理

注释里面也提到,holder里面的WeakHashMap是当成Set容器使用,映射的值都是NULL,每次遍历它的所有KEY就能获取holder里面的所有的TransmittableThreadLocal实例,它是一个全局的存储器,但是本身是一个InheritableThreadLocal,多线程共享后的映射关系会相对复杂:

通过transmittable-thread-local源码理解线程池线程本地变量传递的原理

再聊一下disableIgnoreNullValueSemantics的作用,默认情况下disableIgnoreNullValueSemantics=falseTTL如果设置NULL值,会直接从holder移除对应的TTL实例,在TTL#get()方法被调用的时候,如果原来持有的属性不为NULL,该TTL实例会重新加到holder。如果设置disableIgnoreNullValueSemantics=true,则set(null)的语义和ThreadLocal一致。见下面的例子:

public class TtlSample4 {

    static TransmittableThreadLocal<Integer> TL1 = new TransmittableThreadLocal<Integer>(false) {
        @Override
        protected Integer initialValue() {
            return 5;
        }

        @Override
        protected Integer childValue(Integer parentValue) {
            return 10;
        }
    };

    static TransmittableThreadLocal<Integer> TL2 = new TransmittableThreadLocal<Integer>(true) {
        @Override
        protected Integer initialValue() {
            return 5;
        }

        @Override
        protected Integer childValue(Integer parentValue) {
            return 10;
        }
    };

    public static void main(String[] args) throws Exception {
        TL1.set(null);
        TL2.set(null);
        Thread t1 = new Thread(TtlRunnable.get(() -> {
            System.out.println(String.format("Thread:%s,value:%s", Thread.currentThread().getName(), TL1.get()));
        }), "T1");

        Thread t2 = new Thread(TtlRunnable.get(() -> {
            System.out.println(String.format("Thread:%s,value:%s", Thread.currentThread().getName(), TL2.get()));
        }), "T2");
        t1.start();
        t2.start();
        TimeUnit.SECONDS.sleep(Long.MAX_VALUE);
    }
}
// 输出结果:
Thread:T2,value:null
Thread:T1,value:5

这是因为框架的设计者不想把NULL作为有状态的值,如果真的有需要保持和ThreadLocal一致的用法,可以在构造TransmittableThreadLocal实例的时候传入true

发射器Transmitter

发射器TransmitterTransmittableThreadLocal的一个公有静态类,它的核心功能是传输所有的TransmittableThreadLocal实例和提供静态方法注册当前线程的变量到其他线程。按照笔者阅读源码的习惯,先看构造函数和关键属性:

// # TransmittableThreadLocal#Transmitter
public static class Transmitter {
    
    // 保存手动注册的ThreadLocal->TtlCopier映射,这里是因为部分API提供了TtlCopier给用户实现
    private static volatile WeakHashMap<ThreadLocal<Object>, TtlCopier<Object>> threadLocalHolder = new WeakHashMap<ThreadLocal<Object>, TtlCopier<Object>>();
    // threadLocalHolder更变时候的监视器
    private static final Object threadLocalHolderUpdateLock = new Object();
    // 标记WeakHashMap中的ThreadLocal的对应值为NULL的属性,便于后面清理
    private static final Object threadLocalClearMark = new Object();
    
    // 默认的拷贝器,影子拷贝,直接返回父值
    private static final TtlCopier<Object> shadowCopier = new TtlCopier<Object>() {
        @Override
        public Object copy(Object parentValue) {
            return parentValue;
        }
    };
    
    // 私有构造,说明只能通过静态方法提供外部调用
    private Transmitter() {
        throw new InstantiationError("Must not instantiate this class");
    }
    
    // 私有静态类,快照,保存从holder中捕获的所有TransmittableThreadLocal和外部手动注册保存在threadLocalHolder的ThreadLocal的K-V映射快照
    private static class Snapshot {
        final WeakHashMap<TransmittableThreadLocal<Object>, Object> ttl2Value;
        final WeakHashMap<ThreadLocal<Object>, Object> threadLocal2Value;

        private Snapshot(WeakHashMap<TransmittableThreadLocal<Object>, Object> ttl2Value, WeakHashMap<ThreadLocal<Object>, Object> threadLocal2Value) {
            this.ttl2Value = ttl2Value;
            this.threadLocal2Value = threadLocal2Value;
        }
    }
}

Transmitter在设计上是一个典型的工具类,外部只能调用其公有静态方法。接着看其他静态方法:

// # TransmittableThreadLocal#Transmitter
public static class Transmitter {

    //######################################### 捕获 ###########################################################

    // 捕获当前线程绑定的所有的TransmittableThreadLocal和已经注册的ThreadLocal的值 - 使用了用时拷贝快照的策略
    // 笔者注:它一般在构造任务实例的时候被调用,因此当前线程相对于子线程或者线程池的任务就是父线程,其实本质是捕获父线程的所有线程本地变量的值
    @NonNull
    public static Object capture() {
        return new Snapshot(captureTtlValues(), captureThreadLocalValues());
    }
    
    // 新建一个WeakHashMap,遍历TransmittableThreadLocal#holder中的所有TransmittableThreadLocal的Entry,获取K-V,存放到这个新的WeakHashMap返回
    private static WeakHashMap<TransmittableThreadLocal<Object>, Object> captureTtlValues() {
        WeakHashMap<TransmittableThreadLocal<Object>, Object> ttl2Value = new WeakHashMap<TransmittableThreadLocal<Object>, Object>();
        for (TransmittableThreadLocal<Object> threadLocal : holder.get().keySet()) {
            ttl2Value.put(threadLocal, threadLocal.copyValue());
        }
        return ttl2Value;
    }
    
    // 新建一个WeakHashMap,遍历threadLocalHolder中的所有ThreadLocal的Entry,获取K-V,存放到这个新的WeakHashMap返回
    private static WeakHashMap<ThreadLocal<Object>, Object> captureThreadLocalValues() {
        final WeakHashMap<ThreadLocal<Object>, Object> threadLocal2Value = new WeakHashMap<ThreadLocal<Object>, Object>();
        for (Map.Entry<ThreadLocal<Object>, TtlCopier<Object>> entry : threadLocalHolder.entrySet()) {
            final ThreadLocal<Object> threadLocal = entry.getKey();
            final TtlCopier<Object> copier = entry.getValue();
            threadLocal2Value.put(threadLocal, copier.copy(threadLocal.get()));
        }
        return threadLocal2Value;
    }

    //######################################### 重放 ###########################################################

    // 重放capture()方法中捕获的TransmittableThreadLocal和手动注册的ThreadLocal中的值,本质是重新拷贝holder中的所有变量,生成新的快照
    // 笔者注:重放操作一般会在子线程或者线程池中的线程的任务执行的时候调用,因此此时的holder#get()拿到的是子线程的原来就存在的本地线程变量,重放操作就是把这些子线程原有的本地线程变量备份
    @NonNull
    public static Object replay(@NonNull Object captured) {
        final Snapshot capturedSnapshot = (Snapshot) captured;
        return new Snapshot(replayTtlValues(capturedSnapshot.ttl2Value), replayThreadLocalValues(capturedSnapshot.threadLocal2Value));
    }
    
    // 重放所有的TTL的值
    @NonNull
    private static WeakHashMap<TransmittableThreadLocal<Object>, Object> replayTtlValues(@NonNull WeakHashMap<TransmittableThreadLocal<Object>, Object> captured) {
        // 新建一个新的备份WeakHashMap,其实也是一个快照
        WeakHashMap<TransmittableThreadLocal<Object>, Object> backup = new WeakHashMap<TransmittableThreadLocal<Object>, Object>();
        // 这里的循环针对的是子线程,用于获取的是子线程的所有线程本地变量
        for (final Iterator<TransmittableThreadLocal<Object>> iterator = holder.get().keySet().iterator(); iterator.hasNext(); ) {
            TransmittableThreadLocal<Object> threadLocal = iterator.next();

            // 拷贝holder当前线程(子线程)绑定的所有TransmittableThreadLocal的K-V结构到备份中
            backup.put(threadLocal, threadLocal.get());

            // 清理所有的非捕获快照中的TTL变量,以防有中间过程引入的额外的TTL变量(除了父线程的本地变量)影响了任务执行后的重放操作
            // 简单来说就是:移除所有子线程的不包含在父线程捕获的线程本地变量集合的中所有子线程本地变量和对应的值
            /**
             * 这个问题可以举个简单的例子:
             * static TransmittableThreadLocal<Integer> TTL = new TransmittableThreadLocal<>();
             * 
             * 线程池中的子线程C中原来初始化的时候,在线程C中绑定了TTL的值为10087,C线程是核心线程不会主动销毁。
             * 
             * 父线程P在没有设置TTL值的前提下,调用了线程C去执行任务,那么在C线程的Runnable包装类中通过TTL#get()就会获取到10087,显然是不符合预期的
             *
             * 所以,在C线程的Runnable包装类之前之前,要从C线程的线程本地变量,移除掉不包含在父线程P中的所有线程本地变量,确保Runnable包装类执行期间只能拿到父线程中捕获到的线程本地变量
             *
             * 下面这个判断和移除做的就是这个工作
             */
            if (!captured.containsKey(threadLocal)) {
                iterator.remove();
                threadLocal.superRemove();
            }
        }

        // 重新设置TTL的值到捕获的快照中
        // 其实真实的意图是:把从父线程中捕获的所有线程本地变量重写设置到TTL中,本质上,子线程holder里面的TTL绑定的值会被刷新
        setTtlValuesTo(captured);

        // 回调模板方法beforeExecute
        doExecuteCallback(true);

        return backup;
    }
    
    // 提取WeakHashMap中的KeySet,遍历所有的TransmittableThreadLocal,重新设置VALUE
    private static void setTtlValuesTo(@NonNull WeakHashMap<TransmittableThreadLocal<Object>, Object> ttlValues) {
        for (Map.Entry<TransmittableThreadLocal<Object>, Object> entry : ttlValues.entrySet()) {
            TransmittableThreadLocal<Object> threadLocal = entry.getKey();
            // 重新设置TTL值,本质上,当前线程(子线程)holder里面的TTL绑定的值会被刷新
            threadLocal.set(entry.getValue());
        }
    }
    
    // 重放所有的手动注册的ThreadLocal的值
    private static WeakHashMap<ThreadLocal<Object>, Object> replayThreadLocalValues(@NonNull WeakHashMap<ThreadLocal<Object>, Object> captured) {
        // 新建备份
        final WeakHashMap<ThreadLocal<Object>, Object> backup = new WeakHashMap<ThreadLocal<Object>, Object>();
        // 注意这里是遍历捕获的快照中的ThreadLocal
        for (Map.Entry<ThreadLocal<Object>, Object> entry : captured.entrySet()) {
            final ThreadLocal<Object> threadLocal = entry.getKey();
            // 添加到备份中
            backup.put(threadLocal, threadLocal.get());
            final Object value = entry.getValue();
            // 如果值为清除标记则绑定在当前线程的变量进行remove,否则设置值覆盖
            if (value == threadLocalClearMark) threadLocal.remove();
            else threadLocal.set(value);
        }
        return backup;
    }

    // 从relay()或者clear()方法中恢复TransmittableThreadLocal和手工注册的ThreadLocal的值对应的备份
    // 笔者注:恢复操作一般会在子线程或者线程池中的线程的任务执行的时候调用
    public static void restore(@NonNull Object backup) {
        final Snapshot backupSnapshot = (Snapshot) backup;
        restoreTtlValues(backupSnapshot.ttl2Value);
        restoreThreadLocalValues(backupSnapshot.threadLocal2Value);
    }

    private static void restoreTtlValues(@NonNull WeakHashMap<TransmittableThreadLocal<Object>, Object> backup) {
        // 回调模板方法afterExecute
        doExecuteCallback(false);
        // 这里的循环针对的是子线程,用于获取的是子线程的所有线程本地变量
        for (final Iterator<TransmittableThreadLocal<Object>> iterator = holder.get().keySet().iterator(); iterator.hasNext(); ) {
            TransmittableThreadLocal<Object> threadLocal = iterator.next();
            // 如果子线程原来就绑定的线程本地变量的值,如果不包含某个父线程传来的对象,那么就删除
            // 这一步可以结合前面reply操作里面的方法段一起思考,如果不删除的话,就相当于子线程的原来存在的线程本地变量绑定值被父线程对应的值污染了
            if (!backup.containsKey(threadLocal)) {
                iterator.remove();
                threadLocal.superRemove();
            }
        }

        // 重新设置TTL的值到捕获的快照中
        // 其实真实的意图是:把子线程的线程本地变量恢复到reply()的备份(前面的循环已经做了父线程捕获变量的判断),本质上,等于把holder中绑定于子线程本地变量的部分恢复到reply操作之前的状态
        setTtlValuesTo(backup);
    }
    
    // 恢复所有的手动注册的ThreadLocal的值
    private static void restoreThreadLocalValues(@NonNull WeakHashMap<ThreadLocal<Object>, Object> backup) {
        for (Map.Entry<ThreadLocal<Object>, Object> entry : backup.entrySet()) {
            final ThreadLocal<Object> threadLocal = entry.getKey();
            threadLocal.set(entry.getValue());
        }
    }
}   

这里三个核心方法,看起来比较抽象,要结合多线程的场景和一些空间想象进行推敲才能比较容易地理解:

  • capture():捕获操作,父线程原来就存在的线程本地变量映射和手动注册的线程本地变量映射捕获,得到捕获的快照值captured
  • reply():重放操作,子线程原来就存在的线程本地变量映射和手动注册的线程本地变量生成备份backup,刷新captured的所有值到子线程在全局存储器holder中绑定的值。
  • restore():复原操作,子线程原来就存在的线程本地变量映射和手动注册的线程本地变量恢复成backup

setTtlValuesTo()这个方法比较隐蔽,要特别要结合多线程和空间思维去思考,例如当入参是captured,本质是从父线程捕获到的绑定在父线程的所有线程本地变量,调用的时机在reply()restore(),这两个方法只会在子线程中调用,setTtlValuesTo()里面拿到的TransmittableThreadLocal实例调用set()方法相当于把绑定在父线程的所有线程本地变量的值全部刷新到子线程当前绑定的TTL中的线程本地变量的值,更深层次地想,是基于外部的传入值刷新了子线程绑定在全局存储器holder里面绑定到该子线程的线程本地变量的值。

通过transmittable-thread-local源码理解线程池线程本地变量传递的原理

Transmitter还有不少静态工具方法,这里不做展开,可以参考项目里面的测试demoREADME.md进行调试。

捕获、重放和复原

其实上面一节已经介绍了Transmitter提供的捕获、重放和复原的API,这一节主要结合分析TtlRunnable中的相关逻辑。TtlRunnable的源码如下:

public final class TtlRunnable implements Runnable, TtlWrapper<Runnable>, TtlEnhanced, TtlAttachments {

    // 存放从父线程捕获得到的线程本地变量映射的备份
    private final AtomicReference<Object> capturedRef;
    // 原始的Runable实例
    private final Runnable runnable;
    // 执行之后是否释放TTL值引用
    private final boolean releaseTtlValueReferenceAfterRun;

    private TtlRunnable(@NonNull Runnable runnable, boolean releaseTtlValueReferenceAfterRun) {
        // 这里关键点:TtlRunnable实例化的时候就已经进行了线程本地变量的捕获,所以一定是针对父线程的,因为此时任务还没提交到线程池
        this.capturedRef = new AtomicReference<Object>(capture());
        this.runnable = runnable;
        this.releaseTtlValueReferenceAfterRun = releaseTtlValueReferenceAfterRun;
    }

    @Override
    public void run() {
        // 获取父线程捕获到的线程本地变量映射的备份,做一些前置判断
        Object captured = capturedRef.get();
        if (captured == null || releaseTtlValueReferenceAfterRun && !capturedRef.compareAndSet(captured, null)) {
            throw new IllegalStateException("TTL value reference is released after run!");
        }
        // 重放操作
        Object backup = replay(captured);
        try {
            // 真正的Runnable调用
            runnable.run();
        } finally {
            // 复原操作
            restore(backup);
        }
    }

    @Nullable
    public static TtlRunnable get(@Nullable Runnable runnable) {
        return get(runnable, false, false);
    }

    @Nullable
    public static TtlRunnable get(@Nullable Runnable runnable, boolean releaseTtlValueReferenceAfterRun, boolean idempotent) {
        if (null == runnable) return null;
        if (runnable instanceof TtlEnhanced) {
            // avoid redundant decoration, and ensure idempotency
            if (idempotent) return (TtlRunnable) runnable;
            else throw new IllegalStateException("Already TtlRunnable!");
        }
        return new TtlRunnable(runnable, releaseTtlValueReferenceAfterRun);
    }
    
    // 省略其他不太重要的方法
}

其实关注点只需要放在构造函数、run()方法,其他都是基于此做修饰或者扩展。构造函数的源码说明,capture()TtlRunnable实例化的时候已经被调用,实例化它的一般就是父线程,所以整体的执行流程如下:

通过transmittable-thread-local源码理解线程池线程本地变量传递的原理

Agent模块

启用Agent功能,需要在Java的启动参数添加:-javaagent:path/to/transmittable-thread-local-x.yzx.jar。原理是通过Instrumentation回调激发ClassFileTransformer实现目标类的字节码增强,使用到javassist,被增强的类主要是泛线程池的类:

  • Executor体系:主要包括ThreadPoolExecutorScheduledThreadPoolExecutor,对应的字节码增强类实现是TtlExecutorTransformlet
  • ForkJoinPool:对应的字节码增强类实现是TtlForkJoinTransformlet
  • TimerTask:对应的字节码增强类实现是TtlTimerTaskTransformlet

Agent的入口类是TtlAgent,这里查看对应的源码:

public final class TtlAgent {
    
    public static void premain(String agentArgs, @NonNull Instrumentation inst) {
        kvs = splitCommaColonStringToKV(agentArgs);

        Logger.setLoggerImplType(getLogImplTypeFromAgentArgs(kvs));
        final Logger logger = Logger.getLogger(TtlAgent.class);

        try {
            logger.info("[TtlAgent.premain] begin, agentArgs: " + agentArgs + ", Instrumentation: " + inst);
            final boolean disableInheritableForThreadPool = isDisableInheritableForThreadPool();
            // 装载所有的JavassistTransformlet
            final List<JavassistTransformlet> transformletList = new ArrayList<JavassistTransformlet>();
            transformletList.add(new TtlExecutorTransformlet(disableInheritableForThreadPool));
            transformletList.add(new TtlForkJoinTransformlet(disableInheritableForThreadPool));
            if (isEnableTimerTask()) transformletList.add(new TtlTimerTaskTransformlet());
            final ClassFileTransformer transformer = new TtlTransformer(transformletList);
            inst.addTransformer(transformer, true);
            logger.info("[TtlAgent.premain] addTransformer " + transformer.getClass() + " success");
            logger.info("[TtlAgent.premain] end");
            ttlAgentLoaded = true;
        } catch (Exception e) {
            String msg = "Fail to load TtlAgent , cause: " + e.toString();
            logger.log(Level.SEVERE, msg, e);
            throw new IllegalStateException(msg, e);
        }
    }
}

List<JavassistTransformlet>作为参数传入ClassFileTransformer的实现类TtlTransformer中,其中的转换方法为:

public class TtlTransformer implements ClassFileTransformer {

    private final List<JavassistTransformlet> transformletList = new ArrayList<JavassistTransformlet>();

    TtlTransformer(List<? extends JavassistTransformlet> transformletList) {
        for (JavassistTransformlet transformlet : transformletList) {
            this.transformletList.add(transformlet);
            logger.info("[TtlTransformer] add Transformlet " + transformlet.getClass() + " success");
        }
    }

    @Override
    public final byte[] transform(@Nullable final ClassLoader loader, @Nullable final String classFile, final Class<?> classBeingRedefined,
                                  final ProtectionDomain protectionDomain, @NonNull final byte[] classFileBuffer) {
        try {
            // Lambda has no class file, no need to transform, just return.
            if (classFile == null) return NO_TRANSFORM;
            final String className = toClassName(classFile);
            ClassInfo classInfo = new ClassInfo(className, classFileBuffer, loader);
            // 这里做变量,如果字节码被修改,则跳出循环返回
            for (JavassistTransformlet transformlet : transformletList) {
                transformlet.doTransform(classInfo);
                if (classInfo.isModified()) return classInfo.getCtClass().toBytecode();
            }
        } catch (Throwable t) {
            String msg = "Fail to transform class " + classFile + ", cause: " + t.toString();
            logger.log(Level.SEVERE, msg, t);
            throw new IllegalStateException(msg, t);
        }
        return NO_TRANSFORM;
    }
}

这里挑选TtlExecutorTransformlet的部分方法来看:

    @Override
    public void doTransform(@NonNull final ClassInfo classInfo) throws IOException, NotFoundException, CannotCompileException {
        // 如果当前加载的类包含java.util.concurrent.ThreadPoolExecutor或者java.util.concurrent.ScheduledThreadPoolExecutor
        if (EXECUTOR_CLASS_NAMES.contains(classInfo.getClassName())) {
            final CtClass clazz = classInfo.getCtClass();
            // 遍历所有的方法进行增强
            for (CtMethod method : clazz.getDeclaredMethods()) {
                updateSubmitMethodsOfExecutorClass_decorateToTtlWrapperAndSetAutoWrapperAttachment(method);
            }
            // 省略其他代码
        } 
        // 省略其他代码
    }

    private void updateSubmitMethodsOfExecutorClass_decorateToTtlWrapperAndSetAutoWrapperAttachment(@NonNull final CtMethod method) throws NotFoundException, CannotCompileException {
        final int modifiers = method.getModifiers();
        if (!Modifier.isPublic(modifiers) || Modifier.isStatic(modifiers)) return;
        // 这里主要在java.lang.Runnable构造时候调用com.alibaba.ttl.TtlRunnable#get()包装为com.alibaba.ttl.TtlRunnable
        // 在java.util.concurrent.Callable构造时候调用com.alibaba.ttl.TtlCallable#get()包装为com.alibaba.ttl.TtlCallable
        // 并且设置附件K-V为ttl.is.auto.wrapper=true
        CtClass[] parameterTypes = method.getParameterTypes();
        StringBuilder insertCode = new StringBuilder();
        for (int i = 0; i < parameterTypes.length; i++) {
            final String paramTypeName = parameterTypes[i].getName();
            if (PARAM_TYPE_NAME_TO_DECORATE_METHOD_CLASS.containsKey(paramTypeName)) {
                String code = String.format(
                        // decorate to TTL wrapper,
                        // and then set AutoWrapper attachment/Tag
                        "$%d = %s.get($%d, false, true);"
                                + "\ncom.alibaba.ttl.threadpool.agent.internal.transformlet.impl.Utils.setAutoWrapperAttachment($%<d);",
                        i + 1, PARAM_TYPE_NAME_TO_DECORATE_METHOD_CLASS.get(paramTypeName), i + 1);
                logger.info("insert code before method " + signatureOfMethod(method) + " of class " + method.getDeclaringClass().getName() + ": " + code);
                insertCode.append(code);
            }
        }
        if (insertCode.length() > 0) method.insertBefore(insertCode.toString());
    }

上面分析的方法的功能,就是让java.util.concurrent.ThreadPoolExecutorjava.util.concurrent.ScheduledThreadPoolExecutor的字节码被增强,提交的java.lang.Runnable类型的任务会被包装为TtlRunnable,提交的java.util.concurrent.Callable类型的任务会被包装为TtlCallable,实现了无入侵无感知地嵌入TTL的功能。

小结

TTL在使用线程池等会池化复用线程的执行组件情况下,提供ThreadLocal值的传递功能,解决异步执行时上下文传递的问题。 它是一个Java标准库,为框架/中间件设施开发提供的标配能力,项目代码精悍,只依赖了javassist做字节码增强,实现Agent模式下的近乎无入侵提供TTL功能的特性。TTL能在业务代码中实现透明/自动完成所有异步执行上下文的可定制、规范化的捕捉/传递,如果恰好碰到异步执行时上下文传递的问题,建议可以尝试此库。

参考资料:

  • JDK11相关源码
  • TTL源码

个人博客

(本文完 c-14-d e-a-20200502)

上一篇:ping命令的常用方法


下一篇:全志F1C100S从零开发记录(1)