Skip to content

Distributions API

Probability distributions on hyperbolic manifolds.

Overview

Hyperbolix provides wrapped normal distributions for probabilistic modeling on hyperbolic manifolds via functional interfaces. These distributions are essential for:

  • Variational Autoencoders (VAEs) with hyperbolic latent spaces
  • Bayesian neural networks on manifolds
  • Uncertainty quantification in hyperbolic embeddings

Riemannian Uniform Distribution

Poincaré Uniform

hyperbolix.distributions.uniform_poincare

Riemannian-uniform distribution on a geodesic ball in the Poincaré model.

Samples points uniformly (w.r.t. the Riemannian volume element) within a geodesic ball B(center, R) of finite radius R. Uses geodesic polar coordinates: sample a direction on S^{n-1}, sample a radius from the hyperbolic radial density, form a tangent vector, and map to the ball.

The radial density is p(r) ∝ sinh^{n-1}(√c·r) on [0, R]. A substitution u = cosh(√c·r) - 1 simplifies sampling: - n = 2: u is uniform on [0, cosh(√c·R) - 1] (closed-form) - n ≥ 3: rejection sampling with acceptance ∝ (u·(u+2))^{(n-2)/2}

Dimension key: S: sample dimensions (from sample_shape) D: spatial/manifold dimension (n) Q: quadrature points (64 for GL) T: total flattened samples (for rejection sampling)

volume
volume(c: float, n: int, R: float) -> Float[Array, '']

Riemannian volume of a geodesic ball B^n_c(R) in n-dim hyperbolic space.

Vol = ω_{n-1} / c^{(n-1)/2} · ∫₀ᴿ sinh^{n-1}(√c·r) dr

where ω_{n-1} is the surface area of the unit (n-1)-sphere.

Computed via 64-point Gauss-Legendre quadrature.

Args: c: Positive curvature parameter. n: Ambient dimension of the Poincaré ball. R: Geodesic radius of the ball.

Returns: Scalar volume.

Source code in hyperbolix/distributions/uniform_poincare.py
def volume(c: float, n: int, R: float) -> Float[Array, ""]:
    """Riemannian volume of a geodesic ball B^n_c(R) in n-dim hyperbolic space.

    Vol = ω_{n-1} / c^{(n-1)/2} · ∫₀ᴿ sinh^{n-1}(√c·r) dr

    where ω_{n-1} is the surface area of the unit (n-1)-sphere.

    Computed via 64-point Gauss-Legendre quadrature.

    Args:
        c: Positive curvature parameter.
        n: Ambient dimension of the Poincaré ball.
        R: Geodesic radius of the ball.

    Returns:
        Scalar volume.
    """
    # ω_{n-1} = 2 π^{n/2} / Γ(n/2)
    omega = 2.0 * jnp.pi ** (n / 2.0) / jnp.exp(jax.lax.lgamma(n / 2.0))

    sqrt_c = jnp.sqrt(jnp.float64(c))

    # Map GL nodes from [-1, 1] to [0, R]: r = R/2 · (t + 1)
    r_nodes_Q = (R / 2.0) * (_GL_NODES + 1.0)
    integrand_Q = jnp.sinh(sqrt_c * r_nodes_Q) ** (n - 1)
    integral = (R / 2.0) * jnp.sum(_GL_WEIGHTS * integrand_Q)

    vol = omega / sqrt_c ** (n - 1) * integral
    return vol
sample
sample(
    key: PRNGKeyArray,
    n: int,
    c: float,
    R: float,
    sample_shape: tuple[int, ...] = (),
    center: Float[Array, n] | None = None,
    dtype=None,
    manifold_module: Manifold | None = None,
) -> Float[Array, "... n"]

Sample uniformly from a geodesic ball in the Poincaré model.

Draws points that are Riemannian-uniform within B(center, R).

Algorithm: 1. Sample direction u ~ Uniform(S^{n-1}) 2. Sample geodesic radius r ~ p(r) ∝ sinh^{n-1}(√c·r) on [0, R] 3. Form tangent vector t = (r/2)·u (the /2 accounts for λ(0)=2) 4. Map to ball: x₀ = expmap_0(t, c) 5. Move to center: x = center ⊕ x₀ (Möbius addition)

Args: key: JAX PRNG key. n: Dimension of the Poincaré ball. c: Positive curvature parameter. R: Geodesic radius of the ball. sample_shape: Batch shape of samples. Default: () → single sample. center: Center of the geodesic ball, shape (n,). Default: origin. dtype: Output dtype. Default: float64. manifold_module: Optional Manifold instance. Default: Poincare(dtype).

Returns: Samples on the Poincaré ball, shape sample_shape + (n,).

Examples: >>> import jax >>> import jax.numpy as jnp >>> from hyperbolix.distributions import uniform_poincare >>> >>> key = jax.random.PRNGKey(0) >>> x = uniform_poincare.sample(key, n=2, c=1.0, R=1.0, sample_shape=(100,)) >>> x.shape (100, 2)

