ai-safety-foundation / sparse_autoencoder

Sparse Autoencoder for Mechanistic Interpretability
https://ai-safety-foundation.github.io/sparse_autoencoder/
MIT License
191 stars 39 forks source link

demo pre encoder bias device error #209

Open qwenzo opened 7 months ago

qwenzo commented 7 months ago

Hello,

Thank you for your work! I'm having the following error when just simply running the demo without changing anything.

Traceback (most recent call last):
  File "/nfs/homedirs/abdallam/anaconda3/envs/sparse_auto/lib/python3.10/site-packages/sparse_autoencoder/train/sweep.py", line 334, in train
    run_training_pipeline(
  File "/nfs/homedirs/abdallam/anaconda3/envs/sparse_auto/lib/python3.10/site-packages/sparse_autoencoder/train/sweep.py", line 299, in run_training_pipeline
    pipeline.run_pipeline(
  File "/nfs/homedirs/abdallam/anaconda3/envs/sparse_auto/lib/python3.10/site-packages/pydantic/validate_call_decorator.py", line 58, in wrapper_function
    return validate_call_wrapper(*args, **kwargs)
  File "/nfs/homedirs/abdallam/anaconda3/envs/sparse_auto/lib/python3.10/site-packages/pydantic/_internal/_validate_call.py", line 81, in __call__
    res = self.__pydantic_validator__.validate_python(pydantic_core.ArgsKwargs(args, kwargs))
  File "/nfs/homedirs/abdallam/anaconda3/envs/sparse_auto/lib/python3.10/site-packages/sparse_autoencoder/train/pipeline.py", line 518, in run_pipeline
    self.validate_sae(validation_n_activations)
  File "/nfs/homedirs/abdallam/anaconda3/envs/sparse_auto/lib/python3.10/site-packages/pydantic/validate_call_decorator.py", line 58, in wrapper_function
    return validate_call_wrapper(*args, **kwargs)
  File "/nfs/homedirs/abdallam/anaconda3/envs/sparse_auto/lib/python3.10/site-packages/pydantic/_internal/_validate_call.py", line 81, in __call__
    res = self.__pydantic_validator__.validate_python(pydantic_core.ArgsKwargs(args, kwargs))
  File "/nfs/homedirs/abdallam/anaconda3/envs/sparse_auto/lib/python3.10/site-packages/sparse_autoencoder/train/pipeline.py", line 379, in validate_sae
    loss_with_reconstruction = self.source_model.run_with_hooks(
  File "/nfs/homedirs/abdallam/anaconda3/envs/sparse_auto/lib/python3.10/site-packages/transformer_lens/hook_points.py", line 365, in run_with_hooks
    return hooked_model.forward(*model_args, **model_kwargs)
  File "/nfs/homedirs/abdallam/anaconda3/envs/sparse_auto/lib/python3.10/site-packages/transformer_lens/HookedTransformer.py", line 562, in forward
    residual = block(
  File "/nfs/homedirs/abdallam/anaconda3/envs/sparse_auto/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/nfs/homedirs/abdallam/anaconda3/envs/sparse_auto/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/nfs/homedirs/abdallam/anaconda3/envs/sparse_auto/lib/python3.10/site-packages/transformer_lens/components.py", line 1464, in forward
    mlp_out = self.hook_mlp_out(
  File "/nfs/homedirs/abdallam/anaconda3/envs/sparse_auto/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/nfs/homedirs/abdallam/anaconda3/envs/sparse_auto/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1581, in _call_impl
    hook_result = hook(self, args, result)
  File "/nfs/homedirs/abdallam/anaconda3/envs/sparse_auto/lib/python3.10/site-packages/transformer_lens/hook_points.py", line 65, in full_hook
    return hook(module_output, hook=self)
  File "/nfs/homedirs/abdallam/anaconda3/envs/sparse_auto/lib/python3.10/site-packages/sparse_autoencoder/source_model/replace_activations_hook.py", line 64, in replace_activations_hook
    _learned_activations, output_activations = sparse_autoencoder.forward(expanded)
  File "/nfs/homedirs/abdallam/anaconda3/envs/sparse_auto/lib/python3.10/site-packages/sparse_autoencoder/autoencoder/model.py", line 180, in forward
    x = self.pre_encoder_bias(x)
  File "/nfs/homedirs/abdallam/anaconda3/envs/sparse_auto/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1518, in _wrapped_call_impl
    return self._call_impl(*args, **kwargs)
  File "/nfs/homedirs/abdallam/anaconda3/envs/sparse_auto/lib/python3.10/site-packages/torch/nn/modules/module.py", line 1527, in _call_impl
    return forward_call(*args, **kwargs)
  File "/nfs/homedirs/abdallam/anaconda3/envs/sparse_auto/lib/python3.10/site-packages/sparse_autoencoder/autoencoder/components/tied_bias.py", line 82, in forward
    return x - self.bias
RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cuda:0 and cpu!