JAX API¶
The JAX integration provides a functional API compatible with jax.grad, jax.vmap, and jax.jit.
Solver¶
- class moreau.jax.Solver(n, m, P_row_offsets, P_col_indices, A_row_offsets, A_col_indices, cones, settings=None, jit=True)¶
JAX solver with automatic device selection and gradient support.
Supports two usage patterns:
Full signature:
solve(P_data, A_data, q, b)Two-step:
setup(P_data, A_data)thensolve(q, b)
- Parameters:
n – Number of primal variables
m – Number of constraints
P_row_offsets – CSR row pointers for P matrix (array-like). P must be full symmetric (both upper and lower triangles).
P_col_indices – CSR column indices for P matrix (array-like)
A_row_offsets – CSR row pointers for A matrix (array-like)
A_col_indices – CSR column indices for A matrix (array-like)
cones – Cone specification (moreau.Cones object)
settings – Optional solver settings (moreau.Settings object)
jit – If True (default), JIT-compile the solve method
Example (full signature):
import jax.numpy as jnp from moreau.jax import Solver import moreau cones = moreau.Cones(num_nonneg_cones=2) solver = Solver( n=2, m=2, P_row_offsets=jnp.array([0, 1, 2]), P_col_indices=jnp.array([0, 1]), A_row_offsets=jnp.array([0, 1, 2]), A_col_indices=jnp.array([0, 1]), cones=cones, ) P_data = jnp.array([1.0, 1.0]) A_data = jnp.array([1.0, 1.0]) q = jnp.array([1.0, 1.0]) b = jnp.array([0.5, 0.5]) solution = solver.solve(P_data, A_data, q, b) print(solution.x)
Example (two-step):
solver = Solver(n=2, m=2, ...) solver.setup(P_data, A_data) # Set matrices once solution = solver.solve(q, b) # Solve with 2 args
- setup(P_data, A_data)¶
Set P and A matrix values for subsequent
solve()calls.Gradients w.r.t. P and A are still computed when using the 2-arg
solve().- Parameters:
P_data – P matrix values, shape (nnzP,)
A_data – A matrix values, shape (nnzA,)
- solve(*args, warm_start=None)¶
Solve the optimization problem.
Two signatures supported:
solve(q, b): Uses P/A fromsetup()(raisesRuntimeErrorifsetup()not called)solve(P_data, A_data, q, b): Full signature
- Parameters:
warm_start – Optional
WarmStartorBatchedWarmStartfrom a previous solve (e.g.solution.to_warm_start()). If the warm-started solve fails, it is automatically retried without warm start. Gradients do not flow through warm start values.- Returns:
JaxSolution namedtuple with x, z, s
- info: JaxSolveInfo¶
Metadata from the last
solve()call. ReturnsNoneifsolve()has not been called yet.Note
For
jax.vmapcalls, this returns info from the last single call, not the batched result. To access batched info, use the functionalsolver()API which returns(JaxSolution, JaxSolveInfo)tuples.
- tune_result: TuneResult or None¶
Result from auto-tuning on the first
solve()call. ReturnsNoneif auto-tune has not run (e.g. device and method were set explicitly, orsolve()has not been called).
Functional API¶
For more control, use the solver function which returns a pure function:
- moreau.jax.solver(n, m, P_row_offsets, P_col_indices, A_row_offsets, A_col_indices, cones, settings=None, jit=True)¶
Create a JAX-compatible solve function.
Returns a pure function suitable for
jax.vmap,jax.grad, andjax.jit. The functional API does not support warm starting.- Returns:
Function with signature
(P_data, A_data, q, b) -> (JaxSolution, JaxSolveInfo)
Example:
from moreau.jax import solver import moreau import jax.numpy as jnp cones = moreau.Cones(num_nonneg_cones=2) solve = solver( n=2, m=2, P_row_offsets=jnp.array([0, 1, 2]), P_col_indices=jnp.array([0, 1]), A_row_offsets=jnp.array([0, 1, 2]), A_col_indices=jnp.array([0, 1]), cones=cones, ) P_data = jnp.array([1.0, 1.0]) A_data = jnp.array([1.0, 1.0]) q = jnp.array([1.0, 1.0]) b = jnp.array([0.5, 0.5]) solution, info = solve(P_data, A_data, q, b)
Gradient Computation¶
Use jax.grad for automatic differentiation:
import jax
import jax.numpy as jnp
from moreau.jax import Solver
import moreau
cones = moreau.Cones(num_nonneg_cones=2)
solver = Solver(
n=2, m=2,
P_row_offsets=jnp.array([0, 1, 2]),
P_col_indices=jnp.array([0, 1]),
A_row_offsets=jnp.array([0, 1, 2]),
A_col_indices=jnp.array([0, 1]),
cones=cones,
)
P_data = jnp.array([1.0, 1.0])
A_data = jnp.array([1.0, 1.0])
b = jnp.array([0.5, 0.5])
solver.setup(P_data, A_data)
# Define loss function
def loss_fn(q):
solution = solver.solve(q, b)
return jnp.sum(solution.x)
# Compute gradient
q = jnp.array([1.0, 1.0])
grad_q = jax.grad(loss_fn)(q)
Batching with vmap¶
Use jax.vmap for batched solving:
import jax
import jax.numpy as jnp
from moreau.jax import Solver
import moreau
cones = moreau.Cones(num_nonneg_cones=2)
solver = Solver(
n=2, m=2,
P_row_offsets=jnp.array([0, 1, 2]),
P_col_indices=jnp.array([0, 1]),
A_row_offsets=jnp.array([0, 1, 2]),
A_col_indices=jnp.array([0, 1]),
cones=cones,
)
P_data = jnp.array([1.0, 1.0])
A_data = jnp.array([1.0, 1.0])
solver.setup(P_data, A_data)
# Batch over q and b
batched_solve = jax.vmap(solver.solve)
q_batch = jnp.array([[1.0, 1.0], [2.0, 1.0], [1.0, 2.0], [2.0, 2.0]])
b_batch = jnp.array([[0.5, 0.5], [0.5, 0.5], [0.5, 0.5], [0.5, 0.5]])
solutions = batched_solve(q_batch, b_batch)
print(solutions.x.shape) # (4, 2)
JIT Compilation¶
The solver is JIT-compiled by default. For manual control:
# Disable JIT at construction
solver = Solver(n=2, m=2, ..., cones=cones, jit=False)
# Or JIT your own wrapper
@jax.jit
def my_solve(q, b):
return solver.solve(q, b)