tokio-rs / prost

PROST! a Protocol Buffers implementation for the Rust Language
Apache License 2.0
3.66k stars 477 forks source link

Allow encode BufMut arg to be unsized #1080

Open dspyz-matician opened 1 month ago

dspyz-matician commented 1 month ago

This means it's possible to pass in a &mut dyn BufMut

This allows for creating a custom "zero-copy" implementor of Message for bytes fields which are themselves encoded protobuf values as follows (useful when trying to model generics with protobuf):

trait MySerializer {
  fn serialize_to(&self, b: &mut dyn BufMut);
  fn encoded_len(&self) -> usize;
}

/// message MessageValue {
///   bytes serialized_value = 1;
/// }
#[derive_where(Debug)]
pub enum MessageValue {
    Empty,
    Encoding(#[derive_where(skip)] Arc<dyn MySerializer>),
    Decoding(#[derive_where(skip)] Bytes),
}

impl From<Arc<dyn MySerializer>> for MessageValue {
    fn from(val: Arc<dyn MySerializer>) -> Self {
        MessageValue::Encoding(val)
    }
}

impl Message for MessageValue {
    fn encode_raw<B>(&self, mut buf: &mut (impl bytes::BufMut + ?Sized))
    where
        Self: Sized,
    {
        match self {
            MessageValue::Empty => (),
            MessageValue::Encoding(val) => {
                encode_metadata(val.encoded_len(), buf);
                val.serialize_to(&mut buf);
            }
            MessageValue::Decoding(val) => {
                encode_metadata(val.len(), buf);
                buf.put_slice(&val)
            }
        }
    }

    fn merge_field<B>(
        &mut self,
        tag: u32,
        wire_type: WireType,
        buf: &mut B,
        _ctx: DecodeContext,
    ) -> Result<(), prost::DecodeError>
    where
        B: bytes::Buf,
        Self: Sized,
    {
        if tag != 1 {
            return Ok(());
        }
        if wire_type != WireType::LengthDelimited {
            return Err(prost::DecodeError::new(format!(
                "invalid wire type; expected length delimited, got {wire_type:?}"
            )));
        }
        let len = prost::encoding::decode_varint(buf)?;
        *self = MessageValue::Decoding(buf.copy_to_bytes(len as usize));
        Ok(())
    }

    fn encoded_len(&self) -> usize {
        match self {
            MessageValue::Empty => 0,
            MessageValue::Encoding(val) => message_len_from_byte_len(val.encoded_len()),
            MessageValue::Decoding(val) => message_len_from_byte_len(val.len()),
        }
    }

    fn clear(&mut self) {
        *self = Self::Empty;
    }
}

fn encode_metadata(encoded_len: usize, buf: &mut impl bytes::BufMut) {
    prost::encoding::encode_key(1, prost::encoding::WireType::LengthDelimited, buf);
    prost::encoding::encode_varint(encoded_len as u64, buf);
}

fn message_len_from_byte_len(encoded_len: usize) -> usize {
    1 + prost::encoding::encoded_len_varint(encoded_len as u64) + encoded_len
}

This change won't break compilation for existing uses of prost-build, but custom implementations will have to add the + ?Sized param, so it's not fully backwards-compatible