TyXe-BDL / TyXe

MIT License
144 stars 33 forks source link

Update to pyro 1.8.1 #18

Closed icfly2 closed 2 years ago

icfly2 commented 2 years ago

I've started the process of updating the dependencies and added an example for classification.

Things changed to enable move to pyro-ppl 1.8.1:

bnn.py:207 triggers

ValueError: at site "likelihood.data", invalid log_prob shape
  Expected [20], actual [20, 20]
  Try one of the following fixes:
  - enclose the batched tensor in a with pyro.plate(...): context
  - .to_event(...) the distribution being sampled
  - .permute() data dimensions

I haven't found what it refers to as when i step through it, all pyro.sample statements are in plates. The relevant example in pyro eight_schools has also been updated for the newer versions of pyro to include these pyro.plate contexts.

I'd appreciate input on the approach for the classification model as well as help with fixing these last two errors.

hpplyt commented 2 years ago

Cool, thanks for starting to look into this!

Regarding the VI test, that should just be a matter of adapting the optimization parameters, I think I had set it up fairly coarsely, so if the randomness changed that could affect the test. I have a similar test sitting around in another codebase, I'll compare the setup.

For the likelihood shape: I think you need to pass in event_dim=1 when you're setting up the obs_model, your data and network outputs have a trailing 1 dimension. Otherwise make the y a vector and have a final Squeeze layer in your net. That being said, I'd rather only fix issues related to bumping the pyro version in this PR and keep new examples to a separate PR -- but more than happy to include new examples of course!

Which examples are failing for you specifically? I remember there being an issue with the PretrainedInitializer in examples/resnet.py when trying to bump to pyro=1.4.0, I think that might be fixable by updating the default for the prefix argument in the from_net function to "net".

That being said, I'd rather just force the pyro version to 1.8.1 than handle the possibility of different versions. That's how pyro handles pytorch versions as well if I'm not mistaken and I think it's only a fairly minimal loss of backwards-compatibility but a significant simplification of the codebase.

icfly2 commented 2 years ago

I have made a few decisions regarding the new names of the pyro parameters. It adds net. to the parameters, but not for all methods invoke.

The vcl.py example fails when the prior is updated to the DictPrior. Either a workaround can try names or net. + names or the example needs more adjustments.

There is also a script to test all the examples (bar two) (but not all options) to make sure they run for future development. Please have a look and let me know where you disagree with the direction. And any help on the bnn test would be much appreciated.

icfly2 commented 2 years ago

So a bit more digging and it turns out that net_guide.scale_tril returns a unitriangular matrix rather than a lower triangular in the newer versions of pyro. This is not due to the change from torch.cholesky to torch.linalg.cholesky as I've checked and they return the same as expected and documented.

This can be remedied by multiplying the result with the scale vector, recreating the old triangular matrix. I don't know where else this change would have implications and I've not found the part in the pyro changelog where this change was documented.

Tests now pass. Last remaining issue is the vsl example's dictprior naming and I'm tempted to leave that for now and fix it after the PR.

hpplyt commented 2 years ago

Thanks for the updates!

I have made a few decisions regarding the new names of the pyro parameters. It adds net. to the parameters, but not for all methods invoke.

That's not quite I had in mind, I was thinking of just changing the prefix argument to "net" and not introduce a second one. As this is passed into the call to named_parameters, it will be added as a prefix to all name variables in the loop, i.e. take care of what you're doing manually in the assignment to the dictionary. Or have I overlooked something and this doesn't work?

The vcl.py example fails when the prior is updated to the DictPrior. Either a workaround can try names or net. + names or the example needs more adjustments.

I think this might be fixed by replacing tyxe.util.pyro_sample_sites(bnn.net) with tyxe.util.pyro_sample_sites(bnn) in the script. Let me know in case that doesn't do it, there could also be something going wrong downstream in the autoguide.

There is also a script to test all the examples (bar two) (but not all options) to make sure they run for future development.

Awesome, I'll need to look into setting up continuous integration to run this automatically.

