manzt / anywidget

jupyter widgets made easy
https://anywidget.dev
MIT License
404 stars 34 forks source link

pydantic model serialization disallows binary data (experimental API) #310

Open kylebarron opened 10 months ago

kylebarron commented 10 months ago

The current implementation of _get_pydantic_state_v2 is:

https://github.com/manzt/anywidget/blob/620d380598467cae798c2795c4a7c52d1c3c8243/anywidget/_descriptor.py#L653-L655

When you have a model that serializes to binary data, mode="json" is ok only when the binary data is actually a utf8 string. I.e.:

from pydantic import BaseModel

class Model(BaseModel):
    a: bytes

Model(a=b'abc').model_dump(mode='json') 
# {'a': 'abc'}

Model(a=b'\xa1').model_dump(mode='json')
# UnicodeDecodeError: 'utf-8' codec can't decode byte 0xa1 in position 0: invalid utf-8

But because anywidget uses remove_buffers, setting mode="json" isn't necessary.

state = Model(a=b'\xa1').model_dump()

from anywidget._util import remove_buffers
state2, buffer_paths, buffers = remove_buffers(state)
# {}, [['a']], [b'\xa1']

Though removing mode="json" might have regressions for other data types, like dates?

kylebarron commented 10 months ago

Though removing mode="json" might have regressions for other data types, like dates?

I'm assuming this will end up being a problem...

It looks like I got this to work by subclassing model_dump to serialize the entire model except a specific key to a json-compatible version, and then serializing only that key, and then joining. e.g.:

from pydantic import BaseModel

class Model(BaseModel):
    a: bytes

    def model_dump(self, *args, **kwargs) -> dict[str, Any]:
        if 'exclude' in kwargs.keys():
            if 'a' not in kwargs['exclude']:
                kwargs['exclude'].append('a')
        else:
            kwargs['exclude'] = ['a']

        kwargs['mode'] = 'json'

        json_model_dump = super().model_dump(*args, **kwargs)
        python_model_dump = super().model_dump(include={"a"})
        return {**json_model_dump, **python_model_dump}

Model(a=b'abc').model_dump(mode='json')
# {'a': b'abc'}

Model(a=b'\xa1').model_dump(mode='json')
# {'a': b'\xa1'}

So now it's always binary!

manzt commented 10 months ago

Hmm, yah this is due to the fact binary data isn't serializable to JSON (without base64 encoding it). The _get_pydanic_state_v2 is just a helper function for pydantic. Maybe in this case, the work around is to just implement your own state getter function:

class Model(BaseModel):
  a: bytes

  def _get_anywidget_state(self):
      state = self.model_dump(exclude=["a"], mode="json")
      state["a"] = self.a
      return state

This will take precedent over the inferred _get_pydantic_state_v2 helper (and won't force you to override the model_dump behavior. Logic for how we determine the state getter:

https://github.com/manzt/anywidget/blob/62ed7e470b2294bc2338e051594865ff7e89579e/anywidget/_descriptor.py#L457-L502

manzt commented 10 months ago

Though removing mode="json" might have regressions for other data types, like dates?

Hmm, I'm not too worried about the regression of other date types. All of this stuff is behind experimental, and I'm starting to come around to the idea of just using model.model_dump(). Currently to serialize bytes fields (an important use case), you'd need to either:

with mode="json"

Whereas if we switched to model_dump(), supporting a datetime could look like:

class Model(BaseModel):
  a: bytes
  dt: datetime

  @field_serializer
  def serialize_dt(self, dt: datetime, _info):
      return dt.isoformat()

If someone really wants to keep the behavior of anywidget and pydantic separate, then they can always implement _get_anywidget_state. I just want _get_pydantic_state_v2 to provide the most sensible default.

kylebarron commented 10 months ago

It seems like you either need to manually handle bytes if setting model_dump(mode="json") or you'd need to manually handle dates if setting model_dump(), so not sure which is easiest 🤷‍♂️

manzt commented 10 months ago

Since bytes is a python builtin, I think my preference would be use model_dump() and make someone manually handle higher level objects (just like how one would need to make a serializer for their own class).

kylebarron commented 10 months ago

To play devil's advocate, datetime.datetime is a python builtin too? It's hard because it seems the Jupyter comm requires almost JSON, in that everything but bytes needs to be json-serializable?

manzt commented 10 months ago

Sorry, I was referring to builtins, not built-in modules in the standard library. In particular, inspired by msgspec.to_builtins.

import builtins

builtins.datetime
# AttributeError: module 'builtins' has no attribute 'datetime'

But I see your point, probably need to think more about it. Less magic is better IMO with regard to serialization, and I could imagine that they might be several different ways to serialize a datetime compared to bytes. If we default to model_dump (without mode="json"), users are likely to get a Python error when ipywidgets tries to serialize to JSON (rather than some implicit conversion to a JSON-encoded value they aren't aware of). Either way, I think better documentation about these edge cases serves us the best strategy.