Skip to content

Manifolds API

This page documents the core manifold operations in Hyperbolix. Each manifold is a class that provides geometric operations and automatic dtype casting.

Overview

Hyperbolix provides three manifold classes:

  • Euclidean: Flat Euclidean space (baseline)
  • Poincaré Ball: Conformal model of hyperbolic space
  • Hyperboloid: Lorentz/Minkowski model of hyperbolic space

All manifolds share a common interface defined by the Manifold protocol and support:

  • Automatic dtype casting: Pass dtype=jnp.float64 for higher precision
  • vmap-native methods: Methods operate on single points; use jax.vmap for batching
  • JIT compatibility: All methods are JIT-compilable

Manifold Protocol

hyperbolix.manifolds.protocol.Manifold

Bases: Protocol

Structural protocol for manifold classes.

All three concrete manifold classes (Poincare, Hyperboloid, Euclidean) satisfy this protocol without modification.

The method signatures use the minimal common interface so that manifold-specific optional parameters (e.g. version_idx, atol) do not break compatibility.

Euclidean

Flat Euclidean space (identity operations).

hyperbolix.manifolds.euclidean.Euclidean

Euclidean(dtype: dtype = jnp.float32)

Bases: ManifoldBase

Euclidean manifold with automatic dtype casting.

Provides all manifold operations with automatic casting of array inputs to the specified dtype.

Args: dtype: Target JAX dtype for computations (default: jnp.float32)

Examples: >>> import jax.numpy as jnp >>> from hyperbolix.manifolds.euclidean import Euclidean >>> >>> manifold = Euclidean(dtype=jnp.float64) >>> x = jnp.array([1.0, 2.0]) >>> y = jnp.array([3.0, 4.0]) >>> d = manifold.dist(x, y)

Source code in hyperbolix/manifolds/_base.py
def __init__(self, dtype: jnp.dtype = jnp.float32) -> None:
    self.dtype = dtype

proj

proj(
    x: Float[Array, dim], c: float = 0.0
) -> Float[Array, dim]

Project point onto Euclidean space (identity).

Source code in hyperbolix/manifolds/euclidean.py
def proj(self, x: Float[Array, "dim"], c: float = 0.0) -> Float[Array, "dim"]:
    """Project point onto Euclidean space (identity)."""
    return _proj(self._cast(x), c)

addition

addition(
    x: Float[Array, dim],
    y: Float[Array, dim],
    c: float = 0.0,
) -> Float[Array, dim]

Add Euclidean points.

Source code in hyperbolix/manifolds/euclidean.py
def addition(self, x: Float[Array, "dim"], y: Float[Array, "dim"], c: float = 0.0) -> Float[Array, "dim"]:
    """Add Euclidean points."""
    return _addition(self._cast(x), self._cast(y), c)

scalar_mul

scalar_mul(
    r: float, x: Float[Array, dim], c: float = 0.0
) -> Float[Array, dim]

Scalar multiplication.

Source code in hyperbolix/manifolds/euclidean.py
def scalar_mul(self, r: float, x: Float[Array, "dim"], c: float = 0.0) -> Float[Array, "dim"]:
    """Scalar multiplication."""
    x = self._cast(x)
    r_cast = jnp.asarray(r, dtype=x.dtype)
    return _scalar_mul(r_cast, x, c)  # type: ignore[arg-type]

dist

dist(
    x: Float[Array, dim],
    y: Float[Array, dim],
    c: float = 0.0,
) -> Float[Array, ""]

Compute distance.

Source code in hyperbolix/manifolds/euclidean.py
def dist(self, x: Float[Array, "dim"], y: Float[Array, "dim"], c: float = 0.0) -> Float[Array, ""]:
    """Compute distance."""
    return _dist(self._cast(x), self._cast(y), c)

dist_0

dist_0(
    x: Float[Array, dim], c: float = 0.0
) -> Float[Array, ""]

Distance from origin.

Source code in hyperbolix/manifolds/euclidean.py
def dist_0(self, x: Float[Array, "dim"], c: float = 0.0) -> Float[Array, ""]:
    """Distance from origin."""
    return _dist_0(self._cast(x), c)

expmap

expmap(
    v: Float[Array, dim],
    x: Float[Array, dim],
    c: float = 0.0,
) -> Float[Array, dim]

Exponential map.

Source code in hyperbolix/manifolds/euclidean.py
def expmap(self, v: Float[Array, "dim"], x: Float[Array, "dim"], c: float = 0.0) -> Float[Array, "dim"]:
    """Exponential map."""
    return _expmap(self._cast(v), self._cast(x), c)

expmap_0

expmap_0(
    v: Float[Array, dim], c: float = 0.0
) -> Float[Array, dim]

Exponential map from origin.

Source code in hyperbolix/manifolds/euclidean.py
def expmap_0(self, v: Float[Array, "dim"], c: float = 0.0) -> Float[Array, "dim"]:
    """Exponential map from origin."""
    return _expmap_0(self._cast(v), c)

retraction

retraction(
    v: Float[Array, dim],
    x: Float[Array, dim],
    c: float = 0.0,
) -> Float[Array, dim]

Retraction.

Source code in hyperbolix/manifolds/euclidean.py
def retraction(self, v: Float[Array, "dim"], x: Float[Array, "dim"], c: float = 0.0) -> Float[Array, "dim"]:
    """Retraction."""
    return _retraction(self._cast(v), self._cast(x), c)

logmap

logmap(
    y: Float[Array, dim],
    x: Float[Array, dim],
    c: float = 0.0,
) -> Float[Array, dim]

Logarithmic map.

Source code in hyperbolix/manifolds/euclidean.py
def logmap(self, y: Float[Array, "dim"], x: Float[Array, "dim"], c: float = 0.0) -> Float[Array, "dim"]:
    """Logarithmic map."""
    return _logmap(self._cast(y), self._cast(x), c)

logmap_0

logmap_0(
    y: Float[Array, dim], c: float = 0.0
) -> Float[Array, dim]

Logarithmic map from origin.

Source code in hyperbolix/manifolds/euclidean.py
def logmap_0(self, y: Float[Array, "dim"], c: float = 0.0) -> Float[Array, "dim"]:
    """Logarithmic map from origin."""
    return _logmap_0(self._cast(y), c)

