Skip to content

Data

Tokenizers, datasets, and batching utilities for feeding data to models.

Overview

The data pipeline follows a simple flow:

raw text → Tokenizer → token IDs → Dataset → batch_iterator → (x, y) batches

Three tokenizer implementations are provided:

  • CharTokenizer: character-level tokenization (good for learning, no dependencies)
  • TiktokenTokenizer: OpenAI's BPE tokenizer (GPT-2/GPT-4 vocabularies)
  • HFTokenizer: wraps any HuggingFace AutoTokenizer

Usage

import mlx.core as mx
from lmxlab.data import CharTokenizer, TextDataset, batch_iterator

# Character-level tokenizer
tok = CharTokenizer("the quick brown fox")
ids = tok.encode("the fox")
print(tok.decode(ids))  # "the fox"

# Create dataset with next-token prediction targets
ds = TextDataset("the quick brown fox jumps over the lazy dog", tok, seq_len=16)
x, y = ds[0]  # x = input tokens, y = shifted targets

# Batch iterator for training
tokens = mx.array(tok.encode("..." * 1000), dtype=mx.int32)
for x_batch, y_batch in batch_iterator(tokens, batch_size=4, seq_len=32):
    # x_batch.shape == (4, 32), y_batch.shape == (4, 32)
    pass

Tokenizer

lmxlab.data.tokenizer

Tokenizer protocol and implementations.

CharTokenizer

Character-level tokenizer.

Simple tokenizer that maps each unique character to an ID. Useful for testing and small-scale experiments.

Can be initialized with text directly, or created with default ASCII printable characters (no args). Use fit() to rebuild the vocabulary from new text.

Parameters:

Name Type Description Default
text str | None

Text to build vocabulary from. If None, uses ASCII printable characters (32-126).

None
Source code in src/lmxlab/data/tokenizer.py
class CharTokenizer:
    """Character-level tokenizer.

    Simple tokenizer that maps each unique character to an ID.
    Useful for testing and small-scale experiments.

    Can be initialized with text directly, or created with
    default ASCII printable characters (no args). Use
    ``fit()`` to rebuild the vocabulary from new text.

    Args:
        text: Text to build vocabulary from. If None, uses
            ASCII printable characters (32-126).
    """

    def __init__(self, text: str | None = None) -> None:
        if text is not None:
            chars = sorted(set(text))
        else:
            chars = [chr(i) for i in range(32, 127)]
        self._char_to_id: dict[str, int] = {c: i for i, c in enumerate(chars)}
        self._id_to_char: dict[int, str] = {
            i: c for c, i in self._char_to_id.items()
        }

    def fit(self, text: str) -> None:
        """Build vocabulary from text.

        Args:
            text: Text to extract characters from.
        """
        chars = sorted(set(text))
        self._char_to_id = {c: i for i, c in enumerate(chars)}
        self._id_to_char = {i: c for c, i in self._char_to_id.items()}

    @property
    def vocab_size(self) -> int:
        """Size of the vocabulary."""
        return len(self._char_to_id)

    def encode(self, text: str) -> list[int]:
        """Encode text to character-level token IDs.

        Args:
            text: Input string.

        Returns:
            List of token IDs.

        Raises:
            KeyError: If text contains unknown characters.
        """
        return [self._char_to_id[c] for c in text]

    def decode(self, tokens: list[int]) -> str:
        """Decode token IDs back to text.

        Args:
            tokens: List of token IDs.

        Returns:
            Decoded string.
        """
        return "".join(self._id_to_char[t] for t in tokens)

vocab_size property

Size of the vocabulary.

decode(tokens)

Decode token IDs back to text.

Parameters:

Name Type Description Default
tokens list[int]

List of token IDs.

required

Returns:

Type Description
str

Decoded string.

Source code in src/lmxlab/data/tokenizer.py
def decode(self, tokens: list[int]) -> str:
    """Decode token IDs back to text.

    Args:
        tokens: List of token IDs.

    Returns:
        Decoded string.
    """
    return "".join(self._id_to_char[t] for t in tokens)

encode(text)

Encode text to character-level token IDs.

Parameters:

Name Type Description Default
text str

Input string.

required

Returns:

Type Description
list[int]

List of token IDs.

Raises:

Type Description
KeyError

If text contains unknown characters.

Source code in src/lmxlab/data/tokenizer.py
def encode(self, text: str) -> list[int]:
    """Encode text to character-level token IDs.

    Args:
        text: Input string.

    Returns:
        List of token IDs.

    Raises:
        KeyError: If text contains unknown characters.
    """
    return [self._char_to_id[c] for c in text]

fit(text)

Build vocabulary from text.

Parameters:

Name Type Description Default
text str

Text to extract characters from.

required
Source code in src/lmxlab/data/tokenizer.py
def fit(self, text: str) -> None:
    """Build vocabulary from text.

    Args:
        text: Text to extract characters from.
    """
    chars = sorted(set(text))
    self._char_to_id = {c: i for i, c in enumerate(chars)}
    self._id_to_char = {i: c for c, i in self._char_to_id.items()}

HFTokenizer

HuggingFace tokenizer wrapper.

Wraps a HuggingFace AutoTokenizer for use with lmxlab. Use this when working with pretrained models loaded via load_from_hf.

Requires transformers to be installed::

pip install transformers

Parameters:

Name Type Description Default
repo_id str

HuggingFace model repo ID or local path (e.g., 'meta-llama/Llama-3.2-1B').

required
Example

tok = HFTokenizer('meta-llama/Llama-3.2-1B') tok.encode('hello world') [15339, 1917] tok.decode([15339, 1917]) 'hello world'

Source code in src/lmxlab/data/tokenizer.py
class HFTokenizer:
    """HuggingFace tokenizer wrapper.

    Wraps a HuggingFace ``AutoTokenizer`` for use with lmxlab.
    Use this when working with pretrained models loaded via
    ``load_from_hf``.

    Requires ``transformers`` to be installed::

        pip install transformers

    Args:
        repo_id: HuggingFace model repo ID or local path
            (e.g., 'meta-llama/Llama-3.2-1B').

    Example:
        >>> tok = HFTokenizer('meta-llama/Llama-3.2-1B')
        >>> tok.encode('hello world')
        [15339, 1917]
        >>> tok.decode([15339, 1917])
        'hello world'
    """

    def __init__(self, repo_id: str) -> None:
        try:
            from transformers import AutoTokenizer
        except ImportError as e:
            raise ImportError(
                "transformers is required for HFTokenizer. "
                "Install with: pip install transformers"
            ) from e

        self._tok = AutoTokenizer.from_pretrained(repo_id)
        self._repo_id = repo_id

    @property
    def vocab_size(self) -> int:
        """Size of the vocabulary."""
        return len(self._tok)

    @property
    def eos_token_id(self) -> int | None:
        """End-of-sequence token ID, if available."""
        eos = self._tok.eos_token_id
        return int(eos) if eos is not None else None

    @property
    def bos_token_id(self) -> int | None:
        """Beginning-of-sequence token ID, if available."""
        bos = self._tok.bos_token_id
        return int(bos) if bos is not None else None

    def encode(self, text: str) -> list[int]:
        """Encode text to token IDs.

        Does not add special tokens (BOS/EOS) by default,
        so the output matches what the model expects for
        continuation.

        Args:
            text: Input string.

        Returns:
            List of token IDs.
        """
        return list(self._tok.encode(text, add_special_tokens=False))

    def decode(self, tokens: list[int]) -> str:
        """Decode token IDs back to text.

        Args:
            tokens: List of token IDs.

        Returns:
            Decoded string.
        """
        return str(self._tok.decode(tokens))

bos_token_id property

Beginning-of-sequence token ID, if available.

eos_token_id property

End-of-sequence token ID, if available.

vocab_size property

Size of the vocabulary.

decode(tokens)

