Skip to content

Neural Network Layers API

Hyperbolic neural network layers built with Flax NNX.

Overview

Hyperbolix provides 20+ neural network layer classes and 5 activation functions for building hyperbolic deep learning models:

  • Linear Layers: Poincaré and Hyperboloid linear transformations, including FGG (Fast and Geometrically Grounded) layers
  • Convolutional Layers: HCat-based, HRC-based, and FGG hyperbolic convolutions
  • Normalization: Poincaré batch normalization (PoincareBatchNorm2D), HRC-wrapped norms, and FGG mean-only batch norm
  • Hypformer Components: HTC (Hyperbolic Transformation Component) and HRC (Hyperbolic Regularization Component) with curvature-change support
  • FGG Components: FGGLinear, FGGConv2D, FGGMeanOnlyBatchNorm from Klis et al. (2026) — linear-distance growth, ~3× faster than prior work
  • Attention Layers: Three hyperbolic attention variants (linear O(N), softmax O(N²), full Lorentzian O(N²)) from the Hypformer paper
  • Positional Encoding: HOPE (Hyperbolic Rotary PE) and Hypformer learnable positional encodings for Transformers
  • Regression Layers: Single-layer classifiers with Riemannian geometry, including FGGLorentzMLR
  • Activation Functions: Hyperbolic ReLU, Leaky ReLU, Tanh, Swish, GELU
  • Helper Functions: Utilities for regression and conformal factor computation

All layers follow Flax NNX conventions and store manifold module references.

Linear Layers

Poincaré Linear

hyperbolix.nn_layers.HypLinearPoincare

HypLinearPoincare(
    manifold_module: Manifold,
    in_dim: int,
    out_dim: int,
    *,
    rngs: Rngs,
    input_space: str = "manifold",
)

Bases: Module

Hyperbolic Neural Networks fully connected layer (Poincaré ball model).

Computation steps: 0) Project the input tensor to the tangent space (optional) 1) Perform matrix vector multiplication in the tangent space at the origin. 2) Map the result to the manifold. 3) Add the manifold bias to the result.

Parameters:

Name Type Description Default
manifold_module object

Class-based Poincare manifold instance

required
in_dim int

Dimension of the input space

required
out_dim int

Dimension of the output space

required
rngs Rngs

Random number generators for parameter initialization

required
input_space str

Type of the input tensor, either 'tangent' or 'manifold' (default: 'manifold'). Note: This is a static configuration - changing it after initialization requires recompilation.

'manifold'
Notes

JIT Compatibility: This layer is designed to work with nnx.jit. Configuration parameters (input_space) are treated as static and will be baked into the compiled function. Changing these values after JIT compilation will trigger automatic recompilation.

References

Ganea Octavian, Gary Bécigneul, and Thomas Hofmann. "Hyperbolic neural networks." Advances in neural information processing systems 31 (2018).

Source code in hyperbolix/nn_layers/poincare_linear.py
def __init__(
    self,
    manifold_module: Manifold,
    in_dim: int,
    out_dim: int,
    *,
    rngs: nnx.Rngs,
    input_space: str = "manifold",
):
    if input_space not in ["tangent", "manifold"]:
        raise ValueError(f"input_space must be either 'tangent' or 'manifold', got '{input_space}'")

    # Static configuration (treated as compile-time constants for JIT)
    validate_poincare_manifold(
        manifold_module,
        required_methods=("proj", "addition", "expmap_0", "logmap_0", "compute_mlr_pp"),
    )
    self.manifold = manifold_module
    self.in_dim = in_dim
    self.out_dim = out_dim
    self.input_space = input_space

    # Trainable parameters
    # Tangent space weight (Euclidean) - initialized with std = 1/sqrt(fan_in)
    # to prevent outputs from saturating at the Poincaré ball boundary
    std = 1.0 / jnp.sqrt(in_dim)
    self.kernel = nnx.Param(jax.random.normal(rngs.params(), (out_dim, in_dim)) * std)
    # Manifold bias (initialized to small random values to avoid gradient issues at origin)
    self.bias = ManifoldParam(
        jax.random.normal(rngs.params(), (out_dim,)) * 0.01,
        manifold=self.manifold,
        curvature=1.0,
    )
__call__
__call__(
    x: Float[Array, "batch in_dim"], c: float = 1.0
) -> Float[Array, "batch out_dim"]

Forward pass through the hyperbolic linear layer.

Parameters:

Name Type Description Default
x Array of shape (batch, in_dim)

Input tensor where the hyperbolic_axis is last

required
c float

Manifold curvature (default: 1.0)

1.0

Returns:

Name Type Description
res Array of shape (batch, out_dim)

Output on the Poincaré ball manifold

Source code in hyperbolix/nn_layers/poincare_linear.py
def __call__(
    self,
    x: Float[Array, "batch in_dim"],
    c: float = 1.0,
) -> Float[Array, "batch out_dim"]:
    """
    Forward pass through the hyperbolic linear layer.

    Parameters
    ----------
    x : Array of shape (batch, in_dim)
        Input tensor where the hyperbolic_axis is last
    c : float
        Manifold curvature (default: 1.0)

    Returns
    -------
    res : Array of shape (batch, out_dim)
        Output on the Poincaré ball manifold
    """
    # Project bias to manifold
    bias_O = self.manifold.proj(self.bias[...], c)

    # Map to tangent space if needed (static branch - JIT friendly)
    if self.input_space == "manifold":
        x_BI = jax.vmap(self.manifold.logmap_0, in_axes=(0, None), out_axes=0)(x, c)
    else:
        x_BI = x

    # Matrix-vector multiplication in tangent space at origin
    x_BO = jnp.einsum("bi,oi->bo", x_BI, self.kernel)  # (B, I) @ (I, O) -> (B, O)

    # Map back to manifold
    x_BO = jax.vmap(self.manifold.expmap_0, in_axes=(0, None), out_axes=0)(x_BO, c)

    # Manifold bias addition (Möbius addition for Poincaré)
    res_BO = jax.vmap(self.manifold.addition, in_axes=(0, None, None), out_axes=0)(x_BO, bias_O, c)
    return res_BO

hyperbolix.nn_layers.HypLinearPoincarePP

HypLinearPoincarePP(
    manifold_module: Poincare,
    in_dim: int,
    out_dim: int,
    *,
    rngs: Rngs,
    input_space: str = "manifold",
    clamping_factor: float = 1.0,
    smoothing_factor: float = 50.0,
)

Bases: Module

Hyperbolic Neural Networks ++ fully connected layer (Poincaré ball model).

Computation steps: 0) Project the input tensor onto the manifold (optional) 1) Compute the multinomial linear regression score(s) 2) Calculate the generalized linear transformation from the regression score(s)

Parameters:

Name Type Description Default
manifold_module object

Class-based Poincare manifold instance

required
in_dim int

Dimension of the input space

required
out_dim int

Dimension of the output space

required
rngs Rngs

Random number generators for parameter initialization

required
input_space str

Type of the input tensor, either 'tangent' or 'manifold' (default: 'manifold'). Note: This is a static configuration - changing it after initialization requires recompilation.

'manifold'
clamping_factor float

Clamping factor for the multinomial linear regression output (default: 1.0)

1.0
smoothing_factor float

Smoothing factor for the multinomial linear regression output (default: 50.0)

50.0
Notes

JIT Compatibility: This layer is designed to work with nnx.jit. Configuration parameters (input_space, clamping_factor, smoothing_factor) are treated as static and will be baked into the compiled function.

References

Shimizu Ryohei, Yusuke Mukuta, and Tatsuya Harada. "Hyperbolic neural networks++." arXiv preprint arXiv:2006.08210 (2020).

Source code in hyperbolix/nn_layers/poincare_linear.py
def __init__(
    self,
    manifold_module: Poincare,
    in_dim: int,
    out_dim: int,
    *,
    rngs: nnx.Rngs,
    input_space: str = "manifold",
    clamping_factor: float = 1.0,
    smoothing_factor: float = 50.0,
):
    if input_space not in ["tangent", "manifold"]:
        raise ValueError(f"input_space must be either 'tangent' or 'manifold', got '{input_space}'")

    # Static configuration (treated as compile-time constants for JIT)
    validate_poincare_manifold(
        manifold_module,
        required_methods=("proj", "addition", "expmap_0", "logmap_0", "compute_mlr_pp"),
    )
    self.manifold = manifold_module
    self.in_dim = in_dim
    self.out_dim = out_dim
    self.input_space = input_space
    self.clamping_factor = clamping_factor
    self.smoothing_factor = smoothing_factor

    # Trainable parameters
    # Tangent space weight - initialized with std = 1/sqrt(fan_in)
    # to prevent outputs from saturating at the Poincaré ball boundary
    std = 1.0 / jnp.sqrt(in_dim)
    self.kernel = nnx.Param(jax.random.normal(rngs.params(), (out_dim, in_dim)) * std)
    # Scalar bias
    self.bias = nnx.Param(jnp.zeros((out_dim, 1)))
__call__
__call__(
    x: Float[Array, "batch in_dim"], c: float = 1.0
) -> Float[Array, "batch out_dim"]

Forward pass through the HNN++ hyperbolic linear layer.

Parameters:

Name Type Description Default
x Array of shape (batch, in_dim)

Input tensor where the hyperbolic_axis is last

required
c float

Manifold curvature (default: 1.0)

1.0

Returns:

Name Type Description
res Array of shape (batch, out_dim)

Output on the Poincaré ball manifold

Source code in hyperbolix/nn_layers/poincare_linear.py
def __call__(
    self,
    x: Float[Array, "batch in_dim"],
    c: float = 1.0,
) -> Float[Array, "batch out_dim"]:
    """
    Forward pass through the HNN++ hyperbolic linear layer.

    Parameters
    ----------
    x : Array of shape (batch, in_dim)
        Input tensor where the hyperbolic_axis is last
    c : float
        Manifold curvature (default: 1.0)

    Returns
    -------
    res : Array of shape (batch, out_dim)
        Output on the Poincaré ball manifold
    """
    return _poincare_pp_forward(
        x,
        self.kernel[...],
        self.bias[...],
        self.manifold,
        c,
        self.input_space,
        self.clamping_factor,
        self.smoothing_factor,
    )

Hyperboloid Linear

hyperbolix.nn_layers.HypLinearHyperboloidFHCNN

HypLinearHyperboloidFHCNN(
    manifold_module: Manifold,
    in_dim: int,
    out_dim: int,
    *,
    rngs: Rngs,
    input_space: str = "manifold",
    init_scale: float = 2.3,
    learnable_scale: bool = False,
    eps: float = 1e-05,
    activation: Callable[[Array], Array] | None = None,
    normalize: bool = False,
)

Bases: Module

Fully Hyperbolic Convolutional Neural Networks fully connected layer (Hyperboloid model).

Computation steps: 0) Project the input tensor to the manifold (optional) 1) Apply activation (optional) 2) a) If normalize is True, compute the time and space coordinates of the output by applying a scaled sigmoid of the weight and biases transformed coordinates of the input or the result of the previous step. b) If normalize is False, compute the weight and biases transformed space coordinates of the input or the result of the previous step and set the time coordinate such that the result lies on the manifold.

Parameters:

Name Type Description Default
manifold_module object

Class-based Hyperboloid manifold instance

required
in_dim int

Dimension of the input space

required
out_dim int

Dimension of the output space

required
rngs Rngs

Random number generators for parameter initialization

required
input_space str

Type of the input tensor, either 'tangent' or 'manifold' (default: 'manifold'). Note: This is a static configuration - changing it after initialization requires recompilation.

'manifold'
init_scale float

Initial value for the sigmoid scale parameter (default: 2.3)

2.3
learnable_scale bool

Whether the scale parameter should be learnable (default: False)

False
eps float

Small value to ensure that the time coordinate is bigger than 1/sqrt(c) (default: 1e-5)

1e-05
activation callable or None

Activation function to apply before the linear transformation (default: None). Note: This is a static configuration - changing it after initialization requires recompilation.

None
normalize bool

Whether to normalize the space coordinates before rescaling (default: False). Note: This is a static configuration - changing it after initialization requires recompilation.

False
Notes

JIT Compatibility: This layer is designed to work with nnx.jit. Configuration parameters (input_space, activation, normalize) are treated as static and will be baked into the compiled function.

Relationship to HTC/HRC: When normalize=False and c_in = c_out, this layer uses the same time reconstruction pattern as htc: time = sqrt(||space||^2 + 1/c). The key difference is that FHCNN applies a linear transform to the full input and discards the computed time, while htc uses the linear output directly as spatial components. When normalize=True, FHCNN uses a learned sigmoid scaling which differs from both htc and hrc.

See Also

htc : Hyperbolic Transformation Component with curvature change support. Similar time reconstruction pattern when normalize=False. HTCLinear : Module wrapper for htc with learnable linear transformation.

References

Ahmad Bdeir, Kristian Schwethelm, and Niels Landwehr. "Fully hyperbolic convolutional neural networks for computer vision." arXiv preprint arXiv:2303.15919 (2023).

Source code in hyperbolix/nn_layers/hyperboloid_linear.py
def __init__(
    self,
    manifold_module: Manifold,
    in_dim: int,
    out_dim: int,
    *,
    rngs: nnx.Rngs,
    input_space: str = "manifold",
    init_scale: float = 2.3,
    learnable_scale: bool = False,
    eps: float = 1e-5,
    activation: Callable[[Array], Array] | None = None,
    normalize: bool = False,
):
    if input_space not in ["tangent", "manifold"]:
        raise ValueError(f"input_space must be either 'tangent' or 'manifold', got '{input_space}'")

    # Static configuration (treated as compile-time constants for JIT)
    validate_hyperboloid_manifold(manifold_module, required_methods=("expmap_0",))
    self.manifold = manifold_module
    self.in_dim = in_dim
    self.out_dim = out_dim
    self.input_space = input_space
    self.eps = eps
    self.activation = activation
    self.normalize = normalize

    # Trainable parameters
    bound = 0.02
    weight_init = jax.random.uniform(rngs.params(), (out_dim, in_dim), minval=-bound, maxval=bound)
    self.kernel = nnx.Param(weight_init)
    self.bias = nnx.Param(jnp.zeros((1, out_dim)))

    # Scale parameter for sigmoid
    if learnable_scale:
        self.scale = nnx.Param(jnp.array(init_scale))
    else:
        # For non-learnable scale, store as regular Python float (static)
        self.scale = init_scale
__call__
__call__(
    x: Float[Array, "batch in_dim"], c: float = 1.0
) -> Float[Array, "batch out_dim"]

Forward pass through the FHCNN hyperbolic linear layer.

Parameters:

Name Type Description Default
x Array of shape (batch, in_dim)

Input tensor where the hyperbolic_axis is last. x.shape[-1] must equal self.in_dim.

required
c float

Manifold curvature (default: 1.0)

1.0

Returns:

Name Type Description
res Array of shape (batch, out_dim)

Output on the Hyperboloid manifold

Source code in hyperbolix/nn_layers/hyperboloid_linear.py
def __call__(
    self,
    x: Float[Array, "batch in_dim"],
    c: float = 1.0,
) -> Float[Array, "batch out_dim"]:
    """
    Forward pass through the FHCNN hyperbolic linear layer.

    Parameters
    ----------
    x : Array of shape (batch, in_dim)
        Input tensor where the hyperbolic_axis is last. x.shape[-1] must equal self.in_dim.
    c : float
        Manifold curvature (default: 1.0)

    Returns
    -------
    res : Array of shape (batch, out_dim)
        Output on the Hyperboloid manifold
    """
    scale_val = self.scale[...] if isinstance(self.scale, nnx.Param) else self.scale
    return _fhcnn_forward(
        x,
        self.kernel[...],
        self.bias[...],
        self.manifold,
        c,
        self.input_space,
        self.activation,
        self.normalize,
        scale_val,
        self.eps,
    )

hyperbolix.nn_layers.HypLinearHyperboloidFHNN

HypLinearHyperboloidFHNN(
    manifold_module: Manifold,
    in_dim: int,
    out_dim: int,
    *,
    rngs: Rngs,
    input_space: str = "manifold",
    init_scale: float = 2.3,
    eps: float = 1e-05,
    activation: Callable[[Array], Array] | None = None,
    dropout_rate: float | None = None,
)

Bases: Module

Fully Hyperbolic Neural Networks linear layer (Chen et al. 2021).

Time-primary parameterization: the time coordinate is computed via scaled sigmoid with an additive floor at 1/sqrt(c), and spatial coordinates are rescaled so the output lies on the hyperboloid.

Computation steps: 0) Project the input tensor to the manifold (optional) 1) Apply activation (optional) 2) Apply dropout (optional) 3) Linear transform: z = W @ x + b 4) Time: y0 = exp(scale) * sigmoid(z0) + 1/sqrt(c) + eps 5) Spatial: y_rem = sqrt(y0^2 - 1/c) / ||z_rem|| * z_rem 6) Output: [y0, y_rem] on the hyperboloid

Parameters:

Name Type Description Default
manifold_module object

Class-based Hyperboloid manifold instance

required
in_dim int

Ambient input dimension (d+1, including time)

required
out_dim int

Ambient output dimension (d+1, including time)

required
rngs Rngs

Random number generators for parameter initialization

required
input_space str

Type of the input tensor, either 'tangent' or 'manifold' (default: 'manifold'). Note: This is a static configuration - changing it after initialization requires recompilation.

'manifold'
init_scale float

Initial value for the learnable sigmoid scale (default: 2.3)

2.3
eps float

Numerical stability epsilon (default: 1e-5)

1e-05
activation callable or None

Activation function to apply before the linear transformation (default: None). Note: This is a static configuration - changing it after initialization requires recompilation.

None
dropout_rate float or None

Dropout rate applied before the linear transformation (default: None).

None
Notes

JIT Compatibility: This layer is designed to work with nnx.jit. Configuration parameters (input_space, activation) are treated as static and will be baked into the compiled function.

Weight Initialization: Weights are initialized as tangent vectors at the hyperboloid origin: U(-0.02, 0.02) with the time column (column 0) zeroed. This matches the Chen et al. 2021 init.

Relationship to FHCNN: Both FHNN and FHCNN ensure outputs lie on the hyperboloid, but differ in which coordinate is primary. FHNN controls the time coordinate via sigmoid with an additive floor at 1/sqrt(c), then derives the spatial norm from the hyperboloid constraint. FHCNN (normalize=True) controls the spatial norm via sigmoid, then derives time.

See Also

HypLinearHyperboloidFHCNN : FHCNN layer with spatial-primary parameterization. HypLinearHyperboloidPP : HNN++ layer using MLR + sinh diffeomorphism.

References

Weize Chen, et al. "Fully hyperbolic neural networks." arXiv preprint arXiv:2105.14686 (2021).

Source code in hyperbolix/nn_layers/hyperboloid_linear.py
def __init__(
    self,
    manifold_module: Manifold,
    in_dim: int,
    out_dim: int,
    *,
    rngs: nnx.Rngs,
    input_space: str = "manifold",
    init_scale: float = 2.3,
    eps: float = 1e-5,
    activation: Callable[[Array], Array] | None = None,
    dropout_rate: float | None = None,
):
    if input_space not in ["tangent", "manifold"]:
        raise ValueError(f"input_space must be either 'tangent' or 'manifold', got '{input_space}'")

    # Static configuration (treated as compile-time constants for JIT)
    validate_hyperboloid_manifold(manifold_module, required_methods=("expmap_0",))
    self.manifold = manifold_module
    self.in_dim = in_dim
    self.out_dim = out_dim
    self.input_space = input_space
    self.eps = eps
    self.activation = activation

    # FHNN weight init: U(-0.02, 0.02) with time column zeroed (tangent vectors at origin)
    bound = 0.02
    weight_init = jax.random.uniform(rngs.params(), (out_dim, in_dim), minval=-bound, maxval=bound)
    weight_init = weight_init.at[:, 0].set(0.0)
    self.kernel = nnx.Param(weight_init)
    self.bias = nnx.Param(jnp.zeros((1, out_dim)))

    # Learnable scale for the sigmoid (always learnable in FHNN)
    self.scale = nnx.Param(jnp.array(init_scale))

    # Optional dropout
    if dropout_rate is not None and dropout_rate > 0:
        self.dropout = nnx.Dropout(dropout_rate, rngs=rngs)
    else:
        self.dropout = None
__call__
__call__(
    x: Float[Array, "batch in_dim"],
    c: float = 1.0,
    deterministic: bool = True,
) -> Float[Array, "batch out_dim"]

Forward pass through the FHNN hyperbolic linear layer.

Parameters:

Name Type Description Default
x Array of shape (batch, in_dim)

Input tensor. x.shape[-1] must equal self.in_dim.

required
c float

Manifold curvature (default: 1.0)

1.0
deterministic bool

If True, dropout is disabled (default: True).

True

Returns:

Name Type Description
res Array of shape (batch, out_dim)

Output on the Hyperboloid manifold

Source code in hyperbolix/nn_layers/hyperboloid_linear.py
def __call__(
    self,
    x: Float[Array, "batch in_dim"],
    c: float = 1.0,
    deterministic: bool = True,
) -> Float[Array, "batch out_dim"]:
    """Forward pass through the FHNN hyperbolic linear layer.

    Parameters
    ----------
    x : Array of shape (batch, in_dim)
        Input tensor. x.shape[-1] must equal self.in_dim.
    c : float
        Manifold curvature (default: 1.0)
    deterministic : bool
        If True, dropout is disabled (default: True).

    Returns
    -------
    res : Array of shape (batch, out_dim)
        Output on the Hyperboloid manifold
    """
    # Build dropout closure for the pure function
    dropout_module = self.dropout
    if dropout_module is not None:
        dropout_fn = lambda z: dropout_module(z, deterministic=deterministic)  # noqa: E731
    else:
        dropout_fn = None

    return _fhnn_forward(
        x,
        self.kernel[...],
        self.bias[...],
        self.manifold,
        c,
        self.input_space,
        self.activation,
        dropout_fn,
        self.scale[...],
        self.eps,
    )

hyperbolix.nn_layers.HypLinearHyperboloidPP

HypLinearHyperboloidPP(
    manifold_module: Hyperboloid,
    in_dim: int,
    out_dim: int,
    *,
    rngs: Rngs,
    input_space: str = "manifold",
    clamping_factor: float = 1.0,
    smoothing_factor: float = 50.0,
)

Bases: Module

Hyperbolic Neural Networks ++ fully connected layer (Hyperboloid model).

Computation steps: 0) Project the input tensor onto the manifold (optional) 1) Compute the multinomial linear regression score(s) via compute_mlr 2) Apply element-wise sinh diffeomorphism to obtain spatial coordinates 3) Reconstruct time coordinate from the hyperboloid constraint

Parameters:

Name Type Description Default
manifold_module object

Class-based Hyperboloid manifold instance

required
in_dim int

Full input dimension (ambient, d+1)

required
out_dim int

Full output dimension (ambient, d+1)

required
rngs Rngs

Random number generators for parameter initialization

required
input_space str

Type of the input tensor, either 'tangent' or 'manifold' (default: 'manifold'). Note: This is a static configuration - changing it after initialization requires recompilation.

'manifold'
clamping_factor float

Clamping factor for the multinomial linear regression output (default: 1.0)

1.0
smoothing_factor float

Smoothing factor for the multinomial linear regression output (default: 50.0)

50.0
Notes

JIT Compatibility: This layer is designed to work with nnx.jit. Configuration parameters (input_space, clamping_factor, smoothing_factor) are treated as static and will be baked into the compiled function.

References

Shimizu Ryohei, Yusuke Mukuta, and Tatsuya Harada. "Hyperbolic neural networks++." arXiv preprint arXiv:2006.08210 (2020).

Source code in hyperbolix/nn_layers/hyperboloid_linear.py
def __init__(
    self,
    manifold_module: Hyperboloid,
    in_dim: int,
    out_dim: int,
    *,
    rngs: nnx.Rngs,
    input_space: str = "manifold",
    clamping_factor: float = 1.0,
    smoothing_factor: float = 50.0,
):
    if input_space not in ["tangent", "manifold"]:
        raise ValueError(f"input_space must be either 'tangent' or 'manifold', got '{input_space}'")

    # Static configuration (treated as compile-time constants for JIT)
    validate_hyperboloid_manifold(manifold_module, required_methods=("expmap_0", "compute_mlr"))
    self.manifold = manifold_module
    self.in_dim = in_dim
    self.out_dim = out_dim
    self.input_space = input_space
    self.clamping_factor = clamping_factor
    self.smoothing_factor = smoothing_factor

    # Trainable parameters — standard normal init (Shimizu et al. 2020)
    in_spatial = in_dim - 1
    out_spatial = out_dim - 1
    self.kernel = nnx.Param(jax.random.normal(rngs.params(), (out_spatial, in_spatial)))
    self.bias = nnx.Param(jnp.zeros((out_spatial, 1)))
__call__
__call__(
    x: Float[Array, "batch in_dim"], c: float = 1.0
) -> Float[Array, "batch out_dim"]

Forward pass through the HNN++ hyperboloid linear layer.

Parameters:

Name Type Description Default
x Array of shape (batch, in_dim)

Input tensor. x.shape[-1] must equal self.in_dim.

required
c float

Manifold curvature (default: 1.0)

1.0

Returns:

Name Type Description
res Array of shape (batch, out_dim)

Output on the Hyperboloid manifold

Source code in hyperbolix/nn_layers/hyperboloid_linear.py
def __call__(
    self,
    x: Float[Array, "batch in_dim"],
    c: float = 1.0,
) -> Float[Array, "batch out_dim"]:
    """
    Forward pass through the HNN++ hyperboloid linear layer.

    Parameters
    ----------
    x : Array of shape (batch, in_dim)
        Input tensor. x.shape[-1] must equal self.in_dim.
    c : float
        Manifold curvature (default: 1.0)

    Returns
    -------
    res : Array of shape (batch, out_dim)
        Output on the Hyperboloid manifold
    """
    return _hyperboloid_pp_forward(
        x,
        self.kernel[...],
        self.bias[...],
        self.manifold,
        c,
        self.input_space,
        self.clamping_factor,
        self.smoothing_factor,
    )

FGG Linear (Klis et al. 2026)

hyperbolix.nn_layers.FGGLinear

FGGLinear(
    in_features: int,
    out_features: int,
    *,
    rngs: Rngs,
    activation: Callable[[Array], Array] | None = None,
    reset_params: str = "eye",
    use_weight_norm: bool = False,
    init_bias: float = 0.5,
    eps: float = 1e-07,
)

Bases: Module

Fast and Geometrically Grounded Lorentz linear layer.

Implements the FGG linear layer from Klis et al. 2026. The key insight is that the sinh/arcsinh cancellation in the Lorentzian activation chain simplifies the forward pass to: matmul with spacelike V matrix -> Euclidean activation -> time reconstruction. This achieves linear growth of hyperbolic distance (vs logarithmic for Chen et al. 2022) and ~3x faster training/inference.

Forward pass: 1. Build spacelike V matrix from (U, b) with Minkowski metric absorbed 2. z = x @ V (Minkowski inner products via a single matmul) 3. z = h(z) (Euclidean activation, e.g. ReLU) 4. y_0 = sqrt(||z||^2 + 1/c) (time reconstruction) 5. y = [y_0, z] (on hyperboloid)

Parameters:

Name Type Description Default
in_features int

Input ambient dimension (D_in + 1), including time component.

required
out_features int

Output ambient dimension (D_out + 1), including time component.

required
rngs Rngs

Random number generators for parameter initialization.

required
activation Callable or None

Euclidean activation function applied after matmul (default: None).

None
reset_params str

Weight initialization scheme: "eye", "xavier", "kaiming", "lorentz_kaiming", or "mlr" (default: "eye").

'eye'
use_weight_norm bool

If True, reparameterize U as g * v / ||v|| for weight normalization (default: False).

False
init_bias float

Initial value for bias entries (default: 0.5).

0.5
eps float

Numerical stability floor (default: 1e-7).

1e-07
References

Klis et al. "Fast and Geometrically Grounded Lorentz Neural Networks" (2026).

Examples:

>>> from flax import nnx
>>> from hyperbolix.nn_layers import FGGLinear
>>> import jax.numpy as jnp
>>>
>>> layer = FGGLinear(33, 65, rngs=nnx.Rngs(0), activation=jax.nn.relu)
>>> x = jnp.ones((8, 33))
>>> # project to hyperboloid
>>> x = x.at[:, 0].set(jnp.sqrt(jnp.sum(x[:, 1:]**2, axis=-1) + 1.0))
>>> y = layer(x, c=1.0)
>>> y.shape
(8, 65)
Source code in hyperbolix/nn_layers/hyperboloid_linear.py
def __init__(
    self,
    in_features: int,
    out_features: int,
    *,
    rngs: nnx.Rngs,
    activation: Callable[[jax.Array], jax.Array] | None = None,
    reset_params: str = "eye",
    use_weight_norm: bool = False,
    init_bias: float = 0.5,
    eps: float = 1e-7,
):
    if reset_params not in ("eye", "xavier", "kaiming", "lorentz_kaiming", "mlr"):
        raise ValueError(
            f"reset_params must be 'eye', 'xavier', 'kaiming', 'lorentz_kaiming', or 'mlr', got '{reset_params}'"
        )

    in_spatial = in_features - 1  # I
    out_spatial = out_features - 1  # O

    self.in_features = in_features
    self.out_features = out_features
    self.activation = activation
    self.use_weight_norm = use_weight_norm
    self.eps = eps

    # Initialize Euclidean weight U: (I, O)
    # Reference computes std from ambient dimensions (in_features, out_features)
    key = rngs.params()
    if reset_params == "eye":
        U_init = 0.5 * jnp.eye(in_spatial, out_spatial)
    elif reset_params == "xavier":
        std = jnp.sqrt(1.0 / (in_features + out_features))
        U_init = jax.random.normal(key, (in_spatial, out_spatial)) * std
    elif reset_params == "kaiming":
        std = jnp.sqrt(2.0 / in_features)
        U_init = jax.random.normal(key, (in_spatial, out_spatial)) * std
    elif reset_params == "lorentz_kaiming":
        std = jnp.sqrt(1.0 / in_features)
        U_init = jax.random.normal(key, (in_spatial, out_spatial)) * std
    else:  # mlr
        std = jnp.sqrt(5.0 / in_features)
        U_init = jax.random.normal(key, (in_spatial, out_spatial)) * std

    # Weight normalization: decompose kernel = softplus(kernel_scale) * kernel_dir / ||kernel_dir||
    if use_weight_norm:
        # Reference: kernel_dir from reset_params (normalized in forward), kernel_scale fixed magnitude
        self.kernel_dir = nnx.Param(U_init)  # (I, O) direction
        g_init_val = jnp.sqrt(1.0 / (in_features + out_features))
        self.kernel_scale = nnx.Param(jnp.full((out_spatial,), g_init_val))  # (O,)
    else:
        self.kernel = nnx.Param(U_init)  # (I, O)

    # Bias: init to init_bias
    self.bias = nnx.Param(jnp.full((out_spatial,), init_bias))  # (O,)
