jcrist / msgspec

A fast serialization and validation library, with builtin support for JSON, MessagePack, YAML, and TOML
https://jcristharif.com/msgspec/
BSD 3-Clause "New" or "Revised" License
2.44k stars 75 forks source link

Nested custom extensions #710

Closed makarr closed 4 months ago

makarr commented 4 months ago

Question

Here's an example that extends msgspec to Numpy arrays:

NP_NDARRAY_CODE = 1

class NumpyStruct(msgspec.Struct):
    arr: np.ndarray

def enc_hook(obj: Any) -> Any:
    if isinstance(obj, np.ndarray):
        f = io.BytesIO()
        np.save(f, obj)
        data = f.getvalue()
        return msgspec.msgpack.Ext(NP_NDARRAY_CODE, data)
    else:
        raise NotImplementedError(f"Objects of type {type(obj)} are not supported")

def ext_hook(code: int, data: memoryview) -> Any:
    if code == NP_NDARRAY_CODE:
        return np.load(io.BytesIO(data))
    else:
        raise NotImplementedError(f"Extension type code {code} is not supported")

enc = msgspec.msgpack.Encoder(enc_hook=enc_hook)
dec = msgspec.msgpack.Decoder(NumpyStruct, ext_hook=ext_hook)

s = NumpyStruct(arr=np.random.rand(8))

msg = enc.encode(s1)
s2 = dec.decode(msg)

np.allclose(s1.arr, s2.arr) # True

(Incidentally I believe this may solve the problem presented in issue 655)

My problem is recursive application of extended types. How to handle the following?

class NumpyStructContainer(msgspec.Struct):
    numpy_structs: list[NumpyStruct]
makarr commented 4 months ago

Nevermind, I was overthinking this. It's as simple as defining a different decoder.

dec = msgspec.msgpack.Decoder(NumpyStructContainer, ext_hook=ext_hook)