dotnet / fsharp

The F# compiler, F# core library, F# language service, and F# tooling integration for Visual Studio
https://dotnet.microsoft.com/languages/fsharp
MIT License
3.91k stars 785 forks source link

optimize struct discriminated union when all cases is reference type #9901

Open blowin opened 4 years ago

blowin commented 4 years ago

This optimization will reduce the size of the struct.

type LegalPerson(name: string) = class end
type NaturalPerson(name: string) = class end
type OrganizationDepartment(name: string) = class end

[<Struct>]
type Contractor = 
    | LegalPerson of lp: LegalPerson
    | NaturalPerson of np: NaturalPerson
    | OrganizationDepartment of od: OrganizationDepartment

Generate C# code

[Serializable]
[StructLayout(LayoutKind.Sequential, CharSet = CharSet.Auto, Size = 1)]
[Struct]
[DebuggerDisplay("{__DebugDisplay(),nq}")]
[CompilationMapping(SourceConstructFlags.SumType)]
public struct Contractor : IEquatable<Contractor>, IStructuralEquatable
{
    public static class Tags
    {
        public const int LegalPerson = 0;
        public const int NaturalPerson = 1;
        public const int OrganizationDepartment = 2;
    }

    internal LegalPerson _lp;
    internal NaturalPerson _np;
    internal OrganizationDepartment _od;
    internal int _tag;

    public int Tag => _tag;

    public bool IsLegalPerson => Tag == 0;
    public bool IsNaturalPerson => Tag == 1;
    public bool IsOrganizationDepartment => Tag == 2;

    public LegalPerson lp => _lp;
    public NaturalPerson np => _np;
    public OrganizationDepartment od => _od;

    [CompilationMapping(SourceConstructFlags.UnionCase, 0)]
    public static Contractor NewLegalPerson(LegalPerson _lp) => new Contractor(_lp, 0, false);

    [CompilationMapping(SourceConstructFlags.UnionCase, 1)]
    public static Contractor NewNaturalPerson(NaturalPerson _np) => new Contractor(_np, 1, 0);

    [CompilationMapping(SourceConstructFlags.UnionCase, 2)]
    public static Contractor NewOrganizationDepartment(OrganizationDepartment _od) => new Contractor(_od, 2, 0);

    internal Contractor(LegalPerson _lp, int _tag, bool P_2)
    {
        this._lp = _lp;
        this._tag = _tag;
    }

    internal Contractor(NaturalPerson _np, int _tag, byte P_2)
    {
        this._np = _np;
        this._tag = _tag;
    }

    internal Contractor(OrganizationDepartment _od, int _tag, sbyte P_2)
    {
        this._od = _od;
        this._tag = _tag;
    }

    internal object __DebugDisplay()
    {
        return ExtraTopLevelOperators.PrintFormatToString(new PrintfFormat<FSharpFunc<Contractor, string>, Unit, string, string, string>("%+0.8A")).Invoke(this);
    }

    public override string ToString()
    {
        return ExtraTopLevelOperators.PrintFormatToString(new PrintfFormat<FSharpFunc<Contractor, string>, Unit, string, string, Contractor>("%+A")).Invoke(this);
    }

    public sealed override int GetHashCode(IEqualityComparer comp)
    {
        int num = 0;
        switch (Tag)
        {
            default:
                num = 0;
                return -1640531527 + (LanguagePrimitives.HashCompare.GenericHashWithComparerIntrinsic(comp, _lp) + ((num << 6) + (num >> 2)));
            case 1:
                num = 1;
                return -1640531527 + (LanguagePrimitives.HashCompare.GenericHashWithComparerIntrinsic(comp, _np) + ((num << 6) + (num >> 2)));
            case 2:
                num = 2;
                return -1640531527 + (LanguagePrimitives.HashCompare.GenericHashWithComparerIntrinsic(comp, _od) + ((num << 6) + (num >> 2)));
        }
    }

