wesselb / neuralprocesses

A framework for composing Neural Processes in Python
https://wesselb.github.io/neuralprocesses
MIT License
76 stars 12 forks source link

Add dtype_lik option to nps.predict #21

Closed DrJonnyT closed 5 months ago

DrJonnyT commented 6 months ago

Hi Wessel!

Problem: Apple silicon GPUs don't support FP64. So when running nps.predict on the GPU when using pytorch, there would be an error because dtype_lik was forced to be float64:

TypeError: Cannot convert a MPS Tensor to float64 dtype as the MPS framework doesn't support float64. Please use float32 instead.

Solution: I've added dtype_lik as an optional argument to nps.predict, very similar to how it is implemented in elbo and loglik

wesselb commented 5 months ago

Hey @DrJonnyT! Very sorry for the slow reply. Things have been super hectic lately.

Thanks for the fix. This looks super sensible! Merging right away.

I think the failing CI is due to an unrelated issue.