Skip to content

Experiments

Experiment runner, sweeps, tracking, and analysis.

Runner

lmxlab.experiments.runner.ExperimentConfig dataclass

Configuration for an experiment run.

Parameters:

Name Type Description Default
name str

Experiment name/tag.

'experiment'
description str

Human-readable description.

''
time_budget_s float

Maximum wall-clock time in seconds.

300.0
flop_budget float | None

Maximum FLOPs budget (e.g. 1e15 for 1 PFLOPs).

None
seed int

Random seed.

42
output_dir str

Directory for outputs.

'experiments'
Source code in src/lmxlab/experiments/runner.py
@dataclass
class ExperimentConfig:
    """Configuration for an experiment run.

    Args:
        name: Experiment name/tag.
        description: Human-readable description.
        time_budget_s: Maximum wall-clock time in seconds.
        flop_budget: Maximum FLOPs budget (e.g. 1e15 for 1 PFLOPs).
        seed: Random seed.
        output_dir: Directory for outputs.
    """

    name: str = "experiment"
    description: str = ""
    time_budget_s: float = 300.0  # 5 minutes default
    flop_budget: float | None = None  # e.g., 1e15 for 1 PFLOPs
    seed: int = 42
    output_dir: str = "experiments"

description = '' class-attribute instance-attribute

flop_budget = None class-attribute instance-attribute

name = 'experiment' class-attribute instance-attribute

output_dir = 'experiments' class-attribute instance-attribute

seed = 42 class-attribute instance-attribute

time_budget_s = 300.0 class-attribute instance-attribute

__init__(name='experiment', description='', time_budget_s=300.0, flop_budget=None, seed=42, output_dir='experiments')

lmxlab.experiments.runner.ExperimentRunner

Run experiments with autoresearch patterns.

Enforces fixed time budgets, logs results to results.jsonl, and tracks git commits for reproducibility.

Parameters:

Name Type Description Default
config ExperimentConfig

Experiment configuration.

required
log ExperimentLog | None

Experiment log (defaults to results.jsonl in output_dir).

None
Source code in src/lmxlab/experiments/runner.py
class ExperimentRunner:
    """Run experiments with autoresearch patterns.

    Enforces fixed time budgets, logs results to results.jsonl,
    and tracks git commits for reproducibility.

    Args:
        config: Experiment configuration.
        log: Experiment log (defaults to results.jsonl in output_dir).
    """

    def __init__(
        self,
        config: ExperimentConfig,
        log: ExperimentLog | None = None,
    ) -> None:
        self.config = config
        output = Path(config.output_dir)
        output.mkdir(parents=True, exist_ok=True)
        self.log = log or ExperimentLog(output / "results.jsonl")
        self._start_time: float = 0.0

    def time_remaining(self) -> float:
        """Seconds remaining in the time budget."""
        if self._start_time == 0:
            return self.config.time_budget_s
        elapsed = time.monotonic() - self._start_time
        return max(0.0, self.config.time_budget_s - elapsed)

    def is_time_up(self) -> bool:
        """Check if the time budget has been exceeded."""
        return self.time_remaining() <= 0

    def start(self) -> None:
        """Start the experiment timer and set the random seed."""
        self._start_time = time.monotonic()
        mx.random.seed(self.config.seed)

    def finish(
        self,
        metrics: dict[str, Any],
        param_count: int = 0,
        config_dict: dict[str, Any] | None = None,
        status: str = "keep",
    ) -> LogEntry:
        """Finish the experiment and log results.

        Args:
            metrics: Dict of result metrics (must include
                'val_loss' or 'val_bpb').
            param_count: Number of model parameters.
            config_dict: Full experiment config for logging.
            status: 'keep', 'discard', or 'crash'.

        Returns:
            The logged entry.
        """
        wall_time = time.monotonic() - self._start_time

        entry = LogEntry(
            experiment=self.config.name,
            commit=_get_git_commit(),
            status=status,
            val_bpb=metrics.get("val_bpb", 0.0),
            val_loss=metrics.get("val_loss", 0.0),
            train_loss=metrics.get("train_loss", 0.0),
            param_count=param_count,
            wall_time_s=wall_time,
            description=self.config.description,
            config=config_dict or {},
            metrics=metrics,
            seed=self.config.seed,
        )
        self.log.log(entry)
        return entry

_start_time = 0.0 instance-attribute

config = config instance-attribute

log = log or ExperimentLog(output / 'results.jsonl') instance-attribute

__init__(config, log=None)