    public sealed override int GetHashCode()
    {
        return GetHashCode(LanguagePrimitives.GenericEqualityComparer);
    }

    public sealed override bool Equals(object obj, IEqualityComparer comp)
    {
        if (!LanguagePrimitives.IntrinsicFunctions.TypeTestGeneric<Contractor>(obj))
        {
            return false;
        }
        Contractor contractor = (Contractor)obj;
        int tag = _tag;
        int tag2 = contractor._tag;
        if (tag == tag2)
        {
            switch (Tag)
            {
                default:
                    return LanguagePrimitives.HashCompare.GenericEqualityWithComparerIntrinsic(comp, _lp, contractor._lp);
                case 1:
                    return LanguagePrimitives.HashCompare.GenericEqualityWithComparerIntrinsic(comp, _np, contractor._np);
                case 2:
                    return LanguagePrimitives.HashCompare.GenericEqualityWithComparerIntrinsic(comp, _od, contractor._od);
            }
        }
        return false;
    }

    public sealed override bool Equals(Contractor obj)
    {
        int tag = _tag;
        int tag2 = obj._tag;
        if (tag == tag2)
        {
            switch (Tag)
            {
                default:
                    return LanguagePrimitives.HashCompare.GenericEqualityERIntrinsic(_lp, obj._lp);
                case 1:
                    return LanguagePrimitives.HashCompare.GenericEqualityERIntrinsic(_np, obj._np);
                case 2:
                    return LanguagePrimitives.HashCompare.GenericEqualityERIntrinsic(_od, obj._od);
            }
        }
        return false;
    }

    public sealed override bool Equals(object obj)
    {
        if (!LanguagePrimitives.IntrinsicFunctions.TypeTestGeneric<Contractor>(obj))
        {
            return false;
        }
        return Equals((Contractor)obj);
    }
}

replace 3 field with 1 field

[Serializable]
[StructLayout(LayoutKind.Sequential, CharSet = CharSet.Auto, Size = 1)]
[Struct]
[DebuggerDisplay("{__DebugDisplay(),nq}")]
[CompilationMapping(SourceConstructFlags.SumType)]
public struct Contractor : IEquatable<Contractor>, IStructuralEquatable
{
    public static class Tags
    {
        public const int LegalPerson = 0;
        public const int NaturalPerson = 1;
        public const int OrganizationDepartment = 2;
    }

    internal object _value;
    internal int _tag;

    public int Tag => _tag;

    public bool IsLegalPerson => Tag == 0;
    public bool IsNaturalPerson => Tag == 1;
    public bool IsOrganizationDepartment => Tag == 2;

    public LegalPerson lp => (LegalPerson)_value;
    public NaturalPerson np => (NaturalPerson)_value;
    public OrganizationDepartment od => (OrganizationDepartment)_value;

    [CompilationMapping(SourceConstructFlags.UnionCase, 0)]
    public static Contractor NewLegalPerson(LegalPerson _lp) => new Contractor(_lp, 0, false);

    [CompilationMapping(SourceConstructFlags.UnionCase, 1)]
    public static Contractor NewNaturalPerson(NaturalPerson _np) => new Contractor(_np, 1, 0);

    [CompilationMapping(SourceConstructFlags.UnionCase, 2)]
    public static Contractor NewOrganizationDepartment(OrganizationDepartment _od) => new Contractor(_od, 2, 0);

    internal Contractor(LegalPerson _lp, int _tag, bool P_2)
    {
        this._value = _lp;
        this._tag = _tag;
    }

    internal Contractor(NaturalPerson _np, int _tag, byte P_2)
    {
        this._value = _np;
        this._tag = _tag;
    }

    internal Contractor(OrganizationDepartment _od, int _tag, sbyte P_2)
    {
        this._value = _od;
        this._tag = _tag;
    }

    internal object __DebugDisplay()
    {
        return ExtraTopLevelOperators.PrintFormatToString(new PrintfFormat<FSharpFunc<Contractor, string>, Unit, string, string, string>("%+0.8A")).Invoke(this);
    }