ptransp

ptransp(
    v: Float[Array, dim],
    x: Float[Array, dim],
    y: Float[Array, dim],
    c: float = 0.0,
) -> Float[Array, dim]

Parallel transport.

Source code in hyperbolix/manifolds/euclidean.py
def ptransp(
    self, v: Float[Array, "dim"], x: Float[Array, "dim"], y: Float[Array, "dim"], c: float = 0.0
) -> Float[Array, "dim"]:
    """Parallel transport."""
    return _ptransp(self._cast(v), self._cast(x), self._cast(y), c)

ptransp_0

ptransp_0(
    v: Float[Array, dim],
    y: Float[Array, dim],
    c: float = 0.0,
) -> Float[Array, dim]

Parallel transport from origin.

Source code in hyperbolix/manifolds/euclidean.py
def ptransp_0(self, v: Float[Array, "dim"], y: Float[Array, "dim"], c: float = 0.0) -> Float[Array, "dim"]:
    """Parallel transport from origin."""
    return _ptransp_0(self._cast(v), self._cast(y), c)

tangent_inner

tangent_inner(
    u: Float[Array, dim],
    v: Float[Array, dim],
    x: Float[Array, dim],
    c: float = 0.0,
) -> Float[Array, ""]

Tangent inner product.

Source code in hyperbolix/manifolds/euclidean.py
def tangent_inner(
    self, u: Float[Array, "dim"], v: Float[Array, "dim"], x: Float[Array, "dim"], c: float = 0.0
) -> Float[Array, ""]:
    """Tangent inner product."""
    return _tangent_inner(self._cast(u), self._cast(v), self._cast(x), c)

tangent_norm

tangent_norm(
    v: Float[Array, dim],
    x: Float[Array, dim],
    c: float = 0.0,
) -> Float[Array, ""]

Tangent norm.

Source code in hyperbolix/manifolds/euclidean.py
def tangent_norm(self, v: Float[Array, "dim"], x: Float[Array, "dim"], c: float = 0.0) -> Float[Array, ""]:
    """Tangent norm."""
    return _tangent_norm(self._cast(v), self._cast(x), c)

egrad2rgrad

egrad2rgrad(
    grad: Float[Array, dim],
    x: Float[Array, dim],
    c: float = 0.0,
) -> Float[Array, dim]

Euclidean to Riemannian gradient.

Source code in hyperbolix/manifolds/euclidean.py
def egrad2rgrad(self, grad: Float[Array, "dim"], x: Float[Array, "dim"], c: float = 0.0) -> Float[Array, "dim"]:
    """Euclidean to Riemannian gradient."""
    return _egrad2rgrad(self._cast(grad), self._cast(x), c)

tangent_proj

tangent_proj(
    v: Float[Array, dim],
    x: Float[Array, dim],
    c: float = 0.0,
) -> Float[Array, dim]

Project onto tangent space.

Source code in hyperbolix/manifolds/euclidean.py
def tangent_proj(self, v: Float[Array, "dim"], x: Float[Array, "dim"], c: float = 0.0) -> Float[Array, "dim"]:
    """Project onto tangent space."""
    return _tangent_proj(self._cast(v), self._cast(x), c)

is_in_manifold

is_in_manifold(
    x: Float[Array, dim], c: float = 0.0
) -> Array

Check if on manifold.

Source code in hyperbolix/manifolds/euclidean.py
def is_in_manifold(self, x: Float[Array, "dim"], c: float = 0.0) -> Array:
    """Check if on manifold."""
    return _is_in_manifold(self._cast(x), c)

is_in_tangent_space

is_in_tangent_space(
    v: Float[Array, dim],
    x: Float[Array, dim],
    c: float = 0.0,
) -> Array

Check if in tangent space.

Source code in hyperbolix/manifolds/euclidean.py
def is_in_tangent_space(self, v: Float[Array, "dim"], x: Float[Array, "dim"], c: float = 0.0) -> Array:
    """Check if in tangent space."""
    return _is_in_tangent_space(self._cast(v), self._cast(x), c)

Poincaré Ball

The Poincaré ball model with Möbius operations.

Distance Versions

The Poincaré dist method has a version_idx parameter selecting between 4 formulations:

  • VERSION_MOBIUS_DIRECT (0): Möbius addition formula (default, fastest)
  • VERSION_MOBIUS (1): Möbius via addition
  • VERSION_METRIC_TENSOR (2): Direct metric tensor integration
  • VERSION_LORENTZIAN_PROXY (3): Lorentzian model proxy (best near boundary)

Constants are available as poincare.VERSION_MOBIUS_DIRECT etc., or from hyperbolix.manifolds.poincare.

hyperbolix.manifolds.poincare.Poincare

Poincare(dtype: dtype = jnp.float32)

Bases: ManifoldBase

Poincaré ball manifold with automatic dtype casting.

Provides all manifold operations with automatic casting of array inputs to the specified dtype. This eliminates the need for manual casting and provides better numerical stability control.

Args: dtype: Target JAX dtype for computations (default: jnp.float32)

Examples: >>> import jax.numpy as jnp >>> from hyperbolix.manifolds.poincare import Poincare, VERSION_MOBIUS_DIRECT >>> >>> # Create manifold with float64 for better precision >>> manifold = Poincare(dtype=jnp.float64) >>> >>> # Arrays are automatically cast to float64 >>> x = jnp.array([0.1, 0.2], dtype=jnp.float32) >>> y = jnp.array([0.3, 0.4], dtype=jnp.float32) >>> d = manifold.dist(x, y, c=1.0) >>> d.dtype # float64

Source code in hyperbolix/manifolds/_base.py
def __init__(self, dtype: jnp.dtype = jnp.float32) -> None:
    self.dtype = dtype

proj

proj(x: Float[Array, dim], c: float) -> Float[Array, dim]

Project point onto Poincaré ball by clipping norm.

Source code in hyperbolix/manifolds/poincare.py
def proj(self, x: Float[Array, "dim"], c: float) -> Float[Array, "dim"]:
    """Project point onto Poincaré ball by clipping norm."""
    return _proj(self._cast(x), c)

gyration

