Open jg9zk opened 1 year ago
I modified the forward() function to cast self.mu and self.theta as float64 tensors. Therefore, p is also float64, so I recast p to float32 like the model expects. Seems to work now with minimal impact on speed.
def forward(self, x):
"""
The forward method of class Module in torch.nn
Parameters
----------
x : torch.Tensor
Tensor of shape (n_samples, n_features).
Returns
-------
p : torch.Tensor
Tensor of shape (n_samples, n_features) which is the probability of failure
for each element of data in the ZINB distribution.
"""
self.log_mu = (
self.X @ self.beta_mu + self.gamma_mu.T @ self.V.T + self.W @ self.alpha_mu)
self.log_pi = (
self.X @ self.beta_pi + self.gamma_pi.T @ self.V.T + self.W @ self.alpha_pi)
self.mu = torch.exp(self.log_mu.double())
self.theta = torch.exp(self.log_theta.double())
# Adaptive regulatory parameters are applied:
p = self.mu / (self.mu + self.theta + 1e-4 + 1e-4 * self.mu + 1e-4 * self.theta)
p = p.float()
return p
I though the above worked, but self.mu can still get to inf in float64. Since p is close to 1 when self.mu is inf, I replace all nans in p with 1-1e-10. I don't think the tensor type recasting is necessary anymore, but I didn't remove it.
def forward(self, x):
"""
The forward method of class Module in torch.nn
Parameters
----------
x : torch.Tensor
Tensor of shape (n_samples, n_features).
Returns
-------
p : torch.Tensor
Tensor of shape (n_samples, n_features) which is the probability of failure
for each element of data in the ZINB distribution.
"""
self.log_mu = (
self.X @ self.beta_mu + self.gamma_mu.T @ self.V.T + self.W @ self.alpha_mu)
self.log_pi = (
self.X @ self.beta_pi + self.gamma_pi.T @ self.V.T + self.W @ self.alpha_pi)
self.mu = torch.exp(self.log_mu.double())
self.theta = torch.exp(self.log_theta.double())
# Adaptive regulatory parameters are applied:
p = self.mu / (self.mu + self.theta + 1e-4 + 1e-4 * self.mu + 1e-4 * self.theta)
p = p.float()
p[torch.isnan(p)] = 1-1e-10
return p
Thank you so much for raising the issue and providing potential solutions.
I have reviewed the issue and would like to provide my assistance. Could you please provide the data to reproduce the issue?
I will start investigating this issue and will provide updates as I make progress. If you have any additional information or thoughts, please feel free to share them.
Looking forward to resolving this issue together!
First of all, thanks so much for making this! I've been wanting to use the ZINB-Wave method on my dataset, but it was too big to run.
I got your implementation to run without including batch variables. When I add in the batch variables, it can run for some of the data, but an error appears eventually.
ValueError: Expected parameter probs (Tensor of shape (9556, 3000)) of distribution ZeroInflatedNegativeBinomial(gate_logits: torch.Size([9556, 3000])) to satisfy the constraint HalfOpenInterval(lower_bound=0.0, upper_bound=1.0), but found invalid values:
I think it's saying the issue is with the matrix supplied to probs in _loss(), which is calculated in forward(). The p matrix seems to sometimes get a Nan in one of its elements. Everything else is within the bounds the model wants.
self.mu[torch.isnan(p)] yields Inf, so something is happening there. self.log_mu[torch.isnan(p)] is about 88.8, which exceeds pytorch's default float32 limit.