Training
Compiled training loop, optimizers, and training utilities.
Trainer
lmxlab.training.trainer.Trainer
Training loop with compiled steps and gradient accumulation.
Uses nn.value_and_grad for functional gradient computation
and mx.compile for the full training step. When
grad_accumulation_steps > 1, gradients are averaged over
multiple micro-batches before a single optimizer update.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
LanguageModel
|
Language model to train. |
required |
config
|
TrainConfig
|
Training configuration. |
required |
optimizer
|
Optimizer | None
|
Optional pre-built optimizer. |
None
|
callbacks
|
list[Callback] | None
|
Optional list of callbacks. |
None
|
Source code in src/lmxlab/training/trainer.py
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 | |
_accum_steps = config.grad_accumulation_steps
instance-attribute
_loss_and_grad = nn.value_and_grad(model, _loss_fn)
instance-attribute
_step_fn = mx.compile(self._single_step, inputs=(model.trainable_parameters()), outputs=(model.trainable_parameters()))
instance-attribute
callbacks = callbacks or []
instance-attribute
config = config
instance-attribute
model = model
instance-attribute
optimizer = optimizer
instance-attribute
step = 0
instance-attribute
__init__(model, config, optimizer=None, callbacks=None)
Source code in src/lmxlab/training/trainer.py
_accumulation_step(micro_batches)
Gradient accumulation step over multiple micro-batches.
Computes gradients for each micro-batch, averages them, then applies a single optimizer update.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
micro_batches
|
list[tuple[array, array]]
|
List of (input, target) micro-batches. |
required |
Returns:
| Type | Description |
|---|---|
tuple[array, array]
|
Tuple of (avg_loss, grad_norm). |
Source code in src/lmxlab/training/trainer.py
_maybe_eval(eval_data, metrics)
Run evaluation if due.
Source code in src/lmxlab/training/trainer.py
_single_step(x, y)
Single training step: forward + backward + update.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
array
|
Input tokens (batch, seq_len). |
required |
y
|
array
|
Target tokens (batch, seq_len). |
required |
Returns:
| Type | Description |
|---|---|
array
|
Tuple of (loss, grad_norm) where grad_norm is the |
array
|
L2 norm of gradients before clipping. |
Source code in src/lmxlab/training/trainer.py
_train_accumulated(train_data, eval_data)
Training loop with gradient accumulation.
Source code in src/lmxlab/training/trainer.py
_train_simple(train_data, eval_data)
Training loop without gradient accumulation.
Source code in src/lmxlab/training/trainer.py
evaluate(eval_data)
Run evaluation over the eval dataset.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
eval_data
|
Iterator[tuple[array, array]]
|
Iterator yielding (input, target) batches. |
required |
Returns:
| Type | Description |
|---|---|
dict[str, float]
|
Dict with 'eval_loss'. |
Source code in src/lmxlab/training/trainer.py
train(train_data, eval_data=None)
Run the full training loop.
When grad_accumulation_steps > 1, collects that many
micro-batches before each optimizer update.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
train_data
|
Iterator[tuple[array, array]]
|
Iterator yielding (input, target) batches. |
required |
eval_data
|
Iterator[tuple[array, array]] | None
|
Optional eval data iterator. |
None
|
Returns:
| Type | Description |
|---|---|
list[dict[str, Any]]
|
List of per-step metrics dicts. |
Source code in src/lmxlab/training/trainer.py
train_step(batch)
Execute one training step with eval boundary.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
batch
|
tuple[array, array]
|
Tuple of (input_tokens, target_tokens). |
required |
Returns:
| Type | Description |
|---|---|
dict[str, float]
|
Dict with 'loss', 'learning_rate', and |
dict[str, float]
|
'grad_norm'. |
Source code in src/lmxlab/training/trainer.py
train_step_accumulated(micro_batches)
Execute one accumulated training step.
Averages gradients over micro-batches, then updates once.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
micro_batches
|
list[tuple[array, array]]
|
List of (input, target) micro-batches. |
required |
Returns:
| Type | Description |
|---|---|
dict[str, float]
|
Dict with 'loss', 'learning_rate', and |
dict[str, float]
|
'grad_norm'. |
Source code in src/lmxlab/training/trainer.py
Training Config
lmxlab.training.config.TrainConfig
dataclass
Configuration for the training loop.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
learning_rate
|
float
|
Peak learning rate. |
0.0003
|
weight_decay
|
float
|
Weight decay coefficient. |
0.01
|
warmup_steps
|
int
|
Linear warmup steps. |
100
|
max_steps
|
int
|
Maximum training steps. |
1000
|
batch_size
|
int
|
Training batch size. |
32
|
grad_accumulation_steps
|
int
|
Number of micro-batches to accumulate before an optimizer step. |
1
|
max_grad_norm
|
float
|
Maximum gradient norm for clipping. |
1.0
|
eval_interval
|
int
|
Steps between evaluations. |
100
|
log_interval
|
int
|
Steps between logging. |
10
|
checkpoint_interval
|
int
|
Steps between checkpoints. |
500
|
optimizer
|
str
|
Optimizer name ('adamw', 'lion', 'adafactor', 'sgd'). |
'adamw'
|
lr_schedule
|
str
|
Learning rate schedule ('cosine', 'linear', 'constant'). |
'cosine'
|
compile_step
|
bool
|
Whether to mx.compile the training step. |
True
|
seed
|
int
|
Random seed. |
42
|
Source code in src/lmxlab/training/config.py
batch_size = 32
class-attribute
instance-attribute
checkpoint_interval = 500
class-attribute
instance-attribute
compile_step = True
class-attribute
instance-attribute
eval_interval = 100
class-attribute
instance-attribute
grad_accumulation_steps = 1
class-attribute
instance-attribute
learning_rate = 0.0003
class-attribute
instance-attribute
log_interval = 10
class-attribute
instance-attribute
lr_schedule = 'cosine'
class-attribute
instance-attribute
max_grad_norm = 1.0
class-attribute
instance-attribute
max_steps = 1000
class-attribute
instance-attribute
optimizer = 'adamw'
class-attribute
instance-attribute
seed = 42
class-attribute
instance-attribute
warmup_steps = 100
class-attribute
instance-attribute
weight_decay = 0.01
class-attribute
instance-attribute
__init__(learning_rate=0.0003, weight_decay=0.01, warmup_steps=100, max_steps=1000, batch_size=32, grad_accumulation_steps=1, max_grad_norm=1.0, eval_interval=100, log_interval=10, checkpoint_interval=500, optimizer='adamw', lr_schedule='cosine', compile_step=True, seed=42)
Optimizers
lmxlab.training.optimizers.create_optimizer(config)
Create an optimizer with learning rate schedule.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
TrainConfig
|
Training configuration. |
required |
Returns:
| Type | Description |
|---|---|
Optimizer
|
Configured optimizer. |
Source code in src/lmxlab/training/optimizers.py
lmxlab.training.optimizers.create_schedule(config)
Create a learning rate schedule.
Supports warmup + decay patterns.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
TrainConfig
|
Training configuration. |
required |
Returns:
| Type | Description |
|---|---|
Callable[[int], float]
|
Learning rate scheduler. |
Source code in src/lmxlab/training/optimizers.py
Checkpoints
lmxlab.training.checkpoints.save_checkpoint(path, model, optimizer=None, step=0, metadata=None)
Save a training checkpoint.
Saves model weights as safetensors and metadata as JSON.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path
|
str | Path
|
Directory to save checkpoint. |
required |
model
|
LanguageModel
|
Model to save. |
required |
optimizer
|
Optimizer | None
|
Optional optimizer to save state. |
None
|
step
|
int
|
Current training step. |
0
|
metadata
|
dict[str, Any] | None
|
Additional metadata to save. |
None
|
Source code in src/lmxlab/training/checkpoints.py
lmxlab.training.checkpoints.load_checkpoint(path, model, optimizer=None)
Load a training checkpoint.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path
|
str | Path
|
Directory containing checkpoint. |
required |
model
|
LanguageModel
|
Model to load weights into. |
required |
optimizer
|
Optimizer | None
|
Optional optimizer to load state into. |
None
|
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
Metadata dict from the checkpoint. |
Source code in src/lmxlab/training/checkpoints.py
Callbacks
lmxlab.training.callbacks.Callback
Bases: Protocol
Protocol for training callbacks.
Source code in src/lmxlab/training/callbacks.py
lmxlab.training.callbacks.MetricsLogger
Logs training metrics at configured intervals.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
log_interval
|
int
|
Steps between log outputs. |
10
|
Source code in src/lmxlab/training/callbacks.py
lmxlab.training.callbacks.EarlyStopping
Stop training when eval loss stops improving.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
patience
|
int
|
Steps without improvement before stopping. |
5
|
min_delta
|
float
|
Minimum change to qualify as improvement. |
0.001
|
Source code in src/lmxlab/training/callbacks.py
lmxlab.training.callbacks.ThroughputMonitor
Tracks training throughput (tokens/sec, steps/sec).
Reports throughput at configured intervals, useful for understanding MLX performance characteristics and comparing compiled vs uncompiled training.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
log_interval
|
int
|
Steps between throughput reports. |
10
|
tokens_per_step
|
int | None
|
Tokens processed per step (batch_size * seq_len). If None, only reports steps/sec. |
None
|
Source code in src/lmxlab/training/callbacks.py
DPO
lmxlab.training.dpo
Direct Preference Optimization (DPO) training.
dpo_loss(model, ref_model, chosen, rejected, beta=0.1)
Compute DPO loss.
DPO directly optimizes preferences without reward modeling. L = -log(sigmoid(beta * (log_ratio_chosen - log_ratio_rejected)))
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
LanguageModel
|
Policy model being trained. |
required |
ref_model
|
LanguageModel
|
Reference (frozen) model. |
required |
chosen
|
array
|
Preferred sequence token IDs (batch, seq_len). |
required |
rejected
|
array
|
Dispreferred sequence token IDs (batch, seq_len). |
required |
beta
|
float
|
Temperature parameter controlling deviation from ref. |
0.1
|
Returns:
| Type | Description |
|---|---|
array
|
Scalar DPO loss. |
Source code in src/lmxlab/training/dpo.py
GRPO
lmxlab.training.grpo
Group Relative Policy Optimization (GRPO).
grpo_loss(model, ref_model, prompts, completions, rewards, beta=0.1, epsilon=0.2)
Compute GRPO loss.
GRPO uses group-relative rewards: for each prompt, generate multiple completions, compute rewards, normalize within the group, and optimize using a clipped surrogate objective.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
LanguageModel
|
Policy model being trained. |
required |
ref_model
|
LanguageModel
|
Reference (frozen) model. |
required |
prompts
|
array
|
Prompt token IDs (batch, prompt_len). |
required |
completions
|
array
|
Full sequences (batch, total_len). |
required |
rewards
|
array
|
Scalar rewards per completion (batch,). |
required |
beta
|
float
|
KL penalty coefficient. |
0.1
|
epsilon
|
float
|
Clipping range for surrogate objective. |
0.2
|
Returns:
| Type | Description |
|---|---|
array
|
Scalar GRPO loss. |
Source code in src/lmxlab/training/grpo.py
Multi-Token Prediction
lmxlab.training.mtp.MTPHead
Bases: Module
Single multi-token prediction head.
Takes hidden states and previous target embeddings, normalizes both, concatenates, projects back to d_model, then runs through a transformer block.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
d_model
|
int
|
Hidden dimension. |
required |
block_config
|
BlockConfig
|
Block configuration for the MTP block. |
required |
Source code in src/lmxlab/training/mtp.py
__init__(d_model, block_config)
Source code in src/lmxlab/training/mtp.py
__call__(h, prev_embed, mask=None)
Produce hidden states for future token prediction.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
h
|
array
|
Hidden states (batch, seq_len, d_model). |
required |
prev_embed
|
array
|
Embeddings of previous target tokens (batch, seq_len, d_model). |
required |
mask
|
array | None
|
Optional causal mask. |
None
|
Returns:
| Type | Description |
|---|---|
array
|
Hidden states (batch, seq_len, d_model). |
Source code in src/lmxlab/training/mtp.py
lmxlab.training.mtp.MultiTokenPrediction
Bases: Module
Multi-Token Prediction wrapper around a LanguageModel.
Adds n_predict auxiliary prediction heads that predict future tokens at each position. Shares the base model's lm_head for logit projection.
Training-only module. At inference time, use the base model directly.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
LanguageModel
|
Base language model. |
required |
n_predict
|
int
|
Number of future tokens to predict. |
2
|
mtp_weight
|
float
|
Weight for auxiliary MTP losses. |
0.3
|
block_config
|
BlockConfig | None
|
Block config for MTP heads. If None, uses the base model's block config. |
None
|
Source code in src/lmxlab/training/mtp.py
88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 | |
__init__(model, n_predict=2, mtp_weight=0.3, block_config=None)
Source code in src/lmxlab/training/mtp.py
__call__(x, targets)
Forward pass with multi-token prediction.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
array
|
Input token IDs (batch, seq_len). |
required |
targets
|
array
|
Target token IDs (batch, seq_len). For MTP depth k, predicts target at t+k. |
required |
Returns:
| Type | Description |
|---|---|
array
|
Tuple of (main_logits, loss_dict) where loss_dict |
dict[str, array]
|
contains 'main_loss', 'mtp_loss', 'total_loss'. |
Source code in src/lmxlab/training/mtp.py
151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 | |
Curriculum Learning
lmxlab.training.curriculum
Curriculum learning utilities.
difficulty_curriculum(easy_data, hard_data, batch_size, seq_len, n_batches=200, warmup_fraction=0.5)
Mix easy and hard data with increasing difficulty.
Starts with mostly easy data and transitions to hard data.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
easy_data
|
array
|
Token array of easier text. |
required |
hard_data
|
array
|
Token array of harder text. |
required |
batch_size
|
int
|
Sequences per batch. |
required |
seq_len
|
int
|
Sequence length. |
required |
n_batches
|
int
|
Total number of batches. |
200
|
warmup_fraction
|
float
|
Fraction of training spent warming up. |
0.5
|
Yields:
| Type | Description |
|---|---|
tuple[array, array]
|
(input, target) tuples with mixed difficulty. |
Source code in src/lmxlab/training/curriculum.py
length_curriculum(tokens, batch_size, min_seq_len=32, max_seq_len=512, n_stages=4, batches_per_stage=100)
Generate batches with increasing sequence length.
Starts with short sequences and gradually increases, following curriculum learning principles.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
tokens
|
array
|
Flat array of token IDs. |
required |
batch_size
|
int
|
Sequences per batch. |
required |
min_seq_len
|
int
|
Starting sequence length. |
32
|
max_seq_len
|
int
|
Final sequence length. |
512
|
n_stages
|
int
|
Number of curriculum stages. |
4
|
batches_per_stage
|
int
|
Batches per stage. |
100
|
Yields:
| Type | Description |
|---|---|
tuple[array, array]
|
(input, target) tuples with progressively longer sequences. |
Source code in src/lmxlab/training/curriculum.py
Knowledge Distillation
lmxlab.training.distillation
Knowledge distillation training losses.
Implements teacher-student distillation where a smaller student model learns from a larger teacher model's soft probability distributions (Hinton et al., 2015, arXiv:1503.02531).
The key insight: soft targets carry more information than hard labels. A teacher assigning 0.7 to "cat" and 0.2 to "kitten" teaches the student about word similarity, not just correctness.
Supported modes:
- Logit distillation: KL divergence between temperature-scaled teacher and student logits. The standard approach.
- Combined loss: Weighted mix of distillation loss and standard cross-entropy on hard targets. Balances learning from teacher with learning from ground truth.
Example::
from lmxlab.training.distillation import distillation_loss
# Teacher is frozen, student is trained
loss = distillation_loss(
student, teacher, tokens,
temperature=4.0, alpha=0.7,
)
distillation_loss(student, teacher, tokens, temperature=4.0, alpha=0.7)
Compute combined distillation + hard-target loss.
Loss = alpha * KL(teacher || student) * T^2 + (1 - alpha) * CE(student, targets)
The T^2 factor compensates for the gradient magnitude reduction caused by temperature scaling (Hinton et al.).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
student
|
LanguageModel
|
Student model (being trained). |
required |
teacher
|
LanguageModel
|
Teacher model (frozen, no gradients). |
required |
tokens
|
array
|
Input token IDs (batch, seq_len). Targets are tokens shifted by one position. |
required |
temperature
|
float
|
Softmax temperature for soft targets. Higher = softer distributions, more knowledge transfer. Typical values: 2-10. |
4.0
|
alpha
|
float
|
Weight for distillation loss (0-1). Higher means more reliance on teacher, less on hard targets. |
0.7
|
Returns:
| Type | Description |
|---|---|
array
|
Scalar combined loss. |
Source code in src/lmxlab/training/distillation.py
soft_target_loss(student_logits, teacher_logits, temperature=4.0)
KL divergence between temperature-scaled distributions.
KL(teacher || student) computed on softened logits. Multiplied by T^2 to maintain gradient scale.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
student_logits
|
array
|
Student output (batch, seq_len, vocab). |
required |
teacher_logits
|
array
|
Teacher output (batch, seq_len, vocab). |
required |
temperature
|
float
|
Softmax temperature. |
4.0
|
Returns:
| Type | Description |
|---|---|
array
|
Scalar KL divergence loss. |