Source code in hyperbolix/distributions/uniform_poincare.py
def sample(
    key: PRNGKeyArray,
    n: int,
    c: float,
    R: float,
    sample_shape: tuple[int, ...] = (),
    center: Float[Array, "n"] | None = None,
    dtype=None,
    manifold_module: Manifold | None = None,
) -> Float[Array, "... n"]:
    """Sample uniformly from a geodesic ball in the Poincaré model.

    Draws points that are Riemannian-uniform within B(center, R).

    Algorithm:
        1. Sample direction u ~ Uniform(S^{n-1})
        2. Sample geodesic radius r ~ p(r) ∝ sinh^{n-1}(√c·r) on [0, R]
        3. Form tangent vector t = (r/2)·u  (the /2 accounts for λ(0)=2)
        4. Map to ball: x₀ = expmap_0(t, c)
        5. Move to center: x = center ⊕ x₀  (Möbius addition)

    Args:
        key: JAX PRNG key.
        n: Dimension of the Poincaré ball.
        c: Positive curvature parameter.
        R: Geodesic radius of the ball.
        sample_shape: Batch shape of samples. Default: () → single sample.
        center: Center of the geodesic ball, shape (n,). Default: origin.
        dtype: Output dtype. Default: float64.
        manifold_module: Optional Manifold instance. Default: Poincare(dtype).

    Returns:
        Samples on the Poincaré ball, shape ``sample_shape + (n,)``.

    Examples:
        >>> import jax
        >>> import jax.numpy as jnp
        >>> from hyperbolix.distributions import uniform_poincare
        >>>
        >>> key = jax.random.PRNGKey(0)
        >>> x = uniform_poincare.sample(key, n=2, c=1.0, R=1.0, sample_shape=(100,))
        >>> x.shape
        (100, 2)
    """
    if manifold_module is not None:
        manifold = manifold_module
    else:
        from ..manifolds.poincare import Poincare

        _dtype = dtype if dtype is not None else jnp.float64
        manifold = Poincare(dtype=_dtype)

    if dtype is None:
        dtype = jnp.float64

    k1, k2 = jax.random.split(key)

    # 1. Direction on S^{n-1}
    directions_SD = _sample_uniform_direction(k1, n, sample_shape, dtype)

    # 2. Geodesic radii
    radii_S = _sample_radial(k2, c, n, R, sample_shape, dtype)

    # 3. Tangent vectors: t = (r/2) · u
    # The /2 compensates for expmap_0 mapping ||v|| → geodesic distance 2·||v||
    tangents_SD = (radii_S[..., None] / 2.0) * directions_SD  # (*S, 1) * (*S, D)

    # 4-5. Map to ball and translate to center
    def _map_single(t_D):
        x0_D = manifold.expmap_0(t_D, c)
        if center is not None:
            return manifold.addition(center, x0_D, c)
        return x0_D

    # vmap over all sample dimensions
    mapped_fn = _map_single
    for _ in sample_shape:
        mapped_fn = jax.vmap(mapped_fn)

    if sample_shape:
        result_SD = mapped_fn(tangents_SD)
    else:
        result_SD = _map_single(tangents_SD)

    return result_SD
log_prob
log_prob(
    x: Float[Array, "... n"],
    c: float,
    R: float,
    center: Float[Array, n] | None = None,
    manifold_module: Manifold | None = None,
) -> Float[Array, ...]

Log-probability of the Riemannian-uniform distribution on B(center, R).

Returns -log Vol(B^n_c(R)) for points inside the geodesic ball, -∞ outside.

Args: x: Point(s) on the Poincaré ball, shape (..., n). c: Positive curvature parameter. R: Geodesic radius of the ball. center: Center of the geodesic ball, shape (n,). Default: origin. manifold_module: Optional Manifold instance. Default: Poincare(dtype).

Returns: Log-probability, shape (...).

Source code in hyperbolix/distributions/uniform_poincare.py
def log_prob(
    x: Float[Array, "... n"],
    c: float,
    R: float,
    center: Float[Array, "n"] | None = None,
    manifold_module: Manifold | None = None,
) -> Float[Array, "..."]:
    """Log-probability of the Riemannian-uniform distribution on B(center, R).

    Returns -log Vol(B^n_c(R)) for points inside the geodesic ball, -∞ outside.

    Args:
        x: Point(s) on the Poincaré ball, shape (..., n).
        c: Positive curvature parameter.
        R: Geodesic radius of the ball.
        center: Center of the geodesic ball, shape (n,). Default: origin.
        manifold_module: Optional Manifold instance. Default: Poincare(dtype).

    Returns:
        Log-probability, shape (...).
    """
    if manifold_module is not None:
        manifold = manifold_module
    else:
        from ..manifolds.poincare import Poincare

        manifold = Poincare(dtype=x.dtype)

    n = x.shape[-1]

    # Compute geodesic distance from center
    if center is not None:
        if x.ndim > 1:
            dist_fn = jax.vmap(lambda xi: manifold.dist(xi, center, c))
            d_SB = dist_fn(x)
        else:
            d_SB = manifold.dist(x, center, c)
    else:
        if x.ndim > 1:
            dist_fn = jax.vmap(lambda xi: manifold.dist_0(xi, c))
            d_SB = dist_fn(x)
        else:
            d_SB = manifold.dist_0(x, c)

    log_vol = jnp.log(volume(c, n, R))

    # -log(vol) inside ball, -inf outside
    inside_SB = d_SB <= R
    return jnp.where(inside_SB, -log_vol, -jnp.inf)

Samples uniformly with respect to the Riemannian volume measure within a geodesic ball \(B(\text{center}, R)\) on the Poincaré ball. Uses geodesic polar decomposition: direction from \(S^{n-1}\) (Muller method), radial component from \(p(r) \propto \sinh^{n-1}(\sqrt{c}\,r)\).

  • For \(n=2\): closed-form radial sampling via \(u = \cosh(\sqrt{c}\,r) - 1\)
  • For \(n \geq 3\): rejection sampling with jax.lax.while_loop (JIT-compatible)

Usage:

from hyperbolix.distributions import uniform_poincare
from hyperbolix.manifolds import Poincare
import jax
import jax.numpy as jnp

poincare = Poincare()
key = jax.random.PRNGKey(42)

# Sample 100 points uniformly from geodesic ball B(origin, R=1.0) in 2D
samples = uniform_poincare.sample(key, n=100, c=1.0, R=1.0, dim=2)
print(samples.shape)  # (100, 2)

# All samples inside the geodesic ball
is_valid = jax.vmap(poincare.is_in_manifold, in_axes=(0, None))(samples, 1.0)
print(is_valid.all())  # True

# Volume of geodesic ball (2D closed-form: 2π(cosh R - 1)/c)
vol = uniform_poincare.volume(c=1.0, n=2, R=1.0)
print(vol)  # ~3.086

# Log probability (constant inside ball, -inf outside)
log_p = jax.vmap(lambda x: uniform_poincare.log_prob(x, c=1.0, R=1.0))(samples)
print(jnp.allclose(log_p, log_p[0]))  # True — uniform

