变量的线程传递

背景

我们有些时候会碰到需要在线程池中使用threadLocal的场景,典型的比如分布式链路追踪的traceId,业务场景中的获取用户id等简单信息。如果你去查看下自己公司这块代码的实现,你会发现他们几乎都没有使用jdk自带的ThreadLocal对象,而是使用了alibaba的TransmittleThreadLocal,为什么呢,这就关系到一个线程变量之间的传递性。
在此之前,我们先来了解在java中,一个线程初始化经历了哪些步骤~

线程初始化干了什么

一句话概括,线程初始化时,在参数缺省的情况下,会继承父线程的属性

直接看代码吧,注释在上面。所有的Thread最后都会调用init方法

/*
* 解释下这几个参数
* ThreadGroup g:线程组。线程组见名思义是就是线程的组别,其作用是便于更好的集中管理线程
* Runnable target: 允许的方法
* String name:线程名称,默认是Thread-拼上下一个线程的序号
* long stackSize: 该线程的栈深度,除非是确定该线程的默认配置将用于较深的递归运算,
* 否则不要去修改
* AccessControlContext acc:访问控制上下文,这个不是很了解,和系统安全相关。
* boolean inheritThreadLocals:决定是否要继承父线程的可继承ThreadLocals,默认为true,
* 即继承
*/

public Thread() {
    init(null, null, "Thread-" + nextThreadNum(), 0);
}


private void init(ThreadGroup g, Runnable target, String name,
                  long stackSize) {
    init(g, target, name, stackSize, null, true);
}


private void init(ThreadGroup g, Runnable target, String name,
                  long stackSize, AccessControlContext acc,
                  boolean inheritThreadLocals) {
                  
    // 1. 判断线程的名称是否为空
    if (name == null) {
        throw new NullPointerException("name cannot be null");
    }

    this.name = name;

    Thread parent = currentThread();
    SecurityManager security = System.getSecurityManager();
    // 2. 如果没有设置ThreadGroup,则进行安全检查,并获取父线程的ThreadGroup
    if (g == null) {
        /* Determine if it's an applet or not */

        /* If there is a security manager, ask the security manager
           what to do. */
        if (security != null) {
            g = security.getThreadGroup();
        }

        /* If the security doesn't have a strong opinion of the matter
           use the parent thread group. */
        if (g == null) {
            g = parent.getThreadGroup();
        }
    }

    /* checkAccess regardless of whether or not threadgroup is
       explicitly passed in. */
    g.checkAccess();

    /*
     * Do we have the required permissions?
     */
    if (security != null) {
        if (isCCLOverridden(getClass())) {
            security.checkPermission(SUBCLASS_IMPLEMENTATION_PERMISSION);
        }
    }

    g.addUnstarted();

    this.group = g;
    // 3.下面继承父现场的属性,比如是否是守护线程、优先级、ContextClassLoader...
    this.daemon = parent.isDaemon();
    this.priority = parent.getPriority();
    if (security == null || isCCLOverridden(parent.getClass()))
        this.contextClassLoader = parent.getContextClassLoader();
    else
        this.contextClassLoader = parent.contextClassLoader;
    this.inheritedAccessControlContext =
            acc != null ? acc : AccessController.getContext();
    this.target = target;
    setPriority(priority);
    // 4. 关键,inheritThreadLocals如果是true,就会拷贝父线程的inheritableThreadLocals
    // 到当前线程中
    if (inheritThreadLocals && parent.inheritableThreadLocals != null)
        this.inheritableThreadLocals =
            ThreadLocal.createInheritedMap(parent.inheritableThreadLocals);
    /* Stash the specified stack size in case the VM cares */
    this.stackSize = stackSize;

    /* Set thread ID */
    tid = nextThreadID();
}

看完了上面的代码和注释,相信大家多半有点数了,其中的重点是,子线程copy了父线程的inheritableThreadLocals到自己的inheritableThreadLocals中。而一个ThreadLocal有2个集合存放ThreadLocal,另一个是ThreadLocalsThreadLocals不会被继承
变量的线程传递

如果我们使用ThreadLocal,那set时就会将值放到ThreadLocals中,所以普通的ThreadLocal无法被子线程获取

子线程获取父线程的ThreadLocal

InheritableThreadLocal

InheritableThreadLocal支持在子线程获取父线程中的值

public static void main(String[] args) throws InterruptedException {
    ThreadLocal<String> itl = new InheritableThreadLocal<>();
    itl.set("part");
    new Thread(() -> {
        System.out.println(itl.get());
        itl.set("son");
        System.out.println(itl.get());
    }).start();
    Thread.sleep(500);
    System.out.println("thread:" + itl.get());
}
// 运行结果
part
son
thread:part

