toshas / torch_truncnorm

Truncated Normal Distribution in PyTorch
BSD 3-Clause "New" or "Revised" License
79 stars 13 forks source link

device issue in `rsample()` #2

Closed cyoahs closed 3 years ago

cyoahs commented 3 years ago

Hello~ Thanks for this implementation!

In TruncatedStandardNormal.rsample() method, a new tensor is created and will be defaultly located on cpu as shown in thie line. When the loc and scale are from cuda, error will occur in icdf().

Since torch.Distribution class does not have a to(device) method, i think line 97 can be changed to

p = torch.empty(shape).uniform_(self._dtype_min_gt_0, self._dtype_max_lt_1).to(self._big_phi_a.device)
toshas commented 3 years ago

Thanks for noticing and reporting the issue! Indeed, device placement was not considered with this code; it may probably require quite a bit more work than that, as I see tensors created without device specification in the init method as well.

wallyxie commented 3 years ago

@toshas, thank you for contributing this implementation, as it looks like https://github.com/pytorch/pytorch/pull/32377 has gone dormant. Started catching up on the state of truncated normal distributions in PyTorch after coming into an applied inference project in which truncated Gaussians are critical for establishing the priors. It looks like your TruncatedNormal implementation is the most mature one so far. Do you have any plans on expanding the CUDA integration to resolve the icdf() error? Much appreciation for your time.

toshas commented 3 years ago

Now it all should be working as expected.

wallyxie commented 3 years ago

Thank you, @toshas! I'll be sure to cite this when writing the manuscript for the corresponding project. Are you going to use this in one of your papers? Otherwise, I'll just cite the Github link.