Source code in src/lmxlab/experiments/runner.py
def __init__(
    self,
    config: ExperimentConfig,
    log: ExperimentLog | None = None,
) -> None:
    self.config = config
    output = Path(config.output_dir)
    output.mkdir(parents=True, exist_ok=True)
    self.log = log or ExperimentLog(output / "results.jsonl")
    self._start_time: float = 0.0

finish(metrics, param_count=0, config_dict=None, status='keep')

Finish the experiment and log results.

Parameters:

Name Type Description Default
metrics dict[str, Any]

Dict of result metrics (must include 'val_loss' or 'val_bpb').

required
param_count int

Number of model parameters.

0
config_dict dict[str, Any] | None

Full experiment config for logging.

None
status str

'keep', 'discard', or 'crash'.

'keep'

Returns:

Type Description
LogEntry

The logged entry.

Source code in src/lmxlab/experiments/runner.py
def finish(
    self,
    metrics: dict[str, Any],
    param_count: int = 0,
    config_dict: dict[str, Any] | None = None,
    status: str = "keep",
) -> LogEntry:
    """Finish the experiment and log results.

    Args:
        metrics: Dict of result metrics (must include
            'val_loss' or 'val_bpb').
        param_count: Number of model parameters.
        config_dict: Full experiment config for logging.
        status: 'keep', 'discard', or 'crash'.

    Returns:
        The logged entry.
    """
    wall_time = time.monotonic() - self._start_time

    entry = LogEntry(
        experiment=self.config.name,
        commit=_get_git_commit(),
        status=status,
        val_bpb=metrics.get("val_bpb", 0.0),
        val_loss=metrics.get("val_loss", 0.0),
        train_loss=metrics.get("train_loss", 0.0),
        param_count=param_count,
        wall_time_s=wall_time,
        description=self.config.description,
        config=config_dict or {},
        metrics=metrics,
        seed=self.config.seed,
    )
    self.log.log(entry)
    return entry

is_time_up()

Check if the time budget has been exceeded.

Source code in src/lmxlab/experiments/runner.py
def is_time_up(self) -> bool:
    """Check if the time budget has been exceeded."""
    return self.time_remaining() <= 0

start()

Start the experiment timer and set the random seed.

Source code in src/lmxlab/experiments/runner.py
def start(self) -> None:
    """Start the experiment timer and set the random seed."""
    self._start_time = time.monotonic()
    mx.random.seed(self.config.seed)

time_remaining()

Seconds remaining in the time budget.

Source code in src/lmxlab/experiments/runner.py
def time_remaining(self) -> float:
    """Seconds remaining in the time budget."""
    if self._start_time == 0:
        return self.config.time_budget_s
    elapsed = time.monotonic() - self._start_time
    return max(0.0, self.config.time_budget_s - elapsed)

Sweep

lmxlab.experiments.sweep

Hyperparameter sweep utilities.

grid_sweep(param_grid)

Generate all combinations from a parameter grid.

Parameters:

Name Type Description Default
param_grid dict[str, list[Any]]

Dict mapping parameter names to lists of values to try.

required

Yields:

Type Description
dict[str, Any]

Dicts with one value per parameter.

Example

list(grid_sweep({'lr': [1e-3, 1e-4], 'layers': [2, 4]})) [{'lr': 0.001, 'layers': 2}, {'lr': 0.001, 'layers': 4}, {'lr': 0.0001, 'layers': 2}, {'lr': 0.0001, 'layers': 4}]

Source code in src/lmxlab/experiments/sweep.py
def grid_sweep(
    param_grid: dict[str, list[Any]],
) -> Iterator[dict[str, Any]]:
    """Generate all combinations from a parameter grid.

    Args:
        param_grid: Dict mapping parameter names to lists
            of values to try.

    Yields:
        Dicts with one value per parameter.

    Example:
        >>> list(grid_sweep({'lr': [1e-3, 1e-4], 'layers': [2, 4]}))
        [{'lr': 0.001, 'layers': 2}, {'lr': 0.001, 'layers': 4},
         {'lr': 0.0001, 'layers': 2}, {'lr': 0.0001, 'layers': 4}]
    """
    keys = list(param_grid.keys())
    values = list(param_grid.values())
    for combo in itertools.product(*values):
        yield dict(zip(keys, combo, strict=True))

random_sweep(param_ranges, n_trials=10, seed=42, log_scale=None)

Generate random parameter combinations.

