Skip to content

Training

Compiled training loop, optimizers, and training utilities.

Trainer

lmxlab.training.trainer.Trainer

Training loop with compiled steps and gradient accumulation.

Uses nn.value_and_grad for functional gradient computation and mx.compile for the full training step. When grad_accumulation_steps > 1, gradients are averaged over multiple micro-batches before a single optimizer update.

Parameters:

Name Type Description Default
model LanguageModel

Language model to train.

required
config TrainConfig

Training configuration.

required
optimizer Optimizer | None

Optional pre-built optimizer.

None
callbacks list[Callback] | None

Optional list of callbacks.

None
Source code in src/lmxlab/training/trainer.py
class Trainer:
    """Training loop with compiled steps and gradient accumulation.

    Uses nn.value_and_grad for functional gradient computation
    and mx.compile for the full training step. When
    ``grad_accumulation_steps > 1``, gradients are averaged over
    multiple micro-batches before a single optimizer update.

    Args:
        model: Language model to train.
        config: Training configuration.
        optimizer: Optional pre-built optimizer.
        callbacks: Optional list of callbacks.
    """

    def __init__(
        self,
        model: LanguageModel,
        config: TrainConfig,
        optimizer: optim.Optimizer | None = None,
        callbacks: list[Callback] | None = None,
    ) -> None:
        self.model = model
        self.config = config
        if optimizer is not None:
            self.optimizer = optimizer
        elif model.config.mup_base_width is not None:
            self.optimizer = create_mup_optimizer(
                config,
                model.config.width_mult,
            )
        else:
            self.optimizer = create_optimizer(config)
        self.callbacks = callbacks or []
        self.step = 0
        self._accum_steps = config.grad_accumulation_steps

        # Build the training step function
        self._loss_and_grad = nn.value_and_grad(model, _loss_fn)

        if self._accum_steps <= 1:
            # No accumulation: compile full step (fwd + bwd + update)
            # MultiOptimizer (μP) is incompatible with mx.compile
            can_compile = config.compile_step and not isinstance(
                self.optimizer,
                optim.MultiOptimizer,
            )
            if can_compile:
                self._step_fn = mx.compile(
                    self._single_step,
                    inputs=model.trainable_parameters(),
                    outputs=model.trainable_parameters(),
                )
            else:
                self._step_fn = self._single_step

    def _single_step(
        self,
        x: mx.array,
        y: mx.array,
    ) -> tuple[mx.array, mx.array]:
        """Single training step: forward + backward + update.

        Args:
            x: Input tokens (batch, seq_len).
            y: Target tokens (batch, seq_len).

        Returns:
            Tuple of (loss, grad_norm) where grad_norm is the
            L2 norm of gradients before clipping.
        """
        loss, grads = self._loss_and_grad(self.model, x, y)

        # Gradient clipping
        if self.config.max_grad_norm > 0:
            grads, grad_norm = optim.clip_grad_norm(
                grads, max_norm=self.config.max_grad_norm
            )
        else:
            flat = tree_flatten(grads)
            grad_norm = mx.sqrt(sum(mx.sum(g * g) for _, g in flat))

        self.optimizer.update(self.model, grads)
        return loss, grad_norm

    def _accumulation_step(
        self,
        micro_batches: list[tuple[mx.array, mx.array]],
    ) -> tuple[mx.array, mx.array]:
        """Gradient accumulation step over multiple micro-batches.

        Computes gradients for each micro-batch, averages them,
        then applies a single optimizer update.

        Args:
            micro_batches: List of (input, target) micro-batches.

        Returns:
            Tuple of (avg_loss, grad_norm).
        """
        n = len(micro_batches)
        total_loss = mx.array(0.0)
        acc_grads = None

        for x, y in micro_batches:
            loss, grads = self._loss_and_grad(self.model, x, y)
            total_loss = total_loss + loss
            if acc_grads is None:
                acc_grads = grads
            else:
                acc_grads = tree_map(
                    lambda a, b: a + b,
                    acc_grads,
                    grads,
                )

        # Average gradients
        avg_grads = tree_map(lambda g: g / n, acc_grads)
        avg_loss = total_loss / n

        # Gradient clipping
        if self.config.max_grad_norm > 0:
            avg_grads, grad_norm = optim.clip_grad_norm(
                avg_grads, max_norm=self.config.max_grad_norm
            )
        else:
            flat = tree_flatten(avg_grads)
            grad_norm = mx.sqrt(sum(mx.sum(g * g) for _, g in flat))

        self.optimizer.update(self.model, avg_grads)
        return avg_loss, grad_norm

    def train_step(
        self,
        batch: tuple[mx.array, mx.array],
    ) -> dict[str, float]:
        """Execute one training step with eval boundary.

        Args:
            batch: Tuple of (input_tokens, target_tokens).

        Returns:
            Dict with 'loss', 'learning_rate', and
            'grad_norm'.
        """
        x, y = batch
        loss, grad_norm = self._step_fn(x, y)

        # Explicit eval boundary
        mx.eval(
            loss,
            grad_norm,
            self.model.parameters(),
            self.optimizer.state,
        )

        self.step += 1
        lr = self.optimizer.learning_rate
        if callable(lr):
            lr = lr(self.step)

        metrics = {
            "loss": loss.item(),
            "learning_rate": float(lr),
            "grad_norm": grad_norm.item(),
        }

        for cb in self.callbacks:
            cb.on_step_end(self.step, metrics)

        return metrics

    def train_step_accumulated(
        self,
        micro_batches: list[tuple[mx.array, mx.array]],
    ) -> dict[str, float]:
        """Execute one accumulated training step.

        Averages gradients over micro-batches, then updates once.

        Args:
            micro_batches: List of (input, target) micro-batches.

        Returns:
            Dict with 'loss', 'learning_rate', and
            'grad_norm'.
        """
        loss, grad_norm = self._accumulation_step(
            micro_batches,
        )

        # Explicit eval boundary
        mx.eval(
            loss,
            grad_norm,
            self.model.parameters(),
            self.optimizer.state,
        )

        self.step += 1
        lr = self.optimizer.learning_rate
        if callable(lr):
            lr = lr(self.step)

        metrics = {
            "loss": loss.item(),
            "learning_rate": float(lr),
            "grad_norm": grad_norm.item(),
        }

        for cb in self.callbacks:
            cb.on_step_end(self.step, metrics)

        return metrics

    def train(
        self,
        train_data: Iterator[tuple[mx.array, mx.array]],
        eval_data: Iterator[tuple[mx.array, mx.array]] | None = None,
    ) -> list[dict[str, Any]]:
        """Run the full training loop.

        When ``grad_accumulation_steps > 1``, collects that many
        micro-batches before each optimizer update.

        Args:
            train_data: Iterator yielding (input, target) batches.
            eval_data: Optional eval data iterator.

        Returns:
            List of per-step metrics dicts.
        """
        for cb in self.callbacks:
            cb.on_train_begin(self.config)

        history: list[dict[str, Any]] = []

        if self._accum_steps > 1:
            history = self._train_accumulated(train_data, eval_data)
        else:
            history = self._train_simple(train_data, eval_data)

        for cb in self.callbacks:
            cb.on_train_end(history)

        return history

    def _train_simple(
        self,
        train_data: Iterator[tuple[mx.array, mx.array]],
        eval_data: Iterator[tuple[mx.array, mx.array]] | None,
    ) -> list[dict[str, Any]]:
        """Training loop without gradient accumulation."""
        history: list[dict[str, Any]] = []

        for batch in train_data:
            if self.step >= self.config.max_steps:
                break

            metrics = self.train_step(batch)
            history.append(metrics)
            self._maybe_eval(eval_data, metrics)

        return history

    def _train_accumulated(
        self,
        train_data: Iterator[tuple[mx.array, mx.array]],
        eval_data: Iterator[tuple[mx.array, mx.array]] | None,
    ) -> list[dict[str, Any]]:
        """Training loop with gradient accumulation."""
        history: list[dict[str, Any]] = []
        micro_batches: list[tuple[mx.array, mx.array]] = []

        for batch in train_data:
            if self.step >= self.config.max_steps:
                break

            micro_batches.append(batch)

            if len(micro_batches) == self._accum_steps:
                metrics = self.train_step_accumulated(micro_batches)
                history.append(metrics)
                self._maybe_eval(eval_data, metrics)
                micro_batches = []

        # Handle remaining micro-batches
        if micro_batches and self.step < self.config.max_steps:
            metrics = self.train_step_accumulated(micro_batches)
            history.append(metrics)

        return history

    def _maybe_eval(
        self,
        eval_data: Iterator[tuple[mx.array, mx.array]] | None,
        metrics: dict[str, Any],
    ) -> None:
        """Run evaluation if due."""
        if (
            eval_data is not None
            and self.step % self.config.eval_interval == 0
        ):
            eval_metrics = self.evaluate(eval_data)
            metrics.update(eval_metrics)
            for cb in self.callbacks:
                cb.on_eval_end(self.step, eval_metrics)

    def evaluate(
        self,
        eval_data: Iterator[tuple[mx.array, mx.array]],
    ) -> dict[str, float]:
        """Run evaluation over the eval dataset.

        Args:
            eval_data: Iterator yielding (input, target) batches.

        Returns:
            Dict with 'eval_loss'.
        """
        total_loss = 0.0
        n_batches = 0

        for x, y in eval_data:
            logits, _ = self.model(x)
            logits = logits.reshape(-1, logits.shape[-1])
            targets = y.reshape(-1)
            loss = nn.losses.cross_entropy(logits, targets, reduction="mean")
            mx.eval(loss)
            total_loss += loss.item()
            n_batches += 1

        avg_loss = total_loss / max(n_batches, 1)
        return {"eval_loss": avg_loss}

_accum_steps = config.grad_accumulation_steps instance-attribute

_loss_and_grad = nn.value_and_grad(model, _loss_fn) instance-attribute

_step_fn = mx.compile(self._single_step, inputs=(model.trainable_parameters()), outputs=(model.trainable_parameters())) instance-attribute

callbacks = callbacks or [] instance-attribute

config = config instance-attribute

model = model instance-attribute

optimizer = optimizer instance-attribute

step = 0 instance-attribute

__init__(model, config, optimizer=None, callbacks=None)

Source code in src/lmxlab/training/trainer.py
def __init__(
    self,
    model: LanguageModel,
    config: TrainConfig,
    optimizer: optim.Optimizer | None = None,
    callbacks: list[Callback] | None = None,
) -> None:
    self.model = model
    self.config = config
    if optimizer is not None:
        self.optimizer = optimizer
    elif model.config.mup_base_width is not None:
        self.optimizer = create_mup_optimizer(
            config,
            model.config.width_mult,
        )
    else:
        self.optimizer = create_optimizer(config)
    self.callbacks = callbacks or []
    self.step = 0
    self._accum_steps = config.grad_accumulation_steps

    # Build the training step function
    self._loss_and_grad = nn.value_and_grad(model, _loss_fn)

    if self._accum_steps <= 1:
        # No accumulation: compile full step (fwd + bwd + update)
        # MultiOptimizer (μP) is incompatible with mx.compile
        can_compile = config.compile_step and not isinstance(
            self.optimizer,
            optim.MultiOptimizer,
        )
        if can_compile:
            self._step_fn = mx.compile(
                self._single_step,
                inputs=model.trainable_parameters(),
                outputs=model.trainable_parameters(),
            )
        else:
            self._step_fn = self._single_step

_accumulation_step(micro_batches)

Gradient accumulation step over multiple micro-batches.

Computes gradients for each micro-batch, averages them, then applies a single optimizer update.

Parameters:

Name Type Description Default
micro_batches list[tuple[array, array]]

List of (input, target) micro-batches.

required

Returns:

Type Description
tuple[array, array]

Tuple of (avg_loss, grad_norm).

Source code in src/lmxlab/training/trainer.py
def _accumulation_step(
    self,
    micro_batches: list[tuple[mx.array, mx.array]],
) -> tuple[mx.array, mx.array]:
    """Gradient accumulation step over multiple micro-batches.

    Computes gradients for each micro-batch, averages them,
    then applies a single optimizer update.

    Args:
        micro_batches: List of (input, target) micro-batches.

    Returns:
        Tuple of (avg_loss, grad_norm).
    """
    n = len(micro_batches)
    total_loss = mx.array(0.0)
    acc_grads = None

    for x, y in micro_batches:
        loss, grads = self._loss_and_grad(self.model, x, y)
        total_loss = total_loss + loss
        if acc_grads is None:
            acc_grads = grads
        else:
            acc_grads = tree_map(
                lambda a, b: a + b,
                acc_grads,
                grads,
            )

    # Average gradients
    avg_grads = tree_map(lambda g: g / n, acc_grads)
    avg_loss = total_loss / n

    # Gradient clipping
    if self.config.max_grad_norm > 0:
        avg_grads, grad_norm = optim.clip_grad_norm(
            avg_grads, max_norm=self.config.max_grad_norm
        )
    else:
        flat = tree_flatten(avg_grads)
        grad_norm = mx.sqrt(sum(mx.sum(g * g) for _, g in flat))

    self.optimizer.update(self.model, avg_grads)
    return avg_loss, grad_norm

