Skip to content

Inference

Advanced inference strategies for language models.

lmxlab provides three inference approaches beyond standard autoregressive generation:

  • Best-of-N sampling: generate multiple completions and select the highest-scoring one (useful when quality matters more than speed)
  • Majority vote: generate N completions and group by content, returning frequency counts (useful for tasks with discrete answers like math or code)
  • Speculative decoding: use a small draft model to propose tokens, verified by the target model in a single forward pass (especially interesting on unified memory where both models share the same memory pool)

Usage

import mlx.core as mx
from lmxlab.models.base import LanguageModel
from lmxlab.models.gpt import gpt_config
from lmxlab.inference import best_of_n, majority_vote, speculative_decode

model = LanguageModel(gpt_config(vocab_size=256, d_model=128, n_heads=4, n_layers=4))
mx.eval(model.parameters())

prompt = mx.array([[1, 2, 3]])

# Best-of-N: pick the highest-scoring completion
best = best_of_n(model, prompt, n=4, max_tokens=50, temperature=0.8)

# Majority vote: group completions by content
results = majority_vote(model, prompt, n=10, max_tokens=20)
for tokens, count in results:
    print(f"  count={count}: {tokens[:5]}...")

# Speculative decoding: draft-then-verify
draft_model = LanguageModel(gpt_config(vocab_size=256, d_model=64, n_heads=2, n_layers=2))
mx.eval(draft_model.parameters())
output, stats = speculative_decode(model, draft_model, prompt, max_tokens=50)
print(f"Acceptance rate: {stats['acceptance_rate']:.1%}")

Sampling

lmxlab.inference.sampling

Advanced sampling strategies: best-of-N, majority vote.

best_of_n(model, prompt, n=4, max_tokens=100, temperature=0.8, score_fn='log_prob')

Generate N completions and return the best one.

Parameters:

Name Type Description Default
model LanguageModel

Language model.

required
prompt array

Input token IDs (1, prompt_len).

required
n int

Number of candidate completions.

4
max_tokens int

Maximum tokens to generate.

100
temperature float

Sampling temperature.

0.8
score_fn str | Callable[[array], array]

Scoring function. Either a string ('log_prob' or 'length_normalized') or a callable that takes completions (n, seq_len) and returns scores (n,) or (n, 1).

'log_prob'

Returns:

Type Description
array

Best completion token IDs (1, total_len).

Source code in src/lmxlab/inference/sampling.py
def best_of_n(
    model: LanguageModel,
    prompt: mx.array,
    n: int = 4,
    max_tokens: int = 100,
    temperature: float = 0.8,
    score_fn: str | Callable[[mx.array], mx.array] = "log_prob",
) -> mx.array:
    """Generate N completions and return the best one.

    Args:
        model: Language model.
        prompt: Input token IDs (1, prompt_len).
        n: Number of candidate completions.
        max_tokens: Maximum tokens to generate.
        temperature: Sampling temperature.
        score_fn: Scoring function. Either a string
            ('log_prob' or 'length_normalized') or a callable
            that takes completions (n, seq_len) and returns
            scores (n,) or (n, 1).

    Returns:
        Best completion token IDs (1, total_len).
    """
    # Generate N completions by repeating prompt
    prompts = mx.repeat(prompt, repeats=n, axis=0)
    completions = generate(
        model,
        prompts,
        max_tokens=max_tokens,
        temperature=temperature,
    )
    mx.eval(completions)

    if callable(score_fn):
        scores = score_fn(completions)
        mx.eval(scores)
        if scores.ndim > 1:
            scores = scores.squeeze(-1)
    else:
        # Score each completion by log probability
        scores = _score_sequences(model, completions, prompt.shape[1])
        mx.eval(scores)

        if score_fn == "length_normalized":
            gen_len = completions.shape[1] - prompt.shape[1]
            scores = scores / gen_len

    # Return best
    best_idx = int(mx.argmax(scores).item())
    return completions[best_idx : best_idx + 1]

majority_vote(model, prompt, n=5, max_tokens=50, temperature=0.8)

Generate N completions and return majority vote results.

