ndif-team / nnsight

The nnsight package enables interpreting and manipulating the internals of deep learned models.
https://nnsight.net/
MIT License
399 stars 37 forks source link

ellipses does not work in the FakeTensor system #174

Closed loftusa closed 3 months ago

loftusa commented 3 months ago
from nnsight import LanguageModel
key = "meta-llama/Meta-Llama-3-8B"
lm = LanguageModel(key)
mlp = lm.model.layers[-1].mlp.down_proj
out_layer = lm.lm_head
neurons = [3260, 7737, 8894]
with lm.trace('The truth is the', remote=True):
  print(mlp.input[0][0][..., neurons].shape)

returns

ValidationError: 33 validation errors for RequestModel
intervention_graph.dict[str,union[function-after[<lambda>(), is-instance[Node]],NodeModel]].setitem_0.function-after[<lambda>(), is-instance[Node]].args.1.tagged-union[Reference,SliceModel,TensorModel,PrimitiveModel,ListModel,TupleModel,DictModel]
  Input should be a valid dictionary or object to extract fields from [type=model_attributes_type, input_value=(Ellipsis, [3260, 7737, 8894]), input_type=tuple]
    For further information visit https://errors.pydantic.dev/2.8/v/model_attributes_type

Full traceback (doesn't include a bunch of repetitions of the intervention_graph.dict line):

---------------------------------------------------------------------------
ValidationError                           Traceback (most recent call last)
File /share/u/lofty/abliteration.py:9
      [7](https://vscode-remote+ssh-002dremote-002brtx.vscode-resource.vscode-cdn.net/share/u/lofty/abliteration.py:7) out_layer = lm.lm_head
      [8](https://vscode-remote+ssh-002dremote-002brtx.vscode-resource.vscode-cdn.net/share/u/lofty/abliteration.py:8) neurons = [3260, 7737, 8894]
----> [9](https://vscode-remote+ssh-002dremote-002brtx.vscode-resource.vscode-cdn.net/share/u/lofty/abliteration.py:9) with lm.trace('The truth is the', remote=True):
     [10](https://vscode-remote+ssh-002dremote-002brtx.vscode-resource.vscode-cdn.net/share/u/lofty/abliteration.py:10)   mlp.input[0][0][..., neurons] = 10
     [11](https://vscode-remote+ssh-002dremote-002brtx.vscode-resource.vscode-cdn.net/share/u/lofty/abliteration.py:11)   logits = out_layer.output.save()

File ~/miniconda3/lib/python3.12/site-packages/nnsight/contexts/Runner.py:44, in Runner.__exit__(self, exc_type, exc_val, exc_tb)
     [41](https://vscode-remote+ssh-002dremote-002brtx.vscode-resource.vscode-cdn.net/share/u/lofty/~/miniconda3/lib/python3.12/site-packages/nnsight/contexts/Runner.py:41)     raise exc_val
     [43](https://vscode-remote+ssh-002dremote-002brtx.vscode-resource.vscode-cdn.net/share/u/lofty/~/miniconda3/lib/python3.12/site-packages/nnsight/contexts/Runner.py:43) if self.remote:
---> [44](https://vscode-remote+ssh-002dremote-002brtx.vscode-resource.vscode-cdn.net/share/u/lofty/~/miniconda3/lib/python3.12/site-packages/nnsight/contexts/Runner.py:44)     self.run_server()
     [46](https://vscode-remote+ssh-002dremote-002brtx.vscode-resource.vscode-cdn.net/share/u/lofty/~/miniconda3/lib/python3.12/site-packages/nnsight/contexts/Runner.py:46)     self._graph.tracing = False
     [47](https://vscode-remote+ssh-002dremote-002brtx.vscode-resource.vscode-cdn.net/share/u/lofty/~/miniconda3/lib/python3.12/site-packages/nnsight/contexts/Runner.py:47)     self._graph = None

File ~/miniconda3/lib/python3.12/site-packages/nnsight/contexts/Runner.py:53, in Runner.run_server(self)
     [51](https://vscode-remote+ssh-002dremote-002brtx.vscode-resource.vscode-cdn.net/share/u/lofty/~/miniconda3/lib/python3.12/site-packages/nnsight/contexts/Runner.py:51) def run_server(self):
     [52](https://vscode-remote+ssh-002dremote-002brtx.vscode-resource.vscode-cdn.net/share/u/lofty/~/miniconda3/lib/python3.12/site-packages/nnsight/contexts/Runner.py:52)     # Create the pydantic object for the request.
---> [53](https://vscode-remote+ssh-002dremote-002brtx.vscode-resource.vscode-cdn.net/share/u/lofty/~/miniconda3/lib/python3.12/site-packages/nnsight/contexts/Runner.py:53)     request = pydantics.RequestModel(
     [54](https://vscode-remote+ssh-002dremote-002brtx.vscode-resource.vscode-cdn.net/share/u/lofty/~/miniconda3/lib/python3.12/site-packages/nnsight/contexts/Runner.py:54)         kwargs=self._kwargs,
     [55](https://vscode-remote+ssh-002dremote-002brtx.vscode-resource.vscode-cdn.net/share/u/lofty/~/miniconda3/lib/python3.12/site-packages/nnsight/contexts/Runner.py:55)         repo_id=self._model._model_key,
     [56](https://vscode-remote+ssh-002dremote-002brtx.vscode-resource.vscode-cdn.net/share/u/lofty/~/miniconda3/lib/python3.12/site-packages/nnsight/contexts/Runner.py:56)         batched_input=self._batched_input,
     [57](https://vscode-remote+ssh-002dremote-002brtx.vscode-resource.vscode-cdn.net/share/u/lofty/~/miniconda3/lib/python3.12/site-packages/nnsight/contexts/Runner.py:57)         intervention_graph=self._graph.nodes,
     [58](https://vscode-remote+ssh-002dremote-002brtx.vscode-resource.vscode-cdn.net/share/u/lofty/~/miniconda3/lib/python3.12/site-packages/nnsight/contexts/Runner.py:58)     )
     [60](https://vscode-remote+ssh-002dremote-002brtx.vscode-resource.vscode-cdn.net/share/u/lofty/~/miniconda3/lib/python3.12/site-packages/nnsight/contexts/Runner.py:60)     if self.blocking:
     [61](https://vscode-remote+ssh-002dremote-002brtx.vscode-resource.vscode-cdn.net/share/u/lofty/~/miniconda3/lib/python3.12/site-packages/nnsight/contexts/Runner.py:61)         self.blocking_request(request)

File ~/miniconda3/lib/python3.12/site-packages/pydantic/main.py:193, in BaseModel.__init__(self, **data)
    [191](https://vscode-remote+ssh-002dremote-002brtx.vscode-resource.vscode-cdn.net/share/u/lofty/~/miniconda3/lib/python3.12/site-packages/pydantic/main.py:191) # `__tracebackhide__` tells pytest and some other tools to omit this function from tracebacks
    [192](https://vscode-remote+ssh-002dremote-002brtx.vscode-resource.vscode-cdn.net/share/u/lofty/~/miniconda3/lib/python3.12/site-packages/pydantic/main.py:192) __tracebackhide__ = True
--> [193](https://vscode-remote+ssh-002dremote-002brtx.vscode-resource.vscode-cdn.net/share/u/lofty/~/miniconda3/lib/python3.12/site-packages/pydantic/main.py:193) self.__pydantic_validator__.validate_python(data, self_instance=self)

ValidationError: 33 validation errors for RequestModel
intervention_graph.dict[str,union[function-after[<lambda>(), is-instance[Node]],NodeModel]].setitem_0.function-after[<lambda>(), is-instance[Node]].args.1.tagged-union[Reference,SliceModel,TensorModel,PrimitiveModel,ListModel,TupleModel,DictModel]
  Input should be a valid dictionary or object to extract fields from [type=model_attributes_type, input_value=(Ellipsis, [3260, 7737, 8894]), input_type=tuple]
    For further information visit https://errors.pydantic.dev/2.8/v/model_attributes_type
JadenFiotto-Kaufman commented 3 months ago

@loftusa Yeah this is fixed on dev. Will be in the next release. Just need to explicitly handle all datatypes we want to be able to be sent remotely. Like Ellipsis