Closed janfb closed 3 months ago
Attention: Patch coverage is 92.45283%
with 4 lines
in your changes missing coverage. Please review.
Project coverage is 75.56%. Comparing base (
337f072
) to head (c0f5a6d
).
Hi, I've been trying to use sbi and torch for not very long time. I used the following script to train SNPE. The input net and training data are all of device "cuda". Based on what you updated, does it mean that the SNPE should only trained on cpu?
...
neural_posterior = posterior_nn(model="maf", embedding_net=net, hidden_features=10, num_transforms=2)
inference = SNPE(prior=prior, device="cuda", density_estimator=neural_posterior)
for e in range(num_epoch):
errors_train = []
for i, (scen_idx, inputs_train, targets_train) in enumerate(train_loader):
inputs_train, targets_train = inputs_train.to(device), targets_train.to(device)
embedded_sbi = inference.append_simulations(targets_train, inputs_train).train()
posterior = DirectPosterior(posterior_estimator=embedded_sbi, prior=prior, device="cuda")
...
Here are the errors I got from it:
/ceph/ibmi/mcpg/ningyuxin/.conda/envs/sbi38/lib/python3.8/site-packages/sbi/utils/user_input_checks.py:444: UserWarning: Mismatch between the device of the data fed to the embedding_net and the device of the embedding_net's weights. Fed data has device 'cpu' vs embedding_net weights have device 'cuda:0'. Automatically switching the embedding_net's device to 'cpu', which could otherwise be done manually using the line `embedding_net.to('cpu')`.
warnings.warn(
Traceback (most recent call last):
File "/ceph/ibmi/mcpg/ningyuxin/.conda/envs/sbi38/lib/python3.8/runpy.py", line 194, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/ceph/ibmi/mcpg/ningyuxin/.conda/envs/sbi38/lib/python3.8/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "/ceph/ibmi/mcpg/ningyuxin/popGen/Yuxin_simulatePop/SPIDNA/flora_code/SPIDNA_train.py", line 346, in <module>
logging.info(start_training(args.training_param_path))
File "/ceph/ibmi/mcpg/ningyuxin/popGen/Yuxin_simulatePop/SPIDNA/flora_code/SPIDNA_train.py", line 292, in start_training
best_errors_grouped_val = training_loop(prior = prior, net=net, train_loader=train_loader, validation_loader=validation_loader,
File "/ceph/ibmi/mcpg/ningyuxin/popGen/Yuxin_simulatePop/SPIDNA/flora_code/SPIDNA_train.py", line 141, in training_loop
embedded_sbi = inference.append_simulations(targets_train, inputs_train).train()
File "/ceph/ibmi/mcpg/ningyuxin/.conda/envs/sbi38/lib/python3.8/site-packages/sbi/inference/snpe/snpe_c.py", line 180, in train
return super().train(**kwargs)
File "/ceph/ibmi/mcpg/ningyuxin/.conda/envs/sbi38/lib/python3.8/site-packages/sbi/inference/snpe/snpe_base.py", line 317, in train
self._neural_net = self._build_neural_net(
File "/ceph/ibmi/mcpg/ningyuxin/.conda/envs/sbi38/lib/python3.8/site-packages/sbi/utils/get_nn_models.py", line 265, in build_fn
return build_maf(batch_x=batch_theta, batch_y=batch_x, **kwargs)
File "/ceph/ibmi/mcpg/ningyuxin/.conda/envs/sbi38/lib/python3.8/site-packages/sbi/neural_nets/flow.py", line 139, in build_maf
y_numel = embedding_net(batch_y[:1]).numel()
File "/ceph/ibmi/mcpg/ningyuxin/.conda/envs/sbi38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
return self._call_impl(*args, **kwargs)
File "/ceph/ibmi/mcpg/ningyuxin/.conda/envs/sbi38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
return forward_call(*args, **kwargs)
File "/ceph/ibmi/mcpg/ningyuxin/.conda/envs/sbi38/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py", line 171, in forward
raise RuntimeError("module must have its parameters and buffers "
RuntimeError: module must have its parameters and buffers on device cuda:0 (device_ids[0]) but found one of them on device: cpu
I tried with cpu device as well, but got:
warnings.warn(
Traceback (most recent call last):
File "/ceph/ibmi/mcpg/ningyuxin/.conda/envs/sbi38/lib/python3.8/runpy.py", line 194, in _run_module_as_main
return _run_code(code, main_globals, None,
File "/ceph/ibmi/mcpg/ningyuxin/.conda/envs/sbi38/lib/python3.8/runpy.py", line 87, in _run_code
exec(code, run_globals)
File "/ceph/ibmi/mcpg/ningyuxin/popGen/Yuxin_simulatePop/SPIDNA/flora_code/SPIDNA_train.py", line 349, in <module>
logging.info(start_training(args.training_param_path))
File "/ceph/ibmi/mcpg/ningyuxin/popGen/Yuxin_simulatePop/SPIDNA/flora_code/SPIDNA_train.py", line 295, in start_training
best_errors_grouped_val = training_loop(prior = prior, net=net, train_loader=train_loader, validation_loader=validation_loader,
File "/ceph/ibmi/mcpg/ningyuxin/popGen/Yuxin_simulatePop/SPIDNA/flora_code/SPIDNA_train.py", line 131, in training_loop
inference = SNPE(prior=prior,device=device, density_estimator=neural_posterior)
File "/ceph/ibmi/mcpg/ningyuxin/.conda/envs/sbi38/lib/python3.8/site-packages/sbi/inference/snpe/snpe_c.py", line 84, in __init__
super().__init__(**kwargs)
File "/ceph/ibmi/mcpg/ningyuxin/.conda/envs/sbi38/lib/python3.8/site-packages/sbi/inference/snpe/snpe_base.py", line 64, in __init__
super().__init__(
File "/ceph/ibmi/mcpg/ningyuxin/.conda/envs/sbi38/lib/python3.8/site-packages/sbi/inference/base.py", line 111, in __init__
self._device = process_device(device)
File "/ceph/ibmi/mcpg/ningyuxin/.conda/envs/sbi38/lib/python3.8/site-packages/sbi/utils/torchutils.py", line 48, in process_device
torch.cuda.set_device(device)
File "/ceph/ibmi/mcpg/ningyuxin/.conda/envs/sbi38/lib/python3.8/site-packages/torch/cuda/__init__.py", line 397, in set_device
device = _get_device_index(device)
File "/ceph/ibmi/mcpg/ningyuxin/.conda/envs/sbi38/lib/python3.8/site-packages/torch/cuda/_utils.py", line 34, in _get_device_index
raise ValueError(f"Expected a cuda device, but got: {device}")
ValueError: Expected a cuda device, but got: cpu
I'm not sure if I understood it correctly, could you maybe help? I really appriciate that.
Context
The starting point for this PR is #1161, the incorrect warning that embedding net and data device do not match. On the way I realized that we are treating the
embedding_net
as separate net that can have its own device, different from the actual net. I think this does not make sense.In general, the device handling should be centralized, e.g., have a single entry point. At the moment, this entry point is the inference object, e.g.,
SNPE(..., device=device)
. But are the different scenarios:SNPE
: all good, device handling is centralized via thedevice
device
passed toSNPE
.posterior_nn
to build a flow with an embedding net.posterior_nn
normally returns a net on the cpu. but if theembedding_net
passed by the user is on a different device, things might crash.My suggestions
posterior_nn
etc, that the passed embedding net is on the cpu, or we move it there. EDIT: I will add a function that checks the embedding net device and if it is not on cpu, it warns and moves it there.What this PR does so far
embedding_net
device checkingbuild_posterior
(and add test)get_numel
to be used across the neural net factory. I had to put it into a separate utils file because putting it intosbiutils
ortorchutils
causes circular imports 😵fixes #1161