cvignac / DiGress

code for the paper "DiGress: Discrete Denoising diffusion for graph generation"
MIT License
349 stars 73 forks source link

Reporting KL divergence loss for training step #86

Open chinmay5 opened 8 months ago

chinmay5 commented 8 months ago

Thank you for releasing the code. I am using a custom dataset with 10k graphs. I tried to update the code to include the kl divergence during training to check if there is overfitting on the smaller dataset. While the PosMSE seems fine, the results for E_kl and X_kl always give a nan for the training samples. Can you please tell me if there is something wrong with my approach?

self.train_metrics = torchmetrics.MetricCollection([custom_metrics.PosMSE(), custom_metrics.XKl(), custom_metrics.EKl()])

In my training_step, I invoke nll, log_dict = self.compute_train_nll_loss(pred, z_t, clean_data=dense_data)

Finally, the method definition is

def compute_train_nll_loss(self, pred, z_t, clean_data):

    node_mask = z_t.node_mask
    t_int = z_t.t_int
    s_int = t_int - 1
    logger_metric = self.train_metrics
    # 1.
    N = node_mask.sum(1).long()
    log_pN = self.node_dist.log_prob(N)

    # 2. The KL between q(z_T | x) and p(z_T) = Uniform(1/num_classes). Should be close to zero.
    kl_prior = self.kl_prior(clean_data, node_mask)

    # 3. Diffusion loss
    loss_all_t = self.compute_Lt(clean_data, pred, z_t, s_int, node_mask, logger_metric)

    # Combine terms
    nlls = - log_pN + kl_prior + loss_all_t
    # Update NLL metric object and return batch nll
    nll = self.train_nll(nlls)  # Average over the batch

    log_dict = {"train kl prior": kl_prior.mean(),
                "Estimator loss terms": loss_all_t.mean(),
                "log_pn": log_pN.mean(),
                'train_nll': nll}
    return nll, log_dict

Any help would be highly appreciated.

Best, Chinmay