Skip to content

Core

The core package provides the building blocks for all architectures.

Configuration

lmxlab.core.config.BlockConfig dataclass

Configuration for a single transformer block.

Defines which components (attention, FFN, norm, position encoding) to use and their parameters. Components are looked up by name from the registry.

Parameters:

Name Type Description Default
attention str

Registry name for attention module.

'mha'
ffn str

Registry name for feed-forward module.

'standard'
norm str

Registry name for normalization function.

'layer_norm'
position str

Registry name for positional encoding.

'sinusoidal'
d_model int

Hidden dimension size.

512
n_heads int

Number of attention heads.

8
n_kv_heads int | None

Number of key/value heads (for GQA). Defaults to n_heads (standard MHA).

None
d_ff int

Feed-forward intermediate dimension.

2048
bias bool

Whether to use bias in linear layers.

True
dropout float

Dropout rate (0.0 = no dropout).

0.0
norm_eps float

Epsilon for normalization layers.

1e-05
rope_theta float

Base frequency for RoPE.

10000.0
max_seq_len int

Maximum sequence length.

2048
pre_norm bool

If True, apply norm before attention/FFN (pre-norm). If False, apply after (post-norm).

True
window_size int | None

Sliding window size for local attention. None means full (global) attention.

None
mamba_bc_norm bool

If True, apply RMSNorm to B and C projections in Mamba-3 (analogous to QK-norm).

False
mamba_trapezoidal bool

If True, use trapezoidal discretization in Mamba-3 (two SSD calls).

False
mamba_complex_a bool

If True, apply data-dependent RoPE to B and C in Mamba-3 (complex eigenvalues).

False
qk_norm bool

If True, apply per-head RMSNorm to Q and K after reshape, before RoPE (OLMo 2 style).

False
attention_chunk_size int | None

Chunk size for chunked local attention. None means full (global) attention.

None
mup bool

If True, use μP attention scaling (1/d_head instead of 1/√d_head).

False
Source code in src/lmxlab/core/config.py
@dataclass(frozen=True)
class BlockConfig:
    """Configuration for a single transformer block.

    Defines which components (attention, FFN, norm, position encoding)
    to use and their parameters. Components are looked up by name
    from the registry.

    Args:
        attention: Registry name for attention module.
        ffn: Registry name for feed-forward module.
        norm: Registry name for normalization function.
        position: Registry name for positional encoding.
        d_model: Hidden dimension size.
        n_heads: Number of attention heads.
        n_kv_heads: Number of key/value heads (for GQA).
            Defaults to n_heads (standard MHA).
        d_ff: Feed-forward intermediate dimension.
        bias: Whether to use bias in linear layers.
        dropout: Dropout rate (0.0 = no dropout).
        norm_eps: Epsilon for normalization layers.
        rope_theta: Base frequency for RoPE.
        max_seq_len: Maximum sequence length.
        pre_norm: If True, apply norm before attention/FFN (pre-norm).
            If False, apply after (post-norm).
        window_size: Sliding window size for local attention.
            None means full (global) attention.
        mamba_bc_norm: If True, apply RMSNorm to B and C
            projections in Mamba-3 (analogous to QK-norm).
        mamba_trapezoidal: If True, use trapezoidal
            discretization in Mamba-3 (two SSD calls).
        mamba_complex_a: If True, apply data-dependent RoPE
            to B and C in Mamba-3 (complex eigenvalues).
        qk_norm: If True, apply per-head RMSNorm to Q and K
            after reshape, before RoPE (OLMo 2 style).
        attention_chunk_size: Chunk size for chunked local
            attention. None means full (global) attention.
        mup: If True, use μP attention scaling (1/d_head
            instead of 1/√d_head).
    """

    attention: str = "mha"
    ffn: str = "standard"
    norm: str = "layer_norm"
    position: str = "sinusoidal"
    d_model: int = 512
    n_heads: int = 8
    n_kv_heads: int | None = None
    d_ff: int = 2048
    bias: bool = True
    dropout: float = 0.0
    norm_eps: float = 1e-5
    rope_theta: float = 10000.0
    max_seq_len: int = 2048
    pre_norm: bool = True
    # Sliding window attention parameters
    window_size: int | None = None
    # MLA (Multi-Head Latent Attention) parameters
    kv_lora_rank: int | None = None
    q_lora_rank: int | None = None
    rope_dim: int | None = None
    # Mixture of Experts parameters
    n_experts: int | None = None
    top_k_experts: int = 2
    n_shared_experts: int | None = None
    # Gated DeltaNet (linear attention) parameters
    conv_kernel_size: int = 4
    use_short_conv: bool = False
    # Mamba-2 (SSM) parameters
    mamba_n_heads: int | None = None
    mamba_head_dim: int | None = None
    ssm_state_size: int = 128
    mamba_expand: int = 2
    mamba_n_groups: int = 1
    mamba_chunk_size: int = 128
    # Mamba-3 enhancements
    mamba_bc_norm: bool = False
    mamba_trapezoidal: bool = False
    mamba_complex_a: bool = False
    # LatentMoE parameters
    moe_latent_size: int | None = None
    moe_d_ff: int | None = None
    shared_expert_d_ff: int | None = None
    moe_routed_scaling_factor: float = 1.0
    moe_n_groups: int = 1
    moe_topk_groups: int = 1
    # QK-norm (per-head RMSNorm on Q and K)
    qk_norm: bool = False
    # Chunked local attention chunk size
    attention_chunk_size: int | None = None
    # μP (Maximal Update Parameterization) flag
    mup: bool = False
    # DeepSeek Sparse Attention (DSA) parameters
    sparse_compress_ratio: int | None = None
    sparse_select_k: int | None = None

    @property
    def head_dim(self) -> int:
        """Dimension per attention head."""
        return self.d_model // self.n_heads

    @property
    def effective_n_kv_heads(self) -> int:
        """Number of KV heads (defaults to n_heads if not set)."""
        return self.n_kv_heads if self.n_kv_heads is not None else self.n_heads

attention = 'mha' class-attribute instance-attribute

attention_chunk_size = None class-attribute instance-attribute

bias = True class-attribute instance-attribute

conv_kernel_size = 4 class-attribute instance-attribute

d_ff = 2048 class-attribute instance-attribute

d_model = 512 class-attribute instance-attribute

dropout = 0.0 class-attribute instance-attribute

effective_n_kv_heads property

Number of KV heads (defaults to n_heads if not set).

ffn = 'standard' class-attribute instance-attribute

head_dim property

Dimension per attention head.

kv_lora_rank = None class-attribute instance-attribute

mamba_bc_norm = False class-attribute instance-attribute

mamba_chunk_size = 128 class-attribute instance-attribute

mamba_complex_a = False class-attribute instance-attribute

mamba_expand = 2 class-attribute instance-attribute

mamba_head_dim = None class-attribute instance-attribute

mamba_n_groups = 1 class-attribute instance-attribute

mamba_n_heads = None class-attribute instance-attribute

mamba_trapezoidal = False class-attribute instance-attribute

max_seq_len = 2048 class-attribute instance-attribute

moe_d_ff = None class-attribute instance-attribute

moe_latent_size = None class-attribute instance-attribute

moe_n_groups = 1 class-attribute instance-attribute

moe_routed_scaling_factor = 1.0 class-attribute instance-attribute

moe_topk_groups = 1 class-attribute instance-attribute

mup = False class-attribute instance-attribute

n_experts = None class-attribute instance-attribute

n_heads = 8 class-attribute instance-attribute

n_kv_heads = None class-attribute instance-attribute

n_shared_experts = None class-attribute instance-attribute

norm = 'layer_norm' class-attribute instance-attribute

norm_eps = 1e-05 class-attribute instance-attribute

position = 'sinusoidal' class-attribute instance-attribute

pre_norm = True class-attribute instance-attribute

q_lora_rank = None class-attribute instance-attribute

qk_norm = False class-attribute instance-attribute

rope_dim = None class-attribute instance-attribute

rope_theta = 10000.0 class-attribute instance-attribute

shared_expert_d_ff = None class-attribute instance-attribute

sparse_compress_ratio = None class-attribute instance-attribute

sparse_select_k = None class-attribute instance-attribute

ssm_state_size = 128 class-attribute instance-attribute

top_k_experts = 2 class-attribute instance-attribute

use_short_conv = False class-attribute instance-attribute

window_size = None class-attribute instance-attribute

__init__(attention='mha', ffn='standard', norm='layer_norm', position='sinusoidal', d_model=512, n_heads=8, n_kv_heads=None, d_ff=2048, bias=True, dropout=0.0, norm_eps=1e-05, rope_theta=10000.0, max_seq_len=2048, pre_norm=True, window_size=None, kv_lora_rank=None, q_lora_rank=None, rope_dim=None, n_experts=None, top_k_experts=2, n_shared_experts=None, conv_kernel_size=4, use_short_conv=False, mamba_n_heads=None, mamba_head_dim=None, ssm_state_size=128, mamba_expand=2, mamba_n_groups=1, mamba_chunk_size=128, mamba_bc_norm=False, mamba_trapezoidal=False, mamba_complex_a=False, moe_latent_size=None, moe_d_ff=None, shared_expert_d_ff=None, moe_routed_scaling_factor=1.0, moe_n_groups=1, moe_topk_groups=1, qk_norm=False, attention_chunk_size=None, mup=False, sparse_compress_ratio=None, sparse_select_k=None)

lmxlab.core.config.ModelConfig dataclass

Configuration for a full language model.

Parameters:

Name Type Description Default
block BlockConfig

Block configuration (shared across all layers).

BlockConfig()
vocab_size int

Vocabulary size.

32000
n_layers int

Number of transformer blocks.

6
tie_embeddings bool

Whether to tie input/output embeddings.

True
block_configs tuple[BlockConfig, ...] | None

Per-layer block overrides (optional). If provided, must have length n_layers.

None
mup_base_width int | None

Base model width for μP. When set, enables μP scaling. None means standard parameterization (SP).

None
mtp_n_predict int

Number of multi-token prediction heads. 0 disables MTP.

0
mtp_lambda float

Weight for MTP auxiliary loss.

0.1
Source code in src/lmxlab/core/config.py
@dataclass(frozen=True)
class ModelConfig:
    """Configuration for a full language model.

    Args:
        block: Block configuration (shared across all layers).
        vocab_size: Vocabulary size.
        n_layers: Number of transformer blocks.
        tie_embeddings: Whether to tie input/output embeddings.
        block_configs: Per-layer block overrides (optional).
            If provided, must have length n_layers.
        mup_base_width: Base model width for μP. When set,
            enables μP scaling. None means standard
            parameterization (SP).
        mtp_n_predict: Number of multi-token prediction heads.
            0 disables MTP.
        mtp_lambda: Weight for MTP auxiliary loss.
    """

    block: BlockConfig = field(default_factory=BlockConfig)
    vocab_size: int = 32000
    n_layers: int = 6
    tie_embeddings: bool = True
    block_configs: tuple[BlockConfig, ...] | None = None
    mup_base_width: int | None = None
    mtp_n_predict: int = 0
    mtp_lambda: float = 0.1

    @property
    def width_mult(self) -> float:
        """Width multiplier for μP (d_model / base_width).

        Returns 1.0 when μP is disabled.
        """
        if self.mup_base_width is None:
            return 1.0
        return self.block.d_model / self.mup_base_width

    def get_block_config(self, layer_idx: int) -> BlockConfig:
        """Get block config for a specific layer.

        Args:
            layer_idx: Layer index.

        Returns:
            BlockConfig for the given layer.
        """
        if self.block_configs is not None:
            return self.block_configs[layer_idx]
        return self.block

block = field(default_factory=BlockConfig) class-attribute instance-attribute

block_configs = None class-attribute instance-attribute

mtp_lambda = 0.1 class-attribute instance-attribute

mtp_n_predict = 0 class-attribute instance-attribute

mup_base_width = None class-attribute instance-attribute

n_layers = 6 class-attribute instance-attribute

tie_embeddings = True class-attribute instance-attribute

vocab_size = 32000 class-attribute instance-attribute

width_mult property

Width multiplier for μP (d_model / base_width).

Returns 1.0 when μP is disabled.

__init__(block=BlockConfig(), vocab_size=32000, n_layers=6, tie_embeddings=True, block_configs=None, mup_base_width=None, mtp_n_predict=0, mtp_lambda=0.1)

get_block_config(layer_idx)

Get block config for a specific layer.

Parameters:

Name Type Description Default
layer_idx int

Layer index.

required

Returns:

Type Description
BlockConfig

BlockConfig for the given layer.

Source code in src/lmxlab/core/config.py
def get_block_config(self, layer_idx: int) -> BlockConfig:
    """Get block config for a specific layer.

    Args:
        layer_idx: Layer index.

    Returns:
        BlockConfig for the given layer.
    """
    if self.block_configs is not None:
        return self.block_configs[layer_idx]
    return self.block

ConfigurableBlock

lmxlab.core.block.ConfigurableBlock

Bases: Module

A transformer block assembled from registry components.

The block uses pre-norm or post-norm residual connections depending on the config. Components (attention, FFN, norm, position encoding) are looked up from registries by name.

Parameters:

Name Type Description Default
config BlockConfig

Block configuration specifying components.

required
Example

config = BlockConfig( ... attention='gqa', ffn='gated', ... norm='rms_norm', position='rope', ... d_model=256, n_heads=4, n_kv_heads=2, ... ) block = ConfigurableBlock(config)