Use cases: Uniform priors for hyperbolic VAEs, test/validation sampling, hyperparameter search in hyperbolic spaces.

Wrapped Normal Distribution

The wrapped normal distribution extends the Gaussian distribution to hyperbolic manifolds by wrapping Euclidean Gaussians via the exponential map.

Poincaré Wrapped Normal

hyperbolix.distributions.wrapped_normal_poincare

Wrapped normal distribution on Poincaré ball.

Simpler implementation than hyperboloid - no parallel transport needed! Uses exponential map and Möbius addition.

Dimension key: S: sample dimensions (from sample_shape) B: batch dimensions (from mu batch shape) D: spatial/manifold dimension (n)

References: Mathieu et al. "Continuous Hierarchical Representations with Poincaré Variational Auto-Encoders" NeurIPS 2019. https://arxiv.org/abs/1901.06033

sample
sample(
    key: PRNGKeyArray,
    mu: Float[Array, "... n"],
    sigma: Float[Array, ...] | float,
    c: float,
    sample_shape: tuple[int, ...] = (),
    dtype=None,
    manifold_module: Manifold | None = None,
) -> Float[Array, ...]

Sample from wrapped normal distribution on Poincaré ball.

Simpler than hyperboloid version - no parallel transport needed!

Algorithm: 1. Sample v ~ N(0, Σ) ∈ R^n (directly in tangent space, no embedding) 2. Map to ball at origin: z_0 = exp_0(v) 3. Move to mean: z = μ ⊕ z_0 (Möbius addition)

Args: key: JAX random key mu: Mean point on Poincaré ball, shape (..., n) sigma: Covariance parameterization. Can be: - Scalar: isotropic covariance sigma^2 I (n x n) - 1D array of length n: diagonal covariance diag(sigma_1^2, ..., sigma_n^2) - 2D array (n, n): full covariance matrix (must be SPD) c: Curvature (positive scalar) sample_shape: Shape of samples to draw, prepended to output. Default: () dtype: Output dtype. Default: infer from mu

Returns: Samples from wrapped normal distribution, shape sample_shape + mu.shape

Examples: >>> import jax >>> import jax.numpy as jnp >>> from hyperbolix.distributions import wrapped_normal_poincare >>> >>> # Single sample with isotropic covariance >>> key = jax.random.PRNGKey(0) >>> mu = jnp.array([0.0, 0.0]) # Origin in Poincaré ball >>> sigma = 0.1 # Isotropic >>> z = wrapped_normal_poincare.sample(key, mu, sigma, c=1.0) >>> z.shape (2,) >>> >>> # Multiple samples with diagonal covariance >>> sigma_diag = jnp.array([0.1, 0.2]) # Diagonal >>> z = wrapped_normal_poincare.sample(key, mu, sigma_diag, c=1.0, sample_shape=(5,)) >>> z.shape (5, 2) >>> >>> # Batch of means >>> mu_batch = jnp.array([[0.0, 0.0], [0.1, 0.1]]) # (2, 2) >>> z = wrapped_normal_poincare.sample(key, mu_batch, 0.1, c=1.0) >>> z.shape (2, 2)

Source code in hyperbolix/distributions/wrapped_normal_poincare.py
def sample(
    key: PRNGKeyArray,
    mu: Float[Array, "... n"],
    sigma: Float[Array, "..."] | float,
    c: float,
    sample_shape: tuple[int, ...] = (),
    dtype=None,
    manifold_module: Manifold | None = None,
) -> Float[Array, "..."]:
    """Sample from wrapped normal distribution on Poincaré ball.

    Simpler than hyperboloid version - no parallel transport needed!

    Algorithm:
    1. Sample v ~ N(0, Σ) ∈ R^n (directly in tangent space, no embedding)
    2. Map to ball at origin: z_0 = exp_0(v)
    3. Move to mean: z = μ ⊕ z_0 (Möbius addition)

    Args:
        key: JAX random key
        mu: Mean point on Poincaré ball, shape (..., n)
        sigma: Covariance parameterization. Can be:
            - Scalar: isotropic covariance sigma^2 I (n x n)
            - 1D array of length n: diagonal covariance diag(sigma_1^2, ..., sigma_n^2)
            - 2D array (n, n): full covariance matrix (must be SPD)
        c: Curvature (positive scalar)
        sample_shape: Shape of samples to draw, prepended to output. Default: ()
        dtype: Output dtype. Default: infer from mu

    Returns:
        Samples from wrapped normal distribution, shape sample_shape + mu.shape

    Examples:
        >>> import jax
        >>> import jax.numpy as jnp
        >>> from hyperbolix.distributions import wrapped_normal_poincare
        >>>
        >>> # Single sample with isotropic covariance
        >>> key = jax.random.PRNGKey(0)
        >>> mu = jnp.array([0.0, 0.0])  # Origin in Poincaré ball
        >>> sigma = 0.1  # Isotropic
        >>> z = wrapped_normal_poincare.sample(key, mu, sigma, c=1.0)
        >>> z.shape
        (2,)
        >>>
        >>> # Multiple samples with diagonal covariance
        >>> sigma_diag = jnp.array([0.1, 0.2])  # Diagonal
        >>> z = wrapped_normal_poincare.sample(key, mu, sigma_diag, c=1.0, sample_shape=(5,))
        >>> z.shape
        (5, 2)
        >>>
        >>> # Batch of means
        >>> mu_batch = jnp.array([[0.0, 0.0], [0.1, 0.1]])  # (2, 2)
        >>> z = wrapped_normal_poincare.sample(key, mu_batch, 0.1, c=1.0)
        >>> z.shape
        (2, 2)
    """
    # Use provided manifold module or default class instance
    if manifold_module is not None:
        manifold = manifold_module
    else:
        from ..manifolds.poincare import Poincare

        _dtype = dtype if dtype is not None else mu.dtype
        manifold = Poincare(dtype=_dtype)

    # Determine output dtype
    if dtype is None:
        dtype = mu.dtype

    # Extract dimension
    n = mu.shape[-1]  # Dimension of Poincaré ball

    # Determine batch shape from mu (all dims except the last one)
    # mu.shape = batch_shape + (n,)
    mu_batch_shape = mu.shape[:-1]

    # Step 1: Sample v ~ N(0, Σ) ∈ R^n (directly in tangent space at origin)
    # Full noise shape: sample_shape + mu_batch_shape + (n,) = (*S, *B, D)
    cov_DD = sigma_to_cov(sigma, n, dtype)
    full_sample_shape = sample_shape + mu_batch_shape
    v_SBD = sample_gaussian(key, cov_DD, sample_shape=full_sample_shape, dtype=dtype)

    # Scale Euclidean tangent vector to Riemannian tangent space coordinates.
    # At origin, the conformal factor λ(0) = 2/(1 - c·0) = 2, so Riemannian
    # coordinates require dividing by λ(0). This matches the reference (Mathieu et al.
    # 2019) which does v = v / lambda_x(zero) before the exponential map.
    v_SBD = v_SBD / 2.0

    # Step 2: Map to ball at origin: z_0 = exp_0(v)
    # Step 3: Move to mean: z = μ ⊕ z_0

    def transform_single(v_D, mu_D):
        """Transform a single (v, mu) pair."""
        z_0_D = manifold.expmap_0(v_D, c)
        z_D = manifold.addition(mu_D, z_0_D, c)
        return z_D

    return _batched_transform(transform_single, v_SBD, mu, sample_shape, mu_batch_shape)
