First Training Run
This guide walks through training a small language model from scratch on Apple Silicon. The result is a model that memorizes a short text and generates completions from it.
Prerequisites
- Preparing text data for training
- Using lmxlab's
Trainerwithmx.compile - Monitoring loss curves and generating text
- MLX lazy evaluation and unified memory in the training loop
The full script
The complete training script is shown below. The sections that follow explain each part.
from dataclasses import replace
import mlx.core as mx
from lmxlab.data.batching import batch_iterator
from lmxlab.data.tokenizer import CharTokenizer
from lmxlab.models.base import LanguageModel
from lmxlab.models.generate import generate
from lmxlab.models.gpt import gpt_tiny
from lmxlab.training.callbacks import MetricsLogger, ThroughputMonitor
from lmxlab.training.config import TrainConfig
from lmxlab.training.trainer import Trainer
mx.random.seed(42)
# 1. Prepare data
text = (
"To be, or not to be, that is the question: "
"Whether 'tis nobler in the mind to suffer "
"The slings and arrows of outrageous fortune, "
"Or to take arms against a sea of troubles, "
"And by opposing end them."
)
tokenizer = CharTokenizer(text)
tokens = mx.array(tokenizer.encode(text), dtype=mx.int32)
# 2. Build model
config = replace(gpt_tiny(), vocab_size=tokenizer.vocab_size)
model = LanguageModel(config)
mx.eval(model.parameters())
# 3. Train
train_config = TrainConfig(
learning_rate=1e-3,
max_steps=200,
batch_size=4,
log_interval=25,
compile_step=False,
warmup_steps=10,
)
trainer = Trainer(
model,
train_config,
callbacks=[
MetricsLogger(log_interval=25),
ThroughputMonitor(log_interval=25, tokens_per_step=4 * 32),
],
)
history = trainer.train(
batch_iterator(tokens, batch_size=4, seq_len=32)
)
# 4. Generate
prompt = mx.array([tokenizer.encode("To be")])
output = generate(model, prompt, max_tokens=60, temperature=0.8)
print(tokenizer.decode(output[0].tolist()))
Run it:
Step-by-step walkthrough
1. Prepare text data
from lmxlab.data.tokenizer import CharTokenizer
text = "To be, or not to be, that is the question..."
tokenizer = CharTokenizer(text)
tokens = mx.array(tokenizer.encode(text), dtype=mx.int32)
CharTokenizer maps each unique character to an integer ID.
It builds its vocabulary from the input text, so vocab_size
equals the number of distinct characters. For real training,
a BPE tokenizer such as TiktokenTokenizer('gpt2') is more
appropriate.
The batch_iterator takes this flat token array and creates
sliding windows of (input, target) pairs:
from lmxlab.data.batching import batch_iterator
for x, y in batch_iterator(tokens, batch_size=4, seq_len=32):
# x shape: (4, 32) -- 4 sequences of 32 tokens
# y shape: (4, 32) -- shifted by 1 position
break
Each target token is the next token after the corresponding input token. This is the standard language modeling objective: predict the next token at every position.
2. Build the model
from dataclasses import replace
from lmxlab.models.gpt import gpt_tiny
from lmxlab.models.base import LanguageModel
config = replace(gpt_tiny(), vocab_size=tokenizer.vocab_size)
model = LanguageModel(config)
mx.eval(model.parameters())
gpt_tiny() returns a ModelConfig with small dimensions
(d_model=64, 2 layers, 4 heads). The vocab_size is overridden to
match the tokenizer using dataclasses.replace.
The mx.eval(model.parameters()) call materializes the
weights. MLX is lazy by default: without this call, the
random weight tensors would not be computed until
first use. Evaluating them up front establishes a clean baseline.
Swapping architectures
To use LLaMA instead of GPT, replace the config factory. The rest of the code is identical:
from lmxlab.models.llama import llama_tiny
config = replace(llama_tiny(), vocab_size=tokenizer.vocab_size)
The same LanguageModel class handles both because
architecture differences live in BlockConfig string names
('mha' vs 'gqa', 'layer_norm' vs 'rms_norm', etc.),
resolved by the registry at construction time.
3. Configure and run training
from lmxlab.training.config import TrainConfig
from lmxlab.training.trainer import Trainer
train_config = TrainConfig(
learning_rate=1e-3,
max_steps=200,
batch_size=4,
log_interval=25,
compile_step=False,
warmup_steps=10,
)
trainer = Trainer(model, train_config)
history = trainer.train(
batch_iterator(tokens, batch_size=4, seq_len=32)
)
The Trainer performs the following steps:
- Wraps the loss function with
nn.value_and_gradfor functional gradient computation - Optionally wraps the full step (forward + backward + update)
with
mx.compilefor hardware-fused execution - Each step: computes loss and gradients, clips gradients, updates weights via AdamW
- Calls
mx.eval(loss, model.parameters(), optimizer.state)at the eval boundary, which triggers actual GPU computation
compile_step=False for tiny models
Compilation is disabled here because the tiny model runs in
microseconds and compilation overhead dominates. For real
models (millions of parameters), set compile_step=True
for a significant speedup. See
Compiled Training.
Training config options:
| Parameter | Default | Description |
|---|---|---|
learning_rate |
3e-4 | Peak learning rate |
optimizer |
'adamw' |
Also: 'lion', 'adafactor', 'sgd' |
lr_schedule |
'cosine' |
Also: 'linear', 'constant' |
warmup_steps |
100 | Linear warmup before decay |
max_grad_norm |
1.0 | Gradient clipping threshold |
grad_accumulation_steps |
1 | Micro-batches per update |
compile_step |
True | Use mx.compile |
4. Monitor training
Add callbacks to observe training dynamics:
from lmxlab.training.callbacks import (
MetricsLogger,
ThroughputMonitor,
)
trainer = Trainer(
model,
train_config,
callbacks=[
MetricsLogger(log_interval=25),
ThroughputMonitor(
log_interval=25,
tokens_per_step=4 * 32, # batch_size * seq_len
),
],
)
MetricsLogger prints loss and learning rate at each interval.
ThroughputMonitor reports tokens per second, which is useful for
comparing compiled vs uncompiled steps or different model sizes.
Expected training behavior:
- Loss starts high (approximately 4-5 for character-level,
10-11 for BPE with large vocab), corresponding to
-log(1/vocab_size). - Loss drops in the first 50 or so steps as the model learns character frequencies.
- Loss plateaus around 1.0-2.0 for this tiny dataset. With more data and a larger model, it would continue decreasing.
5. Generate text
from lmxlab.models.generate import generate
prompt = mx.array([tokenizer.encode("To be")])
output = generate(
model, prompt,
max_tokens=60,
temperature=0.8,
top_k=10,
)
print(tokenizer.decode(output[0].tolist()))
Generation uses KV caching: the prompt is processed in one forward pass (prefill), then each new token reuses cached key/value projections. This reduces generation from O(n^2) to O(n).
Sampling parameters:
| Parameter | Effect |
|---|---|
temperature=0.0 |
Greedy (always pick the most likely token) |
temperature=0.8 |
Balanced creativity |
temperature=1.5 |
More random, more diverse |
top_k=10 |
Only consider the top 10 most likely tokens |
top_p=0.95 |
Nucleus sampling (dynamic vocabulary cutoff) |
repetition_penalty=1.2 |
Penalize tokens already generated |
For a tiny model trained on a single paragraph, the output will approximately reproduce the training text with some variation.
Evaluating the model
After training, measure quality with perplexity or bits-per-byte:
from lmxlab.eval.metrics import perplexity, bits_per_byte
# Create eval batches
eval_batches = list(
batch_iterator(tokens, batch_size=2, seq_len=32, shuffle=False)
)
ppl = perplexity(model, [mx.concatenate([x, y[:, -1:]], axis=1)
for x, y in eval_batches])
print(f"Perplexity: {ppl:.2f}")
Lower perplexity indicates higher confidence in predictions. A model that perfectly memorized the training data would approach perplexity 1.0.
Scaling up
The following modifications extend this example to realistic training settings.
Larger model:
from lmxlab.models.llama import llama_config
config = llama_config(
vocab_size=32000,
d_model=512,
n_heads=8,
n_kv_heads=4,
n_layers=6,
d_ff=1376,
)
BPE tokenizer:
from lmxlab.data.tokenizer import TiktokenTokenizer
tokenizer = TiktokenTokenizer('gpt2') # 50257 tokens
Real text data:
from lmxlab.data.dataset import TextDataset
dataset = TextDataset('path/to/text.txt', tokenizer, seq_len=128)
Compiled training (for larger models):
train_config = TrainConfig(
compile_step=True, # Fuse forward+backward+update
learning_rate=3e-4,
max_steps=5000,
batch_size=32,
)
Gradient accumulation (when a batch does not fit in memory):
Checkpointing:
from lmxlab.training.checkpoints import save_checkpoint
save_checkpoint(model, trainer.optimizer, trainer.step, 'ckpt/')
Next steps
- Quickstart -- Forward passes and generation without training
- Compiled Training
-- How
mx.compilespeeds up the training loop - MLX Idioms -- Lazy evaluation, eval boundaries, and unified memory
- Recipes -- 30+ scripts for training, fine-tuning, and evaluation