Useful for tasks with discrete answers (e.g., math, code). Groups completions by content and returns counts.

Parameters:

Name Type Description Default
model LanguageModel

Language model.

required
prompt array

Input token IDs (1, prompt_len).

required
n int

Number of completions to generate.

5
max_tokens int

Maximum tokens per completion.

50
temperature float

Sampling temperature.

0.8

Returns:

Type Description
list[tuple[list[int], int]]

List of (token_list, count) sorted by count descending.

Source code in src/lmxlab/inference/sampling.py
def majority_vote(
    model: LanguageModel,
    prompt: mx.array,
    n: int = 5,
    max_tokens: int = 50,
    temperature: float = 0.8,
) -> list[tuple[list[int], int]]:
    """Generate N completions and return majority vote results.

    Useful for tasks with discrete answers (e.g., math, code).
    Groups completions by content and returns counts.

    Args:
        model: Language model.
        prompt: Input token IDs (1, prompt_len).
        n: Number of completions to generate.
        max_tokens: Maximum tokens per completion.
        temperature: Sampling temperature.

    Returns:
        List of (token_list, count) sorted by count descending.
    """
    prompts = mx.repeat(prompt, repeats=n, axis=0)
    completions = generate(
        model,
        prompts,
        max_tokens=max_tokens,
        temperature=temperature,
    )
    mx.eval(completions)

    prompt_len = prompt.shape[1]

    # Group by generated content
    counts: dict[tuple[int, ...], int] = {}
    for i in range(n):
        raw = completions[i, prompt_len:]
        gen = tuple(int(raw[j].item()) for j in range(raw.shape[0]))
        counts[gen] = counts.get(gen, 0) + 1

    # Sort by frequency
    results = [
        (list(tokens), count)
        for tokens, count in sorted(counts.items(), key=lambda x: -x[1])
    ]
    return results

Speculative Decoding

lmxlab.inference.speculative

Speculative decoding for faster inference.

Uses a small draft model to propose tokens, verified by the target model in a single forward pass. Especially interesting on unified memory where both models share the same memory pool.

speculative_decode(target_model, draft_model, prompt, max_tokens=100, draft_tokens=4, temperature=0.0)

Generate tokens using speculative decoding (greedy).

Draft model proposes tokens, target model verifies in one forward pass. Accepted tokens are kept; on mismatch, use the target model's token and restart drafting.

This is especially efficient on Apple Silicon where both models share unified memory -- no data transfer overhead.

Parameters:

Name Type Description Default
target_model LanguageModel

Large target model.

required
draft_model LanguageModel

Small draft model.

required
prompt array

Token IDs (1, prompt_len).

required
max_tokens int

Maximum new tokens.

100
draft_tokens int

Tokens to draft per step.

4
temperature float

Sampling temperature (only 0.0 supported).

0.0

Returns:

Type Description
tuple[array, dict[str, float]]

Tuple of (tokens, stats_dict).