_maybe_eval(eval_data, metrics)

Run evaluation if due.

Source code in src/lmxlab/training/trainer.py
def _maybe_eval(
    self,
    eval_data: Iterator[tuple[mx.array, mx.array]] | None,
    metrics: dict[str, Any],
) -> None:
    """Run evaluation if due."""
    if (
        eval_data is not None
        and self.step % self.config.eval_interval == 0
    ):
        eval_metrics = self.evaluate(eval_data)
        metrics.update(eval_metrics)
        for cb in self.callbacks:
            cb.on_eval_end(self.step, eval_metrics)

_single_step(x, y)

Single training step: forward + backward + update.

Parameters:

Name Type Description Default
x array

Input tokens (batch, seq_len).

required
y array

Target tokens (batch, seq_len).

required

Returns:

Type Description
array

Tuple of (loss, grad_norm) where grad_norm is the

array

L2 norm of gradients before clipping.

Source code in src/lmxlab/training/trainer.py
def _single_step(
    self,
    x: mx.array,
    y: mx.array,
) -> tuple[mx.array, mx.array]:
    """Single training step: forward + backward + update.

    Args:
        x: Input tokens (batch, seq_len).
        y: Target tokens (batch, seq_len).

    Returns:
        Tuple of (loss, grad_norm) where grad_norm is the
        L2 norm of gradients before clipping.
    """
    loss, grads = self._loss_and_grad(self.model, x, y)

    # Gradient clipping
    if self.config.max_grad_norm > 0:
        grads, grad_norm = optim.clip_grad_norm(
            grads, max_norm=self.config.max_grad_norm
        )
    else:
        flat = tree_flatten(grads)
        grad_norm = mx.sqrt(sum(mx.sum(g * g) for _, g in flat))

    self.optimizer.update(self.model, grads)
    return loss, grad_norm

_train_accumulated(train_data, eval_data)

Training loop with gradient accumulation.

Source code in src/lmxlab/training/trainer.py
def _train_accumulated(
    self,
    train_data: Iterator[tuple[mx.array, mx.array]],
    eval_data: Iterator[tuple[mx.array, mx.array]] | None,
) -> list[dict[str, Any]]:
    """Training loop with gradient accumulation."""
    history: list[dict[str, Any]] = []
    micro_batches: list[tuple[mx.array, mx.array]] = []

    for batch in train_data:
        if self.step >= self.config.max_steps:
            break

        micro_batches.append(batch)

        if len(micro_batches) == self._accum_steps:
            metrics = self.train_step_accumulated(micro_batches)
            history.append(metrics)
            self._maybe_eval(eval_data, metrics)
            micro_batches = []

    # Handle remaining micro-batches
    if micro_batches and self.step < self.config.max_steps:
        metrics = self.train_step_accumulated(micro_batches)
        history.append(metrics)

    return history

_train_simple(train_data, eval_data)

Training loop without gradient accumulation.

Source code in src/lmxlab/training/trainer.py
def _train_simple(
    self,
    train_data: Iterator[tuple[mx.array, mx.array]],
    eval_data: Iterator[tuple[mx.array, mx.array]] | None,
) -> list[dict[str, Any]]:
    """Training loop without gradient accumulation."""
    history: list[dict[str, Any]] = []

    for batch in train_data:
        if self.step >= self.config.max_steps:
            break

        metrics = self.train_step(batch)
        history.append(metrics)
        self._maybe_eval(eval_data, metrics)

    return history

evaluate(eval_data)

Run evaluation over the eval dataset.

Parameters:

Name Type Description Default
eval_data Iterator[tuple[array, array]]

Iterator yielding (input, target) batches.

required

Returns:

Type Description
dict[str, float]

Dict with 'eval_loss'.

Source code in src/lmxlab/training/trainer.py
def evaluate(
    self,
    eval_data: Iterator[tuple[mx.array, mx.array]],
) -> dict[str, float]:
    """Run evaluation over the eval dataset.

    Args:
        eval_data: Iterator yielding (input, target) batches.

    Returns:
        Dict with 'eval_loss'.
    """
    total_loss = 0.0
    n_batches = 0

    for x, y in eval_data:
        logits, _ = self.model(x)
        logits = logits.reshape(-1, logits.shape[-1])
        targets = y.reshape(-1)
        loss = nn.losses.cross_entropy(logits, targets, reduction="mean")
        mx.eval(loss)
        total_loss += loss.item()
        n_batches += 1

    avg_loss = total_loss / max(n_batches, 1)
    return {"eval_loss": avg_loss}

train(train_data, eval_data=None)

Run the full training loop.

When grad_accumulation_steps > 1, collects that many micro-batches before each optimizer update.

Parameters:

Name Type Description Default
train_data Iterator[tuple[array, array]]

Iterator yielding (input, target) batches.

required
eval_data Iterator[tuple[array, array]] | None

Optional eval data iterator.

None

Returns:

Type Description
list[dict[str, Any]]

List of per-step metrics dicts.

Source code in src/lmxlab/training/trainer.py
def train(
    self,
    train_data: Iterator[tuple[mx.array, mx.array]],
    eval_data: Iterator[tuple[mx.array, mx.array]] | None = None,
) -> list[dict[str, Any]]:
    """Run the full training loop.

    When ``grad_accumulation_steps > 1``, collects that many
    micro-batches before each optimizer update.

    Args:
        train_data: Iterator yielding (input, target) batches.
        eval_data: Optional eval data iterator.

    Returns:
        List of per-step metrics dicts.
    """
    for cb in self.callbacks:
        cb.on_train_begin(self.config)

    history: list[dict[str, Any]] = []

    if self._accum_steps > 1:
        history = self._train_accumulated(train_data, eval_data)
    else:
        history = self._train_simple(train_data, eval_data)

    for cb in self.callbacks:
        cb.on_train_end(history)

    return history

train_step(batch)

Execute one training step with eval boundary.

Parameters:

Name Type Description Default
batch tuple[array, array]

Tuple of (input_tokens, target_tokens).

required

Returns:

Type Description
dict[str, float]

Dict with 'loss', 'learning_rate', and

dict[str, float]

'grad_norm'.

Source code in src/lmxlab/training/trainer.py
def train_step(
    self,
    batch: tuple[mx.array, mx.array],
) -> dict[str, float]:
    """Execute one training step with eval boundary.

    Args:
        batch: Tuple of (input_tokens, target_tokens).

    Returns:
        Dict with 'loss', 'learning_rate', and
        'grad_norm'.
    """
    x, y = batch
    loss, grad_norm = self._step_fn(x, y)

    # Explicit eval boundary
    mx.eval(
        loss,
        grad_norm,
        self.model.parameters(),
        self.optimizer.state,
    )

    self.step += 1
    lr = self.optimizer.learning_rate
    if callable(lr):
        lr = lr(self.step)

    metrics = {
        "loss": loss.item(),
        "learning_rate": float(lr),
        "grad_norm": grad_norm.item(),
    }

    for cb in self.callbacks:
        cb.on_step_end(self.step, metrics)

    return metrics

train_step_accumulated(micro_batches)

Execute one accumulated training step.

Averages gradients over micro-batches, then updates once.

Parameters:

Name Type Description Default
micro_batches list[tuple[array, array]]

List of (input, target) micro-batches.

required

Returns:

Type Description
dict[str, float]

Dict with 'loss', 'learning_rate', and

dict[str, float]

'grad_norm'.

Source code in src/lmxlab/training/trainer.py
def train_step_accumulated(
    self,
    micro_batches: list[tuple[mx.array, mx.array]],
) -> dict[str, float]:
    """Execute one accumulated training step.

    Averages gradients over micro-batches, then updates once.

    Args:
        micro_batches: List of (input, target) micro-batches.

    Returns:
        Dict with 'loss', 'learning_rate', and
        'grad_norm'.
    """
    loss, grad_norm = self._accumulation_step(
        micro_batches,
    )

    # Explicit eval boundary
    mx.eval(
        loss,
        grad_norm,
        self.model.parameters(),
        self.optimizer.state,
    )

    self.step += 1
    lr = self.optimizer.learning_rate
    if callable(lr):
        lr = lr(self.step)

    metrics = {
        "loss": loss.item(),
        "learning_rate": float(lr),
        "grad_norm": grad_norm.item(),
    }

    for cb in self.callbacks:
        cb.on_step_end(self.step, metrics)

    return metrics

Training Config

lmxlab.training.config.TrainConfig dataclass

Configuration for the training loop.

Parameters:

Name Type Description Default
learning_rate float

Peak learning rate.

0.0003
weight_decay float

Weight decay coefficient.

0.01
warmup_steps int

Linear warmup steps.

100
max_steps int

Maximum training steps.

1000
batch_size int

Training batch size.

32
grad_accumulation_steps int

Number of micro-batches to accumulate before an optimizer step.

1
max_grad_norm float

Maximum gradient norm for clipping.

1.0
eval_interval int

Steps between evaluations.

100
log_interval int

Steps between logging.

10
checkpoint_interval int

Steps between checkpoints.

500
optimizer str

Optimizer name ('adamw', 'lion', 'adafactor', 'sgd').

'adamw'
lr_schedule str

Learning rate schedule ('cosine', 'linear', 'constant').

'cosine'
compile_step bool

Whether to mx.compile the training step.

True
seed int

Random seed.

42
Source code in src/lmxlab/training/config.py
@dataclass(frozen=True)
class TrainConfig:
    """Configuration for the training loop.

    Args:
        learning_rate: Peak learning rate.
        weight_decay: Weight decay coefficient.
        warmup_steps: Linear warmup steps.
        max_steps: Maximum training steps.
        batch_size: Training batch size.
        grad_accumulation_steps: Number of micro-batches
            to accumulate before an optimizer step.
        max_grad_norm: Maximum gradient norm for clipping.
        eval_interval: Steps between evaluations.
        log_interval: Steps between logging.
        checkpoint_interval: Steps between checkpoints.
        optimizer: Optimizer name ('adamw', 'lion', 'adafactor', 'sgd').
        lr_schedule: Learning rate schedule ('cosine', 'linear', 'constant').
        compile_step: Whether to mx.compile the training step.
        seed: Random seed.
    """

    learning_rate: float = 3e-4
    weight_decay: float = 0.01
    warmup_steps: int = 100
    max_steps: int = 1000
    batch_size: int = 32
    grad_accumulation_steps: int = 1
    max_grad_norm: float = 1.0
    eval_interval: int = 100
    log_interval: int = 10
    checkpoint_interval: int = 500
    optimizer: str = "adamw"
    lr_schedule: str = "cosine"
    compile_step: bool = True
    seed: int = 42

batch_size = 32 class-attribute instance-attribute

checkpoint_interval = 500 class-attribute instance-attribute

compile_step = True class-attribute instance-attribute

eval_interval = 100 class-attribute instance-attribute

grad_accumulation_steps = 1 class-attribute instance-attribute

learning_rate = 0.0003 class-attribute instance-attribute

log_interval = 10 class-attribute instance-attribute

lr_schedule = 'cosine' class-attribute instance-attribute

max_grad_norm = 1.0 class-attribute instance-attribute

max_steps = 1000 class-attribute instance-attribute

optimizer = 'adamw' class-attribute instance-attribute

seed = 42 class-attribute instance-attribute

warmup_steps = 100 class-attribute instance-attribute

weight_decay = 0.01 class-attribute instance-attribute

__init__(learning_rate=0.0003, weight_decay=0.01, warmup_steps=100, max_steps=1000, batch_size=32, grad_accumulation_steps=1, max_grad_norm=1.0, eval_interval=100, log_interval=10, checkpoint_interval=500, optimizer='adamw', lr_schedule='cosine', compile_step=True, seed=42)

Optimizers

lmxlab.training.optimizers.create_optimizer(config)

Create an optimizer with learning rate schedule.

Parameters:

Name Type Description Default
config TrainConfig

Training configuration.

required

Returns:

Type Description
Optimizer

Configured optimizer.

