lucidrains / reformer-pytorch

Reformer, the efficient Transformer, in Pytorch
MIT License
2.1k stars 254 forks source link

Autopadder not working Reformer #115

Closed jamarju closed 4 years ago

jamarju commented 4 years ago

Running this example code:

import torch
from reformer_pytorch import Reformer, Autopadder

model = Reformer(
    dim = 512,
    depth = 12,
    max_seq_len = 8192,
    heads = 8,
    lsh_dropout = 0.1,
    causal = False
).cuda()

model = Autopadder(model)

x = torch.randn(1, 7777, 512).cuda()
y = model(x) # (1, 8192, 512)

Results in the following error:

---------------------------------------------------------------------------
RuntimeError                              Traceback (most recent call last)
<ipython-input-2-e7e425245696> in <module>
     14 
     15 x = torch.randn(1, 7777, 512).cuda()
---> 16 y = model(x) # (1, 8192, 512)

~/anaconda3/envs/sped37/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    720             result = self._slow_forward(*input, **kwargs)
    721         else:
--> 722             result = self.forward(*input, **kwargs)
    723         for hook in itertools.chain(
    724                 _global_forward_hooks.values(),

~/anaconda3/envs/sped37/lib/python3.7/site-packages/reformer_pytorch/autopadder.py in forward(self, x, **kwargs)
     52                 kwargs.update(input_attn_mask=new_mask)
     53 
---> 54         out = self.net(x, **kwargs)
     55         return out[:, 0:t]

~/anaconda3/envs/sped37/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    720             result = self._slow_forward(*input, **kwargs)
    721         else:
--> 722             result = self.forward(*input, **kwargs)
    723         for hook in itertools.chain(
    724                 _global_forward_hooks.values(),

~/anaconda3/envs/sped37/lib/python3.7/site-packages/reformer_pytorch/reformer_pytorch.py in forward(self, x, **kwargs)
    682         x = torch.cat([x, x], dim = -1)
    683         arg_route = (True, self.twin_attention)
--> 684         x = self.layers(x, arg_route = arg_route, **kwargs)
    685         return torch.stack(x.chunk(2, dim=-1)).mean(dim=0)
    686 

~/anaconda3/envs/sped37/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    720             result = self._slow_forward(*input, **kwargs)
    721         else:
--> 722             result = self.forward(*input, **kwargs)
    723         for hook in itertools.chain(
    724                 _global_forward_hooks.values(),

~/anaconda3/envs/sped37/lib/python3.7/site-packages/reformer_pytorch/reversible.py in forward(self, x, arg_route, **kwargs)
    160             return x
    161 
--> 162         return _ReversibleFunction.apply(x, blocks, block_kwargs)

~/anaconda3/envs/sped37/lib/python3.7/site-packages/reformer_pytorch/reversible.py in forward(ctx, x, blocks, kwargs)
    121         ctx.kwargs = kwargs
    122         for block in blocks:
--> 123             x = block(x, **kwargs)
    124         ctx.y = x.detach()
    125         ctx.blocks = blocks

~/anaconda3/envs/sped37/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    720             result = self._slow_forward(*input, **kwargs)
    721         else:
--> 722             result = self.forward(*input, **kwargs)
    723         for hook in itertools.chain(
    724                 _global_forward_hooks.values(),

~/anaconda3/envs/sped37/lib/python3.7/site-packages/reformer_pytorch/reversible.py in forward(self, x, f_args, g_args)
     57 
     58         with torch.no_grad():
---> 59             y1 = x1 + self.f(x2, record_rng=self.training, **f_args)
     60             y2 = x2 + self.g(y1, record_rng=self.training, **g_args)
     61 

~/anaconda3/envs/sped37/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    720             result = self._slow_forward(*input, **kwargs)
    721         else:
--> 722             result = self.forward(*input, **kwargs)
    723         for hook in itertools.chain(
    724                 _global_forward_hooks.values(),

~/anaconda3/envs/sped37/lib/python3.7/site-packages/reformer_pytorch/reversible.py in forward(self, record_rng, set_rng, *args, **kwargs)
     25 
     26         if not set_rng:
---> 27             return self.net(*args, **kwargs)
     28 
     29         rng_devices = []

~/anaconda3/envs/sped37/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    720             result = self._slow_forward(*input, **kwargs)
    721         else:
--> 722             result = self.forward(*input, **kwargs)
    723         for hook in itertools.chain(
    724                 _global_forward_hooks.values(),

~/anaconda3/envs/sped37/lib/python3.7/site-packages/reformer_pytorch/reformer_pytorch.py in forward(self, x, **kwargs)
    145     def forward(self, x, **kwargs):
    146         x = self.norm(x)
--> 147         return self.fn(x, **kwargs)
    148 
    149 class Chunk(nn.Module):

~/anaconda3/envs/sped37/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs)
    720             result = self._slow_forward(*input, **kwargs)
    721         else:
--> 722             result = self.forward(*input, **kwargs)
    723         for hook in itertools.chain(
    724                 _global_forward_hooks.values(),

~/anaconda3/envs/sped37/lib/python3.7/site-packages/reformer_pytorch/reformer_pytorch.py in forward(self, x, keys, input_mask, input_attn_mask, context_mask, **kwargs)
    549             m_mask = default_mask.expand(b, m)
    550             c_mask = default(context_mask, default_mask.expand(b, c))
--> 551             mask = torch.cat((i_mask, m_mask, c_mask), dim=1)
    552             mask = merge_batch_and_heads(expand_dim(1, lsh_h, mask))
    553             masks['input_mask'] = mask

RuntimeError: Tensors must have same number of dimensions: got 2 and 3
lucidrains commented 4 years ago

@jamarju It should work for the Reformer class now! https://github.com/lucidrains/reformer-pytorch/releases/tag/1.1.5 I had previously only tested it for ReformerLM, my bad!