Source code in src/lmxlab/core/block.py
class ConfigurableBlock(nn.Module):
    """A transformer block assembled from registry components.

    The block uses pre-norm or post-norm residual connections
    depending on the config. Components (attention, FFN, norm,
    position encoding) are looked up from registries by name.

    Args:
        config: Block configuration specifying components.

    Example:
        >>> config = BlockConfig(
        ...     attention='gqa', ffn='gated',
        ...     norm='rms_norm', position='rope',
        ...     d_model=256, n_heads=4, n_kv_heads=2,
        ... )
        >>> block = ConfigurableBlock(config)
    """

    def __init__(self, config: BlockConfig) -> None:
        super().__init__()
        self.config = config

        # Build components from registries
        attn_cls = attention_registry.get(config.attention)
        ffn_cls = ffn_registry.get(config.ffn)
        norm_cls = norm_registry.get(config.norm)

        self.attention = attn_cls(config)
        self.ffn = ffn_cls(config)
        self.attn_norm = norm_cls(config)
        self.ffn_norm = norm_cls(config)

        # Residual dropout (applied after sublayer output)
        self.resid_dropout = nn.Dropout(p=config.dropout)

        # Position encoding
        self.position = position_registry.get(config.position)(config)

        # RoPE is passed to attention for Q/K rotation
        self._rope = self.position if config.position == "rope" else None

        # ALiBi is applied to the attention mask
        self._alibi = self.position if config.position == "alibi" else None

    def __call__(
        self,
        x: mx.array,
        mask: mx.array | None = None,
        cache: tuple[mx.array, mx.array] | None = None,
    ) -> tuple[mx.array, tuple[mx.array, mx.array] | None]:
        """Forward pass through the block.

        Args:
            x: Input tensor (batch, seq_len, d_model).
            mask: Optional attention mask.
            cache: Optional KV cache for generation.

        Returns:
            Tuple of (output, updated_cache).
        """
        if self.config.pre_norm:
            return self._pre_norm_forward(x, mask, cache)
        return self._post_norm_forward(x, mask, cache)

    def _pre_norm_forward(
        self,
        x: mx.array,
        mask: mx.array | None,
        cache: tuple[mx.array, mx.array] | None,
    ) -> tuple[mx.array, tuple[mx.array, mx.array] | None]:
        """Pre-norm: norm -> sublayer -> dropout -> residual."""
        # Apply ALiBi bias to attention mask
        if self._alibi is not None:
            L = x.shape[1]
            cache_len = mask.shape[-1] - L if mask is not None else 0
            mask = self._alibi(
                mask=mask,
                seq_len=L,
                cache_len=cache_len,
            )

        # Attention sublayer
        residual = x
        h = self.attn_norm(x)
        h, new_cache = self.attention(
            h,
            mask=mask,
            cache=cache,
            rope=self._rope,
        )
        x = residual + self.resid_dropout(h)

        # FFN sublayer
        residual = x
        h = self.ffn_norm(x)
        h = self.ffn(h)
        x = residual + self.resid_dropout(h)

        return x, new_cache

    def _post_norm_forward(
        self,
        x: mx.array,
        mask: mx.array | None,
        cache: tuple[mx.array, mx.array] | None,
    ) -> tuple[mx.array, tuple[mx.array, mx.array] | None]:
        """Post-norm: sublayer -> dropout -> residual -> norm."""
        # Apply ALiBi bias to attention mask
        if self._alibi is not None:
            L = x.shape[1]
            cache_len = mask.shape[-1] - L if mask is not None else 0
            mask = self._alibi(
                mask=mask,
                seq_len=L,
                cache_len=cache_len,
            )

        # Attention sublayer
        h, new_cache = self.attention(
            x,
            mask=mask,
            cache=cache,
            rope=self._rope,
        )
        x = self.attn_norm(x + self.resid_dropout(h))

        # FFN sublayer
        h = self.ffn(x)
        x = self.ffn_norm(x + self.resid_dropout(h))

        return x, new_cache

_alibi = self.position if config.position == 'alibi' else None instance-attribute

_rope = self.position if config.position == 'rope' else None instance-attribute

attention = attn_cls(config) instance-attribute

attn_norm = norm_cls(config) instance-attribute

config = config instance-attribute

ffn = ffn_cls(config) instance-attribute

ffn_norm = norm_cls(config) instance-attribute

position = position_registry.get(config.position)(config) instance-attribute

resid_dropout = nn.Dropout(p=(config.dropout)) instance-attribute

__call__(x, mask=None, cache=None)

Forward pass through the block.

Parameters:

Name Type Description Default
x array

Input tensor (batch, seq_len, d_model).

required
mask array | None

Optional attention mask.

None
cache tuple[array, array] | None

Optional KV cache for generation.

None

Returns:

Type Description
tuple[array, tuple[array, array] | None]

Tuple of (output, updated_cache).

Source code in src/lmxlab/core/block.py
def __call__(
    self,
    x: mx.array,
    mask: mx.array | None = None,
    cache: tuple[mx.array, mx.array] | None = None,
) -> tuple[mx.array, tuple[mx.array, mx.array] | None]:
    """Forward pass through the block.

    Args:
        x: Input tensor (batch, seq_len, d_model).
        mask: Optional attention mask.
        cache: Optional KV cache for generation.

    Returns:
        Tuple of (output, updated_cache).
    """
    if self.config.pre_norm:
        return self._pre_norm_forward(x, mask, cache)
    return self._post_norm_forward(x, mask, cache)

__init__(config)

Source code in src/lmxlab/core/block.py
def __init__(self, config: BlockConfig) -> None:
    super().__init__()
    self.config = config

    # Build components from registries
    attn_cls = attention_registry.get(config.attention)
    ffn_cls = ffn_registry.get(config.ffn)
    norm_cls = norm_registry.get(config.norm)

    self.attention = attn_cls(config)
    self.ffn = ffn_cls(config)
    self.attn_norm = norm_cls(config)
    self.ffn_norm = norm_cls(config)

    # Residual dropout (applied after sublayer output)
    self.resid_dropout = nn.Dropout(p=config.dropout)

    # Position encoding
    self.position = position_registry.get(config.position)(config)

    # RoPE is passed to attention for Q/K rotation
    self._rope = self.position if config.position == "rope" else None

    # ALiBi is applied to the attention mask
    self._alibi = self.position if config.position == "alibi" else None

_post_norm_forward(x, mask, cache)

Post-norm: sublayer -> dropout -> residual -> norm.

Source code in src/lmxlab/core/block.py
def _post_norm_forward(
    self,
    x: mx.array,
    mask: mx.array | None,
    cache: tuple[mx.array, mx.array] | None,
) -> tuple[mx.array, tuple[mx.array, mx.array] | None]:
    """Post-norm: sublayer -> dropout -> residual -> norm."""
    # Apply ALiBi bias to attention mask
    if self._alibi is not None:
        L = x.shape[1]
        cache_len = mask.shape[-1] - L if mask is not None else 0
        mask = self._alibi(
            mask=mask,
            seq_len=L,
            cache_len=cache_len,
        )

    # Attention sublayer
    h, new_cache = self.attention(
        x,
        mask=mask,
        cache=cache,
        rope=self._rope,
    )
    x = self.attn_norm(x + self.resid_dropout(h))

    # FFN sublayer
    h = self.ffn(x)
    x = self.ffn_norm(x + self.resid_dropout(h))

    return x, new_cache

_pre_norm_forward(x, mask, cache)

Pre-norm: norm -> sublayer -> dropout -> residual.

Source code in src/lmxlab/core/block.py
def _pre_norm_forward(
    self,
    x: mx.array,
    mask: mx.array | None,
    cache: tuple[mx.array, mx.array] | None,
) -> tuple[mx.array, tuple[mx.array, mx.array] | None]:
    """Pre-norm: norm -> sublayer -> dropout -> residual."""
    # Apply ALiBi bias to attention mask
    if self._alibi is not None:
        L = x.shape[1]
        cache_len = mask.shape[-1] - L if mask is not None else 0
        mask = self._alibi(
            mask=mask,
            seq_len=L,
            cache_len=cache_len,
        )

    # Attention sublayer
    residual = x
    h = self.attn_norm(x)
    h, new_cache = self.attention(
        h,
        mask=mask,
        cache=cache,
        rope=self._rope,
    )
    x = residual + self.resid_dropout(h)

    # FFN sublayer
    residual = x
    h = self.ffn_norm(x)
    h = self.ffn(h)
    x = residual + self.resid_dropout(h)

    return x, new_cache

Attention

lmxlab.core.attention.MHA

Bases: AttentionBase

Multi-Head Attention using mx.fast.scaled_dot_product_attention.

Standard MHA where n_kv_heads == n_heads.

Source code in src/lmxlab/core/attention.py
@attention_registry.register("mha")
class MHA(AttentionBase):
    """Multi-Head Attention using mx.fast.scaled_dot_product_attention.

    Standard MHA where n_kv_heads == n_heads.
    """

    def __init__(self, config: BlockConfig) -> None:
        super().__init__(config)
        self.q_proj = nn.Linear(self.d_model, self.d_model, bias=config.bias)
        self.k_proj = nn.Linear(self.d_model, self.d_model, bias=config.bias)
        self.v_proj = nn.Linear(self.d_model, self.d_model, bias=config.bias)
        self.o_proj = nn.Linear(self.d_model, self.d_model, bias=config.bias)
        self._init_qk_norm()
        # μP uses 1/d_head; SP uses 1/√d_head
        exp = -1.0 if config.mup else -0.5
        self.scale = self.head_dim**exp

    def __call__(
        self,
        x: mx.array,
        mask: mx.array | None = None,
        cache: tuple[mx.array, mx.array] | None = None,
        rope: nn.Module | None = None,
    ) -> tuple[mx.array, tuple[mx.array, mx.array] | None]:
        B, L, _ = x.shape

        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)

        q = q.reshape(B, L, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)
        k = k.reshape(B, L, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)
        v = v.reshape(B, L, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)

        q, k = self._apply_qk_norm(q, k)

        if rope is not None:
            offset = cache[0].shape[2] if cache is not None else 0
            q, k = rope(q, k, offset=offset)

        if cache is not None:
            k = mx.concatenate([cache[0], k], axis=2)
            v = mx.concatenate([cache[1], v], axis=2)
        new_cache = (k, v)

        out = mx.fast.scaled_dot_product_attention(
            q, k, v, scale=self.scale, mask=mask
        )
        out = out.transpose(0, 2, 1, 3).reshape(B, L, self.d_model)
        return self.o_proj(out), new_cache

__init__(config)

Source code in src/lmxlab/core/attention.py
def __init__(self, config: BlockConfig) -> None:
    super().__init__(config)
    self.q_proj = nn.Linear(self.d_model, self.d_model, bias=config.bias)
    self.k_proj = nn.Linear(self.d_model, self.d_model, bias=config.bias)
    self.v_proj = nn.Linear(self.d_model, self.d_model, bias=config.bias)
    self.o_proj = nn.Linear(self.d_model, self.d_model, bias=config.bias)
    self._init_qk_norm()
    # μP uses 1/d_head; SP uses 1/√d_head
    exp = -1.0 if config.mup else -0.5
    self.scale = self.head_dim**exp

__call__(x, mask=None, cache=None, rope=None)

Source code in src/lmxlab/core/attention.py
def __call__(
    self,
    x: mx.array,
    mask: mx.array | None = None,
    cache: tuple[mx.array, mx.array] | None = None,
    rope: nn.Module | None = None,
) -> tuple[mx.array, tuple[mx.array, mx.array] | None]:
    B, L, _ = x.shape

    q = self.q_proj(x)
    k = self.k_proj(x)
    v = self.v_proj(x)

    q = q.reshape(B, L, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)
    k = k.reshape(B, L, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)
    v = v.reshape(B, L, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)

    q, k = self._apply_qk_norm(q, k)

    if rope is not None:
        offset = cache[0].shape[2] if cache is not None else 0
        q, k = rope(q, k, offset=offset)

    if cache is not None:
        k = mx.concatenate([cache[0], k], axis=2)
        v = mx.concatenate([cache[1], v], axis=2)
    new_cache = (k, v)

    out = mx.fast.scaled_dot_product_attention(
        q, k, v, scale=self.scale, mask=mask
    )
    out = out.transpose(0, 2, 1, 3).reshape(B, L, self.d_model)
    return self.o_proj(out), new_cache

lmxlab.core.attention.GQA

Bases: AttentionBase

Grouped-Query Attention.

Uses fewer KV heads than query heads for memory efficiency. When n_kv_heads == 1, this is Multi-Query Attention (MQA). When n_kv_heads == n_heads, this is standard MHA.

Source code in src/lmxlab/core/attention.py
@attention_registry.register("gqa")
class GQA(AttentionBase):
    """Grouped-Query Attention.

    Uses fewer KV heads than query heads for memory efficiency.
    When n_kv_heads == 1, this is Multi-Query Attention (MQA).
    When n_kv_heads == n_heads, this is standard MHA.
    """

    def __init__(self, config: BlockConfig) -> None:
        super().__init__(config)
        self.n_kv_heads = config.effective_n_kv_heads
        kv_dim = self.n_kv_heads * self.head_dim

        self.q_proj = nn.Linear(self.d_model, self.d_model, bias=config.bias)
        self.k_proj = nn.Linear(self.d_model, kv_dim, bias=config.bias)
        self.v_proj = nn.Linear(self.d_model, kv_dim, bias=config.bias)
        self.o_proj = nn.Linear(self.d_model, self.d_model, bias=config.bias)
        self._init_qk_norm()
        # μP uses 1/d_head; SP uses 1/√d_head
        exp = -1.0 if config.mup else -0.5
        self.scale = self.head_dim**exp

    def __call__(
        self,
        x: mx.array,
        mask: mx.array | None = None,
        cache: tuple[mx.array, mx.array] | None = None,
        rope: nn.Module | None = None,
    ) -> tuple[mx.array, tuple[mx.array, mx.array] | None]:
        B, L, _ = x.shape

        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)

        q = q.reshape(B, L, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)
        k = k.reshape(B, L, self.n_kv_heads, self.head_dim).transpose(
            0, 2, 1, 3
        )
        v = v.reshape(B, L, self.n_kv_heads, self.head_dim).transpose(
            0, 2, 1, 3
        )

        q, k = self._apply_qk_norm(q, k)

        if rope is not None:
            offset = cache[0].shape[2] if cache is not None else 0
            q, k = rope(q, k, offset=offset)

        if cache is not None:
            k = mx.concatenate([cache[0], k], axis=2)
            v = mx.concatenate([cache[1], v], axis=2)
        new_cache = (k, v)

        out = mx.fast.scaled_dot_product_attention(
            q, k, v, scale=self.scale, mask=mask
        )
        out = out.transpose(0, 2, 1, 3).reshape(B, L, self.d_model)
        return self.o_proj(out), new_cache

__init__(config)

Source code in src/lmxlab/core/attention.py
def __init__(self, config: BlockConfig) -> None:
    super().__init__(config)
    self.n_kv_heads = config.effective_n_kv_heads
    kv_dim = self.n_kv_heads * self.head_dim

    self.q_proj = nn.Linear(self.d_model, self.d_model, bias=config.bias)
    self.k_proj = nn.Linear(self.d_model, kv_dim, bias=config.bias)
    self.v_proj = nn.Linear(self.d_model, kv_dim, bias=config.bias)
    self.o_proj = nn.Linear(self.d_model, self.d_model, bias=config.bias)
    self._init_qk_norm()
    # μP uses 1/d_head; SP uses 1/√d_head
    exp = -1.0 if config.mup else -0.5
    self.scale = self.head_dim**exp

__call__(x, mask=None, cache=None, rope=None)

Source code in src/lmxlab/core/attention.py
def __call__(
    self,
    x: mx.array,
    mask: mx.array | None = None,
    cache: tuple[mx.array, mx.array] | None = None,
    rope: nn.Module | None = None,
) -> tuple[mx.array, tuple[mx.array, mx.array] | None]:
    B, L, _ = x.shape

    q = self.q_proj(x)
    k = self.k_proj(x)
    v = self.v_proj(x)

    q = q.reshape(B, L, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)
    k = k.reshape(B, L, self.n_kv_heads, self.head_dim).transpose(
        0, 2, 1, 3
    )
    v = v.reshape(B, L, self.n_kv_heads, self.head_dim).transpose(
        0, 2, 1, 3
    )

    q, k = self._apply_qk_norm(q, k)

    if rope is not None:
        offset = cache[0].shape[2] if cache is not None else 0
        q, k = rope(q, k, offset=offset)

    if cache is not None:
        k = mx.concatenate([cache[0], k], axis=2)
        v = mx.concatenate([cache[1], v], axis=2)
    new_cache = (k, v)

    out = mx.fast.scaled_dot_product_attention(
        q, k, v, scale=self.scale, mask=mask
    )
    out = out.transpose(0, 2, 1, 3).reshape(B, L, self.d_model)
    return self.o_proj(out), new_cache

lmxlab.core.attention.SlidingWindowGQA

Bases: AttentionBase

Grouped-Query Attention with sliding window masking.

Each token can only attend to the most recent window_size tokens (including itself). Uses GQA head configuration for memory efficiency.

The window size is read from config.window_size.

Source code in src/lmxlab/core/attention.py
@attention_registry.register("sliding_window_gqa")
class SlidingWindowGQA(AttentionBase):
    """Grouped-Query Attention with sliding window masking.

    Each token can only attend to the most recent
    ``window_size`` tokens (including itself). Uses GQA head
    configuration for memory efficiency.

    The window size is read from ``config.window_size``.
    """

    def __init__(self, config: BlockConfig) -> None:
        super().__init__(config)
        if config.window_size is None:
            raise ValueError("SlidingWindowGQA requires config.window_size")
        self.window_size = config.window_size
        self.n_kv_heads = config.effective_n_kv_heads
        kv_dim = self.n_kv_heads * self.head_dim

        self.q_proj = nn.Linear(self.d_model, self.d_model, bias=config.bias)
        self.k_proj = nn.Linear(self.d_model, kv_dim, bias=config.bias)
        self.v_proj = nn.Linear(self.d_model, kv_dim, bias=config.bias)
        self.o_proj = nn.Linear(self.d_model, self.d_model, bias=config.bias)
        self._init_qk_norm()
        # μP uses 1/d_head; SP uses 1/√d_head
        exp = -1.0 if config.mup else -0.5
        self.scale = self.head_dim**exp

    def __call__(
        self,
        x: mx.array,
        mask: mx.array | None = None,
        cache: tuple[mx.array, mx.array] | None = None,
        rope: nn.Module | None = None,
    ) -> tuple[mx.array, tuple[mx.array, mx.array] | None]:
        B, L, _ = x.shape

        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)

        q = q.reshape(B, L, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)
        k = k.reshape(B, L, self.n_kv_heads, self.head_dim).transpose(
            0, 2, 1, 3
        )
        v = v.reshape(B, L, self.n_kv_heads, self.head_dim).transpose(
            0, 2, 1, 3
        )

        q, k = self._apply_qk_norm(q, k)

        cache_len = cache[0].shape[2] if cache is not None else 0

        if rope is not None:
            q, k = rope(q, k, offset=cache_len)

        if cache is not None:
            k = mx.concatenate([cache[0], k], axis=2)
            v = mx.concatenate([cache[1], v], axis=2)
        new_cache = (k, v)

        # Apply sliding window to the mask
        mask = _apply_sliding_window(mask, self.window_size, L, cache_len)

        out = mx.fast.scaled_dot_product_attention(
            q, k, v, scale=self.scale, mask=mask
        )
        out = out.transpose(0, 2, 1, 3).reshape(B, L, self.d_model)
        return self.o_proj(out), new_cache

__init__(config)

Source code in src/lmxlab/core/attention.py
def __init__(self, config: BlockConfig) -> None:
    super().__init__(config)
    if config.window_size is None:
        raise ValueError("SlidingWindowGQA requires config.window_size")
    self.window_size = config.window_size
    self.n_kv_heads = config.effective_n_kv_heads
    kv_dim = self.n_kv_heads * self.head_dim

    self.q_proj = nn.Linear(self.d_model, self.d_model, bias=config.bias)
    self.k_proj = nn.Linear(self.d_model, kv_dim, bias=config.bias)
    self.v_proj = nn.Linear(self.d_model, kv_dim, bias=config.bias)
    self.o_proj = nn.Linear(self.d_model, self.d_model, bias=config.bias)
    self._init_qk_norm()
    # μP uses 1/d_head; SP uses 1/√d_head
    exp = -1.0 if config.mup else -0.5
    self.scale = self.head_dim**exp

__call__(x, mask=None, cache=None, rope=None)

Source code in src/lmxlab/core/attention.py
def __call__(
    self,
    x: mx.array,
    mask: mx.array | None = None,
    cache: tuple[mx.array, mx.array] | None = None,
    rope: nn.Module | None = None,
) -> tuple[mx.array, tuple[mx.array, mx.array] | None]:
    B, L, _ = x.shape

    q = self.q_proj(x)
    k = self.k_proj(x)
    v = self.v_proj(x)

    q = q.reshape(B, L, self.n_heads, self.head_dim).transpose(0, 2, 1, 3)
    k = k.reshape(B, L, self.n_kv_heads, self.head_dim).transpose(
        0, 2, 1, 3
    )
    v = v.reshape(B, L, self.n_kv_heads, self.head_dim).transpose(
        0, 2, 1, 3
    )

    q, k = self._apply_qk_norm(q, k)

    cache_len = cache[0].shape[2] if cache is not None else 0

    if rope is not None:
        q, k = rope(q, k, offset=cache_len)

    if cache is not None:
        k = mx.concatenate([cache[0], k], axis=2)
        v = mx.concatenate([cache[1], v], axis=2)
    new_cache = (k, v)

    # Apply sliding window to the mask
    mask = _apply_sliding_window(mask, self.window_size, L, cache_len)

    out = mx.fast.scaled_dot_product_attention(
        q, k, v, scale=self.scale, mask=mask
    )
    out = out.transpose(0, 2, 1, 3).reshape(B, L, self.d_model)
    return self.o_proj(out), new_cache

Multi-Head Latent Attention

lmxlab.core.mla.MLA

Bases: AttentionBase

Multi-Head Latent Attention.

Compresses KV into a low-rank latent for efficient caching. Uses a shared single-head RoPE key (MQA-style) for the position-dependent portion of K.

Config requirements

kv_lora_rank: Latent dimension for KV compression. q_lora_rank: Latent dimension for Q compression (optional). rope_dim: Dimensions allocated for decoupled RoPE.

Source code in src/lmxlab/core/mla.py
@attention_registry.register("mla")
class MLA(AttentionBase):
    """Multi-Head Latent Attention.

    Compresses KV into a low-rank latent for efficient caching.
    Uses a shared single-head RoPE key (MQA-style) for the
    position-dependent portion of K.

    Config requirements:
        kv_lora_rank: Latent dimension for KV compression.
        q_lora_rank: Latent dimension for Q compression (optional).
        rope_dim: Dimensions allocated for decoupled RoPE.
    """

    def __init__(self, config: BlockConfig) -> None:
        super().__init__(config)

        kv_lora_rank = config.kv_lora_rank
        if kv_lora_rank is None:
            raise ValueError("MLA requires kv_lora_rank in BlockConfig")

        self.kv_lora_rank = kv_lora_rank
        self.q_lora_rank = config.q_lora_rank
        self.rope_dim = config.rope_dim or 0
        self.nope_dim = self.head_dim - self.rope_dim

        # KV down-projection: produces latent + shared rope key
        # in a single projection for efficiency
        kv_down_dim = kv_lora_rank
        if self.rope_dim > 0:
            kv_down_dim += self.rope_dim  # shared single-head
        self.kv_down = nn.Linear(self.d_model, kv_down_dim, bias=False)
        self.kv_norm = nn.RMSNorm(kv_lora_rank)

        # KV up-projection: latent -> multi-head K_nope and V
        self.kv_up = nn.Linear(
            kv_lora_rank,
            self.n_heads * (self.nope_dim + self.head_dim),
            bias=False,
        )

        # Q projection (optionally compressed via LoRA)
        if self.q_lora_rank is not None:
            self.q_down = nn.Linear(self.d_model, self.q_lora_rank, bias=False)
            self.q_norm = nn.RMSNorm(self.q_lora_rank)
            self.q_up = nn.Linear(
                self.q_lora_rank,
                self.n_heads * self.head_dim,
                bias=False,
            )
        else:
            self.q_proj = nn.Linear(
                self.d_model, self.d_model, bias=config.bias
            )

        # Output projection
        self.o_proj = nn.Linear(self.d_model, self.d_model, bias=config.bias)
        self.scale = self.head_dim**-0.5

        # RoPE for decoupled dimensions
        if self.rope_dim > 0:
            self._rope = nn.RoPE(
                self.rope_dim,
                traditional=False,
                base=config.rope_theta,
            )

    def __call__(
        self,
        x: mx.array,
        mask: mx.array | None = None,
        cache: tuple[mx.array, mx.array] | None = None,
        rope: nn.Module | None = None,
    ) -> tuple[mx.array, tuple[mx.array, mx.array] | None]:
        B, L, _ = x.shape

        # --- Q projection ---
        if self.q_lora_rank is not None:
            c_q = self.q_norm(self.q_down(x))
            q = self.q_up(c_q)
        else:
            q = self.q_proj(x)

        q = q.reshape(B, L, self.n_heads, self.head_dim)
        q = q.transpose(0, 2, 1, 3)  # (B, n_heads, L, head_dim)

        # --- KV down-projection ---
        compressed = self.kv_down(x)  # (B, L, kv_lora_rank [+ rope])

        if self.rope_dim > 0:
            # Split into latent and shared rope key
            c_kv = compressed[:, :, : self.kv_lora_rank]
            k_pe = compressed[:, :, self.kv_lora_rank :]

            # k_pe is shared single-head: (B, L, rope_dim)
            # -> (B, 1, L, rope_dim) for MQA-style broadcast
            k_pe = k_pe[:, None, :, :]

            # Apply RoPE
            offset = 0
            if cache is not None:
                offset = cache[0].shape[1]
            q_pe = q[:, :, :, : self.rope_dim]
            q_nope = q[:, :, :, self.rope_dim :]
            q_pe = self._rope(q_pe, offset=offset)
            k_pe = self._rope(k_pe, offset=offset)
            q = mx.concatenate([q_pe, q_nope], axis=-1)
        else:
            c_kv = compressed

        # Normalize latent
        c_kv = self.kv_norm(c_kv)  # (B, L, kv_lora_rank)

        # --- Caching ---
        # Cache: (c_kv, k_pe) — compressed representations
        if cache is not None:
            prev_c_kv, prev_k_pe = cache
            c_kv = mx.concatenate([prev_c_kv, c_kv], axis=1)
            if self.rope_dim > 0:
                k_pe = mx.concatenate([prev_k_pe, k_pe], axis=2)
        new_cache = (c_kv, k_pe) if self.rope_dim > 0 else (c_kv, c_kv)

        # --- KV up-projection from latent ---
        kv = self.kv_up(c_kv)
        # (B, total_L, n_heads * (nope_dim + head_dim))
        kv = kv.reshape(
            B, -1, self.n_heads, self.nope_dim + self.head_dim
        ).transpose(0, 2, 1, 3)

        k_nope = kv[:, :, :, : self.nope_dim]
        v = kv[:, :, :, self.nope_dim :]

        # Combine rope and nope K dimensions
        if self.rope_dim > 0:
            # Broadcast shared k_pe (B,1,L,rope) -> (B,n_heads,L,rope)
            k_pe_broad = mx.broadcast_to(
                k_pe,
                (B, self.n_heads, k_pe.shape[2], self.rope_dim),
            )
            k = mx.concatenate([k_pe_broad, k_nope], axis=-1)
        else:
            k = k_nope

        # --- Attention ---
        out = mx.fast.scaled_dot_product_attention(
            q, k, v, scale=self.scale, mask=mask
        )
        out = out.transpose(0, 2, 1, 3).reshape(B, L, self.d_model)
        return self.o_proj(out), new_cache

__init__(config)

Source code in src/lmxlab/core/mla.py
def __init__(self, config: BlockConfig) -> None:
    super().__init__(config)

    kv_lora_rank = config.kv_lora_rank
    if kv_lora_rank is None:
        raise ValueError("MLA requires kv_lora_rank in BlockConfig")

    self.kv_lora_rank = kv_lora_rank
    self.q_lora_rank = config.q_lora_rank
    self.rope_dim = config.rope_dim or 0
    self.nope_dim = self.head_dim - self.rope_dim

    # KV down-projection: produces latent + shared rope key
    # in a single projection for efficiency
    kv_down_dim = kv_lora_rank
    if self.rope_dim > 0:
        kv_down_dim += self.rope_dim  # shared single-head
    self.kv_down = nn.Linear(self.d_model, kv_down_dim, bias=False)
    self.kv_norm = nn.RMSNorm(kv_lora_rank)

    # KV up-projection: latent -> multi-head K_nope and V
    self.kv_up = nn.Linear(
        kv_lora_rank,
        self.n_heads * (self.nope_dim + self.head_dim),
        bias=False,
    )

    # Q projection (optionally compressed via LoRA)
    if self.q_lora_rank is not None:
        self.q_down = nn.Linear(self.d_model, self.q_lora_rank, bias=False)
        self.q_norm = nn.RMSNorm(self.q_lora_rank)
        self.q_up = nn.Linear(
            self.q_lora_rank,
            self.n_heads * self.head_dim,
            bias=False,
        )
    else:
        self.q_proj = nn.Linear(
            self.d_model, self.d_model, bias=config.bias
        )

    # Output projection
    self.o_proj = nn.Linear(self.d_model, self.d_model, bias=config.bias)
    self.scale = self.head_dim**-0.5

    # RoPE for decoupled dimensions
    if self.rope_dim > 0:
        self._rope = nn.RoPE(
            self.rope_dim,
            traditional=False,
            base=config.rope_theta,
        )

__call__(x, mask=None, cache=None, rope=None)

Source code in src/lmxlab/core/mla.py
def __call__(
    self,
    x: mx.array,
    mask: mx.array | None = None,
    cache: tuple[mx.array, mx.array] | None = None,
    rope: nn.Module | None = None,
) -> tuple[mx.array, tuple[mx.array, mx.array] | None]:
    B, L, _ = x.shape

    # --- Q projection ---
    if self.q_lora_rank is not None:
        c_q = self.q_norm(self.q_down(x))
        q = self.q_up(c_q)
    else:
        q = self.q_proj(x)

    q = q.reshape(B, L, self.n_heads, self.head_dim)
    q = q.transpose(0, 2, 1, 3)  # (B, n_heads, L, head_dim)

    # --- KV down-projection ---
    compressed = self.kv_down(x)  # (B, L, kv_lora_rank [+ rope])

    if self.rope_dim > 0:
        # Split into latent and shared rope key
        c_kv = compressed[:, :, : self.kv_lora_rank]
        k_pe = compressed[:, :, self.kv_lora_rank :]

        # k_pe is shared single-head: (B, L, rope_dim)
        # -> (B, 1, L, rope_dim) for MQA-style broadcast
        k_pe = k_pe[:, None, :, :]

        # Apply RoPE
        offset = 0
        if cache is not None:
            offset = cache[0].shape[1]
        q_pe = q[:, :, :, : self.rope_dim]
        q_nope = q[:, :, :, self.rope_dim :]
        q_pe = self._rope(q_pe, offset=offset)
        k_pe = self._rope(k_pe, offset=offset)
        q = mx.concatenate([q_pe, q_nope], axis=-1)
    else:
        c_kv = compressed

    # Normalize latent
    c_kv = self.kv_norm(c_kv)  # (B, L, kv_lora_rank)

    # --- Caching ---
    # Cache: (c_kv, k_pe) — compressed representations
    if cache is not None:
        prev_c_kv, prev_k_pe = cache
        c_kv = mx.concatenate([prev_c_kv, c_kv], axis=1)
        if self.rope_dim > 0:
            k_pe = mx.concatenate([prev_k_pe, k_pe], axis=2)
    new_cache = (c_kv, k_pe) if self.rope_dim > 0 else (c_kv, c_kv)

    # --- KV up-projection from latent ---
    kv = self.kv_up(c_kv)
    # (B, total_L, n_heads * (nope_dim + head_dim))
    kv = kv.reshape(
        B, -1, self.n_heads, self.nope_dim + self.head_dim
    ).transpose(0, 2, 1, 3)

    k_nope = kv[:, :, :, : self.nope_dim]
    v = kv[:, :, :, self.nope_dim :]

    # Combine rope and nope K dimensions
    if self.rope_dim > 0:
        # Broadcast shared k_pe (B,1,L,rope) -> (B,n_heads,L,rope)
        k_pe_broad = mx.broadcast_to(
            k_pe,
            (B, self.n_heads, k_pe.shape[2], self.rope_dim),
        )
        k = mx.concatenate([k_pe_broad, k_nope], axis=-1)
    else:
        k = k_nope

    # --- Attention ---
    out = mx.fast.scaled_dot_product_attention(
        q, k, v, scale=self.scale, mask=mask
    )
    out = out.transpose(0, 2, 1, 3).reshape(B, L, self.d_model)
    return self.o_proj(out), new_cache

Gated DeltaNet

lmxlab.core.deltanet.GatedDeltaNet

Bases: AttentionBase

Gated Delta Network for linear attention.

Uses the delta rule for error-correcting state updates with learned decay and update gates. The state matrix S has fixed size (d_k, d_v) regardless of sequence length, giving O(1) memory per token during inference.

Forward pass per token
  1. Project x -> Q, K, V, decay_logits, update_logits
  2. Apply causal convolution on Q, K, V (local context)
  3. Compute gates: alpha = sigmoid(decay), beta = sigmoid(update)
  4. L2 normalize Q, K
  5. Delta update: S = alpha * S - beta * (S @ k - v) @ k^T
  6. Output: o = S^T @ q
(S, conv_state) where:

S: (B, H, head_dim, head_dim) — the state matrix conv_state: (B, H, K-1, head_dim) — conv history

Source code in src/lmxlab/core/deltanet.py
@attention_registry.register("gated_deltanet")
class GatedDeltaNet(AttentionBase):
    """Gated Delta Network for linear attention.

    Uses the delta rule for error-correcting state updates
    with learned decay and update gates. The state matrix S
    has fixed size (d_k, d_v) regardless of sequence length,
    giving O(1) memory per token during inference.

    Forward pass per token:
        1. Project x -> Q, K, V, decay_logits, update_logits
        2. Apply causal convolution on Q, K, V (local context)
        3. Compute gates: alpha = sigmoid(decay), beta = sigmoid(update)
        4. L2 normalize Q, K
        5. Delta update: S = alpha * S - beta * (S @ k - v) @ k^T
        6. Output: o = S^T @ q

    Cache format: (S, conv_state) where:
        S: (B, H, head_dim, head_dim) — the state matrix
        conv_state: (B, H, K-1, head_dim) — conv history
    """

    def __init__(self, config: BlockConfig) -> None:
        super().__init__(config)

        # Projections
        self.q_proj = nn.Linear(self.d_model, self.d_model, bias=False)
        self.k_proj = nn.Linear(self.d_model, self.d_model, bias=False)
        self.v_proj = nn.Linear(self.d_model, self.d_model, bias=False)
        self.o_proj = nn.Linear(self.d_model, self.d_model, bias=False)

        # Gate projections (per-head scalar gates)
        self.decay_proj = nn.Linear(self.d_model, self.n_heads, bias=True)
        self.update_proj = nn.Linear(self.d_model, self.n_heads, bias=True)

        # Output gate (per-head, applied to output)
        self.out_gate_proj = nn.Linear(self.d_model, self.d_model, bias=False)

        # Short causal convolution weights (per head)
        self.conv_kernel_size = config.conv_kernel_size
        self.use_short_conv = config.use_short_conv
        if self.use_short_conv:
            # (n_heads, kernel_size, head_dim)
            self.q_conv_w = (
                mx.random.normal(
                    shape=(
                        self.n_heads,
                        self.conv_kernel_size,
                        self.head_dim,
                    )
                )
                * 0.02
            )
            self.k_conv_w = (
                mx.random.normal(
                    shape=(
                        self.n_heads,
                        self.conv_kernel_size,
                        self.head_dim,
                    )
                )
                * 0.02
            )
            self.v_conv_w = (
                mx.random.normal(
                    shape=(
                        self.n_heads,
                        self.conv_kernel_size,
                        self.head_dim,
                    )
                )
                * 0.02
            )

        # Initialize gate biases to negative values so gates
        # start near 0 (conservative updates at init)
        self.decay_proj.bias = mx.full((self.n_heads,), -3.0)
        self.update_proj.bias = mx.full((self.n_heads,), -3.0)

    def __call__(
        self,
        x: mx.array,
        mask: mx.array | None = None,
        cache: tuple[mx.array, ...] | None = None,
        rope: nn.Module | None = None,
    ) -> tuple[mx.array, tuple[mx.array, ...] | None]:
        """Forward pass with delta rule state updates.

        Args:
            x: Input (B, L, d_model).
            mask: Unused (DeltaNet is inherently causal via
                recurrent state).
            cache: Tuple of (S, q_conv, k_conv, v_conv) for
                autoregressive inference.

        Returns:
            Tuple of (output, new_cache).
        """
        B, L, _ = x.shape

        # Project
        q = self.q_proj(x)
        k = self.k_proj(x)
        v = self.v_proj(x)

        # Gate logits
        decay_logits = self.decay_proj(x)  # (B, L, H)
        update_logits = self.update_proj(x)  # (B, L, H)

        # Output gate
        out_gate = nn.silu(self.out_gate_proj(x))
        # (B, L, d_model)

        # Reshape to (B, H, L, head_dim)
        q = q.reshape(B, L, self.n_heads, self.head_dim)
        q = q.transpose(0, 2, 1, 3)
        k = k.reshape(B, L, self.n_heads, self.head_dim)
        k = k.transpose(0, 2, 1, 3)
        v = v.reshape(B, L, self.n_heads, self.head_dim)
        v = v.transpose(0, 2, 1, 3)

        # Parse cache
        if cache is not None:
            S_prev = cache[0]
            q_conv_state = cache[1] if self.use_short_conv else None
            k_conv_state = cache[2] if self.use_short_conv else None
            v_conv_state = cache[3] if self.use_short_conv else None
        else:
            S_prev = mx.zeros((B, self.n_heads, self.head_dim, self.head_dim))
            q_conv_state = None
            k_conv_state = None
            v_conv_state = None

        # Apply short causal convolutions (optional)
        if self.use_short_conv:
            q, q_conv_state = _causal_conv1d(q, self.q_conv_w, q_conv_state)
            k, k_conv_state = _causal_conv1d(k, self.k_conv_w, k_conv_state)
            v, v_conv_state = _causal_conv1d(v, self.v_conv_w, v_conv_state)
            q = nn.silu(q)

        # L2 normalize Q and K
        q = _l2_normalize(q)
        k = _l2_normalize(k)

        # Compute gates: (B, L, H) -> (B, H, L, 1)
        alpha = mx.sigmoid(decay_logits)
        alpha = alpha.transpose(0, 2, 1)[:, :, :, None]
        beta = mx.sigmoid(update_logits)
        beta = beta.transpose(0, 2, 1)[:, :, :, None]

        # Recurrent delta rule over sequence
        outputs = []
        S = S_prev
        for t in range(L):
            q_t = q[:, :, t, :]  # (B, H, d)
            k_t = k[:, :, t, :]  # (B, H, d)
            v_t = v[:, :, t, :]  # (B, H, d)
            a_t = alpha[:, :, t, :]  # (B, H, 1)
            b_t = beta[:, :, t, :]  # (B, H, 1)

            # Prediction: S @ k -> (B, H, d)
            # S is (B, H, d, d), k_t is (B, H, d)
            pred = mx.sum(S * k_t[:, :, None, :], axis=-1)

            # Error: predicted - actual
            error = pred - v_t  # (B, H, d)

            # Delta update: S = alpha * S - beta * error @ k^T
            # error @ k^T: (B, H, d, 1) * (B, H, 1, d)
            correction = error[:, :, :, None] * k_t[:, :, None, :]
            S = a_t[:, :, :, None] * S - b_t[:, :, :, None] * correction

            # Output: q^T @ S -> (B, H, d)
            o_t = mx.sum(q_t[:, :, None, :] * S, axis=-1)
            outputs.append(o_t)

        # Stack outputs: (B, H, L, d)
        out = mx.stack(outputs, axis=2)

        # Reshape to (B, L, d_model)
        out = out.transpose(0, 2, 1, 3).reshape(B, L, self.d_model)

        # Apply output gate
        out = out * out_gate

        # Output projection
        out = self.o_proj(out)

        # Build new cache
        new_cache: tuple[mx.array, ...]
        if self.use_short_conv:
            new_cache = (S, q_conv_state, k_conv_state, v_conv_state)
        else:
            new_cache = (S, mx.array(0), mx.array(0), mx.array(0))

        return out, new_cache

__init__(config)

Source code in src/lmxlab/core/deltanet.py
def __init__(self, config: BlockConfig) -> None:
    super().__init__(config)

    # Projections
    self.q_proj = nn.Linear(self.d_model, self.d_model, bias=False)
    self.k_proj = nn.Linear(self.d_model, self.d_model, bias=False)
    self.v_proj = nn.Linear(self.d_model, self.d_model, bias=False)
    self.o_proj = nn.Linear(self.d_model, self.d_model, bias=False)

    # Gate projections (per-head scalar gates)
    self.decay_proj = nn.Linear(self.d_model, self.n_heads, bias=True)
    self.update_proj = nn.Linear(self.d_model, self.n_heads, bias=True)

    # Output gate (per-head, applied to output)
    self.out_gate_proj = nn.Linear(self.d_model, self.d_model, bias=False)

    # Short causal convolution weights (per head)
    self.conv_kernel_size = config.conv_kernel_size
    self.use_short_conv = config.use_short_conv
    if self.use_short_conv:
        # (n_heads, kernel_size, head_dim)
        self.q_conv_w = (
            mx.random.normal(
                shape=(
                    self.n_heads,
                    self.conv_kernel_size,
                    self.head_dim,
                )
            )
            * 0.02
        )
        self.k_conv_w = (
            mx.random.normal(
                shape=(
                    self.n_heads,
                    self.conv_kernel_size,
                    self.head_dim,
                )
            )
            * 0.02
        )
        self.v_conv_w = (
            mx.random.normal(
                shape=(
                    self.n_heads,
                    self.conv_kernel_size,
                    self.head_dim,
                )
            )
            * 0.02
        )

    # Initialize gate biases to negative values so gates
    # start near 0 (conservative updates at init)
    self.decay_proj.bias = mx.full((self.n_heads,), -3.0)
    self.update_proj.bias = mx.full((self.n_heads,), -3.0)

__call__(x, mask=None, cache=None, rope=None)

Forward pass with delta rule state updates.

Parameters:

Name Type Description Default
x array

Input (B, L, d_model).

required
mask array | None

Unused (DeltaNet is inherently causal via recurrent state).

None
cache tuple[array, ...] | None

Tuple of (S, q_conv, k_conv, v_conv) for autoregressive inference.

None

Returns:

Type Description
tuple[array, tuple[array, ...] | None]

Tuple of (output, new_cache).