Source code in src/lmxlab/inference/speculative.py
def speculative_decode(
    target_model: LanguageModel,
    draft_model: LanguageModel,
    prompt: mx.array,
    max_tokens: int = 100,
    draft_tokens: int = 4,
    temperature: float = 0.0,
) -> tuple[mx.array, dict[str, float]]:
    """Generate tokens using speculative decoding (greedy).

    Draft model proposes tokens, target model verifies in one
    forward pass. Accepted tokens are kept; on mismatch, use
    the target model's token and restart drafting.

    This is especially efficient on Apple Silicon where both
    models share unified memory -- no data transfer overhead.

    Args:
        target_model: Large target model.
        draft_model: Small draft model.
        prompt: Token IDs (1, prompt_len).
        max_tokens: Maximum new tokens.
        draft_tokens: Tokens to draft per step.
        temperature: Sampling temperature (only 0.0 supported).

    Returns:
        Tuple of (tokens, stats_dict).
    """
    tokens = list(prompt[0].tolist())
    prompt_len = len(tokens)
    n_accepted = 0
    n_drafted = 0

    while len(tokens) - prompt_len < max_tokens:
        remaining = max_tokens - (len(tokens) - prompt_len)
        n_draft = min(draft_tokens, remaining)

        # Draft: generate n_draft tokens with small model
        drafted: list[int] = []
        for _ in range(n_draft):
            d_input = mx.array([tokens + drafted])
            d_logits, _ = draft_model(d_input)
            mx.eval(d_logits)
            next_tok = mx.argmax(d_logits[:, -1, :], axis=-1).item()
            drafted.append(next_tok)
        n_drafted += len(drafted)

        # Verify: run target model on all tokens + drafted
        verify_seq = tokens + drafted
        t_input = mx.array([verify_seq])
        t_logits, _ = target_model(t_input)
        mx.eval(t_logits)

        # Check each drafted token against target
        accepted = 0
        for i, draft_tok in enumerate(drafted):
            # Target's prediction at position before this token
            pos = len(tokens) + i - 1
            target_tok = mx.argmax(t_logits[:, pos, :], axis=-1).item()

            if target_tok == draft_tok:
                accepted += 1
            else:
                # Use target's token and stop
                tokens.append(target_tok)
                n_accepted += accepted + 1
                break
        else:
            # All drafted tokens accepted
            tokens.extend(drafted)
            n_accepted += accepted

            # Also get the next token from target
            if len(tokens) - prompt_len < max_tokens:
                last_pos = len(tokens) - 1
                next_tok = mx.argmax(t_logits[:, last_pos, :], axis=-1).item()
                tokens.append(next_tok)
                n_accepted += 1

    # Truncate to exact length
    tokens = tokens[: prompt_len + max_tokens]
    result = mx.array([tokens])

    stats = {
        "acceptance_rate": (n_accepted / max(n_drafted, 1)),
        "total_drafted": n_drafted,
        "total_accepted": n_accepted,
    }
    return result, stats

Generate completions using beam search.

Maintains beam_width candidate sequences and expands each by one token at each step, keeping the top-scoring beams. By default, scores by cumulative log probability.

Parameters:

Name Type Description Default
model LanguageModel

Language model.

required
prompt array

Input token IDs (1, prompt_len).

required
beam_width int

Number of beams to maintain.

4
max_tokens int

Maximum tokens to generate.

100
score_fn Callable[[array], array] | None

Optional scoring function that takes sequences (beam_width, seq_len) and returns scores (beam_width,). If None, uses log-prob.

None

Returns:

Type Description
list[tuple[array, float]]

List of (sequence, score) tuples sorted by score

list[tuple[array, float]]

descending. Each sequence is (1, total_len).

