Following Flux.jl's MultiHeadAttention implementation, dot product attention now only uses permutedims and batched_mul.
Big improvements in training (dataset 2 as example):
Memory allocation down from 24 GiB to 10 GiB per training step
Step execution time down from 651.299 ms ± 99.684 ms to 143.854 ms ± 77.410 ms
Following Flux.jl's MultiHeadAttention implementation, dot product attention now only uses permutedims and batched_mul. Big improvements in training (dataset 2 as example):