Source code in src/lmxlab/core/deltanet.py
def __call__(
    self,
    x: mx.array,
    mask: mx.array | None = None,
    cache: tuple[mx.array, ...] | None = None,
    rope: nn.Module | None = None,
) -> tuple[mx.array, tuple[mx.array, ...] | None]:
    """Forward pass with delta rule state updates.

    Args:
        x: Input (B, L, d_model).
        mask: Unused (DeltaNet is inherently causal via
            recurrent state).
        cache: Tuple of (S, q_conv, k_conv, v_conv) for
            autoregressive inference.

    Returns:
        Tuple of (output, new_cache).
    """
    B, L, _ = x.shape

    # Project
    q = self.q_proj(x)
    k = self.k_proj(x)
    v = self.v_proj(x)

    # Gate logits
    decay_logits = self.decay_proj(x)  # (B, L, H)
    update_logits = self.update_proj(x)  # (B, L, H)

    # Output gate
    out_gate = nn.silu(self.out_gate_proj(x))
    # (B, L, d_model)

    # Reshape to (B, H, L, head_dim)
    q = q.reshape(B, L, self.n_heads, self.head_dim)
    q = q.transpose(0, 2, 1, 3)
    k = k.reshape(B, L, self.n_heads, self.head_dim)
    k = k.transpose(0, 2, 1, 3)
    v = v.reshape(B, L, self.n_heads, self.head_dim)
    v = v.transpose(0, 2, 1, 3)

    # Parse cache
    if cache is not None:
        S_prev = cache[0]
        q_conv_state = cache[1] if self.use_short_conv else None
        k_conv_state = cache[2] if self.use_short_conv else None
        v_conv_state = cache[3] if self.use_short_conv else None
    else:
        S_prev = mx.zeros((B, self.n_heads, self.head_dim, self.head_dim))
        q_conv_state = None
        k_conv_state = None
        v_conv_state = None

    # Apply short causal convolutions (optional)
    if self.use_short_conv:
        q, q_conv_state = _causal_conv1d(q, self.q_conv_w, q_conv_state)
        k, k_conv_state = _causal_conv1d(k, self.k_conv_w, k_conv_state)
        v, v_conv_state = _causal_conv1d(v, self.v_conv_w, v_conv_state)
        q = nn.silu(q)

    # L2 normalize Q and K
    q = _l2_normalize(q)
    k = _l2_normalize(k)

    # Compute gates: (B, L, H) -> (B, H, L, 1)
    alpha = mx.sigmoid(decay_logits)
    alpha = alpha.transpose(0, 2, 1)[:, :, :, None]
    beta = mx.sigmoid(update_logits)
    beta = beta.transpose(0, 2, 1)[:, :, :, None]

    # Recurrent delta rule over sequence
    outputs = []
    S = S_prev
    for t in range(L):
        q_t = q[:, :, t, :]  # (B, H, d)
        k_t = k[:, :, t, :]  # (B, H, d)
        v_t = v[:, :, t, :]  # (B, H, d)
        a_t = alpha[:, :, t, :]  # (B, H, 1)
        b_t = beta[:, :, t, :]  # (B, H, 1)

        # Prediction: S @ k -> (B, H, d)
        # S is (B, H, d, d), k_t is (B, H, d)
        pred = mx.sum(S * k_t[:, :, None, :], axis=-1)

        # Error: predicted - actual
        error = pred - v_t  # (B, H, d)

        # Delta update: S = alpha * S - beta * error @ k^T
        # error @ k^T: (B, H, d, 1) * (B, H, 1, d)
        correction = error[:, :, :, None] * k_t[:, :, None, :]
        S = a_t[:, :, :, None] * S - b_t[:, :, :, None] * correction

        # Output: q^T @ S -> (B, H, d)
        o_t = mx.sum(q_t[:, :, None, :] * S, axis=-1)
        outputs.append(o_t)

    # Stack outputs: (B, H, L, d)
    out = mx.stack(outputs, axis=2)

    # Reshape to (B, L, d_model)
    out = out.transpose(0, 2, 1, 3).reshape(B, L, self.d_model)

    # Apply output gate
    out = out * out_gate

    # Output projection
    out = self.o_proj(out)

    # Build new cache
    new_cache: tuple[mx.array, ...]
    if self.use_short_conv:
        new_cache = (S, q_conv_state, k_conv_state, v_conv_state)
    else:
        new_cache = (S, mx.array(0), mx.array(0), mx.array(0))

    return out, new_cache

Feed-Forward Networks

lmxlab.core.ffn.StandardFFN

Bases: FFNBase

Standard two-layer feed-forward network with GELU activation.

FFN(x) = W2 * GELU(W1 * x + b1) + b2

Source code in src/lmxlab/core/ffn.py
@ffn_registry.register("standard")
class StandardFFN(FFNBase):
    """Standard two-layer feed-forward network with GELU activation.

    FFN(x) = W2 * GELU(W1 * x + b1) + b2
    """

    def __init__(self, config: BlockConfig) -> None:
        super().__init__(config)
        self.up = nn.Linear(config.d_model, config.d_ff, bias=config.bias)
        self.down = nn.Linear(config.d_ff, config.d_model, bias=config.bias)

    def __call__(self, x: mx.array) -> mx.array:
        return self.down(nn.gelu(self.up(x)))

__init__(config)

Source code in src/lmxlab/core/ffn.py
def __init__(self, config: BlockConfig) -> None:
    super().__init__(config)
    self.up = nn.Linear(config.d_model, config.d_ff, bias=config.bias)
    self.down = nn.Linear(config.d_ff, config.d_model, bias=config.bias)

__call__(x)

Source code in src/lmxlab/core/ffn.py
def __call__(self, x: mx.array) -> mx.array:
    return self.down(nn.gelu(self.up(x)))

lmxlab.core.ffn.GatedFFN

Bases: FFNBase

Gated feed-forward network (SwiGLU variant).

FFN(x) = W_down * (SiLU(W_gate * x) * W_up * x) Used in LLaMA, Mistral, etc.

Source code in src/lmxlab/core/ffn.py
@ffn_registry.register("gated")
class GatedFFN(FFNBase):
    """Gated feed-forward network (SwiGLU variant).

    FFN(x) = W_down * (SiLU(W_gate * x) * W_up * x)
    Used in LLaMA, Mistral, etc.
    """

    def __init__(self, config: BlockConfig) -> None:
        super().__init__(config)
        self.gate = nn.Linear(config.d_model, config.d_ff, bias=config.bias)
        self.up = nn.Linear(config.d_model, config.d_ff, bias=config.bias)
        self.down = nn.Linear(config.d_ff, config.d_model, bias=config.bias)

    def __call__(self, x: mx.array) -> mx.array:
        return self.down(nn.silu(self.gate(x)) * self.up(x))

__init__(config)

Source code in src/lmxlab/core/ffn.py
def __init__(self, config: BlockConfig) -> None:
    super().__init__(config)
    self.gate = nn.Linear(config.d_model, config.d_ff, bias=config.bias)
    self.up = nn.Linear(config.d_model, config.d_ff, bias=config.bias)
    self.down = nn.Linear(config.d_ff, config.d_model, bias=config.bias)

__call__(x)

Source code in src/lmxlab/core/ffn.py
def __call__(self, x: mx.array) -> mx.array:
    return self.down(nn.silu(self.gate(x)) * self.up(x))

Mixture of Experts

lmxlab.core.moe.MoEFFN

Bases: Module

Mixture of Experts feed-forward network.

Routes each token to the top-k experts via a learned router. Each expert is a GatedFFN (SwiGLU).

Parameters:

Name Type Description Default
config BlockConfig

Block configuration.

required
n_experts int | None

Total number of experts.

None
top_k int | None

Number of experts per token.

None
Source code in src/lmxlab/core/moe.py
@ffn_registry.register("moe")
class MoEFFN(nn.Module):
    """Mixture of Experts feed-forward network.

    Routes each token to the top-k experts via a learned
    router. Each expert is a GatedFFN (SwiGLU).

    Args:
        config: Block configuration.
        n_experts: Total number of experts.
        top_k: Number of experts per token.
    """

    def __init__(
        self,
        config: BlockConfig,
        n_experts: int | None = None,
        top_k: int | None = None,
    ) -> None:
        super().__init__()
        self.n_experts = n_experts or config.n_experts or 8
        self.top_k = top_k or config.top_k_experts

        # Router: projects hidden states to expert logits
        self.router = nn.Linear(config.d_model, self.n_experts, bias=False)

        # Expert FFNs
        self.experts = [GatedFFN(config) for _ in range(self.n_experts)]

    def __call__(self, x: mx.array) -> mx.array:
        """Route tokens to top-k experts and combine outputs.

        Args:
            x: Input tensor (batch, seq_len, d_model).

        Returns:
            Output tensor (batch, seq_len, d_model).
        """
        # Compute routing weights
        router_logits = self.router(x)  # (B, T, n_experts)

        # Select top-k experts
        top_k_indices = mx.argpartition(
            -router_logits, kth=self.top_k, axis=-1
        )[:, :, : self.top_k]  # (B, T, top_k)

        # Softmax over top-k logits only (Mixtral convention)
        top_k_logits = mx.take_along_axis(
            router_logits, top_k_indices, axis=-1
        )
        top_k_weights = mx.softmax(top_k_logits, axis=-1)

        # Compute expert outputs and combine
        output = mx.zeros_like(x)
        for k in range(self.top_k):
            expert_indices = top_k_indices[:, :, k]  # (B, T)
            weights = top_k_weights[:, :, k : k + 1]  # (B, T, 1)

            # Process each expert
            for e in range(self.n_experts):
                mask = expert_indices == e  # (B, T)
                if not mx.any(mask).item():
                    continue
                expert_out = self.experts[e](x)  # (B, T, D)
                mask_expanded = mask[:, :, None]  # (B, T, 1)
                output = output + expert_out * weights * mask_expanded

        return output

__init__(config, n_experts=None, top_k=None)

Source code in src/lmxlab/core/moe.py
def __init__(
    self,
    config: BlockConfig,
    n_experts: int | None = None,
    top_k: int | None = None,
) -> None:
    super().__init__()
    self.n_experts = n_experts or config.n_experts or 8
    self.top_k = top_k or config.top_k_experts

    # Router: projects hidden states to expert logits
    self.router = nn.Linear(config.d_model, self.n_experts, bias=False)

    # Expert FFNs
    self.experts = [GatedFFN(config) for _ in range(self.n_experts)]

__call__(x)

Route tokens to top-k experts and combine outputs.

Parameters:

Name Type Description Default
x array

Input tensor (batch, seq_len, d_model).

required

Returns:

Type Description
array

Output tensor (batch, seq_len, d_model).

Source code in src/lmxlab/core/moe.py
def __call__(self, x: mx.array) -> mx.array:
    """Route tokens to top-k experts and combine outputs.

    Args:
        x: Input tensor (batch, seq_len, d_model).

    Returns:
        Output tensor (batch, seq_len, d_model).
    """
    # Compute routing weights
    router_logits = self.router(x)  # (B, T, n_experts)

    # Select top-k experts
    top_k_indices = mx.argpartition(
        -router_logits, kth=self.top_k, axis=-1
    )[:, :, : self.top_k]  # (B, T, top_k)

    # Softmax over top-k logits only (Mixtral convention)
    top_k_logits = mx.take_along_axis(
        router_logits, top_k_indices, axis=-1
    )
    top_k_weights = mx.softmax(top_k_logits, axis=-1)

    # Compute expert outputs and combine
    output = mx.zeros_like(x)
    for k in range(self.top_k):
        expert_indices = top_k_indices[:, :, k]  # (B, T)
        weights = top_k_weights[:, :, k : k + 1]  # (B, T, 1)

        # Process each expert
        for e in range(self.n_experts):
            mask = expert_indices == e  # (B, T)
            if not mx.any(mask).item():
                continue
            expert_out = self.experts[e](x)  # (B, T, D)
            mask_expanded = mask[:, :, None]  # (B, T, 1)
            output = output + expert_out * weights * mask_expanded

    return output

lmxlab.core.moe.SharedExpertMoEFFN

Bases: Module

MoE with shared experts and bias-based load balancing.

Combines a set of always-active shared experts with top-k routed experts. Uses aux-loss-free load balancing: a learnable bias is added to router logits for expert selection, but the original un-biased scores are used for gating weights.

Output = shared_experts(x) + routed_experts_output(x)

Parameters:

Name Type Description Default
config BlockConfig

Block configuration.

required
n_experts int | None

Number of routed experts.

None
top_k int | None

Number of routed experts per token.

None
n_shared int | None

Number of shared (always-active) experts.

None
Source code in src/lmxlab/core/moe.py
@ffn_registry.register("shared_moe")
class SharedExpertMoEFFN(nn.Module):
    """MoE with shared experts and bias-based load balancing.

    Combines a set of always-active shared experts with top-k
    routed experts. Uses aux-loss-free load balancing: a
    learnable bias is added to router logits for expert
    selection, but the original un-biased scores are used
    for gating weights.

    Output = shared_experts(x) + routed_experts_output(x)

    Args:
        config: Block configuration.
        n_experts: Number of routed experts.
        top_k: Number of routed experts per token.
        n_shared: Number of shared (always-active) experts.
    """

    def __init__(
        self,
        config: BlockConfig,
        n_experts: int | None = None,
        top_k: int | None = None,
        n_shared: int | None = None,
    ) -> None:
        super().__init__()
        self.n_experts = n_experts or config.n_experts or 8
        self.top_k = top_k or config.top_k_experts
        self.n_shared = n_shared or config.n_shared_experts or 1

        # Router: projects hidden states to expert logits
        self.router = nn.Linear(config.d_model, self.n_experts, bias=False)

        # Learnable bias for aux-loss-free load balancing.
        # Added to logits for selection only, not for weights.
        self.expert_bias = mx.zeros((self.n_experts,))

        # Routed experts
        self.experts = [GatedFFN(config) for _ in range(self.n_experts)]

        # Shared experts (always active, not gated)
        self.shared_experts = [GatedFFN(config) for _ in range(self.n_shared)]

    def __call__(self, x: mx.array) -> mx.array:
        """Route tokens and combine with shared expert output.

        Args:
            x: Input tensor (batch, seq_len, d_model).

        Returns:
            Output tensor (batch, seq_len, d_model).
        """
        # --- Shared expert path (always active) ---
        shared_out = self.shared_experts[0](x)
        for i in range(1, self.n_shared):
            shared_out = shared_out + self.shared_experts[i](x)

        # --- Routed expert path ---
        router_logits = self.router(x)  # (B, T, E)

        # Bias-based selection: add bias for top-k picking
        biased_logits = router_logits + self.expert_bias

        # Select top-k using biased logits
        top_k_indices = mx.argpartition(
            -biased_logits, kth=self.top_k, axis=-1
        )[:, :, : self.top_k]  # (B, T, top_k)

        # Softmax over top-k un-biased logits (DSV3)
        top_k_logits = mx.take_along_axis(
            router_logits, top_k_indices, axis=-1
        )
        top_k_weights = mx.softmax(top_k_logits, axis=-1)

        # Compute routed expert outputs and combine
        routed_out = mx.zeros_like(x)
        for k in range(self.top_k):
            expert_indices = top_k_indices[:, :, k]
            weights = top_k_weights[:, :, k : k + 1]

            for e in range(self.n_experts):
                mask = expert_indices == e
                if not mx.any(mask).item():
                    continue
                expert_out = self.experts[e](x)
                mask_expanded = mask[:, :, None]
                routed_out = routed_out + expert_out * weights * mask_expanded

        return shared_out + routed_out

__init__(config, n_experts=None, top_k=None, n_shared=None)

Source code in src/lmxlab/core/moe.py
def __init__(
    self,
    config: BlockConfig,
    n_experts: int | None = None,
    top_k: int | None = None,
    n_shared: int | None = None,
) -> None:
    super().__init__()
    self.n_experts = n_experts or config.n_experts or 8
    self.top_k = top_k or config.top_k_experts
    self.n_shared = n_shared or config.n_shared_experts or 1

    # Router: projects hidden states to expert logits
    self.router = nn.Linear(config.d_model, self.n_experts, bias=False)

    # Learnable bias for aux-loss-free load balancing.
    # Added to logits for selection only, not for weights.
    self.expert_bias = mx.zeros((self.n_experts,))

    # Routed experts
    self.experts = [GatedFFN(config) for _ in range(self.n_experts)]

    # Shared experts (always active, not gated)
    self.shared_experts = [GatedFFN(config) for _ in range(self.n_shared)]