__call__
__call__(
    x_BAi: Float[Array, "batch in_features"], c: float = 1.0
) -> Float[Array, "batch out_features"]

Forward pass through the FGG linear layer.

Parameters:

Name Type Description Default
x_BAi (Array, shape(B, Ai))

Input points on the hyperboloid with curvature c. Ai = in_features (ambient dimension).

required
c float

Curvature parameter (default: 1.0).

1.0

Returns:

Name Type Description
y_BAo (Array, shape(B, Ao))

Output points on the hyperboloid with curvature c. Ao = out_features (ambient dimension).

Source code in hyperbolix/nn_layers/hyperboloid_linear.py
def __call__(
    self,
    x_BAi: Float[Array, "batch in_features"],
    c: float = 1.0,
) -> Float[Array, "batch out_features"]:
    """Forward pass through the FGG linear layer.

    Parameters
    ----------
    x_BAi : Array, shape (B, Ai)
        Input points on the hyperboloid with curvature ``c``.
        Ai = in_features (ambient dimension).
    c : float, optional
        Curvature parameter (default: 1.0).

    Returns
    -------
    y_BAo : Array, shape (B, Ao)
        Output points on the hyperboloid with curvature ``c``.
        Ao = out_features (ambient dimension).
    """
    U_IO = self._get_kernel()  # (I, O)
    return _fgg_linear_forward(x_BAi, U_IO, self.bias[...], c, self.activation, self.eps)

Usage Example

import jax
from flax import nnx
from hyperbolix.nn_layers import HypLinearPoincare
from hyperbolix.manifolds import Poincare

poincare = Poincare()

# Create hyperbolic linear layer
layer = HypLinearPoincare(
    manifold_module=poincare,
    in_dim=32,
    out_dim=16,
    rngs=nnx.Rngs(0)
)

# Forward pass
x = jax.random.normal(jax.random.PRNGKey(1), (10, 32)) * 0.3
x_proj = jax.vmap(poincare.proj, in_axes=(0, None))(x, 1.0)

output = layer(x_proj, c=1.0)
print(output.shape)  # (10, 16)

Convolutional Layers

Hyperboloid Convolutions

hyperbolix.nn_layers.HypConv2DHyperboloid

HypConv2DHyperboloid(
    manifold_module: Hyperboloid,
    in_channels: int,
    out_channels: int,
    kernel_size: int | tuple[int, int],
    *,
    rngs: Rngs,
    stride: int | tuple[int, int] = 1,
    padding: str = "SAME",
    input_space: str = "manifold",
)

Bases: Module

Hyperbolic 2D Convolutional Layer for Hyperboloid model.

This layer implements fully hyperbolic convolution as described in "Fully Hyperbolic Convolutional Neural Networks for Computer Vision".

Computation steps: 1) Extract receptive field (kernel_size x kernel_size) of hyperbolic points 2) Apply HCat (Lorentz direct concatenation) to combine receptive field points 3) Pass through hyperbolic linear layer (LFC)

Parameters:

Name Type Description Default
manifold_module object

Class-based Hyperboloid manifold instance

required
in_channels int

Number of input channels

required
out_channels int

Number of output channels

required
kernel_size int or tuple[int, int]

Size of the convolutional kernel

required
rngs Rngs

Random number generators for parameter initialization

required
stride int or tuple[int, int]

Stride of the convolution (default: 1)

1
padding str

Padding mode, either 'SAME' or 'VALID' (default: 'SAME')

'SAME'
input_space str

Type of the input tensor, either 'tangent' or 'manifold' (default: 'manifold'). Note: This is a static configuration - changing it after initialization requires recompilation.

'manifold'
Notes

JIT Compatibility: This layer is designed to work with nnx.jit. Configuration parameters (padding, input_space) are treated as static and will be baked into the compiled function.

References

Ahmad Bdeir, Kristian Schwethelm, and Niels Landwehr. "Fully hyperbolic convolutional neural networks for computer vision." arXiv preprint arXiv:2303.15919 (2023).

Source code in hyperbolix/nn_layers/hyperboloid_conv.py
def __init__(
    self,
    manifold_module: Hyperboloid,
    in_channels: int,
    out_channels: int,
    kernel_size: int | tuple[int, int],
    *,
    rngs: nnx.Rngs,
    stride: int | tuple[int, int] = 1,
    padding: str = "SAME",
    input_space: str = "manifold",
):
    if padding not in ["SAME", "VALID"]:
        raise ValueError(f"padding must be either 'SAME' or 'VALID', got '{padding}'")
    if input_space not in ["tangent", "manifold"]:
        raise ValueError(f"input_space must be either 'tangent' or 'manifold', got '{input_space}'")

    # Static configuration
    validate_hyperboloid_manifold(manifold_module, required_methods=("expmap_0", "hcat"))
    self.manifold = manifold_module
    self.in_channels = in_channels
    self.out_channels = out_channels
    self.input_space = input_space
    self.padding = padding

    # Handle kernel_size as int or tuple
    if isinstance(kernel_size, int):
        self.kernel_size = (kernel_size, kernel_size)
    else:
        self.kernel_size = kernel_size

    # Handle stride as int or tuple
    if isinstance(stride, int):
        self.stride = (stride, stride)
    else:
        self.stride = stride

    # Compute dimensions for the linear layer
    # Receptive field: kernel_h x kernel_w pixels, each is a point in in_channels-dim ambient space
    # HCat input: N = kernel_h x kernel_w points, each in in_channels ambient dimensions
    # HCat output: ambient dim = (in_channels - 1) x N + 1
    #            = (in_channels - 1) x (kernel_h x kernel_w) + 1
    kernel_h, kernel_w = self.kernel_size
    N = kernel_h * kernel_w  # Number of points in receptive field
    d = in_channels - 1  # Input manifold dimension
    hcat_out_ambient_dim = d * N + 1  # HCat output ambient dimension

    # Trainable parameters — owned directly for flat parameter paths
    bound = 0.02
    self.kernel = nnx.Param(
        jax.random.uniform(rngs.params(), (out_channels, hcat_out_ambient_dim), minval=-bound, maxval=bound)
    )
    self.bias = nnx.Param(jnp.zeros((1, out_channels)))
    self.scale = 2.3  # not learnable (matches FHCNN default)
__call__
__call__(
    x: Float[Array, "batch height width in_channels"],
    c: float = 1.0,
) -> Float[
    Array, "batch out_height out_width out_channels"
]

Forward pass through the hyperbolic convolutional layer.

Parameters:

Name Type Description Default
x Array of shape (batch, height, width, in_channels)

Input feature map where each pixel is a point on the Hyperboloid manifold

required
c float

Manifold curvature (default: 1.0)

1.0

Returns:

Name Type Description
out Array of shape (batch, out_height, out_width, out_channels)

Output feature map on the Hyperboloid manifold

Source code in hyperbolix/nn_layers/hyperboloid_conv.py
def __call__(
    self,
    x: Float[Array, "batch height width in_channels"],
    c: float = 1.0,
) -> Float[Array, "batch out_height out_width out_channels"]:
    """
    Forward pass through the hyperbolic convolutional layer.

    Parameters
    ----------
    x : Array of shape (batch, height, width, in_channels)
        Input feature map where each pixel is a point on the Hyperboloid manifold
    c : float
        Manifold curvature (default: 1.0)

    Returns
    -------
    out : Array of shape (batch, out_height, out_width, out_channels)
        Output feature map on the Hyperboloid manifold
    """
    # Map to manifold if needed (static branch - JIT friendly)
    if self.input_space == "tangent":
        x_flat_NC = x.reshape(-1, x.shape[-1])  # (B*H*W, C)
        x_mapped_NC = jax.vmap(self.manifold.expmap_0, in_axes=(0, None))(x_flat_NC, c)
        x = x_mapped_NC.reshape(x.shape)  # (B, H, W, C)

    # Extract patches: (B, H, W, kh, kw, C)
    patches_BHWkhkwC = self._extract_patches(x)
    batch, out_h, out_w, kh, kw, in_c = patches_BHWkhkwC.shape

    # Flatten batch+spatial for parallel processing: (B*H*W, K, C)
    patches_flat_NKC = patches_BHWkhkwC.reshape(-1, kh * kw, in_c)

    # HCat: (K, C) -> (hcat_dim,) per patch
    hcat_out_NA = jax.vmap(self.manifold.hcat, in_axes=(0, None))(patches_flat_NKC, c)  # (B*H*W, hcat_dim)

    # Linear: (hcat_dim,) -> (out_channels,)
    linear_out_NC = _fhcnn_forward(
        hcat_out_NA,
        self.kernel[...],
        self.bias[...],
        self.manifold,
        c,
        "manifold",
        None,
        False,
        self.scale,
        1e-5,
    )  # (B*H*W, out_channels)

    # Reshape back to spatial
    output_BHWC = linear_out_NC.reshape(batch, out_h, out_w, self.out_channels)

    return output_BHWC

hyperbolix.nn_layers.HypConv2DHyperboloidFHNN

HypConv2DHyperboloidFHNN(
    manifold_module: Hyperboloid,
    in_channels: int,
    out_channels: int,
    kernel_size: int | tuple[int, int],
    *,
    rngs: Rngs,
    stride: int | tuple[int, int] = 1,
    padding: str = "SAME",
    input_space: str = "manifold",
    init_scale: float = 2.3,
    eps: float = 1e-05,
    activation: Callable[[Array], Array] | None = None,
    dropout_rate: float | None = None,
)

Bases: Module

Fully Hyperbolic Neural Networks 2D convolutional layer (Chen et al. 2021).

Uses HCat (Lorentz direct concatenation) to combine receptive field points, then applies the FHNN linear transform (time-primary sigmoid parameterization) for channel mixing.

Computation steps: 1) Map input to manifold via expmap_0 if input_space="tangent" 2) Extract receptive field (kernel_size x kernel_size) of hyperbolic points 3) Apply HCat (Lorentz direct concatenation) to combine receptive field points 4) Pass through FHNN linear (sigmoid time + spatial rescaling)

Parameters:

Name Type Description Default
manifold_module Hyperboloid

Class-based Hyperboloid manifold instance

required
in_channels int

Number of input channels (ambient dimension, including time component)

required
out_channels int

Number of output channels (ambient dimension, including time component)

required
kernel_size int or tuple[int, int]

Size of the convolutional kernel

required
rngs Rngs

Random number generators for parameter initialization

required
stride int or tuple[int, int]

Stride of the convolution (default: 1)

1
padding str

Padding mode, either 'SAME' or 'VALID' (default: 'SAME')

'SAME'
input_space str

Type of the input tensor, either 'tangent' or 'manifold' (default: 'manifold'). Note: This is a static configuration - changing it after initialization requires recompilation.

'manifold'
init_scale float

Initial value for the learnable sigmoid scale (default: 2.3)

2.3
eps float

Numerical stability epsilon (default: 1e-5)

1e-05
activation callable or None

Activation function to apply before the linear transformation (default: None).

None
dropout_rate float or None

Dropout rate applied before the linear transformation (default: None).

None
Notes

JIT Compatibility: This layer is designed to work with nnx.jit. Configuration parameters (padding, input_space, activation) are treated as static and baked into the compiled function.

See Also

HypConv2DHyperboloid : Equivalent convolution using FHCNN linear instead of FHNN. HypLinearHyperboloidFHNN : The underlying FHNN linear layer.

References

Weize Chen, et al. "Fully hyperbolic neural networks." arXiv preprint arXiv:2105.14686 (2021).

Source code in hyperbolix/nn_layers/hyperboloid_conv.py
def __init__(
    self,
    manifold_module: Hyperboloid,
    in_channels: int,
    out_channels: int,
    kernel_size: int | tuple[int, int],
    *,
    rngs: nnx.Rngs,
    stride: int | tuple[int, int] = 1,
    padding: str = "SAME",
    input_space: str = "manifold",
    init_scale: float = 2.3,
    eps: float = 1e-5,
    activation: Callable[[Array], Array] | None = None,
    dropout_rate: float | None = None,
):
    if padding not in ["SAME", "VALID"]:
        raise ValueError(f"padding must be either 'SAME' or 'VALID', got '{padding}'")
    if input_space not in ["tangent", "manifold"]:
        raise ValueError(f"input_space must be either 'tangent' or 'manifold', got '{input_space}'")

    # Static configuration
    validate_hyperboloid_manifold(manifold_module, required_methods=("expmap_0", "hcat"))
    self.manifold = manifold_module
    self.in_channels = in_channels
    self.out_channels = out_channels
    self.input_space = input_space
    self.padding = padding
    self.eps = eps
    self.activation = activation

    if isinstance(kernel_size, int):
        self.kernel_size = (kernel_size, kernel_size)
    else:
        self.kernel_size = kernel_size

    if isinstance(stride, int):
        self.stride = (stride, stride)
    else:
        self.stride = stride

    # HCat output ambient dim: (in_channels - 1) * kh * kw + 1
    kernel_h, kernel_w = self.kernel_size
    N = kernel_h * kernel_w
    d = in_channels - 1
    hcat_out_ambient_dim = d * N + 1

    # FHNN weight init: U(-0.02, 0.02) with time column zeroed (tangent vectors at origin)
    bound = 0.02
    weight_init = jax.random.uniform(rngs.params(), (out_channels, hcat_out_ambient_dim), minval=-bound, maxval=bound)
    weight_init = weight_init.at[:, 0].set(0.0)
    self.kernel = nnx.Param(weight_init)
    self.bias = nnx.Param(jnp.zeros((1, out_channels)))

    # Learnable scale for the sigmoid (always learnable in FHNN)
    self.scale = nnx.Param(jnp.array(init_scale))

    # Optional dropout
    if dropout_rate is not None and dropout_rate > 0:
        self.dropout = nnx.Dropout(dropout_rate, rngs=rngs)
    else:
        self.dropout = None
__call__
__call__(
    x: Float[Array, "batch height width in_channels"],
    c: float = 1.0,
    deterministic: bool = True,
) -> Float[
    Array, "batch out_height out_width out_channels"
]

Forward pass through the FHNN hyperbolic convolutional layer.

Parameters:

Name Type Description Default
x Array of shape (batch, height, width, in_channels)

Input feature map where each pixel is a point on the Hyperboloid manifold

required
c float

Manifold curvature (default: 1.0)

1.0
deterministic bool

If True, dropout is disabled (default: True).

True

Returns:

Name Type Description
out Array of shape (batch, out_height, out_width, out_channels)

Output feature map on the Hyperboloid manifold

Source code in hyperbolix/nn_layers/hyperboloid_conv.py
def __call__(
    self,
    x: Float[Array, "batch height width in_channels"],
    c: float = 1.0,
    deterministic: bool = True,
) -> Float[Array, "batch out_height out_width out_channels"]:
    """Forward pass through the FHNN hyperbolic convolutional layer.

    Parameters
    ----------
    x : Array of shape (batch, height, width, in_channels)
        Input feature map where each pixel is a point on the Hyperboloid manifold
    c : float
        Manifold curvature (default: 1.0)
    deterministic : bool
        If True, dropout is disabled (default: True).

    Returns
    -------
    out : Array of shape (batch, out_height, out_width, out_channels)
        Output feature map on the Hyperboloid manifold
    """
    # Map to manifold if needed (static branch - JIT friendly)
    if self.input_space == "tangent":
        x_flat_NC = x.reshape(-1, x.shape[-1])  # (B*H*W, C)
        x_mapped_NC = jax.vmap(self.manifold.expmap_0, in_axes=(0, None))(x_flat_NC, c)
        x = x_mapped_NC.reshape(x.shape)  # (B, H, W, C)

    # Extract patches: (B, H, W, kh, kw, C)
    patches_BHWkhkwC = self._extract_patches(x)
    batch, out_h, out_w, kh, kw, in_c = patches_BHWkhkwC.shape

    # Flatten batch+spatial for parallel processing: (B*H*W, K, C)
    patches_flat_NKC = patches_BHWkhkwC.reshape(-1, kh * kw, in_c)

    # HCat: (K, C) -> (hcat_dim,) per patch
    hcat_out_NA = jax.vmap(self.manifold.hcat, in_axes=(0, None))(patches_flat_NKC, c)  # (B*H*W, hcat_dim)

    # Build dropout closure for the pure function
    dropout_module = self.dropout
    if dropout_module is not None:
        dropout_fn = lambda z: dropout_module(z, deterministic=deterministic)  # noqa: E731
    else:
        dropout_fn = None

    # FHNN linear: (hcat_dim,) -> (out_channels,)
    linear_out_NC = _fhnn_forward(
        hcat_out_NA,
        self.kernel[...],
        self.bias[...],
        self.manifold,
        c,
        "manifold",  # HCat output is already on manifold
        self.activation,
        dropout_fn,
        self.scale[...],
        self.eps,
    )  # (B*H*W, out_channels)

    # Reshape back to spatial
    return linear_out_NC.reshape(batch, out_h, out_w, self.out_channels)

hyperbolix.nn_layers.HypConv2DHyperboloidPP

HypConv2DHyperboloidPP(
    manifold_module: Hyperboloid,
    in_channels: int,
    out_channels: int,
    kernel_size: int | tuple[int, int],
    *,
    rngs: Rngs,
    stride: int | tuple[int, int] = 1,
    padding: str = "SAME",
    input_space: str = "manifold",
    clamping_factor: float = 1.0,
    smoothing_factor: float = 50.0,
)

Bases: Module

Hyperbolic Neural Networks ++ 2D convolutional layer (Hyperboloid model).

Uses HCat (Lorentz direct concatenation) to combine receptive field points, then applies the HNN++ linear transform (MLR + sinh diffeomorphism) for channel mixing.

Computation steps: 1) Extract receptive field (kernel_size x kernel_size) of hyperbolic points 2) Apply HCat (Lorentz direct concatenation) to combine receptive field points 3) Pass through HNN++ linear (MLR scores -> sinh -> time reconstruction)

Parameters:

Name Type Description Default
manifold_module Hyperboloid

Class-based Hyperboloid manifold instance

required
in_channels int

Number of input channels (ambient dimension, including time component)

required
out_channels int

Number of output channels (ambient dimension, including time component)

required
kernel_size int or tuple[int, int]

Size of the convolutional kernel

required
rngs Rngs

Random number generators for parameter initialization

required
stride int or tuple[int, int]

Stride of the convolution (default: 1)

1
padding str

Padding mode, either 'SAME' or 'VALID' (default: 'SAME')

'SAME'
input_space str

Type of the input tensor, either 'tangent' or 'manifold' (default: 'manifold'). Note: This is a static configuration - changing it after initialization requires recompilation.

'manifold'
clamping_factor float

Clamping factor for the multinomial linear regression output (default: 1.0)

1.0
smoothing_factor float

Smoothing factor for the multinomial linear regression output (default: 50.0)

50.0
Notes

JIT Compatibility: This layer is designed to work with nnx.jit. Configuration parameters (padding, input_space, clamping_factor, smoothing_factor) are treated as static and baked into the compiled function.

References

Shimizu Ryohei, Yusuke Mukuta, and Tatsuya Harada. "Hyperbolic neural networks++." arXiv preprint arXiv:2006.08210 (2020).

Source code in hyperbolix/nn_layers/hyperboloid_conv.py
def __init__(
    self,
    manifold_module: Hyperboloid,
    in_channels: int,
    out_channels: int,
    kernel_size: int | tuple[int, int],
    *,
    rngs: nnx.Rngs,
    stride: int | tuple[int, int] = 1,
    padding: str = "SAME",
    input_space: str = "manifold",
    clamping_factor: float = 1.0,
    smoothing_factor: float = 50.0,
):
    if padding not in ["SAME", "VALID"]:
        raise ValueError(f"padding must be either 'SAME' or 'VALID', got '{padding}'")
    if input_space not in ["tangent", "manifold"]:
        raise ValueError(f"input_space must be either 'tangent' or 'manifold', got '{input_space}'")

    # Static configuration
    validate_hyperboloid_manifold(manifold_module, required_methods=("expmap_0", "compute_mlr", "hcat"))
    self.manifold = manifold_module
    self.in_channels = in_channels
    self.out_channels = out_channels
    self.input_space = input_space
    self.padding = padding
    self.clamping_factor = clamping_factor
    self.smoothing_factor = smoothing_factor

    if isinstance(kernel_size, int):
        self.kernel_size = (kernel_size, kernel_size)
    else:
        self.kernel_size = kernel_size

    if isinstance(stride, int):
        self.stride = (stride, stride)
    else:
        self.stride = stride

    # HCat output ambient dim: (in_channels - 1) * kh * kw + 1
    kernel_h, kernel_w = self.kernel_size
    N = kernel_h * kernel_w
    d = in_channels - 1
    hcat_out_ambient_dim = d * N + 1

    # Trainable parameters — standard normal init (Shimizu et al. 2020)
    out_spatial = out_channels - 1
    hcat_spatial = hcat_out_ambient_dim - 1
    self.kernel = nnx.Param(jax.random.normal(rngs.params(), (out_spatial, hcat_spatial)))
    self.bias = nnx.Param(jnp.zeros((out_spatial, 1)))
__call__
__call__(
    x: Float[Array, "batch height width in_channels"],
    c: float = 1.0,
) -> Float[
    Array, "batch out_height out_width out_channels"
]

Forward pass through the HNN++ hyperboloid convolutional layer.

Parameters:

Name Type Description Default
x Array of shape (batch, height, width, in_channels)

Input feature map where each pixel is a point on the Hyperboloid manifold

required
c float

Manifold curvature (default: 1.0)

1.0

Returns:

Name Type Description
out Array of shape (batch, out_height, out_width, out_channels)

Output feature map on the Hyperboloid manifold

Source code in hyperbolix/nn_layers/hyperboloid_conv.py
def __call__(
    self,
    x: Float[Array, "batch height width in_channels"],
    c: float = 1.0,
) -> Float[Array, "batch out_height out_width out_channels"]:
    """
    Forward pass through the HNN++ hyperboloid convolutional layer.

    Parameters
    ----------
    x : Array of shape (batch, height, width, in_channels)
        Input feature map where each pixel is a point on the Hyperboloid manifold
    c : float
        Manifold curvature (default: 1.0)

    Returns
    -------
    out : Array of shape (batch, out_height, out_width, out_channels)
        Output feature map on the Hyperboloid manifold
    """
    # Map to manifold if needed (static branch - JIT friendly)
    if self.input_space == "tangent":
        x_flat_NC = x.reshape(-1, x.shape[-1])  # (B*H*W, C)
        x_mapped_NC = jax.vmap(self.manifold.expmap_0, in_axes=(0, None))(x_flat_NC, c)
        x = x_mapped_NC.reshape(x.shape)  # (B, H, W, C)

    # Extract patches: (B, H', W', kh, kw, C)
    patches_BHWkhkwC = self._extract_patches(x)
    batch, out_h, out_w, kh, kw, in_c = patches_BHWkhkwC.shape

    # Flatten batch+spatial for parallel processing: (B*H'*W', K, C)
    patches_flat_NKC = patches_BHWkhkwC.reshape(-1, kh * kw, in_c)

    # HCat: (K, C) -> (hcat_dim,) per patch
    hcat_out_NA = jax.vmap(self.manifold.hcat, in_axes=(0, None))(patches_flat_NKC, c)  # (B*H'*W', hcat_dim)

    # HNN++ linear: (hcat_dim,) -> (out_channels,)
    linear_out_NC = _hyperboloid_pp_forward(
        hcat_out_NA,
        self.kernel[...],
        self.bias[...],
        self.manifold,
        c,
        "manifold",  # HCat output is already on manifold
        self.clamping_factor,
        self.smoothing_factor,
    )  # (B*H'*W', out_channels)

    # Reshape back to spatial
    return linear_out_NC.reshape(batch, out_h, out_w, self.out_channels)

Usage Example

import jax
import jax.numpy as jnp
from hyperbolix.nn_layers import HypConv2DHyperboloid
from hyperbolix.manifolds import Hyperboloid
from flax import nnx

hyperboloid = Hyperboloid()

# Create 2D hyperbolic convolution
conv = HypConv2DHyperboloid(
    manifold_module=hyperboloid,
    in_channels=16,
    out_channels=32,
    kernel_size=(3, 3),
    stride=(1, 1),
    rngs=nnx.Rngs(0)
)

# Input: (batch, height, width, in_channels) — ambient dim = in_channels+1
x = jax.random.normal(jax.random.PRNGKey(1), (8, 28, 28, 16))

# Project to hyperboloid
x_ambient = jnp.concatenate([
    jnp.sqrt(jnp.sum(x**2, axis=-1, keepdims=True) + 1.0),
    x
], axis=-1)  # (8, 28, 28, 17)

# Forward pass (input_space="manifold" by default)
output = conv(x_ambient, c=1.0)
print(output.shape)  # (8, 28, 28, 32×9+1) - dimension grows!

Dimensional Growth

Hyperboloid convolutions increase dimensionality via HCat operation:

  • Input: d+1 dimensions
  • Output: (d×N)+1 dimensions where N = kernel_height × kernel_width

For 3×3 kernel: 3D input → 28D output. Use small kernels or add dimensionality reduction layers.

Poincaré Convolution

hyperbolix.nn_layers.HypConv2DPoincare

HypConv2DPoincare(
    manifold_module: Poincare,
    in_channels: int,
    out_channels: int,
    kernel_size: int | tuple[int, int],
    *,
    rngs: Rngs,
    stride: int | tuple[int, int] = 1,
    padding: str = "SAME",
    input_space: str = "tangent",
    id_init: bool = True,
    clamping_factor: float = 1.0,
    smoothing_factor: float = 50.0,
)

Bases: Module

Hyperbolic 2D Convolutional Layer for Poincaré ball model.

This layer implements hyperbolic convolution following the Poincaré ResNet (van Spengler et al. 2023) approach: beta-scaling in tangent space, patch extraction, expmap to manifold, HNN++ fully-connected, logmap back to tangent space.

The layer operates in tangent space internally and returns tangent-space output, matching the reference implementation. This design avoids the numerically unstable logmap_0 round-trips that cause NaN gradients when points approach the Poincaré ball boundary.

Computation steps: 1) Map to tangent space if input is on manifold 2) Scale tangent vectors by beta function ratio (beta-concatenation scaling) 3) Extract patches via im2col (zero-padding in tangent space) 4) Map concatenated patch vectors to manifold via expmap_0 5) Apply HNN++ fully-connected layer 6) Map back to tangent space via logmap_0

Parameters:

Name Type Description Default
manifold_module Poincare

Class-based Poincaré manifold instance

required
in_channels int

Number of input channels (Poincaré ball dimension per pixel)

required
out_channels int

Number of output channels (Poincaré ball dimension per pixel)

required
kernel_size int or tuple[int, int]

Size of the convolutional kernel

required
rngs Rngs

Random number generators for parameter initialization

required
stride int or tuple[int, int]

Stride of the convolution (default: 1)

1
padding str

Padding mode, either 'SAME' or 'VALID' (default: 'SAME')

'SAME'
input_space str

Type of the input tensor, either 'tangent' or 'manifold' (default: 'tangent'). Note: This is a static configuration - changing it after initialization requires recompilation.

'tangent'
id_init bool

If True, use identity initialization (1/2 * I) for the linear sublayer weights, matching the reference Poincaré ResNet implementation (default: True). The 1/2 factor compensates for the factor of 2 inside the HNN++ distance formula.

True
clamping_factor float

Clamping factor for the HNN++ linear layer output (default: 1.0)

1.0
smoothing_factor float

Smoothing factor for the HNN++ linear layer output (default: 50.0)

50.0
Notes

Output Space: This layer always returns tangent-space output (matching the reference). Between conv layers, use standard activations (e.g., jax.nn.relu) directly on the tangent-space features. Use expmap_0 to map back to the manifold only when needed (e.g., before the classification head).

JIT Compatibility: This layer is designed to work with nnx.jit. Configuration parameters (padding, input_space) are treated as static and will be baked into the compiled function.

Dimension math: - beta-scaling + patch extraction: (H, W, C_in) → (oh, ow, K^2 * C_in) - HNN++ linear: in_dim = K^2 * C_in, out_dim = C_out

References

Shimizu et al. "Hyperbolic neural networks++." arXiv:2006.08210 (2020). van Spengler et al. "Poincaré ResNet." ICML 2023.

