Production Optimizations
Production inference systems use optimizations that deliver 2-10x speedups over naive implementations. This page describes what those optimizations do, why they matter, and how they relate to lmxlab's implementations.
Educational, not implemented
lmxlab does not implement most of these optimizations. This page describes how they work, with references to production systems like vLLM, llama.cpp, and mlx-lm.
Flash Attention
The problem
Standard attention computes O = softmax(QK^T / sqrt(d)) * V. This
materializes the full N x N attention matrix in GPU memory. For a
4096-token sequence in float16, that matrix is 32 MB per head, per
layer. The bottleneck is not compute (modern GPUs have enormous
FLOP/s) but memory bandwidth: reading and writing these large
intermediate matrices dominates wall-clock time.
How Flash Attention solves it
Flash Attention (Dao et al., 2022) is an IO-aware exact attention algorithm. "Exact" is critical: it produces mathematically identical results, just with far fewer memory round-trips.
Tiling
Q, K, V are split into blocks that fit in GPU SRAM (on-chip fast memory, ~20 MB on an A100 vs 40-80 GB HBM). Each block of Q attends to all blocks of K, V without ever writing the full N x N matrix to HBM.
Online softmax
The core algorithmic insight. Standard softmax
requires knowing max(x_1, ..., x_N) across the entire row before
computing any output. Online softmax maintains running statistics
(a running maximum and running denominator) updated incrementally as
each new tile of K is processed. When a new tile produces a new
maximum, previous partial results are rescaled. This makes softmax
associative over tiles.
IO complexity
Standard attention requires O(Nd + N^2) HBM accesses. Flash Attention requires O(N^2 * d^2 * M^{-1}), where M is SRAM size. For typical d=128, M~100KB, this is many-fold fewer HBM accesses. The authors prove this is asymptotically optimal*.
What lmxlab does instead
lmxlab uses mx.fast.scaled_dot_product_attention, which is MLX's
optimized Metal kernel. It supports causal masking and computes
softmax in float32 regardless of input precision. While optimized
for Apple Silicon, it is not fully IO-aware in the Flash Attention
sense; the Metal FlashAttention project demonstrates that Flash
Attention-style tiling is feasible on Apple GPUs.
# lmxlab: uses MLX's optimized kernel
output = mx.fast.scaled_dot_product_attention(
q, k, v, scale=scale, mask=mask
)
References: FlashAttention (Dao et al., 2022), FlashAttention-2 (Dao, 2023), FlashAttention-3 (Dao et al., 2024)
KV Cache Optimizations
Why KV cache dominates memory
During autoregressive generation, each new token attends to all previous tokens. The KV cache stores previously computed K and V tensors so each step only computes projections for the new token.
The problem: KV cache scales as
2 * n_layers * n_heads * head_dim * seq_len * bytes_per_element.
For large models at long contexts, the KV cache can exceed the model
weights themselves. For a 70B model at 32K context in FP16, the KV
cache requires over 85 GB.
PagedAttention (vLLM)
Traditional systems pre-allocate contiguous memory for KV cache based on maximum sequence length. Since actual lengths vary, this wastes 60-80% of allocated memory through fragmentation.
PagedAttention (Kwon et al., 2023) borrows from OS virtual memory: KV cache is broken into fixed-size blocks (e.g., 16 tokens per block) stored non-contiguously. Each request maintains a block table mapping logical blocks to physical locations. The attention kernel follows block table pointers instead of reading contiguous memory.
Result: under 4% memory waste vs 60-80% in traditional systems, enabling 2-4x throughput improvement through larger batch sizes.
KV cache quantization
Since inference is memory-bandwidth-bound, reducing KV cache size directly improves generation speed. Keys are typically quantized per-channel and values per-token, because they have different statistical properties (keys exhibit channel-wise outliers). Quantizing KV cache to 4 bits enables 2x larger batch sizes or 4x longer sequence lengths.
What lmxlab does
lmxlab implements a straightforward KV cache where each layer stores K, V tensors that grow with sequence length:
# lmxlab: simple cache growth
logits, cache = model(next_token, cache=cache)
mx.eval(logits, *[c for pair in cache for c in pair])
This keeps the caching mechanism explicit. Production systems add PagedAttention and quantization on top of this same concept.
References: PagedAttention (Kwon et al., 2023), vLLM docs
Fused Kernels
What kernel fusion is
A GPU kernel is a function launched on the GPU. Each kernel reads inputs from and writes outputs to global memory (HBM). If operation A produces a tensor that operation B consumes, the naive approach writes A's output to HBM, then B reads it back. Kernel fusion merges A and B into a single kernel so the intermediate tensor lives only in registers or shared memory.
Why it matters
Modern GPUs are bottlenecked by memory bandwidth, not compute. An A100 has 312 TFLOP/s of compute but only 2 TB/s of bandwidth. For elementwise operations (activations, normalization, residual adds), the ratio of memory access to compute is extremely unfavorable, and they are almost entirely bandwidth-bound. Fusing them with adjacent operations eliminates round-trips to HBM.
Examples
Fused attention
Flash Attention is the canonical example, fusing QK^T matmul, softmax, and multiplication by V into a single tiled kernel.
Fused LayerNorm + Linear
RMSNorm computes reduction statistics, normalizes, then the next operation is typically a linear projection. Fusing avoids writing the normalized tensor to HBM.
Fused SwiGLU
The gated FFN computes
output = (xW_gate) * silu(xW_up), requiring two projections, an
activation, and a multiply. Fusing all four operations into one kernel
can yield 10-13% throughput improvement.
How MLX handles this
MLX's mx.compile performs graph-level fusion automatically. On
compilation, MLX analyzes the computation graph, identifies
fusion opportunities, and generates fused Metal shaders:
# lmxlab: mx.compile enables automatic fusion
self._step_fn = mx.compile(
self._single_step,
inputs=model.trainable_parameters(),
outputs=model.trainable_parameters(),
)
For cases where automatic fusion is insufficient, MLX supports
custom Metal kernels via mx.fast.metal_kernel(). The unified memory
architecture eliminates CPU-GPU transfer costs, but optimizing the
intra-GPU memory hierarchy (registers, threadgroup memory, device
memory) still matters.
References: Fused SwiGLU kernels (Bitdefender Research), MLX custom Metal kernels
Quantization for Inference
How quantization works
Quantization maps high-precision values (FP16) to lower-precision
representations (INT4/INT8):
x_q = round(x / scale) + zero_point. This reduces model size and
increases inference speed (smaller tensors = less bandwidth = faster
for bandwidth-bound operations).
Granularity matters
- Per-tensor: One scale for the entire weight tensor. Cheapest but crudest; outliers anywhere penalize everything.
- Per-channel: One scale per output channel. Captures channel-wise distribution differences.
- Per-group: Splits each channel into groups of G elements (commonly 32-128), each with its own scale. This is the dominant approach for LLMs, balancing accuracy and overhead.
Post-training quantization methods
GPTQ (Frantar et al., 2022): Processes weights sequentially within each layer. When a weight is quantized, the introduced error is compensated by adjusting remaining weights using approximate second-order (Hessian) information. Strong accuracy at 3-4 bits.
AWQ (Lin et al., 2023): Identifies "salient" weights (connected to large activations) and protects them by applying per-channel scaling before quantization. Faster than GPTQ to apply.
GGUF / llama.cpp K-quants: Uses hierarchical structure: super-blocks of 256 weights subdivided into groups of 32, where group scales are themselves quantized to INT8. K-quants intelligently allocate more bits to important layers (attention) and fewer to less important ones (some FFN layers).
Quantization-aware training (QAT)
QAT inserts fake quantization during training. The forward pass simulates quantized computation (with straight-through estimator for gradients through rounding). The model learns to be robust to quantization error. Better accuracy than PTQ at very low bits (2-3 bit), but requires full training infrastructure.
Practical guidance: PTQ first (fast, cheap, usually good enough at 4-bit). QAT only if PTQ accuracy is insufficient.
What lmxlab does
lmxlab wraps MLX's native affine quantization:
from lmxlab.core.quantize import quantize_model
# Quantize to 4-bit with group size 64
quantize_model(model, bits=4, group_size=64)
MLX provides mx.quantized_matmul() which operates directly on
quantized weights, dequantizing on-the-fly during computation. This
is the core performance primitive: the Metal kernel reads compressed
weights and avoids full dequantization.
References: GPTQ (Frantar et al., 2022), AWQ (Lin et al., 2023), SqueezeLLM (Kim et al., 2023), Visual Guide to Quantization (Grootendorst)
Speculative Decoding
The problem
Autoregressive generation is memory-bandwidth-bound: generating each token requires reading the entire model's weights but only performs a small matrix-vector multiplication. GPU compute units are vastly underutilized.
How it works
A small, fast draft model generates K candidate tokens. Then the large target model verifies all K tokens in a single forward pass (processing K tokens in parallel, like a prefill step with high GPU utilization).
The acceptance criterion is mathematically precise: for each draft
token x sampled from draft distribution q(x), accept with probability
min(1, p(x) / q(x)) where p(x) is the target distribution. On
rejection, sample from the residual distribution
normalize(max(0, p(x) - q(x))).
This guarantees the output distribution is identical to the target model's. Speculative decoding is lossless.
Variants
Medusa (Cai et al., 2024): Adds multiple lightweight decoding heads to the target model itself, where each head predicts a different future position. No separate draft model needed. 2-3x speedup.
EAGLE (Li et al., 2024): Trains a head that predicts hidden states (features) rather than tokens. Feature-level prediction is easier than token-level, yielding higher acceptance rates.
Lookahead decoding: No additional training or models. Uses Jacobi iteration to generate n-gram candidates and verify them.
What lmxlab does
lmxlab implements the basic draft-then-verify paradigm:
from lmxlab.inference.speculative import speculative_decode
tokens = speculative_decode(
target_model=large_model,
draft_model=small_model,
prompt=prompt,
max_tokens=100,
n_draft=5, # Draft 5 tokens per round
)
This is a complete implementation showing the core algorithm. Production systems add tree-structured verification, dynamic draft length, and KV cache management for rejected tokens.
References: Speculative decoding (Leviathan et al., 2022), Medusa (Cai et al., 2024), EAGLE (Li et al., 2024)
Continuous Batching
Why static batching wastes compute
In static batching, the server processes N requests together. The entire batch waits until the longest sequence finishes. If one request produces 500 tokens and others produce 50, those short requests sit idle for 90% of the batch duration.
How continuous batching works
Scheduling decisions happen at every generation step:
- At each decode iteration, check if any sequence has finished.
- Immediately evict finished sequences and insert waiting requests.
- The batch composition changes dynamically every iteration.
GPU resources freed by a completed request are immediately used by a new one. Throughput improvement: up to 23x over static batching in production benchmarks.
Supporting techniques
- Chunked prefill: Long prompts are processed in chunks interleaved with decode steps, preventing one long prompt from blocking the batch.
- Ragged batching: Variable-length sequences packed without padding, using offset arrays to track boundaries.
- PagedAttention integration: KV cache blocks freed by evicted sequences are reassigned to new ones.
Relevance to Apple Silicon
Continuous batching matters less for single-user local inference (lmxlab's primary use case) but becomes important when serving multiple users. The mlx-lm server implements continuous batching for MLX-based model serving.
References: Continuous batching (HuggingFace), Anyscale benchmark: 23x throughput
Tensor and Pipeline Parallelism
Tensor parallelism (TP)
Splits individual weight matrices across devices. For Y = XW, each
device holds a shard W_i and computes Y_i = X * W_i. Results are
combined via all-reduce. Pioneered by Megatron-LM (Shoeybi et al.,
2019).
Each transformer layer requires 2 all-reduce operations. TP needs high-bandwidth interconnect, so it works well within a node (NVLink at 900 GB/s) but poorly across nodes.
Pipeline parallelism (PP)
Assigns groups of consecutive layers to different devices. Device i sends activations to device i+1. Requires less bandwidth than TP (only activation tensors between stages, not all-reduce of full hidden dimensions).
Microbatching with 1F1B (one forward, one backward) scheduling keeps the pipeline full, reducing bubble overhead.
Relevance to Apple Silicon
A single M-series chip has no CPU-GPU transfer bottleneck because unified memory eliminates that class of problems. The M2/M3 Ultra chips are dual-die designs connected via UltraFusion (~800 GB/s), transparently handling a form of multi-chip parallelism.
For multi-node setups (Mac Studios via Thunderbolt 5 at ~40 Gbps), pipeline parallelism is preferred because TP's all-reduce at every layer is too bandwidth-hungry for Thunderbolt. Recent work demonstrates TB5-connected Mac Studio clusters running distributed inference with MLX for 1T-parameter MoE models.
References: Megatron-LM (Shoeybi et al., 2019), Multi-node expert parallelism on Apple Silicon
Summary: lmxlab vs Production Systems
| Concept | lmxlab (educational) | Production |
|---|---|---|
| Attention | mx.fast.scaled_dot_product_attention |
Flash Attention (tiled, IO-aware) |
| KV cache | Simple growing tensors | PagedAttention + quantized cache |
| Kernel fusion | mx.compile (automatic) |
Hand-written fused kernels |
| Quantization | quantize_model(bits=4) |
GPTQ/AWQ/K-quants + calibration |
| Decoding | Greedy / top-k / top-p | Speculative decoding + tree verify |
| Batching | Single sequence | Continuous batching + chunked prefill |
| Parallelism | Single device | TP + PP across devices |
The readable implementations show what these operations compute. The optimized versions show how to compute them fast. Understanding the simple version makes the optimized version comprehensible.