__call__(x)

Route tokens and combine with shared expert output.

Parameters:

Name Type Description Default
x array

Input tensor (batch, seq_len, d_model).

required

Returns:

Type Description
array

Output tensor (batch, seq_len, d_model).

Source code in src/lmxlab/core/moe.py
def __call__(self, x: mx.array) -> mx.array:
    """Route tokens and combine with shared expert output.

    Args:
        x: Input tensor (batch, seq_len, d_model).

    Returns:
        Output tensor (batch, seq_len, d_model).
    """
    # --- Shared expert path (always active) ---
    shared_out = self.shared_experts[0](x)
    for i in range(1, self.n_shared):
        shared_out = shared_out + self.shared_experts[i](x)

    # --- Routed expert path ---
    router_logits = self.router(x)  # (B, T, E)

    # Bias-based selection: add bias for top-k picking
    biased_logits = router_logits + self.expert_bias

    # Select top-k using biased logits
    top_k_indices = mx.argpartition(
        -biased_logits, kth=self.top_k, axis=-1
    )[:, :, : self.top_k]  # (B, T, top_k)

    # Softmax over top-k un-biased logits (DSV3)
    top_k_logits = mx.take_along_axis(
        router_logits, top_k_indices, axis=-1
    )
    top_k_weights = mx.softmax(top_k_logits, axis=-1)

    # Compute routed expert outputs and combine
    routed_out = mx.zeros_like(x)
    for k in range(self.top_k):
        expert_indices = top_k_indices[:, :, k]
        weights = top_k_weights[:, :, k : k + 1]

        for e in range(self.n_experts):
            mask = expert_indices == e
            if not mx.any(mask).item():
                continue
            expert_out = self.experts[e](x)
            mask_expanded = mask[:, :, None]
            routed_out = routed_out + expert_out * weights * mask_expanded

    return shared_out + routed_out

Normalization

lmxlab.core.norm

Normalization wrappers for registry use.

LayerNorm

Bases: LayerNorm

LayerNorm wrapper that constructs from BlockConfig.

Source code in src/lmxlab/core/norm.py
@norm_registry.register("layer_norm")
class LayerNorm(nn.LayerNorm):
    """LayerNorm wrapper that constructs from BlockConfig."""

    def __init__(self, config: BlockConfig) -> None:
        super().__init__(config.d_model, eps=config.norm_eps)

RMSNorm

Bases: RMSNorm

RMSNorm wrapper that constructs from BlockConfig.

Source code in src/lmxlab/core/norm.py
@norm_registry.register("rms_norm")
class RMSNorm(nn.RMSNorm):
    """RMSNorm wrapper that constructs from BlockConfig."""

    def __init__(self, config: BlockConfig) -> None:
        super().__init__(config.d_model, eps=config.norm_eps)

layer_norm(config)

Create a LayerNorm from config.

Source code in src/lmxlab/core/norm.py
def layer_norm(config: BlockConfig) -> LayerNorm:
    """Create a LayerNorm from config."""
    return LayerNorm(config)

rms_norm(config)

Create an RMSNorm from config.

Source code in src/lmxlab/core/norm.py
def rms_norm(config: BlockConfig) -> RMSNorm:
    """Create an RMSNorm from config."""
    return RMSNorm(config)

Position Encoding

lmxlab.core.position

Positional encoding modules: RoPE, ALiBi, Sinusoidal.

ALiBi

Bases: Module

Attention with Linear Biases (Press et al. ICLR 2022, arXiv:2108.12409).

Adds head-specific distance-based biases to the attention mask. Each head gets a fixed slope from a geometric sequence, penalizing distant tokens more strongly. Replaces explicit positional embeddings — no PE is added to inputs.

Applied to the attention mask before softmax, not to Q/K.

Source code in src/lmxlab/core/position.py
@position_registry.register("alibi")
class ALiBi(nn.Module):
    """Attention with Linear Biases (Press et al. ICLR 2022, arXiv:2108.12409).

    Adds head-specific distance-based biases to the attention
    mask. Each head gets a fixed slope from a geometric
    sequence, penalizing distant tokens more strongly. Replaces
    explicit positional embeddings — no PE is added to inputs.

    Applied to the attention mask before softmax, not to Q/K.
    """

    def __init__(self, config: BlockConfig) -> None:
        super().__init__()
        self._alibi = nn.ALiBi()
        self.n_heads = config.n_heads

    def __call__(
        self,
        mask: mx.array | None = None,
        seq_len: int = 0,
        cache_len: int = 0,
    ) -> mx.array:
        """Create ALiBi-biased attention mask.

        Args:
            mask: Optional causal mask (T_q, T_k).
            seq_len: Query sequence length.
            cache_len: Cached key sequence length (offset).

        Returns:
            Combined ALiBi bias + mask (1, H, T_q, T_k).
        """
        t_q = mask.shape[-2] if mask is not None else seq_len
        t_k = mask.shape[-1] if mask is not None else seq_len + cache_len
        dummy = mx.zeros((1, self.n_heads, t_q, t_k))
        return self._alibi(dummy, offset=cache_len, mask=mask)

__call__(mask=None, seq_len=0, cache_len=0)

Create ALiBi-biased attention mask.

Parameters:

Name Type Description Default
mask array | None

Optional causal mask (T_q, T_k).

None
seq_len int

Query sequence length.

0
cache_len int

Cached key sequence length (offset).

0

Returns:

Type Description
array

Combined ALiBi bias + mask (1, H, T_q, T_k).

Source code in src/lmxlab/core/position.py
def __call__(
    self,
    mask: mx.array | None = None,
    seq_len: int = 0,
    cache_len: int = 0,
) -> mx.array:
    """Create ALiBi-biased attention mask.

    Args:
        mask: Optional causal mask (T_q, T_k).
        seq_len: Query sequence length.
        cache_len: Cached key sequence length (offset).

    Returns:
        Combined ALiBi bias + mask (1, H, T_q, T_k).
    """
    t_q = mask.shape[-2] if mask is not None else seq_len
    t_k = mask.shape[-1] if mask is not None else seq_len + cache_len
    dummy = mx.zeros((1, self.n_heads, t_q, t_k))
    return self._alibi(dummy, offset=cache_len, mask=mask)

NoPosition

Bases: Module

No positional encoding (identity).

Used by architectures that get position information from other mechanisms (e.g. causal convolutions in DeltaNet).

Source code in src/lmxlab/core/position.py
@position_registry.register("none")
class NoPosition(nn.Module):
    """No positional encoding (identity).

    Used by architectures that get position information from
    other mechanisms (e.g. causal convolutions in DeltaNet).
    """

    def __init__(self, config: BlockConfig) -> None:
        super().__init__()

    def __call__(self, x: mx.array) -> mx.array:
        """Return input unchanged."""
        return x

__call__(x)

Return input unchanged.

Source code in src/lmxlab/core/position.py
def __call__(self, x: mx.array) -> mx.array:
    """Return input unchanged."""
    return x

RoPE

Bases: Module

Rotary Position Embedding wrapper.

Wraps nn.RoPE with config-driven initialization.

Source code in src/lmxlab/core/position.py
@position_registry.register("rope")
class RoPE(nn.Module):
    """Rotary Position Embedding wrapper.

    Wraps nn.RoPE with config-driven initialization.
    """

    def __init__(self, config: BlockConfig) -> None:
        super().__init__()
        self._rope = nn.RoPE(
            config.head_dim,
            traditional=False,
            base=config.rope_theta,
        )

    def __call__(
        self,
        q: mx.array,
        k: mx.array,
        offset: int = 0,
    ) -> tuple[mx.array, mx.array]:
        """Apply rotary embeddings to queries and keys.

        Args:
            q: Query tensor (batch, heads, seq, head_dim).
            k: Key tensor (batch, kv_heads, seq, head_dim).
            offset: Position offset for KV cache.

        Returns:
            Tuple of (rotated_q, rotated_k).
        """
        q = self._rope(q, offset=offset)
        k = self._rope(k, offset=offset)
        return q, k

__call__(q, k, offset=0)

Apply rotary embeddings to queries and keys.

Parameters:

Name Type Description Default
q array

Query tensor (batch, heads, seq, head_dim).

required
k array

Key tensor (batch, kv_heads, seq, head_dim).

required
offset int

Position offset for KV cache.

0

Returns:

Type Description
tuple[array, array]

Tuple of (rotated_q, rotated_k).

Source code in src/lmxlab/core/position.py
def __call__(
    self,
    q: mx.array,
    k: mx.array,
    offset: int = 0,
) -> tuple[mx.array, mx.array]:
    """Apply rotary embeddings to queries and keys.

    Args:
        q: Query tensor (batch, heads, seq, head_dim).
        k: Key tensor (batch, kv_heads, seq, head_dim).
        offset: Position offset for KV cache.

    Returns:
        Tuple of (rotated_q, rotated_k).
    """
    q = self._rope(q, offset=offset)
    k = self._rope(k, offset=offset)
    return q, k

Sinusoidal

Bases: Module

Sinusoidal positional encoding (added to embeddings).

Source code in src/lmxlab/core/position.py
@position_registry.register("sinusoidal")
class Sinusoidal(nn.Module):
    """Sinusoidal positional encoding (added to embeddings)."""

    def __init__(self, config: BlockConfig) -> None:
        super().__init__()
        self._embed = nn.SinusoidalPositionalEncoding(
            config.d_model,
            full_turns=True,
        )

    def __call__(self, x: mx.array) -> mx.array:
        """Add sinusoidal position encoding to input.

        Args:
            x: Input tensor (batch, seq_len, d_model).

        Returns:
            Input with positional encoding added.
        """
        seq_len = x.shape[1]
        positions = mx.arange(seq_len)
        pe = self._embed(positions)  # (seq_len, d_model)
        return x + pe

__call__(x)

Add sinusoidal position encoding to input.

Parameters:

Name Type Description Default
x array

Input tensor (batch, seq_len, d_model).

required

Returns:

Type Description
array

Input with positional encoding added.

Source code in src/lmxlab/core/position.py
def __call__(self, x: mx.array) -> mx.array:
    """Add sinusoidal position encoding to input.

    Args:
        x: Input tensor (batch, seq_len, d_model).

    Returns:
        Input with positional encoding added.
    """
    seq_len = x.shape[1]
    positions = mx.arange(seq_len)
    pe = self._embed(positions)  # (seq_len, d_model)
    return x + pe

alibi(config)

Create an ALiBi module from config.

Source code in src/lmxlab/core/position.py
def alibi(config: BlockConfig) -> ALiBi:
    """Create an ALiBi module from config."""
    return ALiBi(config)

rope(config)

Create a RoPE module from config.

Source code in src/lmxlab/core/position.py
def rope(config: BlockConfig) -> RoPE:
    """Create a RoPE module from config."""
    return RoPE(config)

sinusoidal(config)

Create a Sinusoidal module from config.

Source code in src/lmxlab/core/position.py
def sinusoidal(config: BlockConfig) -> Sinusoidal:
    """Create a Sinusoidal module from config."""
    return Sinusoidal(config)

Quantization

lmxlab.core.quantize.quantize_model(model, bits=4, group_size=64, mode='affine')

Quantize all Linear and Embedding layers in-place.

Uses MLX's native quantization. Linear layers become nn.QuantizedLinear, Embedding layers become nn.QuantizedEmbedding. Norm layers and other modules are left unchanged.

Parameters:

Name Type Description Default
model Module

Model to quantize (modified in-place).

required
bits int

Bits per weight (2, 4, or 8).

4
group_size int

Quantization group size (32, 64, or 128).

64
mode str

Quantization mode. Default: "affine".

'affine'
Source code in src/lmxlab/core/quantize.py
def quantize_model(
    model: nn.Module,
    bits: int = 4,
    group_size: int = 64,
    mode: str = "affine",
) -> None:
    """Quantize all Linear and Embedding layers in-place.

    Uses MLX's native quantization. Linear layers become
    ``nn.QuantizedLinear``, Embedding layers become
    ``nn.QuantizedEmbedding``. Norm layers and other modules
    are left unchanged.

    Args:
        model: Model to quantize (modified in-place).
        bits: Bits per weight (2, 4, or 8).
        group_size: Quantization group size (32, 64, or 128).
        mode: Quantization mode. Default: ``"affine"``.
    """
    nn.quantize(model, group_size=group_size, bits=bits, mode=mode)

lmxlab.core.quantize.dequantize_model(model)

Dequantize all QuantizedLinear layers back to Linear.

Reconstructs float weights from quantized representation. Useful for fine-tuning after loading quantized weights.

Parameters:

Name Type Description Default
model Module

Model to dequantize (modified in-place).

required
Source code in src/lmxlab/core/quantize.py
def dequantize_model(model: nn.Module) -> None:
    """Dequantize all QuantizedLinear layers back to Linear.

    Reconstructs float weights from quantized representation.
    Useful for fine-tuning after loading quantized weights.

    Args:
        model: Model to dequantize (modified in-place).
    """

    def _maybe_dequantize(_path: str, m: nn.Module) -> nn.Module:
        if isinstance(m, nn.QuantizedLinear):
            # Reconstruct float weight from quantized form
            weight = mx.dequantize(
                m.weight,
                m.scales,
                m.get("biases"),
                m.group_size,
                m.bits,
            )
            has_bias = "bias" in m
            linear = nn.Linear(weight.shape[1], weight.shape[0], bias=has_bias)
            linear.weight = weight
            if has_bias:
                linear.bias = m.bias
            return linear
        if isinstance(m, nn.QuantizedEmbedding):
            weight = mx.dequantize(
                m.weight,
                m.scales,
                m.get("biases"),
                m.group_size,
                m.bits,
            )
            embed = nn.Embedding(weight.shape[0], weight.shape[1])
            embed.weight = weight
            return embed
        return m

    leaves = model.leaf_modules()
    leaves = tree_map_with_path(
        _maybe_dequantize, leaves, is_leaf=nn.Module.is_module
    )
    model.update_modules(leaves)

LoRA

lmxlab.core.lora.LoRALinear

Bases: Module

Linear layer with low-rank adaptation.

Computes: y = xW^T + b + scaling * x @ A @ B^T

where W is frozen and A, B are trainable low-rank matrices. B is zero-initialized so the initial output equals the base linear layer's output.

Parameters:

Name Type Description Default
input_dims int

Input feature dimension.

required
output_dims int

Output feature dimension.

required
rank int

LoRA rank (low-rank dimension).

8
alpha float

LoRA scaling factor. Effective scaling = alpha/rank.

1.0
bias bool

Whether the base layer has bias.