Source code in hyperbolix/nn_layers/poincare_conv.py
def __init__(
    self,
    manifold_module: Poincare,
    in_channels: int,
    out_channels: int,
    kernel_size: int | tuple[int, int],
    *,
    rngs: nnx.Rngs,
    stride: int | tuple[int, int] = 1,
    padding: str = "SAME",
    input_space: str = "tangent",
    id_init: bool = True,
    clamping_factor: float = 1.0,
    smoothing_factor: float = 50.0,
):
    if padding not in ["SAME", "VALID"]:
        raise ValueError(f"padding must be either 'SAME' or 'VALID', got '{padding}'")
    if input_space not in ["tangent", "manifold"]:
        raise ValueError(f"input_space must be either 'tangent' or 'manifold', got '{input_space}'")

    # Static configuration
    validate_poincare_manifold(
        manifold_module,
        required_methods=("expmap_0", "logmap_0", "compute_mlr_pp"),
    )
    self.manifold = manifold_module
    self.in_channels = in_channels
    self.out_channels = out_channels
    self.input_space = input_space
    self.padding = padding

    # Handle kernel_size as int or tuple
    if isinstance(kernel_size, int):
        self.kernel_size = (kernel_size, kernel_size)
    else:
        self.kernel_size = kernel_size

    # Handle stride as int or tuple
    if isinstance(stride, int):
        self.stride = (stride, stride)
    else:
        self.stride = stride

    # Precompute beta function ratio for tangent-space scaling
    # B(n/2, 1/2) / B(n_i/2, 1/2) where n = K^2 * C_in, n_i = C_in
    kernel_h, kernel_w = self.kernel_size
    K2 = kernel_h * kernel_w
    concat_dim = K2 * in_channels
    beta_n = jax.scipy.special.beta(concat_dim / 2.0, 0.5)
    beta_ni = jax.scipy.special.beta(in_channels / 2.0, 0.5)
    self.beta_scale = float(beta_n / beta_ni)

    self.clamping_factor = clamping_factor
    self.smoothing_factor = smoothing_factor

    # Trainable parameters — owned directly for flat parameter paths
    # Weight initialization: identity from reference (van Spengler et al. 2023)
    # or scaled normal std = 1/sqrt(fan_in)
    if id_init:
        # W = 1/2 * I(C_out, K^2*C_in)
        # The 1/2 factor compensates for the factor of 2 in the HNN++ distance formula.
        kernel_init = 0.5 * jnp.eye(out_channels, concat_dim)
    else:
        std = 1.0 / jnp.sqrt(concat_dim)
        kernel_init = jax.random.normal(rngs.params(), (out_channels, concat_dim)) * std
    self.kernel = nnx.Param(kernel_init)
    self.bias = nnx.Param(jnp.zeros((out_channels, 1)))
__call__
__call__(
    x: Float[Array, "batch height width in_channels"],
    c: float = 1.0,
) -> Float[
    Array, "batch out_height out_width out_channels"
]

Forward pass through the Poincaré convolutional layer.

Follows the reference computation flow: tangent-space beta-scaling, patch extraction, expmap_0, HNN++ FC, logmap_0.

Parameters:

Name Type Description Default
x Array of shape (batch, height, width, in_channels)

Input feature map. Can be tangent-space or manifold points depending on input_space setting.

required
c float

Manifold curvature (default: 1.0)

1.0

Returns:

Name Type Description
out Array of shape (batch, out_height, out_width, out_channels)

Output feature map in tangent space at the origin. Use standard activations (e.g., jax.nn.relu) between layers.

Source code in hyperbolix/nn_layers/poincare_conv.py
def __call__(
    self,
    x: Float[Array, "batch height width in_channels"],
    c: float = 1.0,
) -> Float[Array, "batch out_height out_width out_channels"]:
    """
    Forward pass through the Poincaré convolutional layer.

    Follows the reference computation flow: tangent-space beta-scaling,
    patch extraction, expmap_0, HNN++ FC, logmap_0.

    Parameters
    ----------
    x : Array of shape (batch, height, width, in_channels)
        Input feature map. Can be tangent-space or manifold points
        depending on input_space setting.
    c : float
        Manifold curvature (default: 1.0)

    Returns
    -------
    out : Array of shape (batch, out_height, out_width, out_channels)
        Output feature map in tangent space at the origin.
        Use standard activations (e.g., jax.nn.relu) between layers.
    """
    # Step 1: Map to tangent space if input is on manifold
    if self.input_space == "manifold":
        orig_shape = x.shape
        x_flat_NC = x.reshape(-1, x.shape[-1])  # (B*H*W, C_in)
        x_flat_NC = jax.vmap(self.manifold.logmap_0, in_axes=(0, None))(x_flat_NC, c)
        x = x_flat_NC.reshape(orig_shape)

    # Now x is in tangent space: (B, H, W, C_in)

    # Step 2: Scale tangent vectors by beta ratio (matching reference)
    x = x * self.beta_scale

    # Step 3: Extract patches in tangent space using zero-padding
    kernel_h, kernel_w = self.kernel_size
    stride_h, stride_w = self.stride

    patches_BHWKC = jax.lax.conv_general_dilated_patches(
        lhs=x,
        filter_shape=(kernel_h, kernel_w),
        window_strides=(stride_h, stride_w),
        padding=self.padding,
        dimension_numbers=("NHWC", "OIHW", "NHWC"),
    )  # (B, H, W, K²·C_in)

    batch, out_h, out_w, concat_dim = patches_BHWKC.shape

    # Step 4: Map concatenated patch vectors to manifold via expmap_0
    patches_flat_NKC = patches_BHWKC.reshape(-1, concat_dim)  # (N, K²·C_in) where N=B*H*W
    manifold_pts_NKC = jax.vmap(self.manifold.expmap_0, in_axes=(0, None))(
        patches_flat_NKC, c
    )  # (N, K²·C_in) on Poincaré ball

    # Step 5: HNN++ FC — (N, K²·C_in) on manifold → (N, C_out) on manifold
    fc_out_NC = _poincare_pp_forward(
        manifold_pts_NKC,
        self.kernel[...],
        self.bias[...],
        self.manifold,
        c,
        "manifold",
        self.clamping_factor,
        self.smoothing_factor,
    )

    # Step 6: Map back to tangent space
    tangent_out_NC = jax.vmap(self.manifold.logmap_0, in_axes=(0, None))(fc_out_NC, c)  # (N, C_out)

    # Reshape to spatial output
    output_BHWC = tangent_out_NC.reshape(batch, out_h, out_w, self.out_channels)

    return output_BHWC

HypConv2DPoincare extracts patches, applies beta-concatenation (HNN++, Shimizu et al. 2020) over the receptive field, then passes through a HypLinearPoincarePP layer. Dimension math: K² × C_in → C_out where K is the kernel size.

Key Differences from Hyperboloid Convolutions:

Feature HypConv2DPoincare HypConv2DHyperboloid
Model Poincaré ball Hyperboloid
Aggregation Beta-concatenation HCat (Lorentz concatenation)
Dimension Preserved Grows: (d-1)×K²+1
Default input Tangent space Manifold (ambient)

Usage Example

import jax
import jax.numpy as jnp
from hyperbolix.nn_layers import HypConv2DPoincare
from hyperbolix.manifolds import Poincare
from flax import nnx

poincare = Poincare()

# Create Poincaré 2D convolution
conv = HypConv2DPoincare(
    manifold_module=poincare,
    in_channels=16,
    out_channels=32,
    kernel_size=3,
    stride=1,
    rngs=nnx.Rngs(0)
)

# Input: (batch, height, width, in_channels) in tangent space (default input_space="tangent")
x = jax.random.normal(jax.random.PRNGKey(1), (8, 28, 28, 16)) * 0.1

# Forward pass — returns tangent-space output
output = conv(x, c=1.0)
print(output.shape)  # (8, 28, 28, 32)

Poincaré Batch Normalization

hyperbolix.nn_layers.PoincareBatchNorm2D

PoincareBatchNorm2D(
    manifold_module: Poincare,
    num_features: int,
    *,
    use_running_average: bool = False,
    momentum: float = 0.9,
    eps: float = 1e-06,
)

Bases: Module

Poincaré batch normalization for 2D feature maps.

Operates in tangent space (matching conv layer I/O). Internally maps to the manifold for geometric operations (midpoint, parallel transport, variance rescaling), then maps back to tangent space.

Follows nnx.BatchNorm interface: use_running_average is a constructor parameter overridable at call time.

Parameters:

Name Type Description Default
manifold_module Poincare

Poincaré manifold instance.

required
num_features int

Number of channels (Poincaré ball dimension per pixel).

required
use_running_average bool

If True, use running statistics instead of batch statistics (default: False). Overridable at call time.

False
momentum float

EMA momentum for running statistics (default: 0.9).

0.9
eps float

Numerical stability floor (default: 1e-6).

1e-06
References

van Spengler et al. "Poincaré ResNet." ICML 2023.

Source code in hyperbolix/nn_layers/poincare_batchnorm.py
def __init__(
    self,
    manifold_module: Poincare,
    num_features: int,
    *,
    use_running_average: bool = False,
    momentum: float = 0.9,
    eps: float = 1e-6,
):
    validate_poincare_manifold(
        manifold_module,
        required_methods=("expmap_0", "logmap_0", "logmap", "expmap", "ptransp", "proj", "dist", "conformal_factor"),
    )
    self.manifold = manifold_module
    self.num_features = num_features
    self.use_running_average = use_running_average
    self.momentum = momentum
    self.eps = eps

    # Learnable parameters (tangent space)
    self.mean = nnx.Param(jnp.zeros((num_features,)))  # learned mean
    self.var = nnx.Param(jnp.ones(()))  # learned variance (scalar)

    # Running statistics (tangent space)
    self.running_mean = nnx.BatchStat(jnp.zeros((num_features,)))
    self.running_var = nnx.BatchStat(jnp.ones(()))
__call__
__call__(
    x_BHWC: Float[Array, "B H W C"],
    c: float = 1.0,
    use_running_average: bool | None = None,
) -> Float[Array, "B H W C"]

Apply Poincaré batch normalization.

Parameters:

Name Type Description Default
x_BHWC (Array, shape(B, H, W, C))

Tangent-space input features.

required
c float

Curvature (positive, default: 1.0).

1.0
use_running_average bool or None

Override constructor setting. None uses constructor value.

None

Returns:

Type Description
(Array, shape(B, H, W, C))

Normalized tangent-space output.

Source code in hyperbolix/nn_layers/poincare_batchnorm.py
def __call__(
    self,
    x_BHWC: Float[Array, "B H W C"],
    c: float = 1.0,
    use_running_average: bool | None = None,
) -> Float[Array, "B H W C"]:
    """Apply Poincaré batch normalization.

    Parameters
    ----------
    x_BHWC : Array, shape (B, H, W, C)
        Tangent-space input features.
    c : float
        Curvature (positive, default: 1.0).
    use_running_average : bool or None
        Override constructor setting. None uses constructor value.

    Returns
    -------
    Array, shape (B, H, W, C)
        Normalized tangent-space output.
    """
    # Resolve use_running_average: call-time > constructor
    if use_running_average is None:
        use_running_average = self.use_running_average

    B, H, W, C = x_BHWC.shape

    # Flatten (B, H, W, C) → (N, C)
    x_NC = x_BHWC.reshape(-1, C)  # (N, C)

    # --- Map to manifold ---
    # x_NC are tangent vectors; map to Poincaré ball
    x_manifold_NC = jax.vmap(self.manifold.expmap_0, in_axes=(0, None))(x_NC, c)  # (N, C)

    if use_running_average:
        # Use running stats (eval mode)
        batch_mean_tangent_C = self.running_mean[...]  # (C,)
        batch_var = self.running_var[...]  # scalar
    else:
        # Compute batch stats on manifold
        batch_midpoint_C = poincare_midpoint(x_manifold_NC, self.manifold, c, self.eps)  # (C,)
        batch_var = frechet_variance(x_manifold_NC, batch_midpoint_C, self.manifold, c)  # scalar

        # Map midpoint back to tangent space for running stat storage
        batch_mean_tangent_C = self.manifold.logmap_0(batch_midpoint_C, c)  # (C,)

        # Update running statistics (EMA, no gradient flow)
        self.running_mean[...] = jax.lax.stop_gradient(
            self.momentum * self.running_mean[...] + (1.0 - self.momentum) * batch_mean_tangent_C
        )
        self.running_var[...] = jax.lax.stop_gradient(
            self.momentum * self.running_var[...] + (1.0 - self.momentum) * batch_var
        )

    # Get the manifold-space mean to use for geometric operations
    # (either from batch or from running stats)
    if use_running_average:
        input_mean_C = self.manifold.expmap_0(batch_mean_tangent_C, c)  # (C,)
        current_var = batch_var
    else:
        input_mean_C = batch_midpoint_C  # already computed above  # (C,)
        current_var = batch_var

    # Learned mean on manifold
    learned_mean_C = self.manifold.expmap_0(self.mean[...], c)  # (C,)

    # --- Step 4: logmap all points at input mean ---
    v_NC = jax.vmap(self.manifold.logmap, in_axes=(0, None, None))(x_manifold_NC, input_mean_C, c)  # (N, C)

    # --- Step 5: parallel transport from input mean to learned mean ---
    v_NC = jax.vmap(self.manifold.ptransp, in_axes=(0, None, None, None))(v_NC, input_mean_C, learned_mean_C, c)  # (N, C)

    # --- Step 6: rescale by sqrt(learned_var / (batch_var + eps)) ---
    scale = jnp.sqrt(self.var[...] / (current_var + self.eps))  # scalar
    v_NC = v_NC * scale  # (N, C)

    # --- Step 7: expmap at learned mean ---
    out_NC = jax.vmap(self.manifold.expmap, in_axes=(0, None, None))(v_NC, learned_mean_C, c)  # (N, C)

    # --- Step 8: logmap_0 back to tangent space ---
    out_NC = jax.vmap(self.manifold.logmap_0, in_axes=(0, None))(out_NC, c)  # (N, C)

    # Reshape back to (B, H, W, C)
    return out_NC.reshape(B, H, W, C)

PoincareBatchNorm2D operates in tangent space (matching conv layer I/O), mapping to the manifold internally for geometric operations: Einstein midpoint, Fréchet variance, parallel transport, and variance rescaling. Use between Poincaré convolution layers following the reference ResNet pattern: conv → bn → relu → conv → bn → skip.

hyperbolix.nn_layers.poincare_midpoint

poincare_midpoint(
    x_NC: Float[Array, "N C"],
    manifold: Poincare,
    c: float,
    eps: float = 1e-06,
) -> Float[Array, C]

Compute Einstein midpoint of Poincaré ball points.

Uses conformal factor weighting: midpoint = Σ(λ²·x) / Σ(λ²), then projects onto the ball.

Parameters:

Name Type Description Default
x_NC (Array, shape(N, C))

Points on the Poincaré ball.

required
manifold Poincare

Poincaré manifold instance.

required
c float

Curvature (positive).

required
eps float

Numerical stability floor (default: 1e-6).

1e-06

Returns:

Type Description
(Array, shape(C))

Einstein midpoint on the Poincaré ball.

Source code in hyperbolix/nn_layers/poincare_batchnorm.py
def poincare_midpoint(
    x_NC: Float[Array, "N C"],
    manifold: Poincare,
    c: float,
    eps: float = 1e-6,
) -> Float[Array, "C"]:
    """Compute Einstein midpoint of Poincaré ball points.

    Uses conformal factor weighting: midpoint = Σ(λ²·x) / Σ(λ²),
    then projects onto the ball.

    Parameters
    ----------
    x_NC : Array, shape (N, C)
        Points on the Poincaré ball.
    manifold : Poincare
        Poincaré manifold instance.
    c : float
        Curvature (positive).
    eps : float
        Numerical stability floor (default: 1e-6).

    Returns
    -------
    Array, shape (C,)
        Einstein midpoint on the Poincaré ball.
    """
    # lambda_N1: (N, 1) — conformal factor for each point
    lambda_N1 = manifold.conformal_factor(x_NC, c)  # (N, 1)
    lambda_sq_N1 = lambda_N1**2  # (N, 1)

    # Weighted sum: Σ(λ²·x) / Σ(λ²)
    numerator_C = jnp.sum(lambda_sq_N1 * x_NC, axis=0)  # (C,)
    denominator = jnp.sum(lambda_sq_N1, axis=0) + eps  # (1,)

    midpoint_C = numerator_C / denominator  # (C,)
    return manifold.proj(midpoint_C, c)

hyperbolix.nn_layers.frechet_variance

frechet_variance(
    x_NC: Float[Array, "N C"],
    mean_C: Float[Array, C],
    manifold: Poincare,
    c: float,
) -> Float[Array, ""]

Compute Fréchet variance: mean squared geodesic distance to mean.

Parameters:

Name Type Description Default
x_NC (Array, shape(N, C))

Points on the Poincaré ball.

required
mean_C (Array, shape(C))

Mean point on the Poincaré ball.

required
manifold Poincare

Poincaré manifold instance.

required
c float

Curvature (positive).

required

Returns:

Type Description
(Array, scalar)

Mean of squared geodesic distances.

Source code in hyperbolix/nn_layers/poincare_batchnorm.py
def frechet_variance(
    x_NC: Float[Array, "N C"],
    mean_C: Float[Array, "C"],
    manifold: Poincare,
    c: float,
) -> Float[Array, ""]:
    """Compute Fréchet variance: mean squared geodesic distance to mean.

    Parameters
    ----------
    x_NC : Array, shape (N, C)
        Points on the Poincaré ball.
    mean_C : Array, shape (C,)
        Mean point on the Poincaré ball.
    manifold : Poincare
        Poincaré manifold instance.
    c : float
        Curvature (positive).

    Returns
    -------
    Array, scalar
        Mean of squared geodesic distances.
    """
    # dist per point: (N,)
    dists_N = jax.vmap(manifold.dist, in_axes=(0, None, None))(x_NC, mean_C, c)
    return jnp.mean(dists_N**2)

Poincaré BatchNorm Example

from hyperbolix.nn_layers import HypConv2DPoincare, PoincareBatchNorm2D
from hyperbolix.manifolds import Poincare
from flax import nnx
import jax

poincare = Poincare()

# ResNet-style block: conv → bn → relu
conv = HypConv2DPoincare(
    manifold_module=poincare,
    in_channels=16, out_channels=16,
    kernel_size=3, rngs=nnx.Rngs(0),
)
bn = PoincareBatchNorm2D(poincare, num_features=16)

# Input: tangent-space features (B, H, W, C)
x = jax.random.normal(jax.random.PRNGKey(0), (4, 8, 8, 16)) * 0.1

# Training: use_running_average=False (default)
h = conv(x, c=0.1)
h = bn(h, c=0.1)
h = jax.nn.relu(h)

# Evaluation: use_running_average=True
h_eval = conv(x, c=0.1)
h_eval = bn(h_eval, c=0.1, use_running_average=True)
h_eval = jax.nn.relu(h_eval)

FGGConv2D (Klis et al. 2026)

hyperbolix.nn_layers.FGGConv2D

FGGConv2D(
    manifold_module: Hyperboloid,
    in_channels: int,
    out_channels: int,
    kernel_size: int | tuple[int, int],
    *,
    rngs: Rngs,
    stride: int | tuple[int, int] = 1,
    padding: str = "SAME",
    pad_mode: str = "origin",
    activation: Callable | None = None,
    reset_params: str = "lorentz_kaiming",
    use_weight_norm: bool = False,
    init_bias: float = 0.5,
    eps: float = 1e-07,
)

Bases: Module

Fast and Geometrically Grounded Lorentz 2D convolutional layer.

Uses HCat (Lorentz direct concatenation) to combine receptive field points, then applies FGGLinear for the channel mixing. This matches the reference implementation pattern from Klis et al. 2026.

Computation steps: 1) Extract receptive field patches, pad with manifold origin if needed 2) Apply HCat (Lorentz direct concatenation) to combine patch points 3) Pass through FGGLinear for channel transformation

Parameters:

Name Type Description Default
manifold_module Hyperboloid

Class-based Hyperboloid manifold instance.

required
in_channels int

Input ambient channels (D_in + 1), including time component.

required
out_channels int

Output ambient channels (D_out + 1), including time component.

required
kernel_size int or tuple[int, int]

Size of the convolutional kernel.

required
rngs Rngs

Random number generators for parameter initialization.

required
stride int or tuple[int, int]

Stride of the convolution (default: 1).

1
padding str

Padding mode: "SAME" or "VALID" (default: "SAME").

'SAME'
pad_mode str

How to fill padding pixels: "origin" fills with the manifold origin (sqrt(1/c), 0, ..., 0) (matching reference), "edge" replicates border values (default: "origin").

'origin'
activation Callable or None

Euclidean activation for the FGGLinear (default: None).

None
reset_params str

Weight init for FGGLinear: "eye", "xavier", "kaiming", "lorentz_kaiming", or "mlr" (default: "kaiming").

'lorentz_kaiming'
use_weight_norm bool

Weight normalization in FGGLinear (default: False).

False
init_bias float

Initial bias for FGGLinear (default: 0.5).

0.5
eps float

Numerical stability floor (default: 1e-7).

1e-07
References

Klis et al. "Fast and Geometrically Grounded Lorentz Neural Networks" (2026).

Source code in hyperbolix/nn_layers/hyperboloid_conv.py
def __init__(
    self,
    manifold_module: Hyperboloid,
    in_channels: int,
    out_channels: int,
    kernel_size: int | tuple[int, int],
    *,
    rngs: nnx.Rngs,
    stride: int | tuple[int, int] = 1,
    padding: str = "SAME",
    pad_mode: str = "origin",
    activation: Callable | None = None,
    reset_params: str = "lorentz_kaiming",
    use_weight_norm: bool = False,
    init_bias: float = 0.5,
    eps: float = 1e-7,
):
    if padding not in ("SAME", "VALID"):
        raise ValueError(f"padding must be 'SAME' or 'VALID', got '{padding}'")
    if pad_mode not in ("origin", "edge"):
        raise ValueError(f"pad_mode must be 'origin' or 'edge', got '{pad_mode}'")

    validate_hyperboloid_manifold(manifold_module, required_methods=("hcat",))
    self.manifold = manifold_module
    self.in_channels = in_channels
    self.out_channels = out_channels
    self.padding = padding
    self.pad_mode = pad_mode
    self.eps = eps

    if isinstance(kernel_size, int):
        self.kernel_size = (kernel_size, kernel_size)
    else:
        self.kernel_size = kernel_size

    if isinstance(stride, int):
        self.stride = (stride, stride)
    else:
        self.stride = stride

    # HCat output ambient dim: (in_channels - 1) * kh * kw + 1
    kh, kw = self.kernel_size
    hcat_out_ambient = (in_channels - 1) * kh * kw + 1

    # Trainable parameters — owned directly for flat parameter paths
    if reset_params not in ("eye", "xavier", "kaiming", "lorentz_kaiming", "mlr"):
        raise ValueError(
            f"reset_params must be 'eye', 'xavier', 'kaiming', 'lorentz_kaiming', or 'mlr', got '{reset_params}'"
        )

    in_spatial = hcat_out_ambient - 1  # I
    out_spatial = out_channels - 1  # O

    self.activation = activation
    self.use_weight_norm = use_weight_norm

    # Initialize Euclidean weight U: (I, O)
    # Reference computes std from ambient dimensions (hcat_out_ambient, out_channels)
    key = rngs.params()
    if reset_params == "eye":
        U_init = 0.5 * jnp.eye(in_spatial, out_spatial)
    elif reset_params == "xavier":
        std = jnp.sqrt(1.0 / (hcat_out_ambient + out_channels))
        U_init = jax.random.normal(key, (in_spatial, out_spatial)) * std
    elif reset_params == "kaiming":
        std = jnp.sqrt(2.0 / hcat_out_ambient)
        U_init = jax.random.normal(key, (in_spatial, out_spatial)) * std
    elif reset_params == "lorentz_kaiming":
        std = jnp.sqrt(1.0 / hcat_out_ambient)
        U_init = jax.random.normal(key, (in_spatial, out_spatial)) * std
    else:  # mlr
        std = jnp.sqrt(5.0 / hcat_out_ambient)
        U_init = jax.random.normal(key, (in_spatial, out_spatial)) * std

    # Weight normalization: decompose kernel = softplus(kernel_scale) * kernel_dir / ||kernel_dir||
    if use_weight_norm:
        self.kernel_dir = nnx.Param(U_init)  # (I, O) direction
        g_init_val = jnp.sqrt(1.0 / (hcat_out_ambient + out_channels))
        self.kernel_scale = nnx.Param(jnp.full((out_spatial,), g_init_val))  # (O,)
    else:
        self.kernel = nnx.Param(U_init)  # (I, O)

    self.bias = nnx.Param(jnp.full((out_spatial,), init_bias))  # (O,)
__call__
__call__(
    x: Float[Array, "batch height width in_channels"],
    c: float = 1.0,
) -> Float[
    Array, "batch out_height out_width out_channels"
]

Forward pass through the FGG convolutional layer.

Parameters:

Name Type Description Default
x (Array, shape(B, H, W, in_channels))

Input feature map on the hyperboloid.

required
c float

Curvature parameter (default: 1.0).

1.0

Returns:

Name Type Description
out Array, shape (B, H', W', out_channels)

Output feature map on the hyperboloid.

Source code in hyperbolix/nn_layers/hyperboloid_conv.py
def __call__(
    self,
    x: Float[Array, "batch height width in_channels"],
    c: float = 1.0,
) -> Float[Array, "batch out_height out_width out_channels"]:
    """Forward pass through the FGG convolutional layer.

    Parameters
    ----------
    x : Array, shape (B, H, W, in_channels)
        Input feature map on the hyperboloid.
    c : float, optional
        Curvature parameter (default: 1.0).

    Returns
    -------
    out : Array, shape (B, H', W', out_channels)
        Output feature map on the hyperboloid.
    """
    # Extract patches: (B, H', W', kh, kw, C)
    patches = self._extract_patches(x, c)
    batch, out_h, out_w, kh, kw, in_c = patches.shape

    # Flatten batch+spatial: (B*H'*W', K, C) where K = kh*kw
    patches_flat_NKC = patches.reshape(-1, kh * kw, in_c)

    # HCat: (K, C) -> (hcat_dim,) per patch
    hcat_out_NA = jax.vmap(self.manifold.hcat, in_axes=(0, None))(patches_flat_NKC, c)

    # FGG forward: (hcat_dim,) -> (out_channels,)
    U_IO = _get_effective_kernel(
        getattr(self, "kernel", None),
        getattr(self, "kernel_dir", None),
        getattr(self, "kernel_scale", None),
        self.use_weight_norm,
        self.eps,
    )
    linear_out_NC = _fgg_linear_forward(hcat_out_NA, U_IO, self.bias[...], c, self.activation, self.eps)

    # Reshape back to spatial
    return linear_out_NC.reshape(batch, out_h, out_w, self.out_channels)

FGGConv2D combines HCat patch extraction with FGGLinear for channel mixing. Unlike HypConv2DHyperboloid, it uses the FGG spacelike-V construction, achieving linear growth of hyperbolic distance rather than logarithmic. Supports manifold-origin padding (pad_mode="origin") matching the reference implementation.

Feature FGGConv2D HypConv2DHyperboloid HypConv2DHyperboloidPP
Linear layer FGGLinear (V-matrix) HypLinearHyperboloidFHCNN HypLinearHyperboloidPP (MLR)
Distance growth Linear Logarithmic Logarithmic
Default padding Manifold origin Edge replication Edge replication
Weight norm Optional (use_weight_norm) No No

LorentzConv2D (HRC-Based)

hyperbolix.nn_layers.LorentzConv2D

LorentzConv2D(
    in_channels: int,
    out_channels: int,
    kernel_size: int | tuple[int, int],
    *,
    rngs: Rngs,
    stride: int | tuple[int, int] = 1,
    padding: str = "SAME",
)

Bases: Module

Lorentz 2D Convolutional Layer using the Hyperbolic Layer (HL) approach.

This layer applies convolution to the space-like components of Lorentzian vectors and reconstructs the time-like component to maintain the manifold constraint. This is equivalent to an HRC (Hyperbolic Regularization Component) wrapper around a standard Conv2D.

Computation steps: 1) Extract space-like components x_s from input x = [x_t, x_s]^T 2) Apply Euclidean convolution: y_s = Conv2D(x_s) 3) Reconstruct time component: y_t = sqrt(||y_s||^2 + 1/c) 4) Return y = [y_t, y_s]^T

Parameters:

Name Type Description Default
in_channels int

Number of input channels (ambient dimension, including time component)

required
out_channels int

Number of output channels (ambient dimension, including time component)

required
kernel_size int or tuple[int, int]

Size of the convolutional kernel

required
rngs Rngs

Random number generators for parameter initialization

required
stride int or tuple[int, int]

Stride of the convolution (default: 1)

1
padding str or int or tuple

Padding mode: 'SAME', 'VALID', or explicit padding (default: 'SAME')

'SAME'
Notes

This implementation follows the Hyperbolic Layer (HL) approach from "Fully Hyperbolic Convolutional Neural Networks for Computer Vision".

The layer operates only on space-like components, making it more computationally efficient than the HCat-based approach (HypConv2DHyperboloid), though it doesn't perform true hyperbolic convolution. Instead, it applies Euclidean operations to spatial components and reconstructs the time component.

See Also

hypformer.hrc : Core HRC function this layer is based on HypConv2DHyperboloid : Full hyperbolic convolution using HCat concatenation

References

He, Neil, Menglin Yang, and Rex Ying. "Lorentzian residual neural networks." Proceedings of the 31st ACM SIGKDD Conference on Knowledge Discovery and Data Mining V. 1. 2025.

