Skip to content

Models

Language model base class and architecture config factories.

LanguageModel

lmxlab.models.base.LanguageModel

Bases: Module

Transformer language model assembled from config.

Uses ConfigurableBlock for each layer. Supports tied input/output embeddings and KV caching for generation.

Parameters:

Name Type Description Default
config ModelConfig

Full model configuration.

required
Source code in src/lmxlab/models/base.py
class LanguageModel(nn.Module):
    """Transformer language model assembled from config.

    Uses ConfigurableBlock for each layer. Supports tied
    input/output embeddings and KV caching for generation.

    Args:
        config: Full model configuration.
    """

    def __init__(self, config: ModelConfig) -> None:
        super().__init__()
        self.config = config
        block_cfg = config.block

        # Token embedding
        self.embed = nn.Embedding(config.vocab_size, block_cfg.d_model)

        # Embedding dropout
        self.embed_dropout = nn.Dropout(p=block_cfg.dropout)

        # Transformer blocks
        self.blocks = [
            ConfigurableBlock(config.get_block_config(i))
            for i in range(config.n_layers)
        ]

        # Sinusoidal PE applied once at model level
        self._sinusoidal = block_cfg.position == "sinusoidal"

        # Final norm
        final_norm_cls = norm_registry.get(block_cfg.norm)
        self.final_norm = final_norm_cls(block_cfg)

        # Output head (possibly tied with embedding)
        if not config.tie_embeddings:
            self.head = nn.Linear(
                block_cfg.d_model, config.vocab_size, bias=False
            )

        # Apply μP weight initialization scaling
        if config.mup_base_width is not None:
            self._apply_mup_init(config.width_mult)

    def __call__(
        self,
        x: mx.array,
        cache: list | None = None,
        return_hidden: bool = False,
    ) -> tuple[mx.array, list] | tuple[mx.array, list, mx.array]:
        """Forward pass.

        Args:
            x: Token IDs of shape (batch, seq_len).
            cache: Optional list of caches per layer. Cache
                types may be heterogeneous in hybrid models
                (KV tuples for attention, SSM state tuples
                for Mamba, None for identity layers).
            return_hidden: If True, also return hidden states
                from final_norm (before lm_head projection).
                Used by Multi-Token Prediction.

        Returns:
            Tuple of (logits, updated_caches) by default.
            If return_hidden is True, returns
            (logits, updated_caches, hidden_states).
        """
        h = self.embed_dropout(self.embed(x))

        # Sinusoidal position encoding (at model level)
        if self._sinusoidal:
            h = self.blocks[0].position(h)

        # Create causal mask
        T = h.shape[1]
        cache_len = 0
        if cache is not None:
            # Find cache_len from first attention-style KV cache.
            # KV caches are (K, V) where both are 4D arrays
            # (B, heads, seq, head_dim) with matching seq dim.
            # SSM caches differ: (ssm_state_4D, conv_state_3D).
            for layer_cache in cache:
                if (
                    layer_cache is not None
                    and isinstance(layer_cache, tuple)
                    and len(layer_cache) == 2
                    and isinstance(layer_cache[0], mx.array)
                    and isinstance(layer_cache[1], mx.array)
                    and layer_cache[0].ndim == 4
                    and layer_cache[1].ndim == 4
                ):
                    cache_len = layer_cache[0].shape[2]
                    break
        mask = _create_causal_mask(T, cache_len)

        new_caches: list = []
        for i, block in enumerate(self.blocks):
            layer_cache = cache[i] if cache is not None else None
            h, new_cache = block(h, mask=mask, cache=layer_cache)
            new_caches.append(new_cache)

        h = self.final_norm(h)

        # Output projection
        if self.config.tie_embeddings:
            logits = h @ self.embed.weight.T
        else:
            logits = self.head(h)

        # μP: scale logits by 1/width_mult
        if self.config.mup_base_width is not None:
            logits = logits / self.config.width_mult

        if return_hidden:
            return logits, new_caches, h

        return logits, new_caches

    def _apply_mup_init(self, width_mult: float) -> None:
        """Rescale hidden layer weights for μP.

        Scales hidden layer weight init by 1/√width_mult.
        Embedding weights are left unchanged (μP prescribes
        constant embedding init across widths).

        Args:
            width_mult: d_model / base_d_model ratio.
        """
        if width_mult == 1.0:
            return

        scale = 1.0 / math.sqrt(width_mult)

        # Rescale all block (hidden layer) weights
        for block in self.blocks:
            flat = mlx.utils.tree_flatten(block.parameters())
            updates = [
                (k, v * scale)
                for k, v in flat
                if v.ndim >= 2  # only weight matrices, not biases
            ]
            if updates:
                block.load_weights(updates, strict=False)

        # Rescale output head if untied
        if not self.config.tie_embeddings and hasattr(self, "head"):
            self.head.weight = self.head.weight * scale

    def count_parameters(self) -> int:
        """Count total trainable parameters."""
        leaves = mlx.utils.tree_flatten(self.parameters())
        return sum(p.size for _, p in leaves)

_sinusoidal = block_cfg.position == 'sinusoidal' instance-attribute

blocks = [(ConfigurableBlock(config.get_block_config(i))) for i in (range(config.n_layers))] instance-attribute

config = config instance-attribute

embed = nn.Embedding(config.vocab_size, block_cfg.d_model) instance-attribute

embed_dropout = nn.Dropout(p=(block_cfg.dropout)) instance-attribute

final_norm = final_norm_cls(block_cfg) instance-attribute

head = nn.Linear(block_cfg.d_model, config.vocab_size, bias=False) instance-attribute

__call__(x, cache=None, return_hidden=False)

Forward pass.

Parameters:

Name Type Description Default
x array

Token IDs of shape (batch, seq_len).

required
cache list | None

Optional list of caches per layer. Cache types may be heterogeneous in hybrid models (KV tuples for attention, SSM state tuples for Mamba, None for identity layers).

None
return_hidden bool

If True, also return hidden states from final_norm (before lm_head projection). Used by Multi-Token Prediction.

False

Returns:

Type Description
tuple[array, list] | tuple[array, list, array]

Tuple of (logits, updated_caches) by default.

tuple[array, list] | tuple[array, list, array]

If return_hidden is True, returns

tuple[array, list] | tuple[array, list, array]

(logits, updated_caches, hidden_states).

