JAX API¶
The JAX integration provides a functional API compatible with jax.grad, jax.vmap, and jax.jit.
Solver¶
- class moreau.jax.Solver(n, m, P_row_offsets, P_col_indices, A_row_offsets, A_col_indices, cones, settings=None, jit=True)¶
JAX solver with automatic device selection and gradient support.
Supports two usage patterns:
Full signature:
solve(P_data, A_data, q, b)Two-step:
setup(P_data, A_data)thensolve(q, b)
- Parameters:
n – Number of primal variables
m – Number of constraints
P_row_offsets – CSR row pointers for P matrix (array-like)
P_col_indices – CSR column indices for P matrix (array-like)
A_row_offsets – CSR row pointers for A matrix (array-like)
A_col_indices – CSR column indices for A matrix (array-like)
cones – Cone specification (moreau.Cones object)
settings – Optional solver settings (moreau.Settings object)
jit – If True (default), JIT-compile the solve method
Example (full signature):
import jax.numpy as jnp from moreau.jax import Solver import moreau cones = moreau.Cones(num_nonneg_cones=2) solver = Solver( n=2, m=2, P_row_offsets=jnp.array([0, 1, 2]), P_col_indices=jnp.array([0, 1]), A_row_offsets=jnp.array([0, 1, 2]), A_col_indices=jnp.array([0, 1]), cones=cones, ) P_data = jnp.array([1.0, 1.0]) A_data = jnp.array([1.0, 1.0]) q = jnp.array([1.0, 1.0]) b = jnp.array([0.5, 0.5]) solution = solver.solve(P_data, A_data, q, b) print(solution.x)
Example (two-step):
solver = Solver(n=2, m=2, ...) solver.setup(P_data, A_data) # Set matrices once solution = solver.solve(q, b) # Solve with 2 args
- setup(P_data, A_data)¶
Set P and A matrix values for subsequent
solve()calls.- Parameters:
P_data – P matrix values, shape (nnzP,)
A_data – A matrix values, shape (nnzA,)
- solve(*args)¶
Solve the optimization problem.
Two signatures supported:
solve(q, b): Uses P/A fromsetup()solve(P_data, A_data, q, b): Full signature
- Returns:
JaxSolution namedtuple with x, z, s
- info¶
Metadata from the last
solve()call.
Functional API¶
For more control, use the solver function which returns a pure function:
- moreau.jax.solver(n, m, P_row_offsets, P_col_indices, A_row_offsets, A_col_indices, cones, settings=None, jit=True)¶
Create a JAX-compatible solve function.
Returns a pure function suitable for
jax.vmap,jax.grad, andjax.jit.- Returns:
Function with signature
(P_data, A_data, q, b) -> (JaxSolution, JaxSolveInfo)
Example:
from moreau.jax import solver import moreau import jax.numpy as jnp cones = moreau.Cones(num_nonneg_cones=2) solve = solver( n=2, m=2, P_row_offsets=jnp.array([0, 1, 2]), P_col_indices=jnp.array([0, 1]), A_row_offsets=jnp.array([0, 1, 2]), A_col_indices=jnp.array([0, 1]), cones=cones, ) P_data = jnp.array([1.0, 1.0]) A_data = jnp.array([1.0, 1.0]) q = jnp.array([1.0, 1.0]) b = jnp.array([0.5, 0.5]) solution, info = solve(P_data, A_data, q, b)
Gradient Computation¶
Use jax.grad for automatic differentiation:
import jax
import jax.numpy as jnp
from moreau.jax import Solver
import moreau
cones = moreau.Cones(num_nonneg_cones=2)
solver = Solver(
n=2, m=2,
P_row_offsets=jnp.array([0, 1, 2]),
P_col_indices=jnp.array([0, 1]),
A_row_offsets=jnp.array([0, 1, 2]),
A_col_indices=jnp.array([0, 1]),
cones=cones,
)
P_data = jnp.array([1.0, 1.0])
A_data = jnp.array([1.0, 1.0])
b = jnp.array([0.5, 0.5])
solver.setup(P_data, A_data)
# Define loss function
def loss_fn(q):
solution = solver.solve(q, b)
return jnp.sum(solution.x)
# Compute gradient
q = jnp.array([1.0, 1.0])
grad_q = jax.grad(loss_fn)(q)
Batching with vmap¶
Use jax.vmap for batched solving:
import jax
import jax.numpy as jnp
from moreau.jax import Solver
import moreau
cones = moreau.Cones(num_nonneg_cones=2)
solver = Solver(
n=2, m=2,
P_row_offsets=jnp.array([0, 1, 2]),
P_col_indices=jnp.array([0, 1]),
A_row_offsets=jnp.array([0, 1, 2]),
A_col_indices=jnp.array([0, 1]),
cones=cones,
)
P_data = jnp.array([1.0, 1.0])
A_data = jnp.array([1.0, 1.0])
solver.setup(P_data, A_data)
# Batch over q and b
batched_solve = jax.vmap(solver.solve)
q_batch = jnp.array([[1.0, 1.0], [2.0, 1.0], [1.0, 2.0], [2.0, 2.0]])
b_batch = jnp.array([[0.5, 0.5], [0.5, 0.5], [0.5, 0.5], [0.5, 0.5]])
solutions = batched_solve(q_batch, b_batch)
print(solutions.x.shape) # (4, 2)
JIT Compilation¶
The solver is JIT-compiled by default. For manual control:
# Disable JIT at construction
solver = Solver(n=2, m=2, ..., cones=cones, jit=False)
# Or JIT your own wrapper
@jax.jit
def my_solve(q, b):
return solver.solve(q, b)