Skip to content

Optimizers API

Riemannian optimization algorithms for training neural networks with hyperbolic parameters.

Overview

Hyperbolix provides two Riemannian optimizers that extend standard Euclidean optimizers to manifold-valued parameters:

  • Riemannian SGD (RSGD): Stochastic gradient descent with momentum
  • Riemannian Adam (RAdam): Adaptive learning rates with moment transport

Both optimizers:

  • Follow the standard Optax GradientTransformation interface
  • Automatically detect manifold parameters via metadata
  • Support mixed Euclidean/Riemannian parameter optimization
  • Are compatible with nnx.Optimizer wrapper

Riemannian SGD

hyperbolix.optim.riemannian_sgd

Riemannian SGD optimizer for JAX/Flax NNX.

This module implements Riemannian Stochastic Gradient Descent (RSGD) as a standard Optax GradientTransformation. It automatically detects manifold parameters via metadata and applies appropriate Riemannian operations, while treating Euclidean parameters with standard SGD.

The optimizer supports: - Momentum with parallel transport - Both exponential map (exact) and retraction (first-order approximation) - Mixed Euclidean/Riemannian parameter optimization

Algorithm (for manifold parameters): 1. Convert Euclidean gradient to Riemannian gradient: grad = manifold.egrad2rgrad(grad, param, c) 2. Update momentum: m = momentum * m + grad 3. Move on manifold: new_param = manifold.expmap(-lr * m, param, c) # or retraction 4. Transport momentum: m = manifold.ptransp(m, param, new_param, c)

For Euclidean parameters, standard SGD is applied.

Note: Tensor variables in update_single_leaf match the parameter shape (P), which varies per leaf. Scalars (lr, count) have no suffix.

References

Bécigneul, Gary, and Octavian-Eugen Ganea. "Riemannian adaptive optimization methods." arXiv preprint arXiv:1810.00760 (2018).

RSGDState

Bases: NamedTuple

State for Riemannian SGD optimizer.

Attributes:

Name Type Description
momentum Any

Pytree of momentum terms, same structure as parameters

count Array

Step count for schedule handling

riemannian_sgd

riemannian_sgd(
    learning_rate: float | Schedule,
    momentum: float = 0.0,
    use_expmap: bool = True,
) -> optax.GradientTransformation

Create a Riemannian SGD optimizer as an Optax GradientTransformation.

This optimizer automatically detects manifold parameters via metadata and applies Riemannian operations (egrad2rgrad, expmap/retraction, parallel transport), while treating Euclidean parameters with standard SGD.

Parameters:

Name Type Description Default
learning_rate float or Schedule

Learning rate (static or scheduled)

required
momentum float

Momentum coefficient (0 for no momentum, typically 0.9)

0.0
use_expmap bool

If True, use exponential map (exact geodesic). If False, use retraction (first-order approximation, faster).

True

Returns:

Name Type Description
optimizer GradientTransformation

An Optax GradientTransformation that can be used with nnx.Optimizer

Example

import jax from flax import nnx from hyperbolix.optim import riemannian_sgd from hyperbolix.nn_layers import HypLinearPoincare from hyperbolix.manifolds import poincare

Create model with manifold parameters

layer = HypLinearPoincare(poincare, 10, 5, rngs=nnx.Rngs(0))

Create Riemannian optimizer

tx = riemannian_sgd(learning_rate=0.01, momentum=0.9, use_expmap=True) optimizer = nnx.Optimizer(layer, tx, wrt=nnx.Param)

Training step

def loss_fn(model, x): ... y = model(x, c=1.0) ... return jnp.sum(y ** 2)

x = jax.random.normal(jax.random.key(1), (32, 10)) grads = nnx.grad(loss_fn)(layer, x) optimizer.update(grads) # Automatically handles manifold parameters

Notes
  • Compatible with Optax combinators (optax.chain, schedules, etc.)
  • Momentum is parallel transported for manifold parameters
  • Works seamlessly with nnx.Optimizer wrapper
  • Parameters stay on manifold after updates (via expmap/retraction + projection)
