adamsitnik / SafePayloadReader

MIT License
1 stars 0 forks source link

Safe Binary Format Payload Reader

Goal

The goal of this library is to allow for safe reading of Binary Format payload from untrusted input.

The principles:

API

BinaryFormatter.Serialize method accepts two arguments: Stream serializationStream and object graph. The first is a stream that we call the payload and the latter is the root object of the serialization graph.

The Binary Formatter payload consists of serialization records that represent the serialized objects and their metadata. To read the whole payload and get the root object, the user need to call static SerializationRecord Read(Stream payload, PayloadOptions? options = null, bool leaveOpen = false) method. There is more than a dozen of different serialization record types, but this library provides a set of abstractions, so the users need to learn only a few of them:

SerializationRecord rootObject = PayloadReader.Read(payload);

if (rootObject is PrimitiveTypeRecord<string> stringRecord)
{
    Console.WriteLine($"It was a string: '{stringRecord.Value}'");
}
else if (rootObject is ClassRecord classRecord)
{
    Console.WriteLine($"It was a class record of '{classRecord.TypeName}' type.");
}
else if (rootObject is ArrayRecord<byte> arrayOfBytes)
{
    Console.WriteLine($"It was an array of bytes: '{string.Join(",", arrayOfBytes.ToArray())}'");
}

Beside Read, the PayloadReader exposes a ReadClassRecord method that returns ClassRecord (or throws) and. It also provides two ContainsBinaryFormatterPayload methods that allow to check whether given stream or buffer contains binary formatter payload.

ClassRecord

The most important type that derives from SerializationRecord is ClassRecord which represents all class and struct instances beside arrays and selected primitive types.

public class ClassRecord : SerializationRecord
{
    public TypeName TypeName { get; }
    public IEnumerable<string> MemberNames { get; }

    // Checks if member of given name was present in the payload (useful for versioning scenarios)
    public bool HasMember(string memberName);

    // Retrieves the value of the provided memberName
    public string? GetString(string memberName);
    public bool GetBoolean(string memberName);
    public byte GetByte(string memberName);
    public sbyte GetSByte(string memberName);
    public short GetInt16(string memberName);
    public ushort GetUInt16(string memberName);
    public char GetChar(string memberName);
    public int GetInt32(string memberName);
    public uint GetUInt32(string memberName);
    public float GetSingle(string memberName);
    public long GetInt64(string memberName);
    public ulong GetUInt64(string memberName);
    public double GetDouble(string memberName);
    public decimal GetDecimal(string memberName);
    public TimeSpan GetTimeSpan(string memberName);
    public DateTime GetDateTime(string memberName);
    public object? GetRawValue(string memberName);

    // Retrieves an array for the provided memberName, with default max length
    public T[]? GetArrayOfPrimitiveType<T>(string memberName, int maxLength = 64000);

    // Retrieves an instance of ClassRecord that describes non-primitive type for the provided memberName
    public ClassRecord? GetClassRecord(string memberName);
    // Retrieves any other serialization record like jagged array or array of complex types
    public SerializationRecord? GetSerializationRecord(string memberName);
}

Get$PrimitiveType methods read a value of given primitive type. GetArrayOfPrimitiveType<T> methods read arrays of values of given primitive type. GetClassRecord method reads an instance of ClassRecord that describes non-primitive type like a custom class or struct.

[Serializable]
public class Sample
{
    public int Integer;
    public string? Text;
    public byte[]? ArrayOfBytes;
    public Sample? ClassInstance;
}

ClassRecord rootRecord = PayloadReader.ReadClassRecord(payload);
Sample output = new()
{
    // using the dedicated methods to read primitive values
    Integer = rootRecord.GetInt32(nameof(Sample.Integer)),
    Text = rootRecord.GetString(nameof(Sample.Text)),
    // using dedicated method to read an array of bytes
    ArrayOfBytes = rootRecord.GetArrayOfPrimitiveType<byte>(nameof(Sample.ArrayOfBytes)),
    // using GetClassRecord to read a class record
    ClassInstance = new()
    {
        Text = rootRecord
            .GetClassRecord(nameof(Sample.ClassInstance))!
            .GetString(nameof(Sample.Text))
    }  
};

TODO: describe how to work with Jagged and Rectangular arrays