LLMs Inference speed up EP1 - kv cache

6 minute read

Published:

Large Language Models (LLMs) have revolutionized the field of natural language processing, enabling significant advancements in tasks such as language translation, text summarization, and sentiment analysis. However, despite their impressive capabilities, LLMs are not without limitations. One of the most significant challenges facing LLMs today is the problem of inference speed. Due to the sheer size and complexity of these models, the process of making predictions or extracting information from them can be computationally expensive and time-consuming. Several ways to speed up the LLMs without updating hardware:

a. Parallelism: Transformers are designed to work in parallel (Self-attention). One can take advantage of parallel processing on multi-core CPUs or multi-GPU systems.

b. Quantization: Quantizing the weights and activations of neural networks can help reduce the computational requirements for inference, meanwhile save memory usage.

c. Pruning: Pruning techniques can be used to remove redundant connections in neural networks, thereby reducing their size and computational requirements.

d. Knowledge Distillation: This technique involves transferring knowledge from a larger pre-trained model (the teacher) to a smaller model (the student). By leveraging the knowledge learned by the teacher, the student model can achieve better performance with less computational resources.

There are simplier ways, from engineering perspective, to optimizae LLM inference without changing the model architecture: 1. Reduce duplicate calculations (fewer FLOPs), 2. Reduce data transfer time (flash attention)

KV Cache

During inference, an LLM produces its output token by token, also known as autoregressive decoding. Each generated token depends on all previous tokens, including those in the prompt and any previously generated output tokens.

Therefore, the Key and Value of previous tokens are involved in the computation for the next token. The long queue of tokens can cause computation bottleneck in the self-attention stage, if we don’t use the K, V calculated before. To better illustrate this issue, here’s an example shows how self attention works.

The attention score of the first token is obtained by: \(Att_1(Q,K,V)=softmax(\frac{Q_1K^T_1}{\sqrt{d}})\vec{V_1}\) where $K=W_k x$, $Q=W_q x$, $V=W_v x$, $x$ is embedding vector, $d$ is the dimension of embedding vectors, we can overlook it in this discussion. When generating the second token, attention matrix: \(\begin{aligned} Att(Q,K,V) &=softmax( \begin{bmatrix} Q_1K^T_1 & - \infty \\ Q_2K^T_1 & Q_2K^T_2 \end{bmatrix} ) \begin{bmatrix} \vec{V_1} \\ \vec{V_2} \end{bmatrix} \\ &= \begin{bmatrix} softmax(Q_1K^T_1)\vec{V_1} \\ softmax(Q_2K^T_1)\vec{V_1} + softmax(Q_2K^T_2)\vec{V_2} \end{bmatrix} \end{aligned}\) Attention score for the second token is: \(Att_2(Q,K,V)=softmax(Q_2K^T_1)\vec{V_1} + softmax(Q_2K^T_2)\vec{V_2}\) Likewise, we can get the attention score for the thrid token: \(Att_3(Q,K,V)=softmax(Q_3K^T_1)\vec{V_1} + softmax(Q_3K^T_2)\vec{V_2} + softmax(Q_3K^T_3)\vec{V_3}\) We can see Key and Value used from last iteration are reused to generate the next token, thus store K, V to reduce computation.

kvcache

Nice gif from joao lages.

Memory is eaten up? KV cache can be very large, sometimes up to several GB. Let's see its size if data is stored in fp16 (2 bytes) for a single batch: $$ kv_size = 2*2*d*n_layers * max_context_length $$ Note that for Grouped-query Attention (GQA), multiple heads shared the same Key and Value, which could reduce the kv cache size.
GQA
There are research ongoing showing that quantized (with some tricks) KV cache also works as well, things aren't that bad. Overall, more memory consumption for less computation and faster inference, it's a fair trade-off.


PagedAttention

