Closed cyoahs closed 3 years ago
Thanks for noticing and reporting the issue! Indeed, device placement was not considered with this code; it may probably require quite a bit more work than that, as I see tensors created without device specification in the init method as well.
@toshas, thank you for contributing this implementation, as it looks like https://github.com/pytorch/pytorch/pull/32377 has gone dormant. Started catching up on the state of truncated normal distributions in PyTorch after coming into an applied inference project in which truncated Gaussians are critical for establishing the priors. It looks like your TruncatedNormal
implementation is the most mature one so far. Do you have any plans on expanding the CUDA integration to resolve the icdf()
error? Much appreciation for your time.
Now it all should be working as expected.
Thank you, @toshas! I'll be sure to cite this when writing the manuscript for the corresponding project. Are you going to use this in one of your papers? Otherwise, I'll just cite the Github link.
Hello~ Thanks for this implementation!
In
TruncatedStandardNormal.rsample()
method, a new tensor is created and will be defaultly located on cpu as shown in thie line. When theloc
andscale
are from cuda, error will occur inicdf()
.Since
torch.Distribution
class does not have ato(device)
method, i think line 97 can be changed to