Samples uniformly from continuous ranges by default. Parameters listed in log_scale are sampled in log-space, which is standard for learning rates and other parameters spanning multiple orders of magnitude.

Uses Python's random module (no MLX dependency), so sweep configuration can be computed without Apple Silicon hardware.

Parameters:

Name Type Description Default
param_ranges dict[str, tuple[float, float]]

Dict mapping parameter names to (min, max) tuples.

required
n_trials int

Number of random combinations.

10
seed int

Random seed for reproducibility.

42
log_scale set[str] | None

Set of parameter names to sample in log-space. For these, (min, max) must both be positive.

None

Yields:

Type Description
dict[str, float]

Dicts with one sampled value per parameter.

Example

configs = list(random_sweep( ... param_ranges={"lr": (1e-5, 1e-1), "d_model": (64, 512)}, ... n_trials=5, ... log_scale={"lr"}, ... ))

Source code in src/lmxlab/experiments/sweep.py
def random_sweep(
    param_ranges: dict[str, tuple[float, float]],
    n_trials: int = 10,
    seed: int = 42,
    log_scale: set[str] | None = None,
) -> Iterator[dict[str, float]]:
    """Generate random parameter combinations.

    Samples uniformly from continuous ranges by default.
    Parameters listed in ``log_scale`` are sampled in
    log-space, which is standard for learning rates and
    other parameters spanning multiple orders of magnitude.

    Uses Python's ``random`` module (no MLX dependency),
    so sweep configuration can be computed without Apple
    Silicon hardware.

    Args:
        param_ranges: Dict mapping parameter names to
            (min, max) tuples.
        n_trials: Number of random combinations.
        seed: Random seed for reproducibility.
        log_scale: Set of parameter names to sample in
            log-space. For these, (min, max) must both
            be positive.

    Yields:
        Dicts with one sampled value per parameter.

    Example:
        >>> configs = list(random_sweep(
        ...     param_ranges={"lr": (1e-5, 1e-1), "d_model": (64, 512)},
        ...     n_trials=5,
        ...     log_scale={"lr"},
        ... ))
    """
    log_params = log_scale or set()
    rng = random.Random(seed)
    keys = list(param_ranges.keys())
    ranges = list(param_ranges.values())

    for _ in range(n_trials):
        config = {}
        for key, (lo, hi) in zip(keys, ranges, strict=True):
            if key in log_params:
                log_lo = math.log(lo)
                log_hi = math.log(hi)
                config[key] = math.exp(rng.uniform(log_lo, log_hi))
            else:
                config[key] = rng.uniform(lo, hi)
        yield config

Tracking

lmxlab.experiments.tracking.LogEntry dataclass

A single experiment result entry.

Parameters:

Name Type Description Default
experiment str

Experiment name/tag.

''
commit str

Git commit hash.

''
status str

Outcome ('keep', 'discard', 'crash').

'keep'
val_bpb float

Validation bits-per-byte.

0.0
val_loss float

Validation loss.

0.0
train_loss float

Final training loss.

0.0
param_count int

Number of model parameters.

0
total_flops float

Total FLOPs consumed during training.

0.0
peak_memory_mb float

Peak memory usage in MB.

0.0
wall_time_s float

Wall clock time in seconds.

0.0
description str

Human-readable description.

''
config dict[str, Any]

Full experiment config dict.

dict()
metrics dict[str, Any]

Additional metrics dict.

dict()
timestamp float

Unix timestamp (auto-filled).

time()
seed int

Random seed used.

42
Source code in src/lmxlab/experiments/tracking.py
@dataclass
class LogEntry:
    """A single experiment result entry.

    Args:
        experiment: Experiment name/tag.
        commit: Git commit hash.
        status: Outcome ('keep', 'discard', 'crash').
        val_bpb: Validation bits-per-byte.
        val_loss: Validation loss.
        train_loss: Final training loss.
        param_count: Number of model parameters.
        total_flops: Total FLOPs consumed during training.
        peak_memory_mb: Peak memory usage in MB.
        wall_time_s: Wall clock time in seconds.
        description: Human-readable description.
        config: Full experiment config dict.
        metrics: Additional metrics dict.
        timestamp: Unix timestamp (auto-filled).
        seed: Random seed used.
    """

    experiment: str = ""
    commit: str = ""
    status: str = "keep"
    val_bpb: float = 0.0
    val_loss: float = 0.0
    train_loss: float = 0.0
    param_count: int = 0
    total_flops: float = 0.0
    peak_memory_mb: float = 0.0
    wall_time_s: float = 0.0
    description: str = ""
    config: dict[str, Any] = field(default_factory=dict)
    metrics: dict[str, Any] = field(default_factory=dict)
    timestamp: float = field(default_factory=time.time)
    seed: int = 42

