Core
The core package provides the building blocks for all architectures.
Configuration
lmxlab.core.config.BlockConfig
dataclass
Configuration for a single transformer block.
Defines which components (attention, FFN, norm, position encoding) to use and their parameters. Components are looked up by name from the registry.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
attention
|
str
|
Registry name for attention module. |
'mha'
|
ffn
|
str
|
Registry name for feed-forward module. |
'standard'
|
norm
|
str
|
Registry name for normalization function. |
'layer_norm'
|
position
|
str
|
Registry name for positional encoding. |
'sinusoidal'
|
d_model
|
int
|
Hidden dimension size. |
512
|
n_heads
|
int
|
Number of attention heads. |
8
|
n_kv_heads
|
int | None
|
Number of key/value heads (for GQA). Defaults to n_heads (standard MHA). |
None
|
d_ff
|
int
|
Feed-forward intermediate dimension. |
2048
|
bias
|
bool
|
Whether to use bias in linear layers. |
True
|
dropout
|
float
|
Dropout rate (0.0 = no dropout). |
0.0
|
norm_eps
|
float
|
Epsilon for normalization layers. |
1e-05
|
rope_theta
|
float
|
Base frequency for RoPE. |
10000.0
|
max_seq_len
|
int
|
Maximum sequence length. |
2048
|
pre_norm
|
bool
|
If True, apply norm before attention/FFN (pre-norm). If False, apply after (post-norm). |
True
|
window_size
|
int | None
|
Sliding window size for local attention. None means full (global) attention. |
None
|
mamba_bc_norm
|
bool
|
If True, apply RMSNorm to B and C projections in Mamba-3 (analogous to QK-norm). |
False
|
mamba_trapezoidal
|
bool
|
If True, use trapezoidal discretization in Mamba-3 (two SSD calls). |
False
|
mamba_complex_a
|
bool
|
If True, apply data-dependent RoPE to B and C in Mamba-3 (complex eigenvalues). |
False
|
qk_norm
|
bool
|
If True, apply per-head RMSNorm to Q and K after reshape, before RoPE (OLMo 2 style). |
False
|
attention_chunk_size
|
int | None
|
Chunk size for chunked local attention. None means full (global) attention. |
None
|
mup
|
bool
|
If True, use μP attention scaling (1/d_head instead of 1/√d_head). |
False
|
Source code in src/lmxlab/core/config.py
6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 | |
attention = 'mha'
class-attribute
instance-attribute
attention_chunk_size = None
class-attribute
instance-attribute
bias = True
class-attribute
instance-attribute
conv_kernel_size = 4
class-attribute
instance-attribute
d_ff = 2048
class-attribute
instance-attribute
d_model = 512
class-attribute
instance-attribute
dropout = 0.0
class-attribute
instance-attribute
effective_n_kv_heads
property
Number of KV heads (defaults to n_heads if not set).
ffn = 'standard'
class-attribute
instance-attribute
head_dim
property
Dimension per attention head.
kv_lora_rank = None
class-attribute
instance-attribute
mamba_bc_norm = False
class-attribute
instance-attribute
mamba_chunk_size = 128
class-attribute
instance-attribute
mamba_complex_a = False
class-attribute
instance-attribute
mamba_expand = 2
class-attribute
instance-attribute
mamba_head_dim = None
class-attribute
instance-attribute
mamba_n_groups = 1
class-attribute
instance-attribute
mamba_n_heads = None
class-attribute
instance-attribute
mamba_trapezoidal = False
class-attribute
instance-attribute
max_seq_len = 2048
class-attribute
instance-attribute
moe_d_ff = None
class-attribute
instance-attribute
moe_latent_size = None
class-attribute
instance-attribute
moe_n_groups = 1
class-attribute
instance-attribute
moe_routed_scaling_factor = 1.0
class-attribute
instance-attribute
moe_topk_groups = 1
class-attribute
instance-attribute
mup = False
class-attribute
instance-attribute
n_experts = None
class-attribute
instance-attribute
n_heads = 8
class-attribute
instance-attribute
n_kv_heads = None
class-attribute
instance-attribute
n_shared_experts = None
class-attribute
instance-attribute
norm = 'layer_norm'
class-attribute
instance-attribute
norm_eps = 1e-05
class-attribute
instance-attribute
position = 'sinusoidal'
class-attribute
instance-attribute
pre_norm = True
class-attribute
instance-attribute
q_lora_rank = None
class-attribute
instance-attribute
qk_norm = False
class-attribute
instance-attribute
rope_dim = None
class-attribute
instance-attribute
rope_theta = 10000.0
class-attribute
instance-attribute
shared_expert_d_ff = None
class-attribute
instance-attribute
sparse_compress_ratio = None
class-attribute
instance-attribute
sparse_select_k = None
class-attribute
instance-attribute
ssm_state_size = 128
class-attribute
instance-attribute
top_k_experts = 2
class-attribute
instance-attribute
use_short_conv = False
class-attribute
instance-attribute
window_size = None
class-attribute
instance-attribute
__init__(attention='mha', ffn='standard', norm='layer_norm', position='sinusoidal', d_model=512, n_heads=8, n_kv_heads=None, d_ff=2048, bias=True, dropout=0.0, norm_eps=1e-05, rope_theta=10000.0, max_seq_len=2048, pre_norm=True, window_size=None, kv_lora_rank=None, q_lora_rank=None, rope_dim=None, n_experts=None, top_k_experts=2, n_shared_experts=None, conv_kernel_size=4, use_short_conv=False, mamba_n_heads=None, mamba_head_dim=None, ssm_state_size=128, mamba_expand=2, mamba_n_groups=1, mamba_chunk_size=128, mamba_bc_norm=False, mamba_trapezoidal=False, mamba_complex_a=False, moe_latent_size=None, moe_d_ff=None, shared_expert_d_ff=None, moe_routed_scaling_factor=1.0, moe_n_groups=1, moe_topk_groups=1, qk_norm=False, attention_chunk_size=None, mup=False, sparse_compress_ratio=None, sparse_select_k=None)
lmxlab.core.config.ModelConfig
dataclass
Configuration for a full language model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
block
|
BlockConfig
|
Block configuration (shared across all layers). |
BlockConfig()
|
vocab_size
|
int
|
Vocabulary size. |
32000
|
n_layers
|
int
|
Number of transformer blocks. |
6
|
tie_embeddings
|
bool
|
Whether to tie input/output embeddings. |
True
|
block_configs
|
tuple[BlockConfig, ...] | None
|
Per-layer block overrides (optional). If provided, must have length n_layers. |
None
|
mup_base_width
|
int | None
|
Base model width for μP. When set, enables μP scaling. None means standard parameterization (SP). |
None
|
mtp_n_predict
|
int
|
Number of multi-token prediction heads. 0 disables MTP. |
0
|
mtp_lambda
|
float
|
Weight for MTP auxiliary loss. |
0.1
|
Source code in src/lmxlab/core/config.py
block = field(default_factory=BlockConfig)
class-attribute
instance-attribute
block_configs = None
class-attribute
instance-attribute
mtp_lambda = 0.1
class-attribute
instance-attribute
mtp_n_predict = 0
class-attribute
instance-attribute
mup_base_width = None
class-attribute
instance-attribute
n_layers = 6
class-attribute
instance-attribute
tie_embeddings = True
class-attribute
instance-attribute
vocab_size = 32000
class-attribute
instance-attribute
width_mult
property
Width multiplier for μP (d_model / base_width).
Returns 1.0 when μP is disabled.
__init__(block=BlockConfig(), vocab_size=32000, n_layers=6, tie_embeddings=True, block_configs=None, mup_base_width=None, mtp_n_predict=0, mtp_lambda=0.1)
get_block_config(layer_idx)
Get block config for a specific layer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
layer_idx
|
int
|
Layer index. |
required |
Returns:
| Type | Description |
|---|---|
BlockConfig
|
BlockConfig for the given layer. |
Source code in src/lmxlab/core/config.py
ConfigurableBlock
lmxlab.core.block.ConfigurableBlock
Bases: Module
A transformer block assembled from registry components.
The block uses pre-norm or post-norm residual connections depending on the config. Components (attention, FFN, norm, position encoding) are looked up from registries by name.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
BlockConfig
|
Block configuration specifying components. |
required |
Example
config = BlockConfig( ... attention='gqa', ffn='gated', ... norm='rms_norm', position='rope', ... d_model=256, n_heads=4, n_kv_heads=2, ... ) block = ConfigurableBlock(config)
Source code in src/lmxlab/core/block.py
16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 | |
_alibi = self.position if config.position == 'alibi' else None
instance-attribute
_rope = self.position if config.position == 'rope' else None
instance-attribute
attention = attn_cls(config)
instance-attribute
attn_norm = norm_cls(config)
instance-attribute
config = config
instance-attribute
ffn = ffn_cls(config)
instance-attribute
ffn_norm = norm_cls(config)
instance-attribute
position = position_registry.get(config.position)(config)
instance-attribute
resid_dropout = nn.Dropout(p=(config.dropout))
instance-attribute
__call__(x, mask=None, cache=None)
Forward pass through the block.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
array
|
Input tensor (batch, seq_len, d_model). |
required |
mask
|
array | None
|
Optional attention mask. |
None
|
cache
|
tuple[array, array] | None
|
Optional KV cache for generation. |
None
|
Returns:
| Type | Description |
|---|---|
tuple[array, tuple[array, array] | None]
|
Tuple of (output, updated_cache). |
Source code in src/lmxlab/core/block.py
__init__(config)
Source code in src/lmxlab/core/block.py
_post_norm_forward(x, mask, cache)
Post-norm: sublayer -> dropout -> residual -> norm.
Source code in src/lmxlab/core/block.py
_pre_norm_forward(x, mask, cache)
Pre-norm: norm -> sublayer -> dropout -> residual.
Source code in src/lmxlab/core/block.py
Attention
lmxlab.core.attention.MHA
Bases: AttentionBase
Multi-Head Attention using mx.fast.scaled_dot_product_attention.
Standard MHA where n_kv_heads == n_heads.
Source code in src/lmxlab/core/attention.py
__init__(config)
Source code in src/lmxlab/core/attention.py
__call__(x, mask=None, cache=None, rope=None)
Source code in src/lmxlab/core/attention.py
lmxlab.core.attention.GQA
Bases: AttentionBase
Grouped-Query Attention.
Uses fewer KV heads than query heads for memory efficiency. When n_kv_heads == 1, this is Multi-Query Attention (MQA). When n_kv_heads == n_heads, this is standard MHA.
Source code in src/lmxlab/core/attention.py
__init__(config)
Source code in src/lmxlab/core/attention.py
__call__(x, mask=None, cache=None, rope=None)
Source code in src/lmxlab/core/attention.py
lmxlab.core.attention.SlidingWindowGQA
Bases: AttentionBase
Grouped-Query Attention with sliding window masking.
Each token can only attend to the most recent
window_size tokens (including itself). Uses GQA head
configuration for memory efficiency.
The window size is read from config.window_size.
Source code in src/lmxlab/core/attention.py
__init__(config)
Source code in src/lmxlab/core/attention.py
__call__(x, mask=None, cache=None, rope=None)
Source code in src/lmxlab/core/attention.py
Multi-Head Latent Attention
lmxlab.core.mla.MLA
Bases: AttentionBase
Multi-Head Latent Attention.
Compresses KV into a low-rank latent for efficient caching. Uses a shared single-head RoPE key (MQA-style) for the position-dependent portion of K.
Config requirements
kv_lora_rank: Latent dimension for KV compression. q_lora_rank: Latent dimension for Q compression (optional). rope_dim: Dimensions allocated for decoupled RoPE.
Source code in src/lmxlab/core/mla.py
33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 | |
__init__(config)
Source code in src/lmxlab/core/mla.py
__call__(x, mask=None, cache=None, rope=None)
Source code in src/lmxlab/core/mla.py
100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 | |
Gated DeltaNet
lmxlab.core.deltanet.GatedDeltaNet
Bases: AttentionBase
Gated Delta Network for linear attention.
Uses the delta rule for error-correcting state updates with learned decay and update gates. The state matrix S has fixed size (d_k, d_v) regardless of sequence length, giving O(1) memory per token during inference.
Forward pass per token
- Project x -> Q, K, V, decay_logits, update_logits
- Apply causal convolution on Q, K, V (local context)
- Compute gates: alpha = sigmoid(decay), beta = sigmoid(update)
- L2 normalize Q, K
- Delta update: S = alpha * S - beta * (S @ k - v) @ k^T
- Output: o = S^T @ q
(S, conv_state) where:
S: (B, H, head_dim, head_dim) — the state matrix conv_state: (B, H, K-1, head_dim) — conv history
Source code in src/lmxlab/core/deltanet.py
100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 | |
__init__(config)
Source code in src/lmxlab/core/deltanet.py
__call__(x, mask=None, cache=None, rope=None)
Forward pass with delta rule state updates.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
array
|
Input (B, L, d_model). |
required |
mask
|
array | None
|
Unused (DeltaNet is inherently causal via recurrent state). |
None
|
cache
|
tuple[array, ...] | None
|
Tuple of (S, q_conv, k_conv, v_conv) for autoregressive inference. |
None
|
Returns:
| Type | Description |
|---|---|
tuple[array, tuple[array, ...] | None]
|
Tuple of (output, new_cache). |
Source code in src/lmxlab/core/deltanet.py
179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 | |
Feed-Forward Networks
lmxlab.core.ffn.StandardFFN
Bases: FFNBase
Standard two-layer feed-forward network with GELU activation.
FFN(x) = W2 * GELU(W1 * x + b1) + b2
Source code in src/lmxlab/core/ffn.py
__init__(config)
lmxlab.core.ffn.GatedFFN
Bases: FFNBase
Gated feed-forward network (SwiGLU variant).
FFN(x) = W_down * (SiLU(W_gate * x) * W_up * x) Used in LLaMA, Mistral, etc.
Source code in src/lmxlab/core/ffn.py
__init__(config)
Source code in src/lmxlab/core/ffn.py
Mixture of Experts
lmxlab.core.moe.MoEFFN
Bases: Module
Mixture of Experts feed-forward network.
Routes each token to the top-k experts via a learned router. Each expert is a GatedFFN (SwiGLU).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
BlockConfig
|
Block configuration. |
required |
n_experts
|
int | None
|
Total number of experts. |
None
|
top_k
|
int | None
|
Number of experts per token. |
None
|
Source code in src/lmxlab/core/moe.py
__init__(config, n_experts=None, top_k=None)
Source code in src/lmxlab/core/moe.py
__call__(x)
Route tokens to top-k experts and combine outputs.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
array
|
Input tensor (batch, seq_len, d_model). |
required |
Returns:
| Type | Description |
|---|---|
array
|
Output tensor (batch, seq_len, d_model). |
Source code in src/lmxlab/core/moe.py
lmxlab.core.moe.SharedExpertMoEFFN
Bases: Module
MoE with shared experts and bias-based load balancing.
Combines a set of always-active shared experts with top-k routed experts. Uses aux-loss-free load balancing: a learnable bias is added to router logits for expert selection, but the original un-biased scores are used for gating weights.
Output = shared_experts(x) + routed_experts_output(x)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
config
|
BlockConfig
|
Block configuration. |
required |
n_experts
|
int | None
|
Number of routed experts. |
None
|
top_k
|
int | None
|
Number of routed experts per token. |
None
|
n_shared
|
int | None
|
Number of shared (always-active) experts. |
None
|
Source code in src/lmxlab/core/moe.py
85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 | |
__init__(config, n_experts=None, top_k=None, n_shared=None)
Source code in src/lmxlab/core/moe.py
__call__(x)
Route tokens and combine with shared expert output.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
array
|
Input tensor (batch, seq_len, d_model). |
required |
Returns:
| Type | Description |
|---|---|
array
|
Output tensor (batch, seq_len, d_model). |
Source code in src/lmxlab/core/moe.py
Normalization
lmxlab.core.norm
Normalization wrappers for registry use.
LayerNorm
Bases: LayerNorm
LayerNorm wrapper that constructs from BlockConfig.
Source code in src/lmxlab/core/norm.py
RMSNorm
Bases: RMSNorm
RMSNorm wrapper that constructs from BlockConfig.
Source code in src/lmxlab/core/norm.py
layer_norm(config)
Position Encoding
lmxlab.core.position
Positional encoding modules: RoPE, ALiBi, Sinusoidal.
ALiBi
Bases: Module
Attention with Linear Biases (Press et al. ICLR 2022, arXiv:2108.12409).
Adds head-specific distance-based biases to the attention mask. Each head gets a fixed slope from a geometric sequence, penalizing distant tokens more strongly. Replaces explicit positional embeddings — no PE is added to inputs.
Applied to the attention mask before softmax, not to Q/K.
Source code in src/lmxlab/core/position.py
__call__(mask=None, seq_len=0, cache_len=0)
Create ALiBi-biased attention mask.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
mask
|
array | None
|
Optional causal mask (T_q, T_k). |
None
|
seq_len
|
int
|
Query sequence length. |
0
|
cache_len
|
int
|
Cached key sequence length (offset). |
0
|
Returns:
| Type | Description |
|---|---|
array
|
Combined ALiBi bias + mask (1, H, T_q, T_k). |
Source code in src/lmxlab/core/position.py
NoPosition
Bases: Module
No positional encoding (identity).
Used by architectures that get position information from other mechanisms (e.g. causal convolutions in DeltaNet).
Source code in src/lmxlab/core/position.py
RoPE
Bases: Module
Rotary Position Embedding wrapper.
Wraps nn.RoPE with config-driven initialization.
Source code in src/lmxlab/core/position.py
__call__(q, k, offset=0)
Apply rotary embeddings to queries and keys.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
q
|
array
|
Query tensor (batch, heads, seq, head_dim). |
required |
k
|
array
|
Key tensor (batch, kv_heads, seq, head_dim). |
required |
offset
|
int
|
Position offset for KV cache. |
0
|
Returns:
| Type | Description |
|---|---|
tuple[array, array]
|
Tuple of (rotated_q, rotated_k). |
Source code in src/lmxlab/core/position.py
Sinusoidal
Bases: Module
Sinusoidal positional encoding (added to embeddings).
Source code in src/lmxlab/core/position.py
__call__(x)
Add sinusoidal position encoding to input.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
array
|
Input tensor (batch, seq_len, d_model). |
required |
Returns:
| Type | Description |
|---|---|
array
|
Input with positional encoding added. |
Source code in src/lmxlab/core/position.py
alibi(config)
rope(config)
Quantization
lmxlab.core.quantize.quantize_model(model, bits=4, group_size=64, mode='affine')
Quantize all Linear and Embedding layers in-place.
Uses MLX's native quantization. Linear layers become
nn.QuantizedLinear, Embedding layers become
nn.QuantizedEmbedding. Norm layers and other modules
are left unchanged.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Module
|
Model to quantize (modified in-place). |
required |
bits
|
int
|
Bits per weight (2, 4, or 8). |
4
|
group_size
|
int
|
Quantization group size (32, 64, or 128). |
64
|
mode
|
str
|
Quantization mode. Default: |
'affine'
|
Source code in src/lmxlab/core/quantize.py
lmxlab.core.quantize.dequantize_model(model)
Dequantize all QuantizedLinear layers back to Linear.
Reconstructs float weights from quantized representation. Useful for fine-tuning after loading quantized weights.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Module
|
Model to dequantize (modified in-place). |
required |
Source code in src/lmxlab/core/quantize.py
LoRA
lmxlab.core.lora.LoRALinear
Bases: Module
Linear layer with low-rank adaptation.
Computes: y = xW^T + b + scaling * x @ A @ B^T
where W is frozen and A, B are trainable low-rank matrices. B is zero-initialized so the initial output equals the base linear layer's output.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input_dims
|
int
|
Input feature dimension. |
required |
output_dims
|
int
|
Output feature dimension. |
required |
rank
|
int
|
LoRA rank (low-rank dimension). |
8
|
alpha
|
float
|
LoRA scaling factor. Effective scaling = alpha/rank. |
1.0
|
bias
|
bool
|
Whether the base layer has bias. |
False
|
Source code in src/lmxlab/core/lora.py
42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 | |
__init__(input_dims, output_dims, rank=8, alpha=1.0, bias=False)
Source code in src/lmxlab/core/lora.py
__call__(x)
from_linear(linear, rank=8, alpha=1.0)
classmethod
Create a LoRALinear from an existing nn.Linear.
Copies the base weight and bias, then adds LoRA matrices.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
linear
|
Linear
|
Base linear layer to wrap. |
required |
rank
|
int
|
LoRA rank. |
8
|
alpha
|
float
|
LoRA scaling factor. |
1.0
|
Returns:
| Type | Description |
|---|---|
LoRALinear
|
LoRALinear with the same base weights. |
Source code in src/lmxlab/core/lora.py
to_linear()
Merge LoRA weights and return a plain nn.Linear.
Computes W_merged = W + scaling * (A @ B)^T and returns a new nn.Linear with the merged weight.
Source code in src/lmxlab/core/lora.py
lmxlab.core.lora.apply_lora(model, rank=8, alpha=1.0, targets=None)
Apply LoRA to a model's linear layers in-place.
Replaces targeted nn.Linear layers with LoRALinear,
freezing the base weights and making only the LoRA matrices
trainable.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Module
|
Model to modify (in-place). |
required |
rank
|
int
|
LoRA rank for all adapted layers. |
8
|
alpha
|
float
|
LoRA scaling factor. |
1.0
|
targets
|
list[str] | None
|
Which submodules to target. Options:
|
None
|
Source code in src/lmxlab/core/lora.py
lmxlab.core.lora.merge_lora(model)
Merge LoRA weights into base weights in-place.
Replaces all LoRALinear layers with plain nn.Linear
layers whose weights include the LoRA contribution. After
merging, the model has the same output but no LoRA overhead.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Module
|
Model to merge (in-place). |
required |
Source code in src/lmxlab/core/lora.py
lmxlab.core.lora.lora_parameters(model)
Extract only LoRA parameters from a model.
Useful for saving/loading just the LoRA adapter weights, which are much smaller than the full model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Module
|
Model with LoRA layers. |
required |
Returns:
| Type | Description |
|---|---|
dict
|
Nested dict of only lora_A and lora_B parameters. |
Source code in src/lmxlab/core/lora.py
lmxlab.core.lora.save_lora_adapters(path, model, rank=None, alpha=None, metadata=None)
Save only the LoRA adapter weights to a directory.
Saves lora_A and lora_B parameters as a small safetensors file, much smaller than a full model checkpoint. The adapter can be loaded on top of any compatible base model.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path
|
str | Path
|
Directory to save the adapter. |
required |
model
|
Module
|
Model with LoRA layers applied. |
required |
rank
|
int | None
|
LoRA rank (saved in config for reference). |
None
|
alpha
|
float | None
|
LoRA alpha (saved in config for reference). |
None
|
metadata
|
dict[str, Any] | None
|
Additional metadata to include in config. |
None
|
Source code in src/lmxlab/core/lora.py
lmxlab.core.lora.load_lora_adapters(path, model)
Load LoRA adapter weights into a model.
The model must already have LoRA layers applied (via
apply_lora). This function loads only the lora_A and
lora_B weights from a previously saved adapter.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
path
|
str | Path
|
Directory containing the adapter files. |
required |
model
|
Module
|
Model with LoRA layers to load weights into. |
required |
Returns:
| Type | Description |
|---|---|
dict[str, Any]
|
Metadata dict from adapter_config.json. |
Raises:
| Type | Description |
|---|---|
FileNotFoundError
|
If path does not exist. |
Source code in src/lmxlab/core/lora.py
QLoRA
lmxlab.core.qlora.LoRAQuantizedLinear
Bases: Module
Quantized linear layer with low-rank adaptation.
Computes: y = quantized_matmul(x, W_q) + bias + scaling * x @ A @ B
where W_q is a frozen quantized weight and A, B are trainable float LoRA matrices. B is zero-initialized so the initial output equals the quantized layer's output.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
input_dims
|
int
|
Input feature dimension. |
required |
output_dims
|
int
|
Output feature dimension. |
required |
rank
|
int
|
LoRA rank (low-rank dimension). |
8
|
alpha
|
float
|
LoRA scaling factor. Effective scaling = alpha/rank. |
1.0
|
bias
|
bool
|
Whether the layer has bias. |
False
|
group_size
|
int
|
Quantization group size. |
64
|
bits
|
int
|
Quantization bits. |
4
|
mode
|
str
|
Quantization mode. |
'affine'
|
Source code in src/lmxlab/core/qlora.py
34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 | |
__init__(input_dims, output_dims, rank=8, alpha=1.0, bias=False, group_size=64, bits=4, mode='affine')
Source code in src/lmxlab/core/qlora.py
__call__(x)
Source code in src/lmxlab/core/qlora.py
from_quantized(ql, rank=8, alpha=1.0)
classmethod
Create from an existing QuantizedLinear layer.
Copies the quantized weights and adds LoRA adapters.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
ql
|
QuantizedLinear
|
Quantized linear layer to wrap. |
required |
rank
|
int
|
LoRA rank. |
8
|
alpha
|
float
|
LoRA scaling factor. |
1.0
|
Returns:
| Type | Description |
|---|---|
LoRAQuantizedLinear
|
LoRAQuantizedLinear with same quantized base weights. |
Source code in src/lmxlab/core/qlora.py
lmxlab.core.qlora.apply_qlora(model, rank=8, alpha=1.0, targets=None)
Apply QLoRA to a quantized model's layers in-place.
Replaces targeted nn.QuantizedLinear layers with
LoRAQuantizedLinear, keeping base weights quantized and
adding trainable float LoRA matrices.
The model should already be quantized (via quantize_model
or nn.quantize) before calling this.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
model
|
Module
|
Quantized model to modify (in-place). |
required |
rank
|
int
|
LoRA rank for all adapted layers. |
8
|
alpha
|
float
|
LoRA scaling factor. |
1.0
|
targets
|
list[str] | None
|
Which submodules to target. Options:
|
None
|
Source code in src/lmxlab/core/qlora.py
Registry
lmxlab.core.registry.Registry
A typed registry mapping string names to factory functions.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
name
|
str
|
Human-readable name for this registry. |
required |
Example
reg = Registrytype @reg.register('mha') ... class MHA: ... reg.get('mha')
Source code in src/lmxlab/core/registry.py
7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 | |
_entries = {}
instance-attribute
_name = name
instance-attribute
name
property
Registry name.
__contains__(key)
__init__(name)
__repr__()
get(key)
Look up a component by key.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
key
|
str
|
Registered name. |
required |
Returns:
| Type | Description |
|---|---|
T
|
The registered component. |
Raises:
| Type | Description |
|---|---|
KeyError
|
If key is not found. |
Source code in src/lmxlab/core/registry.py
keys()
register(key, value=None)
Register a component under a key.
Can be used as a decorator factory or called directly:
@registry.register('name')
class MyClass: ...
registry.register('name', MyClass)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
key
|
str
|
String name for lookup. |
required |
value
|
T | None
|
The component to register. If None, returns a decorator that registers the decorated object. |
None
|
Returns:
| Type | Description |
|---|---|
T | Callable[[T], T]
|
The registered value, or a decorator if value is None. |
Raises:
| Type | Description |
|---|---|
ValueError
|
If key is already registered. |