locuslab / optnet

OptNet: Differentiable Optimization as a Layer in Neural Networks
Apache License 2.0
513 stars 75 forks source link

Sudoku experiment #1

Closed robotsorcerer closed 7 years ago

robotsorcerer commented 7 years ago

I tried running the sudoku example. My command line arguments are as follows:

python main.py --boardSz 2 --batchSz 100 --testBatchSz 100 --nEpoch 50 --testPct 0.1 --save SAVE --work 'work' optnet

I get errors arising from models.py:

TypeError: super() takes at least 1 argument (0 given)
> /home/robotec/Documents/NNs/optnet/sudoku/models.py(80)__init__()
     78     """ Solve a single SQP iteration of the scheduling problem"""
     79     def __init__(self, n, Qpenalty):
---> 80         super().__init__()
     81         nx = (n**2)**3
     82         self.Q = Variable(Qpenalty*torch.eye(nx).double().cuda())

ipdb> 

Did I miss something?

bamos commented 7 years ago

You should be using Python 3.4+ with a source install of the master branch of our customized PyTorch lib (https://github.com/locuslab/pytorch). Can you make sure that you're using both of these?

robotsorcerer commented 7 years ago

Ah! That must be the problem. I am on Python 2.7 and using Soumith's version of pytorch. I will upgrade these and see what happens.

robotsorcerer commented 7 years ago

A little quick question. I am not too sure what the parameters in the QPFunction denote. For example, initializing QPFunction as in your OptNet readme example as follows:

class OptNet(nn.Module):
    def __init__(self, nFeatures, nHidden, nCls, bn, nineq=200, neq=0, eps=1e-4):
   ...   ...

I can deduce that nineq stands for number of inequalities andneqstands for thenumber of equalities. What does nFeatures andnCls` stand for? For example, I want to minimize the following inequality-constrained optimization problem:

    minimize \sum_{i = 0}^{i=N} (y(x)_i - \hat{y}(x)_i)^2
       subject to x_j + s_j = 1,  (j = 1, 2, ..., 6)
                       x_k - s_k = 0.1 (k= 1, 2,...6)

where y is a function of x that some previous neural network layers have modeled. The constraints are meant to be added as a penultimate layer before a FC output layer in order to get a realistic output. So I have twelve inequality constraints (augmented by slack variables (i.e. s) in the equations above). My neq variable would be 0. What does nFeatures and nCls stand for in your code?

bamos commented 7 years ago

Hi @lakehanne - sorry for missing this message earlier. The OptNet model you posted is to solve a classification problem with nFeatures features and nCls classes and is not a general QP form. For your problem, you should use our qpth library directly (which this repo also uses): https://locuslab.github.io/qpth/

The qpth library is currently missing inline source code documentation, but I hope the usage is clear enough from the examples I put on the website for now. Join in on this discussion if you think there's anything unclear in the usage that I should add to the docs: https://github.com/locuslab/qpth/issues/3

-Brandon.

robotsorcerer commented 7 years ago

Thank you for your [albeit late :)] reply. I have come a bit of a distance understanding your qpth code more. There is a btrifact() call that you make in the pre_factor_kkt function. When I run my own code, I get errors of the following sort:

/home/robotec/anaconda2/lib/python2.7/site-packages/qpth-0.0.2-py2.7.egg/qpth/solvers/pdipm/batch.pyc in pre_factor_kkt(Q=
(0 ,.,.) = 

Columns 0 to 8 
   0.1000  0.0000 ...000  0.1000
[torch.DoubleTensor of size 5x12x12]
, G=
(0 ,.,.) = 
  -1   0   0   0   0   0   0   0   ...  0   0   1
[torch.DoubleTensor of size 5x12x12]
, A=
(0 ,.,.) = 

Columns 0 to 8 
   ;... ...
)
    262     # for more details.
    263 
--> 264     G_invQ_GT = torch.bmm(G, G.transpose(1,2).btrisolve(*Q_LU))
    265     R = G_invQ_GT.clone()
    266     S_LU_pivots = torch.IntTensor(range(1,1+neq+nineq)).unsqueeze(0) \

RuntimeError: Unimplemented at /py/conda-bld/pytorch_1490977962696/work/torch/lib/TH/generic/THTensorLapack.c:1067
> /home/robotec/catkin_ws/src/RAL2017/pyrnn/src/build/bdist.linux-x86_64/egg/qpth/solvers/pdipm/batch.py(264)pre_factor_kkt()

I tried seeing if the brtifact() function would work on an ordinatry tensor but it gives the same error. Is there any special thing one needs to do when using the original pytorch with qpth? I just reinstalled pytorch with conda and qpth with pip this afternoon.

bamos commented 7 years ago

Hi @lakehanne - oops, I accidentally made a typo in the CPU version of btrisolve that makes it throw this unimplemented error when it actually is implemented. I just fixed this in the pytorch source code and sent in a PR to their master branch: https://github.com/pytorch/pytorch/pull/1185.

Can you try again after re-compiling pytorch with this change added?

-Brandon.

bamos commented 7 years ago

Also @lakehanne - the CPU version may be prohibitively slow and you might need access to a GPU to train within a reasonable period of time. I implemented the CPU versions of btrifact and btrisolve for compatibility with deployed models, not for training.

robotsorcerer commented 7 years ago

I know, haha! My problem size is a small adaptive control problem, not the gargantuan deep learning problems you guys solve :)