False
Source code in src/lmxlab/core/lora.py
class LoRALinear(nn.Module):
    """Linear layer with low-rank adaptation.

    Computes: y = xW^T + b + scaling * x @ A @ B^T

    where W is frozen and A, B are trainable low-rank matrices.
    B is zero-initialized so the initial output equals the base
    linear layer's output.

    Args:
        input_dims: Input feature dimension.
        output_dims: Output feature dimension.
        rank: LoRA rank (low-rank dimension).
        alpha: LoRA scaling factor. Effective scaling = alpha/rank.
        bias: Whether the base layer has bias.
    """

    def __init__(
        self,
        input_dims: int,
        output_dims: int,
        rank: int = 8,
        alpha: float = 1.0,
        bias: bool = False,
    ) -> None:
        super().__init__()

        self.rank = rank
        self.scaling = alpha / rank

        # Base weight (frozen)
        scale = math.sqrt(1 / input_dims)
        self.weight = mx.random.uniform(
            low=-scale,
            high=scale,
            shape=(output_dims, input_dims),
        )
        if bias:
            self.bias = mx.zeros((output_dims,))

        # LoRA matrices (trainable)
        # A: Kaiming normal init
        self.lora_A = mx.random.normal((input_dims, rank)) * math.sqrt(
            2 / input_dims
        )
        # B: zero init (so initial LoRA contribution is zero)
        self.lora_B = mx.zeros((rank, output_dims))

        # Freeze base weight, keep LoRA trainable
        self.freeze(keys=["weight", "bias"], recurse=False)

    def __call__(self, x: mx.array) -> mx.array:
        # Base: x @ W^T + bias
        y = x @ self.weight.T
        if "bias" in self:
            y = y + self.bias
        # LoRA: scaling * x @ A @ B
        y = y + (x @ self.lora_A @ self.lora_B) * self.scaling
        return y

    @classmethod
    def from_linear(
        cls,
        linear: nn.Linear,
        rank: int = 8,
        alpha: float = 1.0,
    ) -> "LoRALinear":
        """Create a LoRALinear from an existing nn.Linear.

        Copies the base weight and bias, then adds LoRA matrices.

        Args:
            linear: Base linear layer to wrap.
            rank: LoRA rank.
            alpha: LoRA scaling factor.

        Returns:
            LoRALinear with the same base weights.
        """
        output_dims, input_dims = linear.weight.shape
        has_bias = "bias" in linear

        lora = cls(input_dims, output_dims, rank, alpha, bias=has_bias)
        lora.weight = linear.weight
        if has_bias:
            lora.bias = linear.bias
        # Re-freeze after setting weights
        lora.freeze(keys=["weight", "bias"], recurse=False)
        return lora

    def to_linear(self) -> nn.Linear:
        """Merge LoRA weights and return a plain nn.Linear.

        Computes W_merged = W + scaling * (A @ B)^T and returns
        a new nn.Linear with the merged weight.
        """
        has_bias = "bias" in self
        merged_weight = (
            self.weight + (self.lora_A @ self.lora_B).T * self.scaling
        )
        linear = nn.Linear(
            self.weight.shape[1],
            self.weight.shape[0],
            bias=has_bias,
        )
        linear.weight = merged_weight
        if has_bias:
            linear.bias = self.bias
        return linear

__init__(input_dims, output_dims, rank=8, alpha=1.0, bias=False)

Source code in src/lmxlab/core/lora.py
def __init__(
    self,
    input_dims: int,
    output_dims: int,
    rank: int = 8,
    alpha: float = 1.0,
    bias: bool = False,
) -> None:
    super().__init__()

    self.rank = rank
    self.scaling = alpha / rank

    # Base weight (frozen)
    scale = math.sqrt(1 / input_dims)
    self.weight = mx.random.uniform(
        low=-scale,
        high=scale,
        shape=(output_dims, input_dims),
    )
    if bias:
        self.bias = mx.zeros((output_dims,))

    # LoRA matrices (trainable)
    # A: Kaiming normal init
    self.lora_A = mx.random.normal((input_dims, rank)) * math.sqrt(
        2 / input_dims
    )
    # B: zero init (so initial LoRA contribution is zero)
    self.lora_B = mx.zeros((rank, output_dims))

    # Freeze base weight, keep LoRA trainable
    self.freeze(keys=["weight", "bias"], recurse=False)

__call__(x)

Source code in src/lmxlab/core/lora.py
def __call__(self, x: mx.array) -> mx.array:
    # Base: x @ W^T + bias
    y = x @ self.weight.T
    if "bias" in self:
        y = y + self.bias
    # LoRA: scaling * x @ A @ B
    y = y + (x @ self.lora_A @ self.lora_B) * self.scaling
    return y

from_linear(linear, rank=8, alpha=1.0) classmethod

Create a LoRALinear from an existing nn.Linear.

Copies the base weight and bias, then adds LoRA matrices.

Parameters:

Name Type Description Default
linear Linear

Base linear layer to wrap.

required
rank int

LoRA rank.

8
alpha float

LoRA scaling factor.

1.0

Returns:

Type Description
LoRALinear

LoRALinear with the same base weights.

Source code in src/lmxlab/core/lora.py
@classmethod
def from_linear(
    cls,
    linear: nn.Linear,
    rank: int = 8,
    alpha: float = 1.0,
) -> "LoRALinear":
    """Create a LoRALinear from an existing nn.Linear.

    Copies the base weight and bias, then adds LoRA matrices.

    Args:
        linear: Base linear layer to wrap.
        rank: LoRA rank.
        alpha: LoRA scaling factor.

    Returns:
        LoRALinear with the same base weights.
    """
    output_dims, input_dims = linear.weight.shape
    has_bias = "bias" in linear

    lora = cls(input_dims, output_dims, rank, alpha, bias=has_bias)
    lora.weight = linear.weight
    if has_bias:
        lora.bias = linear.bias
    # Re-freeze after setting weights
    lora.freeze(keys=["weight", "bias"], recurse=False)
    return lora

to_linear()

Merge LoRA weights and return a plain nn.Linear.

Computes W_merged = W + scaling * (A @ B)^T and returns a new nn.Linear with the merged weight.

Source code in src/lmxlab/core/lora.py
def to_linear(self) -> nn.Linear:
    """Merge LoRA weights and return a plain nn.Linear.

    Computes W_merged = W + scaling * (A @ B)^T and returns
    a new nn.Linear with the merged weight.
    """
    has_bias = "bias" in self
    merged_weight = (
        self.weight + (self.lora_A @ self.lora_B).T * self.scaling
    )
    linear = nn.Linear(
        self.weight.shape[1],
        self.weight.shape[0],
        bias=has_bias,
    )
    linear.weight = merged_weight
    if has_bias:
        linear.bias = self.bias
    return linear

lmxlab.core.lora.apply_lora(model, rank=8, alpha=1.0, targets=None)

Apply LoRA to a model's linear layers in-place.

Replaces targeted nn.Linear layers with LoRALinear, freezing the base weights and making only the LoRA matrices trainable.

Parameters:

Name Type Description Default
model Module

Model to modify (in-place).

required
rank int

LoRA rank for all adapted layers.

8
alpha float

LoRA scaling factor.

1.0
targets list[str] | None

Which submodules to target. Options: 'attention' (q/k/v/o projections), 'ffn' (gate/up/down projections). Default: ['attention'].

None
Source code in src/lmxlab/core/lora.py
def apply_lora(
    model: nn.Module,
    rank: int = 8,
    alpha: float = 1.0,
    targets: list[str] | None = None,
) -> None:
    """Apply LoRA to a model's linear layers in-place.

    Replaces targeted ``nn.Linear`` layers with ``LoRALinear``,
    freezing the base weights and making only the LoRA matrices
    trainable.

    Args:
        model: Model to modify (in-place).
        rank: LoRA rank for all adapted layers.
        alpha: LoRA scaling factor.
        targets: Which submodules to target. Options:
            ``'attention'`` (q/k/v/o projections),
            ``'ffn'`` (gate/up/down projections).
            Default: ``['attention']``.
    """
    if targets is None:
        targets = ["attention"]

    def _maybe_lora(path: str, m: nn.Module) -> nn.Module:
        if not isinstance(m, nn.Linear):
            return m
        # Check if path matches any target
        for target in targets:
            if target in path:
                return LoRALinear.from_linear(m, rank, alpha)
        return m

    leaves = model.leaf_modules()
    leaves = tree_map_with_path(
        _maybe_lora, leaves, is_leaf=nn.Module.is_module
    )
    model.update_modules(leaves)

    # Freeze everything, then unfreeze only LoRA params
    model.freeze()
    model.unfreeze(keys=["lora_A", "lora_B"])

lmxlab.core.lora.merge_lora(model)

Merge LoRA weights into base weights in-place.

Replaces all LoRALinear layers with plain nn.Linear layers whose weights include the LoRA contribution. After merging, the model has the same output but no LoRA overhead.

Parameters:

Name Type Description Default
model Module

Model to merge (in-place).

required
Source code in src/lmxlab/core/lora.py
def merge_lora(model: nn.Module) -> None:
    """Merge LoRA weights into base weights in-place.

    Replaces all ``LoRALinear`` layers with plain ``nn.Linear``
    layers whose weights include the LoRA contribution. After
    merging, the model has the same output but no LoRA overhead.

    Args:
        model: Model to merge (in-place).
    """

    def _maybe_merge(_path: str, m: nn.Module) -> nn.Module:
        if isinstance(m, LoRALinear):
            return m.to_linear()
        return m

    leaves = model.leaf_modules()
    leaves = tree_map_with_path(
        _maybe_merge, leaves, is_leaf=nn.Module.is_module
    )
    model.update_modules(leaves)

lmxlab.core.lora.lora_parameters(model)

Extract only LoRA parameters from a model.

Useful for saving/loading just the LoRA adapter weights, which are much smaller than the full model.

Parameters:

Name Type Description Default
model Module

Model with LoRA layers.

required

Returns:

Type Description
dict

Nested dict of only lora_A and lora_B parameters.

Source code in src/lmxlab/core/lora.py
def lora_parameters(model: nn.Module) -> dict:
    """Extract only LoRA parameters from a model.

    Useful for saving/loading just the LoRA adapter weights,
    which are much smaller than the full model.

    Args:
        model: Model with LoRA layers.

    Returns:
        Nested dict of only lora_A and lora_B parameters.
    """
    all_params = model.trainable_parameters()
    # trainable_parameters already returns only unfrozen params,
    # which for LoRA models means only lora_A and lora_B
    return all_params

lmxlab.core.lora.save_lora_adapters(path, model, rank=None, alpha=None, metadata=None)

Save only the LoRA adapter weights to a directory.

Saves lora_A and lora_B parameters as a small safetensors file, much smaller than a full model checkpoint. The adapter can be loaded on top of any compatible base model.

Parameters:

Name Type Description Default
path str | Path

Directory to save the adapter.

required
model Module

Model with LoRA layers applied.

required
rank int | None

LoRA rank (saved in config for reference).

None
alpha float | None

LoRA alpha (saved in config for reference).

None
metadata dict[str, Any] | None

Additional metadata to include in config.

None
Source code in src/lmxlab/core/lora.py
def save_lora_adapters(
    path: str | Path,
    model: nn.Module,
    rank: int | None = None,
    alpha: float | None = None,
    metadata: dict[str, Any] | None = None,
) -> None:
    """Save only the LoRA adapter weights to a directory.

    Saves lora_A and lora_B parameters as a small safetensors file,
    much smaller than a full model checkpoint. The adapter can be
    loaded on top of any compatible base model.

    Args:
        path: Directory to save the adapter.
        model: Model with LoRA layers applied.
        rank: LoRA rank (saved in config for reference).
        alpha: LoRA alpha (saved in config for reference).
        metadata: Additional metadata to include in config.
    """
    path = Path(path)
    path.mkdir(parents=True, exist_ok=True)

    # Extract only LoRA parameters
    lora_weights = {}
    for key, value in mlx.utils.tree_flatten(model.parameters()):
        if "lora_A" in key or "lora_B" in key:
            lora_weights[key] = value

    mx.save_safetensors(str(path / "adapter.safetensors"), lora_weights)

    # Save config
    config: dict[str, Any] = {}
    if rank is not None:
        config["rank"] = rank
    if alpha is not None:
        config["alpha"] = alpha
    if metadata:
        config.update(metadata)

    (path / "adapter_config.json").write_text(json.dumps(config, indent=2))

lmxlab.core.lora.load_lora_adapters(path, model)

Load LoRA adapter weights into a model.

The model must already have LoRA layers applied (via apply_lora). This function loads only the lora_A and lora_B weights from a previously saved adapter.

Parameters:

Name Type Description Default
path str | Path

Directory containing the adapter files.

required
model Module

Model with LoRA layers to load weights into.

required

Returns:

Type Description
dict[str, Any]

Metadata dict from adapter_config.json.

Raises:

Type Description
FileNotFoundError

If path does not exist.

Source code in src/lmxlab/core/lora.py
def load_lora_adapters(
    path: str | Path,
    model: nn.Module,
) -> dict[str, Any]:
    """Load LoRA adapter weights into a model.

    The model must already have LoRA layers applied (via
    ``apply_lora``). This function loads only the lora_A and
    lora_B weights from a previously saved adapter.

    Args:
        path: Directory containing the adapter files.
        model: Model with LoRA layers to load weights into.

    Returns:
        Metadata dict from adapter_config.json.

    Raises:
        FileNotFoundError: If path does not exist.
    """
    path = Path(path)
    if not path.exists():
        raise FileNotFoundError(f"Adapter directory not found: {path}")

    weights = mx.load(str(path / "adapter.safetensors"))
    model.load_weights(list(weights.items()), strict=False)

    config_path = path / "adapter_config.json"
    if config_path.exists():
        return json.loads(config_path.read_text())
    return {}

QLoRA

lmxlab.core.qlora.LoRAQuantizedLinear

Bases: Module

Quantized linear layer with low-rank adaptation.

Computes: y = quantized_matmul(x, W_q) + bias + scaling * x @ A @ B

where W_q is a frozen quantized weight and A, B are trainable float LoRA matrices. B is zero-initialized so the initial output equals the quantized layer's output.

Parameters:

Name Type Description Default
input_dims int

Input feature dimension.

required
output_dims int

Output feature dimension.

required
rank int

LoRA rank (low-rank dimension).

8
alpha float

LoRA scaling factor. Effective scaling = alpha/rank.

1.0
bias bool

Whether the layer has bias.

False
group_size int

Quantization group size.

64
bits int

Quantization bits.

4
mode str

Quantization mode.

