Manifolds API
This page documents the core manifold operations in Hyperbolix. Each manifold is a class that provides geometric operations and automatic dtype casting.
Overview
Hyperbolix provides three manifold classes:
- Euclidean: Flat Euclidean space (baseline)
- Poincaré Ball: Conformal model of hyperbolic space
- Hyperboloid: Lorentz/Minkowski model of hyperbolic space
All manifolds share a common interface defined by the Manifold protocol and support:
- Automatic dtype casting: Pass
dtype=jnp.float64 for higher precision
- vmap-native methods: Methods operate on single points; use
jax.vmap for batching
- JIT compatibility: All methods are JIT-compilable
Manifold Protocol
hyperbolix.manifolds.protocol.Manifold
Bases: Protocol
Structural protocol for manifold classes.
All three concrete manifold classes (Poincare, Hyperboloid,
Euclidean) satisfy this protocol without modification.
The method signatures use the minimal common interface so that
manifold-specific optional parameters (e.g. version_idx, atol)
do not break compatibility.
Euclidean
Flat Euclidean space (identity operations).
hyperbolix.manifolds.euclidean.Euclidean
Euclidean(dtype: dtype = jnp.float32)
Bases: ManifoldBase
Euclidean manifold with automatic dtype casting.
Provides all manifold operations with automatic casting of array inputs
to the specified dtype.
Args:
dtype: Target JAX dtype for computations (default: jnp.float32)
Examples:
>>> import jax.numpy as jnp
>>> from hyperbolix.manifolds.euclidean import Euclidean
>>>
>>> manifold = Euclidean(dtype=jnp.float64)
>>> x = jnp.array([1.0, 2.0])
>>> y = jnp.array([3.0, 4.0])
>>> d = manifold.dist(x, y)
Source code in hyperbolix/manifolds/_base.py
| def __init__(self, dtype: jnp.dtype = jnp.float32) -> None:
self.dtype = dtype
|
proj
proj(
x: Float[Array, dim], c: float = 0.0
) -> Float[Array, dim]
Project point onto Euclidean space (identity).
Source code in hyperbolix/manifolds/euclidean.py
| def proj(self, x: Float[Array, "dim"], c: float = 0.0) -> Float[Array, "dim"]:
"""Project point onto Euclidean space (identity)."""
return _proj(self._cast(x), c)
|
addition
addition(
x: Float[Array, dim],
y: Float[Array, dim],
c: float = 0.0,
) -> Float[Array, dim]
Add Euclidean points.
Source code in hyperbolix/manifolds/euclidean.py
| def addition(self, x: Float[Array, "dim"], y: Float[Array, "dim"], c: float = 0.0) -> Float[Array, "dim"]:
"""Add Euclidean points."""
return _addition(self._cast(x), self._cast(y), c)
|
scalar_mul
scalar_mul(
r: float, x: Float[Array, dim], c: float = 0.0
) -> Float[Array, dim]
Scalar multiplication.
Source code in hyperbolix/manifolds/euclidean.py
| def scalar_mul(self, r: float, x: Float[Array, "dim"], c: float = 0.0) -> Float[Array, "dim"]:
"""Scalar multiplication."""
x = self._cast(x)
r_cast = jnp.asarray(r, dtype=x.dtype)
return _scalar_mul(r_cast, x, c) # type: ignore[arg-type]
|
dist
dist(
x: Float[Array, dim],
y: Float[Array, dim],
c: float = 0.0,
) -> Float[Array, ""]
Compute distance.
Source code in hyperbolix/manifolds/euclidean.py
| def dist(self, x: Float[Array, "dim"], y: Float[Array, "dim"], c: float = 0.0) -> Float[Array, ""]:
"""Compute distance."""
return _dist(self._cast(x), self._cast(y), c)
|
dist_0
dist_0(
x: Float[Array, dim], c: float = 0.0
) -> Float[Array, ""]
Distance from origin.
Source code in hyperbolix/manifolds/euclidean.py
| def dist_0(self, x: Float[Array, "dim"], c: float = 0.0) -> Float[Array, ""]:
"""Distance from origin."""
return _dist_0(self._cast(x), c)
|
expmap
expmap(
v: Float[Array, dim],
x: Float[Array, dim],
c: float = 0.0,
) -> Float[Array, dim]
Exponential map.
Source code in hyperbolix/manifolds/euclidean.py
| def expmap(self, v: Float[Array, "dim"], x: Float[Array, "dim"], c: float = 0.0) -> Float[Array, "dim"]:
"""Exponential map."""
return _expmap(self._cast(v), self._cast(x), c)
|
expmap_0
expmap_0(
v: Float[Array, dim], c: float = 0.0
) -> Float[Array, dim]
Exponential map from origin.
Source code in hyperbolix/manifolds/euclidean.py
| def expmap_0(self, v: Float[Array, "dim"], c: float = 0.0) -> Float[Array, "dim"]:
"""Exponential map from origin."""
return _expmap_0(self._cast(v), c)
|
retraction
retraction(
v: Float[Array, dim],
x: Float[Array, dim],
c: float = 0.0,
) -> Float[Array, dim]
Retraction.
Source code in hyperbolix/manifolds/euclidean.py
| def retraction(self, v: Float[Array, "dim"], x: Float[Array, "dim"], c: float = 0.0) -> Float[Array, "dim"]:
"""Retraction."""
return _retraction(self._cast(v), self._cast(x), c)
|
logmap
logmap(
y: Float[Array, dim],
x: Float[Array, dim],
c: float = 0.0,
) -> Float[Array, dim]
Logarithmic map.
Source code in hyperbolix/manifolds/euclidean.py
| def logmap(self, y: Float[Array, "dim"], x: Float[Array, "dim"], c: float = 0.0) -> Float[Array, "dim"]:
"""Logarithmic map."""
return _logmap(self._cast(y), self._cast(x), c)
|
logmap_0
logmap_0(
y: Float[Array, dim], c: float = 0.0
) -> Float[Array, dim]
Logarithmic map from origin.
Source code in hyperbolix/manifolds/euclidean.py
| def logmap_0(self, y: Float[Array, "dim"], c: float = 0.0) -> Float[Array, "dim"]:
"""Logarithmic map from origin."""
return _logmap_0(self._cast(y), c)
|
ptransp
ptransp(
v: Float[Array, dim],
x: Float[Array, dim],
y: Float[Array, dim],
c: float = 0.0,
) -> Float[Array, dim]
Parallel transport.
Source code in hyperbolix/manifolds/euclidean.py
| def ptransp(
self, v: Float[Array, "dim"], x: Float[Array, "dim"], y: Float[Array, "dim"], c: float = 0.0
) -> Float[Array, "dim"]:
"""Parallel transport."""
return _ptransp(self._cast(v), self._cast(x), self._cast(y), c)
|
ptransp_0
ptransp_0(
v: Float[Array, dim],
y: Float[Array, dim],
c: float = 0.0,
) -> Float[Array, dim]
Parallel transport from origin.
Source code in hyperbolix/manifolds/euclidean.py
| def ptransp_0(self, v: Float[Array, "dim"], y: Float[Array, "dim"], c: float = 0.0) -> Float[Array, "dim"]:
"""Parallel transport from origin."""
return _ptransp_0(self._cast(v), self._cast(y), c)
|
tangent_inner
tangent_inner(
u: Float[Array, dim],
v: Float[Array, dim],
x: Float[Array, dim],
c: float = 0.0,
) -> Float[Array, ""]
Tangent inner product.
Source code in hyperbolix/manifolds/euclidean.py
| def tangent_inner(
self, u: Float[Array, "dim"], v: Float[Array, "dim"], x: Float[Array, "dim"], c: float = 0.0
) -> Float[Array, ""]:
"""Tangent inner product."""
return _tangent_inner(self._cast(u), self._cast(v), self._cast(x), c)
|
tangent_norm
tangent_norm(
v: Float[Array, dim],
x: Float[Array, dim],
c: float = 0.0,
) -> Float[Array, ""]
Tangent norm.
Source code in hyperbolix/manifolds/euclidean.py
| def tangent_norm(self, v: Float[Array, "dim"], x: Float[Array, "dim"], c: float = 0.0) -> Float[Array, ""]:
"""Tangent norm."""
return _tangent_norm(self._cast(v), self._cast(x), c)
|
egrad2rgrad
egrad2rgrad(
grad: Float[Array, dim],
x: Float[Array, dim],
c: float = 0.0,
) -> Float[Array, dim]
Euclidean to Riemannian gradient.
Source code in hyperbolix/manifolds/euclidean.py
| def egrad2rgrad(self, grad: Float[Array, "dim"], x: Float[Array, "dim"], c: float = 0.0) -> Float[Array, "dim"]:
"""Euclidean to Riemannian gradient."""
return _egrad2rgrad(self._cast(grad), self._cast(x), c)
|
tangent_proj
tangent_proj(
v: Float[Array, dim],
x: Float[Array, dim],
c: float = 0.0,
) -> Float[Array, dim]
Project onto tangent space.
Source code in hyperbolix/manifolds/euclidean.py
| def tangent_proj(self, v: Float[Array, "dim"], x: Float[Array, "dim"], c: float = 0.0) -> Float[Array, "dim"]:
"""Project onto tangent space."""
return _tangent_proj(self._cast(v), self._cast(x), c)
|
is_in_manifold
is_in_manifold(
x: Float[Array, dim], c: float = 0.0
) -> Array
Check if on manifold.
Source code in hyperbolix/manifolds/euclidean.py
| def is_in_manifold(self, x: Float[Array, "dim"], c: float = 0.0) -> Array:
"""Check if on manifold."""
return _is_in_manifold(self._cast(x), c)
|
is_in_tangent_space
is_in_tangent_space(
v: Float[Array, dim],
x: Float[Array, dim],
c: float = 0.0,
) -> Array
Check if in tangent space.
Source code in hyperbolix/manifolds/euclidean.py
| def is_in_tangent_space(self, v: Float[Array, "dim"], x: Float[Array, "dim"], c: float = 0.0) -> Array:
"""Check if in tangent space."""
return _is_in_tangent_space(self._cast(v), self._cast(x), c)
|
Poincaré Ball
The Poincaré ball model with Möbius operations.
Distance Versions
The Poincaré dist method has a version_idx parameter selecting between 4 formulations:
VERSION_MOBIUS_DIRECT (0): Möbius addition formula (default, fastest)
VERSION_MOBIUS (1): Möbius via addition
VERSION_METRIC_TENSOR (2): Direct metric tensor integration
VERSION_LORENTZIAN_PROXY (3): Lorentzian model proxy (best near boundary)
Constants are available as poincare.VERSION_MOBIUS_DIRECT etc., or from
hyperbolix.manifolds.poincare.
hyperbolix.manifolds.poincare.Poincare
Poincare(dtype: dtype = jnp.float32)
Bases: ManifoldBase
Poincaré ball manifold with automatic dtype casting.
Provides all manifold operations with automatic casting of array inputs
to the specified dtype. This eliminates the need for manual casting and
provides better numerical stability control.
Args:
dtype: Target JAX dtype for computations (default: jnp.float32)
Examples:
>>> import jax.numpy as jnp
>>> from hyperbolix.manifolds.poincare import Poincare, VERSION_MOBIUS_DIRECT
>>>
>>> # Create manifold with float64 for better precision
>>> manifold = Poincare(dtype=jnp.float64)
>>>
>>> # Arrays are automatically cast to float64
>>> x = jnp.array([0.1, 0.2], dtype=jnp.float32)
>>> y = jnp.array([0.3, 0.4], dtype=jnp.float32)
>>> d = manifold.dist(x, y, c=1.0)
>>> d.dtype # float64
Source code in hyperbolix/manifolds/_base.py
| def __init__(self, dtype: jnp.dtype = jnp.float32) -> None:
self.dtype = dtype
|
proj
proj(x: Float[Array, dim], c: float) -> Float[Array, dim]
Project point onto Poincaré ball by clipping norm.
Source code in hyperbolix/manifolds/poincare.py
| def proj(self, x: Float[Array, "dim"], c: float) -> Float[Array, "dim"]:
"""Project point onto Poincaré ball by clipping norm."""
return _proj(self._cast(x), c)
|
gyration
gyration(
x: Float[Array, dim],
y: Float[Array, dim],
z: Float[Array, dim],
c: float,
) -> Float[Array, dim]
Compute gyration gyr[x,y]z to restore commutativity.
Source code in hyperbolix/manifolds/poincare.py
| def gyration(
self, x: Float[Array, "dim"], y: Float[Array, "dim"], z: Float[Array, "dim"], c: float
) -> Float[Array, "dim"]:
"""Compute gyration gyr[x,y]z to restore commutativity."""
return _gyration(self._cast(x), self._cast(y), self._cast(z), c)
|
addition
addition(
x: Float[Array, dim], y: Float[Array, dim], c: float
) -> Float[Array, dim]
Möbius gyrovector addition x ⊕ y.
Source code in hyperbolix/manifolds/poincare.py
| def addition(self, x: Float[Array, "dim"], y: Float[Array, "dim"], c: float) -> Float[Array, "dim"]:
"""Möbius gyrovector addition x ⊕ y."""
return _addition(self._cast(x), self._cast(y), c)
|
scalar_mul
scalar_mul(
r: float, x: Float[Array, dim], c: float
) -> Float[Array, dim]
Scalar multiplication r ⊗ x on Poincaré ball.
Source code in hyperbolix/manifolds/poincare.py
| def scalar_mul(self, r: float, x: Float[Array, "dim"], c: float) -> Float[Array, "dim"]:
"""Scalar multiplication r ⊗ x on Poincaré ball."""
x = self._cast(x)
r_cast = jnp.asarray(r, dtype=x.dtype)
return _scalar_mul(r_cast, x, c) # type: ignore[arg-type]
|
dist
dist(
x: Float[Array, dim],
y: Float[Array, dim],
c: float,
version_idx: int = VERSION_MOBIUS_DIRECT,
) -> Float[Array, ""]
Compute geodesic distance between Poincaré ball points.
Source code in hyperbolix/manifolds/poincare.py
| def dist(
self,
x: Float[Array, "dim"],
y: Float[Array, "dim"],
c: float,
version_idx: int = VERSION_MOBIUS_DIRECT,
) -> Float[Array, ""]:
"""Compute geodesic distance between Poincaré ball points."""
return _dist(self._cast(x), self._cast(y), c, version_idx)
|
dist_0
dist_0(
x: Float[Array, dim],
c: float,
version_idx: int = VERSION_MOBIUS_DIRECT,
) -> Float[Array, ""]
Compute geodesic distance from Poincaré ball origin.
Source code in hyperbolix/manifolds/poincare.py
| def dist_0(self, x: Float[Array, "dim"], c: float, version_idx: int = VERSION_MOBIUS_DIRECT) -> Float[Array, ""]:
"""Compute geodesic distance from Poincaré ball origin."""
return _dist_0(self._cast(x), c, version_idx)
|
expmap
expmap(
v: Float[Array, dim], x: Float[Array, dim], c: float
) -> Float[Array, dim]
Exponential map: map tangent vector v at point x to manifold.
Source code in hyperbolix/manifolds/poincare.py
| def expmap(self, v: Float[Array, "dim"], x: Float[Array, "dim"], c: float) -> Float[Array, "dim"]:
"""Exponential map: map tangent vector v at point x to manifold."""
return _expmap(self._cast(v), self._cast(x), c)
|
expmap_0
expmap_0(
v: Float[Array, dim], c: float
) -> Float[Array, dim]
Exponential map from origin: map tangent vector v at origin to manifold.
Source code in hyperbolix/manifolds/poincare.py
| def expmap_0(self, v: Float[Array, "dim"], c: float) -> Float[Array, "dim"]:
"""Exponential map from origin: map tangent vector v at origin to manifold."""
return _expmap_0(self._cast(v), c)
|
retraction
retraction(
v: Float[Array, dim], x: Float[Array, dim], c: float
) -> Float[Array, dim]
Retraction: first-order approximation of exponential map.
Source code in hyperbolix/manifolds/poincare.py
| def retraction(self, v: Float[Array, "dim"], x: Float[Array, "dim"], c: float) -> Float[Array, "dim"]:
"""Retraction: first-order approximation of exponential map."""
return _retraction(self._cast(v), self._cast(x), c)
|
logmap
logmap(
y: Float[Array, dim], x: Float[Array, dim], c: float
) -> Float[Array, dim]
Logarithmic map: map point y to tangent space at point x.
Source code in hyperbolix/manifolds/poincare.py
| def logmap(self, y: Float[Array, "dim"], x: Float[Array, "dim"], c: float) -> Float[Array, "dim"]:
"""Logarithmic map: map point y to tangent space at point x."""
return _logmap(self._cast(y), self._cast(x), c)
|
logmap_0
logmap_0(
y: Float[Array, dim], c: float
) -> Float[Array, dim]
Logarithmic map from origin: map point y to tangent space at origin.
Source code in hyperbolix/manifolds/poincare.py
| def logmap_0(self, y: Float[Array, "dim"], c: float) -> Float[Array, "dim"]:
"""Logarithmic map from origin: map point y to tangent space at origin."""
return _logmap_0(self._cast(y), c)
|
ptransp
ptransp(
v: Float[Array, dim],
x: Float[Array, dim],
y: Float[Array, dim],
c: float,
) -> Float[Array, dim]
Parallel transport tangent vector v from point x to point y.
Source code in hyperbolix/manifolds/poincare.py
| def ptransp(self, v: Float[Array, "dim"], x: Float[Array, "dim"], y: Float[Array, "dim"], c: float) -> Float[Array, "dim"]:
"""Parallel transport tangent vector v from point x to point y."""
return _ptransp(self._cast(v), self._cast(x), self._cast(y), c)
|
ptransp_0
ptransp_0(
v: Float[Array, dim], y: Float[Array, dim], c: float
) -> Float[Array, dim]
Parallel transport tangent vector v from origin to point y.
Source code in hyperbolix/manifolds/poincare.py
| def ptransp_0(self, v: Float[Array, "dim"], y: Float[Array, "dim"], c: float) -> Float[Array, "dim"]:
"""Parallel transport tangent vector v from origin to point y."""
return _ptransp_0(self._cast(v), self._cast(y), c)
|
tangent_inner
tangent_inner(
u: Float[Array, dim],
v: Float[Array, dim],
x: Float[Array, dim],
c: float,
) -> Float[Array, ""]
Compute inner product of tangent vectors u and v at point x.
Source code in hyperbolix/manifolds/poincare.py
| def tangent_inner(
self, u: Float[Array, "dim"], v: Float[Array, "dim"], x: Float[Array, "dim"], c: float
) -> Float[Array, ""]:
"""Compute inner product of tangent vectors u and v at point x."""
return _tangent_inner(self._cast(u), self._cast(v), self._cast(x), c)
|
tangent_norm
tangent_norm(
v: Float[Array, dim], x: Float[Array, dim], c: float
) -> Float[Array, ""]
Compute norm of tangent vector v at point x.
Source code in hyperbolix/manifolds/poincare.py
| def tangent_norm(self, v: Float[Array, "dim"], x: Float[Array, "dim"], c: float) -> Float[Array, ""]:
"""Compute norm of tangent vector v at point x."""
return _tangent_norm(self._cast(v), self._cast(x), c)
|
egrad2rgrad
egrad2rgrad(
grad: Float[Array, dim], x: Float[Array, dim], c: float
) -> Float[Array, dim]
Convert Euclidean gradient to Riemannian gradient.
Source code in hyperbolix/manifolds/poincare.py
| def egrad2rgrad(self, grad: Float[Array, "dim"], x: Float[Array, "dim"], c: float) -> Float[Array, "dim"]:
"""Convert Euclidean gradient to Riemannian gradient."""
return _egrad2rgrad(self._cast(grad), self._cast(x), c)
|
tangent_proj
tangent_proj(
v: Float[Array, dim], x: Float[Array, dim], c: float
) -> Float[Array, dim]
Project vector v onto tangent space at point x.
Source code in hyperbolix/manifolds/poincare.py
| def tangent_proj(self, v: Float[Array, "dim"], x: Float[Array, "dim"], c: float) -> Float[Array, "dim"]:
"""Project vector v onto tangent space at point x."""
return _tangent_proj(self._cast(v), self._cast(x), c)
|
is_in_manifold
is_in_manifold(
x: Float[Array, dim], c: float, atol: float = 1e-05
) -> Array
Check if point x lies in Poincaré ball.
Source code in hyperbolix/manifolds/poincare.py
| def is_in_manifold(self, x: Float[Array, "dim"], c: float, atol: float = 1e-5) -> Array:
"""Check if point x lies in Poincaré ball."""
return _is_in_manifold(self._cast(x), c, atol)
|
is_in_tangent_space
is_in_tangent_space(
v: Float[Array, dim], x: Float[Array, dim], c: float
) -> Array
Check if vector v lies in tangent space at point x.
Source code in hyperbolix/manifolds/poincare.py
| def is_in_tangent_space(self, v: Float[Array, "dim"], x: Float[Array, "dim"], c: float) -> Array:
"""Check if vector v lies in tangent space at point x."""
return _is_in_tangent_space(self._cast(v), self._cast(x), c)
|
conformal_factor(
x: Float[Array, "... dim"], c: float
) -> Float[Array, "... 1"]
Numerically stable conformal factor lambda(x) = 2 / (1 - c||x||^2).
Batch-compatible version that handles arbitrary leading dimensions.
Source code in hyperbolix/manifolds/poincare.py
| def conformal_factor(self, x: Float[Array, "... dim"], c: float) -> Float[Array, "... 1"]:
"""Numerically stable conformal factor lambda(x) = 2 / (1 - c||x||^2).
Batch-compatible version that handles arbitrary leading dimensions.
"""
return _conformal_factor_batch(self._cast(x), c)
|
compute_mlr_pp
compute_mlr_pp(
x: Float[Array, "batch in_dim"],
z: Float[Array, "out_dim in_dim"],
r: Float[Array, "out_dim 1"],
c: float,
clamping_factor: float,
smoothing_factor: float,
min_enorm: float = 1e-15,
) -> Float[Array, "batch out_dim"]
Compute HNN++ multinomial linear regression on the Poincare ball.
Source code in hyperbolix/manifolds/poincare.py
| def compute_mlr_pp(
self,
x: Float[Array, "batch in_dim"],
z: Float[Array, "out_dim in_dim"],
r: Float[Array, "out_dim 1"],
c: float,
clamping_factor: float,
smoothing_factor: float,
min_enorm: float = 1e-15,
) -> Float[Array, "batch out_dim"]:
"""Compute HNN++ multinomial linear regression on the Poincare ball."""
return _compute_mlr_pp(self._cast(x), self._cast(z), self._cast(r), c, clamping_factor, smoothing_factor, min_enorm)
|
beta_concat
beta_concat(
points: Float[Array, "M n_i"], c: float
) -> Float[Array, n]
Beta-concatenation of M equal-dimensional Poincaré ball points.
Source code in hyperbolix/manifolds/poincare.py
| def beta_concat(self, points: Float[Array, "M n_i"], c: float) -> Float[Array, "n"]:
"""Beta-concatenation of M equal-dimensional Poincaré ball points."""
return _beta_concat(self._cast(points), c)
|
Hyperboloid
The hyperboloid (Lorentz) model with Minkowski geometry.
Lorentz Operations
The Hyperboloid class includes specialized operations for convolutional layers:
lorentz_boost: Lorentz boost transformation
distance_rescale: Distance-based rescaling
hcat: Lorentz direct concatenation for convolutions
hyperbolix.manifolds.hyperboloid.Hyperboloid
Hyperboloid(dtype: dtype = jnp.float32)
Bases: ManifoldBase
Hyperboloid manifold with automatic dtype casting.
Provides all manifold operations with automatic casting of array inputs
to the specified dtype. This eliminates the need for manual casting and
provides better numerical stability control.
Args:
dtype: Target JAX dtype for computations (default: jnp.float32)
Examples:
>>> import jax.numpy as jnp
>>> from hyperbolix.manifolds.hyperboloid import Hyperboloid, VERSION_DEFAULT
>>>
>>> # Create manifold with float64 for better precision
>>> manifold = Hyperboloid(dtype=jnp.float64)
>>>
>>> # Arrays are automatically cast to float64
>>> x = jnp.array([1.0, 0.1, 0.2], dtype=jnp.float32)
>>> x = manifold.proj(x, c=1.0)
>>> x.dtype # float64
Source code in hyperbolix/manifolds/_base.py
| def __init__(self, dtype: jnp.dtype = jnp.float32) -> None:
self.dtype = dtype
|
create_origin
create_origin(
c: float, dim: int
) -> Float[Array, dim_plus_1]
Create hyperboloid origin [1/√c, 0, ..., 0].
Source code in hyperbolix/manifolds/hyperboloid.py
| def create_origin(self, c: float, dim: int) -> Float[Array, "dim_plus_1"]:
"""Create hyperboloid origin [1/√c, 0, ..., 0]."""
return _create_origin(c, dim, self.dtype)
|
minkowski_inner
minkowski_inner(
x: Float[Array, dim_plus_1], y: Float[Array, dim_plus_1]
) -> Float[Array, ""]
Compute Minkowski inner product ⟨x, y⟩_L = -x₀y₀ + ⟨x_rest, y_rest⟩.
Source code in hyperbolix/manifolds/hyperboloid.py
| def minkowski_inner(self, x: Float[Array, "dim_plus_1"], y: Float[Array, "dim_plus_1"]) -> Float[Array, ""]:
"""Compute Minkowski inner product ⟨x, y⟩_L = -x₀y₀ + ⟨x_rest, y_rest⟩."""
return _minkowski_inner(self._cast(x), self._cast(y))
|
proj
proj(
x: Float[Array, dim_plus_1], c: float
) -> Float[Array, dim_plus_1]
Project point onto hyperboloid.
Source code in hyperbolix/manifolds/hyperboloid.py
| def proj(self, x: Float[Array, "dim_plus_1"], c: float) -> Float[Array, "dim_plus_1"]:
"""Project point onto hyperboloid."""
return _proj(self._cast(x), c)
|
proj_batch
proj_batch(
x: Float[Array, "... dim_plus_1"], c: float
) -> Float[Array, "... dim_plus_1"]
Project batched points onto hyperboloid (handles arbitrary leading dimensions).
Source code in hyperbolix/manifolds/hyperboloid.py
| def proj_batch(self, x: Float[Array, "... dim_plus_1"], c: float) -> Float[Array, "... dim_plus_1"]:
"""Project batched points onto hyperboloid (handles arbitrary leading dimensions)."""
return _proj_batch(self._cast(x), c)
|
addition
addition(
x: Float[Array, dim_plus_1],
y: Float[Array, dim_plus_1],
c: float,
) -> Float[Array, dim_plus_1]
Gyrovector addition on hyperboloid.
Source code in hyperbolix/manifolds/hyperboloid.py
| def addition(self, x: Float[Array, "dim_plus_1"], y: Float[Array, "dim_plus_1"], c: float) -> Float[Array, "dim_plus_1"]:
"""Gyrovector addition on hyperboloid."""
return _addition(self._cast(x), self._cast(y), c)
|
scalar_mul
scalar_mul(
r: float, x: Float[Array, dim_plus_1], c: float
) -> Float[Array, dim_plus_1]
Scalar multiplication on hyperboloid.
Source code in hyperbolix/manifolds/hyperboloid.py
| def scalar_mul(self, r: float, x: Float[Array, "dim_plus_1"], c: float) -> Float[Array, "dim_plus_1"]:
"""Scalar multiplication on hyperboloid."""
x = self._cast(x)
r_cast = jnp.asarray(r, dtype=x.dtype)
return _scalar_mul(r_cast, x, c) # type: ignore[arg-type]
|
dist
dist(
x: Float[Array, dim_plus_1],
y: Float[Array, dim_plus_1],
c: float,
version_idx: int = VERSION_DEFAULT,
) -> Float[Array, ""]
Compute geodesic distance between hyperboloid points.
Source code in hyperbolix/manifolds/hyperboloid.py
| def dist(
self,
x: Float[Array, "dim_plus_1"],
y: Float[Array, "dim_plus_1"],
c: float,
version_idx: int = VERSION_DEFAULT,
) -> Float[Array, ""]:
"""Compute geodesic distance between hyperboloid points."""
return _dist(self._cast(x), self._cast(y), c, version_idx)
|
dist_0
dist_0(
x: Float[Array, dim_plus_1],
c: float,
version_idx: int = VERSION_DEFAULT,
) -> Float[Array, ""]
Compute geodesic distance from hyperboloid origin.
Source code in hyperbolix/manifolds/hyperboloid.py
| def dist_0(self, x: Float[Array, "dim_plus_1"], c: float, version_idx: int = VERSION_DEFAULT) -> Float[Array, ""]:
"""Compute geodesic distance from hyperboloid origin."""
return _dist_0(self._cast(x), c, version_idx)
|
expmap
expmap(
v: Float[Array, dim_plus_1],
x: Float[Array, dim_plus_1],
c: float,
) -> Float[Array, dim_plus_1]
Exponential map: map tangent vector v at point x to manifold.
Source code in hyperbolix/manifolds/hyperboloid.py
| def expmap(self, v: Float[Array, "dim_plus_1"], x: Float[Array, "dim_plus_1"], c: float) -> Float[Array, "dim_plus_1"]:
"""Exponential map: map tangent vector v at point x to manifold."""
return _expmap(self._cast(v), self._cast(x), c)
|
expmap_0
expmap_0(
v: Float[Array, dim_plus_1], c: float
) -> Float[Array, dim_plus_1]
Exponential map from origin.
Source code in hyperbolix/manifolds/hyperboloid.py
| def expmap_0(self, v: Float[Array, "dim_plus_1"], c: float) -> Float[Array, "dim_plus_1"]:
"""Exponential map from origin."""
return _expmap_0(self._cast(v), c)
|
retraction
retraction(
v: Float[Array, dim_plus_1],
x: Float[Array, dim_plus_1],
c: float,
) -> Float[Array, dim_plus_1]
Retraction: first-order approximation of exponential map.
Source code in hyperbolix/manifolds/hyperboloid.py
| def retraction(self, v: Float[Array, "dim_plus_1"], x: Float[Array, "dim_plus_1"], c: float) -> Float[Array, "dim_plus_1"]:
"""Retraction: first-order approximation of exponential map."""
return _retraction(self._cast(v), self._cast(x), c)
|
logmap
logmap(
y: Float[Array, dim_plus_1],
x: Float[Array, dim_plus_1],
c: float,
) -> Float[Array, dim_plus_1]
Logarithmic map: map point y to tangent space at point x.
Source code in hyperbolix/manifolds/hyperboloid.py
| def logmap(self, y: Float[Array, "dim_plus_1"], x: Float[Array, "dim_plus_1"], c: float) -> Float[Array, "dim_plus_1"]:
"""Logarithmic map: map point y to tangent space at point x."""
return _logmap(self._cast(y), self._cast(x), c)
|
logmap_0
logmap_0(
y: Float[Array, dim_plus_1], c: float
) -> Float[Array, dim_plus_1]
Logarithmic map from origin.
Source code in hyperbolix/manifolds/hyperboloid.py
| def logmap_0(self, y: Float[Array, "dim_plus_1"], c: float) -> Float[Array, "dim_plus_1"]:
"""Logarithmic map from origin."""
return _logmap_0(self._cast(y), c)
|
ptransp
ptransp(
v: Float[Array, dim_plus_1],
x: Float[Array, dim_plus_1],
y: Float[Array, dim_plus_1],
c: float,
) -> Float[Array, dim_plus_1]
Parallel transport tangent vector v from point x to point y.
Source code in hyperbolix/manifolds/hyperboloid.py
| def ptransp(
self, v: Float[Array, "dim_plus_1"], x: Float[Array, "dim_plus_1"], y: Float[Array, "dim_plus_1"], c: float
) -> Float[Array, "dim_plus_1"]:
"""Parallel transport tangent vector v from point x to point y."""
return _ptransp(self._cast(v), self._cast(x), self._cast(y), c)
|
ptransp_0
ptransp_0(
v: Float[Array, dim_plus_1],
y: Float[Array, dim_plus_1],
c: float,
) -> Float[Array, dim_plus_1]
Parallel transport tangent vector v from origin to point y.
Source code in hyperbolix/manifolds/hyperboloid.py
| def ptransp_0(self, v: Float[Array, "dim_plus_1"], y: Float[Array, "dim_plus_1"], c: float) -> Float[Array, "dim_plus_1"]:
"""Parallel transport tangent vector v from origin to point y."""
return _ptransp_0(self._cast(v), self._cast(y), c)
|
tangent_inner
tangent_inner(
u: Float[Array, dim_plus_1],
v: Float[Array, dim_plus_1],
x: Float[Array, dim_plus_1],
c: float,
) -> Float[Array, ""]
Compute inner product of tangent vectors u and v at point x.
Source code in hyperbolix/manifolds/hyperboloid.py
| def tangent_inner(
self, u: Float[Array, "dim_plus_1"], v: Float[Array, "dim_plus_1"], x: Float[Array, "dim_plus_1"], c: float
) -> Float[Array, ""]:
"""Compute inner product of tangent vectors u and v at point x."""
return _tangent_inner(self._cast(u), self._cast(v), self._cast(x), c)
|
tangent_norm
tangent_norm(
v: Float[Array, dim_plus_1],
x: Float[Array, dim_plus_1],
c: float,
) -> Float[Array, ""]
Compute norm of tangent vector v at point x.
Source code in hyperbolix/manifolds/hyperboloid.py
| def tangent_norm(self, v: Float[Array, "dim_plus_1"], x: Float[Array, "dim_plus_1"], c: float) -> Float[Array, ""]:
"""Compute norm of tangent vector v at point x."""
return _tangent_norm(self._cast(v), self._cast(x), c)
|
egrad2rgrad
egrad2rgrad(
grad: Float[Array, dim_plus_1],
x: Float[Array, dim_plus_1],
c: float,
) -> Float[Array, dim_plus_1]
Convert Euclidean gradient to Riemannian gradient.
Source code in hyperbolix/manifolds/hyperboloid.py
| def egrad2rgrad(
self, grad: Float[Array, "dim_plus_1"], x: Float[Array, "dim_plus_1"], c: float
) -> Float[Array, "dim_plus_1"]:
"""Convert Euclidean gradient to Riemannian gradient."""
return _egrad2rgrad(self._cast(grad), self._cast(x), c)
|
tangent_proj
tangent_proj(
v: Float[Array, dim_plus_1],
x: Float[Array, dim_plus_1],
c: float,
) -> Float[Array, dim_plus_1]
Project vector v onto tangent space at point x.
Source code in hyperbolix/manifolds/hyperboloid.py
| def tangent_proj(
self, v: Float[Array, "dim_plus_1"], x: Float[Array, "dim_plus_1"], c: float
) -> Float[Array, "dim_plus_1"]:
"""Project vector v onto tangent space at point x."""
return _tangent_proj(self._cast(v), self._cast(x), c)
|
is_in_manifold
is_in_manifold(
x: Float[Array, dim_plus_1],
c: float,
atol: float = 0.0001,
) -> Array
Check if point x lies on hyperboloid.
Source code in hyperbolix/manifolds/hyperboloid.py
| def is_in_manifold(self, x: Float[Array, "dim_plus_1"], c: float, atol: float = 1e-4) -> Array:
"""Check if point x lies on hyperboloid."""
return _is_in_manifold(self._cast(x), c, atol)
|
is_in_tangent_space
is_in_tangent_space(
v: Float[Array, dim_plus_1],
x: Float[Array, dim_plus_1],
c: float,
) -> Array
Check if vector v lies in tangent space at point x.
Source code in hyperbolix/manifolds/hyperboloid.py
| def is_in_tangent_space(self, v: Float[Array, "dim_plus_1"], x: Float[Array, "dim_plus_1"], c: float) -> Array:
"""Check if vector v lies in tangent space at point x."""
return _is_in_tangent_space(self._cast(v), self._cast(x), c)
|
hcat
hcat(
points: Float[Array, "N n"], c: float = 1.0
) -> Float[Array, dN_plus_1]
Hyperbolic concatenation of N points into one point.
Source code in hyperbolix/manifolds/hyperboloid.py
| def hcat(
self,
points: Float[Array, "N n"],
c: float = 1.0,
) -> Float[Array, "dN_plus_1"]:
"""Hyperbolic concatenation of N points into one point."""
return _hcat(self._cast(points), c)
|
embed_spatial_0
embed_spatial_0(
v_spatial: Float[Array, "... n"],
) -> Float[Array, "... n_plus_1"]
Embed spatial vector as tangent vector at origin.
Source code in hyperbolix/manifolds/hyperboloid.py
| def embed_spatial_0(self, v_spatial: Float[Array, "... n"]) -> Float[Array, "... n_plus_1"]:
"""Embed spatial vector as tangent vector at origin."""
return _embed_spatial_0(self._cast(v_spatial))
|
compute_mlr
compute_mlr(
x: Float[Array, "batch in_dim"],
z: Float[Array, "out_dim in_dim_minus_1"],
r: Float[Array, "out_dim 1"],
c: float,
clamping_factor: float,
smoothing_factor: float,
min_enorm: float = 1e-15,
) -> Float[Array, "batch out_dim"]
Compute multinomial linear regression on hyperboloid.
Source code in hyperbolix/manifolds/hyperboloid.py
| def compute_mlr(
self,
x: Float[Array, "batch in_dim"],
z: Float[Array, "out_dim in_dim_minus_1"],
r: Float[Array, "out_dim 1"],
c: float,
clamping_factor: float,
smoothing_factor: float,
min_enorm: float = 1e-15,
) -> Float[Array, "batch out_dim"]:
"""Compute multinomial linear regression on hyperboloid."""
return _compute_mlr(self._cast(x), self._cast(z), self._cast(r), c, clamping_factor, smoothing_factor, min_enorm)
|
Isometry Mappings
Distance-preserving maps between Poincaré ball and hyperboloid models.
hyperbolix.manifolds.isometry_mappings
Isometry mappings between hyperbolic manifold models.
This module implements distance-preserving transformations (isometries) between
different models of hyperbolic geometry. All functions operate on single points
and use JAX's vmap for batch operations.
Supported Models:
- Hyperboloid model (Lorentz model): Points in R^(d+1) satisfying ⟨x,x⟩_L = -1/c
- Poincaré ball model: Points in R^d with ||y||² < 1/c
The mappings are implemented via stereographic projection from the hyperboloid
to the Poincaré ball, projecting through the point [-1, 0, ..., 0].
JIT Compilation & Batching
All functions work with single points and return single points.
Use jax.vmap for batch operations:
>>> import jax
>>> import jax.numpy as jnp
>>> from hyperbolix.manifolds import isometry_mappings
>>>
>>> # Single point conversion
>>> x_hyp = jnp.array([1.0, 0.1, 0.2]) # Hyperboloid point
>>> y_poinc = isometry_mappings.hyperboloid_to_poincare(x_hyp, c=1.0)
>>>
>>> # Batch conversion with vmap
>>> x_batch = jnp.array([[1.0, 0.1, 0.2], [1.1, 0.15, 0.25]])
>>> convert_batch = jax.vmap(isometry_mappings.hyperboloid_to_poincare, in_axes=(0, None))
>>> y_batch = convert_batch(x_batch, 1.0)
References:
Wikipedia: Hyperboloid model
https://en.wikipedia.org/wiki/Hyperboloid_model#Relation_to_other_models
hyperboloid_to_poincare
hyperboloid_to_poincare(
x: Float[Array, dim_plus_1], c: Float[Array, ""] | float
) -> Float[Array, dim]
Convert hyperboloid point to Poincaré ball via stereographic projection.
Projects the hyperboloid point onto the hyperplane t = 0 by intersecting
with a line through [-1, 0, ..., 0]. This implements the canonical isometry
between the two models.
Formula:
y_i = x_i / (1 + t)
where x = [t, x_1, ..., x_n] on hyperboloid
Args:
x: Point on hyperboloid, shape (dim+1,). Should satisfy ⟨x,x⟩_L = -1/c.
c: Curvature (positive)
Returns:
Point in Poincaré ball, shape (dim,). Satisfies ||y||² < 1/c.
Examples:
>>> import jax.numpy as jnp
>>> from hyperbolix.manifolds import isometry_mappings
>>>
>>> # Convert hyperboloid origin to Poincaré origin
>>> x_origin = jnp.array([1.0, 0.0, 0.0]) # c=1.0 origin
>>> y = isometry_mappings.hyperboloid_to_poincare(x_origin, c=1.0)
>>> jnp.allclose(y, jnp.zeros(2))
True
References:
Wikipedia: Hyperboloid model - Relation to other models
Source code in hyperbolix/manifolds/isometry_mappings.py
| def hyperboloid_to_poincare(
x: Float[Array, "dim_plus_1"],
c: Float[Array, ""] | float,
) -> Float[Array, "dim"]:
"""Convert hyperboloid point to Poincaré ball via stereographic projection.
Projects the hyperboloid point onto the hyperplane t = 0 by intersecting
with a line through [-1, 0, ..., 0]. This implements the canonical isometry
between the two models.
Formula:
y_i = x_i / (1 + t)
where x = [t, x_1, ..., x_n] on hyperboloid
Args:
x: Point on hyperboloid, shape (dim+1,). Should satisfy ⟨x,x⟩_L = -1/c.
c: Curvature (positive)
Returns:
Point in Poincaré ball, shape (dim,). Satisfies ||y||² < 1/c.
Examples:
>>> import jax.numpy as jnp
>>> from hyperbolix.manifolds import isometry_mappings
>>>
>>> # Convert hyperboloid origin to Poincaré origin
>>> x_origin = jnp.array([1.0, 0.0, 0.0]) # c=1.0 origin
>>> y = isometry_mappings.hyperboloid_to_poincare(x_origin, c=1.0)
>>> jnp.allclose(y, jnp.zeros(2))
True
References:
Wikipedia: Hyperboloid model - Relation to other models
"""
t = x[0] # Temporal component
x_spatial = x[1:] # Spatial components (x_1, ..., x_n)
# Stereographic projection: y_i = x_i / (1 + t)
denominator = jnp.maximum(1.0 + t, MIN_DENOM)
return x_spatial / denominator
|
poincare_to_hyperboloid
poincare_to_hyperboloid(
y: Float[Array, dim], c: Float[Array, ""] | float
) -> Float[Array, dim_plus_1]
Convert Poincaré ball point to hyperboloid via inverse stereographic projection.
Inverts the stereographic projection to map points from the Poincaré ball
back to the hyperboloid. This implements the canonical isometry between
the two models.
Formula:
(t, x_i) = ((1 + Σy_i²), 2y_i) / ((1 - Σy_i²) * √c)
where y = [y_1, ..., y_n] in Poincaré ball
Args:
y: Point in Poincaré ball, shape (dim,). Should satisfy ||y||² < 1/c.
c: Curvature (positive)
Returns:
Point on hyperboloid, shape (dim+1,). Satisfies ⟨x,x⟩_L = -1/c.
Examples:
>>> import jax.numpy as jnp
>>> from hyperbolix.manifolds import isometry_mappings
>>>
>>> # Convert Poincaré origin to hyperboloid origin
>>> y_origin = jnp.array([0.0, 0.0])
>>> x = isometry_mappings.poincare_to_hyperboloid(y_origin, c=1.0)
>>> jnp.allclose(x, jnp.array([1.0, 0.0, 0.0]))
True
References:
Wikipedia: Hyperboloid model - Relation to other models
Source code in hyperbolix/manifolds/isometry_mappings.py
| def poincare_to_hyperboloid(
y: Float[Array, "dim"],
c: Float[Array, ""] | float,
) -> Float[Array, "dim_plus_1"]:
"""Convert Poincaré ball point to hyperboloid via inverse stereographic projection.
Inverts the stereographic projection to map points from the Poincaré ball
back to the hyperboloid. This implements the canonical isometry between
the two models.
Formula:
(t, x_i) = ((1 + Σy_i²), 2y_i) / ((1 - Σy_i²) * √c)
where y = [y_1, ..., y_n] in Poincaré ball
Args:
y: Point in Poincaré ball, shape (dim,). Should satisfy ||y||² < 1/c.
c: Curvature (positive)
Returns:
Point on hyperboloid, shape (dim+1,). Satisfies ⟨x,x⟩_L = -1/c.
Examples:
>>> import jax.numpy as jnp
>>> from hyperbolix.manifolds import isometry_mappings
>>>
>>> # Convert Poincaré origin to hyperboloid origin
>>> y_origin = jnp.array([0.0, 0.0])
>>> x = isometry_mappings.poincare_to_hyperboloid(y_origin, c=1.0)
>>> jnp.allclose(x, jnp.array([1.0, 0.0, 0.0]))
True
References:
Wikipedia: Hyperboloid model - Relation to other models
"""
y_sqnorm = jnp.dot(y, y)
sqrt_c = jnp.sqrt(c)
# Inverse stereographic projection with curvature scaling
numerator = 1.0 + y_sqnorm
denominator = jnp.maximum((1.0 - y_sqnorm) * sqrt_c, MIN_DENOM)
t = numerator / denominator
x_spatial = 2.0 * y / denominator
# Concatenate temporal and spatial components: [t, x_1, ..., x_n]
return jnp.concatenate([jnp.array([t]), x_spatial])
|
Usage Examples
Basic Distance Computation
import jax.numpy as jnp
from hyperbolix.manifolds import Poincare
poincare = Poincare()
x = jnp.array([0.1, 0.2])
y = jnp.array([0.3, -0.1])
c = 1.0
# Compute distance (default: VERSION_MOBIUS_DIRECT)
distance = poincare.dist(x, y, c)
Float64 Precision
from hyperbolix.manifolds import Poincare
import jax.numpy as jnp
# High-precision manifold
poincare_f64 = Poincare(dtype=jnp.float64)
x = jnp.array([0.1, 0.2]) # float32 input
distance = poincare_f64.dist(x, y, c=1.0) # automatically cast to float64
print(distance.dtype) # float64
Batched Operations with vmap
import jax
from hyperbolix.manifolds import Hyperboloid
hyperboloid = Hyperboloid()
c = 1.0
# Batch of ambient points (d+1 dimensions)
x_batch = jax.random.normal(jax.random.PRNGKey(0), (100, 4))
y_batch = jax.random.normal(jax.random.PRNGKey(1), (100, 4))
# Project to hyperboloid
x_proj = jax.vmap(hyperboloid.proj, in_axes=(0, None))(x_batch, c)
y_proj = jax.vmap(hyperboloid.proj, in_axes=(0, None))(y_batch, c)
# Compute distances
distances = jax.vmap(hyperboloid.dist, in_axes=(0, 0, None))(x_proj, y_proj, c)
Exponential and Logarithmic Maps
from hyperbolix.manifolds import Poincare
import jax.numpy as jnp
poincare = Poincare()
# Point on manifold
x = poincare.proj(jnp.array([0.2, 0.3]), c=1.0)
# Tangent vector
v = jnp.array([0.1, -0.05])
# Exponential map (move along geodesic)
y = poincare.expmap(v, x, c=1.0)
# Logarithmic map (inverse operation)
v_recovered = poincare.logmap(y, x, c=1.0)
Isometry Mappings
from hyperbolix.manifolds import isometry_mappings
import jax.numpy as jnp
# Hyperboloid point (ambient coordinates, d+1 dims)
x_hyperboloid = jnp.array([1.5, 0.5, 0.3]) # Must satisfy Lorentz constraint
# Map to Poincaré ball (intrinsic coordinates, d dims)
x_poincare = isometry_mappings.hyperboloid_to_poincare(x_hyperboloid, c=1.0)
# Map back (round-trip)
x_hyperboloid_recovered = isometry_mappings.poincare_to_hyperboloid(x_poincare, c=1.0)
Numerical Considerations
Float32 Precision
Float32 can cause numerical issues, especially in the Poincaré ball near the boundary. Use Poincare(dtype=jnp.float64) for:
- High curvature values (
c > 1.0)
- Points near manifold boundaries
- Deep neural networks with many layers
See the Numerical Stability guide for details.