THUDM / ChatGLM3

ChatGLM3 series: Open Bilingual Chat LLMs | 开源双语对话语言模型
Apache License 2.0
13.19k stars 1.52k forks source link

[lora finetune] RuntimeError: CUDA error: device-side assert triggered #1219

Closed Janet-Baker closed 2 months ago

Janet-Baker commented 2 months ago

System Info / 系統信息

系统环境: cuda:

(base) root@instance:~# nvidia-smi
Sat May 18 10:04:24 2024
+---------------------------------------------------------------------------------------+
| NVIDIA-SMI 535.146.02             Driver Version: 535.146.02   CUDA Version: 12.2     |
|-----------------------------------------+----------------------+----------------------+
| GPU  Name                 Persistence-M | Bus-Id        Disp.A | Volatile Uncorr. ECC |
| Fan  Temp   Perf          Pwr:Usage/Cap |         Memory-Usage | GPU-Util  Compute M. |
|                                         |                      |               MIG M. |
|=========================================+======================+======================|
|   0  NVIDIA A100-PCIE-40GB          On  | 00000000:00:08.0 Off |                    0 |
| N/A   34C    P0              40W / 250W |      0MiB / 40960MiB |      0%      Default |
|                                         |                      |             Disabled |
+-----------------------------------------+----------------------+----------------------+

transformers:

