MAGICS-LAB / DNABERT_2

[ICLR 2024] DNABERT-2: Efficient Foundation Model and Benchmark for Multi-Species Genome
Apache License 2.0
212 stars 49 forks source link

Triton version issue #60

Closed yuddecho closed 6 months ago

yuddecho commented 6 months ago

The test code is as follows(my transformers==4.29.2):

import torch
from transformers import AutoTokenizer, AutoModel

model_id = 'zhihan1996/DNABERT-2-117M'

tokenizer = AutoTokenizer.from_pretrained(model_id, trust_remote_code=True)
model = AutoModel.from_pretrained(model_id, trust_remote_code=True)

print(model)

dna = "ACGTAGCATCGGATCTATCTATCGACACTTGGTTATCGATCTACGAGCATCTCGTTAGC"
inputs = tokenizer(dna, return_tensors = 'pt')["input_ids"]

inputs = inputs.to('cuda')
model = model.to('cuda')

hidden_states = model(inputs)[0] # [1, sequence_length, 768]

# embedding with mean pooling
embedding_mean = torch.mean(hidden_states[0], dim=0)
print(embedding_mean.shape) # expect to be 768

# embedding with max pooling
embedding_max = torch.max(hidden_states[0], dim=0)[0]
print(embedding_max.shape) # expect to be 768

When I usetriton=2.0.0.dev20221103, there is an RuntimeError: CUDA: Error- invalid source error, log:

File ~/.conda/envs/yudd/lib/python3.10/site-packages/triton/compiler.py:1301, in CompiledKernel.__init__(self, fn_name, so_path, cache_dir, device)
   1298 with open(os.path.join(cache_dir, f"{fn_name}.ttir"), "r") as f:
   1299     self.asm["ttir"] = f.read()
-> 1301 mod, func, n_regs, n_spills = _triton.code_gen.load_binary(metadata["name"], self.asm["cubin"], self.shared, device)
   1302 self.fn_name = fn_name
   1303 self.cu_module = mod

RuntimeError: CUDA: Error- invalid source

When I usetriton=2.1.0, there is an TypeError("dot() got an unexpected keyword argument 'trans_b'") error, log:

File ~/.conda/envs/yudd/lib/python3.10/site-packages/triton/compiler/code_generator.py:1133, in ast_to_ttir(fn, signature, specialization, constants, debug, arch)
   1131     if node is None:
   1132         raise
-> 1133     raise CompilationError(fn.src, node, repr(e)) from e
   1134 ret = generator.module
   1135 # module takes ownership of the context

CompilationError: at 114:24:        else:
            if EVEN_HEADDIM:
                k = tl.load(k_ptrs + start_n * stride_kn,
                            mask=(start_n + offs_n)[:, None] < seqlen_k,
                            other=0.0)
            else:
                k = tl.load(k_ptrs + start_n * stride_kn,
                            mask=((start_n + offs_n)[:, None] < seqlen_k) &
                            (offs_d[None, :] < headdim),
                            other=0.0)
        qk = tl.zeros([BLOCK_M, BLOCK_N], dtype=tl.float32)
        qk += tl.dot(q, k, trans_b=True)
                        ^
TypeError("dot() got an unexpected keyword argument 'trans_b'")

What should I do now. Thank you.

yuddecho commented 6 months ago

When I pip uninstall triton, it works.