Skip to content

Training

Compiled training loop, optimizers, and training utilities.

Trainer

lmxlab.training.trainer.Trainer

Training loop with compiled steps and gradient accumulation.

Uses nn.value_and_grad for functional gradient computation and mx.compile for the full training step. When grad_accumulation_steps > 1, gradients are averaged over multiple micro-batches before a single optimizer update.

Parameters:

Name Type Description Default
model LanguageModel

Language model to train.

required
config TrainConfig

Training configuration.

required
optimizer Optimizer | None

Optional pre-built optimizer.

None
callbacks list[Callback] | None

Optional list of callbacks.

None
Source code in src/lmxlab/training/trainer.py
class Trainer:
    """Training loop with compiled steps and gradient accumulation.

    Uses nn.value_and_grad for functional gradient computation
    and mx.compile for the full training step. When
    ``grad_accumulation_steps > 1``, gradients are averaged over
    multiple micro-batches before a single optimizer update.

    Args:
        model: Language model to train.
        config: Training configuration.
        optimizer: Optional pre-built optimizer.
        callbacks: Optional list of callbacks.
    """

    def __init__(
        self,
        model: LanguageModel,
        config: TrainConfig,
        optimizer: optim.Optimizer | None = None,
        callbacks: list[Callback] | None = None,
    ) -> None:
        self.model = model
        self.config = config
        if optimizer is not None:
            self.optimizer = optimizer
        elif model.config.mup_base_width is not None:
            self.optimizer = create_mup_optimizer(
                config,
                model.config.width_mult,
            )
        else:
            self.optimizer = create_optimizer(config)
        self.callbacks = callbacks or []
        self.step = 0
        self._accum_steps = config.grad_accumulation_steps

        # Build the training step function
        self._loss_and_grad = nn.value_and_grad(model, _loss_fn)

        if self._accum_steps <= 1:
            # No accumulation: compile full step (fwd + bwd + update)
            # MultiOptimizer (μP) is incompatible with mx.compile
            can_compile = config.compile_step and not isinstance(
                self.optimizer,
                optim.MultiOptimizer,
            )
            if can_compile:
                self._step_fn = mx.compile(
                    self._single_step,
                    inputs=model.trainable_parameters(),
                    outputs=model.trainable_parameters(),
                )
            else:
                self._step_fn = self._single_step

    def _single_step(
        self,
        x: mx.array,
        y: mx.array,
    ) -> tuple[mx.array, mx.array]:
        """Single training step: forward + backward + update.

        Args:
            x: Input tokens (batch, seq_len).
            y: Target tokens (batch, seq_len).

        Returns:
            Tuple of (loss, grad_norm) where grad_norm is the
            L2 norm of gradients before clipping.
        """
        loss, grads = self._loss_and_grad(self.model, x, y)

        # Gradient clipping
        if self.config.max_grad_norm > 0:
            grads, grad_norm = optim.clip_grad_norm(
                grads, max_norm=self.config.max_grad_norm
            )
        else:
            flat = tree_flatten(grads)
            grad_norm = mx.sqrt(sum(mx.sum(g * g) for _, g in flat))

        self.optimizer.update(self.model, grads)
        return loss, grad_norm

    def _accumulation_step(
        self,
        micro_batches: list[tuple[mx.array, mx.array]],
    ) -> tuple[mx.array, mx.array]:
        """Gradient accumulation step over multiple micro-batches.

        Computes gradients for each micro-batch, averages them,
        then applies a single optimizer update.

        Args:
            micro_batches: List of (input, target) micro-batches.

        Returns:
            Tuple of (avg_loss, grad_norm).
        """
        n = len(micro_batches)
        total_loss = mx.array(0.0)
        acc_grads = None

        for x, y in micro_batches:
            loss, grads = self._loss_and_grad(self.model, x, y)
            total_loss = total_loss + loss
            if acc_grads is None:
                acc_grads = grads
            else:
                acc_grads = tree_map(
                    lambda a, b: a + b,
                    acc_grads,
                    grads,
                )

        # Average gradients
        avg_grads = tree_map(lambda g: g / n, acc_grads)
        avg_loss = total_loss / n

        # Gradient clipping
        if self.config.max_grad_norm > 0:
            avg_grads, grad_norm = optim.clip_grad_norm(
                avg_grads, max_norm=self.config.max_grad_norm
            )
        else:
            flat = tree_flatten(avg_grads)
            grad_norm = mx.sqrt(sum(mx.sum(g * g) for _, g in flat))

        self.optimizer.update(self.model, avg_grads)
        return avg_loss, grad_norm

    def train_step(
        self,
        batch: tuple[mx.array, mx.array],
    ) -> dict[str, float]:
        """Execute one training step with eval boundary.

        Args:
            batch: Tuple of (input_tokens, target_tokens).

        Returns:
            Dict with 'loss', 'learning_rate', and
            'grad_norm'.
        """
        x, y = batch
        loss, grad_norm = self._step_fn(x, y)

        # Explicit eval boundary
        mx.eval(
            loss,
            grad_norm,
            self.model.parameters(),
            self.optimizer.state,
        )

        self.step += 1
        lr = self.optimizer.learning_rate
        if callable(lr):
            lr = lr(self.step)

        metrics = {
            "loss": loss.item(),
            "learning_rate": float(lr),
            "grad_norm": grad_norm.item(),
        }

        for cb in self.callbacks:
            cb.on_step_end(self.step, metrics)

        return metrics

    def train_step_accumulated(
        self,
        micro_batches: list[tuple[mx.array, mx.array]],
    ) -> dict[str, float]:
        """Execute one accumulated training step.

        Averages gradients over micro-batches, then updates once.

        Args:
            micro_batches: List of (input, target) micro-batches.

        Returns:
            Dict with 'loss', 'learning_rate', and
            'grad_norm'.
        """
        loss, grad_norm = self._accumulation_step(
            micro_batches,
        )

        # Explicit eval boundary
        mx.eval(
            loss,
            grad_norm,
            self.model.parameters(),
            self.optimizer.state,
        )

        self.step += 1
        lr = self.optimizer.learning_rate
        if callable(lr):
            lr = lr(self.step)

        metrics = {
            "loss": loss.item(),
            "learning_rate": float(lr),
            "grad_norm": grad_norm.item(),
        }

        for cb in self.callbacks:
            cb.on_step_end(self.step, metrics)

        return metrics

    def train(
        self,
        train_data: Iterator[tuple[mx.array, mx.array]],
        eval_data: Iterator[tuple[mx.array, mx.array]] | None = None,
    ) -> list[dict[str, Any]]:
        """Run the full training loop.

        When ``grad_accumulation_steps > 1``, collects that many
        micro-batches before each optimizer update.

        Args:
            train_data: Iterator yielding (input, target) batches.
            eval_data: Optional eval data iterator.

        Returns:
            List of per-step metrics dicts.
        """
        for cb in self.callbacks:
            cb.on_train_begin(self.config)

        history: list[dict[str, Any]] = []

        if self._accum_steps > 1:
            history = self._train_accumulated(train_data, eval_data)
        else:
            history = self._train_simple(train_data, eval_data)

        for cb in self.callbacks:
            cb.on_train_end(history)

        return history

    def _train_simple(
        self,
        train_data: Iterator[tuple[mx.array, mx.array]],
        eval_data: Iterator[tuple[mx.array, mx.array]] | None,
    ) -> list[dict[str, Any]]:
        """Training loop without gradient accumulation."""
        history: list[dict[str, Any]] = []

        for batch in train_data:
            if self.step >= self.config.max_steps:
                break

            metrics = self.train_step(batch)
            history.append(metrics)
            self._maybe_eval(eval_data, metrics)

        return history

    def _train_accumulated(
        self,
        train_data: Iterator[tuple[mx.array, mx.array]],
        eval_data: Iterator[tuple[mx.array, mx.array]] | None,
    ) -> list[dict[str, Any]]:
        """Training loop with gradient accumulation."""
        history: list[dict[str, Any]] = []
        micro_batches: list[tuple[mx.array, mx.array]] = []

        for batch in train_data:
            if self.step >= self.config.max_steps:
                break

            micro_batches.append(batch)

            if len(micro_batches) == self._accum_steps:
                metrics = self.train_step_accumulated(micro_batches)
                history.append(metrics)
                self._maybe_eval(eval_data, metrics)
                micro_batches = []

        # Handle remaining micro-batches
        if micro_batches and self.step < self.config.max_steps:
            metrics = self.train_step_accumulated(micro_batches)
            history.append(metrics)

        return history

    def _maybe_eval(
        self,
        eval_data: Iterator[tuple[mx.array, mx.array]] | None,
        metrics: dict[str, Any],
    ) -> None:
        """Run evaluation if due."""
        if (
            eval_data is not None
            and self.step % self.config.eval_interval == 0
        ):
            eval_metrics = self.evaluate(eval_data)
            metrics.update(eval_metrics)
            for cb in self.callbacks:
                cb.on_eval_end(self.step, eval_metrics)

    def evaluate(
        self,
        eval_data: Iterator[tuple[mx.array, mx.array]],
    ) -> dict[str, float]:
        """Run evaluation over the eval dataset.

        Args:
            eval_data: Iterator yielding (input, target) batches.

        Returns:
            Dict with 'eval_loss'.
        """
        total_loss = 0.0
        n_batches = 0

        for x, y in eval_data:
            logits, _ = self.model(x)
            logits = logits.reshape(-1, logits.shape[-1])
            targets = y.reshape(-1)
            loss = nn.losses.cross_entropy(logits, targets, reduction="mean")
            mx.eval(loss)
            total_loss += loss.item()
            n_batches += 1

        avg_loss = total_loss / max(n_batches, 1)
        return {"eval_loss": avg_loss}

