microsoft / protein-frame-flow

Fast protein backbone generation with SE(3) flow matching.
MIT License
200 stars 13 forks source link

How should the auxiliary losses be weighted? #15

Closed amorehead closed 9 months ago

amorehead commented 9 months ago

Hello. Thank you for making this work fully open-source.

I had a question regarding the weighting of FrameFlow's auxiliary losses (i.e., backbone atom and pairwise distance losses). Referencing the line of code below, should this read as t[:, 0] > training_cfg.aux_loss_t_pass or instead t[:, 0] > (1 - training_cfg.aux_loss_t_pass) since (as I understand it) the intention behind setting aux_loss_t_pass to 0.25 (by default) is to have the network learn these atomic-level details only for the last 0.25 time step units (when we would expect highly-plausible structures to emerge from the network).

In other words, shouldn't we only be backpropagating these losses when t > 0.75, not when t > 0.25? Or with flow matching, do highly-plausible structures emerge much earlier in the sampling process (e.g., compared to diffusion models)?

https://github.com/microsoft/frame-flow/blob/b8d868b9666a4ee7405feb4f54a71c31488fedcb/models/flow_module.py#L135

jasonkyuyim commented 9 months ago

As you say, for flow matching we reverse time so in fact setting aux_loss_t_pass=0.25 is turning on the aux losses for the last 75% steps of the generative process. This was intentional since we noticed for flow matching that the global structure becomes set early on so there was a benefit to turning on the aux losses this early. Also we didn't experiment too much with different settings for aux_loss_t_pass so it's probably not optimal.

amorehead commented 9 months ago

This answers my question. Thanks for clarifying that this was intentional. Given the trajectory animations you uploaded with this repository, I think your explanation makes a lot of sense in the context of flow matching.