Decode token IDs back to text.

Parameters:

Name Type Description Default
tokens list[int]

List of token IDs.

required

Returns:

Type Description
str

Decoded string.

Source code in src/lmxlab/data/tokenizer.py
def decode(self, tokens: list[int]) -> str:
    """Decode token IDs back to text.

    Args:
        tokens: List of token IDs.

    Returns:
        Decoded string.
    """
    return str(self._tok.decode(tokens))

encode(text)

Encode text to token IDs.

Does not add special tokens (BOS/EOS) by default, so the output matches what the model expects for continuation.

Parameters:

Name Type Description Default
text str

Input string.

required

Returns:

Type Description
list[int]

List of token IDs.

Source code in src/lmxlab/data/tokenizer.py
def encode(self, text: str) -> list[int]:
    """Encode text to token IDs.

    Does not add special tokens (BOS/EOS) by default,
    so the output matches what the model expects for
    continuation.

    Args:
        text: Input string.

    Returns:
        List of token IDs.
    """
    return list(self._tok.encode(text, add_special_tokens=False))

TiktokenTokenizer

BPE tokenizer using OpenAI's tiktoken.

Wraps a tiktoken encoding for use with lmxlab. Supports any tiktoken encoding name (e.g. 'gpt2', 'cl100k_base', 'o200k_base').

Requires tiktoken to be installed::

pip install tiktoken

Parameters:

Name Type Description Default
encoding_name str

Name of the tiktoken encoding. Defaults to 'gpt2' (50257 tokens).

'gpt2'
Example

tok = TiktokenTokenizer('gpt2') tok.encode('hello world') [31373, 995] tok.decode([31373, 995]) 'hello world'

Source code in src/lmxlab/data/tokenizer.py
class TiktokenTokenizer:
    """BPE tokenizer using OpenAI's tiktoken.

    Wraps a tiktoken encoding for use with lmxlab.
    Supports any tiktoken encoding name (e.g. 'gpt2',
    'cl100k_base', 'o200k_base').

    Requires ``tiktoken`` to be installed::

        pip install tiktoken

    Args:
        encoding_name: Name of the tiktoken encoding.
            Defaults to 'gpt2' (50257 tokens).

    Example:
        >>> tok = TiktokenTokenizer('gpt2')
        >>> tok.encode('hello world')
        [31373, 995]
        >>> tok.decode([31373, 995])
        'hello world'
    """

    def __init__(self, encoding_name: str = "gpt2") -> None:
        try:
            import tiktoken
        except ImportError as e:
            raise ImportError(
                "tiktoken is required for TiktokenTokenizer. "
                "Install it with: pip install tiktoken"
            ) from e

        self._enc = tiktoken.get_encoding(encoding_name)
        self._encoding_name = encoding_name

    @property
    def vocab_size(self) -> int:
        """Size of the vocabulary."""
        return int(self._enc.n_vocab)

    def encode(self, text: str) -> list[int]:
        """Encode text to BPE token IDs.

        Args:
            text: Input string.

        Returns:
            List of token IDs.
        """
        return list(self._enc.encode(text))

    def decode(self, tokens: list[int]) -> str:
        """Decode BPE token IDs back to text.

        Args:
            tokens: List of token IDs.

        Returns:
            Decoded string.
        """
        return str(self._enc.decode(tokens))

vocab_size property

Size of the vocabulary.

decode(tokens)

Decode BPE token IDs back to text.

Parameters:

Name Type Description Default
tokens list[int]

List of token IDs.

required

Returns:

Type Description
str

Decoded string.

Source code in src/lmxlab/data/tokenizer.py
def decode(self, tokens: list[int]) -> str:
    """Decode BPE token IDs back to text.

    Args:
        tokens: List of token IDs.

    Returns:
        Decoded string.
    """
    return str(self._enc.decode(tokens))

encode(text)

Encode text to BPE token IDs.

Parameters:

Name Type Description Default
text str

Input string.

required

Returns:

