Closed janfb closed 4 months ago
Hi @ningyuxin1999
the PR under which you commented is merged into the main
branch, but has not been released yet. Therefore, you still get the warning and the error. We will make the new release soon, probably in August. In the meantime however, there is a fix you can apply by just setting up the embedding net accordingly.
Essentially, you have to pass the device only once, i.e., when you create the inference class. Importantly, your custom embedding net should not be on the CUDA device yet, it will be moved internally.
net = YourCustomEmbeddingNet(...) # lives on the CPU, output layers returns 10 units
neural_posterior = posterior_nn(model="maf", embedding_net=net, hidden_features=10, num_transforms=2)
inference = SNPE(prior=prior, device="cuda", density_estimator=neural_posterior)
Then, SNPE
will take care of concatenating your embedding net and the MAF density estimator and then moving it to the desired device for training.
The same applies to the data: sbi
takes care of the device handling of the data. In most cases, it is even better to keep the data on the CPU and move only the batches to GPU during training.
The remaining code you posted will probably not work, I think there is a misunderstanding. You do not need to code an explicit training loop. All you need to do is appending the data to SNPE
and calling train()
. Like this:
inference.append_simulations(your_sampled_theta, your_simulated_data)
inference.train()
posterior = inference.build_posterior()
(If you really have to use your custom dataloader (e.g., because your data does not fit into the RAM at once), then things become a bit tricky, but there has been a similar issue, see https://github.com/sbi-dev/sbi/discussions/1193)
Moving the comment by @ningyuxin1999 under https://github.com/sbi-dev/sbi/pull/1186 here for discussion:
Here are the errors I got from it:
I tried with cpu device as well, but got:
I'm not sure if I understood it correctly, could you maybe help? I really appriciate that.
Originally posted by @ningyuxin1999 in https://github.com/sbi-dev/sbi/issues/1186#issuecomment-2245461488