Neural Network Layers API¶
Hyperbolic neural network layers built with Flax NNX.
Overview¶
Hyperbolix provides 20+ neural network layer classes and 5 activation functions for building hyperbolic deep learning models:
- Linear Layers: Poincaré and Hyperboloid linear transformations, including FGG (Fast and Geometrically Grounded) layers
- Convolutional Layers: HCat-based, HRC-based, and FGG hyperbolic convolutions
- Normalization: Poincaré batch normalization (
PoincareBatchNorm2D), HRC-wrapped norms, and FGG mean-only batch norm - Hypformer Components: HTC (Hyperbolic Transformation Component) and HRC (Hyperbolic Regularization Component) with curvature-change support
- FGG Components:
FGGLinear,FGGConv2D,FGGMeanOnlyBatchNormfrom Klis et al. (2026) — linear-distance growth, ~3× faster than prior work - Attention Layers: Three hyperbolic attention variants (linear O(N), softmax O(N²), full Lorentzian O(N²)) from the Hypformer paper
- Positional Encoding: HOPE (Hyperbolic Rotary PE) and Hypformer learnable positional encodings for Transformers
- Regression Layers: Single-layer classifiers with Riemannian geometry, including
FGGLorentzMLR - Activation Functions: Hyperbolic ReLU, Leaky ReLU, Tanh, Swish, GELU
- Helper Functions: Utilities for regression and conformal factor computation
All layers follow Flax NNX conventions and store manifold module references.
Linear Layers¶
Poincaré Linear¶
hyperbolix.nn_layers.HypLinearPoincare ¶
HypLinearPoincare(
manifold_module: Manifold,
in_dim: int,
out_dim: int,
*,
rngs: Rngs,
input_space: str = "manifold",
)
Bases: Module
Hyperbolic Neural Networks fully connected layer (Poincaré ball model).
Computation steps: 0) Project the input tensor to the tangent space (optional) 1) Perform matrix vector multiplication in the tangent space at the origin. 2) Map the result to the manifold. 3) Add the manifold bias to the result.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
manifold_module
|
object
|
Class-based Poincare manifold instance |
required |
in_dim
|
int
|
Dimension of the input space |
required |
out_dim
|
int
|
Dimension of the output space |
required |
rngs
|
Rngs
|
Random number generators for parameter initialization |
required |
input_space
|
str
|
Type of the input tensor, either 'tangent' or 'manifold' (default: 'manifold'). Note: This is a static configuration - changing it after initialization requires recompilation. |
'manifold'
|
Notes
JIT Compatibility: This layer is designed to work with nnx.jit. Configuration parameters (input_space) are treated as static and will be baked into the compiled function. Changing these values after JIT compilation will trigger automatic recompilation.
References
Ganea Octavian, Gary Bécigneul, and Thomas Hofmann. "Hyperbolic neural networks." Advances in neural information processing systems 31 (2018).
Source code in hyperbolix/nn_layers/poincare_linear.py
__call__ ¶
Forward pass through the hyperbolic linear layer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array of shape (batch, in_dim)
|
Input tensor where the hyperbolic_axis is last |
required |
c
|
float
|
Manifold curvature (default: 1.0) |
1.0
|
Returns:
| Name | Type | Description |
|---|---|---|
res |
Array of shape (batch, out_dim)
|
Output on the Poincaré ball manifold |
Source code in hyperbolix/nn_layers/poincare_linear.py
hyperbolix.nn_layers.HypLinearPoincarePP ¶
HypLinearPoincarePP(
manifold_module: Poincare,
in_dim: int,
out_dim: int,
*,
rngs: Rngs,
input_space: str = "manifold",
clamping_factor: float = 1.0,
smoothing_factor: float = 50.0,
)
Bases: Module
Hyperbolic Neural Networks ++ fully connected layer (Poincaré ball model).
Computation steps: 0) Project the input tensor onto the manifold (optional) 1) Compute the multinomial linear regression score(s) 2) Calculate the generalized linear transformation from the regression score(s)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
manifold_module
|
object
|
Class-based Poincare manifold instance |
required |
in_dim
|
int
|
Dimension of the input space |
required |
out_dim
|
int
|
Dimension of the output space |
required |
rngs
|
Rngs
|
Random number generators for parameter initialization |
required |
input_space
|
str
|
Type of the input tensor, either 'tangent' or 'manifold' (default: 'manifold'). Note: This is a static configuration - changing it after initialization requires recompilation. |
'manifold'
|
clamping_factor
|
float
|
Clamping factor for the multinomial linear regression output (default: 1.0) |
1.0
|
smoothing_factor
|
float
|
Smoothing factor for the multinomial linear regression output (default: 50.0) |
50.0
|
Notes
JIT Compatibility: This layer is designed to work with nnx.jit. Configuration parameters (input_space, clamping_factor, smoothing_factor) are treated as static and will be baked into the compiled function.
References
Shimizu Ryohei, Yusuke Mukuta, and Tatsuya Harada. "Hyperbolic neural networks++." arXiv preprint arXiv:2006.08210 (2020).
Source code in hyperbolix/nn_layers/poincare_linear.py
__call__ ¶
Forward pass through the HNN++ hyperbolic linear layer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array of shape (batch, in_dim)
|
Input tensor where the hyperbolic_axis is last |
required |
c
|
float
|
Manifold curvature (default: 1.0) |
1.0
|
Returns:
| Name | Type | Description |
|---|---|---|
res |
Array of shape (batch, out_dim)
|
Output on the Poincaré ball manifold |
Source code in hyperbolix/nn_layers/poincare_linear.py
Hyperboloid Linear¶
hyperbolix.nn_layers.HypLinearHyperboloidFHCNN ¶
HypLinearHyperboloidFHCNN(
manifold_module: Manifold,
in_dim: int,
out_dim: int,
*,
rngs: Rngs,
input_space: str = "manifold",
init_scale: float = 2.3,
learnable_scale: bool = False,
eps: float = 1e-05,
activation: Callable[[Array], Array] | None = None,
normalize: bool = False,
)
Bases: Module
Fully Hyperbolic Convolutional Neural Networks fully connected layer (Hyperboloid model).
Computation steps: 0) Project the input tensor to the manifold (optional) 1) Apply activation (optional) 2) a) If normalize is True, compute the time and space coordinates of the output by applying a scaled sigmoid of the weight and biases transformed coordinates of the input or the result of the previous step. b) If normalize is False, compute the weight and biases transformed space coordinates of the input or the result of the previous step and set the time coordinate such that the result lies on the manifold.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
manifold_module
|
object
|
Class-based Hyperboloid manifold instance |
required |
in_dim
|
int
|
Dimension of the input space |
required |
out_dim
|
int
|
Dimension of the output space |
required |
rngs
|
Rngs
|
Random number generators for parameter initialization |
required |
input_space
|
str
|
Type of the input tensor, either 'tangent' or 'manifold' (default: 'manifold'). Note: This is a static configuration - changing it after initialization requires recompilation. |
'manifold'
|
init_scale
|
float
|
Initial value for the sigmoid scale parameter (default: 2.3) |
2.3
|
learnable_scale
|
bool
|
Whether the scale parameter should be learnable (default: False) |
False
|
eps
|
float
|
Small value to ensure that the time coordinate is bigger than 1/sqrt(c) (default: 1e-5) |
1e-05
|
activation
|
callable or None
|
Activation function to apply before the linear transformation (default: None). Note: This is a static configuration - changing it after initialization requires recompilation. |
None
|
normalize
|
bool
|
Whether to normalize the space coordinates before rescaling (default: False). Note: This is a static configuration - changing it after initialization requires recompilation. |
False
|
Notes
JIT Compatibility: This layer is designed to work with nnx.jit. Configuration parameters (input_space, activation, normalize) are treated as static and will be baked into the compiled function.
Relationship to HTC/HRC:
When normalize=False and c_in = c_out, this layer uses the same time reconstruction
pattern as htc: time = sqrt(||space||^2 + 1/c). The key difference is that FHCNN
applies a linear transform to the full input and discards the computed time, while htc
uses the linear output directly as spatial components. When normalize=True, FHCNN uses
a learned sigmoid scaling which differs from both htc and hrc.
See Also
htc : Hyperbolic Transformation Component with curvature change support. Similar time reconstruction pattern when normalize=False. HTCLinear : Module wrapper for htc with learnable linear transformation.
References
Ahmad Bdeir, Kristian Schwethelm, and Niels Landwehr. "Fully hyperbolic convolutional neural networks for computer vision." arXiv preprint arXiv:2303.15919 (2023).
Source code in hyperbolix/nn_layers/hyperboloid_linear.py
__call__ ¶
Forward pass through the FHCNN hyperbolic linear layer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array of shape (batch, in_dim)
|
Input tensor where the hyperbolic_axis is last. x.shape[-1] must equal self.in_dim. |
required |
c
|
float
|
Manifold curvature (default: 1.0) |
1.0
|
Returns:
| Name | Type | Description |
|---|---|---|
res |
Array of shape (batch, out_dim)
|
Output on the Hyperboloid manifold |
Source code in hyperbolix/nn_layers/hyperboloid_linear.py
hyperbolix.nn_layers.HypLinearHyperboloidFHNN ¶
HypLinearHyperboloidFHNN(
manifold_module: Manifold,
in_dim: int,
out_dim: int,
*,
rngs: Rngs,
input_space: str = "manifold",
init_scale: float = 2.3,
eps: float = 1e-05,
activation: Callable[[Array], Array] | None = None,
dropout_rate: float | None = None,
)
Bases: Module
Fully Hyperbolic Neural Networks linear layer (Chen et al. 2021).
Time-primary parameterization: the time coordinate is computed via scaled sigmoid with an additive floor at 1/sqrt(c), and spatial coordinates are rescaled so the output lies on the hyperboloid.
Computation steps: 0) Project the input tensor to the manifold (optional) 1) Apply activation (optional) 2) Apply dropout (optional) 3) Linear transform: z = W @ x + b 4) Time: y0 = exp(scale) * sigmoid(z0) + 1/sqrt(c) + eps 5) Spatial: y_rem = sqrt(y0^2 - 1/c) / ||z_rem|| * z_rem 6) Output: [y0, y_rem] on the hyperboloid
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
manifold_module
|
object
|
Class-based Hyperboloid manifold instance |
required |
in_dim
|
int
|
Ambient input dimension (d+1, including time) |
required |
out_dim
|
int
|
Ambient output dimension (d+1, including time) |
required |
rngs
|
Rngs
|
Random number generators for parameter initialization |
required |
input_space
|
str
|
Type of the input tensor, either 'tangent' or 'manifold' (default: 'manifold'). Note: This is a static configuration - changing it after initialization requires recompilation. |
'manifold'
|
init_scale
|
float
|
Initial value for the learnable sigmoid scale (default: 2.3) |
2.3
|
eps
|
float
|
Numerical stability epsilon (default: 1e-5) |
1e-05
|
activation
|
callable or None
|
Activation function to apply before the linear transformation (default: None). Note: This is a static configuration - changing it after initialization requires recompilation. |
None
|
dropout_rate
|
float or None
|
Dropout rate applied before the linear transformation (default: None). |
None
|
Notes
JIT Compatibility: This layer is designed to work with nnx.jit. Configuration parameters (input_space, activation) are treated as static and will be baked into the compiled function.
Weight Initialization: Weights are initialized as tangent vectors at the hyperboloid origin: U(-0.02, 0.02) with the time column (column 0) zeroed. This matches the Chen et al. 2021 init.
Relationship to FHCNN: Both FHNN and FHCNN ensure outputs lie on the hyperboloid, but differ in which coordinate is primary. FHNN controls the time coordinate via sigmoid with an additive floor at 1/sqrt(c), then derives the spatial norm from the hyperboloid constraint. FHCNN (normalize=True) controls the spatial norm via sigmoid, then derives time.
See Also
HypLinearHyperboloidFHCNN : FHCNN layer with spatial-primary parameterization. HypLinearHyperboloidPP : HNN++ layer using MLR + sinh diffeomorphism.
References
Weize Chen, et al. "Fully hyperbolic neural networks." arXiv preprint arXiv:2105.14686 (2021).
Source code in hyperbolix/nn_layers/hyperboloid_linear.py
__call__ ¶
__call__(
x: Float[Array, "batch in_dim"],
c: float = 1.0,
deterministic: bool = True,
) -> Float[Array, "batch out_dim"]
Forward pass through the FHNN hyperbolic linear layer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array of shape (batch, in_dim)
|
Input tensor. x.shape[-1] must equal self.in_dim. |
required |
c
|
float
|
Manifold curvature (default: 1.0) |
1.0
|
deterministic
|
bool
|
If True, dropout is disabled (default: True). |
True
|
Returns:
| Name | Type | Description |
|---|---|---|
res |
Array of shape (batch, out_dim)
|
Output on the Hyperboloid manifold |
Source code in hyperbolix/nn_layers/hyperboloid_linear.py
hyperbolix.nn_layers.HypLinearHyperboloidPP ¶
HypLinearHyperboloidPP(
manifold_module: Hyperboloid,
in_dim: int,
out_dim: int,
*,
rngs: Rngs,
input_space: str = "manifold",
clamping_factor: float = 1.0,
smoothing_factor: float = 50.0,
)
Bases: Module
Hyperbolic Neural Networks ++ fully connected layer (Hyperboloid model).
Computation steps:
0) Project the input tensor onto the manifold (optional)
1) Compute the multinomial linear regression score(s) via compute_mlr
2) Apply element-wise sinh diffeomorphism to obtain spatial coordinates
3) Reconstruct time coordinate from the hyperboloid constraint
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
manifold_module
|
object
|
Class-based Hyperboloid manifold instance |
required |
in_dim
|
int
|
Full input dimension (ambient, d+1) |
required |
out_dim
|
int
|
Full output dimension (ambient, d+1) |
required |
rngs
|
Rngs
|
Random number generators for parameter initialization |
required |
input_space
|
str
|
Type of the input tensor, either 'tangent' or 'manifold' (default: 'manifold'). Note: This is a static configuration - changing it after initialization requires recompilation. |
'manifold'
|
clamping_factor
|
float
|
Clamping factor for the multinomial linear regression output (default: 1.0) |
1.0
|
smoothing_factor
|
float
|
Smoothing factor for the multinomial linear regression output (default: 50.0) |
50.0
|
Notes
JIT Compatibility: This layer is designed to work with nnx.jit. Configuration parameters (input_space, clamping_factor, smoothing_factor) are treated as static and will be baked into the compiled function.
References
Shimizu Ryohei, Yusuke Mukuta, and Tatsuya Harada. "Hyperbolic neural networks++." arXiv preprint arXiv:2006.08210 (2020).
Source code in hyperbolix/nn_layers/hyperboloid_linear.py
__call__ ¶
Forward pass through the HNN++ hyperboloid linear layer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array of shape (batch, in_dim)
|
Input tensor. x.shape[-1] must equal self.in_dim. |
required |
c
|
float
|
Manifold curvature (default: 1.0) |
1.0
|
Returns:
| Name | Type | Description |
|---|---|---|
res |
Array of shape (batch, out_dim)
|
Output on the Hyperboloid manifold |
Source code in hyperbolix/nn_layers/hyperboloid_linear.py
FGG Linear (Klis et al. 2026)¶
hyperbolix.nn_layers.FGGLinear ¶
FGGLinear(
in_features: int,
out_features: int,
*,
rngs: Rngs,
activation: Callable[[Array], Array] | None = None,
reset_params: str = "eye",
use_weight_norm: bool = False,
init_bias: float = 0.5,
eps: float = 1e-07,
)
Bases: Module
Fast and Geometrically Grounded Lorentz linear layer.
Implements the FGG linear layer from Klis et al. 2026. The key insight is that the sinh/arcsinh cancellation in the Lorentzian activation chain simplifies the forward pass to: matmul with spacelike V matrix -> Euclidean activation -> time reconstruction. This achieves linear growth of hyperbolic distance (vs logarithmic for Chen et al. 2022) and ~3x faster training/inference.
Forward pass: 1. Build spacelike V matrix from (U, b) with Minkowski metric absorbed 2. z = x @ V (Minkowski inner products via a single matmul) 3. z = h(z) (Euclidean activation, e.g. ReLU) 4. y_0 = sqrt(||z||^2 + 1/c) (time reconstruction) 5. y = [y_0, z] (on hyperboloid)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_features
|
int
|
Input ambient dimension (D_in + 1), including time component. |
required |
out_features
|
int
|
Output ambient dimension (D_out + 1), including time component. |
required |
rngs
|
Rngs
|
Random number generators for parameter initialization. |
required |
activation
|
Callable or None
|
Euclidean activation function applied after matmul (default: None). |
None
|
reset_params
|
str
|
Weight initialization scheme: |
'eye'
|
use_weight_norm
|
bool
|
If True, reparameterize U as |
False
|
init_bias
|
float
|
Initial value for bias entries (default: 0.5). |
0.5
|
eps
|
float
|
Numerical stability floor (default: 1e-7). |
1e-07
|
References
Klis et al. "Fast and Geometrically Grounded Lorentz Neural Networks" (2026).
Examples:
>>> from flax import nnx
>>> from hyperbolix.nn_layers import FGGLinear
>>> import jax.numpy as jnp
>>>
>>> layer = FGGLinear(33, 65, rngs=nnx.Rngs(0), activation=jax.nn.relu)
>>> x = jnp.ones((8, 33))
>>> # project to hyperboloid
>>> x = x.at[:, 0].set(jnp.sqrt(jnp.sum(x[:, 1:]**2, axis=-1) + 1.0))
>>> y = layer(x, c=1.0)
>>> y.shape
(8, 65)
Source code in hyperbolix/nn_layers/hyperboloid_linear.py
__call__ ¶
__call__(
x_BAi: Float[Array, "batch in_features"], c: float = 1.0
) -> Float[Array, "batch out_features"]
Forward pass through the FGG linear layer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x_BAi
|
(Array, shape(B, Ai))
|
Input points on the hyperboloid with curvature |
required |
c
|
float
|
Curvature parameter (default: 1.0). |
1.0
|
Returns:
| Name | Type | Description |
|---|---|---|
y_BAo |
(Array, shape(B, Ao))
|
Output points on the hyperboloid with curvature |
Source code in hyperbolix/nn_layers/hyperboloid_linear.py
Usage Example¶
import jax
from flax import nnx
from hyperbolix.nn_layers import HypLinearPoincare
from hyperbolix.manifolds import Poincare
poincare = Poincare()
# Create hyperbolic linear layer
layer = HypLinearPoincare(
manifold_module=poincare,
in_dim=32,
out_dim=16,
rngs=nnx.Rngs(0)
)
# Forward pass
x = jax.random.normal(jax.random.PRNGKey(1), (10, 32)) * 0.3
x_proj = jax.vmap(poincare.proj, in_axes=(0, None))(x, 1.0)
output = layer(x_proj, c=1.0)
print(output.shape) # (10, 16)
Convolutional Layers¶
Hyperboloid Convolutions¶
hyperbolix.nn_layers.HypConv2DHyperboloid ¶
HypConv2DHyperboloid(
manifold_module: Hyperboloid,
in_channels: int,
out_channels: int,
kernel_size: int | tuple[int, int],
*,
rngs: Rngs,
stride: int | tuple[int, int] = 1,
padding: str = "SAME",
input_space: str = "manifold",
)
Bases: Module
Hyperbolic 2D Convolutional Layer for Hyperboloid model.
This layer implements fully hyperbolic convolution as described in "Fully Hyperbolic Convolutional Neural Networks for Computer Vision".
Computation steps: 1) Extract receptive field (kernel_size x kernel_size) of hyperbolic points 2) Apply HCat (Lorentz direct concatenation) to combine receptive field points 3) Pass through hyperbolic linear layer (LFC)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
manifold_module
|
object
|
Class-based Hyperboloid manifold instance |
required |
in_channels
|
int
|
Number of input channels |
required |
out_channels
|
int
|
Number of output channels |
required |
kernel_size
|
int or tuple[int, int]
|
Size of the convolutional kernel |
required |
rngs
|
Rngs
|
Random number generators for parameter initialization |
required |
stride
|
int or tuple[int, int]
|
Stride of the convolution (default: 1) |
1
|
padding
|
str
|
Padding mode, either 'SAME' or 'VALID' (default: 'SAME') |
'SAME'
|
input_space
|
str
|
Type of the input tensor, either 'tangent' or 'manifold' (default: 'manifold'). Note: This is a static configuration - changing it after initialization requires recompilation. |
'manifold'
|
Notes
JIT Compatibility: This layer is designed to work with nnx.jit. Configuration parameters (padding, input_space) are treated as static and will be baked into the compiled function.
References
Ahmad Bdeir, Kristian Schwethelm, and Niels Landwehr. "Fully hyperbolic convolutional neural networks for computer vision." arXiv preprint arXiv:2303.15919 (2023).
Source code in hyperbolix/nn_layers/hyperboloid_conv.py
__call__ ¶
__call__(
x: Float[Array, "batch height width in_channels"],
c: float = 1.0,
) -> Float[
Array, "batch out_height out_width out_channels"
]
Forward pass through the hyperbolic convolutional layer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array of shape (batch, height, width, in_channels)
|
Input feature map where each pixel is a point on the Hyperboloid manifold |
required |
c
|
float
|
Manifold curvature (default: 1.0) |
1.0
|
Returns:
| Name | Type | Description |
|---|---|---|
out |
Array of shape (batch, out_height, out_width, out_channels)
|
Output feature map on the Hyperboloid manifold |
Source code in hyperbolix/nn_layers/hyperboloid_conv.py
hyperbolix.nn_layers.HypConv2DHyperboloidFHNN ¶
HypConv2DHyperboloidFHNN(
manifold_module: Hyperboloid,
in_channels: int,
out_channels: int,
kernel_size: int | tuple[int, int],
*,
rngs: Rngs,
stride: int | tuple[int, int] = 1,
padding: str = "SAME",
input_space: str = "manifold",
init_scale: float = 2.3,
eps: float = 1e-05,
activation: Callable[[Array], Array] | None = None,
dropout_rate: float | None = None,
)
Bases: Module
Fully Hyperbolic Neural Networks 2D convolutional layer (Chen et al. 2021).
Uses HCat (Lorentz direct concatenation) to combine receptive field points, then applies the FHNN linear transform (time-primary sigmoid parameterization) for channel mixing.
Computation steps: 1) Map input to manifold via expmap_0 if input_space="tangent" 2) Extract receptive field (kernel_size x kernel_size) of hyperbolic points 3) Apply HCat (Lorentz direct concatenation) to combine receptive field points 4) Pass through FHNN linear (sigmoid time + spatial rescaling)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
manifold_module
|
Hyperboloid
|
Class-based Hyperboloid manifold instance |
required |
in_channels
|
int
|
Number of input channels (ambient dimension, including time component) |
required |
out_channels
|
int
|
Number of output channels (ambient dimension, including time component) |
required |
kernel_size
|
int or tuple[int, int]
|
Size of the convolutional kernel |
required |
rngs
|
Rngs
|
Random number generators for parameter initialization |
required |
stride
|
int or tuple[int, int]
|
Stride of the convolution (default: 1) |
1
|
padding
|
str
|
Padding mode, either 'SAME' or 'VALID' (default: 'SAME') |
'SAME'
|
input_space
|
str
|
Type of the input tensor, either 'tangent' or 'manifold' (default: 'manifold'). Note: This is a static configuration - changing it after initialization requires recompilation. |
'manifold'
|
init_scale
|
float
|
Initial value for the learnable sigmoid scale (default: 2.3) |
2.3
|
eps
|
float
|
Numerical stability epsilon (default: 1e-5) |
1e-05
|
activation
|
callable or None
|
Activation function to apply before the linear transformation (default: None). |
None
|
dropout_rate
|
float or None
|
Dropout rate applied before the linear transformation (default: None). |
None
|
Notes
JIT Compatibility: This layer is designed to work with nnx.jit. Configuration parameters (padding, input_space, activation) are treated as static and baked into the compiled function.
See Also
HypConv2DHyperboloid : Equivalent convolution using FHCNN linear instead of FHNN. HypLinearHyperboloidFHNN : The underlying FHNN linear layer.
References
Weize Chen, et al. "Fully hyperbolic neural networks." arXiv preprint arXiv:2105.14686 (2021).
Source code in hyperbolix/nn_layers/hyperboloid_conv.py
__call__ ¶
__call__(
x: Float[Array, "batch height width in_channels"],
c: float = 1.0,
deterministic: bool = True,
) -> Float[
Array, "batch out_height out_width out_channels"
]
Forward pass through the FHNN hyperbolic convolutional layer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array of shape (batch, height, width, in_channels)
|
Input feature map where each pixel is a point on the Hyperboloid manifold |
required |
c
|
float
|
Manifold curvature (default: 1.0) |
1.0
|
deterministic
|
bool
|
If True, dropout is disabled (default: True). |
True
|
Returns:
| Name | Type | Description |
|---|---|---|
out |
Array of shape (batch, out_height, out_width, out_channels)
|
Output feature map on the Hyperboloid manifold |
Source code in hyperbolix/nn_layers/hyperboloid_conv.py
hyperbolix.nn_layers.HypConv2DHyperboloidPP ¶
HypConv2DHyperboloidPP(
manifold_module: Hyperboloid,
in_channels: int,
out_channels: int,
kernel_size: int | tuple[int, int],
*,
rngs: Rngs,
stride: int | tuple[int, int] = 1,
padding: str = "SAME",
input_space: str = "manifold",
clamping_factor: float = 1.0,
smoothing_factor: float = 50.0,
)
Bases: Module
Hyperbolic Neural Networks ++ 2D convolutional layer (Hyperboloid model).
Uses HCat (Lorentz direct concatenation) to combine receptive field points, then applies the HNN++ linear transform (MLR + sinh diffeomorphism) for channel mixing.
Computation steps: 1) Extract receptive field (kernel_size x kernel_size) of hyperbolic points 2) Apply HCat (Lorentz direct concatenation) to combine receptive field points 3) Pass through HNN++ linear (MLR scores -> sinh -> time reconstruction)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
manifold_module
|
Hyperboloid
|
Class-based Hyperboloid manifold instance |
required |
in_channels
|
int
|
Number of input channels (ambient dimension, including time component) |
required |
out_channels
|
int
|
Number of output channels (ambient dimension, including time component) |
required |
kernel_size
|
int or tuple[int, int]
|
Size of the convolutional kernel |
required |
rngs
|
Rngs
|
Random number generators for parameter initialization |
required |
stride
|
int or tuple[int, int]
|
Stride of the convolution (default: 1) |
1
|
padding
|
str
|
Padding mode, either 'SAME' or 'VALID' (default: 'SAME') |
'SAME'
|
input_space
|
str
|
Type of the input tensor, either 'tangent' or 'manifold' (default: 'manifold'). Note: This is a static configuration - changing it after initialization requires recompilation. |
'manifold'
|
clamping_factor
|
float
|
Clamping factor for the multinomial linear regression output (default: 1.0) |
1.0
|
smoothing_factor
|
float
|
Smoothing factor for the multinomial linear regression output (default: 50.0) |
50.0
|
Notes
JIT Compatibility: This layer is designed to work with nnx.jit. Configuration parameters (padding, input_space, clamping_factor, smoothing_factor) are treated as static and baked into the compiled function.
References
Shimizu Ryohei, Yusuke Mukuta, and Tatsuya Harada. "Hyperbolic neural networks++." arXiv preprint arXiv:2006.08210 (2020).
Source code in hyperbolix/nn_layers/hyperboloid_conv.py
__call__ ¶
__call__(
x: Float[Array, "batch height width in_channels"],
c: float = 1.0,
) -> Float[
Array, "batch out_height out_width out_channels"
]
Forward pass through the HNN++ hyperboloid convolutional layer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array of shape (batch, height, width, in_channels)
|
Input feature map where each pixel is a point on the Hyperboloid manifold |
required |
c
|
float
|
Manifold curvature (default: 1.0) |
1.0
|
Returns:
| Name | Type | Description |
|---|---|---|
out |
Array of shape (batch, out_height, out_width, out_channels)
|
Output feature map on the Hyperboloid manifold |
Source code in hyperbolix/nn_layers/hyperboloid_conv.py
Usage Example¶
import jax
import jax.numpy as jnp
from hyperbolix.nn_layers import HypConv2DHyperboloid
from hyperbolix.manifolds import Hyperboloid
from flax import nnx
hyperboloid = Hyperboloid()
# Create 2D hyperbolic convolution
conv = HypConv2DHyperboloid(
manifold_module=hyperboloid,
in_channels=16,
out_channels=32,
kernel_size=(3, 3),
stride=(1, 1),
rngs=nnx.Rngs(0)
)
# Input: (batch, height, width, in_channels) — ambient dim = in_channels+1
x = jax.random.normal(jax.random.PRNGKey(1), (8, 28, 28, 16))
# Project to hyperboloid
x_ambient = jnp.concatenate([
jnp.sqrt(jnp.sum(x**2, axis=-1, keepdims=True) + 1.0),
x
], axis=-1) # (8, 28, 28, 17)
# Forward pass (input_space="manifold" by default)
output = conv(x_ambient, c=1.0)
print(output.shape) # (8, 28, 28, 32×9+1) - dimension grows!
Dimensional Growth
Hyperboloid convolutions increase dimensionality via HCat operation:
- Input:
d+1dimensions - Output:
(d×N)+1dimensions whereN = kernel_height × kernel_width
For 3×3 kernel: 3D input → 28D output. Use small kernels or add dimensionality reduction layers.
Poincaré Convolution¶
hyperbolix.nn_layers.HypConv2DPoincare ¶
HypConv2DPoincare(
manifold_module: Poincare,
in_channels: int,
out_channels: int,
kernel_size: int | tuple[int, int],
*,
rngs: Rngs,
stride: int | tuple[int, int] = 1,
padding: str = "SAME",
input_space: str = "tangent",
id_init: bool = True,
clamping_factor: float = 1.0,
smoothing_factor: float = 50.0,
)
Bases: Module
Hyperbolic 2D Convolutional Layer for Poincaré ball model.
This layer implements hyperbolic convolution following the Poincaré ResNet (van Spengler et al. 2023) approach: beta-scaling in tangent space, patch extraction, expmap to manifold, HNN++ fully-connected, logmap back to tangent space.
The layer operates in tangent space internally and returns tangent-space output, matching the reference implementation. This design avoids the numerically unstable logmap_0 round-trips that cause NaN gradients when points approach the Poincaré ball boundary.
Computation steps: 1) Map to tangent space if input is on manifold 2) Scale tangent vectors by beta function ratio (beta-concatenation scaling) 3) Extract patches via im2col (zero-padding in tangent space) 4) Map concatenated patch vectors to manifold via expmap_0 5) Apply HNN++ fully-connected layer 6) Map back to tangent space via logmap_0
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
manifold_module
|
Poincare
|
Class-based Poincaré manifold instance |
required |
in_channels
|
int
|
Number of input channels (Poincaré ball dimension per pixel) |
required |
out_channels
|
int
|
Number of output channels (Poincaré ball dimension per pixel) |
required |
kernel_size
|
int or tuple[int, int]
|
Size of the convolutional kernel |
required |
rngs
|
Rngs
|
Random number generators for parameter initialization |
required |
stride
|
int or tuple[int, int]
|
Stride of the convolution (default: 1) |
1
|
padding
|
str
|
Padding mode, either 'SAME' or 'VALID' (default: 'SAME') |
'SAME'
|
input_space
|
str
|
Type of the input tensor, either 'tangent' or 'manifold' (default: 'tangent'). Note: This is a static configuration - changing it after initialization requires recompilation. |
'tangent'
|
id_init
|
bool
|
If True, use identity initialization (1/2 * I) for the linear sublayer weights, matching the reference Poincaré ResNet implementation (default: True). The 1/2 factor compensates for the factor of 2 inside the HNN++ distance formula. |
True
|
clamping_factor
|
float
|
Clamping factor for the HNN++ linear layer output (default: 1.0) |
1.0
|
smoothing_factor
|
float
|
Smoothing factor for the HNN++ linear layer output (default: 50.0) |
50.0
|
Notes
Output Space: This layer always returns tangent-space output (matching the reference). Between conv layers, use standard activations (e.g., jax.nn.relu) directly on the tangent-space features. Use expmap_0 to map back to the manifold only when needed (e.g., before the classification head).
JIT Compatibility: This layer is designed to work with nnx.jit. Configuration parameters (padding, input_space) are treated as static and will be baked into the compiled function.
Dimension math: - beta-scaling + patch extraction: (H, W, C_in) → (oh, ow, K^2 * C_in) - HNN++ linear: in_dim = K^2 * C_in, out_dim = C_out
References
Shimizu et al. "Hyperbolic neural networks++." arXiv:2006.08210 (2020). van Spengler et al. "Poincaré ResNet." ICML 2023.
Source code in hyperbolix/nn_layers/poincare_conv.py
__call__ ¶
__call__(
x: Float[Array, "batch height width in_channels"],
c: float = 1.0,
) -> Float[
Array, "batch out_height out_width out_channels"
]
Forward pass through the Poincaré convolutional layer.
Follows the reference computation flow: tangent-space beta-scaling, patch extraction, expmap_0, HNN++ FC, logmap_0.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array of shape (batch, height, width, in_channels)
|
Input feature map. Can be tangent-space or manifold points depending on input_space setting. |
required |
c
|
float
|
Manifold curvature (default: 1.0) |
1.0
|
Returns:
| Name | Type | Description |
|---|---|---|
out |
Array of shape (batch, out_height, out_width, out_channels)
|
Output feature map in tangent space at the origin. Use standard activations (e.g., jax.nn.relu) between layers. |
Source code in hyperbolix/nn_layers/poincare_conv.py
171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 | |
HypConv2DPoincare extracts patches, applies beta-concatenation (HNN++, Shimizu et al. 2020) over the receptive field, then passes through a HypLinearPoincarePP layer. Dimension math: K² × C_in → C_out where K is the kernel size.
Key Differences from Hyperboloid Convolutions:
| Feature | HypConv2DPoincare | HypConv2DHyperboloid |
|---|---|---|
| Model | Poincaré ball | Hyperboloid |
| Aggregation | Beta-concatenation | HCat (Lorentz concatenation) |
| Dimension | Preserved | Grows: (d-1)×K²+1 |
| Default input | Tangent space | Manifold (ambient) |
Usage Example¶
import jax
import jax.numpy as jnp
from hyperbolix.nn_layers import HypConv2DPoincare
from hyperbolix.manifolds import Poincare
from flax import nnx
poincare = Poincare()
# Create Poincaré 2D convolution
conv = HypConv2DPoincare(
manifold_module=poincare,
in_channels=16,
out_channels=32,
kernel_size=3,
stride=1,
rngs=nnx.Rngs(0)
)
# Input: (batch, height, width, in_channels) in tangent space (default input_space="tangent")
x = jax.random.normal(jax.random.PRNGKey(1), (8, 28, 28, 16)) * 0.1
# Forward pass — returns tangent-space output
output = conv(x, c=1.0)
print(output.shape) # (8, 28, 28, 32)
Poincaré Batch Normalization¶
hyperbolix.nn_layers.PoincareBatchNorm2D ¶
PoincareBatchNorm2D(
manifold_module: Poincare,
num_features: int,
*,
use_running_average: bool = False,
momentum: float = 0.9,
eps: float = 1e-06,
)
Bases: Module
Poincaré batch normalization for 2D feature maps.
Operates in tangent space (matching conv layer I/O). Internally maps to the manifold for geometric operations (midpoint, parallel transport, variance rescaling), then maps back to tangent space.
Follows nnx.BatchNorm interface: use_running_average is a
constructor parameter overridable at call time.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
manifold_module
|
Poincare
|
Poincaré manifold instance. |
required |
num_features
|
int
|
Number of channels (Poincaré ball dimension per pixel). |
required |
use_running_average
|
bool
|
If True, use running statistics instead of batch statistics (default: False). Overridable at call time. |
False
|
momentum
|
float
|
EMA momentum for running statistics (default: 0.9). |
0.9
|
eps
|
float
|
Numerical stability floor (default: 1e-6). |
1e-06
|
References
van Spengler et al. "Poincaré ResNet." ICML 2023.
Source code in hyperbolix/nn_layers/poincare_batchnorm.py
__call__ ¶
__call__(
x_BHWC: Float[Array, "B H W C"],
c: float = 1.0,
use_running_average: bool | None = None,
) -> Float[Array, "B H W C"]
Apply Poincaré batch normalization.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x_BHWC
|
(Array, shape(B, H, W, C))
|
Tangent-space input features. |
required |
c
|
float
|
Curvature (positive, default: 1.0). |
1.0
|
use_running_average
|
bool or None
|
Override constructor setting. None uses constructor value. |
None
|
Returns:
| Type | Description |
|---|---|
(Array, shape(B, H, W, C))
|
Normalized tangent-space output. |
Source code in hyperbolix/nn_layers/poincare_batchnorm.py
158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 | |
PoincareBatchNorm2D operates in tangent space (matching conv layer I/O), mapping to the manifold internally for geometric operations: Einstein midpoint, Fréchet variance, parallel transport, and variance rescaling. Use between Poincaré convolution layers following the reference ResNet pattern: conv → bn → relu → conv → bn → skip.
hyperbolix.nn_layers.poincare_midpoint ¶
poincare_midpoint(
x_NC: Float[Array, "N C"],
manifold: Poincare,
c: float,
eps: float = 1e-06,
) -> Float[Array, C]
Compute Einstein midpoint of Poincaré ball points.
Uses conformal factor weighting: midpoint = Σ(λ²·x) / Σ(λ²), then projects onto the ball.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x_NC
|
(Array, shape(N, C))
|
Points on the Poincaré ball. |
required |
manifold
|
Poincare
|
Poincaré manifold instance. |
required |
c
|
float
|
Curvature (positive). |
required |
eps
|
float
|
Numerical stability floor (default: 1e-6). |
1e-06
|
Returns:
| Type | Description |
|---|---|
(Array, shape(C))
|
Einstein midpoint on the Poincaré ball. |
Source code in hyperbolix/nn_layers/poincare_batchnorm.py
hyperbolix.nn_layers.frechet_variance ¶
frechet_variance(
x_NC: Float[Array, "N C"],
mean_C: Float[Array, C],
manifold: Poincare,
c: float,
) -> Float[Array, ""]
Compute Fréchet variance: mean squared geodesic distance to mean.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x_NC
|
(Array, shape(N, C))
|
Points on the Poincaré ball. |
required |
mean_C
|
(Array, shape(C))
|
Mean point on the Poincaré ball. |
required |
manifold
|
Poincare
|
Poincaré manifold instance. |
required |
c
|
float
|
Curvature (positive). |
required |
Returns:
| Type | Description |
|---|---|
(Array, scalar)
|
Mean of squared geodesic distances. |
Source code in hyperbolix/nn_layers/poincare_batchnorm.py
Poincaré BatchNorm Example¶
from hyperbolix.nn_layers import HypConv2DPoincare, PoincareBatchNorm2D
from hyperbolix.manifolds import Poincare
from flax import nnx
import jax
poincare = Poincare()
# ResNet-style block: conv → bn → relu
conv = HypConv2DPoincare(
manifold_module=poincare,
in_channels=16, out_channels=16,
kernel_size=3, rngs=nnx.Rngs(0),
)
bn = PoincareBatchNorm2D(poincare, num_features=16)
# Input: tangent-space features (B, H, W, C)
x = jax.random.normal(jax.random.PRNGKey(0), (4, 8, 8, 16)) * 0.1
# Training: use_running_average=False (default)
h = conv(x, c=0.1)
h = bn(h, c=0.1)
h = jax.nn.relu(h)
# Evaluation: use_running_average=True
h_eval = conv(x, c=0.1)
h_eval = bn(h_eval, c=0.1, use_running_average=True)
h_eval = jax.nn.relu(h_eval)
FGGConv2D (Klis et al. 2026)¶
hyperbolix.nn_layers.FGGConv2D ¶
FGGConv2D(
manifold_module: Hyperboloid,
in_channels: int,
out_channels: int,
kernel_size: int | tuple[int, int],
*,
rngs: Rngs,
stride: int | tuple[int, int] = 1,
padding: str = "SAME",
pad_mode: str = "origin",
activation: Callable | None = None,
reset_params: str = "lorentz_kaiming",
use_weight_norm: bool = False,
init_bias: float = 0.5,
eps: float = 1e-07,
)
Bases: Module
Fast and Geometrically Grounded Lorentz 2D convolutional layer.
Uses HCat (Lorentz direct concatenation) to combine receptive field points, then applies FGGLinear for the channel mixing. This matches the reference implementation pattern from Klis et al. 2026.
Computation steps: 1) Extract receptive field patches, pad with manifold origin if needed 2) Apply HCat (Lorentz direct concatenation) to combine patch points 3) Pass through FGGLinear for channel transformation
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
manifold_module
|
Hyperboloid
|
Class-based Hyperboloid manifold instance. |
required |
in_channels
|
int
|
Input ambient channels (D_in + 1), including time component. |
required |
out_channels
|
int
|
Output ambient channels (D_out + 1), including time component. |
required |
kernel_size
|
int or tuple[int, int]
|
Size of the convolutional kernel. |
required |
rngs
|
Rngs
|
Random number generators for parameter initialization. |
required |
stride
|
int or tuple[int, int]
|
Stride of the convolution (default: 1). |
1
|
padding
|
str
|
Padding mode: |
'SAME'
|
pad_mode
|
str
|
How to fill padding pixels: |
'origin'
|
activation
|
Callable or None
|
Euclidean activation for the FGGLinear (default: None). |
None
|
reset_params
|
str
|
Weight init for FGGLinear: |
'lorentz_kaiming'
|
use_weight_norm
|
bool
|
Weight normalization in FGGLinear (default: False). |
False
|
init_bias
|
float
|
Initial bias for FGGLinear (default: 0.5). |
0.5
|
eps
|
float
|
Numerical stability floor (default: 1e-7). |
1e-07
|
References
Klis et al. "Fast and Geometrically Grounded Lorentz Neural Networks" (2026).
Source code in hyperbolix/nn_layers/hyperboloid_conv.py
618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 | |
__call__ ¶
__call__(
x: Float[Array, "batch height width in_channels"],
c: float = 1.0,
) -> Float[
Array, "batch out_height out_width out_channels"
]
Forward pass through the FGG convolutional layer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
(Array, shape(B, H, W, in_channels))
|
Input feature map on the hyperboloid. |
required |
c
|
float
|
Curvature parameter (default: 1.0). |
1.0
|
Returns:
| Name | Type | Description |
|---|---|---|
out |
Array, shape (B, H', W', out_channels)
|
Output feature map on the hyperboloid. |
Source code in hyperbolix/nn_layers/hyperboloid_conv.py
FGGConv2D combines HCat patch extraction with FGGLinear for channel mixing. Unlike HypConv2DHyperboloid, it uses the FGG spacelike-V construction, achieving linear growth of hyperbolic distance rather than logarithmic. Supports manifold-origin padding (pad_mode="origin") matching the reference implementation.
| Feature | FGGConv2D | HypConv2DHyperboloid | HypConv2DHyperboloidPP |
|---|---|---|---|
| Linear layer | FGGLinear (V-matrix) | HypLinearHyperboloidFHCNN | HypLinearHyperboloidPP (MLR) |
| Distance growth | Linear | Logarithmic | Logarithmic |
| Default padding | Manifold origin | Edge replication | Edge replication |
| Weight norm | Optional (use_weight_norm) |
No | No |
LorentzConv2D (HRC-Based)¶
hyperbolix.nn_layers.LorentzConv2D ¶
LorentzConv2D(
in_channels: int,
out_channels: int,
kernel_size: int | tuple[int, int],
*,
rngs: Rngs,
stride: int | tuple[int, int] = 1,
padding: str = "SAME",
)
Bases: Module
Lorentz 2D Convolutional Layer using the Hyperbolic Layer (HL) approach.
This layer applies convolution to the space-like components of Lorentzian vectors and reconstructs the time-like component to maintain the manifold constraint. This is equivalent to an HRC (Hyperbolic Regularization Component) wrapper around a standard Conv2D.
Computation steps: 1) Extract space-like components x_s from input x = [x_t, x_s]^T 2) Apply Euclidean convolution: y_s = Conv2D(x_s) 3) Reconstruct time component: y_t = sqrt(||y_s||^2 + 1/c) 4) Return y = [y_t, y_s]^T
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_channels
|
int
|
Number of input channels (ambient dimension, including time component) |
required |
out_channels
|
int
|
Number of output channels (ambient dimension, including time component) |
required |
kernel_size
|
int or tuple[int, int]
|
Size of the convolutional kernel |
required |
rngs
|
Rngs
|
Random number generators for parameter initialization |
required |
stride
|
int or tuple[int, int]
|
Stride of the convolution (default: 1) |
1
|
padding
|
str or int or tuple
|
Padding mode: 'SAME', 'VALID', or explicit padding (default: 'SAME') |
'SAME'
|
Notes
This implementation follows the Hyperbolic Layer (HL) approach from "Fully Hyperbolic Convolutional Neural Networks for Computer Vision".
The layer operates only on space-like components, making it more computationally efficient than the HCat-based approach (HypConv2DHyperboloid), though it doesn't perform true hyperbolic convolution. Instead, it applies Euclidean operations to spatial components and reconstructs the time component.
See Also
hypformer.hrc : Core HRC function this layer is based on HypConv2DHyperboloid : Full hyperbolic convolution using HCat concatenation
References
He, Neil, Menglin Yang, and Rex Ying. "Lorentzian residual neural networks." Proceedings of the 31st ACM SIGKDD Conference on Knowledge Discovery and Data Mining V. 1. 2025.
Source code in hyperbolix/nn_layers/hyperboloid_conv.py
__call__ ¶
__call__(
x: Float[Array, "batch height width in_channels"],
c: float = 1.0,
) -> Float[
Array, "batch out_height out_width out_channels"
]
Forward pass through the Lorentz convolutional layer.
This layer is a specific instance of the Hyperbolic Regularization Component (HRC) where the regularization function f_r is a 2D convolution. The HRC pattern: 1. Extracts space components 2. Applies Euclidean convolution 3. Reconstructs time component using Lorentz constraint
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array of shape (batch, height, width, in_channels)
|
Input feature map where x[..., 0] is time component and x[..., 1:] are space components on the Lorentz manifold |
required |
c
|
float
|
Manifold curvature parameter (default: 1.0) |
1.0
|
Returns:
| Name | Type | Description |
|---|---|---|
out |
Array of shape (batch, out_height, out_width, out_channels)
|
Output feature map on the Lorentz manifold |
Notes
This implementation uses the HRC function from hypformer.py, demonstrating that LorentzConv2D (from LResNet) and HRC (from Hypformer) are mathematically equivalent approaches to adapting Euclidean operations for hyperbolic geometry.
Source code in hyperbolix/nn_layers/hyperboloid_conv.py
LorentzConv2D provides a simpler, more efficient alternative to HCat-based convolutions by using the Hyperbolic Regularization Component (HRC) pattern from the Hypformer paper.
Key Differences from HypConv2DHyperboloid:
| Feature | HypConv2DHyperboloid (HCat+FHCNN) | HypConv2DHyperboloidPP (HCat+HNN++) | LorentzConv2D (HRC) |
|---|---|---|---|
| Method | HCat + FHCNN linear | HCat + MLR + sinh diffeomorphism | Euclidean conv on space components |
| Dimension | Grows: (d-1)×N+1 |
Grows: (d-1)×N+1 |
Preserved |
| Speed | Slower (~80s/epoch) | Similar to FHCNN | 2.5x faster (~32s/epoch) |
| Accuracy | Higher (~71% on MNIST) | Better gradient flow (HNN++) | Lower (~46% on MNIST) |
| Use Case | HCat accuracy | Deep HCat networks | Speed/memory efficiency |
Theoretical Connection:
LorentzConv2D implements the Hyperbolic Layer (HL) pattern from LResNet, which is mathematically equivalent to the Hyperbolic Regularization Component (HRC) from Hypformer:
# Both approaches:
# 1. Extract space components: x_s = x[..., 1:]
# 2. Apply Euclidean function: y_s = f(x_s)
# 3. Reconstruct time: y_t = sqrt(||y_s||^2 + 1/c)
Usage Example:
from hyperbolix.nn_layers import LorentzConv2D
from flax import nnx
import jax.numpy as jnp
# Create efficient hyperbolic convolution
conv = LorentzConv2D(
in_channels=33, # Including time component
out_channels=65, # Including time component
kernel_size=3,
stride=2,
padding="SAME",
rngs=nnx.Rngs(0)
)
# Input: points on Lorentz manifold (batch, height, width, in_channels)
x = jnp.ones((8, 28, 28, 33))
x_space = x[..., 1:]
x_time = jnp.sqrt(jnp.sum(x_space**2, axis=-1, keepdims=True) + 1.0)
x = jnp.concatenate([x_time, x_space], axis=-1)
# Forward pass
output = conv(x, c=1.0)
print(output.shape) # (8, 14, 14, 65) - dimensions preserved!
When to Use LorentzConv2D
Choose LorentzConv2D when:
- Speed and memory efficiency are priorities
- Working with resource-constrained environments
- Acceptable accuracy trade-off for 2.5x speedup
Choose HypConv2DHyperboloid or HypConv2DHyperboloidPP when:
- Maximum accuracy is required
- Willing to accept slower training and dimensional growth
- Use HypConv2DHyperboloidPP for deeper networks (better gradient flow via HNN++)
Hypformer Components¶
The Hyperbolic Transformation Component (HTC) and Hyperbolic Regularization Component (HRC) from the Hypformer paper provide general-purpose wrappers for adapting Euclidean operations to hyperbolic geometry with curvature-change support.
Core Functions¶
hyperbolix.nn_layers.hrc ¶
hrc(
x: Float[Array, "... dim_plus_1"],
f_r: Callable[[Float[Array, ...]], Float[Array, ...]],
c_in: float,
c_out: float,
eps: float = 1e-07,
) -> Float[Array, "... out_dim_plus_1"]
Hyperbolic Regularization Component.
Applies a Euclidean regularization/activation function f_r to the spatial components of hyperboloid points, then maps the result to the hyperboloid with curvature c_out.
Mathematical formula: space = sqrt(c_in/c_out) * f_r(x[..., 1:]) time = sqrt(||space||^2 + 1/c_out) output = [time, space]
When c_in = c_out = c, this reduces to: output = [sqrt(||f_r(x_s)||^2 + 1/c), f_r(x_s)] which is the pattern used by curvature-preserving hyperboloid activations.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array of shape (..., dim+1)
|
Input point(s) on the hyperboloid manifold with curvature c_in. The first element is the time-like component, remaining are spatial. |
required |
f_r
|
Callable
|
Euclidean function to apply to spatial components. Can be any activation, normalization, dropout, etc. Takes spatial components and returns transformed spatial components (may change dimension). |
required |
c_in
|
float
|
Input curvature parameter (must be positive, c > 0). |
required |
c_out
|
float
|
Output curvature parameter (must be positive, c > 0). |
required |
eps
|
float
|
Small value for numerical stability (default: 1e-7). |
1e-07
|
Returns:
| Name | Type | Description |
|---|---|---|
y |
Array of shape (..., out_dim+1)
|
Output point(s) on the hyperboloid manifold with curvature c_out. |
Notes
- f_r operates only on spatial components x[..., 1:], not the time component
- The time component is reconstructed using the hyperboloid constraint: -x₀² + ||x_rest||² = -1/c_out
- This avoids expensive exp/log maps while maintaining mathematical correctness
- The spatial scaling factor sqrt(c_in/c_out) ensures proper curvature transformation
See Also
htc : Hyperbolic Transformation Component for full-point operations.
References
Hypformer paper (citation to be added)
Examples:
>>> import jax.numpy as jnp
>>> from hyperbolix.nn_layers.hyperboloid_core import hrc
>>> from hyperbolix.manifolds import Hyperboloid
>>>
>>> # Create a point on the hyperboloid
>>> manifold = Hyperboloid()
>>> x = jnp.array([1.05, 0.1, -0.2, 0.15])
>>> x = manifold.proj(x, c=1.0)
>>>
>>> # Apply HRC with ReLU (curvature-preserving)
>>> y = hrc(x, jax.nn.relu, c_in=1.0, c_out=1.0)
>>>
>>> # Apply HRC with curvature change
>>> y = hrc(x, jax.nn.relu, c_in=1.0, c_out=2.0)
>>>
>>> # Custom activation
>>> def custom_act(z):
... return jax.nn.gelu(z) * 0.5
>>> y = hrc(x, custom_act, c_in=1.0, c_out=0.5)
Source code in hyperbolix/nn_layers/hyperboloid_core.py
219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 | |
hyperbolix.nn_layers.htc ¶
htc(
x: Float[Array, "... in_dim_plus_1"],
f_t: Callable[[Float[Array, ...]], Float[Array, ...]],
c_in: float,
c_out: float,
eps: float = 1e-07,
) -> Float[Array, "... out_dim_plus_1"]
Hyperbolic Transformation Component.
Applies a Euclidean linear transformation f_t to the full hyperboloid point (including time component), then maps the result to the hyperboloid with curvature c_out.
Mathematical formula: space = sqrt(c_in/c_out) * f_t(x) time = sqrt(||space||^2 + 1/c_out) output = [time, space]
where f_t takes the full (dim+1)-dimensional input and produces the output spatial components.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array of shape (..., in_dim+1)
|
Input point(s) on the hyperboloid manifold with curvature c_in. All components (time and spatial) are passed to f_t. |
required |
f_t
|
Callable
|
Euclidean linear transformation applied to the full input. Takes (in_dim+1)-dimensional input and produces out_dim-dimensional output (which becomes the spatial components of the output). |
required |
c_in
|
float
|
Input curvature parameter (must be positive, c > 0). |
required |
c_out
|
float
|
Output curvature parameter (must be positive, c > 0). |
required |
eps
|
float
|
Small value for numerical stability (default: 1e-7). |
1e-07
|
Returns:
| Name | Type | Description |
|---|---|---|
y |
Array of shape (..., out_dim+1)
|
Output point(s) on the hyperboloid manifold with curvature c_out. |
Notes
- Unlike HRC, f_t operates on the full point including the time component
- f_t's output dimension determines the output spatial dimension
- This is typically used for learnable linear transformations
- The spatial scaling factor sqrt(c_in/c_out) ensures proper curvature transformation
See Also
hrc : Hyperbolic Regularization Component for spatial-only operations. HTCLinear : Module wrapper for htc with learnable linear transformation.
References
Hypformer paper (citation to be added)
Examples:
>>> import jax
>>> import jax.numpy as jnp
>>> from hyperbolix.nn_layers.hyperboloid_core import htc
>>> from hyperbolix.manifolds import Hyperboloid
>>>
>>> # Create a point on the hyperboloid
>>> manifold = Hyperboloid()
>>> x = jnp.array([1.05, 0.1, -0.2, 0.15])
>>> x = manifold.proj(x, c=1.0)
>>>
>>> # Define a linear transformation
>>> W = jax.random.normal(jax.random.PRNGKey(0), (3, 4))
>>> def linear(z):
... return z @ W.T
>>>
>>> # Apply HTC
>>> y = htc(x, linear, c_in=1.0, c_out=2.0)
>>> y.shape
(4,) # (3 spatial + 1 time)
Source code in hyperbolix/nn_layers/hyperboloid_core.py
315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 | |
HTC/HRC Modules¶
hyperbolix.nn_layers.HTCLinear ¶
HTCLinear(
in_features: int,
out_features: int,
*,
rngs: Rngs,
use_bias: bool = True,
init_bound: float = 0.02,
eps: float = 1e-07,
)
Bases: Module
Hyperbolic Transformation Component with learnable linear transformation.
This module wraps a Euclidean linear layer with the HTC operation, enabling learnable transformations between hyperboloid manifolds with different curvatures.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_features
|
int
|
Input feature dimension (full hyperboloid dimension, including time component). |
required |
out_features
|
int
|
Output spatial dimension (time component is reconstructed automatically). |
required |
rngs
|
Rngs
|
Random number generators for parameter initialization. |
required |
use_bias
|
bool
|
Whether to include a bias term (default: True). |
True
|
init_bound
|
float
|
Bound for uniform weight initialization. Weights are initialized from Uniform(-init_bound, init_bound). Small values keep initial outputs close to the hyperboloid origin for stable training (default: 0.02). |
0.02
|
eps
|
float
|
Small value for numerical stability (default: 1e-7). |
1e-07
|
Attributes:
| Name | Type | Description |
|---|---|---|
kernel |
Param
|
Weight matrix of shape (in_features, out_features). |
bias |
Param or None
|
Bias vector of shape (out_features,) if use_bias=True, else None. |
eps |
float
|
Numerical stability parameter. |
Notes
Weight Initialization: This layer uses small uniform initialization U(-0.02, 0.02) by default, matching the initialization used by FHNN/FHCNN layers. Standard deep learning initializations (Xavier, Lecun) produce weights that are too large for hyperbolic operations, causing gradient explosion and training instability.
See Also
hyperbolix.nn_layers.hyperboloid_core.htc : Core HTC function for functional transformations. HypLinearHyperboloidFHCNN : Alternative hyperbolic linear layer with sigmoid scaling.
References
Hypformer paper (citation to be added)
Examples:
>>> from flax import nnx
>>> from hyperbolix.nn_layers import HTCLinear
>>> from hyperbolix.manifolds import Hyperboloid
>>>
>>> # Create layer
>>> layer = HTCLinear(in_features=5, out_features=8, rngs=nnx.Rngs(0))
>>>
>>> # Forward pass
>>> manifold = Hyperboloid()
>>> x = jnp.ones((32, 5)) # batch of 32 points
>>> x = jax.vmap(manifold.proj, in_axes=(0, None))(x, 1.0)
>>> y = layer(x, c_in=1.0, c_out=2.0)
>>> y.shape
(32, 9) # 8 spatial + 1 time
Source code in hyperbolix/nn_layers/hyperboloid_linear.py
__call__ ¶
__call__(
x: Float[Array, "batch in_features"],
c_in: float = 1.0,
c_out: float = 1.0,
) -> Float[Array, "batch out_features_plus_1"]
Apply HTC linear transformation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array of shape (batch, in_features)
|
Input points on hyperboloid with curvature c_in. |
required |
c_in
|
float
|
Input curvature (default: 1.0). |
1.0
|
c_out
|
float
|
Output curvature (default: 1.0). |
1.0
|
Returns:
| Name | Type | Description |
|---|---|---|
y |
Array of shape (batch, out_features+1)
|
Output points on hyperboloid with curvature c_out. |
Source code in hyperbolix/nn_layers/hyperboloid_linear.py
hyperbolix.nn_layers.HRCBatchNorm ¶
HRCBatchNorm(
num_features: int,
*,
rngs: Rngs,
momentum: float = 0.99,
epsilon: float = 1e-05,
eps: float = 1e-07,
)
Bases: Module
Hyperbolic Regularization Component with batch normalization.
Applies batch normalization to spatial components of hyperboloid points, then reconstructs the time component for the output curvature.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
num_features
|
int
|
Number of spatial features to normalize. |
required |
rngs
|
Rngs
|
Random number generators for parameter initialization. |
required |
momentum
|
float
|
Momentum for running statistics (default: 0.99). |
0.99
|
epsilon
|
float
|
Small value for numerical stability in batch norm (default: 1e-5). |
1e-05
|
eps
|
float
|
Small value for numerical stability in HRC (default: 1e-7). |
1e-07
|
Attributes:
| Name | Type | Description |
|---|---|---|
bn |
BatchNorm
|
Flax batch normalization. |
eps |
float
|
Numerical stability parameter for HRC. |
Notes
Training vs Evaluation Mode: During training (use_running_average=False), batch norm computes statistics from the current batch and updates running averages. During evaluation (use_running_average=True), it uses the accumulated running statistics.
Examples:
>>> from flax import nnx
>>> from hyperbolix.nn_layers import HRCBatchNorm
>>>
>>> # Create batch norm layer
>>> bn = HRCBatchNorm(num_features=64, rngs=nnx.Rngs(0))
>>>
>>> # Training mode
>>> y_train = bn(x, c_in=1.0, c_out=2.0, use_running_average=False)
>>>
>>> # Evaluation mode
>>> y_eval = bn(x, c_in=1.0, c_out=2.0, use_running_average=True)
Source code in hyperbolix/nn_layers/hyperboloid_regularization.py
__call__ ¶
__call__(
x: Float[Array, "batch dim_plus_1"],
c_in: float = 1.0,
c_out: float = 1.0,
use_running_average: bool | None = None,
) -> Float[Array, "batch dim_plus_1"]
Apply HRC batch normalization.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array of shape (batch, dim+1)
|
Input points on hyperboloid with curvature c_in. |
required |
c_in
|
float
|
Input curvature (default: 1.0). |
1.0
|
c_out
|
float
|
Output curvature (default: 1.0). |
1.0
|
use_running_average
|
bool or None
|
If True, use running statistics (eval mode). If False, use batch statistics (train mode). If None, use the default set during initialization. |
None
|
Returns:
| Name | Type | Description |
|---|---|---|
y |
Array of shape (batch, dim+1)
|
Output points on hyperboloid with curvature c_out. |
Source code in hyperbolix/nn_layers/hyperboloid_regularization.py
hyperbolix.nn_layers.HRCLayerNorm ¶
Bases: Module
Hyperbolic Regularization Component with layer normalization.
Applies layer normalization to spatial components of hyperboloid points, then reconstructs the time component for the output curvature.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
num_features
|
int
|
Number of spatial features to normalize. |
required |
rngs
|
Rngs
|
Random number generators for parameter initialization. |
required |
epsilon
|
float
|
Small value for numerical stability in layer norm (default: 1e-5). |
1e-05
|
eps
|
float
|
Small value for numerical stability in HRC (default: 1e-7). |
1e-07
|
Attributes:
| Name | Type | Description |
|---|---|---|
ln |
LayerNorm
|
Flax layer normalization. |
eps |
float
|
Numerical stability parameter for HRC. |
Examples:
>>> from flax import nnx
>>> from hyperbolix.nn_layers import HRCLayerNorm
>>>
>>> ln = HRCLayerNorm(num_features=64, rngs=nnx.Rngs(0))
>>> y = ln(x, c_in=1.0, c_out=2.0)
Source code in hyperbolix/nn_layers/hyperboloid_regularization.py
__call__ ¶
__call__(
x: Float[Array, "batch dim_plus_1"],
c_in: float = 1.0,
c_out: float = 1.0,
) -> Float[Array, "batch dim_plus_1"]
Apply HRC layer normalization.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array of shape (batch, dim+1)
|
Input points on hyperboloid with curvature c_in. |
required |
c_in
|
float
|
Input curvature (default: 1.0). |
1.0
|
c_out
|
float
|
Output curvature (default: 1.0). |
1.0
|
Returns:
| Name | Type | Description |
|---|---|---|
y |
Array of shape (batch, dim+1)
|
Output points on hyperboloid with curvature c_out. |
Source code in hyperbolix/nn_layers/hyperboloid_regularization.py
hyperbolix.nn_layers.HRCRMSNorm ¶
Bases: Module
Hyperbolic Regularization Component with RMS normalization.
Applies RMS normalization to spatial components of hyperboloid points, then reconstructs the time component for the output curvature. RMSNorm is a simpler and faster variant of LayerNorm that only normalizes by the root mean square without centering.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
num_features
|
int
|
Number of spatial features to normalize. |
required |
rngs
|
Rngs
|
Random number generators for parameter initialization. |
required |
epsilon
|
float
|
Small value for numerical stability in RMS norm (default: 1e-6). |
1e-06
|
eps
|
float
|
Small value for numerical stability in HRC (default: 1e-7). |
1e-07
|
Attributes:
| Name | Type | Description |
|---|---|---|
rms |
RMSNorm
|
Flax RMS normalization. |
eps |
float
|
Numerical stability parameter for HRC. |
Notes
RMSNorm Formula: y = x / RMS(x) * scale, where RMS(x) = sqrt(mean(x^2) + epsilon)
RMSNorm is more efficient than LayerNorm as it skips mean subtraction and is commonly used in modern transformers (e.g., LLaMA, GPT-NeoX).
Examples:
>>> from flax import nnx
>>> from hyperbolix.nn_layers import HRCRMSNorm
>>>
>>> rms = HRCRMSNorm(num_features=64, rngs=nnx.Rngs(0))
>>> y = rms(x, c_in=1.0, c_out=2.0)
Source code in hyperbolix/nn_layers/hyperboloid_regularization.py
__call__ ¶
__call__(
x: Float[Array, "batch dim_plus_1"],
c_in: float = 1.0,
c_out: float = 1.0,
) -> Float[Array, "batch dim_plus_1"]
Apply HRC RMS normalization.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array of shape (batch, dim+1)
|
Input points on hyperboloid with curvature c_in. |
required |
c_in
|
float
|
Input curvature (default: 1.0). |
1.0
|
c_out
|
float
|
Output curvature (default: 1.0). |
1.0
|
Returns:
| Name | Type | Description |
|---|---|---|
y |
Array of shape (batch, dim+1)
|
Output points on hyperboloid with curvature c_out. |
Source code in hyperbolix/nn_layers/hyperboloid_regularization.py
hyperbolix.nn_layers.HRCDropout ¶
Bases: Module
Hyperbolic Regularization Component with dropout.
Applies dropout to spatial components of hyperboloid points, then reconstructs the time component for the output curvature.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
rate
|
float
|
Dropout probability (fraction of units to drop). |
required |
rngs
|
Rngs
|
Random number generators for dropout. |
required |
eps
|
float
|
Small value for numerical stability (default: 1e-7). |
1e-07
|
Attributes:
| Name | Type | Description |
|---|---|---|
dropout |
Dropout
|
Flax dropout layer. |
eps |
float
|
Numerical stability parameter. |
Examples:
>>> from flax import nnx
>>> from hyperbolix.nn_layers import HRCDropout
>>>
>>> dropout = HRCDropout(rate=0.1, rngs=nnx.Rngs(dropout=42))
>>> y = dropout(x, c_in=1.0, c_out=1.0, deterministic=False)
Source code in hyperbolix/nn_layers/hyperboloid_regularization.py
__call__ ¶
__call__(
x: Float[Array, "batch dim_plus_1"],
c_in: float = 1.0,
c_out: float = 1.0,
deterministic: bool = False,
) -> Float[Array, "batch dim_plus_1"]
Apply HRC dropout.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array of shape (batch, dim+1)
|
Input points on hyperboloid with curvature c_in. |
required |
c_in
|
float
|
Input curvature (default: 1.0). |
1.0
|
c_out
|
float
|
Output curvature (default: 1.0). |
1.0
|
deterministic
|
bool
|
If True, no dropout is applied (for evaluation mode). |
False
|
Returns:
| Name | Type | Description |
|---|---|---|
y |
Array of shape (batch, dim+1)
|
Output points on hyperboloid with curvature c_out. |
Source code in hyperbolix/nn_layers/hyperboloid_regularization.py
hyperbolix.nn_layers.FGGMeanOnlyBatchNorm ¶
Bases: Module
Mean-only batch normalization for FGG-LNN layers.
Subtracts the batch mean from spatial components without dividing by variance, then adds a learnable bias. This avoids the instability that standard BatchNorm causes in hyperbolic space (exponentially large/small variance denominators) while still centering activations for stable training.
Designed to pair with weight normalization (FGGLinear(use_weight_norm=True)),
which handles magnitude control. Together they replace standard BatchNorm in
fully hyperbolic FGG-LNN networks.
Formula (applied to spatial components via HRC):
y = x - E[x] + bias
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
num_features
|
int
|
Number of spatial features (D, not D+1). |
required |
momentum
|
float
|
Exponential moving average momentum for running mean (default: 0.99).
Update: |
0.99
|
eps
|
float
|
Numerical stability floor for HRC time reconstruction (default: 1e-7). |
1e-07
|
References
Klis et al. "Fast and Geometrically Grounded Lorentz Neural Networks" (2026), §4.4. Salimans & Kingma "Weight Normalization" (2016).
Source code in hyperbolix/nn_layers/hyperboloid_regularization.py
__call__ ¶
__call__(
x: Float[Array, "batch dim_plus_1"],
c_in: float = 1.0,
c_out: float = 1.0,
use_running_average: bool = False,
) -> Float[Array, "batch dim_plus_1"]
Apply mean-only batch normalization via HRC.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
(Array, shape(B, D + 1) or (B, ..., D + 1))
|
Input points on hyperboloid with curvature |
required |
c_in
|
float
|
Input curvature (default: 1.0). |
1.0
|
c_out
|
float
|
Output curvature (default: 1.0). |
1.0
|
use_running_average
|
bool
|
If True, use running mean (eval mode). If False, compute from batch and update running mean (train mode). Default: False. |
False
|
Returns:
| Type | Description |
|---|---|
Array, same shape as input
|
Points on hyperboloid with curvature |
Source code in hyperbolix/nn_layers/hyperboloid_regularization.py
Hypformer Example¶
from hyperbolix.nn_layers import HTCLinear, HRCBatchNorm, HRCRMSNorm, hrc_relu
from hyperbolix.manifolds import Hyperboloid
from flax import nnx
import jax
import jax.numpy as jnp
hyperboloid = Hyperboloid()
class HypformerBlock(nnx.Module):
"""Example using HTC/HRC components with curvature change."""
def __init__(self, in_dim, out_dim, rngs):
self.linear = HTCLinear(
in_features=in_dim,
out_features=out_dim,
rngs=rngs
)
# Can use BatchNorm or RMSNorm for normalization
self.bn = HRCBatchNorm(num_features=out_dim, rngs=rngs)
# self.rms = HRCRMSNorm(num_features=out_dim, rngs=rngs) # Alternative: faster, simpler
def __call__(self, x, c_in=1.0, c_out=2.0, use_running_average=False):
# Linear transformation with curvature change
x = self.linear(x, c_in=c_in, c_out=c_out)
# Batch normalization (curvature-preserving)
x = self.bn(x, c_in=c_out, c_out=c_out,
use_running_average=use_running_average)
# Activation (curvature-preserving)
x = hrc_relu(x, c_in=c_out, c_out=c_out)
return x
# Create and use block
block = HypformerBlock(in_dim=33, out_dim=64, rngs=nnx.Rngs(0))
# Input on hyperboloid with curvature 1.0
x = jax.random.normal(jax.random.PRNGKey(1), (32, 33))
x_proj = jax.vmap(hyperboloid.proj, in_axes=(0, None))(x, 1.0)
# Transform to curvature 2.0
output = block(x_proj, c_in=1.0, c_out=2.0)
print(output.shape) # (32, 65) - 64 spatial + 1 time
HTC vs HRC
HRC (Hyperbolic Regularization Component):
- Applies Euclidean function
f_rto space components only - Use for: activations, normalization, dropout, convolutions
- Formula:
space = f_r(x_s),time = sqrt(||space||^2 + 1/c_out)
HTC (Hyperbolic Transformation Component):
- Applies Euclidean function
f_tto full point (time + space) - Use for: learnable linear transformations
- Formula:
space = f_t(x),time = sqrt(||space||^2 + 1/c_out)
Both support curvature changes (c_in → c_out) for flexible network design.
Attention Layers¶
Three hyperbolic attention variants from the Hypformer paper (Yang et al. 2025, Section 4.3). All operate on hyperboloid points and support independent curvatures for input (c_in), attention computation (c_attn), and output (c_out). All variants support causal (autoregressive) masking via the causal=True flag, making them suitable for language models and sequence generation tasks.
Core Utilities¶
hyperbolix.nn_layers.spatial_to_hyperboloid ¶
spatial_to_hyperboloid(
spatial: Float[Array, "... D"],
c_in: float,
c_out: float,
eps: float = 1e-07,
) -> Float[Array, "... D_plus_1"]
Scale spatial components and reconstruct time to produce a hyperboloid point.
Extracts the common tail of hrc/htc: curvature scaling + time
reconstruction via the hyperboloid constraint.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
spatial
|
(Array, shape(..., D))
|
Spatial components (no time coordinate). |
required |
c_in
|
float
|
Source curvature (positive). |
required |
c_out
|
float
|
Target curvature (positive). |
required |
eps
|
float
|
Numerical stability floor (default: 1e-7). |
1e-07
|
Returns:
| Type | Description |
|---|---|
(Array, shape(..., D + 1))
|
Points on the hyperboloid with curvature |
Source code in hyperbolix/nn_layers/hyperboloid_core.py
hyperbolix.nn_layers.lorentz_midpoint ¶
lorentz_midpoint(
points: Float[Array, "... M A"],
weights: Float[Array, "... N M"],
c: float,
eps: float = 1e-07,
) -> Float[Array, "... N A"]
Weighted Lorentzian midpoint over M points.
Generalises :func:lorentz_residual (which handles two points) to an
arbitrary weighted combination used by full attention aggregation and
multi-head averaging.
Formula (HELM, Chen et al. 2024):
h = weights @ points (weighted sum)
mu = h / (sqrt(c) * ||h||_L)
where ||h||_L = sqrt(-<h,h>_L) and <h,h>_L = -h_0^2 + ||h_s||^2.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
points
|
(Array, shape(..., M, A))
|
Points on the hyperboloid with curvature |
required |
weights
|
(Array, shape(..., N, M))
|
Combination weights (e.g. attention weights, uniform |
required |
c
|
float
|
Curvature parameter (positive). |
required |
eps
|
float
|
Numerical stability floor (default: 1e-7). |
1e-07
|
Returns:
| Type | Description |
|---|---|
(Array, shape(..., N, A))
|
Midpoints on the hyperboloid with curvature |
Source code in hyperbolix/nn_layers/hyperboloid_core.py
hyperbolix.nn_layers.focus_transform ¶
focus_transform(
x_D: Float[Array, "... D"],
temperature: Float[Array, ""],
power: float,
eps: float = 1e-07,
) -> Float[Array, "... D"]
Norm-preserving focus function (Eq 19, Hypformer).
Applies temperature-scaled ReLU followed by element-wise power sharpening while preserving the original norm.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x_D
|
(Array, shape(..., D))
|
Input spatial features. |
required |
temperature
|
scalar Array
|
Learnable temperature parameter. |
required |
power
|
float
|
Sharpening exponent ( |
required |
eps
|
float
|
Numerical stability floor (default: 1e-7). |
1e-07
|
Returns:
| Type | Description |
|---|---|
(Array, shape(..., D))
|
Focus-transformed features with |
Source code in hyperbolix/nn_layers/hyperboloid_attention.py
Attention Modules¶
hyperbolix.nn_layers.HyperbolicLinearAttention ¶
HyperbolicLinearAttention(
in_features: int,
out_features: int,
*,
num_heads: int = 1,
power: float = 2.0,
init_bound: float = 0.02,
eps: float = 1e-07,
rngs: Rngs,
)
Bases: _HyperbolicAttentionBase
Hyperbolic linear attention with focus function (Eq 14-19).
The paper's main contribution: O(N) attention using the kernel trick in the spatial domain of the hyperboloid. Focus function φ sharpens query and key.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_features
|
int
|
Ambient input dimension ( |
required |
out_features
|
int
|
Spatial output dimension per head. |
required |
num_heads
|
int
|
Number of attention heads (default: 1). |
1
|
power
|
float
|
Focus function sharpening exponent (default: 2.0). |
2.0
|
init_bound
|
float
|
Uniform init bound for weights (default: 0.02). |
0.02
|
eps
|
float
|
Numerical stability floor (default: 1e-7). |
1e-07
|
rngs
|
Rngs
|
Random number generators. |
required |
Source code in hyperbolix/nn_layers/hyperboloid_attention.py
hyperbolix.nn_layers.HyperbolicSoftmaxAttention ¶
HyperbolicSoftmaxAttention(
in_features: int,
out_features: int,
*,
num_heads: int = 1,
init_bound: float = 0.02,
eps: float = 1e-07,
rngs: Rngs,
)
Bases: _HyperbolicAttentionBase
Hyperbolic softmax attention in the spatial domain.
Standard scaled dot-product attention applied to spatial components of query, key, value, followed by the same HRC pipeline (residual + time calibration) as the linear variant.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_features
|
int
|
Ambient input dimension ( |
required |
out_features
|
int
|
Spatial output dimension per head. |
required |
num_heads
|
int
|
Number of attention heads (default: 1). |
1
|
init_bound
|
float
|
Uniform init bound for weights (default: 0.02). |
0.02
|
eps
|
float
|
Numerical stability floor (default: 1e-7). |
1e-07
|
rngs
|
Rngs
|
Random number generators. |
required |
Source code in hyperbolix/nn_layers/hyperboloid_attention.py
hyperbolix.nn_layers.HyperbolicFullAttention ¶
HyperbolicFullAttention(
in_features: int,
out_features: int,
*,
num_heads: int = 1,
init_bound: float = 0.02,
eps: float = 1e-07,
rngs: Rngs,
)
Bases: _HyperbolicAttentionBase
Full Lorentzian attention with midpoint aggregation.
Uses the Lorentzian inner product for similarity and weighted Lorentzian midpoint for aggregation — operating on full hyperboloid points throughout.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_features
|
int
|
Ambient input dimension ( |
required |
out_features
|
int
|
Spatial output dimension per head. |
required |
num_heads
|
int
|
Number of attention heads (default: 1). |
1
|
init_bound
|
float
|
Uniform init bound for weights (default: 0.02). |
0.02
|
eps
|
float
|
Numerical stability floor (default: 1e-7). |
1e-07
|
rngs
|
Rngs
|
Random number generators. |
required |
Source code in hyperbolix/nn_layers/hyperboloid_attention.py
Attention Example¶
import jax
import jax.numpy as jnp
from flax import nnx
from hyperbolix.manifolds import Hyperboloid
from hyperbolix.nn_layers import (
HyperbolicLinearAttention,
HyperbolicSoftmaxAttention,
HyperbolicFullAttention,
)
hyperboloid = Hyperboloid()
# Input: (batch, seq_len, ambient_dim) on the hyperboloid
B, N, A_in, D_out = 4, 8, 9, 8 # 8-dim spatial + 1 time
key = jax.random.PRNGKey(0)
spatial = jax.random.normal(key, (B, N, A_in - 1)) * 0.1
time = jnp.sqrt(jnp.sum(spatial**2, axis=-1, keepdims=True) + 1.0)
x = jnp.concatenate([time, spatial], axis=-1) # (B, N, A_in)
# O(N) linear attention with focus function — fastest, main Hypformer contribution
linear_attn = HyperbolicLinearAttention(
in_features=A_in,
out_features=D_out,
num_heads=2,
power=2.0,
rngs=nnx.Rngs(0),
)
y = linear_attn(x, c_in=1.0, c_attn=1.0, c_out=1.0)
print(y.shape) # (4, 8, 9) — D_out spatial + 1 time
# O(N²) softmax attention in the spatial domain
softmax_attn = HyperbolicSoftmaxAttention(
in_features=A_in,
out_features=D_out,
num_heads=2,
rngs=nnx.Rngs(0),
)
y = softmax_attn(x, c_in=1.0, c_attn=1.0, c_out=1.0)
# O(N²) full Lorentzian attention — operates entirely on hyperboloid points
full_attn = HyperbolicFullAttention(
in_features=A_in,
out_features=D_out,
num_heads=2,
rngs=nnx.Rngs(0),
)
y = full_attn(x, c_in=1.0, c_attn=1.0, c_out=1.0)
# Verify outputs are on the hyperboloid
for b in range(B):
for n in range(N):
assert hyperboloid.is_in_manifold(y[b, n], c=1.0, atol=1e-4)
Causal (Autoregressive) Attention¶
All three variants support causal masking via causal=True. Position n can only attend to positions m ≤ n, which is required for autoregressive tasks like language modeling.
# Bidirectional (default) — each token attends to all tokens
y_bidir = softmax_attn(x, c_in=1.0, c_attn=1.0, c_out=1.0, causal=False)
# Causal — position n only attends to positions 0..n
y_causal = softmax_attn(x, c_in=1.0, c_attn=1.0, c_out=1.0, causal=True)
# Causal is JIT-compatible
@nnx.jit
def forward(model, inp):
return model(inp, c_in=1.0, c_attn=1.0, c_out=1.0, causal=True)
Causal masking implementations
The three variants implement causal masking differently:
HyperbolicSoftmaxAttentionandHyperbolicFullAttention: Apply a lower-triangular-infmask to the score matrix before softmax — O(N²) in both causal and non-causal mode.HyperbolicLinearAttention: Uses a cumulative-sum recurrence (jax.lax.scan) following Katharopoulos et al. (2020):S_i = Σ_{j≤i} φ(K_j) V_j^Tcomputed in O(1) per step → O(N) total, making it especially well-suited for long autoregressive sequences.
Choosing an Attention Variant
| Variant | Complexity | Causal complexity | Mechanism | Best For |
|---|---|---|---|---|
HyperbolicLinearAttention |
O(N) | O(N) | Kernel trick + focus function φ | Long sequences, autoregressive models |
HyperbolicSoftmaxAttention |
O(N²) | O(N²) | Standard softmax on spatial components | Short sequences, simplicity |
HyperbolicFullAttention |
O(N²) | O(N²) | Lorentzian inner product + midpoint | Maximum geometric fidelity |
All variants support independent curvatures: c_in for input, c_attn for Q/K/V projections, c_out for output.
Positional Encoding¶
Positional encoding layers for hyperbolic Transformers and attention mechanisms. These layers enable position-aware models on the hyperboloid manifold while preserving geometric structure.
Lorentzian Residual Connection¶
hyperbolix.nn_layers.lorentz_residual ¶
lorentz_residual(
x: Float[Array, "... dim_plus_1"],
y: Float[Array, "... dim_plus_1"],
w_y: float | Float[Array, ""],
c: float,
eps: float = 1e-07,
) -> Float[Array, "... dim_plus_1"]
Lorentzian midpoint-based residual connection (LResNet from HELM).
Computes the weighted Lorentzian midpoint of x and y, projecting back to the hyperboloid:
ave = x + w_y * y
result = ave / sqrt(c * |<ave, ave>_L|)
where _L = -a_0^2 + ||a_s||^2 is the Minkowski inner product.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
(Array, shape(..., d + 1))
|
Points on hyperboloid with curvature c. |
required |
y
|
(Array, shape(..., d + 1))
|
Points on hyperboloid with curvature c (to be added with weight w_y). |
required |
w_y
|
float or scalar Array
|
Weight for the y contribution. |
required |
c
|
float
|
Curvature parameter (positive, c > 0). |
required |
eps
|
float
|
Numerical stability floor (default: 1e-7). |
1e-07
|
Returns:
| Type | Description |
|---|---|
(Array, shape(..., d + 1))
|
Points on hyperboloid with curvature c. |
References
Chen et al., "Hyperbolic Embeddings for Learning on Manifolds" (HELM), 2024.
Source code in hyperbolix/nn_layers/hyperboloid_core.py
HOPE (Hyperbolic Rotary Positional Encoding)¶
hyperbolix.nn_layers.hope ¶
hope(
z: Float[Array, "... seq d_plus_1"],
positions: Float[Array, seq],
c: float = 1.0,
base: float = 10000.0,
eps: float = 1e-07,
) -> Float[Array, "... seq d_plus_1"]
Hyperbolic Rotary Positional Encoding (HOPE).
Applies RoPE-style rotation to the spatial components of hyperboloid
points, then reconstructs the time component to satisfy the manifold
constraint. Equivalent to hrc(z, R_{i,Theta}, c, c) where R is a
block-diagonal rotation matrix.
Since rotation preserves norms, the Minkowski inner product between encoded points depends only on the relative position offset, giving the standard RoPE relative-position property on the hyperboloid.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
z
|
(Array, shape(..., seq_len, d + 1))
|
Points on hyperboloid (d must be even). |
required |
positions
|
(Array, shape(seq_len))
|
Integer position indices. |
required |
c
|
float
|
Curvature parameter (default: 1.0). |
1.0
|
base
|
float
|
Frequency base for rotation angles (default: 10000.0). |
10000.0
|
eps
|
float
|
Numerical stability floor (default: 1e-7). |
1e-07
|
Returns:
| Type | Description |
|---|---|
(Array, shape(..., seq_len, d + 1))
|
Rotated points on hyperboloid with curvature c. |
References
Chen et al., "Hyperbolic Embeddings for Learning on Manifolds" (HELM), 2024.
Source code in hyperbolix/nn_layers/hyperboloid_positional.py
hyperbolix.nn_layers.HyperbolicRoPE ¶
Bases: Module
NNX module wrapper for HOPE (Hyperbolic Rotary Positional Encoding).
This is a stateless module (no learnable parameters) that wraps the
functional :func:hope for convenient use in NNX model definitions.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dim
|
int
|
Spatial dimension d (must be even). |
required |
max_seq_len
|
int
|
Maximum sequence length (for documentation; not enforced, default: 2048). |
2048
|
base
|
float
|
Frequency base for rotation angles (default: 10000.0). |
10000.0
|
eps
|
float
|
Numerical stability floor (default: 1e-7). |
1e-07
|
References
Chen et al., "Hyperbolic Embeddings for Learning on Manifolds" (HELM), 2024.
Source code in hyperbolix/nn_layers/hyperboloid_positional.py
__call__ ¶
__call__(
z: Float[Array, "... seq d_plus_1"],
positions: Float[Array, seq],
c: float = 1.0,
) -> Float[Array, "... seq d_plus_1"]
Apply HOPE positional encoding.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
z
|
(Array, shape(..., seq_len, d + 1))
|
Points on hyperboloid (spatial dim must equal self.dim, must be even). |
required |
positions
|
(Array, shape(seq_len))
|
Integer position indices. |
required |
c
|
float
|
Curvature parameter (default: 1.0). |
1.0
|
Returns:
| Type | Description |
|---|---|
(Array, shape(..., seq_len, d + 1))
|
Rotated points on hyperboloid. |
Source code in hyperbolix/nn_layers/hyperboloid_positional.py
Hypformer Positional Encoding¶
hyperbolix.nn_layers.HypformerPositionalEncoding ¶
HypformerPositionalEncoding(
in_features: int,
out_features: int,
*,
rngs: Rngs,
init_bound: float = 0.02,
eps: float = 1e-07,
)
Bases: Module
Learnable relative positional encoding from Hypformer.
Computes a position vector via HTCLinear, then combines it with the input using a Lorentzian residual connection:
p = HTCLinear(x)
result = lorentz_residual(x, p, w_y=epsilon, c=c)
where epsilon is a learnable scalar magnitude parameter.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_features
|
int
|
Input ambient dimension (d+1, including time component). |
required |
out_features
|
int
|
Output spatial dimension (d). The HTCLinear output will have ambient dimension d+1 (= out_features + 1), matching the input. |
required |
rngs
|
Rngs
|
Random number generators for parameter initialization. |
required |
init_bound
|
float
|
Bound for HTCLinear uniform weight initialization (default: 0.02). |
0.02
|
eps
|
float
|
Numerical stability floor for lorentz_residual (default: 1e-7). |
1e-07
|
Attributes:
| Name | Type | Description |
|---|---|---|
htc_linear |
HTCLinear
|
Linear transformation producing the position encoding vector. |
epsilon |
Param
|
Learnable scalar weight for the position encoding contribution. |
eps |
float
|
Numerical stability parameter. |
References
Chen et al., "Hyperbolic Embeddings for Learning on Manifolds" (HELM), 2024. Hypformer paper (citation to be added).
Source code in hyperbolix/nn_layers/hyperboloid_positional.py
__call__ ¶
Apply learnable positional encoding.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
(Array, shape(..., d + 1))
|
Points on hyperboloid with curvature c. |
required |
c
|
float
|
Curvature parameter (default: 1.0). |
1.0
|
Returns:
| Type | Description |
|---|---|
(Array, shape(..., d + 1))
|
Positionally-encoded points on hyperboloid with curvature c. |
Source code in hyperbolix/nn_layers/hyperboloid_positional.py
HOPE Example¶
from hyperbolix.nn_layers import hope, HyperbolicRoPE
from hyperbolix.manifolds import Hyperboloid
import jax.numpy as jnp
import jax
hyperboloid = Hyperboloid()
# Create sequence of hyperboloid points (batch, seq_len, d+1)
key = jax.random.PRNGKey(42)
batch, seq_len, d = 4, 16, 8
spatial = jax.random.normal(key, (batch, seq_len, d)) * 0.1
time = jnp.sqrt(jnp.sum(spatial**2, axis=-1, keepdims=True) + 1.0)
z = jnp.concatenate([time, spatial], axis=-1) # (4, 16, 9)
# Position indices
positions = jnp.arange(seq_len)
# Apply HOPE (functional interface)
z_encoded = hope(z, positions, c=1.0)
print(z_encoded.shape) # (4, 16, 9)
# Or use the NNX module wrapper
from flax import nnx
rope = HyperbolicRoPE(dim=d, max_seq_len=64, base=10000.0)
z_encoded = rope(z, positions, c=1.0)
# Verify manifold constraint
for b in range(batch):
for s in range(seq_len):
assert hyperboloid.is_in_manifold(z_encoded[b, s], c=1.0, atol=1e-4)
# Verify relative position property: <HOPE(q,i), HOPE(k,j)>_L depends only on i-j
q = z[0, 0] # Single point
k = z[0, 1] # Another point
# Same relative offset (=3) at different absolute positions
q_enc_0 = hope(q[None, None, :], jnp.array([0]), c=1.0)[0, 0]
k_enc_3 = hope(k[None, None, :], jnp.array([3]), c=1.0)[0, 0]
q_enc_10 = hope(q[None, None, :], jnp.array([10]), c=1.0)[0, 0]
k_enc_13 = hope(k[None, None, :], jnp.array([13]), c=1.0)[0, 0]
# Minkowski inner products should be equal
def minkowski_inner(x, y):
return -x[0]*y[0] + jnp.sum(x[1:]*y[1:])
ip1 = minkowski_inner(q_enc_0, k_enc_3)
ip2 = minkowski_inner(q_enc_10, k_enc_13)
print(jnp.allclose(ip1, ip2, atol=1e-5)) # True
Hypformer Positional Encoding Example¶
from hyperbolix.nn_layers import HypformerPositionalEncoding
from hyperbolix.manifolds import Hyperboloid
from flax import nnx
hyperboloid = Hyperboloid()
import jax.numpy as jnp
import jax
# Create positional encoding layer
d = 8 # spatial dimension
in_features = d + 1 # ambient dimension (including time)
pe = HypformerPositionalEncoding(
in_features=in_features,
out_features=d, # output spatial dimension
rngs=nnx.Rngs(0),
init_bound=0.02 # small initialization for stability
)
# Input: batch of hyperboloid points
key = jax.random.PRNGKey(42)
batch_size = 32
spatial = jax.random.normal(key, (batch_size, d)) * 0.1
time = jnp.sqrt(jnp.sum(spatial**2, axis=-1, keepdims=True) + 1.0)
x = jnp.concatenate([time, spatial], axis=-1) # (32, 9)
# Apply learnable positional encoding
x_encoded = pe(x, c=1.0)
print(x_encoded.shape) # (32, 9) - shape preserved
# Verify manifold constraint
is_valid = jax.vmap(hyperboloid.is_in_manifold, in_axes=(0, None))(x_encoded, 1.0)
print(is_valid.all()) # True
# The epsilon parameter is learnable
print(f"Epsilon: {pe.epsilon.value}") # Initially 1.0
# Use in training loop with gradient updates
import optax
optimizer = nnx.Optimizer(pe, optax.adam(1e-3), wrt=nnx.Param)
def loss_fn(model):
out = model(x, c=1.0)
return jnp.sum(out**2) # Dummy loss
loss, grads = nnx.value_and_grad(loss_fn)(pe)
optimizer.update(pe, grads)
print(f"Epsilon after update: {pe.epsilon.value}") # Changed
Lorentzian Residual Example¶
from hyperbolix.nn_layers import lorentz_residual
from hyperbolix.manifolds import Hyperboloid
import jax.numpy as jnp
import jax
hyperboloid = Hyperboloid()
# Create two hyperboloid points
key1, key2 = jax.random.split(jax.random.PRNGKey(42))
d = 6
c = 1.0
spatial_x = jax.random.normal(key1, (d,)) * 0.1
time_x = jnp.sqrt(jnp.sum(spatial_x**2) + 1/c)
x = jnp.concatenate([time_x[None], spatial_x])
spatial_y = jax.random.normal(key2, (d,)) * 0.1
time_y = jnp.sqrt(jnp.sum(spatial_y**2) + 1/c)
y = jnp.concatenate([time_y[None], spatial_y])
# Combine with Lorentzian residual (weighted midpoint)
result = lorentz_residual(x, y, w_y=0.5, c=c)
# Verify output is on hyperboloid
assert hyperboloid.is_in_manifold(result, c, atol=1e-5)
# Works with batches too
x_batch = jax.random.normal(jax.random.PRNGKey(0), (8, d+1))
y_batch = jax.random.normal(jax.random.PRNGKey(1), (8, d+1))
# Project to manifold
x_batch = jax.vmap(hyperboloid.proj, in_axes=(0, None))(x_batch, c)
y_batch = jax.vmap(hyperboloid.proj, in_axes=(0, None))(y_batch, c)
# Apply residual connection
result_batch = lorentz_residual(x_batch, y_batch, w_y=0.3, c=c)
print(result_batch.shape) # (8, 7)
Positional Encoding for Hyperbolic Transformers
HOPE (Hyperbolic Rotary Positional Encoding):
- Deterministic, no learnable parameters
- Based on RoPE: applies rotations to spatial components
- Preserves relative position information:
⟨HOPE(q,i), HOPE(k,j)⟩_Ldepends only oni-j - Rotation is an isometry: preserves spatial norms
- Identity at position 0
- Suitable for long sequences (no learned embeddings to store)
HypformerPositionalEncoding:
- Learnable, adapts to task
- Uses HTCLinear + Lorentzian residual
epsilonparameter controls position encoding magnitude- More flexible but requires training
- Suitable when position patterns are task-specific
Lorentzian Residual:
- Building block for both approaches
- Computes weighted Lorentzian midpoint
- Used for skip connections in hyperbolic Transformers/ResNets
- Formula:
ave = x + w_y*y, normalized to hyperboloid
Hybrid Euclidean-Hyperbolic¶
Hyper++ Feature Scaling¶
hyperbolix.nn_layers.HyperPPFeatureScaling ¶
HyperPPFeatureScaling(
dim: int,
*,
alpha: float | None = None,
activation: Callable | None = jax.nn.tanh,
rngs: Rngs,
)
Bases: Module
Hyper++ feature scaling for hybrid Euclidean-hyperbolic networks.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
dim
|
int
|
Embedding dimension (needed for nnx.Linear and 1/sqrt(d) scaling). |
required |
alpha
|
float or None
|
Value in (0, 1) enabling learned rescaling. None disables it and the layer is entirely parameter-free. Default: None. |
None
|
activation
|
Callable or None
|
Lipschitz activation function. Default: jax.nn.tanh. None to skip. |
tanh
|
rngs
|
Rngs
|
Random number generators for sub-layers. |
required |
Examples:
>>> from flax import nnx
>>> from hyperbolix.nn_layers import HyperPPFeatureScaling
>>>
>>> # Parameter-free mode
>>> layer = HyperPPFeatureScaling(dim=64, rngs=nnx.Rngs(0))
>>> y = layer(x, c=1.0)
>>>
>>> # Learned rescaling mode
>>> layer = HyperPPFeatureScaling(dim=64, alpha=0.9, rngs=nnx.Rngs(0))
>>> y = layer(x, c=0.1)
Source code in hyperbolix/nn_layers/hybrid_regularization.py
__call__ ¶
Apply Hyper++ feature scaling pipeline.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x_BD
|
Array of shape (B, D)
|
Euclidean features from the last Euclidean layer. |
required |
c
|
float
|
Curvature parameter (positive). Passed at call time per library convention to support learnable curvature. |
1.0
|
Returns:
| Type | Description |
|---|---|
Array of shape (B, D)
|
Scaled features ready for expmap_0. |
Source code in hyperbolix/nn_layers/hybrid_regularization.py
HyperPPFeatureScaling is applied after the last Euclidean layer and before expmap_0 to either the Poincaré ball or Hyperboloid. It operates entirely in Euclidean space but uses hyperbolic geometry to bound the output norm via rho_max = atanh(alpha) / sqrt(c).
Usage Example¶
import jax
import jax.numpy as jnp
from flax import nnx
from hyperbolix.nn_layers import HyperPPFeatureScaling
from hyperbolix.manifolds import Poincare
poincare = Poincare()
# Parameter-free mode (RMSNorm + tanh + dim scaling only)
layer = HyperPPFeatureScaling(dim=64, rngs=nnx.Rngs(0))
x = jax.random.normal(jax.random.PRNGKey(0), (32, 64))
scaled = layer(x, c=1.0)
# With learned rescaling (adds rho_max * sigmoid(xi(x)) * x)
layer = HyperPPFeatureScaling(dim=64, alpha=0.9, rngs=nnx.Rngs(0))
scaled = layer(x, c=0.1)
# Map to Poincaré ball
expmap_batch = jax.vmap(poincare.expmap_0, in_axes=(0, None))
points = expmap_batch(scaled, 0.1)
Pipeline Steps
- RMSNorm (parameter-free): normalizes feature magnitudes
- Lipschitz activation (default
tanh, configurable): bounds per-component values - Dimension scaling (
1/sqrt(d)): ensures norm doesn't grow with dimension - Learned rescaling (when
alphais set):rho_max * sigmoid(xi_theta(x)) * xwhererho_max = atanh(alpha) / sqrt(c)
When alpha is None, the layer is entirely parameter-free. When set, only xi_theta (a linear projection to scalar) has learnable parameters.
Regression Layers¶
Single-layer classifiers with Riemannian geometry.
Poincaré Regression¶
hyperbolix.nn_layers.HypRegressionPoincare ¶
HypRegressionPoincare(
manifold_module: Poincare,
in_dim: int,
out_dim: int,
*,
rngs: Rngs,
input_space: str = "manifold",
clamping_factor: float = 1.0,
smoothing_factor: float = 50.0,
)
Bases: Module
Hyperbolic Neural Networks multinomial linear regression layer (Poincaré ball model).
Computation steps: 0) Project the input tensor onto the manifold (optional) 1) Compute the multinomial linear regression score(s)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
manifold_module
|
object
|
Class-based Poincare manifold instance |
required |
in_dim
|
int
|
Dimension of the input space |
required |
out_dim
|
int
|
Dimension of the output space |
required |
rngs
|
Rngs
|
Random number generators for parameter initialization |
required |
input_space
|
str
|
Type of the input tensor, either 'tangent' or 'manifold' (default: 'manifold'). Note: This is a static configuration - changing it after initialization requires recompilation. |
'manifold'
|
clamping_factor
|
float
|
Clamping factor for the multinomial linear regression output (default: 1.0) |
1.0
|
smoothing_factor
|
float
|
Smoothing factor for the multinomial linear regression output (default: 50.0) |
50.0
|
Notes
JIT Compatibility: This layer is designed to work with nnx.jit. Configuration parameters (input_space, clamping_factor, smoothing_factor) are treated as static and will be baked into the compiled function.
References
Ganea Octavian, Gary Bécigneul, and Thomas Hofmann. "Hyperbolic neural networks." Advances in neural information processing systems 31 (2018).
Source code in hyperbolix/nn_layers/poincare_regression.py
__call__ ¶
Forward pass through the hyperbolic regression layer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array of shape (batch, in_dim)
|
Input tensor where the hyperbolic_axis is last |
required |
c
|
float
|
Manifold curvature (default: 1.0) |
1.0
|
Returns:
| Name | Type | Description |
|---|---|---|
res |
Array of shape (batch, out_dim)
|
Multinomial linear regression scores |
Source code in hyperbolix/nn_layers/poincare_regression.py
hyperbolix.nn_layers.HypRegressionPoincarePP ¶
HypRegressionPoincarePP(
manifold_module: Poincare,
in_dim: int,
out_dim: int,
*,
rngs: Rngs,
input_space: str = "manifold",
clamping_factor: float = 1.0,
smoothing_factor: float = 50.0,
)
Bases: Module
Hyperbolic Neural Networks ++ multinomial linear regression layer (Poincaré ball model).
Computation steps: 0) Project the input tensor onto the manifold (optional) 1) Compute the multinomial linear regression score(s)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
manifold_module
|
object
|
Class-based Poincare manifold instance |
required |
in_dim
|
int
|
Dimension of the input space |
required |
out_dim
|
int
|
Dimension of the output space |
required |
rngs
|
Rngs
|
Random number generators for parameter initialization |
required |
input_space
|
str
|
Type of the input tensor, either 'tangent' or 'manifold' (default: 'manifold'). Note: This is a static configuration - changing it after initialization requires recompilation. |
'manifold'
|
clamping_factor
|
float
|
Clamping factor for the multinomial linear regression output (default: 1.0) |
1.0
|
smoothing_factor
|
float
|
Smoothing factor for the multinomial linear regression output (default: 50.0) |
50.0
|
Notes
JIT Compatibility: This layer is designed to work with nnx.jit. Configuration parameters (input_space, clamping_factor, smoothing_factor) are treated as static and will be baked into the compiled function.
References
Shimizu Ryohei, Yusuke Mukuta, and Tatsuya Harada. "Hyperbolic neural networks++." arXiv preprint arXiv:2006.08210 (2020).
Source code in hyperbolix/nn_layers/poincare_regression.py
__call__ ¶
Forward pass through the HNN++ hyperbolic regression layer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array of shape (batch, in_dim)
|
Input tensor where the hyperbolic_axis is last |
required |
c
|
float
|
Manifold curvature (default: 1.0) |
1.0
|
Returns:
| Name | Type | Description |
|---|---|---|
res |
Array of shape (batch, out_dim)
|
Multinomial linear regression scores |
Source code in hyperbolix/nn_layers/poincare_regression.py
Hyperboloid Regression¶
hyperbolix.nn_layers.HypRegressionHyperboloid ¶
HypRegressionHyperboloid(
manifold_module: Hyperboloid,
in_dim: int,
out_dim: int,
*,
rngs: Rngs,
input_space: str = "manifold",
clamping_factor: float = 1.0,
smoothing_factor: float = 50.0,
)
Bases: Module
Fully Hyperbolic Convolutional Neural Networks multinomial linear regression layer (Hyperboloid model).
Computation steps: 0) Project the input tensor onto the manifold (optional) 1) Compute the multinomial linear regression score(s)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
manifold_module
|
object
|
Class-based Hyperboloid manifold instance |
required |
in_dim
|
int
|
Dimension of the input space |
required |
out_dim
|
int
|
Dimension of the output space |
required |
rngs
|
Rngs
|
Random number generators for parameter initialization |
required |
input_space
|
str
|
Type of the input tensor, either 'tangent' or 'manifold' (default: 'manifold'). Note: This is a static configuration - changing it after initialization requires recompilation. |
'manifold'
|
clamping_factor
|
float
|
Clamping factor for the multinomial linear regression output (default: 1.0) |
1.0
|
smoothing_factor
|
float
|
Smoothing factor for the multinomial linear regression output (default: 50.0) |
50.0
|
Notes
JIT Compatibility: This layer is designed to work with nnx.jit. Configuration parameters (input_space, clamping_factor, smoothing_factor) are treated as static and will be baked into the compiled function.
References
Ahmad Bdeir, Kristian Schwethelm, and Niels Landwehr. "Fully hyperbolic convolutional neural networks for computer vision." arXiv preprint arXiv:2303.15919 (2023).
Source code in hyperbolix/nn_layers/hyperboloid_regression.py
__call__ ¶
Forward pass through the hyperbolic regression layer.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array of shape (batch, in_dim)
|
Input tensor where the hyperbolic_axis is last |
required |
c
|
float
|
Manifold curvature (default: 1.0) |
1.0
|
Returns:
| Name | Type | Description |
|---|---|---|
res |
Array of shape (batch, out_dim)
|
Multinomial linear regression scores |
Source code in hyperbolix/nn_layers/hyperboloid_regression.py
FGG Lorentz MLR (Klis et al. 2026)¶
hyperbolix.nn_layers.FGGLorentzMLR ¶
FGGLorentzMLR(
in_features: int,
num_classes: int,
*,
rngs: Rngs,
reset_params: str = "mlr",
init_bias: float = 0.5,
eps: float = 1e-07,
)
Bases: Module
Fast and Geometrically Grounded Lorentz multinomial logistic regression.
Outputs Euclidean logits (signed scaled distances to hyperplanes) using the
FGG spacelike V construction. Unlike HypRegressionHyperboloid, this layer
uses the distance-to-hyperplane formulation matching the reference fc_mlr
(signed_dist2hyperplanes_scaled_angle).
Forward pass (matching reference fc_mlr): 1. Build V_mink from (z, a) 2. mink = x @ V_mink (Minkowski inner products) 3. logits = asinh(sqrt(c) * mink) / sqrt(c) (signed scaled distances)
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
in_features
|
int
|
Input ambient dimension (D_in + 1), including time component. |
required |
num_classes
|
int
|
Number of output classes. |
required |
rngs
|
Rngs
|
Random number generators for parameter initialization. |
required |
reset_params
|
str
|
Weight initialization scheme for hyperplane normals: |
'mlr'
|
init_bias
|
float
|
Initial value for bias entries (default: 0.5). |
0.5
|
eps
|
float
|
Numerical stability floor (default: 1e-7). |
1e-07
|
References
Klis et al. "Fast and Geometrically Grounded Lorentz Neural Networks" (2026), Eq. 23.
Source code in hyperbolix/nn_layers/hyperboloid_regression.py
__call__ ¶
__call__(
x_BAi: Float[Array, "batch in_features"], c: float = 1.0
) -> Float[Array, "batch num_classes"]
Forward pass returning Euclidean logits.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x_BAi
|
(Array, shape(B, Ai))
|
Input points on the hyperboloid with curvature |
required |
c
|
float
|
Curvature parameter (default: 1.0). |
1.0
|
Returns:
| Name | Type | Description |
|---|---|---|
logits_BK |
(Array, shape(B, K))
|
Euclidean logits (signed scaled distances to hyperplanes). |
Source code in hyperbolix/nn_layers/hyperboloid_regression.py
FGG Usage Example¶
import jax
import jax.numpy as jnp
from flax import nnx
from hyperbolix.manifolds import Hyperboloid
from hyperbolix.nn_layers import FGGLinear, FGGConv2D, FGGLorentzMLR, FGGMeanOnlyBatchNorm
hyperboloid = Hyperboloid()
rngs = nnx.Rngs(0)
# --- FGGLinear: FC layer with linear hyperbolic distance growth ---
linear = FGGLinear(
in_features=33, # 32 spatial + 1 time
out_features=65, # 64 spatial + 1 time
rngs=rngs,
activation=jax.nn.relu,
reset_params="lorentz_kaiming",
)
x_B33 = jnp.ones((8, 33))
x_B33 = x_B33.at[:, 0].set(jnp.sqrt(jnp.sum(x_B33[:, 1:]**2, axis=-1) + 1.0))
y_B65 = linear(x_B33, c=1.0)
print(y_B65.shape) # (8, 65)
# --- FGGConv2D: 2D conv with manifold-origin padding ---
conv = FGGConv2D(
manifold_module=hyperboloid,
in_channels=33,
out_channels=65,
kernel_size=3,
rngs=rngs,
activation=jax.nn.relu,
pad_mode="origin", # manifold origin padding (reference default)
)
x_BHWC = jnp.zeros((4, 14, 14, 33))
x_BHWC = x_BHWC.at[..., 0].set(1.0) # valid origin point at c=1
y_BHWC = conv(x_BHWC, c=1.0)
print(y_BHWC.shape) # (4, 14, 14, 65)
# --- FGGLorentzMLR: classification head ---
mlr = FGGLorentzMLR(
in_features=65,
num_classes=10,
rngs=rngs,
reset_params="mlr", # N(0, sqrt(5/Ai)) where Ai = in_features (ambient)
init_bias=0.5,
)
logits_B10 = mlr(y_BHWC.reshape(-1, 65), c=1.0)
print(logits_B10.shape) # (784, 10)
# --- FGGMeanOnlyBatchNorm: pairs with FGGLinear(use_weight_norm=True) ---
# num_features is the SPATIAL (out) dimension
bn = FGGMeanOnlyBatchNorm(num_features=64, rngs=rngs)
y_normed = bn(y_B65, c_in=1.0, c_out=1.0, use_running_average=False)
print(y_normed.shape) # (8, 65)
FGG Layer Family
The four FGG components from Klis et al. (2026) form a complete layer stack:
| Layer | Role | Key params |
|---|---|---|
FGGLinear |
Fully-connected | reset_params, use_weight_norm, activation |
FGGConv2D |
2D convolution | pad_mode="origin", wraps FGGLinear |
FGGLorentzMLR |
Classification head | reset_params="mlr", init_bias=0.5 |
FGGMeanOnlyBatchNorm |
Batch normalization | pairs with use_weight_norm=True |
Core insight: the sinh/arcsinh cancellation in the Lorentzian activation chain reduces the forward pass to a single matmul with a spacelike V matrix, Euclidean activation, then time reconstruction — achieving linear (not logarithmic) growth of hyperbolic distance.
Regression Example¶
import jax
from hyperbolix.nn_layers import HypRegressionPoincare
from hyperbolix.manifolds import Poincare
from flax import nnx
poincare = Poincare()
# Multi-class classification (10 classes)
regressor = HypRegressionPoincare(
manifold_module=poincare,
in_dim=32,
out_dim=10,
rngs=nnx.Rngs(0)
)
# Input: hyperbolic embeddings
x = jax.random.normal(jax.random.PRNGKey(1), (64, 32)) * 0.3
x_proj = jax.vmap(poincare.proj, in_axes=(0, None))(x, 1.0)
# Forward pass returns logits
logits = regressor(x_proj, c=1.0)
print(logits.shape) # (64, 10)
# Use with softmax for classification
probs = jax.nn.softmax(logits, axis=-1)
Activation Functions¶
Hyperbolic activation functions that preserve manifold constraints. All activations follow the HRC pattern: apply function to space components, then reconstruct time.
Curvature-Preserving Activations¶
hyperbolix.nn_layers.hyp_relu ¶
Apply ReLU activation to space components of hyperboloid point(s).
Curvature-preserving wrapper around hrc_relu(x, c_in=c, c_out=c).
This function applies the ReLU activation function to the spatial components of hyperboloid points and reconstructs valid manifold points using the hyperboloid constraint.
Mathematical formula: y = [sqrt(||ReLU(x_s)||^2 + 1/c), ReLU(x_s)]
where x_s are the spatial components x[..., 1:].
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array of shape (..., dim+1)
|
Input point(s) on the hyperboloid manifold in ambient space, where ... represents arbitrary batch dimensions. The last dimension contains the time component (x[..., 0]) and spatial components (x[..., 1:]). |
required |
c
|
float
|
Curvature parameter, must be positive (c > 0). |
required |
Returns:
| Name | Type | Description |
|---|---|---|
y |
Array of shape (..., dim+1)
|
Output point(s) on the hyperboloid manifold, same shape as input. |
Notes
- This function applies ReLU only to spatial components, not the time component
- The time component is reconstructed using the hyperboloid constraint: -x₀² + ||x_rest||² = -1/c
- This approach avoids frequent exp/log maps for better numerical stability
- Works on arrays of any shape, similar to jax.nn.relu
- For curvature-changing transformations, use
hrc_reluwhich supports different input/output curvatures
References
Ahmad Bdeir, Kristian Schwethelm, and Niels Landwehr. "Fully hyperbolic convolutional neural networks for computer vision." arXiv preprint arXiv:2303.15919 (2023).
Examples:
>>> import jax.numpy as jnp
>>> from hyperbolix.nn_layers import hyp_relu
>>>
>>> # Single point
>>> x = jnp.array([1.05, 0.1, -0.2, 0.15])
>>> y = hyp_relu(x, c=1.0)
>>> y.shape
(4,)
>>>
>>> # Batch of points
>>> x_batch = jnp.ones((8, 5)) # 8 points in 5-dim ambient space
>>> y_batch = hyp_relu(x_batch, c=1.0)
>>> y_batch.shape
(8, 5)
>>>
>>> # Multi-dimensional batch (e.g., feature maps)
>>> x_feature = jnp.ones((4, 16, 16, 10)) # 4 images, 16x16 spatial, 10-dim
>>> y_feature = hyp_relu(x_feature, c=1.0)
>>> y_feature.shape
(4, 16, 16, 10)
Source code in hyperbolix/nn_layers/hyperboloid_activations.py
hyperbolix.nn_layers.hyp_leaky_relu ¶
hyp_leaky_relu(
x: Float[Array, "... dim_plus_1"],
c: float,
negative_slope: float = 0.01,
) -> Float[Array, "... dim_plus_1"]
Apply LeakyReLU activation to space components of hyperboloid point(s).
Curvature-preserving wrapper around hrc_leaky_relu(x, c_in=c, c_out=c, negative_slope).
This function applies the LeakyReLU activation function to the spatial components of hyperboloid points and reconstructs valid manifold points using the hyperboloid constraint.
Mathematical formula: y = [sqrt(||LeakyReLU(x_s)||^2 + 1/c), LeakyReLU(x_s)]
where x_s are the spatial components x[..., 1:], and LeakyReLU(x) = x if x > 0 else negative_slope * x.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array of shape (..., dim+1)
|
Input point(s) on the hyperboloid manifold in ambient space, where ... represents arbitrary batch dimensions. The last dimension contains the time component (x[..., 0]) and spatial components (x[..., 1:]). |
required |
c
|
float
|
Curvature parameter, must be positive (c > 0). |
required |
negative_slope
|
float
|
Negative slope coefficient for LeakyReLU (default: 0.01). |
0.01
|
Returns:
| Name | Type | Description |
|---|---|---|
y |
Array of shape (..., dim+1)
|
Output point(s) on the hyperboloid manifold, same shape as input. |
Notes
- This function applies LeakyReLU only to spatial components
- The time component is reconstructed using the hyperboloid constraint
- LeakyReLU allows small negative values (scaled by negative_slope) which can help gradient flow compared to standard ReLU
- Works on arrays of any shape, similar to jax.nn.leaky_relu
- For curvature-changing transformations, use
hrc_leaky_reluwhich supports different input/output curvatures
References
Ahmad Bdeir, Kristian Schwethelm, and Niels Landwehr. "Fully hyperbolic convolutional neural networks for computer vision." arXiv preprint arXiv:2303.15919 (2023).
Examples:
>>> import jax.numpy as jnp
>>> from hyperbolix.nn_layers import hyp_leaky_relu
>>>
>>> # Single point with default negative_slope
>>> x = jnp.array([1.05, 0.1, -0.2, 0.15])
>>> y = hyp_leaky_relu(x, c=1.0)
>>> y.shape
(4,)
>>>
>>> # Custom negative_slope
>>> y = hyp_leaky_relu(x, c=1.0, negative_slope=0.1)
>>>
>>> # Batch of points
>>> x_batch = jnp.ones((8, 5))
>>> y_batch = hyp_leaky_relu(x_batch, c=1.0, negative_slope=0.01)
>>> y_batch.shape
(8, 5)
Source code in hyperbolix/nn_layers/hyperboloid_activations.py
hyperbolix.nn_layers.hyp_tanh ¶
Apply tanh activation to space components of hyperboloid point(s).
Curvature-preserving wrapper around hrc_tanh(x, c_in=c, c_out=c).
This function applies the hyperbolic tangent activation function to the spatial components of hyperboloid points and reconstructs valid manifold points using the hyperboloid constraint.
Mathematical formula: y = [sqrt(||tanh(x_s)||^2 + 1/c), tanh(x_s)]
where x_s are the spatial components x[..., 1:].
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array of shape (..., dim+1)
|
Input point(s) on the hyperboloid manifold in ambient space, where ... represents arbitrary batch dimensions. The last dimension contains the time component (x[..., 0]) and spatial components (x[..., 1:]). |
required |
c
|
float
|
Curvature parameter, must be positive (c > 0). |
required |
Returns:
| Name | Type | Description |
|---|---|---|
y |
Array of shape (..., dim+1)
|
Output point(s) on the hyperboloid manifold, same shape as input. |
Notes
- This function applies tanh only to spatial components
- The time component is reconstructed using the hyperboloid constraint
- Tanh naturally bounds outputs in [-1, 1], which can help with stability
- Works on arrays of any shape, similar to jax.nn.tanh
- For curvature-changing transformations, use
hrc_tanhwhich supports different input/output curvatures
References
Ahmad Bdeir, Kristian Schwethelm, and Niels Landwehr. "Fully hyperbolic convolutional neural networks for computer vision." arXiv preprint arXiv:2303.15919 (2023).
Examples:
>>> import jax.numpy as jnp
>>> from hyperbolix.nn_layers import hyp_tanh
>>>
>>> # Single point
>>> x = jnp.array([1.05, 0.1, -0.2, 0.15])
>>> y = hyp_tanh(x, c=1.0)
>>> y.shape
(4,)
>>>
>>> # Batch of points
>>> x_batch = jnp.ones((8, 5))
>>> y_batch = hyp_tanh(x_batch, c=1.0)
>>> y_batch.shape
(8, 5)
>>>
>>> # Verify spatial components are bounded
>>> import jax
>>> assert jnp.all(jnp.abs(y_batch[..., 1:]) <= 1.0)
Source code in hyperbolix/nn_layers/hyperboloid_activations.py
hyperbolix.nn_layers.hyp_swish ¶
Apply Swish/SiLU activation to space components of hyperboloid point(s).
Curvature-preserving wrapper around hrc_swish(x, c_in=c, c_out=c).
This function applies the Swish (also known as SiLU) activation function to the spatial components of hyperboloid points and reconstructs valid manifold points using the hyperboloid constraint.
Swish is defined as: swish(x) = x * sigmoid(x)
Mathematical formula: y = [sqrt(||swish(x_s)||^2 + 1/c), swish(x_s)]
where x_s are the spatial components x[..., 1:].
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array of shape (..., dim+1)
|
Input point(s) on the hyperboloid manifold in ambient space, where ... represents arbitrary batch dimensions. The last dimension contains the time component (x[..., 0]) and spatial components (x[..., 1:]). |
required |
c
|
float
|
Curvature parameter, must be positive (c > 0). |
required |
Returns:
| Name | Type | Description |
|---|---|---|
y |
Array of shape (..., dim+1)
|
Output point(s) on the hyperboloid manifold, same shape as input. |
Notes
- This function applies Swish only to spatial components
- The time component is reconstructed using the hyperboloid constraint
- Swish is smooth and non-monotonic, often performing well in deep networks
- Works on arrays of any shape, similar to jax.nn.swish
- For curvature-changing transformations, use
hrc_swishwhich supports different input/output curvatures
References
Ahmad Bdeir, Kristian Schwethelm, and Niels Landwehr. "Fully hyperbolic convolutional neural networks for computer vision." arXiv preprint arXiv:2303.15919 (2023).
Examples:
>>> import jax.numpy as jnp
>>> from hyperbolix.nn_layers import hyp_swish
>>>
>>> # Single point
>>> x = jnp.array([1.05, 0.1, -0.2, 0.15])
>>> y = hyp_swish(x, c=1.0)
>>> y.shape
(4,)
>>>
>>> # Batch of points
>>> x_batch = jnp.ones((8, 5))
>>> y_batch = hyp_swish(x_batch, c=1.0)
>>> y_batch.shape
(8, 5)
>>>
>>> # Multi-dimensional batch
>>> x_feature = jnp.ones((4, 16, 16, 10))
>>> y_feature = hyp_swish(x_feature, c=1.0)
>>> y_feature.shape
(4, 16, 16, 10)
Source code in hyperbolix/nn_layers/hyperboloid_activations.py
hyperbolix.nn_layers.hyp_gelu ¶
Apply GELU activation to space components of hyperboloid point(s).
Curvature-preserving wrapper around hrc_gelu(x, c_in=c, c_out=c).
This function applies the Gaussian Error Linear Unit (GELU) activation function to the spatial components of hyperboloid points and reconstructs valid manifold points using the hyperboloid constraint.
Mathematical formula: y = [sqrt(||GELU(x_s)||^2 + 1/c), GELU(x_s)]
where x_s are the spatial components x[..., 1:].
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array of shape (..., dim+1)
|
Input point(s) on the hyperboloid manifold in ambient space, where ... represents arbitrary batch dimensions. The last dimension contains the time component (x[..., 0]) and spatial components (x[..., 1:]). |
required |
c
|
float
|
Curvature parameter, must be positive (c > 0). |
required |
Returns:
| Name | Type | Description |
|---|---|---|
y |
Array of shape (..., dim+1)
|
Output point(s) on the hyperboloid manifold, same shape as input. |
Notes
- This function applies GELU only to spatial components
- The time component is reconstructed using the hyperboloid constraint
- GELU is smooth and commonly used in transformer architectures
- Works on arrays of any shape, similar to jax.nn.gelu
- For curvature-changing transformations, use
hrc_geluwhich supports different input/output curvatures
Examples:
>>> import jax.numpy as jnp
>>> from hyperbolix.nn_layers import hyp_gelu
>>>
>>> # Single point
>>> x = jnp.array([1.05, 0.1, -0.2, 0.15])
>>> y = hyp_gelu(x, c=1.0)
>>> y.shape
(4,)
>>>
>>> # Batch of points
>>> x_batch = jnp.ones((8, 5))
>>> y_batch = hyp_gelu(x_batch, c=1.0)
>>> y_batch.shape
(8, 5)
>>>
>>> # Multi-dimensional batch
>>> x_feature = jnp.ones((4, 16, 16, 10))
>>> y_feature = hyp_gelu(x_feature, c=1.0)
>>> y_feature.shape
(4, 16, 16, 10)
Source code in hyperbolix/nn_layers/hyperboloid_activations.py
Poincaré Activations¶
Thin wrappers that apply standard activations in the Poincaré tangent space via logmap_0 → activation → expmap_0.
hyperbolix.nn_layers.poincare_relu ¶
Poincaré ReLU activation: exp_0^c ∘ ReLU ∘ log_0^c.
Applies ReLU in the tangent space at the origin of the Poincaré ball, then maps back to the manifold. This is the standard nonlinearity for Poincaré ball neural networks.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array of shape (..., dim)
|
Input point(s) on the Poincaré ball. Supports arbitrary batch dimensions (e.g., (batch, H, W, channels) for feature maps). |
required |
c
|
float
|
Curvature parameter (positive). |
required |
Returns:
| Name | Type | Description |
|---|---|---|
y |
Array of shape (..., dim)
|
Output point(s) on the Poincaré ball. |
References
van Spengler et al. "Poincaré ResNet." ICML 2023.
Examples:
>>> import jax.numpy as jnp
>>> from hyperbolix.nn_layers import poincare_relu
>>>
>>> # Single point
>>> x = jnp.array([0.1, -0.2, 0.15])
>>> y = poincare_relu(x, c=1.0)
>>> y.shape
(3,)
>>>
>>> # Batch of feature maps
>>> x_batch = jnp.ones((4, 14, 14, 8)) * 0.1
>>> y_batch = poincare_relu(x_batch, c=1.0)
>>> y_batch.shape
(4, 14, 14, 8)
Source code in hyperbolix/nn_layers/poincare_activations.py
hyperbolix.nn_layers.poincare_leaky_relu ¶
poincare_leaky_relu(
x: Float[Array, "... dim"],
c: float,
negative_slope: float = 0.01,
) -> Float[Array, "... dim"]
Poincaré LeakyReLU activation: exp_0^c ∘ LeakyReLU ∘ log_0^c.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array of shape (..., dim)
|
Input point(s) on the Poincaré ball. |
required |
c
|
float
|
Curvature parameter (positive). |
required |
negative_slope
|
float
|
Negative slope coefficient (default: 0.01). |
0.01
|
Returns:
| Name | Type | Description |
|---|---|---|
y |
Array of shape (..., dim)
|
Output point(s) on the Poincaré ball. |
Source code in hyperbolix/nn_layers/poincare_activations.py
hyperbolix.nn_layers.poincare_tanh ¶
Poincaré tanh activation: exp_0^c ∘ tanh ∘ log_0^c.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array of shape (..., dim)
|
Input point(s) on the Poincaré ball. |
required |
c
|
float
|
Curvature parameter (positive). |
required |
Returns:
| Name | Type | Description |
|---|---|---|
y |
Array of shape (..., dim)
|
Output point(s) on the Poincaré ball. |
Source code in hyperbolix/nn_layers/poincare_activations.py
Curvature-Changing Activations (HRC-based)¶
For advanced use cases requiring curvature transformations:
hyperbolix.nn_layers.hrc_relu ¶
hrc_relu(
x: Float[Array, "... dim_plus_1"],
c_in: float,
c_out: float,
eps: float = 1e-07,
) -> Float[Array, "... dim_plus_1"]
HRC with ReLU activation.
Equivalent to hrc(x, jax.nn.relu, c_in, c_out, eps). When c_in = c_out = c, this is equivalent to hyp_relu(x, c).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array of shape (..., dim+1)
|
Input point(s) on the hyperboloid manifold with curvature c_in. |
required |
c_in
|
float
|
Input curvature parameter (must be positive). |
required |
c_out
|
float
|
Output curvature parameter (must be positive). |
required |
eps
|
float
|
Small value for numerical stability (default: 1e-7). |
1e-07
|
Returns:
| Name | Type | Description |
|---|---|---|
y |
Array of shape (..., dim+1)
|
Output point(s) on the hyperboloid manifold with curvature c_out. |
Source code in hyperbolix/nn_layers/hyperboloid_activations.py
hyperbolix.nn_layers.hrc_leaky_relu ¶
hrc_leaky_relu(
x: Float[Array, "... dim_plus_1"],
c_in: float,
c_out: float,
negative_slope: float = 0.01,
eps: float = 1e-07,
) -> Float[Array, "... dim_plus_1"]
HRC with LeakyReLU activation.
When c_in = c_out = c, this is equivalent to hyp_leaky_relu(x, c, negative_slope).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array of shape (..., dim+1)
|
Input point(s) on the hyperboloid manifold with curvature c_in. |
required |
c_in
|
float
|
Input curvature parameter (must be positive). |
required |
c_out
|
float
|
Output curvature parameter (must be positive). |
required |
negative_slope
|
float
|
Negative slope coefficient for LeakyReLU (default: 0.01). |
0.01
|
eps
|
float
|
Small value for numerical stability (default: 1e-7). |
1e-07
|
Returns:
| Name | Type | Description |
|---|---|---|
y |
Array of shape (..., dim+1)
|
Output point(s) on the hyperboloid manifold with curvature c_out. |
Source code in hyperbolix/nn_layers/hyperboloid_activations.py
hyperbolix.nn_layers.hrc_tanh ¶
hrc_tanh(
x: Float[Array, "... dim_plus_1"],
c_in: float,
c_out: float,
eps: float = 1e-07,
) -> Float[Array, "... dim_plus_1"]
HRC with tanh activation.
When c_in = c_out = c, this is equivalent to hyp_tanh(x, c).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array of shape (..., dim+1)
|
Input point(s) on the hyperboloid manifold with curvature c_in. |
required |
c_in
|
float
|
Input curvature parameter (must be positive). |
required |
c_out
|
float
|
Output curvature parameter (must be positive). |
required |
eps
|
float
|
Small value for numerical stability (default: 1e-7). |
1e-07
|
Returns:
| Name | Type | Description |
|---|---|---|
y |
Array of shape (..., dim+1)
|
Output point(s) on the hyperboloid manifold with curvature c_out. |
Source code in hyperbolix/nn_layers/hyperboloid_activations.py
hyperbolix.nn_layers.hrc_swish ¶
hrc_swish(
x: Float[Array, "... dim_plus_1"],
c_in: float,
c_out: float,
eps: float = 1e-07,
) -> Float[Array, "... dim_plus_1"]
HRC with Swish/SiLU activation.
When c_in = c_out = c, this is equivalent to hyp_swish(x, c).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array of shape (..., dim+1)
|
Input point(s) on the hyperboloid manifold with curvature c_in. |
required |
c_in
|
float
|
Input curvature parameter (must be positive). |
required |
c_out
|
float
|
Output curvature parameter (must be positive). |
required |
eps
|
float
|
Small value for numerical stability (default: 1e-7). |
1e-07
|
Returns:
| Name | Type | Description |
|---|---|---|
y |
Array of shape (..., dim+1)
|
Output point(s) on the hyperboloid manifold with curvature c_out. |
Source code in hyperbolix/nn_layers/hyperboloid_activations.py
hyperbolix.nn_layers.hrc_gelu ¶
hrc_gelu(
x: Float[Array, "... dim_plus_1"],
c_in: float,
c_out: float,
eps: float = 1e-07,
) -> Float[Array, "... dim_plus_1"]
HRC with GELU activation.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
x
|
Array of shape (..., dim+1)
|
Input point(s) on the hyperboloid manifold with curvature c_in. |
required |
c_in
|
float
|
Input curvature parameter (must be positive). |
required |
c_out
|
float
|
Output curvature parameter (must be positive). |
required |
eps
|
float
|
Small value for numerical stability (default: 1e-7). |
1e-07
|
Returns:
| Name | Type | Description |
|---|---|---|
y |
Array of shape (..., dim+1)
|
Output point(s) on the hyperboloid manifold with curvature c_out. |
Source code in hyperbolix/nn_layers/hyperboloid_activations.py
Activation Examples¶
Curvature-Preserving Activation:
import jax
import jax.numpy as jnp
from hyperbolix.nn_layers import hyp_relu, hyp_gelu
from hyperbolix.manifolds import Hyperboloid
hyperboloid = Hyperboloid()
# Points on hyperboloid (ambient coordinates)
x = jax.random.normal(jax.random.PRNGKey(0), (10, 5))
x_ambient = jnp.concatenate([
jnp.sqrt(jnp.sum(x**2, axis=-1, keepdims=True) + 1.0),
x
], axis=-1)
# Apply hyperbolic ReLU (curvature preserving)
output = hyp_relu(x_ambient, c=1.0)
print(output.shape) # (10, 6) - same shape
# Verify manifold constraint
constraint = -output[:, 0]**2 + jnp.sum(output[:, 1:]**2, axis=-1)
print(jnp.allclose(constraint, -1.0, atol=1e-5)) # True
# Use GELU instead
output_gelu = hyp_gelu(x_ambient, c=1.0)
Curvature-Changing Activation:
from hyperbolix.nn_layers import hrc_relu
# Transform from curvature 1.0 to curvature 2.0
output = hrc_relu(x_ambient, c_in=1.0, c_out=2.0)
# Verify new manifold constraint (c=2.0)
constraint = -output[:, 0]**2 + jnp.sum(output[:, 1:]**2, axis=-1)
print(jnp.allclose(constraint, -1.0/2.0, atol=1e-5)) # True
How Activations Work
Hyperbolic activations follow the HRC pattern:
- Extract space components
x_s = x[..., 1:] - Apply activation to space:
y_s = activation(x_s) - Scale for curvature change:
y_s = sqrt(c_in/c_out) * y_s - Reconstruct time:
y_t = sqrt(||y_s||^2 + 1/c_out)
This avoids expensive exp/log maps while preserving geometry and enabling flexible curvature transformations.
Building Models¶
Example of a complete hyperbolic neural network:
import jax
import jax.numpy as jnp
from flax import nnx
from hyperbolix.nn_layers import HypLinearPoincare, hyp_relu
from hyperbolix.manifolds import Poincare
poincare = Poincare()
class HyperbolicNN(nnx.Module):
def __init__(self, rngs):
self.layer1 = HypLinearPoincare(
manifold_module=poincare,
in_dim=784, # MNIST flattened
out_dim=256,
rngs=rngs
)
self.layer2 = HypLinearPoincare(
manifold_module=poincare,
in_dim=256,
out_dim=128,
rngs=rngs
)
self.layer3 = HypLinearPoincare(
manifold_module=poincare,
in_dim=128,
out_dim=10,
rngs=rngs
)
def __call__(self, x, c=1.0):
# x: (batch, 784) on Poincaré ball
x = self.layer1(x, c)
x = jax.vmap(lambda xi: hyp_relu(xi, c))(x)
x = self.layer2(x, c)
x = jax.vmap(lambda xi: hyp_relu(xi, c))(x)
x = self.layer3(x, c)
return x
# Create and use model
model = HyperbolicNN(rngs=nnx.Rngs(0))
# Input data (projected to Poincaré ball)
x = jax.random.normal(jax.random.PRNGKey(1), (32, 784)) * 0.1
x_proj = jax.vmap(poincare.proj, in_axes=(0, None))(x, 1.0)
output = model(x_proj, c=1.0)
print(output.shape) # (32, 10)
References¶
The neural network layers implement methods from:
- Ganea et al. (2018): "Hyperbolic Neural Networks" - Poincaré linear layers and activations
- Shimizu et al. (2020): "Hyperbolic Neural Networks++" - Enhanced Poincaré and Hyperboloid operations (
HypLinearPoincarePP,HypLinearHyperboloidPP,HypConv2DHyperboloidPP) - Bdeir et al. (2023): "Fully Hyperbolic Convolutional Neural Networks for Computer Vision" - HCat-based convolutions (
HypConv2DHyperboloid) - Chen et al. (2022): "Fully Hyperbolic Neural Networks" - FHCNN linear layers
- LResNet (2023): "Lorentzian ResNet" - HRC-based convolutions (
LorentzConv2D) - Hypformer (Yang et al. 2025): "Hyperbolic Transformers" - HTC/HRC components with curvature-change support
- Chen et al. (2024): "Hyperbolic Embeddings for Learning on Manifolds (HELM)" - HOPE positional encoding and Lorentzian residual connections
- Klis et al. (2026): "Fast and Geometrically Grounded Lorentz Neural Networks" -
FGGLinear,FGGConv2D,FGGLorentzMLR,FGGMeanOnlyBatchNorm; sinh/arcsinh cancellation for linear hyperbolic distance growth
Key Theoretical Connections¶
- HL (Hyperbolic Layer) from LResNet ≡ HRC (Hyperbolic Regularization Component) from Hypformer
- Both apply Euclidean operations to spatial components and reconstruct time using the Lorentz constraint
LorentzConv2Dis a specific instance ofhrc()wheref_ris a 2D convolution
See also:
- Manifolds API: Underlying geometric operations
- Optimizers API: Training with Riemannian optimization
- Training Workflows: Complete training examples