uci-cbcl / esm-efficient

24 stars 2 forks source link

loading directly from hugginface #3

Open pengzhangzhi opened 15 hours ago

pengzhangzhi commented 15 hours ago

Hi, In the current codebase, we have to download the ckpt to the local and load it using the following method:

# download "{model}.safetensors" to the local 
# and load it like below
model = ESM2.from_pretrained("{model}.safetensors", device=0)

I wonder if we can directly load the ckpt from Hugginface? such as

model = ESM2.from_pretrained("facebook/esm2_t30_150M_UR50D", device=0)

That way, it's more straightforward to replace existing codebase with a flash-attention version of esm2.

It seems doable to me bc eesm shares the same model architecture with ESM2 except for the use of flash attention?

Would love to hear ur thoughts!

pengzhangzhi commented 12 hours ago

I'm working on it. I guess the work is converting of the names defined in esm-efficient to be what's defined in esm2, which is the standrad hugginface names? Let me know if you are interested! We can talk more about that!!