Distributions API¶
Probability distributions on hyperbolic manifolds.
Overview¶
Hyperbolix provides wrapped normal distributions for probabilistic modeling on hyperbolic manifolds via functional interfaces. These distributions are essential for:
- Variational Autoencoders (VAEs) with hyperbolic latent spaces
- Bayesian neural networks on manifolds
- Uncertainty quantification in hyperbolic embeddings
Riemannian Uniform Distribution¶
Poincaré Uniform¶
hyperbolix.distributions.uniform_poincare ¶
Riemannian-uniform distribution on a geodesic ball in the Poincaré model.
Samples points uniformly (w.r.t. the Riemannian volume element) within a geodesic ball B(center, R) of finite radius R. Uses geodesic polar coordinates: sample a direction on S^{n-1}, sample a radius from the hyperbolic radial density, form a tangent vector, and map to the ball.
The radial density is p(r) ∝ sinh^{n-1}(√c·r) on [0, R]. A substitution u = cosh(√c·r) - 1 simplifies sampling: - n = 2: u is uniform on [0, cosh(√c·R) - 1] (closed-form) - n ≥ 3: rejection sampling with acceptance ∝ (u·(u+2))^{(n-2)/2}
Dimension key: S: sample dimensions (from sample_shape) D: spatial/manifold dimension (n) Q: quadrature points (64 for GL) T: total flattened samples (for rejection sampling)
volume ¶
Riemannian volume of a geodesic ball B^n_c(R) in n-dim hyperbolic space.
Vol = ω_{n-1} / c^{(n-1)/2} · ∫₀ᴿ sinh^{n-1}(√c·r) dr
where ω_{n-1} is the surface area of the unit (n-1)-sphere.
Computed via 64-point Gauss-Legendre quadrature.
Args: c: Positive curvature parameter. n: Ambient dimension of the Poincaré ball. R: Geodesic radius of the ball.
Returns: Scalar volume.
Source code in hyperbolix/distributions/uniform_poincare.py
sample ¶
sample(
key: PRNGKeyArray,
n: int,
c: float,
R: float,
sample_shape: tuple[int, ...] = (),
center: Float[Array, n] | None = None,
dtype=None,
manifold_module: Manifold | None = None,
) -> Float[Array, "... n"]
Sample uniformly from a geodesic ball in the Poincaré model.
Draws points that are Riemannian-uniform within B(center, R).
Algorithm: 1. Sample direction u ~ Uniform(S^{n-1}) 2. Sample geodesic radius r ~ p(r) ∝ sinh^{n-1}(√c·r) on [0, R] 3. Form tangent vector t = (r/2)·u (the /2 accounts for λ(0)=2) 4. Map to ball: x₀ = expmap_0(t, c) 5. Move to center: x = center ⊕ x₀ (Möbius addition)
Args: key: JAX PRNG key. n: Dimension of the Poincaré ball. c: Positive curvature parameter. R: Geodesic radius of the ball. sample_shape: Batch shape of samples. Default: () → single sample. center: Center of the geodesic ball, shape (n,). Default: origin. dtype: Output dtype. Default: float64. manifold_module: Optional Manifold instance. Default: Poincare(dtype).
Returns:
Samples on the Poincaré ball, shape sample_shape + (n,).
Examples: >>> import jax >>> import jax.numpy as jnp >>> from hyperbolix.distributions import uniform_poincare >>> >>> key = jax.random.PRNGKey(0) >>> x = uniform_poincare.sample(key, n=2, c=1.0, R=1.0, sample_shape=(100,)) >>> x.shape (100, 2)
Source code in hyperbolix/distributions/uniform_poincare.py
215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 | |
log_prob ¶
log_prob(
x: Float[Array, "... n"],
c: float,
R: float,
center: Float[Array, n] | None = None,
manifold_module: Manifold | None = None,
) -> Float[Array, ...]
Log-probability of the Riemannian-uniform distribution on B(center, R).
Returns -log Vol(B^n_c(R)) for points inside the geodesic ball, -∞ outside.
Args: x: Point(s) on the Poincaré ball, shape (..., n). c: Positive curvature parameter. R: Geodesic radius of the ball. center: Center of the geodesic ball, shape (n,). Default: origin. manifold_module: Optional Manifold instance. Default: Poincare(dtype).
Returns: Log-probability, shape (...).
Source code in hyperbolix/distributions/uniform_poincare.py
Samples uniformly with respect to the Riemannian volume measure within a geodesic ball \(B(\text{center}, R)\) on the Poincaré ball. Uses geodesic polar decomposition: direction from \(S^{n-1}\) (Muller method), radial component from \(p(r) \propto \sinh^{n-1}(\sqrt{c}\,r)\).
- For \(n=2\): closed-form radial sampling via \(u = \cosh(\sqrt{c}\,r) - 1\)
- For \(n \geq 3\): rejection sampling with
jax.lax.while_loop(JIT-compatible)
Usage:
from hyperbolix.distributions import uniform_poincare
from hyperbolix.manifolds import Poincare
import jax
import jax.numpy as jnp
poincare = Poincare()
key = jax.random.PRNGKey(42)
# Sample 100 points uniformly from geodesic ball B(origin, R=1.0) in 2D
samples = uniform_poincare.sample(key, n=100, c=1.0, R=1.0, dim=2)
print(samples.shape) # (100, 2)
# All samples inside the geodesic ball
is_valid = jax.vmap(poincare.is_in_manifold, in_axes=(0, None))(samples, 1.0)
print(is_valid.all()) # True
# Volume of geodesic ball (2D closed-form: 2π(cosh R - 1)/c)
vol = uniform_poincare.volume(c=1.0, n=2, R=1.0)
print(vol) # ~3.086
# Log probability (constant inside ball, -inf outside)
log_p = jax.vmap(lambda x: uniform_poincare.log_prob(x, c=1.0, R=1.0))(samples)
print(jnp.allclose(log_p, log_p[0])) # True — uniform
Use cases: Uniform priors for hyperbolic VAEs, test/validation sampling, hyperparameter search in hyperbolic spaces.
Wrapped Normal Distribution¶
The wrapped normal distribution extends the Gaussian distribution to hyperbolic manifolds by wrapping Euclidean Gaussians via the exponential map.
Poincaré Wrapped Normal¶
hyperbolix.distributions.wrapped_normal_poincare ¶
Wrapped normal distribution on Poincaré ball.
Simpler implementation than hyperboloid - no parallel transport needed! Uses exponential map and Möbius addition.
Dimension key: S: sample dimensions (from sample_shape) B: batch dimensions (from mu batch shape) D: spatial/manifold dimension (n)
References: Mathieu et al. "Continuous Hierarchical Representations with Poincaré Variational Auto-Encoders" NeurIPS 2019. https://arxiv.org/abs/1901.06033
sample ¶
sample(
key: PRNGKeyArray,
mu: Float[Array, "... n"],
sigma: Float[Array, ...] | float,
c: float,
sample_shape: tuple[int, ...] = (),
dtype=None,
manifold_module: Manifold | None = None,
) -> Float[Array, ...]
Sample from wrapped normal distribution on Poincaré ball.
Simpler than hyperboloid version - no parallel transport needed!
Algorithm: 1. Sample v ~ N(0, Σ) ∈ R^n (directly in tangent space, no embedding) 2. Map to ball at origin: z_0 = exp_0(v) 3. Move to mean: z = μ ⊕ z_0 (Möbius addition)
Args: key: JAX random key mu: Mean point on Poincaré ball, shape (..., n) sigma: Covariance parameterization. Can be: - Scalar: isotropic covariance sigma^2 I (n x n) - 1D array of length n: diagonal covariance diag(sigma_1^2, ..., sigma_n^2) - 2D array (n, n): full covariance matrix (must be SPD) c: Curvature (positive scalar) sample_shape: Shape of samples to draw, prepended to output. Default: () dtype: Output dtype. Default: infer from mu
Returns: Samples from wrapped normal distribution, shape sample_shape + mu.shape
Examples: >>> import jax >>> import jax.numpy as jnp >>> from hyperbolix.distributions import wrapped_normal_poincare >>> >>> # Single sample with isotropic covariance >>> key = jax.random.PRNGKey(0) >>> mu = jnp.array([0.0, 0.0]) # Origin in Poincaré ball >>> sigma = 0.1 # Isotropic >>> z = wrapped_normal_poincare.sample(key, mu, sigma, c=1.0) >>> z.shape (2,) >>> >>> # Multiple samples with diagonal covariance >>> sigma_diag = jnp.array([0.1, 0.2]) # Diagonal >>> z = wrapped_normal_poincare.sample(key, mu, sigma_diag, c=1.0, sample_shape=(5,)) >>> z.shape (5, 2) >>> >>> # Batch of means >>> mu_batch = jnp.array([[0.0, 0.0], [0.1, 0.1]]) # (2, 2) >>> z = wrapped_normal_poincare.sample(key, mu_batch, 0.1, c=1.0) >>> z.shape (2, 2)
Source code in hyperbolix/distributions/wrapped_normal_poincare.py
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 | |
log_prob ¶
log_prob(
z: Float[Array, "... n"],
mu: Float[Array, "... n"],
sigma: Float[Array, ...] | float,
c: float,
manifold_module: Manifold | None = None,
) -> Float[Array, ...]
Compute log probability of wrapped normal distribution on Poincaré ball.
Implements Algorithm 2 from the paper adapted for Poincaré ball: 1. Map z to u = log_μ(z) ∈ T_μB^n (logarithmic map) 2. Move u to v = PT_{μ→0}(u) ∈ T_0B^n (parallel transport to origin) 3. Calculate log p(z) = log p(v) - log det(∂proj_μ(v)/∂v)
Args: z: Sample point(s) on Poincaré ball, shape (..., n) mu: Mean point on Poincaré ball, shape (..., n) sigma: Covariance parameterization. Can be: - Scalar: isotropic covariance sigma^2 I (n x n) - 1D array of length n: diagonal covariance diag(sigma_1^2, ..., sigma_n^2) - 2D array (n, n): full covariance matrix (must be SPD) c: Curvature (positive scalar)
Returns: Log probability, shape (...) (spatial dimension removed)
Examples: >>> import jax >>> import jax.numpy as jnp >>> from hyperbolix.distributions import wrapped_normal_poincare >>> >>> # Compute log probability of samples >>> key = jax.random.PRNGKey(0) >>> mu = jnp.array([0.0, 0.0]) >>> sigma = 0.1 >>> z = wrapped_normal_poincare.sample(key, mu, sigma, c=1.0) >>> log_p = wrapped_normal_poincare.log_prob(z, mu, sigma, c=1.0) >>> log_p.shape () >>> >>> # Batch computation >>> z_batch = wrapped_normal_poincare.sample(key, mu, sigma, c=1.0, sample_shape=(10,)) >>> log_p_batch = wrapped_normal_poincare.log_prob(z_batch, mu, sigma, c=1.0) >>> log_p_batch.shape (10,)
Source code in hyperbolix/distributions/wrapped_normal_poincare.py
127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 | |
Hyperboloid Wrapped Normal¶
hyperbolix.distributions.wrapped_normal_hyperboloid ¶
Wrapped normal distribution on hyperboloid manifold.
Implementation of the wrapped normal distribution that wraps a Gaussian from the tangent space at the origin onto the hyperboloid via parallel transport and exponential map.
Dimension key: S: sample dimensions (from sample_shape) B: batch dimensions (from mu batch shape) D: spatial dimension (n) A: ambient dimension (n+1, hyperboloid time+space)
References: Mathieu et al. "Continuous Hierarchical Representations with Poincaré Variational Auto-Encoders" NeurIPS 2019. https://arxiv.org/abs/1901.06033
sample ¶
sample(
key: PRNGKeyArray,
mu: Float[Array, "... n_plus_1"],
sigma: Float[Array, ...] | float,
c: float,
sample_shape: tuple[int, ...] = (),
dtype=None,
manifold_module: Hyperboloid | None = None,
) -> Float[Array, ...]
Sample from wrapped normal distribution on hyperboloid.
Implements Algorithm 1 from the paper: 1. Sample v_bar ~ N(0, Σ) ∈ R^n 2. Embed as tangent vector v = [0, v_bar] ∈ T_{μ₀}ℍⁿ at origin 3. Parallel transport to mean: u = PT_{μ₀→μ}(v) ∈ T_μℍⁿ 4. Map to manifold: z = exp_μ(u) ∈ ℍⁿ
Args: key: JAX random key mu: Mean point on hyperboloid, shape (..., n+1) in ambient coordinates sigma: Covariance parameterization in spatial coordinates. Can be: - Scalar: isotropic covariance sigma^2 I (n x n) - 1D array of length n: diagonal covariance diag(sigma_1^2, ..., sigma_n^2) - 2D array (n, n): full covariance matrix (must be SPD) c: Curvature (positive scalar) sample_shape: Shape of samples to draw, prepended to output. Default: () dtype: Output dtype. Default: infer from mu
Returns: Samples from wrapped normal distribution, shape sample_shape + mu.shape
Examples: >>> import jax >>> import jax.numpy as jnp >>> from hyperbolix.distributions import wrapped_normal_hyperboloid >>> >>> # Single sample with isotropic covariance >>> key = jax.random.PRNGKey(0) >>> mu = jnp.array([1.0, 0.0, 0.0]) # Origin in H^2 >>> sigma = 0.1 # Isotropic >>> z = wrapped_normal_hyperboloid.sample(key, mu, sigma, c=1.0) >>> z.shape (3,) >>> >>> # Multiple samples with diagonal covariance >>> sigma_diag = jnp.array([0.1, 0.2]) # Diagonal >>> z = wrapped_normal_hyperboloid.sample(key, mu, sigma_diag, c=1.0, sample_shape=(5,)) >>> z.shape (5, 3) >>> >>> # Batch of means >>> mu_batch = jnp.array([[1.0, 0.0, 0.0], [1.0, 0.1, 0.1]]) # (2, 3) >>> z = wrapped_normal_hyperboloid.sample(key, mu_batch, 0.1, c=1.0) >>> z.shape (2, 3)
Source code in hyperbolix/distributions/wrapped_normal_hyperboloid.py
26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 | |
log_prob ¶
log_prob(
z: Float[Array, "... n_plus_1"],
mu: Float[Array, "... n_plus_1"],
sigma: Float[Array, ...] | float,
c: float,
manifold_module: Hyperboloid | None = None,
) -> Float[Array, ...]
Compute log probability of wrapped normal distribution.
Implements Algorithm 2 from the paper: 1. Map z to u = exp_μ⁻¹(z) ∈ T_μℍⁿ (logarithmic map) 2. Move u to v = PT_{μ→μ₀}(u) ∈ T_{μ₀}ℍⁿ (parallel transport to origin) 3. Calculate log p(z) = log p(v) - log det(∂proj_μ(v)/∂v)
Args: z: Sample point(s) on hyperboloid, shape (..., n+1) mu: Mean point on hyperboloid, shape (..., n+1) sigma: Covariance parameterization. Can be: - Scalar: isotropic covariance sigma^2 I (n x n) - 1D array of length n: diagonal covariance diag(sigma_1^2, ..., sigma_n^2) - 2D array (n, n): full covariance matrix (must be SPD) c: Curvature (positive scalar)
Returns: Log probability, shape (...) (manifold dimension removed)
Examples: >>> import jax >>> import jax.numpy as jnp >>> from hyperbolix.distributions import wrapped_normal_hyperboloid >>> >>> # Compute log probability of samples >>> key = jax.random.PRNGKey(0) >>> mu = jnp.array([1.0, 0.0, 0.0]) >>> sigma = 0.1 >>> z = wrapped_normal_hyperboloid.sample(key, mu, sigma, c=1.0) >>> log_p = wrapped_normal_hyperboloid.log_prob(z, mu, sigma, c=1.0) >>> log_p.shape () >>> >>> # Batch computation >>> z_batch = wrapped_normal_hyperboloid.sample(key, mu, sigma, c=1.0, sample_shape=(10,)) >>> log_p_batch = wrapped_normal_hyperboloid.log_prob(z_batch, mu, sigma, c=1.0) >>> log_p_batch.shape (10,)
Source code in hyperbolix/distributions/wrapped_normal_hyperboloid.py
122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 | |
Usage Examples¶
Basic Sampling (Poincaré)¶
from hyperbolix.distributions import wrapped_normal_poincare
from hyperbolix.manifolds import Poincare
import jax
import jax.numpy as jnp
poincare = Poincare()
# Mean on Poincaré ball
mean = jnp.array([0.2, 0.3])
mean_proj = poincare.proj(mean, c=1.0)
# Standard deviation
std = 0.1
# Sample
key = jax.random.PRNGKey(42)
samples = wrapped_normal_poincare.sample(
key, mean_proj, std, c=1.0, sample_shape=(100,), manifold_module=poincare
)
print(samples.shape) # (100, 2)
# Samples lie on Poincaré ball
norms = jnp.linalg.norm(samples, axis=-1)
print(jnp.all(norms < 1.0 / jnp.sqrt(1.0))) # True
Log Probability¶
# Compute log probability of samples
log_probs = jax.vmap(
lambda x: wrapped_normal_poincare.log_prob(x, mean_proj, std, c=1.0, manifold_module=poincare)
)(samples)
print(log_probs.shape) # (100,)
# Higher probability near mean
point_near_mean = poincare.proj(jnp.array([0.21, 0.29]), c=1.0)
point_far = poincare.proj(jnp.array([0.7, 0.7]), c=1.0)
print(f"Log prob (near): {wrapped_normal_poincare.log_prob(point_near_mean, mean_proj, std, c=1.0, manifold_module=poincare):.4f}")
print(f"Log prob (far): {wrapped_normal_poincare.log_prob(point_far, mean_proj, std, c=1.0, manifold_module=poincare):.4f}")
Hyperboloid Distribution¶
from hyperbolix.distributions import wrapped_normal_hyperboloid
from hyperbolix.manifolds import Hyperboloid
import jax.numpy as jnp
hyperboloid = Hyperboloid()
# Mean on hyperboloid (ambient coordinates)
mean_space = jnp.array([0.2, 0.3, -0.1])
mean_ambient = jnp.concatenate([
jnp.array([jnp.sqrt(jnp.sum(mean_space**2) + 1.0)]),
mean_space
])
# Sample
key = jax.random.PRNGKey(123)
samples = wrapped_normal_hyperboloid.sample(
key, mean_ambient, std=0.15, c=1.0, sample_shape=(50,), manifold_module=hyperboloid
)
# Compute log probabilities
log_probs = jax.vmap(
lambda x: wrapped_normal_hyperboloid.log_prob(x, mean_ambient, 0.15, c=1.0, manifold_module=hyperboloid)
)(samples)
VAE Example¶
Using wrapped normal distributions in a Variational Autoencoder:
from flax import nnx
from hyperbolix.distributions import wrapped_normal_poincare
from hyperbolix.nn_layers import HypLinearPoincare
from hyperbolix.manifolds import Poincare
import jax
import jax.numpy as jnp
poincare = Poincare()
class HyperbolicVAE(nnx.Module):
def __init__(self, latent_dim, rngs):
self.latent_dim = latent_dim
# Encoder: Euclidean → Hyperbolic
self.encoder = nnx.Linear(784, 128, rngs=rngs)
self.enc_hyp = HypLinearPoincare(
manifold_module=poincare,
in_dim=128,
out_dim=latent_dim,
rngs=rngs
)
# Decoder: Hyperbolic → Euclidean
self.dec_hyp = HypLinearPoincare(
manifold_module=poincare,
in_dim=latent_dim,
out_dim=128,
rngs=rngs
)
self.decoder = nnx.Linear(128, 784, rngs=rngs)
def encode(self, x, c):
# Returns mean and std for latent distribution
h = jax.nn.relu(self.encoder(x))
# Project to Poincaré ball
h_proj = jax.vmap(poincare.proj, in_axes=(0, None))(h, c)
# Mean on Poincaré ball
mean = self.enc_hyp(h_proj, c)
# Std in tangent space (Euclidean)
log_std_layer = nnx.Linear(128, self.latent_dim, rngs=nnx.Rngs(0))
log_std = log_std_layer(h)
std = jnp.exp(log_std)
return mean, std
def decode(self, z, c):
h = self.dec_hyp(z, c)
# Logmap to tangent space for Euclidean decoder
h_tangent = jax.vmap(poincare.logmap, in_axes=(None, 0, None))(
jnp.zeros(self.latent_dim), h, c
)
return jax.nn.sigmoid(self.decoder(h_tangent))
def __call__(self, x, key, c=1.0):
# Encode
mean, std = self.encode(x, c)
# Sample latent code
keys = jax.random.split(key, mean.shape[0])
z = jax.vmap(
lambda k, m, s: wrapped_normal_poincare.sample(k, m, s, c, (), manifold_module=poincare)
)(keys, mean, std)
# Decode
recon = self.decode(z, c)
return recon, mean, std, z
# Loss function
def vae_loss(model, x, key, c):
recon, mean, std, z = model(x, key, c)
# Reconstruction loss
recon_loss = jnp.mean((x - recon) ** 2)
# KL divergence (approximate for wrapped normal)
# Use standard Gaussian prior in tangent space at origin
kl_loss = -0.5 * jnp.mean(
1 + 2 * jnp.log(std) - jnp.sum(mean**2, axis=-1) - std**2
)
return recon_loss + kl_loss
Mathematical Background¶
Wrapped Normal Definition¶
Given a mean \(\mu \in \mathcal{M}\) on manifold \(\mathcal{M}\) and standard deviation \(\sigma\), the wrapped normal distribution is defined as:
- Sample \(v \sim \mathcal{N}(0, \sigma^2 I)\) in tangent space \(T_\mu \mathcal{M}\)
- Wrap to manifold: \(x = \exp_\mu(v)\)
The log probability is:
where \(\log_\mu\) is the logarithmic map at \(\mu\).
Sampling Algorithm¶
# Conceptual implementation (simplified)
def sample_concept(key, mean, std, c, sample_shape, manifold):
# 1. Sample in tangent space at mean
tangent_sample = std * jax.random.normal(key, sample_shape + mean.shape)
# 2. Exponential map to manifold
manifold_sample = manifold.expmap(tangent_sample, mean, c)
return manifold_sample
Numerical Considerations¶
Numerical Stability
For small standard deviations and/or high curvatures, the exponential map can become numerically unstable. Consider:
- Using float64 for very small \(\sigma\) (< 0.01)
- Clipping standard deviations to reasonable range: \(\sigma \in [0.01, 1.0]\)
- Using version parameter in manifold operations for better stability
Curvature Choice
The curvature parameter \(c\) affects the distribution:
- Higher \(c\) → More concentrated distributions
- Lower \(c\) → More spread out distributions
Tune \(c\) based on your application's needs.
References¶
Wrapped distributions on manifolds are discussed in:
- Nagano, Y., et al. (2019). "A Wrapped Normal Distribution on Hyperbolic Space for Gradient-Based Learning"
- Davidson, T., et al. (2018). "Hyperspherical Variational Auto-Encoders"
See also:
- Manifolds API: Exponential and logarithmic maps
- NN Layers API: Building VAEs with hyperbolic layers