In the begining, KV cache is simply saved in VRAM, which requires large contiguous memory allocation. PagedAttention finds that there’s a considerable amount of memory being wasted in the KV cache because of excessive reservation. Instead of directly allocate the maximum amount of memory to fill the model context length, no matter how long the real prompt and generation would be, Pagedattention dynamicly allocates memory and save KV cache in non-contiguous blocks:


pagedattention

Pagedattention allows more sophisticated management, for instance, when multiple inference requests share the same large initial system prompt, it’s wise to save key and value vectors for the initial prompt once and share among requests. Now you can tell why the Pagedattention is so popular: it increases model throughput without any architecture modification, functions as a cache management layer that can be seamlessly integrated with various LLMs, especially useful in scenarios where LLMs handle large batches of prompts.

Further thoughts on memory management, pre-allocating memory is not bad as long as the allocated blocks are filled. vAttension pre-reserves virtual memory (continuous) to save the time on looking up block tables in Pagedattention, and allocates physical memory in runtime – one page at a time and only when a request has used all of its previously allocated physical memory page.

Flash Attention

Authors of Flash attention found that attention operation is memory-bound caused by tensor transfer within a GPU, which means instead of targeting at matrix multiplication and reducing FLOPs, an IO-aware attention is the key to acceleration and overcome speed bottleneck. To better understand the intention of flash attention, here’s the structure of GPU memory, standard attention implementation and its time cost.

standard attention and gpu memory

To perform a standard attention operation in GPU, tensors generated in intermediate steps are transfered between SRAM and HBM multiple times, results in much more time cost than matrix multiplication. In comparison, flash attention works in the following way:

fa

The core is matrix split in softmax (also known as tiling). Recall the softmax computation of vector $x \in R^B$:

\[\begin{aligned} &m(x):=\max_{i} (x_i)\\ &f(x):=[e^{x_1-m(x)},...,e^{x_B-m(x)}]\\ &l(x):=\sum_if(x)_i\\ &softmax:=\frac{f(x)}{l(x)} \end{aligned}\]

For vector $x^{(1)},x^{(2)} \in R^B$, the softmax of concatenated vector $x=[x^{(1)},x^{(2)}] \in R^{2B}$ is computed as:

\[\begin{aligned} &m(x)=m([x^{(1)},x^{(2)}])=\max(m(x^{(1)}),m(x^{(2)}))\\ &f(x)=[e^{m(x^{(1)})-m(x)}f(x^{(1)}),e^{m(x^{(2)})-m(x)}f(x^{(2)})]\\ &l(x)=e^{m(x^{(1)})-m(x)}l(x^{(1)})+e^{m(x^{(2)})-m(x)}l(x^{(2)})\\ &softmax=\frac{f(x)}{l(x)} \end{aligned}\]

In this way, we can split $Q,K,V$ into blocks and incrementally perform the softmax reduction, compute softmax one block at a time. Flash attention does not store intermediate values ($S,P \in R^{N \times N}$ in algorithm 1) to compute gradients w.r.t $Q,K,V$ during backward pass, instead $S.P$ are recomputed by storing the output O and the softmax normalization statistics $(m,l)$. Recomputation not only reduces the required memory, but also speeds up the backward pass due to the reduction of HBM accesses. That’s why in algorithm 1 the block sizes are set to $\frac{M}{4d}$: blocks of $Q,K,V,O$ are loaded to SRAM and $m,l$ are kept in registers during computation. Another graph to illustrate:

flash

The final road block when trying to understand algorithm 1 may lies in step 12. $O_i$ is intermediate computed attention result used during iteration. Since $S,P$ are not stored in memory, $\text{diag} (l_i) e^{m_i-m_i^{new}} O_i$ is to recover $PV$ from last iteration. Diagonal matrix here is a way to operate mul/div on matrix row by row, and the exponential part simply updates the new max value, these small mathematical tricks may be clearer after you verify it yourself. More details and proofs are provided in the paper.