Source code in hyperbolix/optim/riemannian_sgd.py
def riemannian_sgd(
    learning_rate: float | optax.Schedule,
    momentum: float = 0.0,
    use_expmap: bool = True,
) -> optax.GradientTransformation:
    """Create a Riemannian SGD optimizer as an Optax GradientTransformation.

    This optimizer automatically detects manifold parameters via metadata and
    applies Riemannian operations (egrad2rgrad, expmap/retraction, parallel transport),
    while treating Euclidean parameters with standard SGD.

    Parameters
    ----------
    learning_rate : float or optax.Schedule
        Learning rate (static or scheduled)
    momentum : float, default=0.0
        Momentum coefficient (0 for no momentum, typically 0.9)
    use_expmap : bool, default=True
        If True, use exponential map (exact geodesic).
        If False, use retraction (first-order approximation, faster).

    Returns
    -------
    optimizer : optax.GradientTransformation
        An Optax GradientTransformation that can be used with nnx.Optimizer

    Example
    -------
    >>> import jax
    >>> from flax import nnx
    >>> from hyperbolix.optim import riemannian_sgd
    >>> from hyperbolix.nn_layers import HypLinearPoincare
    >>> from hyperbolix.manifolds import poincare
    >>>
    >>> # Create model with manifold parameters
    >>> layer = HypLinearPoincare(poincare, 10, 5, rngs=nnx.Rngs(0))
    >>>
    >>> # Create Riemannian optimizer
    >>> tx = riemannian_sgd(learning_rate=0.01, momentum=0.9, use_expmap=True)
    >>> optimizer = nnx.Optimizer(layer, tx, wrt=nnx.Param)
    >>>
    >>> # Training step
    >>> def loss_fn(model, x):
    ...     y = model(x, c=1.0)
    ...     return jnp.sum(y ** 2)
    >>>
    >>> x = jax.random.normal(jax.random.key(1), (32, 10))
    >>> grads = nnx.grad(loss_fn)(layer, x)
    >>> optimizer.update(grads)  # Automatically handles manifold parameters

    Notes
    -----
    - Compatible with Optax combinators (optax.chain, schedules, etc.)
    - Momentum is parallel transported for manifold parameters
    - Works seamlessly with nnx.Optimizer wrapper
    - Parameters stay on manifold after updates (via expmap/retraction + projection)
    """

    def manifold_leaf_fn(rgrad, moments, param_value, manifold_module, c, lr, count):
        (mom_value,) = moments
        new_mom = momentum * mom_value + rgrad
        lr_cast = lr.astype(new_mom.dtype)
        direction = -lr_cast * new_mom
        # Parallel transport momentum only if momentum > 0
        ptransp_indices = (0,) if momentum > 0.0 else ()
        return direction, (new_mom,), ptransp_indices

    def euclidean_leaf_fn(grad_value, moments, lr, count):
        (mom_value,) = moments
        new_mom = momentum * mom_value + grad_value
        lr_cast = lr.astype(new_mom.dtype)
        param_update = -lr_cast * new_mom
        return param_update, (new_mom,)

    return make_riemannian_optimizer(
        n_moments=1,
        state_cls=RSGDState,
        manifold_leaf_fn=manifold_leaf_fn,
        euclidean_leaf_fn=euclidean_leaf_fn,
        learning_rate=learning_rate,
        use_expmap=use_expmap,
    )

Example

import jax.numpy as jnp
from flax import nnx
from hyperbolix.optim import riemannian_sgd
from hyperbolix.nn_layers import HypLinearPoincare
from hyperbolix.manifolds import Poincare

poincare = Poincare()

# Create model with hyperbolic parameters
model = HypLinearPoincare(
    manifold_module=poincare,
    in_dim=32,
    out_dim=16,
    rngs=nnx.Rngs(0)
)

# Create Riemannian SGD optimizer
optimizer = nnx.Optimizer(
    model,
    riemannian_sgd(learning_rate=0.01, momentum=0.9),
    wrt=nnx.Param
)

# Training step
def train_step(model, optimizer, x, y):
    def loss_fn(model):
        pred = model(x, c=1.0)
        return jnp.mean((pred - y) ** 2)

    loss, grads = nnx.value_and_grad(loss_fn)(model)
    optimizer.update(model, grads)

    return loss

Riemannian Adam

hyperbolix.optim.riemannian_adam

Riemannian Adam optimizer for JAX/Flax NNX.

This module implements Riemannian Adam as a standard Optax GradientTransformation. It automatically detects manifold parameters via metadata and applies appropriate Riemannian operations (adaptive learning rates on manifolds), while treating Euclidean parameters with standard Adam.

The optimizer supports: - Adaptive learning rates with first and second moment estimation - Parallel transport of first moments (second moments follow PyTorch scalar update) - Both exponential map (exact) and retraction (first-order approximation) - Mixed Euclidean/Riemannian parameter optimization

Algorithm (for manifold parameters): 1. Convert Euclidean gradient to Riemannian gradient: grad = manifold.egrad2rgrad(grad, param, c) 2. Update first moment: m1 = beta1 * m1 + (1 - beta1) * grad 3. Update second moment: m2 = beta2 * m2 + (1 - beta2) * _param 4. Bias correction: m1_hat = m1 / (1 - beta1^t), m2_hat = m2 / (1 - beta2^t) 5. Compute direction: direction = m1_hat / (sqrt(m2_hat) + eps) 6. Move on manifold: new_param = manifold.expmap(-lr * direction, param, c) 7. Transport moments: m1 = manifold.ptransp(m1, param, new_param, c) # m2 accumulated via tangent inner product, no transport

For Euclidean parameters, standard Adam is applied.

Note: Tensor variables in the update loop match the parameter shape (P), which varies per leaf. Scalars (lr, count, bias_correction) have no suffix.

References

Bécigneul, Gary, and Octavian-Eugen Ganea. "Riemannian adaptive optimization methods." arXiv preprint arXiv:1810.00760 (2018). Kingma, Diederik P., and Jimmy Ba. "Adam: A method for stochastic optimization." arXiv preprint arXiv:1412.6980 (2014).

RAdamState

Bases: NamedTuple

State for Riemannian Adam optimizer.

Attributes:

Name Type Description
m1 Any

Pytree of first moment estimates (exponential moving average of gradients)

m2 Any

Pytree of second moment estimates (exponential moving average of squared gradients)

count Array

Step count for bias correction

riemannian_adam

riemannian_adam(
    learning_rate: float | Schedule,
    beta1: float = 0.9,
    beta2: float = 0.999,
    eps: float = 1e-08,
    use_expmap: bool = True,
) -> optax.GradientTransformation

Create a Riemannian Adam optimizer as an Optax GradientTransformation.

This optimizer automatically detects manifold parameters via metadata and applies Riemannian operations with adaptive learning rates, while treating Euclidean parameters with standard Adam.

Parameters:

Name Type Description Default
learning_rate float or Schedule

Learning rate (static or scheduled)

required
beta1 float

Exponential decay rate for first moment estimates

0.9
beta2 float

Exponential decay rate for second moment estimates

0.999
eps float

Small constant for numerical stability

1e-8
use_expmap bool

If True, use exponential map (exact geodesic). If False, use retraction (first-order approximation, faster).

True

Returns:

Name Type Description
optimizer GradientTransformation

An Optax GradientTransformation that can be used with nnx.Optimizer

Example

import jax from flax import nnx from hyperbolix.optim import riemannian_adam from hyperbolix.nn_layers import HypLinearPoincare from hyperbolix.manifolds import poincare

Create model with manifold parameters

layer = HypLinearPoincare(poincare, 10, 5, rngs=nnx.Rngs(0))

Create Riemannian Adam optimizer

tx = riemannian_adam(learning_rate=0.001, use_expmap=True) optimizer = nnx.Optimizer(layer, tx, wrt=nnx.Param)

Training step

def loss_fn(model, x): ... y = model(x, c=1.0) ... return jnp.sum(y ** 2)

x = jax.random.normal(jax.random.key(1), (32, 10)) grads = nnx.grad(loss_fn)(layer, x) optimizer.update(grads) # Automatically handles manifold parameters

Notes
  • Compatible with Optax combinators (optax.chain, schedules, etc.)
  • First moments are parallel transported for manifold parameters (matching PyTorch behaviour)
  • Second moments follow Geoopt/PyTorch: accumulated as tangent inner products without transport
  • Works seamlessly with nnx.Optimizer wrapper
  • Parameters stay on manifold after updates (via expmap/retraction + projection)