gyration(
    x: Float[Array, dim],
    y: Float[Array, dim],
    z: Float[Array, dim],
    c: float,
) -> Float[Array, dim]

Compute gyration gyr[x,y]z to restore commutativity.

Source code in hyperbolix/manifolds/poincare.py
def gyration(
    self, x: Float[Array, "dim"], y: Float[Array, "dim"], z: Float[Array, "dim"], c: float
) -> Float[Array, "dim"]:
    """Compute gyration gyr[x,y]z to restore commutativity."""
    return _gyration(self._cast(x), self._cast(y), self._cast(z), c)

addition

addition(
    x: Float[Array, dim], y: Float[Array, dim], c: float
) -> Float[Array, dim]

Möbius gyrovector addition x ⊕ y.

Source code in hyperbolix/manifolds/poincare.py
def addition(self, x: Float[Array, "dim"], y: Float[Array, "dim"], c: float) -> Float[Array, "dim"]:
    """Möbius gyrovector addition x ⊕ y."""
    return _addition(self._cast(x), self._cast(y), c)

scalar_mul

scalar_mul(
    r: float, x: Float[Array, dim], c: float
) -> Float[Array, dim]

Scalar multiplication r ⊗ x on Poincaré ball.

Source code in hyperbolix/manifolds/poincare.py
def scalar_mul(self, r: float, x: Float[Array, "dim"], c: float) -> Float[Array, "dim"]:
    """Scalar multiplication r ⊗ x on Poincaré ball."""
    x = self._cast(x)
    r_cast = jnp.asarray(r, dtype=x.dtype)
    return _scalar_mul(r_cast, x, c)  # type: ignore[arg-type]

dist

dist(
    x: Float[Array, dim],
    y: Float[Array, dim],
    c: float,
    version_idx: int = VERSION_MOBIUS_DIRECT,
) -> Float[Array, ""]

Compute geodesic distance between Poincaré ball points.

Source code in hyperbolix/manifolds/poincare.py
def dist(
    self,
    x: Float[Array, "dim"],
    y: Float[Array, "dim"],
    c: float,
    version_idx: int = VERSION_MOBIUS_DIRECT,
) -> Float[Array, ""]:
    """Compute geodesic distance between Poincaré ball points."""
    return _dist(self._cast(x), self._cast(y), c, version_idx)

dist_0

dist_0(
    x: Float[Array, dim],
    c: float,
    version_idx: int = VERSION_MOBIUS_DIRECT,
) -> Float[Array, ""]

Compute geodesic distance from Poincaré ball origin.

Source code in hyperbolix/manifolds/poincare.py
def dist_0(self, x: Float[Array, "dim"], c: float, version_idx: int = VERSION_MOBIUS_DIRECT) -> Float[Array, ""]:
    """Compute geodesic distance from Poincaré ball origin."""
    return _dist_0(self._cast(x), c, version_idx)

expmap

expmap(
    v: Float[Array, dim], x: Float[Array, dim], c: float
) -> Float[Array, dim]

Exponential map: map tangent vector v at point x to manifold.

Source code in hyperbolix/manifolds/poincare.py
def expmap(self, v: Float[Array, "dim"], x: Float[Array, "dim"], c: float) -> Float[Array, "dim"]:
    """Exponential map: map tangent vector v at point x to manifold."""
    return _expmap(self._cast(v), self._cast(x), c)

expmap_0

expmap_0(
    v: Float[Array, dim], c: float
) -> Float[Array, dim]

Exponential map from origin: map tangent vector v at origin to manifold.

Source code in hyperbolix/manifolds/poincare.py
def expmap_0(self, v: Float[Array, "dim"], c: float) -> Float[Array, "dim"]:
    """Exponential map from origin: map tangent vector v at origin to manifold."""
    return _expmap_0(self._cast(v), c)

retraction

retraction(
    v: Float[Array, dim], x: Float[Array, dim], c: float
) -> Float[Array, dim]

Retraction: first-order approximation of exponential map.

Source code in hyperbolix/manifolds/poincare.py
def retraction(self, v: Float[Array, "dim"], x: Float[Array, "dim"], c: float) -> Float[Array, "dim"]:
    """Retraction: first-order approximation of exponential map."""
    return _retraction(self._cast(v), self._cast(x), c)

logmap

logmap(
    y: Float[Array, dim], x: Float[Array, dim], c: float
) -> Float[Array, dim]

Logarithmic map: map point y to tangent space at point x.

Source code in hyperbolix/manifolds/poincare.py
def logmap(self, y: Float[Array, "dim"], x: Float[Array, "dim"], c: float) -> Float[Array, "dim"]:
    """Logarithmic map: map point y to tangent space at point x."""
    return _logmap(self._cast(y), self._cast(x), c)

logmap_0

logmap_0(
    y: Float[Array, dim], c: float
) -> Float[Array, dim]

Logarithmic map from origin: map point y to tangent space at origin.

Source code in hyperbolix/manifolds/poincare.py
def logmap_0(self, y: Float[Array, "dim"], c: float) -> Float[Array, "dim"]:
    """Logarithmic map from origin: map point y to tangent space at origin."""
    return _logmap_0(self._cast(y), c)

ptransp

ptransp(
    v: Float[Array, dim],
    x: Float[Array, dim],
    y: Float[Array, dim],
    c: float,
) -> Float[Array, dim]

Parallel transport tangent vector v from point x to point y.

Source code in hyperbolix/manifolds/poincare.py
def ptransp(self, v: Float[Array, "dim"], x: Float[Array, "dim"], y: Float[Array, "dim"], c: float) -> Float[Array, "dim"]:
    """Parallel transport tangent vector v from point x to point y."""
    return _ptransp(self._cast(v), self._cast(x), self._cast(y), c)

ptransp_0

ptransp_0(
    v: Float[Array, dim], y: Float[Array, dim], c: float
) -> Float[Array, dim]

Parallel transport tangent vector v from origin to point y.

Source code in hyperbolix/manifolds/poincare.py
def ptransp_0(self, v: Float[Array, "dim"], y: Float[Array, "dim"], c: float) -> Float[Array, "dim"]:
    """Parallel transport tangent vector v from origin to point y."""
    return _ptransp_0(self._cast(v), self._cast(y), c)

