MAGICS-LAB / DNABERT_2

[ICLR 2024] DNABERT-2: Efficient Foundation Model and Benchmark for Multi-Species Genome
Apache License 2.0
254 stars 59 forks source link

AssertionError in model(inputs) #54

Closed ZhaoyueZhang closed 11 months ago

ZhaoyueZhang commented 11 months ago

I set the environment according to

create and activate virtual python environment

conda create -n dna python=3.8 conda activate dna

(optional if you would like to use flash attention)

install triton from source

git clone https://github.com/openai/triton.git; cd triton/python; pip install cmake; # build-time dependency pip install -e .

install required packages

python3 -m pip install -r requirements.txt

However, AssertionError occured when run the model(inputs).

....flash_attn_triton.py", line 781, in _flash_attn_forward assert q.is_cuda and k.is_cuda and v.is_cuda

I found other users also encounter this situation. Could you help me solve this problem?

ZhaoyueZhang commented 11 months ago

which triton should be installed, triton-2.1.0 or treton 2.0.0

ZhaoyueZhang commented 11 months ago

when I tried model.cuda() and inputs.cuda(), return code_generator.py", line 1233, in ast_to_ttir raise CompilationError(fn.src, node, repr(e)) from e triton.compiler.errors.CompilationError: at 114:24: else: if EVEN_HEADDIM: k = tl.load(k_ptrs + start_n stride_kn, mask=(start_n + offs_n)[:, None] < seqlen_k, other=0.0) else: k = tl.load(k_ptrs + start_n stride_kn, mask=((start_n + offs_n)[:, None] < seqlen_k) & (offs_d[None, :] < headdim), other=0.0) qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32) qk += tl.dot(q, k, trans_b=True) ^ TypeError("dot() got an unexpected keyword argument 'trans_b'")

ZhaoyueZhang commented 11 months ago

uninstall triton still occured AssertionError.

ZhaoyueZhang commented 11 months ago

change transformers 4.28.0 to 4.30.2 did not fix the matter.

ZhaoyueZhang commented 11 months ago

delete the previous environment and rebuilt the environment. got "AttributeError: module 'triton' has no attribute 'autotune'" when run the model = AutoModel.from_pretrained("zhihan1996/DNABERT-2-117M", trust_remote_code=True)

ZhaoyueZhang commented 11 months ago

delete the previous environment. build the environment without install triton. pip uninstall triton, 'hidden_states = model(inputs)[0]' can run without error.