ndif-team / nnsight

The nnsight package enables interpreting and manipulating the internals of deep learned models.
https://nnsight.net/
MIT License
401 stars 38 forks source link

LogitsProcessor not working when remote=True #241

Open LouisHernandez17 opened 1 month ago

LouisHernandez17 commented 1 month ago

I am trying to use constrained generation with a remote model, however, it only works on local so far, and I am not sure why. This might be related to #137. Here is a MWE, with the error trace:

from outlines import models
from outlines.processors import RegexLogitsProcessor
import transformers

model = LanguageModel("meta-llama/Meta-Llama-3.1-405B")
prompts = "What is 2+2?"
logits_processor = transformers.LogitsProcessorList(
    [RegexLogitsProcessor("4|5", tokenizer=models.TransformerTokenizer(model.tokenizer))])

# %%
with model.generate(prompts, remote=True, logits_processor=logits_processor) as tracer:
    out = model.generator.output.save()

Error trace:

{
    "name": "ValidationError",
    "message": "32 validation errors for RequestModel
object.function-after[<lambda>(), is-instance[Session]]
  Input should be an instance of Session [type=is_instance_of, input_value=&lt;Tracer at 0x703f652d5b70&gt;, input_type=Tracer]
    For further information visit https://errors.pydantic.dev/2.9/v/is_instance_of
object.function-after[<lambda>(), is-instance[Tracer]].kwargs.logits_processor.`tagged-union[...,IteratorModel,SessionModel,Reference,SliceModel,TensorModel,PrimitiveModel,TupleModel,ListModel,DictModel,EllipsisModel]`
  Unable to extract tag using discriminator 'type_name' [type=union_tag_not_found, input_value=[<outlines.processors.str...ject at 0x703f8c5da830>], input_type=LogitsProcessorList]
    For further information visit https://errors.pydantic.dev/2.9/v/union_tag_not_found
object.function-after[<lambda>(), is-instance[Tracer]].kwargs.logits_processor.function-after[<lambda>(), is-instance[Tracer]]
  Input should be an instance of Tracer [type=is_instance_of, input_value=[<outlines.processors.str...ject at 0x703f8c5da830>], input_type=LogitsProcessorList]
    For further information visit https://errors.pydantic.dev/2.9/v/is_instance_of
object.function-after[<lambda>(), is-instance[Tracer]].kwargs.logits_processor.function-after[<lambda>(), is-instance[Iterator]]
  Input should be an instance of Iterator [type=is_instance_of, input_value=[<outlines.processors.str...ject at 0x703f8c5da830>], input_type=LogitsProcessorList]
    For further information visit https://errors.pydantic.dev/2.9/v/is_instance_of
object.function-after[<lambda>(), is-instance[Tracer]].kwargs.logits_processor.function-after[<lambda>(), is-instance[Session]]
  Input should be an instance of Session [type=is_instance_of, input_value=[<outlines.processors.str...ject at 0x703f8c5da830>], input_type=LogitsProcessorList]
    For further information visit https://errors.pydantic.dev/2.9/v/is_instance_of
object.function-after[<lambda>(), is-instance[Tracer]].kwargs.logits_processor.function-after[<lambda>(), is-instance[Node]]
  Input should be an instance of Node [type=is_instance_of, input_value=[<outlines.processors.str...ject at 0x703f8c5da830>], input_type=LogitsProcessorList]
    For further information visit https://errors.pydantic.dev/2.9/v/is_instance_of
object.function-after[<lambda>(), is-instance[Tracer]].kwargs.logits_processor.function-after[<lambda>(), is-instance[slice]]
  Input should be an instance of slice [type=is_instance_of, input_value=[<outlines.processors.str...ject at 0x703f8c5da830>], input_type=LogitsProcessorList]
    For further information visit https://errors.pydantic.dev/2.9/v/is_instance_of
object.function-after[<lambda>(), is-instance[Tracer]].kwargs.logits_processor.function-after[<lambda>(), is-instance[Tensor]]
  Input should be an instance of Tensor [type=is_instance_of, input_value=[<outlines.processors.str...ject at 0x703f8c5da830>], input_type=LogitsProcessorList]
    For further information visit https://errors.pydantic.dev/2.9/v/is_instance_of
object.function-after[<lambda>(), is-instance[Tracer]].kwargs.logits_processor.function-after[<lambda>(), nullable[union[int,float,str,bool]]].int
  Input should be a valid integer [type=int_type, input_value=[<outlines.processors.str...ject at 0x703f8c5da830>], input_type=LogitsProcessorList]
    For further information visit https://errors.pydantic.dev/2.9/v/int_type
object.function-after[<lambda>(), is-instance[Tracer]].kwargs.logits_processor.function-after[<lambda>(), nullable[union[int,float,str,bool]]].float
  Input should be a valid number [type=float_type, input_value=[<outlines.processors.str...ject at 0x703f8c5da830>], input_type=LogitsProcessorList]
    For further information visit https://errors.pydantic.dev/2.9/v/float_type
object.function-after[<lambda>(), is-instance[Tracer]].kwargs.logits_processor.function-after[<lambda>(), nullable[union[int,float,str,bool]]].str
  Input should be a valid string [type=string_type, input_value=[<outlines.processors.str...ject at 0x703f8c5da830>], input_type=LogitsProcessorList]
    For further information visit https://errors.pydantic.dev/2.9/v/string_type
object.function-after[<lambda>(), is-instance[Tracer]].kwargs.logits_processor.function-after[<lambda>(), nullable[union[int,float,str,bool]]].bool
  Input should be a valid boolean [type=bool_type, input_value=[<outlines.processors.str...ject at 0x703f8c5da830>], input_type=LogitsProcessorList]
    For further information visit https://errors.pydantic.dev/2.9/v/bool_type
object.function-after[<lambda>(), is-instance[Tracer]].kwargs.logits_processor.`function-after[<lambda>(), tuple[any, ...]]`
  Input should be a valid tuple [type=tuple_type, input_value=[<outlines.processors.str...ject at 0x703f8c5da830>], input_type=LogitsProcessorList]
    For further information visit https://errors.pydantic.dev/2.9/v/tuple_type
object.function-after[<lambda>(), is-instance[Tracer]].kwargs.logits_processor.function-after[<lambda>(), list[any]].values.0.`tagged-union[...,...,...,Reference,...,TensorModel,PrimitiveModel,...,...,DictModel,EllipsisModel]`
  Unable to extract tag using discriminator 'type_name' [type=union_tag_not_found, input_value=<outlines.processors.stru...bject at 0x703f8c5da830>, input_type=RegexLogitsProcessor]
    For further information visit https://errors.pydantic.dev/2.9/v/union_tag_not_found
object.function-after[<lambda>(), is-instance[Tracer]].kwargs.logits_processor.function-after[<lambda>(), list[any]].values.0.function-after[<lambda>(), is-instance[Tracer]]
  Input should be an instance of Tracer [type=is_instance_of, input_value=<outlines.processors.stru...bject at 0x703f8c5da830>, input_type=RegexLogitsProcessor]
    For further information visit https://errors.pydantic.dev/2.9/v/is_instance_of
object.function-after[<lambda>(), is-instance[Tracer]].kwargs.logits_processor.function-after[<lambda>(), list[any]].values.0.function-after[<lambda>(), is-instance[Iterator]]
  Input should be an instance of Iterator [type=is_instance_of, input_value=<outlines.processors.stru...bject at 0x703f8c5da830>, input_type=RegexLogitsProcessor]
    For further information visit https://errors.pydantic.dev/2.9/v/is_instance_of
object.function-after[<lambda>(), is-instance[Tracer]].kwargs.logits_processor.function-after[<lambda>(), list[any]].values.0.function-after[<lambda>(), is-instance[Session]]
  Input should be an instance of Session [type=is_instance_of, input_value=<outlines.processors.stru...bject at 0x703f8c5da830>, input_type=RegexLogitsProcessor]
    For further information visit https://errors.pydantic.dev/2.9/v/is_instance_of
object.function-after[<lambda>(), is-instance[Tracer]].kwargs.logits_processor.function-after[<lambda>(), list[any]].values.0.function-after[<lambda>(), is-instance[Node]]
  Input should be an instance of Node [type=is_instance_of, input_value=<outlines.processors.stru...bject at 0x703f8c5da830>, input_type=RegexLogitsProcessor]
    For further information visit https://errors.pydantic.dev/2.9/v/is_instance_of
object.function-after[<lambda>(), is-instance[Tracer]].kwargs.logits_processor.function-after[<lambda>(), list[any]].values.0.function-after[<lambda>(), is-instance[slice]]
  Input should be an instance of slice [type=is_instance_of, input_value=<outlines.processors.stru...bject at 0x703f8c5da830>, input_type=RegexLogitsProcessor]
    For further information visit https://errors.pydantic.dev/2.9/v/is_instance_of
object.function-after[<lambda>(), is-instance[Tracer]].kwargs.logits_processor.function-after[<lambda>(), list[any]].values.0.function-after[<lambda>(), is-instance[Tensor]]
  Input should be an instance of Tensor [type=is_instance_of, input_value=<outlines.processors.stru...bject at 0x703f8c5da830>, input_type=RegexLogitsProcessor]
    For further information visit https://errors.pydantic.dev/2.9/v/is_instance_of
object.function-after[<lambda>(), is-instance[Tracer]].kwargs.logits_processor.function-after[<lambda>(), list[any]].values.0.function-after[<lambda>(), nullable[union[int,float,str,bool]]].int
  Input should be a valid integer [type=int_type, input_value=<outlines.processors.stru...bject at 0x703f8c5da830>, input_type=RegexLogitsProcessor]
    For further information visit https://errors.pydantic.dev/2.9/v/int_type
object.function-after[<lambda>(), is-instance[Tracer]].kwargs.logits_processor.function-after[<lambda>(), list[any]].values.0.function-after[<lambda>(), nullable[union[int,float,str,bool]]].float
  Input should be a valid number [type=float_type, input_value=<outlines.processors.stru...bject at 0x703f8c5da830>, input_type=RegexLogitsProcessor]
    For further information visit https://errors.pydantic.dev/2.9/v/float_type
object.function-after[<lambda>(), is-instance[Tracer]].kwargs.logits_processor.function-after[<lambda>(), list[any]].values.0.function-after[<lambda>(), nullable[union[int,float,str,bool]]].str
  Input should be a valid string [type=string_type, input_value=<outlines.processors.stru...bject at 0x703f8c5da830>, input_type=RegexLogitsProcessor]
    For further information visit https://errors.pydantic.dev/2.9/v/string_type
object.function-after[<lambda>(), is-instance[Tracer]].kwargs.logits_processor.function-after[<lambda>(), list[any]].values.0.function-after[<lambda>(), nullable[union[int,float,str,bool]]].bool
  Input should be a valid boolean [type=bool_type, input_value=<outlines.processors.stru...bject at 0x703f8c5da830>, input_type=RegexLogitsProcessor]
    For further information visit https://errors.pydantic.dev/2.9/v/bool_type
object.function-after[<lambda>(), is-instance[Tracer]].kwargs.logits_processor.function-after[<lambda>(), list[any]].values.0.`function-after[<lambda>(), tuple[any, ...]]`
  Input should be a valid tuple [type=tuple_type, input_value=<outlines.processors.stru...bject at 0x703f8c5da830>, input_type=RegexLogitsProcessor]
    For further information visit https://errors.pydantic.dev/2.9/v/tuple_type
object.function-after[<lambda>(), is-instance[Tracer]].kwargs.logits_processor.function-after[<lambda>(), list[any]].values.0.function-after[<lambda>(), list[any]]
  Input should be a valid list [type=list_type, input_value=<outlines.processors.stru...bject at 0x703f8c5da830>, input_type=RegexLogitsProcessor]
    For further information visit https://errors.pydantic.dev/2.9/v/list_type
object.function-after[<lambda>(), is-instance[Tracer]].kwargs.logits_processor.function-after[<lambda>(), list[any]].values.0.function-after[<lambda>(), dict[any,any]]
  Input should be a valid dictionary [type=dict_type, input_value=<outlines.processors.stru...bject at 0x703f8c5da830>, input_type=RegexLogitsProcessor]
    For further information visit https://errors.pydantic.dev/2.9/v/dict_type
object.function-after[<lambda>(), is-instance[Tracer]].kwargs.logits_processor.function-after[<lambda>(), list[any]].values.0.function-after[<lambda>(), is-instance[ellipsis]]
  Input should be an instance of ellipsis [type=is_instance_of, input_value=<outlines.processors.stru...bject at 0x703f8c5da830>, input_type=RegexLogitsProcessor]
    For further information visit https://errors.pydantic.dev/2.9/v/is_instance_of
object.function-after[<lambda>(), is-instance[Tracer]].kwargs.logits_processor.function-after[<lambda>(), dict[any,any]]
  Input should be a valid dictionary [type=dict_type, input_value=[<outlines.processors.str...ject at 0x703f8c5da830>], input_type=LogitsProcessorList]
    For further information visit https://errors.pydantic.dev/2.9/v/dict_type
object.function-after[<lambda>(), is-instance[Tracer]].kwargs.logits_processor.function-after[<lambda>(), is-instance[ellipsis]]
  Input should be an instance of ellipsis [type=is_instance_of, input_value=[<outlines.processors.str...ject at 0x703f8c5da830>], input_type=LogitsProcessorList]
    For further information visit https://errors.pydantic.dev/2.9/v/is_instance_of
object.SessionModel
  Input should be a valid dictionary or instance of SessionModel [type=model_type, input_value=&lt;Tracer at 0x703f652d5b70&gt;, input_type=Tracer]
    For further information visit https://errors.pydantic.dev/2.9/v/model_type
object.TracerModel
  Input should be a valid dictionary or instance of TracerModel [type=model_type, input_value=&lt;Tracer at 0x703f652d5b70&gt;, input_type=Tracer]
    For further information visit https://errors.pydantic.dev/2.9/v/model_type",
    "stack": "---------------------------------------------------------------------------
ValidationError                           Traceback (most recent call last)
File /home/louis/Documents/outlines-nnsight/quicktest.py:2
      1 # %%
----> 2 with model.generate(prompts, remote=True, logits_processor=logits_processor) as tracer:
      3     out = model.generator.output.save()
      4 print(out)

File ~/miniconda3/envs/outlines-dev/lib/python3.10/site-packages/nnsight/contexts/Tracer.py:102, in Tracer.__exit__(self, exc_type, exc_val, exc_tb)
     97     self.invoker.__exit__(None, None, None)
     99 self.model._envoy._reset()
--> 102 super().__exit__(exc_type, exc_val, exc_tb)

File ~/miniconda3/envs/outlines-dev/lib/python3.10/site-packages/nnsight/contexts/GraphBasedContext.py:217, in GraphBasedContext.__exit__(self, exc_type, exc_val, exc_tb)
    214     self.graph = None
    215     raise exc_val
--> 217 self.backend(self)

File ~/miniconda3/envs/outlines-dev/lib/python3.10/site-packages/nnsight/contexts/backends/RemoteBackend.py:104, in RemoteBackend.__call__(self, obj)
    100 self.handle_result = obj.remote_backend_handle_result_value
    102 if self.blocking:
--> 104     request = self.request(obj)
    106     # Do blocking request.
    107     self.blocking_request(request)

File ~/miniconda3/envs/outlines-dev/lib/python3.10/site-packages/nnsight/contexts/backends/RemoteBackend.py:96, in RemoteBackend.request(self, obj)
     93 from ...schema.Request import RequestModel
     95 # Create request using pydantic to parse the object itself.
---> 96 return RequestModel(object=obj, model_key=model_key)

File ~/miniconda3/envs/outlines-dev/lib/python3.10/site-packages/pydantic/main.py:209, in BaseModel.__init__(self, **data)
    207 # `__tracebackhide__` tells pytest and some other tools to omit this function from tracebacks
    208 __tracebackhide__ = True
--> 209 validated_self = self.__pydantic_validator__.validate_python(data, self_instance=self)
    210 if self is not validated_self:
    211     warnings.warn(
    212         'A custom validator is returning a value other than `self`.\
'
    213         \"Returning anything other than `self` from a top level model validator isn't supported when validating via `__init__`.\
\"
    214         'See the `model_validator` docs (https://docs.pydantic.dev/latest/concepts/validators/#model-validators) for more details.',
    215         category=None,
    216     )

ValidationError: 32 validation errors for RequestModel
object.function-after[<lambda>(), is-instance[Session]]
  Input should be an instance of Session [type=is_instance_of, input_value=&lt;Tracer at 0x703f652d5b70&gt;, input_type=Tracer]
    For further information visit https://errors.pydantic.dev/2.9/v/is_instance_of
object.function-after[<lambda>(), is-instance[Tracer]].kwargs.logits_processor.`tagged-union[...,IteratorModel,SessionModel,Reference,SliceModel,TensorModel,PrimitiveModel,TupleModel,ListModel,DictModel,EllipsisModel]`
  Unable to extract tag using discriminator 'type_name' [type=union_tag_not_found, input_value=[<outlines.processors.str...ject at 0x703f8c5da830>], input_type=LogitsProcessorList]
    For further information visit https://errors.pydantic.dev/2.9/v/union_tag_not_found
object.function-after[<lambda>(), is-instance[Tracer]].kwargs.logits_processor.function-after[<lambda>(), is-instance[Tracer]]
  Input should be an instance of Tracer [type=is_instance_of, input_value=[<outlines.processors.str...ject at 0x703f8c5da830>], input_type=LogitsProcessorList]
    For further information visit https://errors.pydantic.dev/2.9/v/is_instance_of
object.function-after[<lambda>(), is-instance[Tracer]].kwargs.logits_processor.function-after[<lambda>(), is-instance[Iterator]]
  Input should be an instance of Iterator [type=is_instance_of, input_value=[<outlines.processors.str...ject at 0x703f8c5da830>], input_type=LogitsProcessorList]
    For further information visit https://errors.pydantic.dev/2.9/v/is_instance_of
object.function-after[<lambda>(), is-instance[Tracer]].kwargs.logits_processor.function-after[<lambda>(), is-instance[Session]]
  Input should be an instance of Session [type=is_instance_of, input_value=[<outlines.processors.str...ject at 0x703f8c5da830>], input_type=LogitsProcessorList]
    For further information visit https://errors.pydantic.dev/2.9/v/is_instance_of
object.function-after[<lambda>(), is-instance[Tracer]].kwargs.logits_processor.function-after[<lambda>(), is-instance[Node]]
  Input should be an instance of Node [type=is_instance_of, input_value=[<outlines.processors.str...ject at 0x703f8c5da830>], input_type=LogitsProcessorList]
    For further information visit https://errors.pydantic.dev/2.9/v/is_instance_of
object.function-after[<lambda>(), is-instance[Tracer]].kwargs.logits_processor.function-after[<lambda>(), is-instance[slice]]
  Input should be an instance of slice [type=is_instance_of, input_value=[<outlines.processors.str...ject at 0x703f8c5da830>], input_type=LogitsProcessorList]
    For further information visit https://errors.pydantic.dev/2.9/v/is_instance_of
object.function-after[<lambda>(), is-instance[Tracer]].kwargs.logits_processor.function-after[<lambda>(), is-instance[Tensor]]
  Input should be an instance of Tensor [type=is_instance_of, input_value=[<outlines.processors.str...ject at 0x703f8c5da830>], input_type=LogitsProcessorList]
    For further information visit https://errors.pydantic.dev/2.9/v/is_instance_of
object.function-after[<lambda>(), is-instance[Tracer]].kwargs.logits_processor.function-after[<lambda>(), nullable[union[int,float,str,bool]]].int
  Input should be a valid integer [type=int_type, input_value=[<outlines.processors.str...ject at 0x703f8c5da830>], input_type=LogitsProcessorList]
    For further information visit https://errors.pydantic.dev/2.9/v/int_type
object.function-after[<lambda>(), is-instance[Tracer]].kwargs.logits_processor.function-after[<lambda>(), nullable[union[int,float,str,bool]]].float
  Input should be a valid number [type=float_type, input_value=[<outlines.processors.str...ject at 0x703f8c5da830>], input_type=LogitsProcessorList]
    For further information visit https://errors.pydantic.dev/2.9/v/float_type
object.function-after[<lambda>(), is-instance[Tracer]].kwargs.logits_processor.function-after[<lambda>(), nullable[union[int,float,str,bool]]].str
  Input should be a valid string [type=string_type, input_value=[<outlines.processors.str...ject at 0x703f8c5da830>], input_type=LogitsProcessorList]
    For further information visit https://errors.pydantic.dev/2.9/v/string_type
object.function-after[<lambda>(), is-instance[Tracer]].kwargs.logits_processor.function-after[<lambda>(), nullable[union[int,float,str,bool]]].bool
  Input should be a valid boolean [type=bool_type, input_value=[<outlines.processors.str...ject at 0x703f8c5da830>], input_type=LogitsProcessorList]
    For further information visit https://errors.pydantic.dev/2.9/v/bool_type
object.function-after[<lambda>(), is-instance[Tracer]].kwargs.logits_processor.`function-after[<lambda>(), tuple[any, ...]]`
  Input should be a valid tuple [type=tuple_type, input_value=[<outlines.processors.str...ject at 0x703f8c5da830>], input_type=LogitsProcessorList]
    For further information visit https://errors.pydantic.dev/2.9/v/tuple_type
object.function-after[<lambda>(), is-instance[Tracer]].kwargs.logits_processor.function-after[<lambda>(), list[any]].values.0.`tagged-union[...,...,...,Reference,...,TensorModel,PrimitiveModel,...,...,DictModel,EllipsisModel]`
  Unable to extract tag using discriminator 'type_name' [type=union_tag_not_found, input_value=<outlines.processors.stru...bject at 0x703f8c5da830>, input_type=RegexLogitsProcessor]
    For further information visit https://errors.pydantic.dev/2.9/v/union_tag_not_found
object.function-after[<lambda>(), is-instance[Tracer]].kwargs.logits_processor.function-after[<lambda>(), list[any]].values.0.function-after[<lambda>(), is-instance[Tracer]]
  Input should be an instance of Tracer [type=is_instance_of, input_value=<outlines.processors.stru...bject at 0x703f8c5da830>, input_type=RegexLogitsProcessor]
    For further information visit https://errors.pydantic.dev/2.9/v/is_instance_of
object.function-after[<lambda>(), is-instance[Tracer]].kwargs.logits_processor.function-after[<lambda>(), list[any]].values.0.function-after[<lambda>(), is-instance[Iterator]]
  Input should be an instance of Iterator [type=is_instance_of, input_value=<outlines.processors.stru...bject at 0x703f8c5da830>, input_type=RegexLogitsProcessor]
    For further information visit https://errors.pydantic.dev/2.9/v/is_instance_of
object.function-after[<lambda>(), is-instance[Tracer]].kwargs.logits_processor.function-after[<lambda>(), list[any]].values.0.function-after[<lambda>(), is-instance[Session]]
  Input should be an instance of Session [type=is_instance_of, input_value=<outlines.processors.stru...bject at 0x703f8c5da830>, input_type=RegexLogitsProcessor]
    For further information visit https://errors.pydantic.dev/2.9/v/is_instance_of
object.function-after[<lambda>(), is-instance[Tracer]].kwargs.logits_processor.function-after[<lambda>(), list[any]].values.0.function-after[<lambda>(), is-instance[Node]]
  Input should be an instance of Node [type=is_instance_of, input_value=<outlines.processors.stru...bject at 0x703f8c5da830>, input_type=RegexLogitsProcessor]
    For further information visit https://errors.pydantic.dev/2.9/v/is_instance_of
object.function-after[<lambda>(), is-instance[Tracer]].kwargs.logits_processor.function-after[<lambda>(), list[any]].values.0.function-after[<lambda>(), is-instance[slice]]
  Input should be an instance of slice [type=is_instance_of, input_value=<outlines.processors.stru...bject at 0x703f8c5da830>, input_type=RegexLogitsProcessor]
    For further information visit https://errors.pydantic.dev/2.9/v/is_instance_of
object.function-after[<lambda>(), is-instance[Tracer]].kwargs.logits_processor.function-after[<lambda>(), list[any]].values.0.function-after[<lambda>(), is-instance[Tensor]]
  Input should be an instance of Tensor [type=is_instance_of, input_value=<outlines.processors.stru...bject at 0x703f8c5da830>, input_type=RegexLogitsProcessor]
    For further information visit https://errors.pydantic.dev/2.9/v/is_instance_of
object.function-after[<lambda>(), is-instance[Tracer]].kwargs.logits_processor.function-after[<lambda>(), list[any]].values.0.function-after[<lambda>(), nullable[union[int,float,str,bool]]].int
  Input should be a valid integer [type=int_type, input_value=<outlines.processors.stru...bject at 0x703f8c5da830>, input_type=RegexLogitsProcessor]
    For further information visit https://errors.pydantic.dev/2.9/v/int_type
object.function-after[<lambda>(), is-instance[Tracer]].kwargs.logits_processor.function-after[<lambda>(), list[any]].values.0.function-after[<lambda>(), nullable[union[int,float,str,bool]]].float
  Input should be a valid number [type=float_type, input_value=<outlines.processors.stru...bject at 0x703f8c5da830>, input_type=RegexLogitsProcessor]
    For further information visit https://errors.pydantic.dev/2.9/v/float_type
object.function-after[<lambda>(), is-instance[Tracer]].kwargs.logits_processor.function-after[<lambda>(), list[any]].values.0.function-after[<lambda>(), nullable[union[int,float,str,bool]]].str
  Input should be a valid string [type=string_type, input_value=<outlines.processors.stru...bject at 0x703f8c5da830>, input_type=RegexLogitsProcessor]
    For further information visit https://errors.pydantic.dev/2.9/v/string_type
object.function-after[<lambda>(), is-instance[Tracer]].kwargs.logits_processor.function-after[<lambda>(), list[any]].values.0.function-after[<lambda>(), nullable[union[int,float,str,bool]]].bool
  Input should be a valid boolean [type=bool_type, input_value=<outlines.processors.stru...bject at 0x703f8c5da830>, input_type=RegexLogitsProcessor]
    For further information visit https://errors.pydantic.dev/2.9/v/bool_type
object.function-after[<lambda>(), is-instance[Tracer]].kwargs.logits_processor.function-after[<lambda>(), list[any]].values.0.`function-after[<lambda>(), tuple[any, ...]]`
  Input should be a valid tuple [type=tuple_type, input_value=<outlines.processors.stru...bject at 0x703f8c5da830>, input_type=RegexLogitsProcessor]
    For further information visit https://errors.pydantic.dev/2.9/v/tuple_type
object.function-after[<lambda>(), is-instance[Tracer]].kwargs.logits_processor.function-after[<lambda>(), list[any]].values.0.function-after[<lambda>(), list[any]]
  Input should be a valid list [type=list_type, input_value=<outlines.processors.stru...bject at 0x703f8c5da830>, input_type=RegexLogitsProcessor]
    For further information visit https://errors.pydantic.dev/2.9/v/list_type
object.function-after[<lambda>(), is-instance[Tracer]].kwargs.logits_processor.function-after[<lambda>(), list[any]].values.0.function-after[<lambda>(), dict[any,any]]
  Input should be a valid dictionary [type=dict_type, input_value=<outlines.processors.stru...bject at 0x703f8c5da830>, input_type=RegexLogitsProcessor]
    For further information visit https://errors.pydantic.dev/2.9/v/dict_type
object.function-after[<lambda>(), is-instance[Tracer]].kwargs.logits_processor.function-after[<lambda>(), list[any]].values.0.function-after[<lambda>(), is-instance[ellipsis]]
  Input should be an instance of ellipsis [type=is_instance_of, input_value=<outlines.processors.stru...bject at 0x703f8c5da830>, input_type=RegexLogitsProcessor]
    For further information visit https://errors.pydantic.dev/2.9/v/is_instance_of
object.function-after[<lambda>(), is-instance[Tracer]].kwargs.logits_processor.function-after[<lambda>(), dict[any,any]]
  Input should be a valid dictionary [type=dict_type, input_value=[<outlines.processors.str...ject at 0x703f8c5da830>], input_type=LogitsProcessorList]
    For further information visit https://errors.pydantic.dev/2.9/v/dict_type
object.function-after[<lambda>(), is-instance[Tracer]].kwargs.logits_processor.function-after[<lambda>(), is-instance[ellipsis]]
  Input should be an instance of ellipsis [type=is_instance_of, input_value=[<outlines.processors.str...ject at 0x703f8c5da830>], input_type=LogitsProcessorList]
    For further information visit https://errors.pydantic.dev/2.9/v/is_instance_of
object.SessionModel
  Input should be a valid dictionary or instance of SessionModel [type=model_type, input_value=&lt;Tracer at 0x703f652d5b70&gt;, input_type=Tracer]
    For further information visit https://errors.pydantic.dev/2.9/v/model_type
object.TracerModel
  Input should be a valid dictionary or instance of TracerModel [type=model_type, input_value=&lt;Tracer at 0x703f652d5b70&gt;, input_type=Tracer]
    For further information visit https://errors.pydantic.dev/2.9/v/model_type"
}
JadenFiotto-Kaufman commented 1 month ago

@LouisHernandez17 Hey!

So the problem here is the intervention graph ( which stores all of your interventions that are sent to the server) has to only contain object types that are registered by nnsight so it knows how to serialize and deserialize them. Logits Processor is not one of them. What does RegexProcessor do? Does it just stop generation if it sees a 4 or 5? I think you can create interventions that will do this for you : )

LouisHernandez17 commented 1 month ago

Hi ! Thank you for your answer. Actually, it is a bit more complicated than that, it builds a Finite State Machine, determining for each step of the regex the next legal moves. In this very simple case, the FSM is very simple, we have just two legal moves that go from the initial state to a finished state : either generate a 4 or a 5. But potentially, it could be way harder, for harder RegExes.

At each sampling step, it manually sets the illegal logits the $-\infty$, so that, after the softmax, the only logits with non-zero probability are legal ones.

A relevant library I use is Outlines, based off this paper : https://arxiv.org/pdf/2307.09702

Although I think it could be interesting to support such post-processing, I suppose for simpler regex I could find a workaround, like using a simple mask by hand.