tangent_inner

tangent_inner(
    u: Float[Array, dim],
    v: Float[Array, dim],
    x: Float[Array, dim],
    c: float,
) -> Float[Array, ""]

Compute inner product of tangent vectors u and v at point x.

Source code in hyperbolix/manifolds/poincare.py
def tangent_inner(
    self, u: Float[Array, "dim"], v: Float[Array, "dim"], x: Float[Array, "dim"], c: float
) -> Float[Array, ""]:
    """Compute inner product of tangent vectors u and v at point x."""
    return _tangent_inner(self._cast(u), self._cast(v), self._cast(x), c)

tangent_norm

tangent_norm(
    v: Float[Array, dim], x: Float[Array, dim], c: float
) -> Float[Array, ""]

Compute norm of tangent vector v at point x.

Source code in hyperbolix/manifolds/poincare.py
def tangent_norm(self, v: Float[Array, "dim"], x: Float[Array, "dim"], c: float) -> Float[Array, ""]:
    """Compute norm of tangent vector v at point x."""
    return _tangent_norm(self._cast(v), self._cast(x), c)

egrad2rgrad

egrad2rgrad(
    grad: Float[Array, dim], x: Float[Array, dim], c: float
) -> Float[Array, dim]

Convert Euclidean gradient to Riemannian gradient.

Source code in hyperbolix/manifolds/poincare.py
def egrad2rgrad(self, grad: Float[Array, "dim"], x: Float[Array, "dim"], c: float) -> Float[Array, "dim"]:
    """Convert Euclidean gradient to Riemannian gradient."""
    return _egrad2rgrad(self._cast(grad), self._cast(x), c)

tangent_proj

tangent_proj(
    v: Float[Array, dim], x: Float[Array, dim], c: float
) -> Float[Array, dim]

Project vector v onto tangent space at point x.

Source code in hyperbolix/manifolds/poincare.py
def tangent_proj(self, v: Float[Array, "dim"], x: Float[Array, "dim"], c: float) -> Float[Array, "dim"]:
    """Project vector v onto tangent space at point x."""
    return _tangent_proj(self._cast(v), self._cast(x), c)

is_in_manifold

is_in_manifold(
    x: Float[Array, dim], c: float, atol: float = 1e-05
) -> Array

Check if point x lies in Poincaré ball.

Source code in hyperbolix/manifolds/poincare.py
def is_in_manifold(self, x: Float[Array, "dim"], c: float, atol: float = 1e-5) -> Array:
    """Check if point x lies in Poincaré ball."""
    return _is_in_manifold(self._cast(x), c, atol)

is_in_tangent_space

is_in_tangent_space(
    v: Float[Array, dim], x: Float[Array, dim], c: float
) -> Array

Check if vector v lies in tangent space at point x.

Source code in hyperbolix/manifolds/poincare.py
def is_in_tangent_space(self, v: Float[Array, "dim"], x: Float[Array, "dim"], c: float) -> Array:
    """Check if vector v lies in tangent space at point x."""
    return _is_in_tangent_space(self._cast(v), self._cast(x), c)

conformal_factor

conformal_factor(
    x: Float[Array, "... dim"], c: float
) -> Float[Array, "... 1"]

Numerically stable conformal factor lambda(x) = 2 / (1 - c||x||^2).

Batch-compatible version that handles arbitrary leading dimensions.

Source code in hyperbolix/manifolds/poincare.py
def conformal_factor(self, x: Float[Array, "... dim"], c: float) -> Float[Array, "... 1"]:
    """Numerically stable conformal factor lambda(x) = 2 / (1 - c||x||^2).

    Batch-compatible version that handles arbitrary leading dimensions.
    """
    return _conformal_factor_batch(self._cast(x), c)

compute_mlr_pp

compute_mlr_pp(
    x: Float[Array, "batch in_dim"],
    z: Float[Array, "out_dim in_dim"],
    r: Float[Array, "out_dim 1"],
    c: float,
    clamping_factor: float,
    smoothing_factor: float,
    min_enorm: float = 1e-15,
) -> Float[Array, "batch out_dim"]

Compute HNN++ multinomial linear regression on the Poincare ball.

Source code in hyperbolix/manifolds/poincare.py
def compute_mlr_pp(
    self,
    x: Float[Array, "batch in_dim"],
    z: Float[Array, "out_dim in_dim"],
    r: Float[Array, "out_dim 1"],
    c: float,
    clamping_factor: float,
    smoothing_factor: float,
    min_enorm: float = 1e-15,
) -> Float[Array, "batch out_dim"]:
    """Compute HNN++ multinomial linear regression on the Poincare ball."""
    return _compute_mlr_pp(self._cast(x), self._cast(z), self._cast(r), c, clamping_factor, smoothing_factor, min_enorm)

beta_concat

beta_concat(
    points: Float[Array, "M n_i"], c: float
) -> Float[Array, n]

Beta-concatenation of M equal-dimensional Poincaré ball points.

Source code in hyperbolix/manifolds/poincare.py
def beta_concat(self, points: Float[Array, "M n_i"], c: float) -> Float[Array, "n"]:
    """Beta-concatenation of M equal-dimensional Poincaré ball points."""
    return _beta_concat(self._cast(points), c)

Hyperboloid

The hyperboloid (Lorentz) model with Minkowski geometry.

Lorentz Operations

The Hyperboloid class includes specialized operations for convolutional layers:

  • lorentz_boost: Lorentz boost transformation
  • distance_rescale: Distance-based rescaling
  • hcat: Lorentz direct concatenation for convolutions

hyperbolix.manifolds.hyperboloid.Hyperboloid

Hyperboloid(dtype: dtype = jnp.float32)

Bases: ManifoldBase

Hyperboloid manifold with automatic dtype casting.

Provides all manifold operations with automatic casting of array inputs to the specified dtype. This eliminates the need for manual casting and provides better numerical stability control.

Args: dtype: Target JAX dtype for computations (default: jnp.float32)

Examples: >>> import jax.numpy as jnp >>> from hyperbolix.manifolds.hyperboloid import Hyperboloid, VERSION_DEFAULT >>> >>> # Create manifold with float64 for better precision >>> manifold = Hyperboloid(dtype=jnp.float64) >>> >>> # Arrays are automatically cast to float64 >>> x = jnp.array([1.0, 0.1, 0.2], dtype=jnp.float32) >>> x = manifold.proj(x, c=1.0) >>> x.dtype # float64

Source code in hyperbolix/manifolds/_base.py
def __init__(self, dtype: jnp.dtype = jnp.float32) -> None:
    self.dtype = dtype

create_origin

create_origin(
    c: float, dim: int
) -> Float[Array, dim_plus_1]

Create hyperboloid origin [1/√c, 0, ..., 0].

Source code in hyperbolix/manifolds/hyperboloid.py
def create_origin(self, c: float, dim: int) -> Float[Array, "dim_plus_1"]:
    """Create hyperboloid origin [1/√c, 0, ..., 0]."""
    return _create_origin(c, dim, self.dtype)

minkowski_inner

minkowski_inner(
    x: Float[Array, dim_plus_1], y: Float[Array, dim_plus_1]
) -> Float[Array, ""]

Compute Minkowski inner product ⟨x, y⟩_L = -x₀y₀ + ⟨x_rest, y_rest⟩.

Source code in hyperbolix/manifolds/hyperboloid.py
def minkowski_inner(self, x: Float[Array, "dim_plus_1"], y: Float[Array, "dim_plus_1"]) -> Float[Array, ""]:
    """Compute Minkowski inner product ⟨x, y⟩_L = -x₀y₀ + ⟨x_rest, y_rest⟩."""
    return _minkowski_inner(self._cast(x), self._cast(y))

proj

proj(
    x: Float[Array, dim_plus_1], c: float
) -> Float[Array, dim_plus_1]

Project point onto hyperboloid.

Source code in hyperbolix/manifolds/hyperboloid.py
def proj(self, x: Float[Array, "dim_plus_1"], c: float) -> Float[Array, "dim_plus_1"]:
    """Project point onto hyperboloid."""
    return _proj(self._cast(x), c)

proj_batch

proj_batch(
    x: Float[Array, "... dim_plus_1"], c: float
) -> Float[Array, "... dim_plus_1"]

Project batched points onto hyperboloid (handles arbitrary leading dimensions).

Source code in hyperbolix/manifolds/hyperboloid.py
def proj_batch(self, x: Float[Array, "... dim_plus_1"], c: float) -> Float[Array, "... dim_plus_1"]:
    """Project batched points onto hyperboloid (handles arbitrary leading dimensions)."""
    return _proj_batch(self._cast(x), c)

addition

addition(
    x: Float[Array, dim_plus_1],
    y: Float[Array, dim_plus_1],
    c: float,
) -> Float[Array, dim_plus_1]

Gyrovector addition on hyperboloid.

Source code in hyperbolix/manifolds/hyperboloid.py
def addition(self, x: Float[Array, "dim_plus_1"], y: Float[Array, "dim_plus_1"], c: float) -> Float[Array, "dim_plus_1"]:
    """Gyrovector addition on hyperboloid."""
    return _addition(self._cast(x), self._cast(y), c)

scalar_mul

scalar_mul(
    r: float, x: Float[Array, dim_plus_1], c: float
) -> Float[Array, dim_plus_1]

Scalar multiplication on hyperboloid.

Source code in hyperbolix/manifolds/hyperboloid.py
def scalar_mul(self, r: float, x: Float[Array, "dim_plus_1"], c: float) -> Float[Array, "dim_plus_1"]:
    """Scalar multiplication on hyperboloid."""
    x = self._cast(x)
    r_cast = jnp.asarray(r, dtype=x.dtype)
    return _scalar_mul(r_cast, x, c)  # type: ignore[arg-type]

dist

dist(
    x: Float[Array, dim_plus_1],
    y: Float[Array, dim_plus_1],
    c: float,
    version_idx: int = VERSION_DEFAULT,
) -> Float[Array, ""]

Compute geodesic distance between hyperboloid points.

Source code in hyperbolix/manifolds/hyperboloid.py
def dist(
    self,
    x: Float[Array, "dim_plus_1"],
    y: Float[Array, "dim_plus_1"],
    c: float,
    version_idx: int = VERSION_DEFAULT,
) -> Float[Array, ""]:
    """Compute geodesic distance between hyperboloid points."""
    return _dist(self._cast(x), self._cast(y), c, version_idx)

dist_0

dist_0(
    x: Float[Array, dim_plus_1],
    c: float,
    version_idx: int = VERSION_DEFAULT,
) -> Float[Array, ""]

Compute geodesic distance from hyperboloid origin.

Source code in hyperbolix/manifolds/hyperboloid.py
def dist_0(self, x: Float[Array, "dim_plus_1"], c: float, version_idx: int = VERSION_DEFAULT) -> Float[Array, ""]:
    """Compute geodesic distance from hyperboloid origin."""
    return _dist_0(self._cast(x), c, version_idx)

expmap

expmap(
    v: Float[Array, dim_plus_1],
    x: Float[Array, dim_plus_1],
    c: float,
) -> Float[Array, dim_plus_1]

Exponential map: map tangent vector v at point x to manifold.

Source code in hyperbolix/manifolds/hyperboloid.py
def expmap(self, v: Float[Array, "dim_plus_1"], x: Float[Array, "dim_plus_1"], c: float) -> Float[Array, "dim_plus_1"]:
    """Exponential map: map tangent vector v at point x to manifold."""
    return _expmap(self._cast(v), self._cast(x), c)

expmap_0

expmap_0(
    v: Float[Array, dim_plus_1], c: float
) -> Float[Array, dim_plus_1]

Exponential map from origin.

Source code in hyperbolix/manifolds/hyperboloid.py
def expmap_0(self, v: Float[Array, "dim_plus_1"], c: float) -> Float[Array, "dim_plus_1"]:
    """Exponential map from origin."""
    return _expmap_0(self._cast(v), c)

retraction

retraction(
    v: Float[Array, dim_plus_1],
    x: Float[Array, dim_plus_1],
    c: float,
) -> Float[Array, dim_plus_1]

Retraction: first-order approximation of exponential map.

Source code in hyperbolix/manifolds/hyperboloid.py
def retraction(self, v: Float[Array, "dim_plus_1"], x: Float[Array, "dim_plus_1"], c: float) -> Float[Array, "dim_plus_1"]:
    """Retraction: first-order approximation of exponential map."""
    return _retraction(self._cast(v), self._cast(x), c)

logmap

logmap(
    y: Float[Array, dim_plus_1],
    x: Float[Array, dim_plus_1],
    c: float,
) -> Float[Array, dim_plus_1]

Logarithmic map: map point y to tangent space at point x.

Source code in hyperbolix/manifolds/hyperboloid.py
def logmap(self, y: Float[Array, "dim_plus_1"], x: Float[Array, "dim_plus_1"], c: float) -> Float[Array, "dim_plus_1"]:
    """Logarithmic map: map point y to tangent space at point x."""
    return _logmap(self._cast(y), self._cast(x), c)

logmap_0

logmap_0(
    y: Float[Array, dim_plus_1], c: float
) -> Float[Array, dim_plus_1]

Logarithmic map from origin.

Source code in hyperbolix/manifolds/hyperboloid.py
def logmap_0(self, y: Float[Array, "dim_plus_1"], c: float) -> Float[Array, "dim_plus_1"]:
    """Logarithmic map from origin."""
    return _logmap_0(self._cast(y), c)

ptransp

ptransp(
    v: Float[Array, dim_plus_1],
    x: Float[Array, dim_plus_1],
    y: Float[Array, dim_plus_1],
    c: float,
) -> Float[Array, dim_plus_1]

Parallel transport tangent vector v from point x to point y.

Source code in hyperbolix/manifolds/hyperboloid.py
def ptransp(
    self, v: Float[Array, "dim_plus_1"], x: Float[Array, "dim_plus_1"], y: Float[Array, "dim_plus_1"], c: float
) -> Float[Array, "dim_plus_1"]:
    """Parallel transport tangent vector v from point x to point y."""
    return _ptransp(self._cast(v), self._cast(x), self._cast(y), c)

ptransp_0

ptransp_0(
    v: Float[Array, dim_plus_1],
    y: Float[Array, dim_plus_1],
    c: float,
) -> Float[Array, dim_plus_1]

Parallel transport tangent vector v from origin to point y.

Source code in hyperbolix/manifolds/hyperboloid.py
def ptransp_0(self, v: Float[Array, "dim_plus_1"], y: Float[Array, "dim_plus_1"], c: float) -> Float[Array, "dim_plus_1"]:
    """Parallel transport tangent vector v from origin to point y."""
    return _ptransp_0(self._cast(v), self._cast(y), c)

tangent_inner

tangent_inner(
    u: Float[Array, dim_plus_1],
    v: Float[Array, dim_plus_1],
    x: Float[Array, dim_plus_1],
    c: float,
) -> Float[Array, ""]

Compute inner product of tangent vectors u and v at point x.

Source code in hyperbolix/manifolds/hyperboloid.py
def tangent_inner(
    self, u: Float[Array, "dim_plus_1"], v: Float[Array, "dim_plus_1"], x: Float[Array, "dim_plus_1"], c: float
) -> Float[Array, ""]:
    """Compute inner product of tangent vectors u and v at point x."""
    return _tangent_inner(self._cast(u), self._cast(v), self._cast(x), c)

tangent_norm

tangent_norm(
    v: Float[Array, dim_plus_1],
    x: Float[Array, dim_plus_1],
    c: float,
) -> Float[Array, ""]

Compute norm of tangent vector v at point x.

Source code in hyperbolix/manifolds/hyperboloid.py
def tangent_norm(self, v: Float[Array, "dim_plus_1"], x: Float[Array, "dim_plus_1"], c: float) -> Float[Array, ""]:
    """Compute norm of tangent vector v at point x."""
    return _tangent_norm(self._cast(v), self._cast(x), c)

egrad2rgrad

egrad2rgrad(
    grad: Float[Array, dim_plus_1],
    x: Float[Array, dim_plus_1],
    c: float,
) -> Float[Array, dim_plus_1]

Convert Euclidean gradient to Riemannian gradient.

Source code in hyperbolix/manifolds/hyperboloid.py
def egrad2rgrad(
    self, grad: Float[Array, "dim_plus_1"], x: Float[Array, "dim_plus_1"], c: float
) -> Float[Array, "dim_plus_1"]:
    """Convert Euclidean gradient to Riemannian gradient."""
    return _egrad2rgrad(self._cast(grad), self._cast(x), c)

tangent_proj

tangent_proj(
    v: Float[Array, dim_plus_1],
    x: Float[Array, dim_plus_1],
    c: float,
) -> Float[Array, dim_plus_1]

Project vector v onto tangent space at point x.

Source code in hyperbolix/manifolds/hyperboloid.py
def tangent_proj(
    self, v: Float[Array, "dim_plus_1"], x: Float[Array, "dim_plus_1"], c: float
) -> Float[Array, "dim_plus_1"]:
    """Project vector v onto tangent space at point x."""
    return _tangent_proj(self._cast(v), self._cast(x), c)

is_in_manifold

is_in_manifold(
    x: Float[Array, dim_plus_1],
    c: float,
    atol: float = 0.0001,
) -> Array

Check if point x lies on hyperboloid.

Source code in hyperbolix/manifolds/hyperboloid.py
def is_in_manifold(self, x: Float[Array, "dim_plus_1"], c: float, atol: float = 1e-4) -> Array:
    """Check if point x lies on hyperboloid."""
    return _is_in_manifold(self._cast(x), c, atol)

is_in_tangent_space

is_in_tangent_space(
    v: Float[Array, dim_plus_1],
    x: Float[Array, dim_plus_1],
    c: float,
) -> Array

Check if vector v lies in tangent space at point x.

Source code in hyperbolix/manifolds/hyperboloid.py
def is_in_tangent_space(self, v: Float[Array, "dim_plus_1"], x: Float[Array, "dim_plus_1"], c: float) -> Array:
    """Check if vector v lies in tangent space at point x."""
    return _is_in_tangent_space(self._cast(v), self._cast(x), c)

hcat

hcat(
    points: Float[Array, "N n"], c: float = 1.0
) -> Float[Array, dN_plus_1]

Hyperbolic concatenation of N points into one point.

Source code in hyperbolix/manifolds/hyperboloid.py
def hcat(
    self,
    points: Float[Array, "N n"],
    c: float = 1.0,
) -> Float[Array, "dN_plus_1"]:
    """Hyperbolic concatenation of N points into one point."""
    return _hcat(self._cast(points), c)

embed_spatial_0

embed_spatial_0(
    v_spatial: Float[Array, "... n"],
) -> Float[Array, "... n_plus_1"]

Embed spatial vector as tangent vector at origin.

Source code in hyperbolix/manifolds/hyperboloid.py
def embed_spatial_0(self, v_spatial: Float[Array, "... n"]) -> Float[Array, "... n_plus_1"]:
    """Embed spatial vector as tangent vector at origin."""
    return _embed_spatial_0(self._cast(v_spatial))

compute_mlr

compute_mlr(
    x: Float[Array, "batch in_dim"],
    z: Float[Array, "out_dim in_dim_minus_1"],
    r: Float[Array, "out_dim 1"],
    c: float,
    clamping_factor: float,
    smoothing_factor: float,
    min_enorm: float = 1e-15,
) -> Float[Array, "batch out_dim"]

Compute multinomial linear regression on hyperboloid.

Source code in hyperbolix/manifolds/hyperboloid.py
def compute_mlr(
    self,
    x: Float[Array, "batch in_dim"],
    z: Float[Array, "out_dim in_dim_minus_1"],
    r: Float[Array, "out_dim 1"],
    c: float,
    clamping_factor: float,
    smoothing_factor: float,
    min_enorm: float = 1e-15,
) -> Float[Array, "batch out_dim"]:
    """Compute multinomial linear regression on hyperboloid."""
    return _compute_mlr(self._cast(x), self._cast(z), self._cast(r), c, clamping_factor, smoothing_factor, min_enorm)

Isometry Mappings

Distance-preserving maps between Poincaré ball and hyperboloid models.

hyperbolix.manifolds.isometry_mappings

Isometry mappings between hyperbolic manifold models.

This module implements distance-preserving transformations (isometries) between different models of hyperbolic geometry. All functions operate on single points and use JAX's vmap for batch operations.

Supported Models: - Hyperboloid model (Lorentz model): Points in R^(d+1) satisfying ⟨x,x⟩_L = -1/c - Poincaré ball model: Points in R^d with ||y||² < 1/c

The mappings are implemented via stereographic projection from the hyperboloid to the Poincaré ball, projecting through the point [-1, 0, ..., 0].

JIT Compilation & Batching

All functions work with single points and return single points. Use jax.vmap for batch operations:

>>> import jax
>>> import jax.numpy as jnp
>>> from hyperbolix.manifolds import isometry_mappings
>>>
>>> # Single point conversion
>>> x_hyp = jnp.array([1.0, 0.1, 0.2])  # Hyperboloid point
>>> y_poinc = isometry_mappings.hyperboloid_to_poincare(x_hyp, c=1.0)
>>>
>>> # Batch conversion with vmap
>>> x_batch = jnp.array([[1.0, 0.1, 0.2], [1.1, 0.15, 0.25]])
>>> convert_batch = jax.vmap(isometry_mappings.hyperboloid_to_poincare, in_axes=(0, None))
>>> y_batch = convert_batch(x_batch, 1.0)

References: Wikipedia: Hyperboloid model https://en.wikipedia.org/wiki/Hyperboloid_model#Relation_to_other_models

hyperboloid_to_poincare

hyperboloid_to_poincare(
    x: Float[Array, dim_plus_1], c: Float[Array, ""] | float
) -> Float[Array, dim]

Convert hyperboloid point to Poincaré ball via stereographic projection.

Projects the hyperboloid point onto the hyperplane t = 0 by intersecting with a line through [-1, 0, ..., 0]. This implements the canonical isometry between the two models.

Formula: y_i = x_i / (1 + t) where x = [t, x_1, ..., x_n] on hyperboloid

Args: x: Point on hyperboloid, shape (dim+1,). Should satisfy ⟨x,x⟩_L = -1/c. c: Curvature (positive)

Returns: Point in Poincaré ball, shape (dim,). Satisfies ||y||² < 1/c.

Examples: >>> import jax.numpy as jnp >>> from hyperbolix.manifolds import isometry_mappings >>> >>> # Convert hyperboloid origin to Poincaré origin >>> x_origin = jnp.array([1.0, 0.0, 0.0]) # c=1.0 origin >>> y = isometry_mappings.hyperboloid_to_poincare(x_origin, c=1.0) >>> jnp.allclose(y, jnp.zeros(2)) True

References: Wikipedia: Hyperboloid model - Relation to other models

Source code in hyperbolix/manifolds/isometry_mappings.py
def hyperboloid_to_poincare(
    x: Float[Array, "dim_plus_1"],
    c: Float[Array, ""] | float,
) -> Float[Array, "dim"]:
    """Convert hyperboloid point to Poincaré ball via stereographic projection.

    Projects the hyperboloid point onto the hyperplane t = 0 by intersecting
    with a line through [-1, 0, ..., 0]. This implements the canonical isometry
    between the two models.

    Formula:
        y_i = x_i / (1 + t)
        where x = [t, x_1, ..., x_n] on hyperboloid

    Args:
        x: Point on hyperboloid, shape (dim+1,). Should satisfy ⟨x,x⟩_L = -1/c.
        c: Curvature (positive)

    Returns:
        Point in Poincaré ball, shape (dim,). Satisfies ||y||² < 1/c.

    Examples:
        >>> import jax.numpy as jnp
        >>> from hyperbolix.manifolds import isometry_mappings
        >>>
        >>> # Convert hyperboloid origin to Poincaré origin
        >>> x_origin = jnp.array([1.0, 0.0, 0.0])  # c=1.0 origin
        >>> y = isometry_mappings.hyperboloid_to_poincare(x_origin, c=1.0)
        >>> jnp.allclose(y, jnp.zeros(2))
        True

    References:
        Wikipedia: Hyperboloid model - Relation to other models
    """
    t = x[0]  # Temporal component
    x_spatial = x[1:]  # Spatial components (x_1, ..., x_n)

    # Stereographic projection: y_i = x_i / (1 + t)
    denominator = jnp.maximum(1.0 + t, MIN_DENOM)
    return x_spatial / denominator

