lucidrains / FLASH-pytorch

Implementation of the Transformer variant proposed in "Transformer Quality in Linear Time"
MIT License
344 stars 24 forks source link

Is it a typo in FLASH module? #10

Closed marsggbo closed 1 year ago

marsggbo commented 1 year ago

The original code is below: https://github.com/lucidrains/FLASH-pytorch/blob/edce0fd9a172e65d94844162fd7b31defa1e9fea/flash_pytorch/flash_pytorch.py#L338

Is that a typo? maybe the correct version is n=x.shape[-2] or set g=self.group_size

lucidrains commented 1 year ago

@marsggbo ohh yea, the einops equation isn't very clear

it should be b (n g) d -> b n g d, with g = self.group_size but otherwise it is correct