serde-rs / json

Strongly typed JSON library for Rust
Apache License 2.0
4.91k stars 562 forks source link

How to allow trailing characters #632

Closed Timmmm closed 4 years ago

Timmmm commented 4 years ago

I have a slightly annoying use-case. I have a large JSON file with a top-level key that contains the version. It is named so that it occurs early in the file (hopefully). I want to extract this object without reading the entire file which can be very very big. For example:

{
   ".version": 2,
  "data": { many gigabytes of data },
}

Then I can have code like:

let version = Version::deserialize(...);
switch version {
  0 => DataV0::deserialize(...),
  1 => DataV1::deserialize(...),

JSON is totally the wrong format for this, since it doesn't support "sparse reads". Most formats don't seem to (notable exceptions: Amazon Ion, Protobuf). But anyway, I have JSON, and I know the key is probably near the start of the file. So I want to parse the file and then immediately bail out when I reach the .version key.

I almost got this to work with serde_json by writing a custom visitor, something like this:

    impl<'de> Visitor<'de> for JsonVersionVisitor {
      type Value = VersionInfo;

      #[inline]
      fn visit_map<A>(mut map: A) -> Result<Self::Value, A::Error>
      where
        A: serde::de::MapAccess<'de>,
      {
        while let Some(key) = map.next_key()? {
          match key {
            Field::VersionInfo => { // ".version"
              let version_info = map.next_value::<VersionInfo>()?;
              // We're done!
              return Ok(version_info);
            }
            _ => {
              let _ = map.next_value::<serde::de::IgnoredAny>()?;
            }
          }
        }

        Err(serde::de::missing_field_or_whatever(".version")
      }
    }

It almost works! Except serde_json checks that you really are at the end of a map when visit_map() returns, and throws an error if you aren't ("trailing characters").

The second thing I tried was to implement DeserializeSeed. The idea being that I could store the VersionInfo somewhere, and then just return an error immediately. The calling code could check if VersionInfo was stored and if so ignore the error.

Unfortunately DeserializeSeed is super-complicated and also it seems to consume the seed, so you can only pass arbitrary things into it, you can't get arbitrary things out. Or at least I couldn't figure out how to.

Is there any way to do this? Or do I have to make a custom fork of serde_json (or write my own crappy JSON parser)?

dtolnay commented 4 years ago

The best way would be to not bail out after version, but have a Visitor that reads the ".version" and then also deserializes the subsequent "data" field in the same pass through the input file.

If you really need to bail out, you can make the DeserializeSeed hold a &mut Option<n> where it writes the version.

Timmmm commented 4 years ago

also deserializes the subsequent "data" field in the same pass through the input file.

Ah my example was too simple - I actually have lots of fields at the same level as version, and I'd like to use serde_derive for it, so that isn't quite as simple.

If you really need to bail out, you can make the DeserializeSeed hold a &mut Option where it writes the version.

Yeah I did try this but couldn't figure out the lifetimes. :-/ Anyway I've come up with a much hackier solution that seems to work. I'll close this.

dtolnay commented 4 years ago

That's too bad, DeserializeSeed is definitely the way to do it. If you share a compilable minimal repro of what you did with lifetimes, I can help fix it.

Timmmm commented 4 years ago

Thanks for the kind offer of help! I was going to take you up on it (been on lockdown holiday for a couple of weeks), so I resurrected my code, cleaned it up, and magically the lifetime issues are gone. I swear it didn't work before! Anyway here is my solution in case anyone else comes across this problem. I also have the additional problem that the version info can be under two possible key names - versionInfo or .versionInfo, but it is easy to handle in this case.

use serde::{Deserialize, Deserializer, de::Visitor};
use std::fmt;

#[derive(Deserialize, Debug, Clone, Copy)]
pub struct VersionInfo {
  pub major: u32,
  pub minor: u32,
}

#[derive(Debug)]
pub struct VersionedFile {
  pub version_info: VersionInfo,
  // Other fields are ignored.
}

impl<'de> Deserialize<'de> for VersionedFile {

  fn deserialize<D>(deserializer: D) -> serde::export::Result<Self, D::Error>
  where
    D: Deserializer<'de>,
  {
    enum Field {
      VersionInfo,
      Ignore,
    }
    struct FieldVisitor;

    impl<'de> Visitor<'de> for FieldVisitor {
      type Value = Field;

      fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
        formatter.write_str("field identifier")
      }

      fn visit_str<E>(self, value: &str) -> Result<Self::Value, E>
      where
        E: serde::de::Error,
      {
        match value {
          // "versionInfo" is for backwards compatibility.
          ".versionInfo" | "versionInfo" => Ok(Field::VersionInfo),
          _ => Ok(Field::Ignore),
        }
      }
    }

    impl<'de> Deserialize<'de> for Field {
      #[inline]
      fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
      where
        D: Deserializer<'de>,
      {
        deserializer.deserialize_identifier(FieldVisitor)
      }
    }

    // If found, store the VersionedFile here. Then we immediately return
    // an error to stop parsing. Yes this is ugly.
    struct VersionedFileVisitor<'a>(&'a mut Option<VersionInfo>);

    impl<'de, 'a> Visitor<'de> for VersionedFileVisitor<'a> {
      // We don't actually return anything directly so the return type is ().
      type Value = ();

      fn expecting(&self, formatter: &mut fmt::Formatter) -> fmt::Result {
        formatter.write_str("struct VersionedFile")
      }

      #[inline]
      fn visit_map<A>(self, mut map: A) -> Result<Self::Value, A::Error>
      where
        A: serde::de::MapAccess<'de>,
      {
        while let Some(key) = map.next_key()? {
          match key {
            Field::VersionInfo => {
              let version_info = map.next_value::<VersionInfo>()?;
              // Immediately stop processing.
              *self.0 = Some(version_info);
              // We're done!
              break;
            }
            _ => {
              let _ = map.next_value::<serde::de::IgnoredAny>()?;
            }
          }
        }

        // This will cause an error, but we can ignore it.
        Ok(())
      }
    }

    const FIELDS: &'static [&'static str] = &["versionInfo"];

    let mut version_info: Option<VersionInfo> = None;

    let res = deserializer.deserialize_struct(
      "VersionedFile",
      FIELDS,
      VersionedFileVisitor(&mut version_info),
    );

    match &version_info {
      Some(v) => Ok(VersionedFile{
        version_info: *v,
      }),
      _ => Err(res.unwrap_err()),
    }
  }
}

pub fn deserialize_partial(data: &[u8]) -> Result<VersionedFile, failure::Error> {
  // We can't use `serde_json::from_slice(data)` because it complains about trailing characters.
  let mut deserializer = serde_json::Deserializer::from_reader(data);
  let version = VersionedFile::deserialize(&mut deserializer)?;
  Ok(version)
}

#[cfg(test)]
mod tests {
  #[test]
  fn deserialize_incomplete_object() {
    // Test both key names.
    for key in ["versionInfo", ".versionInfo"].iter() {

      let data = format!(r#"
{{
  "foo": 3,
  "bar": [2, 3, 4],
  "{}": {{
    "major": 1,
    "minor": 2
  }},
  "baz": [
"#,
        key,
      );

      let version = super::deserialize_partial(data.as_bytes()).unwrap();

      assert!(version.version_info.major == 1);
      assert!(version.version_info.minor == 2);
    }
  }
}