Closed alvorithm closed 4 years ago
I think the warning is caused by the cast to tensor, if the parameters/observations are tensor already.
The cast to tensor is necessary because some of the simulators are in numpy, e.g., the Lotka Volterra one.
I suggest to refactor the simulation_wrapper:
Similar problem occurs here: https://github.com/mackelab/sbi/blob/68ec669c2236d405bbd10c66ae6792bf4275ff37/sbi/inference/snpe/base_snpe.py#L262-L263
here we can just remove the cast to tensor because it comes from simulation_wrapper
anyways.
Reopening this issue because the UserWarnings are still present on master
Did you update pyknos/nflows?
Sure
are you sure it worked? Yesterday I had to fight quite a bit, i.e. remove them by hand ultimately via pip uninstall
. Have a look at ~/<path-to-miniconda>/envs/sbi/lib/python3.8/site-packages/nflows/transforms/standard.py
, you should find there a PiecewiseAffineTransform
if it's the right tagged version.
I think so:
# pyknos is at d439bb6df487cc2004a44adc2136b8fc13e89a4c
# nflows is at 6d4f5ddcdcf077e155fb6da8d0c2c6f8c3847b35
# sbi is at fb68a06c97ab5426f3c36fa4a068d99b81c8bf93
Error:
/home/ge57bux/sbibm/src/sbi/sbi/simulators/simutils.py:141: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor).
return torch.tensor(parameters), torch.cat(all_x)
As a side note: It would be good if there was a simple command in sbi
that would give the version/commit info of the deps. This is currently not helpful/informative:
from pyknos import version; version.__version__ # == "0.10"
from nflows import version; version.__version__ # == "0.1"
from sbi import version; version.__version__ # == "20200209-Conor_base"
How did you get the commit hashes of the dependencies? I guess with that I could solve your request.
As to the warning in simutils, that's true - I had to build that back because as_tensor
conflicted with multiround SNPE.
We dont need the cast to Tensor
anymore once the user input checks are merged, because then, everything will be Tensor
.
https://github.com/mackelab/sbi/blob/b38cc27fb33f4743f3c33e01e69ed5813146d09f/sbi/simulators/simutils.py#L150
UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requiresgrad(True), rather than torch.tensor(sourceTensor).
Present on two lines (return statements).