guiyrt / MLinJulia

GSoC @ CERN-HSF - "Machine Learning in Julia for Calorimeter Showers"
2 stars 0 forks source link

Optimized attention dot product #2

Closed guiyrt closed 1 month ago

guiyrt commented 1 month ago

Following Flux.jl's MultiHeadAttention implementation, dot product attention now only uses permutedims and batched_mul. Big improvements in training (dataset 2 as example):