Closed leannmlindsey closed 1 month ago
Can you provide more details about what model you are using and what values you are passing for these flags:
model.conjoin_test
decoder.conjoin_train
decoder.conjoin_test
Sorry about the delay in response, I am at a conference this week.
caduceus_ps_LR-2e-3_BATCH_SIZE-128_RC_AUG-false_701970.log
What should those variables be?
These are the flags in the run that failed. I was using the given gb benchmark code...but I did change the pre-trained model to one that I had pre-trained on a different dataset. I think that I do not fully understand what these flags do, and that could be the cause of the problem...I will read the paper again and see if that clears things up for me.
model.conjoin_test=false decoder.conjoin_train=true decoder.conjoin_test=false
I will attach the log file.
In your log I see that d_model = 128
. I think in your case you'd want it to be 256
.
Feel free to re-open if there are additional issues here
sorry that I didn't respond. Yes that solved it.
I am getting an error when trying to do any sequence classification with models that have d_model = 256
Error:
Traceback (most recent call last): File "/uufs/chpc.utah.edu/common/home/sundar-group2/PHAGE/MODELS/CADUCEUS_PHAGE/caduceus/train.py", line 715, in main train(config) File "/uufs/chpc.utah.edu/common/home/sundar-group2/PHAGE/MODELS/CADUCEUS_PHAGE/caduceus/train.py", line 680, in train trainer.fit(model) File "/uufs/chpc.utah.edu/common/home/u1323098/software/pkg/miniconda3/envs/CADUCEUS_3/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 603, in fit call._call_and_handle_interrupt( File "/uufs/chpc.utah.edu/common/home/u1323098/software/pkg/miniconda3/envs/CADUCEUS_3/lib/python3.8/site-packages/pytorch_lightning/trainer/call.py", line 38, in _call_and_handle_interrupt return trainer_fn(*args, kwargs) File "/uufs/chpc.utah.edu/common/home/u1323098/software/pkg/miniconda3/envs/CADUCEUS_3/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 645, in _fit_impl self._run(model, ckpt_path=self.ckpt_path) File "/uufs/chpc.utah.edu/common/home/u1323098/software/pkg/miniconda3/envs/CADUCEUS_3/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1098, in _run results = self._run_stage() File "/uufs/chpc.utah.edu/common/home/u1323098/software/pkg/miniconda3/envs/CADUCEUS_3/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1177, in _run_stage self._run_train() File "/uufs/chpc.utah.edu/common/home/u1323098/software/pkg/miniconda3/envs/CADUCEUS_3/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1190, in _run_train self._run_sanity_check() File "/uufs/chpc.utah.edu/common/home/u1323098/software/pkg/miniconda3/envs/CADUCEUS_3/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1262, in _run_sanity_check val_loop.run() File "/uufs/chpc.utah.edu/common/home/u1323098/software/pkg/miniconda3/envs/CADUCEUS_3/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 199, in run self.advance(*args, *kwargs) File "/uufs/chpc.utah.edu/common/home/u1323098/software/pkg/miniconda3/envs/CADUCEUS_3/lib/python3.8/site-packages/pytorch_lightning/loops/dataloader/evaluation_loop.py", line 152, in advance dl_outputs = self.epoch_loop.run(self._data_fetcher, dl_max_batches, kwargs) File "/uufs/chpc.utah.edu/common/home/u1323098/software/pkg/miniconda3/envs/CADUCEUS_3/lib/python3.8/site-packages/pytorch_lightning/loops/loop.py", line 199, in run self.advance(args, kwargs) File "/uufs/chpc.utah.edu/common/home/u1323098/software/pkg/miniconda3/envs/CADUCEUS_3/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 137, in advance output = self._evaluation_step(kwargs) File "/uufs/chpc.utah.edu/common/home/u1323098/software/pkg/miniconda3/envs/CADUCEUS_3/lib/python3.8/site-packages/pytorch_lightning/loops/epoch/evaluation_epoch_loop.py", line 234, in _evaluation_step output = self.trainer._call_strategy_hook(hook_name, kwargs.values()) File "/uufs/chpc.utah.edu/common/home/u1323098/software/pkg/miniconda3/envs/CADUCEUS_3/lib/python3.8/site-packages/pytorch_lightning/trainer/trainer.py", line 1480, in _call_strategy_hook output = fn(args, kwargs) File "/uufs/chpc.utah.edu/common/home/u1323098/software/pkg/miniconda3/envs/CADUCEUS_3/lib/python3.8/site-packages/pytorch_lightning/strategies/strategy.py", line 390, in validation_step return self.model.validation_step(args, kwargs) File "/uufs/chpc.utah.edu/common/home/sundar-group2/PHAGE/MODELS/CADUCEUS_PHAGE/caduceus/train.py", line 447, in validation_step loss = self._shared_step( File "/uufs/chpc.utah.edu/common/home/sundar-group2/PHAGE/MODELS/CADUCEUS_PHAGE/caduceus/train.py", line 339, in _shared_step x, y, w = self.forward(batch) File "/uufs/chpc.utah.edu/common/home/sundar-group2/PHAGE/MODELS/CADUCEUS_PHAGE/caduceus/train.py", line 327, in forward return self.task.forward(batch, self.encoder, self.model, self.decoder, self._state) File "/uufs/chpc.utah.edu/common/home/sundar-group2/PHAGE/MODELS/CADUCEUS_PHAGE/caduceus/src/tasks/tasks.py", line 157, in forward x, w = decoder(x, state=state, z) File "/uufs/chpc.utah.edu/common/home/u1323098/software/pkg/miniconda3/envs/CADUCEUS_3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl return self._call_impl(args, kwargs) File "/uufs/chpc.utah.edu/common/home/u1323098/software/pkg/miniconda3/envs/CADUCEUS_3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl return forward_call(*args, kwargs) File "/uufs/chpc.utah.edu/common/home/sundar-group2/PHAGE/MODELS/CADUCEUS_PHAGE/caduceus/src/models/nn/utils.py", line 112, in forward x, kwargs = wrap_kwargs(layer.forward)(x, kwargs) File "/uufs/chpc.utah.edu/common/home/sundar-group2/PHAGE/MODELS/CADUCEUS_PHAGE/caduceus/src/models/nn/utils.py", line 77, in f_kwargs y = f(*bound.arguments) File "/uufs/chpc.utah.edu/common/home/sundar-group2/PHAGE/MODELS/CADUCEUS_PHAGE/caduceus/src/tasks/decoders.py", line 148, in forward x = self.output_transform(x.squeeze()) File "/uufs/chpc.utah.edu/common/home/u1323098/software/pkg/miniconda3/envs/CADUCEUS_3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1511, in _wrapped_call_impl return self._call_impl(args, kwargs) File "/uufs/chpc.utah.edu/common/home/u1323098/software/pkg/miniconda3/envs/CADUCEUS_3/lib/python3.8/site-packages/torch/nn/modules/module.py", line 1520, in _call_impl return forward_call(*args, **kwargs) File "/uufs/chpc.utah.edu/common/home/u1323098/software/pkg/miniconda3/envs/CADUCEUS_3/lib/python3.8/site-packages/torch/nn/modules/linear.py", line 116, in forward return F.linear(input, self.weight, self.bias) RuntimeError: mat1 and mat2 shapes cannot be multiplied (128x128 and 256x2)