Optimizers API¶
Riemannian optimization algorithms for training neural networks with hyperbolic parameters.
Overview¶
Hyperbolix provides two Riemannian optimizers that extend standard Euclidean optimizers to manifold-valued parameters:
- Riemannian SGD (RSGD): Stochastic gradient descent with momentum
- Riemannian Adam (RAdam): Adaptive learning rates with moment transport
Both optimizers:
- Follow the standard Optax
GradientTransformationinterface - Automatically detect manifold parameters via metadata
- Support mixed Euclidean/Riemannian parameter optimization
- Are compatible with
nnx.Optimizerwrapper
Riemannian SGD¶
hyperbolix.optim.riemannian_sgd ¶
Riemannian SGD optimizer for JAX/Flax NNX.
This module implements Riemannian Stochastic Gradient Descent (RSGD) as a standard Optax GradientTransformation. It automatically detects manifold parameters via metadata and applies appropriate Riemannian operations, while treating Euclidean parameters with standard SGD.
The optimizer supports: - Momentum with parallel transport - Both exponential map (exact) and retraction (first-order approximation) - Mixed Euclidean/Riemannian parameter optimization
Algorithm (for manifold parameters): 1. Convert Euclidean gradient to Riemannian gradient: grad = manifold.egrad2rgrad(grad, param, c) 2. Update momentum: m = momentum * m + grad 3. Move on manifold: new_param = manifold.expmap(-lr * m, param, c) # or retraction 4. Transport momentum: m = manifold.ptransp(m, param, new_param, c)
For Euclidean parameters, standard SGD is applied.
Note: Tensor variables in update_single_leaf match the parameter shape (P), which varies per leaf. Scalars (lr, count) have no suffix.
References
Bécigneul, Gary, and Octavian-Eugen Ganea. "Riemannian adaptive optimization methods." arXiv preprint arXiv:1810.00760 (2018).
RSGDState ¶
Bases: NamedTuple
State for Riemannian SGD optimizer.
Attributes:
| Name | Type | Description |
|---|---|---|
momentum |
Any
|
Pytree of momentum terms, same structure as parameters |
count |
Array
|
Step count for schedule handling |
riemannian_sgd ¶
riemannian_sgd(
learning_rate: float | Schedule,
momentum: float = 0.0,
use_expmap: bool = True,
) -> optax.GradientTransformation
Create a Riemannian SGD optimizer as an Optax GradientTransformation.
This optimizer automatically detects manifold parameters via metadata and applies Riemannian operations (egrad2rgrad, expmap/retraction, parallel transport), while treating Euclidean parameters with standard SGD.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
learning_rate
|
float or Schedule
|
Learning rate (static or scheduled) |
required |
momentum
|
float
|
Momentum coefficient (0 for no momentum, typically 0.9) |
0.0
|
use_expmap
|
bool
|
If True, use exponential map (exact geodesic). If False, use retraction (first-order approximation, faster). |
True
|
Returns:
| Name | Type | Description |
|---|---|---|
optimizer |
GradientTransformation
|
An Optax GradientTransformation that can be used with nnx.Optimizer |
Example
import jax from flax import nnx from hyperbolix.optim import riemannian_sgd from hyperbolix.nn_layers import HypLinearPoincare from hyperbolix.manifolds import poincare
Create model with manifold parameters¶
layer = HypLinearPoincare(poincare, 10, 5, rngs=nnx.Rngs(0))
Create Riemannian optimizer¶
tx = riemannian_sgd(learning_rate=0.01, momentum=0.9, use_expmap=True) optimizer = nnx.Optimizer(layer, tx, wrt=nnx.Param)
Training step¶
def loss_fn(model, x): ... y = model(x, c=1.0) ... return jnp.sum(y ** 2)
x = jax.random.normal(jax.random.key(1), (32, 10)) grads = nnx.grad(loss_fn)(layer, x) optimizer.update(grads) # Automatically handles manifold parameters
Notes
- Compatible with Optax combinators (optax.chain, schedules, etc.)
- Momentum is parallel transported for manifold parameters
- Works seamlessly with nnx.Optimizer wrapper
- Parameters stay on manifold after updates (via expmap/retraction + projection)
Source code in hyperbolix/optim/riemannian_sgd.py
Example¶
import jax.numpy as jnp
from flax import nnx
from hyperbolix.optim import riemannian_sgd
from hyperbolix.nn_layers import HypLinearPoincare
from hyperbolix.manifolds import Poincare
poincare = Poincare()
# Create model with hyperbolic parameters
model = HypLinearPoincare(
manifold_module=poincare,
in_dim=32,
out_dim=16,
rngs=nnx.Rngs(0)
)
# Create Riemannian SGD optimizer
optimizer = nnx.Optimizer(
model,
riemannian_sgd(learning_rate=0.01, momentum=0.9),
wrt=nnx.Param
)
# Training step
def train_step(model, optimizer, x, y):
def loss_fn(model):
pred = model(x, c=1.0)
return jnp.mean((pred - y) ** 2)
loss, grads = nnx.value_and_grad(loss_fn)(model)
optimizer.update(model, grads)
return loss
Riemannian Adam¶
hyperbolix.optim.riemannian_adam ¶
Riemannian Adam optimizer for JAX/Flax NNX.
This module implements Riemannian Adam as a standard Optax GradientTransformation. It automatically detects manifold parameters via metadata and applies appropriate Riemannian operations (adaptive learning rates on manifolds), while treating Euclidean parameters with standard Adam.
The optimizer supports: - Adaptive learning rates with first and second moment estimation - Parallel transport of first moments (second moments follow PyTorch scalar update) - Both exponential map (exact) and retraction (first-order approximation) - Mixed Euclidean/Riemannian parameter optimization
Algorithm (for manifold parameters):
1. Convert Euclidean gradient to Riemannian gradient: grad = manifold.egrad2rgrad(grad, param, c)
2. Update first moment: m1 = beta1 * m1 + (1 - beta1) * grad
3. Update second moment: m2 = beta2 * m2 + (1 - beta2) *
For Euclidean parameters, standard Adam is applied.
Note: Tensor variables in the update loop match the parameter shape (P), which varies per leaf. Scalars (lr, count, bias_correction) have no suffix.
References
Bécigneul, Gary, and Octavian-Eugen Ganea. "Riemannian adaptive optimization methods." arXiv preprint arXiv:1810.00760 (2018). Kingma, Diederik P., and Jimmy Ba. "Adam: A method for stochastic optimization." arXiv preprint arXiv:1412.6980 (2014).
RAdamState ¶
Bases: NamedTuple
State for Riemannian Adam optimizer.
Attributes:
| Name | Type | Description |
|---|---|---|
m1 |
Any
|
Pytree of first moment estimates (exponential moving average of gradients) |
m2 |
Any
|
Pytree of second moment estimates (exponential moving average of squared gradients) |
count |
Array
|
Step count for bias correction |
riemannian_adam ¶
riemannian_adam(
learning_rate: float | Schedule,
beta1: float = 0.9,
beta2: float = 0.999,
eps: float = 1e-08,
use_expmap: bool = True,
) -> optax.GradientTransformation
Create a Riemannian Adam optimizer as an Optax GradientTransformation.
This optimizer automatically detects manifold parameters via metadata and applies Riemannian operations with adaptive learning rates, while treating Euclidean parameters with standard Adam.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
learning_rate
|
float or Schedule
|
Learning rate (static or scheduled) |
required |
beta1
|
float
|
Exponential decay rate for first moment estimates |
0.9
|
beta2
|
float
|
Exponential decay rate for second moment estimates |
0.999
|
eps
|
float
|
Small constant for numerical stability |
1e-8
|
use_expmap
|
bool
|
If True, use exponential map (exact geodesic). If False, use retraction (first-order approximation, faster). |
True
|
Returns:
| Name | Type | Description |
|---|---|---|
optimizer |
GradientTransformation
|
An Optax GradientTransformation that can be used with nnx.Optimizer |
Example
import jax from flax import nnx from hyperbolix.optim import riemannian_adam from hyperbolix.nn_layers import HypLinearPoincare from hyperbolix.manifolds import poincare
Create model with manifold parameters¶
layer = HypLinearPoincare(poincare, 10, 5, rngs=nnx.Rngs(0))
Create Riemannian Adam optimizer¶
tx = riemannian_adam(learning_rate=0.001, use_expmap=True) optimizer = nnx.Optimizer(layer, tx, wrt=nnx.Param)
Training step¶
def loss_fn(model, x): ... y = model(x, c=1.0) ... return jnp.sum(y ** 2)
x = jax.random.normal(jax.random.key(1), (32, 10)) grads = nnx.grad(loss_fn)(layer, x) optimizer.update(grads) # Automatically handles manifold parameters
Notes
- Compatible with Optax combinators (optax.chain, schedules, etc.)
- First moments are parallel transported for manifold parameters (matching PyTorch behaviour)
- Second moments follow Geoopt/PyTorch: accumulated as tangent inner products without transport
- Works seamlessly with nnx.Optimizer wrapper
- Parameters stay on manifold after updates (via expmap/retraction + projection)
Source code in hyperbolix/optim/riemannian_adam.py
63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 | |
Example¶
from hyperbolix.optim import riemannian_adam
# Create Riemannian Adam optimizer
optimizer = nnx.Optimizer(
model,
riemannian_adam(
learning_rate=0.001,
b1=0.9,
b2=0.999,
eps=1e-8
),
wrt=nnx.Param
)
# Use in training loop (same as RSGD example, call optimizer.update(model, grads))
Manifold Metadata System¶
Hyperbolix uses ManifoldParam, an nnx.Param subclass, to tag parameters that live on a Riemannian manifold. The optimizers detect these via isinstance checks.
How It Works¶
from hyperbolix.optim import ManifoldParam
# Layer tags hyperbolic parameters using ManifoldParam
class HypLinearPoincare(nnx.Module):
def __init__(self, manifold_module, in_dim, out_dim, *, rngs):
self.manifold = manifold_module
# Kernel is Euclidean (plain nnx.Param)
self.kernel = nnx.Param(
jax.random.normal(rngs.params(), (out_dim, in_dim)) * 0.01
)
# Bias lives on Poincaré ball (ManifoldParam)
self.bias = ManifoldParam(
jnp.zeros(out_dim),
manifold=manifold_module,
curvature=1.0,
)
The optimizer automatically:
- Detects
ManifoldParaminstances - Applies Riemannian gradient updates (expmap/retraction)
- Performs parallel transport for momentum/adaptive moments
- Falls back to Euclidean updates for plain
nnx.Paramparameters
API Reference¶
hyperbolix.optim.manifold_metadata ¶
Manifold metadata utilities for Riemannian optimization.
This module provides ManifoldParam, an nnx.Param subclass that marks
parameters as living on a Riemannian manifold. Riemannian optimizers detect
these parameters via isinstance(param, ManifoldParam) and apply the
appropriate exponential-map / retraction updates automatically.
Design rationale:
- ManifoldParam is a thin nnx.Param subclass — all NNX machinery
(state extraction, serialization, JIT) works unchanged
- Manifold and curvature are stored as standard Variable metadata kwargs,
accessible via attribute access (param.manifold, param.curvature)
- The manifold instance carries dtype control via _cast(), so Riemannian
optimizer operations automatically respect the layer's precision setting
- Supports both static and callable curvature parameters
Example: >>> import jax.numpy as jnp >>> from flax import nnx >>> from hyperbolix.manifolds.poincare import Poincare >>> from hyperbolix.optim import ManifoldParam, get_manifold_info >>> >>> # Create a parameter on the Poincaré manifold >>> manifold = Poincare(dtype=jnp.float64) >>> bias = ManifoldParam( ... jnp.zeros((10,)), ... manifold=manifold, ... curvature=1.0, ... ) >>> >>> # In optimizer: extract manifold info >>> manifold_info = get_manifold_info(bias) >>> if manifold_info is not None: ... manifold_instance, c = manifold_info ... # Apply Riemannian operations via public methods...
ManifoldParam ¶
ManifoldParam(
value: Any,
*,
manifold: Manifold,
curvature: float | Callable[[], Any],
**metadata: Any,
)
Bases: Param
nnx.Param subclass for parameters living on a Riemannian manifold.
Stores the manifold instance and curvature as Variable metadata kwargs,
providing type-safe detection via isinstance(param, ManifoldParam).
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
value
|
array - like
|
The parameter value (a JAX array). |
required |
manifold
|
Manifold
|
A manifold class instance (e.g., |
required |
curvature
|
float or callable
|
Either a static curvature value or a callable that returns the
current curvature. Use a callable (e.g., |
required |
Example
import jax import jax.numpy as jnp from flax import nnx from hyperbolix.manifolds.poincare import Poincare from hyperbolix.optim import ManifoldParam
manifold = Poincare(dtype=jnp.float64)
Static curvature¶
bias = ManifoldParam( ... jax.random.normal(jax.random.key(0), (10,)) * 0.01, ... manifold=manifold, ... curvature=1.0, ... )
Learnable curvature (callable)¶
class MyLayer(nnx.Module): ... def init(self, rngs): ... self.c = nnx.Param(jnp.array(1.0)) ... self.bias = ManifoldParam( ... jax.random.normal(rngs.params(), (10,)) * 0.01, ... manifold=manifold, ... curvature=lambda: self.c.value, ... )
Source code in hyperbolix/optim/manifold_metadata.py
mark_manifold_param ¶
mark_manifold_param(
param: Param,
manifold: Manifold,
curvature: float | Callable[[], Any],
) -> ManifoldParam
Create a ManifoldParam from an existing nnx.Param.
This is a convenience wrapper around ManifoldParam. Prefer using
ManifoldParam directly in new code.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
param
|
Param
|
The parameter whose value will be copied into a new |
required |
manifold
|
Manifold
|
A manifold class instance. |
required |
curvature
|
float or callable
|
Static curvature value or callable returning current curvature. |
required |
Returns:
| Type | Description |
|---|---|
ManifoldParam
|
A new |
Source code in hyperbolix/optim/manifold_metadata.py
get_manifold_info ¶
Extract manifold information from a parameter.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
param
|
Variable
|
The parameter to extract manifold info from. |
required |
Returns:
| Name | Type | Description |
|---|---|---|
manifold_info |
tuple of (Manifold, curvature) or None
|
If the parameter is a |
Example
manifold_info = get_manifold_info(param) if manifold_info is not None: ... manifold, c = manifold_info ... rgrad = manifold.egrad2rgrad(grad, param[...], c)
Source code in hyperbolix/optim/manifold_metadata.py
has_manifold_params ¶
Check if a parameter pytree contains any manifold parameters.
Parameters:
| Name | Type | Description | Default |
|---|---|---|---|
params_pytree
|
Any
|
A pytree of parameters (typically from |
required |
Returns:
| Name | Type | Description |
|---|---|---|
has_manifold |
bool
|
True if any parameter in the pytree is a |
Example
import jax from flax import nnx
model = MyHyperbolicModel(rngs=nnx.Rngs(0)) params = nnx.state(model, nnx.Param) if has_manifold_params(params): ... print("Model contains manifold parameters")
Source code in hyperbolix/optim/manifold_metadata.py
Expmap vs Retraction¶
Both optimizers support two update modes:
- Exponential map (default):
expmap(x, -lr * grad) - Exact geodesic following
- Numerically stable for large steps
-
Slightly slower
-
Retraction:
proj(x - lr * grad) - First-order approximation
- Faster computation
- Can be less stable for large learning rates
Choosing Update Mode¶
# Use exponential map (default, recommended)
opt = riemannian_adam(learning_rate=0.001)
# For extremely performance-critical applications,
# you can experiment with retraction-based updates
# by modifying the optimizer implementation
In practice, exponential maps provide better stability and convergence, especially for hyperbolic neural networks.
Mixed Optimization¶
The optimizers seamlessly handle models with both Euclidean and hyperbolic parameters:
from hyperbolix.manifolds import Poincare
from hyperbolix.nn_layers import HypLinearPoincare
poincare = Poincare()
class MixedModel(nnx.Module):
def __init__(self, rngs):
# Euclidean linear layer
self.fc1 = nnx.Linear(32, 64, rngs=rngs)
# Hyperbolic layer (bias has manifold metadata)
self.hyp = HypLinearPoincare(
manifold_module=poincare,
in_dim=64,
out_dim=16,
rngs=rngs
)
# Another Euclidean layer
self.fc2 = nnx.Linear(16, 10, rngs=rngs)
# Optimizer handles all parameter types automatically
optimizer = nnx.Optimizer(model, riemannian_adam(learning_rate=0.001), wrt=nnx.Param)
The optimizer will:
- Apply standard Adam updates to
fc1andfc2parameters - Apply Riemannian Adam updates to
hyp.bias(tagged with metadata) - Apply Euclidean Adam updates to
hyp.kernel(no metadata)
Performance Considerations¶
JIT Compilation
Both optimizers are JIT-compatible. For best performance:
Curvature as Static Argument
If curvature c is constant during training, pass it as a static argument to enable better JIT optimization:
References¶
The Riemannian optimizers are based on:
- Bécigneul, G., & Ganea, O. (2019). "Riemannian Adaptive Optimization Methods." ICLR 2019.
- Bonnabel, S. (2013). "Stochastic gradient descent on Riemannian manifolds." IEEE TAC.
See the User Guide for detailed explanations and best practices.