Source code in hyperbolix/optim/riemannian_adam.py
def riemannian_adam(
    learning_rate: float | optax.Schedule,
    beta1: float = 0.9,
    beta2: float = 0.999,
    eps: float = 1e-8,
    use_expmap: bool = True,
) -> optax.GradientTransformation:
    """Create a Riemannian Adam optimizer as an Optax GradientTransformation.

    This optimizer automatically detects manifold parameters via metadata and
    applies Riemannian operations with adaptive learning rates, while treating
    Euclidean parameters with standard Adam.

    Parameters
    ----------
    learning_rate : float or optax.Schedule
        Learning rate (static or scheduled)
    beta1 : float, default=0.9
        Exponential decay rate for first moment estimates
    beta2 : float, default=0.999
        Exponential decay rate for second moment estimates
    eps : float, default=1e-8
        Small constant for numerical stability
    use_expmap : bool, default=True
        If True, use exponential map (exact geodesic).
        If False, use retraction (first-order approximation, faster).

    Returns
    -------
    optimizer : optax.GradientTransformation
        An Optax GradientTransformation that can be used with nnx.Optimizer

    Example
    -------
    >>> import jax
    >>> from flax import nnx
    >>> from hyperbolix.optim import riemannian_adam
    >>> from hyperbolix.nn_layers import HypLinearPoincare
    >>> from hyperbolix.manifolds import poincare
    >>>
    >>> # Create model with manifold parameters
    >>> layer = HypLinearPoincare(poincare, 10, 5, rngs=nnx.Rngs(0))
    >>>
    >>> # Create Riemannian Adam optimizer
    >>> tx = riemannian_adam(learning_rate=0.001, use_expmap=True)
    >>> optimizer = nnx.Optimizer(layer, tx, wrt=nnx.Param)
    >>>
    >>> # Training step
    >>> def loss_fn(model, x):
    ...     y = model(x, c=1.0)
    ...     return jnp.sum(y ** 2)
    >>>
    >>> x = jax.random.normal(jax.random.key(1), (32, 10))
    >>> grads = nnx.grad(loss_fn)(layer, x)
    >>> optimizer.update(grads)  # Automatically handles manifold parameters

    Notes
    -----
    - Compatible with Optax combinators (optax.chain, schedules, etc.)
    - First moments are parallel transported for manifold parameters (matching PyTorch behaviour)
    - Second moments follow Geoopt/PyTorch: accumulated as tangent inner products without transport
    - Works seamlessly with nnx.Optimizer wrapper
    - Parameters stay on manifold after updates (via expmap/retraction + projection)
    """

    def manifold_leaf_fn(rgrad, moments, param_value, manifold_module, c, lr, count):
        m1_value, m2_value = moments

        # First moment: EMA of Riemannian gradients
        new_m1 = beta1 * m1_value + (1 - beta1) * rgrad

        # Second moment: EMA of tangent inner product (scalar broadcast to param shape)
        rgrad_sq = manifold_module.tangent_inner(rgrad, rgrad, param_value, c)
        rgrad_sq = jnp.asarray(rgrad_sq, dtype=rgrad.dtype)
        rgrad_sq = jnp.broadcast_to(rgrad_sq, m2_value.shape)
        new_m2 = beta2 * m2_value + (1 - beta2) * rgrad_sq

        # Bias correction
        bias_correction1 = 1 - beta1**count
        bias_correction2 = 1 - beta2**count
        m1_hat = new_m1 / bias_correction1
        m2_hat = new_m2 / bias_correction2

        # Direction (expmap/retraction applied by base)
        lr_cast = lr.astype(m1_hat.dtype)
        direction = -lr_cast * m1_hat / (jnp.sqrt(m2_hat) + eps)

        # Parallel transport m1 only (index 0); m2 stays (no transport)
        return direction, (new_m1, new_m2), (0,)

    def euclidean_leaf_fn(grad_value, moments, lr, count):
        m1_value, m2_value = moments

        new_m1 = beta1 * m1_value + (1 - beta1) * grad_value
        new_m2 = beta2 * m2_value + (1 - beta2) * (grad_value**2)

        bias_correction1 = 1 - beta1**count
        bias_correction2 = 1 - beta2**count
        m1_hat = new_m1 / bias_correction1
        m2_hat = new_m2 / bias_correction2

        lr_cast = lr.astype(m1_hat.dtype)
        param_update = -lr_cast * m1_hat / (jnp.sqrt(m2_hat) + eps)

        return param_update, (new_m1, new_m2)

    return make_riemannian_optimizer(
        n_moments=2,
        state_cls=RAdamState,
        manifold_leaf_fn=manifold_leaf_fn,
        euclidean_leaf_fn=euclidean_leaf_fn,
        learning_rate=learning_rate,
        use_expmap=use_expmap,
    )

