huggingface / transformers

🤗 Transformers: State-of-the-art Machine Learning for Pytorch, TensorFlow, and JAX.
https://huggingface.co/transformers
Apache License 2.0
128.46k stars 25.48k forks source link

When tranining the RWKV, it report "backward error" #31413

Open lxianl455 opened 2 weeks ago

lxianl455 commented 2 weeks ago

System Info

Who can help?

Using /root/.cache/torch_extensions/py38_cu118 as PyTorch extensions root... Loading extension module wkv_20... /usr/local/python/lib/python3.8/site-packages/torch/autograd/init.py:251: UserWarning: Error detected in RwkvLinearAttentionBackward. Traceback of forward call that caused the error: File "/data1/rl_server/rl_learner/code//train.py", line 18, in main() File "/data1/rl_server/rl_learner/code//train.py", line 14, in main trainer.run() File "/usr/local/python/lib/python3.8/site-packages/sail/learner/init.py", line 14, in run self.bench.run() File "/usr/local/python/lib/python3.8/site-packages/sail/learner/framework/apd_benchmark.py", line 236, in run self._do_train() File "/usr/local/python/lib/python3.8/site-packages/sail/learner/framework/apd_benchmark.py", line 210, in _do_train self.do_train_step(step_context, _input_datas) File "/usr/local/python/lib/python3.8/site-packages/sail/learner/framework/apd_benchmark.py", line 172, in do_train_step outputs = self.net_wrapper(_input_datas, self.local_step) File "/usr/local/python/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, kwargs) File "/usr/local/python/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, *kwargs) File "/usr/local/python/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1519, in forward else self._run_ddp_forward(inputs, kwargs) File "/usr/local/python/lib/python3.8/site-packages/torch/nn/parallel/distributed.py", line 1355, in _run_ddp_forward return self.module(*inputs, kwargs) # type: ignore[index] File "/usr/local/python/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, *kwargs) File "/usr/local/python/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(args, kwargs) File "/dockerdata/rl_server/rl_learner/code/algorithm.py", line 338, in forward each_hero_fc_result_list, all_lstm_state = self._inference(each_hero_data_list, lstm_initial_state, pos_lstm_initial_state) File "/dockerdata/rl_server/rl_learner/code/algorithm.py", line 391, in _inference lstm_outputs, lstm_state = self.public_lstm(reshape_new_fc_public_results, lstm_initial_state) File "/usr/local/python/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, kwargs) File "/usr/local/python/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, *kwargs) File "/dockerdata/rl_server/rl_learner/code/algorithm.py", line 261, in forward lstm_outputs, rwkv_state = self.lstm(reshape_new_fc_public_results, ) #暂时不传 File "/usr/local/python/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(args, kwargs) File "/usr/local/python/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, kwargs) File "/dockerdata/rl_server/rl_learner/code/components/custom_lstm_torch.py", line 573, in forward hidden_states, state = block( File "/usr/local/python/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, *kwargs) File "/usr/local/python/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(args, kwargs) File "/dockerdata/rl_server/rl_learner/code/components/custom_lstm_torch.py", line 402, in forward attention, state = self.attention(self.ln1(hidden), state=state, use_cache=use_cache) File "/usr/local/python/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl return self._call_impl(*args, kwargs) File "/usr/local/python/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl return forward_call(*args, *kwargs) File "/dockerdata/rl_server/rl_learner/code/components/custom_lstm_torch.py", line 323, in forward rwkv, layer_state = rwkv_linear_attention( File "/dockerdata/rl_server/rl_learner/code/components/custom_lstm_torch.py", line 260, in rwkv_linear_attention return RwkvLinearAttention.apply(time_decay, time_first, key, value, state, return_state) File "/usr/local/python/lib/python3.8/site-packages/torch/autograd/function.py", line 539, in apply return super().apply(args, kwargs) # type: ignore[misc] (Triggered internally at ../torch/csrc/autograd/python_anomaly_mode.cpp:114.) Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass Traceback (most recent call last): File "/data1/rl_server/rl_learner/code//train.py", line 18, in main() File "/data1/rl_server/rl_learner/code//train.py", line 14, in main trainer.run() File "/usr/local/python/lib/python3.8/site-packages/sail/learner/init.py", line 14, in run self.bench.run() File "/usr/local/python/lib/python3.8/site-packages/sail/learner/framework/apd_benchmark.py", line 236, in run self._do_train() File "/usr/local/python/lib/python3.8/site-packages/sail/learner/framework/apd_benchmark.py", line 210, in _do_train self.do_train_step(step_context, _input_datas) File "/usr/local/python/lib/python3.8/site-packages/sail/learner/framework/apd_benchmark.py", line 174, in do_train_step total_loss.backward() File "/usr/local/python/lib/python3.8/site-packages/torch/_tensor.py", line 492, in backward torch.autograd.backward( File "/usr/local/python/lib/python3.8/site-packages/torch/autograd/init.py", line 251, in backward Variable._execution_engine.run_backward( # Calls into the C++ engine to run the backward pass RuntimeError: CUDA error: an illegal memory access was encountered Compile with TORCH_USE_CUDA_DSA to enable device-side assertions.

