MAGICS-LAB / DNABERT_2

[ICLR 2024] DNABERT-2: Efficient Foundation Model and Benchmark for Multi-Species Genome
Apache License 2.0
227 stars 51 forks source link

flash_attn_triton.py #11

Open up4472 opened 1 year ago

up4472 commented 1 year ago

When calling : hidden_states = model(inputs)[0] # [1, sequence_length, 768]

we receive traceback:

in <cell line: 1>:1 │ │ │ │ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1501 in _call_impl │ │ │ │ 1498 │ │ if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks │ │ 1499 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hooks │ │ 1500 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │ │ ❱ 1501 │ │ │ return forward_call(*args, kwargs) │ │ 1502 │ │ # Do not call functions when jit is used │ │ 1503 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │ │ 1504 │ │ backward_pre_hooks = [] │ │ │ │ /root/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-2-117M/5fd206e1a13cee3e │ │ f4a608677312175eb6f8143d/bert_layers.py:608 in forward │ │ │ │ 605 │ │ │ first_col_mask[:, 0] = True │ │ 606 │ │ │ subset_mask = masked_tokens_mask | first_col_mask │ │ 607 │ │ │ │ ❱ 608 │ │ encoder_outputs = self.encoder( │ │ 609 │ │ │ embedding_output, │ │ 610 │ │ │ attention_mask, │ │ 611 │ │ │ output_all_encoded_layers=output_all_encoded_layers, │ │ │ │ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1501 in _call_impl │ │ │ │ 1498 │ │ if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks │ │ 1499 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hooks │ │ 1500 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │ │ ❱ 1501 │ │ │ return forward_call(*args, *kwargs) │ │ 1502 │ │ # Do not call functions when jit is used │ │ 1503 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │ │ 1504 │ │ backward_pre_hooks = [] │ │ │ │ /root/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-2-117M/5fd206e1a13cee3e │ │ f4a608677312175eb6f8143d/bert_layers.py:446 in forward │ │ │ │ 443 │ │ all_encoder_layers = [] │ │ 444 │ │ if subset_mask is None: │ │ 445 │ │ │ for layer_module in self.layer: │ │ ❱ 446 │ │ │ │ hidden_states = layer_module(hidden_states, │ │ 447 │ │ │ │ │ │ │ │ │ │ │ cu_seqlens, │ │ 448 │ │ │ │ │ │ │ │ │ │ │ seqlen, │ │ 449 │ │ │ │ │ │ │ │ │ │ │ None, │ │ │ │ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1501 in _call_impl │ │ │ │ 1498 │ │ if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks │ │ 1499 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hooks │ │ 1500 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │ │ ❱ 1501 │ │ │ return forward_call(args, kwargs) │ │ 1502 │ │ # Do not call functions when jit is used │ │ 1503 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │ │ 1504 │ │ backward_pre_hooks = [] │ │ │ │ /root/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-2-117M/5fd206e1a13cee3e │ │ f4a608677312175eb6f8143d/bert_layers.py:327 in forward │ │ │ │ 324 │ │ │ attn_mask: None or (batch, max_seqlen_in_batch) │ │ 325 │ │ │ bias: None or (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch) │ │ 326 │ │ """ │ │ ❱ 327 │ │ attention_output = self.attention(hidden_states, cu_seqlens, seqlen, │ │ 328 │ │ │ │ │ │ │ │ │ │ subset_idx, indices, attn_mask, bias) │ │ 329 │ │ layer_output = self.mlp(attention_output) │ │ 330 │ │ return layer_output │ │ │ │ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1501 in _call_impl │ │ │ │ 1498 │ │ if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks │ │ 1499 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hooks │ │ 1500 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │ │ ❱ 1501 │ │ │ return forward_call(*args, kwargs) │ │ 1502 │ │ # Do not call functions when jit is used │ │ 1503 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │ │ 1504 │ │ backward_pre_hooks = [] │ │ │ │ /root/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-2-117M/5fd206e1a13cee3e │ │ f4a608677312175eb6f8143d/bert_layers.py:240 in forward │ │ │ │ 237 │ │ │ attn_mask: None or (batch, max_seqlen_in_batch) │ │ 238 │ │ │ bias: None or (batch, heads, max_seqlen_in_batch, max_seqlen_in_batch) │ │ 239 │ │ """ │ │ ❱ 240 │ │ self_output = self.self(input_tensor, cu_seqlens, max_s, indices, │ │ 241 │ │ │ │ │ │ │ │ attn_mask, bias) │ │ 242 │ │ if subset_idx is not None: │ │ 243 │ │ │ return self.output(index_first_axis(self_output, subset_idx), │ │ │ │ /usr/local/lib/python3.10/dist-packages/torch/nn/modules/module.py:1501 in _call_impl │ │ │ │ 1498 │ │ if not (self._backward_hooks or self._backward_pre_hooks or self._forward_hooks │ │ 1499 │ │ │ │ or _global_backward_pre_hooks or _global_backward_hooks │ │ 1500 │ │ │ │ or _global_forward_hooks or _global_forward_pre_hooks): │ │ ❱ 1501 │ │ │ return forward_call(*args, *kwargs) │ │ 1502 │ │ # Do not call functions when jit is used │ │ 1503 │ │ full_backward_hooks, non_full_backward_hooks = [], [] │ │ 1504 │ │ backward_pre_hooks = [] │ │ │ │ /root/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-2-117M/5fd206e1a13cee3e │ │ f4a608677312175eb6f8143d/bert_layers.py:181 in forward │ │ │ │ 178 │ │ │ │ qkv = qkv.to(torch.float16) │ │ 179 │ │ │ │ bias_dtype = bias.dtype │ │ 180 │ │ │ │ bias = bias.to(torch.float16) │ │ ❱ 181 │ │ │ │ attention = flash_attn_qkvpacked_func(qkv, bias) │ │ 182 │ │ │ │ attention = attention.to(orig_dtype) │ │ 183 │ │ │ │ bias = bias.to(bias_dtype) │ │ 184 │ │ │ else: │ │ │ │ /usr/local/lib/python3.10/dist-packages/torch/autograd/function.py:506 in apply │ │ │ │ 503 │ │ if not torch._C._are_functorch_transforms_active(): │ │ 504 │ │ │ # See NOTE: [functorch vjp and autograd interaction] │ │ 505 │ │ │ args = _functorch.utils.unwrap_dead_wrappers(args) │ │ ❱ 506 │ │ │ return super().apply(args, kwargs) # type: ignore[misc] │ │ 507 │ │ │ │ 508 │ │ if cls.setup_context == _SingleLevelFunction.setup_context: │ │ 509 │ │ │ raise RuntimeError( │ │ │ │ /root/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-2-117M/5fd206e1a13cee3e │ │ f4a608677312175eb6f8143d/flash_attn_triton.py:1021 in forward │ │ │ │ 1018 │ │ # Make sure that the last dimension is contiguous │ │ 1019 │ │ if qkv.stride(-1) != 1: │ │ 1020 │ │ │ qkv = qkv.contiguous() │ │ ❱ 1021 │ │ o, lse, ctx.softmax_scale = _flash_attn_forward( │ │ 1022 │ │ │ qkv[:, :, 0], │ │ 1023 │ │ │ qkv[:, :, 1], │ │ 1024 │ │ │ qkv[:, :, 2], │ │ │ │ /root/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-2-117M/5fd206e1a13cee3e │ │ f4a608677312175eb6f8143d/flash_attn_triton.py:781 in _flash_attn_forward │ │ │ │ 778 │ assert q.dtype == k.dtype == v.dtype, 'All tensors must have the same type' │ │ 779 │ assert q.dtype in [torch.float16, │ │ 780 │ │ │ │ │ torch.bfloat16], 'Only support fp16 and bf16' │ │ ❱ 781 │ assert q.is_cuda and k.is_cuda and v.is_cuda │ │ 782 │ softmax_scale = softmax_scale or 1.0 / math.sqrt(d) │ │ 783 │ │ │ 784 │ has_bias = bias is not None │ ╰──────────────────────────────────────────────────────────────────────────────────────────────────╯ AssertionError

