Numerical Stability Guide¶
Best practices for maintaining numerical precision in hyperbolic operations.
Overview¶
Hyperbolic geometry presents unique numerical challenges due to the exponential growth of the conformal factor near the boundary and the involvement of hyperbolic functions (cosh, sinh, atanh). This guide explains these challenges and provides strategies to maintain numerical stability.
Key Challenges
- Conformal factor explosion: λ(x) grows exponentially as points approach the boundary
- Float32 limitations: ~7 significant digits, insufficient for large distances (>10)
- Hyperbolic function overflow: cosh/sinh overflow for large arguments
- Division by near-zero: Operations involving 1 - c||x||² near the boundary
Float Precision: Float32 vs Float64¶
When to Use Each¶
Float32 (default): - Sufficient for most applications with small to moderate distances (< 5) - 2-4x faster on GPU - Lower memory footprint (important for large models) - ~7 significant decimal digits
Float64 (high precision): - Required for large distances (> 10) or near-boundary points - Better numerical stability in edge cases - ~15-16 significant decimal digits - Use for research, validation, or stability-critical applications
import jax.numpy as jnp
from hyperbolix.manifolds import Poincare
# Float32 (default)
poincare_f32 = Poincare()
x = jnp.array([0.1, 0.2])
y = jnp.array([0.8, 0.5])
dist = poincare_f32.dist(x, y, c=1.0)
# Float64 (high precision) — inputs are automatically cast
poincare_f64 = Poincare(dtype=jnp.float64)
dist = poincare_f64.dist(x, y, c=1.0) # returns float64
Precision Requirements by Distance¶
| Distance from Origin | Float32 Accuracy | Recommended Precision |
|---|---|---|
| d < 3 | Excellent (< 0.01% error) | float32 |
| 3 ≤ d < 5 | Good (< 0.1% error) | float32 |
| 5 ≤ d < 10 | Moderate (< 3% error) | float64 for critical ops |
| d ≥ 10 | Poor (> 3% error) | float64 required |
Quick Check
If your embeddings have distances from the origin > 7, switch to float64:
The Conformal Factor Problem¶
Understanding λ(x)¶
The conformal factor in Poincaré ball geometry is:
This factor appears in: - Exponential map: scales tangent vectors - Logarithmic map: scales back to tangent space - Riemannian gradient: converts Euclidean to Riemannian gradients
Exponential Growth¶
As points move toward the boundary (||x|| → 1/√c), λ(x) explodes:
import jax.numpy as jnp
from hyperbolix.manifolds import Poincare
poincare = Poincare()
c = 1.0
distances = [0, 1, 2, 3, 5, 7, 10]
for d in distances:
# Point at distance d from origin
x = poincare.expmap_0(jnp.array([d, 0.0]), c=c)
norm = jnp.linalg.norm(x)
lambda_x = 2.0 / (1.0 - c * norm**2)
print(f"d={d:2d}: ||x||={norm:.6f}, λ(x)={lambda_x:10.1f}")
Output:
d= 0: ||x||=0.000000, λ(x)= 2.0
d= 1: ||x||=0.761594, λ(x)= 3.6
d= 2: ||x||=0.964028, λ(x)= 27.7
d= 3: ||x||=0.995055, λ(x)= 202.0
d= 5: ||x||=0.999909, λ(x)= 11013.2
d= 7: ||x||=0.999991, λ(x)= 1096633.2
d=10: ||x||=1.000000, λ(x)= inf
Numerical Issues¶
Problem 1: Precision loss in logmap
# logmap divides by λ(x), then later operations multiply by λ(x)
# With float32 and λ(x) ≈ 10,000:
# - Division by 10,000 loses 4 digits of precision
# - Multiplication by 10,000 doesn't recover them
# Result: ~3 digits of precision remaining (out of 7)
Problem 2: Cancellation in 1 - c||x||²
# Near boundary: ||x||² ≈ 0.999999
# Computing 1 - c||x||² loses significant digits due to catastrophic cancellation
# Float32: 1.0 - 0.999999 = 0.000001 (but stored imprecisely!)
Mitigation Strategies¶
1. Use projection after operations
from hyperbolix.manifolds import Poincare
poincare = Poincare()
# After Möbius addition or other operations
result = poincare.addition(x, y, c=1.0)
result = poincare.proj(result, c=1.0) # Project back to manifold
2. Keep points away from boundary
from hyperbolix.manifolds import Poincare
poincare = Poincare()
# During initialization
def init_hyperbolic_embeddings(key, n_points, dim, max_norm=0.8):
"""Initialize embeddings safely away from boundary."""
x = jax.random.normal(key, (n_points, dim)) * 0.1
x_proj = jax.vmap(poincare.proj, in_axes=(0, None))(x, 1.0)
# Clip to max_norm to avoid boundary
norms = jnp.linalg.norm(x_proj, axis=-1, keepdims=True)
x_clipped = jnp.where(norms > max_norm, x_proj * max_norm / norms, x_proj)
return x_clipped
3. Use float64 manifold for critical operations
from hyperbolix.manifolds import Poincare
import jax.numpy as jnp
# Create a float64 manifold — inputs are automatically cast
poincare_f64 = Poincare(dtype=jnp.float64)
dist_precise = poincare_f64.dist(x, y, c=1.0) # returns float64
Hyperbolic Function Overflow¶
The Problem¶
Standard implementations of cosh, sinh can overflow:
# Standard numpy/jax
import jax.numpy as jnp
x = jnp.array(100.0, dtype=jnp.float32)
print(jnp.cosh(x)) # inf (overflow!)
print(jnp.sinh(x)) # inf (overflow!)
Solution: Protected Math Utils¶
Hyperbolix provides overflow-protected hyperbolic functions:
from hyperbolix.utils.math_utils import cosh, sinh, acosh, atanh
# Protected versions
x = jnp.array(100.0, dtype=jnp.float32)
print(cosh(x)) # Finite value (clamped to safe range)
print(sinh(x)) # Finite value (clamped to safe range)
# Domain-protected inverse functions
y = jnp.array(0.5, dtype=jnp.float32)
print(acosh(y)) # Clamped to valid domain [1, inf)
z = jnp.array(0.999999, dtype=jnp.float32)
print(atanh(z)) # Clamped away from ±1 singularities
Smooth Clamping¶
The library uses smooth clamping via softplus instead of hard clipping:
from hyperbolix.utils.math_utils import smooth_clamp
# Smooth clamp (differentiable, no gradient issues)
x = jnp.array([-10.0, -1.0, 0.0, 1.0, 10.0])
clamped = smooth_clamp(x, min_value=-5.0, max_value=5.0, smoothing_factor=50.0)
print(clamped)
# Near boundaries: smooth transition, not abrupt cutoff
Benefits: - Differentiable everywhere (no gradient discontinuities) - Numerically stable (uses softplus internally) - Adjustable smoothing factor for trade-off between accuracy and gradient flow
Version Parameters¶
Purpose¶
Many manifold operations have multiple mathematically equivalent formulations that differ in numerical properties. The version_idx parameter selects which to use.
Poincaré Ball Distance Versions¶
from hyperbolix.manifolds import Poincare
import jax.numpy as jnp
poincare = Poincare()
x = jnp.array([0.1, 0.2])
y = jnp.array([0.3, 0.4])
c = 1.0
# Version 0: Direct Möbius distance (FASTEST, default)
d0 = poincare.dist(x, y, c, version_idx=poincare.VERSION_MOBIUS_DIRECT)
# Version 1: Möbius via addition
d1 = poincare.dist(x, y, c, version_idx=poincare.VERSION_MOBIUS)
# Version 2: Metric tensor induced
d2 = poincare.dist(x, y, c, version_idx=poincare.VERSION_METRIC_TENSOR)
# Version 3: Lorentzian proxy (best near boundary)
d3 = poincare.dist(x, y, c, version_idx=poincare.VERSION_LORENTZIAN_PROXY)
print(f"Version 0: {d0:.6f}")
print(f"Version 1: {d1:.6f}")
print(f"Version 2: {d2:.6f}")
print(f"Version 3: {d3:.6f}")
# All should be approximately equal
Which Version to Use?¶
General recommendation: VERSION_MOBIUS_DIRECT (version 0)
- Fastest
- Fewest intermediate operations
- Best for most applications
Special cases:
- Near-boundary points (||x|| > 0.9): Try VERSION_LORENTZIAN_PROXY (version 3) for better stability
- Very high dimensions (> 1000): VERSION_METRIC_TENSOR (version 2) may be more stable
- Debugging: Compare all versions — significant differences indicate numerical issues
Using Versions with JIT¶
import jax
from hyperbolix.manifolds import Poincare
poincare = Poincare()
# IMPORTANT: version_idx must be static for JIT
@jax.jit
def compute_distances(x_batch, y_batch, c):
# Version baked into function body (static)
return jax.vmap(
lambda x, y: poincare.dist(x, y, c, version_idx=0)
)(x_batch, y_batch)
# Or use static_argnames
dist_jit = jax.jit(poincare.dist, static_argnames=['version_idx'])
d = dist_jit(x, y, c=1.0, version_idx=0)
Projection Strategies¶
Why Project?¶
Operations like addition, linear transformations can push points off the manifold. Projection restores the manifold constraint.
When to Project¶
Always project:
- After Möbius addition: poincare.addition(x, y, c)
- After neural network layers
- After parameter updates in optimization
Usually don't need projection:
- After expmap (already on manifold)
- After proj (redundant)
Projection¶
Projection ensures points stay on the manifold by clipping norms:
from hyperbolix.manifolds import Poincare
poincare = Poincare()
# Project to Poincaré ball
x_proj = poincare.proj(x, c=1.0)
# Projection is numerically stable and automatically handles edge cases
Projection in Training¶
from hyperbolix.manifolds import Poincare
from hyperbolix.nn_layers import HypLinearPoincare
from flax import nnx
poincare = Poincare()
class HyperbolicModel(nnx.Module):
def __init__(self, rngs):
self.layer1 = HypLinearPoincare(poincare, 128, 64, rngs=rngs)
self.layer2 = HypLinearPoincare(poincare, 64, 32, rngs=rngs)
def __call__(self, x, c=1.0):
x = self.layer1(x, c)
# Project after layer (layer already includes projection internally)
x = self.layer2(x, c)
# Final projection
x = jax.vmap(lambda xi: poincare.proj(xi, c))(x)
return x
Layer Projection
Hyperbolix layers already project internally after operations, so explicit projection between layers is optional but recommended for extra safety.
Common Edge Cases¶
Edge Case 1: Points Near the Boundary¶
Symptoms: NaN or Inf in gradients, exploding losses
Solution:
# Check if points are too close to boundary
def check_boundary_proximity(x_batch, c=1.0):
norms = jnp.linalg.norm(x_batch, axis=-1)
max_norm = 1.0 / jnp.sqrt(c)
proximity = norms / max_norm
if jnp.any(proximity > 0.95):
print(f"WARNING: Points near boundary (max proximity: {jnp.max(proximity):.4f})")
return True
return False
# Clip if needed
def safe_clip_to_interior(x_batch, c=1.0, safety_factor=0.9):
max_allowed = safety_factor / jnp.sqrt(c)
norms = jnp.linalg.norm(x_batch, axis=-1, keepdims=True)
scale = jnp.minimum(1.0, max_allowed / (norms + 1e-8))
return x_batch * scale
Edge Case 2: Zero or Near-Zero Vectors¶
Symptoms: Division by zero warnings, NaN in tangent operations
Solution:
# Manifold functions handle this internally with MIN_NORM
# But you can add explicit checks:
def safe_normalize(v, eps=1e-8):
norm = jnp.linalg.norm(v)
return jnp.where(norm > eps, v / norm, jnp.zeros_like(v))
Edge Case 3: Large Learning Rates¶
Symptoms: Points shoot to boundary, training collapse
Solution:
# Use conservative learning rates
from hyperbolix.optim import riemannian_adam
# For Poincaré ball
optimizer = riemannian_adam(learning_rate=1e-3) # Not 1e-2 or higher!
# For Hyperboloid
optimizer = riemannian_adam(learning_rate=5e-4) # Even more conservative
# Use learning rate scheduling
from optax import exponential_decay
schedule = exponential_decay(
init_value=1e-3,
transition_steps=1000,
decay_rate=0.96,
staircase=True
)
optimizer = riemannian_adam(learning_rate=schedule)
Edge Case 4: High Curvature Values¶
Symptoms: Numerical instability, rapid convergence to boundary
Solution:
# Keep curvature moderate
c = 1.0 # Good default
# High curvature (c > 1) increases numerical challenges
c = 0.1 # Lower curvature = larger hyperbolic space = more stable
# If learning curvature, clip it
def clip_curvature(c, min_c=0.01, max_c=10.0):
return jnp.clip(c, min_c, max_c)
Checking Manifold Constraints¶
Validation Functions¶
Each manifold provides is_in_manifold for validation:
from hyperbolix.manifolds import Poincare, Hyperboloid
import jax.numpy as jnp
poincare = Poincare()
hyperboloid = Hyperboloid()
# Poincaré ball: ||x||² < 1/c
x = jnp.array([0.5, 0.3])
assert poincare.is_in_manifold(x, c=1.0, atol=1e-5)
# Hyperboloid: -x₀² + Σxᵢ² = -1/c (with x₀ > 0)
x_ambient = jnp.array([1.5, 0.2, 0.3, 0.1]) # (dim+1,)
assert hyperboloid.is_in_manifold(x_ambient, c=1.0, atol=1e-5)
Batch Validation¶
from hyperbolix.manifolds import Poincare
poincare = Poincare()
def validate_batch(x_batch, c=1.0, atol=1e-5):
"""Check if all points in batch satisfy manifold constraint."""
valid = jax.vmap(lambda x: poincare.is_in_manifold(x, c, atol))(x_batch)
num_valid = jnp.sum(valid)
total = len(x_batch)
if num_valid < total:
print(f"WARNING: {total - num_valid}/{total} points off manifold")
violations = jnp.where(~valid)[0]
print(f"Violating indices: {violations[:10]}") # Show first 10
return jnp.all(valid)
Best Practices Summary¶
Numerical Stability Checklist
- ✅ Use float32 for distances < 7, float64 for larger
- ✅ Project after operations that might violate constraints
- ✅ Keep points away from boundary (max norm < 0.9/√c)
- ✅ Use conservative learning rates (< 1e-3 for Poincaré, < 5e-4 for Hyperboloid)
- ✅ Use protected math functions (
hyperbolix.utils.math_utils) - ✅ Monitor conformal factors during training
- ✅ Validate manifold constraints in debugging
- ✅ Use
VERSION_MOBIUS_DIRECTfor Poincaré distance unless issues arise - ✅ Clip curvature if learnable (0.01 < c < 10.0)
- ✅ Initialize embeddings conservatively (small norms)
Debugging Numerical Issues¶
Step-by-Step Diagnostic¶
-
Check for NaN/Inf:
-
Verify manifold constraints:
-
Check boundary proximity:
-
Switch to float64:
-
Try different version:
-
Use float64 manifold:
See Also¶
- Batching & JIT: Performance optimization patterns
- Manifolds API: Manifold function reference
- Training Workflows: End-to-end training examples
- Mathematical Background: Theory and formulas