为什么呢,相比小伙伴们一定猜出来了,InheritableThreadLocal是从inheritableThreadLocals获取的,而InheritableThreadLocal是默认继承的。

/*
 * Copyright (c) 1998, 2012, Oracle and/or its affiliates. All rights reserved.
 * ORACLE PROPRIETARY/CONFIDENTIAL. Use is subject to license terms.
 */

package java.lang;
import java.lang.ref.*;

/**
 * This class extends <tt>ThreadLocal</tt> to provide inheritance of values
 * from parent thread to child thread: when a child thread is created, the
 * child receives initial values for all inheritable thread-local variables
 * for which the parent has values.  Normally the child's values will be
 * identical to the parent's; however, the child's value can be made an
 * arbitrary function of the parent's by overriding the <tt>childValue</tt>
 * method in this class.
 *
 * <p>Inheritable thread-local variables are used in preference to
 * ordinary thread-local variables when the per-thread-attribute being
 * maintained in the variable (e.g., User ID, Transaction ID) must be
 * automatically transmitted to any child threads that are created.
 *
 * @author  Josh Bloch and Doug Lea
 * @see     ThreadLocal
 * @since   1.2
 */

public class InheritableThreadLocal<T> extends ThreadLocal<T> {
    /**
     * Computes the child's initial value for this inheritable thread-local
     * variable as a function of the parent's value at the time the child
     * thread is created.  This method is called from within the parent
     * thread before the child is started.
     * <p>
     * This method merely returns its input argument, and should be overridden
     * if a different behavior is desired.
     *
     * @param parentValue the parent thread's value
     * @return the child thread's initial value
     */
    protected T childValue(T parentValue) {
        return parentValue;
    }

    /**
     * Get the map associated with a ThreadLocal.
     *
     * @param t the current thread
     */
    ThreadLocalMap getMap(Thread t) {
       return t.inheritableThreadLocals;
    }

    /**
     * Create the map associated with a ThreadLocal.
     *
     * @param t the current thread
     * @param firstValue value for the initial entry of the table.
     */
    void createMap(Thread t, T firstValue) {
        t.inheritableThreadLocals = new ThreadLocalMap(this, firstValue);
    }
}

InheritableThreadLocal本身的设计就很简单,他直接继承ThreadLocal,重写了createMap等方法,从线程的InheritableThreadLocals中获取

线程池中使用ThreadLocal

那InheritableThreadLocal既然能够获取父线程的ThreadLocal了,那为什么还要用TransmittableThreadLocal呢?

因为InheritableThreadLocal不支持池化场景

看懂上面代码的小伙伴一定能够注意到线程变量的继承是发生在线程init这个节点,对于线程池这种已经将线程创建好的场景,并不可能再去调用init。
TransmittableThreadLocal则克服了这种场景,它在InheritableThreadLocal的基础上使线程池的线程也可以拿到当前线程的变量。

先说明TransmittableThreadLocal的解题思路

包装Runnable,将调用线程的变量取出来,一起传入线程池中

先看一下TransmittableThreadLocal的使用方式


public static void main(String[] args) throws InterruptedException {
     // 初始化线程池
     ExecutorService executorService = new ThreadPoolExecutor(1, 1,
             0, TimeUnit.MILLISECONDS, new ArrayBlockingQueue<>(1));
     // 包装线程池,使得在向线程池提交任务时, 会无侵入地将当前线程的线程变量传递进去
     Executor ttlExecutor = TtlExecutors.getTtlExecutor(executorService);
     // 使得核心线程创建,调用init方法,结束后该线程已池化到线程池中。
     executorService.execute(()->{
         log.info("xx");
     });
     // 创建一个ttl并设置值
     TransmittableThreadLocal<String> ttl = new TransmittableThreadLocal<>();
     ttl.set("ps5");
     Thread.sleep(1000L);
     // 在线程池中获取ttl的值
     ttlExecutor.execute(()->{
         log.info(ttl.get()); // 正常输出
     });
}

源码简单解析

由于ttl的代码还挺复杂,我也是一知半解,这里只描述下核心代码和流程,帮助大家理解ttl的原理, 在此先列出要使用ttl的步骤(结合代码)

创建一个线程池

没啥好说的

包装提交的线程池任务


public class TtlExecutors{

    public static Executor getTtlExecutor(@Nullable Executor executor) {
        if (TtlAgent.isTtlAgentLoaded() || null == executor || executor instanceof TtlEnhanced) {
            return executor;
        }
        // 返回包装后的线程池
        return new ExecutorTtlWrapper(executor, true);
    }
    // ...
}

它到底干了什么呢?,看一下ExecutorTtlWrapper,它重写了execute方法,将Runable包装成了
TtlRunnable,相比较Runable,TtlRunable中最重要的一点就是包含了AtomicReference对象,该对象持有ttl的副本的引用,当我们向线程池提交任务的时候,该引用的值会被放入当前线程的InheritableThreadLocals中,所以我们能在方法中引用ttl~

