×
Community Blog LLM Inference Acceleration: GPU Optimization for Attention in the Decode Phase

LLM Inference Acceleration: GPU Optimization for Attention in the Decode Phase

This article introduces how the Attention in the decode phase is optimized on GPU based on RTP-LLM practices.

By Jiying Dong

Since Large Language Models (LLMs) are widely used in various fields, how to build high-throughput and low-latency inference services at low cost has become an urgent issue. However, LLM requires a large number of parameters and computation tasks when inferring on GPU so that just single-stream execution can make full use of GPU resources. Taking this into account, we can decompose the inference delay of LLM into kernel level. Therefore, regardless of the small percentage of kernel calculation time, the latency optimization for LLM is decomposed into kernel optimization for GEMM and Attention accordingly.

RTP-LLM is a Large Language Model inference acceleration engine developed by Alibaba's Intelligence Engine team. As a high-performance Large Language Model inference solution, it is widely used within Alibaba Group. In this article, we will introduce how the Attention in the decode phase is optimized on GPU based on RTP-LLM practices.

Background

The following figure shows a common process of calculating an attention score: multiply Q by K, mask the result before SoftMax, multiply it by V, and we'll get an attention score. In the decode phase of LLM inference, due to the addition of KV Cache optimization, only one newly added token needs to be calculated in each iteration, so the calculation can be changed to the calculation between Q(seq == 1) of the current step, K Cache, and V Cache.

1
2

The shape of each tensor in the computing process can be expressed as follows:

Q (B, H, 1, D)
K Cache (B, H_kv, S, D)
V Cache (B, H_kv, S, D)
Q * K Cache (B, H, 1, S)
O (B, H, 1, D)

The following table describes the parameters:

B Batch size / num_seqs
H head_num
H_kv head_num_kv
S seq length of KV Cache
D head_size

In the analysis of this article, we consider a simple Multi Head Attention implementation, namely H == H_kv.

We want to use a kernel to implement the computation in the preceding figure. To achieve better performance, the BiasAdd and Rotary Embedding operations in the previous step are also merged. Therefore, the input accepted by this kernel is Q, K, and V with the QKV GEMM, and BiasAdd is completed in the kernel. Then, Q and K will perform Rotary Embedding together. The current K and V are spliced with the previously computed KV Cache and expanded into the KV Cache with (B, H, S, D). Next, Q is multiplied by K Cache and the result is used to calculate SoftMax in the S dimension. Then, multiply it by V Cache to obtain the final output.

The simplified code example is as follows:

#(B, 3, H, D) -> 3 * (B, H, 1, D)
Q, K, V = add(QKV_buffer, bias)
#(B, H, 1, D) -> (B, H, 1, D)
Q, K = rotary_embedding(Q, K)
#(B, H, 1, D) -> (B, H, S, D) 
K, V = concat(past_KV, K, V)
#(B, H, 1, D) * (B, H, S, D) -> (B, H, 1, S)
res = matmul(Q, K)/ sqrt(self.head_dim)
#(B, H, 1, S) -> (B, H, 1, S)
res = =softmax(res, dim=-1)
#(B, H, 1, S) * (B, H, S, D) -> (B, H, 1, D)
out = matmul(res, V)

During the entire computing process, BiasAdd and Rotary Embedding have relatively low computation requirements and little impact on the latency of the kernel. Therefore, the analysis for this part is omitted in the following section.

Computing and Analysis

We take the current implementation of Masked Multi Head Attention (MMHA) in TensorRT-LLM as an example to analyze how the current MMHA achieves high performance.

When it involves GPU parallel computing, we need to consider splitting the task first. In this scenario, the task can be split clearly: B and H are parallel dimensions, and QK and QKV in the execution process can be understood as a batch GEMV with batch size = B * H. SoftMax is also a Reduce operation, so it is best to compute a single GEMV in one block as much as possible. Therefore, the relatively basic task division of MMHA is roughly:

dim3 grid(B, H, 1);
dim3 block(THREAD_PER_BLOCK, 1, 1);

The THREAD_PER_BLOCK here refers to the number of threads used by each block to compute a head on S. Generally, more threads will increase the active warps of each SM to make better use of computing resources and increase load instructions to improve data load efficiency. Therefore, a larger THREAD_PER_BLOCK is preferred (preferably close to 1024). However, due to the complexity of the overall computing logic of the kernel and the large amount of registers used, threads may be limited by the total number of registers. Besides, under the limit of the total number of registers, we can simply think that there is only one active block on each SM.

Based on this division, we continue to analyze how each block is computed. The actual layout of the QKV buffer passed into the kernel is (B, 3, H, D). In the TensorRT-LLM implementation, the Q and K of the current step are loaded first, the BiasAdd and ROPE are computed, and the K Cache obtained in this step is written back to the global buffer. After completing these computations, since the data is still in the register, the corresponding QK dot is directly calculated. As these computations are less time-consuming, we will skip the analysis for this part and look directly at how TensorRT-LLM calculates Q * K Cache.

The multiplication of Q by K is accumulated in the cache on D. Suppose we use half to store the KV Cache and float for multiplying and accumulation. To ensure load efficiency, each thread will load 16 consecutive bytes of data, which is equivalent to 8 elements. For the common D==128, 16 threads are required to compute a head. It can be regarded that threads in the block are grouped. Each group of 16 threads is responsible for the computation of one head. Each thread reads 8 elements and completes the corresponding multiplying and accumulation of these 8 elements. Then, this group of threads computes the current head through shuffle in the warp and stores the calculation result in SMEM. Groups are expanded on S.

