AlignmentResearch / tuned-lens

Tools for understanding how transformer predictions are built layer-by-layer
https://tuned-lens.readthedocs.io/en/latest/
MIT License
438 stars 47 forks source link

Problems running the notebook on GPU #53

Closed svenschultze closed 1 year ago

svenschultze commented 1 year ago

Describe the bug I am trying to run the interactive notebook on GPU, but there seems to be an error in the plot_lens function.

To Reproduce Steps to reproduce the behavior:

  1. Open the interactive Notebook on Colab and change runtime type to gpu
  2. Change the device to "cuda" and run. The same problem also happens when running locally on H100 GPU

Exception Traceback

 /usr/local/lib/python3.9/dist-packages/ipywidgets/widgets/interaction.py:257 in update           

   254 │   │   │   │   for widget in self.kwargs_widgets:                                         
   255 │   │   │   │   │   value = widget.get_interact_value()                                    
   256 │   │   │   │   │   self.kwargs[widget._kwarg] = value                                     
 ❱ 257 │   │   │   │   self.result = self.f(**self.kwargs)                                        
   258 │   │   │   │   show_inline_matplotlib_plots()                                             
   259 │   │   │   │   if self.auto_display and self.result is not None:                          
   260 │   │   │   │   │   display(self.result)                                                   
 in make_plot:16                                                                                  

 /usr/local/lib/python3.9/dist-packages/torch/amp/autocast_mode.py:14 in decorate_autocast        

    11 │   @functools.wraps(func)                                                                 
    12 │   def decorate_autocast(*args, **kwargs):                                                
    13 │   │   with autocast_instance:                                                            
 ❱  14 │   │   │   return func(*args, **kwargs)                                                   
    15 │   decorate_autocast.__script_unsupported = '@autocast() decorator is not supported in    
    16 │   return decorate_autocast                                                               
    17                                                                                            

 /usr/local/lib/python3.9/dist-packages/torch/autograd/grad_mode.py:27 in decorate_context        

    24 │   │   @functools.wraps(func)                                                             
    25 │   │   def decorate_context(*args, **kwargs):                                             
    26 │   │   │   with self.clone():                                                             
 ❱  27 │   │   │   │   return func(*args, **kwargs)                                               
    28 │   │   return cast(F, decorate_context)                                                   
    29 │                                                                                          
    30 │   def _wrap_generator(self, func):                                                       

 /usr/local/lib/python3.9/dist-packages/tuned_lens/plotting/plot_lens.py:120 in plot_lens         

   117 │   │                                                                                      
   118 │   │   return logits.log_softmax(dim=-1)                                                  
   119 │                                                                                          
 ❱ 120 │   hidden_lps = stream.zip_map(                                                           
   121 │   │   decode_tl,                                                                         
   122 │   │   range(len(stream) - 1),                                                            
   123 │   )                                                                                      

 /usr/local/lib/python3.9/dist-packages/tuned_lens/residual_stream.py:105 in zip_map              

   102 │                                                                                          
   103 │   def zip_map(self, fn: Callable, *others: "Iterable") -> "ResidualStream":              
   104 │   │   """Map over corresponding states, returning a new `ResidualStream`."""             
 ❱ 105 │   │   return self.new_from_list(list(starmap(fn, zip(self, *others))))                   
   106 │                                                                                          
   107 │   def new_from_list(self, states: list[th.Tensor]) -> "ResidualStream":                  
   108 │   │   """Create a new `ResidualStream` with the given states."""                         

 /usr/local/lib/python3.9/dist-packages/tuned_lens/plotting/plot_lens.py:114 in decode_tl         

   111 │   tokens = tokens[start_pos:end_pos]                                                     
   112 │                                                                                          
   113 │   def decode_tl(h, i):                                                                   
 ❱ 114 │   │   logits = lens.forward(h, i)                                                        
   115 │   │   if mask_input:                                                                     
   116 │   │   │   logits[..., input_ids] = -th.finfo(h.dtype).max                                
   117                                                                                            

 /usr/local/lib/python3.9/dist-packages/tuned_lens/nn/lenses.py:320 in forward                    

   317 │   │   │   h_ = self.layer_norm(h)                                                        
   318 │   │   │   return self[idx](h_)                                                           
   319 │   │                                                                                      
 ❱ 320 │   │   h = self.transform_hidden(h, idx)                                                  
   321 │   │   return self.to_logits(h)                                                           
   322 │                                                                                          
   323 │   def __len__(self) -> int:                                                              

 /usr/local/lib/python3.9/dist-packages/tuned_lens/nn/lenses.py:299 in transform_hidden           

   296 │   │   # Note that we add the translator output residually, in contrast to the formula    
   297 │   │   # in the paper. By parametrizing it this way we ensure that weight decay           
   298 │   │   # regularizes the transform toward the identity, not the zero transformation.      
 ❱ 299 │   │   return h + self[idx](h)                                                            
   300 │                                                                                          
   301 │   def to_logits(self, h: th.Tensor) -> th.Tensor:                                        
   302 │   │   """Decode a hidden state into logits."""                                           

 /usr/local/lib/python3.9/dist-packages/torch/nn/modules/module.py:1194 in _call_impl             

   1191 │   │   # this function, and just call forward.                                           
   1192 │   │   if not (self._backward_hooks or self._forward_hooks or self._forward_pre_hooks o  
   1193 │   │   │   │   or _global_forward_hooks or _global_forward_pre_hooks):                   
 ❱ 1194 │   │   │   return forward_call(*input, **kwargs)                                         
   1195 │   │   # Do not call functions when jit is used                                          
   1196 │   │   full_backward_hooks, non_full_backward_hooks = [], []                             
   1197 │   │   if self._backward_hooks or _global_backward_hooks:                                

 /usr/local/lib/python3.9/dist-packages/torch/nn/modules/linear.py:114 in forward                 

   111 │   │   │   init.uniform_(self.bias, -bound, bound)                                        
   112 │                                                                                          
   113 │   def forward(self, input: Tensor) -> Tensor:                                            
 ❱ 114 │   │   return F.linear(input, self.weight, self.bias)                                     
   115 │                                                                                          
   116 │   def extra_repr(self) -> str:                                                           
   117 │   │   return 'in_features={}, out_features={}, bias={}'.format(                          

RuntimeError: Expected all tensors to be on the same device, but found at least two devices, cpu and cuda:0! (when 
checking argument for argument mat1 in method wrapper_addmm)
levmckinney commented 1 year ago

Thank you for reporting this! I believe it's a problem with the tuned_lens not being correctly sent to the GPU. It should be resolved by just adding a tuned_lens=tuned_lens.to(device) to the second cell. I've opened a pull request to resolve this for future users #54.