_accum_steps = config.grad_accumulation_steps instance-attribute

_loss_and_grad = nn.value_and_grad(model, _loss_fn) instance-attribute

_step_fn = mx.compile(self._single_step, inputs=(model.trainable_parameters()), outputs=(model.trainable_parameters())) instance-attribute

callbacks = callbacks or [] instance-attribute

config = config instance-attribute

model = model instance-attribute

optimizer = optimizer instance-attribute

step = 0 instance-attribute

__init__(model, config, optimizer=None, callbacks=None)

Source code in src/lmxlab/training/trainer.py
def __init__(
    self,
    model: LanguageModel,
    config: TrainConfig,
    optimizer: optim.Optimizer | None = None,
    callbacks: list[Callback] | None = None,
) -> None:
    self.model = model
    self.config = config
    if optimizer is not None:
        self.optimizer = optimizer
    elif model.config.mup_base_width is not None:
        self.optimizer = create_mup_optimizer(
            config,
            model.config.width_mult,
        )
    else:
        self.optimizer = create_optimizer(config)
    self.callbacks = callbacks or []
    self.step = 0
    self._accum_steps = config.grad_accumulation_steps

    # Build the training step function
    self._loss_and_grad = nn.value_and_grad(model, _loss_fn)

    if self._accum_steps <= 1:
        # No accumulation: compile full step (fwd + bwd + update)
        # MultiOptimizer (μP) is incompatible with mx.compile
        can_compile = config.compile_step and not isinstance(
            self.optimizer,
            optim.MultiOptimizer,
        )
        if can_compile:
            self._step_fn = mx.compile(
                self._single_step,
                inputs=model.trainable_parameters(),
                outputs=model.trainable_parameters(),
            )
        else:
            self._step_fn = self._single_step

_accumulation_step(micro_batches)

Gradient accumulation step over multiple micro-batches.

Computes gradients for each micro-batch, averages them, then applies a single optimizer update.

Parameters:

Name Type Description Default
micro_batches list[tuple[array, array]]

List of (input, target) micro-batches.

required

Returns:

Type Description
tuple[array, array]

Tuple of (avg_loss, grad_norm).

Source code in src/lmxlab/training/trainer.py
def _accumulation_step(
    self,
    micro_batches: list[tuple[mx.array, mx.array]],
) -> tuple[mx.array, mx.array]:
    """Gradient accumulation step over multiple micro-batches.

    Computes gradients for each micro-batch, averages them,
    then applies a single optimizer update.

    Args:
        micro_batches: List of (input, target) micro-batches.

    Returns:
        Tuple of (avg_loss, grad_norm).
    """
    n = len(micro_batches)
    total_loss = mx.array(0.0)
    acc_grads = None

    for x, y in micro_batches:
        loss, grads = self._loss_and_grad(self.model, x, y)
        total_loss = total_loss + loss
        if acc_grads is None:
            acc_grads = grads
        else:
            acc_grads = tree_map(
                lambda a, b: a + b,
                acc_grads,
                grads,
            )

    # Average gradients
    avg_grads = tree_map(lambda g: g / n, acc_grads)
    avg_loss = total_loss / n

    # Gradient clipping
    if self.config.max_grad_norm > 0:
        avg_grads, grad_norm = optim.clip_grad_norm(
            avg_grads, max_norm=self.config.max_grad_norm
        )
    else:
        flat = tree_flatten(avg_grads)
        grad_norm = mx.sqrt(sum(mx.sum(g * g) for _, g in flat))

    self.optimizer.update(self.model, avg_grads)
    return avg_loss, grad_norm

_maybe_eval(eval_data, metrics)

Run evaluation if due.

Source code in src/lmxlab/training/trainer.py
def _maybe_eval(
    self,
    eval_data: Iterator[tuple[mx.array, mx.array]] | None,
    metrics: dict[str, Any],
) -> None:
    """Run evaluation if due."""
    if (
        eval_data is not None
        and self.step % self.config.eval_interval == 0
    ):
        eval_metrics = self.evaluate(eval_data)
        metrics.update(eval_metrics)
        for cb in self.callbacks:
            cb.on_eval_end(self.step, eval_metrics)

_single_step(x, y)

Single training step: forward + backward + update.

Parameters:

Name Type Description Default
x array

Input tokens (batch, seq_len).

required
y array

Target tokens (batch, seq_len).

required

Returns:

Type Description
array

Tuple of (loss, grad_norm) where grad_norm is the

array

L2 norm of gradients before clipping.

Source code in src/lmxlab/training/trainer.py
def _single_step(
    self,
    x: mx.array,
    y: mx.array,
) -> tuple[mx.array, mx.array]:
    """Single training step: forward + backward + update.

    Args:
        x: Input tokens (batch, seq_len).
        y: Target tokens (batch, seq_len).

    Returns:
        Tuple of (loss, grad_norm) where grad_norm is the
        L2 norm of gradients before clipping.
    """
    loss, grads = self._loss_and_grad(self.model, x, y)

    # Gradient clipping
    if self.config.max_grad_norm > 0:
        grads, grad_norm = optim.clip_grad_norm(
            grads, max_norm=self.config.max_grad_norm
        )
    else:
        flat = tree_flatten(grads)
        grad_norm = mx.sqrt(sum(mx.sum(g * g) for _, g in flat))

    self.optimizer.update(self.model, grads)
    return loss, grad_norm

_train_accumulated(train_data, eval_data)

Training loop with gradient accumulation.

Source code in src/lmxlab/training/trainer.py
def _train_accumulated(
    self,
    train_data: Iterator[tuple[mx.array, mx.array]],
    eval_data: Iterator[tuple[mx.array, mx.array]] | None,
) -> list[dict[str, Any]]:
    """Training loop with gradient accumulation."""
    history: list[dict[str, Any]] = []
    micro_batches: list[tuple[mx.array, mx.array]] = []

    for batch in train_data:
        if self.step >= self.config.max_steps:
            break

        micro_batches.append(batch)

        if len(micro_batches) == self._accum_steps:
            metrics = self.train_step_accumulated(micro_batches)
            history.append(metrics)
            self._maybe_eval(eval_data, metrics)
            micro_batches = []

    # Handle remaining micro-batches
    if micro_batches and self.step < self.config.max_steps:
        metrics = self.train_step_accumulated(micro_batches)
        history.append(metrics)

    return history

