Batched Solving

Moreau excels at solving many problems with shared structure. The CompiledSolver API is optimized for high-throughput batched workloads.

When to Use Batching

Use CompiledSolver when you have:

  • Multiple problems with the same sparsity pattern

  • Need for parallel GPU execution

  • Training loops where P/A structure is fixed

# Batched: 1000 similar QPs
settings = moreau.Settings(batch_size=1000, device='cuda')
solver = moreau.CompiledSolver(n, m, ..., settings=settings)
solver.setup(P_values, A_values)
solutions = solver.solve(qs, bs)  # Solves all 1000 in parallel

Three-Step API

CompiledSolver uses a three-step pattern:

1. Construct with Structure

Define the problem structure using CSR format:

import moreau
import numpy as np

cones = moreau.Cones(num_zero_cones=1, num_nonneg_cones=2)
settings = moreau.Settings(batch_size=64)

solver = moreau.CompiledSolver(
    n=2,  # Number of variables
    m=3,  # Number of constraints

    # P matrix structure (CSR)
    P_row_offsets=[0, 1, 2],   # Row pointers
    P_col_indices=[0, 1],      # Column indices

    # A matrix structure (CSR)
    A_row_offsets=[0, 2, 3, 4],
    A_col_indices=[0, 1, 0, 1],

    cones=cones,
    settings=settings,
)

2. Setup Matrix Values

Set the numerical values for P and A matrices:

# Shared across batch (1D)
P_values = np.array([1.0, 1.0])
A_values = np.array([1.0, 1.0, 1.0, 1.0])
solver.setup(P_values, A_values)

# Or per-problem (2D): shape (batch_size, nnz)
P_values_batch = np.tile([1.0, 1.0], (64, 1))
A_values_batch = np.tile([1.0, 1.0, 1.0, 1.0], (64, 1))
solver.setup(P_values_batch, A_values_batch)

3. Solve with Parameters

Solve the batch with q and b vectors:

# Shape: (batch_size, n) and (batch_size, m)
qs = np.random.randn(64, 2)
bs = np.ones((64, 3))

solution = solver.solve(qs, bs)

# Results have batch dimension
print(solution.x.shape)  # (64, 2)
print(solution.z.shape)  # (64, 3)

Value Broadcasting

Both setup() and solve() support broadcasting:

Shared Values (1D input)

# Same P and A for all problems
solver.setup(
    P_values=[1.0, 1.0],      # Shape (nnz_P,) - shared
    A_values=[1.0, 1.0, 1.0, 1.0],  # Shape (nnz_A,) - shared
)

Per-Problem Values (2D input)

# Different P and A for each problem
solver.setup(
    P_values=P_batch,   # Shape (batch, nnz_P)
    A_values=A_batch,   # Shape (batch, nnz_A)
)

Accessing Batch Results

Solution and info are batched:

solution = solver.solve(qs, bs)
info = solver.info

# Solutions have batch dimension
x = solution.x  # Shape (batch, n)
z = solution.z  # Shape (batch, m)
s = solution.s  # Shape (batch, m)

# Info has list of statuses
statuses = info.status  # List of SolverStatus
obj_vals = info.obj_val  # Array of shape (batch,)

# Check specific problem
print(f"Problem 0: status={statuses[0]}, obj={obj_vals[0]}")

Indexing Solutions

You can index into batched solutions:

# Get solution for problem 5
sol_5 = solution[5]
print(sol_5.x)  # Shape (n,)

Re-solving with New Data

The compiled structure supports efficient re-solving:

# Solve first batch
solution1 = solver.solve(qs_batch1, bs_batch1)

# Update values and solve again
solver.setup(new_P_values, new_A_values)
solution2 = solver.solve(qs_batch2, bs_batch2)

GPU Performance Tips

For maximum GPU throughput:

  1. Use large batch sizes: GPU parallelism scales with batch size

    settings = moreau.Settings(batch_size=1024, device='cuda')
    
  2. Pre-allocate: Reuse the same solver object

    for batch in data_loader:
        solver.setup(batch.P, batch.A)
        solution = solver.solve(batch.q, batch.b)
    
  3. Avoid host transfers: Keep data on GPU when using PyTorch/JAX

    from moreau.torch import Solver
    # Tensors stay on GPU throughout
    solver.setup(P_values.cuda(), A_values.cuda())
    solution = solver.solve(q.cuda(), b.cuda())
    

Memory Considerations

Batch size affects memory usage:

Memory ~ batch_size * (n + 2*m) * 8 bytes (float64)

For large problems, balance batch size against available memory:

# Estimate memory for batch
n, m, batch_size = 1000, 2000, 512
memory_mb = batch_size * (n + 2*m) * 8 / 1e6
print(f"Estimated memory: {memory_mb:.1f} MB")