Closed yzh119 closed 2 weeks ago
float attention mask consumes too much gpu memory and makes the attention kernel slow. In this pr we use 0/1 attention mask and uses bit-packed array (1 bit per element, 8 elements are packed together as uint8) to save gpu memory.
float attention mask consumes too much gpu memory and makes the attention kernel slow. In this pr we use 0/1 attention mask and uses bit-packed array (1 bit per element, 8 elements are packed together as uint8) to save gpu memory.