Source code in src/lmxlab/training/optimizers.py
def create_optimizer(
    config: TrainConfig,
) -> optim.Optimizer:
    """Create an optimizer with learning rate schedule.

    Args:
        config: Training configuration.

    Returns:
        Configured optimizer.
    """
    schedule = create_schedule(config)

    if config.optimizer == "adamw":
        return optim.AdamW(
            learning_rate=schedule,
            weight_decay=config.weight_decay,
        )
    elif config.optimizer == "lion":
        return optim.Lion(
            learning_rate=schedule,
            weight_decay=config.weight_decay,
        )
    elif config.optimizer == "adafactor":
        return optim.Adafactor(
            learning_rate=schedule,
        )
    elif config.optimizer == "sgd":
        return optim.SGD(
            learning_rate=schedule,
            momentum=0.9,
            weight_decay=config.weight_decay,
        )
    else:
        raise ValueError(f"Unknown optimizer: {config.optimizer!r}")

lmxlab.training.optimizers.create_schedule(config)

Create a learning rate schedule.

Supports warmup + decay patterns.

Parameters:

Name Type Description Default
config TrainConfig

Training configuration.

required

Returns:

Type Description
Callable[[int], float]

Learning rate scheduler.

Source code in src/lmxlab/training/optimizers.py
def create_schedule(
    config: TrainConfig,
) -> Callable[[int], float]:
    """Create a learning rate schedule.

    Supports warmup + decay patterns.

    Args:
        config: Training configuration.

    Returns:
        Learning rate scheduler.
    """
    warmup = optim.schedulers.linear_schedule(
        init=1e-7,
        end=config.learning_rate,
        steps=config.warmup_steps,
    )
    if config.lr_schedule == "cosine":
        decay = optim.schedulers.cosine_decay(
            init=config.learning_rate,
            decay_steps=config.max_steps - config.warmup_steps,
        )
    elif config.lr_schedule == "linear":
        decay = optim.schedulers.linear_schedule(
            init=config.learning_rate,
            end=0.0,
            steps=config.max_steps - config.warmup_steps,
        )
    elif config.lr_schedule == "constant":
        decay = config.learning_rate  # type: ignore[assignment]
    else:
        raise ValueError(f"Unknown schedule: {config.lr_schedule!r}")

    return optim.schedulers.join_schedules(
        schedules=[warmup, decay],
        boundaries=[config.warmup_steps],
    )

Checkpoints

lmxlab.training.checkpoints.save_checkpoint(path, model, optimizer=None, step=0, metadata=None)

Save a training checkpoint.

Saves model weights as safetensors and metadata as JSON.

Parameters:

Name Type Description Default
path str | Path

Directory to save checkpoint.

required
model LanguageModel

Model to save.

required
optimizer Optimizer | None

Optional optimizer to save state.

None
step int

Current training step.

0
metadata dict[str, Any] | None

Additional metadata to save.

None
Source code in src/lmxlab/training/checkpoints.py
def save_checkpoint(
    path: str | Path,
    model: LanguageModel,
    optimizer: optim.Optimizer | None = None,
    step: int = 0,
    metadata: dict[str, Any] | None = None,
) -> None:
    """Save a training checkpoint.

    Saves model weights as safetensors and metadata as JSON.

    Args:
        path: Directory to save checkpoint.
        model: Model to save.
        optimizer: Optional optimizer to save state.
        step: Current training step.
        metadata: Additional metadata to save.
    """
    path = Path(path)
    path.mkdir(parents=True, exist_ok=True)

    # Save model weights
    weights = dict(mlx.utils.tree_flatten(model.parameters()))
    mx.save_safetensors(str(path / "model.safetensors"), weights)

    # Save optimizer state if provided
    if optimizer is not None:
        opt_state = dict(mlx.utils.tree_flatten(optimizer.state))
        if opt_state:
            mx.save_safetensors(str(path / "optimizer.safetensors"), opt_state)

    # Save metadata
    meta = {
        "step": step,
        **(metadata or {}),
    }
    (path / "metadata.json").write_text(json.dumps(meta, indent=2))

lmxlab.training.checkpoints.load_checkpoint(path, model, optimizer=None)

Load a training checkpoint.

Parameters:

Name Type Description Default
path str | Path

Directory containing checkpoint.

required
model LanguageModel

Model to load weights into.

required
optimizer Optimizer | None

Optional optimizer to load state into.

None

Returns:

Type Description
dict[str, Any]

Metadata dict from the checkpoint.

Source code in src/lmxlab/training/checkpoints.py
def load_checkpoint(
    path: str | Path,
    model: LanguageModel,
    optimizer: optim.Optimizer | None = None,
) -> dict[str, Any]:
    """Load a training checkpoint.

    Args:
        path: Directory containing checkpoint.
        model: Model to load weights into.
        optimizer: Optional optimizer to load state into.

    Returns:
        Metadata dict from the checkpoint.
    """
    path = Path(path)

    # Load model weights
    weights = mx.load(str(path / "model.safetensors"))
    model.load_weights(list(weights.items()))  # type: ignore[union-attr]

    # Load optimizer state if available
    opt_path = path / "optimizer.safetensors"
    if optimizer is not None and opt_path.exists():
        opt_state = mx.load(str(opt_path))
        # Reconstruct nested state from flat keys
        optimizer.state = mlx.utils.tree_unflatten(list(opt_state.items()))  # type: ignore[union-attr]

    # Load metadata
    meta_path = path / "metadata.json"
    if meta_path.exists():
        return json.loads(meta_path.read_text())  # type: ignore[no-any-return]
    return {}

Callbacks

lmxlab.training.callbacks.Callback

Bases: Protocol

Protocol for training callbacks.

Source code in src/lmxlab/training/callbacks.py
class Callback(Protocol):
    """Protocol for training callbacks."""

    def on_train_begin(self, config: TrainConfig) -> None: ...
    def on_train_end(self, history: list[dict[str, Any]]) -> None: ...
    def on_step_end(self, step: int, metrics: dict[str, Any]) -> None: ...
    def on_eval_end(self, step: int, metrics: dict[str, Any]) -> None: ...

lmxlab.training.callbacks.MetricsLogger

Logs training metrics at configured intervals.

Parameters:

Name Type Description Default
log_interval int

Steps between log outputs.

10
Source code in src/lmxlab/training/callbacks.py
class MetricsLogger:
    """Logs training metrics at configured intervals.

    Args:
        log_interval: Steps between log outputs.
    """

    def __init__(self, log_interval: int = 10) -> None:
        self.log_interval = log_interval
        self._start_time: float = 0.0

    def on_train_begin(self, config: TrainConfig) -> None:
        self._start_time = time.monotonic()

    def on_train_end(self, history: list[dict[str, Any]]) -> None:
        elapsed = time.monotonic() - self._start_time
        print(f"Training complete: {len(history)} steps in {elapsed:.1f}s")

    def on_step_end(self, step: int, metrics: dict[str, Any]) -> None:
        if step % self.log_interval == 0:
            loss = metrics.get("loss", 0.0)
            lr = metrics.get("learning_rate", 0.0)
            print(f"step {step}: loss={loss:.4f}, lr={lr:.2e}")

    def on_eval_end(self, step: int, metrics: dict[str, Any]) -> None:
        eval_loss = metrics.get("eval_loss", 0.0)
        print(f"step {step} eval: loss={eval_loss:.4f}")

lmxlab.training.callbacks.EarlyStopping

Stop training when eval loss stops improving.

Parameters:

Name Type Description Default
patience int

Steps without improvement before stopping.

5
min_delta float

Minimum change to qualify as improvement.

0.001
Source code in src/lmxlab/training/callbacks.py
class EarlyStopping:
    """Stop training when eval loss stops improving.

    Args:
        patience: Steps without improvement before stopping.
        min_delta: Minimum change to qualify as improvement.
    """

    def __init__(self, patience: int = 5, min_delta: float = 0.001) -> None:
        self.patience = patience
        self.min_delta = min_delta
        self._best_loss: float = float("inf")
        self._wait: int = 0
        self.should_stop: bool = False

    def on_train_begin(self, config: TrainConfig) -> None:
        self._best_loss = float("inf")
        self._wait = 0
        self.should_stop = False

    def on_train_end(self, history: list[dict[str, Any]]) -> None:
        pass

    def on_step_end(self, step: int, metrics: dict[str, Any]) -> None:
        pass

    def on_eval_end(self, step: int, metrics: dict[str, Any]) -> None:
        eval_loss = metrics.get("eval_loss", float("inf"))
        if eval_loss < self._best_loss - self.min_delta:
            self._best_loss = eval_loss
            self._wait = 0
        else:
            self._wait += 1
            if self._wait >= self.patience:
                self.should_stop = True
                print(
                    f"Early stopping at step {step} "
                    f"(best loss: {self._best_loss:.4f})"
                )

lmxlab.training.callbacks.ThroughputMonitor

Tracks training throughput (tokens/sec, steps/sec).

Reports throughput at configured intervals, useful for understanding MLX performance characteristics and comparing compiled vs uncompiled training.

Parameters:

Name Type Description Default
log_interval int

Steps between throughput reports.

10
tokens_per_step int | None

Tokens processed per step (batch_size * seq_len). If None, only reports steps/sec.

None
Source code in src/lmxlab/training/callbacks.py
class ThroughputMonitor:
    """Tracks training throughput (tokens/sec, steps/sec).

    Reports throughput at configured intervals, useful for
    understanding MLX performance characteristics and comparing
    compiled vs uncompiled training.

    Args:
        log_interval: Steps between throughput reports.
        tokens_per_step: Tokens processed per step
            (batch_size * seq_len). If None, only reports
            steps/sec.
    """

    def __init__(
        self,
        log_interval: int = 10,
        tokens_per_step: int | None = None,
    ) -> None:
        self.log_interval = log_interval
        self.tokens_per_step = tokens_per_step
        self._step_times: list[float] = []
        self._last_time: float = 0.0
        self._total_steps: int = 0
        self._train_start: float = 0.0

    def on_train_begin(self, config: TrainConfig) -> None:
        self._last_time = time.monotonic()
        self._train_start = self._last_time
        self._step_times = []
        self._total_steps = 0

    def on_train_end(self, history: list[dict[str, Any]]) -> None:
        elapsed = time.monotonic() - self._train_start
        if self._total_steps > 0 and elapsed > 0:
            avg_steps_sec = self._total_steps / elapsed
            msg = (
                f"Throughput summary: {self._total_steps} steps "
                f"in {elapsed:.1f}s ({avg_steps_sec:.1f} steps/s"
            )
            if self.tokens_per_step is not None:
                total_tokens = self._total_steps * self.tokens_per_step
                tok_sec = total_tokens / elapsed
                msg += f", {tok_sec:.0f} tok/s"
            msg += ")"
            print(msg)

    def on_step_end(self, step: int, metrics: dict[str, Any]) -> None:
        now = time.monotonic()
        dt = now - self._last_time
        self._last_time = now
        self._step_times.append(dt)
        self._total_steps += 1

        # Inject cumulative token count every step
        if self.tokens_per_step is not None:
            total_tokens = self._total_steps * self.tokens_per_step
            metrics["tokens_processed"] = total_tokens

        if step % self.log_interval == 0 and dt > 0:
            # Use recent window for smoother reporting
            window = self._step_times[-self.log_interval :]
            avg_dt = sum(window) / len(window)
            steps_sec = 1.0 / avg_dt if avg_dt > 0 else 0
            metrics["steps_per_sec"] = steps_sec

            msg = f"step {step}: {steps_sec:.1f} steps/s"
            if self.tokens_per_step is not None:
                tok_sec = self.tokens_per_step * steps_sec
                metrics["tokens_per_sec"] = tok_sec
                msg += f", {tok_sec:.0f} tok/s"
            print(msg)

    def on_eval_end(self, step: int, metrics: dict[str, Any]) -> None:
        pass

DPO

lmxlab.training.dpo

Direct Preference Optimization (DPO) training.

dpo_loss(model, ref_model, chosen, rejected, beta=0.1)

Compute DPO loss.

DPO directly optimizes preferences without reward modeling. L = -log(sigmoid(beta * (log_ratio_chosen - log_ratio_rejected)))

Parameters:

Name Type Description Default
model LanguageModel

Policy model being trained.

required
ref_model LanguageModel

Reference (frozen) model.

required
chosen array

Preferred sequence token IDs (batch, seq_len).

required
rejected array

Dispreferred sequence token IDs (batch, seq_len).

required
beta float

Temperature parameter controlling deviation from ref.

0.1

Returns:

Type Description
array

Scalar DPO loss.