_train_simple(train_data, eval_data)

Training loop without gradient accumulation.

Source code in src/lmxlab/training/trainer.py
def _train_simple(
    self,
    train_data: Iterator[tuple[mx.array, mx.array]],
    eval_data: Iterator[tuple[mx.array, mx.array]] | None,
) -> list[dict[str, Any]]:
    """Training loop without gradient accumulation."""
    history: list[dict[str, Any]] = []

    for batch in train_data:
        if self.step >= self.config.max_steps:
            break

        metrics = self.train_step(batch)
        history.append(metrics)
        self._maybe_eval(eval_data, metrics)

    return history

evaluate(eval_data)

Run evaluation over the eval dataset.

Parameters:

Name Type Description Default
eval_data Iterator[tuple[array, array]]

Iterator yielding (input, target) batches.

required

Returns:

Type Description
dict[str, float]

Dict with 'eval_loss'.

Source code in src/lmxlab/training/trainer.py
def evaluate(
    self,
    eval_data: Iterator[tuple[mx.array, mx.array]],
) -> dict[str, float]:
    """Run evaluation over the eval dataset.

    Args:
        eval_data: Iterator yielding (input, target) batches.

    Returns:
        Dict with 'eval_loss'.
    """
    total_loss = 0.0
    n_batches = 0

    for x, y in eval_data:
        logits, _ = self.model(x)
        logits = logits.reshape(-1, logits.shape[-1])
        targets = y.reshape(-1)
        loss = nn.losses.cross_entropy(logits, targets, reduction="mean")
        mx.eval(loss)
        total_loss += loss.item()
        n_batches += 1

    avg_loss = total_loss / max(n_batches, 1)
    return {"eval_loss": avg_loss}

train(train_data, eval_data=None)

Run the full training loop.

When grad_accumulation_steps > 1, collects that many micro-batches before each optimizer update.

Parameters:

Name Type Description Default
train_data Iterator[tuple[array, array]]

Iterator yielding (input, target) batches.

required
eval_data Iterator[tuple[array, array]] | None

Optional eval data iterator.

None

Returns:

Type Description
list[dict[str, Any]]

List of per-step metrics dicts.

Source code in src/lmxlab/training/trainer.py
def train(
    self,
    train_data: Iterator[tuple[mx.array, mx.array]],
    eval_data: Iterator[tuple[mx.array, mx.array]] | None = None,
) -> list[dict[str, Any]]:
    """Run the full training loop.

    When ``grad_accumulation_steps > 1``, collects that many
    micro-batches before each optimizer update.

    Args:
        train_data: Iterator yielding (input, target) batches.
        eval_data: Optional eval data iterator.

    Returns:
        List of per-step metrics dicts.
    """
    for cb in self.callbacks:
        cb.on_train_begin(self.config)

    history: list[dict[str, Any]] = []

    if self._accum_steps > 1:
        history = self._train_accumulated(train_data, eval_data)
    else:
        history = self._train_simple(train_data, eval_data)

    for cb in self.callbacks:
        cb.on_train_end(history)

    return history

train_step(batch)

Execute one training step with eval boundary.

Parameters:

Name Type Description Default
batch tuple[array, array]

Tuple of (input_tokens, target_tokens).

required

Returns:

Type Description
dict[str, float]

Dict with 'loss', 'learning_rate', and

dict[str, float]

'grad_norm'.

Source code in src/lmxlab/training/trainer.py
def train_step(
    self,
    batch: tuple[mx.array, mx.array],
) -> dict[str, float]:
    """Execute one training step with eval boundary.

    Args:
        batch: Tuple of (input_tokens, target_tokens).

    Returns:
        Dict with 'loss', 'learning_rate', and
        'grad_norm'.
    """
    x, y = batch
    loss, grad_norm = self._step_fn(x, y)

    # Explicit eval boundary
    mx.eval(
        loss,
        grad_norm,
        self.model.parameters(),
        self.optimizer.state,
    )

    self.step += 1
    lr = self.optimizer.learning_rate
    if callable(lr):
        lr = lr(self.step)

    metrics = {
        "loss": loss.item(),
        "learning_rate": float(lr),
        "grad_norm": grad_norm.item(),
    }

    for cb in self.callbacks:
        cb.on_step_end(self.step, metrics)

    return metrics

train_step_accumulated(micro_batches)

Execute one accumulated training step.

Averages gradients over micro-batches, then updates once.

Parameters:

Name Type Description Default
micro_batches list[tuple[array, array]]

List of (input, target) micro-batches.

required

Returns:

Type Description
dict[str, float]

Dict with 'loss', 'learning_rate', and

dict[str, float]

'grad_norm'.

Source code in src/lmxlab/training/trainer.py
def train_step_accumulated(
    self,
    micro_batches: list[tuple[mx.array, mx.array]],
) -> dict[str, float]:
    """Execute one accumulated training step.

    Averages gradients over micro-batches, then updates once.

    Args:
        micro_batches: List of (input, target) micro-batches.

    Returns:
        Dict with 'loss', 'learning_rate', and
        'grad_norm'.
    """
    loss, grad_norm = self._accumulation_step(
        micro_batches,
    )

    # Explicit eval boundary
    mx.eval(
        loss,
        grad_norm,
        self.model.parameters(),
        self.optimizer.state,
    )

    self.step += 1
    lr = self.optimizer.learning_rate
    if callable(lr):
        lr = lr(self.step)

    metrics = {
        "loss": loss.item(),
        "learning_rate": float(lr),
        "grad_norm": grad_norm.item(),
    }

    for cb in self.callbacks:
        cb.on_step_end(self.step, metrics)

    return metrics

Training Config

lmxlab.training.config.TrainConfig dataclass

Configuration for the training loop.

Parameters:

Name Type Description Default
learning_rate float

Peak learning rate.

0.0003
weight_decay float

Weight decay coefficient.

0.01
warmup_steps int

Linear warmup steps.

100
max_steps int

Maximum training steps.

1000
batch_size int

Training batch size.

32
grad_accumulation_steps int

Number of micro-batches to accumulate before an optimizer step.

1
max_grad_norm float

Maximum gradient norm for clipping.

1.0
eval_interval int

Steps between evaluations.

100
log_interval int

Steps between logging.

10
checkpoint_interval int

Steps between checkpoints.

500
optimizer str

Optimizer name ('adamw', 'lion', 'adafactor', 'sgd').

'adamw'
lr_schedule str

Learning rate schedule ('cosine', 'linear', 'constant').

'cosine'
compile_step bool

Whether to mx.compile the training step.

True
seed int

Random seed.

42
Source code in src/lmxlab/training/config.py
@dataclass(frozen=True)
class TrainConfig:
    """Configuration for the training loop.

    Args:
        learning_rate: Peak learning rate.
        weight_decay: Weight decay coefficient.
        warmup_steps: Linear warmup steps.
        max_steps: Maximum training steps.
        batch_size: Training batch size.
        grad_accumulation_steps: Number of micro-batches
            to accumulate before an optimizer step.
        max_grad_norm: Maximum gradient norm for clipping.
        eval_interval: Steps between evaluations.
        log_interval: Steps between logging.
        checkpoint_interval: Steps between checkpoints.
        optimizer: Optimizer name ('adamw', 'lion', 'adafactor', 'sgd').
        lr_schedule: Learning rate schedule ('cosine', 'linear', 'constant').
        compile_step: Whether to mx.compile the training step.
        seed: Random seed.
    """

    learning_rate: float = 3e-4
    weight_decay: float = 0.01
    warmup_steps: int = 100
    max_steps: int = 1000
    batch_size: int = 32
    grad_accumulation_steps: int = 1
    max_grad_norm: float = 1.0
    eval_interval: int = 100
    log_interval: int = 10
    checkpoint_interval: int = 500
    optimizer: str = "adamw"
    lr_schedule: str = "cosine"
    compile_step: bool = True
    seed: int = 42

batch_size = 32 class-attribute instance-attribute

checkpoint_interval = 500 class-attribute instance-attribute

compile_step = True class-attribute instance-attribute