commit = '' class-attribute instance-attribute

config = field(default_factory=dict) class-attribute instance-attribute

description = '' class-attribute instance-attribute

experiment = '' class-attribute instance-attribute

metrics = field(default_factory=dict) class-attribute instance-attribute

param_count = 0 class-attribute instance-attribute

peak_memory_mb = 0.0 class-attribute instance-attribute

seed = 42 class-attribute instance-attribute

status = 'keep' class-attribute instance-attribute

timestamp = field(default_factory=(time.time)) class-attribute instance-attribute

total_flops = 0.0 class-attribute instance-attribute

train_loss = 0.0 class-attribute instance-attribute

val_bpb = 0.0 class-attribute instance-attribute

val_loss = 0.0 class-attribute instance-attribute

wall_time_s = 0.0 class-attribute instance-attribute

__init__(experiment='', commit='', status='keep', val_bpb=0.0, val_loss=0.0, train_loss=0.0, param_count=0, total_flops=0.0, peak_memory_mb=0.0, wall_time_s=0.0, description='', config=dict(), metrics=dict(), timestamp=time.time(), seed=42)

lmxlab.experiments.tracking.ExperimentLog

Append-only experiment log backed by results.jsonl.

This is the ground truth for all experiments. Zero dependencies, git-trackable, easy for agents to parse.

Parameters:

Name Type Description Default
path str | Path

Path to results.jsonl file.

'results.jsonl'
Source code in src/lmxlab/experiments/tracking.py
class ExperimentLog:
    """Append-only experiment log backed by results.jsonl.

    This is the ground truth for all experiments. Zero
    dependencies, git-trackable, easy for agents to parse.

    Args:
        path: Path to results.jsonl file.
    """

    def __init__(self, path: str | Path = "results.jsonl") -> None:
        self.path = Path(path)

    def log(self, entry: LogEntry) -> None:
        """Append an entry to the log.

        Args:
            entry: Experiment result to log.
        """
        self.path.parent.mkdir(parents=True, exist_ok=True)
        with open(self.path, "a") as f:
            f.write(json.dumps(asdict(entry)) + "\n")

    def load(self) -> list[LogEntry]:
        """Load all entries from the log.

        Returns:
            List of LogEntry objects.
        """
        if not self.path.exists():
            return []
        entries = []
        with open(self.path) as f:
            for line in f:
                line = line.strip()
                if line:
                    data = json.loads(line)
                    entries.append(LogEntry(**data))
        return entries

    def best(
        self,
        metric: str = "val_bpb",
        lower_is_better: bool = True,
    ) -> LogEntry | None:
        """Find the best entry by a metric.

        Args:
            metric: Name of the metric field.
            lower_is_better: If True, minimize; else maximize.

        Returns:
            Best LogEntry, or None if log is empty.
        """
        entries = [e for e in self.load() if e.status == "keep"]
        if not entries:
            return None
        return min(
            entries,
            key=lambda e: getattr(e, metric) * (1 if lower_is_better else -1),
        )

    def summary(self) -> dict[str, Any]:
        """Get summary statistics of all experiments.

        Returns:
            Dict with counts, best metrics, etc.
        """
        entries = self.load()
        if not entries:
            return {"total": 0}
        kept = [e for e in entries if e.status == "keep"]
        return {
            "total": len(entries),
            "kept": len(kept),
            "discarded": sum(1 for e in entries if e.status == "discard"),
            "crashed": sum(1 for e in entries if e.status == "crash"),
            "best_val_bpb": min(
                (e.val_bpb for e in kept), default=float("inf")
            ),
        }

path = Path(path) instance-attribute

__init__(path='results.jsonl')

Source code in src/lmxlab/experiments/tracking.py
def __init__(self, path: str | Path = "results.jsonl") -> None:
    self.path = Path(path)

best(metric='val_bpb', lower_is_better=True)

Find the best entry by a metric.

Parameters:

Name Type Description Default
metric str

Name of the metric field.

'val_bpb'
lower_is_better bool

If True, minimize; else maximize.

True

Returns:

Type Description
LogEntry | None

Best LogEntry, or None if log is empty.

Source code in src/lmxlab/experiments/tracking.py
def best(
    self,
    metric: str = "val_bpb",
    lower_is_better: bool = True,
) -> LogEntry | None:
    """Find the best entry by a metric.

    Args:
        metric: Name of the metric field.
        lower_is_better: If True, minimize; else maximize.

    Returns:
        Best LogEntry, or None if log is empty.
    """
    entries = [e for e in self.load() if e.status == "keep"]
    if not entries:
        return None
    return min(
        entries,
        key=lambda e: getattr(e, metric) * (1 if lower_is_better else -1),
    )

load()

Load all entries from the log.

Returns:

Type Description
list[LogEntry]

List of LogEntry objects.

Source code in src/lmxlab/experiments/tracking.py
def load(self) -> list[LogEntry]:
    """Load all entries from the log.

    Returns:
        List of LogEntry objects.
    """
    if not self.path.exists():
        return []
    entries = []
    with open(self.path) as f:
        for line in f:
            line = line.strip()
            if line:
                data = json.loads(line)
                entries.append(LogEntry(**data))
    return entries

log(entry)

Append an entry to the log.

Parameters:

Name Type Description Default
entry LogEntry

Experiment result to log.

required
Source code in src/lmxlab/experiments/tracking.py
def log(self, entry: LogEntry) -> None:
    """Append an entry to the log.

    Args:
        entry: Experiment result to log.
    """
    self.path.parent.mkdir(parents=True, exist_ok=True)
    with open(self.path, "a") as f:
        f.write(json.dumps(asdict(entry)) + "\n")

summary()

Get summary statistics of all experiments.

Returns:

Type Description
dict[str, Any]

Dict with counts, best metrics, etc.

Source code in src/lmxlab/experiments/tracking.py
def summary(self) -> dict[str, Any]:
    """Get summary statistics of all experiments.

    Returns:
        Dict with counts, best metrics, etc.
    """
    entries = self.load()
    if not entries:
        return {"total": 0}
    kept = [e for e in entries if e.status == "keep"]
    return {
        "total": len(entries),
        "kept": len(kept),
        "discarded": sum(1 for e in entries if e.status == "discard"),
        "crashed": sum(1 for e in entries if e.status == "crash"),
        "best_val_bpb": min(
            (e.val_bpb for e in kept), default=float("inf")
        ),
    }

Analysis

lmxlab.experiments.analysis

Analysis utilities for experiment results.

cohens_d(group_a, group_b)

Compute Cohen's d effect size between two groups.

Uses pooled standard deviation (equal-variance assumption). Useful for reporting effect sizes alongside p-values, as recommended in the pre-registered experiment plans.

Parameters:

Name Type Description Default
group_a list[float]

Values from the first group.

required
group_b list[float]

Values from the second group.

required

Returns:

Name Type Description
float

Cohen's d. Positive means group_a > group_b.

Conventions float

|d| < 0.2 small, 0.5 medium, 0.8 large.

Source code in src/lmxlab/experiments/analysis.py
def cohens_d(
    group_a: list[float],
    group_b: list[float],
) -> float:
    """Compute Cohen's d effect size between two groups.

    Uses pooled standard deviation (equal-variance assumption).
    Useful for reporting effect sizes alongside p-values, as
    recommended in the pre-registered experiment plans.

    Args:
        group_a: Values from the first group.
        group_b: Values from the second group.

    Returns:
        Cohen's d. Positive means group_a > group_b.
        Conventions: |d| < 0.2 small, 0.5 medium, 0.8 large.
    """
    n_a, n_b = len(group_a), len(group_b)
    if n_a < 2 or n_b < 2:
        return 0.0

    mean_a = sum(group_a) / n_a
    mean_b = sum(group_b) / n_b
    var_a = sum((x - mean_a) ** 2 for x in group_a) / (n_a - 1)
    var_b = sum((x - mean_b) ** 2 for x in group_b) / (n_b - 1)

    pooled_var = ((n_a - 1) * var_a + (n_b - 1) * var_b) / (n_a + n_b - 2)
    pooled_std = math.sqrt(pooled_var)

    if pooled_std == 0:
        return 0.0
    return (mean_a - mean_b) / pooled_std

compare_experiments(log, metric='val_bpb')

Compare all kept experiments by a metric.

Returns experiments sorted by the metric (ascending).

Parameters:

Name Type Description Default
log ExperimentLog

Experiment log to analyze.

required
metric str

Metric name to compare.

'val_bpb'

Returns:

Type Description
list[dict[str, Any]]

List of dicts with experiment name, metric value,

list[dict[str, Any]]

param_count, and wall_time.