Source code in hyperbolix/nn_layers/hyperboloid_conv.py
def __init__(
    self,
    in_channels: int,
    out_channels: int,
    kernel_size: int | tuple[int, int],
    *,
    rngs: nnx.Rngs,
    stride: int | tuple[int, int] = 1,
    padding: str = "SAME",
):
    self.in_channels = in_channels
    self.out_channels = out_channels

    # Create Euclidean conv layer for space components only
    # in_channels - 1: skip time component at index 0
    # out_channels - 1: time will be reconstructed from constraint
    self.conv = nnx.Conv(
        in_features=in_channels - 1,
        out_features=out_channels - 1,
        kernel_size=kernel_size,
        strides=stride,
        padding=padding,
        rngs=rngs,
    )
__call__
__call__(
    x: Float[Array, "batch height width in_channels"],
    c: float = 1.0,
) -> Float[
    Array, "batch out_height out_width out_channels"
]

Forward pass through the Lorentz convolutional layer.

This layer is a specific instance of the Hyperbolic Regularization Component (HRC) where the regularization function f_r is a 2D convolution. The HRC pattern: 1. Extracts space components 2. Applies Euclidean convolution 3. Reconstructs time component using Lorentz constraint

Parameters:

Name Type Description Default
x Array of shape (batch, height, width, in_channels)

Input feature map where x[..., 0] is time component and x[..., 1:] are space components on the Lorentz manifold

required
c float

Manifold curvature parameter (default: 1.0)

1.0

Returns:

Name Type Description
out Array of shape (batch, out_height, out_width, out_channels)

Output feature map on the Lorentz manifold

Notes

This implementation uses the HRC function from hypformer.py, demonstrating that LorentzConv2D (from LResNet) and HRC (from Hypformer) are mathematically equivalent approaches to adapting Euclidean operations for hyperbolic geometry.

Source code in hyperbolix/nn_layers/hyperboloid_conv.py
def __call__(
    self,
    x: Float[Array, "batch height width in_channels"],
    c: float = 1.0,
) -> Float[Array, "batch out_height out_width out_channels"]:
    """
    Forward pass through the Lorentz convolutional layer.

    This layer is a specific instance of the Hyperbolic Regularization Component (HRC)
    where the regularization function f_r is a 2D convolution. The HRC pattern:
    1. Extracts space components
    2. Applies Euclidean convolution
    3. Reconstructs time component using Lorentz constraint

    Parameters
    ----------
    x : Array of shape (batch, height, width, in_channels)
        Input feature map where x[..., 0] is time component and
        x[..., 1:] are space components on the Lorentz manifold
    c : float
        Manifold curvature parameter (default: 1.0)

    Returns
    -------
    out : Array of shape (batch, out_height, out_width, out_channels)
        Output feature map on the Lorentz manifold

    Notes
    -----
    This implementation uses the HRC function from hypformer.py, demonstrating that
    LorentzConv2D (from LResNet) and HRC (from Hypformer) are mathematically equivalent
    approaches to adapting Euclidean operations for hyperbolic geometry.
    """

    # Define convolution as the HRC regularization function f_r
    def conv_fn(x_space):
        return self.conv(x_space)

    # Apply HRC with curvature-preserving transformation (c_in = c_out = c)
    return hrc(x, conv_fn, c_in=c, c_out=c, eps=1e-8)

LorentzConv2D provides a simpler, more efficient alternative to HCat-based convolutions by using the Hyperbolic Regularization Component (HRC) pattern from the Hypformer paper.

Key Differences from HypConv2DHyperboloid:

Feature HypConv2DHyperboloid (HCat+FHCNN) HypConv2DHyperboloidPP (HCat+HNN++) LorentzConv2D (HRC)
Method HCat + FHCNN linear HCat + MLR + sinh diffeomorphism Euclidean conv on space components
Dimension Grows: (d-1)×N+1 Grows: (d-1)×N+1 Preserved
Speed Slower (~80s/epoch) Similar to FHCNN 2.5x faster (~32s/epoch)
Accuracy Higher (~71% on MNIST) Better gradient flow (HNN++) Lower (~46% on MNIST)
Use Case HCat accuracy Deep HCat networks Speed/memory efficiency

Theoretical Connection:

LorentzConv2D implements the Hyperbolic Layer (HL) pattern from LResNet, which is mathematically equivalent to the Hyperbolic Regularization Component (HRC) from Hypformer:

# Both approaches:
# 1. Extract space components: x_s = x[..., 1:]
# 2. Apply Euclidean function: y_s = f(x_s)
# 3. Reconstruct time: y_t = sqrt(||y_s||^2 + 1/c)

Usage Example:

from hyperbolix.nn_layers import LorentzConv2D
from flax import nnx
import jax.numpy as jnp

# Create efficient hyperbolic convolution
conv = LorentzConv2D(
    in_channels=33,    # Including time component
    out_channels=65,   # Including time component
    kernel_size=3,
    stride=2,
    padding="SAME",
    rngs=nnx.Rngs(0)
)

# Input: points on Lorentz manifold (batch, height, width, in_channels)
x = jnp.ones((8, 28, 28, 33))
x_space = x[..., 1:]
x_time = jnp.sqrt(jnp.sum(x_space**2, axis=-1, keepdims=True) + 1.0)
x = jnp.concatenate([x_time, x_space], axis=-1)

# Forward pass
output = conv(x, c=1.0)
print(output.shape)  # (8, 14, 14, 65) - dimensions preserved!

When to Use LorentzConv2D

Choose LorentzConv2D when:

  • Speed and memory efficiency are priorities
  • Working with resource-constrained environments
  • Acceptable accuracy trade-off for 2.5x speedup

Choose HypConv2DHyperboloid or HypConv2DHyperboloidPP when:

  • Maximum accuracy is required
  • Willing to accept slower training and dimensional growth
  • Use HypConv2DHyperboloidPP for deeper networks (better gradient flow via HNN++)

Hypformer Components

The Hyperbolic Transformation Component (HTC) and Hyperbolic Regularization Component (HRC) from the Hypformer paper provide general-purpose wrappers for adapting Euclidean operations to hyperbolic geometry with curvature-change support.

Core Functions

hyperbolix.nn_layers.hrc

hrc(
    x: Float[Array, "... dim_plus_1"],
    f_r: Callable[[Float[Array, ...]], Float[Array, ...]],
    c_in: float,
    c_out: float,
    eps: float = 1e-07,
) -> Float[Array, "... out_dim_plus_1"]

Hyperbolic Regularization Component.

Applies a Euclidean regularization/activation function f_r to the spatial components of hyperboloid points, then maps the result to the hyperboloid with curvature c_out.

Mathematical formula: space = sqrt(c_in/c_out) * f_r(x[..., 1:]) time = sqrt(||space||^2 + 1/c_out) output = [time, space]

When c_in = c_out = c, this reduces to: output = [sqrt(||f_r(x_s)||^2 + 1/c), f_r(x_s)] which is the pattern used by curvature-preserving hyperboloid activations.

Parameters:

Name Type Description Default
x Array of shape (..., dim+1)

Input point(s) on the hyperboloid manifold with curvature c_in. The first element is the time-like component, remaining are spatial.

required
f_r Callable

Euclidean function to apply to spatial components. Can be any activation, normalization, dropout, etc. Takes spatial components and returns transformed spatial components (may change dimension).

required
c_in float

Input curvature parameter (must be positive, c > 0).

required
c_out float

Output curvature parameter (must be positive, c > 0).

required
eps float

Small value for numerical stability (default: 1e-7).

1e-07

Returns:

Name Type Description
y Array of shape (..., out_dim+1)

Output point(s) on the hyperboloid manifold with curvature c_out.

Notes
  • f_r operates only on spatial components x[..., 1:], not the time component
  • The time component is reconstructed using the hyperboloid constraint: -x₀² + ||x_rest||² = -1/c_out
  • This avoids expensive exp/log maps while maintaining mathematical correctness
  • The spatial scaling factor sqrt(c_in/c_out) ensures proper curvature transformation
See Also

htc : Hyperbolic Transformation Component for full-point operations.

References

Hypformer paper (citation to be added)

Examples:

>>> import jax.numpy as jnp
>>> from hyperbolix.nn_layers.hyperboloid_core import hrc
>>> from hyperbolix.manifolds import Hyperboloid
>>>
>>> # Create a point on the hyperboloid
>>> manifold = Hyperboloid()
>>> x = jnp.array([1.05, 0.1, -0.2, 0.15])
>>> x = manifold.proj(x, c=1.0)
>>>
>>> # Apply HRC with ReLU (curvature-preserving)
>>> y = hrc(x, jax.nn.relu, c_in=1.0, c_out=1.0)
>>>
>>> # Apply HRC with curvature change
>>> y = hrc(x, jax.nn.relu, c_in=1.0, c_out=2.0)
>>>
>>> # Custom activation
>>> def custom_act(z):
...     return jax.nn.gelu(z) * 0.5
>>> y = hrc(x, custom_act, c_in=1.0, c_out=0.5)
Source code in hyperbolix/nn_layers/hyperboloid_core.py
def hrc(
    x: Float[Array, "... dim_plus_1"],
    f_r: Callable[[Float[Array, "..."]], Float[Array, "..."]],
    c_in: float,
    c_out: float,
    eps: float = 1e-7,
) -> Float[Array, "... out_dim_plus_1"]:
    """Hyperbolic Regularization Component.

    Applies a Euclidean regularization/activation function f_r to the spatial
    components of hyperboloid points, then maps the result to the hyperboloid
    with curvature c_out.

    Mathematical formula:
        space = sqrt(c_in/c_out) * f_r(x[..., 1:])
        time  = sqrt(||space||^2 + 1/c_out)
        output = [time, space]

    When c_in = c_out = c, this reduces to:
        output = [sqrt(||f_r(x_s)||^2 + 1/c), f_r(x_s)]
    which is the pattern used by curvature-preserving hyperboloid activations.

    Parameters
    ----------
    x : Array of shape (..., dim+1)
        Input point(s) on the hyperboloid manifold with curvature c_in.
        The first element is the time-like component, remaining are spatial.
    f_r : Callable
        Euclidean function to apply to spatial components. Can be any activation,
        normalization, dropout, etc. Takes spatial components and returns
        transformed spatial components (may change dimension).
    c_in : float
        Input curvature parameter (must be positive, c > 0).
    c_out : float
        Output curvature parameter (must be positive, c > 0).
    eps : float, optional
        Small value for numerical stability (default: 1e-7).

    Returns
    -------
    y : Array of shape (..., out_dim+1)
        Output point(s) on the hyperboloid manifold with curvature c_out.

    Notes
    -----
    - f_r operates only on spatial components x[..., 1:], not the time component
    - The time component is reconstructed using the hyperboloid constraint:
      -x₀² + ||x_rest||² = -1/c_out
    - This avoids expensive exp/log maps while maintaining mathematical correctness
    - The spatial scaling factor sqrt(c_in/c_out) ensures proper curvature transformation

    See Also
    --------
    htc : Hyperbolic Transformation Component for full-point operations.

    References
    ----------
    Hypformer paper (citation to be added)

    Examples
    --------
    >>> import jax.numpy as jnp
    >>> from hyperbolix.nn_layers.hyperboloid_core import hrc
    >>> from hyperbolix.manifolds import Hyperboloid
    >>>
    >>> # Create a point on the hyperboloid
    >>> manifold = Hyperboloid()
    >>> x = jnp.array([1.05, 0.1, -0.2, 0.15])
    >>> x = manifold.proj(x, c=1.0)
    >>>
    >>> # Apply HRC with ReLU (curvature-preserving)
    >>> y = hrc(x, jax.nn.relu, c_in=1.0, c_out=1.0)
    >>>
    >>> # Apply HRC with curvature change
    >>> y = hrc(x, jax.nn.relu, c_in=1.0, c_out=2.0)
    >>>
    >>> # Custom activation
    >>> def custom_act(z):
    ...     return jax.nn.gelu(z) * 0.5
    >>> y = hrc(x, custom_act, c_in=1.0, c_out=0.5)
    """
    x_space_D = x[..., 1:]  # (..., D) spatial components

    out_space_D = f_r(x_space_D)  # (..., D') — may change dim

    # Scale for curvature transformation: sqrt(c_in / c_out)
    scale = jnp.sqrt(c_in / c_out)
    scaled_D = scale * out_space_D  # (..., D')

    # Reconstruct time via hyperboloid constraint: x₀ = sqrt(||x_rest||² + 1/c_out)
    norm_sq = jnp.sum(scaled_D**2, axis=-1)  # (...)
    x0 = jnp.sqrt(jnp.maximum(norm_sq + 1.0 / c_out, eps))  # (...)

    return jnp.concatenate([x0[..., None], scaled_D], axis=-1)  # (..., D'+1)

hyperbolix.nn_layers.htc

htc(
    x: Float[Array, "... in_dim_plus_1"],
    f_t: Callable[[Float[Array, ...]], Float[Array, ...]],
    c_in: float,
    c_out: float,
    eps: float = 1e-07,
) -> Float[Array, "... out_dim_plus_1"]

Hyperbolic Transformation Component.

Applies a Euclidean linear transformation f_t to the full hyperboloid point (including time component), then maps the result to the hyperboloid with curvature c_out.

Mathematical formula: space = sqrt(c_in/c_out) * f_t(x) time = sqrt(||space||^2 + 1/c_out) output = [time, space]

where f_t takes the full (dim+1)-dimensional input and produces the output spatial components.

Parameters:

Name Type Description Default
x Array of shape (..., in_dim+1)

Input point(s) on the hyperboloid manifold with curvature c_in. All components (time and spatial) are passed to f_t.

required
f_t Callable

Euclidean linear transformation applied to the full input. Takes (in_dim+1)-dimensional input and produces out_dim-dimensional output (which becomes the spatial components of the output).

required
c_in float

Input curvature parameter (must be positive, c > 0).

required
c_out float

Output curvature parameter (must be positive, c > 0).

required
eps float

Small value for numerical stability (default: 1e-7).

1e-07

Returns:

Name Type Description
y Array of shape (..., out_dim+1)

Output point(s) on the hyperboloid manifold with curvature c_out.

Notes
  • Unlike HRC, f_t operates on the full point including the time component
  • f_t's output dimension determines the output spatial dimension
  • This is typically used for learnable linear transformations
  • The spatial scaling factor sqrt(c_in/c_out) ensures proper curvature transformation
See Also

hrc : Hyperbolic Regularization Component for spatial-only operations. HTCLinear : Module wrapper for htc with learnable linear transformation.

References

Hypformer paper (citation to be added)

Examples:

>>> import jax
>>> import jax.numpy as jnp
>>> from hyperbolix.nn_layers.hyperboloid_core import htc
>>> from hyperbolix.manifolds import Hyperboloid
>>>
>>> # Create a point on the hyperboloid
>>> manifold = Hyperboloid()
>>> x = jnp.array([1.05, 0.1, -0.2, 0.15])
>>> x = manifold.proj(x, c=1.0)
>>>
>>> # Define a linear transformation
>>> W = jax.random.normal(jax.random.PRNGKey(0), (3, 4))
>>> def linear(z):
...     return z @ W.T
>>>
>>> # Apply HTC
>>> y = htc(x, linear, c_in=1.0, c_out=2.0)
>>> y.shape
(4,)  # (3 spatial + 1 time)
Source code in hyperbolix/nn_layers/hyperboloid_core.py
def htc(
    x: Float[Array, "... in_dim_plus_1"],
    f_t: Callable[[Float[Array, "..."]], Float[Array, "..."]],
    c_in: float,
    c_out: float,
    eps: float = 1e-7,
) -> Float[Array, "... out_dim_plus_1"]:
    """Hyperbolic Transformation Component.

    Applies a Euclidean linear transformation f_t to the full hyperboloid point
    (including time component), then maps the result to the hyperboloid with
    curvature c_out.

    Mathematical formula:
        space = sqrt(c_in/c_out) * f_t(x)
        time  = sqrt(||space||^2 + 1/c_out)
        output = [time, space]

    where f_t takes the full (dim+1)-dimensional input and produces the output
    spatial components.

    Parameters
    ----------
    x : Array of shape (..., in_dim+1)
        Input point(s) on the hyperboloid manifold with curvature c_in.
        All components (time and spatial) are passed to f_t.
    f_t : Callable
        Euclidean linear transformation applied to the full input. Takes
        (in_dim+1)-dimensional input and produces out_dim-dimensional output
        (which becomes the spatial components of the output).
    c_in : float
        Input curvature parameter (must be positive, c > 0).
    c_out : float
        Output curvature parameter (must be positive, c > 0).
    eps : float, optional
        Small value for numerical stability (default: 1e-7).

    Returns
    -------
    y : Array of shape (..., out_dim+1)
        Output point(s) on the hyperboloid manifold with curvature c_out.

    Notes
    -----
    - Unlike HRC, f_t operates on the full point including the time component
    - f_t's output dimension determines the output spatial dimension
    - This is typically used for learnable linear transformations
    - The spatial scaling factor sqrt(c_in/c_out) ensures proper curvature transformation

    See Also
    --------
    hrc : Hyperbolic Regularization Component for spatial-only operations.
    HTCLinear : Module wrapper for htc with learnable linear transformation.

    References
    ----------
    Hypformer paper (citation to be added)

    Examples
    --------
    >>> import jax
    >>> import jax.numpy as jnp
    >>> from hyperbolix.nn_layers.hyperboloid_core import htc
    >>> from hyperbolix.manifolds import Hyperboloid
    >>>
    >>> # Create a point on the hyperboloid
    >>> manifold = Hyperboloid()
    >>> x = jnp.array([1.05, 0.1, -0.2, 0.15])
    >>> x = manifold.proj(x, c=1.0)
    >>>
    >>> # Define a linear transformation
    >>> W = jax.random.normal(jax.random.PRNGKey(0), (3, 4))
    >>> def linear(z):
    ...     return z @ W.T
    >>>
    >>> # Apply HTC
    >>> y = htc(x, linear, c_in=1.0, c_out=2.0)
    >>> y.shape
    (4,)  # (3 spatial + 1 time)
    """
    # f_t: (..., A_in) → (..., D_out) where A_in = in_dim+1
    out_D = f_t(x)

    # Scale for curvature transformation
    scale = jnp.sqrt(c_in / c_out)
    scaled_D = scale * out_D  # (..., D_out)

    # Reconstruct time via hyperboloid constraint: x₀ = sqrt(||space||² + 1/c_out)
    norm_sq = jnp.sum(scaled_D**2, axis=-1)  # (...)
    x0 = jnp.sqrt(jnp.maximum(norm_sq + 1.0 / c_out, eps))  # (...)

    return jnp.concatenate([x0[..., None], scaled_D], axis=-1)  # (..., D_out+1)

HTC/HRC Modules

hyperbolix.nn_layers.HTCLinear

HTCLinear(
    in_features: int,
    out_features: int,
    *,
    rngs: Rngs,
    use_bias: bool = True,
    init_bound: float = 0.02,
    eps: float = 1e-07,
)

Bases: Module

Hyperbolic Transformation Component with learnable linear transformation.

This module wraps a Euclidean linear layer with the HTC operation, enabling learnable transformations between hyperboloid manifolds with different curvatures.

Parameters:

Name Type Description Default
in_features int

Input feature dimension (full hyperboloid dimension, including time component).

required
out_features int

Output spatial dimension (time component is reconstructed automatically).

required
rngs Rngs

Random number generators for parameter initialization.

required
use_bias bool

Whether to include a bias term (default: True).

True
init_bound float

Bound for uniform weight initialization. Weights are initialized from Uniform(-init_bound, init_bound). Small values keep initial outputs close to the hyperboloid origin for stable training (default: 0.02).

0.02
eps float

Small value for numerical stability (default: 1e-7).

1e-07

Attributes:

Name Type Description
kernel Param

Weight matrix of shape (in_features, out_features).

bias Param or None

Bias vector of shape (out_features,) if use_bias=True, else None.

eps float

Numerical stability parameter.

Notes

Weight Initialization: This layer uses small uniform initialization U(-0.02, 0.02) by default, matching the initialization used by FHNN/FHCNN layers. Standard deep learning initializations (Xavier, Lecun) produce weights that are too large for hyperbolic operations, causing gradient explosion and training instability.

See Also

hyperbolix.nn_layers.hyperboloid_core.htc : Core HTC function for functional transformations. HypLinearHyperboloidFHCNN : Alternative hyperbolic linear layer with sigmoid scaling.

References

Hypformer paper (citation to be added)

Examples:

>>> from flax import nnx
>>> from hyperbolix.nn_layers import HTCLinear
>>> from hyperbolix.manifolds import Hyperboloid
>>>
>>> # Create layer
>>> layer = HTCLinear(in_features=5, out_features=8, rngs=nnx.Rngs(0))
>>>
>>> # Forward pass
>>> manifold = Hyperboloid()
>>> x = jnp.ones((32, 5))  # batch of 32 points
>>> x = jax.vmap(manifold.proj, in_axes=(0, None))(x, 1.0)
>>> y = layer(x, c_in=1.0, c_out=2.0)
>>> y.shape
(32, 9)  # 8 spatial + 1 time
Source code in hyperbolix/nn_layers/hyperboloid_linear.py
def __init__(
    self,
    in_features: int,
    out_features: int,
    *,
    rngs: nnx.Rngs,
    use_bias: bool = True,
    init_bound: float = 0.02,
    eps: float = 1e-7,
):
    # Small uniform initialization for hyperbolic stability
    # Standard initializations (Lecun, Xavier) are too large and cause gradient explosion
    self.kernel = nnx.Param(
        jax.random.uniform(rngs.params(), (in_features, out_features), minval=-init_bound, maxval=init_bound)
    )
    if use_bias:
        self.bias = nnx.Param(jnp.zeros((out_features,)))
    else:
        self.bias = None
    self.eps = eps
__call__
__call__(
    x: Float[Array, "batch in_features"],
    c_in: float = 1.0,
    c_out: float = 1.0,
) -> Float[Array, "batch out_features_plus_1"]

Apply HTC linear transformation.

Parameters:

Name Type Description Default
x Array of shape (batch, in_features)

Input points on hyperboloid with curvature c_in.

required
c_in float

Input curvature (default: 1.0).

1.0
c_out float

Output curvature (default: 1.0).

1.0

Returns:

Name Type Description
y Array of shape (batch, out_features+1)

Output points on hyperboloid with curvature c_out.

Source code in hyperbolix/nn_layers/hyperboloid_linear.py
def __call__(
    self,
    x: Float[Array, "batch in_features"],
    c_in: float = 1.0,
    c_out: float = 1.0,
) -> Float[Array, "batch out_features_plus_1"]:
    """Apply HTC linear transformation.

    Parameters
    ----------
    x : Array of shape (batch, in_features)
        Input points on hyperboloid with curvature c_in.
    c_in : float, optional
        Input curvature (default: 1.0).
    c_out : float, optional
        Output curvature (default: 1.0).

    Returns
    -------
    y : Array of shape (batch, out_features+1)
        Output points on hyperboloid with curvature c_out.
    """

    def linear_fn(z):
        out = z @ self.kernel[...]
        if self.bias is not None:
            out = out + self.bias[...]
        return out

    return htc(x, linear_fn, c_in, c_out, self.eps)

hyperbolix.nn_layers.HRCBatchNorm

HRCBatchNorm(
    num_features: int,
    *,
    rngs: Rngs,
    momentum: float = 0.99,
    epsilon: float = 1e-05,
    eps: float = 1e-07,
)

Bases: Module

Hyperbolic Regularization Component with batch normalization.

Applies batch normalization to spatial components of hyperboloid points, then reconstructs the time component for the output curvature.

Parameters:

Name Type Description Default
num_features int

Number of spatial features to normalize.

required
rngs Rngs

Random number generators for parameter initialization.

required
momentum float

Momentum for running statistics (default: 0.99).

0.99
epsilon float

Small value for numerical stability in batch norm (default: 1e-5).

1e-05
eps float

Small value for numerical stability in HRC (default: 1e-7).

1e-07

Attributes:

Name Type Description
bn BatchNorm

Flax batch normalization.

eps float

Numerical stability parameter for HRC.

Notes

Training vs Evaluation Mode: During training (use_running_average=False), batch norm computes statistics from the current batch and updates running averages. During evaluation (use_running_average=True), it uses the accumulated running statistics.

Examples:

>>> from flax import nnx
>>> from hyperbolix.nn_layers import HRCBatchNorm
>>>
>>> # Create batch norm layer
>>> bn = HRCBatchNorm(num_features=64, rngs=nnx.Rngs(0))
>>>
>>> # Training mode
>>> y_train = bn(x, c_in=1.0, c_out=2.0, use_running_average=False)
>>>
>>> # Evaluation mode
>>> y_eval = bn(x, c_in=1.0, c_out=2.0, use_running_average=True)
Source code in hyperbolix/nn_layers/hyperboloid_regularization.py
def __init__(
    self,
    num_features: int,
    *,
    rngs: nnx.Rngs,
    momentum: float = 0.99,
    epsilon: float = 1e-5,
    eps: float = 1e-7,
):
    self.bn = nnx.BatchNorm(
        num_features,
        momentum=momentum,
        epsilon=epsilon,
        rngs=rngs,
    )
    self.eps = eps
__call__
__call__(
    x: Float[Array, "batch dim_plus_1"],
    c_in: float = 1.0,
    c_out: float = 1.0,
    use_running_average: bool | None = None,
) -> Float[Array, "batch dim_plus_1"]

Apply HRC batch normalization.

Parameters:

Name Type Description Default
x Array of shape (batch, dim+1)

Input points on hyperboloid with curvature c_in.

required
c_in float

Input curvature (default: 1.0).

1.0
c_out float

Output curvature (default: 1.0).

1.0
use_running_average bool or None

If True, use running statistics (eval mode). If False, use batch statistics (train mode). If None, use the default set during initialization.

None

Returns:

Name Type Description
y Array of shape (batch, dim+1)

Output points on hyperboloid with curvature c_out.

Source code in hyperbolix/nn_layers/hyperboloid_regularization.py
def __call__(
    self,
    x: Float[Array, "batch dim_plus_1"],
    c_in: float = 1.0,
    c_out: float = 1.0,
    use_running_average: bool | None = None,
) -> Float[Array, "batch dim_plus_1"]:
    """Apply HRC batch normalization.

    Parameters
    ----------
    x : Array of shape (batch, dim+1)
        Input points on hyperboloid with curvature c_in.
    c_in : float, optional
        Input curvature (default: 1.0).
    c_out : float, optional
        Output curvature (default: 1.0).
    use_running_average : bool or None, optional
        If True, use running statistics (eval mode).
        If False, use batch statistics (train mode).
        If None, use the default set during initialization.

    Returns
    -------
    y : Array of shape (batch, dim+1)
        Output points on hyperboloid with curvature c_out.
    """

    def bn_fn(z):
        return self.bn(z, use_running_average=use_running_average)

    return hrc(x, bn_fn, c_in, c_out, self.eps)

hyperbolix.nn_layers.HRCLayerNorm

HRCLayerNorm(
    num_features: int,
    *,
    rngs: Rngs,
    epsilon: float = 1e-05,
    eps: float = 1e-07,
)

Bases: Module

Hyperbolic Regularization Component with layer normalization.

Applies layer normalization to spatial components of hyperboloid points, then reconstructs the time component for the output curvature.

Parameters:

Name Type Description Default
num_features int

Number of spatial features to normalize.

required
rngs Rngs

Random number generators for parameter initialization.

required
epsilon float

Small value for numerical stability in layer norm (default: 1e-5).

1e-05
eps float

Small value for numerical stability in HRC (default: 1e-7).

1e-07

Attributes:

Name Type Description
ln LayerNorm

Flax layer normalization.

eps float

Numerical stability parameter for HRC.

Examples:

>>> from flax import nnx
>>> from hyperbolix.nn_layers import HRCLayerNorm
>>>
>>> ln = HRCLayerNorm(num_features=64, rngs=nnx.Rngs(0))
>>> y = ln(x, c_in=1.0, c_out=2.0)
Source code in hyperbolix/nn_layers/hyperboloid_regularization.py
def __init__(
    self,
    num_features: int,
    *,
    rngs: nnx.Rngs,
    epsilon: float = 1e-5,
    eps: float = 1e-7,
):
    self.ln = nnx.LayerNorm(num_features, epsilon=epsilon, rngs=rngs)
    self.eps = eps
__call__
__call__(
    x: Float[Array, "batch dim_plus_1"],
    c_in: float = 1.0,
    c_out: float = 1.0,
) -> Float[Array, "batch dim_plus_1"]

Apply HRC layer normalization.

Parameters:

Name Type Description Default
x Array of shape (batch, dim+1)

Input points on hyperboloid with curvature c_in.

required
c_in float

Input curvature (default: 1.0).

1.0
c_out float

Output curvature (default: 1.0).

1.0

Returns:

Name Type Description
y Array of shape (batch, dim+1)

Output points on hyperboloid with curvature c_out.

Source code in hyperbolix/nn_layers/hyperboloid_regularization.py
def __call__(
    self,
    x: Float[Array, "batch dim_plus_1"],
    c_in: float = 1.0,
    c_out: float = 1.0,
) -> Float[Array, "batch dim_plus_1"]:
    """Apply HRC layer normalization.

    Parameters
    ----------
    x : Array of shape (batch, dim+1)
        Input points on hyperboloid with curvature c_in.
    c_in : float, optional
        Input curvature (default: 1.0).
    c_out : float, optional
        Output curvature (default: 1.0).

    Returns
    -------
    y : Array of shape (batch, dim+1)
        Output points on hyperboloid with curvature c_out.
    """
    return hrc(x, self.ln, c_in, c_out, self.eps)