eval_interval = 100 class-attribute instance-attribute

grad_accumulation_steps = 1 class-attribute instance-attribute

learning_rate = 0.0003 class-attribute instance-attribute

log_interval = 10 class-attribute instance-attribute

lr_schedule = 'cosine' class-attribute instance-attribute

max_grad_norm = 1.0 class-attribute instance-attribute

max_steps = 1000 class-attribute instance-attribute

optimizer = 'adamw' class-attribute instance-attribute

seed = 42 class-attribute instance-attribute

warmup_steps = 100 class-attribute instance-attribute

weight_decay = 0.01 class-attribute instance-attribute

__init__(learning_rate=0.0003, weight_decay=0.01, warmup_steps=100, max_steps=1000, batch_size=32, grad_accumulation_steps=1, max_grad_norm=1.0, eval_interval=100, log_interval=10, checkpoint_interval=500, optimizer='adamw', lr_schedule='cosine', compile_step=True, seed=42)

Optimizers

lmxlab.training.optimizers.create_optimizer(config)

Create an optimizer with learning rate schedule.

Parameters:

Name Type Description Default
config TrainConfig

Training configuration.

required

Returns:

Type Description
Optimizer

Configured optimizer.

Source code in src/lmxlab/training/optimizers.py
def create_optimizer(
    config: TrainConfig,
) -> optim.Optimizer:
    """Create an optimizer with learning rate schedule.

    Args:
        config: Training configuration.

    Returns:
        Configured optimizer.
    """
    schedule = create_schedule(config)

    if config.optimizer == "adamw":
        return optim.AdamW(
            learning_rate=schedule,
            weight_decay=config.weight_decay,
        )
    elif config.optimizer == "lion":
        return optim.Lion(
            learning_rate=schedule,
            weight_decay=config.weight_decay,
        )
    elif config.optimizer == "adafactor":
        return optim.Adafactor(
            learning_rate=schedule,
        )
    elif config.optimizer == "sgd":
        return optim.SGD(
            learning_rate=schedule,
            momentum=0.9,
            weight_decay=config.weight_decay,
        )
    else:
        raise ValueError(f"Unknown optimizer: {config.optimizer!r}")

lmxlab.training.optimizers.create_schedule(config)

Create a learning rate schedule.

Supports warmup + decay patterns.

Parameters:

Name Type Description Default
config TrainConfig

Training configuration.

required

Returns:

Type Description
Callable[[int], float]

Learning rate scheduler.

Source code in src/lmxlab/training/optimizers.py
def create_schedule(
    config: TrainConfig,
) -> Callable[[int], float]:
    """Create a learning rate schedule.

    Supports warmup + decay patterns.

    Args:
        config: Training configuration.

    Returns:
        Learning rate scheduler.
    """
    warmup = optim.schedulers.linear_schedule(
        init=1e-7,
        end=config.learning_rate,
        steps=config.warmup_steps,
    )
    if config.lr_schedule == "cosine":
        decay = optim.schedulers.cosine_decay(
            init=config.learning_rate,
            decay_steps=config.max_steps - config.warmup_steps,
        )
    elif config.lr_schedule == "linear":
        decay = optim.schedulers.linear_schedule(
            init=config.learning_rate,
            end=0.0,
            steps=config.max_steps - config.warmup_steps,
        )
    elif config.lr_schedule == "constant":
        decay = config.learning_rate
    else:
        raise ValueError(f"Unknown schedule: {config.lr_schedule!r}")

    return optim.schedulers.join_schedules(
        schedules=[warmup, decay],
        boundaries=[config.warmup_steps],
    )

Checkpoints

lmxlab.training.checkpoints.save_checkpoint(path, model, optimizer=None, step=0, metadata=None)

Save a training checkpoint.

Saves model weights as safetensors and metadata as JSON.

Parameters:

Name Type Description Default
path str | Path

Directory to save checkpoint.

required
model LanguageModel

Model to save.

required
optimizer Optimizer | None

Optional optimizer to save state.

None
step int

Current training step.

0
metadata dict[str, Any] | None

Additional metadata to save.

None
Source code in src/lmxlab/training/checkpoints.py
def save_checkpoint(
    path: str | Path,
    model: LanguageModel,
    optimizer: optim.Optimizer | None = None,
    step: int = 0,
    metadata: dict[str, Any] | None = None,
) -> None:
    """Save a training checkpoint.

    Saves model weights as safetensors and metadata as JSON.

    Args:
        path: Directory to save checkpoint.
        model: Model to save.
        optimizer: Optional optimizer to save state.
        step: Current training step.
        metadata: Additional metadata to save.
    """
    path = Path(path)
    path.mkdir(parents=True, exist_ok=True)

    # Save model weights
    weights = dict(mlx.utils.tree_flatten(model.parameters()))
    mx.save_safetensors(str(path / "model.safetensors"), weights)

    # Save optimizer state if provided
    if optimizer is not None:
        opt_state = dict(mlx.utils.tree_flatten(optimizer.state))
        if opt_state:
            mx.save_safetensors(str(path / "optimizer.safetensors"), opt_state)

    # Save metadata
    meta = {
        "step": step,
        **(metadata or {}),
    }
    (path / "metadata.json").write_text(json.dumps(meta, indent=2))

lmxlab.training.checkpoints.load_checkpoint(path, model, optimizer=None)

Load a training checkpoint.

Parameters:

Name Type Description Default
path str | Path

Directory containing checkpoint.

required
model LanguageModel

Model to load weights into.

required
optimizer Optimizer | None

Optional optimizer to load state into.

None

Returns:

Type Description
dict[str, Any]

Metadata dict from the checkpoint.

Source code in src/lmxlab/training/checkpoints.py
def load_checkpoint(
    path: str | Path,
    model: LanguageModel,
    optimizer: optim.Optimizer | None = None,
) -> dict[str, Any]:
    """Load a training checkpoint.

    Args:
        path: Directory containing checkpoint.
        model: Model to load weights into.
        optimizer: Optional optimizer to load state into.

    Returns:
        Metadata dict from the checkpoint.
    """
    path = Path(path)

    # Load model weights
    weights = mx.load(str(path / "model.safetensors"))
    model.load_weights(list(weights.items()))

    # Load optimizer state if available
    opt_path = path / "optimizer.safetensors"
    if optimizer is not None and opt_path.exists():
        opt_state = mx.load(str(opt_path))
        # Reconstruct nested state from flat keys
        optimizer.state = mlx.utils.tree_unflatten(list(opt_state.items()))

    # Load metadata
    meta_path = path / "metadata.json"
    if meta_path.exists():
        return json.loads(meta_path.read_text())
    return {}

Callbacks

lmxlab.training.callbacks.Callback

Bases: Protocol

Protocol for training callbacks.

Source code in src/lmxlab/training/callbacks.py
class Callback(Protocol):
    """Protocol for training callbacks."""

    def on_train_begin(self, config: TrainConfig) -> None: ...
    def on_train_end(self, history: list[dict[str, Any]]) -> None: ...
    def on_step_end(self, step: int, metrics: dict[str, Any]) -> None: ...
    def on_eval_end(self, step: int, metrics: dict[str, Any]) -> None: ...

lmxlab.training.callbacks.MetricsLogger

Logs training metrics at configured intervals.

Parameters:

Name Type Description Default
log_interval int

Steps between log outputs.

10
Source code in src/lmxlab/training/callbacks.py
class MetricsLogger:
    """Logs training metrics at configured intervals.

    Args:
        log_interval: Steps between log outputs.
    """

    def __init__(self, log_interval: int = 10) -> None:
        self.log_interval = log_interval
        self._start_time: float = 0.0

    def on_train_begin(self, config: TrainConfig) -> None:
        self._start_time = time.monotonic()

    def on_train_end(self, history: list[dict[str, Any]]) -> None:
        elapsed = time.monotonic() - self._start_time
        print(f"Training complete: {len(history)} steps in {elapsed:.1f}s")

    def on_step_end(self, step: int, metrics: dict[str, Any]) -> None:
        if step % self.log_interval == 0:
            loss = metrics.get("loss", 0.0)
            lr = metrics.get("learning_rate", 0.0)
            print(f"step {step}: loss={loss:.4f}, lr={lr:.2e}")

    def on_eval_end(self, step: int, metrics: dict[str, Any]) -> None:
        eval_loss = metrics.get("eval_loss", 0.0)
        print(f"step {step} eval: loss={eval_loss:.4f}")