Type Description
list[int]

List of token IDs.

Source code in src/lmxlab/data/tokenizer.py
def encode(self, text: str) -> list[int]:
    """Encode text to BPE token IDs.

    Args:
        text: Input string.

    Returns:
        List of token IDs.
    """
    return list(self._enc.encode(text))

Tokenizer

Bases: Protocol

Protocol for tokenizers.

All tokenizers must implement encode/decode and expose their vocabulary size.

Source code in src/lmxlab/data/tokenizer.py
class Tokenizer(Protocol):
    """Protocol for tokenizers.

    All tokenizers must implement encode/decode and expose
    their vocabulary size.
    """

    @property
    def vocab_size(self) -> int:
        """Size of the vocabulary."""
        ...

    def encode(self, text: str) -> list[int]:
        """Encode text to token IDs.

        Args:
            text: Input string.

        Returns:
            List of token IDs.
        """
        ...

    def decode(self, tokens: list[int]) -> str:
        """Decode token IDs to text.

        Args:
            tokens: List of token IDs.

        Returns:
            Decoded string.
        """
        ...

vocab_size property

Size of the vocabulary.

decode(tokens)

Decode token IDs to text.

Parameters:

Name Type Description Default
tokens list[int]

List of token IDs.

required

Returns:

Type Description
str

Decoded string.

Source code in src/lmxlab/data/tokenizer.py
def decode(self, tokens: list[int]) -> str:
    """Decode token IDs to text.

    Args:
        tokens: List of token IDs.

    Returns:
        Decoded string.
    """
    ...

encode(text)

Encode text to token IDs.

Parameters:

Name Type Description Default
text str

Input string.

required

Returns:

Type Description
list[int]

List of token IDs.

Source code in src/lmxlab/data/tokenizer.py
def encode(self, text: str) -> list[int]:
    """Encode text to token IDs.

    Args:
        text: Input string.

    Returns:
        List of token IDs.
    """
    ...

Datasets

lmxlab.data.dataset.TextDataset

Dataset that tokenizes raw text.

Tokenizes text and stores as a flat array of token IDs. Yields overlapping windows of (input, target) pairs.

Parameters:

Name Type Description Default
text str

Raw text to tokenize.

required
tokenizer Tokenizer

Tokenizer to use.

required
seq_len int

Sequence length for training windows.

128
Source code in src/lmxlab/data/dataset.py
class TextDataset:
    """Dataset that tokenizes raw text.

    Tokenizes text and stores as a flat array of token IDs.
    Yields overlapping windows of (input, target) pairs.

    Args:
        text: Raw text to tokenize.
        tokenizer: Tokenizer to use.
        seq_len: Sequence length for training windows.
    """

    def __init__(
        self,
        text: str,
        tokenizer: Tokenizer,
        seq_len: int = 128,
    ) -> None:
        self.seq_len = seq_len
        self.tokenizer = tokenizer
        tokens = tokenizer.encode(text)
        self.tokens = mx.array(tokens, dtype=mx.int32)

    def __len__(self) -> int:
        """Number of training windows available."""
        return max(0, len(self.tokens) - self.seq_len)

    def __getitem__(self, idx: int) -> tuple[mx.array, mx.array]:
        """Get a (input, target) pair at the given index.

        Args:
            idx: Starting position in the token array.

        Returns:
            Tuple of (input_tokens, target_tokens), each
            of shape (seq_len,).
        """
        x = self.tokens[idx : idx + self.seq_len]
        y = self.tokens[idx + 1 : idx + self.seq_len + 1]
        return x, y

seq_len = seq_len instance-attribute

tokenizer = tokenizer instance-attribute

tokens = mx.array(tokens, dtype=(mx.int32)) instance-attribute

__getitem__(idx)

Get a (input, target) pair at the given index.

Parameters:

Name Type Description Default
idx int

Starting position in the token array.

required

Returns:

Type Description
array

Tuple of (input_tokens, target_tokens), each

array

of shape (seq_len,).

