lucidrains / PaLM-rlhf-pytorch

Implementation of RLHF (Reinforcement Learning with Human Feedback) on top of the PaLM architecture. Basically ChatGPT but with PaLM
MIT License
7.67k stars 668 forks source link

Can not train the model using PyTorch version 2? #36

Closed linhduongtuan closed 1 year ago

linhduongtuan commented 1 year ago

Dear Phil, I'd love not only the source code, but also all your contributions for the open community. While trying to custom this source code with using PT v2, I struggle with an error. The error message is shown: from user code: " File "/path/to/code/PaLM-rlhf/palm_rlhf_pytorch/palm.py", line 254, in sim = sim.masked_fill(causal_mask, -torch.finfo(sim.dtype).max)

Set torch._dynamo.config.verbose=True for more information

You can suppress this exception and fall back to eager by setting: torch._dynamo.config.suppress_errors = True " Can you help me to fix it. All the best. Linh

nicholasruha commented 1 year ago

I got a similar error because I'm using a gtx 1080 and I got this warning RuntimeError: Found NVIDIA GeForce GTX 1080 which is too old to be supported by the triton GPU compiler, which is used as the backend. Triton only supports devices of CUDA Capability >= 7.0, but your device is of CUDA capability 6.

When I tried running torch._dynamo.config.suppress_errors = True I would get an error saying torch does not have a module called "_dynamo". So I turned compile off and the error went away.

But apparently you can import _dynamo.config like this import torch._dynamo.config

Now when I call torch._dynamo.config.suppress_errors = True The error is suppressed, although I am not how much it impacts the benefits of using torch.compile.

You can also make the error go away by changing the backend used when running torch.compile. The backend parameter defaults to "inductor", so this is what it will default look like torch.compile(model, backend='inductor')

If you want to look at the other backends available to you, run torch._dynamo.list_backends() to list non-experimental backends or torch._dynamo.list_backends(None) to list experimental backends.

For completeness I'm using this version of pytorch 2.1.0.dev20230313+cu117