lmxlab.training.callbacks.EarlyStopping

Stop training when eval loss stops improving.

Parameters:

Name Type Description Default
patience int

Steps without improvement before stopping.

5
min_delta float

Minimum change to qualify as improvement.

0.001
Source code in src/lmxlab/training/callbacks.py
class EarlyStopping:
    """Stop training when eval loss stops improving.

    Args:
        patience: Steps without improvement before stopping.
        min_delta: Minimum change to qualify as improvement.
    """

    def __init__(self, patience: int = 5, min_delta: float = 0.001) -> None:
        self.patience = patience
        self.min_delta = min_delta
        self._best_loss: float = float("inf")
        self._wait: int = 0
        self.should_stop: bool = False

    def on_train_begin(self, config: TrainConfig) -> None:
        self._best_loss = float("inf")
        self._wait = 0
        self.should_stop = False

    def on_train_end(self, history: list[dict[str, Any]]) -> None:
        pass

    def on_step_end(self, step: int, metrics: dict[str, Any]) -> None:
        pass

    def on_eval_end(self, step: int, metrics: dict[str, Any]) -> None:
        eval_loss = metrics.get("eval_loss", float("inf"))
        if eval_loss < self._best_loss - self.min_delta:
            self._best_loss = eval_loss
            self._wait = 0
        else:
            self._wait += 1
            if self._wait >= self.patience:
                self.should_stop = True
                print(
                    f"Early stopping at step {step} "
                    f"(best loss: {self._best_loss:.4f})"
                )

lmxlab.training.callbacks.ThroughputMonitor

Tracks training throughput (tokens/sec, steps/sec).

Reports throughput at configured intervals, useful for understanding MLX performance characteristics and comparing compiled vs uncompiled training.

Parameters:

Name Type Description Default
log_interval int

Steps between throughput reports.

10
tokens_per_step int | None

Tokens processed per step (batch_size * seq_len). If None, only reports steps/sec.

None
Source code in src/lmxlab/training/callbacks.py
class ThroughputMonitor:
    """Tracks training throughput (tokens/sec, steps/sec).

    Reports throughput at configured intervals, useful for
    understanding MLX performance characteristics and comparing
    compiled vs uncompiled training.

    Args:
        log_interval: Steps between throughput reports.
        tokens_per_step: Tokens processed per step
            (batch_size * seq_len). If None, only reports
            steps/sec.
    """

    def __init__(
        self,
        log_interval: int = 10,
        tokens_per_step: int | None = None,
    ) -> None:
        self.log_interval = log_interval
        self.tokens_per_step = tokens_per_step
        self._step_times: list[float] = []
        self._last_time: float = 0.0
        self._total_steps: int = 0
        self._train_start: float = 0.0

    def on_train_begin(self, config: TrainConfig) -> None:
        self._last_time = time.monotonic()
        self._train_start = self._last_time
        self._step_times = []
        self._total_steps = 0

    def on_train_end(self, history: list[dict[str, Any]]) -> None:
        elapsed = time.monotonic() - self._train_start
        if self._total_steps > 0 and elapsed > 0:
            avg_steps_sec = self._total_steps / elapsed
            msg = (
                f"Throughput summary: {self._total_steps} steps "
                f"in {elapsed:.1f}s ({avg_steps_sec:.1f} steps/s"
            )
            if self.tokens_per_step is not None:
                total_tokens = self._total_steps * self.tokens_per_step
                tok_sec = total_tokens / elapsed
                msg += f", {tok_sec:.0f} tok/s"
            msg += ")"
            print(msg)

    def on_step_end(self, step: int, metrics: dict[str, Any]) -> None:
        now = time.monotonic()
        dt = now - self._last_time
        self._last_time = now
        self._step_times.append(dt)
        self._total_steps += 1

        # Inject cumulative token count every step
        if self.tokens_per_step is not None:
            total_tokens = self._total_steps * self.tokens_per_step
            metrics["tokens_processed"] = total_tokens

        if step % self.log_interval == 0 and dt > 0:
            # Use recent window for smoother reporting
            window = self._step_times[-self.log_interval :]
            avg_dt = sum(window) / len(window)
            steps_sec = 1.0 / avg_dt if avg_dt > 0 else 0
            metrics["steps_per_sec"] = steps_sec

            msg = f"step {step}: {steps_sec:.1f} steps/s"
            if self.tokens_per_step is not None:
                tok_sec = self.tokens_per_step * steps_sec
                metrics["tokens_per_sec"] = tok_sec
                msg += f", {tok_sec:.0f} tok/s"
            print(msg)

    def on_eval_end(self, step: int, metrics: dict[str, Any]) -> None:
        pass

DPO

lmxlab.training.dpo

Direct Preference Optimization (DPO) training.

dpo_loss(model, ref_model, chosen, rejected, beta=0.1)

Compute DPO loss.

DPO directly optimizes preferences without reward modeling. L = -log(sigmoid(beta * (log_ratio_chosen - log_ratio_rejected)))

Parameters:

Name Type Description Default
model LanguageModel

Policy model being trained.

required
ref_model LanguageModel

Reference (frozen) model.

required
chosen array

Preferred sequence token IDs (batch, seq_len).

required
rejected array

Dispreferred sequence token IDs (batch, seq_len).

required
beta float

Temperature parameter controlling deviation from ref.

0.1

Returns:

Type Description
array

Scalar DPO loss.

Source code in src/lmxlab/training/dpo.py
def dpo_loss(
    model: LanguageModel,
    ref_model: LanguageModel,
    chosen: mx.array,
    rejected: mx.array,
    beta: float = 0.1,
) -> mx.array:
    """Compute DPO loss.

    DPO directly optimizes preferences without reward modeling.
    L = -log(sigmoid(beta * (log_ratio_chosen - log_ratio_rejected)))

    Args:
        model: Policy model being trained.
        ref_model: Reference (frozen) model.
        chosen: Preferred sequence token IDs (batch, seq_len).
        rejected: Dispreferred sequence token IDs (batch, seq_len).
        beta: Temperature parameter controlling deviation from ref.

    Returns:
        Scalar DPO loss.
    """
    # Compute log probs for chosen sequences
    chosen_logits, _ = model(chosen[:, :-1])
    chosen_ref_logits, _ = ref_model(chosen[:, :-1])
    chosen_targets = chosen[:, 1:]

    chosen_logps = _sequence_log_probs(chosen_logits, chosen_targets)
    chosen_ref_logps = _sequence_log_probs(chosen_ref_logits, chosen_targets)

    # Compute log probs for rejected sequences
    rejected_logits, _ = model(rejected[:, :-1])
    rejected_ref_logits, _ = ref_model(rejected[:, :-1])
    rejected_targets = rejected[:, 1:]

    rejected_logps = _sequence_log_probs(rejected_logits, rejected_targets)
    rejected_ref_logps = _sequence_log_probs(
        rejected_ref_logits, rejected_targets
    )

    # DPO objective
    chosen_rewards = beta * (chosen_logps - chosen_ref_logps)
    rejected_rewards = beta * (rejected_logps - rejected_ref_logps)

    # log_sigmoid(x) = -softplus(-x) for numerical stability
    loss = mx.logaddexp(0, -(chosen_rewards - rejected_rewards))
    return mx.mean(loss)

GRPO

lmxlab.training.grpo

Group Relative Policy Optimization (GRPO).

grpo_loss(model, ref_model, prompts, completions, rewards, beta=0.1, epsilon=0.2)

Compute GRPO loss.

GRPO uses group-relative rewards: for each prompt, generate multiple completions, compute rewards, normalize within the group, and optimize using a clipped surrogate objective.

Parameters:

Name Type Description Default
model LanguageModel

Policy model being trained.

required
ref_model LanguageModel

Reference (frozen) model.

required
prompts array

Prompt token IDs (batch, prompt_len).

required
completions array

Full sequences (batch, total_len).

required
rewards array

Scalar rewards per completion (batch,).

required
beta float

KL penalty coefficient.

0.1
epsilon float

Clipping range for surrogate objective.

0.2

Returns:

Type Description
array

Scalar GRPO loss.

Source code in src/lmxlab/training/grpo.py
def grpo_loss(
    model: LanguageModel,
    ref_model: LanguageModel,
    prompts: mx.array,
    completions: mx.array,
    rewards: mx.array,
    beta: float = 0.1,
    epsilon: float = 0.2,
) -> mx.array:
    """Compute GRPO loss.

    GRPO uses group-relative rewards: for each prompt, generate
    multiple completions, compute rewards, normalize within the
    group, and optimize using a clipped surrogate objective.

    Args:
        model: Policy model being trained.
        ref_model: Reference (frozen) model.
        prompts: Prompt token IDs (batch, prompt_len).
        completions: Full sequences (batch, total_len).
        rewards: Scalar rewards per completion (batch,).
        beta: KL penalty coefficient.
        epsilon: Clipping range for surrogate objective.

    Returns:
        Scalar GRPO loss.
    """
    # Normalize rewards within the group (zero mean, unit variance)
    reward_mean = mx.mean(rewards)
    reward_std = mx.maximum(mx.std(rewards), mx.array(1e-8))
    advantages = (rewards - reward_mean) / reward_std

    # Compute log probs for completions
    inputs = completions[:, :-1]
    targets = completions[:, 1:]

    logits, _ = model(inputs)
    ref_logits, _ = ref_model(inputs)

    log_probs = _sequence_log_probs(logits, targets)
    ref_log_probs = _sequence_log_probs(ref_logits, targets)

    # Ratio and clipped surrogate
    ratio = mx.exp(log_probs - ref_log_probs)
    clipped_ratio = mx.clip(ratio, 1.0 - epsilon, 1.0 + epsilon)

    surrogate = mx.minimum(
        ratio * advantages,
        clipped_ratio * advantages,
    )

    # KL penalty
    kl = ref_log_probs - log_probs

    loss = -(surrogate - beta * kl)
    return mx.mean(loss)

Multi-Token Prediction

lmxlab.training.mtp.MTPHead

Bases: Module

Single multi-token prediction head.

Takes hidden states and previous target embeddings, normalizes both, concatenates, projects back to d_model, then runs through a transformer block.

Parameters:

Name Type Description Default
d_model int

Hidden dimension.

required
block_config BlockConfig

Block configuration for the MTP block.

required
Source code in src/lmxlab/training/mtp.py
class MTPHead(nn.Module):
    """Single multi-token prediction head.

    Takes hidden states and previous target embeddings,
    normalizes both, concatenates, projects back to d_model,
    then runs through a transformer block.

    Args:
        d_model: Hidden dimension.
        block_config: Block configuration for the MTP block.
    """

    def __init__(
        self,
        d_model: int,
        block_config: BlockConfig,
    ) -> None:
        super().__init__()
        self.hidden_norm = nn.RMSNorm(d_model)
        self.embed_norm = nn.RMSNorm(d_model)
        self.proj = nn.Linear(
            2 * d_model,
            d_model,
            bias=False,
        )
        self.block = ConfigurableBlock(block_config)

    def __call__(
        self,
        h: mx.array,
        prev_embed: mx.array,
        mask: mx.array | None = None,
    ) -> mx.array:
        """Produce hidden states for future token prediction.

        Args:
            h: Hidden states (batch, seq_len, d_model).
            prev_embed: Embeddings of previous target tokens
                (batch, seq_len, d_model).
            mask: Optional causal mask.

        Returns:
            Hidden states (batch, seq_len, d_model).
        """
        combined = mx.concatenate(
            [
                self.hidden_norm(h),
                self.embed_norm(prev_embed),
            ],
            axis=-1,
        )
        projected = self.proj(combined)
        out, _ = self.block(projected, mask=mask)
        return out

__init__(d_model, block_config)

Source code in src/lmxlab/training/mtp.py
def __init__(
    self,
    d_model: int,
    block_config: BlockConfig,
) -> None:
    super().__init__()
    self.hidden_norm = nn.RMSNorm(d_model)
    self.embed_norm = nn.RMSNorm(d_model)
    self.proj = nn.Linear(
        2 * d_model,
        d_model,
        bias=False,
    )
    self.block = ConfigurableBlock(block_config)

__call__(h, prev_embed, mask=None)

Produce hidden states for future token prediction.

Parameters:

Name Type Description Default
h array

Hidden states (batch, seq_len, d_model).

required
prev_embed array

Embeddings of previous target tokens (batch, seq_len, d_model).

required
mask array | None

Optional causal mask.

None

Returns:

Type Description
array

Hidden states (batch, seq_len, d_model).

Source code in src/lmxlab/training/mtp.py
def __call__(
    self,
    h: mx.array,
    prev_embed: mx.array,
    mask: mx.array | None = None,
) -> mx.array:
    """Produce hidden states for future token prediction.

    Args:
        h: Hidden states (batch, seq_len, d_model).
        prev_embed: Embeddings of previous target tokens
            (batch, seq_len, d_model).
        mask: Optional causal mask.

    Returns:
        Hidden states (batch, seq_len, d_model).
    """
    combined = mx.concatenate(
        [
            self.hidden_norm(h),
            self.embed_norm(prev_embed),
        ],
        axis=-1,
    )
    projected = self.proj(combined)
    out, _ = self.block(projected, mask=mask)
    return out

lmxlab.training.mtp.MultiTokenPrediction

Bases: Module

Multi-Token Prediction wrapper around a LanguageModel.

Adds n_predict auxiliary prediction heads that predict future tokens at each position. Shares the base model's lm_head for logit projection.

Training-only module. At inference time, use the base model directly.

Parameters:

Name Type Description Default
model LanguageModel

Base language model.

required
n_predict int

Number of future tokens to predict.

2
mtp_weight float

Weight for auxiliary MTP losses.

0.3
block_config BlockConfig | None

Block config for MTP heads. If None, uses the base model's block config.

