sklose / NCalc2

expression evaluator for .NET with built-in compiler
MIT License
166 stars 58 forks source link

C# like handling of `null` in comparation operators #71

Open Yankes opened 1 year ago

Yankes commented 1 year ago

In C# when you write code alike (int?)0 == null result is false but in NCalc2 because of LambdaExpressionVistor.UnwrapNullable all null are coveted to 0.

This in effect cause NullInt() == 0 return true.

As some user expressions could relying on this behavior I think best would be new option flag for this behavior. When this flag is enable then in LambdaExpressionVistor.WithCommonNumericType add code like:

        private bool IsNullable(L.Expression expression)
        {
            var ti = expression.Type.GetTypeInfo();
            return ti.IsGenericType && ti.GetGenericTypeDefinition() == typeof(Nullable<>);
        }

        private L.Expression HaveValueExpression(L.Expression expression)
        {
            if (IsNullable(expression))
            {
                return L.Expression.Property(expression, "HasValue");
            }
            else
            {
                return L.Expression.Constant(true);
            }
        }

        private Func<L.Expression, L.Expression, L.Expression> UnwrapNullableAction(L.Expression left, L.Expression rigth, Func<L.Expression, L.Expression, L.Expression> action, BinaryExpressionType expressiontype)
        {
            if (/*TODO should we use null like C#? */ && (IsNullable(left) || IsNullable(rigth)))
            {
                if (/*TODO is comparison expressiontype */)
                {
                    return (l, r) => L.Expression.Condition(
                        L.Expression.Equal(HaveValueExpression(left), HaveValueExpression(rigth)),
                        action(l, r), // will compare `0 == 0` instead of `null == null` but this is fine
                        L.Expression.Constant(false));
                }
                else
                {
                    return (l, r) => L.Expression.Condition(
                        L.Expression.And(HaveValueExpression(left), HaveValueExpression(rigth)),
                        action(l, r),
                        L.Expression.Constant(null));
                }
            }
            else
            {
                return action;
            }
        }

        private L.Expression WithCommonNumericType(L.Expression left, L.Expression right,
            Func<L.Expression, L.Expression, L.Expression> action, BinaryExpressionType expressiontype = BinaryExpressionType.Unknown)
        {
            action = UnwrapNullableAction(left, right, action, expressiontype);
            left = UnwrapNullable(left);
            right = UnwrapNullable(right);

            /*TODO rest of logic */
        }
Yankes commented 1 year ago