log_prob
log_prob(
    z: Float[Array, "... n"],
    mu: Float[Array, "... n"],
    sigma: Float[Array, ...] | float,
    c: float,
    manifold_module: Manifold | None = None,
) -> Float[Array, ...]

Compute log probability of wrapped normal distribution on Poincaré ball.

Implements Algorithm 2 from the paper adapted for Poincaré ball: 1. Map z to u = log_μ(z) ∈ T_μB^n (logarithmic map) 2. Move u to v = PT_{μ→0}(u) ∈ T_0B^n (parallel transport to origin) 3. Calculate log p(z) = log p(v) - log det(∂proj_μ(v)/∂v)

Args: z: Sample point(s) on Poincaré ball, shape (..., n) mu: Mean point on Poincaré ball, shape (..., n) sigma: Covariance parameterization. Can be: - Scalar: isotropic covariance sigma^2 I (n x n) - 1D array of length n: diagonal covariance diag(sigma_1^2, ..., sigma_n^2) - 2D array (n, n): full covariance matrix (must be SPD) c: Curvature (positive scalar)

Returns: Log probability, shape (...) (spatial dimension removed)

Examples: >>> import jax >>> import jax.numpy as jnp >>> from hyperbolix.distributions import wrapped_normal_poincare >>> >>> # Compute log probability of samples >>> key = jax.random.PRNGKey(0) >>> mu = jnp.array([0.0, 0.0]) >>> sigma = 0.1 >>> z = wrapped_normal_poincare.sample(key, mu, sigma, c=1.0) >>> log_p = wrapped_normal_poincare.log_prob(z, mu, sigma, c=1.0) >>> log_p.shape () >>> >>> # Batch computation >>> z_batch = wrapped_normal_poincare.sample(key, mu, sigma, c=1.0, sample_shape=(10,)) >>> log_p_batch = wrapped_normal_poincare.log_prob(z_batch, mu, sigma, c=1.0) >>> log_p_batch.shape (10,)

Source code in hyperbolix/distributions/wrapped_normal_poincare.py
def log_prob(
    z: Float[Array, "... n"],
    mu: Float[Array, "... n"],
    sigma: Float[Array, "..."] | float,
    c: float,
    manifold_module: Manifold | None = None,
) -> Float[Array, "..."]:
    """Compute log probability of wrapped normal distribution on Poincaré ball.

    Implements Algorithm 2 from the paper adapted for Poincaré ball:
    1. Map z to u = log_μ(z) ∈ T_μB^n (logarithmic map)
    2. Move u to v = PT_{μ→0}(u) ∈ T_0B^n (parallel transport to origin)
    3. Calculate log p(z) = log p(v) - log det(∂proj_μ(v)/∂v)

    Args:
        z: Sample point(s) on Poincaré ball, shape (..., n)
        mu: Mean point on Poincaré ball, shape (..., n)
        sigma: Covariance parameterization. Can be:
            - Scalar: isotropic covariance sigma^2 I (n x n)
            - 1D array of length n: diagonal covariance diag(sigma_1^2, ..., sigma_n^2)
            - 2D array (n, n): full covariance matrix (must be SPD)
        c: Curvature (positive scalar)

    Returns:
        Log probability, shape (...) (spatial dimension removed)

    Examples:
        >>> import jax
        >>> import jax.numpy as jnp
        >>> from hyperbolix.distributions import wrapped_normal_poincare
        >>>
        >>> # Compute log probability of samples
        >>> key = jax.random.PRNGKey(0)
        >>> mu = jnp.array([0.0, 0.0])
        >>> sigma = 0.1
        >>> z = wrapped_normal_poincare.sample(key, mu, sigma, c=1.0)
        >>> log_p = wrapped_normal_poincare.log_prob(z, mu, sigma, c=1.0)
        >>> log_p.shape
        ()
        >>>
        >>> # Batch computation
        >>> z_batch = wrapped_normal_poincare.sample(key, mu, sigma, c=1.0, sample_shape=(10,))
        >>> log_p_batch = wrapped_normal_poincare.log_prob(z_batch, mu, sigma, c=1.0)
        >>> log_p_batch.shape
        (10,)
    """

    # Use provided manifold module or default class instance
    if manifold_module is not None:
        manifold = manifold_module
    else:
        from ..manifolds.poincare import Poincare

        manifold = Poincare(dtype=z.dtype)

    # Determine dtype
    dtype = z.dtype

    # Extract dimension
    n = mu.shape[-1]  # Dimension of Poincaré ball

    # Step 1: Map z to tangent space at mu: u = log_μ(z), shape (..., D)
    if z.ndim > mu.ndim:
        n_sample_dims = z.ndim - mu.ndim
        logmap_fn = manifold.logmap
        for _ in range(n_sample_dims):
            logmap_fn = jax.vmap(logmap_fn, in_axes=(0, None, None))
        u_SBD = logmap_fn(z, mu, c)
    elif z.ndim == mu.ndim and mu.ndim > 1:
        u_SBD = jax.vmap(lambda zz, mm: manifold.logmap(zz, mm, c))(z, mu)
    else:
        u_SBD = manifold.logmap(z, mu, c)

    # Step 2: Parallel transport from mu to origin: v = PT_{μ→0}(u)
    mu_0_D = jnp.zeros(n, dtype=dtype)

    if u_SBD.ndim > 1:
        if mu.ndim > 1:
            v_SBD = jax.vmap(lambda uu, mm: manifold.ptransp(uu, mm, mu_0_D, c))(u_SBD, mu)
        else:
            v_SBD = jax.vmap(lambda uu: manifold.ptransp(uu, mu, mu_0_D, c))(u_SBD)
    else:
        v_SBD = manifold.ptransp(u_SBD, mu, mu_0_D, c)

    # Step 3: Compute log p(v) where v ~ N(0, Σ)
    # Scale to Riemannian coordinates: v_riem = λ(0) · v_euclid = 2 · v
    v_riem_SBD = v_SBD * 2.0
    log_p_v_SB = gaussian_log_prob(v_riem_SBD, sigma, n, dtype)

    # Step 4: Compute log det Jacobian
    # Riemannian norm r = λ(0) · ||v||_E = 2 · ||v||_E
    r_SB = 2.0 * jnp.sqrt(jnp.sum(v_SBD**2, axis=-1))
    log_det_jac_SB = _log_det_jacobian_from_r(r_SB, c, n)

    # Step 5: log p(z) = log p(v) - log det(∂proj_μ(v)/∂v)
    log_p_z_SB = log_p_v_SB - log_det_jac_SB

    return log_p_z_SB

Hyperboloid Wrapped Normal

hyperbolix.distributions.wrapped_normal_hyperboloid

Wrapped normal distribution on hyperboloid manifold.

Implementation of the wrapped normal distribution that wraps a Gaussian from the tangent space at the origin onto the hyperboloid via parallel transport and exponential map.

Dimension key: S: sample dimensions (from sample_shape) B: batch dimensions (from mu batch shape) D: spatial dimension (n) A: ambient dimension (n+1, hyperboloid time+space)

References: Mathieu et al. "Continuous Hierarchical Representations with Poincaré Variational Auto-Encoders" NeurIPS 2019. https://arxiv.org/abs/1901.06033

sample
sample(
    key: PRNGKeyArray,
    mu: Float[Array, "... n_plus_1"],
    sigma: Float[Array, ...] | float,
    c: float,
    sample_shape: tuple[int, ...] = (),
    dtype=None,
    manifold_module: Hyperboloid | None = None,
) -> Float[Array, ...]

Sample from wrapped normal distribution on hyperboloid.

Implements Algorithm 1 from the paper: 1. Sample v_bar ~ N(0, Σ) ∈ R^n 2. Embed as tangent vector v = [0, v_bar] ∈ T_{μ₀}ℍⁿ at origin 3. Parallel transport to mean: u = PT_{μ₀→μ}(v) ∈ T_μℍⁿ 4. Map to manifold: z = exp_μ(u) ∈ ℍⁿ

Args: key: JAX random key mu: Mean point on hyperboloid, shape (..., n+1) in ambient coordinates sigma: Covariance parameterization in spatial coordinates. Can be: - Scalar: isotropic covariance sigma^2 I (n x n) - 1D array of length n: diagonal covariance diag(sigma_1^2, ..., sigma_n^2) - 2D array (n, n): full covariance matrix (must be SPD) c: Curvature (positive scalar) sample_shape: Shape of samples to draw, prepended to output. Default: () dtype: Output dtype. Default: infer from mu

Returns: Samples from wrapped normal distribution, shape sample_shape + mu.shape

Examples: >>> import jax >>> import jax.numpy as jnp >>> from hyperbolix.distributions import wrapped_normal_hyperboloid >>> >>> # Single sample with isotropic covariance >>> key = jax.random.PRNGKey(0) >>> mu = jnp.array([1.0, 0.0, 0.0]) # Origin in H^2 >>> sigma = 0.1 # Isotropic >>> z = wrapped_normal_hyperboloid.sample(key, mu, sigma, c=1.0) >>> z.shape (3,) >>> >>> # Multiple samples with diagonal covariance >>> sigma_diag = jnp.array([0.1, 0.2]) # Diagonal >>> z = wrapped_normal_hyperboloid.sample(key, mu, sigma_diag, c=1.0, sample_shape=(5,)) >>> z.shape (5, 3) >>> >>> # Batch of means >>> mu_batch = jnp.array([[1.0, 0.0, 0.0], [1.0, 0.1, 0.1]]) # (2, 3) >>> z = wrapped_normal_hyperboloid.sample(key, mu_batch, 0.1, c=1.0) >>> z.shape (2, 3)