package com.alibaba.ttl.threadpool;

import com.alibaba.ttl.TransmittableThreadLocal;
import com.alibaba.ttl.TtlRunnable;
import com.alibaba.ttl.spi.TtlEnhanced;
import com.alibaba.ttl.spi.TtlWrapper;
import edu.umd.cs.findbugs.annotations.NonNull;

import java.util.concurrent.Executor;

/**
 * {@link TransmittableThreadLocal} Wrapper of {@link Executor},
 * transmit the {@link TransmittableThreadLocal} from the task submit time of {@link Runnable}
 * to the execution time of {@link Runnable}.
 *
 * @author Jerry Lee (oldratlee at gmail dot com)
 * @since 0.9.0
 */
class ExecutorTtlWrapper implements Executor, TtlWrapper<Executor>, TtlEnhanced {
    private final Executor executor;
    protected final boolean idempotent;

    ExecutorTtlWrapper(@NonNull Executor executor, boolean idempotent) {
        this.executor = executor;
        this.idempotent = idempotent;
    }
    
    // 重点嗷
    // 包装Runnable~
    @Override
    public void execute(@NonNull Runnable command) {
        executor.execute(TtlRunnable.get(command, false, idempotent));
    }
    // ...
}

TtlRunnable对象,包装了Runnable,持有AtomicReference


public final class TtlRunnable implements Runnable, TtlWrapper<Runnable>, TtlEnhanced, TtlAttachments {
    // 存放了外部线程ttl的局部变量的副本
    private final AtomicReference<Object> capturedRef;
    private final Runnable runnable;
    private final boolean releaseTtlValueReferenceAfterRun;
    
    // 初始化
    private TtlRunnable(@NonNull Runnable runnable, boolean releaseTtlValueReferenceAfterRun) {
        // 这里的capture方法很重要,它获取当前线程中ttl的值的副本的引用
        this.capturedRef = new AtomicReference<Object>(capture());
        this.runnable = runnable;
        this.releaseTtlValueReferenceAfterRun = releaseTtlValueReferenceAfterRun;
    }

    /**
     * wrap method {@link Runnable#run()}.
     */
    @Override
    // 略

capturedRef是什么时候有值的呢?当然是在构造该对象的时候啦~
来看一下它的构造方法

private TtlRunnable(@NonNull Runnable runnable, boolean releaseTtlValueReferenceAfterRun) {
        // 这里的capture方法很重要,它获取当前线程中ttl的值的副本的引用
        this.capturedRef = new AtomicReference<Object>(capture());
        this.runnable = runnable;
        this.releaseTtlValueReferenceAfterRun = releaseTtlValueReferenceAfterRun;
    }

重点,capture方法


// 返回的是ttl的快照,是一个Snapshot对象
@NonNull
public static Object capture() {
    return new Snapshot(captureTtlValues(), captureThreadLocalValues());
}

// 从当前线程(执行到这部时还是调用线程而非线程池中的执行线程)中获取ttl,并值塞入一个map中
private static HashMap<TransmittableThreadLocal<Object>, Object> captureTtlValues() {
    HashMap<TransmittableThreadLocal<Object>, Object> ttl2Value = new HashMap<TransmittableThreadLocal<Object>, Object>();
    // 拷贝ttl的副本到map中
    for (TransmittableThreadLocal<Object> threadLocal : holder.get().keySet()) {
        ttl2Value.put(threadLocal, threadLocal.copyValue());
    }
    // 返回当前map,作为snapshot的一部分
    return ttl2Value;
}

// 这个还没看懂是干啥的...
private static HashMap<ThreadLocal<Object>, Object> captureThreadLocalValues() {
    final HashMap<ThreadLocal<Object>, Object> threadLocal2Value = new HashMap<ThreadLocal<Object>, Object>();
    for (Map.Entry<ThreadLocal<Object>, TtlCopier<Object>> entry : threadLocalHolder.entrySet()) {
        final ThreadLocal<Object> threadLocal = entry.getKey();
        final TtlCopier<Object> copier = entry.getValue();

        threadLocal2Value.put(threadLocal, copier.copy(threadLocal.get()));
    }
    return threadLocal2Value;
}

从上面可以看到,capturedRef中保存的就是ttl的副本的AtomicReference引用,其实上面有一个非常核心的holder全局变量,它是用来存放ttl的,稍后细说。到此刻,ttl在线程池的传递的前置工作(指由调用线程执行的动作)都已经全部完成。

总结下这一步干了什么事

  1. 将线程池包装成ExecutorTtlWrapper,被包装成ExecutorTtlWrapper的线程池会将每一个提交到线程池中的Runnable变成TtlRunnable
  2. 从当前线程获取ttl上下文的副本,并用AtomicReference存在TtlRunnable

任务执行

从这里开始,执行线程就是线程池中的线程了

/**
* TtlRunnable
*/
@Override
public void run() {
    // 获取当前TtlRunnable的ttl副本引用
    final Object captured = capturedRef.get();
    if (captured == null || releaseTtlValueReferenceAfterRun && !capturedRef.compareAndSet(captured, null)) {
        throw new IllegalStateException("TTL value reference is released after run!");
    }
    // 将调用线程的ttl存入当前线程上下文,并返回备份的存入前当前线程的上下文
    final Object backup = replay(captured);
    // 执行
    try {
        // 此时可以从inheritThreadLocals中直接拿到ttl的值
        runnable.run();
    } finally {
        // 恢复备份的当前线程的上下文
        restore(backup);
    }
}



// 将Snapshot的值塞入到当前线程的inheritThreadLocals
@NonNull
public static Object replay(@NonNull Object captured) {
    final Snapshot capturedSnapshot = (Snapshot) captured;
    return new Snapshot(replayTtlValues(capturedSnapshot.ttl2Value), replayThreadLocalValues(capturedSnapshot.threadLocal2Value));
}

@NonNull
private static HashMap<TransmittableThreadLocal<Object>, Object> replayTtlValues(@NonNull HashMap<TransmittableThreadLocal<Object>, Object> captured) {
    HashMap<TransmittableThreadLocal<Object>, Object> backup = new HashMap<TransmittableThreadLocal<Object>, Object>();

    for (final Iterator<TransmittableThreadLocal<Object>> iterator = holder.get().keySet().iterator(); iterator.hasNext(); ) {
        TransmittableThreadLocal<Object> threadLocal = iterator.next();
        // 备份当前ttl上下文中的所有值
        backup.put(threadLocal, threadLocal.get());

        // clear the TTL values that is not in captured
        // avoid the extra TTL values after replay when run task
        // 从线程上下文中清除不是调用线程ttl的ttl
        if (!captured.containsKey(threadLocal)) {
            iterator.remove();
            threadLocal.superRemove();
        }
    }

    // 设置ttl值到当前线程的inheritThreadLocals
    setTtlValuesTo(captured);

    // call beforeExecute callback
    doExecuteCallback(true);
    // 返回备份
    return backup;
}

private static HashMap<ThreadLocal<Object>, Object> replayThreadLocalValues(@NonNull HashMap<ThreadLocal<Object>, Object> captured) {
    final HashMap<ThreadLocal<Object>, Object> backup = new HashMap<ThreadLocal<Object>, Object>();

    for (Map.Entry<ThreadLocal<Object>, Object> entry : captured.entrySet()) {
        final ThreadLocal<Object> threadLocal = entry.getKey();
        backup.put(threadLocal, threadLocal.get());

        final Object value = entry.getValue();
        if (value == threadLocalClearMark) threadLocal.remove();
        else threadLocal.set(value);
    }

    return backup;
}

总结下这一步干了什么事

  1. 从captureRef中获得ttl的副本
  2. 备份当前线程的上下文,并将captureRef引用的副本都写入inheritThreadLocal中
  3. 执行方法
  4. 还原当前线程的上下文

这些步骤执行完,我们就可以在线程池中拿到ttl变量了

holder的作用

从上面的代码我们可以发现,holder就是存放ttl的地方,我们看一下holder的结构


private static final InheritableThreadLocal<WeakHashMap<TransmittableThreadLocal<Object>, ?>> holder =
    new InheritableThreadLocal<WeakHashMap<TransmittableThreadLocal<Object>, ?>>() {
        @Override
        protected WeakHashMap<TransmittableThreadLocal<Object>, ?> initialValue() {
            return new WeakHashMap<TransmittableThreadLocal<Object>, Object>();
        }

        @Override
        protected WeakHashMap<TransmittableThreadLocal<Object>, ?> childValue(WeakHashMap<TransmittableThreadLocal<Object>, ?> parentValue) {
            return new WeakHashMap<TransmittableThreadLocal<Object>, Object>(parentValue);
        }
    };

holder其实是一个全局静态的InheritableThreadLocal<WeakHashMap<TransmittableThreadLocal, ?>>变量,注意holder是全局静态的变量,但是通过holder只能操作到当前线程的上下文,理解这点很重要
当我们对ttl进行set时,除了将值存到inheritThreadLocals,还会将值存入到holder中。但需要注意,get时,ttl还是从inheritThreadLocals中读取值,而不是holder。

可以理解为holder就是存取ttl入口,以及为了使ttl具备线程池传递性的缓冲层,而ttl真正存储的位置还是inheritThreadLocals

上一篇:Hive UDF 实验1


下一篇:RabbitMQ的高级特性--TTL、死信队列、延迟队列