在GPU上加速RWKV6模型的Linear Attention计算

精简版:经过一些profile发现flash-linear-attention中的rwkv6 linear attention算子的表现比RWKV-CUDA中的实现性能还要更好,然后也看到了继续优化triton版本kernel的线索。接着还分析了一下rwkv6 cuda kernel的几次开发迭代以此说明对于不懂cuda以及平时无法从擅长cuda的大佬身上取经的人比如我就完全放弃cuda了,可以深入学一下和使用triton,这已经完全足够了(除了会写之外还可以了解内部的MLIR相关的编译器知识,可以对GPU体系架构理解得更加深刻)。

0x0. 前言

本文主要讲一些看到的RWKV 6模型的Linear Attention模块推理加速方法,在这篇博客中暂不涉及对kernel的深入解析。首先,flash-linear-attention(https://github.com/sustcsonglin/flash-linear-attention )这个仓库旨在对各种线性Attention架构进行工程加速,例如RetNet,GLA,Manba,RWKV6(2024年4月引入)。它使用Triton来编写代码,并针对不同的线性Transformer架构使用不同的优化方式。例如对于RWKV 6就采用在时间维度进行kernel fuse的方式来加速。其次,RWKV-CUDA是RWKV系列模型迭代中针对Linear Attention模块的改进开发的自定义高性能cuda kernel(https://github.com/BlinkDL/RWKV-CUDA)。flash-rwkv(https://github.com/BBuf/flash-rwkv)仓库在RWKV-CUDA的最优性能算子的基础上进行了封装,提供了rwkv5_cuda_linear_attentionrwkv6_cuda_linear_attention两个接口方便在HuggingFace模型实现中直接加速推理的prefill阶段速度。

本篇文章主要会对比一下RWKV6 Linear Attention模块的naive实现(pure pytorch),RWKV-CUDA的RWKV6 Linear Attention cuda kernel实现(用flash-rwkv提供的接口进行测试),flash-linear-attention里的RWKV6 Linear Attention实现。来说明Triton已经成为目前LLM时代开发的一个趋势,小伙伴们确实可以学起来。目前我对Triton的了解也非常少而且很肤浅,后续也会持续学习和实践。

下面列举本文相关的资料,如果你想对RWKV 6这个架构有一些了解可以阅读后面三个链接,当然不阅读也不影响阅读本文:

  • https://github.com/sustcsonglin/flash-linear-attention
  • https://mp.weixin.qq.com/s/Vol_LeHVHDAwE1pWTHOl2Q
  • 梳理RWKV 4,5(Eagle),6(Finch)架构的区别以及个人理解和建议
  • RWKV 模型保姆级微调教程

另外,本文使用了PyTorch Profiler TensorBoard 插件来做程序的性能分析,感兴趣的小伙伴可以在系统调优助手,PyTorch Profiler TensorBoard 插件教程 获取到详细的教程。

0x1. 瓶颈是什么

RWKV6 推理 Prefill 阶段的性能瓶颈就在于RWKV6模型代码中的rwkv6_linear_attention_cpu函数:https://huggingface.co/RWKV/rwkv-6-world-1b6/blob/main/modeling_rwkv6.py#L54-L104

def rwkv6_linear_attention(
    training,
    receptance,
    key,
    value,
    time_decay,
    time_first,
    state,
):
    no_cuda = any(t.device.type != "cuda" for t in [time_decay, time_first, receptance, key, value])
    # Launching the CUDA kernel for just one token will actually be slower (there is no for loop in the CPU version
    # in this case).
    one_token = key.size(1) == 1
    if no_cuda or one_token:
        return rwkv6_linear_attention_cpu(
            receptance, key, value, time_decay, time_first, state
        )
    else:
        ...

这里的判断是如果是decode阶段(对比prefill阶段)或者非GPU模式执行代码,就使用rwkv6_linear_attention_cpu这个算子,否则就使用优化后的实现比如使用这里的cuda kernel(https://github.com/BlinkDL/RWKV-CUDA/tree/main/wkv6)编译出的CUDA Kernel。flash-linear-attention库的目的是使用Triton来加速rwkv6_linear_attention_cpu这个naive的实现。这个naive实现的代码如下:

def hf_rwkv6_linear_attention_cpu(receptance, key, value, time_decay, time_first, state):
    # For CPU fallback. Will be slower and probably take more memory than the custom CUDA kernel if not executed
    # within a torch.no_grad.
    batch, seq_length, _ = receptance.shape
    num_heads, head_size = time_first.shape
    key = key.float().view(batch, seq_length, num_heads, head_size).transpose(1, 2).transpose(-2, -1) # b, t, h, n -> b, h, t, n -> b, h, n, t
    value = value.float().view(batch, seq_length, num_heads, head_size).transpose(1, 2) # b, t, h, n -> b, h, t, n
    receptance = receptance.float().view(batch, seq_length, num_heads, head_size).transpose(1, 2) # b, t, h, n -> b, h, t, n
    time_decay = torch.exp(-torch.exp(time_decay.float())).view(batch, seq_length, num_heads, head_size).permute(0, 2, 3, 1) # b, t, h, n -> b, h, n, t
    time_first = time_first.float().reshape(-1, 1, 1).reshape(num_heads, -1, 1) # h, n -> h * n, 1, 1 -> h, n, 1
    out = torch.zeros_like(key).reshape(batch, seq_length, num_heads, head_size)

    for current_index in range(seq_length):
        current_receptance = receptance[:, :, current_index:current_index+1, :]
        current_key = key[:, :, :, current_index:current_index+1]
        current_value = value[:, :, current_index:current_index+1, :]
        current_time_decay = time_decay[:, :, :, current_index:current_index+1]
        attention_output = current_key @ current_value
        out[:, current_index] = (current_receptance @ (time_first * attention_output + state)).squeeze(2)
        with torch.no_grad():
            # attention_output.shape: [b, h, n, 1] x [b, h, 1, n] -> [b, h, n, n]
            # current_time_decay * state: [b, h, n, 1] * [b, h, n, n] ->[b, h, n, n]
            # state.shape: [b, h, n, n]
            state = attention_output + current_time_decay * state 

    return out, state

这样看代码可能会有点懵,可以看下一节的完整demo测试代码。

0x2. Profile代码编写

上一节明确了,我们需要加速RWKV模型中rwkv6_linear_attention_cpu的计算,https://github.com/sustcsonglin/flash-linear-attention 这个库在2024年4月份支持了RWKV6模型,它加速RWKV 6 Linear Attention计算的核心api有两个,fused_recurrent_rwkv6chunk_rwkv6。现在直接写出profile的代码(https://github.com/BBuf/flash-rwkv/blob/main/profile/profile_rwkv6_linear_attention.py)来对naive的实现,RWKV官方提供的cuda kernel以及fused_recurrent_rwkv6chunk_rwkv6进行性能分析。

import sys
import torch
from fla.ops.rwkv6.chunk import chunk_rwkv6
from fla.ops.rwkv6.recurrent_fuse import fused_recurrent_rwkv6
from flash_rwkv import rwkv6_cuda_linear_attention

def hf_rwkv6_linear_attention_cpu(receptance, key, value, time_decay, time_first, state):
    # For CPU fallback. Will be slower and probably take more memory than the custom CUDA kernel if not executed
    # within a torch.no_grad.
    batch, seq_length, _ = receptance.shape
    num_heads, head_size = time_first.shape
    key = key.float().view(batch, seq_length, num_heads, head_size).transpose(1, 2).transpose(-2, -1)
    value = value.float().view(batch, seq_length, num_heads, head_size).transpose(1, 2)
    receptance = receptance.float().view(batch, seq_length, num_heads, head_size).transpose(1, 2)
    time_decay = torch.exp(-torch.exp(time_decay.float())).view(batch, seq_length, num_heads, head_size).permute(0, 2, 3, 1)
    time_first = time_first.float().reshape(-1, 1, 1).reshape(num_heads, -1, 1)
    out = torch.zeros_like(key).reshape(batch, seq_length, num_heads, head_size)

    for current_index in range(seq_length):
        current_receptance = receptance[:, :, current_index:current_index+1, :]
        current_key = key[:, :, :, current_index:current_index+1]
        current_value = value[:, :, current_index:current_index+1, :]
        current_time_decay = time_decay[:, :, :, current_index:current_index+1]
        attention_output = current_key @ current_value
        out[:, current_index] = (current_receptance @ (time_first * attention_output + state)).squeeze(2)
        with torch.no_grad():
            state = attention_output + current_time_decay * state

    return out, state



if __name__ == "__main__":
    mode = sys.argv[1]
    B = 1
    H = 32
    L = 54
    D = 64
    HIDDEN_SIZE = H * D
    dtype = torch.float32
    
    if mode == 'hf':
        profile_path = '/bbuf/rwkv_profile_result/hf/'
    elif mode == 'recurrent':
        profile_path = '/bbuf/rwkv_profile_result/recurrent/'
    elif mode == 'chunk':
        profile_path = '/bbuf/rwkv_profile_result/chunk/'
    elif mode == 'cuda':
        profile_path = '/bbuf/rwkv_profile_result/cuda'
    else:
        raise NotImplementedError
    with torch.profiler.profile(
        activities=[
            torch.profiler.ProfilerActivity.CPU,
            torch.profiler.ProfilerActivity.CUDA],
        schedule=torch.profiler.schedule(
            wait=1,
            warmup=1,
            active=1),
        on_trace_ready=torch.profiler.tensorboard_trace_handler(profile_path, worker_name='worker0'),
        record_shapes=True,
        profile_memory=True,  # This will take 1 to 2 minutes. Setting it to False could greatly speedup.
        with_stack=True
    ) as p:
        for i in range(10):
            q = (torch.randn(B, L, HIDDEN_SIZE).cuda().to(torch.float16)).requires_grad_(True)
            k = (torch.randn(B, L, HIDDEN_SIZE).cuda().to(torch.float16)).requires_grad_(True)
            v = torch.randn(B, L, HIDDEN_SIZE).cuda().to(torch.float16).requires_grad_(True)
            w = torch.nn.functional.logsigmoid(torch.randn(B, L, HIDDEN_SIZE)).cuda().to(torch.float32).requires_grad_(True)
            u = (torch.randn(H, D).cuda().to(torch.float16)).requires_grad_(True)
            state = (torch.randn(B, H, D, D).cuda().to(torch.float32)).requires_grad_(True)
            if mode == 'hf':
                o1, state1 = hf_rwkv6_linear_attention_cpu(q, k, v, w, u, state)
            elif mode =='cuda':
                o2, state2 = rwkv6_cuda_linear_attention(q, k, v, w, u.flatten(), state)
            else:
                batch, seq_length, _ = q.shape
                num_heads, head_size = u.shape
                k = k.float().view(batch, seq_length, num_heads, head_size).transpose(1, 2) # B, T, H, K -> B, H, T, K
                v = v.float().view(batch, seq_length, num_heads, head_size).transpose(1, 2) # B, T, H, K - > B, H, T, V
                q = q.float().view(batch, seq_length, num_heads, head_size).transpose(1, 2) # B, H, T, K
                w = -torch.exp(w.float()).view(batch, seq_length, num_heads, head_size).permute(0, 2, 1, 3) # B, T, H, K -> B, H, T, K
                u = u.float().reshape(num_heads, head_size) # H, K

                if mode == 'recurrent':
                    o3, state3 = fused_recurrent_rwkv6(q, k, v, w, u, initial_state=state, scale=1.0, output_final_state=True)
                elif mode == 'chunk':
                    o4, state4 = chunk_rwkv6(q, k, v, w, u, initial_state=state, scale=1.0, output_final_state=True)
            p.step()

这段代码就是要分别profile hf_rwkv6_linear_attention_cpurwkv6_cuda_linear_attentionfused_recurrent_rwkv6chunk_rwkv6这三个api看一下它们的性能表现以及GPU kernel的详细使用情况。但这段代码中有一些需要说明的地方:

  • hf_rwkv6_linear_attention_cpu这个api接收的输入Tensor形状和fla包提供的两个加速api的输入Tensor形状不一样,所以在对hf_rwkv6_linear_attention_cpu设定输入之后需要经过一些维度重排操作才能给fla包的两个api使用。
  • 对于time_decay来说,hf_rwkv6_linear_attention_cpu在计算时做了两次exp,而fused_recurrent_rwkv6chunk_rwkv6的api内部会做一次exp,所以输入给fused_recurrent_rwkv6chunk_rwkv6的time_decay只需要做内层的-exp操作就足够了。
  • 对于输出来说,fused_recurrent_rwkv6chunk_rwkv6的结果需要转置一下才能得到和hf_rwkv6_linear_attention_cpu一样的计算结果,state不需要做额外操作,直接就可以对应上。
  • 注意api的调用方式,例如chunk_rwkv6(q, k, v, w, u, initial_state=state, scale=1.0, output_final_state=True)里面的kwargs是缺一不可的。

接下来就可以执行这个profile脚本分别得到这三个api的profile结果了。我在一张NVIDIA A800-SXM4-80GB上进行了profile,结果上传到了 https://github.com/BBuf/flash-rwkv/tree/main/profile/rwkv_profile_result ,你可以通过 tensorboard --logdir=./rwkv_profile_result/recurrent/ --bind_all 这样的命令来可视化结果,并在本地的浏览器中打开 http://localhost:6006/#pytorch_profiler 网址来查看详细的结果。

0x3. Profile结果分析

0x3.1 hf_rwkv6_linear_attention_cpu 函数profile结果

在这里插入图片描述

使用hf_rwkv6_linear_attention_cpu函数进行计算时Kernel部分花了1105us,算子总的时间花了21.5ms,然后它的kernel分布为:

在这里插入图片描述我们可以发现在kernel里面只有gemv相关的矩阵乘调用,并且elementwise算子占比非常大已经接近40%。

0x3.2 rwkv6_cuda_linear_attention API profile结果

在这里插入图片描述
kernel的执行时间为73us,算子执行的总时间只花了4.5ms,相比于naive的实现(21.5)速度有大幅提升。观察GPU kernel执行情况:

在这里插入图片描述
现在rwkv6_cuda_linear_attention中的核心kernel: kernel_forward执行时间为101us。并且现在这个版本只有上面截图的2个kernel有耗时,剩下的2个elementwise的kernel耗时只有2us。在这里插入图片描述
由此可见,使用cuda来编写和优化上面的rwkv6_cuda_linear_attention api可以获得大幅度的性能提升。

0x3.3 fused_recurrent_rwkv6 API profile结果

在这里插入图片描述

现在Kernel执行总时间只有125us,算子总的时间花了5.26ms,相比于naive的实现(21.5)速度有大幅提升,同时kernel的占比也明显更小,GPU kernel分布情况:

在这里插入图片描述
在GPU kernel的具体执行分布中,fused_recurrent_rwkv6_fwd_kernel已经是比例的最大的kernel了,而这个kernel的整体耗时非常低只花了64us,而在naive的实现中则存在数个耗时超过100us的elementwise kernel。目前的整体耗时和优化后的cuda kernel实现也是比较接近的。

0x3.4 chunk_rwkv6 API profile结果

在这里插入图片描述

在这里插入图片描述
chunk_rwkv6的情况和fused_recurrent_rwkv6类似,也是达到了不错的性能。

0x3.5 Profile结果总结

方法 RWKV 6 Linear Attention端到端耗时(us) Kernel最大耗时(us)
hf_rwkv6_linear_attention_cpu 21500 432us
rwkv6_cuda_linear_attention 4500 101us
fused_recurrent_rwkv6 5260 64us
chunk_rwkv6 5602 49us

注:hf_rwkv6_linear_attention_cpu中有很多个耗时比较长的element-wise kernel,性能是最差的,这里只记录了耗时最长的那个element-wise kernel,已经足够说明问题。后续三种方案都通过kernel fuse让hf_rwkv6_linear_attention_cpu实现中的seq_length维度的遍历和众多gemv/elemetwise相关kernel最终fuse成1个或者2个kernel。chunk_rwkv6 api的计算分为2个kernel,耗时分别为27和22us,统计kernel最大耗时的时候进行了求和。

结论:手工优化的rwkv6_cuda_linear_attention在端到端的耗时方面目前是最快的,从上面的profile代码也可以看出来主要原因是因为它不需要对各个输入进行一系列的维度转换,而naive的实现和Triton的实现则必须做一堆维度转换来匹配api提供的计算功能。从Kernel最大耗时的角度看,triton实现的fused_recurrent_rwkv6和chunk_rwkv6 kernel本身的计算是比RWKV-CUDA的手工kernel更快的(虽然还不太清楚Triton实现的版本在编译中发生了什么,但真的找到了放弃cuda的理由,毕竟不是专业做这个东西的,而Triton大家都可以写),后续应该会考虑在Triton kernel的基础上继续做优化以及训练性能验证。

0x4. flash-rwkv库中的rwkv5_cuda_linear_attention开发历程

这里讲一下flash-rwkv库中的rwkv5_cuda_linear_attention这个api背后开发的迭代历程。时间回到2023年8月,ChatGPT的火爆让我也想参与到开源的大模型开发过程中,然后Peng Bo说可以参与到实现RWKV5 CUDA算子的事情。为了锻炼下CUDA就开始参与实现和优化RWKV5 CUDA,在这个过程中也有幸见识到了RWKV开源社区中 https://github.com/Blealtan 这位大佬的优化水平,同时也了解了Parallel Scan算法和实现。后续RWKV6的rwkv6_cuda_linear_attention仍然沿用了rwkv5的cuda kernel,只做了微量的坐标映射修改。

HuggingFace中RWKV5模型的Linear Attention Naive实现在 https://huggingface.co/RWKV/rwkv-5-world-1b5/blob/main/modeling_rwkv5.py#L62-L84 ,贴一下这段代码。

def rwkv5_linear_attention_cpu(receptance, key, value, time_decay, time_first, state):
    input_dtype = receptance.dtype
    # For CPU fallback. Will be slower and probably take more memory than the custom CUDA kernel if not executed
    # within a torch.no_grad.
    batch, seq_length, hidden_size = receptance.shape
    num_heads, head_size = time_first.shape
    key = key.float().view(batch, seq_length, num_heads, head_size).transpose(1, 2).transpose(-2, -1)
    value = value.float().view(batch, seq_length, num_heads, head_size).transpose(1, 2)
    receptance = receptance.float().view(batch, seq_length, num_heads, head_size).transpose(1, 2)
    time_decay = torch.exp(-torch.exp(time_decay.float())).reshape(-1, 1, 1).reshape(num_heads, -1, 1)
    time_first = time_first.float().reshape(-1, 1, 1).reshape(num_heads, -1, 1)
    out = torch.zeros_like(key).reshape(batch, seq_length, num_heads, head_size)

    for current_index in range(seq_length):
        current_receptance = receptance[:, :, current_index:current_index+1, :]
        current_key = key[:, :, :, current_index:current_index+1]
        current_value = value[:, :, current_index:current_index+1, :]
        attention_output = current_key @ current_value
        out[:, current_index] = (current_receptance @ (time_first * attention_output + state)).squeeze(2)
        with torch.no_grad():
            state = attention_output + time_decay * state

    return out, state

要把这段代码变成cuda kernel,首先需要在形式上做一些还原,使得它更靠近原始的计算公式。还原之后的原始计算公式如下(https://github.com/BlinkDL/RWKV-CUDA/blob/main/wkv5/run.py#L67-L87):

def RUN_FORMULA_1A(B, T, C, H, r, k, v, w, u):
    N = C // H
    r = r.view(B, T, H, N)
    k = k.view(B, T, H, N)
    v = v.view(B, T, H, N)
    w = w.view(H, N)
    u 
上一篇:MVC和DDD的贫血和充血模型对比-架构区别


下一篇:Crossplane 实战:构建统一的云原生控制平面