HannesStark / dirichlet-flow-matching

MIT License
75 stars 12 forks source link

Question about FBD calculation #2

Closed junhaobearxiong closed 4 months ago

junhaobearxiong commented 4 months ago

Hi Hannes!

Thank you for the great work! I'm trying to reproduce the evaluations on the enhancer experiments, but had a few questions about reproducing the FBD scores.

  1. For the results in Table 2 and Figure 5, are the FBDs calculated against the train, validation or test set? I ran the commands for run names FB_dirichlet_fm (for unconditional generation) and dirichlet_fm_target2_guidance20(for conditional generation with target class 2, guidance factor 20). The relevant fields in the outputted metrics.csv seems to be val_fxd_generated_to_allseqs for FBD against the unconditional data distribution (used in Table 2), val_fxd_generated_to_targetclsseqs for FBD against the class-conditional data distribution (used in Figure 5, left y-axis), and cleancls_mean_target_prob for the class-conditional probability (used in Figure 5, right y-axis). But when running the generated sequences in val_0.csv through the FBD calculation against either the train, valid or test set, I can't exactly reproduce the numbers in metrics.csv. The numbers in metrics.csv seem pretty closely (but not exactly) match the FBD calculated against the training set. I just want to confirm if calculating the FBD against the training set is the right way to reproduce the results? Also, how many generated sequences are used in Figure 5?
  2. To calculate the FBD of random sequences in Table 2, the code seems to be here. It's likely I'm missing something, but I'm wondering if this is correct: the code seems to compute the FBD using the random sequences themselves, but shouldn't the FBD be calculated using the embeddings of the random sequences? When using the embeddings of 10k random sequences, I got an unconditional FBD (against training set) of 97.22 instead of 880.38, and a conditional FBD (against class 2 in the training set) of 344.64 instead of 948.26.
  3. The paper mentions that classifier-guidance performs worse than CFG but still has significant improvements. I couldn't seem to find the command to reproduce the classifier-guidance results. Is it possible for those be made available?

Happy to include more details of my attempted reproduced analyses if that's helpful!

Best, Bear

HannesStark commented 4 months ago

Hi Bear!

Thanks for your interest and for trying the code!

  1. Following image generative model evaluations, the FBD metrics are indeed calculated against the training set and we evaluate the ability to reproduce the training distribution. The number of examples to calculate the FBD are the validation set sizes which is 9012 for Melanoma and 10434 for Flybrain. Since the generated sequences are random, it is possible that the produced numbers are slightly different but that difference should be very small. Could you let me know what the differences are that you observe?
  2. Thank you very much for pointing that out. I agree that the random FBD should be calculated against embeddings of random sequences - we will update that in the next version of the paper and acknowledge your help in finding this.
  3. The results for classifier guidance are in Figure 9 (it is much worse than classifier free guidance). Here is an example command where you can change the class and the guidance factor to obtain the results.
python -m train_dna --run_name inf_DNA_target2_clsGuidance3 --batch_size 256 --print_freq 200 --wandb --dataset_type enhancer --num_integration_steps 100 --model cnn --num_cnn_stacks 4 --clean_cls_ckpt_hparams workdir/clsDNAclean_cnn_1stack_2023-12-30_15-01-30/lightning_logs/version_0/hparams.yaml --clean_cls_ckpt workdir/clsDNAclean_cnn_1stack_2023-12-30_15-01-30/epoch=15-step=10480.ckpt --target_class 2 --check_val_every_n_epoch 1 --subset_train_as_val --mode dirichlet --fid_early_stop --max_epochs 741 --limit_train_batches 3 --limit_val_batches 100 --validate_on_test --ckpt workdir/DNA_diri_target2_valOnTrain_noDropoutEval_2024-01-08_10-11-25/epoch=739-step=242720.ckpt --cls_guidance --guidance_scale 3 --cls_ckpt_hparams workdir/clsDNA_lr1e-3_2023-12-31_12-28-43/lightning_logs/version_0/hparams.yaml --cls_ckpt workdir/clsDNA_lr1e-3_2023-12-31_12-28-43/epoch=184-step=121175.ckpt

I uploaded all weights needed for this here: https://publbuck.s3.us-east-2.amazonaws.com/cls_guidance_ckpts.zip

Please let me know if I can help in any other way!

junhaobearxiong commented 4 months ago

Thank you for the response Hannes!

For the unconditional FBD (calculated against the entire training set, with 10434 sampled sequences), I got 15.969 forFB_dirichlet_fm; and for the conditional FBD (calculated against the all sequences in the training set with the target class=2, guidance factor=20), I got 100.706. If these seem reasonable, I will use the same FBD settings when reproducing the other results.

A couple of additional questions:

  1. I calculated the FBD of the validation set against the training set, hoping to get a sense of how to calibrate the results in Table 2 and Figure 5, and got the following: Uncond FBD: 3.44; Cond FBD (cls=35): 99.27; Cond FBD (cls=2): 113.00; Cond FBD (cls=68): 92.99; Cond FBD (cls=16): 40.99 Do you think these serve as reasonable lower bound for the FBDs? One potential issue about this is that the validation set has much few sequences per class, which may lead to higher variance in the conditional FBD score, so we might only want to evaluate on classes with a reasonable number of samples.

  2. Somewhat relatedly, I wanted to evaluate guidance on a few more classes to get a better sense. It was mentioned in the Supplemental section B.2 in the paper that the chosen classes (35, 2, 68, 16) were based on the tradeoff between AUROC and AUPR. However, looking at Figure 11, it's only clear to me that the classifier performs well for class 2. I also tried to reproduce this figure using the provided checkpoint and got something pretty different (plot attached). I think it would make the most sense to evaluate on classes where the clean classifier performs well, since both the FBD and the class probability relies on the clean classifier (i.e. classes in the upper right of the AUROC vs. AUPR plot), so I want to make sure I'm looking at the right plot.

Screen Shot 2024-05-06 at 4 04 09 PM

Happy to provide more details if helpful! Thank you so much!

Best, Bear

HannesStark commented 4 months ago

Likewise!

While I also respond below, would you be interested in a quick zoom call to go back and forth quicker? Also, would be curious to chat just in general. Let me know if so: hannes.staerk@gmail.com

The 15.969 FBD seems reasonable to me. For "(calculated against the all sequences in the training set with the target class=2, guidance factor=20), I got 100.706", the number seems ~8% higher but maybe this could be explained by the higher variance that you mention in your next point 1.

  1. The 3.44 seems a surprising amount better than 5.3, and, e.g., 99.27 is a surprising amount higher than 56 (where the latter could be explained by the variance you mention, but for the former, it would be less likely).

  2. I think you numbered the dots with 0 indexing (which the code also uses) while the appendix plot in the paper starts from 1. Noticing that, the corresponding classes 3, 17 and 69 are close to the right. And we chose the PNG class because Taskiran et al had their sequences available for it. Nevertheless, the AUPR scale is different. The classifier of the plot was among the first we ran (I think overfit with a larger network and without early stopping) and we will clarify in the caption that this plot for choosing the class was not from the same weights as the final evaluation classifier.

junhaobearxiong commented 4 months ago

For the record here, Hannes and I had a call and clarified things. In summary, I didn't have issues reproducing the FBD results from the paper. The main difference was I computed the FBD against the entire training set, whereas the paper computed the FBD against 10434 randomly sampled sequences from the training set. The numbers I mentioned in my point 1 were more for my own exploration, and were not numbers reported in the paper.