lucidrains / FLASH-pytorch

Implementation of the Transformer variant proposed in "Transformer Quality in Linear Time"
MIT License
342 stars 24 forks source link

AttributeError: module 'torch' has no attribute 'special' #12

Closed bibo-msft closed 10 months ago

bibo-msft commented 10 months ago

I have torch==1.7.0a0 installed, but when I ran out = gau(x), there is an error:

Traceback (most recent call last): File "", line 1, in File "MYPATH/work/venv/pytorch-rocm/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl result = self.forward(*input, *kwargs) File "MYPATH/work/flash-pytorch/git-FLASH-pytorch/flash_pytorch/flash_pytorch.py", line 201, in forward attn = self.attn_fn(sim / seq_len) File "MYPATH/work/venv/pytorch-rocm/lib/python3.8/site-packages/torch/nn/modules/module.py", line 727, in _call_impl result = self.forward(input, *kwargs) File "MYPATH/work/flash-pytorch/git-FLASH-pytorch/flash_pytorch/flash_pytorch.py", line 131, in forward return (1 + torch.special.erf((x - mu) / (std math.sqrt(2)))) * 0.5 AttributeError: module 'torch' has no attribute 'special'

It seems that torch.special is introduced in Pytorch 1.9.

lucidrains commented 10 months ago

@bibo-msft ah ok, i put in a quick fix, thanks!

bibo-msft commented 10 months ago

I used torch.erf instead of torch.special.erf, and it works for me now with torch==1.7.0a0. There is no need to change the requirement to >1.9.

image

I will let you know if there is further issues with my torch==1.7.0a0.