Source code in src/lmxlab/training/dpo.py
def dpo_loss(
    model: LanguageModel,
    ref_model: LanguageModel,
    chosen: mx.array,
    rejected: mx.array,
    beta: float = 0.1,
) -> mx.array:
    """Compute DPO loss.

    DPO directly optimizes preferences without reward modeling.
    L = -log(sigmoid(beta * (log_ratio_chosen - log_ratio_rejected)))

    Args:
        model: Policy model being trained.
        ref_model: Reference (frozen) model.
        chosen: Preferred sequence token IDs (batch, seq_len).
        rejected: Dispreferred sequence token IDs (batch, seq_len).
        beta: Temperature parameter controlling deviation from ref.

    Returns:
        Scalar DPO loss.
    """
    # Compute log probs for chosen sequences
    chosen_logits, _ = model(chosen[:, :-1])
    chosen_ref_logits, _ = ref_model(chosen[:, :-1])
    chosen_targets = chosen[:, 1:]

    chosen_logps = _sequence_log_probs(chosen_logits, chosen_targets)
    chosen_ref_logps = _sequence_log_probs(chosen_ref_logits, chosen_targets)

    # Compute log probs for rejected sequences
    rejected_logits, _ = model(rejected[:, :-1])
    rejected_ref_logits, _ = ref_model(rejected[:, :-1])
    rejected_targets = rejected[:, 1:]

    rejected_logps = _sequence_log_probs(rejected_logits, rejected_targets)
    rejected_ref_logps = _sequence_log_probs(
        rejected_ref_logits, rejected_targets
    )

    # DPO objective
    chosen_rewards = beta * (chosen_logps - chosen_ref_logps)
    rejected_rewards = beta * (rejected_logps - rejected_ref_logps)

    # log_sigmoid(x) = -softplus(-x) for numerical stability
    loss = mx.logaddexp(0, -(chosen_rewards - rejected_rewards))
    return mx.mean(loss)

GRPO

lmxlab.training.grpo

Group Relative Policy Optimization (GRPO).

grpo_loss(model, ref_model, prompts, completions, rewards, beta=0.1, epsilon=0.2)

Compute GRPO loss.

GRPO uses group-relative rewards: for each prompt, generate multiple completions, compute rewards, normalize within the group, and optimize using a clipped surrogate objective.

Parameters:

Name Type Description Default
model LanguageModel

Policy model being trained.

required
ref_model LanguageModel

Reference (frozen) model.

required
prompts array

Prompt token IDs (batch, prompt_len).

required
completions array

Full sequences (batch, total_len).

required
rewards array

Scalar rewards per completion (batch,).

required
beta float

KL penalty coefficient.

0.1
epsilon float

Clipping range for surrogate objective.

0.2

Returns:

Type Description
array

Scalar GRPO loss.

Source code in src/lmxlab/training/grpo.py
def grpo_loss(
    model: LanguageModel,
    ref_model: LanguageModel,
    prompts: mx.array,
    completions: mx.array,
    rewards: mx.array,
    beta: float = 0.1,
    epsilon: float = 0.2,
) -> mx.array:
    """Compute GRPO loss.

    GRPO uses group-relative rewards: for each prompt, generate
    multiple completions, compute rewards, normalize within the
    group, and optimize using a clipped surrogate objective.

    Args:
        model: Policy model being trained.
        ref_model: Reference (frozen) model.
        prompts: Prompt token IDs (batch, prompt_len).
        completions: Full sequences (batch, total_len).
        rewards: Scalar rewards per completion (batch,).
        beta: KL penalty coefficient.
        epsilon: Clipping range for surrogate objective.

    Returns:
        Scalar GRPO loss.
    """
    # Normalize rewards within the group (zero mean, unit variance)
    reward_mean = mx.mean(rewards)
    reward_std = mx.maximum(mx.std(rewards), mx.array(1e-8))
    advantages = (rewards - reward_mean) / reward_std

    # Compute log probs for completions
    inputs = completions[:, :-1]
    targets = completions[:, 1:]

    logits, _ = model(inputs)
    ref_logits, _ = ref_model(inputs)

    log_probs = _sequence_log_probs(logits, targets)
    ref_log_probs = _sequence_log_probs(ref_logits, targets)

    # Ratio and clipped surrogate
    ratio = mx.exp(log_probs - ref_log_probs)
    clipped_ratio = mx.clip(ratio, 1.0 - epsilon, 1.0 + epsilon)

    surrogate = mx.minimum(
        ratio * advantages,
        clipped_ratio * advantages,
    )

    # KL penalty
    kl = ref_log_probs - log_probs

    loss = -(surrogate - beta * kl)
    return mx.mean(loss)

GRPO Trainer

lmxlab.training.grpo_trainer.GRPOConfig dataclass

Configuration for GRPO training.

Parameters:

Name Type Description Default
group_size int

Number of completions per prompt.

4
max_gen_tokens int

Maximum tokens to generate per completion.

256
temperature float

Sampling temperature for generation.

0.8
beta float

KL penalty coefficient.

0.1
epsilon float

Clipping range for surrogate objective.

0.2
learning_rate float

Optimizer learning rate.

1e-05
max_grad_norm float

Maximum gradient norm for clipping.

1.0
Source code in src/lmxlab/training/grpo_trainer.py
@dataclass(frozen=True)
class GRPOConfig:
    """Configuration for GRPO training.

    Args:
        group_size: Number of completions per prompt.
        max_gen_tokens: Maximum tokens to generate per completion.
        temperature: Sampling temperature for generation.
        beta: KL penalty coefficient.
        epsilon: Clipping range for surrogate objective.
        learning_rate: Optimizer learning rate.
        max_grad_norm: Maximum gradient norm for clipping.
    """

    group_size: int = 4
    max_gen_tokens: int = 256
    temperature: float = 0.8
    beta: float = 0.1
    epsilon: float = 0.2
    learning_rate: float = 1e-5
    max_grad_norm: float = 1.0

beta = 0.1 class-attribute instance-attribute

epsilon = 0.2 class-attribute instance-attribute

group_size = 4 class-attribute instance-attribute

learning_rate = 1e-05 class-attribute instance-attribute

max_gen_tokens = 256 class-attribute instance-attribute

max_grad_norm = 1.0 class-attribute instance-attribute

temperature = 0.8 class-attribute instance-attribute

__init__(group_size=4, max_gen_tokens=256, temperature=0.8, beta=0.1, epsilon=0.2, learning_rate=1e-05, max_grad_norm=1.0)

lmxlab.training.grpo_trainer.GRPOTrainer

GRPO training loop.

Generates completions, scores them with a reward function, and optimizes the policy model using clipped surrogate GRPO.

Parameters:

Name Type Description Default
model LanguageModel

Policy model to train.

required
ref_model LanguageModel

Frozen reference model for KL penalty.

required
config GRPOConfig

GRPO training configuration.

required
reward_fn Callable[[array, array], float]

Callable that scores (prompt, completion) pairs. Takes two mx.array arguments and returns a scalar float reward.

required
optimizer Optimizer

MLX optimizer instance.

required
callbacks list[Callback] | None

Optional list of training callbacks.

None
Example

trainer = GRPOTrainer( ... model, ref_model, GRPOConfig(), ... reward_fn=lambda p, c: 1.0, ... optimizer=optim.Adam(learning_rate=1e-5), ... ) trainer.train(prompt_iterator, n_steps=100)

Source code in src/lmxlab/training/grpo_trainer.py
class GRPOTrainer:
    """GRPO training loop.

    Generates completions, scores them with a reward function,
    and optimizes the policy model using clipped surrogate GRPO.

    Args:
        model: Policy model to train.
        ref_model: Frozen reference model for KL penalty.
        config: GRPO training configuration.
        reward_fn: Callable that scores (prompt, completion)
            pairs. Takes two ``mx.array`` arguments and returns
            a scalar ``float`` reward.
        optimizer: MLX optimizer instance.
        callbacks: Optional list of training callbacks.

    Example:
        >>> trainer = GRPOTrainer(
        ...     model, ref_model, GRPOConfig(),
        ...     reward_fn=lambda p, c: 1.0,
        ...     optimizer=optim.Adam(learning_rate=1e-5),
        ... )
        >>> trainer.train(prompt_iterator, n_steps=100)
    """

    def __init__(
        self,
        model: LanguageModel,
        ref_model: LanguageModel,
        config: GRPOConfig,
        reward_fn: Callable[[mx.array, mx.array], float],
        optimizer: optim.Optimizer,
        callbacks: list[Callback] | None = None,
    ) -> None:
        self.model = model
        self.ref_model = ref_model
        self.config = config
        self.reward_fn = reward_fn
        self.optimizer = optimizer
        self.callbacks = callbacks or []

        # Build value_and_grad for the GRPO loss
        self._loss_and_grad = nn.value_and_grad(model, self._compute_loss)

    def _compute_loss(
        self,
        model: LanguageModel,
        prompts: mx.array,
        completions: mx.array,
        rewards: mx.array,
    ) -> mx.array:
        """Compute GRPO loss for gradient computation.

        Args:
            model: Policy model (passed by value_and_grad).
            prompts: Prompt tokens (group_size, prompt_len).
            completions: Full sequences
                (group_size, total_len).
            rewards: Scalar rewards (group_size,).

        Returns:
            Scalar GRPO loss.
        """
        return grpo_loss(
            model,
            self.ref_model,
            prompts,
            completions,
            rewards,
            beta=self.config.beta,
            epsilon=self.config.epsilon,
        )

    def _generate_completions(
        self,
        prompt: mx.array,
    ) -> mx.array:
        """Generate group_size completions for a prompt.

        Args:
            prompt: Single prompt (1, prompt_len).

        Returns:
            Full sequences (group_size, prompt_len + gen_len).
        """
        # Expand prompt to group_size copies
        prompts = mx.broadcast_to(
            prompt,
            (self.config.group_size, prompt.shape[1]),
        )
        # Force copy so broadcast doesn't interfere
        prompts = mx.array(prompts)

        completions = generate(
            self.model,
            prompts,
            max_tokens=self.config.max_gen_tokens,
            temperature=self.config.temperature,
        )
        return completions

    def _score_completions(
        self,
        prompt: mx.array,
        completions: mx.array,
    ) -> mx.array:
        """Score completions with the reward function.

        Args:
            prompt: Single prompt (1, prompt_len).
            completions: Full sequences
                (group_size, total_len).

        Returns:
            Rewards (group_size,).
        """
        rewards = []
        for i in range(completions.shape[0]):
            r = self.reward_fn(prompt[0], completions[i])
            rewards.append(r)
        return mx.array(rewards)

    def train(
        self,
        prompt_iterator: Iterator[mx.array],
        n_steps: int,
    ) -> list[dict[str, Any]]:
        """Run the GRPO training loop.

        Args:
            prompt_iterator: Yields prompt tensors of shape
                (1, prompt_len).
            n_steps: Number of optimization steps.

        Returns:
            List of per-step metrics dicts.
        """
        history: list[dict[str, Any]] = []

        for cb in self.callbacks:
            cb.on_train_begin(None)

        for step in range(1, n_steps + 1):
            prompt = next(prompt_iterator)
            if prompt.ndim == 1:
                prompt = prompt[None, :]

            # Generate completions
            completions = self._generate_completions(prompt)
            mx.eval(completions)

            # Score with reward function
            rewards = self._score_completions(prompt, completions)

            # Expand prompt to match group_size
            prompts = mx.broadcast_to(
                prompt,
                (self.config.group_size, prompt.shape[1]),
            )
            prompts = mx.array(prompts)

            # Compute loss and gradients
            loss, grads = self._loss_and_grad(
                self.model,
                prompts,
                completions,
                rewards,
            )

            # Clip gradients
            grads, grad_norm = optim.clip_grad_norm(
                grads, max_norm=self.config.max_grad_norm
            )

            # Update model
            self.optimizer.update(self.model, grads)
            mx.eval(self.model.parameters(), self.optimizer.state)

            metrics = {
                "loss": loss.item(),
                "grad_norm": grad_norm.item(),
                "mean_reward": mx.mean(rewards).item(),
            }
            history.append(metrics)

            for cb in self.callbacks:
                cb.on_step_end(step, metrics)

        for cb in self.callbacks:
            cb.on_train_end(history)

        return history

_loss_and_grad = nn.value_and_grad(model, self._compute_loss) instance-attribute

callbacks = callbacks or [] instance-attribute

config = config instance-attribute

model = model instance-attribute

optimizer = optimizer instance-attribute

ref_model = ref_model instance-attribute

reward_fn = reward_fn instance-attribute

__init__(model, ref_model, config, reward_fn, optimizer, callbacks=None)

Source code in src/lmxlab/training/grpo_trainer.py
def __init__(
    self,
    model: LanguageModel,
    ref_model: LanguageModel,
    config: GRPOConfig,
    reward_fn: Callable[[mx.array, mx.array], float],
    optimizer: optim.Optimizer,
    callbacks: list[Callback] | None = None,
) -> None:
    self.model = model
    self.ref_model = ref_model
    self.config = config
    self.reward_fn = reward_fn
    self.optimizer = optimizer
    self.callbacks = callbacks or []

    # Build value_and_grad for the GRPO loss
    self._loss_and_grad = nn.value_and_grad(model, self._compute_loss)

_compute_loss(model, prompts, completions, rewards)

Compute GRPO loss for gradient computation.

Parameters:

Name Type Description Default
model LanguageModel

Policy model (passed by value_and_grad).

required
prompts array

Prompt tokens (group_size, prompt_len).

required
completions array

Full sequences (group_size, total_len).

required
rewards array

Scalar rewards (group_size,).

required

Returns:

Type Description
array

Scalar GRPO loss.

Source code in src/lmxlab/training/grpo_trainer.py
def _compute_loss(
    self,
    model: LanguageModel,
    prompts: mx.array,
    completions: mx.array,
    rewards: mx.array,
) -> mx.array:
    """Compute GRPO loss for gradient computation.

    Args:
        model: Policy model (passed by value_and_grad).
        prompts: Prompt tokens (group_size, prompt_len).
        completions: Full sequences
            (group_size, total_len).
        rewards: Scalar rewards (group_size,).

    Returns:
        Scalar GRPO loss.
    """
    return grpo_loss(
        model,
        self.ref_model,
        prompts,
        completions,
        rewards,
        beta=self.config.beta,
        epsilon=self.config.epsilon,
    )

_generate_completions(prompt)

Generate group_size completions for a prompt.

Parameters:

Name Type Description Default
prompt array

Single prompt (1, prompt_len).

required

Returns:

Type Description
array

Full sequences (group_size, prompt_len + gen_len).

Source code in src/lmxlab/training/grpo_trainer.py
def _generate_completions(
    self,
    prompt: mx.array,
) -> mx.array:
    """Generate group_size completions for a prompt.

    Args:
        prompt: Single prompt (1, prompt_len).

    Returns:
        Full sequences (group_size, prompt_len + gen_len).
    """
    # Expand prompt to group_size copies
    prompts = mx.broadcast_to(
        prompt,
        (self.config.group_size, prompt.shape[1]),
    )
    # Force copy so broadcast doesn't interfere
    prompts = mx.array(prompts)

    completions = generate(
        self.model,
        prompts,
        max_tokens=self.config.max_gen_tokens,
        temperature=self.config.temperature,
    )
    return completions

_score_completions(prompt, completions)

Score completions with the reward function.

Parameters:

Name Type Description Default
prompt array

Single prompt (1, prompt_len).

required
completions array

Full sequences (group_size, total_len).

required

Returns:

Type Description
array

Rewards (group_size,).

Source code in src/lmxlab/training/grpo_trainer.py
def _score_completions(
    self,
    prompt: mx.array,
    completions: mx.array,
) -> mx.array:
    """Score completions with the reward function.

    Args:
        prompt: Single prompt (1, prompt_len).
        completions: Full sequences
            (group_size, total_len).

    Returns:
        Rewards (group_size,).
    """
    rewards = []
    for i in range(completions.shape[0]):
        r = self.reward_fn(prompt[0], completions[i])
        rewards.append(r)
    return mx.array(rewards)

train(prompt_iterator, n_steps)

Run the GRPO training loop.

Parameters:

Name Type Description Default
prompt_iterator Iterator[array]

Yields prompt tensors of shape (1, prompt_len).

required
n_steps int

Number of optimization steps.

required

Returns:

Type Description
list[dict[str, Any]]

List of per-step metrics dicts.

Source code in src/lmxlab/training/grpo_trainer.py
def train(
    self,
    prompt_iterator: Iterator[mx.array],
    n_steps: int,
) -> list[dict[str, Any]]:
    """Run the GRPO training loop.

    Args:
        prompt_iterator: Yields prompt tensors of shape
            (1, prompt_len).
        n_steps: Number of optimization steps.

    Returns:
        List of per-step metrics dicts.
    """
    history: list[dict[str, Any]] = []

    for cb in self.callbacks:
        cb.on_train_begin(None)

    for step in range(1, n_steps + 1):
        prompt = next(prompt_iterator)
        if prompt.ndim == 1:
            prompt = prompt[None, :]

        # Generate completions
        completions = self._generate_completions(prompt)
        mx.eval(completions)

        # Score with reward function
        rewards = self._score_completions(prompt, completions)

        # Expand prompt to match group_size
        prompts = mx.broadcast_to(
            prompt,
            (self.config.group_size, prompt.shape[1]),
        )
        prompts = mx.array(prompts)

        # Compute loss and gradients
        loss, grads = self._loss_and_grad(
            self.model,
            prompts,
            completions,
            rewards,
        )

        # Clip gradients
        grads, grad_norm = optim.clip_grad_norm(
            grads, max_norm=self.config.max_grad_norm
        )

        # Update model
        self.optimizer.update(self.model, grads)
        mx.eval(self.model.parameters(), self.optimizer.state)

        metrics = {
            "loss": loss.item(),
            "grad_norm": grad_norm.item(),
            "mean_reward": mx.mean(rewards).item(),
        }
        history.append(metrics)

        for cb in self.callbacks:
            cb.on_step_end(step, metrics)

    for cb in self.callbacks:
        cb.on_train_end(history)

    return history

Multi-Token Prediction

lmxlab.training.mtp.MTPHead

Bases: Module

Single multi-token prediction head.

Takes hidden states and previous target embeddings, normalizes both, concatenates, projects back to d_model, then runs through a transformer block.

Parameters:

Name Type Description Default
d_model int

Hidden dimension.

required
block_config BlockConfig

Block configuration for the MTP block.

required
Source code in src/lmxlab/training/mtp.py
class MTPHead(nn.Module):
    """Single multi-token prediction head.

    Takes hidden states and previous target embeddings,
    normalizes both, concatenates, projects back to d_model,
    then runs through a transformer block.

    Args:
        d_model: Hidden dimension.
        block_config: Block configuration for the MTP block.
    """

    def __init__(
        self,
        d_model: int,
        block_config: BlockConfig,
    ) -> None:
        super().__init__()
        self.hidden_norm = nn.RMSNorm(d_model)
        self.embed_norm = nn.RMSNorm(d_model)
        self.proj = nn.Linear(
            2 * d_model,
            d_model,
            bias=False,
        )
        self.block = ConfigurableBlock(block_config)

    def __call__(
        self,
        h: mx.array,
        prev_embed: mx.array,
        mask: mx.array | None = None,
    ) -> mx.array:
        """Produce hidden states for future token prediction.

        Args:
            h: Hidden states (batch, seq_len, d_model).
            prev_embed: Embeddings of previous target tokens
                (batch, seq_len, d_model).
            mask: Optional causal mask.

        Returns:
            Hidden states (batch, seq_len, d_model).
        """
        combined = mx.concatenate(
            [
                self.hidden_norm(h),
                self.embed_norm(prev_embed),
            ],
            axis=-1,
        )
        projected = self.proj(combined)
        out, _ = self.block(projected, mask=mask)
        return out

__init__(d_model, block_config)

Source code in src/lmxlab/training/mtp.py
def __init__(
    self,
    d_model: int,
    block_config: BlockConfig,
) -> None:
    super().__init__()
    self.hidden_norm = nn.RMSNorm(d_model)
    self.embed_norm = nn.RMSNorm(d_model)
    self.proj = nn.Linear(
        2 * d_model,
        d_model,
        bias=False,
    )
    self.block = ConfigurableBlock(block_config)

__call__(h, prev_embed, mask=None)

Produce hidden states for future token prediction.

Parameters:

Name Type Description Default
h array

Hidden states (batch, seq_len, d_model).

required
prev_embed array

Embeddings of previous target tokens (batch, seq_len, d_model).

required
mask array | None

Optional causal mask.

None

Returns:

Type Description
array

Hidden states (batch, seq_len, d_model).

Source code in src/lmxlab/training/mtp.py
def __call__(
    self,
    h: mx.array,
    prev_embed: mx.array,
    mask: mx.array | None = None,
) -> mx.array:
    """Produce hidden states for future token prediction.

    Args:
        h: Hidden states (batch, seq_len, d_model).
        prev_embed: Embeddings of previous target tokens
            (batch, seq_len, d_model).
        mask: Optional causal mask.

    Returns:
        Hidden states (batch, seq_len, d_model).
    """
    combined = mx.concatenate(
        [
            self.hidden_norm(h),
            self.embed_norm(prev_embed),
        ],
        axis=-1,
    )
    projected = self.proj(combined)
    out, _ = self.block(projected, mask=mask)
    return out

lmxlab.training.mtp.MultiTokenPrediction

Bases: Module

Multi-Token Prediction wrapper around a LanguageModel.

Adds n_predict auxiliary prediction heads that predict future tokens at each position. Shares the base model's lm_head for logit projection.

Training-only module. At inference time, use the base model directly.

Parameters:

Name Type Description Default
model LanguageModel

Base language model.

required
n_predict int

Number of future tokens to predict.

2
mtp_weight float

Weight for auxiliary MTP losses.

0.3
block_config BlockConfig | None

Block config for MTP heads. If None, uses the base model's block config.