For my poroject I hack LambdaExpressionVistor to better work wilth nulls:


        internal class HackExpressionVistor : LogicalExpressionVisitor
        {
            private readonly IDictionary<string, object> _parameters;
            private L.Expression _result;
            private readonly L.Expression _context;
            private readonly EvaluateOptions _options = EvaluateOptions.None;
            private readonly Dictionary<Type, HashSet<Type>> _implicitPrimitiveConversionTable = new Dictionary<Type, HashSet<Type>>() {
                { typeof(sbyte), new HashSet<Type> { typeof(short), typeof(int), typeof(long), typeof(float), typeof(double), typeof(decimal) }},
                { typeof(byte), new HashSet<Type> { typeof(short), typeof(ushort), typeof(int), typeof(uint), typeof(long), typeof(ulong), typeof(float), typeof(double), typeof(decimal) }},
                { typeof(short), new HashSet<Type> { typeof(int), typeof(long), typeof(float), typeof(double), typeof(decimal) }},
                { typeof(ushort), new HashSet<Type> { typeof(int), typeof(uint), typeof(long), typeof(ulong), typeof(float), typeof(double), typeof(decimal) }},
                { typeof(int), new HashSet<Type> { typeof(long), typeof(float), typeof(double), typeof(decimal) }},
                { typeof(uint), new HashSet<Type> { typeof(long), typeof(ulong), typeof(float), typeof(double), typeof(decimal) }},
                { typeof(long), new HashSet<Type> { typeof(float), typeof(double), typeof(decimal) }},
                { typeof(char), new HashSet<Type> { typeof(ushort), typeof(int), typeof(uint), typeof(long), typeof(ulong), typeof(float), typeof(double), typeof(decimal) }},
                { typeof(float), new HashSet<Type> { typeof(double) }},
                { typeof(ulong), new HashSet<Type> { typeof(float), typeof(double), typeof(decimal) }},
            };

            private bool Ordinal { get { return (_options & EvaluateOptions.MatchStringsOrdinal) == EvaluateOptions.MatchStringsOrdinal; } }
            private bool IgnoreCaseString { get { return (_options & EvaluateOptions.MatchStringsWithIgnoreCase) == EvaluateOptions.MatchStringsWithIgnoreCase; } }
            private bool Checked { get { return (_options & EvaluateOptions.OverflowProtection) == EvaluateOptions.OverflowProtection; } }
            private bool CSharpNull { get { return true; /* TODO: should we use null like C#? */ } }

            public HackExpressionVistor(IDictionary<string, object> parameters, EvaluateOptions options)
            {
                _parameters = parameters;
                _options = options;
            }

            public HackExpressionVistor(L.ParameterExpression context, EvaluateOptions options)
            {
                _context = context;
                _options = options;
            }

            public L.Expression Result => _result;

            public override void Visit(LogicalExpression expression)
            {
                throw new NotImplementedException();
            }

            public override void Visit(TernaryExpression expression)
            {
                expression.LeftExpression.Accept(this);
                var test = _result;

                expression.MiddleExpression.Accept(this);
                var ifTrue = _result;

                expression.RightExpression.Accept(this);
                var ifFalse = _result;

                _result = L.Expression.Condition(test, ifTrue, ifFalse);
            }

            public override void Visit(BinaryExpression expression)
            {
                expression.LeftExpression.Accept(this);
                var left = _result;

                expression.RightExpression.Accept(this);
                var right = _result;

                switch (expression.Type)
                {
                    case BinaryExpressionType.And:
                        _result = L.Expression.AndAlso(left, right);
                        break;
                    case BinaryExpressionType.Or:
                        _result = L.Expression.OrElse(left, right);
                        break;
                    case BinaryExpressionType.NotEqual:
                        _result = WithCommonNumericType(left, right, L.Expression.NotEqual, expression.Type);
                        break;
                    case BinaryExpressionType.LesserOrEqual:
                        _result = WithCommonNumericType(left, right, L.Expression.LessThanOrEqual, expression.Type);
                        break;
                    case BinaryExpressionType.GreaterOrEqual:
                        _result = WithCommonNumericType(left, right, L.Expression.GreaterThanOrEqual, expression.Type);
                        break;
                    case BinaryExpressionType.Lesser:
                        _result = WithCommonNumericType(left, right, L.Expression.LessThan, expression.Type);
                        break;
                    case BinaryExpressionType.Greater:
                        _result = WithCommonNumericType(left, right, L.Expression.GreaterThan, expression.Type);
                        break;
                    case BinaryExpressionType.Equal:
                        _result = WithCommonNumericType(left, right, L.Expression.Equal, expression.Type);
                        break;
                    case BinaryExpressionType.Minus:
                        if (Checked) _result = WithCommonNumericType(left, right, L.Expression.SubtractChecked);
                        else _result = WithCommonNumericType(left, right, L.Expression.Subtract);
                        break;
                    case BinaryExpressionType.Plus:
                        if (Checked) _result = WithCommonNumericType(left, right, L.Expression.AddChecked);
                        else _result = WithCommonNumericType(left, right, L.Expression.Add);
                        break;
                    case BinaryExpressionType.Modulo:
                        _result = WithCommonNumericType(left, right, L.Expression.Modulo);
                        break;
                    case BinaryExpressionType.Div:
                        _result = WithCommonNumericType(left, right, L.Expression.Divide);
                        break;
                    case BinaryExpressionType.Times:
                        if (Checked) _result = WithCommonNumericType(left, right, L.Expression.MultiplyChecked);
                        else _result = WithCommonNumericType(left, right, L.Expression.Multiply);
                        break;
                    case BinaryExpressionType.BitwiseOr:
                        _result = L.Expression.Or(left, right);
                        break;
                    case BinaryExpressionType.BitwiseAnd:
                        _result = L.Expression.And(left, right);
                        break;
                    case BinaryExpressionType.BitwiseXOr:
                        _result = L.Expression.ExclusiveOr(left, right);
                        break;
                    case BinaryExpressionType.LeftShift:
                        _result = L.Expression.LeftShift(left, right);
                        break;
                    case BinaryExpressionType.RightShift:
                        _result = L.Expression.RightShift(left, right);
                        break;
                    default:
                        throw new ArgumentOutOfRangeException();
                }
            }

            public override void Visit(UnaryExpression expression)
            {
                expression.Expression.Accept(this);
                switch (expression.Type)
                {
                    case UnaryExpressionType.Not:
                        _result = L.Expression.Not(_result);
                        break;
                    case UnaryExpressionType.Negate:
                        _result = L.Expression.Negate(_result);
                        break;
                    case UnaryExpressionType.BitwiseNot:
                        _result = L.Expression.Not(_result);
                        break;
                    default:
                        throw new ArgumentOutOfRangeException();
                }
            }

            public override void Visit(ValueExpression expression)
            {
                _result = L.Expression.Constant(expression.Value);
            }

            public override void Visit(Function function)
            {
                var args = new L.Expression[function.Expressions.Length];
                for (int i = 0; i < function.Expressions.Length; i++)
                {
                    function.Expressions[i].Accept(this);
                    args[i] = _result;
                }

                string functionName = function.Identifier.Name.ToLowerInvariant();
                if (functionName == "if")
                {
                    var numberTypePriority = new Type[] { typeof(double), typeof(float), typeof(long), typeof(int), typeof(short) };
                    var index1 = Array.IndexOf(numberTypePriority, args[1].Type);
                    var index2 = Array.IndexOf(numberTypePriority, args[2].Type);
                    if (index1 >= 0 && index2 >= 0 && index1 != index2)
                    {
                        args[1] = L.Expression.Convert(args[1], numberTypePriority[Math.Min(index1, index2)]);
                        args[2] = L.Expression.Convert(args[2], numberTypePriority[Math.Min(index1, index2)]);
                    }
                    _result = L.Expression.Condition(args[0], args[1], args[2]);
                    return;
                }
                else if (functionName == "in")
                {
                    var items = L.Expression.NewArrayInit(args[0].Type,
                            new ArraySegment<L.Expression>(args, 1, args.Length - 1));
                    var smi = typeof(Array).GetRuntimeMethod("IndexOf", new[] { typeof(Array), typeof(object) });
                    var r = L.Expression.Call(smi, L.Expression.Convert(items, typeof(Array)), L.Expression.Convert(args[0], typeof(object)));
                    _result = L.Expression.GreaterThanOrEqual(r, L.Expression.Constant(0));
                    return;
                }

                //Context methods take precedence over built-in functions because they're user-customisable.
                var mi = FindMethod(function.Identifier.Name, args);
                if (mi != null)
                {
                    _result = L.Expression.Call(_context, mi.BaseMethodInfo, mi.PreparedArguments);
                    return;
                }

                switch (functionName)
                {
                    case "min":
                        var minArg0 = L.Expression.Convert(args[0], typeof(double));
                        var minArg1 = L.Expression.Convert(args[1], typeof(double));
                        _result = L.Expression.Condition(L.Expression.LessThan(minArg0, minArg1), minArg0, minArg1);
                        break;
                    case "max":
                        var maxArg0 = L.Expression.Convert(args[0], typeof(double));
                        var maxArg1 = L.Expression.Convert(args[1], typeof(double));
                        _result = L.Expression.Condition(L.Expression.GreaterThan(maxArg0, maxArg1), maxArg0, maxArg1);
                        break;
                    case "pow":
                        var powArg0 = L.Expression.Convert(args[0], typeof(double));
                        var powArg1 = L.Expression.Convert(args[1], typeof(double));
                        _result = L.Expression.Power(powArg0, powArg1);
                        break;
                    default:
                        throw new MissingMethodException($"method not found: {functionName}");
                }
            }

            public override void Visit(Identifier function)
            {
                if (_context == null)
                {
                    _result = L.Expression.Constant(_parameters[function.Name]);
                }
                else
                {
                    _result = L.Expression.PropertyOrField(_context, function.Name);
                }
            }

            private ExtendedMethodInfo FindMethod(string methodName, L.Expression[] methodArgs)
            {
                if (_context == null) return null;

                TypeInfo contextTypeInfo = _context.Type.GetTypeInfo();
                TypeInfo objectTypeInfo = typeof(object).GetTypeInfo();
                do
                {
                    var methods = contextTypeInfo.DeclaredMethods.Where(m => m.Name.Equals(methodName, StringComparison.OrdinalIgnoreCase) && m.IsPublic && !m.IsStatic);
                    var candidates = new List<ExtendedMethodInfo>();
                    foreach (var potentialMethod in methods)
                    {
                        var methodParams = potentialMethod.GetParameters();
                        var preparedArguments = PrepareMethodArgumentsIfValid(methodParams, methodArgs);

                        if (preparedArguments != null)
                        {
                            var candidate = new ExtendedMethodInfo()
                            {
                                BaseMethodInfo = potentialMethod,
                                PreparedArguments = preparedArguments.Item2,
                                Score = preparedArguments.Item1
                            };
                            if (candidate.Score == 0) return candidate;
                            candidates.Add(candidate);
                        }
                    }
                    if (candidates.Any()) return candidates.OrderBy(method => method.Score).First();
                    contextTypeInfo = contextTypeInfo.BaseType.GetTypeInfo();
                } while (contextTypeInfo != objectTypeInfo);
                return null;
            }

            /// <summary>
            /// Returns a tuple where the first item is a score, and the second is a list of prepared arguments. 
            /// Score is a simplified indicator of how close the arguments' types are to the parameters'. A score of 0 indicates a perfect match between arguments and parameters. 
            /// Prepared arguments refers to having the arguments implicitly converted where necessary, and "params" arguments collated into one array.
            /// </summary>
            /// <param name="parameters"></param>
            /// <param name="arguments"></param>
            /// <returns></returns>
            private Tuple<int, L.Expression[]> PrepareMethodArgumentsIfValid(ParameterInfo[] parameters, L.Expression[] arguments)
            {
                if (!parameters.Any() && !arguments.Any()) return Tuple.Create(0, arguments);
                if (!parameters.Any()) return null;

                var lastParameter = parameters.Last();
                bool hasParamsKeyword = lastParameter.IsDefined(typeof(ParamArrayAttribute));
                if (hasParamsKeyword && parameters.Length > arguments.Length) return null;
                L.Expression[] newArguments = new L.Expression[parameters.Length];
                L.Expression[] paramsKeywordArgument = null;
                Type paramsElementType = null;
                int paramsParameterPosition = 0;
                if (!hasParamsKeyword)
                {
                    if (parameters.Length != arguments.Length) return null;
                }
                else
                {
                    paramsParameterPosition = lastParameter.Position;
                    paramsElementType = lastParameter.ParameterType.GetElementType();
                    paramsKeywordArgument = new L.Expression[arguments.Length - parameters.Length + 1];
                }

                int functionMemberScore = 0;
                for (int i = 0; i < arguments.Length; i++)
                {
                    var isParamsElement = hasParamsKeyword && i >= paramsParameterPosition;
                    var argument = arguments[i];
                    var argumentType = argument.Type;
                    var parameterType = isParamsElement ? paramsElementType : parameters[i].ParameterType;
                    if (argumentType != parameterType)
                    {
                        bool canCastImplicitly = TryCastImplicitly(argumentType, parameterType, ref argument);
                        if (!canCastImplicitly) return null;
                        functionMemberScore++;
                    }
                    if (!isParamsElement)
                    {
                        newArguments[i] = argument;
                    }
                    else
                    {
                        paramsKeywordArgument[i - paramsParameterPosition] = argument;
                    }
                }

                if (hasParamsKeyword)
                {
                    newArguments[paramsParameterPosition] = L.Expression.NewArrayInit(paramsElementType, paramsKeywordArgument);
                }
                return Tuple.Create(functionMemberScore, newArguments);
            }

            private bool TryCastImplicitly(Type from, Type to, ref L.Expression argument)
            {
                if (CSharpNull && IsNullable(to))
                {
                    var fromWithoutNullable = GetNullableType(from) ?? from;
                    var toWithoutNullable = GetNullableType(to);
                    if (fromWithoutNullable != toWithoutNullable)
                    {
                        bool convertingFromPrimitiveType = _implicitPrimitiveConversionTable.TryGetValue(fromWithoutNullable, out var possibleConversions);
                        if (!convertingFromPrimitiveType || !possibleConversions.Contains(toWithoutNullable))
                        {
                            argument = null;
                            return false;
                        }
                    }
                }
                else
                {
                    bool convertingFromPrimitiveType = _implicitPrimitiveConversionTable.TryGetValue(from, out var possibleConversions);
                    if (!convertingFromPrimitiveType || !possibleConversions.Contains(to))
                    {
                        argument = null;
                        return false;
                    }
                }
                argument = L.Expression.Convert(argument, to);
                return true;
            }

            private L.Expression WithCommonNumericType(L.Expression left, L.Expression right,
                Func<L.Expression, L.Expression, L.Expression> action, BinaryExpressionType expressiontype = BinaryExpressionType.Unknown)
            {
                action = UnwrapNullableAction(left, right, action, expressiontype);
                left = UnwrapNullable(left);
                right = UnwrapNullable(right);

                if (_options.HasFlag(EvaluateOptions.BooleanCalculation))
                {
                    if (left.Type == typeof(bool))
                    {
                        left = L.Expression.Condition(left, L.Expression.Constant(1.0), L.Expression.Constant(0.0));
                    }

                    if (right.Type == typeof(bool))
                    {
                        right = L.Expression.Condition(right, L.Expression.Constant(1.0), L.Expression.Constant(0.0));
                    }
                }

                var precedence = new[]
                {
                    typeof(decimal),
                    typeof(double),
                    typeof(float),
                    typeof(ulong),
                    typeof(long),
                    typeof(uint),
                    typeof(int),
                    typeof(ushort),
                    typeof(short),
                    typeof(byte),
                    typeof(sbyte)
                };

                int l = Array.IndexOf(precedence, left.Type);
                int r = Array.IndexOf(precedence, right.Type);
                if (l >= 0 && r >= 0)
                {
                    var type = precedence[Math.Min(l, r)];
                    if (left.Type != type)
                    {
                        left = L.Expression.Convert(left, type);
                    }

                    if (right.Type != type)
                    {
                        right = L.Expression.Convert(right, type);
                    }
                }
                L.Expression comparer = null;
                if (IgnoreCaseString)
                {
                    if (Ordinal) comparer = L.Expression.Property(null, typeof(StringComparer), "OrdinalIgnoreCase");
                    else comparer = L.Expression.Property(null, typeof(StringComparer), "CurrentCultureIgnoreCase");
                }
                else comparer = L.Expression.Property(null, typeof(StringComparer), "Ordinal");

                if (comparer != null && (typeof(string).Equals(left.Type) || typeof(string).Equals(right.Type)))
                {
                    switch (expressiontype)
                    {
                        case BinaryExpressionType.Equal: return L.Expression.Call(comparer, typeof(StringComparer).GetRuntimeMethod("Equals", new[] { typeof(string), typeof(string) }), new L.Expression[] { left, right });
                        case BinaryExpressionType.NotEqual: return L.Expression.Not(L.Expression.Call(comparer, typeof(StringComparer).GetRuntimeMethod("Equals", new[] { typeof(string), typeof(string) }), new L.Expression[] { left, right }));
                        case BinaryExpressionType.GreaterOrEqual: return L.Expression.GreaterThanOrEqual(L.Expression.Call(comparer, typeof(StringComparer).GetRuntimeMethod("Compare", new[] { typeof(string), typeof(string) }), new L.Expression[] { left, right }), L.Expression.Constant(0));
                        case BinaryExpressionType.LesserOrEqual: return L.Expression.LessThanOrEqual(L.Expression.Call(comparer, typeof(StringComparer).GetRuntimeMethod("Compare", new[] { typeof(string), typeof(string) }), new L.Expression[] { left, right }), L.Expression.Constant(0));
                        case BinaryExpressionType.Greater: return L.Expression.GreaterThan(L.Expression.Call(comparer, typeof(StringComparer).GetRuntimeMethod("Compare", new[] { typeof(string), typeof(string) }), new L.Expression[] { left, right }), L.Expression.Constant(0));
                        case BinaryExpressionType.Lesser: return L.Expression.LessThan(L.Expression.Call(comparer, typeof(StringComparer).GetRuntimeMethod("Compare", new[] { typeof(string), typeof(string) }), new L.Expression[] { left, right }), L.Expression.Constant(0));
                    }
                }
                return action(left, right);
            }

            private bool IsNullable(Type type)
            {
                var ti = type.GetTypeInfo();
                return ti.IsGenericType && ti.GetGenericTypeDefinition() == typeof(Nullable<>);
            }

            private bool IsNullable(L.Expression expression)
            {
                return IsNullable(expression.Type);
            }

            private Type GetNullableType(Type type)
            {
                return IsNullable(type) ? type.GetTypeInfo().GenericTypeArguments[0] : null;
            }

            private Type GetNullableType(L.Expression expression)
            {
                return GetNullableType(expression.Type);
            }

            private L.Expression HaveValueExpression(L.Expression expression)
            {
                if (IsNullable(expression))
                {
                    return L.Expression.Property(expression, "HasValue");
                }
                else
                {
                    return L.Expression.Constant(true);
                }
            }

            private Func<L.Expression, L.Expression, L.Expression> UnwrapNullableAction(L.Expression left, L.Expression rigth, Func<L.Expression, L.Expression, L.Expression> action, BinaryExpressionType expressiontype)
            {
                if (CSharpNull && (IsNullable(left) || IsNullable(rigth)))
                {
                    bool castToNullable = false;
                    L.Expression nullMismatch;
                    switch (expressiontype)
                    {
                        case BinaryExpressionType.Equal:
                            nullMismatch = L.Expression.Equal(HaveValueExpression(left), HaveValueExpression(rigth));
                            break;

                        case BinaryExpressionType.NotEqual:
                            nullMismatch = L.Expression.NotEqual(HaveValueExpression(left), HaveValueExpression(rigth));
                            break;

                        case BinaryExpressionType.Lesser:
                        case BinaryExpressionType.LesserOrEqual:
                        case BinaryExpressionType.Greater:
                        case BinaryExpressionType.GreaterOrEqual:
                            nullMismatch = L.Expression.Constant(false);
                            break;

                        default:
                            castToNullable = true;
                            nullMismatch = L.Expression.Constant(null);
                            break;
                    }

                    return (l, r) =>
                    {
                        var act = action(l, r);
                        var nullableType = typeof(Nullable<>).MakeGenericType(act.Type);
                        return L.Expression.Condition(
                            L.Expression.And(HaveValueExpression(left), HaveValueExpression(rigth)),
                            castToNullable ? L.Expression.Convert(act, nullableType) : act,
                            castToNullable ? L.Expression.Convert(nullMismatch, nullableType) : nullMismatch);
                    };
                }
                else
                {
                    return action;
                }
            }

            private L.Expression UnwrapNullable(L.Expression expression)
            {
                if (IsNullable(expression))
                {
                    return L.Expression.Condition(
                        HaveValueExpression(expression),
                        L.Expression.Property(expression, "Value"),
                        L.Expression.Default(GetNullableType(expression)));
                }

                return expression;
            }
        }

Its not perfect solution, as I could change more code to have simpler result expression tree but I do not need 101% performace from this code only that all automatic tests pass.

sklose commented 1 year ago

Thanks for raising this issue. If you can turn your code into a PR I am happy to merge this change. I agree it should be behind an op-in flag to not break existing users.

Yankes commented 1 year ago

Ok, probaby in next week I will have some free time to prepare PR for this.