Example

from hyperbolix.optim import riemannian_adam

# Create Riemannian Adam optimizer
optimizer = nnx.Optimizer(
    model,
    riemannian_adam(
        learning_rate=0.001,
        b1=0.9,
        b2=0.999,
        eps=1e-8
    ),
    wrt=nnx.Param
)

# Use in training loop (same as RSGD example, call optimizer.update(model, grads))

Manifold Metadata System

Hyperbolix uses ManifoldParam, an nnx.Param subclass, to tag parameters that live on a Riemannian manifold. The optimizers detect these via isinstance checks.

How It Works

from hyperbolix.optim import ManifoldParam

# Layer tags hyperbolic parameters using ManifoldParam
class HypLinearPoincare(nnx.Module):
    def __init__(self, manifold_module, in_dim, out_dim, *, rngs):
        self.manifold = manifold_module

        # Kernel is Euclidean (plain nnx.Param)
        self.kernel = nnx.Param(
            jax.random.normal(rngs.params(), (out_dim, in_dim)) * 0.01
        )

        # Bias lives on Poincaré ball (ManifoldParam)
        self.bias = ManifoldParam(
            jnp.zeros(out_dim),
            manifold=manifold_module,
            curvature=1.0,
        )

The optimizer automatically:

  1. Detects ManifoldParam instances
  2. Applies Riemannian gradient updates (expmap/retraction)
  3. Performs parallel transport for momentum/adaptive moments
  4. Falls back to Euclidean updates for plain nnx.Param parameters

API Reference

hyperbolix.optim.manifold_metadata

Manifold metadata utilities for Riemannian optimization.

This module provides ManifoldParam, an nnx.Param subclass that marks parameters as living on a Riemannian manifold. Riemannian optimizers detect these parameters via isinstance(param, ManifoldParam) and apply the appropriate exponential-map / retraction updates automatically.

Design rationale: - ManifoldParam is a thin nnx.Param subclass — all NNX machinery (state extraction, serialization, JIT) works unchanged - Manifold and curvature are stored as standard Variable metadata kwargs, accessible via attribute access (param.manifold, param.curvature) - The manifold instance carries dtype control via _cast(), so Riemannian optimizer operations automatically respect the layer's precision setting - Supports both static and callable curvature parameters

Example: >>> import jax.numpy as jnp >>> from flax import nnx >>> from hyperbolix.manifolds.poincare import Poincare >>> from hyperbolix.optim import ManifoldParam, get_manifold_info >>> >>> # Create a parameter on the Poincaré manifold >>> manifold = Poincare(dtype=jnp.float64) >>> bias = ManifoldParam( ... jnp.zeros((10,)), ... manifold=manifold, ... curvature=1.0, ... ) >>> >>> # In optimizer: extract manifold info >>> manifold_info = get_manifold_info(bias) >>> if manifold_info is not None: ... manifold_instance, c = manifold_info ... # Apply Riemannian operations via public methods...

ManifoldParam
ManifoldParam(
    value: Any,
    *,
    manifold: Manifold,
    curvature: float | Callable[[], Any],
    **metadata: Any,
)

Bases: Param

nnx.Param subclass for parameters living on a Riemannian manifold.

Stores the manifold instance and curvature as Variable metadata kwargs, providing type-safe detection via isinstance(param, ManifoldParam).

Parameters:

Name Type Description Default
value array - like

The parameter value (a JAX array).

required
manifold Manifold

A manifold class instance (e.g., Poincare(dtype=jnp.float64)).

required
curvature float or callable

Either a static curvature value or a callable that returns the current curvature. Use a callable (e.g., lambda: self.c[...]) for learnable curvature.

required
Example

