Open Yoon-Jeong-ho opened 1 month ago
I can't reproduce this on my local machine, but I also don't have multiple GPUs. Does this only happen when using multiple GPUs?
Yes, I didn't encounter this error when using just one GPU, but when using multiple GPUs with a larger context size and latent size, causing higher GPU memory usage, this error occurs.
Previously, the same error occurred when learning the same-size sparse autoencoder.
self.scaler = torch.cuda.amp.GradScaler(enabled=self.cfg.autocast) 3 Training SAE: 0%| | 0/2048000000 [00:00<?, ?it/s]/data_x/aa007878/miniconda3/envs/sae/lib/python3.11/site-packages/sae_lens/training/activations_store.py:283: UserWarning: To copy construct from a tensor, it is recommended to use sourceTensor.clone().detach() or sourceTensor.clone().detach().requires_grad_(True), rather than torch.tensor(sourceTensor). 4 yield torch.tensor(ng factor: 0%| | 0/1000 [00:00<?, ?it/s] 5 Estimating norm scaling factor: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [08:33<00:00, 1.95it/s] 6 249900| MSE Loss 7.922 | L1 72.222: 50%|█████████████████████████████████████▉ | 1023590400/2048000000 [37:11:49<40:19:52, 7055.51it/s]Traceback (most recent call last): 7 File "/data_x/aa007878/SAE/sae_training.py", line 95, in <module> 8 sparse_autoencoder = SAETrainingRunner(cfg).run() 9 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 10 File "/data_x/aa007878/miniconda3/envs/sae/lib/python3.11/site-packages/sae_lens/sae_training_runner.py", line 106, in run 11 sae = self.run_trainer_with_interruption_handling(trainer) 12 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 13 File "/data_x/aa007878/miniconda3/envs/sae/lib/python3.11/site-packages/sae_lens/sae_training_runner.py", line 150, in run_trainer_with_interruption_handling 14 sae = trainer.fit() 15 ^^^^^^^^^^^^^ 16 File "/data_x/aa007878/miniconda3/envs/sae/lib/python3.11/site-packages/sae_lens/training/sae_trainer.py", line 176, in fit 17 self._run_and_log_evals() 18 File "/data_x/aa007878/miniconda3/envs/sae/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context 19 return func(*args, **kwargs) 20 ^^^^^^^^^^^^^^^^^^^^^ 21 File "/data_x/aa007878/miniconda3/envs/sae/lib/python3.11/site-packages/sae_lens/training/sae_trainer.py", line 333, in _run_and_log_evals 22 eval_metrics = run_evals( 23 ^^^^^^^^^^ 24 File "/data_x/aa007878/miniconda3/envs/sae/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context 25 return func(*args, **kwargs) 26 ^^^^^^^^^^^^^^^^^^^^^ 27 File "/data_x/aa007878/miniconda3/envs/sae/lib/python3.11/site-packages/sae_lens/evals.py", line 105, in run_evals 28 metrics |= get_sparsity_and_variance_metrics( 29 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 30 File "/data_x/aa007878/miniconda3/envs/sae/lib/python3.11/site-packages/sae_lens/evals.py", line 289, in get_sparsity_and_variance_metrics 31 flattened_sae_input = flattened_sae_input[flattened_mask] 32 ~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^ 33 RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cuda:1)
I tried to lower the learning rate under the same conditions, but the same error occurred in the same place.
self.scaler = torch.cuda.amp.GradScaler(enabled=self.cfg.autocast) 3 Estimating norm scaling factor: 100%|███████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1000/1000 [08:41<00:00, 1.92it/s] 4 89900| MSE Loss 141.638 | L1 62.357: 18%|████████████████████████▎ | 368230400/2048000000 [13:34:58<67:16:47, 6935.25it/s]Traceback (most recent call last): 5 File "/data_x/aa007878/SAE/sae_training.py", line 95, in <module> 6 sparse_autoencoder = SAETrainingRunner(cfg).run() 7 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 8 File "/data_x/aa007878/miniconda3/envs/sae/lib/python3.11/site-packages/sae_lens/sae_training_runner.py", line 106, in run 9 sae = self.run_trainer_with_interruption_handling(trainer) 10 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 11 File "/data_x/aa007878/miniconda3/envs/sae/lib/python3.11/site-packages/sae_lens/sae_training_runner.py", line 150, in run_trainer_with_interruption_handling 12 sae = trainer.fit() 13 ^^^^^^^^^^^^^ 14 File "/data_x/aa007878/miniconda3/envs/sae/lib/python3.11/site-packages/sae_lens/training/sae_trainer.py", line 176, in fit 15 self._run_and_log_evals() 16 File "/data_x/aa007878/miniconda3/envs/sae/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context 17 return func(*args, **kwargs) 18 ^^^^^^^^^^^^^^^^^^^^^ 19 File "/data_x/aa007878/miniconda3/envs/sae/lib/python3.11/site-packages/sae_lens/training/sae_trainer.py", line 333, in _run_and_log_evals 20 eval_metrics = run_evals( 21 ^^^^^^^^^^ 22 File "/data_x/aa007878/miniconda3/envs/sae/lib/python3.11/site-packages/torch/utils/_contextlib.py", line 116, in decorate_context 23 return func(*args, **kwargs) 24 ^^^^^^^^^^^^^^^^^^^^^ 25 File "/data_x/aa007878/miniconda3/envs/sae/lib/python3.11/site-packages/sae_lens/evals.py", line 105, in run_evals 26 metrics |= get_sparsity_and_variance_metrics( 27 ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ 28 File "/data_x/aa007878/miniconda3/envs/sae/lib/python3.11/site-packages/sae_lens/evals.py", line 289, in get_sparsity_and_variance_metrics 29 flattened_sae_input = flattened_sae_input[flattened_mask] 30 ~~~~~~~~~~~~~~~~~~~^^^^^^^^^^^^^^^^ 31 RuntimeError: indices should be either on cpu or on the same device as the indexed tensor (cuda:2)
Update here - I believe the root cause of this issue might be the PR last week which added get_sae_config
, this introduced a bug where SAEs wouldn't be loaded onto the specified device. I also had a device-related bug, and making this change has now fixed things. See link to my PR.
@Yoon-Jeong-ho Is this fixed in the most recent version of SAELens (4.0.9)? Thanks for the fix @callummcdougall!
If you are submitting a bug report, please fill in the following details and use the tag [bug].
Describe the bug I encountered a RuntimeError during training while using sae_lens. The error appears to be related to a mismatch between the device used for tensor operations and the indices (CPU vs CUDA).
error message
Code example
System Info Python : 3.11.9 CUDA : 12.4 GPU : NVIDIA RTX A6000 PyTorch : 2.0.1 ununtu : 20.04.1 LTS sae-lens : 3.22.2 torch : 2.4.1 transformer-lens : 2.7.0
Checklist