Source code in src/lmxlab/inference/beam_search.py
def beam_search(
    model: LanguageModel,
    prompt: mx.array,
    beam_width: int = 4,
    max_tokens: int = 100,
    score_fn: Callable[[mx.array], mx.array] | None = None,
) -> list[tuple[mx.array, float]]:
    """Generate completions using beam search.

    Maintains ``beam_width`` candidate sequences and expands
    each by one token at each step, keeping the top-scoring
    beams. By default, scores by cumulative log probability.

    Args:
        model: Language model.
        prompt: Input token IDs (1, prompt_len).
        beam_width: Number of beams to maintain.
        max_tokens: Maximum tokens to generate.
        score_fn: Optional scoring function that takes
            sequences (beam_width, seq_len) and returns
            scores (beam_width,). If None, uses log-prob.

    Returns:
        List of (sequence, score) tuples sorted by score
        descending. Each sequence is (1, total_len).
    """
    if prompt.ndim == 1:
        prompt = prompt[None, :]

    B = prompt.shape[0]
    if B != 1:
        raise ValueError("beam_search expects a single prompt (batch=1)")

    # Initialize beams: (sequence, cumulative_log_prob)
    beams: list[tuple[mx.array, float]] = [
        (prompt, 0.0),
    ]

    for _ in range(max_tokens):
        all_candidates: list[tuple[mx.array, float]] = []

        # Batch all current beams for efficiency
        beam_seqs = mx.concatenate(
            [b[0] for b in beams], axis=0
        )  # (n_beams, seq_len)
        beam_scores = [b[1] for b in beams]

        logits, _ = model(beam_seqs)
        mx.eval(logits)
        # Get log probs for last position
        last_logits = logits[:, -1, :]  # (n_beams, vocab)
        log_probs = nn.log_softmax(last_logits, axis=-1)

        # Get top-k candidates per beam
        n_beams = len(beams)
        for i in range(n_beams):
            # Get top beam_width tokens for this beam
            top_k_vals = mx.topk(log_probs[i], k=beam_width)
            top_k_idx = mx.argpartition(-log_probs[i], kth=beam_width - 1)[
                :beam_width
            ]

            # Sort by value
            sort_order = mx.argsort(-top_k_vals)
            top_k_vals = mx.take(top_k_vals, sort_order)
            top_k_idx = mx.take(top_k_idx, sort_order)

            mx.eval(top_k_vals, top_k_idx)

            for j in range(beam_width):
                token = top_k_idx[j : j + 1][None, :]
                new_seq = mx.concatenate([beams[i][0], token], axis=1)
                new_score = beam_scores[i] + top_k_vals[j].item()
                all_candidates.append((new_seq, new_score))

        # Keep top beam_width candidates
        all_candidates.sort(key=lambda x: -x[1])
        beams = all_candidates[:beam_width]

    # Optional reranking with custom score_fn
    if score_fn is not None:
        beam_seqs = mx.concatenate([b[0] for b in beams], axis=0)
        scores = score_fn(beam_seqs)
        mx.eval(scores)
        if scores.ndim > 1:
            scores = scores.squeeze(-1)
        scored = [(beams[i][0], scores[i].item()) for i in range(len(beams))]
        scored.sort(key=lambda x: -x[1])
        return scored

    return beams

Reward Model

lmxlab.inference.reward_model.RewardModel

Bases: Module

Reward model: language model + scalar head.

Takes token sequences and returns a scalar reward score based on the last-token hidden state.

Parameters:

Name Type Description Default
model LanguageModel

Base language model.

required
Example

rm = RewardModel(model) scores = rm(token_ids) # (batch, 1)

Source code in src/lmxlab/inference/reward_model.py
class RewardModel(nn.Module):
    """Reward model: language model + scalar head.

    Takes token sequences and returns a scalar reward score
    based on the last-token hidden state.

    Args:
        model: Base language model.

    Example:
        >>> rm = RewardModel(model)
        >>> scores = rm(token_ids)  # (batch, 1)
    """

    def __init__(self, model: LanguageModel) -> None:
        super().__init__()
        self.model = model
        d_model = model.config.block.d_model
        self.score_head = nn.Linear(d_model, 1, bias=False)

    def __call__(self, x: mx.array) -> mx.array:
        """Score sequences.

        Args:
            x: Token IDs (batch, seq_len).

        Returns:
            Scalar reward scores (batch, 1).
        """
        logits, _, hidden = self.model(x, return_hidden=True)
        # Take last-token hidden state
        last_hidden = hidden[:, -1, :]  # (batch, d_model)
        return self.score_head(last_hidden)  # (batch, 1)

__init__(model)

Source code in src/lmxlab/inference/reward_model.py
def __init__(self, model: LanguageModel) -> None:
    super().__init__()
    self.model = model
    d_model = model.config.block.d_model
    self.score_head = nn.Linear(d_model, 1, bias=False)

__call__(x)

Score sequences.

Parameters:

Name Type Description Default
x array

Token IDs (batch, seq_len).

required

Returns:

Type Description
array

Scalar reward scores (batch, 1).

Source code in src/lmxlab/inference/reward_model.py
def __call__(self, x: mx.array) -> mx.array:
    """Score sequences.

    Args:
        x: Token IDs (batch, seq_len).

    Returns:
        Scalar reward scores (batch, 1).
    """
    logits, _, hidden = self.model(x, return_hidden=True)
    # Take last-token hidden state
    last_hidden = hidden[:, -1, :]  # (batch, d_model)
    return self.score_head(last_hidden)  # (batch, 1)