Skip to content

Inference

Advanced inference strategies for language models.

lmxlab provides three inference approaches beyond standard autoregressive generation:

  • Best-of-N sampling: generate multiple completions and select the highest-scoring one (useful when quality matters more than speed)
  • Majority vote: generate N completions and group by content, returning frequency counts (useful for tasks with discrete answers like math or code)
  • Speculative decoding: use a small draft model to propose tokens, verified by the target model in a single forward pass (especially interesting on unified memory where both models share the same memory pool)

Usage

import mlx.core as mx
from lmxlab.models.base import LanguageModel
from lmxlab.models.gpt import gpt_config
from lmxlab.inference import best_of_n, majority_vote, speculative_decode

model = LanguageModel(gpt_config(vocab_size=256, d_model=128, n_heads=4, n_layers=4))
mx.eval(model.parameters())

prompt = mx.array([[1, 2, 3]])

# Best-of-N: pick the highest-scoring completion
best = best_of_n(model, prompt, n=4, max_tokens=50, temperature=0.8)

# Majority vote: group completions by content
results = majority_vote(model, prompt, n=10, max_tokens=20)
for tokens, count in results:
    print(f"  count={count}: {tokens[:5]}...")

# Speculative decoding: draft-then-verify
draft_model = LanguageModel(gpt_config(vocab_size=256, d_model=64, n_heads=2, n_layers=2))
mx.eval(draft_model.parameters())
output, stats = speculative_decode(model, draft_model, prompt, max_tokens=50)
print(f"Acceptance rate: {stats['acceptance_rate']:.1%}")

Sampling

lmxlab.inference.sampling

Advanced sampling strategies: best-of-N, majority vote.

best_of_n(model, prompt, n=4, max_tokens=100, temperature=0.8, score_fn='log_prob')

Generate N completions and return the best one.

Parameters:

Name Type Description Default
model LanguageModel

Language model.

required
prompt array

Input token IDs (1, prompt_len).

required
n int

Number of candidate completions.

4
max_tokens int

Maximum tokens to generate.

100
temperature float

Sampling temperature.

0.8
score_fn str | Callable[[array], array]

Scoring function. Either a string ('log_prob' or 'length_normalized') or a callable that takes completions (n, seq_len) and returns scores (n,) or (n, 1).

'log_prob'

Returns:

Type Description
array

Best completion token IDs (1, total_len).

Source code in src/lmxlab/inference/sampling.py
def best_of_n(
    model: LanguageModel,
    prompt: mx.array,
    n: int = 4,
    max_tokens: int = 100,
    temperature: float = 0.8,
    score_fn: str | Callable[[mx.array], mx.array] = "log_prob",
) -> mx.array:
    """Generate N completions and return the best one.

    Args:
        model: Language model.
        prompt: Input token IDs (1, prompt_len).
        n: Number of candidate completions.
        max_tokens: Maximum tokens to generate.
        temperature: Sampling temperature.
        score_fn: Scoring function. Either a string
            ('log_prob' or 'length_normalized') or a callable
            that takes completions (n, seq_len) and returns
            scores (n,) or (n, 1).

    Returns:
        Best completion token IDs (1, total_len).
    """
    # Generate N completions by repeating prompt
    prompts = mx.repeat(prompt, repeats=n, axis=0)
    completions = generate(
        model,
        prompts,
        max_tokens=max_tokens,
        temperature=temperature,
    )
    mx.eval(completions)

    if callable(score_fn):
        scores = score_fn(completions)
        mx.eval(scores)
        if scores.ndim > 1:
            scores = scores.squeeze(-1)
    else:
        # Score each completion by log probability
        scores = _score_sequences(model, completions, prompt.shape[1])
        mx.eval(scores)

        if score_fn == "length_normalized":
            gen_len = completions.shape[1] - prompt.shape[1]
            scores = scores / gen_len

    # Return best
    best_idx = int(mx.argmax(scores).item())
    return completions[best_idx : best_idx + 1]

majority_vote(model, prompt, n=5, max_tokens=50, temperature=0.8)

Generate N completions and return majority vote results.

Useful for tasks with discrete answers (e.g., math, code). Groups completions by content and returns counts.

Parameters:

Name Type Description Default
model LanguageModel

Language model.

required
prompt array

Input token IDs (1, prompt_len).

required
n int

Number of completions to generate.

5
max_tokens int

Maximum tokens per completion.

50
temperature float

Sampling temperature.

0.8

Returns:

Type Description
list[tuple[list[int], int]]

List of (token_list, count) sorted by count descending.

Source code in src/lmxlab/inference/sampling.py
def majority_vote(
    model: LanguageModel,
    prompt: mx.array,
    n: int = 5,
    max_tokens: int = 50,
    temperature: float = 0.8,
) -> list[tuple[list[int], int]]:
    """Generate N completions and return majority vote results.

    Useful for tasks with discrete answers (e.g., math, code).
    Groups completions by content and returns counts.

    Args:
        model: Language model.
        prompt: Input token IDs (1, prompt_len).
        n: Number of completions to generate.
        max_tokens: Maximum tokens per completion.
        temperature: Sampling temperature.

    Returns:
        List of (token_list, count) sorted by count descending.
    """
    prompts = mx.repeat(prompt, repeats=n, axis=0)
    completions = generate(
        model,
        prompts,
        max_tokens=max_tokens,
        temperature=temperature,
    )
    mx.eval(completions)

    prompt_len = prompt.shape[1]

    # Group by generated content
    counts: dict[tuple[int, ...], int] = {}
    for i in range(n):
        raw = completions[i, prompt_len:]
        gen = tuple(int(raw[j].item()) for j in range(raw.shape[0]))
        counts[gen] = counts.get(gen, 0) + 1

    # Sort by frequency
    results = [
        (list(tokens), count)
        for tokens, count in sorted(counts.items(), key=lambda x: -x[1])
    ]
    return results

Speculative Decoding

lmxlab.inference.speculative

Speculative decoding for faster inference.

Uses a small draft model to propose tokens, verified by the target model in a single forward pass. Especially interesting on unified memory where both models share the same memory pool.

speculative_decode(target_model, draft_model, prompt, max_tokens=100, draft_tokens=4, temperature=0.0)

Generate tokens using speculative decoding (greedy).

Draft model proposes tokens, target model verifies in one forward pass. Accepted tokens are kept; on mismatch, use the target model's token and restart drafting.

This is especially efficient on Apple Silicon where both models share unified memory -- no data transfer overhead.

Parameters:

Name Type Description Default
target_model LanguageModel

Large target model.

required
draft_model LanguageModel

Small draft model.

required
prompt array

Token IDs (1, prompt_len).

required
max_tokens int

Maximum new tokens.

100
draft_tokens int

Tokens to draft per step.

4
temperature float

Sampling temperature (only 0.0 supported).

0.0

Returns:

Type Description
tuple[array, dict[str, float]]

Tuple of (tokens, stats_dict).

Source code in src/lmxlab/inference/speculative.py
def speculative_decode(
    target_model: LanguageModel,
    draft_model: LanguageModel,
    prompt: mx.array,
    max_tokens: int = 100,
    draft_tokens: int = 4,
    temperature: float = 0.0,
) -> tuple[mx.array, dict[str, float]]:
    """Generate tokens using speculative decoding (greedy).

    Draft model proposes tokens, target model verifies in one
    forward pass. Accepted tokens are kept; on mismatch, use
    the target model's token and restart drafting.

    This is especially efficient on Apple Silicon where both
    models share unified memory -- no data transfer overhead.

    Args:
        target_model: Large target model.
        draft_model: Small draft model.
        prompt: Token IDs (1, prompt_len).
        max_tokens: Maximum new tokens.
        draft_tokens: Tokens to draft per step.
        temperature: Sampling temperature (only 0.0 supported).

    Returns:
        Tuple of (tokens, stats_dict).
    """
    tokens = list(prompt[0].tolist())
    prompt_len = len(tokens)
    n_accepted = 0
    n_drafted = 0

    while len(tokens) - prompt_len < max_tokens:
        remaining = max_tokens - (len(tokens) - prompt_len)
        n_draft = min(draft_tokens, remaining)

        # Draft: generate n_draft tokens with small model
        drafted = []
        for _ in range(n_draft):
            d_input = mx.array([tokens + drafted])
            d_logits, _ = draft_model(d_input)
            mx.eval(d_logits)
            next_tok = mx.argmax(d_logits[:, -1, :], axis=-1).item()
            drafted.append(next_tok)
        n_drafted += len(drafted)

        # Verify: run target model on all tokens + drafted
        verify_seq = tokens + drafted
        t_input = mx.array([verify_seq])
        t_logits, _ = target_model(t_input)
        mx.eval(t_logits)

        # Check each drafted token against target
        accepted = 0
        for i, draft_tok in enumerate(drafted):
            # Target's prediction at position before this token
            pos = len(tokens) + i - 1
            target_tok = mx.argmax(t_logits[:, pos, :], axis=-1).item()

            if target_tok == draft_tok:
                accepted += 1
            else:
                # Use target's token and stop
                tokens.append(target_tok)
                n_accepted += accepted + 1
                break
        else:
            # All drafted tokens accepted
            tokens.extend(drafted)
            n_accepted += accepted

            # Also get the next token from target
            if len(tokens) - prompt_len < max_tokens:
                last_pos = len(tokens) - 1
                next_tok = mx.argmax(t_logits[:, last_pos, :], axis=-1).item()
                tokens.append(next_tok)
                n_accepted += 1

    # Truncate to exact length
    tokens = tokens[: prompt_len + max_tokens]
    result = mx.array([tokens])

    stats = {
        "acceptance_rate": (n_accepted / max(n_drafted, 1)),
        "total_drafted": n_drafted,
        "total_accepted": n_accepted,
    }
    return result, stats