    public override string ToString()
    {
        return ExtraTopLevelOperators.PrintFormatToString(new PrintfFormat<FSharpFunc<Contractor, string>, Unit, string, string, Contractor>("%+A")).Invoke(this);
    }

    public sealed override int GetHashCode(IEqualityComparer comp)
    {
        int num = 0;
        switch (Tag)
        {
            default:
                num = 0;
                return -1640531527 + (LanguagePrimitives.HashCompare.GenericHashWithComparerIntrinsic(comp, (LegalPerson)_value) + ((num << 6) + (num >> 2)));
            case 1:
                num = 1;
                return -1640531527 + (LanguagePrimitives.HashCompare.GenericHashWithComparerIntrinsic(comp, (NaturalPerson)_value) + ((num << 6) + (num >> 2)));
            case 2:
                num = 2;
                return -1640531527 + (LanguagePrimitives.HashCompare.GenericHashWithComparerIntrinsic(comp, (OrganizationDepartment)_value) + ((num << 6) + (num >> 2)));
        }
    }

    public sealed override int GetHashCode()
    {
        return GetHashCode(LanguagePrimitives.GenericEqualityComparer);
    }

    public sealed override bool Equals(object obj, IEqualityComparer comp)
    {
        if (!LanguagePrimitives.IntrinsicFunctions.TypeTestGeneric<Contractor>(obj))
        {
            return false;
        }
        Contractor contractor = (Contractor)obj;
        int tag = _tag;
        int tag2 = contractor._tag;
        if (tag == tag2)
        {
            switch (Tag)
            {
                default:
                    return LanguagePrimitives.HashCompare.GenericEqualityWithComparerIntrinsic(comp, (LegalPerson)_value, (LegalPerson)contractor._value);
                case 1:
                    return LanguagePrimitives.HashCompare.GenericEqualityWithComparerIntrinsic(comp, (NaturalPerson)_value, (NaturalPerson)contractor._value);
                case 2:
                    return LanguagePrimitives.HashCompare.GenericEqualityWithComparerIntrinsic(comp, (OrganizationDepartment)_value, (OrganizationDepartment)contractor._value);
            }
        }
        return false;
    }

    public sealed override bool Equals(Contractor obj)
    {
        int tag = _tag;
        int tag2 = obj._tag;
        if (tag == tag2)
        {
            switch (Tag)
            {
                default:
                    return LanguagePrimitives.HashCompare.GenericEqualityERIntrinsic((LegalPerson)_value, (LegalPerson)obj._value);
                case 1:
                    return LanguagePrimitives.HashCompare.GenericEqualityERIntrinsic((NaturalPerson)_value, (NaturalPerson)obj._value);
                case 2:
                    return LanguagePrimitives.HashCompare.GenericEqualityERIntrinsic((OrganizationDepartment)_value, (OrganizationDepartment)obj._value);
            }
        }
        return false;
    }

    public sealed override bool Equals(object obj)
    {
        if (!LanguagePrimitives.IntrinsicFunctions.TypeTestGeneric<Contractor>(obj))
        {
            return false;
        }
        return Equals((Contractor)obj);
    }
}
cartermp commented 4 years ago

I think this would be interesting to explore. This kind of stuff is always kind of tricky, but if it's possible and non-breaking (at least in normal usage) then it'd be worth doing.

kerams commented 3 years ago

The assumption is that for every use case the memory savings outweigh the casting overhead, even when there are 2 cases? Since this only changes the internal representation and FSharp.Core/reflect.fs isn't affected (the public properties for case reading and construction won't change), could it conceivably break anything else?

By the way, what's the purpose of that P_2 constructor parameter (https://github.com/dotnet/fsharp/issues/9767) and why isn't the tag baked into every constructor?

dsyme commented 3 years ago

This is covered by https://github.com/fsharp/fslang-suggestions/issues/699 and other suggestions.

We should use Explicit overlapping layouts where possible