dtolnay / typetag

Serde serializable and deserializable trait objects
Apache License 2.0
1.19k stars 38 forks source link

Deserializing map into a vector of types #47

Closed vikigenius closed 2 years ago

vikigenius commented 2 years ago

Reproducing the README example:

use serde::{Deserialize, Serialize};

#[typetag::serde]
trait WebEvent {
    fn inspect(&self);
}

#[derive(Serialize, Deserialize)]
struct PageLoad;

#[typetag::serde]
impl WebEvent for PageLoad {
    fn inspect(&self) {
        println!("200 milliseconds or bust");
    }
}

#[derive(Serialize, Deserialize)]
struct Click {
    x: i32,
    y: i32,
}

#[typetag::serde]
impl WebEvent for Click {
    fn inspect(&self) {
        println!("negative space between the ads: x={} y={}", self.x, self.y);
    }
}

I would like to be able to deserialize a dictionary of events into a vector of events instead.

#[test]
fn test_events_de() {
    let json_events: &str = r#"{"Click": {"x": 10, "y": 10}, "PageLoad": null}"#;
    let events: Vec<Box<dyn WebEvent>> = serde_json::from_str(json_events).unwrap();
}

This obviously panics because it expects a sequence instead of a map

thread 'test_events_de' panicked at 'called `Result::unwrap()` on an `Err` value: Error("invalid type: map, expected a sequence", line: 1, column: 0)', tests/test_trait_vecs.rs:42:76

But I am wondering if there is a way to easily achieve the deserialization I want.

dtolnay commented 2 years ago

You could do something like this:

fn main() {
    let json_events: &str = r#"{"Click": {"x": 10, "y": 10}, "PageLoad": null}"#;
    let mut de = serde_json::Deserializer::from_str(json_events);
    let events: Vec<Box<dyn WebEvent>> = vec_element_from_each_entry(&mut de).unwrap();
    println!("{events:#?}");
}

use serde::de::{Deserialize, DeserializeSeed, Deserializer, MapAccess, Visitor};
use std::fmt::{self, Display};
use std::marker::PhantomData;

fn vec_element_from_each_entry<'de, T, D>(deserializer: D) -> Result<Vec<T>, D::Error>
where
    T: Deserialize<'de>,
    D: Deserializer<'de>,
{
    struct VecElementFromEachEntry<T> {
        marker: PhantomData<T>,
    }

    impl<'de, T> Visitor<'de> for VecElementFromEachEntry<T>
    where
        T: Deserialize<'de>,
    {
        type Value = Vec<T>;

        fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
            formatter.write_str("typetag map")
        }

        fn visit_map<M>(self, mut map: M) -> Result<Self::Value, M::Error>
        where
            M: MapAccess<'de>,
        {
            let mut vec = Vec::new();
            loop {
                match T::deserialize(serde::de::value::MapAccessDeserializer::new(
                    SingleEntryFromMapAccess {
                        map_access: &mut map,
                        done: false,
                    },
                )) {
                    Ok(entry) => vec.push(entry),
                    Err(MapError::EndOfMap) => return Ok(vec),
                    Err(MapError::Error(error)) => return Err(error),
                }
            }
        }
    }

    struct SingleEntryFromMapAccess<M> {
        map_access: M,
        done: bool,
    }

    impl<'de, M> MapAccess<'de> for SingleEntryFromMapAccess<M>
    where
        M: MapAccess<'de>,
    {
        type Error = MapError<M::Error>;

        fn next_key_seed<K>(&mut self, seed: K) -> Result<Option<K::Value>, Self::Error>
        where
            K: DeserializeSeed<'de>,
        {
            if self.done {
                Ok(None)
            } else {
                self.done = true;
                match self.map_access.next_key_seed(seed) {
                    Ok(None) => Err(MapError::EndOfMap),
                    Ok(Some(key)) => Ok(Some(key)),
                    Err(error) => Err(MapError::Error(error)),
                }
            }
        }

        fn next_value_seed<V>(&mut self, seed: V) -> Result<V::Value, Self::Error>
        where
            V: DeserializeSeed<'de>,
        {
            self.map_access
                .next_value_seed(seed)
                .map_err(MapError::Error)
        }
    }

    #[derive(Debug)]
    enum MapError<E> {
        Error(E),
        EndOfMap,
    }

    impl<E> serde::de::Error for MapError<E>
    where
        E: serde::de::Error,
    {
        fn custom<T: Display>(msg: T) -> Self {
            MapError::Error(serde::de::Error::custom(msg))
        }
    }

    impl<E> std::error::Error for MapError<E> where E: std::error::Error {}

    impl<E> Display for MapError<E>
    where
        E: Display,
    {
        fn fmt(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
            match self {
                MapError::Error(error) => Display::fmt(error, formatter),
                MapError::EndOfMap => formatter.write_str("end of map"),
            }
        }
    }

    deserializer.deserialize_map(VecElementFromEachEntry {
        marker: PhantomData,
    })
}