Source code in hyperbolix/distributions/wrapped_normal_hyperboloid.py
def sample(
    key: PRNGKeyArray,
    mu: Float[Array, "... n_plus_1"],
    sigma: Float[Array, "..."] | float,
    c: float,
    sample_shape: tuple[int, ...] = (),
    dtype=None,
    manifold_module: Hyperboloid | None = None,
) -> Float[Array, "..."]:
    """Sample from wrapped normal distribution on hyperboloid.

    Implements Algorithm 1 from the paper:
    1. Sample v_bar ~ N(0, Σ) ∈ R^n
    2. Embed as tangent vector v = [0, v_bar] ∈ T_{μ₀}ℍⁿ at origin
    3. Parallel transport to mean: u = PT_{μ₀→μ}(v) ∈ T_μℍⁿ
    4. Map to manifold: z = exp_μ(u) ∈ ℍⁿ

    Args:
        key: JAX random key
        mu: Mean point on hyperboloid, shape (..., n+1) in ambient coordinates
        sigma: Covariance parameterization in spatial coordinates. Can be:
            - Scalar: isotropic covariance sigma^2 I (n x n)
            - 1D array of length n: diagonal covariance diag(sigma_1^2, ..., sigma_n^2)
            - 2D array (n, n): full covariance matrix (must be SPD)
        c: Curvature (positive scalar)
        sample_shape: Shape of samples to draw, prepended to output. Default: ()
        dtype: Output dtype. Default: infer from mu

    Returns:
        Samples from wrapped normal distribution, shape sample_shape + mu.shape

    Examples:
        >>> import jax
        >>> import jax.numpy as jnp
        >>> from hyperbolix.distributions import wrapped_normal_hyperboloid
        >>>
        >>> # Single sample with isotropic covariance
        >>> key = jax.random.PRNGKey(0)
        >>> mu = jnp.array([1.0, 0.0, 0.0])  # Origin in H^2
        >>> sigma = 0.1  # Isotropic
        >>> z = wrapped_normal_hyperboloid.sample(key, mu, sigma, c=1.0)
        >>> z.shape
        (3,)
        >>>
        >>> # Multiple samples with diagonal covariance
        >>> sigma_diag = jnp.array([0.1, 0.2])  # Diagonal
        >>> z = wrapped_normal_hyperboloid.sample(key, mu, sigma_diag, c=1.0, sample_shape=(5,))
        >>> z.shape
        (5, 3)
        >>>
        >>> # Batch of means
        >>> mu_batch = jnp.array([[1.0, 0.0, 0.0], [1.0, 0.1, 0.1]])  # (2, 3)
        >>> z = wrapped_normal_hyperboloid.sample(key, mu_batch, 0.1, c=1.0)
        >>> z.shape
        (2, 3)
    """
    # Use provided manifold module or default class instance
    if manifold_module is not None:
        manifold = manifold_module
    else:
        # Determine output dtype first for class instantiation
        _dtype = dtype if dtype is not None else mu.dtype
        manifold = Hyperboloid(dtype=_dtype)

    # Determine output dtype
    if dtype is None:
        dtype = mu.dtype

    # Extract spatial dimension
    n = mu.shape[-1] - 1  # Spatial dimension (n for H^n in R^(n+1))

    # Determine batch shape from mu (all dims except the last one)
    # mu.shape = batch_shape + (n+1,)
    mu_batch_shape = mu.shape[:-1]

    # Step 1: Sample v_bar ~ N(0, Σ) ∈ R^n
    # Full noise shape: sample_shape + mu_batch_shape + (n,) = (*S, *B, D)
    cov_DD = sigma_to_cov(sigma, n, dtype)
    full_sample_shape = sample_shape + mu_batch_shape
    v_spatial_SBD = sample_gaussian(key, cov_DD, sample_shape=full_sample_shape, dtype=dtype)

    # Step 2: Embed as tangent vector v = [0, v_bar] ∈ T_{μ₀}ℍⁿ at origin
    # Shape: (*S, *B, A) where A = n+1
    v_SBA = manifold.embed_spatial_0(v_spatial_SBD)

    # Step 3 & 4: Parallel transport and exponential map

    def transform_single(v_A, mu_A):
        """Transform a single (v, mu) pair."""
        u_A = manifold.ptransp_0(v_A, mu_A, c)  # Step 3: parallel transport to mu
        z_A = manifold.expmap(u_A, mu_A, c)  # Step 4: exponential map at mu
        return z_A

    return _batched_transform(transform_single, v_SBA, mu, sample_shape, mu_batch_shape)
log_prob
log_prob(
    z: Float[Array, "... n_plus_1"],
    mu: Float[Array, "... n_plus_1"],
    sigma: Float[Array, ...] | float,
    c: float,
    manifold_module: Hyperboloid | None = None,
) -> Float[Array, ...]

Compute log probability of wrapped normal distribution.

Implements Algorithm 2 from the paper: 1. Map z to u = exp_μ⁻¹(z) ∈ T_μℍⁿ (logarithmic map) 2. Move u to v = PT_{μ→μ₀}(u) ∈ T_{μ₀}ℍⁿ (parallel transport to origin) 3. Calculate log p(z) = log p(v) - log det(∂proj_μ(v)/∂v)

Args: z: Sample point(s) on hyperboloid, shape (..., n+1) mu: Mean point on hyperboloid, shape (..., n+1) sigma: Covariance parameterization. Can be: - Scalar: isotropic covariance sigma^2 I (n x n) - 1D array of length n: diagonal covariance diag(sigma_1^2, ..., sigma_n^2) - 2D array (n, n): full covariance matrix (must be SPD) c: Curvature (positive scalar)

Returns: Log probability, shape (...) (manifold dimension removed)

Examples: >>> import jax >>> import jax.numpy as jnp >>> from hyperbolix.distributions import wrapped_normal_hyperboloid >>> >>> # Compute log probability of samples >>> key = jax.random.PRNGKey(0) >>> mu = jnp.array([1.0, 0.0, 0.0]) >>> sigma = 0.1 >>> z = wrapped_normal_hyperboloid.sample(key, mu, sigma, c=1.0) >>> log_p = wrapped_normal_hyperboloid.log_prob(z, mu, sigma, c=1.0) >>> log_p.shape () >>> >>> # Batch computation >>> z_batch = wrapped_normal_hyperboloid.sample(key, mu, sigma, c=1.0, sample_shape=(10,)) >>> log_p_batch = wrapped_normal_hyperboloid.log_prob(z_batch, mu, sigma, c=1.0) >>> log_p_batch.shape (10,)

