sbi-dev / sbi

Simulation-based inference toolkit
https://sbi-dev.github.io/sbi/
Apache License 2.0
599 stars 152 forks source link

How to set up custom embedding nets for CUDA training #1201

Closed janfb closed 4 months ago

janfb commented 4 months ago

Moving the comment by @ningyuxin1999 under https://github.com/sbi-dev/sbi/pull/1186 here for discussion:

Hi, I've been trying to use sbi and torch for not very long time. I used the following script to train SNPE. The input net and training data are all of device "cuda". Based on what you updated, does it mean that the SNPE should only trained on cpu?

...
neural_posterior = posterior_nn(model="maf", embedding_net=net, hidden_features=10, num_transforms=2)
inference = SNPE(prior=prior, device="cuda", density_estimator=neural_posterior)

for e in range(num_epoch):
    errors_train = []
    for i, (scen_idx, inputs_train, targets_train) in enumerate(train_loader):
        inputs_train, targets_train = inputs_train.to(device), targets_train.to(device)
        embedded_sbi = inference.append_simulations(targets_train, inputs_train).train()
        posterior = DirectPosterior(posterior_estimator=embedded_sbi, prior=prior, device="cuda")
        ...

Here are the errors I got from it:

/ceph/ibmi/mcpg/ningyuxin/.conda/envs/sbi38/lib/python3.8/site-packages/sbi/utils/user_input_checks.py:444: UserWarning: Mismatch between the device of the data fed to the embedding_net and the device of the embedding_net's weights. Fed data has device 'cpu' vs embedding_net weights have device 'cuda:0'. Automatically switching the embedding_net's device to 'cpu', which could otherwise be done manually using the line `embedding_net.to('cpu')`.
  warnings.warn(
Traceback (most recent call last):
  File "/ceph/ibmi/mcpg/ningyuxin/.conda/envs/sbi38/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/ceph/ibmi/mcpg/ningyuxin/.conda/envs/sbi38/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/ceph/ibmi/mcpg/ningyuxin/popGen/Yuxin_simulatePop/SPIDNA/flora_code/SPIDNA_train.py", line 346, in <module>
    logging.info(start_training(args.training_param_path))
  File "/ceph/ibmi/mcpg/ningyuxin/popGen/Yuxin_simulatePop/SPIDNA/flora_code/SPIDNA_train.py", line 292, in start_training
    best_errors_grouped_val = training_loop(prior = prior, net=net, train_loader=train_loader, validation_loader=validation_loader,
  File "/ceph/ibmi/mcpg/ningyuxin/popGen/Yuxin_simulatePop/SPIDNA/flora_code/SPIDNA_train.py", line 141, in training_loop
    embedded_sbi = inference.append_simulations(targets_train, inputs_train).train()
  File "/ceph/ibmi/mcpg/ningyuxin/.conda/envs/sbi38/lib/python3.8/site-packages/sbi/inference/snpe/snpe_c.py", line 180, in train
    return super().train(**kwargs)
  File "/ceph/ibmi/mcpg/ningyuxin/.conda/envs/sbi38/lib/python3.8/site-packages/sbi/inference/snpe/snpe_base.py", line 317, in train
    self._neural_net = self._build_neural_net(
  File "/ceph/ibmi/mcpg/ningyuxin/.conda/envs/sbi38/lib/python3.8/site-packages/sbi/utils/get_nn_models.py", line 265, in build_fn
    return build_maf(batch_x=batch_theta, batch_y=batch_x, **kwargs)
  File "/ceph/ibmi/mcpg/ningyuxin/.conda/envs/sbi38/lib/python3.8/site-packages/sbi/neural_nets/flow.py", line 139, in build_maf
    y_numel = embedding_net(batch_y[:1]).numel()
  File "/ceph/ibmi/mcpg/ningyuxin/.conda/envs/sbi38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1532, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/ceph/ibmi/mcpg/ningyuxin/.conda/envs/sbi38/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1541, in _call_impl
    return forward_call(*args, **kwargs)
  File "/ceph/ibmi/mcpg/ningyuxin/.conda/envs/sbi38/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py", line 171, in forward
    raise RuntimeError("module must have its parameters and buffers "
RuntimeError: module must have its parameters and buffers on device cuda:0 (device_ids[0]) but found one of them on device: cpu

I tried with cpu device as well, but got:

  warnings.warn(
Traceback (most recent call last):
  File "/ceph/ibmi/mcpg/ningyuxin/.conda/envs/sbi38/lib/python3.8/runpy.py", line 194, in _run_module_as_main
    return _run_code(code, main_globals, None,
  File "/ceph/ibmi/mcpg/ningyuxin/.conda/envs/sbi38/lib/python3.8/runpy.py", line 87, in _run_code
    exec(code, run_globals)
  File "/ceph/ibmi/mcpg/ningyuxin/popGen/Yuxin_simulatePop/SPIDNA/flora_code/SPIDNA_train.py", line 349, in <module>
    logging.info(start_training(args.training_param_path))
  File "/ceph/ibmi/mcpg/ningyuxin/popGen/Yuxin_simulatePop/SPIDNA/flora_code/SPIDNA_train.py", line 295, in start_training
    best_errors_grouped_val = training_loop(prior = prior, net=net, train_loader=train_loader, validation_loader=validation_loader,
  File "/ceph/ibmi/mcpg/ningyuxin/popGen/Yuxin_simulatePop/SPIDNA/flora_code/SPIDNA_train.py", line 131, in training_loop
    inference = SNPE(prior=prior,device=device, density_estimator=neural_posterior)
  File "/ceph/ibmi/mcpg/ningyuxin/.conda/envs/sbi38/lib/python3.8/site-packages/sbi/inference/snpe/snpe_c.py", line 84, in __init__
    super().__init__(**kwargs)
  File "/ceph/ibmi/mcpg/ningyuxin/.conda/envs/sbi38/lib/python3.8/site-packages/sbi/inference/snpe/snpe_base.py", line 64, in __init__
    super().__init__(
  File "/ceph/ibmi/mcpg/ningyuxin/.conda/envs/sbi38/lib/python3.8/site-packages/sbi/inference/base.py", line 111, in __init__
    self._device = process_device(device)
  File "/ceph/ibmi/mcpg/ningyuxin/.conda/envs/sbi38/lib/python3.8/site-packages/sbi/utils/torchutils.py", line 48, in process_device
    torch.cuda.set_device(device)
  File "/ceph/ibmi/mcpg/ningyuxin/.conda/envs/sbi38/lib/python3.8/site-packages/torch/cuda/__init__.py", line 397, in set_device
    device = _get_device_index(device)
  File "/ceph/ibmi/mcpg/ningyuxin/.conda/envs/sbi38/lib/python3.8/site-packages/torch/cuda/_utils.py", line 34, in _get_device_index
    raise ValueError(f"Expected a cuda device, but got: {device}")
ValueError: Expected a cuda device, but got: cpu

I'm not sure if I understood it correctly, could you maybe help? I really appriciate that.

Originally posted by @ningyuxin1999 in https://github.com/sbi-dev/sbi/issues/1186#issuecomment-2245461488

janfb commented 4 months ago

Hi @ningyuxin1999

the PR under which you commented is merged into the main branch, but has not been released yet. Therefore, you still get the warning and the error. We will make the new release soon, probably in August. In the meantime however, there is a fix you can apply by just setting up the embedding net accordingly.

Essentially, you have to pass the device only once, i.e., when you create the inference class. Importantly, your custom embedding net should not be on the CUDA device yet, it will be moved internally.

net = YourCustomEmbeddingNet(...)  # lives on the CPU, output layers returns 10 units
neural_posterior = posterior_nn(model="maf", embedding_net=net, hidden_features=10, num_transforms=2)
inference = SNPE(prior=prior, device="cuda", density_estimator=neural_posterior)

Then, SNPE will take care of concatenating your embedding net and the MAF density estimator and then moving it to the desired device for training.

The same applies to the data: sbi takes care of the device handling of the data. In most cases, it is even better to keep the data on the CPU and move only the batches to GPU during training.

The remaining code you posted will probably not work, I think there is a misunderstanding. You do not need to code an explicit training loop. All you need to do is appending the data to SNPE and calling train(). Like this:

inference.append_simulations(your_sampled_theta, your_simulated_data)
inference.train()

posterior = inference.build_posterior()

(If you really have to use your custom dataloader (e.g., because your data does not fit into the RAM at once), then things become a bit tricky, but there has been a similar issue, see https://github.com/sbi-dev/sbi/discussions/1193)