hyperbolix.nn_layers.HRCRMSNorm

HRCRMSNorm(
    num_features: int,
    *,
    rngs: Rngs,
    epsilon: float = 1e-06,
    eps: float = 1e-07,
)

Bases: Module

Hyperbolic Regularization Component with RMS normalization.

Applies RMS normalization to spatial components of hyperboloid points, then reconstructs the time component for the output curvature. RMSNorm is a simpler and faster variant of LayerNorm that only normalizes by the root mean square without centering.

Parameters:

Name Type Description Default
num_features int

Number of spatial features to normalize.

required
rngs Rngs

Random number generators for parameter initialization.

required
epsilon float

Small value for numerical stability in RMS norm (default: 1e-6).

1e-06
eps float

Small value for numerical stability in HRC (default: 1e-7).

1e-07

Attributes:

Name Type Description
rms RMSNorm

Flax RMS normalization.

eps float

Numerical stability parameter for HRC.

Notes

RMSNorm Formula: y = x / RMS(x) * scale, where RMS(x) = sqrt(mean(x^2) + epsilon)

RMSNorm is more efficient than LayerNorm as it skips mean subtraction and is commonly used in modern transformers (e.g., LLaMA, GPT-NeoX).

Examples:

>>> from flax import nnx
>>> from hyperbolix.nn_layers import HRCRMSNorm
>>>
>>> rms = HRCRMSNorm(num_features=64, rngs=nnx.Rngs(0))
>>> y = rms(x, c_in=1.0, c_out=2.0)
Source code in hyperbolix/nn_layers/hyperboloid_regularization.py
def __init__(
    self,
    num_features: int,
    *,
    rngs: nnx.Rngs,
    epsilon: float = 1e-6,
    eps: float = 1e-7,
):
    self.rms = nnx.RMSNorm(num_features, epsilon=epsilon, rngs=rngs)
    self.eps = eps
__call__
__call__(
    x: Float[Array, "batch dim_plus_1"],
    c_in: float = 1.0,
    c_out: float = 1.0,
) -> Float[Array, "batch dim_plus_1"]

Apply HRC RMS normalization.

Parameters:

Name Type Description Default
x Array of shape (batch, dim+1)

Input points on hyperboloid with curvature c_in.

required
c_in float

Input curvature (default: 1.0).

1.0
c_out float

Output curvature (default: 1.0).

1.0

Returns:

Name Type Description
y Array of shape (batch, dim+1)

Output points on hyperboloid with curvature c_out.

Source code in hyperbolix/nn_layers/hyperboloid_regularization.py
def __call__(
    self,
    x: Float[Array, "batch dim_plus_1"],
    c_in: float = 1.0,
    c_out: float = 1.0,
) -> Float[Array, "batch dim_plus_1"]:
    """Apply HRC RMS normalization.

    Parameters
    ----------
    x : Array of shape (batch, dim+1)
        Input points on hyperboloid with curvature c_in.
    c_in : float, optional
        Input curvature (default: 1.0).
    c_out : float, optional
        Output curvature (default: 1.0).

    Returns
    -------
    y : Array of shape (batch, dim+1)
        Output points on hyperboloid with curvature c_out.
    """
    return hrc(x, self.rms, c_in, c_out, self.eps)

hyperbolix.nn_layers.HRCDropout

HRCDropout(rate: float, *, rngs: Rngs, eps: float = 1e-07)

Bases: Module

Hyperbolic Regularization Component with dropout.

Applies dropout to spatial components of hyperboloid points, then reconstructs the time component for the output curvature.

Parameters:

Name Type Description Default
rate float

Dropout probability (fraction of units to drop).

required
rngs Rngs

Random number generators for dropout.

required
eps float

Small value for numerical stability (default: 1e-7).

1e-07

Attributes:

Name Type Description
dropout Dropout

Flax dropout layer.

eps float

Numerical stability parameter.

Examples:

>>> from flax import nnx
>>> from hyperbolix.nn_layers import HRCDropout
>>>
>>> dropout = HRCDropout(rate=0.1, rngs=nnx.Rngs(dropout=42))
>>> y = dropout(x, c_in=1.0, c_out=1.0, deterministic=False)
Source code in hyperbolix/nn_layers/hyperboloid_regularization.py
def __init__(self, rate: float, *, rngs: nnx.Rngs, eps: float = 1e-7):
    self.dropout = nnx.Dropout(rate, rngs=rngs)
    self.eps = eps
__call__
__call__(
    x: Float[Array, "batch dim_plus_1"],
    c_in: float = 1.0,
    c_out: float = 1.0,
    deterministic: bool = False,
) -> Float[Array, "batch dim_plus_1"]

Apply HRC dropout.

Parameters:

Name Type Description Default
x Array of shape (batch, dim+1)

Input points on hyperboloid with curvature c_in.

required
c_in float

Input curvature (default: 1.0).

1.0
c_out float

Output curvature (default: 1.0).

1.0
deterministic bool

If True, no dropout is applied (for evaluation mode).

False

Returns:

Name Type Description
y Array of shape (batch, dim+1)

Output points on hyperboloid with curvature c_out.

Source code in hyperbolix/nn_layers/hyperboloid_regularization.py
def __call__(
    self,
    x: Float[Array, "batch dim_plus_1"],
    c_in: float = 1.0,
    c_out: float = 1.0,
    deterministic: bool = False,
) -> Float[Array, "batch dim_plus_1"]:
    """Apply HRC dropout.

    Parameters
    ----------
    x : Array of shape (batch, dim+1)
        Input points on hyperboloid with curvature c_in.
    c_in : float, optional
        Input curvature (default: 1.0).
    c_out : float, optional
        Output curvature (default: 1.0).
    deterministic : bool, optional
        If True, no dropout is applied (for evaluation mode).

    Returns
    -------
    y : Array of shape (batch, dim+1)
        Output points on hyperboloid with curvature c_out.
    """

    def drop_fn(z):
        return self.dropout(z, deterministic=deterministic)

    return hrc(x, drop_fn, c_in, c_out, self.eps)

hyperbolix.nn_layers.FGGMeanOnlyBatchNorm

FGGMeanOnlyBatchNorm(
    num_features: int,
    *,
    momentum: float = 0.99,
    eps: float = 1e-07,
)

Bases: Module

Mean-only batch normalization for FGG-LNN layers.

Subtracts the batch mean from spatial components without dividing by variance, then adds a learnable bias. This avoids the instability that standard BatchNorm causes in hyperbolic space (exponentially large/small variance denominators) while still centering activations for stable training.

Designed to pair with weight normalization (FGGLinear(use_weight_norm=True)), which handles magnitude control. Together they replace standard BatchNorm in fully hyperbolic FGG-LNN networks.

Formula (applied to spatial components via HRC): y = x - E[x] + bias

Parameters:

Name Type Description Default
num_features int

Number of spatial features (D, not D+1).

required
momentum float

Exponential moving average momentum for running mean (default: 0.99). Update: running_mean = momentum * running_mean + (1 - momentum) * batch_mean.

0.99
eps float

Numerical stability floor for HRC time reconstruction (default: 1e-7).

1e-07
References

Klis et al. "Fast and Geometrically Grounded Lorentz Neural Networks" (2026), §4.4. Salimans & Kingma "Weight Normalization" (2016).

Source code in hyperbolix/nn_layers/hyperboloid_regularization.py
def __init__(
    self,
    num_features: int,
    *,
    momentum: float = 0.99,
    eps: float = 1e-7,
):
    self.num_features = num_features
    self.momentum = momentum
    self.eps = eps

    # Learnable bias (shift after centering)
    self.bias = nnx.Param(jnp.zeros((num_features,)))
    # Running mean for eval mode
    self.running_mean = nnx.BatchStat(jnp.zeros((num_features,)))
__call__
__call__(
    x: Float[Array, "batch dim_plus_1"],
    c_in: float = 1.0,
    c_out: float = 1.0,
    use_running_average: bool = False,
) -> Float[Array, "batch dim_plus_1"]

Apply mean-only batch normalization via HRC.

Parameters:

Name Type Description Default
x (Array, shape(B, D + 1) or (B, ..., D + 1))

Input points on hyperboloid with curvature c_in.

required
c_in float

Input curvature (default: 1.0).

1.0
c_out float

Output curvature (default: 1.0).

1.0
use_running_average bool

If True, use running mean (eval mode). If False, compute from batch and update running mean (train mode). Default: False.

False

Returns:

Type Description
Array, same shape as input

Points on hyperboloid with curvature c_out.

Source code in hyperbolix/nn_layers/hyperboloid_regularization.py
def __call__(
    self,
    x: Float[Array, "batch dim_plus_1"],
    c_in: float = 1.0,
    c_out: float = 1.0,
    use_running_average: bool = False,
) -> Float[Array, "batch dim_plus_1"]:
    """Apply mean-only batch normalization via HRC.

    Parameters
    ----------
    x : Array, shape (B, D+1) or (B, ..., D+1)
        Input points on hyperboloid with curvature ``c_in``.
    c_in : float, optional
        Input curvature (default: 1.0).
    c_out : float, optional
        Output curvature (default: 1.0).
    use_running_average : bool, optional
        If True, use running mean (eval mode). If False, compute from
        batch and update running mean (train mode). Default: False.

    Returns
    -------
    Array, same shape as input
        Points on hyperboloid with curvature ``c_out``.
    """

    def mean_only_bn(z):
        # z: spatial components, shape (..., D)
        if use_running_average:
            mean = self.running_mean[...]
        else:
            # Compute mean over all dims except the last (feature) dim
            # Flatten leading dims for mean computation
            z_flat = z.reshape(-1, z.shape[-1])  # (N, D)
            mean = jnp.mean(z_flat, axis=0)  # (D,)

            # Update running mean (EMA)
            self.running_mean[...] = self.momentum * self.running_mean[...] + (1.0 - self.momentum) * mean

        return z - mean + self.bias[...]

    return hrc(x, mean_only_bn, c_in, c_out, self.eps)

Hypformer Example

from hyperbolix.nn_layers import HTCLinear, HRCBatchNorm, HRCRMSNorm, hrc_relu
from hyperbolix.manifolds import Hyperboloid
from flax import nnx
import jax
import jax.numpy as jnp

hyperboloid = Hyperboloid()

class HypformerBlock(nnx.Module):
    """Example using HTC/HRC components with curvature change."""

    def __init__(self, in_dim, out_dim, rngs):
        self.linear = HTCLinear(
            in_features=in_dim,
            out_features=out_dim,
            rngs=rngs
        )
        # Can use BatchNorm or RMSNorm for normalization
        self.bn = HRCBatchNorm(num_features=out_dim, rngs=rngs)
        # self.rms = HRCRMSNorm(num_features=out_dim, rngs=rngs)  # Alternative: faster, simpler

    def __call__(self, x, c_in=1.0, c_out=2.0, use_running_average=False):
        # Linear transformation with curvature change
        x = self.linear(x, c_in=c_in, c_out=c_out)

        # Batch normalization (curvature-preserving)
        x = self.bn(x, c_in=c_out, c_out=c_out,
                    use_running_average=use_running_average)

        # Activation (curvature-preserving)
        x = hrc_relu(x, c_in=c_out, c_out=c_out)

        return x

# Create and use block
block = HypformerBlock(in_dim=33, out_dim=64, rngs=nnx.Rngs(0))

# Input on hyperboloid with curvature 1.0
x = jax.random.normal(jax.random.PRNGKey(1), (32, 33))
x_proj = jax.vmap(hyperboloid.proj, in_axes=(0, None))(x, 1.0)

# Transform to curvature 2.0
output = block(x_proj, c_in=1.0, c_out=2.0)
print(output.shape)  # (32, 65) - 64 spatial + 1 time

HTC vs HRC

HRC (Hyperbolic Regularization Component):

  • Applies Euclidean function f_r to space components only
  • Use for: activations, normalization, dropout, convolutions
  • Formula: space = f_r(x_s), time = sqrt(||space||^2 + 1/c_out)

HTC (Hyperbolic Transformation Component):

  • Applies Euclidean function f_t to full point (time + space)
  • Use for: learnable linear transformations
  • Formula: space = f_t(x), time = sqrt(||space||^2 + 1/c_out)

Both support curvature changes (c_in → c_out) for flexible network design.

Attention Layers

Three hyperbolic attention variants from the Hypformer paper (Yang et al. 2025, Section 4.3). All operate on hyperboloid points and support independent curvatures for input (c_in), attention computation (c_attn), and output (c_out). All variants support causal (autoregressive) masking via the causal=True flag, making them suitable for language models and sequence generation tasks.

Core Utilities

hyperbolix.nn_layers.spatial_to_hyperboloid

spatial_to_hyperboloid(
    spatial: Float[Array, "... D"],
    c_in: float,
    c_out: float,
    eps: float = 1e-07,
) -> Float[Array, "... D_plus_1"]

Scale spatial components and reconstruct time to produce a hyperboloid point.

Extracts the common tail of hrc/htc: curvature scaling + time reconstruction via the hyperboloid constraint.

Parameters:

Name Type Description Default
spatial (Array, shape(..., D))

Spatial components (no time coordinate).

required
c_in float

Source curvature (positive).

required
c_out float

Target curvature (positive).

required
eps float

Numerical stability floor (default: 1e-7).

1e-07

Returns:

Type Description
(Array, shape(..., D + 1))

Points on the hyperboloid with curvature c_out.

Source code in hyperbolix/nn_layers/hyperboloid_core.py
def spatial_to_hyperboloid(
    spatial: Float[Array, "... D"],
    c_in: float,
    c_out: float,
    eps: float = 1e-7,
) -> Float[Array, "... D_plus_1"]:
    """Scale spatial components and reconstruct time to produce a hyperboloid point.

    Extracts the common tail of ``hrc``/``htc``: curvature scaling + time
    reconstruction via the hyperboloid constraint.

    Parameters
    ----------
    spatial : Array, shape (..., D)
        Spatial components (no time coordinate).
    c_in : float
        Source curvature (positive).
    c_out : float
        Target curvature (positive).
    eps : float, optional
        Numerical stability floor (default: 1e-7).

    Returns
    -------
    Array, shape (..., D+1)
        Points on the hyperboloid with curvature ``c_out``.
    """
    scale = jnp.sqrt(c_in / c_out)
    scaled_D = scale * spatial  # (..., D)

    norm_sq = jnp.sum(scaled_D**2, axis=-1)  # (...)
    x0 = jnp.sqrt(jnp.maximum(norm_sq + 1.0 / c_out, eps))  # (...)

    return jnp.concatenate([x0[..., None], scaled_D], axis=-1)  # (..., D+1)

hyperbolix.nn_layers.lorentz_midpoint

lorentz_midpoint(
    points: Float[Array, "... M A"],
    weights: Float[Array, "... N M"],
    c: float,
    eps: float = 1e-07,
) -> Float[Array, "... N A"]

Weighted Lorentzian midpoint over M points.

Generalises :func:lorentz_residual (which handles two points) to an arbitrary weighted combination used by full attention aggregation and multi-head averaging.

Formula (HELM, Chen et al. 2024): h = weights @ points (weighted sum) mu = h / (sqrt(c) * ||h||_L)

where ||h||_L = sqrt(-<h,h>_L) and <h,h>_L = -h_0^2 + ||h_s||^2.

Parameters:

Name Type Description Default
points (Array, shape(..., M, A))

Points on the hyperboloid with curvature c. A = d + 1.

required
weights (Array, shape(..., N, M))

Combination weights (e.g. attention weights, uniform 1/M).

required
c float

Curvature parameter (positive).

required
eps float

Numerical stability floor (default: 1e-7).

1e-07

Returns:

Type Description
(Array, shape(..., N, A))

Midpoints on the hyperboloid with curvature c.

Source code in hyperbolix/nn_layers/hyperboloid_core.py
def lorentz_midpoint(
    points: Float[Array, "... M A"],
    weights: Float[Array, "... N M"],
    c: float,
    eps: float = 1e-7,
) -> Float[Array, "... N A"]:
    """Weighted Lorentzian midpoint over M points.

    Generalises :func:`lorentz_residual` (which handles two points) to an
    arbitrary weighted combination used by full attention aggregation and
    multi-head averaging.

    Formula (HELM, Chen et al. 2024):
        ``h = weights @ points``  (weighted sum)
        ``mu = h / (sqrt(c) * ||h||_L)``

    where ``||h||_L = sqrt(-<h,h>_L)`` and ``<h,h>_L = -h_0^2 + ||h_s||^2``.

    Parameters
    ----------
    points : Array, shape (..., M, A)
        Points on the hyperboloid with curvature ``c``.  ``A = d + 1``.
    weights : Array, shape (..., N, M)
        Combination weights (e.g. attention weights, uniform ``1/M``).
    c : float
        Curvature parameter (positive).
    eps : float, optional
        Numerical stability floor (default: 1e-7).

    Returns
    -------
    Array, shape (..., N, A)
        Midpoints on the hyperboloid with curvature ``c``.
    """
    # h = sum_m w_{n,m} * points_m  →  (..., N, A)
    h_NA = jnp.einsum("...nm,...ma->...na", weights, points)

    # Minkowski squared norm: <h,h>_L = -h_0^2 + ||h_s||^2  (should be < 0)
    mink_1 = -(h_NA[..., 0:1] ** 2) + jnp.sum(h_NA[..., 1:] ** 2, axis=-1, keepdims=True)  # (..., N, 1)
    denom_1 = jnp.sqrt(jnp.maximum(c * jnp.abs(mink_1), eps))  # (..., N, 1)

    return h_NA / denom_1  # (..., N, A)

hyperbolix.nn_layers.focus_transform

focus_transform(
    x_D: Float[Array, "... D"],
    temperature: Float[Array, ""],
    power: float,
    eps: float = 1e-07,
) -> Float[Array, "... D"]

Norm-preserving focus function (Eq 19, Hypformer).

Applies temperature-scaled ReLU followed by element-wise power sharpening while preserving the original norm.

Parameters:

Name Type Description Default
x_D (Array, shape(..., D))

Input spatial features.

required
temperature scalar Array

Learnable temperature parameter.

required
power float

Sharpening exponent (p > 1 concentrates mass).

required
eps float

Numerical stability floor (default: 1e-7).

1e-07

Returns:

Type Description
(Array, shape(..., D))

Focus-transformed features with ||output|| ≈ ||relu(x)/|t|||.

Source code in hyperbolix/nn_layers/hyperboloid_attention.py
def focus_transform(
    x_D: Float[Array, "... D"],
    temperature: Float[Array, ""],
    power: float,
    eps: float = 1e-7,
) -> Float[Array, "... D"]:
    """Norm-preserving focus function (Eq 19, Hypformer).

    Applies temperature-scaled ReLU followed by element-wise power sharpening
    while preserving the original norm.

    Parameters
    ----------
    x_D : Array, shape (..., D)
        Input spatial features.
    temperature : scalar Array
        Learnable temperature parameter.
    power : float
        Sharpening exponent (``p > 1`` concentrates mass).
    eps : float, optional
        Numerical stability floor (default: 1e-7).

    Returns
    -------
    Array, shape (..., D)
        Focus-transformed features with ``||output|| ≈ ||relu(x)/|t|||``.
    """
    # Temperature-scaled ReLU
    scaled_relu_D = (jax.nn.relu(x_D) + eps) / (jnp.abs(temperature) + eps)  # (..., D)

    # Element-wise power sharpening
    sharpened_D = scaled_relu_D**power  # (..., D)

    # Norm-preserving rescaling
    norm_scaled = jnp.sqrt(jnp.sum(scaled_relu_D**2, axis=-1, keepdims=True) + eps)  # (..., 1)
    norm_sharpened = jnp.sqrt(jnp.sum(sharpened_D**2, axis=-1, keepdims=True) + eps)  # (..., 1)

    return (norm_scaled / norm_sharpened) * sharpened_D  # (..., D)

Attention Modules

hyperbolix.nn_layers.HyperbolicLinearAttention

HyperbolicLinearAttention(
    in_features: int,
    out_features: int,
    *,
    num_heads: int = 1,
    power: float = 2.0,
    init_bound: float = 0.02,
    eps: float = 1e-07,
    rngs: Rngs,
)

Bases: _HyperbolicAttentionBase

Hyperbolic linear attention with focus function (Eq 14-19).

The paper's main contribution: O(N) attention using the kernel trick in the spatial domain of the hyperboloid. Focus function φ sharpens query and key.

Parameters:

Name Type Description Default
in_features int

Ambient input dimension (d_in + 1).

required
out_features int

Spatial output dimension per head.

required
num_heads int

Number of attention heads (default: 1).

1
power float

Focus function sharpening exponent (default: 2.0).

2.0
init_bound float

Uniform init bound for weights (default: 0.02).

0.02
eps float

Numerical stability floor (default: 1e-7).

1e-07
rngs Rngs

Random number generators.

required
Source code in hyperbolix/nn_layers/hyperboloid_attention.py
def __init__(
    self,
    in_features: int,
    out_features: int,
    *,
    num_heads: int = 1,
    power: float = 2.0,
    init_bound: float = 0.02,
    eps: float = 1e-7,
    rngs: nnx.Rngs,
):
    super().__init__(
        in_features,
        out_features,
        num_heads=num_heads,
        init_bound=init_bound,
        eps=eps,
        rngs=rngs,
    )
    self.power = power
    self.temperature = nnx.Param(jnp.array(1.0))
    # Spatial residual projection ψ: D → D (shared across heads)
    self.residual_proj = nnx.Linear(out_features, out_features, rngs=rngs)

hyperbolix.nn_layers.HyperbolicSoftmaxAttention

HyperbolicSoftmaxAttention(
    in_features: int,
    out_features: int,
    *,
    num_heads: int = 1,
    init_bound: float = 0.02,
    eps: float = 1e-07,
    rngs: Rngs,
)

Bases: _HyperbolicAttentionBase

Hyperbolic softmax attention in the spatial domain.

Standard scaled dot-product attention applied to spatial components of query, key, value, followed by the same HRC pipeline (residual + time calibration) as the linear variant.

Parameters:

Name Type Description Default
in_features int

Ambient input dimension (d_in + 1).

required
out_features int

Spatial output dimension per head.

required
num_heads int

Number of attention heads (default: 1).

1
init_bound float

Uniform init bound for weights (default: 0.02).

0.02
eps float

Numerical stability floor (default: 1e-7).

1e-07
rngs Rngs

Random number generators.

required
Source code in hyperbolix/nn_layers/hyperboloid_attention.py
def __init__(
    self,
    in_features: int,
    out_features: int,
    *,
    num_heads: int = 1,
    init_bound: float = 0.02,
    eps: float = 1e-7,
    rngs: nnx.Rngs,
):
    super().__init__(
        in_features,
        out_features,
        num_heads=num_heads,
        init_bound=init_bound,
        eps=eps,
        rngs=rngs,
    )
    # Spatial residual projection ψ: D → D (shared across heads)
    self.residual_proj = nnx.Linear(out_features, out_features, rngs=rngs)

hyperbolix.nn_layers.HyperbolicFullAttention

HyperbolicFullAttention(
    in_features: int,
    out_features: int,
    *,
    num_heads: int = 1,
    init_bound: float = 0.02,
    eps: float = 1e-07,
    rngs: Rngs,
)

Bases: _HyperbolicAttentionBase

Full Lorentzian attention with midpoint aggregation.

Uses the Lorentzian inner product for similarity and weighted Lorentzian midpoint for aggregation — operating on full hyperboloid points throughout.

Parameters:

Name Type Description Default
in_features int

Ambient input dimension (d_in + 1).

required
out_features int

Spatial output dimension per head.

required
num_heads int

Number of attention heads (default: 1).

1
init_bound float

Uniform init bound for weights (default: 0.02).

0.02
eps float

Numerical stability floor (default: 1e-7).

1e-07
rngs Rngs

Random number generators.

required
Source code in hyperbolix/nn_layers/hyperboloid_attention.py
def __init__(
    self,
    in_features: int,
    out_features: int,
    *,
    num_heads: int = 1,
    init_bound: float = 0.02,
    eps: float = 1e-7,
    rngs: nnx.Rngs,
):
    super().__init__(
        in_features,
        out_features,
        num_heads=num_heads,
        init_bound=init_bound,
        eps=eps,
        rngs=rngs,
    )
    self.scale = nnx.Param(jnp.array(1.0))
    self.attn_bias = nnx.Param(jnp.array(0.0))

Attention Example

import jax
import jax.numpy as jnp
from flax import nnx
from hyperbolix.manifolds import Hyperboloid
from hyperbolix.nn_layers import (
    HyperbolicLinearAttention,
    HyperbolicSoftmaxAttention,
    HyperbolicFullAttention,
)

hyperboloid = Hyperboloid()

# Input: (batch, seq_len, ambient_dim) on the hyperboloid
B, N, A_in, D_out = 4, 8, 9, 8  # 8-dim spatial + 1 time
key = jax.random.PRNGKey(0)
spatial = jax.random.normal(key, (B, N, A_in - 1)) * 0.1
time = jnp.sqrt(jnp.sum(spatial**2, axis=-1, keepdims=True) + 1.0)
x = jnp.concatenate([time, spatial], axis=-1)  # (B, N, A_in)

# O(N) linear attention with focus function — fastest, main Hypformer contribution
linear_attn = HyperbolicLinearAttention(
    in_features=A_in,
    out_features=D_out,
    num_heads=2,
    power=2.0,
    rngs=nnx.Rngs(0),
)
y = linear_attn(x, c_in=1.0, c_attn=1.0, c_out=1.0)
print(y.shape)  # (4, 8, 9) — D_out spatial + 1 time

# O(N²) softmax attention in the spatial domain
softmax_attn = HyperbolicSoftmaxAttention(
    in_features=A_in,
    out_features=D_out,
    num_heads=2,
    rngs=nnx.Rngs(0),
)
y = softmax_attn(x, c_in=1.0, c_attn=1.0, c_out=1.0)

# O(N²) full Lorentzian attention — operates entirely on hyperboloid points
full_attn = HyperbolicFullAttention(
    in_features=A_in,
    out_features=D_out,
    num_heads=2,
    rngs=nnx.Rngs(0),
)
y = full_attn(x, c_in=1.0, c_attn=1.0, c_out=1.0)

# Verify outputs are on the hyperboloid
for b in range(B):
    for n in range(N):
        assert hyperboloid.is_in_manifold(y[b, n], c=1.0, atol=1e-4)

Causal (Autoregressive) Attention

All three variants support causal masking via causal=True. Position n can only attend to positions m ≤ n, which is required for autoregressive tasks like language modeling.

# Bidirectional (default) — each token attends to all tokens
y_bidir = softmax_attn(x, c_in=1.0, c_attn=1.0, c_out=1.0, causal=False)

# Causal — position n only attends to positions 0..n
y_causal = softmax_attn(x, c_in=1.0, c_attn=1.0, c_out=1.0, causal=True)

# Causal is JIT-compatible
@nnx.jit
def forward(model, inp):
    return model(inp, c_in=1.0, c_attn=1.0, c_out=1.0, causal=True)

Causal masking implementations

The three variants implement causal masking differently:

  • HyperbolicSoftmaxAttention and HyperbolicFullAttention: Apply a lower-triangular -inf mask to the score matrix before softmax — O(N²) in both causal and non-causal mode.
  • HyperbolicLinearAttention: Uses a cumulative-sum recurrence (jax.lax.scan) following Katharopoulos et al. (2020): S_i = Σ_{j≤i} φ(K_j) V_j^T computed in O(1) per step → O(N) total, making it especially well-suited for long autoregressive sequences.

Choosing an Attention Variant

Variant Complexity Causal complexity Mechanism Best For
HyperbolicLinearAttention O(N) O(N) Kernel trick + focus function φ Long sequences, autoregressive models
HyperbolicSoftmaxAttention O(N²) O(N²) Standard softmax on spatial components Short sequences, simplicity
HyperbolicFullAttention O(N²) O(N²) Lorentzian inner product + midpoint Maximum geometric fidelity

All variants support independent curvatures: c_in for input, c_attn for Q/K/V projections, c_out for output.

Positional Encoding

Positional encoding layers for hyperbolic Transformers and attention mechanisms. These layers enable position-aware models on the hyperboloid manifold while preserving geometric structure.

Lorentzian Residual Connection

hyperbolix.nn_layers.lorentz_residual

lorentz_residual(
    x: Float[Array, "... dim_plus_1"],
    y: Float[Array, "... dim_plus_1"],
    w_y: float | Float[Array, ""],
    c: float,
    eps: float = 1e-07,
) -> Float[Array, "... dim_plus_1"]

Lorentzian midpoint-based residual connection (LResNet from HELM).

Computes the weighted Lorentzian midpoint of x and y, projecting back to the hyperboloid:

ave = x + w_y * y
result = ave / sqrt(c * |<ave, ave>_L|)

where _L = -a_0^2 + ||a_s||^2 is the Minkowski inner product.

Parameters:

Name Type Description Default
x (Array, shape(..., d + 1))

Points on hyperboloid with curvature c.

required
y (Array, shape(..., d + 1))

Points on hyperboloid with curvature c (to be added with weight w_y).

required
w_y float or scalar Array

Weight for the y contribution.

required
c float

Curvature parameter (positive, c > 0).

required
eps float

Numerical stability floor (default: 1e-7).

1e-07

Returns:

Type Description
(Array, shape(..., d + 1))

Points on hyperboloid with curvature c.

References

Chen et al., "Hyperbolic Embeddings for Learning on Manifolds" (HELM), 2024.