Source code in src/lmxlab/experiments/analysis.py
def compare_experiments(
    log: ExperimentLog,
    metric: str = "val_bpb",
) -> list[dict[str, Any]]:
    """Compare all kept experiments by a metric.

    Returns experiments sorted by the metric (ascending).

    Args:
        log: Experiment log to analyze.
        metric: Metric name to compare.

    Returns:
        List of dicts with experiment name, metric value,
        param_count, and wall_time.
    """
    entries = [e for e in log.load() if e.status == "keep"]
    entries.sort(key=lambda e: getattr(e, metric, float("inf")))
    return [
        {
            "experiment": e.experiment,
            metric: getattr(e, metric),
            "param_count": e.param_count,
            "wall_time_s": e.wall_time_s,
            "description": e.description,
        }
        for e in entries
    ]

compute_statistics(values)

Compute basic statistics for a list of values.

Parameters:

Name Type Description Default
values list[float]

List of numeric values.

required

Returns:

Type Description
dict[str, float]

Dict with mean, std, min, max, n.

Source code in src/lmxlab/experiments/analysis.py
def compute_statistics(
    values: list[float],
) -> dict[str, float]:
    """Compute basic statistics for a list of values.

    Args:
        values: List of numeric values.

    Returns:
        Dict with mean, std, min, max, n.
    """
    if not values:
        return {
            "mean": 0.0,
            "std": 0.0,
            "min": 0.0,
            "max": 0.0,
            "n": 0,
        }

    n = len(values)
    mean = sum(values) / n
    variance = sum((x - mean) ** 2 for x in values) / max(n - 1, 1)
    return {
        "mean": mean,
        "std": math.sqrt(variance),
        "min": min(values),
        "max": max(values),
        "n": n,
    }

confidence_interval(values, confidence=0.95)

Compute a confidence interval for the mean.

Uses the t-distribution for small samples. Falls back to z-approximation for n >= 30.

Parameters:

Name Type Description Default
values list[float]

Sample values.

required
confidence float

Confidence level (default 0.95).

0.95

Returns:

Type Description
tuple[float, float]

(lower, upper) bounds of the confidence interval.

Source code in src/lmxlab/experiments/analysis.py
def confidence_interval(
    values: list[float],
    confidence: float = 0.95,
) -> tuple[float, float]:
    """Compute a confidence interval for the mean.

    Uses the t-distribution for small samples. Falls back to
    z-approximation for n >= 30.

    Args:
        values: Sample values.
        confidence: Confidence level (default 0.95).

    Returns:
        (lower, upper) bounds of the confidence interval.
    """
    n = len(values)
    if n < 2:
        mean = values[0] if values else 0.0
        return (mean, mean)

    mean = sum(values) / n
    variance = sum((x - mean) ** 2 for x in values) / (n - 1)
    std_err = math.sqrt(variance / n)

    # t critical values for common confidence levels and small n
    # For n >= 30, use z-approximation
    if n >= 30:
        z_map = {0.90: 1.645, 0.95: 1.960, 0.99: 2.576}
        z = z_map.get(confidence, 1.960)
    else:
        # Approximate t critical value using Abramowitz & Stegun
        # For educational use; production code should use scipy
        df = n - 1
        z_map = {0.90: 1.645, 0.95: 1.960, 0.99: 2.576}
        z_approx = z_map.get(confidence, 1.960)
        # Crude t correction: t ≈ z + (z + z^3) / (4 * df)
        z = z_approx + (z_approx + z_approx**3) / (4 * df)

    margin = z * std_err
    return (mean - margin, mean + margin)

simplicity_score(entry, baseline_params, baseline_metric, metric='val_bpb')

Score an experiment by the simplicity bias principle.

Rewards improvements that use fewer parameters. Score = metric_improvement * (baseline_params / param_count)

Higher is better. Positive means improvement over baseline.

Parameters:

Name Type Description Default
entry LogEntry

Experiment entry to score.

required
baseline_params int

Baseline parameter count.

required
baseline_metric float

Baseline metric value.

required
metric str

Metric name (lower is better).

'val_bpb'

Returns:

Type Description
float

Simplicity-weighted improvement score.

Source code in src/lmxlab/experiments/analysis.py
def simplicity_score(
    entry: LogEntry,
    baseline_params: int,
    baseline_metric: float,
    metric: str = "val_bpb",
) -> float:
    """Score an experiment by the simplicity bias principle.

    Rewards improvements that use fewer parameters.
    Score = metric_improvement * (baseline_params / param_count)

    Higher is better. Positive means improvement over baseline.

    Args:
        entry: Experiment entry to score.
        baseline_params: Baseline parameter count.
        baseline_metric: Baseline metric value.
        metric: Metric name (lower is better).

    Returns:
        Simplicity-weighted improvement score.
    """
    metric_val: float = getattr(entry, metric)
    improvement = baseline_metric - metric_val  # positive = better
    param_ratio = baseline_params / max(entry.param_count, 1)
    return improvement * param_ratio

Profiling

lmxlab.experiments.profiling.benchmark_fn(fn, n_warmup=3, n_iter=10)

Time a function with warmup iterations.

Runs the function n_warmup times (discarded), then n_iter times (timed). Returns timing statistics.

Parameters:

Name Type Description Default
fn Callable[[], Any]

Callable to benchmark (should include mx.eval).

required
n_warmup int

Number of warmup iterations.

3
n_iter int

Number of timed iterations.

10

Returns:

Type Description
dict[str, float]

Dict with mean_ms, std_ms, min_ms, max_ms, n_iter.

Source code in src/lmxlab/experiments/profiling.py
def benchmark_fn(
    fn: Callable[[], Any],
    n_warmup: int = 3,
    n_iter: int = 10,
) -> dict[str, float]:
    """Time a function with warmup iterations.

    Runs the function n_warmup times (discarded), then n_iter
    times (timed). Returns timing statistics.

    Args:
        fn: Callable to benchmark (should include mx.eval).
        n_warmup: Number of warmup iterations.
        n_iter: Number of timed iterations.

    Returns:
        Dict with mean_ms, std_ms, min_ms, max_ms, n_iter.
    """
    # Warmup
    for _ in range(n_warmup):
        fn()

    # Timed runs
    times = []
    for _ in range(n_iter):
        t0 = time.perf_counter()
        fn()
        t1 = time.perf_counter()
        times.append((t1 - t0) * 1000)  # ms

    mean = sum(times) / len(times)
    variance = sum((t - mean) ** 2 for t in times) / max(len(times) - 1, 1)

    return {
        "mean_ms": mean,
        "std_ms": math.sqrt(variance),
        "min_ms": min(times),
        "max_ms": max(times),
        "n_iter": n_iter,
    }

lmxlab.experiments.profiling.memory_estimate(model)

Estimate model memory usage from parameter shapes and dtypes.

This is a static estimate based on parameter tensors. Actual memory usage during inference includes activations, KV cache, and MLX graph overhead.

Parameters:

Name Type Description Default
model Module

Model to estimate.

required

Returns:

Type Description
dict[str, Any]

Dict with total_bytes, total_mb, param_count,

dict[str, Any]

and per-dtype breakdown.

Source code in src/lmxlab/experiments/profiling.py
def memory_estimate(model: nn.Module) -> dict[str, Any]:  # type: ignore[name-defined]
    """Estimate model memory usage from parameter shapes and dtypes.

    This is a static estimate based on parameter tensors. Actual
    memory usage during inference includes activations, KV cache,
    and MLX graph overhead.

    Args:
        model: Model to estimate.

    Returns:
        Dict with total_bytes, total_mb, param_count,
        and per-dtype breakdown.
    """
    flat = mlx.utils.tree_flatten(model.parameters())
    total_bytes = 0
    param_count = 0
    dtype_bytes: dict[str, int] = {}

    for _, p in flat:  # type: ignore[misc]
        nbytes = p.nbytes
        total_bytes += nbytes
        param_count += p.size
        dtype_name = str(p.dtype)
        dtype_bytes[dtype_name] = dtype_bytes.get(dtype_name, 0) + nbytes

    return {
        "total_bytes": total_bytes,
        "total_mb": total_bytes / (1024 * 1024),
        "param_count": param_count,
        "by_dtype": dtype_bytes,
    }

lmxlab.experiments.profiling.count_parameters_by_module(model)

Count parameters per top-level submodule.

Returns a dict mapping module names to their parameter counts, useful for understanding where parameters are concentrated.

Parameters:

Name Type Description Default
model Module

Model to analyze.

required

Returns:

Type Description
dict[str, int]

Dict mapping module name to parameter count.

