jpata / particleflow

Machine-learned, GPU-accelerated particle flow reconstruction
Apache License 2.0
24 stars 30 forks source link

enable FlashAttention in pytorch, update to torch 2.2.0 #292

Closed jpata closed 8 months ago

jpata commented 9 months ago

image image

The physics performance on the QCD high-pt sample is as follows.

Mamba (~98M): jet_res met_res

GNNLSH (~98M): jet_res met_res

FlashAttention (~4M): jet_res met_res

Training on the full dataset for 10 epochs (pyg-cms_20240208_214210_447656), we get the following performance: jet_res met_res

jpata commented 9 months ago

deps fail because https://data.pyg.org/whl/torch-2.2.0+cpu.html does not yet contain pyg-lib, pytorch-geometric was updated to pytorch 2.2 only in the last few days.

jpata commented 9 months ago

pyg-lib is now released and the tests pass again.