Source code in hyperbolix/distributions/wrapped_normal_hyperboloid.py
def log_prob(
    z: Float[Array, "... n_plus_1"],
    mu: Float[Array, "... n_plus_1"],
    sigma: Float[Array, "..."] | float,
    c: float,
    manifold_module: Hyperboloid | None = None,
) -> Float[Array, "..."]:
    """Compute log probability of wrapped normal distribution.

    Implements Algorithm 2 from the paper:
    1. Map z to u = exp_μ⁻¹(z) ∈ T_μℍⁿ (logarithmic map)
    2. Move u to v = PT_{μ→μ₀}(u) ∈ T_{μ₀}ℍⁿ (parallel transport to origin)
    3. Calculate log p(z) = log p(v) - log det(∂proj_μ(v)/∂v)

    Args:
        z: Sample point(s) on hyperboloid, shape (..., n+1)
        mu: Mean point on hyperboloid, shape (..., n+1)
        sigma: Covariance parameterization. Can be:
            - Scalar: isotropic covariance sigma^2 I (n x n)
            - 1D array of length n: diagonal covariance diag(sigma_1^2, ..., sigma_n^2)
            - 2D array (n, n): full covariance matrix (must be SPD)
        c: Curvature (positive scalar)

    Returns:
        Log probability, shape (...) (manifold dimension removed)

    Examples:
        >>> import jax
        >>> import jax.numpy as jnp
        >>> from hyperbolix.distributions import wrapped_normal_hyperboloid
        >>>
        >>> # Compute log probability of samples
        >>> key = jax.random.PRNGKey(0)
        >>> mu = jnp.array([1.0, 0.0, 0.0])
        >>> sigma = 0.1
        >>> z = wrapped_normal_hyperboloid.sample(key, mu, sigma, c=1.0)
        >>> log_p = wrapped_normal_hyperboloid.log_prob(z, mu, sigma, c=1.0)
        >>> log_p.shape
        ()
        >>>
        >>> # Batch computation
        >>> z_batch = wrapped_normal_hyperboloid.sample(key, mu, sigma, c=1.0, sample_shape=(10,))
        >>> log_p_batch = wrapped_normal_hyperboloid.log_prob(z_batch, mu, sigma, c=1.0)
        >>> log_p_batch.shape
        (10,)
    """
    # Use provided manifold module or default class instance
    if manifold_module is not None:
        manifold = manifold_module
    else:
        manifold = Hyperboloid(dtype=z.dtype)

    # Determine dtype
    dtype = z.dtype

    # Extract spatial dimension
    n = mu.shape[-1] - 1  # Spatial dimension

    # Step 1: Map z to tangent space at mu: u = log_μ(z), shape (..., A)
    if z.ndim > mu.ndim:
        n_sample_dims = z.ndim - mu.ndim
        logmap_fn = manifold.logmap
        for _ in range(n_sample_dims):
            logmap_fn = jax.vmap(logmap_fn, in_axes=(0, None, None))
        u_SBA = logmap_fn(z, mu, c)
    elif z.ndim == mu.ndim and mu.ndim > 1:
        u_SBA = jax.vmap(lambda zz, mm: manifold.logmap(zz, mm, c))(z, mu)
    else:
        u_SBA = manifold.logmap(z, mu, c)

    # Step 2: Parallel transport from mu to origin: v = PT_{μ→μ₀}(u)
    mu_0_A = manifold.create_origin(c, n)

    if u_SBA.ndim > 1:
        if mu.ndim > 1:
            v_SBA = jax.vmap(lambda uu, mm: manifold.ptransp(uu, mm, mu_0_A, c))(u_SBA, mu)
        else:
            v_SBA = jax.vmap(lambda uu: manifold.ptransp(uu, mu, mu_0_A, c))(u_SBA)
    else:
        v_SBA = manifold.ptransp(u_SBA, mu, mu_0_A, c)

    # Step 3: Extract spatial components: v = [0, v_bar] at origin
    v_spatial_SBD = v_SBA[..., 1:]

    # Step 4: Compute log p(v) where v ~ N(0, Σ)
    log_p_v_SB = gaussian_log_prob(v_spatial_SBD, sigma, n, dtype)

    # Step 5: Compute log det Jacobian
    # Minkowski norm at origin: r = ||v_spatial|| (since v = [0, v_bar])
    r_SB = jnp.sqrt(jnp.maximum(jnp.sum(v_spatial_SBD**2, axis=-1), 1e-15))
    log_det_jac_SB = _log_det_jacobian_from_r(r_SB, c, n)

    # Step 6: log p(z) = log p(v) - log det(∂proj_μ(v)/∂v)
    log_p_z_SB = log_p_v_SB - log_det_jac_SB

    return log_p_z_SB

Usage Examples

Basic Sampling (Poincaré)

from hyperbolix.distributions import wrapped_normal_poincare
from hyperbolix.manifolds import Poincare
import jax
import jax.numpy as jnp

poincare = Poincare()

# Mean on Poincaré ball
mean = jnp.array([0.2, 0.3])
mean_proj = poincare.proj(mean, c=1.0)

# Standard deviation
std = 0.1

# Sample
key = jax.random.PRNGKey(42)
samples = wrapped_normal_poincare.sample(
    key, mean_proj, std, c=1.0, sample_shape=(100,), manifold_module=poincare
)
print(samples.shape)  # (100, 2)

# Samples lie on Poincaré ball
norms = jnp.linalg.norm(samples, axis=-1)
print(jnp.all(norms < 1.0 / jnp.sqrt(1.0)))  # True

Log Probability

# Compute log probability of samples
log_probs = jax.vmap(
    lambda x: wrapped_normal_poincare.log_prob(x, mean_proj, std, c=1.0, manifold_module=poincare)
)(samples)
print(log_probs.shape)  # (100,)

# Higher probability near mean
point_near_mean = poincare.proj(jnp.array([0.21, 0.29]), c=1.0)
point_far = poincare.proj(jnp.array([0.7, 0.7]), c=1.0)