finngaida commented 1 year ago

+1 same issue. Interestingly on macOS I can run on cpu just fine, but on Linux and Colab (no GPU) I get this error

hosnaa commented 1 year ago

Had the same issue, just changed the device of the model and inputs to .cuda()

Zhihan1996 commented 1 year ago

Hey,

Can you please try to install triton from source?

git clone https://github.com/openai/triton.git;
cd triton/python;
pip install cmake; # build-time dependency
pip install -e .
tlimato commented 12 months ago

ERROR: Failed building editable for triton Failed to build triton ERROR: Could not build wheels for triton, which is required to install pyproject.toml-based projects

leannmlindsey commented 10 months ago

I was able to get past this by uninstalling triton:

Here are my instructions for creating a conda env:

git clone https://github.com/Zhihan1996/DNABERT_2.git cd DNABERT_2 conda create -n dna_sandbox python=3.8
conda activate dna_sandbox

Then open the requirements.txt and change the transformers line to read

transformers==4.28.0

Then close and save the file

python3 -m pip install -r requirements.txt
pip uninstall triton pip install scikit-learn

Hopefully that will also work for you

jessicakan789 commented 6 months ago

Thanks leannmlindsey!

Just for other people who have had errors:

At first I had an fp16 error for train.py:

$ python train.py     --model_name_or_path zhihan1996/DNABERT-2-117M     --data_path  ${DATA_PATH}     --kmer -1     --run_name DNABERT2_${DATA_PATH}     --model_max_length ${MAX_LENGTH}     --per_device_train_batch_size 8     --per_device_eval_batch_size 16     --gradient_accumulation_steps 1     --learning_rate ${LR}     --num_train_epochs 5    --fp16    --save_steps 200     --output_dir output/dnabert2     --evaluation_strategy steps     --eval_steps 200     --warmup_steps 50     --logging_steps 100     --overwrite_output_dir True     --log_level info     --find_unused_parameters False
Traceback (most recent call last):
  File "train.py", line 303, in <module>
    train()
  File "train.py", line 227, in train
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()
  File "/data/jess_tmp/fh/miniconda3/envs/dna/lib/python3.8/site-packages/transformers/hf_argparser.py", line 346, in parse_args_into_dataclasses
    obj = dtype(**inputs)
  File "<string>", line 117, in __init__
  File "/data/jess_tmp/fh/miniconda3/envs/dna/lib/python3.8/site-packages/transformers/training_args.py", line 1337, in __post_init__
    raise ValueError(
ValueError: FP16 Mixed precision training with AMP or APEX (`--fp16`) and FP16 half precision evaluation (`--fp16_full_eval`) can only be used on CUDA devices.

Then I removed the --fp16 flag and got an assertion error with cuda:

$ python train.py     --model_name_or_path zhihan1996/DNABERT-2-117M     --data_path  ${DATA_PATH}     --kmer -1     --run_name DNABERT2_${DATA_PATH}     --model_max_length ${MAX_LENGTH}     --per_device_train_batch_size 8     --per_device_eval_batch_size 16     --gradient_accumulation_steps 1     --learning_rate ${LR}     --num_train_epochs 5     --save_steps 200     --output_dir output/dnabert2     --evaluation_strategy steps     --eval_steps 200     --warmup_steps 50     --logging_steps 100     --overwrite_output_dir True     --log_level info     --find_unused_parameters False
WARNING:root:Perform single sequence classification...
WARNING:root:Perform single sequence classification...
WARNING:root:Perform single sequence classification...
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
        - Avoid using `tokenizers` before the fork if possible
        - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
        - Avoid using `tokenizers` before the fork if possible
        - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Some weights of the model checkpoint at zhihan1996/DNABERT-2-117M were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at zhihan1996/DNABERT-2-117M and are newly initialized: ['classifier.weight', 'bert.pooler.dense.bias', 'classifier.bias', 'bert.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
***** Running training *****
  Num examples = 36,496
  Num Epochs = 5
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 22,810
  Number of trainable parameters = 117,070,851
  0%|                                                                                                                                                                                                              | 0/22810 [00:00<?, ?it/s]Traceback (most recent call last):
  File "train.py", line 303, in <module>
    train()
  File "train.py", line 285, in train
    trainer.train()
  File "/data/jess_tmp/fh/miniconda3/envs/dna/lib/python3.8/site-packages/transformers/trainer.py", line 1664, in train
    return inner_training_loop(
  File "/data/jess_tmp/fh/miniconda3/envs/dna/lib/python3.8/site-packages/transformers/trainer.py", line 1940, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/data/jess_tmp/fh/miniconda3/envs/dna/lib/python3.8/site-packages/transformers/trainer.py", line 2735, in training_step
    loss = self.compute_loss(model, inputs)
  File "/data/jess_tmp/fh/miniconda3/envs/dna/lib/python3.8/site-packages/transformers/trainer.py", line 2767, in compute_loss
    outputs = model(**inputs)
  File "/data/jess_tmp/fh/miniconda3/envs/dna/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/jess_tmp/fh/miniconda3/envs/dna/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/jess/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-2-117M/25abaf0bd247444fcfa837109f12088114898d98/bert_layers.py", line 862, in forward
    outputs = self.bert(
  File "/data/jess_tmp/fh/miniconda3/envs/dna/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/jess_tmp/fh/miniconda3/envs/dna/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/jess/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-2-117M/25abaf0bd247444fcfa837109f12088114898d98/bert_layers.py", line 608, in forward
    encoder_outputs = self.encoder(
  File "/data/jess_tmp/fh/miniconda3/envs/dna/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/jess_tmp/fh/miniconda3/envs/dna/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/jess/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-2-117M/25abaf0bd247444fcfa837109f12088114898d98/bert_layers.py", line 446, in forward
    hidden_states = layer_module(hidden_states,
  File "/data/jess_tmp/fh/miniconda3/envs/dna/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/jess_tmp/fh/miniconda3/envs/dna/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/jess/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-2-117M/25abaf0bd247444fcfa837109f12088114898d98/bert_layers.py", line 327, in forward
    attention_output = self.attention(hidden_states, cu_seqlens, seqlen,
  File "/data/jess_tmp/fh/miniconda3/envs/dna/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/jess_tmp/fh/miniconda3/envs/dna/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/jess/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-2-117M/25abaf0bd247444fcfa837109f12088114898d98/bert_layers.py", line 240, in forward
    self_output = self.self(input_tensor, cu_seqlens, max_s, indices,
  File "/data/jess_tmp/fh/miniconda3/envs/dna/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/jess_tmp/fh/miniconda3/envs/dna/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/jess/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-2-117M/25abaf0bd247444fcfa837109f12088114898d98/bert_layers.py", line 181, in forward
    attention = flash_attn_qkvpacked_func(qkv, bias)
  File "/data/jess_tmp/fh/miniconda3/envs/dna/lib/python3.8/site-packages/torch/autograd/function.py", line 539, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/home/jess/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-2-117M/25abaf0bd247444fcfa837109f12088114898d98/flash_attn_triton.py", line 1021, in forward
    o, lse, ctx.softmax_scale = _flash_attn_forward(
  File "/home/jess/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-2-117M/25abaf0bd247444fcfa837109f12088114898d98/flash_attn_triton.py", line 781, in _flash_attn_forward
    assert q.is_cuda and k.is_cuda and v.is_cuda
AssertionError
  0%|      

Then I applied leannmlindsey's method which seems to be working but taking a long time.

Earthones commented 4 months ago

I also met the problem, and I handled this with leannmlindsey's method.

mengchengyao commented 2 months ago

Thanks leannmlindsey!

Just for other people who have had errors:

At first I had an fp16 error for train.py:

$ python train.py     --model_name_or_path zhihan1996/DNABERT-2-117M     --data_path  ${DATA_PATH}     --kmer -1     --run_name DNABERT2_${DATA_PATH}     --model_max_length ${MAX_LENGTH}     --per_device_train_batch_size 8     --per_device_eval_batch_size 16     --gradient_accumulation_steps 1     --learning_rate ${LR}     --num_train_epochs 5    --fp16    --save_steps 200     --output_dir output/dnabert2     --evaluation_strategy steps     --eval_steps 200     --warmup_steps 50     --logging_steps 100     --overwrite_output_dir True     --log_level info     --find_unused_parameters False
Traceback (most recent call last):
  File "train.py", line 303, in <module>
    train()
  File "train.py", line 227, in train
    model_args, data_args, training_args = parser.parse_args_into_dataclasses()
  File "/data/jess_tmp/fh/miniconda3/envs/dna/lib/python3.8/site-packages/transformers/hf_argparser.py", line 346, in parse_args_into_dataclasses
    obj = dtype(**inputs)
  File "<string>", line 117, in __init__
  File "/data/jess_tmp/fh/miniconda3/envs/dna/lib/python3.8/site-packages/transformers/training_args.py", line 1337, in __post_init__
    raise ValueError(
ValueError: FP16 Mixed precision training with AMP or APEX (`--fp16`) and FP16 half precision evaluation (`--fp16_full_eval`) can only be used on CUDA devices.

Then I removed the --fp16 flag and got an assertion error with cuda:

$ python train.py     --model_name_or_path zhihan1996/DNABERT-2-117M     --data_path  ${DATA_PATH}     --kmer -1     --run_name DNABERT2_${DATA_PATH}     --model_max_length ${MAX_LENGTH}     --per_device_train_batch_size 8     --per_device_eval_batch_size 16     --gradient_accumulation_steps 1     --learning_rate ${LR}     --num_train_epochs 5     --save_steps 200     --output_dir output/dnabert2     --evaluation_strategy steps     --eval_steps 200     --warmup_steps 50     --logging_steps 100     --overwrite_output_dir True     --log_level info     --find_unused_parameters False
WARNING:root:Perform single sequence classification...
WARNING:root:Perform single sequence classification...
WARNING:root:Perform single sequence classification...
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
        - Avoid using `tokenizers` before the fork if possible
        - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...
To disable this warning, you can either:
        - Avoid using `tokenizers` before the fork if possible
        - Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)
Some weights of the model checkpoint at zhihan1996/DNABERT-2-117M were not used when initializing BertForSequenceClassification: ['cls.predictions.transform.LayerNorm.bias', 'cls.predictions.transform.LayerNorm.weight', 'cls.predictions.transform.dense.weight', 'cls.predictions.decoder.bias', 'cls.predictions.decoder.weight', 'cls.predictions.transform.dense.bias']
- This IS expected if you are initializing BertForSequenceClassification from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).
- This IS NOT expected if you are initializing BertForSequenceClassification from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).
Some weights of BertForSequenceClassification were not initialized from the model checkpoint at zhihan1996/DNABERT-2-117M and are newly initialized: ['classifier.weight', 'bert.pooler.dense.bias', 'classifier.bias', 'bert.pooler.dense.weight']
You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.
***** Running training *****
  Num examples = 36,496
  Num Epochs = 5
  Instantaneous batch size per device = 8
  Total train batch size (w. parallel, distributed & accumulation) = 8
  Gradient Accumulation steps = 1
  Total optimization steps = 22,810
  Number of trainable parameters = 117,070,851
  0%|                                                                                                                                                                                                              | 0/22810 [00:00<?, ?it/s]Traceback (most recent call last):
  File "train.py", line 303, in <module>
    train()
  File "train.py", line 285, in train
    trainer.train()
  File "/data/jess_tmp/fh/miniconda3/envs/dna/lib/python3.8/site-packages/transformers/trainer.py", line 1664, in train
    return inner_training_loop(
  File "/data/jess_tmp/fh/miniconda3/envs/dna/lib/python3.8/site-packages/transformers/trainer.py", line 1940, in _inner_training_loop
    tr_loss_step = self.training_step(model, inputs)
  File "/data/jess_tmp/fh/miniconda3/envs/dna/lib/python3.8/site-packages/transformers/trainer.py", line 2735, in training_step
    loss = self.compute_loss(model, inputs)
  File "/data/jess_tmp/fh/miniconda3/envs/dna/lib/python3.8/site-packages/transformers/trainer.py", line 2767, in compute_loss
    outputs = model(**inputs)
  File "/data/jess_tmp/fh/miniconda3/envs/dna/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/jess_tmp/fh/miniconda3/envs/dna/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/jess/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-2-117M/25abaf0bd247444fcfa837109f12088114898d98/bert_layers.py", line 862, in forward
    outputs = self.bert(
  File "/data/jess_tmp/fh/miniconda3/envs/dna/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/jess_tmp/fh/miniconda3/envs/dna/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/jess/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-2-117M/25abaf0bd247444fcfa837109f12088114898d98/bert_layers.py", line 608, in forward
    encoder_outputs = self.encoder(
  File "/data/jess_tmp/fh/miniconda3/envs/dna/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/jess_tmp/fh/miniconda3/envs/dna/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/jess/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-2-117M/25abaf0bd247444fcfa837109f12088114898d98/bert_layers.py", line 446, in forward
    hidden_states = layer_module(hidden_states,
  File "/data/jess_tmp/fh/miniconda3/envs/dna/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/jess_tmp/fh/miniconda3/envs/dna/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/jess/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-2-117M/25abaf0bd247444fcfa837109f12088114898d98/bert_layers.py", line 327, in forward
    attention_output = self.attention(hidden_states, cu_seqlens, seqlen,
  File "/data/jess_tmp/fh/miniconda3/envs/dna/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/jess_tmp/fh/miniconda3/envs/dna/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/jess/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-2-117M/25abaf0bd247444fcfa837109f12088114898d98/bert_layers.py", line 240, in forward
    self_output = self.self(input_tensor, cu_seqlens, max_s, indices,
  File "/data/jess_tmp/fh/miniconda3/envs/dna/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/data/jess_tmp/fh/miniconda3/envs/dna/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/home/jess/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-2-117M/25abaf0bd247444fcfa837109f12088114898d98/bert_layers.py", line 181, in forward
    attention = flash_attn_qkvpacked_func(qkv, bias)
  File "/data/jess_tmp/fh/miniconda3/envs/dna/lib/python3.8/site-packages/torch/autograd/function.py", line 539, in apply
    return super().apply(*args, **kwargs)  # type: ignore[misc]
  File "/home/jess/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-2-117M/25abaf0bd247444fcfa837109f12088114898d98/flash_attn_triton.py", line 1021, in forward
    o, lse, ctx.softmax_scale = _flash_attn_forward(
  File "/home/jess/.cache/huggingface/modules/transformers_modules/zhihan1996/DNABERT-2-117M/25abaf0bd247444fcfa837109f12088114898d98/flash_attn_triton.py", line 781, in _flash_attn_forward
    assert q.is_cuda and k.is_cuda and v.is_cuda
AssertionError
  0%|      

Then I applied leannmlindsey's method which seems to be working but taking a long time.

maby this method could solve the promblem: https://github.com/MAGICS-LAB/DNABERT_2/issues/71