poincare_to_hyperboloid

poincare_to_hyperboloid(
    y: Float[Array, dim], c: Float[Array, ""] | float
) -> Float[Array, dim_plus_1]

Convert Poincaré ball point to hyperboloid via inverse stereographic projection.

Inverts the stereographic projection to map points from the Poincaré ball back to the hyperboloid. This implements the canonical isometry between the two models.

Formula: (t, x_i) = ((1 + Σy_i²), 2y_i) / ((1 - Σy_i²) * √c) where y = [y_1, ..., y_n] in Poincaré ball

Args: y: Point in Poincaré ball, shape (dim,). Should satisfy ||y||² < 1/c. c: Curvature (positive)

Returns: Point on hyperboloid, shape (dim+1,). Satisfies ⟨x,x⟩_L = -1/c.

Examples: >>> import jax.numpy as jnp >>> from hyperbolix.manifolds import isometry_mappings >>> >>> # Convert Poincaré origin to hyperboloid origin >>> y_origin = jnp.array([0.0, 0.0]) >>> x = isometry_mappings.poincare_to_hyperboloid(y_origin, c=1.0) >>> jnp.allclose(x, jnp.array([1.0, 0.0, 0.0])) True

References: Wikipedia: Hyperboloid model - Relation to other models

Source code in hyperbolix/manifolds/isometry_mappings.py
def poincare_to_hyperboloid(
    y: Float[Array, "dim"],
    c: Float[Array, ""] | float,
) -> Float[Array, "dim_plus_1"]:
    """Convert Poincaré ball point to hyperboloid via inverse stereographic projection.

    Inverts the stereographic projection to map points from the Poincaré ball
    back to the hyperboloid. This implements the canonical isometry between
    the two models.

    Formula:
        (t, x_i) = ((1 + Σy_i²), 2y_i) / ((1 - Σy_i²) * √c)
        where y = [y_1, ..., y_n] in Poincaré ball

    Args:
        y: Point in Poincaré ball, shape (dim,). Should satisfy ||y||² < 1/c.
        c: Curvature (positive)

    Returns:
        Point on hyperboloid, shape (dim+1,). Satisfies ⟨x,x⟩_L = -1/c.

    Examples:
        >>> import jax.numpy as jnp
        >>> from hyperbolix.manifolds import isometry_mappings
        >>>
        >>> # Convert Poincaré origin to hyperboloid origin
        >>> y_origin = jnp.array([0.0, 0.0])
        >>> x = isometry_mappings.poincare_to_hyperboloid(y_origin, c=1.0)
        >>> jnp.allclose(x, jnp.array([1.0, 0.0, 0.0]))
        True

    References:
        Wikipedia: Hyperboloid model - Relation to other models
    """
    y_sqnorm = jnp.dot(y, y)
    sqrt_c = jnp.sqrt(c)

    # Inverse stereographic projection with curvature scaling
    numerator = 1.0 + y_sqnorm
    denominator = jnp.maximum((1.0 - y_sqnorm) * sqrt_c, MIN_DENOM)

    t = numerator / denominator
    x_spatial = 2.0 * y / denominator

    # Concatenate temporal and spatial components: [t, x_1, ..., x_n]
    return jnp.concatenate([jnp.array([t]), x_spatial])

Usage Examples

Basic Distance Computation

import jax.numpy as jnp
from hyperbolix.manifolds import Poincare

poincare = Poincare()

x = jnp.array([0.1, 0.2])
y = jnp.array([0.3, -0.1])
c = 1.0

# Compute distance (default: VERSION_MOBIUS_DIRECT)
distance = poincare.dist(x, y, c)

Float64 Precision

from hyperbolix.manifolds import Poincare
import jax.numpy as jnp

# High-precision manifold
poincare_f64 = Poincare(dtype=jnp.float64)

x = jnp.array([0.1, 0.2])  # float32 input
distance = poincare_f64.dist(x, y, c=1.0)  # automatically cast to float64
print(distance.dtype)  # float64

Batched Operations with vmap

import jax
from hyperbolix.manifolds import Hyperboloid

hyperboloid = Hyperboloid()
c = 1.0

# Batch of ambient points (d+1 dimensions)
x_batch = jax.random.normal(jax.random.PRNGKey(0), (100, 4))
y_batch = jax.random.normal(jax.random.PRNGKey(1), (100, 4))

# Project to hyperboloid
x_proj = jax.vmap(hyperboloid.proj, in_axes=(0, None))(x_batch, c)
y_proj = jax.vmap(hyperboloid.proj, in_axes=(0, None))(y_batch, c)

# Compute distances
distances = jax.vmap(hyperboloid.dist, in_axes=(0, 0, None))(x_proj, y_proj, c)

Exponential and Logarithmic Maps

from hyperbolix.manifolds import Poincare
import jax.numpy as jnp

poincare = Poincare()

# Point on manifold
x = poincare.proj(jnp.array([0.2, 0.3]), c=1.0)

# Tangent vector
v = jnp.array([0.1, -0.05])

# Exponential map (move along geodesic)
y = poincare.expmap(v, x, c=1.0)

# Logarithmic map (inverse operation)
v_recovered = poincare.logmap(y, x, c=1.0)

Isometry Mappings

from hyperbolix.manifolds import isometry_mappings
import jax.numpy as jnp

# Hyperboloid point (ambient coordinates, d+1 dims)
x_hyperboloid = jnp.array([1.5, 0.5, 0.3])  # Must satisfy Lorentz constraint

# Map to Poincaré ball (intrinsic coordinates, d dims)
x_poincare = isometry_mappings.hyperboloid_to_poincare(x_hyperboloid, c=1.0)

# Map back (round-trip)
x_hyperboloid_recovered = isometry_mappings.poincare_to_hyperboloid(x_poincare, c=1.0)

Numerical Considerations

Float32 Precision

Float32 can cause numerical issues, especially in the Poincaré ball near the boundary. Use Poincare(dtype=jnp.float64) for:

  • High curvature values (c > 1.0)
  • Points near manifold boundaries
  • Deep neural networks with many layers

See the Numerical Stability guide for details.