Source code in src/lmxlab/data/dataset.py
def __getitem__(self, idx: int) -> tuple[mx.array, mx.array]:
    """Get a (input, target) pair at the given index.

    Args:
        idx: Starting position in the token array.

    Returns:
        Tuple of (input_tokens, target_tokens), each
        of shape (seq_len,).
    """
    x = self.tokens[idx : idx + self.seq_len]
    y = self.tokens[idx + 1 : idx + self.seq_len + 1]
    return x, y

__init__(text, tokenizer, seq_len=128)

Source code in src/lmxlab/data/dataset.py
def __init__(
    self,
    text: str,
    tokenizer: Tokenizer,
    seq_len: int = 128,
) -> None:
    self.seq_len = seq_len
    self.tokenizer = tokenizer
    tokens = tokenizer.encode(text)
    self.tokens = mx.array(tokens, dtype=mx.int32)

__len__()

Number of training windows available.

Source code in src/lmxlab/data/dataset.py
def __len__(self) -> int:
    """Number of training windows available."""
    return max(0, len(self.tokens) - self.seq_len)

lmxlab.data.dataset.TokenDataset

Dataset from pre-tokenized data.

Wraps an existing array of token IDs.

Parameters:

Name Type Description Default
tokens array

Array of token IDs.

required
seq_len int

Sequence length for training windows.

128
Source code in src/lmxlab/data/dataset.py
class TokenDataset:
    """Dataset from pre-tokenized data.

    Wraps an existing array of token IDs.

    Args:
        tokens: Array of token IDs.
        seq_len: Sequence length for training windows.
    """

    def __init__(
        self,
        tokens: mx.array,
        seq_len: int = 128,
    ) -> None:
        self.tokens = tokens
        self.seq_len = seq_len

    def __len__(self) -> int:
        """Number of training windows available."""
        return max(0, len(self.tokens) - self.seq_len)

    def __getitem__(self, idx: int) -> tuple[mx.array, mx.array]:
        """Get a (input, target) pair.

        Args:
            idx: Starting position.

        Returns:
            Tuple of (input_tokens, target_tokens).
        """
        x = self.tokens[idx : idx + self.seq_len]
        y = self.tokens[idx + 1 : idx + self.seq_len + 1]
        return x, y

seq_len = seq_len instance-attribute

tokens = tokens instance-attribute

__getitem__(idx)

Get a (input, target) pair.

Parameters:

Name Type Description Default
idx int

Starting position.

required

Returns:

Type Description
tuple[array, array]

Tuple of (input_tokens, target_tokens).

Source code in src/lmxlab/data/dataset.py
def __getitem__(self, idx: int) -> tuple[mx.array, mx.array]:
    """Get a (input, target) pair.

    Args:
        idx: Starting position.

    Returns:
        Tuple of (input_tokens, target_tokens).
    """
    x = self.tokens[idx : idx + self.seq_len]
    y = self.tokens[idx + 1 : idx + self.seq_len + 1]
    return x, y

__init__(tokens, seq_len=128)

Source code in src/lmxlab/data/dataset.py
def __init__(
    self,
    tokens: mx.array,
    seq_len: int = 128,
) -> None:
    self.tokens = tokens
    self.seq_len = seq_len

__len__()

Number of training windows available.

Source code in src/lmxlab/data/dataset.py
def __len__(self) -> int:
    """Number of training windows available."""
    return max(0, len(self.tokens) - self.seq_len)

lmxlab.data.dataset.HFDataset

Dataset backed by a HuggingFace dataset.

Streams or loads a HuggingFace dataset, tokenizes on-the-fly, and yields batches of (input, target) pairs.

Requires the datasets package (pip install datasets).

Parameters:

Name Type Description Default
name str

HuggingFace dataset name (e.g. 'wikitext').

required
tokenizer Tokenizer

Tokenizer implementing the Tokenizer protocol.

required
seq_len int

Sequence length for training windows.

128
split str

Dataset split to use.

