Batching & JIT Guide¶
Efficient JAX patterns for hyperbolic deep learning with vmap-native APIs and JIT compilation.
Overview¶
Hyperbolix adopts a vmap-native API design where all manifold functions operate on single points/vectors. This design provides maximum flexibility and composability with JAX's transformation system.
Key Design Principles
- Functions operate on single points with shape
(dim,)or(dim+1,)(ambient) - Use
jax.vmapfor batch operations - Use
jax.jitfor compilation with appropriate static arguments - No built-in
axisorkeepdimparameters — compose transformations explicitly
The vmap-Native API¶
Single Point Operations¶
All manifold methods work with individual points:
import jax.numpy as jnp
from hyperbolix.manifolds import Poincare
poincare = Poincare()
# Single points (intrinsic coordinates)
x = jnp.array([0.1, 0.2]) # Shape: (2,)
y = jnp.array([0.3, 0.4]) # Shape: (2,)
# Compute distance between two points
distance = poincare.dist(x, y, c=1.0, version_idx=poincare.VERSION_MOBIUS_DIRECT)
print(distance) # Scalar
# Exponential map from origin
v = jnp.array([0.5, 0.0]) # Tangent vector at origin
point = poincare.expmap_0(v, c=1.0)
print(point.shape) # (2,)
Batching with vmap¶
Use jax.vmap to process batches efficiently:
import jax
poincare = Poincare()
# Batch of points
x_batch = jnp.array([[0.1, 0.2], [0.15, 0.25], [0.05, 0.1]]) # (3, 2)
y_batch = jnp.array([[0.3, 0.4], [0.35, 0.45], [0.2, 0.3]]) # (3, 2)
# Option 1: Explicit vmap
dist_fn = jax.vmap(poincare.dist, in_axes=(0, 0, None, None))
distances = dist_fn(x_batch, y_batch, 1.0, poincare.VERSION_MOBIUS_DIRECT)
print(distances.shape) # (3,)
# Option 2: Inline vmap
distances = jax.vmap(
lambda x, y: poincare.dist(x, y, c=1.0, version_idx=poincare.VERSION_MOBIUS_DIRECT)
)(x_batch, y_batch)
Understanding in_axes¶
The in_axes parameter specifies which axes to map over:
# in_axes=(0, 0, None, None) means:
# - Map over axis 0 of first argument (x_batch)
# - Map over axis 0 of second argument (y_batch)
# - Don't map over curvature (c) — use same value for all
# - Don't map over version_idx — static argument
Common patterns:
poincare = Poincare()
# Project batch of points
x_batch = jax.random.normal(jax.random.PRNGKey(0), (100, 16))
x_proj = jax.vmap(poincare.proj, in_axes=(0, None))(x_batch, 1.0)
# Compute distances from single point to batch
origin = jnp.zeros(16)
x_batch = jax.random.normal(jax.random.PRNGKey(0), (100, 16)) * 0.3
distances = jax.vmap(
lambda x: poincare.dist(origin, x, c=1.0, version_idx=poincare.VERSION_MOBIUS_DIRECT)
)(x_batch)
print(distances.shape) # (100,)
# Exponential map with batch of tangent vectors
v_batch = jax.random.normal(jax.random.PRNGKey(0), (100, 16))
base_point = jnp.zeros(16)
points = jax.vmap(
lambda v: poincare.expmap(v, base_point, c=1.0)
)(v_batch)
print(points.shape) # (100, 16)
JIT Compilation¶
Basic JIT Usage¶
Use jax.jit to compile functions for 10-100x speedup:
from hyperbolix.manifolds import Poincare
poincare = Poincare()
# Without JIT
distance = poincare.dist(x, y, c=1.0, version_idx=poincare.VERSION_MOBIUS_DIRECT)
# With JIT (version_idx is static since it controls which kernel to run)
dist_jit = jax.jit(poincare.dist, static_argnames=['version_idx'])
distance = dist_jit(x, y, c=1.0, version_idx=poincare.VERSION_MOBIUS_DIRECT)
JIT Performance
- First call: Slow (compilation overhead, 100ms-1s)
- Subsequent calls: Fast (10-100x speedup)
- Most beneficial for large batches (1000+) and high dimensions (128+)
Static vs Dynamic Arguments¶
Static arguments are known at compile time and trigger recompilation if changed:
# version_idx is static (integer constant)
dist_jit = jax.jit(poincare.dist, static_argnames=['version_idx'])
# These compile once and reuse:
d1 = dist_jit(x1, y1, c=1.0, version_idx=0)
d2 = dist_jit(x2, y2, c=1.5, version_idx=0) # Reuses compilation
# This triggers recompilation (different version_idx):
d3 = dist_jit(x3, y3, c=1.0, version_idx=1)
Dynamic arguments can change without recompilation:
# Curvature 'c' is dynamic (can vary)
d1 = dist_jit(x1, y1, c=1.0, version_idx=0)
d2 = dist_jit(x2, y2, c=2.5, version_idx=0) # No recompilation needed
Learnable Curvature
Keep curvature parameter c dynamic (not static) to support gradient-based learning of curvature values during training.
Combining vmap and jit¶
The order matters for performance:
from hyperbolix.manifolds import Poincare
poincare = Poincare()
# Pattern 1: JIT then vmap (RECOMMENDED)
@jax.jit
def distance_fn(x, y, c):
return poincare.dist(x, y, c, version_idx=poincare.VERSION_MOBIUS_DIRECT)
distances = jax.vmap(distance_fn, in_axes=(0, 0, None))(x_batch, y_batch, 1.0)
# Pattern 2: vmap then JIT
dist_batched = jax.vmap(poincare.dist, in_axes=(0, 0, None, None))
distances = jax.jit(dist_batched, static_argnames=['version_idx'])(
x_batch, y_batch, 1.0, poincare.VERSION_MOBIUS_DIRECT
)
# Pattern 3: Combined (one-liner)
distances = jax.jit(
jax.vmap(poincare.dist, in_axes=(0, 0, None, None)),
static_argnames=['version_idx']
)(x_batch, y_batch, 1.0, poincare.VERSION_MOBIUS_DIRECT)
Best Practice
JIT the inner function and vmap the outer function for best performance and flexibility.
Neural Network Patterns¶
Forward Pass¶
Flax NNX layers automatically handle batching:
from flax import nnx
from hyperbolix.nn_layers import HypLinearPoincare
from hyperbolix.manifolds import Poincare
poincare = Poincare()
# Create layer
layer = HypLinearPoincare(
manifold_module=poincare,
in_dim=128,
out_dim=64,
rngs=nnx.Rngs(0)
)
# Batch input: (batch_size, in_dim)
x_batch = jax.random.normal(jax.random.PRNGKey(1), (32, 128)) * 0.3
x_proj = jax.vmap(poincare.proj, in_axes=(0, None))(x_batch, 1.0)
# Forward pass handles batching internally
output = layer(x_proj, c=1.0)
print(output.shape) # (32, 64)
Activations with vmap¶
Hyperbolic activations are functional and need explicit batching:
from hyperbolix.nn_layers import hyp_relu
# Single point (ambient coordinates, d+1 dims for hyperboloid)
x = jnp.array([1.5, 0.2, 0.3, 0.1]) # Ambient coordinates (4,)
activated = hyp_relu(x, c=1.0)
# Batch of points - use vmap
x_batch = jax.random.normal(jax.random.PRNGKey(0), (32, 4))
activated_batch = jax.vmap(lambda x: hyp_relu(x, c=1.0))(x_batch)
print(activated_batch.shape) # (32, 4)
Complete Model with JIT¶
from hyperbolix.nn_layers import HypLinearPoincare, hyp_relu
from hyperbolix.manifolds import Poincare
poincare = Poincare()
class HyperbolicClassifier(nnx.Module):
def __init__(self, rngs):
self.layer1 = HypLinearPoincare(poincare, 784, 256, rngs=rngs)
self.layer2 = HypLinearPoincare(poincare, 256, 128, rngs=rngs)
self.layer3 = HypLinearPoincare(poincare, 128, 10, rngs=rngs)
def __call__(self, x, c=1.0):
x = self.layer1(x, c)
# vmap activation over batch
x = jax.vmap(lambda xi: hyp_relu(xi, c))(x)
x = self.layer2(x, c)
x = jax.vmap(lambda xi: hyp_relu(xi, c))(x)
x = self.layer3(x, c)
return x
# Create model
model = HyperbolicClassifier(rngs=nnx.Rngs(0))
# JIT the forward pass
@jax.jit
def forward(model, x, c):
return model(x, c)
# Use with batch
x_batch = jax.random.normal(jax.random.PRNGKey(1), (32, 784)) * 0.1
x_proj = jax.vmap(poincare.proj, in_axes=(0, None))(x_batch, 1.0)
logits = forward(model, x_proj, c=1.0)
print(logits.shape) # (32, 10)
Training Loop Patterns¶
Efficient Training Step¶
from flax import nnx
from hyperbolix.manifolds import Poincare
from hyperbolix.optim import riemannian_adam
poincare = Poincare()
@jax.jit
def train_step(model, optimizer, x_batch, y_batch, c):
"""Single training step with JIT compilation."""
def loss_fn(model):
preds = model(x_batch, c)
return jnp.mean((preds - y_batch) ** 2)
loss, grads = nnx.value_and_grad(loss_fn)(model)
optimizer.update(model, grads)
return loss
# Training loop
for epoch in range(num_epochs):
for x_batch, y_batch in dataloader:
# Project to manifold
x_batch = jax.vmap(poincare.proj, in_axes=(0, None))(x_batch, 1.0)
# Single JIT-compiled step
loss = train_step(model, optimizer, x_batch, y_batch, c=1.0)
print(f"Loss: {loss:.4f}")
Performance Optimization Tips¶
1. Profile Before Optimizing¶
import time
# Warmup JIT compilation
_ = dist_jit(x, y, c=1.0, version_idx=0)
# Time subsequent calls
start = time.time()
for _ in range(1000):
_ = dist_jit(x, y, c=1.0, version_idx=0)
elapsed = time.time() - start
print(f"Time per call: {elapsed/1000*1e6:.2f} µs")
2. Minimize Recompilation¶
# BAD: Different shapes trigger recompilation
d1 = dist_jit(x1, y1, c=1.0, version_idx=0) # Compile for shape (16,)
d2 = dist_jit(x2, y2, c=1.0, version_idx=0) # Recompile for shape (32,)
# GOOD: Use consistent shapes
x_batch = jnp.array([[0.1, 0.2], [0.3, 0.4]])
distances = jax.vmap(dist_jit, in_axes=(0, 0, None, None))(
x_batch[:, 0], x_batch[:, 1], 1.0, poincare.VERSION_MOBIUS_DIRECT
)
3. Use Static Arguments Appropriately¶
# GOOD: Keep curvature dynamic
@jax.jit
def process_batch(x_batch, c):
return jax.vmap(
lambda x: poincare.proj(x, c) # Simple projection
)(x_batch)
# BAD: Making everything static reduces flexibility
@jax.jit
def process_batch_bad(x_batch): # c=1.0 hardcoded
return jax.vmap(
lambda x: poincare.proj(x, c=1.0)
)(x_batch) # Can't change curvature without recompilation
4. Batch Size Considerations¶
# Small batches: Less JIT benefit
x_small = jax.random.normal(jax.random.PRNGKey(0), (10, 128))
# ~10-20x speedup
# Large batches: Maximum JIT benefit
x_large = jax.random.normal(jax.random.PRNGKey(0), (1000, 128))
# ~50-100x speedup
5. Memory vs Computation Trade-offs¶
# Memory-efficient: Process in chunks
def process_large_batch(x_batch, chunk_size=1000):
n = len(x_batch)
results = []
for i in range(0, n, chunk_size):
chunk = x_batch[i:i+chunk_size]
results.append(jax.vmap(some_fn)(chunk))
return jnp.concatenate(results)
# Compute-efficient: Process all at once (may OOM)
def process_all_at_once(x_batch):
return jax.vmap(some_fn)(x_batch)
Common Pitfalls¶
Pitfall 1: Forgetting to vmap Activations¶
# WRONG: Activation expects single point
x = layer(x_batch, c=1.0) # (batch, dim+1) for hyperboloid
activated = hyp_relu(x, c=1.0) # May work but semantics unclear
# CORRECT: Explicit vmap
activated = jax.vmap(lambda xi: hyp_relu(xi, c=1.0))(x)
# ALSO CORRECT: hyp_relu handles batches
activated = hyp_relu(x, c=1.0) # Directly works on (batch, dim+1)
Pitfall 2: Shape Mismatches with vmap¶
# WRONG: Incompatible in_axes
x_batch = jnp.array([[0.1, 0.2]]) # (1, 2)
y_batch = jnp.array([[0.3, 0.4]]) # (1, 2)
c_batch = jnp.array([1.0, 1.5]) # (2,)
distances = jax.vmap(poincare.dist, in_axes=(0, 0, 0))(
x_batch, y_batch, c_batch # Shape mismatch: (1,) vs (2,)
)
# CORRECT: Broadcast curvature or use same value
distances = jax.vmap(poincare.dist, in_axes=(0, 0, None))(
x_batch, y_batch, 1.0
)
Pitfall 3: Static Curvature¶
# WRONG: Can't learn curvature
@jax.jit
def model_forward(x, c=1.0): # c fixed at compile time
return poincare.proj(x, c)
# CORRECT: Keep c dynamic
@jax.jit
def model_forward(x, c): # c can vary
return poincare.proj(x, c)
Benchmark Results¶
Typical speedups on M1/M2 Mac or modern GPU:
| Operation | Batch Size | No JIT | With JIT | Speedup |
|---|---|---|---|---|
| Distance (dim=128) | 100 | 12 ms | 0.8 ms | 15x |
| Distance (dim=128) | 1000 | 120 ms | 1.5 ms | 80x |
| Expmap (dim=256) | 100 | 18 ms | 1.2 ms | 15x |
| Linear layer forward | 1000 | 45 ms | 2.1 ms | 21x |
| Full model (3 layers) | 1000 | 150 ms | 6.5 ms | 23x |
Run benchmarks yourself:
See Also¶
- Manifolds API: Manifold function signatures
- NN Layers API: Layer implementations
- Training Workflows: Complete training examples
- Numerical Stability: Float precision considerations