Closed Warvito closed 3 years ago
@Warvito Hi Walter! Perhaps what we can do is disable the invocation https://github.com/lucidrains/performer-pytorch/blob/main/performer_pytorch/performer_pytorch.py#L429 and you can call it yourself in the training loop?
are you using data parallel or distributed data parallel (ddp)?
@lucidrains I am using DataParallel
Here follow a code to replicated the issue.
import torch
from performer_pytorch import PerformerLM
from tqdm import trange
model = PerformerLM(
num_tokens=20000,
max_seq_len=2048,
dim=512,
depth=12,
heads=8,
causal=False,
nb_features=256,
feature_redraw_interval=5,
)
model.train()
device = torch.device("cuda")
model = torch.nn.DataParallel(model).to(device)
x = torch.randint(0, 20000, (1, 2048))
for i in trange(10):
outputs=model(x)
In a single GPU it works fine. But running it in a multiGPU I got this error:
StopIteration Traceback (most recent call last)
in 21 22 for i in trange(10): ---> 23 outputs=model(x) 24 /opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py in _call_impl(self, *input, **kwargs) 742 result = self._slow_forward(*input, **kwargs) 743 else: --> 744 result = self.forward(*input, **kwargs) 745 for hook in itertools.chain( 746 _global_forward_hooks.values(), /opt/conda/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py in forward(self, *inputs, **kwargs) 165 return self.module(*inputs[0], **kwargs[0]) 166 replicas = self.replicate(self.module, self.device_ids[:len(inputs)]) --> 167 outputs = self.parallel_apply(replicas, inputs, kwargs) 168 return self.gather(outputs, self.output_device) 169 /opt/conda/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py in parallel_apply(self, replicas, inputs, kwargs) 175 176 def parallel_apply(self, replicas, inputs, kwargs): --> 177 return parallel_apply(replicas, inputs, kwargs, self.device_ids[:len(replicas)]) 178 179 def gather(self, outputs, output_device): /opt/conda/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py in parallel_apply(modules, inputs, kwargs_tup, devices) 84 output = results[i] 85 if isinstance(output, ExceptionWrapper): ---> 86 output.reraise() 87 outputs.append(output) 88 return outputs /opt/conda/lib/python3.8/site-packages/torch/_utils.py in reraise(self) 426 # have message field 427 raise self.exc_type(message=msg) --> 428 raise self.exc_type(msg) 429 430 StopIteration: Caught StopIteration in replica 0 on device 0. Original Traceback (most recent call last): File "/opt/conda/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py", line 61, in _worker output = module(*input, **kwargs) File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 744, in _call_impl result = self.forward(*input, **kwargs) File "/opt/conda/lib/python3.8/site-packages/performer_pytorch/performer_pytorch.py", line 467, in forward x = self.performer(x, **kwargs) File "/opt/conda/lib/python3.8/site-packages/torch/nn/modules/module.py", line 744, in _call_impl result = self.forward(*input, **kwargs) File "/opt/conda/lib/python3.8/site-packages/performer_pytorch/performer_pytorch.py", line 429, in forward self.check_redraw_projections() File "/opt/conda/lib/python3.8/site-packages/performer_pytorch/performer_pytorch.py", line 417, in check_redraw_projections device = get_module_device(self) File "/opt/conda/lib/python3.8/site-packages/performer_pytorch/performer_pytorch.py", line 40, in get_module_device return next(module.parameters()).device StopIteration
As expected, setting "feature_redraw_interval=None," does not generate the error.
As suggested, I tried to perform the redraw on the training loop, and it works:
import torch
from performer_pytorch import PerformerLM
from tqdm import trange
from performer_pytorch.performer_pytorch import find_modules, FastAttention
model = PerformerLM(
num_tokens=20000,
max_seq_len=2048,
dim=512,
depth=12,
heads=8,
causal=False,
nb_features=256,
feature_redraw_interval=None,
)
model.train()
device = torch.device("cuda")
model = torch.nn.DataParallel(model).to(device)
x = torch.randint(0, 20000, (1, 2048))
for i in trange(5):
outputs=model(x)
fast_attentions = find_modules(model, FastAttention)
for fast_attention in fast_attentions:
fast_attention.redraw_projection_matrix(device)
model.module.performer.calls_since_last_redraw.zero_()
for i in trange(5):
outputs=model(x)
@lucidrains Thanks very much for the suggestion ^^
@Warvito no problem Walter! I added a way to turn off automatic redrawing with auto_check_redraw = False
, and then you can just invoke check_redraw_projections()
yourself wherever
Hi,
I have been trying to use the PerformerLM (from version 0.15.0) with DataParallel, but after a few epochs I am getting this error.