MLX Development Guide
Environment Setup
Use uv for Python environment and package management:
# Install MLX
uv add mlx
# Run MLX scripts
uv run python train.py
# Run with specific dependencies
uv run --with mlx python script.py
Critical Rules
1. Lazy Evaluation - Always Evaluate at Loop Boundaries
Operations build a graph; nothing computes until mx.eval():
# CORRECT: Evaluate at iteration boundaries
for batch in dataset:
loss, grads = value_and_grad_fn(model, batch)
optimizer.update(model, grads)
mx.eval(loss, model.parameters()) # ALL computation here
# WRONG: Evaluating too frequently
for _ in range(100):
a = a + b
mx.eval(a) # Massive overhead!
Implicit eval triggers: print(a), a.item(), np.array(a), if a > 0:.
2. Array Indexing Differs from NumPy
# Lists must be mx.array
a[[0, 1]] # ValueError!
a[mx.array([0, 1])] # Works
# Slice indices must be Python ints
i = mx.array(2)
x[i:i+2] # ValueError!
x[i.item():i.item()+2] # Works (forces eval)
# Slices create COPIES, not views (opposite of NumPy)
b = a[:]
b[2] = 0 # a is unchanged!
# Boolean mask READS not supported
a[mask] # Not supported - use mx.where()
# No bounds checking - out-of-bounds returns garbage
For accumulating updates, use at[] syntax:
a = a.at[idx].add(1) # Properly accumulates at duplicate indices
See references/array-indexing.md for complete patterns.
3. Neural Networks: NHWC Format and call
# Conv2d uses NHWC (not NCHW like PyTorch)
x_mlx = mx.array(x_torch.numpy().transpose(0, 2, 3, 1))
# Override __call__, not forward()
class MyModel(nn.Module):
def __call__(self, x): # NOT forward()
return self.layer(x)
# No dtype in constructors - use set_dtype()
layer = nn.Linear(10, 10)
layer.set_dtype(mx.bfloat16)
See references/neural-networks.md for layer equivalents.
4. Data Types: float64 is CPU-Only
a = mx.array([1.0], dtype=mx.float64)
mx.exp(a, stream=mx.gpu) # RuntimeError!
# Solutions:
mx.exp(a, stream=mx.cpu)
mx.exp(a.astype(mx.float32))
# bfloat16 from external sources gets misinterpreted
from ml_dtypes import bfloat16
x = np.array(1., dtype=bfloat16)
mx.array(x) # Returns complex64!
mx.array(x.astype(np.float32), dtype=mx.bfloat16) # Correct
See references/dtypes.md for full type support table.
5. Compilation: Capture All Mutable State
from functools import partial
state = [model.state, optimizer.state, mx.random.state] # Include random!
@partial(mx.compile, inputs=state, outputs=state)
def train_step(x, y):
loss, grads = nn.value_and_grad(model, loss_fn)(model, x, y)
optimizer.update(model, grads)
return loss
# No print() in compiled functions - crashes during tracing
# String decoding triggers recompilation - decode outside loop
See references/compilation.md for recompilation triggers.
Quick Reference Tables
Dtype Support
| Type | GPU | Notes | |------|-----|-------| | float32 | Yes | Default float | | float16 | Yes | | | bfloat16 | Yes | M3+ recommended | | float64 | CPU only | GPU throws! | | int8-64, uint8-64 | Yes | | | complex64 | Partial | No matmul |
PyTorch → MLX Equivalents
| PyTorch | MLX |
|---------|-----|
| tensor.to('cuda') | Not needed (unified memory) |
| nn.forward() | nn.__call__() |
| NCHW format | NHWC format |
| torch.gather() | mx.take_along_axis() |
| torch.scatter_add_() | arr.at[idx].add() |
Not Available in MLX
np.nonzero()- restructure algorithmnp.unique()- pre-sort or use dictsarr[bool_mask]read - usemx.where()np.linalg.det(),np.linalg.lstsq()
Performance Notes
- Transformers: MLX typically 2-3x faster than PyTorch MPS
- Convolutions: 10-150x SLOWER than PyTorch MPS (known limitation)
- LLM inference: Excellent, especially quantized
- Use float16/bfloat16 for 2x memory bandwidth
- Use 4-bit quantization for LLMs (4x bandwidth)
See Also
- references/array-indexing.md - Complete indexing patterns, at[] syntax, gather/scatter
- references/neural-networks.md - nn module, layers, weight format
- references/compilation.md - mx.compile patterns, state capture, recompilation
- references/memory-management.md - Memory APIs, debugging leaks
- references/dtypes.md - Type support, conversion patterns
- references/random.md - Seed vs key, splitting patterns
- references/gradients.md - Autodiff, value_and_grad, control flow
- references/pytorch-migration.md - Weight conversion, format changes
- references/error-decoder.md - Common errors → solutions
Idiomatic Training Example
import mlx.core as mx
import mlx.nn as nn
import mlx.optimizers as optim
from functools import partial
class Model(nn.Module):
def __init__(self):
super().__init__()
self.layers = [nn.Linear(784, 256), nn.Linear(256, 10)]
def __call__(self, x):
for layer in self.layers[:-1]:
x = mx.maximum(layer(x), 0)
return self.layers[-1](x)
def loss_fn(model, x, y):
return nn.losses.cross_entropy(model(x), y, reduction="mean")
model = Model()
optimizer = optim.AdamW(learning_rate=1e-3)
state = [model.state, optimizer.state, mx.random.state]
@partial(mx.compile, inputs=state, outputs=state)
def train_step(x, y):
loss, grads = nn.value_and_grad(model, loss_fn)(model, x, y)
optimizer.update(model, grads)
return loss
for epoch in range(num_epochs):
for x_batch, y_batch in dataloader:
loss = train_step(x_batch, y_batch)
mx.eval(state)
print(f"Epoch {epoch}: {loss.item():.4f}")