kuleshov-group / caduceus

Bi-Directional Equivariant Long-Range DNA Sequence Modeling
Apache License 2.0
137 stars 14 forks source link

Error using seq classification with d_model 256 #27

Closed leannmlindsey closed 1 month ago

leannmlindsey commented 2 months ago

I am getting an error when trying to do any sequence classification with models that have d_model = 256

Error:

Traceback (most recent call last): File "/uufs/chpc.utah.edu/common/home/sundar-group2/PHAGE/MODELS/CADUCEUS_PHAGE/caduceus/train.py", line 715, in main train(config) File "/uufs/chpc.utah.edu/common/home/sundar-group2/PHAGE/MODELS/CADUCEUS_PHAGE/caduceus/train.py", line 680, in train trainer.fit(model) File "/uufs/chpc.utah.edu/common/home/u1323098/software/pkg/miniconda3/envs/CADUCEUS_3/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 603, in fit call._call_and_handle_interrupt( File "/uufs/chpc.utah.edu/common/home/u1323098/software/pkg/miniconda3/envs/CADUCEUS_3/lib/python3.8/site-packages/pytorch_lightning/trainer/call.py", line 38, in _call_and_handle_interrupt return trainer_fn(*args, kwargs) File "/uufs/chpc.utah.edu/common/home/u1323098/software/pkg/miniconda3/envs/CADUCEUS_3/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 645, in _fit_impl self._run(model, ckpt_path=self.ckpt_path) File "/uufs/chpc.utah.edu/common/home/u1323098/software/pkg/miniconda3/envs/CADUCEUS_3/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1098, in _run results = self._run_stage() File "/uufs/chpc.utah.edu/common/home/u1323098/software/pkg/miniconda3/envs/CADUCEUS_3/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1177, in _run_stage self._run_train() File "/uufs/chpc.utah.edu/common/home/u1323098/software/pkg/miniconda3/envs/CADUCEUS_3/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1190, in _run_train self._run_sanity_check() File "/uufs/chpc.utah.edu/common/home/u1323098/software/pkg/miniconda3/envs/CADUCEUS_3/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1262, in _run_sanity_check val_loop.run() File "/uufs/chpc.utah.edu/common/home/u1323098/software/pkg/miniconda3/envs/CADUCEUS_3/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 199, in run self.advance(*args, *kwargs) File "/uufs/chpc.utah.edu/common/home/u1323098/software/pkg/miniconda3/envs/CADUCEUS_3/lib/python3.8/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 152, in advance dl_outputs = self.epoch_loop.run(self._data_fetcher, dl_max_batches, kwargs) File "/uufs/chpc.utah.edu/common/home/u1323098/software/pkg/miniconda3/envs/CADUCEUS_3/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 199, in run self.advance(args, kwargs) File "/uufs/chpc.utah.edu/common/home/u1323098/software/pkg/miniconda3/envs/CADUCEUS_3/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 137, in advance output = self._evaluation_step(kwargs) File "/uufs/chpc.utah.edu/common/home/u1323098/software/pkg/miniconda3/envs/CADUCEUS_3/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 234, in _evaluation_step output = self.trainer._call_strategy_hook(hook_name, kwargs.values()) File "/uufs/chpc.utah.edu/common/home/u1323098/software/pkg/miniconda3/envs/CADUCEUS_3/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1480, in _call_strategy_hook output = fn(args, kwargs) File "/uufs/chpc.utah.edu/common/home/u1323098/software/pkg/miniconda3/envs/CADUCEUS_3/lib/python3.8/site-packages/pytorch_lightning/strategies/strategy.py", line 390, in validation_step return self.model.validation_step(args, kwargs) File "/uufs/chpc.utah.edu/common/home/sundar-group2/PHAGE/MODELS/CADUCEUS_PHAGE/caduceus/train.py", line 447, in validation_step loss = self._shared_step( File "/uufs/chpc.utah.edu/common/home/sundar-group2/PHAGE/MODELS/CADUCEUS_PHAGE/caduceus/train.py", line 339, in _shared_step x, y, w = self.forward(batch) File "/uufs/chpc.utah.edu/common/home/sundar-group2/PHAGE/MODELS/CADUCEUS_PHAGE/caduceus/train.py", line 327, in forward return self.task.forward(batch, self.encoder, self.model, self.decoder, self._state) File "/uufs/chpc.utah.edu/common/home/sundar-group2/PHAGE/MODELS/CADUCEUS_PHAGE/caduceus/src/tasks/tasks.py", line 157, in forward x, w = decoder(x, state=state, z) File "/uufs/chpc.utah.edu/common/home/u1323098/software/pkg/miniconda3/envs/CADUCEUS_3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl return self._call_impl(args, kwargs) File "/uufs/chpc.utah.edu/common/home/u1323098/software/pkg/miniconda3/envs/CADUCEUS_3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl return forward_call(*args, kwargs) File "/uufs/chpc.utah.edu/common/home/sundar-group2/PHAGE/MODELS/CADUCEUS_PHAGE/caduceus/src/models/nn/utils.py", line 112, in forward x, kwargs = wrap_kwargs(layer.forward)(x, kwargs) File "/uufs/chpc.utah.edu/common/home/sundar-group2/PHAGE/MODELS/CADUCEUS_PHAGE/caduceus/src/models/nn/utils.py", line 77, in f_kwargs y = f(*bound.arguments) File "/uufs/chpc.utah.edu/common/home/sundar-group2/PHAGE/MODELS/CADUCEUS_PHAGE/caduceus/src/tasks/decoders.py", line 148, in forward x = self.output_transform(x.squeeze()) File "/uufs/chpc.utah.edu/common/home/u1323098/software/pkg/miniconda3/envs/CADUCEUS_3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl return self._call_impl(args, kwargs) File "/uufs/chpc.utah.edu/common/home/u1323098/software/pkg/miniconda3/envs/CADUCEUS_3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl return forward_call(*args, **kwargs) File "/uufs/chpc.utah.edu/common/home/u1323098/software/pkg/miniconda3/envs/CADUCEUS_3/lib/python3.8/site-packages/torch/nn/modules/linear.py", line 116, in forward return F.linear(input, self.weight, self.bias) RuntimeError: mat1 and mat2 shapes cannot be multiplied (128x128 and 256x2)

yair-schiff commented 2 months ago

Can you provide more details about what model you are using and what values you are passing for these flags:

model.conjoin_test
decoder.conjoin_train
decoder.conjoin_test
leannmlindsey commented 2 months ago

Sorry about the delay in response, I am at a conference this week.

caduceus_ps_LR-2e-3_BATCH_SIZE-128_RC_AUG-false_701970.log

What should those variables be?

These are the flags in the run that failed. I was using the given gb benchmark code...but I did change the pre-trained model to one that I had pre-trained on a different dataset. I think that I do not fully understand what these flags do, and that could be the cause of the problem...I will read the paper again and see if that clears things up for me.

model.conjoin_test=false decoder.conjoin_train=true decoder.conjoin_test=false

I will attach the log file.

caduceus_ps_LR-2e-3_BATCH_SIZE-128_RC_AUG-false_701970.log

yair-schiff commented 2 months ago

In your log I see that d_model = 128. I think in your case you'd want it to be 256.

yair-schiff commented 1 month ago

Feel free to re-open if there are additional issues here

leannmlindsey commented 1 month ago

sorry that I didn't respond. Yes that solved it.