import jax import jax.numpy as jnp from flax import nnx from hyperbolix.manifolds.poincare import Poincare from hyperbolix.optim import ManifoldParam

manifold = Poincare(dtype=jnp.float64)

Static curvature

bias = ManifoldParam( ... jax.random.normal(jax.random.key(0), (10,)) * 0.01, ... manifold=manifold, ... curvature=1.0, ... )

Learnable curvature (callable)

class MyLayer(nnx.Module): ... def init(self, rngs): ... self.c = nnx.Param(jnp.array(1.0)) ... self.bias = ManifoldParam( ... jax.random.normal(rngs.params(), (10,)) * 0.01, ... manifold=manifold, ... curvature=lambda: self.c.value, ... )

Source code in hyperbolix/optim/manifold_metadata.py
def __init__(
    self,
    value: Any,
    *,
    manifold: Manifold,
    curvature: float | Callable[[], Any],
    **metadata: Any,
) -> None:
    super().__init__(value, manifold=manifold, curvature=curvature, **metadata)
mark_manifold_param
mark_manifold_param(
    param: Param,
    manifold: Manifold,
    curvature: float | Callable[[], Any],
) -> ManifoldParam

Create a ManifoldParam from an existing nnx.Param.

This is a convenience wrapper around ManifoldParam. Prefer using ManifoldParam directly in new code.

Parameters:

Name Type Description Default
param Param

The parameter whose value will be copied into a new ManifoldParam.

required
manifold Manifold

A manifold class instance.

required
curvature float or callable

Static curvature value or callable returning current curvature.

required

Returns:

Type Description
ManifoldParam

A new ManifoldParam wrapping the same array value.

Source code in hyperbolix/optim/manifold_metadata.py
def mark_manifold_param(
    param: nnx.Param,
    manifold: Manifold,
    curvature: float | Callable[[], Any],
) -> ManifoldParam:
    """Create a ``ManifoldParam`` from an existing ``nnx.Param``.

    This is a convenience wrapper around ``ManifoldParam``.  Prefer using
    ``ManifoldParam`` directly in new code.

    Parameters
    ----------
    param : nnx.Param
        The parameter whose value will be copied into a new ``ManifoldParam``.
    manifold : Manifold
        A manifold class instance.
    curvature : float or callable
        Static curvature value or callable returning current curvature.

    Returns
    -------
    ManifoldParam
        A new ``ManifoldParam`` wrapping the same array value.
    """
    return ManifoldParam(param[...], manifold=manifold, curvature=curvature)
get_manifold_info
get_manifold_info(
    param: Variable,
) -> tuple[Manifold, Any] | None

Extract manifold information from a parameter.

Parameters:

Name Type Description Default
param Variable

The parameter to extract manifold info from.

required

Returns:

Name Type Description
manifold_info tuple of (Manifold, curvature) or None

If the parameter is a ManifoldParam: - manifold: The manifold class instance - curvature: The current curvature value (evaluated if callable) Otherwise None.

Example

manifold_info = get_manifold_info(param) if manifold_info is not None: ... manifold, c = manifold_info ... rgrad = manifold.egrad2rgrad(grad, param[...], c)

Source code in hyperbolix/optim/manifold_metadata.py
def get_manifold_info(param: nnx.Variable) -> tuple[Manifold, Any] | None:
    """Extract manifold information from a parameter.

    Parameters
    ----------
    param : nnx.Variable
        The parameter to extract manifold info from.

    Returns
    -------
    manifold_info : tuple of (Manifold, curvature) or None
        If the parameter is a ``ManifoldParam``:
            - manifold: The manifold class instance
            - curvature: The current curvature value (evaluated if callable)
        Otherwise ``None``.

    Example
    -------
    >>> manifold_info = get_manifold_info(param)
    >>> if manifold_info is not None:
    ...     manifold, c = manifold_info
    ...     rgrad = manifold.egrad2rgrad(grad, param[...], c)
    """
    if not isinstance(param, ManifoldParam):
        return None

    manifold = param.manifold
    curvature_value = param.curvature
    if callable(curvature_value):
        curvature_value = curvature_value()

    return (manifold, curvature_value)
has_manifold_params
has_manifold_params(params_pytree: Any) -> bool

Check if a parameter pytree contains any manifold parameters.

Parameters:

Name Type Description Default
params_pytree Any

A pytree of parameters (typically from nnx.state(model, nnx.Param)).

required

Returns:

Name Type Description
has_manifold bool

True if any parameter in the pytree is a ManifoldParam.

Example

import jax from flax import nnx

model = MyHyperbolicModel(rngs=nnx.Rngs(0)) params = nnx.state(model, nnx.Param) if has_manifold_params(params): ... print("Model contains manifold parameters")

Source code in hyperbolix/optim/manifold_metadata.py
def has_manifold_params(params_pytree: Any) -> bool:
    """Check if a parameter pytree contains any manifold parameters.

    Parameters
    ----------
    params_pytree : Any
        A pytree of parameters (typically from ``nnx.state(model, nnx.Param)``).

    Returns
    -------
    has_manifold : bool
        True if any parameter in the pytree is a ``ManifoldParam``.

    Example
    -------
    >>> import jax
    >>> from flax import nnx
    >>>
    >>> model = MyHyperbolicModel(rngs=nnx.Rngs(0))
    >>> params = nnx.state(model, nnx.Param)
    >>> if has_manifold_params(params):
    ...     print("Model contains manifold parameters")
    """
    from jax import tree_util

    leaves = tree_util.tree_leaves(params_pytree)
    return any(isinstance(leaf, ManifoldParam) for leaf in leaves)

Expmap vs Retraction

Both optimizers support two update modes:

  • Exponential map (default): expmap(x, -lr * grad)
  • Exact geodesic following
  • Numerically stable for large steps
  • Slightly slower

  • Retraction: proj(x - lr * grad)

  • First-order approximation
  • Faster computation
  • Can be less stable for large learning rates

Choosing Update Mode

# Use exponential map (default, recommended)
opt = riemannian_adam(learning_rate=0.001)

# For extremely performance-critical applications,
# you can experiment with retraction-based updates
# by modifying the optimizer implementation

In practice, exponential maps provide better stability and convergence, especially for hyperbolic neural networks.

Mixed Optimization

The optimizers seamlessly handle models with both Euclidean and hyperbolic parameters:

from hyperbolix.manifolds import Poincare
from hyperbolix.nn_layers import HypLinearPoincare

poincare = Poincare()

class MixedModel(nnx.Module):
    def __init__(self, rngs):
        # Euclidean linear layer
        self.fc1 = nnx.Linear(32, 64, rngs=rngs)

        # Hyperbolic layer (bias has manifold metadata)
        self.hyp = HypLinearPoincare(
            manifold_module=poincare,
            in_dim=64,
            out_dim=16,
            rngs=rngs
        )

        # Another Euclidean layer
        self.fc2 = nnx.Linear(16, 10, rngs=rngs)

# Optimizer handles all parameter types automatically
optimizer = nnx.Optimizer(model, riemannian_adam(learning_rate=0.001), wrt=nnx.Param)

The optimizer will:

  • Apply standard Adam updates to fc1 and fc2 parameters
  • Apply Riemannian Adam updates to hyp.bias (tagged with metadata)
  • Apply Euclidean Adam updates to hyp.kernel (no metadata)

Performance Considerations

JIT Compilation

Both optimizers are JIT-compatible. For best performance:

@jax.jit
def train_step(model, optimizer, x, y):
    def loss_fn(model):
        return compute_loss(model, x, y)

    loss, grads = nnx.value_and_grad(loss_fn)(model)
    optimizer.update(model, grads)
    return loss

Curvature as Static Argument

If curvature c is constant during training, pass it as a static argument to enable better JIT optimization:

@jax.jit
def forward(model, x):
    return model(x, c=1.0)  # c is traced, not ideal

# Better: use partial application
from functools import partial

@partial(jax.jit, static_argnums=(2,))
def forward(model, x, c):
    return model(x, c=c)

output = forward(model, x, 1.0)  # c=1.0 is static

References

The Riemannian optimizers are based on:

  • Bécigneul, G., & Ganea, O. (2019). "Riemannian Adaptive Optimization Methods." ICLR 2019.
  • Bonnabel, S. (2013). "Stochastic gradient descent on Riemannian manifolds." IEEE TAC.

See the User Guide for detailed explanations and best practices.