'train'
text_field str

Name of the text column in the dataset.

'text'
config_name str | None

Optional dataset configuration name.

None
streaming bool

Whether to stream the dataset.

False
Source code in src/lmxlab/data/dataset.py
class HFDataset:
    """Dataset backed by a HuggingFace dataset.

    Streams or loads a HuggingFace dataset, tokenizes on-the-fly,
    and yields batches of (input, target) pairs.

    Requires the ``datasets`` package (``pip install datasets``).

    Args:
        name: HuggingFace dataset name (e.g. ``'wikitext'``).
        tokenizer: Tokenizer implementing the Tokenizer protocol.
        seq_len: Sequence length for training windows.
        split: Dataset split to use.
        text_field: Name of the text column in the dataset.
        config_name: Optional dataset configuration name.
        streaming: Whether to stream the dataset.
    """

    def __init__(
        self,
        name: str,
        tokenizer: Tokenizer,
        seq_len: int = 128,
        split: str = "train",
        text_field: str = "text",
        config_name: str | None = None,
        streaming: bool = False,
    ) -> None:
        from datasets import load_dataset

        self.tokenizer = tokenizer
        self.seq_len = seq_len
        self.text_field = text_field
        self._streaming = streaming
        self._dataset = load_dataset(
            name, config_name, split=split, streaming=streaming
        )

    def token_iterator(self) -> Iterator[int]:
        """Yield token IDs one at a time from the dataset."""
        for example in self._dataset:
            text = example[self.text_field]
            if text and text.strip():
                yield from self.tokenizer.encode(text)

    def batch_iterator(
        self,
        batch_size: int = 8,
        max_batches: int | None = None,
    ) -> Iterator[tuple[mx.array, mx.array]]:
        """Yield (input, target) batches from the dataset.

        Accumulates tokens into a buffer and yields batches
        of shape ``(batch_size, seq_len)``.

        Args:
            batch_size: Number of sequences per batch.
            max_batches: Stop after this many batches.

        Yields:
            Tuple of (inputs, targets), each of shape
            ``(batch_size, seq_len)``.
        """
        buffer: list[int] = []
        tokens_needed = batch_size * self.seq_len + 1
        n_batches = 0
        for token_id in self.token_iterator():
            buffer.append(token_id)
            if len(buffer) >= tokens_needed:
                arr = mx.array(buffer[:tokens_needed], dtype=mx.int32)
                inputs = arr[:-1].reshape(batch_size, self.seq_len)
                targets = arr[1:].reshape(batch_size, self.seq_len)
                yield inputs, targets
                n_batches += 1
                if max_batches and n_batches >= max_batches:
                    return
                buffer = buffer[tokens_needed - 1 :]

_dataset = load_dataset(name, config_name, split=split, streaming=streaming) instance-attribute

_streaming = streaming instance-attribute

seq_len = seq_len instance-attribute

text_field = text_field instance-attribute

tokenizer = tokenizer instance-attribute

__init__(name, tokenizer, seq_len=128, split='train', text_field='text', config_name=None, streaming=False)

Source code in src/lmxlab/data/dataset.py
def __init__(
    self,
    name: str,
    tokenizer: Tokenizer,
    seq_len: int = 128,
    split: str = "train",
    text_field: str = "text",
    config_name: str | None = None,
    streaming: bool = False,
) -> None:
    from datasets import load_dataset

    self.tokenizer = tokenizer
    self.seq_len = seq_len
    self.text_field = text_field
    self._streaming = streaming
    self._dataset = load_dataset(
        name, config_name, split=split, streaming=streaming
    )

batch_iterator(batch_size=8, max_batches=None)

Yield (input, target) batches from the dataset.

Accumulates tokens into a buffer and yields batches of shape (batch_size, seq_len).

Parameters:

Name Type Description Default
batch_size int

Number of sequences per batch.

8
max_batches int | None

Stop after this many batches.

None

Yields:

Type Description
array

Tuple of (inputs, targets), each of shape