Source code in src/lmxlab/experiments/profiling.py
def count_parameters_by_module(
    model: nn.Module,  # type: ignore[name-defined]
) -> dict[str, int]:
    """Count parameters per top-level submodule.

    Returns a dict mapping module names to their parameter
    counts, useful for understanding where parameters are
    concentrated.

    Args:
        model: Model to analyze.

    Returns:
        Dict mapping module name to parameter count.
    """
    result = {}
    for name, child in model.children().items():
        flat = mlx.utils.tree_flatten(child)
        count = sum(p.size for _, p in flat)  # type: ignore[misc]
        if count > 0:
            result[name] = count
    return result

lmxlab.experiments.profiling.profile_forward(model, tokens, n_warmup=2, n_iter=5)

Profile forward pass throughput.

Times the model's forward pass and computes tokens/second.

Parameters:

Name Type Description Default
model Module

Language model to profile.

required
tokens array

Input token IDs (batch, seq_len).

required
n_warmup int

Warmup iterations.

2
n_iter int

Timed iterations.

5

Returns:

Type Description
dict[str, Any]

Dict with timing stats, tokens_per_sec, batch_size,

dict[str, Any]

seq_len.

Source code in src/lmxlab/experiments/profiling.py
def profile_forward(
    model: nn.Module,  # type: ignore[name-defined]
    tokens: mx.array,
    n_warmup: int = 2,
    n_iter: int = 5,
) -> dict[str, Any]:
    """Profile forward pass throughput.

    Times the model's forward pass and computes tokens/second.

    Args:
        model: Language model to profile.
        tokens: Input token IDs (batch, seq_len).
        n_warmup: Warmup iterations.
        n_iter: Timed iterations.

    Returns:
        Dict with timing stats, tokens_per_sec, batch_size,
        seq_len.
    """
    batch_size, seq_len = tokens.shape

    def run() -> None:
        logits, _ = model(tokens)
        mx.eval(logits)

    timing = benchmark_fn(run, n_warmup=n_warmup, n_iter=n_iter)

    total_tokens = batch_size * seq_len
    tokens_per_sec = (
        total_tokens / (timing["mean_ms"] / 1000)
        if timing["mean_ms"] > 0
        else 0
    )

    return {
        **timing,
        "tokens_per_sec": tokens_per_sec,
        "batch_size": batch_size,
        "seq_len": seq_len,
    }

lmxlab.experiments.profiling.profile_generation(model, prompt, max_tokens=50)

Profile autoregressive generation throughput.

Measures time-to-first-token (prompt processing) and per-token generation speed.

Parameters:

Name Type Description Default
model Module

Language model.

required
prompt array

Prompt token IDs (1, prompt_len).

required
max_tokens int

Number of tokens to generate.

50

Returns:

Type Description
dict[str, Any]

Dict with prefill_ms, decode_ms_per_token,

dict[str, Any]

total_ms, tokens_generated.

Source code in src/lmxlab/experiments/profiling.py
def profile_generation(
    model: nn.Module,  # type: ignore[name-defined]
    prompt: mx.array,
    max_tokens: int = 50,
) -> dict[str, Any]:
    """Profile autoregressive generation throughput.

    Measures time-to-first-token (prompt processing) and
    per-token generation speed.

    Args:
        model: Language model.
        prompt: Prompt token IDs (1, prompt_len).
        max_tokens: Number of tokens to generate.

    Returns:
        Dict with prefill_ms, decode_ms_per_token,
        total_ms, tokens_generated.
    """
    # Prefill: process the full prompt
    t0 = time.perf_counter()
    logits, cache = model(prompt)
    next_token = mx.argmax(logits[:, -1, :], axis=-1, keepdims=True)
    mx.eval(next_token, *[c for pair in cache for c in pair])
    prefill_ms = (time.perf_counter() - t0) * 1000

    # Decode: generate token by token
    tokens_generated = 0
    t0 = time.perf_counter()
    for _ in range(max_tokens - 1):
        logits, cache = model(next_token, cache=cache)
        next_token = mx.argmax(logits[:, -1, :], axis=-1, keepdims=True)
        mx.eval(next_token, *[c for pair in cache for c in pair])
        tokens_generated += 1
    decode_ms = (time.perf_counter() - t0) * 1000

    total_generated = tokens_generated + 1  # include first token
    decode_per_token = decode_ms / max(tokens_generated, 1)

    return {
        "prefill_ms": prefill_ms,
        "decode_ms_per_token": decode_per_token,
        "total_ms": prefill_ms + decode_ms,
        "tokens_generated": total_generated,
        "prompt_len": prompt.shape[1],
        "decode_tokens_per_sec": (
            1000 / decode_per_token if decode_per_token > 0 else 0
        ),
    }