HH197 / Grad-Learn

Large-Scale Statistical Learning and Inference with Gradient Descent, PyTorch, and Pyro
https://hh197.github.io/Grad-Learn/
MIT License
1 stars 0 forks source link

Nan in p matrix when running with batch covariates #1

Open jg9zk opened 1 year ago

jg9zk commented 1 year ago

First of all, thanks so much for making this! I've been wanting to use the ZINB-Wave method on my dataset, but it was too big to run.

I got your implementation to run without including batch variables. When I add in the batch variables, it can run for some of the data, but an error appears eventually.

groups = np.unique(adata.obs['GEO']).shape[0]
one_hot_encoding = one_hot(torch.tensor(list(adata.obs['GEO'].cat.codes)).long(), num_classes=groups).reshape((-1,groups))
X = torch.ones((batch_size, 1+groups))

PATH = './zinb_grad_constants_geo/'
model = ZINB_grad.ZINB_Grad(Y = torch.randint(0, 100, size = size), X = X, K = K, device=device).to(device)

for i, data in enumerate(data_loader):

    batch = data["X"].to(device, dtype=torch.int32)
    X[:,1:] = one_hot_encoding[(batch_size*i):(batch_size*(i+1))]

    # Using the alphas, betas, and theta from the dummy model.
    model = ZINB_grad.ZINB_Grad(Y = batch, K = K, X = X, device = device,
                            alpha_mu = model.alpha_mu,
                            alpha_pi = model.alpha_pi,
                            beta_mu = model.beta_mu,
                            beta_pi =  model.beta_pi, 
                            log_theta = model.log_theta).to(device)

    optimizer = torch.optim.Adam(model.parameters(), lr = 0.1, weight_decay=.001)

    losses, neg_log_liks = ZINB_grad.train_ZINB(batch, optimizer, model, epochs = n_epochs)

ValueError: Expected parameter probs (Tensor of shape (9556, 3000)) of distribution ZeroInflatedNegativeBinomial(gate_logits: torch.Size([9556, 3000])) to satisfy the constraint HalfOpenInterval(lower_bound=0.0, upper_bound=1.0), but found invalid values:

I think it's saying the issue is with the matrix supplied to probs in _loss(), which is calculated in forward(). The p matrix seems to sometimes get a Nan in one of its elements. Everything else is within the bounds the model wants.

self.mu[torch.isnan(p)] yields Inf, so something is happening there. self.log_mu[torch.isnan(p)] is about 88.8, which exceeds pytorch's default float32 limit.

jg9zk commented 1 year ago

I modified the forward() function to cast self.mu and self.theta as float64 tensors. Therefore, p is also float64, so I recast p to float32 like the model expects. Seems to work now with minimal impact on speed.

def forward(self, x):
    """
     The forward method of class Module in torch.nn
     Parameters
     ----------
     x : torch.Tensor
         Tensor of shape (n_samples, n_features).
    Returns
    -------
    p : torch.Tensor
         Tensor of shape (n_samples, n_features) which is the probability of failure
          for each element of data in the ZINB distribution.
    """
    self.log_mu = (
        self.X @ self.beta_mu + self.gamma_mu.T @ self.V.T + self.W @ self.alpha_mu)
    self.log_pi = (
        self.X @ self.beta_pi + self.gamma_pi.T @ self.V.T + self.W @ self.alpha_pi)

    self.mu = torch.exp(self.log_mu.double())
    self.theta = torch.exp(self.log_theta.double())

    # Adaptive regulatory parameters are applied:
    p = self.mu / (self.mu + self.theta + 1e-4 + 1e-4 * self.mu + 1e-4 * self.theta)
    p = p.float()

    return p
jg9zk commented 1 year ago

I though the above worked, but self.mu can still get to inf in float64. Since p is close to 1 when self.mu is inf, I replace all nans in p with 1-1e-10. I don't think the tensor type recasting is necessary anymore, but I didn't remove it.

def forward(self, x):
    """
     The forward method of class Module in torch.nn
     Parameters
     ----------
     x : torch.Tensor
         Tensor of shape (n_samples, n_features).
    Returns
    -------
    p : torch.Tensor
         Tensor of shape (n_samples, n_features) which is the probability of failure
          for each element of data in the ZINB distribution.
    """
    self.log_mu = (
        self.X @ self.beta_mu + self.gamma_mu.T @ self.V.T + self.W @ self.alpha_mu)
    self.log_pi = (
        self.X @ self.beta_pi + self.gamma_pi.T @ self.V.T + self.W @ self.alpha_pi)

    self.mu = torch.exp(self.log_mu.double())
    self.theta = torch.exp(self.log_theta.double())

    # Adaptive regulatory parameters are applied:
    p = self.mu / (self.mu + self.theta + 1e-4 + 1e-4 * self.mu + 1e-4 * self.theta)
    p = p.float()
    p[torch.isnan(p)] = 1-1e-10

    return p
HH197 commented 1 year ago

Thank you so much for raising the issue and providing potential solutions.

I have reviewed the issue and would like to provide my assistance. Could you please provide the data to reproduce the issue?

I will start investigating this issue and will provide updates as I make progress. If you have any additional information or thoughts, please feel free to share them.

Looking forward to resolving this issue together!