array

(batch_size, seq_len).

Source code in src/lmxlab/data/dataset.py
def batch_iterator(
    self,
    batch_size: int = 8,
    max_batches: int | None = None,
) -> Iterator[tuple[mx.array, mx.array]]:
    """Yield (input, target) batches from the dataset.

    Accumulates tokens into a buffer and yields batches
    of shape ``(batch_size, seq_len)``.

    Args:
        batch_size: Number of sequences per batch.
        max_batches: Stop after this many batches.

    Yields:
        Tuple of (inputs, targets), each of shape
        ``(batch_size, seq_len)``.
    """
    buffer: list[int] = []
    tokens_needed = batch_size * self.seq_len + 1
    n_batches = 0
    for token_id in self.token_iterator():
        buffer.append(token_id)
        if len(buffer) >= tokens_needed:
            arr = mx.array(buffer[:tokens_needed], dtype=mx.int32)
            inputs = arr[:-1].reshape(batch_size, self.seq_len)
            targets = arr[1:].reshape(batch_size, self.seq_len)
            yield inputs, targets
            n_batches += 1
            if max_batches and n_batches >= max_batches:
                return
            buffer = buffer[tokens_needed - 1 :]

token_iterator()

Yield token IDs one at a time from the dataset.

Source code in src/lmxlab/data/dataset.py
def token_iterator(self) -> Iterator[int]:
    """Yield token IDs one at a time from the dataset."""
    for example in self._dataset:
        text = example[self.text_field]
        if text and text.strip():
            yield from self.tokenizer.encode(text)

Batching

lmxlab.data.batching

Batch iterator for MLX training.

batch_iterator(tokens, batch_size, seq_len, shuffle=True)

Yield batches of (input, target) pairs from a token array.

Creates non-overlapping windows from the token array, optionally shuffles, and yields batches.

Parameters:

Name Type Description Default
tokens array

Flat array of token IDs.

required
batch_size int

Number of sequences per batch.

required
seq_len int

Length of each sequence.

required
shuffle bool

Whether to shuffle windows each epoch.

True

Yields:

Type Description
array

Tuples of (input_batch, target_batch), each of

array

shape (batch_size, seq_len).

Source code in src/lmxlab/data/batching.py
def batch_iterator(
    tokens: mx.array,
    batch_size: int,
    seq_len: int,
    shuffle: bool = True,
) -> Iterator[tuple[mx.array, mx.array]]:
    """Yield batches of (input, target) pairs from a token array.

    Creates non-overlapping windows from the token array,
    optionally shuffles, and yields batches.

    Args:
        tokens: Flat array of token IDs.
        batch_size: Number of sequences per batch.
        seq_len: Length of each sequence.
        shuffle: Whether to shuffle windows each epoch.

    Yields:
        Tuples of (input_batch, target_batch), each of
        shape (batch_size, seq_len).
    """
    # Calculate number of complete sequences
    n_tokens = len(tokens)
    n_sequences = (n_tokens - 1) // seq_len

    if n_sequences < batch_size:
        raise ValueError(
            f"Not enough data for batch_size={batch_size}: "
            f"only {n_sequences} sequences available"
        )

    # Truncate to fit evenly
    usable = n_sequences * seq_len
    data = tokens[: usable + 1]

    # Create input/target arrays
    # Shape: (n_sequences, seq_len)
    inputs = mx.stack(
        [data[i * seq_len : (i + 1) * seq_len] for i in range(n_sequences)]
    )
    targets = mx.stack(
        [
            data[i * seq_len + 1 : (i + 1) * seq_len + 1]
            for i in range(n_sequences)
        ]
    )

    # Shuffle
    if shuffle:
        indices = mx.random.permutation(n_sequences)
        inputs = inputs[indices]
        targets = targets[indices]

    # Yield batches
    n_batches = n_sequences // batch_size
    for i in range(n_batches):
        start = i * batch_size
        end = start + batch_size
        yield inputs[start:end], targets[start:end]