Utilities API¶
Utility functions for hyperbolic deep learning.
Math Utilities¶
Numerically stable implementations of hyperbolic functions.
hyperbolix.utils.math_utils ¶
Math utils functions for hyperbolic operations with numerically stable limits.
Direct JAX port of PyTorch math_utils.py with type annotations using jaxtyping.
cosh ¶
Hyperbolic cosine with overflow protection. Domain=(-inf, inf).
Clamps input to safe ranges to prevent overflow based on dtype. Uses log(max) * 0.99 as safety margin.
Args: x: Input array of any shape
Returns: cosh(x) with overflow protection
Source code in hyperbolix/utils/math_utils.py
sinh ¶
Hyperbolic sine with overflow protection. Domain=(-inf, inf).
Clamps input to safe ranges to prevent overflow based on dtype. Uses log(max) * 0.99 as safety margin.
Args: x: Input array of any shape
Returns: sinh(x) with overflow protection
Source code in hyperbolix/utils/math_utils.py
acosh ¶
Inverse hyperbolic cosine with domain clamping. Domain=[1, inf).
Args: x: Input array of any shape
Returns: acosh(x) with domain protection (clamps x >= 1.0)
Source code in hyperbolix/utils/math_utils.py
atanh ¶
Inverse hyperbolic tangent with domain clamping. Domain=(-1, 1).
Clamps input away from ±1 to avoid singularities.
Args: x: Input array of any shape
Returns: atanh(x) with domain protection
Source code in hyperbolix/utils/math_utils.py
smooth_clamp ¶
smooth_clamp(
x: Float[Array, ...],
min_value: float,
max_value: float,
smoothing_factor: float = 50.0,
) -> Float[Array, ...]
Smoothly clamp array values to a range [min_value, max_value].
Args: x: Input array of any shape min_value: Minimum value to clamp to max_value: Maximum value to clamp to smoothing_factor: Beta parameter for softplus (higher = sharper transition)
Returns: Array with values smoothly clamped to [min_value, max_value]
Source code in hyperbolix/utils/math_utils.py
Usage Example¶
from hyperbolix.utils.math_utils import acosh, atanh, smooth_clamp
import jax.numpy as jnp
# Numerically stable hyperbolic functions
x = jnp.array([1.5, 2.0, 10.0])
y = acosh(x) # Handles edge cases near 1.0
# Smooth clamping for stability
z = jnp.array([0.99, 1.0, 1.01])
z_clamped = smooth_clamp(z, min_val=0.0, max_val=1.0)
Helper Functions¶
Helper utilities for distance computation and delta-hyperbolicity analysis.
hyperbolix.utils.helpers ¶
Helper utilities for hyperbolic geometry computations.
This module provides utilities for computing pairwise distances, delta-hyperbolicity metrics, and other geometric measures on hyperbolic manifolds.
compute_pairwise_distances ¶
compute_pairwise_distances(
points: Float[Array, "n_points dim"],
manifold_module,
c: Float[Array, ""] | float,
version_idx: int = 0,
) -> Float[Array, "n_points n_points"]
Compute pairwise geodesic distances between points on a manifold.
This function computes the full distance matrix efficiently by leveraging JAX's vmap for vectorization. The computation is NOT chunked - the entire distance matrix is computed in a single pass using nested vmap operations.
Memory Considerations: For n points, this computes an n-by-n distance matrix in memory. For very large point sets (>5000-10000 points depending on available memory), consider subsampling or implementing a chunked version. The current implementation prioritizes simplicity and leverages XLA's automatic memory optimizations.
Args: points: Points on the manifold, shape (n_points, dim) For Hyperboloid: dim is ambient dimension (dim+1) For PoincareBall: dim is intrinsic dimension manifold_module: Manifold module (hyperboloid or poincare) c: Curvature parameter (positive scalar) version_idx: Distance version index (manifold-specific, default: 0) For Hyperboloid: 0 = VERSION_DEFAULT (standard acosh with hard clipping) 1 = VERSION_SMOOTHENED (smoothened distance) For PoincareBall: 0 = VERSION_MOBIUS_DIRECT (direct Möbius formula) 1 = VERSION_MOBIUS (via addition) 2 = VERSION_METRIC_TENSOR (metric tensor induced) 3 = VERSION_LORENTZIAN_PROXY (Lorentzian proxy)
Returns: Symmetric distance matrix of shape (n_points, n_points)
Examples: >>> import jax.numpy as jnp >>> from hyperbolix.manifolds import hyperboloid >>> from hyperbolix.utils.helpers import compute_pairwise_distances >>> >>> # Generate random hyperboloid points >>> key = jax.random.PRNGKey(0) >>> points = jax.random.normal(key, (100, 11)) >>> points = jax.vmap(hyperboloid.proj, in_axes=(0, None))(points, 1.0) >>> >>> # Compute pairwise distances >>> distmat = compute_pairwise_distances( ... points, hyperboloid, c=1.0, version_idx=hyperboloid.VERSION_DEFAULT ... ) >>> print(distmat.shape) # (100, 100)
Notes: - The PyTorch reference implementation used explicit chunking for memory management. This JAX version uses vmap and relies on XLA optimization. - The distance matrix is symmetric: distmat[i, j] == distmat[j, i] - Diagonal elements are zero: distmat[i, i] == 0 - For large datasets, consider subsampling before calling this function
Source code in hyperbolix/utils/helpers.py
compute_hyperbolic_delta ¶
compute_hyperbolic_delta(
distmat: Float[Array, "n_points n_points"],
version: str = "average",
) -> Float[Array, ""]
Compute the delta-hyperbolicity value from a distance matrix.
Delta-hyperbolicity is a metric space property that quantifies how "tree-like" or "hyperbolic" a metric space is. It is based on the Gromov 4-point condition.
For any four points w, x, y, z in a metric space, define: S1 = d(w,x) + d(y,z) S2 = d(w,y) + d(x,z) S3 = d(w,z) + d(x,y)
The 4-point condition requires that the two largest of these sums differ by at most 2δ. A space is δ-hyperbolic if this holds for all quadruples.
This implementation uses a reference point (the first point) to compute Gromov products efficiently: (x|y)_w = [d(w,x) + d(w,y) - d(x,y)] / 2
Args: distmat: Symmetric distance matrix, shape (n_points, n_points) version: Which delta statistic to compute (default: "average") - "average": Mean of delta values over all point quadruples - "smallest": Maximum delta (worst-case over all quadruples)
Returns: Delta-hyperbolicity value (scalar)
References: Gromov, M. (1987). "Hyperbolic groups." Essays in group theory. Chami, I., et al. (2021). "HoroPCA: Hyperbolic dimensionality reduction via horospherical projections." ICML 2021.
Examples: >>> import jax.numpy as jnp >>> from hyperbolix.utils.helpers import compute_hyperbolic_delta >>> >>> # Create a distance matrix (should be symmetric) >>> distmat = jnp.array([ ... [0.0, 1.0, 2.0, 3.0], ... [1.0, 0.0, 1.5, 2.5], ... [2.0, 1.5, 0.0, 1.0], ... [3.0, 2.5, 1.0, 0.0] ... ]) >>> >>> delta_avg = compute_hyperbolic_delta(distmat, version="average") >>> delta_max = compute_hyperbolic_delta(distmat, version="smallest")
Notes: - The result is scaled by 2 because we fix a reference point - Lower delta values indicate more hyperbolic (tree-like) structure - Euclidean spaces have unbounded delta; hyperbolic spaces have bounded delta
Source code in hyperbolix/utils/helpers.py
98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 | |
get_delta ¶
get_delta(
points: Float[Array, "n_points dim"],
manifold_module,
c: float,
version_idx: int = 0,
sample_size: int = 1500,
version: str = "average",
key: Key[Array, ""] | None = None,
) -> tuple[
Float[Array, ""], Float[Array, ""], Float[Array, ""]
]
Compute delta-hyperbolicity and related metrics for a point set.
This function subsamples points (if needed), computes the pairwise distance matrix, and then calculates the delta-hyperbolicity value along with the diameter and relative delta (delta normalized by diameter).
Args: points: Points on the manifold, shape (n_points, dim) manifold_module: Manifold module (hyperboloid or poincare) c: Curvature parameter (positive scalar) version_idx: Distance version index (manifold-specific, default: 0) sample_size: Maximum number of points to use for delta computation (default: 1500). If n_points > sample_size, randomly subsample. version: Which delta statistic to compute (default: "average") - "average": Mean of delta values - "smallest": Maximum delta (worst-case) key: JAX random key for subsampling (required if n_points > sample_size)
Returns: Tuple of (delta, diameter, relative_delta): - delta: Delta-hyperbolicity value - diameter: Maximum pairwise distance in the point set - relative_delta: delta / diameter (scale-invariant measure)
Examples: >>> import jax >>> import jax.numpy as jnp >>> from hyperbolix.manifolds import hyperboloid >>> from hyperbolix.utils.helpers import get_delta >>> >>> # Generate random hyperboloid points >>> key = jax.random.PRNGKey(42) >>> points = jax.random.normal(key, (2000, 11)) >>> points = jax.vmap(hyperboloid.proj, in_axes=(0, None))(points, 1.0) >>> >>> # Compute delta metrics >>> key, subkey = jax.random.split(key) >>> delta, diam, rel_delta = get_delta( ... points, hyperboloid, c=1.0, sample_size=1500, key=subkey ... ) >>> print(f"Delta: {delta:.4f}, Diameter: {diam:.4f}, Relative: {rel_delta:.4f}")
Notes: - Subsampling is done randomly without replacement - For reproducibility, always provide the same random key - The PyTorch version used torch.randperm; we use jax.random.permutation
Source code in hyperbolix/utils/helpers.py
200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 | |
Usage Examples¶
Pairwise Distances¶
import jax
import jax.numpy as jnp
from hyperbolix.utils.helpers import compute_pairwise_distances
from hyperbolix.manifolds import Poincare
poincare = Poincare()
# Set of points on Poincaré ball
points = jnp.array([
[0.1, 0.2],
[0.3, -0.1],
[-0.2, 0.4],
[0.0, 0.0]
])
# Compute all pairwise distances
dist_matrix = compute_pairwise_distances(
points,
manifold_module=poincare,
c=1.0,
version_idx=0
)
# Result: (4, 4) matrix of distances
print(dist_matrix.shape) # (4, 4)
Delta-Hyperbolicity¶
Measure how "hyperbolic" a dataset is using the Gromov delta metric:
import jax
import jax.numpy as jnp
from hyperbolix.utils.helpers import get_delta
from hyperbolix.manifolds import Poincare
poincare = Poincare()
# Generate random points
key = jax.random.PRNGKey(0)
points = jax.random.normal(key, (100, 2)) * 0.3
# Project to Poincaré ball
points_proj = jax.vmap(poincare.proj, in_axes=(0, None))(points, 1.0)
# Compute delta-hyperbolicity
delta, diameter, rel_delta = get_delta(
points_proj,
manifold_module=poincare,
c=1.0,
sample_size=500, # Number of 4-point samples
seed=42
)
print(f"Delta: {delta:.4f}")
print(f"Diameter: {diameter:.4f}")
print(f"Relative delta: {rel_delta:.4f}")
The Gromov delta quantifies tree-likeness:
- δ ≈ 0: Perfect tree structure (hyperbolic)
- δ > 0: Non-tree structure (less hyperbolic)
- δ/diameter: Normalized measure (relative delta)
Performance Tips¶
JIT Compilation
All utility functions support JIT compilation:
Batching
For large datasets, consider batching delta-hyperbolicity computation:
References¶
- Gromov Delta: Gromov, M. (1987). "Hyperbolic groups."
See also:
- Manifolds API: Core geometric operations
- Numerical Stability Guide: Best practices