3

Next, we will calculate the SoftMax. Since the previous computations ensure that the inputs required by SoftMax are all in SMEM of the current block, the computation of SoftMax can be completed through Block Reduce Max and Block Reduce Sum.

The calculation idea of multiplying V Cache is very similar to that of multiplying K Cache above. The slight difference is that this step needs to be accumulated on S. The threads are still grouped. Each group of 16 threads is responsible for one head, and each thread is responsible for the computation of 8 elements. Since it needs to be accumulated on S, each thread needs to save the partial accumulation sum of the 8 elements computed by GPUsde currently. Finally, with SMEM, the partial sums on different threads are accumulated to obtain the output of Attention.

4

In addition to HFMA, QK dot can also call HMMA to compute a single head. However, memory access has a kernel performance bottleneck, so the method used for the dot has little impact on performance. This conclusion can also be verified according to our tests.

Some details are still omitted from the above analysis. Specifically, for example, we usually use a paged KV Block Array to store KV Cache now, that is, KV Cache can be inconsecutive in the S dimension so as to dynamically allocate buffers when S grows. However, the storage of the paged KV Block Array does not change the continuity of the D dimension, so it does not affect the above analysis. In addition, each thread will load an extra part into the local register when loading KV Cache, so as to overlap the load data and dot calculation as much as possible.

Mainstream frameworks such as vLLM and xFormers have similar implementation and optimization ideas for MMHA, with only slight differences in details. In addition to MMHA, XQA is also implemented in TensorRT-LLM to continue to optimize the calculation of an attention score in the decode phase, but this article will not analyze the code since it is not open-sourced.

Improvement and Optimization

Of course, the simple optimizations discussed above are still not sufficient in practical applications, especially in scenarios with small batch sizes (B) and long sequences (S).

Considering the actual GPU resources, for example, A100 has 108 SMs and each SM has only one block (that is, only one head is computed), when the product of the batch size and the number of heads (B * H) exactly fills 108 (or a multiple of 108) SMs, it can be considered that the utilization rate is relatively high. Taking the 7B model or the 72B model with 2 transformer layers (TP) as an example, if H = 32 and B = 3, the utilization rate is 88.9%. However, if B = 4, the utilization rate drops to 59% due to the necessity of two rounds of computation. If B = 1, the utilization rate is even as low as 30%. At this time, if S is relatively large, it can be found that most of the device resources are idle, but we still have to wait for some SMs to complete a long-term computing process.

For this situation, we allocate S to grid dim, and the resource allocation is changed to:

dim3 grid(B, H, S_tile);
dim3 block(THREAD_PER_BLOCK, 1, 1);

Under this task division, combined with the above analysis, assuming that there is only one active block on each SM of the long seq, the waves can be computed as:

5

When the waves value is closer to the ceil value, it means a higher device occupancy. In scenarios with a small size of B but a large size of S, if S_tile > 1, the occupancy will be increased. In this case, S_tile blocks jointly compute a head on S, each block is responsible for the calculation of S / S_tile, and the Reduce operation between blocks is completed by enabling an additional global buffer. In this mode, the new global read/write operation is time-consuming. However, since the device occupancy is increased, the performance is significantly improved in scenarios with a small size of B but a large size of S. This is the idea of flash decoding and is supported in all frameworks.

In addition to performance reasons, ultra-long seq must also achieve this implementation. Since the result of Q K needs to be reduced on S, that is, SMEM needs to store intermediate data of the corresponding size. According to the kernel implementation, the input type is half, and the accumulation is performed by float, which can be estimated to be 6 S. Based on A100, the actual available SMEM per SM is 163KB, and the maximum supported S is about 27K. When the input value is greater than this value, we must split it in seq to ensure the kernel calculation.

Another scenario that requires tasks to be split is GQA. During the calculation of GQA, the KV Cache for each head corresponds to the Q of multiple heads. To avoid the repeated loading of the KV Cache, the resource allocation should be adjusted, and the calculation should be modified accordingly.

dim3 grid(B, H_kv, S_tile);
dim3 block(THREAD_PER_BLOCK, 1, 1);

In addition to the optimization for task division, MMHA can also be optimized from the following aspects:

1) Optimizing register usage may help reach a higher occupancy rate (you can launch multiple blocks on one SM or increase the threads of each block).

2) Continue to adjust the loading behavior of the KV Cache, allowing computation and data retrieval to further overlap, so as to ease the memory-bound scenario.

3) With large batch sizes and the GQA model, the attention mechanism will become compute-bound, so the calculation mode should be modified to better utilize tensor cores for accelerating computation.

We will continue to explore and practice with a more flexible and scalable optimization strategy to address the increasingly diverse and complex application scenarios. The optimized kernel will be open sourced in the RTP-LLM. We are looking forward to your feedback.

References

[01] TensorRT-LLM: https://github.com/NVIDIA/TensorRT-LLM
[02] vllm: https://github.com/vllm-project/vllm
[03] xformers: https://github.com/facebookresearch/xformers
[04] flash decoding: https://crfm.stanford.edu/2023/10/12/flashdecoding.html
[03] RTP-LLM: https://github.com/alibaba/rtp-llm


Disclaimer: The views expressed herein are for reference only and don't necessarily represent the official views of Alibaba Cloud.

0 1 0
Share on

Alibaba Cloud Community

1,057 posts | 259 followers

You may also like

Comments