Examples¶
Practical examples demonstrating Moreau for common optimization tasks.
Quick Examples¶
Simple QP¶
A basic quadratic program:
import moreau
import numpy as np
from scipy import sparse
# minimize (1/2)||x||^2 + c'x
# subject to Ax <= b
n, m = 10, 5
P = sparse.eye(n, format='csr')
q = np.random.randn(n)
A = sparse.random(m, n, density=0.3, format='csr')
b = np.ones(m)
cones = moreau.Cones(num_nonneg_cones=m)
solver = moreau.Solver(P, q, A, b, cones=cones)
solution = solver.solve()
print(f"Status: {solver.info.status}")
Least Squares with Constraints¶
import moreau
import numpy as np
from scipy import sparse
# minimize ||Cx - d||^2
# subject to x >= 0
n, p = 20, 50
C = np.random.randn(p, n)
d = np.random.randn(p)
# Convert to QP: minimize x'(C'C)x - 2(C'd)'x
P = sparse.csr_matrix(C.T @ C)
P = (P + P.T) / 2 # Ensure symmetric
q = -2 * C.T @ d
A = sparse.eye(n, format='csr') # x >= 0
b = np.zeros(n)
cones = moreau.Cones(num_nonneg_cones=n)
solver = moreau.Solver(P, q, A, b, cones=cones)
solution = solver.solve()
PyTorch Training Loop¶
import torch
from moreau.torch import Solver
import moreau
# Setup solver with CSR structure
cones = moreau.Cones(num_nonneg_cones=2)
settings = moreau.Settings(device='cuda', batch_size=32)
solver = Solver(
n=2, m=2,
P_row_offsets=torch.tensor([0, 1, 2]),
P_col_indices=torch.tensor([0, 1]),
A_row_offsets=torch.tensor([0, 1, 2]),
A_col_indices=torch.tensor([0, 1]),
cones=cones,
settings=settings,
)
# Matrix values and target
P_values = torch.tensor([1.0, 1.0], dtype=torch.float64)
A_values = torch.tensor([1.0, 1.0], dtype=torch.float64)
b = torch.tensor([0.5, 0.5], dtype=torch.float64)
target = torch.tensor([0.3, 0.3], dtype=torch.float64)
# Learnable parameter
q_param = torch.randn(2, dtype=torch.float64, requires_grad=True)
optimizer = torch.optim.Adam([q_param], lr=0.01)
for epoch in range(100):
optimizer.zero_grad()
solver.setup(P_values, A_values)
solution = solver.solve(q_param, b)
loss = (solution.x - target).pow(2).mean()
loss.backward()
optimizer.step()
JAX with vmap¶
import jax
import jax.numpy as jnp
from moreau.jax import Solver
import moreau
# Setup solver with CSR structure
cones = moreau.Cones(num_nonneg_cones=2)
solver = Solver(
n=2, m=2,
P_row_offsets=jnp.array([0, 1, 2]),
P_col_indices=jnp.array([0, 1]),
A_row_offsets=jnp.array([0, 1, 2]),
A_col_indices=jnp.array([0, 1]),
cones=cones,
)
# Matrix values
P_data = jnp.array([1.0, 1.0])
A_data = jnp.array([1.0, 1.0])
solver.setup(P_data, A_data)
# Batch solve with vmap
batched_solve = jax.vmap(solver.solve)
q_batch = jnp.array([[1.0, 1.0], [2.0, 1.0], [1.0, 2.0], [2.0, 2.0]])
b_batch = jnp.array([[0.5, 0.5], [0.5, 0.5], [0.5, 0.5], [0.5, 0.5]])
solutions = batched_solve(q_batch, b_batch)
print(solutions.x.shape) # (4, 2)
Application Areas¶
Domain |
Example Problem |
|---|---|
Finance |
Portfolio optimization, risk management |
Control |
Model predictive control, LQR |
Robotics |
Motion planning, inverse kinematics |
ML |
Constrained learning, optimal transport |
Signal Processing |
Sparse reconstruction, denoising |