'affine'
Source code in src/lmxlab/core/qlora.py
class LoRAQuantizedLinear(nn.Module):
    """Quantized linear layer with low-rank adaptation.

    Computes: y = quantized_matmul(x, W_q) + bias + scaling * x @ A @ B

    where W_q is a frozen quantized weight and A, B are trainable
    float LoRA matrices. B is zero-initialized so the initial output
    equals the quantized layer's output.

    Args:
        input_dims: Input feature dimension.
        output_dims: Output feature dimension.
        rank: LoRA rank (low-rank dimension).
        alpha: LoRA scaling factor. Effective scaling = alpha/rank.
        bias: Whether the layer has bias.
        group_size: Quantization group size.
        bits: Quantization bits.
        mode: Quantization mode.
    """

    def __init__(
        self,
        input_dims: int,
        output_dims: int,
        rank: int = 8,
        alpha: float = 1.0,
        bias: bool = False,
        group_size: int = 64,
        bits: int = 4,
        mode: str = "affine",
    ) -> None:
        super().__init__()

        self.rank = rank
        self.scaling = alpha / rank
        self.group_size = group_size
        self.bits = bits
        self.mode = mode

        # Quantized base weight (frozen) — placeholder; use from_quantized
        scale = math.sqrt(1 / input_dims)
        weight = mx.random.uniform(
            low=-scale,
            high=scale,
            shape=(output_dims, input_dims),
        )
        self.weight, self.scales, *biases = mx.quantize(
            weight, group_size, bits, mode=mode
        )
        self.biases = biases[0] if biases else None

        if bias:
            self.bias = mx.zeros((output_dims,))

        # LoRA matrices (trainable, float)
        self.lora_A = mx.random.normal((input_dims, rank)) * math.sqrt(
            2 / input_dims
        )
        self.lora_B = mx.zeros((rank, output_dims))

        # Freeze everything except LoRA
        self.freeze(
            keys=["weight", "scales", "biases", "bias"],
            recurse=False,
        )

    def __call__(self, x: mx.array) -> mx.array:
        # Quantized base: uses efficient quantized matmul
        y = mx.quantized_matmul(
            x,
            self["weight"],
            scales=self["scales"],
            biases=self.get("biases"),
            transpose=True,
            group_size=self.group_size,
            bits=self.bits,
            mode=self.mode,
        )
        if "bias" in self:
            y = y + self["bias"]
        # LoRA: full-precision low-rank update
        y = y + (x @ self.lora_A @ self.lora_B) * self.scaling
        return y

    @classmethod
    def from_quantized(
        cls,
        ql: nn.QuantizedLinear,
        rank: int = 8,
        alpha: float = 1.0,
    ) -> "LoRAQuantizedLinear":
        """Create from an existing QuantizedLinear layer.

        Copies the quantized weights and adds LoRA adapters.

        Args:
            ql: Quantized linear layer to wrap.
            rank: LoRA rank.
            alpha: LoRA scaling factor.

        Returns:
            LoRAQuantizedLinear with same quantized base weights.
        """
        # Infer dimensions from quantized weight
        out_dims = ql.weight.shape[0]
        in_dims = (ql.weight.shape[1] * 32) // ql.bits
        has_bias = "bias" in ql

        lora_ql = cls(
            in_dims,
            out_dims,
            rank,
            alpha,
            bias=has_bias,
            group_size=ql.group_size,
            bits=ql.bits,
            mode=ql.mode,
        )
        # Copy quantized state
        lora_ql.weight = ql.weight
        lora_ql.scales = ql.scales
        if ql.get("biases") is not None:
            lora_ql.biases = ql["biases"]
        if has_bias:
            lora_ql.bias = ql.bias

        # Re-freeze quantized params
        lora_ql.freeze(
            keys=["weight", "scales", "biases", "bias"],
            recurse=False,
        )
        return lora_ql

__init__(input_dims, output_dims, rank=8, alpha=1.0, bias=False, group_size=64, bits=4, mode='affine')

Source code in src/lmxlab/core/qlora.py
def __init__(
    self,
    input_dims: int,
    output_dims: int,
    rank: int = 8,
    alpha: float = 1.0,
    bias: bool = False,
    group_size: int = 64,
    bits: int = 4,
    mode: str = "affine",
) -> None:
    super().__init__()

    self.rank = rank
    self.scaling = alpha / rank
    self.group_size = group_size
    self.bits = bits
    self.mode = mode

    # Quantized base weight (frozen) — placeholder; use from_quantized
    scale = math.sqrt(1 / input_dims)
    weight = mx.random.uniform(
        low=-scale,
        high=scale,
        shape=(output_dims, input_dims),
    )
    self.weight, self.scales, *biases = mx.quantize(
        weight, group_size, bits, mode=mode
    )
    self.biases = biases[0] if biases else None

    if bias:
        self.bias = mx.zeros((output_dims,))

    # LoRA matrices (trainable, float)
    self.lora_A = mx.random.normal((input_dims, rank)) * math.sqrt(
        2 / input_dims
    )
    self.lora_B = mx.zeros((rank, output_dims))

    # Freeze everything except LoRA
    self.freeze(
        keys=["weight", "scales", "biases", "bias"],
        recurse=False,
    )

__call__(x)

Source code in src/lmxlab/core/qlora.py
def __call__(self, x: mx.array) -> mx.array:
    # Quantized base: uses efficient quantized matmul
    y = mx.quantized_matmul(
        x,
        self["weight"],
        scales=self["scales"],
        biases=self.get("biases"),
        transpose=True,
        group_size=self.group_size,
        bits=self.bits,
        mode=self.mode,
    )
    if "bias" in self:
        y = y + self["bias"]
    # LoRA: full-precision low-rank update
    y = y + (x @ self.lora_A @ self.lora_B) * self.scaling
    return y

from_quantized(ql, rank=8, alpha=1.0) classmethod

Create from an existing QuantizedLinear layer.

Copies the quantized weights and adds LoRA adapters.

Parameters:

Name Type Description Default
ql QuantizedLinear

Quantized linear layer to wrap.

required
rank int

LoRA rank.

8
alpha float

LoRA scaling factor.

1.0

Returns:

Type Description
LoRAQuantizedLinear

LoRAQuantizedLinear with same quantized base weights.

Source code in src/lmxlab/core/qlora.py
@classmethod
def from_quantized(
    cls,
    ql: nn.QuantizedLinear,
    rank: int = 8,
    alpha: float = 1.0,
) -> "LoRAQuantizedLinear":
    """Create from an existing QuantizedLinear layer.

    Copies the quantized weights and adds LoRA adapters.

    Args:
        ql: Quantized linear layer to wrap.
        rank: LoRA rank.
        alpha: LoRA scaling factor.

    Returns:
        LoRAQuantizedLinear with same quantized base weights.
    """
    # Infer dimensions from quantized weight
    out_dims = ql.weight.shape[0]
    in_dims = (ql.weight.shape[1] * 32) // ql.bits
    has_bias = "bias" in ql

    lora_ql = cls(
        in_dims,
        out_dims,
        rank,
        alpha,
        bias=has_bias,
        group_size=ql.group_size,
        bits=ql.bits,
        mode=ql.mode,
    )
    # Copy quantized state
    lora_ql.weight = ql.weight
    lora_ql.scales = ql.scales
    if ql.get("biases") is not None:
        lora_ql.biases = ql["biases"]
    if has_bias:
        lora_ql.bias = ql.bias

    # Re-freeze quantized params
    lora_ql.freeze(
        keys=["weight", "scales", "biases", "bias"],
        recurse=False,
    )
    return lora_ql

lmxlab.core.qlora.apply_qlora(model, rank=8, alpha=1.0, targets=None)

Apply QLoRA to a quantized model's layers in-place.

Replaces targeted nn.QuantizedLinear layers with LoRAQuantizedLinear, keeping base weights quantized and adding trainable float LoRA matrices.

The model should already be quantized (via quantize_model or nn.quantize) before calling this.

Parameters:

Name Type Description Default
model Module

Quantized model to modify (in-place).

required
rank int

LoRA rank for all adapted layers.

8
alpha float

LoRA scaling factor.

1.0
targets list[str] | None

Which submodules to target. Options: 'attention' (q/k/v/o projections), 'ffn' (gate/up/down projections). Default: ['attention'].

None
Source code in src/lmxlab/core/qlora.py
def apply_qlora(
    model: nn.Module,
    rank: int = 8,
    alpha: float = 1.0,
    targets: list[str] | None = None,
) -> None:
    """Apply QLoRA to a quantized model's layers in-place.

    Replaces targeted ``nn.QuantizedLinear`` layers with
    ``LoRAQuantizedLinear``, keeping base weights quantized and
    adding trainable float LoRA matrices.

    The model should already be quantized (via ``quantize_model``
    or ``nn.quantize``) before calling this.

    Args:
        model: Quantized model to modify (in-place).
        rank: LoRA rank for all adapted layers.
        alpha: LoRA scaling factor.
        targets: Which submodules to target. Options:
            ``'attention'`` (q/k/v/o projections),
            ``'ffn'`` (gate/up/down projections).
            Default: ``['attention']``.
    """
    if targets is None:
        targets = ["attention"]

    def _maybe_qlora(path: str, m: nn.Module) -> nn.Module:
        if not isinstance(m, nn.QuantizedLinear):
            return m
        for target in targets:
            if target in path:
                return LoRAQuantizedLinear.from_quantized(m, rank, alpha)
        return m

    leaves = model.leaf_modules()
    leaves = tree_map_with_path(
        _maybe_qlora, leaves, is_leaf=nn.Module.is_module
    )
    model.update_modules(leaves)

    # Freeze everything, then unfreeze only LoRA params
    model.freeze()
    model.unfreeze(keys=["lora_A", "lora_B"])

Registry

lmxlab.core.registry.Registry

A typed registry mapping string names to factory functions.

Parameters:

Name Type Description Default
name str

Human-readable name for this registry.

required
Example

reg = Registrytype @reg.register('mha') ... class MHA: ... reg.get('mha')

Source code in src/lmxlab/core/registry.py
class Registry[T]:
    """A typed registry mapping string names to factory functions.

    Args:
        name: Human-readable name for this registry.

    Example:
        >>> reg = Registry[type]('attention')
        >>> @reg.register('mha')
        ... class MHA: ...
        >>> reg.get('mha')
        <class 'MHA'>
    """

    def __init__(self, name: str) -> None:
        self._name = name
        self._entries: dict[str, T] = {}

    @property
    def name(self) -> str:
        """Registry name."""
        return self._name

    @overload
    def register(self, key: str) -> Callable[[T], T]: ...

    @overload
    def register(self, key: str, value: T) -> T: ...

    def register(
        self, key: str, value: T | None = None
    ) -> T | Callable[[T], T]:
        """Register a component under a key.

        Can be used as a decorator factory or called directly:

            @registry.register('name')
            class MyClass: ...

            registry.register('name', MyClass)

        Args:
            key: String name for lookup.
            value: The component to register. If None, returns
                a decorator that registers the decorated object.

        Returns:
            The registered value, or a decorator if value is None.

        Raises:
            ValueError: If key is already registered.
        """
        if value is not None:
            if key in self._entries:
                raise ValueError(
                    f"{self._name} registry already has key {key!r}"
                )
            self._entries[key] = value
            return value

        def decorator(val: T) -> T:
            if key in self._entries:
                raise ValueError(
                    f"{self._name} registry already has key {key!r}"
                )
            self._entries[key] = val
            return val

        return decorator

    def get(self, key: str) -> T:
        """Look up a component by key.

        Args:
            key: Registered name.

        Returns:
            The registered component.

        Raises:
            KeyError: If key is not found.
        """
        if key not in self._entries:
            available = ", ".join(sorted(self._entries))
            raise KeyError(
                f"{self._name} registry has no key {key!r}. "
                f"Available: [{available}]"
            )
        return self._entries[key]

    def keys(self) -> list[str]:
        """List all registered keys."""
        return sorted(self._entries)

    def __contains__(self, key: str) -> bool:
        return key in self._entries

    def __repr__(self) -> str:
        keys = ", ".join(sorted(self._entries))
        return f"Registry({self._name!r}, keys=[{keys}])"

_entries = {} instance-attribute

_name = name instance-attribute

name property

Registry name.

__contains__(key)

Source code in src/lmxlab/core/registry.py
def __contains__(self, key: str) -> bool:
    return key in self._entries

__init__(name)

Source code in src/lmxlab/core/registry.py
def __init__(self, name: str) -> None:
    self._name = name
    self._entries: dict[str, T] = {}

__repr__()

Source code in src/lmxlab/core/registry.py
def __repr__(self) -> str:
    keys = ", ".join(sorted(self._entries))
    return f"Registry({self._name!r}, keys=[{keys}])"

get(key)

Look up a component by key.

Parameters:

Name Type Description Default
key str

Registered name.

required

Returns:

Type Description
T

The registered component.

Raises:

Type Description
KeyError

If key is not found.

Source code in src/lmxlab/core/registry.py
def get(self, key: str) -> T:
    """Look up a component by key.

    Args:
        key: Registered name.

    Returns:
        The registered component.

    Raises:
        KeyError: If key is not found.
    """
    if key not in self._entries:
        available = ", ".join(sorted(self._entries))
        raise KeyError(
            f"{self._name} registry has no key {key!r}. "
            f"Available: [{available}]"
        )
    return self._entries[key]

keys()

List all registered keys.

Source code in src/lmxlab/core/registry.py
def keys(self) -> list[str]:
    """List all registered keys."""
    return sorted(self._entries)

register(key, value=None)

register(key: str) -> Callable[[T], T]
register(key: str, value: T) -> T

Register a component under a key.

Can be used as a decorator factory or called directly:

@registry.register('name')
class MyClass: ...

registry.register('name', MyClass)

Parameters:

Name Type Description Default
key str

String name for lookup.

required
value T | None

The component to register. If None, returns a decorator that registers the decorated object.

None

Returns:

Type Description
T | Callable[[T], T]

The registered value, or a decorator if value is None.

Raises:

Type Description
ValueError

If key is already registered.

Source code in src/lmxlab/core/registry.py
def register(
    self, key: str, value: T | None = None
) -> T | Callable[[T], T]:
    """Register a component under a key.

    Can be used as a decorator factory or called directly:

        @registry.register('name')
        class MyClass: ...

        registry.register('name', MyClass)

    Args:
        key: String name for lookup.
        value: The component to register. If None, returns
            a decorator that registers the decorated object.

    Returns:
        The registered value, or a decorator if value is None.

    Raises:
        ValueError: If key is already registered.
    """
    if value is not None:
        if key in self._entries:
            raise ValueError(
                f"{self._name} registry already has key {key!r}"
            )
        self._entries[key] = value
        return value

    def decorator(val: T) -> T:
        if key in self._entries:
            raise ValueError(
                f"{self._name} registry already has key {key!r}"
            )
        self._entries[key] = val
        return val

    return decorator