Skip to content
Draft
214 changes: 142 additions & 72 deletions bindings/python/proxsuite/torch/qplayer.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,10 @@ def QPFunction(
class QPFunctionFn(Function):
@staticmethod
def forward(ctx, Q_, p_, A_, b_, G_, l_, u_):
nBatch = extract_nBatch(Q_, p_, A_, b_, G_, l_, u_)
if len(Q_.size()) == 3:
nBatch = Q_.size(0)
else:
nBatch = 1
Q, _ = expandParam(Q_, nBatch, 3)
p, _ = expandParam(p_, nBatch, 2)
G, _ = expandParam(G_, nBatch, 3)
Expand All @@ -103,8 +106,13 @@ def forward(ctx, Q_, p_, A_, b_, G_, l_, u_):
ctx.vector_of_qps = proxsuite.proxqp.dense.BatchQP()

ctx.nBatch = nBatch

_, nineq, nz = G.size()
do_neq = True
if len(G.size()) == 3 or len(G.size()) == 2:
nineq, nz = G.size()[1:]
else:
nineq = 0
nz = Q.size()[-1]
do_neq = False
neq = A.size(1) if A.nelement() > 0 else 0
assert neq > 0 or nineq > 0
ctx.neq, ctx.nineq, ctx.nz = neq, nineq, nz
Expand All @@ -121,7 +129,7 @@ def forward(ctx, Q_, p_, A_, b_, G_, l_, u_):
qp = ctx.vector_of_qps.init_qp_in_place(ctx.nz, ctx.neq, ctx.nineq)
qp.settings.primal_infeasibility_solving = False
qp.settings.max_iter = maxIter
qp.settings.max_iter_in = 100
qp.settings.max_iter_in = 1000
default_rho = 5.0e-5
qp.settings.default_rho = default_rho
qp.settings.refactor_rho_threshold = default_rho # no refactorization
Expand All @@ -134,21 +142,20 @@ def forward(ctx, Q_, p_, A_, b_, G_, l_, u_):
if p[i] is not None:
p__ = p[i].cpu().numpy()
G__ = None
if G[i] is not None:
if do_neq and G[i] is not None:
G__ = G[i].cpu().numpy()
u__ = None
if u[i] is not None:
if do_neq and u[i] is not None:
u__ = u[i].cpu().numpy()
l__ = None
if l[i] is not None:
if do_neq and l[i] is not None:
l__ = l[i].cpu().numpy()
A__ = None
if Ai is not None:
A__ = Ai.cpu().numpy()
b__ = None
if bi is not None:
b__ = bi.cpu().numpy()

qp.init(
H=H__, g=p__, A=A__, b=b__, C=G__, l=l__, u=u__, rho=default_rho
)
Expand Down Expand Up @@ -255,8 +262,22 @@ def backward(ctx, dl_dzhat, dl_dlams, dl_dnus):
class QPFunctionFn_infeas(Function):
@staticmethod
def forward(ctx, Q_, p_, A_, b_, G_, l_, u_):
n_in, nz = G_.size() # true double-sided inequality size
nBatch = extract_nBatch(Q_, p_, A_, b_, G_, l_, u_)

do_neq = True
if len(G_.size()) == 3:
_, n_in, nz = G_.size()
elif len(G_.size()) == 2:
n_in = G_.size()[-2]
nz = G_.size()[-1]
else:
n_in = Q_.size()[-1]
nz = Q_.size()[-1]
do_neq = False
ctx.G_size = G_.size()
if len(Q_.size()) == 3:
nBatch = Q_.size(0)
else:
nBatch = 1

Q, _ = expandParam(Q_, nBatch, 3)
p, _ = expandParam(p_, nBatch, 2)
Expand All @@ -268,32 +289,43 @@ def forward(ctx, Q_, p_, A_, b_, G_, l_, u_):

h = torch.cat((-l, u), axis=1) # single-sided inequality
G = torch.cat((-G, G), axis=1) # single-sided inequality

_, nineq, nz = G.size()
neq = A.size(1) if A.nelement() > 0 else 0
if len(G.size()) == 3:
_, nineq, nz = G.size()
else:
nineq = 0
nz = Q.size()[-1]
if len(A.size()) == 3 or len(A.size()) == 2:
neq = A.size(-2) if A.nelement() > 0 else 0
else:
neq = 0
assert neq > 0 or nineq > 0
ctx.neq, ctx.nineq, ctx.nz = neq, nineq, nz

zhats = torch.empty((nBatch, ctx.nz), dtype=Q.dtype)
nus = torch.empty((nBatch, ctx.nineq), dtype=Q.dtype)
nus_sol = torch.empty(
(nBatch, n_in), dtype=Q.dtype
) # double-sided inequality multiplier
if do_neq:
nus_sol = torch.empty(
(nBatch, n_in), dtype=Q.dtype
) # double-sided inequality multiplier
else:
nus_sol = None
lams = (
torch.empty(nBatch, ctx.neq, dtype=Q.dtype)
if ctx.neq > 0
else torch.empty()
else torch.tensor([])
)
s_e = (
torch.empty(nBatch, ctx.neq, dtype=Q.dtype)
if ctx.neq > 0
else torch.empty()
else torch.tensor([])
)
slacks = torch.empty((nBatch, ctx.nineq), dtype=Q.dtype)
s_i = torch.empty(
(nBatch, n_in), dtype=Q.dtype
) # this one is of size the one of the original n_in

if do_neq:
s_i = torch.empty(
(nBatch, n_in), dtype=Q.dtype
) # this one is of size the one of the original n_in
else:
s_i = None
vector_of_qps = proxsuite.proxqp.dense.BatchQP()

ctx.cpu = os.cpu_count()
Expand All @@ -305,24 +337,27 @@ def forward(ctx, Q_, p_, A_, b_, G_, l_, u_):
qp = vector_of_qps.init_qp_in_place(ctx.nz, ctx.neq, ctx.nineq)
qp.settings.primal_infeasibility_solving = True
qp.settings.max_iter = maxIter
qp.settings.max_iter_in = 100
qp.settings.max_iter_in = 1000
default_rho = 5.0e-5
qp.settings.default_rho = default_rho
qp.settings.refactor_rho_threshold = default_rho # no refactorization
qp.settings.eps_abs = eps
Ai, bi = (A[i], b[i]) if neq > 0 else (None, None)

H__ = None
if Q[i] is not None:
H__ = Q[i].cpu().numpy()
p__ = None
if p[i] is not None:
p__ = p[i].cpu().numpy()
G__ = None
if G[i] is not None:
if do_neq and G[i] is not None:
G__ = G[i].cpu().numpy()
u__ = None
if h[i] is not None:
if do_neq and h[i] is not None:
u__ = h[i].cpu().numpy()
if not do_neq:
l = None
# l__ = None
# if (l[i] is not None):
# l__ = l[i].cpu().numpy()
Expand All @@ -332,7 +367,6 @@ def forward(ctx, Q_, p_, A_, b_, G_, l_, u_):
b__ = None
if bi is not None:
b__ = bi.cpu().numpy()

qp.init(H=H__, g=p__, A=A__, b=b__, C=G__, l=l, u=u__, rho=default_rho)

if proxqp_parallel:
Expand All @@ -348,16 +382,18 @@ def forward(ctx, Q_, p_, A_, b_, G_, l_, u_):
if nineq > 0:
# we re-convert the solution to a double sided inequality QP
slack = -h[i] + G[i] @ vector_of_qps.get(i).results.x
nus_sol[i] = torch.Tensor(
-vector_of_qps.get(i).results.z[:n_in]
+ vector_of_qps.get(i).results.z[n_in:]
) # de-projecting this one may provoke loss of information when using inexact solution
if do_neq:
nus_sol[i] = torch.Tensor(
-vector_of_qps.get(i).results.z[:n_in]
+ vector_of_qps.get(i).results.z[n_in:]
) # de-projecting this one may provoke loss of information when using inexact solution
nus[i] = torch.tensor(vector_of_qps.get(i).results.z)
slacks[i] = slack.clone().detach()
s_i[i] = torch.tensor(
-vector_of_qps.get(i).results.si[:n_in]
+ vector_of_qps.get(i).results.si[n_in:]
)
if do_neq:
s_i[i] = torch.tensor(
-vector_of_qps.get(i).results.si[:n_in]
+ vector_of_qps.get(i).results.si[n_in:]
)
if neq > 0:
lams[i] = torch.tensor(vector_of_qps.get(i).results.y)
s_e[i] = torch.tensor(vector_of_qps.get(i).results.se)
Expand All @@ -371,7 +407,10 @@ def forward(ctx, Q_, p_, A_, b_, G_, l_, u_):
@staticmethod
def backward(ctx, dl_dzhat, dl_dlams, dl_dnus, dl_ds_e, dl_ds_i):
zhats, s_e, Q, p, G, l, u, A, b = ctx.saved_tensors
nBatch = extract_nBatch(Q, p, A, b, G, l, u)
if len(Q.size()) == 3:
nBatch = Q.size(0)
else:
nBatch = 1

Q, Q_e = expandParam(Q, nBatch, 3)
p, p_e = expandParam(p, nBatch, 2)
Expand Down Expand Up @@ -414,7 +453,9 @@ def backward(ctx, dl_dzhat, dl_dlams, dl_dnus, dl_ds_e, dl_ds_i):

for i in range(nBatch):
Q_i = Q[i].numpy()
C_i = G[i].numpy()
C_i = None
if G is not None and G.numel() != 0:
C_i = G[i].numpy()
A_i = None
if A is not None:
if A.shape[0] != 0:
Expand All @@ -436,16 +477,17 @@ def backward(ctx, dl_dzhat, dl_dlams, dl_dnus, dl_ds_e, dl_ds_i):
if neq > 0:
kkt[:dim, dim : dim + n_eq] = A_i.transpose()
kkt[dim : dim + n_eq, :dim] = A_i
kkt[
dim + n_eq + n_in : dim + 2 * n_eq + n_in, dim : dim + n_eq
] = -np.eye(n_eq)
kkt[dim + n_eq + n_in : dim + 2 * n_eq + n_in, dim : dim + n_eq] = (
-np.eye(n_eq)
)
kkt[
dim + n_eq + n_in : dim + 2 * n_eq + n_in,
dim + n_eq + 2 * n_in : 2 * dim + n_eq + 2 * n_in,
] = A_i

kkt[:dim, dim + n_eq : dim + n_eq + n_in] = C_i.transpose()
kkt[dim + n_eq : dim + n_eq + n_in, :dim] = C_i
if n_in > 0:
kkt[:dim, dim + n_eq : dim + n_eq + n_in] = C_i.transpose()
kkt[dim + n_eq : dim + n_eq + n_in, :dim] = C_i

D_1_c = np.eye(n_in) # represents [s_i]_- + z_i < 0
D_1_c[P_1, P_1] = 0.0
Expand Down Expand Up @@ -485,9 +527,9 @@ def backward(ctx, dl_dzhat, dl_dlams, dl_dnus, dl_ds_e, dl_ds_i):
rhs[dim + n_eq : dim + n_eq + n_in_sol][~active_set] = dl_dnus[
i
][~active_set]
rhs[dim + n_eq + n_in_sol : dim + n_eq + n_in][
active_set
] = -dl_dnus[i][active_set]
rhs[dim + n_eq + n_in_sol : dim + n_eq + n_in][active_set] = (
-dl_dnus[i][active_set]
)
if dl_ds_e is not None:
if dl_ds_e.shape[0] != 0:
rhs[dim + n_eq + n_in : dim + 2 * n_eq + n_in] = -dl_ds_e[i]
Expand Down Expand Up @@ -515,9 +557,9 @@ def backward(ctx, dl_dzhat, dl_dlams, dl_dnus, dl_ds_e, dl_ds_i):

qp.settings.primal_infeasibility_solving = True
qp.settings.eps_abs = eps_backward
qp.settings.max_iter = 10
qp.settings.default_rho = 1.0e-3
qp.settings.refactor_rho_threshold = 1.0e-3
qp.settings.max_iter = 1000
qp.settings.default_rho = 5.0e-5
qp.settings.refactor_rho_threshold = 5.0e-5
qp.init(
H,
g,
Expand All @@ -542,13 +584,19 @@ def backward(ctx, dl_dzhat, dl_dlams, dl_dnus, dl_ds_e, dl_ds_i):
)
if n_eq > 0:
dlam[i] = torch.from_numpy(
np.float64(vector_of_qps.get(i).results.x[dim : dim + n_eq])
vector_of_qps.get(i)
.results.x[dim : dim + n_eq]
.astype(np.float64)
)
dnu[i] = torch.from_numpy(
np.float64(
vector_of_qps.get(i).results.x[dim + n_eq : dim + n_eq + n_in]

if dnu is not None:
dnu[i] = torch.from_numpy(
np.float64(
vector_of_qps.get(i).results.x[
dim + n_eq : dim + n_eq + n_in
]
)
)
)
dim_ = 0
if n_eq > 0:
b_5[i] = torch.from_numpy(
Expand All @@ -566,16 +614,18 @@ def backward(ctx, dl_dzhat, dl_dlams, dl_dnus, dl_ds_e, dl_ds_i):
)

dps = dx
dGs = (
bger(dnu.double(), zhats.double())
+ bger(ctx.nus.double(), dx.double())
+ bger(P_2_c_s_i.double(), b_6.double())
)
if G_e:
dGs = dGs.mean(0)
dhs = -dnu
if h_e:
dhs = dhs.mean(0)
dGs = None
if dnu is not None:
dGs = (
bger(dnu.double(), zhats.double())
+ bger(ctx.nus.double(), dx.double())
+ bger(P_2_c_s_i.double(), b_6.double())
)
if G_e:
dGs = dGs.mean(0)
dhs = -dnu
if h_e:
dhs = dhs.mean(0)
if neq > 0:
dAs = (
bger(dlam.double(), zhats.double())
Expand All @@ -597,16 +647,36 @@ def backward(ctx, dl_dzhat, dl_dlams, dl_dnus, dl_ds_e, dl_ds_i):
if p_e:
dps = dps.mean(0)

grads = (
dQs,
dps,
dAs,
dbs,
dGs[n_in_sol:, :],
-dhs[:n_in_sol],
dhs[n_in_sol:],
)

if len(ctx.G_size) == 2:
grads = (
dQs,
dps,
dAs,
dbs,
dGs[n_in_sol:, :],
-dhs[:n_in_sol],
dhs[n_in_sol:],
)
elif len(ctx.G_size) == 3:
grads = (
dQs,
dps,
dAs,
dbs,
dGs[:, n_in_sol:, :],
-dhs[:, :n_in_sol],
dhs[:, n_in_sol:],
)
else:
grads = (
dQs,
dps,
dAs,
dbs,
None,
None,
None,
)
return grads

if structural_feasibility:
Expand Down
Loading