osqp / osqpth

The differentiable OSQP solver layer for PyTorch
Other
59 stars 5 forks source link

Quick note on calling into OSQP again for the backward pass #3

Closed bamos closed 5 years ago

bamos commented 5 years ago

Just leaving this here in case it's useful in the future (probably won't be)

I just tried replacing the sparse solve in the bw pass with another OSQP solve to see if re-using the factorizations helps and it seems ~twice as slow on some tests that I'm looking at. The relevant part of the bw pass is:

m = self.solvers[i]
l = np.zeros(self.m)
u = np.zeros(self.m)
l[ind_inactive] = -np.inf
u[ind_inactive] = np.inf
m.update(q=-dl_dx.squeeze(), l=l, u=u)
result = m.solve()
r_x = result.x
r_y = result.y
r_y[ind_inactive] = 0.
r_yl = r_y[ind_low]
r_yu = r_y[ind_upp]

And line-profiling this gives:

Line #      Hits         Time  Per Hit   % Time  Line Contents
==============================================================
   175        10    1017499.0 101749.9     67.9              m.update(q=-dl_dx.squeeze(), l=l, u=u)
   176        10     386027.0  38602.7     25.8              result = m.solve()