JAX API

The JAX integration provides a functional API compatible with jax.grad, jax.vmap, and jax.jit.

Solver

class moreau.jax.Solver(n, m, P_row_offsets, P_col_indices, A_row_offsets, A_col_indices, cones, settings=None, jit=True)

JAX solver with automatic device selection and gradient support.

Supports two usage patterns:

  1. Full signature: solve(P_data, A_data, q, b)

  2. Two-step: setup(P_data, A_data) then solve(q, b)

Parameters:
  • n – Number of primal variables

  • m – Number of constraints

  • P_row_offsets – CSR row pointers for P matrix (array-like)

  • P_col_indices – CSR column indices for P matrix (array-like)

  • A_row_offsets – CSR row pointers for A matrix (array-like)

  • A_col_indices – CSR column indices for A matrix (array-like)

  • cones – Cone specification (moreau.Cones object)

  • settings – Optional solver settings (moreau.Settings object)

  • jit – If True (default), JIT-compile the solve method

Example (full signature):

import jax.numpy as jnp
from moreau.jax import Solver
import moreau

cones = moreau.Cones(num_nonneg_cones=2)
solver = Solver(
    n=2, m=2,
    P_row_offsets=jnp.array([0, 1, 2]),
    P_col_indices=jnp.array([0, 1]),
    A_row_offsets=jnp.array([0, 1, 2]),
    A_col_indices=jnp.array([0, 1]),
    cones=cones,
)

P_data = jnp.array([1.0, 1.0])
A_data = jnp.array([1.0, 1.0])
q = jnp.array([1.0, 1.0])
b = jnp.array([0.5, 0.5])

solution = solver.solve(P_data, A_data, q, b)
print(solution.x)

Example (two-step):

solver = Solver(n=2, m=2, ...)
solver.setup(P_data, A_data)  # Set matrices once
solution = solver.solve(q, b)  # Solve with 2 args
setup(P_data, A_data)

Set P and A matrix values for subsequent solve() calls.

Parameters:
  • P_data – P matrix values, shape (nnzP,)

  • A_data – A matrix values, shape (nnzA,)

solve(*args)

Solve the optimization problem.

Two signatures supported:

  • solve(q, b): Uses P/A from setup()

  • solve(P_data, A_data, q, b): Full signature

Returns:

JaxSolution namedtuple with x, z, s

info

Metadata from the last solve() call.

device: str

Active device name (‘cpu’ or ‘cuda’).

Functional API

For more control, use the solver function which returns a pure function:

moreau.jax.solver(n, m, P_row_offsets, P_col_indices, A_row_offsets, A_col_indices, cones, settings=None, jit=True)

Create a JAX-compatible solve function.

Returns a pure function suitable for jax.vmap, jax.grad, and jax.jit.

Returns:

Function with signature (P_data, A_data, q, b) -> (JaxSolution, JaxSolveInfo)

Example:

from moreau.jax import solver
import moreau
import jax.numpy as jnp

cones = moreau.Cones(num_nonneg_cones=2)
solve = solver(
    n=2, m=2,
    P_row_offsets=jnp.array([0, 1, 2]),
    P_col_indices=jnp.array([0, 1]),
    A_row_offsets=jnp.array([0, 1, 2]),
    A_col_indices=jnp.array([0, 1]),
    cones=cones,
)

P_data = jnp.array([1.0, 1.0])
A_data = jnp.array([1.0, 1.0])
q = jnp.array([1.0, 1.0])
b = jnp.array([0.5, 0.5])

solution, info = solve(P_data, A_data, q, b)

Gradient Computation

Use jax.grad for automatic differentiation:

import jax
import jax.numpy as jnp
from moreau.jax import Solver
import moreau

cones = moreau.Cones(num_nonneg_cones=2)
solver = Solver(
    n=2, m=2,
    P_row_offsets=jnp.array([0, 1, 2]),
    P_col_indices=jnp.array([0, 1]),
    A_row_offsets=jnp.array([0, 1, 2]),
    A_col_indices=jnp.array([0, 1]),
    cones=cones,
)

P_data = jnp.array([1.0, 1.0])
A_data = jnp.array([1.0, 1.0])
b = jnp.array([0.5, 0.5])
solver.setup(P_data, A_data)

# Define loss function
def loss_fn(q):
    solution = solver.solve(q, b)
    return jnp.sum(solution.x)

# Compute gradient
q = jnp.array([1.0, 1.0])
grad_q = jax.grad(loss_fn)(q)

Batching with vmap

Use jax.vmap for batched solving:

import jax
import jax.numpy as jnp
from moreau.jax import Solver
import moreau

cones = moreau.Cones(num_nonneg_cones=2)
solver = Solver(
    n=2, m=2,
    P_row_offsets=jnp.array([0, 1, 2]),
    P_col_indices=jnp.array([0, 1]),
    A_row_offsets=jnp.array([0, 1, 2]),
    A_col_indices=jnp.array([0, 1]),
    cones=cones,
)

P_data = jnp.array([1.0, 1.0])
A_data = jnp.array([1.0, 1.0])
solver.setup(P_data, A_data)

# Batch over q and b
batched_solve = jax.vmap(solver.solve)

q_batch = jnp.array([[1.0, 1.0], [2.0, 1.0], [1.0, 2.0], [2.0, 2.0]])
b_batch = jnp.array([[0.5, 0.5], [0.5, 0.5], [0.5, 0.5], [0.5, 0.5]])

solutions = batched_solve(q_batch, b_batch)
print(solutions.x.shape)  # (4, 2)

JIT Compilation

The solver is JIT-compiled by default. For manual control:

# Disable JIT at construction
solver = Solver(n=2, m=2, ..., cones=cones, jit=False)

# Or JIT your own wrapper
@jax.jit
def my_solve(q, b):
    return solver.solve(q, b)

Data Types

JaxSolution

class moreau.jax.JaxSolution

NamedTuple containing solution arrays.

x: jax.Array

Primal solution

z: jax.Array

Dual variables

s: jax.Array

Slack variables

JaxSolveInfo

class moreau.jax.JaxSolveInfo

NamedTuple containing solver metadata.

status: int

Solver status (SolverStatus enum value)

obj_val: float

Objective value

iterations: int

IPM iterations

solve_time: float

Solve time in seconds