lucidrains / performer-pytorch

An implementation of Performer, a linear attention-based transformer, in Pytorch
MIT License
1.08k stars 141 forks source link

Getting error with the check_redraw_projections when using DataParallel #61

Closed Warvito closed 3 years ago

Warvito commented 3 years ago

Hi,

I have been trying to use the PerformerLM (from version 0.15.0) with DataParallel, but after a few epochs I am getting this error.

Traceback (most recent call last): File "/project/train_transformer.py", line 168, in main(args) File "/project/train_transformer.py", line 116, in main val_loss = train_transformer(model, File "/project/training_functions.py", line 963, in train_transformer train_epoch_transformer(model, vqvae, train_loader, optimizer, scheduler, loss_fn, device, epoch, writer_train) File "/project/training_functions.py", line 1032, in train_epoch_transformer outputs = model(encoded_in) File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 744, in _call_impl result = self.forward(*input, kwargs) File "/opt/conda/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py", line 167, in forward outputs = self.parallel_apply(replicas, inputs, kwargs) File "/opt/conda/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py", line 177, in parallel_apply return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)]) File "/opt/conda/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py", line 86, in parallel_apply output.reraise() File "/opt/conda/lib/python3.8/site-packages/torch/_utils.py", line 428, in reraise raise self.exc_type(msg) StopIteration: Caught StopIteration in replica 0 on device 0. Original Traceback (most recent call last): File "/opt/conda/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker output = module(*input, *kwargs) File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 744, in _call_impl result = self.forward(input, kwargs) File "/project/models/transformer.py", line 92, in forward return self.model(x) File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 744, in _call_impl result = self.forward(input, kwargs) File "/opt/conda/lib/python3.8/site-packages/performer_pytorch/performer_pytorch.py", line 467, in forward x = self.performer(x, kwargs) File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 744, in _call_impl result = self.forward(input, **kwargs) File "/opt/conda/lib/python3.8/site-packages/performer_pytorch/performer_pytorch.py", line 429, in forward self.check_redraw_projections() File "/opt/conda/lib/python3.8/site-packages/performer_pytorch/performer_pytorch.py", line 417, in check_redraw_projections device = get_module_device(self) File "/opt/conda/lib/python3.8/site-packages/performer_pytorch/performer_pytorch.py", line 40, in get_module_device return next(module.parameters()).device StopIteration

lucidrains commented 3 years ago

@Warvito Hi Walter! Perhaps what we can do is disable the invocation https://github.com/lucidrains/performer-pytorch/blob/main/performer_pytorch/performer_pytorch.py#L429 and you can call it yourself in the training loop?

lucidrains commented 3 years ago

are you using data parallel or distributed data parallel (ddp)?

Warvito commented 3 years ago

@lucidrains I am using DataParallel

Here follow a code to replicated the issue.

import torch
from performer_pytorch import PerformerLM
from tqdm import trange

model = PerformerLM(
    num_tokens=20000,
    max_seq_len=2048,
    dim=512,
    depth=12,
    heads=8,
    causal=False,
    nb_features=256,
    feature_redraw_interval=5,
)

model.train()
device = torch.device("cuda")
model = torch.nn.DataParallel(model).to(device)

x = torch.randint(0, 20000, (1, 2048))

for i in trange(10):
    outputs=model(x)

In a single GPU it works fine. But running it in a multiGPU I got this error:


StopIteration Traceback (most recent call last)

in 21 22 for i in trange(10): ---> 23 outputs=model(x) 24 /opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs) 742 result = self._slow_forward(*input, **kwargs) 743 else: --> 744 result = self.forward(*input, **kwargs) 745 for hook in itertools.chain( 746 _global_forward_hooks.values(), /opt/conda/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py in forward(self, *inputs, **kwargs) 165 return self.module(*inputs[0], **kwargs[0]) 166 replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) --> 167 outputs = self.parallel_apply(replicas, inputs, kwargs) 168 return self.gather(outputs, self.output_device) 169 /opt/conda/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py in parallel_apply(self, replicas, inputs, kwargs) 175 176 def parallel_apply(self, replicas, inputs, kwargs): --> 177 return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)]) 178 179 def gather(self, outputs, output_device): /opt/conda/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py in parallel_apply(modules, inputs, kwargs_tup, devices) 84 output = results[i] 85 if isinstance(output, ExceptionWrapper): ---> 86 output.reraise() 87 outputs.append(output) 88 return outputs /opt/conda/lib/python3.8/site-packages/torch/_utils.py in reraise(self) 426 # have message field 427 raise self.exc_type(message=msg) --> 428 raise self.exc_type(msg) 429 430 StopIteration: Caught StopIteration in replica 0 on device 0. Original Traceback (most recent call last): File "/opt/conda/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker output = module(*input, **kwargs) File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 744, in _call_impl result = self.forward(*input, **kwargs) File "/opt/conda/lib/python3.8/site-packages/performer_pytorch/performer_pytorch.py", line 467, in forward x = self.performer(x, **kwargs) File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 744, in _call_impl result = self.forward(*input, **kwargs) File "/opt/conda/lib/python3.8/site-packages/performer_pytorch/performer_pytorch.py", line 429, in forward self.check_redraw_projections() File "/opt/conda/lib/python3.8/site-packages/performer_pytorch/performer_pytorch.py", line 417, in check_redraw_projections device = get_module_device(self) File "/opt/conda/lib/python3.8/site-packages/performer_pytorch/performer_pytorch.py", line 40, in get_module_device return next(module.parameters()).device StopIteration

As expected, setting "feature_redraw_interval=None," does not generate the error.

As suggested, I tried to perform the redraw on the training loop, and it works:

import torch
from performer_pytorch import PerformerLM
from tqdm import trange
from performer_pytorch.performer_pytorch import find_modules, FastAttention

model = PerformerLM(
    num_tokens=20000,
    max_seq_len=2048,
    dim=512,
    depth=12,
    heads=8,
    causal=False,
    nb_features=256,
    feature_redraw_interval=None,
)

model.train()
device = torch.device("cuda")
model = torch.nn.DataParallel(model).to(device)

x = torch.randint(0, 20000, (1, 2048))

for i in trange(5):
    outputs=model(x)

fast_attentions = find_modules(model, FastAttention)
for fast_attention in fast_attentions:
    fast_attention.redraw_projection_matrix(device)

model.module.performer.calls_since_last_redraw.zero_()

for i in trange(5):
    outputs=model(x)

@lucidrains Thanks very much for the suggestion ^^

lucidrains commented 3 years ago

@Warvito no problem Walter! I added a way to turn off automatic redrawing with auto_check_redraw = False, and then you can just invoke check_redraw_projections() yourself wherever