Source code in hyperbolix/nn_layers/hyperboloid_core.py
def lorentz_residual(
    x: Float[Array, "... dim_plus_1"],
    y: Float[Array, "... dim_plus_1"],
    w_y: float | Float[Array, ""],
    c: float,
    eps: float = 1e-7,
) -> Float[Array, "... dim_plus_1"]:
    """Lorentzian midpoint-based residual connection (LResNet from HELM).

    Computes the weighted Lorentzian midpoint of x and y, projecting back
    to the hyperboloid:

        ave = x + w_y * y
        result = ave / sqrt(c * |<ave, ave>_L|)

    where <a, a>_L = -a_0^2 + ||a_s||^2 is the Minkowski inner product.

    Parameters
    ----------
    x : Array, shape (..., d+1)
        Points on hyperboloid with curvature c.
    y : Array, shape (..., d+1)
        Points on hyperboloid with curvature c (to be added with weight w_y).
    w_y : float or scalar Array
        Weight for the y contribution.
    c : float
        Curvature parameter (positive, c > 0).
    eps : float, optional
        Numerical stability floor (default: 1e-7).

    Returns
    -------
    Array, shape (..., d+1)
        Points on hyperboloid with curvature c.

    References
    ----------
    Chen et al., "Hyperbolic Embeddings for Learning on Manifolds" (HELM), 2024.
    """
    ave_A = x + w_y * y  # (..., A) where A = d+1
    # Minkowski inner: -ave_0^2 + ||ave_s||^2
    mink_1 = -(ave_A[..., 0:1] ** 2) + jnp.sum(ave_A[..., 1:] ** 2, axis=-1, keepdims=True)  # (..., 1)
    denom_1 = jnp.sqrt(jnp.maximum(c * jnp.abs(mink_1), eps))  # (..., 1)
    return ave_A / denom_1  # (..., A)

HOPE (Hyperbolic Rotary Positional Encoding)

hyperbolix.nn_layers.hope

hope(
    z: Float[Array, "... seq d_plus_1"],
    positions: Float[Array, seq],
    c: float = 1.0,
    base: float = 10000.0,
    eps: float = 1e-07,
) -> Float[Array, "... seq d_plus_1"]

Hyperbolic Rotary Positional Encoding (HOPE).

Applies RoPE-style rotation to the spatial components of hyperboloid points, then reconstructs the time component to satisfy the manifold constraint. Equivalent to hrc(z, R_{i,Theta}, c, c) where R is a block-diagonal rotation matrix.

Since rotation preserves norms, the Minkowski inner product between encoded points depends only on the relative position offset, giving the standard RoPE relative-position property on the hyperboloid.

Parameters:

Name Type Description Default
z (Array, shape(..., seq_len, d + 1))

Points on hyperboloid (d must be even).

required
positions (Array, shape(seq_len))

Integer position indices.

required
c float

Curvature parameter (default: 1.0).

1.0
base float

Frequency base for rotation angles (default: 10000.0).

10000.0
eps float

Numerical stability floor (default: 1e-7).

1e-07

Returns:

Type Description
(Array, shape(..., seq_len, d + 1))

Rotated points on hyperboloid with curvature c.

References

Chen et al., "Hyperbolic Embeddings for Learning on Manifolds" (HELM), 2024.

Source code in hyperbolix/nn_layers/hyperboloid_positional.py
def hope(
    z: Float[Array, "... seq d_plus_1"],
    positions: Float[Array, "seq"],
    c: float = 1.0,
    base: float = 10000.0,
    eps: float = 1e-7,
) -> Float[Array, "... seq d_plus_1"]:
    """Hyperbolic Rotary Positional Encoding (HOPE).

    Applies RoPE-style rotation to the spatial components of hyperboloid
    points, then reconstructs the time component to satisfy the manifold
    constraint. Equivalent to ``hrc(z, R_{i,Theta}, c, c)`` where R is a
    block-diagonal rotation matrix.

    Since rotation preserves norms, the Minkowski inner product between
    encoded points depends only on the *relative* position offset, giving
    the standard RoPE relative-position property on the hyperboloid.

    Parameters
    ----------
    z : Array, shape (..., seq_len, d+1)
        Points on hyperboloid (d must be even).
    positions : Array, shape (seq_len,)
        Integer position indices.
    c : float, optional
        Curvature parameter (default: 1.0).
    base : float, optional
        Frequency base for rotation angles (default: 10000.0).
    eps : float, optional
        Numerical stability floor (default: 1e-7).

    Returns
    -------
    Array, shape (..., seq_len, d+1)
        Rotated points on hyperboloid with curvature c.

    References
    ----------
    Chen et al., "Hyperbolic Embeddings for Learning on Manifolds" (HELM), 2024.
    """
    spatial_SD = z[..., 1:]  # (..., S, D) where S=seq, D=spatial dim
    d = spatial_SD.shape[-1]

    # Frequency schedule: theta_i = 1 / base^(2i/d)
    freqs_F = 1.0 / (base ** (jnp.arange(0, d, 2) / d))  # (F,) where F = d//2
    angles_SF = positions[:, None] * freqs_F[None, :]  # (S, F)
    cos_SF = jnp.cos(angles_SF)
    sin_SF = jnp.sin(angles_SF)

    # Rotate spatial components (interleaved pairs)
    rotated_SD = _apply_rotary_interleaved(spatial_SD, cos_SF, sin_SF)  # (..., S, D)

    # Reconstruct time: t = sqrt(||rotated||^2 + 1/c)
    norm_sq_S1 = jnp.sum(rotated_SD**2, axis=-1, keepdims=True)  # (..., S, 1)
    time_S1 = jnp.sqrt(jnp.maximum(norm_sq_S1 + 1.0 / c, eps))  # (..., S, 1)

    return jnp.concatenate([time_S1, rotated_SD], axis=-1)  # (..., S, A)

hyperbolix.nn_layers.HyperbolicRoPE

HyperbolicRoPE(
    dim: int,
    max_seq_len: int = 2048,
    base: float = 10000.0,
    eps: float = 1e-07,
)

Bases: Module

NNX module wrapper for HOPE (Hyperbolic Rotary Positional Encoding).

This is a stateless module (no learnable parameters) that wraps the functional :func:hope for convenient use in NNX model definitions.

Parameters:

Name Type Description Default
dim int

Spatial dimension d (must be even).

required
max_seq_len int

Maximum sequence length (for documentation; not enforced, default: 2048).

2048
base float

Frequency base for rotation angles (default: 10000.0).

10000.0
eps float

Numerical stability floor (default: 1e-7).

1e-07
References

Chen et al., "Hyperbolic Embeddings for Learning on Manifolds" (HELM), 2024.

Source code in hyperbolix/nn_layers/hyperboloid_positional.py
def __init__(
    self,
    dim: int,
    max_seq_len: int = 2048,
    base: float = 10000.0,
    eps: float = 1e-7,
):
    self.dim = dim
    self.max_seq_len = max_seq_len
    self.base = base
    self.eps = eps
__call__
__call__(
    z: Float[Array, "... seq d_plus_1"],
    positions: Float[Array, seq],
    c: float = 1.0,
) -> Float[Array, "... seq d_plus_1"]

Apply HOPE positional encoding.

Parameters:

Name Type Description Default
z (Array, shape(..., seq_len, d + 1))

Points on hyperboloid (spatial dim must equal self.dim, must be even).

required
positions (Array, shape(seq_len))

Integer position indices.

required
c float

Curvature parameter (default: 1.0).

1.0

Returns:

Type Description
(Array, shape(..., seq_len, d + 1))

Rotated points on hyperboloid.

Source code in hyperbolix/nn_layers/hyperboloid_positional.py
def __call__(
    self,
    z: Float[Array, "... seq d_plus_1"],
    positions: Float[Array, "seq"],
    c: float = 1.0,
) -> Float[Array, "... seq d_plus_1"]:
    """Apply HOPE positional encoding.

    Parameters
    ----------
    z : Array, shape (..., seq_len, d+1)
        Points on hyperboloid (spatial dim must equal self.dim, must be even).
    positions : Array, shape (seq_len,)
        Integer position indices.
    c : float, optional
        Curvature parameter (default: 1.0).

    Returns
    -------
    Array, shape (..., seq_len, d+1)
        Rotated points on hyperboloid.
    """
    return hope(z, positions, c, self.base, self.eps)

Hypformer Positional Encoding

hyperbolix.nn_layers.HypformerPositionalEncoding

HypformerPositionalEncoding(
    in_features: int,
    out_features: int,
    *,
    rngs: Rngs,
    init_bound: float = 0.02,
    eps: float = 1e-07,
)

Bases: Module

Learnable relative positional encoding from Hypformer.

Computes a position vector via HTCLinear, then combines it with the input using a Lorentzian residual connection:

p = HTCLinear(x)
result = lorentz_residual(x, p, w_y=epsilon, c=c)

where epsilon is a learnable scalar magnitude parameter.

Parameters:

Name Type Description Default
in_features int

Input ambient dimension (d+1, including time component).

required
out_features int

Output spatial dimension (d). The HTCLinear output will have ambient dimension d+1 (= out_features + 1), matching the input.

required
rngs Rngs

Random number generators for parameter initialization.

required
init_bound float

Bound for HTCLinear uniform weight initialization (default: 0.02).

0.02
eps float

Numerical stability floor for lorentz_residual (default: 1e-7).

1e-07

Attributes:

Name Type Description
htc_linear HTCLinear

Linear transformation producing the position encoding vector.

epsilon Param

Learnable scalar weight for the position encoding contribution.

eps float

Numerical stability parameter.

References

Chen et al., "Hyperbolic Embeddings for Learning on Manifolds" (HELM), 2024. Hypformer paper (citation to be added).

Source code in hyperbolix/nn_layers/hyperboloid_positional.py
def __init__(
    self,
    in_features: int,
    out_features: int,
    *,
    rngs: nnx.Rngs,
    init_bound: float = 0.02,
    eps: float = 1e-7,
):
    self.htc_linear = HTCLinear(in_features, out_features, rngs=rngs, init_bound=init_bound)
    self.epsilon = nnx.Param(jnp.array(1.0))
    self.eps = eps
__call__
__call__(
    x: Float[Array, "... dim_plus_1"], c: float = 1.0
) -> Float[Array, "... dim_plus_1"]

Apply learnable positional encoding.

Parameters:

Name Type Description Default
x (Array, shape(..., d + 1))

Points on hyperboloid with curvature c.

required
c float

Curvature parameter (default: 1.0).

1.0

Returns:

Type Description
(Array, shape(..., d + 1))

Positionally-encoded points on hyperboloid with curvature c.

Source code in hyperbolix/nn_layers/hyperboloid_positional.py
def __call__(
    self,
    x: Float[Array, "... dim_plus_1"],
    c: float = 1.0,
) -> Float[Array, "... dim_plus_1"]:
    """Apply learnable positional encoding.

    Parameters
    ----------
    x : Array, shape (..., d+1)
        Points on hyperboloid with curvature c.
    c : float, optional
        Curvature parameter (default: 1.0).

    Returns
    -------
    Array, shape (..., d+1)
        Positionally-encoded points on hyperboloid with curvature c.
    """
    p = self.htc_linear(x, c_in=c, c_out=c)  # (..., d+1)
    return lorentz_residual(x, p, w_y=self.epsilon[...], c=c, eps=self.eps)

HOPE Example

from hyperbolix.nn_layers import hope, HyperbolicRoPE
from hyperbolix.manifolds import Hyperboloid
import jax.numpy as jnp
import jax

hyperboloid = Hyperboloid()

# Create sequence of hyperboloid points (batch, seq_len, d+1)
key = jax.random.PRNGKey(42)
batch, seq_len, d = 4, 16, 8
spatial = jax.random.normal(key, (batch, seq_len, d)) * 0.1
time = jnp.sqrt(jnp.sum(spatial**2, axis=-1, keepdims=True) + 1.0)
z = jnp.concatenate([time, spatial], axis=-1)  # (4, 16, 9)

# Position indices
positions = jnp.arange(seq_len)

# Apply HOPE (functional interface)
z_encoded = hope(z, positions, c=1.0)
print(z_encoded.shape)  # (4, 16, 9)

# Or use the NNX module wrapper
from flax import nnx
rope = HyperbolicRoPE(dim=d, max_seq_len=64, base=10000.0)
z_encoded = rope(z, positions, c=1.0)

# Verify manifold constraint
for b in range(batch):
    for s in range(seq_len):
        assert hyperboloid.is_in_manifold(z_encoded[b, s], c=1.0, atol=1e-4)

# Verify relative position property: <HOPE(q,i), HOPE(k,j)>_L depends only on i-j
q = z[0, 0]  # Single point
k = z[0, 1]  # Another point

# Same relative offset (=3) at different absolute positions
q_enc_0 = hope(q[None, None, :], jnp.array([0]), c=1.0)[0, 0]
k_enc_3 = hope(k[None, None, :], jnp.array([3]), c=1.0)[0, 0]

q_enc_10 = hope(q[None, None, :], jnp.array([10]), c=1.0)[0, 0]
k_enc_13 = hope(k[None, None, :], jnp.array([13]), c=1.0)[0, 0]

# Minkowski inner products should be equal
def minkowski_inner(x, y):
    return -x[0]*y[0] + jnp.sum(x[1:]*y[1:])

ip1 = minkowski_inner(q_enc_0, k_enc_3)
ip2 = minkowski_inner(q_enc_10, k_enc_13)
print(jnp.allclose(ip1, ip2, atol=1e-5))  # True

Hypformer Positional Encoding Example

from hyperbolix.nn_layers import HypformerPositionalEncoding
from hyperbolix.manifolds import Hyperboloid
from flax import nnx

hyperboloid = Hyperboloid()
import jax.numpy as jnp
import jax

# Create positional encoding layer
d = 8  # spatial dimension
in_features = d + 1  # ambient dimension (including time)
pe = HypformerPositionalEncoding(
    in_features=in_features,
    out_features=d,  # output spatial dimension
    rngs=nnx.Rngs(0),
    init_bound=0.02  # small initialization for stability
)

# Input: batch of hyperboloid points
key = jax.random.PRNGKey(42)
batch_size = 32
spatial = jax.random.normal(key, (batch_size, d)) * 0.1
time = jnp.sqrt(jnp.sum(spatial**2, axis=-1, keepdims=True) + 1.0)
x = jnp.concatenate([time, spatial], axis=-1)  # (32, 9)

# Apply learnable positional encoding
x_encoded = pe(x, c=1.0)
print(x_encoded.shape)  # (32, 9) - shape preserved

# Verify manifold constraint
is_valid = jax.vmap(hyperboloid.is_in_manifold, in_axes=(0, None))(x_encoded, 1.0)
print(is_valid.all())  # True

# The epsilon parameter is learnable
print(f"Epsilon: {pe.epsilon.value}")  # Initially 1.0

# Use in training loop with gradient updates
import optax
optimizer = nnx.Optimizer(pe, optax.adam(1e-3), wrt=nnx.Param)

def loss_fn(model):
    out = model(x, c=1.0)
    return jnp.sum(out**2)  # Dummy loss

loss, grads = nnx.value_and_grad(loss_fn)(pe)
optimizer.update(pe, grads)

print(f"Epsilon after update: {pe.epsilon.value}")  # Changed

Lorentzian Residual Example

from hyperbolix.nn_layers import lorentz_residual
from hyperbolix.manifolds import Hyperboloid
import jax.numpy as jnp
import jax

hyperboloid = Hyperboloid()

# Create two hyperboloid points
key1, key2 = jax.random.split(jax.random.PRNGKey(42))
d = 6
c = 1.0

spatial_x = jax.random.normal(key1, (d,)) * 0.1
time_x = jnp.sqrt(jnp.sum(spatial_x**2) + 1/c)
x = jnp.concatenate([time_x[None], spatial_x])

spatial_y = jax.random.normal(key2, (d,)) * 0.1
time_y = jnp.sqrt(jnp.sum(spatial_y**2) + 1/c)
y = jnp.concatenate([time_y[None], spatial_y])

# Combine with Lorentzian residual (weighted midpoint)
result = lorentz_residual(x, y, w_y=0.5, c=c)

# Verify output is on hyperboloid
assert hyperboloid.is_in_manifold(result, c, atol=1e-5)

# Works with batches too
x_batch = jax.random.normal(jax.random.PRNGKey(0), (8, d+1))
y_batch = jax.random.normal(jax.random.PRNGKey(1), (8, d+1))

# Project to manifold
x_batch = jax.vmap(hyperboloid.proj, in_axes=(0, None))(x_batch, c)
y_batch = jax.vmap(hyperboloid.proj, in_axes=(0, None))(y_batch, c)

# Apply residual connection
result_batch = lorentz_residual(x_batch, y_batch, w_y=0.3, c=c)
print(result_batch.shape)  # (8, 7)

Positional Encoding for Hyperbolic Transformers

HOPE (Hyperbolic Rotary Positional Encoding):

  • Deterministic, no learnable parameters
  • Based on RoPE: applies rotations to spatial components
  • Preserves relative position information: ⟨HOPE(q,i), HOPE(k,j)⟩_L depends only on i-j
  • Rotation is an isometry: preserves spatial norms
  • Identity at position 0
  • Suitable for long sequences (no learned embeddings to store)

HypformerPositionalEncoding:

  • Learnable, adapts to task
  • Uses HTCLinear + Lorentzian residual
  • epsilon parameter controls position encoding magnitude
  • More flexible but requires training
  • Suitable when position patterns are task-specific

Lorentzian Residual:

  • Building block for both approaches
  • Computes weighted Lorentzian midpoint
  • Used for skip connections in hyperbolic Transformers/ResNets
  • Formula: ave = x + w_y*y, normalized to hyperboloid

Hybrid Euclidean-Hyperbolic

Hyper++ Feature Scaling

hyperbolix.nn_layers.HyperPPFeatureScaling

HyperPPFeatureScaling(
    dim: int,
    *,
    alpha: float | None = None,
    activation: Callable | None = jax.nn.tanh,
    rngs: Rngs,
)

Bases: Module

Hyper++ feature scaling for hybrid Euclidean-hyperbolic networks.

Parameters:

Name Type Description Default
dim int

Embedding dimension (needed for nnx.Linear and 1/sqrt(d) scaling).

required
alpha float or None

Value in (0, 1) enabling learned rescaling. None disables it and the layer is entirely parameter-free. Default: None.

None
activation Callable or None

Lipschitz activation function. Default: jax.nn.tanh. None to skip.

tanh
rngs Rngs

Random number generators for sub-layers.

required

Examples:

>>> from flax import nnx
>>> from hyperbolix.nn_layers import HyperPPFeatureScaling
>>>
>>> # Parameter-free mode
>>> layer = HyperPPFeatureScaling(dim=64, rngs=nnx.Rngs(0))
>>> y = layer(x, c=1.0)
>>>
>>> # Learned rescaling mode
>>> layer = HyperPPFeatureScaling(dim=64, alpha=0.9, rngs=nnx.Rngs(0))
>>> y = layer(x, c=0.1)
Source code in hyperbolix/nn_layers/hybrid_regularization.py
def __init__(
    self,
    dim: int,
    *,
    alpha: float | None = None,
    activation: Callable | None = jax.nn.tanh,
    rngs: nnx.Rngs,
):
    if alpha is not None:
        if not (0.0 < alpha < 1.0):
            msg = f"alpha must be in (0, 1), got {alpha}"
            raise ValueError(msg)
        self._atanh_alpha = math.atanh(alpha)
        self.xi_theta = nnx.Linear(dim, 1, rngs=rngs)
    else:
        self._atanh_alpha = None
        self.xi_theta = None

    self.rms_norm = nnx.RMSNorm(dim, use_scale=False, rngs=rngs)
    self._inv_sqrt_dim = 1.0 / math.sqrt(dim)
    self.activation = activation
__call__
__call__(
    x_BD: Float[Array, "B D"], c: float = 1.0
) -> Float[Array, "B D"]

Apply Hyper++ feature scaling pipeline.

Parameters:

Name Type Description Default
x_BD Array of shape (B, D)

Euclidean features from the last Euclidean layer.

required
c float

Curvature parameter (positive). Passed at call time per library convention to support learnable curvature.

1.0

Returns:

Type Description
Array of shape (B, D)

Scaled features ready for expmap_0.

Source code in hyperbolix/nn_layers/hybrid_regularization.py
def __call__(
    self,
    x_BD: Float[Array, "B D"],
    c: float = 1.0,
) -> Float[Array, "B D"]:
    """Apply Hyper++ feature scaling pipeline.

    Parameters
    ----------
    x_BD : Array of shape (B, D)
        Euclidean features from the last Euclidean layer.
    c : float
        Curvature parameter (positive). Passed at call time per
        library convention to support learnable curvature.

    Returns
    -------
    Array of shape (B, D)
        Scaled features ready for expmap_0.
    """
    # Step 1: RMSNorm (parameter-free, no learned scale)
    x_BD = self.rms_norm(x_BD)

    # Step 2: Lipschitz activation
    if self.activation is not None:
        x_BD = self.activation(x_BD)

    # Step 3: Dimension scaling
    x_BD = x_BD * self._inv_sqrt_dim

    # Step 4: Learned rescaling
    if self._atanh_alpha is not None:
        assert self.xi_theta is not None
        rho_max = self._atanh_alpha / jnp.sqrt(c)  # scalar
        scale_B1 = rho_max * jax.nn.sigmoid(self.xi_theta(x_BD))  # (B, 1)
        x_BD = scale_B1 * x_BD

    return x_BD

HyperPPFeatureScaling is applied after the last Euclidean layer and before expmap_0 to either the Poincaré ball or Hyperboloid. It operates entirely in Euclidean space but uses hyperbolic geometry to bound the output norm via rho_max = atanh(alpha) / sqrt(c).

Usage Example

import jax
import jax.numpy as jnp
from flax import nnx
from hyperbolix.nn_layers import HyperPPFeatureScaling
from hyperbolix.manifolds import Poincare

poincare = Poincare()

# Parameter-free mode (RMSNorm + tanh + dim scaling only)
layer = HyperPPFeatureScaling(dim=64, rngs=nnx.Rngs(0))
x = jax.random.normal(jax.random.PRNGKey(0), (32, 64))
scaled = layer(x, c=1.0)

# With learned rescaling (adds rho_max * sigmoid(xi(x)) * x)
layer = HyperPPFeatureScaling(dim=64, alpha=0.9, rngs=nnx.Rngs(0))
scaled = layer(x, c=0.1)

# Map to Poincaré ball
expmap_batch = jax.vmap(poincare.expmap_0, in_axes=(0, None))
points = expmap_batch(scaled, 0.1)

Pipeline Steps

  1. RMSNorm (parameter-free): normalizes feature magnitudes
  2. Lipschitz activation (default tanh, configurable): bounds per-component values
  3. Dimension scaling (1/sqrt(d)): ensures norm doesn't grow with dimension
  4. Learned rescaling (when alpha is set): rho_max * sigmoid(xi_theta(x)) * x where rho_max = atanh(alpha) / sqrt(c)

When alpha is None, the layer is entirely parameter-free. When set, only xi_theta (a linear projection to scalar) has learnable parameters.

Regression Layers

Single-layer classifiers with Riemannian geometry.

Poincaré Regression

hyperbolix.nn_layers.HypRegressionPoincare

HypRegressionPoincare(
    manifold_module: Poincare,
    in_dim: int,
    out_dim: int,
    *,
    rngs: Rngs,
    input_space: str = "manifold",
    clamping_factor: float = 1.0,
    smoothing_factor: float = 50.0,
)

Bases: Module

Hyperbolic Neural Networks multinomial linear regression layer (Poincaré ball model).

Computation steps: 0) Project the input tensor onto the manifold (optional) 1) Compute the multinomial linear regression score(s)

Parameters:

Name Type Description Default
manifold_module object

Class-based Poincare manifold instance

required
in_dim int

Dimension of the input space

required
out_dim int

Dimension of the output space

required
rngs Rngs

Random number generators for parameter initialization

required
input_space str

Type of the input tensor, either 'tangent' or 'manifold' (default: 'manifold'). Note: This is a static configuration - changing it after initialization requires recompilation.

'manifold'
clamping_factor float

Clamping factor for the multinomial linear regression output (default: 1.0)

1.0
smoothing_factor float

Smoothing factor for the multinomial linear regression output (default: 50.0)

50.0
Notes

JIT Compatibility: This layer is designed to work with nnx.jit. Configuration parameters (input_space, clamping_factor, smoothing_factor) are treated as static and will be baked into the compiled function.

References

Ganea Octavian, Gary Bécigneul, and Thomas Hofmann. "Hyperbolic neural networks." Advances in neural information processing systems 31 (2018).

Source code in hyperbolix/nn_layers/poincare_regression.py
def __init__(
    self,
    manifold_module: Poincare,
    in_dim: int,
    out_dim: int,
    *,
    rngs: nnx.Rngs,
    input_space: str = "manifold",
    clamping_factor: float = 1.0,
    smoothing_factor: float = 50.0,
):
    if input_space not in ["tangent", "manifold"]:
        raise ValueError(f"input_space must be either 'tangent' or 'manifold', got '{input_space}'")

    # Static configuration (treated as compile-time constants for JIT)
    validate_poincare_manifold(
        manifold_module,
        required_methods=("proj", "addition", "expmap_0", "ptransp_0", "conformal_factor", "compute_mlr_pp"),
    )
    self.manifold = manifold_module
    self.in_dim = in_dim
    self.out_dim = out_dim
    self.input_space = input_space
    self.clamping_factor = clamping_factor
    self.smoothing_factor = smoothing_factor

    # Trainable parameters
    # Tangent space weight (Euclidean)
    self.kernel = nnx.Param(jax.random.normal(rngs.params(), (out_dim, in_dim)))
    # Manifold bias (initialized to small random values)
    self.bias = ManifoldParam(
        jax.random.normal(rngs.params(), (out_dim, in_dim)) * 0.01,
        manifold=self.manifold,
        curvature=1.0,
    )
__call__
__call__(
    x: Float[Array, "batch in_dim"], c: float = 1.0
) -> Float[Array, "batch out_dim"]

Forward pass through the hyperbolic regression layer.

Parameters:

Name Type Description Default
x Array of shape (batch, in_dim)

Input tensor where the hyperbolic_axis is last

required
c float

Manifold curvature (default: 1.0)

1.0

Returns:

Name Type Description
res Array of shape (batch, out_dim)

Multinomial linear regression scores

Source code in hyperbolix/nn_layers/poincare_regression.py
def __call__(
    self,
    x: Float[Array, "batch in_dim"],
    c: float = 1.0,
) -> Float[Array, "batch out_dim"]:
    """
    Forward pass through the hyperbolic regression layer.

    Parameters
    ----------
    x : Array of shape (batch, in_dim)
        Input tensor where the hyperbolic_axis is last
    c : float
        Manifold curvature (default: 1.0)

    Returns
    -------
    res : Array of shape (batch, out_dim)
        Multinomial linear regression scores
    """
    # Map to manifold if needed (static branch - JIT friendly)
    if self.input_space == "tangent":
        x = jax.vmap(self.manifold.expmap_0, in_axes=(0, None), out_axes=0)(x, c)

    # Project bias to manifold (vmap over out_dim dimension)
    bias_PD = jax.vmap(self.manifold.proj, in_axes=(0, None), out_axes=0)(self.bias[...], c)  # type: ignore[arg-type]

    # Parallel transport weight from tangent space at origin to tangent space at bias
    pt_weight_PD = jax.vmap(self.manifold.ptransp_0, in_axes=(0, 0, None), out_axes=0)(self.kernel[...], bias_PD, c)

    # Compute the multinomial linear regression score(s)
    res_BP = self._compute_mlr(x, pt_weight_PD, bias_PD, c)
    return res_BP

hyperbolix.nn_layers.HypRegressionPoincarePP

HypRegressionPoincarePP(
    manifold_module: Poincare,
    in_dim: int,
    out_dim: int,
    *,
    rngs: Rngs,
    input_space: str = "manifold",
    clamping_factor: float = 1.0,
    smoothing_factor: float = 50.0,
)

Bases: Module

Hyperbolic Neural Networks ++ multinomial linear regression layer (Poincaré ball model).

Computation steps: 0) Project the input tensor onto the manifold (optional) 1) Compute the multinomial linear regression score(s)

Parameters:

Name Type Description Default
manifold_module object

Class-based Poincare manifold instance

required
in_dim int

Dimension of the input space

required
out_dim int

Dimension of the output space

required
rngs Rngs

Random number generators for parameter initialization

required
input_space str

Type of the input tensor, either 'tangent' or 'manifold' (default: 'manifold'). Note: This is a static configuration - changing it after initialization requires recompilation.

'manifold'
clamping_factor float

Clamping factor for the multinomial linear regression output (default: 1.0)

1.0
smoothing_factor float

Smoothing factor for the multinomial linear regression output (default: 50.0)

50.0
Notes

JIT Compatibility: This layer is designed to work with nnx.jit. Configuration parameters (input_space, clamping_factor, smoothing_factor) are treated as static and will be baked into the compiled function.

References

Shimizu Ryohei, Yusuke Mukuta, and Tatsuya Harada. "Hyperbolic neural networks++." arXiv preprint arXiv:2006.08210 (2020).

Source code in hyperbolix/nn_layers/poincare_regression.py
def __init__(
    self,
    manifold_module: Poincare,
    in_dim: int,
    out_dim: int,
    *,
    rngs: nnx.Rngs,
    input_space: str = "manifold",
    clamping_factor: float = 1.0,
    smoothing_factor: float = 50.0,
):
    if input_space not in ["tangent", "manifold"]:
        raise ValueError(f"input_space must be either 'tangent' or 'manifold', got '{input_space}'")

    # Static configuration (treated as compile-time constants for JIT)
    validate_poincare_manifold(
        manifold_module,
        required_methods=("proj", "addition", "expmap_0", "ptransp_0", "conformal_factor", "compute_mlr_pp"),
    )
    self.manifold = manifold_module
    self.in_dim = in_dim
    self.out_dim = out_dim
    self.input_space = input_space
    self.clamping_factor = clamping_factor
    self.smoothing_factor = smoothing_factor

    # Trainable parameters
    # Tangent space weight — scaled to match reference (van Spengler et al. 2023)
    # Reference uses std = (2 * in_dim * out_dim)^{-0.5}; unscaled normal(0,1) gives
    # row norms ≈ sqrt(in_dim) which overwhelms the MLR output scaling.
    std = 1.0 / jnp.sqrt(2.0 * in_dim * out_dim)
    self.kernel = nnx.Param(jax.random.normal(rngs.params(), (out_dim, in_dim)) * std)
    # Scalar bias (initialized to small random values)
    self.bias = nnx.Param(jax.random.normal(rngs.params(), (out_dim, 1)) * 0.01)
__call__
__call__(
    x: Float[Array, "batch in_dim"], c: float = 1.0
) -> Float[Array, "batch out_dim"]

Forward pass through the HNN++ hyperbolic regression layer.

Parameters:

Name Type Description Default
x Array of shape (batch, in_dim)

Input tensor where the hyperbolic_axis is last

required
c float

Manifold curvature (default: 1.0)

1.0

Returns:

Name Type Description
res Array of shape (batch, out_dim)

Multinomial linear regression scores

Source code in hyperbolix/nn_layers/poincare_regression.py
def __call__(
    self,
    x: Float[Array, "batch in_dim"],
    c: float = 1.0,
) -> Float[Array, "batch out_dim"]:
    """
    Forward pass through the HNN++ hyperbolic regression layer.

    Parameters
    ----------
    x : Array of shape (batch, in_dim)
        Input tensor where the hyperbolic_axis is last
    c : float
        Manifold curvature (default: 1.0)

    Returns
    -------
    res : Array of shape (batch, out_dim)
        Multinomial linear regression scores
    """
    # Map to manifold if needed (static branch - JIT friendly)
    if self.input_space == "tangent":
        x = jax.vmap(self.manifold.expmap_0, in_axes=(0, None), out_axes=0)(x, c)

    # Compute multinomial linear regression
    res = self.manifold.compute_mlr_pp(
        x,
        self.kernel[...],
        self.bias[...],
        c,
        self.clamping_factor,
        self.smoothing_factor,
    )

    return res

Hyperboloid Regression

hyperbolix.nn_layers.HypRegressionHyperboloid

HypRegressionHyperboloid(
    manifold_module: Hyperboloid,
    in_dim: int,
    out_dim: int,
    *,
    rngs: Rngs,
    input_space: str = "manifold",
    clamping_factor: float = 1.0,
    smoothing_factor: float = 50.0,
)

Bases: Module

Fully Hyperbolic Convolutional Neural Networks multinomial linear regression layer (Hyperboloid model).

Computation steps: 0) Project the input tensor onto the manifold (optional) 1) Compute the multinomial linear regression score(s)

Parameters:

Name Type Description Default
manifold_module object

Class-based Hyperboloid manifold instance

required
in_dim int

Dimension of the input space

required
out_dim int

Dimension of the output space

required
rngs Rngs

Random number generators for parameter initialization

required
input_space str

Type of the input tensor, either 'tangent' or 'manifold' (default: 'manifold'). Note: This is a static configuration - changing it after initialization requires recompilation.

'manifold'
clamping_factor float

Clamping factor for the multinomial linear regression output (default: 1.0)

1.0
smoothing_factor float

Smoothing factor for the multinomial linear regression output (default: 50.0)

50.0
Notes

JIT Compatibility: This layer is designed to work with nnx.jit. Configuration parameters (input_space, clamping_factor, smoothing_factor) are treated as static and will be baked into the compiled function.

References

Ahmad Bdeir, Kristian Schwethelm, and Niels Landwehr. "Fully hyperbolic convolutional neural networks for computer vision." arXiv preprint arXiv:2303.15919 (2023).

Source code in hyperbolix/nn_layers/hyperboloid_regression.py
def __init__(
    self,
    manifold_module: Hyperboloid,
    in_dim: int,
    out_dim: int,
    *,
    rngs: nnx.Rngs,
    input_space: str = "manifold",
    clamping_factor: float = 1.0,
    smoothing_factor: float = 50.0,
):
    if input_space not in ["tangent", "manifold"]:
        raise ValueError(f"input_space must be either 'tangent' or 'manifold', got '{input_space}'")

    # Static configuration (treated as compile-time constants for JIT)
    validate_hyperboloid_manifold(manifold_module, required_methods=("expmap_0", "compute_mlr"))
    self.manifold = manifold_module
    self.in_dim = in_dim
    self.out_dim = out_dim
    self.input_space = input_space
    self.clamping_factor = clamping_factor
    self.smoothing_factor = smoothing_factor

    # Trainable parameters
    # kernel lies in the tangent space of the Hyperboloid origin, so the time coordinate along axis is zero
    self.kernel = nnx.Param(jax.random.normal(rngs.params(), (out_dim, in_dim - 1)))
    # Scalar bias (initialized to small random values)
    self.bias = nnx.Param(jax.random.normal(rngs.params(), (out_dim, 1)) * 0.01)
__call__
__call__(
    x: Float[Array, "batch in_dim"], c: float = 1.0
) -> Float[Array, "batch out_dim"]

Forward pass through the hyperbolic regression layer.

Parameters:

Name Type Description Default
x Array of shape (batch, in_dim)

Input tensor where the hyperbolic_axis is last

required
c float

Manifold curvature (default: 1.0)

1.0

Returns:

Name Type Description
res Array of shape (batch, out_dim)

Multinomial linear regression scores

Source code in hyperbolix/nn_layers/hyperboloid_regression.py
def __call__(
    self,
    x: Float[Array, "batch in_dim"],
    c: float = 1.0,
) -> Float[Array, "batch out_dim"]:
    """
    Forward pass through the hyperbolic regression layer.

    Parameters
    ----------
    x : Array of shape (batch, in_dim)
        Input tensor where the hyperbolic_axis is last
    c : float
        Manifold curvature (default: 1.0)

    Returns
    -------
    res : Array of shape (batch, out_dim)
        Multinomial linear regression scores
    """
    # Map to manifold if needed (static branch - JIT friendly)
    if self.input_space == "tangent":
        x = jax.vmap(self.manifold.expmap_0, in_axes=(0, None), out_axes=0)(x, c)

    # Compute multinomial linear regression
    res = self.manifold.compute_mlr(
        x,
        self.kernel[...],
        self.bias[...],
        c,
        self.clamping_factor,
        self.smoothing_factor,
    )

    return res

FGG Lorentz MLR (Klis et al. 2026)

hyperbolix.nn_layers.FGGLorentzMLR

FGGLorentzMLR(
    in_features: int,
    num_classes: int,
    *,
    rngs: Rngs,
    reset_params: str = "mlr",
    init_bias: float = 0.5,
    eps: float = 1e-07,
)

Bases: Module

Fast and Geometrically Grounded Lorentz multinomial logistic regression.

Outputs Euclidean logits (signed scaled distances to hyperplanes) using the FGG spacelike V construction. Unlike HypRegressionHyperboloid, this layer uses the distance-to-hyperplane formulation matching the reference fc_mlr (signed_dist2hyperplanes_scaled_angle).

Forward pass (matching reference fc_mlr): 1. Build V_mink from (z, a) 2. mink = x @ V_mink (Minkowski inner products) 3. logits = asinh(sqrt(c) * mink) / sqrt(c) (signed scaled distances)

Parameters:

Name Type Description Default
in_features int

Input ambient dimension (D_in + 1), including time component.

required
num_classes int

Number of output classes.

required
rngs Rngs

Random number generators for parameter initialization.

required
reset_params str

Weight initialization scheme for hyperplane normals: "mlr" (normal, std=sqrt(5/I)) or "default" (uniform) (default: "mlr").

'mlr'
init_bias float

Initial value for bias entries (default: 0.5).

0.5
eps float

Numerical stability floor (default: 1e-7).

1e-07
References

Klis et al. "Fast and Geometrically Grounded Lorentz Neural Networks" (2026), Eq. 23.

Source code in hyperbolix/nn_layers/hyperboloid_regression.py
def __init__(
    self,
    in_features: int,
    num_classes: int,
    *,
    rngs: nnx.Rngs,
    reset_params: str = "mlr",
    init_bias: float = 0.5,
    eps: float = 1e-7,
):
    if reset_params not in ("default", "mlr"):
        raise ValueError(f"reset_params must be 'default' or 'mlr', got '{reset_params}'")

    in_spatial = in_features - 1  # I
    self.in_features = in_features
    self.num_classes = num_classes
    self.eps = eps

    # Hyperplane normals (spatial) and bias offsets
    # Reference computes std from ambient dimension (in_features)
    key = rngs.params()
    if reset_params == "mlr":
        std = jnp.sqrt(5.0 / in_features)
        self.kernel = nnx.Param(jax.random.normal(key, (in_spatial, num_classes)) * std)
    else:  # default
        stdv = 1.0 / jnp.sqrt(jnp.array(in_features, dtype=jnp.float32))
        self.kernel = nnx.Param(jax.random.uniform(key, (in_spatial, num_classes), minval=-stdv, maxval=stdv))
    self.bias = nnx.Param(jnp.full((num_classes,), init_bias))
__call__
__call__(
    x_BAi: Float[Array, "batch in_features"], c: float = 1.0
) -> Float[Array, "batch num_classes"]

Forward pass returning Euclidean logits.

Parameters:

Name Type Description Default
x_BAi (Array, shape(B, Ai))

Input points on the hyperboloid with curvature c.

required
c float

Curvature parameter (default: 1.0).

1.0

Returns:

Name Type Description
logits_BK (Array, shape(B, K))

Euclidean logits (signed scaled distances to hyperplanes).

Source code in hyperbolix/nn_layers/hyperboloid_regression.py
def __call__(
    self,
    x_BAi: Float[Array, "batch in_features"],
    c: float = 1.0,
) -> Float[Array, "batch num_classes"]:
    """Forward pass returning Euclidean logits.

    Parameters
    ----------
    x_BAi : Array, shape (B, Ai)
        Input points on the hyperboloid with curvature ``c``.
    c : float, optional
        Curvature parameter (default: 1.0).

    Returns
    -------
    logits_BK : Array, shape (B, K)
        Euclidean logits (signed scaled distances to hyperplanes).
    """
    # 1. Build V_mink from (z, a)
    V_AiK = build_spacelike_V(self.kernel[...], self.bias[...], c, self.eps)  # (Ai, K)
    # Cast V to match input dtype (avoids float32/float64 scatter warnings)
    V_AiK = V_AiK.astype(x_BAi.dtype)

    # 2. Minkowski inner products
    mink_BK = x_BAi @ V_AiK  # (B, K)

    # 3. Signed scaled distances (matching reference fc_mlr: no norm scaling)
    sqrt_c = jnp.sqrt(c)
    logits_BK = jnp.asinh(sqrt_c * mink_BK) / sqrt_c  # (B, K)

    return logits_BK

FGG Usage Example

import jax
import jax.numpy as jnp
from flax import nnx
from hyperbolix.manifolds import Hyperboloid
from hyperbolix.nn_layers import FGGLinear, FGGConv2D, FGGLorentzMLR, FGGMeanOnlyBatchNorm

hyperboloid = Hyperboloid()
rngs = nnx.Rngs(0)

# --- FGGLinear: FC layer with linear hyperbolic distance growth ---
linear = FGGLinear(
    in_features=33,   # 32 spatial + 1 time
    out_features=65,  # 64 spatial + 1 time
    rngs=rngs,
    activation=jax.nn.relu,
    reset_params="lorentz_kaiming",
)
x_B33 = jnp.ones((8, 33))
x_B33 = x_B33.at[:, 0].set(jnp.sqrt(jnp.sum(x_B33[:, 1:]**2, axis=-1) + 1.0))
y_B65 = linear(x_B33, c=1.0)
print(y_B65.shape)  # (8, 65)

# --- FGGConv2D: 2D conv with manifold-origin padding ---
conv = FGGConv2D(
    manifold_module=hyperboloid,
    in_channels=33,
    out_channels=65,
    kernel_size=3,
    rngs=rngs,
    activation=jax.nn.relu,
    pad_mode="origin",   # manifold origin padding (reference default)
)
x_BHWC = jnp.zeros((4, 14, 14, 33))
x_BHWC = x_BHWC.at[..., 0].set(1.0)  # valid origin point at c=1
y_BHWC = conv(x_BHWC, c=1.0)
print(y_BHWC.shape)  # (4, 14, 14, 65)

# --- FGGLorentzMLR: classification head ---
mlr = FGGLorentzMLR(
    in_features=65,
    num_classes=10,
    rngs=rngs,
    reset_params="mlr",   # N(0, sqrt(5/Ai)) where Ai = in_features (ambient)
    init_bias=0.5,
)
logits_B10 = mlr(y_BHWC.reshape(-1, 65), c=1.0)
print(logits_B10.shape)  # (784, 10)

# --- FGGMeanOnlyBatchNorm: pairs with FGGLinear(use_weight_norm=True) ---
# num_features is the SPATIAL (out) dimension
bn = FGGMeanOnlyBatchNorm(num_features=64, rngs=rngs)
y_normed = bn(y_B65, c_in=1.0, c_out=1.0, use_running_average=False)
print(y_normed.shape)  # (8, 65)

FGG Layer Family

The four FGG components from Klis et al. (2026) form a complete layer stack:

Layer Role Key params
FGGLinear Fully-connected reset_params, use_weight_norm, activation
FGGConv2D 2D convolution pad_mode="origin", wraps FGGLinear
FGGLorentzMLR Classification head reset_params="mlr", init_bias=0.5
FGGMeanOnlyBatchNorm Batch normalization pairs with use_weight_norm=True

Core insight: the sinh/arcsinh cancellation in the Lorentzian activation chain reduces the forward pass to a single matmul with a spacelike V matrix, Euclidean activation, then time reconstruction — achieving linear (not logarithmic) growth of hyperbolic distance.

Regression Example

import jax
from hyperbolix.nn_layers import HypRegressionPoincare
from hyperbolix.manifolds import Poincare
from flax import nnx

poincare = Poincare()

# Multi-class classification (10 classes)
regressor = HypRegressionPoincare(
    manifold_module=poincare,
    in_dim=32,
    out_dim=10,
    rngs=nnx.Rngs(0)
)

# Input: hyperbolic embeddings
x = jax.random.normal(jax.random.PRNGKey(1), (64, 32)) * 0.3
x_proj = jax.vmap(poincare.proj, in_axes=(0, None))(x, 1.0)

# Forward pass returns logits
logits = regressor(x_proj, c=1.0)
print(logits.shape)  # (64, 10)

# Use with softmax for classification
probs = jax.nn.softmax(logits, axis=-1)

Activation Functions

Hyperbolic activation functions that preserve manifold constraints. All activations follow the HRC pattern: apply function to space components, then reconstruct time.

Curvature-Preserving Activations

hyperbolix.nn_layers.hyp_relu

hyp_relu(
    x: Float[Array, "... dim_plus_1"], c: float
) -> Float[Array, "... dim_plus_1"]

Apply ReLU activation to space components of hyperboloid point(s).

Curvature-preserving wrapper around hrc_relu(x, c_in=c, c_out=c).

This function applies the ReLU activation function to the spatial components of hyperboloid points and reconstructs valid manifold points using the hyperboloid constraint.

Mathematical formula: y = [sqrt(||ReLU(x_s)||^2 + 1/c), ReLU(x_s)]

where x_s are the spatial components x[..., 1:].

Parameters:

Name Type Description Default
x Array of shape (..., dim+1)

Input point(s) on the hyperboloid manifold in ambient space, where ... represents arbitrary batch dimensions. The last dimension contains the time component (x[..., 0]) and spatial components (x[..., 1:]).

required
c float

Curvature parameter, must be positive (c > 0).

required

Returns:

Name Type Description
y Array of shape (..., dim+1)

Output point(s) on the hyperboloid manifold, same shape as input.

Notes
  • This function applies ReLU only to spatial components, not the time component
  • The time component is reconstructed using the hyperboloid constraint: -x₀² + ||x_rest||² = -1/c
  • This approach avoids frequent exp/log maps for better numerical stability
  • Works on arrays of any shape, similar to jax.nn.relu
  • For curvature-changing transformations, use hrc_relu which supports different input/output curvatures
References

Ahmad Bdeir, Kristian Schwethelm, and Niels Landwehr. "Fully hyperbolic convolutional neural networks for computer vision." arXiv preprint arXiv:2303.15919 (2023).

Examples:

>>> import jax.numpy as jnp
>>> from hyperbolix.nn_layers import hyp_relu
>>>
>>> # Single point
>>> x = jnp.array([1.05, 0.1, -0.2, 0.15])
>>> y = hyp_relu(x, c=1.0)
>>> y.shape
(4,)
>>>
>>> # Batch of points
>>> x_batch = jnp.ones((8, 5))  # 8 points in 5-dim ambient space
>>> y_batch = hyp_relu(x_batch, c=1.0)
>>> y_batch.shape
(8, 5)
>>>
>>> # Multi-dimensional batch (e.g., feature maps)
>>> x_feature = jnp.ones((4, 16, 16, 10))  # 4 images, 16x16 spatial, 10-dim
>>> y_feature = hyp_relu(x_feature, c=1.0)
>>> y_feature.shape
(4, 16, 16, 10)
Source code in hyperbolix/nn_layers/hyperboloid_activations.py
def hyp_relu(x: Float[Array, "... dim_plus_1"], c: float) -> Float[Array, "... dim_plus_1"]:
    """Apply ReLU activation to space components of hyperboloid point(s).

    Curvature-preserving wrapper around hrc_relu(x, c_in=c, c_out=c).

    This function applies the ReLU activation function to the spatial components
    of hyperboloid points and reconstructs valid manifold points using the
    hyperboloid constraint.

    Mathematical formula:
        y = [sqrt(||ReLU(x_s)||^2 + 1/c), ReLU(x_s)]

    where x_s are the spatial components x[..., 1:].

    Parameters
    ----------
    x : Array of shape (..., dim+1)
        Input point(s) on the hyperboloid manifold in ambient space, where
        ... represents arbitrary batch dimensions. The last dimension contains
        the time component (x[..., 0]) and spatial components (x[..., 1:]).
    c : float
        Curvature parameter, must be positive (c > 0).

    Returns
    -------
    y : Array of shape (..., dim+1)
        Output point(s) on the hyperboloid manifold, same shape as input.

    Notes
    -----
    - This function applies ReLU only to spatial components, not the time component
    - The time component is reconstructed using the hyperboloid constraint:
      -x₀² + ||x_rest||² = -1/c
    - This approach avoids frequent exp/log maps for better numerical stability
    - Works on arrays of any shape, similar to jax.nn.relu
    - For curvature-changing transformations, use `hrc_relu` which supports
      different input/output curvatures

    References
    ----------
    Ahmad Bdeir, Kristian Schwethelm, and Niels Landwehr. "Fully hyperbolic
    convolutional neural networks for computer vision." arXiv preprint
    arXiv:2303.15919 (2023).

    Examples
    --------
    >>> import jax.numpy as jnp
    >>> from hyperbolix.nn_layers import hyp_relu
    >>>
    >>> # Single point
    >>> x = jnp.array([1.05, 0.1, -0.2, 0.15])
    >>> y = hyp_relu(x, c=1.0)
    >>> y.shape
    (4,)
    >>>
    >>> # Batch of points
    >>> x_batch = jnp.ones((8, 5))  # 8 points in 5-dim ambient space
    >>> y_batch = hyp_relu(x_batch, c=1.0)
    >>> y_batch.shape
    (8, 5)
    >>>
    >>> # Multi-dimensional batch (e.g., feature maps)
    >>> x_feature = jnp.ones((4, 16, 16, 10))  # 4 images, 16x16 spatial, 10-dim
    >>> y_feature = hyp_relu(x_feature, c=1.0)
    >>> y_feature.shape
    (4, 16, 16, 10)
    """
    return hrc_relu(x, c_in=c, c_out=c)

hyperbolix.nn_layers.hyp_leaky_relu

hyp_leaky_relu(
    x: Float[Array, "... dim_plus_1"],
    c: float,
    negative_slope: float = 0.01,
) -> Float[Array, "... dim_plus_1"]

Apply LeakyReLU activation to space components of hyperboloid point(s).

Curvature-preserving wrapper around hrc_leaky_relu(x, c_in=c, c_out=c, negative_slope).

This function applies the LeakyReLU activation function to the spatial components of hyperboloid points and reconstructs valid manifold points using the hyperboloid constraint.

Mathematical formula: y = [sqrt(||LeakyReLU(x_s)||^2 + 1/c), LeakyReLU(x_s)]

where x_s are the spatial components x[..., 1:], and LeakyReLU(x) = x if x > 0 else negative_slope * x.

Parameters:

Name Type Description Default
x Array of shape (..., dim+1)

Input point(s) on the hyperboloid manifold in ambient space, where ... represents arbitrary batch dimensions. The last dimension contains the time component (x[..., 0]) and spatial components (x[..., 1:]).

required
c float

Curvature parameter, must be positive (c > 0).

required
negative_slope float

Negative slope coefficient for LeakyReLU (default: 0.01).

0.01

Returns:

Name Type Description
y Array of shape (..., dim+1)

Output point(s) on the hyperboloid manifold, same shape as input.

Notes
  • This function applies LeakyReLU only to spatial components
  • The time component is reconstructed using the hyperboloid constraint
  • LeakyReLU allows small negative values (scaled by negative_slope) which can help gradient flow compared to standard ReLU
  • Works on arrays of any shape, similar to jax.nn.leaky_relu
  • For curvature-changing transformations, use hrc_leaky_relu which supports different input/output curvatures
References

Ahmad Bdeir, Kristian Schwethelm, and Niels Landwehr. "Fully hyperbolic convolutional neural networks for computer vision." arXiv preprint arXiv:2303.15919 (2023).

Examples:

>>> import jax.numpy as jnp
>>> from hyperbolix.nn_layers import hyp_leaky_relu
>>>
>>> # Single point with default negative_slope
>>> x = jnp.array([1.05, 0.1, -0.2, 0.15])
>>> y = hyp_leaky_relu(x, c=1.0)
>>> y.shape
(4,)
>>>
>>> # Custom negative_slope
>>> y = hyp_leaky_relu(x, c=1.0, negative_slope=0.1)
>>>
>>> # Batch of points
>>> x_batch = jnp.ones((8, 5))
>>> y_batch = hyp_leaky_relu(x_batch, c=1.0, negative_slope=0.01)
>>> y_batch.shape
(8, 5)
Source code in hyperbolix/nn_layers/hyperboloid_activations.py
def hyp_leaky_relu(
    x: Float[Array, "... dim_plus_1"], c: float, negative_slope: float = 0.01
) -> Float[Array, "... dim_plus_1"]:
    """Apply LeakyReLU activation to space components of hyperboloid point(s).

    Curvature-preserving wrapper around hrc_leaky_relu(x, c_in=c, c_out=c, negative_slope).

    This function applies the LeakyReLU activation function to the spatial
    components of hyperboloid points and reconstructs valid manifold points
    using the hyperboloid constraint.

    Mathematical formula:
        y = [sqrt(||LeakyReLU(x_s)||^2 + 1/c), LeakyReLU(x_s)]

    where x_s are the spatial components x[..., 1:], and
    LeakyReLU(x) = x if x > 0 else negative_slope * x.

    Parameters
    ----------
    x : Array of shape (..., dim+1)
        Input point(s) on the hyperboloid manifold in ambient space, where
        ... represents arbitrary batch dimensions. The last dimension contains
        the time component (x[..., 0]) and spatial components (x[..., 1:]).
    c : float
        Curvature parameter, must be positive (c > 0).
    negative_slope : float, optional
        Negative slope coefficient for LeakyReLU (default: 0.01).

    Returns
    -------
    y : Array of shape (..., dim+1)
        Output point(s) on the hyperboloid manifold, same shape as input.

    Notes
    -----
    - This function applies LeakyReLU only to spatial components
    - The time component is reconstructed using the hyperboloid constraint
    - LeakyReLU allows small negative values (scaled by negative_slope) which
      can help gradient flow compared to standard ReLU
    - Works on arrays of any shape, similar to jax.nn.leaky_relu
    - For curvature-changing transformations, use `hrc_leaky_relu` which
      supports different input/output curvatures

    References
    ----------
    Ahmad Bdeir, Kristian Schwethelm, and Niels Landwehr. "Fully hyperbolic
    convolutional neural networks for computer vision." arXiv preprint
    arXiv:2303.15919 (2023).

    Examples
    --------
    >>> import jax.numpy as jnp
    >>> from hyperbolix.nn_layers import hyp_leaky_relu
    >>>
    >>> # Single point with default negative_slope
    >>> x = jnp.array([1.05, 0.1, -0.2, 0.15])
    >>> y = hyp_leaky_relu(x, c=1.0)
    >>> y.shape
    (4,)
    >>>
    >>> # Custom negative_slope
    >>> y = hyp_leaky_relu(x, c=1.0, negative_slope=0.1)
    >>>
    >>> # Batch of points
    >>> x_batch = jnp.ones((8, 5))
    >>> y_batch = hyp_leaky_relu(x_batch, c=1.0, negative_slope=0.01)
    >>> y_batch.shape
    (8, 5)
    """
    return hrc_leaky_relu(x, c_in=c, c_out=c, negative_slope=negative_slope)

hyperbolix.nn_layers.hyp_tanh

hyp_tanh(
    x: Float[Array, "... dim_plus_1"], c: float
) -> Float[Array, "... dim_plus_1"]

Apply tanh activation to space components of hyperboloid point(s).

Curvature-preserving wrapper around hrc_tanh(x, c_in=c, c_out=c).

This function applies the hyperbolic tangent activation function to the spatial components of hyperboloid points and reconstructs valid manifold points using the hyperboloid constraint.

Mathematical formula: y = [sqrt(||tanh(x_s)||^2 + 1/c), tanh(x_s)]

where x_s are the spatial components x[..., 1:].

Parameters:

Name Type Description Default
x Array of shape (..., dim+1)

Input point(s) on the hyperboloid manifold in ambient space, where ... represents arbitrary batch dimensions. The last dimension contains the time component (x[..., 0]) and spatial components (x[..., 1:]).

required
c float

Curvature parameter, must be positive (c > 0).

required

Returns:

Name Type Description
y Array of shape (..., dim+1)

Output point(s) on the hyperboloid manifold, same shape as input.

Notes
  • This function applies tanh only to spatial components
  • The time component is reconstructed using the hyperboloid constraint
  • Tanh naturally bounds outputs in [-1, 1], which can help with stability
  • Works on arrays of any shape, similar to jax.nn.tanh
  • For curvature-changing transformations, use hrc_tanh which supports different input/output curvatures
References

Ahmad Bdeir, Kristian Schwethelm, and Niels Landwehr. "Fully hyperbolic convolutional neural networks for computer vision." arXiv preprint arXiv:2303.15919 (2023).

Examples:

>>> import jax.numpy as jnp
>>> from hyperbolix.nn_layers import hyp_tanh
>>>
>>> # Single point
>>> x = jnp.array([1.05, 0.1, -0.2, 0.15])
>>> y = hyp_tanh(x, c=1.0)
>>> y.shape
(4,)
>>>
>>> # Batch of points
>>> x_batch = jnp.ones((8, 5))
>>> y_batch = hyp_tanh(x_batch, c=1.0)
>>> y_batch.shape
(8, 5)
>>>
>>> # Verify spatial components are bounded
>>> import jax
>>> assert jnp.all(jnp.abs(y_batch[..., 1:]) <= 1.0)
Source code in hyperbolix/nn_layers/hyperboloid_activations.py
def hyp_tanh(x: Float[Array, "... dim_plus_1"], c: float) -> Float[Array, "... dim_plus_1"]:
    """Apply tanh activation to space components of hyperboloid point(s).

    Curvature-preserving wrapper around hrc_tanh(x, c_in=c, c_out=c).

    This function applies the hyperbolic tangent activation function to the
    spatial components of hyperboloid points and reconstructs valid manifold
    points using the hyperboloid constraint.

    Mathematical formula:
        y = [sqrt(||tanh(x_s)||^2 + 1/c), tanh(x_s)]

    where x_s are the spatial components x[..., 1:].

    Parameters
    ----------
    x : Array of shape (..., dim+1)
        Input point(s) on the hyperboloid manifold in ambient space, where
        ... represents arbitrary batch dimensions. The last dimension contains
        the time component (x[..., 0]) and spatial components (x[..., 1:]).
    c : float
        Curvature parameter, must be positive (c > 0).

    Returns
    -------
    y : Array of shape (..., dim+1)
        Output point(s) on the hyperboloid manifold, same shape as input.

    Notes
    -----
    - This function applies tanh only to spatial components
    - The time component is reconstructed using the hyperboloid constraint
    - Tanh naturally bounds outputs in [-1, 1], which can help with stability
    - Works on arrays of any shape, similar to jax.nn.tanh
    - For curvature-changing transformations, use `hrc_tanh` which supports
      different input/output curvatures

    References
    ----------
    Ahmad Bdeir, Kristian Schwethelm, and Niels Landwehr. "Fully hyperbolic
    convolutional neural networks for computer vision." arXiv preprint
    arXiv:2303.15919 (2023).

    Examples
    --------
    >>> import jax.numpy as jnp
    >>> from hyperbolix.nn_layers import hyp_tanh
    >>>
    >>> # Single point
    >>> x = jnp.array([1.05, 0.1, -0.2, 0.15])
    >>> y = hyp_tanh(x, c=1.0)
    >>> y.shape
    (4,)
    >>>
    >>> # Batch of points
    >>> x_batch = jnp.ones((8, 5))
    >>> y_batch = hyp_tanh(x_batch, c=1.0)
    >>> y_batch.shape
    (8, 5)
    >>>
    >>> # Verify spatial components are bounded
    >>> import jax
    >>> assert jnp.all(jnp.abs(y_batch[..., 1:]) <= 1.0)
    """
    return hrc_tanh(x, c_in=c, c_out=c)

