PlaytikaOSS / tft-torch

A Python library that implements ״Temporal Fusion Transformers for Interpretable Multi-horizon Time Series Forecasting״
MIT License
109 stars 17 forks source link

Usage of flash attention #12

Open shaharbar1 opened 9 months ago

shaharbar1 commented 9 months ago

Consider wrapping the call to self.attention in InterpretableMultiHeadAttention with with torch.backends.cuda.sdp_kernel(enable_flash=True, enable_math=False, enable_mem_efficient=True): In order to improve speed and memory efficiency.

otto-dev commented 3 weeks ago

no https://github.com/pytorch/pytorch/issues/125674