Closed jamarju closed 4 years ago
Running this example code:
import torch from reformer_pytorch import Reformer, Autopadder model = Reformer( dim = 512, depth = 12, max_seq_len = 8192, heads = 8, lsh_dropout = 0.1, causal = False ).cuda() model = Autopadder(model) x = torch.randn(1, 7777, 512).cuda() y = model(x) # (1, 8192, 512)
Results in the following error:
--------------------------------------------------------------------------- RuntimeError Traceback (most recent call last) <ipython-input-2-e7e425245696> in <module> 14 15 x = torch.randn(1, 7777, 512).cuda() ---> 16 y = model(x) # (1, 8192, 512) ~/anaconda3/envs/sped37/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs) 720 result = self._slow_forward(*input, **kwargs) 721 else: --> 722 result = self.forward(*input, **kwargs) 723 for hook in itertools.chain( 724 _global_forward_hooks.values(), ~/anaconda3/envs/sped37/lib/python3.7/site-packages/reformer_pytorch/autopadder.py in forward(self, x, **kwargs) 52 kwargs.update(input_attn_mask=new_mask) 53 ---> 54 out = self.net(x, **kwargs) 55 return out[:, 0:t] ~/anaconda3/envs/sped37/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs) 720 result = self._slow_forward(*input, **kwargs) 721 else: --> 722 result = self.forward(*input, **kwargs) 723 for hook in itertools.chain( 724 _global_forward_hooks.values(), ~/anaconda3/envs/sped37/lib/python3.7/site-packages/reformer_pytorch/reformer_pytorch.py in forward(self, x, **kwargs) 682 x = torch.cat([x, x], dim = -1) 683 arg_route = (True, self.twin_attention) --> 684 x = self.layers(x, arg_route = arg_route, **kwargs) 685 return torch.stack(x.chunk(2, dim=-1)).mean(dim=0) 686 ~/anaconda3/envs/sped37/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs) 720 result = self._slow_forward(*input, **kwargs) 721 else: --> 722 result = self.forward(*input, **kwargs) 723 for hook in itertools.chain( 724 _global_forward_hooks.values(), ~/anaconda3/envs/sped37/lib/python3.7/site-packages/reformer_pytorch/reversible.py in forward(self, x, arg_route, **kwargs) 160 return x 161 --> 162 return _ReversibleFunction.apply(x, blocks, block_kwargs) ~/anaconda3/envs/sped37/lib/python3.7/site-packages/reformer_pytorch/reversible.py in forward(ctx, x, blocks, kwargs) 121 ctx.kwargs = kwargs 122 for block in blocks: --> 123 x = block(x, **kwargs) 124 ctx.y = x.detach() 125 ctx.blocks = blocks ~/anaconda3/envs/sped37/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs) 720 result = self._slow_forward(*input, **kwargs) 721 else: --> 722 result = self.forward(*input, **kwargs) 723 for hook in itertools.chain( 724 _global_forward_hooks.values(), ~/anaconda3/envs/sped37/lib/python3.7/site-packages/reformer_pytorch/reversible.py in forward(self, x, f_args, g_args) 57 58 with torch.no_grad(): ---> 59 y1 = x1 + self.f(x2, record_rng=self.training, **f_args) 60 y2 = x2 + self.g(y1, record_rng=self.training, **g_args) 61 ~/anaconda3/envs/sped37/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs) 720 result = self._slow_forward(*input, **kwargs) 721 else: --> 722 result = self.forward(*input, **kwargs) 723 for hook in itertools.chain( 724 _global_forward_hooks.values(), ~/anaconda3/envs/sped37/lib/python3.7/site-packages/reformer_pytorch/reversible.py in forward(self, record_rng, set_rng, *args, **kwargs) 25 26 if not set_rng: ---> 27 return self.net(*args, **kwargs) 28 29 rng_devices = [] ~/anaconda3/envs/sped37/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs) 720 result = self._slow_forward(*input, **kwargs) 721 else: --> 722 result = self.forward(*input, **kwargs) 723 for hook in itertools.chain( 724 _global_forward_hooks.values(), ~/anaconda3/envs/sped37/lib/python3.7/site-packages/reformer_pytorch/reformer_pytorch.py in forward(self, x, **kwargs) 145 def forward(self, x, **kwargs): 146 x = self.norm(x) --> 147 return self.fn(x, **kwargs) 148 149 class Chunk(nn.Module): ~/anaconda3/envs/sped37/lib/python3.7/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs) 720 result = self._slow_forward(*input, **kwargs) 721 else: --> 722 result = self.forward(*input, **kwargs) 723 for hook in itertools.chain( 724 _global_forward_hooks.values(), ~/anaconda3/envs/sped37/lib/python3.7/site-packages/reformer_pytorch/reformer_pytorch.py in forward(self, x, keys, input_mask, input_attn_mask, context_mask, **kwargs) 549 m_mask = default_mask.expand(b, m) 550 c_mask = default(context_mask, default_mask.expand(b, c)) --> 551 mask = torch.cat((i_mask, m_mask, c_mask), dim=1) 552 mask = merge_batch_and_heads(expand_dim(1, lsh_h, mask)) 553 masks['input_mask'] = mask RuntimeError: Tensors must have same number of dimensions: got 2 and 3
@jamarju It should work for the Reformer class now! https://github.com/lucidrains/reformer-pytorch/releases/tag/1.1.5 I had previously only tested it for ReformerLM, my bad!
Running this example code:
Results in the following error: