HazyResearch / hyena-dna

Official implementation for HyenaDNA, a long-range genomic foundation model built with Hyena
https://arxiv.org/abs/2306.15794
Apache License 2.0
532 stars 74 forks source link

Running with standard Huggingface config and trainer files does not give optimal results #60

Open leannmlindsey opened 3 months ago

leannmlindsey commented 3 months ago

Hello, I have been running your model since last summer using a standard huggingface model framework, (see code below). And it has not been giving us the same results on the benchmarking tests as you report in the paper. For example:

GenomicBenchmarks

Mouse Enhancers, you report 85.1. our results 63.6 Human Enhancers Cohn, you report 74.2, our results 66.3

I think it is possible that it is because we are not using the parameter that you have at the bottom of the config file, freeze_backbone: false but I am not sure how to incorporate this into a standard Huggingface trainer.

Do you support the huggingface trainer or only the hydra/lightning trainer?

My concern is that since we are not able to match your reported results, perhaps your model is not performing optimally on our specific classification task. I had expected it to be comparable in performance to DNABERT2 but it was not. I think this may be because we have not set up our run correctly. Any direction would be appreciated. Thank you.

Sample Code checkpoint = 'LongSafari/hyenadna-tiny-16k-seqlen-d128-hf' max_length = 4010 args = { "output_dir": "test_output", "num_train_epochs": 25, "per_device_train_batch_size": 512, "per_device_eval_batch_size": 512, "gradient_accumulation_steps": 4, "gradient_checkpointing": False, "learning_rate": 2e-5, "evaluation_strategy": "steps", "eval_steps": 1, "wandb": "null", } training_args = TrainingArguments(**args)

trainer = Trainer(model=model, args=training_args, train_dataset=ds_tok_train, eval_dataset=ds_tok_val, compute_metrics=compute_metrics) trainer.train() tokenizer = AutoTokenizer.from_pretrained(checkpoint, trust_remote_code=True) model = AutoModelForSequenceClassification.from_pretrained(checkpoint, torch_dtype=torch.bfloat16, device_map="auto", pad_token_id=tokenizer.pad_token_id, trust_remote_code=True)