(base) root@instance:~# pip show transformers
Name: transformers
Version: 4.38.1
Summary: State-of-the-art Machine Learning for JAX, PyTorch and TensorFlow
Home-page: https://github.com/huggingface/transformers
Author: The Hugging Face team (past and future) with the help of all our contributors (https://github.com/huggingface/transformers/graphs/contributors)
Author-email: transformers@huggingface.co
License: Apache 2.0 License
Location: /root/miniconda3/lib/python3.10/site-packages
Requires: filelock, huggingface-hub, numpy, packaging, pyyaml, regex, requests, safetensors, tokenizers, tqdm
Required-by: peft, sentence-transformers, transformers-stream-generator

python:

(base) root@instance:~# python --version
Python 3.10.12

os: ubuntu

(base) root@instance:~# lsb_release -a
No LSB modules are available.
Distributor ID: Ubuntu
Description:    Ubuntu 22.04.2 LTS
Release:        22.04
Codename:       jammy
(base) root@instance:~# uname -a
Linux instance 6.1.56-1.2.3 #1 SMP PREEMPT_DYNAMIC Tue Oct 10 19:39:42 CST 2023 x86_64 x86_64 x86_64 GNU/Linux

hardware information

(base) root@instance:~# free -m
               total        used        free      shared  buff/cache   available
Mem:           65536         159       65368           0           8       65376
Swap:              0           0           0
(base) root@instance:~# lscpu
Architecture:           x86_64
  CPU op-mode(s):       32-bit, 64-bit
  Address sizes:        46 bits physical, 48 bits virtual
  Byte Order:           Little Endian
CPU(s):                 10
  On-line CPU(s) list:  0-9
Vendor ID:              GenuineIntel
  Model name:           Intel Xeon Processor (Cascadelake)
    CPU family:         6
    Model:              85
    Thread(s) per core: 2
    Core(s) per socket: 5
    Socket(s):          1
    Stepping:           6
    BogoMIPS:           5985.76
    Flags:              fpu vme de pse tsc msr pae mce cx8 apic sep mtrr pge mca cmov pat pse36 clflush mmx fxsr sse sse2 ht syscall nx pdpe1gb rdtscp lm constant_tsc rep_good nopl xtopolog
                        y cpuid pni pclmulqdq ssse3 fma cx16 pcid sse4_1 sse4_2 x2apic movbe popcnt tsc_deadline_timer aes xsave avx f16c rdrand hypervisor lahf_lm abm 3dnowprefetch cpuid_f
                        ault invpcid_single ssbd ibrs ibpb fsgsbase bmi1 hle avx2 smep bmi2 erms invpcid rtm avx512f avx512dq rdseed adx smap clflushopt clwb avx512cd avx512bw avx512vl xsav
                        eopt xsavec xgetbv1 arat pku ospke avx512_vnni
Caches (sum of all):
  L1d:                  320 KiB (10 instances)
  L1i:                  320 KiB (10 instances)
  L2:                   20 MiB (5 instances)
  L3:                   16 MiB (1 instance)
NUMA:
  NUMA node(s):         2
  NUMA node0 CPU(s):    0-39
  NUMA node1 CPU(s):    40-79
Vulnerabilities:
  Gather data sampling: Unknown: Dependent on hypervisor status
  Itlb multihit:        KVM: Mitigation: VMX unsupported
  L1tf:                 Mitigation; PTE Inversion
  Mds:                  Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
  Meltdown:             Vulnerable
  Mmio stale data:      Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
  Retbleed:             Mitigation; IBRS
  Spec rstack overflow: Not affected
  Spec store bypass:    Mitigation; Speculative Store Bypass disabled via prctl
  Spectre v1:           Mitigation; usercopy/swapgs barriers and __user pointer sanitization
  Spectre v2:           Mitigation; IBRS, IBPB conditional, STIBP disabled, RSB filling, PBRSB-eIBRS Not affected
  Srbds:                Not affected
  Tsx async abort:      Vulnerable: Clear CPU buffers attempted, no microcode; SMT Host state unknown
(base) root@instance:~# df -h
Filesystem      Size  Used Avail Use% Mounted on
overlay          20G  1.9G   19G  10% /
tmpfs            64M     0   64M   0% /dev
tmpfs           315G     0  315G   0% /sys/fs/cgroup
/dev/sdc        9.0T  1.1T  8.0T  13% /.ot_readonly
/dev/sda1        50G  8.2G   39G  18% /etc/hosts
shm              32G     0   32G   0% /dev/shm
tmpfs           315G   12K  315G   1% /proc/driver/nvidia
tmpfs           315G  134M  315G   1% /run/nvidia-persistenced/socket
tmpfs           315G     0  315G   0% /proc/acpi
tmpfs           315G     0  315G   0% /proc/scsi
tmpfs           315G     0  315G   0% /sys/firmware

Who can help? / 谁可以帮助到您?

finetune_demo: @Btlmd

Information / 问题信息

Reproduction / 复现过程

数据文件:链接: https://pan.baidu.com/s/10PlfKWF1f5tP_gKSidE28w?pwd=niq3 提取码: niq3 复制这段内容后打开百度网盘手机App,操作更方便哦 在训练到约3050steps时报错。

仅修改了 /root/onethingai-tmp/ChatGLM3/finetune_demo/configs/lora.yaml 的内容,其余脚本保持原样:

data_config:
  train_file: result.json
  val_file: val_50.json
  test_file: 2023sfks.json
  num_proc: 6
max_input_length: 8192
max_output_length: 8192
training_args:
  # see `transformers.Seq2SeqTrainingArguments`
  output_dir: /root/onethingai-tmp/output/
  max_steps: 12000
  # needed to be fit for the dataset
  learning_rate: 1e-5
  # settings for data loading
  per_device_train_batch_size: 4
  dataloader_num_workers: 8
  remove_unused_columns: false
  # settings for saving checkpoints
  save_strategy: steps
  save_steps: 11999
  # settings for logging
  log_level: info
  logging_strategy: steps
  logging_steps: 500
  # settings for evaluation
  per_device_eval_batch_size: 1
  evaluation_strategy: steps
  eval_steps: 100001
  # settings for optimizer
  # adam_epsilon: 1e-6
  # uncomment the following line to detect nan or inf values
  # debug: underflow_overflow
  predict_with_generate: true
  # see `transformers.GenerationConfig`
  generation_config:
    max_new_tokens: 4096
  # set your absolute deepspeed path here
  #deepspeed: ds_zero_2.json
  # set to true if train with cpu.
  use_cpu: false
peft_config:
  peft_type: LORA
  task_type: CAUSAL_LM
  r: 8
  lora_alpha: 32
  lora_dropout: 0.1

控制台命令:

cd /root/onethingai-tmp/ChatGLM3/finetune_demo
python finetune_hf.py /root/onethingai-tmp /root/onethingai-tmp/ZhipuAI/chatglm3-6b configs/lora.yaml

控制台输出(因为内容太长所以被截断):

../aten/src/ATen/native/cuda/IndexKernel.cu:92: operator(): block: [1204,0,0], thread: [125,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:92: operator(): block: [1204,0,0], thread: [126,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
../aten/src/ATen/native/cuda/IndexKernel.cu:92: operator(): block: [1204,0,0], thread: [127,0,0] Assertion `index >= -sizes[i] && index < sizes[i] && "index out of bounds"` failed.
╭─────────────────────────────── Traceback (most recent call last) ────────────────────────────────╮
│ /root/onethingai-tmp/ChatGLM3/finetune_demo/finetune_hf.py:532 in main                           │
│                                                                                                  │
│   529 │   )                                                                                      │
│   530 │                                                                                          │
│   531 │   if auto_resume_from_checkpoint.upper() == "" or auto_resume_from_checkpoint is None:   │
│ ❱ 532 │   │   trainer.train()                                                                    │
│   533 │   else:                                                                                  │
│   534 │   │   output_dir = ft_config.training_args.output_dir                                    │
│   535 │   │   dirlist = os.listdir(output_dir)                                                   │
│                                                                                                  │
│ /root/miniconda3/lib/python3.10/site-packages/transformers/trainer.py:1859 in train              │
│                                                                                                  │
│   1856 │   │   │   finally:                                                                      │
│   1857 │   │   │   │   hf_hub_utils.enable_progress_bars()                                       │
│   1858 │   │   else:                                                                             │
│ ❱ 1859 │   │   │   return inner_training_loop(                                                   │
│   1860 │   │   │   │   args=args,                                                                │
│   1861 │   │   │   │   resume_from_checkpoint=resume_from_checkpoint,                            │
│   1862 │   │   │   │   trial=trial,                                                              │
│                                                                                                  │
│ /root/miniconda3/lib/python3.10/site-packages/transformers/trainer.py:2203 in                    │
│ _inner_training_loop                                                                             │
│                                                                                                  │
│   2200 │   │   │   │   │   self.control = self.callback_handler.on_step_begin(args, self.state,  │
│   2201 │   │   │   │                                                                             │
│   2202 │   │   │   │   with self.accelerator.accumulate(model):                                  │
│ ❱ 2203 │   │   │   │   │   tr_loss_step = self.training_step(model, inputs)                      │
│   2204 │   │   │   │                                                                             │
│   2205 │   │   │   │   if (                                                                      │
│   2206 │   │   │   │   │   args.logging_nan_inf_filter                                           │
│                                                                                                  │
│ /root/miniconda3/lib/python3.10/site-packages/transformers/trainer.py:3138 in training_step      │
│                                                                                                  │
│   3135 │   │   │   return loss_mb.reduce_mean().detach().to(self.args.device)                    │
│   3136 │   │                                                                                     │
│   3137 │   │   with self.compute_loss_context_manager():                                         │
│ ❱ 3138 │   │   │   loss = self.compute_loss(model, inputs)                                       │
│   3139 │   │                                                                                     │
│   3140 │   │   if self.args.n_gpu > 1:                                                           │
│   3141 │   │   │   loss = loss.mean()  # mean() to average on multi-gpu parallel training        │
│                                                                                                  │
│ /root/miniconda3/lib/python3.10/site-packages/transformers/trainer.py:3161 in compute_loss       │
│                                                                                                  │
│   3158 │   │   │   labels = inputs.pop("labels")                                                 │
│   3159 │   │   else:                                                                             │
│   3160 │   │   │   labels = None                                                                 │
│ ❱ 3161 │   │   outputs = model(**inputs)                                                         │
│   3162 │   │   # Save past state if it exists                                                    │
│   3163 │   │   # TODO: this needs to be fixed and made cleaner later.                            │
│   3164 │   │   if self.args.past_index >= 0:                                                     │
│                                                                                                  │
│ /root/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1501 in _call_impl      │
│                                                                                                  │
│   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1501 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1502 │   │   # Do not call functions when jit is used                                          │
│   1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1504 │   │   backward_pre_hooks = []                                                           │
│                                                                                                  │
│ /root/miniconda3/lib/python3.10/site-packages/peft/peft_model.py:1129 in forward                 │
│                                                                                                  │
│   1126 │   │   │                                                                                 │
│   1127 │   │   │   with self._enable_peft_forward_hooks(**kwargs):                               │
│   1128 │   │   │   │   kwargs = {k: v for k, v in kwargs.items() if k not in self.special_peft_  │
│ ❱ 1129 │   │   │   │   return self.base_model(                                                   │
│   1130 │   │   │   │   │   input_ids=input_ids,                                                  │
│   1131 │   │   │   │   │   attention_mask=attention_mask,                                        │
│   1132 │   │   │   │   │   inputs_embeds=inputs_embeds,                                          │
│                                                                                                  │
│ /root/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1501 in _call_impl      │
│                                                                                                  │
│   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1501 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1502 │   │   # Do not call functions when jit is used                                          │
│   1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1504 │   │   backward_pre_hooks = []                                                           │
│                                                                                                  │
│ /root/miniconda3/lib/python3.10/site-packages/peft/tuners/tuners_utils.py:161 in forward         │
│                                                                                                  │
│   158 │   │   return self.active_adapter                                                         │
│   159 │                                                                                          │
│   160 │   def forward(self, *args: Any, **kwargs: Any):                                          │
│ ❱ 161 │   │   return self.model.forward(*args, **kwargs)                                         │
│   162 │                                                                                          │
│   163 │   @abstractmethod                                                                        │
│   164 │   def _prepare_adapter_config(self, peft_config: PeftConfig, model_config: dict) -> Pe   │
│                                                                                                  │
│ /root/.cache/huggingface/modules/transformers_modules/chatglm3-6b/modeling_chatglm.py:937 in     │
│ forward                                                                                          │
│                                                                                                  │
│    934 │   │   use_cache = use_cache if use_cache is not None else self.config.use_cache         │
│    935 │   │   return_dict = return_dict if return_dict is not None else self.config.use_return  │
│    936 │   │                                                                                     │
│ ❱  937 │   │   transformer_outputs = self.transformer(                                           │
│    938 │   │   │   input_ids=input_ids,                                                          │
│    939 │   │   │   position_ids=position_ids,                                                    │
│    940 │   │   │   attention_mask=attention_mask,                                                │
│                                                                                                  │
│ /root/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1501 in _call_impl      │
│                                                                                                  │
│   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1501 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1502 │   │   # Do not call functions when jit is used                                          │
│   1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1504 │   │   backward_pre_hooks = []                                                           │
│                                                                                                  │
│ /root/.cache/huggingface/modules/transformers_modules/chatglm3-6b/modeling_chatglm.py:830 in     │
│ forward                                                                                          │
│                                                                                                  │
│    827 │   │   rotary_pos_emb = rotary_pos_emb.transpose(0, 1).contiguous()                      │
│    828 │   │                                                                                     │
│    829 │   │   # Run encoder.                                                                    │
│ ❱  830 │   │   hidden_states, presents, all_hidden_states, all_self_attentions = self.encoder(   │
│    831 │   │   │   inputs_embeds, full_attention_mask, rotary_pos_emb=rotary_pos_emb,            │
│    832 │   │   │   kv_caches=past_key_values, use_cache=use_cache, output_hidden_states=output_  │
│    833 │   │   )                                                                                 │
│                                                                                                  │
│ /root/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1501 in _call_impl      │
│                                                                                                  │
│   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1501 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1502 │   │   # Do not call functions when jit is used                                          │
│   1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1504 │   │   backward_pre_hooks = []                                                           │
│                                                                                                  │
│ /root/.cache/huggingface/modules/transformers_modules/chatglm3-6b/modeling_chatglm.py:631 in     │
│ forward                                                                                          │
│                                                                                                  │
│    628 │   │   │                                                                                 │
│    629 │   │   │   layer = self._get_layer(index)                                                │
│    630 │   │   │   if self.gradient_checkpointing and self.training:                             │
│ ❱  631 │   │   │   │   layer_ret = torch.utils.checkpoint.checkpoint(                            │
│    632 │   │   │   │   │   layer,                                                                │
│    633 │   │   │   │   │   hidden_states,                                                        │
│    634 │   │   │   │   │   attention_mask,                                                       │
│                                                                                                  │
│ /root/miniconda3/lib/python3.10/site-packages/torch/utils/checkpoint.py:249 in checkpoint        │
│                                                                                                  │
│   246 │   │   raise ValueError("Unexpected keyword arguments: " + ",".join(arg for arg in kwar   │
│   247 │                                                                                          │
│   248 │   if use_reentrant:                                                                      │
│ ❱ 249 │   │   return CheckpointFunction.apply(function, preserve, *args)                         │
│   250 │   else:                                                                                  │
│   251 │   │   return _checkpoint_without_reentrant(                                              │
│   252 │   │   │   function,                                                                      │
│                                                                                                  │
│ /root/miniconda3/lib/python3.10/site-packages/torch/autograd/function.py:506 in apply            │
│                                                                                                  │
│   503 │   │   if not torch._C._are_functorch_transforms_active():                                │
│   504 │   │   │   # See NOTE: [functorch vjp and autograd interaction]                           │
│   505 │   │   │   args = _functorch.utils.unwrap_dead_wrappers(args)                             │
│ ❱ 506 │   │   │   return super().apply(*args, **kwargs)  # type: ignore[misc]                    │
│   507 │   │                                                                                      │
│   508 │   │   if cls.setup_context == _SingleLevelFunction.setup_context:                        │
│   509 │   │   │   raise RuntimeError(                                                            │
│                                                                                                  │
│ /root/miniconda3/lib/python3.10/site-packages/torch/utils/checkpoint.py:107 in forward           │
│                                                                                                  │
│   104 │   │   ctx.save_for_backward(*tensor_inputs)                                              │
│   105 │   │                                                                                      │
│   106 │   │   with torch.no_grad():                                                              │
│ ❱ 107 │   │   │   outputs = run_function(*args)                                                  │
│   108 │   │   return outputs                                                                     │
│   109 │                                                                                          │
│   110 │   @staticmethod                                                                          │
│                                                                                                  │
│ /root/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1501 in _call_impl      │
│                                                                                                  │
│   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1501 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1502 │   │   # Do not call functions when jit is used                                          │
│   1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1504 │   │   backward_pre_hooks = []                                                           │
│                                                                                                  │
│ /root/.cache/huggingface/modules/transformers_modules/chatglm3-6b/modeling_chatglm.py:544 in     │
│ forward                                                                                          │
│                                                                                                  │
│    541 │   │   # Layer norm at the beginning of the transformer layer.                           │
│    542 │   │   layernorm_output = self.input_layernorm(hidden_states)                            │
│    543 │   │   # Self attention.                                                                 │
│ ❱  544 │   │   attention_output, kv_cache = self.self_attention(                                 │
│    545 │   │   │   layernorm_output,                                                             │
│    546 │   │   │   attention_mask,                                                               │
│    547 │   │   │   rotary_pos_emb,                                                               │
│                                                                                                  │
│ /root/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1501 in _call_impl      │
│                                                                                                  │
│   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1501 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1502 │   │   # Do not call functions when jit is used                                          │
│   1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1504 │   │   backward_pre_hooks = []                                                           │
│                                                                                                  │
│ /root/.cache/huggingface/modules/transformers_modules/chatglm3-6b/modeling_chatglm.py:441 in     │
│ forward                                                                                          │
│                                                                                                  │
│    438 │   │   # core attention computation                                                      │
│    439 │   │   # ==================================                                              │
│    440 │   │                                                                                     │
│ ❱  441 │   │   context_layer = self.core_attention(query_layer, key_layer, value_layer, attenti  │
│    442 │   │                                                                                     │
│    443 │   │   # =================                                                               │
│    444 │   │   # Output. [sq, b, h]                                                              │
│                                                                                                  │
│ /root/miniconda3/lib/python3.10/site-packages/torch/nn/modules/module.py:1501 in _call_impl      │
│                                                                                                  │
│   1498 │   │   if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks   │
│   1499 │   │   │   │   or _global_backward_pre_hooks or _global_backward_hooks                   │
│   1500 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   │
│ ❱ 1501 │   │   │   return forward_call(*args, **kwargs)                                          │
│   1502 │   │   # Do not call functions when jit is used                                          │
│   1503 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             │
│   1504 │   │   backward_pre_hooks = []                                                           │
│                                                                                                  │
│ /root/.cache/huggingface/modules/transformers_modules/chatglm3-6b/modeling_chatglm.py:231 in     │
│ forward                                                                                          │
│                                                                                                  │
│    228 │   │   │   else:                                                                         │
│    229 │   │   │   │   if attention_mask is not None:                                            │
│    230 │   │   │   │   │   attention_mask = ~attention_mask                                      │
│ ❱  231 │   │   │   │   context_layer = torch.nn.functional.scaled_dot_product_attention(query_l  │
│    232 │   │   │   │   │   │   │   │   │   │   │   │   │   │   │   │   │   │   │   │    attenti  │
│    233 │   │   │   context_layer = context_layer.permute(2, 0, 1, 3)                             │
│    234 │   │   │   new_context_layer_shape = context_layer.size()[:-2] + (self.hidden_size_per_  │
╰──────────────────────────────────────────────────────────────────────────────────────────────────╯
RuntimeError: CUDA error: device-side assert triggered
CUDA kernel errors might be asynchronously reported at some other API call, so the stacktrace below might be incorrect.
For debugging consider passing CUDA_LAUNCH_BLOCKING=1.
Compile with `TORCH_USE_CUDA_DSA` to enable device-side assertions.

Expected behavior / 期待表现

希望能够在不降低steps的情况下完成训练.

遇到超长的训练数据时(如果确实是因为这个的话),应该自动截断或跳过,并明确提示超长。