Skip to content

Evaluation

Metrics for language model evaluation.

Overview

lmxlab provides standard evaluation metrics:

  • Perplexity: exponential of the average cross-entropy loss. Lower is better. A perplexity of 10 means the model is as uncertain as choosing uniformly among 10 tokens.
  • Bits-per-byte (BPB): cross-entropy loss normalized by bytes. This is tokenizer-independent, making it useful for comparing models with different vocabularies.
  • pass@k: estimates the probability that at least one of k code samples passes a set of tests (Chen et al., 2021). Used for code generation evaluation.

Usage

import mlx.core as mx
from lmxlab.models.base import LanguageModel
from lmxlab.models.gpt import gpt_tiny
from lmxlab.eval import perplexity, bits_per_byte, pass_at_k

model = LanguageModel(gpt_tiny())
mx.eval(model.parameters())

tokens = mx.array([[1, 2, 3, 4, 5, 6, 7, 8]])

# Perplexity on a sequence
ppl = perplexity(model, tokens)
print(f"Perplexity: {ppl:.1f}")

# Bits-per-byte (tokenizer-independent)
bpb = bits_per_byte(model, tokens, bytes_per_token=3.5)
print(f"BPB: {bpb:.3f}")

# pass@k for code generation
# If 3 out of 10 samples pass, estimate pass@1
p1 = pass_at_k(n=10, c=3, k=1)
print(f"pass@1: {p1:.3f}")

Metrics

lmxlab.eval.metrics.perplexity(model, data)

Compute perplexity over a dataset.

PPL = exp(average cross-entropy loss)

Parameters:

Name Type Description Default
model LanguageModel

Language model.

required
data list[array]

List of token ID arrays, each (batch, seq_len).

required

Returns:

Type Description
float

Perplexity score (lower is better).

Source code in src/lmxlab/eval/metrics.py
def perplexity(
    model: LanguageModel,
    data: list[mx.array],
) -> float:
    """Compute perplexity over a dataset.

    PPL = exp(average cross-entropy loss)

    Args:
        model: Language model.
        data: List of token ID arrays, each (batch, seq_len).

    Returns:
        Perplexity score (lower is better).
    """
    total_loss = 0.0
    n_batches = 0

    for tokens in data:
        loss = _compute_loss(model, tokens)
        mx.eval(loss)
        total_loss += loss.item()
        n_batches += 1

    avg_loss = total_loss / max(n_batches, 1)
    return math.exp(avg_loss)

lmxlab.eval.metrics.bits_per_byte(model, data, bytes_per_token=1.0)

Compute bits-per-byte (BPB).

BPB = (cross-entropy in nats) / (ln(2) * bytes_per_token)

For character-level tokenizers, bytes_per_token ≈ 1.0. For BPE tokenizers, estimate from data.

Parameters:

Name Type Description Default
model LanguageModel

Language model.

required
data list[array]

List of token ID arrays.

required
bytes_per_token float

Average bytes per token.

1.0

Returns:

Type Description
float

BPB score (lower is better).

Source code in src/lmxlab/eval/metrics.py
def bits_per_byte(
    model: LanguageModel,
    data: list[mx.array],
    bytes_per_token: float = 1.0,
) -> float:
    """Compute bits-per-byte (BPB).

    BPB = (cross-entropy in nats) / (ln(2) * bytes_per_token)

    For character-level tokenizers, bytes_per_token ≈ 1.0.
    For BPE tokenizers, estimate from data.

    Args:
        model: Language model.
        data: List of token ID arrays.
        bytes_per_token: Average bytes per token.

    Returns:
        BPB score (lower is better).
    """
    total_loss = 0.0
    n_batches = 0

    for tokens in data:
        loss = _compute_loss(model, tokens)
        mx.eval(loss)
        total_loss += loss.item()
        n_batches += 1

    avg_loss = total_loss / max(n_batches, 1)
    return avg_loss / (math.log(2) * bytes_per_token)

Code Generation Evaluation

lmxlab.eval.metrics.pass_at_k(n, c, k)

Compute pass@k metric (Chen et al., 2021, arXiv:2107.03374).

Estimates the probability that at least one of k samples passes a given test, given that c of n total samples pass. Uses the unbiased estimator from the Codex paper.

pass@k = 1 - C(n-c, k) / C(n, k)

Parameters:

Name Type Description Default
n int

Total number of generated samples.

required
c int

Number of samples that pass the test.

required
k int

Number of samples to consider.

required

Returns:

Type Description
float

pass@k probability in [0, 1].

Example::

# 10 samples generated, 3 pass the test
p1 = pass_at_k(n=10, c=3, k=1)   # ~0.30
p5 = pass_at_k(n=10, c=3, k=5)   # ~0.83
Source code in src/lmxlab/eval/metrics.py
def pass_at_k(
    n: int,
    c: int,
    k: int,
) -> float:
    """Compute pass@k metric (Chen et al., 2021, arXiv:2107.03374).

    Estimates the probability that at least one of k samples
    passes a given test, given that c of n total samples pass.
    Uses the unbiased estimator from the Codex paper.

    pass@k = 1 - C(n-c, k) / C(n, k)

    Args:
        n: Total number of generated samples.
        c: Number of samples that pass the test.
        k: Number of samples to consider.

    Returns:
        pass@k probability in [0, 1].

    Example::

        # 10 samples generated, 3 pass the test
        p1 = pass_at_k(n=10, c=3, k=1)   # ~0.30
        p5 = pass_at_k(n=10, c=3, k=5)   # ~0.83
    """
    if n - c < k:
        return 1.0
    # Use log-space for numerical stability
    # pass@k = 1 - prod((n-c-i)/(n-i) for i in range(k))
    log_prod = 0.0
    for i in range(k):
        log_prod += math.log(n - c - i) - math.log(n - i)
    return 1.0 - math.exp(log_prod)

lmxlab.eval.metrics.evaluate_pass_at_k(completions, test_fn, k_values=None)

Evaluate pass@k over multiple problems.

For each problem, generates N completions and checks how many pass using the provided test function.

Parameters:

Name Type Description Default
completions list[list[str]]

List of problems, each a list of N completion strings.

required
test_fn Callable[[str], bool]

Function that returns True if a completion is correct.

required
k_values list[int] | None

Values of k to evaluate. Default: [1, 5, 10].

None

Returns:

Type Description
dict[str, float]

Dict mapping 'pass@k' to the average score across

dict[str, float]

problems.

Source code in src/lmxlab/eval/metrics.py
def evaluate_pass_at_k(
    completions: list[list[str]],
    test_fn: Callable[[str], bool],
    k_values: list[int] | None = None,
) -> dict[str, float]:
    """Evaluate pass@k over multiple problems.

    For each problem, generates N completions and checks how
    many pass using the provided test function.

    Args:
        completions: List of problems, each a list of N
            completion strings.
        test_fn: Function that returns True if a completion
            is correct.
        k_values: Values of k to evaluate. Default: [1, 5, 10].

    Returns:
        Dict mapping 'pass@k' to the average score across
        problems.
    """
    if k_values is None:
        k_values = [1, 5, 10]

    results: dict[str, list[float]] = {f"pass@{k}": [] for k in k_values}

    for problem_completions in completions:
        n = len(problem_completions)
        c = sum(1 for comp in problem_completions if test_fn(comp))
        for k in k_values:
            if k <= n:
                score = pass_at_k(n, c, k)
                results[f"pass@{k}"].append(score)

    return {
        key: sum(vals) / len(vals) if vals else 0.0
        for key, vals in results.items()
    }