None
Source code in src/lmxlab/training/mtp.py
class MultiTokenPrediction(nn.Module):
    """Multi-Token Prediction wrapper around a LanguageModel.

    Adds n_predict auxiliary prediction heads that predict
    future tokens at each position. Shares the base model's
    lm_head for logit projection.

    Training-only module. At inference time, use the base
    model directly.

    Args:
        model: Base language model.
        n_predict: Number of future tokens to predict.
        mtp_weight: Weight for auxiliary MTP losses.
        block_config: Block config for MTP heads. If None,
            uses the base model's block config.
    """

    def __init__(
        self,
        model: LanguageModel,
        n_predict: int = 2,
        mtp_weight: float = 0.3,
        block_config: BlockConfig | None = None,
    ) -> None:
        super().__init__()
        self.model = model
        self.n_predict = n_predict
        self.mtp_weight = mtp_weight

        d_model = model.config.block.d_model

        # Default MTP block: lightweight attention block
        if block_config is None:
            block_config = BlockConfig(
                attention="mha",
                ffn="standard",
                norm="rms_norm",
                position="none",
                d_model=d_model,
                n_heads=max(1, model.config.block.n_heads // 2),
                d_ff=d_model * 2,
                bias=False,
                pre_norm=True,
            )

        self.mtp_heads = [
            MTPHead(d_model, block_config) for _ in range(n_predict)
        ]

    def _project_logits(self, h: mx.array) -> mx.array:
        """Project hidden states to logits using shared head.

        Args:
            h: Hidden states (batch, seq_len, d_model).

        Returns:
            Logits (batch, seq_len, vocab_size).
        """
        if self.model.config.tie_embeddings:
            return h @ self.model.embed.weight.T
        return self.model.head(h)

    def __call__(
        self,
        x: mx.array,
        targets: mx.array,
    ) -> tuple[mx.array, dict[str, mx.array]]:
        """Forward pass with multi-token prediction.

        Args:
            x: Input token IDs (batch, seq_len).
            targets: Target token IDs (batch, seq_len).
                For MTP depth k, predicts target at t+k.

        Returns:
            Tuple of (main_logits, loss_dict) where loss_dict
            contains 'main_loss', 'mtp_loss', 'total_loss'.
        """
        # Forward with hidden states
        logits, _, hidden = self.model(
            x,
            return_hidden=True,
        )

        # Main loss (next token prediction)
        main_logits = logits[:, :-1, :]
        main_targets = targets[:, : main_logits.shape[1]]
        main_loss = nn.losses.cross_entropy(
            main_logits.reshape(-1, main_logits.shape[-1]),
            main_targets.reshape(-1),
            reduction="mean",
        )

        # MTP losses for each prediction depth
        mtp_losses = []
        h = hidden
        for k, head in enumerate(self.mtp_heads, start=1):
            if targets.shape[1] <= k:
                continue

            # Hidden states for positions with enough future
            h_slice = h[:, :-k, :]

            # Embeddings of tokens at position t+k-1
            prev_toks = targets[:, k - 1 : -1]
            if prev_toks.shape[1] > h_slice.shape[1]:
                prev_toks = prev_toks[:, : h_slice.shape[1]]
            prev_embed = self.model.embed(prev_toks)

            # MTP head produces new hidden states
            h_mtp = head(h_slice, prev_embed)

            # Project to logits using shared lm_head
            mtp_logits = self._project_logits(h_mtp)
            mtp_targets = targets[:, k : k + mtp_logits.shape[1]]

            if mtp_targets.shape[1] > 0:
                mtp_loss = nn.losses.cross_entropy(
                    mtp_logits.reshape(
                        -1,
                        mtp_logits.shape[-1],
                    ),
                    mtp_targets.reshape(-1),
                    reduction="mean",
                )
                mtp_losses.append(mtp_loss)

            # Chain: next head uses this head's output
            h = h_mtp

        # Combine losses
        if mtp_losses:
            avg_mtp_loss = sum(mtp_losses) / len(mtp_losses)
        else:
            avg_mtp_loss = mx.array(0.0)

        total_loss = main_loss + self.mtp_weight * avg_mtp_loss

        losses = {
            "main_loss": main_loss,
            "mtp_loss": avg_mtp_loss,
            "total_loss": total_loss,
        }

        return logits, losses

__init__(model, n_predict=2, mtp_weight=0.3, block_config=None)

Source code in src/lmxlab/training/mtp.py
def __init__(
    self,
    model: LanguageModel,
    n_predict: int = 2,
    mtp_weight: float = 0.3,
    block_config: BlockConfig | None = None,
) -> None:
    super().__init__()
    self.model = model
    self.n_predict = n_predict
    self.mtp_weight = mtp_weight

    d_model = model.config.block.d_model

    # Default MTP block: lightweight attention block
    if block_config is None:
        block_config = BlockConfig(
            attention="mha",
            ffn="standard",
            norm="rms_norm",
            position="none",
            d_model=d_model,
            n_heads=max(1, model.config.block.n_heads // 2),
            d_ff=d_model * 2,
            bias=False,
            pre_norm=True,
        )

    self.mtp_heads = [
        MTPHead(d_model, block_config) for _ in range(n_predict)
    ]

__call__(x, targets)

Forward pass with multi-token prediction.

Parameters:

Name Type Description Default
x array

Input token IDs (batch, seq_len).

required
targets array

Target token IDs (batch, seq_len). For MTP depth k, predicts target at t+k.

required

Returns:

Type Description
array

Tuple of (main_logits, loss_dict) where loss_dict

dict[str, array]

contains 'main_loss', 'mtp_loss', 'total_loss'.

Source code in src/lmxlab/training/mtp.py
def __call__(
    self,
    x: mx.array,
    targets: mx.array,
) -> tuple[mx.array, dict[str, mx.array]]:
    """Forward pass with multi-token prediction.

    Args:
        x: Input token IDs (batch, seq_len).
        targets: Target token IDs (batch, seq_len).
            For MTP depth k, predicts target at t+k.

    Returns:
        Tuple of (main_logits, loss_dict) where loss_dict
        contains 'main_loss', 'mtp_loss', 'total_loss'.
    """
    # Forward with hidden states
    logits, _, hidden = self.model(
        x,
        return_hidden=True,
    )

    # Main loss (next token prediction)
    main_logits = logits[:, :-1, :]
    main_targets = targets[:, : main_logits.shape[1]]
    main_loss = nn.losses.cross_entropy(
        main_logits.reshape(-1, main_logits.shape[-1]),
        main_targets.reshape(-1),
        reduction="mean",
    )

    # MTP losses for each prediction depth
    mtp_losses = []
    h = hidden
    for k, head in enumerate(self.mtp_heads, start=1):
        if targets.shape[1] <= k:
            continue

        # Hidden states for positions with enough future
        h_slice = h[:, :-k, :]

        # Embeddings of tokens at position t+k-1
        prev_toks = targets[:, k - 1 : -1]
        if prev_toks.shape[1] > h_slice.shape[1]:
            prev_toks = prev_toks[:, : h_slice.shape[1]]
        prev_embed = self.model.embed(prev_toks)

        # MTP head produces new hidden states
        h_mtp = head(h_slice, prev_embed)

        # Project to logits using shared lm_head
        mtp_logits = self._project_logits(h_mtp)
        mtp_targets = targets[:, k : k + mtp_logits.shape[1]]

        if mtp_targets.shape[1] > 0:
            mtp_loss = nn.losses.cross_entropy(
                mtp_logits.reshape(
                    -1,
                    mtp_logits.shape[-1],
                ),
                mtp_targets.reshape(-1),
                reduction="mean",
            )
            mtp_losses.append(mtp_loss)

        # Chain: next head uses this head's output
        h = h_mtp

    # Combine losses
    if mtp_losses:
        avg_mtp_loss = sum(mtp_losses) / len(mtp_losses)
    else:
        avg_mtp_loss = mx.array(0.0)

    total_loss = main_loss + self.mtp_weight * avg_mtp_loss

    losses = {
        "main_loss": main_loss,
        "mtp_loss": avg_mtp_loss,
        "total_loss": total_loss,
    }

    return logits, losses

Curriculum Learning

lmxlab.training.curriculum

Curriculum learning utilities.

difficulty_curriculum(easy_data, hard_data, batch_size, seq_len, n_batches=200, warmup_fraction=0.5)

Mix easy and hard data with increasing difficulty.

Starts with mostly easy data and transitions to hard data.

Parameters:

Name Type Description Default
easy_data array

Token array of easier text.

required
hard_data array

Token array of harder text.

required
batch_size int

Sequences per batch.

required
seq_len int

Sequence length.

required
n_batches int

Total number of batches.

200
warmup_fraction float

Fraction of training spent warming up.

0.5

Yields:

Type Description
tuple[array, array]

(input, target) tuples with mixed difficulty.

Source code in src/lmxlab/training/curriculum.py
def difficulty_curriculum(
    easy_data: mx.array,
    hard_data: mx.array,
    batch_size: int,
    seq_len: int,
    n_batches: int = 200,
    warmup_fraction: float = 0.5,
) -> Iterator[tuple[mx.array, mx.array]]:
    """Mix easy and hard data with increasing difficulty.

    Starts with mostly easy data and transitions to hard data.

    Args:
        easy_data: Token array of easier text.
        hard_data: Token array of harder text.
        batch_size: Sequences per batch.
        seq_len: Sequence length.
        n_batches: Total number of batches.
        warmup_fraction: Fraction of training spent warming up.

    Yields:
        (input, target) tuples with mixed difficulty.
    """
    for i in range(n_batches):
        # Hard data fraction increases linearly
        progress = i / max(n_batches - 1, 1)
        hard_fraction = min(progress / warmup_fraction, 1.0)
        n_hard = int(batch_size * hard_fraction)
        n_easy = batch_size - n_hard

        inputs_list = []
        targets_list = []

        # Sample from easy data
        if n_easy > 0:
            starts = mx.random.randint(
                0, len(easy_data) - seq_len - 1, shape=(n_easy,)
            )
            mx.eval(starts)
            for s_val in cast(list[int], starts.tolist()):
                inputs_list.append(easy_data[s_val : s_val + seq_len])
                targets_list.append(easy_data[s_val + 1 : s_val + seq_len + 1])

        # Sample from hard data
        if n_hard > 0:
            starts = mx.random.randint(
                0, len(hard_data) - seq_len - 1, shape=(n_hard,)
            )
            mx.eval(starts)
            for s_val in cast(list[int], starts.tolist()):
                inputs_list.append(hard_data[s_val : s_val + seq_len])
                targets_list.append(hard_data[s_val + 1 : s_val + seq_len + 1])

        yield mx.stack(inputs_list), mx.stack(targets_list)

length_curriculum(tokens, batch_size, min_seq_len=32, max_seq_len=512, n_stages=4, batches_per_stage=100)

Generate batches with increasing sequence length.

Starts with short sequences and gradually increases, following curriculum learning principles.

Parameters:

Name Type Description Default
tokens array

Flat array of token IDs.

required
batch_size int

Sequences per batch.

required
min_seq_len int

Starting sequence length.

32
max_seq_len int

Final sequence length.

512
n_stages int

Number of curriculum stages.

4
batches_per_stage int

Batches per stage.

100

Yields:

Type Description
tuple[array, array]

(input, target) tuples with progressively longer sequences.

Source code in src/lmxlab/training/curriculum.py
def length_curriculum(
    tokens: mx.array,
    batch_size: int,
    min_seq_len: int = 32,
    max_seq_len: int = 512,
    n_stages: int = 4,
    batches_per_stage: int = 100,
) -> Iterator[tuple[mx.array, mx.array]]:
    """Generate batches with increasing sequence length.

    Starts with short sequences and gradually increases,
    following curriculum learning principles.

    Args:
        tokens: Flat array of token IDs.
        batch_size: Sequences per batch.
        min_seq_len: Starting sequence length.
        max_seq_len: Final sequence length.
        n_stages: Number of curriculum stages.
        batches_per_stage: Batches per stage.

    Yields:
        (input, target) tuples with progressively longer sequences.
    """
    for stage in range(n_stages):
        # Linear interpolation of sequence length
        progress = stage / max(n_stages - 1, 1)
        seq_len = int(min_seq_len + progress * (max_seq_len - min_seq_len))

        n_tokens = len(tokens)
        n_sequences = (n_tokens - 1) // seq_len

        if n_sequences < batch_size:
            continue

        for _ in range(batches_per_stage):
            # Random starting positions
            starts = mx.random.randint(
                0, n_tokens - seq_len - 1, shape=(batch_size,)
            )
            mx.eval(starts)

            inputs_list = []
            targets_list = []
            for s_val in cast(list[int], starts.tolist()):
                inputs_list.append(tokens[s_val : s_val + seq_len])
                targets_list.append(tokens[s_val + 1 : s_val + seq_len + 1])

            yield mx.stack(inputs_list), mx.stack(targets_list)

Knowledge Distillation

lmxlab.training.distillation

Knowledge distillation training losses.

Implements teacher-student distillation where a smaller student model learns from a larger teacher model's soft probability distributions (Hinton et al., 2015, arXiv:1503.02531).

The key insight: soft targets carry more information than hard labels. A teacher assigning 0.7 to "cat" and 0.2 to "kitten" teaches the student about word similarity, not just correctness.

Supported modes:

  • Logit distillation: KL divergence between temperature-scaled teacher and student logits. The standard approach.
  • Combined loss: Weighted mix of distillation loss and standard cross-entropy on hard targets. Balances learning from teacher with learning from ground truth.

Example::

from lmxlab.training.distillation import distillation_loss

# Teacher is frozen, student is trained
loss = distillation_loss(
    student, teacher, tokens,
    temperature=4.0, alpha=0.7,
)

distillation_loss(student, teacher, tokens, temperature=4.0, alpha=0.7)

Compute combined distillation + hard-target loss.

Loss = alpha * KL(teacher || student) * T^2 + (1 - alpha) * CE(student, targets)

The T^2 factor compensates for the gradient magnitude reduction caused by temperature scaling (Hinton et al.).

Parameters:

Name Type Description Default
student LanguageModel

Student model (being trained).

required
teacher LanguageModel

Teacher model (frozen, no gradients).

required
tokens array

Input token IDs (batch, seq_len). Targets are tokens shifted by one position.

required
temperature float

Softmax temperature for soft targets. Higher = softer distributions, more knowledge transfer. Typical values: 2-10.

4.0
alpha float

Weight for distillation loss (0-1). Higher means more reliance on teacher, less on hard targets.

0.7

Returns:

Type Description
array

Scalar combined loss.

Source code in src/lmxlab/training/distillation.py
def distillation_loss(
    student: LanguageModel,
    teacher: LanguageModel,
    tokens: mx.array,
    temperature: float = 4.0,
    alpha: float = 0.7,
) -> mx.array:
    """Compute combined distillation + hard-target loss.

    Loss = alpha * KL(teacher || student) * T^2
         + (1 - alpha) * CE(student, targets)

    The T^2 factor compensates for the gradient magnitude
    reduction caused by temperature scaling (Hinton et al.).

    Args:
        student: Student model (being trained).
        teacher: Teacher model (frozen, no gradients).
        tokens: Input token IDs (batch, seq_len). Targets are
            tokens shifted by one position.
        temperature: Softmax temperature for soft targets.
            Higher = softer distributions, more knowledge
            transfer. Typical values: 2-10.
        alpha: Weight for distillation loss (0-1). Higher means
            more reliance on teacher, less on hard targets.

    Returns:
        Scalar combined loss.
    """
    inputs = tokens[:, :-1]
    targets = tokens[:, 1:]

    # Student forward pass (will receive gradients)
    student_logits, _ = student(inputs)

    # Teacher forward pass (no gradients needed)
    teacher_logits, _ = teacher(inputs)

    # Distillation component: KL divergence on soft targets
    kl = soft_target_loss(student_logits, teacher_logits, temperature)

    if alpha >= 1.0:
        return kl

    # Hard target component: standard cross-entropy
    ce = _cross_entropy(student_logits, targets)

    return alpha * kl + (1.0 - alpha) * ce

soft_target_loss(student_logits, teacher_logits, temperature=4.0)

KL divergence between temperature-scaled distributions.

KL(teacher || student) computed on softened logits. Multiplied by T^2 to maintain gradient scale.

Parameters:

Name Type Description Default
student_logits array

Student output (batch, seq_len, vocab).

required
teacher_logits array

Teacher output (batch, seq_len, vocab).

required
temperature float

Softmax temperature.

4.0

Returns:

Type Description
array

Scalar KL divergence loss.

Source code in src/lmxlab/training/distillation.py
def soft_target_loss(
    student_logits: mx.array,
    teacher_logits: mx.array,
    temperature: float = 4.0,
) -> mx.array:
    """KL divergence between temperature-scaled distributions.

    KL(teacher || student) computed on softened logits.
    Multiplied by T^2 to maintain gradient scale.

    Args:
        student_logits: Student output (batch, seq_len, vocab).
        teacher_logits: Teacher output (batch, seq_len, vocab).
        temperature: Softmax temperature.

    Returns:
        Scalar KL divergence loss.
    """
    # Temperature-scaled log-softmax
    student_log_probs = nn.log_softmax(
        student_logits / temperature,
        axis=-1,
    )
    teacher_log_probs = nn.log_softmax(
        teacher_logits / temperature,
        axis=-1,
    )

    # KL(P || Q) = sum(P * (log P - log Q))
    teacher_probs = mx.exp(teacher_log_probs)
    kl = mx.sum(
        teacher_probs * (teacher_log_probs - student_log_probs),
        axis=-1,
    )

    # T^2 scaling (Hinton et al.)
    return mx.mean(kl) * (temperature**2)