Source code in src/lmxlab/models/base.py
def __call__(
    self,
    x: mx.array,
    cache: list | None = None,
    return_hidden: bool = False,
) -> tuple[mx.array, list] | tuple[mx.array, list, mx.array]:
    """Forward pass.

    Args:
        x: Token IDs of shape (batch, seq_len).
        cache: Optional list of caches per layer. Cache
            types may be heterogeneous in hybrid models
            (KV tuples for attention, SSM state tuples
            for Mamba, None for identity layers).
        return_hidden: If True, also return hidden states
            from final_norm (before lm_head projection).
            Used by Multi-Token Prediction.

    Returns:
        Tuple of (logits, updated_caches) by default.
        If return_hidden is True, returns
        (logits, updated_caches, hidden_states).
    """
    h = self.embed_dropout(self.embed(x))

    # Sinusoidal position encoding (at model level)
    if self._sinusoidal:
        h = self.blocks[0].position(h)

    # Create causal mask
    T = h.shape[1]
    cache_len = 0
    if cache is not None:
        # Find cache_len from first attention-style KV cache.
        # KV caches are (K, V) where both are 4D arrays
        # (B, heads, seq, head_dim) with matching seq dim.
        # SSM caches differ: (ssm_state_4D, conv_state_3D).
        for layer_cache in cache:
            if (
                layer_cache is not None
                and isinstance(layer_cache, tuple)
                and len(layer_cache) == 2
                and isinstance(layer_cache[0], mx.array)
                and isinstance(layer_cache[1], mx.array)
                and layer_cache[0].ndim == 4
                and layer_cache[1].ndim == 4
            ):
                cache_len = layer_cache[0].shape[2]
                break
    mask = _create_causal_mask(T, cache_len)

    new_caches: list = []
    for i, block in enumerate(self.blocks):
        layer_cache = cache[i] if cache is not None else None
        h, new_cache = block(h, mask=mask, cache=layer_cache)
        new_caches.append(new_cache)

    h = self.final_norm(h)

    # Output projection
    if self.config.tie_embeddings:
        logits = h @ self.embed.weight.T
    else:
        logits = self.head(h)

    # μP: scale logits by 1/width_mult
    if self.config.mup_base_width is not None:
        logits = logits / self.config.width_mult

    if return_hidden:
        return logits, new_caches, h

    return logits, new_caches

__init__(config)

Source code in src/lmxlab/models/base.py
def __init__(self, config: ModelConfig) -> None:
    super().__init__()
    self.config = config
    block_cfg = config.block

    # Token embedding
    self.embed = nn.Embedding(config.vocab_size, block_cfg.d_model)

    # Embedding dropout
    self.embed_dropout = nn.Dropout(p=block_cfg.dropout)

    # Transformer blocks
    self.blocks = [
        ConfigurableBlock(config.get_block_config(i))
        for i in range(config.n_layers)
    ]

    # Sinusoidal PE applied once at model level
    self._sinusoidal = block_cfg.position == "sinusoidal"

    # Final norm
    final_norm_cls = norm_registry.get(block_cfg.norm)
    self.final_norm = final_norm_cls(block_cfg)

    # Output head (possibly tied with embedding)
    if not config.tie_embeddings:
        self.head = nn.Linear(
            block_cfg.d_model, config.vocab_size, bias=False
        )

    # Apply μP weight initialization scaling
    if config.mup_base_width is not None:
        self._apply_mup_init(config.width_mult)

_apply_mup_init(width_mult)

Rescale hidden layer weights for μP.

Scales hidden layer weight init by 1/√width_mult. Embedding weights are left unchanged (μP prescribes constant embedding init across widths).

Parameters:

Name Type Description Default
width_mult float

d_model / base_d_model ratio.

required
Source code in src/lmxlab/models/base.py
def _apply_mup_init(self, width_mult: float) -> None:
    """Rescale hidden layer weights for μP.

    Scales hidden layer weight init by 1/√width_mult.
    Embedding weights are left unchanged (μP prescribes
    constant embedding init across widths).

    Args:
        width_mult: d_model / base_d_model ratio.
    """
    if width_mult == 1.0:
        return

    scale = 1.0 / math.sqrt(width_mult)

    # Rescale all block (hidden layer) weights
    for block in self.blocks:
        flat = mlx.utils.tree_flatten(block.parameters())
        updates = [
            (k, v * scale)
            for k, v in flat
            if v.ndim >= 2  # only weight matrices, not biases
        ]
        if updates:
            block.load_weights(updates, strict=False)

    # Rescale output head if untied
    if not self.config.tie_embeddings and hasattr(self, "head"):
        self.head.weight = self.head.weight * scale

count_parameters()

Count total trainable parameters.

Source code in src/lmxlab/models/base.py
def count_parameters(self) -> int:
    """Count total trainable parameters."""
    leaves = mlx.utils.tree_flatten(self.parameters())
    return sum(p.size for _, p in leaves)

Generation

lmxlab.models.generate

Autoregressive text generation with sampling strategies.

generate(model, prompt, max_tokens=100, temperature=1.0, top_k=0, top_p=1.0, repetition_penalty=1.0, stop_tokens=None)

Generate tokens autoregressively with KV caching.

Parameters:

Name Type Description Default
model LanguageModel

Language model to generate from.

required
prompt array

Input token IDs of shape (batch, prompt_len).

required
max_tokens int

Maximum number of new tokens to generate.

100
temperature float

Sampling temperature (0 = greedy).

1.0
top_k int

If > 0, only sample from top-k tokens.

0
top_p float

If < 1.0, use nucleus sampling.

1.0
repetition_penalty float

Penalty for repeating tokens (> 1.0 discourages repetition, 1.0 = no effect).

1.0
stop_tokens list[int] | None

List of token IDs that stop generation. When any batch element generates a stop token, generation stops for all.

None

Returns:

Type Description
array

Generated token IDs of shape

array

(batch, prompt_len + generated_len).

Source code in src/lmxlab/models/generate.py
def generate(
    model: LanguageModel,
    prompt: mx.array,
    max_tokens: int = 100,
    temperature: float = 1.0,
    top_k: int = 0,
    top_p: float = 1.0,
    repetition_penalty: float = 1.0,
    stop_tokens: list[int] | None = None,
) -> mx.array:
    """Generate tokens autoregressively with KV caching.

    Args:
        model: Language model to generate from.
        prompt: Input token IDs of shape (batch, prompt_len).
        max_tokens: Maximum number of new tokens to generate.
        temperature: Sampling temperature (0 = greedy).
        top_k: If > 0, only sample from top-k tokens.
        top_p: If < 1.0, use nucleus sampling.
        repetition_penalty: Penalty for repeating tokens (> 1.0
            discourages repetition, 1.0 = no effect).
        stop_tokens: List of token IDs that stop generation.
            When any batch element generates a stop token,
            generation stops for all.

    Returns:
        Generated token IDs of shape
        (batch, prompt_len + generated_len).
    """
    tokens = prompt
    cache = None
    stop_set = set(stop_tokens) if stop_tokens else set()

    # Process prompt (prefill)
    logits, cache = model(tokens, cache=cache)
    mx.eval(logits, *[c for pair in cache for c in pair])

    generated: list[mx.array] = []
    for _ in range(max_tokens):
        next_logits = logits[:, -1, :]

        if repetition_penalty != 1.0:
            next_logits = _apply_repetition_penalty(
                next_logits, generated, repetition_penalty
            )

        next_token = _sample_next(next_logits, temperature, top_k, top_p)
        mx.eval(next_token)

        # Check stop tokens
        if stop_set:
            token_val = next_token[0, 0].item()
            if token_val in stop_set:
                break

        generated.append(next_token)

        logits, cache = model(next_token, cache=cache)
        mx.eval(logits, *[c for pair in cache for c in pair])

    if generated:
        all_generated = mx.concatenate(generated, axis=1)
        return mx.concatenate([prompt, all_generated], axis=1)
    return prompt

stream_generate(model, prompt, max_tokens=100, temperature=1.0, top_k=0, top_p=1.0, repetition_penalty=1.0, stop_tokens=None)

Generate tokens one at a time, yielding each as produced.

This is the standard interface for interactive/streaming applications. Each token is yielded immediately after generation, enabling real-time display.

Parameters:

Name Type Description Default
model LanguageModel

Language model to generate from.

required
prompt array

Input token IDs of shape (1, prompt_len).

required
max_tokens int

Maximum number of new tokens.

100
temperature float

Sampling temperature (0 = greedy).

1.0
top_k int

If > 0, only sample from top-k.

0
top_p float

If < 1.0, use nucleus sampling.

1.0
repetition_penalty float

Penalty for repeating tokens.

1.0
stop_tokens list[int] | None

Token IDs that stop generation.

None

Yields:

Type Description
int

Generated token IDs one at a time.

Source code in src/lmxlab/models/generate.py
def stream_generate(
    model: LanguageModel,
    prompt: mx.array,
    max_tokens: int = 100,
    temperature: float = 1.0,
    top_k: int = 0,
    top_p: float = 1.0,
    repetition_penalty: float = 1.0,
    stop_tokens: list[int] | None = None,
) -> Iterator[int]:
    """Generate tokens one at a time, yielding each as produced.

    This is the standard interface for interactive/streaming
    applications. Each token is yielded immediately after
    generation, enabling real-time display.

    Args:
        model: Language model to generate from.
        prompt: Input token IDs of shape (1, prompt_len).
        max_tokens: Maximum number of new tokens.
        temperature: Sampling temperature (0 = greedy).
        top_k: If > 0, only sample from top-k.
        top_p: If < 1.0, use nucleus sampling.
        repetition_penalty: Penalty for repeating tokens.
        stop_tokens: Token IDs that stop generation.

    Yields:
        Generated token IDs one at a time.
    """
    cache = None
    stop_set = set(stop_tokens) if stop_tokens else set()

    # Prefill
    logits, cache = model(prompt, cache=cache)
    mx.eval(logits, *[c for pair in cache for c in pair])

    generated: list[mx.array] = []
    for _ in range(max_tokens):
        next_logits = logits[:, -1, :]

        if repetition_penalty != 1.0:
            next_logits = _apply_repetition_penalty(
                next_logits, generated, repetition_penalty
            )

        next_token = _sample_next(next_logits, temperature, top_k, top_p)
        mx.eval(next_token)

        token_id = next_token[0, 0].item()

        if stop_set and token_id in stop_set:
            return

        generated.append(next_token)
        yield token_id

        logits, cache = model(next_token, cache=cache)
        mx.eval(logits, *[c for pair in cache for c in pair])

Config Factories

Each factory returns a ModelConfig that builds the corresponding architecture when passed to LanguageModel.

GPT

lmxlab.models.gpt.gpt_config(vocab_size=50257, d_model=768, n_heads=12, n_layers=12, d_ff=3072, max_seq_len=1024, tie_embeddings=True, dropout=0.0, mup_base_width=None)

Create a GPT-style model configuration.

GPT uses: LayerNorm, standard MHA, standard FFN (GELU), sinusoidal positional encoding, pre-norm, bias everywhere.

Parameters:

Name Type Description Default
vocab_size int

Vocabulary size (default: GPT-2 BPE vocab).

50257
d_model int

Hidden dimension.

768
n_heads int

Number of attention heads.

12
n_layers int

Number of transformer layers.

12
d_ff int

Feed-forward intermediate dimension.

3072
max_seq_len int

Maximum sequence length.

1024
tie_embeddings bool

Whether to tie input/output embeddings.

True
dropout float

Dropout rate.

0.0
mup_base_width int | None

Base width for μP. When set, enables μP attention scaling and logit scaling.

None

Returns:

Type Description
ModelConfig

ModelConfig for a GPT-style model.

Source code in src/lmxlab/models/gpt.py
def gpt_config(
    vocab_size: int = 50257,
    d_model: int = 768,
    n_heads: int = 12,
    n_layers: int = 12,
    d_ff: int = 3072,
    max_seq_len: int = 1024,
    tie_embeddings: bool = True,
    dropout: float = 0.0,
    mup_base_width: int | None = None,
) -> ModelConfig:
    """Create a GPT-style model configuration.

    GPT uses: LayerNorm, standard MHA, standard FFN (GELU),
    sinusoidal positional encoding, pre-norm, bias everywhere.

    Args:
        vocab_size: Vocabulary size (default: GPT-2 BPE vocab).
        d_model: Hidden dimension.
        n_heads: Number of attention heads.
        n_layers: Number of transformer layers.
        d_ff: Feed-forward intermediate dimension.
        max_seq_len: Maximum sequence length.
        tie_embeddings: Whether to tie input/output embeddings.
        dropout: Dropout rate.
        mup_base_width: Base width for μP. When set, enables
            μP attention scaling and logit scaling.

    Returns:
        ModelConfig for a GPT-style model.
    """
    block = BlockConfig(
        attention="mha",
        ffn="standard",
        norm="layer_norm",
        position="sinusoidal",
        d_model=d_model,
        n_heads=n_heads,
        d_ff=d_ff,
        bias=True,
        dropout=dropout,
        max_seq_len=max_seq_len,
        pre_norm=True,
        mup=mup_base_width is not None,
    )
    return ModelConfig(
        block=block,
        vocab_size=vocab_size,
        n_layers=n_layers,
        tie_embeddings=tie_embeddings,
        mup_base_width=mup_base_width,
    )

lmxlab.models.gpt.gpt_tiny()

Tiny GPT for testing (d=64, 2 layers, 2 heads).

Source code in src/lmxlab/models/gpt.py
def gpt_tiny() -> ModelConfig:
    """Tiny GPT for testing (d=64, 2 layers, 2 heads)."""
    return gpt_config(
        vocab_size=256,
        d_model=64,
        n_heads=2,
        n_layers=2,
        d_ff=128,
        max_seq_len=128,
    )

lmxlab.models.gpt.gpt_small()

GPT-small (~125M params).

Source code in src/lmxlab/models/gpt.py
def gpt_small() -> ModelConfig:
    """GPT-small (~125M params)."""
    return gpt_config()

lmxlab.models.gpt.gpt_medium()

GPT-medium (~350M params).

Source code in src/lmxlab/models/gpt.py
def gpt_medium() -> ModelConfig:
    """GPT-medium (~350M params)."""
    return gpt_config(
        d_model=1024,
        n_heads=16,
        n_layers=24,
        d_ff=4096,
    )

LLaMA

lmxlab.models.llama.llama_config(vocab_size=32000, d_model=4096, n_heads=32, n_kv_heads=8, n_layers=32, d_ff=11008, max_seq_len=4096, rope_theta=10000.0, tie_embeddings=False, dropout=0.0, mup_base_width=None)

Create a LLaMA-style model configuration.

LLaMA uses: RMSNorm, GQA, GatedFFN (SwiGLU), RoPE, pre-norm, no bias.

Parameters:

Name Type Description Default
vocab_size int

Vocabulary size.

32000
d_model int

Hidden dimension.

4096
n_heads int

Number of query heads.

32
n_kv_heads int

Number of KV heads (for GQA).

8
n_layers int

Number of transformer layers.

32
d_ff int

Feed-forward intermediate dimension.

11008
max_seq_len int

Maximum sequence length.

4096
rope_theta float

RoPE base frequency.

10000.0
tie_embeddings bool

Whether to tie input/output embeddings.

False
dropout float

Dropout rate.

0.0
mup_base_width int | None

Base width for μP. When set, enables μP attention scaling and logit scaling.

None

Returns:

Type Description
ModelConfig

ModelConfig for a LLaMA-style model.

Source code in src/lmxlab/models/llama.py
def llama_config(
    vocab_size: int = 32000,
    d_model: int = 4096,
    n_heads: int = 32,
    n_kv_heads: int = 8,
    n_layers: int = 32,
    d_ff: int = 11008,
    max_seq_len: int = 4096,
    rope_theta: float = 10000.0,
    tie_embeddings: bool = False,
    dropout: float = 0.0,
    mup_base_width: int | None = None,
) -> ModelConfig:
    """Create a LLaMA-style model configuration.

    LLaMA uses: RMSNorm, GQA, GatedFFN (SwiGLU), RoPE,
    pre-norm, no bias.

    Args:
        vocab_size: Vocabulary size.
        d_model: Hidden dimension.
        n_heads: Number of query heads.
        n_kv_heads: Number of KV heads (for GQA).
        n_layers: Number of transformer layers.
        d_ff: Feed-forward intermediate dimension.
        max_seq_len: Maximum sequence length.
        rope_theta: RoPE base frequency.
        tie_embeddings: Whether to tie input/output embeddings.
        dropout: Dropout rate.
        mup_base_width: Base width for μP. When set, enables
            μP attention scaling and logit scaling.

    Returns:
        ModelConfig for a LLaMA-style model.
    """
    block = BlockConfig(
        attention="gqa",
        ffn="gated",
        norm="rms_norm",
        position="rope",
        d_model=d_model,
        n_heads=n_heads,
        n_kv_heads=n_kv_heads,
        d_ff=d_ff,
        bias=False,
        dropout=dropout,
        rope_theta=rope_theta,
        max_seq_len=max_seq_len,
        pre_norm=True,
        mup=mup_base_width is not None,
    )
    return ModelConfig(
        block=block,
        vocab_size=vocab_size,
        n_layers=n_layers,
        tie_embeddings=tie_embeddings,
        mup_base_width=mup_base_width,
    )

lmxlab.models.llama.llama_tiny()

Tiny LLaMA for testing (d=64, 2 layers, 4 heads, 2 kv).

Source code in src/lmxlab/models/llama.py
def llama_tiny() -> ModelConfig:
    """Tiny LLaMA for testing (d=64, 2 layers, 4 heads, 2 kv)."""
    return llama_config(
        vocab_size=256,
        d_model=64,
        n_heads=4,
        n_kv_heads=2,
        n_layers=2,
        d_ff=128,
        max_seq_len=128,
    )

lmxlab.models.llama.llama_7b()

LLaMA-7B configuration.

Source code in src/lmxlab/models/llama.py
def llama_7b() -> ModelConfig:
    """LLaMA-7B configuration."""
    return llama_config()

lmxlab.models.llama.llama_13b()

LLaMA-13B configuration.

Source code in src/lmxlab/models/llama.py
def llama_13b() -> ModelConfig:
    """LLaMA-13B configuration."""
    return llama_config(
        d_model=5120,
        n_heads=40,
        n_kv_heads=10,
        n_layers=40,
        d_ff=13824,
    )

Gemma

lmxlab.models.gemma.gemma_config(vocab_size=256000, d_model=2048, n_heads=8, n_kv_heads=1, n_layers=18, d_ff=16384, max_seq_len=8192, rope_theta=10000.0, tie_embeddings=True)

Create a Gemma-style model configuration.

Gemma uses: RMSNorm, GQA (multi-query), GatedFFN (GeGLU), RoPE, pre-norm, no bias, tied embeddings.

Parameters:

Name Type Description Default
vocab_size int

Vocabulary size.

256000
d_model int

Hidden dimension.

2048
n_heads int

Number of query heads.

8
n_kv_heads int

Number of KV heads.

1
n_layers int

Number of transformer layers.

18
d_ff int

Feed-forward intermediate dimension.

16384
max_seq_len int

Maximum sequence length.

8192
rope_theta float

RoPE base frequency.

10000.0
tie_embeddings bool

Whether to tie embeddings.

True

Returns:

Type Description
ModelConfig

ModelConfig for a Gemma-style model.

Source code in src/lmxlab/models/gemma.py
def gemma_config(
    vocab_size: int = 256000,
    d_model: int = 2048,
    n_heads: int = 8,
    n_kv_heads: int = 1,
    n_layers: int = 18,
    d_ff: int = 16384,
    max_seq_len: int = 8192,
    rope_theta: float = 10000.0,
    tie_embeddings: bool = True,
) -> ModelConfig:
    """Create a Gemma-style model configuration.

    Gemma uses: RMSNorm, GQA (multi-query), GatedFFN (GeGLU),
    RoPE, pre-norm, no bias, tied embeddings.

    Args:
        vocab_size: Vocabulary size.
        d_model: Hidden dimension.
        n_heads: Number of query heads.
        n_kv_heads: Number of KV heads.
        n_layers: Number of transformer layers.
        d_ff: Feed-forward intermediate dimension.
        max_seq_len: Maximum sequence length.
        rope_theta: RoPE base frequency.
        tie_embeddings: Whether to tie embeddings.

    Returns:
        ModelConfig for a Gemma-style model.
    """
    block = BlockConfig(
        attention="gqa",
        ffn="gated",
        norm="rms_norm",
        position="rope",
        d_model=d_model,
        n_heads=n_heads,
        n_kv_heads=n_kv_heads,
        d_ff=d_ff,
        bias=False,
        rope_theta=rope_theta,
        max_seq_len=max_seq_len,
        pre_norm=True,
    )
    return ModelConfig(
        block=block,
        vocab_size=vocab_size,
        n_layers=n_layers,
        tie_embeddings=tie_embeddings,
    )

lmxlab.models.gemma.gemma_tiny()

Tiny Gemma for testing.

Source code in src/lmxlab/models/gemma.py
def gemma_tiny() -> ModelConfig:
    """Tiny Gemma for testing."""
    return gemma_config(
        vocab_size=256,
        d_model=64,
        n_heads=4,
        n_kv_heads=1,
        n_layers=2,
        d_ff=128,
        max_seq_len=128,
    )

Gemma 3

lmxlab.models.gemma3.gemma3_config(vocab_size=256000, d_model=2048, n_heads=8, n_kv_heads=4, n_layers=26, d_ff=16384, max_seq_len=8192, rope_theta=10000.0, window_size=4096, global_every=6, tie_embeddings=True)

Create a Gemma 3-style model configuration.

Gemma 3 interleaves local (sliding window) and global attention layers. Every global_every-th layer (0-indexed, i.e. layers 5, 11, 17, ...) uses full global GQA; all other layers use sliding window GQA with the given window size.

Uses: RMSNorm, GatedFFN (GeGLU), RoPE, pre-norm, no bias, tied embeddings.

Parameters:

Name Type Description Default
vocab_size int

Vocabulary size.

256000
d_model int

Hidden dimension.

2048
n_heads int

Number of query heads.

8
n_kv_heads int

Number of KV heads.

4
n_layers int

Number of transformer layers.

26
d_ff int

Feed-forward intermediate dimension.

16384
max_seq_len int

Maximum sequence length.

8192
rope_theta float

RoPE base frequency.

10000.0
window_size int

Sliding window size for local layers.

4096
global_every int

Place a global attention layer every N layers (1-indexed: layer global_every-1, 2*global_every-1, ...).

6
tie_embeddings bool

Whether to tie embeddings.

True

Returns:

Type Description
ModelConfig

ModelConfig for a Gemma 3-style model.

Source code in src/lmxlab/models/gemma3.py
def gemma3_config(
    vocab_size: int = 256000,
    d_model: int = 2048,
    n_heads: int = 8,
    n_kv_heads: int = 4,
    n_layers: int = 26,
    d_ff: int = 16384,
    max_seq_len: int = 8192,
    rope_theta: float = 10000.0,
    window_size: int = 4096,
    global_every: int = 6,
    tie_embeddings: bool = True,
) -> ModelConfig:
    """Create a Gemma 3-style model configuration.

    Gemma 3 interleaves local (sliding window) and global
    attention layers. Every ``global_every``-th layer (0-indexed,
    i.e. layers 5, 11, 17, ...) uses full global GQA; all other
    layers use sliding window GQA with the given window size.

    Uses: RMSNorm, GatedFFN (GeGLU), RoPE, pre-norm, no bias,
    tied embeddings.

    Args:
        vocab_size: Vocabulary size.
        d_model: Hidden dimension.
        n_heads: Number of query heads.
        n_kv_heads: Number of KV heads.
        n_layers: Number of transformer layers.
        d_ff: Feed-forward intermediate dimension.
        max_seq_len: Maximum sequence length.
        rope_theta: RoPE base frequency.
        window_size: Sliding window size for local layers.
        global_every: Place a global attention layer every N
            layers (1-indexed: layer global_every-1, 2*global_every-1, ...).
        tie_embeddings: Whether to tie embeddings.

    Returns:
        ModelConfig for a Gemma 3-style model.
    """
    # Default block uses sliding window (most common)
    default_block = BlockConfig(
        attention="sliding_window_gqa",
        ffn="gated",
        norm="rms_norm",
        position="rope",
        d_model=d_model,
        n_heads=n_heads,
        n_kv_heads=n_kv_heads,
        d_ff=d_ff,
        bias=False,
        rope_theta=rope_theta,
        max_seq_len=max_seq_len,
        pre_norm=True,
        window_size=window_size,
    )

    # Global block uses standard GQA
    global_block = BlockConfig(
        attention="gqa",
        ffn="gated",
        norm="rms_norm",
        position="rope",
        d_model=d_model,
        n_heads=n_heads,
        n_kv_heads=n_kv_heads,
        d_ff=d_ff,
        bias=False,
        rope_theta=rope_theta,
        max_seq_len=max_seq_len,
        pre_norm=True,
    )

    # Build per-layer configs: every global_every-th layer
    # is global (0-indexed: layers global_every-1, 2*global_every-1, ...)
    block_configs = tuple(
        global_block if (i + 1) % global_every == 0 else default_block
        for i in range(n_layers)
    )

    return ModelConfig(
        block=default_block,
        vocab_size=vocab_size,
        n_layers=n_layers,
        tie_embeddings=tie_embeddings,
        block_configs=block_configs,
    )

lmxlab.models.gemma3.gemma3_tiny()

Tiny Gemma 3 for testing (4 layers, global every 4th).

Source code in src/lmxlab/models/gemma3.py
def gemma3_tiny() -> ModelConfig:
    """Tiny Gemma 3 for testing (4 layers, global every 4th)."""
    return gemma3_config(
        vocab_size=256,
        d_model=64,
        n_heads=4,
        n_kv_heads=2,
        n_layers=4,
        d_ff=128,
        max_seq_len=128,
        window_size=16,
        global_every=4,
    )

Qwen

lmxlab.models.qwen.qwen_config(vocab_size=151936, d_model=4096, n_heads=32, n_kv_heads=32, n_layers=32, d_ff=11008, max_seq_len=32768, rope_theta=1000000.0, tie_embeddings=False)

Create a Qwen-style model configuration.

Qwen uses: RMSNorm, GQA, GatedFFN (SwiGLU), RoPE (high theta for long context), pre-norm, bias in QKV.

Parameters:

Name Type Description Default
vocab_size int

Vocabulary size.

151936
d_model int

Hidden dimension.

4096
n_heads int

Number of query heads.

32
n_kv_heads int

Number of KV heads.

32
n_layers int

Number of transformer layers.

32
d_ff int

Feed-forward intermediate dimension.

11008
max_seq_len int

Maximum sequence length.

32768
rope_theta float

RoPE base frequency.

1000000.0
tie_embeddings bool

Whether to tie embeddings.

False

Returns:

Type Description
ModelConfig

ModelConfig for a Qwen-style model.

Source code in src/lmxlab/models/qwen.py
def qwen_config(
    vocab_size: int = 151936,
    d_model: int = 4096,
    n_heads: int = 32,
    n_kv_heads: int = 32,
    n_layers: int = 32,
    d_ff: int = 11008,
    max_seq_len: int = 32768,
    rope_theta: float = 1000000.0,
    tie_embeddings: bool = False,
) -> ModelConfig:
    """Create a Qwen-style model configuration.

    Qwen uses: RMSNorm, GQA, GatedFFN (SwiGLU), RoPE
    (high theta for long context), pre-norm, bias in QKV.

    Args:
        vocab_size: Vocabulary size.
        d_model: Hidden dimension.
        n_heads: Number of query heads.
        n_kv_heads: Number of KV heads.
        n_layers: Number of transformer layers.
        d_ff: Feed-forward intermediate dimension.
        max_seq_len: Maximum sequence length.
        rope_theta: RoPE base frequency.
        tie_embeddings: Whether to tie embeddings.

    Returns:
        ModelConfig for a Qwen-style model.
    """
    block = BlockConfig(
        attention="gqa",
        ffn="gated",
        norm="rms_norm",
        position="rope",
        d_model=d_model,
        n_heads=n_heads,
        n_kv_heads=n_kv_heads,
        d_ff=d_ff,
        bias=True,
        rope_theta=rope_theta,
        max_seq_len=max_seq_len,
        pre_norm=True,
    )
    return ModelConfig(
        block=block,
        vocab_size=vocab_size,
        n_layers=n_layers,
        tie_embeddings=tie_embeddings,
    )

lmxlab.models.qwen.qwen_tiny()

Tiny Qwen for testing.

Source code in src/lmxlab/models/qwen.py
def qwen_tiny() -> ModelConfig:
    """Tiny Qwen for testing."""
    return qwen_config(
        vocab_size=256,
        d_model=64,
        n_heads=4,
        n_kv_heads=2,
        n_layers=2,
        d_ff=128,
        max_seq_len=128,
    )

Mixtral

lmxlab.models.mixtral.mixtral_config(vocab_size=32000, d_model=4096, n_heads=32, n_kv_heads=8, n_layers=32, d_ff=14336, n_experts=8, top_k_experts=2, max_seq_len=32768, rope_theta=1000000.0, tie_embeddings=False)

Create a Mixtral-style model configuration.

Mixtral uses GQA attention with MoE FFN: each token is routed to top-k of n_experts GatedFFN (SwiGLU) experts.

Parameters:

Name Type Description Default
vocab_size int

Vocabulary size.

32000
d_model int

Hidden dimension.

4096
n_heads int

Number of query heads.

32
n_kv_heads int

Number of KV heads.

8
n_layers int

Number of transformer layers.

32
d_ff int

Per-expert feed-forward dimension.

14336
n_experts int

Number of expert FFNs.

8
top_k_experts int

Experts per token.

2
max_seq_len int

Maximum sequence length.

32768
rope_theta float

RoPE base frequency.

1000000.0
tie_embeddings bool

Whether to tie embeddings.

False

Returns:

Type Description
ModelConfig

ModelConfig for a Mixtral-style model.

Source code in src/lmxlab/models/mixtral.py
def mixtral_config(
    vocab_size: int = 32000,
    d_model: int = 4096,
    n_heads: int = 32,
    n_kv_heads: int = 8,
    n_layers: int = 32,
    d_ff: int = 14336,
    n_experts: int = 8,
    top_k_experts: int = 2,
    max_seq_len: int = 32768,
    rope_theta: float = 1000000.0,
    tie_embeddings: bool = False,
) -> ModelConfig:
    """Create a Mixtral-style model configuration.

    Mixtral uses GQA attention with MoE FFN: each token is
    routed to top-k of n_experts GatedFFN (SwiGLU) experts.

    Args:
        vocab_size: Vocabulary size.
        d_model: Hidden dimension.
        n_heads: Number of query heads.
        n_kv_heads: Number of KV heads.
        n_layers: Number of transformer layers.
        d_ff: Per-expert feed-forward dimension.
        n_experts: Number of expert FFNs.
        top_k_experts: Experts per token.
        max_seq_len: Maximum sequence length.
        rope_theta: RoPE base frequency.
        tie_embeddings: Whether to tie embeddings.

    Returns:
        ModelConfig for a Mixtral-style model.
    """
    block = BlockConfig(
        attention="gqa",
        ffn="moe",
        norm="rms_norm",
        position="rope",
        d_model=d_model,
        n_heads=n_heads,
        n_kv_heads=n_kv_heads,
        d_ff=d_ff,
        n_experts=n_experts,
        top_k_experts=top_k_experts,
        bias=False,
        rope_theta=rope_theta,
        max_seq_len=max_seq_len,
        pre_norm=True,
    )
    return ModelConfig(
        block=block,
        vocab_size=vocab_size,
        n_layers=n_layers,
        tie_embeddings=tie_embeddings,
    )

lmxlab.models.mixtral.mixtral_tiny()

Tiny Mixtral for testing (with MoE).

Source code in src/lmxlab/models/mixtral.py
def mixtral_tiny() -> ModelConfig:
    """Tiny Mixtral for testing (with MoE)."""
    return mixtral_config(
        vocab_size=256,
        d_model=64,
        n_heads=4,
        n_kv_heads=2,
        n_layers=2,
        d_ff=128,
        n_experts=4,
        top_k_experts=2,
        max_seq_len=128,
    )

Qwen 3.5 (Hybrid DeltaNet)

lmxlab.models.qwen35.qwen35_config(vocab_size=151936, d_model=2048, n_heads=16, n_kv_heads=4, n_layers=28, d_ff=5504, max_seq_len=32768, rope_theta=1000000.0, global_every=4, tie_embeddings=False)

Create a Qwen 3.5-style model configuration.

Qwen 3.5 interleaves Gated DeltaNet (linear attention) and standard GQA layers. Every global_every-th layer uses full GQA; all other layers use Gated DeltaNet.

Uses: RMSNorm, GatedFFN (SwiGLU), RoPE (for GQA layers), short causal convolutions (for DeltaNet layers), no bias.

The 3:1 hybrid ratio (75% DeltaNet, 25% GQA) balances efficiency and expressiveness: - DeltaNet: O(d^2) per token, fixed-size state, no KV cache - GQA: O(n^2) per token, growing KV cache, global context

Parameters:

Name Type Description Default
vocab_size int

Vocabulary size.

151936
d_model int

Hidden dimension.

2048
n_heads int

Number of attention heads.

16
n_kv_heads int

Number of KV heads (for GQA layers).

4
n_layers int

Number of transformer layers.

28
d_ff int

Feed-forward intermediate dimension.

5504
max_seq_len int

Maximum sequence length.

32768
rope_theta float

RoPE base frequency (for GQA layers).

1000000.0
global_every int

Place a GQA layer every N layers.

4
tie_embeddings bool

Whether to tie embeddings.

False

Returns:

Type Description
ModelConfig

ModelConfig for a Qwen 3.5-style model.

Source code in src/lmxlab/models/qwen35.py
def qwen35_config(
    vocab_size: int = 151936,
    d_model: int = 2048,
    n_heads: int = 16,
    n_kv_heads: int = 4,
    n_layers: int = 28,
    d_ff: int = 5504,
    max_seq_len: int = 32768,
    rope_theta: float = 1000000.0,
    global_every: int = 4,
    tie_embeddings: bool = False,
) -> ModelConfig:
    """Create a Qwen 3.5-style model configuration.

    Qwen 3.5 interleaves Gated DeltaNet (linear attention) and
    standard GQA layers. Every ``global_every``-th layer uses
    full GQA; all other layers use Gated DeltaNet.

    Uses: RMSNorm, GatedFFN (SwiGLU), RoPE (for GQA layers),
    short causal convolutions (for DeltaNet layers), no bias.

    The 3:1 hybrid ratio (75% DeltaNet, 25% GQA) balances
    efficiency and expressiveness:
    - DeltaNet: O(d^2) per token, fixed-size state, no KV cache
    - GQA: O(n^2) per token, growing KV cache, global context

    Args:
        vocab_size: Vocabulary size.
        d_model: Hidden dimension.
        n_heads: Number of attention heads.
        n_kv_heads: Number of KV heads (for GQA layers).
        n_layers: Number of transformer layers.
        d_ff: Feed-forward intermediate dimension.
        max_seq_len: Maximum sequence length.
        rope_theta: RoPE base frequency (for GQA layers).
        global_every: Place a GQA layer every N layers.
        tie_embeddings: Whether to tie embeddings.

    Returns:
        ModelConfig for a Qwen 3.5-style model.
    """
    # DeltaNet block (majority of layers)
    deltanet_block = BlockConfig(
        attention="gated_deltanet",
        ffn="gated",
        norm="rms_norm",
        position="none",
        d_model=d_model,
        n_heads=n_heads,
        d_ff=d_ff,
        bias=False,
        max_seq_len=max_seq_len,
        pre_norm=True,
        use_short_conv=True,
        conv_kernel_size=4,
    )

    # GQA block (every global_every-th layer)
    gqa_block = BlockConfig(
        attention="gqa",
        ffn="gated",
        norm="rms_norm",
        position="rope",
        d_model=d_model,
        n_heads=n_heads,
        n_kv_heads=n_kv_heads,
        d_ff=d_ff,
        bias=False,
        rope_theta=rope_theta,
        max_seq_len=max_seq_len,
        pre_norm=True,
    )

    # Build per-layer configs: 3:1 DeltaNet:GQA pattern
    block_configs = tuple(
        gqa_block if (i + 1) % global_every == 0 else deltanet_block
        for i in range(n_layers)
    )

    return ModelConfig(
        block=deltanet_block,
        vocab_size=vocab_size,
        n_layers=n_layers,
        tie_embeddings=tie_embeddings,
        block_configs=block_configs,
    )

lmxlab.models.qwen35.qwen35_tiny()

Tiny Qwen 3.5 for testing (4 layers, global every 4th).

Source code in src/lmxlab/models/qwen35.py
def qwen35_tiny() -> ModelConfig:
    """Tiny Qwen 3.5 for testing (4 layers, global every 4th)."""
    return qwen35_config(
        vocab_size=256,
        d_model=64,
        n_heads=4,
        n_kv_heads=2,
        n_layers=4,
        d_ff=128,
        max_seq_len=128,
        global_every=4,
    )

DeepSeek

lmxlab.models.deepseek.deepseek_config(vocab_size=102400, d_model=5120, n_heads=128, n_layers=60, d_ff=12288, kv_lora_rank=512, q_lora_rank=1536, rope_dim=64, max_seq_len=4096, rope_theta=10000.0, tie_embeddings=False)

Create a DeepSeek V2-style model configuration.

DeepSeek V2 uses: RMSNorm, MLA, GatedFFN (SwiGLU), decoupled RoPE, pre-norm, no bias.

Parameters:

Name Type Description Default
vocab_size int

Vocabulary size.

102400
d_model int

Hidden dimension.

5120
n_heads int

Number of attention heads.

128
n_layers int

Number of transformer layers.

60
d_ff int

Feed-forward intermediate dimension.

12288
kv_lora_rank int

Latent dimension for KV compression.

512
q_lora_rank int

Latent dimension for Q compression.

1536
rope_dim int

Number of head dims for RoPE.

64
max_seq_len int

Maximum sequence length.

4096
rope_theta float

RoPE base frequency.

10000.0
tie_embeddings bool

Whether to tie embeddings.

False

Returns:

Type Description
ModelConfig

ModelConfig for a DeepSeek V2-style model.

Source code in src/lmxlab/models/deepseek.py
def deepseek_config(
    vocab_size: int = 102400,
    d_model: int = 5120,
    n_heads: int = 128,
    n_layers: int = 60,
    d_ff: int = 12288,
    kv_lora_rank: int = 512,
    q_lora_rank: int = 1536,
    rope_dim: int = 64,
    max_seq_len: int = 4096,
    rope_theta: float = 10000.0,
    tie_embeddings: bool = False,
) -> ModelConfig:
    """Create a DeepSeek V2-style model configuration.

    DeepSeek V2 uses: RMSNorm, MLA, GatedFFN (SwiGLU),
    decoupled RoPE, pre-norm, no bias.

    Args:
        vocab_size: Vocabulary size.
        d_model: Hidden dimension.
        n_heads: Number of attention heads.
        n_layers: Number of transformer layers.
        d_ff: Feed-forward intermediate dimension.
        kv_lora_rank: Latent dimension for KV compression.
        q_lora_rank: Latent dimension for Q compression.
        rope_dim: Number of head dims for RoPE.
        max_seq_len: Maximum sequence length.
        rope_theta: RoPE base frequency.
        tie_embeddings: Whether to tie embeddings.

    Returns:
        ModelConfig for a DeepSeek V2-style model.
    """
    block = BlockConfig(
        attention="mla",
        ffn="gated",
        norm="rms_norm",
        position="rope",
        d_model=d_model,
        n_heads=n_heads,
        d_ff=d_ff,
        bias=False,
        rope_theta=rope_theta,
        max_seq_len=max_seq_len,
        pre_norm=True,
        kv_lora_rank=kv_lora_rank,
        q_lora_rank=q_lora_rank,
        rope_dim=rope_dim,
    )
    return ModelConfig(
        block=block,
        vocab_size=vocab_size,
        n_layers=n_layers,
        tie_embeddings=tie_embeddings,
    )

lmxlab.models.deepseek.deepseek_tiny()

Tiny DeepSeek for testing (d=64, 2 layers, 4 heads).

Source code in src/lmxlab/models/deepseek.py
def deepseek_tiny() -> ModelConfig:
    """Tiny DeepSeek for testing (d=64, 2 layers, 4 heads)."""
    return deepseek_config(
        vocab_size=256,
        d_model=64,
        n_heads=4,
        n_layers=2,
        d_ff=128,
        kv_lora_rank=16,
        q_lora_rank=32,
        rope_dim=8,
        max_seq_len=128,
    )

Weight Conversion

lmxlab.models.convert.load_from_hf(repo_id, revision=None, dtype=mx.float16, quantize=None)

Download and load a HuggingFace model into lmxlab.

Requires the huggingface_hub package.

Parameters:

Name Type Description Default
repo_id str

HuggingFace repo ID (e.g., 'meta-llama/Llama-3.2-1B').

required
revision str | None

Git revision (branch, tag, or commit hash).

None
dtype Dtype

Target dtype for weights (default: float16).

float16
quantize int | None

If set, quantize the model to this many bits (4 or 8) after loading. Reduces memory usage.

None

Returns:

Type Description
tuple[LanguageModel, ModelConfig]

Tuple of (loaded LanguageModel, ModelConfig).

Raises:

Type Description
ImportError

If huggingface_hub is not installed.

ValueError

If model_type is not supported.

Source code in src/lmxlab/models/convert.py
def load_from_hf(
    repo_id: str,
    revision: str | None = None,
    dtype: mx.Dtype = mx.float16,
    quantize: int | None = None,
) -> tuple[LanguageModel, ModelConfig]:
    """Download and load a HuggingFace model into lmxlab.

    Requires the ``huggingface_hub`` package.

    Args:
        repo_id: HuggingFace repo ID (e.g., 'meta-llama/Llama-3.2-1B').
        revision: Git revision (branch, tag, or commit hash).
        dtype: Target dtype for weights (default: float16).
        quantize: If set, quantize the model to this many bits
            (4 or 8) after loading. Reduces memory usage.

    Returns:
        Tuple of (loaded LanguageModel, ModelConfig).

    Raises:
        ImportError: If huggingface_hub is not installed.
        ValueError: If model_type is not supported.
    """
    try:
        from huggingface_hub import snapshot_download
    except ImportError as e:
        raise ImportError(
            "huggingface_hub is required for load_from_hf. "
            "Install with: pip install huggingface_hub"
        ) from e

    # Download model files
    local_dir = snapshot_download(
        repo_id,
        revision=revision,
        allow_patterns=[
            "*.safetensors",
            "config.json",
            "tokenizer.json",
            "tokenizer_config.json",
        ],
    )
    local_path = Path(local_dir)

    # Load config
    config_path = local_path / "config.json"
    hf_config = json.loads(config_path.read_text())
    model_config = config_from_hf(hf_config)

    # Load weights from all safetensors files
    weight_files = sorted(local_path.glob("*.safetensors"))
    if not weight_files:
        raise FileNotFoundError(f"No .safetensors files found in {local_path}")

    hf_weights: dict[str, mx.array] = {}
    for wf in weight_files:
        loaded = mx.load(str(wf))
        if isinstance(loaded, dict):
            hf_weights.update(loaded)

    # Determine architecture for weight mapping
    arch = hf_config["model_type"]

    # Convert weight names
    pattern = hf_config.get("hybrid_override_pattern")
    lmt_weights = convert_weights(
        hf_weights,
        arch,
        pattern=pattern,
    )

    # Cast to target dtype
    if dtype != mx.float32:
        lmt_weights = {k: v.astype(dtype) for k, v in lmt_weights.items()}

    # Build model and load weights
    model = LanguageModel(model_config)

    # Warn if converted weights don't cover all model parameters
    import mlx.utils

    model_keys = set(dict(mlx.utils.tree_flatten(model.parameters())).keys())
    loaded_keys = set(lmt_weights.keys())
    missing = model_keys - loaded_keys
    if missing:
        logger.warning(
            "Missing %d model parameters after conversion: %s",
            len(missing),
            sorted(missing)[:10],
        )

    model.load_weights(list(lmt_weights.items()))

    # Optional post-load quantization
    if quantize is not None:
        from lmxlab.core.quantize import quantize_model

        quantize_model(model, bits=quantize)

    return model, model_config

lmxlab.models.convert.config_from_hf(hf_config)

Create a ModelConfig from a HuggingFace config dict.

Reads config.json fields and maps them to lmxlab's BlockConfig and ModelConfig.

Parameters:

Name Type Description Default
hf_config dict[str, Any]

Parsed HuggingFace config.json dict.

required

Returns:

Type Description
ModelConfig

ModelConfig matching the HF model architecture.

Raises:

Type Description
ValueError

If model_type is not supported.

Source code in src/lmxlab/models/convert.py
def config_from_hf(
    hf_config: dict[str, Any],
) -> ModelConfig:
    """Create a ModelConfig from a HuggingFace config dict.

    Reads ``config.json`` fields and maps them to lmxlab's
    BlockConfig and ModelConfig.

    Args:
        hf_config: Parsed HuggingFace config.json dict.

    Returns:
        ModelConfig matching the HF model architecture.

    Raises:
        ValueError: If model_type is not supported.
    """
    model_type = hf_config.get("model_type", "")

    # Nemotron-H hybrid architecture
    if model_type == "nemotron_h":
        return _config_from_nemotron_h(hf_config)

    # LLaMA-family (llama, gemma, qwen2, mistral)
    llama_types = {"llama", "gemma", "gemma2", "qwen2", "mistral"}
    if model_type not in llama_types:
        raise ValueError(
            f"Unsupported model_type '{model_type}'. "
            f"Supported: {sorted(llama_types | {'nemotron_h'})}"
        )

    # Validate required keys with clear error messages
    required = [
        "num_attention_heads",
        "hidden_size",
        "intermediate_size",
        "vocab_size",
        "num_hidden_layers",
    ]
    missing = [k for k in required if k not in hf_config]
    if missing:
        raise ValueError(f"HF config missing required keys: {missing}")

    n_heads = hf_config["num_attention_heads"]
    block = BlockConfig(
        attention="gqa",
        ffn="gated",
        norm="rms_norm",
        position="rope",
        d_model=hf_config["hidden_size"],
        n_heads=n_heads,
        n_kv_heads=hf_config.get("num_key_value_heads", n_heads),
        d_ff=hf_config["intermediate_size"],
        bias=False,
        rope_theta=hf_config.get("rope_theta", 10000.0),
        max_seq_len=hf_config.get("max_position_embeddings", 4096),
        pre_norm=True,
    )

    return ModelConfig(
        block=block,
        vocab_size=hf_config["vocab_size"],
        n_layers=hf_config["num_hidden_layers"],
        tie_embeddings=hf_config.get("tie_word_embeddings", False),
    )

lmxlab.models.convert.convert_weights(hf_weights, arch, pattern=None)

Convert HuggingFace weight dict to lmxlab naming.

Parameters:

Name Type Description Default
hf_weights dict[str, array]

Dictionary of HF parameter names to arrays.

required
arch str

Architecture name (e.g., 'llama', 'nemotron_h').

required
pattern str | None

Hybrid override pattern (required for nemotron_h architecture).

None

Returns:

Type Description
dict[str, array]

Dictionary with lmxlab parameter names.

Raises:

Type Description
KeyError

If arch is not supported.

ValueError

If pattern is required but not provided.

Source code in src/lmxlab/models/convert.py
def convert_weights(
    hf_weights: dict[str, mx.array],
    arch: str,
    pattern: str | None = None,
) -> dict[str, mx.array]:
    """Convert HuggingFace weight dict to lmxlab naming.

    Args:
        hf_weights: Dictionary of HF parameter names to arrays.
        arch: Architecture name (e.g., 'llama', 'nemotron_h').
        pattern: Hybrid override pattern (required for
            nemotron_h architecture).

    Returns:
        Dictionary with lmxlab parameter names.

    Raises:
        KeyError: If arch is not supported.
        ValueError: If pattern is required but not provided.
    """
    if arch == "nemotron_h":
        if pattern is None:
            raise ValueError("pattern is required for nemotron_h")
        wmap = _nemotron_weight_map(pattern)
    elif arch in WEIGHT_MAPS:
        wmap = WEIGHT_MAPS[arch]
    else:
        raise KeyError(
            f"Unknown architecture '{arch}'. "
            f"Available: {list(WEIGHT_MAPS.keys()) + ['nemotron_h']}"
        )

    converted = {}
    for hf_name, arr in hf_weights.items():
        lmt_name = wmap(hf_name)
        if lmt_name is not None:
            converted[lmt_name] = arr

    return converted