Information

Tasks

Reproduction

import torch from transformers import AutoTokenizer, RwkvForCausalLM import os os.environ['CUDA_LAUNCH_BLOCKING'] = '1' torch.autograd.set_detect_anomaly(True) tokenizer = AutoTokenizer.from_pretrained("RWKV/rwkv-4-169m-pile") model = RwkvForCausalLM.from_pretrained("RWKV/rwkv-4-169m-pile", use_cache=True) tokenizer.pad_token = tokenizer.eos_token

inputs = tokenizer(["Hello, my dog is cute","i like this"], return_tensors="pt",padding=True) outputs = model(**inputs, labels=inputs["input_ids"]) loss = outputs.loss logits = outputs.logits loss.backward()

Expected behavior

Success to backward. Figure out why this happened and fix it.

amyeroberts commented 2 weeks ago

cc @ArthurZucker

RUFFY-369 commented 2 weeks ago

Hi @lxianl455 please put your model in train mode with model.train() before performing a backprop

model = RwkvForCausalLM.from_pretrained("RWKV/rwkv-4-169m-pile", use_cache=True)
tokenizer.pad_token = tokenizer.eos_token
model.train()
.
.
.

Cheers!

lxianl455 commented 2 weeks ago

Hi @lxianl455 please put your model in train mode with model.train() before performing a backprop

model = RwkvForCausalLM.from_pretrained("RWKV/rwkv-4-169m-pile", use_cache=True)
tokenizer.pad_token = tokenizer.eos_token
model.train()
.
.
.

Cheers!

Actually, I want to combine RWKV block with some other modules to predict time series information. I am not using the whole RWKV model, but only its blocks. In this scenario, the Transformer Trainer cannot be used. How can I solve this backward error?

RUFFY-369 commented 2 weeks ago

@lxianl455 okay so if you want to to use the RWKV model just for inference or in default eval() mode and don't want to put it in train mode then modify your code to this, the error will go away:

tokenizer = AutoTokenizer.from_pretrained("RWKV/rwkv-4-169m-pile")
model = RwkvForCausalLM.from_pretrained("RWKV/rwkv-4-169m-pile", use_cache=False)
.
.
.

give use_cache a False value because during caching, the model stores intermediate results to speed up computation and it can interfere with gradient computation and also use_cache is not used for training or gradient computation.

Cheers!

lxianl455 commented 2 weeks ago

Actually, the function I want is to work like LSTM. When training, LSTM can take in the init state for initialization and can also return the ending state afterwards. https://pytorch.org/docs/stable/generated/torch.nn.LSTM.html

I think the use_cache in the RWKV code is not like the use_cache in other models. Its actual function is like asking for the return of LSTM cell state and hidden state. Here is the code (line 300 in modeling_rwkv.py ):

rwkv, layer_state = rwkv_linear_attention( self.time_decay, self.time_first, key, value, state=layer_state, return_state=use_cache, )

ArthurZucker commented 1 week ago

The recurrentGemma model implements something more in the lines of RNN (so close to LSTM) if you are looking for an equivalent)