KV-Cache in LLM Inference
The KV-cache is the single most impactful performance optimization in LLM inference. Understanding it is essential for anyone building or scaling AI systems.
The problem it solves
Section titled βThe problem it solvesβTransformers use attention: every output token attends to every previous token. Without caching, generating a 1000-token response requires recomputing attention for tokens 1β999 when generating token 1000.
Without KV-cache: Token 1: 1 attention computation Token 2: 2 attention computations Token N: N attention computations Total: O(NΒ²) computations β quadratic!
With KV-cache: Token 1: 1 computation β store K,V in cache Token 2: 1 computation β append K,V to cache Token N: 1 computation β append K,V to cache Total: O(N) computations β linear!What gets cached
Section titled βWhat gets cachedβFor each transformer layer, every token produces three vectors: Query (Q), Key (K), and Value (V). During generation, Q changes every step (new token), but K and V for previous tokens are static β they never change.
import torch
def attention_with_kv_cache(query, key_cache, value_cache, new_key, new_value): """ query: [1, d_head] β current token only key_cache: [seq_len, d_head] β all previous K vectors value_cache: [seq_len, d_head] β all previous V vectors """ # Append new K, V to cache key_cache = torch.cat([key_cache, new_key.unsqueeze(0)], dim=0) value_cache = torch.cat([value_cache, new_value.unsqueeze(0)], dim=0)
# Attention over full cached sequence scores = torch.matmul(query, key_cache.T) # [1, seq_len] scores = scores / (query.shape[-1] ** 0.5) weights = torch.softmax(scores, dim=-1) output = torch.matmul(weights, value_cache) # [1, d_head]
return output, key_cache, value_cacheMemory cost of KV-cache
Section titled βMemory cost of KV-cacheβThis is where storage becomes critical. KV-cache size for a single request:
def kv_cache_bytes(num_layers, num_heads, head_dim, seq_len, dtype_bytes=2): """ num_layers: e.g. 32 (Llama-2-7B has 32 layers) num_heads: e.g. 32 head_dim: e.g. 128 seq_len: max context length, e.g. 4096 dtype_bytes: 2 for float16/bfloat16 """ per_token = num_layers * num_heads * head_dim * dtype_bytes * 2 # K + V total = per_token * seq_len return total
# Llama-2-7B, 4096 token context, bfloat16size = kv_cache_bytes(32, 32, 128, 4096, 2)print(f"KV-cache per request: {size / 1e9:.2f} GB")# Output: KV-cache per request: 1.07 GBWith 40 GB GPU RAM and 1 GB per request, you can serve ~37 concurrent users β thatβs your throughput ceiling.
PagedAttention (vLLM)
Section titled βPagedAttention (vLLM)βvLLM solved the KV-cache memory fragmentation problem by treating GPU memory exactly like OS virtual memory β paged, demand-allocated.
Traditional KV-cache: PagedAttention (vLLM):ββββββββββββββββββββββ ββββββββ ββββββββ βββββββββ Request A (2048 tok)β βPage 1β βPage 3β βPage 5β β Request Aβ [pre-allocated] β ββββββββ ββββββββ βββββββββ [wasted if shorter]β ββββββββ ββββββββββββββββββββββββββββββ€ βPage 2β βPage 4β β Request Bβ Request B (512 tok) β ββββββββ βββββββββ [padded to max] βββββββββββββββββββββββ Pages allocated on demand β no wasteResult: vLLM achieves 2β4x higher throughput than naive implementations on the same hardware.
Key takeaways
Section titled βKey takeawaysβ| Concept | Detail |
|---|---|
| Whatβs cached | K and V vectors per layer, per token |
| Memory grows | Linearly with sequence length |
| GPU memory limit | Defines max concurrent requests |
| vLLM innovation | Paged KV-cache β eliminates fragmentation |
| Quantization win | INT8 KV-cache halves memory β 2x more users |
See also: Flash Attention, GPU Memory Management, vLLM Architecture