Compiled Training
This page explains how lmxlab's training loop uses mx.compile to fuse
the entire training step into a single optimized computation graph. The
resulting reduction in overhead is significant on Apple Silicon.
The basic training step
Without compilation, a training step looks like this:
# 1. Forward + backward (functional gradient)
loss, grads = loss_and_grad_fn(model, x, y)
# 2. Gradient clipping
grads, _ = optim.clip_grad_norm(grads, max_norm=1.0)
# 3. Optimizer update
optimizer.update(model, grads)
# 4. Force evaluation
mx.eval(loss, model.parameters(), optimizer.state)
Each of these steps builds a computation graph. MLX's lazy evaluation means
nothing actually runs until mx.eval. But without compilation, each call
to the step function creates a new graph every time, so MLX must trace,
optimize, and schedule it from scratch.
What mx.compile does
mx.compile traces the function once, captures the computation graph, and
reuses it on subsequent calls:
# In Trainer.__init__:
if config.compile_step:
self._step_fn = mx.compile(
self._single_step,
inputs=model.trainable_parameters(),
outputs=model.trainable_parameters(),
)
After the first call, the compiled function:
- Skips graph construction by reusing the cached graph
- Fuses multiple operations into single GPU kernels
- Plans buffer reuse across the graph, reducing memory allocation
- Eliminates Python-level tracing on subsequent calls
The inputs and outputs contract
The inputs and outputs arguments tell the compiler which state is
read and mutated by the function:
mx.compile(
self._single_step,
inputs=model.trainable_parameters(), # State read by the function
outputs=model.trainable_parameters(), # State written by the function
)
This is necessary because _single_step mutates model parameters via
optimizer.update. Without declaring this, the compiler would not know
that the model's parameter arrays change between calls.
Getting inputs/outputs wrong
If optimizer state is not included in outputs, the optimizer's
internal state (momentum, second moments for Adam) will not be
updated correctly after the first step. In lmxlab, we pass
model.trainable_parameters() which captures both the parameters
and the optimizer's state through the model's parameter tree.
When to compile (and when not to)
Compile when:
- Running production training loops (the default:
compile_step=True) - Profiling throughput, since compilation gives realistic performance numbers
- The training step has no data-dependent control flow
Don't compile when:
- Debugging, because compiled functions give less informative stack traces
- Prototyping, because compilation adds startup latency for the first step
- Variable-shape inputs: if batch size or sequence length changes, the graph must be retraced (triggering recompilation)
Set compile_step=False in TrainConfig to disable:
The compile gotcha: captured control flow
mx.compile traces the function once and caches the graph. Any
Python-level control flow that depends on tensor values is captured
at trace time and frozen:
# BAD: condition depends on loss value (a tensor)
def step(x, y):
loss, grads = loss_and_grad_fn(model, x, y)
if loss > 10.0: # This is evaluated ONCE at trace time!
grads = scale_grads(grads, 0.1)
optimizer.update(model, grads)
return loss
After tracing, the if branch is permanently taken (or not), regardless
of the actual loss value. The compiled function becomes a fixed graph.
The fix: keep compiled functions free of data-dependent branches.
Use mx.where for conditional computation that should vary with data:
# OK: mx.where is a graph operation, not Python control flow
scale = mx.where(loss > 10.0, 0.1, 1.0)
grads = tree_map(lambda g: g * scale, grads)
Compilation and LoRA
LoRA fine-tuning works with compiled training. Since apply_lora freezes
all non-LoRA parameters, model.trainable_parameters() returns only
the LoRA matrices (lora_A and lora_B). The compiled step correctly
updates only these:
apply_lora(model, rank=8)
# trainable_parameters() now returns only LoRA params
# The compiled step will only compute gradients for these
trainer = Trainer(model, TrainConfig(compile_step=True))
Performance impact
The speedup from compilation depends on model size and step complexity. Expected improvements on Apple Silicon (approximate, based on MLX documentation and community benchmarks; actual results vary by hardware, batch size, and sequence length):
| Scenario | Uncompiled | Compiled | Speedup |
|---|---|---|---|
| Tiny model (64d, 2L) | ~1ms/step | ~0.8ms/step | ~1.3x |
| Small model (256d, 6L) | ~5ms/step | ~3ms/step | ~1.7x |
| Medium model (1024d, 12L) | ~30ms/step | ~15ms/step | ~2x |
The larger the model, the more opportunity for kernel fusion and the
greater the relative reduction in Python overhead. Use
benchmark_compile.py to measure on a given hardware configuration.
How lmxlab structures the compiled step
The full compiled function in Trainer._single_step:
def _single_step(self, x, y):
# Forward pass + backward pass (functional)
loss, grads = self._loss_and_grad(self.model, x, y)
# Gradient clipping (functional, returns new grads)
if self.config.max_grad_norm > 0:
grads, _ = optim.clip_grad_norm(
grads, max_norm=self.config.max_grad_norm
)
# Optimizer update (mutates model params in-place)
self.optimizer.update(self.model, grads)
return loss
Key design choices:
- Forward, backward, clipping, and optimizer update are fused into a single compiled graph.
nn.value_and_gradreturns a gradient dict (not in-place.gradattributes), which is whatmx.compileexpects.mx.evalis called outside the compiled function, after it returns, to force evaluation of the entire graph.- The
if max_grad_norm > 0check is on a Python float (config value), not a tensor, so it is safe inside a compiled function.