atong01 / conditional-flow-matching

TorchCFM: a Conditional Flow Matching library
https://arxiv.org/abs/2302.00482
MIT License
1.25k stars 101 forks source link

Variance reduction over minibatches? #91

Closed ameya98 closed 8 months ago

ameya98 commented 11 months ago

Hi, thanks for this great implementation! I was wondering where the variance reduction method described in Appendix C.1 is present in this code. Also, is there an easy way to reproduce the results of Appendix D.1 on variance reduction?

atong01 commented 10 months ago

Hi @ameya98,

The code I used is here: https://github.com/atong01/conditional-flow-matching/blob/6b3adb46750df045e9285f0b2833382ed87cec7a/runner/src/models/cfm_module.py#L203-L216

I believe something like this command will get you close to the results in D.1

#!/bin/bash
python src/train.py -m experiment=cfm \
  model=sbcfm,cfm,fm \
  +model.avg_size=1,2,4,8,16,32,64,128,256,512 \
  datamodule=twodim \
  datamodule.batch_size=512 \
  trainer.check_val_every_n_epoch=1 \
  trainer.max_epochs=1000 \
  seed=42,43,44,45,46 &
kilianFatras commented 10 months ago

@ameya98 did it work for you?