Closed f-dangel closed 4 months ago
Hi Felix :)
Thanks a lot for the PR! I re-ran the tests (with the changes you propose) but I couldn't replicate the error.
pytest .\test_cg.py
=================================== test session starts ===================================
platform win32 -- Python 3.12.3, pytest-8.2.2, pluggy-1.5.0
rootdir: C:\Users\Lukas\PromotionOffline\PyTorchHessianFree
configfile: pyproject.toml
collected 117 items
test_cg.py ......................................................................... [ 62%]
............................................ [100%]
=================================== 117 passed in 3.20s ===================================
(pytorch_hf) PS C:\Users\Lukas\PromotionOffline\PyTorchHessianFree\tests> pip show torch
Name: torch
Version: 2.3.0
Summary: Tensors and Dynamic neural networks in Python with strong GPU acceleration
Home-page: https://pytorch.org/
Author: PyTorch Team
Author-email: packages@pytorch.org
License: BSD-3
Location: C:\Users\Lukas\anaconda3\envs\pytorch_hf\Lib\site-packages
Requires: filelock, fsspec, jinja2, mkl, networkx, sympy, typing-extensions
Required-by: backpack-for-pytorch, hessianfree, torchvision, unfoldNd
(pytorch_hf) PS C:\Users\Lukas\PromotionOffline\PyTorchHessianFree\tests> python
Python 3.12.3 | packaged by Anaconda, Inc. | (main, May 6 2024, 19:42:21) [MSC v.1916 64 bit (AMD64)] on win32
Type "help", "copyright", "credits" or "license" for more information.
>>> import torch
>>> torch.get_default_dtype()
torch.float32
>>> exit()
Which pytorch version are you using?
I am using torch==2.2.0
and python==3.9.16
, on a Mac. Could be OS-related.
Ok. Could you run test_cg.py
with torch.set_default_dtype(torch.float64)
to check if this is due to some rounding error?
Yes, that made the tests pass!
Hi,
I am using your optimizer and setting up the curvature matrix-vector products (
mvp
) and the gradients (grad
) of thestep
function myself. In that case, one can turn off the autodiff when evaluating the loss and model output inside the optimizer, which should reduce memory because the computation graph will not be stored.Note 1: I was executing the tests, and got the following fails (both seem to be unrelated to the lines I changed):
Note 2: I am using
contextlib.nullcontext()
, which was introduced in Python 3.7. Since Python 3.8 will be deprecated soon, I think this should not be a problem.