Closed marsggbo closed 1 year ago
The original code is below: https://github.com/lucidrains/FLASH-pytorch/blob/edce0fd9a172e65d94844162fd7b31defa1e9fea/flash_pytorch/flash_pytorch.py#L338
Is that a typo? maybe the correct version is n=x.shape[-2] or set g=self.group_size
n=x.shape[-2]
g=self.group_size
@marsggbo ohh yea, the einops equation isn't very clear
it should be b (n g) d -> b n g d, with g = self.group_size but otherwise it is correct
b (n g) d -> b n g d
g = self.group_size
The original code is below: https://github.com/lucidrains/FLASH-pytorch/blob/edce0fd9a172e65d94844162fd7b31defa1e9fea/flash_pytorch/flash_pytorch.py#L338
Is that a typo? maybe the correct version is
n=x.shape[-2]
or setg=self.group_size