print(f"Log prob (near): {wrapped_normal_poincare.log_prob(point_near_mean, mean_proj, std, c=1.0, manifold_module=poincare):.4f}")
print(f"Log prob (far): {wrapped_normal_poincare.log_prob(point_far, mean_proj, std, c=1.0, manifold_module=poincare):.4f}")

Hyperboloid Distribution

from hyperbolix.distributions import wrapped_normal_hyperboloid
from hyperbolix.manifolds import Hyperboloid
import jax.numpy as jnp

hyperboloid = Hyperboloid()

# Mean on hyperboloid (ambient coordinates)
mean_space = jnp.array([0.2, 0.3, -0.1])
mean_ambient = jnp.concatenate([
    jnp.array([jnp.sqrt(jnp.sum(mean_space**2) + 1.0)]),
    mean_space
])

# Sample
key = jax.random.PRNGKey(123)
samples = wrapped_normal_hyperboloid.sample(
    key, mean_ambient, std=0.15, c=1.0, sample_shape=(50,), manifold_module=hyperboloid
)

# Compute log probabilities
log_probs = jax.vmap(
    lambda x: wrapped_normal_hyperboloid.log_prob(x, mean_ambient, 0.15, c=1.0, manifold_module=hyperboloid)
)(samples)

VAE Example

Using wrapped normal distributions in a Variational Autoencoder:

from flax import nnx
from hyperbolix.distributions import wrapped_normal_poincare
from hyperbolix.nn_layers import HypLinearPoincare
from hyperbolix.manifolds import Poincare
import jax
import jax.numpy as jnp

poincare = Poincare()

class HyperbolicVAE(nnx.Module):
    def __init__(self, latent_dim, rngs):
        self.latent_dim = latent_dim

        # Encoder: Euclidean → Hyperbolic
        self.encoder = nnx.Linear(784, 128, rngs=rngs)
        self.enc_hyp = HypLinearPoincare(
            manifold_module=poincare,
            in_dim=128,
            out_dim=latent_dim,
            rngs=rngs
        )

        # Decoder: Hyperbolic → Euclidean
        self.dec_hyp = HypLinearPoincare(
            manifold_module=poincare,
            in_dim=latent_dim,
            out_dim=128,
            rngs=rngs
        )
        self.decoder = nnx.Linear(128, 784, rngs=rngs)

    def encode(self, x, c):
        # Returns mean and std for latent distribution
        h = jax.nn.relu(self.encoder(x))

        # Project to Poincaré ball
        h_proj = jax.vmap(poincare.proj, in_axes=(0, None))(h, c)

        # Mean on Poincaré ball
        mean = self.enc_hyp(h_proj, c)

        # Std in tangent space (Euclidean)
        log_std_layer = nnx.Linear(128, self.latent_dim, rngs=nnx.Rngs(0))
        log_std = log_std_layer(h)
        std = jnp.exp(log_std)

        return mean, std

    def decode(self, z, c):
        h = self.dec_hyp(z, c)

        # Logmap to tangent space for Euclidean decoder
        h_tangent = jax.vmap(poincare.logmap, in_axes=(None, 0, None))(
            jnp.zeros(self.latent_dim), h, c
        )

        return jax.nn.sigmoid(self.decoder(h_tangent))

    def __call__(self, x, key, c=1.0):
        # Encode
        mean, std = self.encode(x, c)

        # Sample latent code
        keys = jax.random.split(key, mean.shape[0])
        z = jax.vmap(
            lambda k, m, s: wrapped_normal_poincare.sample(k, m, s, c, (), manifold_module=poincare)
        )(keys, mean, std)

        # Decode
        recon = self.decode(z, c)

        return recon, mean, std, z

# Loss function
def vae_loss(model, x, key, c):
    recon, mean, std, z = model(x, key, c)

    # Reconstruction loss
    recon_loss = jnp.mean((x - recon) ** 2)

    # KL divergence (approximate for wrapped normal)
    # Use standard Gaussian prior in tangent space at origin
    kl_loss = -0.5 * jnp.mean(
        1 + 2 * jnp.log(std) - jnp.sum(mean**2, axis=-1) - std**2
    )

    return recon_loss + kl_loss

Mathematical Background

Wrapped Normal Definition

Given a mean \(\mu \in \mathcal{M}\) on manifold \(\mathcal{M}\) and standard deviation \(\sigma\), the wrapped normal distribution is defined as:

  1. Sample \(v \sim \mathcal{N}(0, \sigma^2 I)\) in tangent space \(T_\mu \mathcal{M}\)
  2. Wrap to manifold: \(x = \exp_\mu(v)\)

The log probability is:

\[ \log p(x) = -\frac{1}{2\sigma^2} \|\log_\mu(x)\|^2 - \frac{d}{2}\log(2\pi\sigma^2) \]

where \(\log_\mu\) is the logarithmic map at \(\mu\).

Sampling Algorithm

# Conceptual implementation (simplified)
def sample_concept(key, mean, std, c, sample_shape, manifold):
    # 1. Sample in tangent space at mean
    tangent_sample = std * jax.random.normal(key, sample_shape + mean.shape)

    # 2. Exponential map to manifold
    manifold_sample = manifold.expmap(tangent_sample, mean, c)

    return manifold_sample

Numerical Considerations

Numerical Stability

For small standard deviations and/or high curvatures, the exponential map can become numerically unstable. Consider:

  • Using float64 for very small \(\sigma\) (< 0.01)
  • Clipping standard deviations to reasonable range: \(\sigma \in [0.01, 1.0]\)
  • Using version parameter in manifold operations for better stability

Curvature Choice

The curvature parameter \(c\) affects the distribution:

  • Higher \(c\) → More concentrated distributions
  • Lower \(c\) → More spread out distributions

Tune \(c\) based on your application's needs.

References

Wrapped distributions on manifolds are discussed in:

  • Nagano, Y., et al. (2019). "A Wrapped Normal Distribution on Hyperbolic Space for Gradient-Based Learning"
  • Davidson, T., et al. (2018). "Hyperspherical Variational Auto-Encoders"

See also: