JAX API

The JAX integration provides a functional API compatible with jax.grad, jax.vmap, and jax.jit.

Solver

class moreau.jax.Solver(n, m, P_row_offsets, P_col_indices, A_row_offsets, A_col_indices, cones, settings=None, jit=True)

JAX solver with automatic device selection and gradient support.

Supports two usage patterns:

  1. Full signature: solve(P_data, A_data, q, b)

  2. Two-step: setup(P_data, A_data) then solve(q, b)

Parameters:
  • n – Number of primal variables

  • m – Number of constraints

  • P_row_offsets – CSR row pointers for P matrix (array-like). P must be full symmetric (both upper and lower triangles).

  • P_col_indices – CSR column indices for P matrix (array-like)

  • A_row_offsets – CSR row pointers for A matrix (array-like)

  • A_col_indices – CSR column indices for A matrix (array-like)

  • cones – Cone specification (moreau.Cones object)

  • settings – Optional solver settings (moreau.Settings object)

  • jit – If True (default), JIT-compile the solve method

Example (full signature):

import jax.numpy as jnp
from moreau.jax import Solver
import moreau

cones = moreau.Cones(num_nonneg_cones=2)
solver = Solver(
    n=2, m=2,
    P_row_offsets=jnp.array([0, 1, 2]),
    P_col_indices=jnp.array([0, 1]),
    A_row_offsets=jnp.array([0, 1, 2]),
    A_col_indices=jnp.array([0, 1]),
    cones=cones,
)

P_data = jnp.array([1.0, 1.0])
A_data = jnp.array([1.0, 1.0])
q = jnp.array([1.0, 1.0])
b = jnp.array([0.5, 0.5])

solution = solver.solve(P_data, A_data, q, b)
print(solution.x)

Example (two-step):

solver = Solver(n=2, m=2, ...)
solver.setup(P_data, A_data)  # Set matrices once
solution = solver.solve(q, b)  # Solve with 2 args
setup(P_data, A_data)

Set P and A matrix values for subsequent solve() calls.

Gradients w.r.t. P and A are still computed when using the 2-arg solve().

Parameters:
  • P_data – P matrix values, shape (nnzP,)

  • A_data – A matrix values, shape (nnzA,)

solve(*args, warm_start=None)

Solve the optimization problem.

Two signatures supported:

  • solve(q, b): Uses P/A from setup() (raises RuntimeError if setup() not called)

  • solve(P_data, A_data, q, b): Full signature

Parameters:

warm_start – Optional WarmStart or BatchedWarmStart from a previous solve (e.g. solution.to_warm_start()). If the warm-started solve fails, it is automatically retried without warm start. Gradients do not flow through warm start values.

Returns:

JaxSolution namedtuple with x, z, s

info: JaxSolveInfo

Metadata from the last solve() call. Returns None if solve() has not been called yet.

Note

For jax.vmap calls, this returns info from the last single call, not the batched result. To access batched info, use the functional solver() API which returns (JaxSolution, JaxSolveInfo) tuples.

device: str

Active device name (‘cpu’ or ‘cuda’).

n: int

Number of primal variables.

m: int

Number of constraints.

construction_time: float

Time spent constructing solver structure (seconds).

tune_result: TuneResult or None

Result from auto-tuning on the first solve() call. Returns None if auto-tune has not run (e.g. device and method were set explicitly, or solve() has not been called).

Functional API

For more control, use the solver function which returns a pure function:

moreau.jax.solver(n, m, P_row_offsets, P_col_indices, A_row_offsets, A_col_indices, cones, settings=None, jit=True)

Create a JAX-compatible solve function.

Returns a pure function suitable for jax.vmap, jax.grad, and jax.jit. The functional API does not support warm starting.

Returns:

Function with signature (P_data, A_data, q, b) -> (JaxSolution, JaxSolveInfo)

Example:

from moreau.jax import solver
import moreau
import jax.numpy as jnp

cones = moreau.Cones(num_nonneg_cones=2)
solve = solver(
    n=2, m=2,
    P_row_offsets=jnp.array([0, 1, 2]),
    P_col_indices=jnp.array([0, 1]),
    A_row_offsets=jnp.array([0, 1, 2]),
    A_col_indices=jnp.array([0, 1]),
    cones=cones,
)

P_data = jnp.array([1.0, 1.0])
A_data = jnp.array([1.0, 1.0])
q = jnp.array([1.0, 1.0])
b = jnp.array([0.5, 0.5])

solution, info = solve(P_data, A_data, q, b)

Gradient Computation

Use jax.grad for automatic differentiation:

import jax
import jax.numpy as jnp
from moreau.jax import Solver
import moreau

cones = moreau.Cones(num_nonneg_cones=2)
solver = Solver(
    n=2, m=2,
    P_row_offsets=jnp.array([0, 1, 2]),
    P_col_indices=jnp.array([0, 1]),
    A_row_offsets=jnp.array([0, 1, 2]),
    A_col_indices=jnp.array([0, 1]),
    cones=cones,
)

P_data = jnp.array([1.0, 1.0])
A_data = jnp.array([1.0, 1.0])
b = jnp.array([0.5, 0.5])
solver.setup(P_data, A_data)

# Define loss function
def loss_fn(q):
    solution = solver.solve(q, b)
    return jnp.sum(solution.x)

# Compute gradient
q = jnp.array([1.0, 1.0])
grad_q = jax.grad(loss_fn)(q)

Batching with vmap

Use jax.vmap for batched solving:

import jax
import jax.numpy as jnp
from moreau.jax import Solver
import moreau

cones = moreau.Cones(num_nonneg_cones=2)
solver = Solver(
    n=2, m=2,
    P_row_offsets=jnp.array([0, 1, 2]),
    P_col_indices=jnp.array([0, 1]),
    A_row_offsets=jnp.array([0, 1, 2]),
    A_col_indices=jnp.array([0, 1]),
    cones=cones,
)

P_data = jnp.array([1.0, 1.0])
A_data = jnp.array([1.0, 1.0])
solver.setup(P_data, A_data)

# Batch over q and b
batched_solve = jax.vmap(solver.solve)

q_batch = jnp.array([[1.0, 1.0], [2.0, 1.0], [1.0, 2.0], [2.0, 2.0]])
b_batch = jnp.array([[0.5, 0.5], [0.5, 0.5], [0.5, 0.5], [0.5, 0.5]])

solutions = batched_solve(q_batch, b_batch)
print(solutions.x.shape)  # (4, 2)

JIT Compilation

The solver is JIT-compiled by default. For manual control:

# Disable JIT at construction
solver = Solver(n=2, m=2, ..., cones=cones, jit=False)

# Or JIT your own wrapper
@jax.jit
def my_solve(q, b):
    return solver.solve(q, b)

Data Types

JaxSolution

class moreau.jax.JaxSolution

NamedTuple containing solution arrays (JAX pytree-compatible).

x: jax.Array

Primal solution

z: jax.Array

Dual variables

s: jax.Array

Slack variables

to_warm_start()

Create a WarmStart from this solution (converts JAX arrays to numpy).

Return type:

WarmStart

JaxSolveInfo

class moreau.jax.JaxSolveInfo

NamedTuple containing solver metadata (JAX pytree-compatible).

status: float

Solver status as a float (SolverStatus integer value cast to float for JAX pytree compatibility). Compare with int(info.status) == SolverStatus.Solved.

obj_val: float

Objective value

iterations: int

IPM iterations

solve_time: float

Solve time in seconds

setup_time: float

Time setting matrix values (seconds)

construction_time: float

Time constructing solver (seconds)