Inference
Advanced inference strategies for language models.
lmxlab provides three inference approaches beyond standard autoregressive generation:
- Best-of-N sampling: generate multiple completions and select the highest-scoring one (useful when quality matters more than speed)
- Majority vote: generate N completions and group by content, returning frequency counts (useful for tasks with discrete answers like math or code)
- Speculative decoding: use a small draft model to propose tokens, verified by the target model in a single forward pass (especially interesting on unified memory where both models share the same memory pool)
Usage
import mlx.core as mx
from lmxlab.models.base import LanguageModel
from lmxlab.models.gpt import gpt_config
from lmxlab.inference import best_of_n, majority_vote, speculative_decode
model = LanguageModel(gpt_config(vocab_size=256, d_model=128, n_heads=4, n_layers=4))
mx.eval(model.parameters())
prompt = mx.array([[1, 2, 3]])
# Best-of-N: pick the highest-scoring completion
best = best_of_n(model, prompt, n=4, max_tokens=50, temperature=0.8)
# Majority vote: group completions by content
results = majority_vote(model, prompt, n=10, max_tokens=20)
for tokens, count in results:
print(f" count={count}: {tokens[:5]}...")
# Speculative decoding: draft-then-verify
draft_model = LanguageModel(gpt_config(vocab_size=256, d_model=64, n_heads=2, n_layers=2))
mx.eval(draft_model.parameters())
output, stats = speculative_decode(model, draft_model, prompt, max_tokens=50)
print(f"Acceptance rate: {stats['acceptance_rate']:.1%}")
Sampling
lmxlab.inference.sampling
Advanced sampling strategies: best-of-N, majority vote.
best_of_n(model, prompt, n=4, max_tokens=100, temperature=0.8, score_fn='log_prob')
Generate N completions and return the best one.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
LanguageModel
|
Language model. |
required |
prompt
|
array
|
Input token IDs (1, prompt_len). |
required |
n
|
int
|
Number of candidate completions. |
4
|
max_tokens
|
int
|
Maximum tokens to generate. |
100
|
temperature
|
float
|
Sampling temperature. |
0.8
|
score_fn
|
str | Callable[[array], array]
|
Scoring function. Either a string ('log_prob' or 'length_normalized') or a callable that takes completions (n, seq_len) and returns scores (n,) or (n, 1). |
'log_prob'
|
Returns:
| Type | Description |
|---|---|
array
|
Best completion token IDs (1, total_len). |
Source code in src/lmxlab/inference/sampling.py
majority_vote(model, prompt, n=5, max_tokens=50, temperature=0.8)
Generate N completions and return majority vote results.
Useful for tasks with discrete answers (e.g., math, code). Groups completions by content and returns counts.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
LanguageModel
|
Language model. |
required |
prompt
|
array
|
Input token IDs (1, prompt_len). |
required |
n
|
int
|
Number of completions to generate. |
5
|
max_tokens
|
int
|
Maximum tokens per completion. |
50
|
temperature
|
float
|
Sampling temperature. |
0.8
|
Returns:
| Type | Description |
|---|---|
list[tuple[list[int], int]]
|
List of (token_list, count) sorted by count descending. |
Source code in src/lmxlab/inference/sampling.py
Speculative Decoding
lmxlab.inference.speculative
Speculative decoding for faster inference.
Uses a small draft model to propose tokens, verified by the target model in a single forward pass. Especially interesting on unified memory where both models share the same memory pool.
speculative_decode(target_model, draft_model, prompt, max_tokens=100, draft_tokens=4, temperature=0.0)
Generate tokens using speculative decoding (greedy).
Draft model proposes tokens, target model verifies in one forward pass. Accepted tokens are kept; on mismatch, use the target model's token and restart drafting.
This is especially efficient on Apple Silicon where both models share unified memory -- no data transfer overhead.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
target_model
|
LanguageModel
|
Large target model. |
required |
draft_model
|
LanguageModel
|
Small draft model. |
required |
prompt
|
array
|
Token IDs (1, prompt_len). |
required |
max_tokens
|
int
|
Maximum new tokens. |
100
|
draft_tokens
|
int
|
Tokens to draft per step. |
4
|
temperature
|
float
|
Sampling temperature (only 0.0 supported). |
0.0
|
Returns:
| Type | Description |
|---|---|
tuple[array, dict[str, float]]
|
Tuple of (tokens, stats_dict). |