hyperbolix.nn_layers.hyp_swish

hyp_swish(
    x: Float[Array, "... dim_plus_1"], c: float
) -> Float[Array, "... dim_plus_1"]

Apply Swish/SiLU activation to space components of hyperboloid point(s).

Curvature-preserving wrapper around hrc_swish(x, c_in=c, c_out=c).

This function applies the Swish (also known as SiLU) activation function to the spatial components of hyperboloid points and reconstructs valid manifold points using the hyperboloid constraint.

Swish is defined as: swish(x) = x * sigmoid(x)

Mathematical formula: y = [sqrt(||swish(x_s)||^2 + 1/c), swish(x_s)]

where x_s are the spatial components x[..., 1:].

Parameters:

Name Type Description Default
x Array of shape (..., dim+1)

Input point(s) on the hyperboloid manifold in ambient space, where ... represents arbitrary batch dimensions. The last dimension contains the time component (x[..., 0]) and spatial components (x[..., 1:]).

required
c float

Curvature parameter, must be positive (c > 0).

required

Returns:

Name Type Description
y Array of shape (..., dim+1)

Output point(s) on the hyperboloid manifold, same shape as input.

Notes
  • This function applies Swish only to spatial components
  • The time component is reconstructed using the hyperboloid constraint
  • Swish is smooth and non-monotonic, often performing well in deep networks
  • Works on arrays of any shape, similar to jax.nn.swish
  • For curvature-changing transformations, use hrc_swish which supports different input/output curvatures
References

Ahmad Bdeir, Kristian Schwethelm, and Niels Landwehr. "Fully hyperbolic convolutional neural networks for computer vision." arXiv preprint arXiv:2303.15919 (2023).

Examples:

>>> import jax.numpy as jnp
>>> from hyperbolix.nn_layers import hyp_swish
>>>
>>> # Single point
>>> x = jnp.array([1.05, 0.1, -0.2, 0.15])
>>> y = hyp_swish(x, c=1.0)
>>> y.shape
(4,)
>>>
>>> # Batch of points
>>> x_batch = jnp.ones((8, 5))
>>> y_batch = hyp_swish(x_batch, c=1.0)
>>> y_batch.shape
(8, 5)
>>>
>>> # Multi-dimensional batch
>>> x_feature = jnp.ones((4, 16, 16, 10))
>>> y_feature = hyp_swish(x_feature, c=1.0)
>>> y_feature.shape
(4, 16, 16, 10)
Source code in hyperbolix/nn_layers/hyperboloid_activations.py
def hyp_swish(x: Float[Array, "... dim_plus_1"], c: float) -> Float[Array, "... dim_plus_1"]:
    """Apply Swish/SiLU activation to space components of hyperboloid point(s).

    Curvature-preserving wrapper around hrc_swish(x, c_in=c, c_out=c).

    This function applies the Swish (also known as SiLU) activation function
    to the spatial components of hyperboloid points and reconstructs valid
    manifold points using the hyperboloid constraint.

    Swish is defined as: swish(x) = x * sigmoid(x)

    Mathematical formula:
        y = [sqrt(||swish(x_s)||^2 + 1/c), swish(x_s)]

    where x_s are the spatial components x[..., 1:].

    Parameters
    ----------
    x : Array of shape (..., dim+1)
        Input point(s) on the hyperboloid manifold in ambient space, where
        ... represents arbitrary batch dimensions. The last dimension contains
        the time component (x[..., 0]) and spatial components (x[..., 1:]).
    c : float
        Curvature parameter, must be positive (c > 0).

    Returns
    -------
    y : Array of shape (..., dim+1)
        Output point(s) on the hyperboloid manifold, same shape as input.

    Notes
    -----
    - This function applies Swish only to spatial components
    - The time component is reconstructed using the hyperboloid constraint
    - Swish is smooth and non-monotonic, often performing well in deep networks
    - Works on arrays of any shape, similar to jax.nn.swish
    - For curvature-changing transformations, use `hrc_swish` which supports
      different input/output curvatures

    References
    ----------
    Ahmad Bdeir, Kristian Schwethelm, and Niels Landwehr. "Fully hyperbolic
    convolutional neural networks for computer vision." arXiv preprint
    arXiv:2303.15919 (2023).

    Examples
    --------
    >>> import jax.numpy as jnp
    >>> from hyperbolix.nn_layers import hyp_swish
    >>>
    >>> # Single point
    >>> x = jnp.array([1.05, 0.1, -0.2, 0.15])
    >>> y = hyp_swish(x, c=1.0)
    >>> y.shape
    (4,)
    >>>
    >>> # Batch of points
    >>> x_batch = jnp.ones((8, 5))
    >>> y_batch = hyp_swish(x_batch, c=1.0)
    >>> y_batch.shape
    (8, 5)
    >>>
    >>> # Multi-dimensional batch
    >>> x_feature = jnp.ones((4, 16, 16, 10))
    >>> y_feature = hyp_swish(x_feature, c=1.0)
    >>> y_feature.shape
    (4, 16, 16, 10)
    """
    return hrc_swish(x, c_in=c, c_out=c)

hyperbolix.nn_layers.hyp_gelu

hyp_gelu(
    x: Float[Array, "... dim_plus_1"], c: float
) -> Float[Array, "... dim_plus_1"]

Apply GELU activation to space components of hyperboloid point(s).

Curvature-preserving wrapper around hrc_gelu(x, c_in=c, c_out=c).

This function applies the Gaussian Error Linear Unit (GELU) activation function to the spatial components of hyperboloid points and reconstructs valid manifold points using the hyperboloid constraint.

Mathematical formula: y = [sqrt(||GELU(x_s)||^2 + 1/c), GELU(x_s)]

where x_s are the spatial components x[..., 1:].

Parameters:

Name Type Description Default
x Array of shape (..., dim+1)

Input point(s) on the hyperboloid manifold in ambient space, where ... represents arbitrary batch dimensions. The last dimension contains the time component (x[..., 0]) and spatial components (x[..., 1:]).

required
c float

Curvature parameter, must be positive (c > 0).

required

Returns:

Name Type Description
y Array of shape (..., dim+1)

Output point(s) on the hyperboloid manifold, same shape as input.

Notes
  • This function applies GELU only to spatial components
  • The time component is reconstructed using the hyperboloid constraint
  • GELU is smooth and commonly used in transformer architectures
  • Works on arrays of any shape, similar to jax.nn.gelu
  • For curvature-changing transformations, use hrc_gelu which supports different input/output curvatures

Examples:

>>> import jax.numpy as jnp
>>> from hyperbolix.nn_layers import hyp_gelu
>>>
>>> # Single point
>>> x = jnp.array([1.05, 0.1, -0.2, 0.15])
>>> y = hyp_gelu(x, c=1.0)
>>> y.shape
(4,)
>>>
>>> # Batch of points
>>> x_batch = jnp.ones((8, 5))
>>> y_batch = hyp_gelu(x_batch, c=1.0)
>>> y_batch.shape
(8, 5)
>>>
>>> # Multi-dimensional batch
>>> x_feature = jnp.ones((4, 16, 16, 10))
>>> y_feature = hyp_gelu(x_feature, c=1.0)
>>> y_feature.shape
(4, 16, 16, 10)
Source code in hyperbolix/nn_layers/hyperboloid_activations.py
def hyp_gelu(x: Float[Array, "... dim_plus_1"], c: float) -> Float[Array, "... dim_plus_1"]:
    """Apply GELU activation to space components of hyperboloid point(s).

    Curvature-preserving wrapper around hrc_gelu(x, c_in=c, c_out=c).

    This function applies the Gaussian Error Linear Unit (GELU) activation function
    to the spatial components of hyperboloid points and reconstructs valid manifold
    points using the hyperboloid constraint.

    Mathematical formula:
        y = [sqrt(||GELU(x_s)||^2 + 1/c), GELU(x_s)]

    where x_s are the spatial components x[..., 1:].

    Parameters
    ----------
    x : Array of shape (..., dim+1)
        Input point(s) on the hyperboloid manifold in ambient space, where
        ... represents arbitrary batch dimensions. The last dimension contains
        the time component (x[..., 0]) and spatial components (x[..., 1:]).
    c : float
        Curvature parameter, must be positive (c > 0).

    Returns
    -------
    y : Array of shape (..., dim+1)
        Output point(s) on the hyperboloid manifold, same shape as input.

    Notes
    -----
    - This function applies GELU only to spatial components
    - The time component is reconstructed using the hyperboloid constraint
    - GELU is smooth and commonly used in transformer architectures
    - Works on arrays of any shape, similar to jax.nn.gelu
    - For curvature-changing transformations, use `hrc_gelu` which supports
      different input/output curvatures

    Examples
    --------
    >>> import jax.numpy as jnp
    >>> from hyperbolix.nn_layers import hyp_gelu
    >>>
    >>> # Single point
    >>> x = jnp.array([1.05, 0.1, -0.2, 0.15])
    >>> y = hyp_gelu(x, c=1.0)
    >>> y.shape
    (4,)
    >>>
    >>> # Batch of points
    >>> x_batch = jnp.ones((8, 5))
    >>> y_batch = hyp_gelu(x_batch, c=1.0)
    >>> y_batch.shape
    (8, 5)
    >>>
    >>> # Multi-dimensional batch
    >>> x_feature = jnp.ones((4, 16, 16, 10))
    >>> y_feature = hyp_gelu(x_feature, c=1.0)
    >>> y_feature.shape
    (4, 16, 16, 10)
    """
    return hrc_gelu(x, c_in=c, c_out=c)

Poincaré Activations

Thin wrappers that apply standard activations in the Poincaré tangent space via logmap_0 → activation → expmap_0.

hyperbolix.nn_layers.poincare_relu

poincare_relu(
    x: Float[Array, "... dim"], c: float
) -> Float[Array, "... dim"]

Poincaré ReLU activation: exp_0^c ∘ ReLU ∘ log_0^c.

Applies ReLU in the tangent space at the origin of the Poincaré ball, then maps back to the manifold. This is the standard nonlinearity for Poincaré ball neural networks.

Parameters:

Name Type Description Default
x Array of shape (..., dim)

Input point(s) on the Poincaré ball. Supports arbitrary batch dimensions (e.g., (batch, H, W, channels) for feature maps).

required
c float

Curvature parameter (positive).

required

Returns:

Name Type Description
y Array of shape (..., dim)

Output point(s) on the Poincaré ball.

References

van Spengler et al. "Poincaré ResNet." ICML 2023.

Examples:

>>> import jax.numpy as jnp
>>> from hyperbolix.nn_layers import poincare_relu
>>>
>>> # Single point
>>> x = jnp.array([0.1, -0.2, 0.15])
>>> y = poincare_relu(x, c=1.0)
>>> y.shape
(3,)
>>>
>>> # Batch of feature maps
>>> x_batch = jnp.ones((4, 14, 14, 8)) * 0.1
>>> y_batch = poincare_relu(x_batch, c=1.0)
>>> y_batch.shape
(4, 14, 14, 8)
Source code in hyperbolix/nn_layers/poincare_activations.py
def poincare_relu(
    x: Float[Array, "... dim"],
    c: float,
) -> Float[Array, "... dim"]:
    """Poincaré ReLU activation: exp_0^c ∘ ReLU ∘ log_0^c.

    Applies ReLU in the tangent space at the origin of the Poincaré ball,
    then maps back to the manifold. This is the standard nonlinearity for
    Poincaré ball neural networks.

    Parameters
    ----------
    x : Array of shape (..., dim)
        Input point(s) on the Poincaré ball. Supports arbitrary batch
        dimensions (e.g., (batch, H, W, channels) for feature maps).
    c : float
        Curvature parameter (positive).

    Returns
    -------
    y : Array of shape (..., dim)
        Output point(s) on the Poincaré ball.

    References
    ----------
    van Spengler et al. "Poincaré ResNet." ICML 2023.

    Examples
    --------
    >>> import jax.numpy as jnp
    >>> from hyperbolix.nn_layers import poincare_relu
    >>>
    >>> # Single point
    >>> x = jnp.array([0.1, -0.2, 0.15])
    >>> y = poincare_relu(x, c=1.0)
    >>> y.shape
    (3,)
    >>>
    >>> # Batch of feature maps
    >>> x_batch = jnp.ones((4, 14, 14, 8)) * 0.1
    >>> y_batch = poincare_relu(x_batch, c=1.0)
    >>> y_batch.shape
    (4, 14, 14, 8)
    """
    return _apply_in_tangent_space(x, jax.nn.relu, c)

hyperbolix.nn_layers.poincare_leaky_relu

poincare_leaky_relu(
    x: Float[Array, "... dim"],
    c: float,
    negative_slope: float = 0.01,
) -> Float[Array, "... dim"]

Poincaré LeakyReLU activation: exp_0^c ∘ LeakyReLU ∘ log_0^c.

Parameters:

Name Type Description Default
x Array of shape (..., dim)

Input point(s) on the Poincaré ball.

required
c float

Curvature parameter (positive).

required
negative_slope float

Negative slope coefficient (default: 0.01).

0.01

Returns:

Name Type Description
y Array of shape (..., dim)

Output point(s) on the Poincaré ball.

Source code in hyperbolix/nn_layers/poincare_activations.py
def poincare_leaky_relu(
    x: Float[Array, "... dim"],
    c: float,
    negative_slope: float = 0.01,
) -> Float[Array, "... dim"]:
    """Poincaré LeakyReLU activation: exp_0^c ∘ LeakyReLU ∘ log_0^c.

    Parameters
    ----------
    x : Array of shape (..., dim)
        Input point(s) on the Poincaré ball.
    c : float
        Curvature parameter (positive).
    negative_slope : float, optional
        Negative slope coefficient (default: 0.01).

    Returns
    -------
    y : Array of shape (..., dim)
        Output point(s) on the Poincaré ball.
    """

    def f(z):
        return jax.nn.leaky_relu(z, negative_slope)

    return _apply_in_tangent_space(x, f, c)

hyperbolix.nn_layers.poincare_tanh

poincare_tanh(
    x: Float[Array, "... dim"], c: float
) -> Float[Array, "... dim"]

Poincaré tanh activation: exp_0^c ∘ tanh ∘ log_0^c.

Parameters:

Name Type Description Default
x Array of shape (..., dim)

Input point(s) on the Poincaré ball.

required
c float

Curvature parameter (positive).

required

Returns:

Name Type Description
y Array of shape (..., dim)

Output point(s) on the Poincaré ball.

Source code in hyperbolix/nn_layers/poincare_activations.py
def poincare_tanh(
    x: Float[Array, "... dim"],
    c: float,
) -> Float[Array, "... dim"]:
    """Poincaré tanh activation: exp_0^c ∘ tanh ∘ log_0^c.

    Parameters
    ----------
    x : Array of shape (..., dim)
        Input point(s) on the Poincaré ball.
    c : float
        Curvature parameter (positive).

    Returns
    -------
    y : Array of shape (..., dim)
        Output point(s) on the Poincaré ball.
    """
    return _apply_in_tangent_space(x, jnp.tanh, c)

Curvature-Changing Activations (HRC-based)

For advanced use cases requiring curvature transformations:

hyperbolix.nn_layers.hrc_relu

hrc_relu(
    x: Float[Array, "... dim_plus_1"],
    c_in: float,
    c_out: float,
    eps: float = 1e-07,
) -> Float[Array, "... dim_plus_1"]

HRC with ReLU activation.

Equivalent to hrc(x, jax.nn.relu, c_in, c_out, eps). When c_in = c_out = c, this is equivalent to hyp_relu(x, c).

Parameters:

Name Type Description Default
x Array of shape (..., dim+1)

Input point(s) on the hyperboloid manifold with curvature c_in.

required
c_in float

Input curvature parameter (must be positive).

required
c_out float

Output curvature parameter (must be positive).

required
eps float

Small value for numerical stability (default: 1e-7).

1e-07

Returns:

Name Type Description
y Array of shape (..., dim+1)

Output point(s) on the hyperboloid manifold with curvature c_out.

Source code in hyperbolix/nn_layers/hyperboloid_activations.py
def hrc_relu(
    x: Float[Array, "... dim_plus_1"],
    c_in: float,
    c_out: float,
    eps: float = 1e-7,
) -> Float[Array, "... dim_plus_1"]:
    """HRC with ReLU activation.

    Equivalent to hrc(x, jax.nn.relu, c_in, c_out, eps).
    When c_in = c_out = c, this is equivalent to hyp_relu(x, c).

    Parameters
    ----------
    x : Array of shape (..., dim+1)
        Input point(s) on the hyperboloid manifold with curvature c_in.
    c_in : float
        Input curvature parameter (must be positive).
    c_out : float
        Output curvature parameter (must be positive).
    eps : float, optional
        Small value for numerical stability (default: 1e-7).

    Returns
    -------
    y : Array of shape (..., dim+1)
        Output point(s) on the hyperboloid manifold with curvature c_out.
    """
    return hrc(x, jax.nn.relu, c_in, c_out, eps)

hyperbolix.nn_layers.hrc_leaky_relu

hrc_leaky_relu(
    x: Float[Array, "... dim_plus_1"],
    c_in: float,
    c_out: float,
    negative_slope: float = 0.01,
    eps: float = 1e-07,
) -> Float[Array, "... dim_plus_1"]

HRC with LeakyReLU activation.

When c_in = c_out = c, this is equivalent to hyp_leaky_relu(x, c, negative_slope).

Parameters:

Name Type Description Default
x Array of shape (..., dim+1)

Input point(s) on the hyperboloid manifold with curvature c_in.

required
c_in float

Input curvature parameter (must be positive).

required
c_out float

Output curvature parameter (must be positive).

required
negative_slope float

Negative slope coefficient for LeakyReLU (default: 0.01).

0.01
eps float

Small value for numerical stability (default: 1e-7).

1e-07

Returns:

Name Type Description
y Array of shape (..., dim+1)

Output point(s) on the hyperboloid manifold with curvature c_out.

Source code in hyperbolix/nn_layers/hyperboloid_activations.py
def hrc_leaky_relu(
    x: Float[Array, "... dim_plus_1"],
    c_in: float,
    c_out: float,
    negative_slope: float = 0.01,
    eps: float = 1e-7,
) -> Float[Array, "... dim_plus_1"]:
    """HRC with LeakyReLU activation.

    When c_in = c_out = c, this is equivalent to hyp_leaky_relu(x, c, negative_slope).

    Parameters
    ----------
    x : Array of shape (..., dim+1)
        Input point(s) on the hyperboloid manifold with curvature c_in.
    c_in : float
        Input curvature parameter (must be positive).
    c_out : float
        Output curvature parameter (must be positive).
    negative_slope : float, optional
        Negative slope coefficient for LeakyReLU (default: 0.01).
    eps : float, optional
        Small value for numerical stability (default: 1e-7).

    Returns
    -------
    y : Array of shape (..., dim+1)
        Output point(s) on the hyperboloid manifold with curvature c_out.
    """

    def f_r(z):
        return jax.nn.leaky_relu(z, negative_slope)

    return hrc(x, f_r, c_in, c_out, eps)

hyperbolix.nn_layers.hrc_tanh

hrc_tanh(
    x: Float[Array, "... dim_plus_1"],
    c_in: float,
    c_out: float,
    eps: float = 1e-07,
) -> Float[Array, "... dim_plus_1"]

HRC with tanh activation.

When c_in = c_out = c, this is equivalent to hyp_tanh(x, c).

Parameters:

Name Type Description Default
x Array of shape (..., dim+1)

Input point(s) on the hyperboloid manifold with curvature c_in.

required
c_in float

Input curvature parameter (must be positive).

required
c_out float

Output curvature parameter (must be positive).

required
eps float

Small value for numerical stability (default: 1e-7).

1e-07

Returns:

Name Type Description
y Array of shape (..., dim+1)

Output point(s) on the hyperboloid manifold with curvature c_out.

Source code in hyperbolix/nn_layers/hyperboloid_activations.py
def hrc_tanh(
    x: Float[Array, "... dim_plus_1"],
    c_in: float,
    c_out: float,
    eps: float = 1e-7,
) -> Float[Array, "... dim_plus_1"]:
    """HRC with tanh activation.

    When c_in = c_out = c, this is equivalent to hyp_tanh(x, c).

    Parameters
    ----------
    x : Array of shape (..., dim+1)
        Input point(s) on the hyperboloid manifold with curvature c_in.
    c_in : float
        Input curvature parameter (must be positive).
    c_out : float
        Output curvature parameter (must be positive).
    eps : float, optional
        Small value for numerical stability (default: 1e-7).

    Returns
    -------
    y : Array of shape (..., dim+1)
        Output point(s) on the hyperboloid manifold with curvature c_out.
    """
    return hrc(x, jnp.tanh, c_in, c_out, eps)

hyperbolix.nn_layers.hrc_swish

hrc_swish(
    x: Float[Array, "... dim_plus_1"],
    c_in: float,
    c_out: float,
    eps: float = 1e-07,
) -> Float[Array, "... dim_plus_1"]

HRC with Swish/SiLU activation.

When c_in = c_out = c, this is equivalent to hyp_swish(x, c).

Parameters:

Name Type Description Default
x Array of shape (..., dim+1)

Input point(s) on the hyperboloid manifold with curvature c_in.

required
c_in float

Input curvature parameter (must be positive).

required
c_out float

Output curvature parameter (must be positive).

required
eps float

Small value for numerical stability (default: 1e-7).

1e-07

Returns:

Name Type Description
y Array of shape (..., dim+1)

Output point(s) on the hyperboloid manifold with curvature c_out.

Source code in hyperbolix/nn_layers/hyperboloid_activations.py
def hrc_swish(
    x: Float[Array, "... dim_plus_1"],
    c_in: float,
    c_out: float,
    eps: float = 1e-7,
) -> Float[Array, "... dim_plus_1"]:
    """HRC with Swish/SiLU activation.

    When c_in = c_out = c, this is equivalent to hyp_swish(x, c).

    Parameters
    ----------
    x : Array of shape (..., dim+1)
        Input point(s) on the hyperboloid manifold with curvature c_in.
    c_in : float
        Input curvature parameter (must be positive).
    c_out : float
        Output curvature parameter (must be positive).
    eps : float, optional
        Small value for numerical stability (default: 1e-7).

    Returns
    -------
    y : Array of shape (..., dim+1)
        Output point(s) on the hyperboloid manifold with curvature c_out.
    """
    return hrc(x, jax.nn.swish, c_in, c_out, eps)

hyperbolix.nn_layers.hrc_gelu

hrc_gelu(
    x: Float[Array, "... dim_plus_1"],
    c_in: float,
    c_out: float,
    eps: float = 1e-07,
) -> Float[Array, "... dim_plus_1"]

HRC with GELU activation.

Parameters:

Name Type Description Default
x Array of shape (..., dim+1)

Input point(s) on the hyperboloid manifold with curvature c_in.

required
c_in float

Input curvature parameter (must be positive).

required
c_out float

Output curvature parameter (must be positive).

required
eps float

Small value for numerical stability (default: 1e-7).

1e-07

Returns:

Name Type Description
y Array of shape (..., dim+1)

Output point(s) on the hyperboloid manifold with curvature c_out.

Source code in hyperbolix/nn_layers/hyperboloid_activations.py
def hrc_gelu(
    x: Float[Array, "... dim_plus_1"],
    c_in: float,
    c_out: float,
    eps: float = 1e-7,
) -> Float[Array, "... dim_plus_1"]:
    """HRC with GELU activation.

    Parameters
    ----------
    x : Array of shape (..., dim+1)
        Input point(s) on the hyperboloid manifold with curvature c_in.
    c_in : float
        Input curvature parameter (must be positive).
    c_out : float
        Output curvature parameter (must be positive).
    eps : float, optional
        Small value for numerical stability (default: 1e-7).

    Returns
    -------
    y : Array of shape (..., dim+1)
        Output point(s) on the hyperboloid manifold with curvature c_out.
    """
    return hrc(x, jax.nn.gelu, c_in, c_out, eps)

Activation Examples

Curvature-Preserving Activation:

import jax
import jax.numpy as jnp
from hyperbolix.nn_layers import hyp_relu, hyp_gelu
from hyperbolix.manifolds import Hyperboloid

hyperboloid = Hyperboloid()

# Points on hyperboloid (ambient coordinates)
x = jax.random.normal(jax.random.PRNGKey(0), (10, 5))
x_ambient = jnp.concatenate([
    jnp.sqrt(jnp.sum(x**2, axis=-1, keepdims=True) + 1.0),
    x
], axis=-1)

# Apply hyperbolic ReLU (curvature preserving)
output = hyp_relu(x_ambient, c=1.0)
print(output.shape)  # (10, 6) - same shape

# Verify manifold constraint
constraint = -output[:, 0]**2 + jnp.sum(output[:, 1:]**2, axis=-1)
print(jnp.allclose(constraint, -1.0, atol=1e-5))  # True

# Use GELU instead
output_gelu = hyp_gelu(x_ambient, c=1.0)

Curvature-Changing Activation:

from hyperbolix.nn_layers import hrc_relu

# Transform from curvature 1.0 to curvature 2.0
output = hrc_relu(x_ambient, c_in=1.0, c_out=2.0)

# Verify new manifold constraint (c=2.0)
constraint = -output[:, 0]**2 + jnp.sum(output[:, 1:]**2, axis=-1)
print(jnp.allclose(constraint, -1.0/2.0, atol=1e-5))  # True

How Activations Work

Hyperbolic activations follow the HRC pattern:

  1. Extract space components x_s = x[..., 1:]
  2. Apply activation to space: y_s = activation(x_s)
  3. Scale for curvature change: y_s = sqrt(c_in/c_out) * y_s
  4. Reconstruct time: y_t = sqrt(||y_s||^2 + 1/c_out)

This avoids expensive exp/log maps while preserving geometry and enabling flexible curvature transformations.

Building Models

Example of a complete hyperbolic neural network:

import jax
import jax.numpy as jnp
from flax import nnx
from hyperbolix.nn_layers import HypLinearPoincare, hyp_relu
from hyperbolix.manifolds import Poincare

poincare = Poincare()

class HyperbolicNN(nnx.Module):
    def __init__(self, rngs):
        self.layer1 = HypLinearPoincare(
            manifold_module=poincare,
            in_dim=784,  # MNIST flattened
            out_dim=256,
            rngs=rngs
        )
        self.layer2 = HypLinearPoincare(
            manifold_module=poincare,
            in_dim=256,
            out_dim=128,
            rngs=rngs
        )
        self.layer3 = HypLinearPoincare(
            manifold_module=poincare,
            in_dim=128,
            out_dim=10,
            rngs=rngs
        )

    def __call__(self, x, c=1.0):
        # x: (batch, 784) on Poincaré ball
        x = self.layer1(x, c)
        x = jax.vmap(lambda xi: hyp_relu(xi, c))(x)

        x = self.layer2(x, c)
        x = jax.vmap(lambda xi: hyp_relu(xi, c))(x)

        x = self.layer3(x, c)
        return x

# Create and use model
model = HyperbolicNN(rngs=nnx.Rngs(0))

# Input data (projected to Poincaré ball)
x = jax.random.normal(jax.random.PRNGKey(1), (32, 784)) * 0.1
x_proj = jax.vmap(poincare.proj, in_axes=(0, None))(x, 1.0)

output = model(x_proj, c=1.0)
print(output.shape)  # (32, 10)

References

The neural network layers implement methods from:

  • Ganea et al. (2018): "Hyperbolic Neural Networks" - Poincaré linear layers and activations
  • Shimizu et al. (2020): "Hyperbolic Neural Networks++" - Enhanced Poincaré and Hyperboloid operations (HypLinearPoincarePP, HypLinearHyperboloidPP, HypConv2DHyperboloidPP)
  • Bdeir et al. (2023): "Fully Hyperbolic Convolutional Neural Networks for Computer Vision" - HCat-based convolutions (HypConv2DHyperboloid)
  • Chen et al. (2022): "Fully Hyperbolic Neural Networks" - FHCNN linear layers
  • LResNet (2023): "Lorentzian ResNet" - HRC-based convolutions (LorentzConv2D)
  • Hypformer (Yang et al. 2025): "Hyperbolic Transformers" - HTC/HRC components with curvature-change support
  • Chen et al. (2024): "Hyperbolic Embeddings for Learning on Manifolds (HELM)" - HOPE positional encoding and Lorentzian residual connections
  • Klis et al. (2026): "Fast and Geometrically Grounded Lorentz Neural Networks" - FGGLinear, FGGConv2D, FGGLorentzMLR, FGGMeanOnlyBatchNorm; sinh/arcsinh cancellation for linear hyperbolic distance growth

Key Theoretical Connections

  • HL (Hyperbolic Layer) from LResNet ≡ HRC (Hyperbolic Regularization Component) from Hypformer
  • Both apply Euclidean operations to spatial components and reconstruct time using the Lorentz constraint
  • LorentzConv2D is a specific instance of hrc() where f_r is a 2D convolution

See also: