KV Cache Explained with Examples from Real World LLMs
Published on Aug 22, 2025
Get Started
Fast, scalable, pay-per-token APIs for the top frontier models like DeepSeek V3 and Llama 3.3. Fully OpenAI-compatible. Set up in minutes. Scale forever.
When you run a large language model, repeating attention over previous tokens wastes time and resources, especially when you need quick responses in chat or streaming apps. KV Cache Explained shows how storing past key value tensors lets the model skip recomputing attention for earlier tokens, cutting latency and lowering compute cost while preserving context across the context window for improved LLM inference optimization. Want to know how that trade-off affects memory, throughput, and prompt length? This guide will help you reach one clear goal: To clearly understand how KV caching works and walk away confident enough to apply it in real-world LLM systems for faster, more efficient inference.
To reach that goal, Inference's AI inference APIs give you a fast path to try caching with real models, measure speed and cost gains, and deploy optimized inference without building in-house infrastructure.
What is KV Caching and How Does It Work?

Self-attention turns a sequence of token embeddings into context vectors by asking, for each token, Which other tokens matter right now? The model makes three linear projections of the input embeddings:
- Queries
- Keys
- Values
Think of each token carrying a question vector q, an index vector k, and a content vector v. You compute a score for how much q should attend to every k by taking q · k, scale and softmax those scores into attention weights, then form the context as the weighted sum of the v vectors. In matrix form, this is fast.
Q @ K^T gives all pairwise scores, softmax yields attention weights, and those weights multiply V to provide the output. For autoregressive generation, you also apply a causal mask so a token cannot attend to future tokens. This single mechanism underpins causal transformers and is what KV caching targets.
Why Transformers Repeat Work During Generation
Autoregressive models generate one token at a time. At step t the model computes Q, K, V for the whole input prefix of length t, then does Q_t @ K_t^T and uses V_t to make the output. At step t+1 the same model recomputes K and V for the old tokens even though they never change.
The query vectors for earlier tokens, the key projections of earlier embeddings, and the weight matrices are all constant during inference. That repeated projection and the repeated Q@K multiplications cause waste. Can you imagine recomputing the index cards for a library every time you add a new book? KV caching saves you from that.
What Kv Caching Means In Plain Words
KV caching stores the per-token key and value projections from each attention layer so you do not recompute them every time you extend the sequence. When the model emits a new token you only compute the new token’s Q (and its K and V once), then reuse the cached K and V for every earlier token.
The cache acts like an index of previously computed keys and their corresponding values, allowing you to read them quickly. This reduces the amount of work per generation step from redoing the whole attention matrix to only computing interactions between the new token and the stored history.
Questions That Guide Optimization Choices
- Is your priority single-token latency or throughput for many prompts?
- What is the expected max generation length per request?
- How much GPU memory can you afford for cached K and V?
- Are you using rotary or absolute position encodings and can you ensure position ids stay aligned?
Answering these will shape whether you enable full caching, use sliding windows, or adopt beam-level cache cloning.
Short, Actionable Checklist to Implement KV Caching Safely
- Enable use_cache or the library equivalent.
- Inspect past_key_values shape and dtype so you can estimate memory use.
- Preallocate buffers when you anticipate long generations.
- For beams, clone cache per beam or use a batched beam-aware layout.
- Validate with known short prompts to ensure position ids and causal masks remain correct.
Would you like a small example in PyTorch showing how past_key_values is read and updated step by step?
Related Reading
- Model Context Protocol
- Speculative Decoding
- Lora Fine Tuning
- Gradient Checkpointing
- LLM Quantization
- LLM Use Cases
- Post Training Quantization
- vLLM Continuous Batching
KV Cache Explained for Optimizing Transformer Inference Efficiency from Scratch

Open the two scripts provided. One implements the model without a cache, the other adds the cache. Scan gpt_with_kv_cache.py for lines labeled # NEW to see each change in context.
You can also run a file diff tool between gpt_ch04.py and gpt_with_kv_cache.py to inspect the exact edits. Want a fast check? Look for register_buffer calls, a use_cache flag in forward methods, and a reset_cache helper.
Registering Cache Buffers Inside MultiHeadAttention
Add two non-persistent buffers to hold the concatenated keys and values across incremental decoding steps. Use register_buffer so these tensors live on the same device as the module but do not become model parameters.
Example:
- self.register_buffer("cache_k", None, persistent=False)
- self.register_buffer("cache_v", None, persistent=False)
Why non-persistent? You don’t want these temporary tensors saved with model weights. Why buffers and not parameters? They are runtime state, not learnable weights. The initial value is None. That signals an empty cache ready to accept the first keys and values during the first forward when use_cache=True.
Forward Pass With Use_cache Flag
Add a use_cache argument to the MultiHeadAttention forward. Compute keys and values for the current input tokens, then either use the newly computed tensors or append them to the cache.
Key code flow
Compute keys_new
, values_new
, and queries
from the current token batch.If use_cache
is True:
- If cache is empty: set
cache_k = keys_new
andcache_v = values_new
. - Else: append (
torch.cat
) the new keys and values to the existing cache. - Use the concatenated
cache_k
,cache_v
for attention.
If use_cache
is False:
- Use
keys_new
,values_new
directly.
# Project to queries, keys, values
keys_new = self.W_key(x) # (b, seq_len, d_out)
values_new = self.W_value(x) # (b, seq_len, d_out)
queries = self.W_query(x) # (b, seq_len, d_out)
if use_cache:
if self.cache_k is None:
# First tokens — initialize cache
self.cache_k, self.cache_v = keys_new, values_new
else:
# Append to existing cache
self.cache_k = torch.cat([self.cache_k, keys_new], dim=1)
self.cache_v = torch.cat([self.cache_v, values_new], dim=1)
# Use full cache for attention
keys, values = self.cache_k, self.cache_v
else:
# No caching — just use current batch
keys, values = keys_new, values_new
Notes on Shapes and Attention Masks
Keep a careful track of dimensions: concatenation uses the token axis dim=1. Attention score calculation still queries all cached keys. If you use causal masking, make sure mask size grows with the cached sequence length. Mistakes here produce wrong attention behavior.
Clearing The Cache Between Generation Sessions Reset_cache
Provide a small reset helper on the attention block to clear its buffers so a fresh prompt won’t attend to prior tokens.
Example:
def reset_cache(self): self.cache_k, self.cache_v = None,
NoneCall this on every block before starting a new incremental generation session. Forgetting to clear leads to stale context bleeding into new outputs.
Propagating Use_cache and Position Tracking in The Full Model
Add a current_pos counter to the GPTModel to track how many tokens have already been stored. When use_cache=True, compute position ids starting at current_pos and increment the counter by the incoming seq_len.
Example:
if use_cache:
# Start positions where the cached sequence left off
pos_ids = torch.arange(
self.current_pos,
self.current_pos + seq_len,
device=in_idx.device,
dtype=torch.long
)
self.current_pos += seq_len
else:
# Fresh sequence, start from zero
pos_ids = torch.arange(
0,
seq_len,
device=in_idx.device,
dtype=torch.long
)
# Positional embeddings
pos_embeds = self.pos_emb(pos_ids).unsqueeze(0)
# Add to token embeddings
x = token_embeds + pos_embeds
Attach pos_embeds = self.pos_emb(pos_ids).unsqueeze(0) and add to token embeddings.
This ensures positional embeddings for newly fed tokens align immediately after previously cached tokens so attention indices line up.
Alternative offset approachInstead of current_pos, you can compute an offset per block like offset = block.att.cache_k.shape[1] to infer the starting position from the cached key length.
TransformerBlock change — thread use_cache down
Adjust TransformerBlock.forward to accept use_cache and forward that flag into its attention submodule:def forward(self, x, use_cache=False):
x = self.att(x, use_cache=use_cache)
# continue with feed-forward and residual connections
Using the Cache in Generation
Use the cache to avoid recomputing keys/values for past tokens. The generation loop becomes:
- Reset model caches and call model once with the full prompt and use_cache=True to populate caches for the prompt.
- For each new token: a) Pick the next token from logits[:, -1].argmax(dim=-1, keepdim=True). b) Append it to the running sequence if you want to track the full output. c) Call model(next_token, use_cache=True) so the model computes keys/values only for the new token and attends to the combined cache.
# Reset cache before starting generation
model.reset_kv_cache()
# Encode the prompt
idx = prompt_tokens
# Run the model once on the prompt (initialize cache)
with torch.no_grad():
logits = model(idx, use_cache=True)
# Generate tokens one by one
for _ in range(max_new_tokens):
# Pick the next token (greedy decoding)
next_idx = logits[:, -1].argmax(dim=-1, keepdim=True)
# Append to sequence (optional, if full sequence tracking is needed)
idx = torch.cat([idx, next_idx], dim=1)
# Forward pass only on the new token (cache makes this efficient)
with torch.no_grad():
logits = model(next_idx, use_cache=True)
Note how the first full forward populates cache for prompt tokens. Subsequent steps pass single token inputs.
Performance Trade-Offs and Optimization Tips
This simple concat-based cache yields big savings in compute but carries overhead from repeated torch.cat operations and Python-level control. On small CPU runs, you may already see 4 to 6 times lower latency for multi-token generation. For production:
- Pre-allocate tensors for cache up to max context and write into slices instead of concatenating.
- Use integer offsets to index preallocated key and value tensors.
- Prefer contiguous memory and avoid repeated reallocations.
- For GPUs, use kernels that support incremental attention or FlashAttention-style kernels that handle long context efficiently.
- When batching different prompts, either keep separate caches per prompt or implement packed batch offsets to avoid cross-contamination.
- For beam search, maintain per-beam caches or implement pointer copying to reuse stored keys and values across beams.
- Monitor memory usage: cached keys and values scale linearly with context length and number of heads.
Validation and Common Pitfalls
Check functional parity: run the same prompt without caching and with caching and verify that the generated token sequence is identical. If outputs diverge, inspect these common failure points:
- Off by one in pos_ids or current_pos
- Forgetting to reset caches between prompts.
- Incorrect concatenation axis or shapes on keys/values.
- Wrong attention masks relative to cache length.
- Using in-place ops that unintentionally mutate tensors shared across batches.
Testing Strategies
- Start with deterministic seeds and tiny inputs so differences surface quickly.
- Compare intermediate attention score matrices between cached and non cached runs for the same final sequence.
- Log cache shapes after each step to verify the token axis matches expected counts.
Questions for You
Do you want the version that preallocates caches and writes with offsets, or are you happy with the readable concat style while exploring correctness first?
Related Reading
- LLM Performance Metrics
- LLM Serving
- Pytorch Inference
- Serving Ml Models
- LLM Benchmark Comparison
- Inference Optimization
- Inference Latency
- KV Cache Explained
- LLM Performance Benchmarks
KV Cache Management in Production Systems
This section explains the performance wins and the resource trade-offs that come with keeping past keys and values in memory for autoregressive decoding. A KV cache converts attention work from repeated recomputation into reuse: compute each key and value once and read them for every next token. That moves compute from quadratic across tokens to roughly linear during generation, which cuts GPU FLOPs and shortens end-to-end decode time when the same context is reused.
The cache grows with each new token and each transformer layer, so GPU memory usage rises linearly with sequence length. For large models and long sessions, the KV memory footprint can dominate the device, forcing truncation or offload strategies that cost latency or complexity.
When Sequence Length Grows: How Computational Efficiency and Memory Scale
As sequence length increases, the computational advantage becomes more obvious. Without caching the attention at step t, it recomputes keys and values for all prior tokens so work accumulates toward O(n squared). With a KV cache, each token’s key and value are stored and read later, keeping per step work closer to O(n) over the whole generation.
The downside grows in parallel:
- Each appended token adds tensors proportional to the number of layers, heads, and head dimension, so memory grows linearly.
- Truncating the cache or windowing the context reduces memory but sacrifices long-range coherence and forces recomputation for the discarded portion.
Engineering The KV Cache: Practical Implementation Patterns That Run in Production
Preallocate large contiguous buffers for each layer and write into slices instead of repeatedly concatenating tensors. Organize storage as [layer, head, seq position, head dim] or merge head and head dim into a packed shape to make GPU memory accesses contiguous. Keep dtypes compact, typically fp16 or bf16 for modern GPUs, and consider 8-bit quantization for very large contexts.
Use a single contiguous GPU buffer per model instance and track write offsets with integer counters. When you must expand capacity, grow by large chunks to avoid repeated reallocation. If you need CPU-level memory, use pinned host memory for fast GPU copies and consider background DMA transfers to hide copy cost.
Common Pitfalls at Scale and How to Fix Them
- Repeated torch.cat calls and many small allocations create fragmentation and spikes in allocation time.
- Replace concatenation with preallocated tensors and index writes.
- Non-contiguous tensors cause extra clones during attention ops; enforce contiguous storage or call .contiguous at controlled points.
- Watch dtype mismatches that force automatic casting and extra buffers.
- Concurrency errors arise when multiple inference threads share a cache without locks; use atomic counters or per-request copy-on-write for shared buffers.
- Track torch.cuda.memory_allocated and the allocator peak to spot leaks, and use unit tests that simulate long sequences to reproduce growth patterns.
Cache Invalidation That Fits Real User Behavior and Memory ProfilesSession-Based Clearing
Clear the cache at session end or on explicit conversation reset. This is simple to implement:
- Tie the cache lifetime to a session ID
- Free associated GPU buffers when the session terminates
- Guard against abandoned sessions by an inactivity timer
For short independent conversations, this minimizes wasted memory while keeping logic easy to reason about
Time to Live Expiration
Assign a TTL to cached contexts and sweep entries that exceed that age. Implement a background sweeper or lazy eviction that checks timestamps on access.
Use coarse TTL buckets to reduce metadata overhead and avoid scanning all caches on each tick. TTL works well when cached relevance decays predictably and lets you balance memory pressure against the need to serve semi-recent contexts quickly
Contextual Relevance-Based Eviction
Evict as soon as the cached context no longer matches the current task or topic. Detect context switches by:
- Comparing user metadata
- Document IDs
- Explicit user actions
For example, move to a new code project and drop the old project cache. This requires heuristics or signal detection and is the best fit when sessions contain multiple distinct topics
How To Implement Safe Invalidation
Store metadata per cache:
- Session ID
- Last access timestamp
- Context fingerprint
- Size in bytes
Use a priority queue keyed by eviction score or timestamp for fast selection. Free large contiguous buffers back to the allocator rather than leaking small fragments, and signal running inference to fallback to recompute if its cache is in the process of being evicted.
Cache Reuse and Warm Start Patterns That Save Compute
When multiple requests share a prefix context, they reuse the same cached keys and values. Precompute KV for common prompts and keep them as warm starters. Use canonicalization to detect identical prefixes across users.
For shared caches, implement reference counting and copy-on-write to avoid accidental mutation. To share across processes, use CUDA IPC or pinned host memory with a small control plane that grants read-only access. For extensive shared contexts, consider storing a compressed quantized version and decompressing into GPU memory on demand.
Batching Strategies And The Latency Throughput Trade Off
Do you prioritize single request latency or overall throughput? The answer drives the batching strategy. For high throughput, group requests by decode step into large batches so the attention kernel sees many queries at once.
Batching increases GPU utilization and amortizes kernel launch and memory overhead, but adds queuing latency per request. For low latency, use micro-batching or empty batching windows with small timers: collect requests for 5 to 20 milliseconds then run a batch.
Provide priority lanes:
Route latency-sensitive traffic to small batches and bulk traffic to large batches. Use dynamic batching that pads queries to the same decode position and merges their queries into the attention call to reuse the stored keys and values without duplicating cache buffers
Eviction Policies That Control Memory Under Pressure
LRU and LFU are standard. LRU works well when recent context predicts reuse. LFU helps when a small set of contexts is repeatedly accessed. Implement approximate LRU like CLOCK for lower overhead in large scales.
Combine size-based eviction with recency:
Evict entries that free the most bytes first when under pressure. For multi-tenant systems enforce per-tenant quotas and hard caps so noisy tenants cannot evict other tenants’ caches.
When eviction occurs consider offloading the evicted cache to CPU or NVMe in compressed form so it can be restored with some latency instead of full recompute
Sharding and Distributed Serving: How to Scale Across GPUs and NodesSharding Options
Shard KV by layer, by head, or by sequence ranges. Layer sharding fits pipeline and tensor parallel models: each GPU holds the KV for a subset of layers.
This keeps per per-device memory lower but requires network communication each decode step. Sequence sharding splits long sequences across nodes and trades communication for local memory savings.
Consistency and Remote Fetches
When you fetch remote keys add asynchronous prefetching and local cache to hide network latency. Use GPUDirect RDMA where possible to move tensors directly between devices without staging in host memory. Design a fallback path for cache misses so the caller can continue decoding by recomputing missing keys or waiting with a timeout.
Replication and Recovery
Keep a lightweight metadata log of cache ownership and offsets to reconstruct the cache after node failure. For hot shared contexts, replicate a small number of entries to multiple nodes so a single failure does not stall requests. For strict correctness, avoid writing duplication; prefer read-only sharing of precomputed prefixes
Monitoring and Alerting: Metrics and Traces to Catch Problems Early
Track these metrics per model and per device
- Total KV bytes allocated and number of active cache entries
- Cache hit rate and miss rate
- Eviction count and reclaimed bytes
- Allocation rate and peak GPU memory used
- Per step latency and tail latency percentiles
- Batching statistics: average batch size and queued wait time
Export metrics to Prometheus and visualize in Grafana. Correlate spikes in allocation rate with increased observed tail latency. Add sampling traces that record when code falls back to recompute due to a missing cache entry.
Debugging Common Production Faults
Symptoms to hunt for
- Growing GPU memory without freeing: look for lingering references to tensors in Python scopes or scheduler tasks
- Sudden latency spikes when cache grows: check for fragmentation caused by many small frees and allocs
- Non-contiguous tensors or implicit copies: instrument memory copies and contiguous checks
- Race conditions when multiple threads write to the same buffer: enforce per-session locks or use compare and swap offsets
Tools and Commands
Use torch.cuda.memory_summary and torch.cuda.memory_allocated to inspect allocator state. Use nvidia-smi for coarse device memory and nvprof or Nsight for kernel timelines.
Turn on the PyTorch allocator trace via TORCH_CUDA_ALLOC_CONF to analyze allocation patterns. Add synthetic load tests that simulate many concurrent long sessions and measure eviction and swap behavior.
Operational Playbook: How to Deploy, Maintain, and Control Costs
Set per-model and per-tenant hard caps that trigger graceful fallbacks. Implement an admission controller that rejects or queues requests when memory pressure is high. Prewarm caches for high-frequency prefixes during low-load periods.
Build a small background sweeper process that performs TTL eviction, compaction, and optional offload to NVMe. Alert on sustained cache hit rate drops and elevated eviction rates so SREs can add capacity or tune TTLs
Questions to Decide Your Design Trade-Offs
Which do you need:
- Sub-50 millisecond tail latency or maximum throughput?
- How long are your typical sessions, and how many concurrent sessions must a single GPU host?
- How often do requests share prefixes, and which parts of the model dominate cache size?
Answer these and pick between heavier reuse and sharding or simpler session-based invalidation and full recompute strategies.
Start Building with $10 in Free API Credits Today!
Inference delivers OpenAI-compatible serverless inference APIs for top open source LLM models. You call the same style of endpoints you expect from a managed service while the backend handles model hosting, autoscaling, and resource allocation.
The API supports streaming outputs, token-level control, and request-level timeouts so developers can tune throughput and latency. Want the cheapest path to production for a large model without rewriting your code to a new SDK or protocol?
Start Fast With $10 Free Credits And Migration Paths
You can sign up and receive $10 in free API credits to test serverless inference, batch jobs, and document extraction workflows. The APIs mimic OpenAI patterns, which simplifies SDK changes and lets you validate performance and cost with real traffic.
Try a small RAG pipeline, measure token cost with and without KV cache, and iterate on packing and quantization choices while you use the credits. Which experiment will you run first to see immediate savings?
Related Reading
- Continuous Batching LLM
- Inference Solutions
- vLLM Multi-GPU
- Distributed Inference
- KV Caching
- Inference Acceleration
- Memory-Efficient Attention