None
Source code in src/lmxlab/training/mtp.py
class MultiTokenPrediction(nn.Module):
    """Multi-Token Prediction wrapper around a LanguageModel.

    Adds n_predict auxiliary prediction heads that predict
    future tokens at each position. Shares the base model's
    lm_head for logit projection.

    Training-only module. At inference time, use the base
    model directly.

    Args:
        model: Base language model.
        n_predict: Number of future tokens to predict.
        mtp_weight: Weight for auxiliary MTP losses.
        block_config: Block config for MTP heads. If None,
            uses the base model's block config.
    """

    def __init__(
        self,
        model: LanguageModel,
        n_predict: int = 2,
        mtp_weight: float = 0.3,
        block_config: BlockConfig | None = None,
    ) -> None:
        super().__init__()
        self.model = model
        self.n_predict = n_predict
        self.mtp_weight = mtp_weight

        d_model = model.config.block.d_model

        # Default MTP block: lightweight attention block
        if block_config is None:
            block_config = BlockConfig(
                attention="mha",
                ffn="standard",
                norm="rms_norm",
                position="none",
                d_model=d_model,
                n_heads=max(1, model.config.block.n_heads // 2),
                d_ff=d_model * 2,
                bias=False,
                pre_norm=True,
            )

        self.mtp_heads = [
            MTPHead(d_model, block_config) for _ in range(n_predict)
        ]

    def _project_logits(self, h: mx.array) -> mx.array:
        """Project hidden states to logits using shared head.

        Args:
            h: Hidden states (batch, seq_len, d_model).

        Returns:
            Logits (batch, seq_len, vocab_size).
        """
        if self.model.config.tie_embeddings:
            return h @ self.model.embed.weight.T
        return self.model.head(h)

    def __call__(
        self,
        x: mx.array,
        targets: mx.array,
    ) -> tuple[mx.array, dict[str, mx.array]]:
        """Forward pass with multi-token prediction.

        Args:
            x: Input token IDs (batch, seq_len).
            targets: Target token IDs (batch, seq_len).
                For MTP depth k, predicts target at t+k.

        Returns:
            Tuple of (main_logits, loss_dict) where loss_dict
            contains 'main_loss', 'mtp_loss', 'total_loss'.
        """
        # Forward with hidden states
        logits, _, hidden = self.model(
            x,
            return_hidden=True,
        )

        # Main loss (next token prediction)
        main_logits = logits[:, :-1, :]
        main_targets = targets[:, : main_logits.shape[1]]
        main_loss = nn.losses.cross_entropy(
            main_logits.reshape(-1, main_logits.shape[-1]),
            main_targets.reshape(-1),
            reduction="mean",
        )

        # MTP losses for each prediction depth
        mtp_losses = []
        h = hidden
        for k, head in enumerate(self.mtp_heads, start=1):
            if targets.shape[1] <= k:
                continue

            # Hidden states for positions with enough future
            h_slice = h[:, :-k, :]

            # Embeddings of tokens at position t+k-1
            prev_toks = targets[:, k - 1 : -1]
            if prev_toks.shape[1] > h_slice.shape[1]:
                prev_toks = prev_toks[:, : h_slice.shape[1]]
            prev_embed = self.model.embed(prev_toks)

            # MTP head produces new hidden states
            h_mtp = head(h_slice, prev_embed)

            # Project to logits using shared lm_head
            mtp_logits = self._project_logits(h_mtp)
            mtp_targets = targets[:, k : k + mtp_logits.shape[1]]

            if mtp_targets.shape[1] > 0:
                mtp_loss = nn.losses.cross_entropy(
                    mtp_logits.reshape(
                        -1,
                        mtp_logits.shape[-1],
                    ),
                    mtp_targets.reshape(-1),
                    reduction="mean",
                )
                mtp_losses.append(mtp_loss)

            # Chain: next head uses this head's output
            h = h_mtp

        # Combine losses
        if mtp_losses:
            avg_mtp_loss = sum(mtp_losses) / len(mtp_losses)
        else:
            avg_mtp_loss = mx.array(0.0)

        total_loss = main_loss + self.mtp_weight * avg_mtp_loss

        losses = {
            "main_loss": main_loss,
            "mtp_loss": avg_mtp_loss,
            "total_loss": total_loss,
        }

        return logits, losses  # type: ignore[return-value]

__init__(model, n_predict=2, mtp_weight=0.3, block_config=None)

Source code in src/lmxlab/training/mtp.py
def __init__(
    self,
    model: LanguageModel,
    n_predict: int = 2,
    mtp_weight: float = 0.3,
    block_config: BlockConfig | None = None,
) -> None:
    super().__init__()
    self.model = model
    self.n_predict = n_predict
    self.mtp_weight = mtp_weight

    d_model = model.config.block.d_model

    # Default MTP block: lightweight attention block
    if block_config is None:
        block_config = BlockConfig(
            attention="mha",
            ffn="standard",
            norm="rms_norm",
            position="none",
            d_model=d_model,
            n_heads=max(1, model.config.block.n_heads // 2),
            d_ff=d_model * 2,
            bias=False,
            pre_norm=True,
        )

    self.mtp_heads = [
        MTPHead(d_model, block_config) for _ in range(n_predict)
    ]

__call__(x, targets)

Forward pass with multi-token prediction.

Parameters:

Name Type Description Default
x array

Input token IDs (batch, seq_len).

required
targets array

Target token IDs (batch, seq_len). For MTP depth k, predicts target at t+k.

required

Returns:

Type Description
array

Tuple of (main_logits, loss_dict) where loss_dict

dict[str, array]

contains 'main_loss', 'mtp_loss', 'total_loss'.

Source code in src/lmxlab/training/mtp.py
def __call__(
    self,
    x: mx.array,
    targets: mx.array,
) -> tuple[mx.array, dict[str, mx.array]]:
    """Forward pass with multi-token prediction.

    Args:
        x: Input token IDs (batch, seq_len).
        targets: Target token IDs (batch, seq_len).
            For MTP depth k, predicts target at t+k.

    Returns:
        Tuple of (main_logits, loss_dict) where loss_dict
        contains 'main_loss', 'mtp_loss', 'total_loss'.
    """
    # Forward with hidden states
    logits, _, hidden = self.model(
        x,
        return_hidden=True,
    )

    # Main loss (next token prediction)
    main_logits = logits[:, :-1, :]
    main_targets = targets[:, : main_logits.shape[1]]
    main_loss = nn.losses.cross_entropy(
        main_logits.reshape(-1, main_logits.shape[-1]),
        main_targets.reshape(-1),
        reduction="mean",
    )

    # MTP losses for each prediction depth
    mtp_losses = []
    h = hidden
    for k, head in enumerate(self.mtp_heads, start=1):
        if targets.shape[1] <= k:
            continue

        # Hidden states for positions with enough future
        h_slice = h[:, :-k, :]

        # Embeddings of tokens at position t+k-1
        prev_toks = targets[:, k - 1 : -1]
        if prev_toks.shape[1] > h_slice.shape[1]:
            prev_toks = prev_toks[:, : h_slice.shape[1]]
        prev_embed = self.model.embed(prev_toks)

        # MTP head produces new hidden states
        h_mtp = head(h_slice, prev_embed)

        # Project to logits using shared lm_head
        mtp_logits = self._project_logits(h_mtp)
        mtp_targets = targets[:, k : k + mtp_logits.shape[1]]

        if mtp_targets.shape[1] > 0:
            mtp_loss = nn.losses.cross_entropy(
                mtp_logits.reshape(
                    -1,
                    mtp_logits.shape[-1],
                ),
                mtp_targets.reshape(-1),
                reduction="mean",
            )
            mtp_losses.append(mtp_loss)

        # Chain: next head uses this head's output
        h = h_mtp

    # Combine losses
    if mtp_losses:
        avg_mtp_loss = sum(mtp_losses) / len(mtp_losses)
    else:
        avg_mtp_loss = mx.array(0.0)

    total_loss = main_loss + self.mtp_weight * avg_mtp_loss

    losses = {
        "main_loss": main_loss,
        "mtp_loss": avg_mtp_loss,
        "total_loss": total_loss,
    }

    return logits, losses  # type: ignore[return-value]

Curriculum Learning

lmxlab.training.curriculum

Curriculum learning utilities.

difficulty_curriculum(easy_data, hard_data, batch_size, seq_len, n_batches=200, warmup_fraction=0.5)

Mix easy and hard data with increasing difficulty.

Starts with mostly easy data and transitions to hard data.

Parameters:

Name Type Description Default
easy_data array

Token array of easier text.

required
hard_data array

Token array of harder text.

required
batch_size int

Sequences per batch.

required
seq_len int

Sequence length.

required
n_batches int

Total number of batches.

200
warmup_fraction float

Fraction of training spent warming up.

0.5

Yields:

Type Description
tuple[array, array]

(input, target) tuples with mixed difficulty.

Source code in src/lmxlab/training/curriculum.py
def difficulty_curriculum(
    easy_data: mx.array,
    hard_data: mx.array,
    batch_size: int,
    seq_len: int,
    n_batches: int = 200,
    warmup_fraction: float = 0.5,
) -> Iterator[tuple[mx.array, mx.array]]:
    """Mix easy and hard data with increasing difficulty.

    Starts with mostly easy data and transitions to hard data.

    Args:
        easy_data: Token array of easier text.
        hard_data: Token array of harder text.
        batch_size: Sequences per batch.
        seq_len: Sequence length.
        n_batches: Total number of batches.
        warmup_fraction: Fraction of training spent warming up.

    Yields:
        (input, target) tuples with mixed difficulty.
    """
    for i in range(n_batches):
        # Hard data fraction increases linearly
        progress = i / max(n_batches - 1, 1)
        hard_fraction = min(progress / warmup_fraction, 1.0)
        n_hard = int(batch_size * hard_fraction)
        n_easy = batch_size - n_hard

        inputs_list = []
        targets_list = []

        # Sample from easy data
        if n_easy > 0:
            starts = mx.random.randint(
                0, len(easy_data) - seq_len - 1, shape=(n_easy,)
            )
            mx.eval(starts)
            for s_val in cast(list[int], starts.tolist()):
                inputs_list.append(easy_data[s_val : s_val + seq_len])
                targets_list.append(easy_data[s_val + 1 : s_val + seq_len + 1])

        # Sample from hard data
        if n_hard > 0:
            starts = mx.random.randint(
                0, len(hard_data) - seq_len - 1, shape=(n_hard,)
            )
            mx.eval(starts)
            for s_val in cast(list[int], starts.tolist()):
                inputs_list.append(hard_data[s_val : s_val + seq_len])
                targets_list.append(hard_data[s_val + 1 : s_val + seq_len + 1])

        yield mx.stack(inputs_list), mx.stack(targets_list)

length_curriculum(tokens, batch_size, min_seq_len=32, max_seq_len=512, n_stages=4, batches_per_stage=100)

Generate batches with increasing sequence length.

Starts with short sequences and gradually increases, following curriculum learning principles.

Parameters:

Name Type Description Default
tokens array

Flat array of token IDs.

required
batch_size int

Sequences per batch.

required
min_seq_len int

Starting sequence length.

32
max_seq_len int

Final sequence length.

512
n_stages int

Number of curriculum stages.

4
batches_per_stage int

Batches per stage.

100

Yields:

Type Description
tuple[array, array]

(input, target) tuples with progressively longer sequences.

Source code in src/lmxlab/training/curriculum.py
def length_curriculum(
    tokens: mx.array,
    batch_size: int,
    min_seq_len: int = 32,
    max_seq_len: int = 512,
    n_stages: int = 4,
    batches_per_stage: int = 100,
) -> Iterator[tuple[mx.array, mx.array]]:
    """Generate batches with increasing sequence length.

    Starts with short sequences and gradually increases,
    following curriculum learning principles.

    Args:
        tokens: Flat array of token IDs.
        batch_size: Sequences per batch.
        min_seq_len: Starting sequence length.
        max_seq_len: Final sequence length.
        n_stages: Number of curriculum stages.
        batches_per_stage: Batches per stage.

    Yields:
        (input, target) tuples with progressively longer sequences.
    """
    for stage in range(n_stages):
        # Linear interpolation of sequence length
        progress = stage / max(n_stages - 1, 1)
        seq_len = int(min_seq_len + progress * (max_seq_len - min_seq_len))

        n_tokens = len(tokens)
        n_sequences = (n_tokens - 1) // seq_len

        if n_sequences < batch_size:
            continue

        for _ in range(batches_per_stage):
            # Random starting positions
            starts = mx.random.randint(
                0, n_tokens - seq_len - 1, shape=(batch_size,)
            )
            mx.eval(starts)

            inputs_list = []
            targets_list = []
            for s_val in cast(list[int], starts.tolist()):
                inputs_list.append(tokens[s_val : s_val + seq_len])
                targets_list.append(tokens[s_val + 1 : s_val + seq_len + 1])

            yield mx.stack(inputs_list), mx.stack(targets_list)

Knowledge Distillation

lmxlab.training.distillation

Knowledge distillation training losses.

Implements teacher-student distillation where a smaller student model learns from a larger teacher model's soft probability distributions (Hinton et al., 2015, arXiv:1503.02531).

The key insight: soft targets carry more information than hard labels. A teacher assigning 0.7 to "cat" and 0.2 to "kitten" teaches the student about word similarity, not just correctness.

Supported modes:

  • Logit distillation: KL divergence between temperature-scaled teacher and student logits. The standard approach.
  • Combined loss: Weighted mix of distillation loss and standard cross-entropy on hard targets. Balances learning from teacher with learning from ground truth.

Example::

from lmxlab.training.distillation import distillation_loss

# Teacher is frozen, student is trained
loss = distillation_loss(
    student, teacher, tokens,
    temperature=4.0, alpha=0.7,
)

distillation_loss(student, teacher, tokens, temperature=4.0, alpha=0.7)

Compute combined distillation + hard-target loss.

Loss = alpha * KL(teacher || student) * T^2 + (1 - alpha) * CE(student, targets)

The T^2 factor compensates for the gradient magnitude reduction caused by temperature scaling (Hinton et al.).

Parameters:

Name Type Description Default
student LanguageModel

Student model (being trained).

required
teacher LanguageModel

Teacher model (frozen, no gradients).

required
tokens array

Input token IDs (batch, seq_len). Targets are tokens shifted by one position.

required
temperature float

Softmax temperature for soft targets. Higher = softer distributions, more knowledge transfer. Typical values: 2-10.

4.0
alpha float

Weight for distillation loss (0-1). Higher means more reliance on teacher, less on hard targets.

0.7

Returns:

Type Description
array

Scalar combined loss.

Source code in src/lmxlab/training/distillation.py
def distillation_loss(
    student: LanguageModel,
    teacher: LanguageModel,
    tokens: mx.array,
    temperature: float = 4.0,
    alpha: float = 0.7,
) -> mx.array:
    """Compute combined distillation + hard-target loss.

    Loss = alpha * KL(teacher || student) * T^2
         + (1 - alpha) * CE(student, targets)

    The T^2 factor compensates for the gradient magnitude
    reduction caused by temperature scaling (Hinton et al.).

    Args:
        student: Student model (being trained).
        teacher: Teacher model (frozen, no gradients).
        tokens: Input token IDs (batch, seq_len). Targets are
            tokens shifted by one position.
        temperature: Softmax temperature for soft targets.
            Higher = softer distributions, more knowledge
            transfer. Typical values: 2-10.
        alpha: Weight for distillation loss (0-1). Higher means
            more reliance on teacher, less on hard targets.

    Returns:
        Scalar combined loss.
    """
    inputs = tokens[:, :-1]
    targets = tokens[:, 1:]

    # Student forward pass (will receive gradients)
    student_logits, _ = student(inputs)

    # Teacher forward pass (no gradients needed)
    teacher_logits, _ = teacher(inputs)

    # Distillation component: KL divergence on soft targets
    kl = soft_target_loss(student_logits, teacher_logits, temperature)

    if alpha >= 1.0:
        return kl

    # Hard target component: standard cross-entropy
    ce = _cross_entropy(student_logits, targets)

    return alpha * kl + (1.0 - alpha) * ce

soft_target_loss(student_logits, teacher_logits, temperature=4.0)

KL divergence between temperature-scaled distributions.

KL(teacher || student) computed on softened logits. Multiplied by T^2 to maintain gradient scale.

Parameters:

Name Type Description Default
student_logits array

Student output (batch, seq_len, vocab).

required
teacher_logits array

Teacher output (batch, seq_len, vocab).

required
temperature float

Softmax temperature.

4.0

Returns:

Type Description
array

Scalar KL divergence loss.

Source code in src/lmxlab/training/distillation.py
def soft_target_loss(
    student_logits: mx.array,
    teacher_logits: mx.array,
    temperature: float = 4.0,
) -> mx.array:
    """KL divergence between temperature-scaled distributions.

    KL(teacher || student) computed on softened logits.
    Multiplied by T^2 to maintain gradient scale.

    Args:
        student_logits: Student output (batch, seq_len, vocab).
        teacher_logits: Teacher output (batch, seq_len, vocab).
        temperature: Softmax temperature.

    Returns:
        Scalar KL divergence loss.
    """
    # Temperature-scaled log-softmax
    student_log_probs = nn.log_softmax(
        student_logits / temperature,
        axis=-1,
    )
    teacher_log_probs = nn.log_softmax(
        teacher_logits / temperature,
        axis=-1,
    )

    # KL(P || Q) = sum(P * (log P - log Q))
    teacher_probs = mx.exp(teacher_log_probs)
    kl = mx.sum(
        teacher_probs * (teacher_log_probs - student_log_probs),
        axis=-1,
    )

    # T^2 scaling (Hinton et al.)
    return mx.mean(kl) * (temperature**2)

Hardware Detection

lmxlab.training.hardware.detect_peak_tflops()

Detect FP32 peak TFLOP/s for the current GPU.

Uses mx.device_info() to identify the Apple Silicon architecture and returns a known peak value. Returns None if the architecture is not recognized.

Returns:

Type Description
float | None

Peak FP32 TFLOP/s, or None if unknown.

Source code in src/lmxlab/training/hardware.py
def detect_peak_tflops() -> float | None:
    """Detect FP32 peak TFLOP/s for the current GPU.

    Uses ``mx.device_info()`` to identify the Apple Silicon
    architecture and returns a known peak value. Returns None
    if the architecture is not recognized.

    Returns:
        Peak FP32 TFLOP/s, or None if unknown.
    """
    info = mx.device_info()
    arch = info.get("architecture", "")
    return _APPLE_SILICON_TFLOPS.get(arch)

Metric Callbacks

lmxlab.training.metric_callbacks.GradientStatsCallback

Tracks per-layer gradient norm statistics.

Computes gradient norms via a separate forward+backward pass on a stored probe batch at measurement intervals.

Parameters:

Name Type Description Default
model Module

Model to measure gradients on.

required
loss_fn Any

Loss function (model, x, y) -> scalar.

required
log_interval int

Steps between measurements.

100
Source code in src/lmxlab/training/metric_callbacks.py
class GradientStatsCallback:
    """Tracks per-layer gradient norm statistics.

    Computes gradient norms via a separate forward+backward pass
    on a stored probe batch at measurement intervals.

    Args:
        model: Model to measure gradients on.
        loss_fn: Loss function ``(model, x, y) -> scalar``.
        log_interval: Steps between measurements.
    """

    def __init__(
        self,
        model: nn.Module,
        loss_fn: Any,
        log_interval: int = 100,
    ) -> None:
        self.model = model
        self.loss_fn = loss_fn
        self.log_interval = log_interval
        self._probe_batch: tuple[mx.array, mx.array] | None = None

    def on_train_begin(self, config: TrainConfig) -> None:
        """No action on train begin."""

    def on_step_end(self, step: int, metrics: dict[str, Any]) -> None:
        """Compute gradient stats at log_interval."""
        if self._probe_batch is None:
            return
        if step % self.log_interval != 0:
            return

        was_training = self.model.training
        self.model.eval()
        try:
            loss_and_grad = nn.value_and_grad(self.model, self.loss_fn)
            x, y = self._probe_batch
            _, grads = loss_and_grad(self.model, x, y)
        finally:
            if was_training:
                self.model.train()

        norms = []
        for _, g in tree_flatten(grads):
            norm = mx.sqrt(mx.sum(g * g))
            mx.eval(norm)
            norms.append(norm.item())

        if norms:
            mean_norm = sum(norms) / len(norms)
            std_norm = (
                sum((n - mean_norm) ** 2 for n in norms) / len(norms)
            ) ** 0.5
            max_idx = max(range(len(norms)), key=norms.__getitem__)
            metrics["exp_grad_norm_mean"] = mean_norm
            metrics["exp_grad_norm_std"] = std_norm
            metrics["exp_grad_norm_max_layer"] = float(max_idx)

    def on_eval_end(self, step: int, metrics: dict[str, Any]) -> None:
        """No action on eval."""

    def on_train_end(self, history: list[dict[str, Any]]) -> None:
        """No action on train end."""

    def set_probe_batch(self, batch: tuple[mx.array, mx.array]) -> None:
        """Store a probe batch for gradient measurement.

        Args:
            batch: Tuple of (input, target) arrays.
        """
        self._probe_batch = batch

on_eval_end(step, metrics)

No action on eval.

Source code in src/lmxlab/training/metric_callbacks.py
def on_eval_end(self, step: int, metrics: dict[str, Any]) -> None:
    """No action on eval."""

on_step_end(step, metrics)

Compute gradient stats at log_interval.

Source code in src/lmxlab/training/metric_callbacks.py
def on_step_end(self, step: int, metrics: dict[str, Any]) -> None:
    """Compute gradient stats at log_interval."""
    if self._probe_batch is None:
        return
    if step % self.log_interval != 0:
        return

    was_training = self.model.training
    self.model.eval()
    try:
        loss_and_grad = nn.value_and_grad(self.model, self.loss_fn)
        x, y = self._probe_batch
        _, grads = loss_and_grad(self.model, x, y)
    finally:
        if was_training:
            self.model.train()

    norms = []
    for _, g in tree_flatten(grads):
        norm = mx.sqrt(mx.sum(g * g))
        mx.eval(norm)
        norms.append(norm.item())

    if norms:
        mean_norm = sum(norms) / len(norms)
        std_norm = (
            sum((n - mean_norm) ** 2 for n in norms) / len(norms)
        ) ** 0.5
        max_idx = max(range(len(norms)), key=norms.__getitem__)
        metrics["exp_grad_norm_mean"] = mean_norm
        metrics["exp_grad_norm_std"] = std_norm
        metrics["exp_grad_norm_max_layer"] = float(max_idx)

on_train_begin(config)

No action on train begin.

Source code in src/lmxlab/training/metric_callbacks.py
def on_train_begin(self, config: TrainConfig) -> None:
    """No action on train begin."""

on_train_end(history)

No action on train end.

Source code in src/lmxlab/training/metric_callbacks.py
def on_train_end(self, history: list[dict[str, Any]]) -> None:
    """No action on train end."""

set_probe_batch(batch)

Store a probe batch for gradient measurement.

Parameters:

Name Type Description Default
batch tuple[array, array]

Tuple of (input, target) arrays.

required
Source code in src/lmxlab/training/metric_callbacks.py
def set_probe_batch(self, batch: tuple[mx.array, mx.array]) -> None:
    """Store a probe batch for gradient measurement.

    Args:
        batch: Tuple of (input, target) arrays.
    """
    self._probe_batch = batch

lmxlab.training.metric_callbacks.WeightStatsCallback

Tracks weight norm and weight delta statistics.

Stores initial norms in on_train_begin and computes delta (change from initial) at measurement intervals.

Parameters:

Name Type Description Default
model Module

Model to measure weights on.

required
log_interval int

Steps between measurements.

100
Source code in src/lmxlab/training/metric_callbacks.py
class WeightStatsCallback:
    """Tracks weight norm and weight delta statistics.

    Stores initial norms in ``on_train_begin`` and computes
    delta (change from initial) at measurement intervals.

    Args:
        model: Model to measure weights on.
        log_interval: Steps between measurements.
    """

    def __init__(
        self,
        model: nn.Module,
        log_interval: int = 100,
    ) -> None:
        self.model = model
        self.log_interval = log_interval
        self._initial_norm: float = 0.0

    def on_train_begin(self, config: TrainConfig) -> None:
        """Store initial weight norm."""
        self._initial_norm = self._compute_weight_norm()

    def on_step_end(self, step: int, metrics: dict[str, Any]) -> None:
        """Compute weight stats at log_interval."""
        if step % self.log_interval != 0:
            return
        current = self._compute_weight_norm()
        metrics["exp_weight_norm"] = current
        metrics["exp_weight_delta"] = abs(current - self._initial_norm)

    def on_eval_end(self, step: int, metrics: dict[str, Any]) -> None:
        """No action on eval."""

    def on_train_end(self, history: list[dict[str, Any]]) -> None:
        """No action on train end."""

    def _compute_weight_norm(self) -> float:
        """Compute total L2 norm of trainable parameters."""
        total = mx.array(0.0)
        for _, p in tree_flatten(self.model.trainable_parameters()):
            total = total + mx.sum(p * p)
        mx.eval(total)
        return mx.sqrt(total).item()

on_eval_end(step, metrics)

No action on eval.

Source code in src/lmxlab/training/metric_callbacks.py
def on_eval_end(self, step: int, metrics: dict[str, Any]) -> None:
    """No action on eval."""

on_step_end(step, metrics)

Compute weight stats at log_interval.

Source code in src/lmxlab/training/metric_callbacks.py
def on_step_end(self, step: int, metrics: dict[str, Any]) -> None:
    """Compute weight stats at log_interval."""
    if step % self.log_interval != 0:
        return
    current = self._compute_weight_norm()
    metrics["exp_weight_norm"] = current
    metrics["exp_weight_delta"] = abs(current - self._initial_norm)

on_train_begin(config)

Store initial weight norm.

Source code in src/lmxlab/training/metric_callbacks.py
def on_train_begin(self, config: TrainConfig) -> None:
    """Store initial weight norm."""
    self._initial_norm = self._compute_weight_norm()

on_train_end(history)

No action on train end.

Source code in src/lmxlab/training/metric_callbacks.py
def on_train_end(self, history: list[dict[str, Any]]) -> None:
    """No action on train end."""

lmxlab.training.metric_callbacks.ActivationStatsCallback

Tracks activation norm ratios and sparsity.

Uses ActivationCapture from the analysis module to capture layer activations on a probe batch.

Parameters:

Name Type Description Default
model Any

Language model to instrument.

required
probe_batch array

Input tokens for activation capture.

required
eval_interval int

Steps between measurements.

500
eps float

Threshold for sparsity (fraction |x| < eps).

0.001
Source code in src/lmxlab/training/metric_callbacks.py
class ActivationStatsCallback:
    """Tracks activation norm ratios and sparsity.

    Uses ``ActivationCapture`` from the analysis module to
    capture layer activations on a probe batch.

    Args:
        model: Language model to instrument.
        probe_batch: Input tokens for activation capture.
        eval_interval: Steps between measurements.
        eps: Threshold for sparsity (fraction |x| < eps).
    """

    def __init__(
        self,
        model: Any,
        probe_batch: mx.array,
        eval_interval: int = 500,
        eps: float = 1e-3,
    ) -> None:
        self.model = model
        self.probe_batch = probe_batch
        self.eval_interval = eval_interval
        self.eps = eps

    def on_train_begin(self, config: TrainConfig) -> None:
        """No action on train begin."""

    def on_step_end(self, step: int, metrics: dict[str, Any]) -> None:
        """Compute activation stats at eval_interval."""
        if step % self.eval_interval != 0:
            return

        from lmxlab.analysis.activations import (
            ActivationCapture,
        )

        was_training = self.model.training
        self.model.eval()
        try:
            with ActivationCapture(self.model) as cap:
                self.model(self.probe_batch)
        finally:
            if was_training:
                self.model.train()

        # Compute per-layer output norms
        output_norms = []
        sparsities = []
        for key, val in sorted(cap.activations.items()):
            if not key.endswith("/output"):
                continue
            mx.eval(val)
            norm = mx.sqrt(mx.sum(val * val)).item()
            output_norms.append(norm)
            # Sparsity: fraction of elements near zero
            sparse_frac = mx.mean((mx.abs(val) < self.eps).astype(mx.float32))
            mx.eval(sparse_frac)
            sparsities.append(sparse_frac.item())

        if len(output_norms) >= 2:
            metrics["exp_act_norm_ratio"] = output_norms[-1] / max(
                output_norms[0], 1e-10
            )
        if sparsities:
            metrics["exp_act_sparsity_mean"] = sum(sparsities) / len(
                sparsities
            )

    def on_eval_end(self, step: int, metrics: dict[str, Any]) -> None:
        """No action on eval."""

    def on_train_end(self, history: list[dict[str, Any]]) -> None:
        """No action on train end."""

on_eval_end(step, metrics)

No action on eval.

Source code in src/lmxlab/training/metric_callbacks.py
def on_eval_end(self, step: int, metrics: dict[str, Any]) -> None:
    """No action on eval."""

on_step_end(step, metrics)

Compute activation stats at eval_interval.

Source code in src/lmxlab/training/metric_callbacks.py
def on_step_end(self, step: int, metrics: dict[str, Any]) -> None:
    """Compute activation stats at eval_interval."""
    if step % self.eval_interval != 0:
        return

    from lmxlab.analysis.activations import (
        ActivationCapture,
    )

    was_training = self.model.training
    self.model.eval()
    try:
        with ActivationCapture(self.model) as cap:
            self.model(self.probe_batch)
    finally:
        if was_training:
            self.model.train()

    # Compute per-layer output norms
    output_norms = []
    sparsities = []
    for key, val in sorted(cap.activations.items()):
        if not key.endswith("/output"):
            continue
        mx.eval(val)
        norm = mx.sqrt(mx.sum(val * val)).item()
        output_norms.append(norm)
        # Sparsity: fraction of elements near zero
        sparse_frac = mx.mean((mx.abs(val) < self.eps).astype(mx.float32))
        mx.eval(sparse_frac)
        sparsities.append(sparse_frac.item())

    if len(output_norms) >= 2:
        metrics["exp_act_norm_ratio"] = output_norms[-1] / max(
            output_norms[0], 1e-10
        )
    if sparsities:
        metrics["exp_act_sparsity_mean"] = sum(sparsities) / len(
            sparsities
        )

on_train_begin(config)

No action on train begin.

Source code in src/lmxlab/training/metric_callbacks.py
def on_train_begin(self, config: TrainConfig) -> None:
    """No action on train begin."""

on_train_end(history)

No action on train end.

Source code in src/lmxlab/training/metric_callbacks.py
def on_train_end(self, history: list[dict[str, Any]]) -> None:
    """No action on train end."""

lmxlab.training.metric_callbacks.AttentionEntropyCallback

Tracks Shannon entropy of attention weights.

Uses extract_attention_maps from the analysis module to get per-head attention weights.

Parameters:

Name Type Description Default
model Any

Language model with attention layers.

required
probe_batch array

Input tokens for attention extraction.

required
eval_interval int

Steps between measurements.

500
Source code in src/lmxlab/training/metric_callbacks.py
class AttentionEntropyCallback:
    """Tracks Shannon entropy of attention weights.

    Uses ``extract_attention_maps`` from the analysis module
    to get per-head attention weights.

    Args:
        model: Language model with attention layers.
        probe_batch: Input tokens for attention extraction.
        eval_interval: Steps between measurements.
    """

    def __init__(
        self,
        model: Any,
        probe_batch: mx.array,
        eval_interval: int = 500,
    ) -> None:
        self.model = model
        self.probe_batch = probe_batch
        self.eval_interval = eval_interval

    def on_train_begin(self, config: TrainConfig) -> None:
        """No action on train begin."""

    def on_step_end(self, step: int, metrics: dict[str, Any]) -> None:
        """Compute attention entropy at eval_interval."""
        if step % self.eval_interval != 0:
            return

        from lmxlab.analysis.attention import (
            extract_attention_maps,
        )

        was_training = self.model.training
        self.model.eval()
        try:
            maps = extract_attention_maps(self.model, self.probe_batch)
        finally:
            if was_training:
                self.model.train()

        entropies = []
        for weights in maps.values():
            # weights: (batch, heads, seq, seq)
            # Shannon entropy per head, averaged over batch/seq
            # Clamp to avoid log(0)
            w = mx.clip(weights, 1e-10, 1.0)
            h = -mx.sum(w * mx.log(w), axis=-1)  # per position
            mean_h = mx.mean(h)
            mx.eval(mean_h)
            entropies.append(mean_h.item())

        if entropies:
            mean_ent = sum(entropies) / len(entropies)
            std_ent = (
                sum((e - mean_ent) ** 2 for e in entropies) / len(entropies)
            ) ** 0.5
            metrics["exp_attn_entropy_mean"] = mean_ent
            metrics["exp_attn_entropy_std"] = std_ent

    def on_eval_end(self, step: int, metrics: dict[str, Any]) -> None:
        """No action on eval."""

    def on_train_end(self, history: list[dict[str, Any]]) -> None:
        """No action on train end."""

on_eval_end(step, metrics)

No action on eval.

Source code in src/lmxlab/training/metric_callbacks.py
def on_eval_end(self, step: int, metrics: dict[str, Any]) -> None:
    """No action on eval."""

on_step_end(step, metrics)

Compute attention entropy at eval_interval.

Source code in src/lmxlab/training/metric_callbacks.py
def on_step_end(self, step: int, metrics: dict[str, Any]) -> None:
    """Compute attention entropy at eval_interval."""
    if step % self.eval_interval != 0:
        return

    from lmxlab.analysis.attention import (
        extract_attention_maps,
    )

    was_training = self.model.training
    self.model.eval()
    try:
        maps = extract_attention_maps(self.model, self.probe_batch)
    finally:
        if was_training:
            self.model.train()

    entropies = []
    for weights in maps.values():
        # weights: (batch, heads, seq, seq)
        # Shannon entropy per head, averaged over batch/seq
        # Clamp to avoid log(0)
        w = mx.clip(weights, 1e-10, 1.0)
        h = -mx.sum(w * mx.log(w), axis=-1)  # per position
        mean_h = mx.mean(h)
        mx.eval(mean_h)
        entropies.append(mean_h.item())

    if entropies:
        mean_ent = sum(entropies) / len(entropies)
        std_ent = (
            sum((e - mean_ent) ** 2 for e in entropies) / len(entropies)
        ) ** 0.5
        metrics["exp_attn_entropy_mean"] = mean_ent
        metrics["exp_attn_entropy_std"] = std_ent

on_train_begin(config)

No action on train begin.

Source code in src/lmxlab/training/metric_callbacks.py
def on_train_begin(self, config: TrainConfig) -> None:
    """No action on train begin."""

on_train_end(history)

No action on train end.

Source code in src/lmxlab/training/metric_callbacks.py
def on_train_end(self, history: list[dict[str, Any]]) -> None:
    """No action on train end."""

lmxlab.training.metric_callbacks.LossCurvatureCallback

Tracks gradient noise scale from grad_norm history.

Maintains a running window of grad_norm values from the metrics dict and computes gradient noise scale = std(grad_norms) / mean(grad_norms).

Parameters:

Name Type Description Default
window_size int

Number of recent grad_norms to track.

50
Source code in src/lmxlab/training/metric_callbacks.py
class LossCurvatureCallback:
    """Tracks gradient noise scale from grad_norm history.

    Maintains a running window of ``grad_norm`` values from
    the metrics dict and computes gradient noise scale =
    std(grad_norms) / mean(grad_norms).

    Args:
        window_size: Number of recent grad_norms to track.
    """

    def __init__(self, window_size: int = 50) -> None:
        self.window_size = window_size
        self._window: deque[float] = deque(maxlen=window_size)

    def on_train_begin(self, config: TrainConfig) -> None:
        """Reset window."""
        self._window.clear()

    def on_step_end(self, step: int, metrics: dict[str, Any]) -> None:
        """Update window and compute noise scale."""
        grad_norm = metrics.get("grad_norm")
        if grad_norm is None:
            return
        self._window.append(float(grad_norm))
        if len(self._window) >= 2:
            vals = list(self._window)
            mean = sum(vals) / len(vals)
            if mean > 1e-10:
                std = (sum((v - mean) ** 2 for v in vals) / len(vals)) ** 0.5
                metrics["exp_grad_noise_scale"] = std / mean

    def on_eval_end(self, step: int, metrics: dict[str, Any]) -> None:
        """No action on eval."""

    def on_train_end(self, history: list[dict[str, Any]]) -> None:
        """No action on train end."""

on_eval_end(step, metrics)

No action on eval.

Source code in src/lmxlab/training/metric_callbacks.py
def on_eval_end(self, step: int, metrics: dict[str, Any]) -> None:
    """No action on eval."""

on_step_end(step, metrics)

Update window and compute noise scale.

Source code in src/lmxlab/training/metric_callbacks.py
def on_step_end(self, step: int, metrics: dict[str, Any]) -> None:
    """Update window and compute noise scale."""
    grad_norm = metrics.get("grad_norm")
    if grad_norm is None:
        return
    self._window.append(float(grad_norm))
    if len(self._window) >= 2:
        vals = list(self._window)
        mean = sum(vals) / len(vals)
        if mean > 1e-10:
            std = (sum((v - mean) ** 2 for v in vals) / len(vals)) ** 0.5
            metrics["exp_grad_noise_scale"] = std / mean

on_train_begin(config)

Reset window.

Source code in src/lmxlab/training/metric_callbacks.py
def on_train_begin(self, config: TrainConfig) -> None:
    """Reset window."""
    self._window.clear()

on_train_end(history)

No action on train end.

Source code in src/lmxlab/training/metric_callbacks.py
def on_train_end(self, history: list[dict[str, Any]]) -> None:
    """No action on train end."""

lmxlab.training.metric_callbacks.EffectiveRankCallback

Tracks effective rank of weight matrices.

Computes SVD of the largest weight matrix per layer, then effective rank = exp(entropy of normalized singular values).

Parameters:

Name Type Description Default
model Module

Model to analyze.

required
eval_interval int

Steps between measurements.

500
Source code in src/lmxlab/training/metric_callbacks.py
class EffectiveRankCallback:
    """Tracks effective rank of weight matrices.

    Computes SVD of the largest weight matrix per layer, then
    effective rank = exp(entropy of normalized singular values).

    Args:
        model: Model to analyze.
        eval_interval: Steps between measurements.
    """

    def __init__(
        self,
        model: nn.Module,
        eval_interval: int = 500,
    ) -> None:
        self.model = model
        self.eval_interval = eval_interval

    def on_train_begin(self, config: TrainConfig) -> None:
        """No action on train begin."""

    def on_step_end(self, step: int, metrics: dict[str, Any]) -> None:
        """Compute effective rank at eval_interval."""
        if step % self.eval_interval != 0:
            return

        ranks = []
        for _, p in tree_flatten(self.model.trainable_parameters()):
            if p.ndim != 2:
                continue
            # Only process matrices above a size threshold
            if min(p.shape) < 4:
                continue
            sv = mx.linalg.svd(p, compute_uv=False, stream=mx.cpu)
            mx.eval(sv)
            # Normalize singular values
            sv_sum = mx.sum(sv)
            if sv_sum.item() < 1e-10:
                continue
            p_sv = sv / sv_sum  # type: ignore[operator]
            # Entropy of normalized singular values
            p_sv = mx.clip(p_sv, 1e-10, 1.0)
            entropy = -mx.sum(p_sv * mx.log(p_sv))
            mx.eval(entropy)
            eff_rank = math.exp(entropy.item())
            ranks.append(eff_rank)

        if ranks:
            metrics["exp_effective_rank_mean"] = sum(ranks) / len(ranks)

    def on_eval_end(self, step: int, metrics: dict[str, Any]) -> None:
        """No action on eval."""

    def on_train_end(self, history: list[dict[str, Any]]) -> None:
        """No action on train end."""

on_eval_end(step, metrics)

No action on eval.

Source code in src/lmxlab/training/metric_callbacks.py
def on_eval_end(self, step: int, metrics: dict[str, Any]) -> None:
    """No action on eval."""

on_step_end(step, metrics)

Compute effective rank at eval_interval.

Source code in src/lmxlab/training/metric_callbacks.py
def on_step_end(self, step: int, metrics: dict[str, Any]) -> None:
    """Compute effective rank at eval_interval."""
    if step % self.eval_interval != 0:
        return

    ranks = []
    for _, p in tree_flatten(self.model.trainable_parameters()):
        if p.ndim != 2:
            continue
        # Only process matrices above a size threshold
        if min(p.shape) < 4:
            continue
        sv = mx.linalg.svd(p, compute_uv=False, stream=mx.cpu)
        mx.eval(sv)
        # Normalize singular values
        sv_sum = mx.sum(sv)
        if sv_sum.item() < 1e-10:
            continue
        p_sv = sv / sv_sum  # type: ignore[operator]
        # Entropy of normalized singular values
        p_sv = mx.clip(p_sv, 1e-10, 1.0)
        entropy = -mx.sum(p_sv * mx.log(p_sv))
        mx.eval(entropy)
        eff_rank = math.exp(entropy.item())
        ranks.append(eff_rank)

    if ranks:
        metrics["exp_effective_rank_mean"] = sum(ranks) / len(ranks)

on_train_begin(config)

No action on train begin.

Source code in src/lmxlab/training/metric_callbacks.py
def on_train_begin(self, config: TrainConfig) -> None:
    """No action on train begin."""

on_train_end(history)

No action on train end.

Source code in src/lmxlab/training/metric_callbacks.py
def on_train_end(self, history: list[dict[str, Any]]) -> None:
    """No action on train end."""