So a bit more digging and it turns out that net_guide.scale_tril returns a unitriangular matrix rather than a lower triangular in the newer versions of pyro. This is not due to the change from torch.cholesky to torch.linalg.cholesky as I've checked and they return the same as expected and documented.

Nice catch! Seems like they've changed the parametrization of the AutoMultivariateNormalautoguide covariance.

This can be remedied by multiplying the result with the scale vector, recreating the old triangular matrix.

Yes that should do it.

I don't know where else this change would have implications

I don't think we're using the multivariate Normal autoguide outside of testing, so that should be it.

Last remaining issue is the vsl example's dictprior naming and I'm tempted to leave that for now and fix it after the PR.

I'd rather not leave things in an inconsistent state, i.e. everything should work if we're bumping the pyro version. I'll need to go through your changes in a bit more detail anyway, I should hopefully be able to find some time over the weekend, so I can also finish things up if you want to leave it at this. Thanks a lot again for your help already!

icfly2 commented 2 years ago

I've done my best to address your comments.

See the comments in the files for the remaining issues.

I'll contribute the github actions for release to pypi and testing after this PR, if you like.

hpplyt commented 2 years ago

Left you a few comments. Thanks again for the updates. The big TODOs left at this point for me are:

Let me know if I forgot anything :)

Btw should I have access to your branch? I think there's a github option to set up PRs in that way. In that case I can obviously make some code changes as well.

icfly2 commented 2 years ago

Left you a few comments. Thanks again for the updates. The big TODOs left at this point for me are:

Thanks a lot, i'll see what I can chew through in the next few days, most are fairy minor. I agree with all your points.

Btw should I have access to your branch? I think there's a github option to set up PRs in that way. In that case I can obviously make some code changes as well.

the branch isn't protected, but I guess the way GitHub works you'd have to make PRs against my branch. If this turns into a continuing effort, I guess I could be added as a collaborator on the original repo and then make branches there. But if you know of a setting I'm missing let me know, I'm not that familiar with GitHub.

icfly2 commented 2 years ago

Ok I think I've adressed everything i can. Only the deprecation warnings are left adn then the actualy push to PyPi but i guess they are for after the PR is merged. Let me know if there is anything else I've missed.

hpplyt commented 2 years ago

Thanks could you just also remove the regression example?

Also for the prefix, I meant having it as an argument on the get_detached_distributions of the AutoNormal autoguide class. So if you could move that functionality there and revert the changes to DictPrior that would be great. Sorry if I hadn't been clear/am being pedantic, but I think it makes a lot more sense to pull these distribution names correctly out of the guide in case anyone wants to use them downstream. And have a look at how pytorch handles prefixes, there's no need to include a "." in the string just add it when concatenating it with the name -- makes calling the function easier by one character :)

Other little bits to clean up (I can handle those as well after merging):

icfly2 commented 2 years ago

Not 100% sure I follow. These functions are only accessed internally in tyxe functions and the import in utils is just a renaming. Unless you want to add the check for the version of pyro back, this won't run anyway. Also as there is no PyPi package, I'm not sure how you would push the deprecation warning to users before it is actually deprecated. but maybe I'm jst misunderstanding the whole thing, in which case please just change as you see fit.

icfly2 commented 2 years ago

Other than the depreciation warning, where I don't quite understand what you're trying to achieve, I think I have addressed all your points. I appreciate the detailed, nitpicking feedback, that is how things progress well.

hpplyt commented 2 years ago

Cool, looks good to merge now. I was playing around with the whole prefix issue for the prior update and found a nicer way of fixing this directly in the BNN class, I'll push that after merging your PR (I was actually wrong that the issue came from the guide, the net prefix is handled correctly there out of the box now).

Regarding the deprecation warning: I don't want to remove public facing functions without a warning. I don't know if anyone might be using them and if they would pull updates to the codebase, it just seems like the safer choice to me than outright deleting them.

Glad you appreciate the feedback, I hope it wasn't too much back and forth about small details!