pnnl-predictive-phenomics / emll

GNU General Public License v2.0
4 stars 0 forks source link

`Scan` operation causing significant overhead #23

Open mcnaughtonadm opened 1 month ago

mcnaughtonadm commented 1 month ago

When applying the linlog model for use in pymc inference, we face quite significant inference time (2 days - 10 days). This is obviously not ideal for any productive application of the workflow. Utilizing pymc built-in profiling methods, we can narrow the issue down to the scan operation occurring in emll. Specifically for https://github.com/pnnl-predictive-phenomics/emll/blob/bf4eca3dfb0f20bb4db3df2e61bbdae6ebcfc32e/src/emll/linlog_model.py#L120-L166

This contains the line

        if method == "scan":
            xn, _ = pytensor.scan(
                lambda A, b: self.solve_pytensor(A, b), sequences=[As, bs], strict=True
            )

which is where scan is being used by our code. We need to find a way to optimize this method of determining xn because it my have worked with theano, but it is definitely struggling with pytensor.

djinnome commented 1 month ago

Instead of solve_pytensor, should we be using the Cholesky solve? I thought that was the secret sauce. https://github.com/pnnl-predictive-phenomics/emll/blob/bf4eca3dfb0f20bb4db3df2e61bbdae6ebcfc32e/src/emll/linlog_model.py#L266-L271 @pstjohn?

djinnome commented 1 month ago

Nevermind, that is Cholesky solve using scipy. I don't know if pytensor has cholesky solvers.

djinnome commented 1 month ago

Should we try HMC? It has gotten a lot faster

djinnome commented 1 month ago

Nevermind. You still have the same bottleneck

mcnaughtonadm commented 1 month ago

Yea so for my initial "profiling", I ran the following snippet:

model.profile(model.logp()).summary()

which is agnostic of any inference, and only takes the log probability of a state in the pymc model itself. So any slowdown is occurring in the model formulation, and not the inference over the model.

But the slowdown does follow the model into the inference step, hence the problem.

mcnaughtonadm commented 1 month ago

Nevermind, that is Cholesky solve using scipy. I don't know if pytensor has cholesky solvers.

Looking around the PyTensor docs, they do seem to have other solving methods available. Mainly Cholesky and Triangular. The original Theano implementation also had access to these, just in a different way. Do you think a Cholesky solve of the system of equations would improve performance over a standard solve?

The CholeskySolve class can be found here. https://github.com/pymc-devs/pytensor/blob/bb028ae2330433755b9d4aa32ab6e8d0c9f662fc/pytensor/tensor/slinalg.py#L237

The standard Solve class that the LeastSquaresSolve that emll uses is here: https://github.com/pymc-devs/pytensor/blob/bb028ae2330433755b9d4aa32ab6e8d0c9f662fc/pytensor/tensor/slinalg.py#L366

djinnome commented 1 month ago

I would say it is worth a shot. Peter St John felt that the absence of the Cholesky solver in PyTorch and Tensorflow was a reason not to try porting emll to those frameworks.

mcnaughtonadm commented 1 month ago

I am also trying something else that I noticed. In theano, the Solve class inherits directly from Op. But in PyTensor there is a SolveBase that inherits directly from Op and Solve inherits from SolveBase. I am seeing if this is causing unnecessary computation and switching the LeastSquaresSolve in emll to pull from SolveBase over Solve.