From 8124d8c40f4bc9a3792a86af5bda0ffea05cf2a5 Mon Sep 17 00:00:00 2001 From: Stef Heyenrath Date: Thu, 29 Feb 2024 09:02:22 +0100 Subject: [PATCH] Keep original type from subquery (#777) * Fix 775 * ... * ... * Issue775a * InvalidOperationException * dict --- .../Parser/ExpressionParser.cs | 40 +- .../SupportedMethods/CompareConversionType.cs | 16 +- .../SupportedMethods/IDictionarySignatures.cs | 7 - .../SupportedMethods/IEnumerableSignatures.cs | 78 --- .../SupportedMethods/IQueryableSignatures.cs | 73 --- .../Parser/SupportedMethods/MethodFinder.cs | 508 +++++++++--------- .../Parser/TypeHelper.cs | 4 +- src/System.Linq.Dynamic.Core/Res.cs | 1 - .../Issues/Issue775.cs | 52 ++ 9 files changed, 336 insertions(+), 443 deletions(-) delete mode 100644 src/System.Linq.Dynamic.Core/Parser/SupportedMethods/IDictionarySignatures.cs delete mode 100644 src/System.Linq.Dynamic.Core/Parser/SupportedMethods/IEnumerableSignatures.cs delete mode 100644 src/System.Linq.Dynamic.Core/Parser/SupportedMethods/IQueryableSignatures.cs create mode 100644 test/System.Linq.Dynamic.Core.Tests/Issues/Issue775.cs diff --git a/src/System.Linq.Dynamic.Core/Parser/ExpressionParser.cs b/src/System.Linq.Dynamic.Core/Parser/ExpressionParser.cs index 74240e61..8dfc8426 100644 --- a/src/System.Linq.Dynamic.Core/Parser/ExpressionParser.cs +++ b/src/System.Linq.Dynamic.Core/Parser/ExpressionParser.cs @@ -34,6 +34,7 @@ public class ExpressionParser private readonly TextParser _textParser; private readonly NumberParser _numberParser; private readonly IExpressionHelper _expressionHelper; + private readonly ConstantExpressionHelper _constantExpressionHelper; private readonly ITypeFinder _typeFinder; private readonly ITypeConverterFactory _typeConverterFactory; private readonly Dictionary _internals = new(); @@ -45,7 +46,6 @@ public class ExpressionParser private ParameterExpression? _root; private Type? _resultType; private bool _createParameterCtor; - private ConstantExpressionHelper _constantExpressionHelper; /// /// Gets name for the `it` field. By default this is set to the KeyWord value "it". @@ -387,19 +387,11 @@ private Expression ParseIn() throw ParseError(_textParser.CurrentToken.Pos, Res.IdentifierImplementingInterfaceExpected, typeof(IEnumerable)); } - var args = new[] { left }; - - Expression? nullExpressionReference = null; - if (_methodFinder.FindMethod(typeof(IEnumerableSignatures), nameof(IEnumerableSignatures.Contains), false, ref nullExpressionReference, ref args, out var containsSignature) != 1) - { - throw ParseError(op.Pos, Res.NoApplicableAggregate, nameof(IEnumerableSignatures.Contains), string.Join(",", args.Select(a => a.Type.Name).ToArray())); - } - var typeArgs = new[] { left.Type }; - args = new[] { right, left }; + var args = new[] { right, left }; - accumulate = Expression.Call(typeof(Enumerable), containsSignature!.Name, typeArgs, args); + accumulate = Expression.Call(typeof(Enumerable), nameof(Enumerable.Contains), typeArgs, args); } else { @@ -2014,9 +2006,6 @@ private Expression ParseAsEnum(string id) private Expression ParseEnumerable(Expression instance, Type elementType, string methodName, int errorPos, Type? type) { - bool isQueryable = TypeHelper.FindGenericType(typeof(IQueryable<>), type) != null; - bool isDictionary = TypeHelper.IsDictionary(type); - var oldParent = _parent; ParameterExpression? outerIt = _it; @@ -2024,7 +2013,7 @@ private Expression ParseEnumerable(Expression instance, Type elementType, string _parent = _it; - if (methodName == "Contains" || methodName == "ContainsKey" || methodName == "Skip" || methodName == "Take") + if (new[] { "Contains", "ContainsKey", "Skip", "Take" }.Contains(methodName)) { // for any method that acts on the parent element type, we need to specify the outerIt as scope. _it = outerIt; @@ -2039,19 +2028,14 @@ private Expression ParseEnumerable(Expression instance, Type elementType, string _it = outerIt; _parent = oldParent; - if (isDictionary && _methodFinder.ContainsMethod(typeof(IDictionarySignatures), methodName, false, null, ref args)) - { - var method = type!.GetMethod(methodName)!; - return Expression.Call(instance, method, args); - } - - if (!_methodFinder.ContainsMethod(typeof(IEnumerableSignatures), methodName, false, null, ref args)) + if (type != null && TypeHelper.IsDictionary(type) && _methodFinder.ContainsMethod(type, methodName, false)) { - throw ParseError(errorPos, Res.NoApplicableAggregate, methodName, string.Join(",", args.Select(a => a.Type.Name).ToArray())); + var dictionaryMethod = type.GetMethod(methodName)!; + return Expression.Call(instance, dictionaryMethod, args); } - Type callType = typeof(Enumerable); - if (isQueryable && _methodFinder.ContainsMethod(typeof(IQueryableSignatures), methodName, false, null, ref args)) + var callType = typeof(Enumerable); + if (type != null && TypeHelper.FindGenericType(typeof(IQueryable<>), type) != null && _methodFinder.ContainsMethod(type, methodName)) { callType = typeof(Queryable); } @@ -2073,10 +2057,14 @@ private Expression ParseEnumerable(Expression instance, Type elementType, string { typeArgs = new[] { elementType, args[0].Type, args[1].Type }; } - else + else if (args.Length == 1) { typeArgs = new[] { elementType, args[0].Type }; } + else + { + typeArgs = new[] { elementType }; + } } else if (methodName == "SelectMany") { diff --git a/src/System.Linq.Dynamic.Core/Parser/SupportedMethods/CompareConversionType.cs b/src/System.Linq.Dynamic.Core/Parser/SupportedMethods/CompareConversionType.cs index 655b87e5..21bd6cce 100644 --- a/src/System.Linq.Dynamic.Core/Parser/SupportedMethods/CompareConversionType.cs +++ b/src/System.Linq.Dynamic.Core/Parser/SupportedMethods/CompareConversionType.cs @@ -1,10 +1,8 @@ - -namespace System.Linq.Dynamic.Core.Parser.SupportedMethods +namespace System.Linq.Dynamic.Core.Parser.SupportedMethods; + +internal enum CompareConversionType { - internal enum CompareConversionType - { - Both = 0, - First = 1, - Second = -1 - } -} + Both = 0, + First = 1, + Second = -1 +} \ No newline at end of file diff --git a/src/System.Linq.Dynamic.Core/Parser/SupportedMethods/IDictionarySignatures.cs b/src/System.Linq.Dynamic.Core/Parser/SupportedMethods/IDictionarySignatures.cs deleted file mode 100644 index b3cd8506..00000000 --- a/src/System.Linq.Dynamic.Core/Parser/SupportedMethods/IDictionarySignatures.cs +++ /dev/null @@ -1,7 +0,0 @@ -namespace System.Linq.Dynamic.Core.Parser.SupportedMethods -{ - internal interface IDictionarySignatures - { - void ContainsKey(object selector); - } -} diff --git a/src/System.Linq.Dynamic.Core/Parser/SupportedMethods/IEnumerableSignatures.cs b/src/System.Linq.Dynamic.Core/Parser/SupportedMethods/IEnumerableSignatures.cs deleted file mode 100644 index d6c3f08f..00000000 --- a/src/System.Linq.Dynamic.Core/Parser/SupportedMethods/IEnumerableSignatures.cs +++ /dev/null @@ -1,78 +0,0 @@ -using System.Collections; - -namespace System.Linq.Dynamic.Core.Parser.SupportedMethods -{ - internal interface IEnumerableSignatures - { - void All(bool predicate); - void Any(); - void Any(bool predicate); - void Average(decimal? selector); - void Average(decimal selector); - void Average(double? selector); - void Average(double selector); - void Average(float? selector); - void Average(float selector); - void Average(int? selector); - void Average(int selector); - void Average(long? selector); - void Average(long selector); - void Cast(string type); - void Cast(Type type); - void Concat(IEnumerable enumerable); - void Contains(object selector); - void Count(); - void Count(bool predicate); - void DefaultIfEmpty(); - void DefaultIfEmpty(object defaultValue); - void Distinct(); - void Except(IEnumerable enumerable); - void First(); - void First(bool predicate); - void FirstOrDefault(); - void FirstOrDefault(bool predicate); - void GroupBy(object keySelector); - void GroupBy(object keySelector, object elementSelector); - void Intersect(IEnumerable enumerable); - void Last(); - void Last(bool predicate); - void LastOrDefault(); - void LastOrDefault(bool predicate); - void LongCount(); - void LongCount(bool predicate); - void Max(object selector); - void Min(object selector); - void OfType(string type); - void OfType(Type type); - void OrderBy(object selector); - void OrderByDescending(object selector); - void Select(object selector); - void SelectMany(object selector); - void Single(); - void Single(bool predicate); - void SingleOrDefault(); - void SingleOrDefault(bool predicate); - void Skip(int count); - void SkipWhile(bool predicate); - void Sum(decimal? selector); - void Sum(decimal selector); - void Sum(double? selector); - void Sum(double selector); - void Sum(float? selector); - void Sum(float selector); - void Sum(int? selector); - void Sum(int selector); - void Sum(long? selector); - void Sum(long selector); - void Take(int count); - void TakeWhile(bool predicate); - void ThenBy(object selector); - void ThenByDescending(object selector); - void Union(IEnumerable enumerable); - void Where(bool predicate); - - // Executors - void ToArray(); - void ToList(); - } -} diff --git a/src/System.Linq.Dynamic.Core/Parser/SupportedMethods/IQueryableSignatures.cs b/src/System.Linq.Dynamic.Core/Parser/SupportedMethods/IQueryableSignatures.cs deleted file mode 100644 index 60d6dffb..00000000 --- a/src/System.Linq.Dynamic.Core/Parser/SupportedMethods/IQueryableSignatures.cs +++ /dev/null @@ -1,73 +0,0 @@ -using System.Collections; - -namespace System.Linq.Dynamic.Core.Parser.SupportedMethods -{ - internal interface IQueryableSignatures - { - void All(bool predicate); - void Any(); - void Any(bool predicate); - void Average(decimal? selector); - void Average(decimal selector); - void Average(double? selector); - void Average(double selector); - void Average(float? selector); - void Average(float selector); - void Average(int? selector); - void Average(int selector); - void Average(long? selector); - void Average(long selector); - void Concat(IEnumerable enumerable); - void Cast(string type); - void Cast(Type type); - void Count(); - void Count(bool predicate); - void DefaultIfEmpty(); - void DefaultIfEmpty(object defaultValue); - void Distinct(); - void Except(IEnumerable enumerable); - void First(); - void First(bool predicate); - void FirstOrDefault(); - void FirstOrDefault(bool predicate); - void GroupBy(object keySelector); - void GroupBy(object keySelector, object elementSelector); - void Intersect(IEnumerable enumerable); - void Last(); - void Last(bool predicate); - void LastOrDefault(); - void LastOrDefault(bool predicate); - void LongCount(); - void LongCount(bool predicate); - void Max(object selector); - void Min(object selector); - void OfType(string type); - void OfType(Type type); - void OrderBy(object selector); - void OrderByDescending(object selector); - void Select(object selector); - void SelectMany(object selector); - void Single(); - void Single(bool predicate); - void SingleOrDefault(); - void SingleOrDefault(bool predicate); - void Skip(int count); - void SkipWhile(bool predicate); - void Sum(decimal? selector); - void Sum(decimal selector); - void Sum(double? selector); - void Sum(double selector); - void Sum(float? selector); - void Sum(float selector); - void Sum(int? selector); - void Sum(int selector); - void Sum(long? selector); - void Sum(long selector); - void Take(int count); - void TakeWhile(bool predicate); - void ThenBy(object selector); - void ThenByDescending(object selector); - void Union(IEnumerable enumerable); - void Where(bool predicate); - } -} diff --git a/src/System.Linq.Dynamic.Core/Parser/SupportedMethods/MethodFinder.cs b/src/System.Linq.Dynamic.Core/Parser/SupportedMethods/MethodFinder.cs index d5d667b9..98afb028 100644 --- a/src/System.Linq.Dynamic.Core/Parser/SupportedMethods/MethodFinder.cs +++ b/src/System.Linq.Dynamic.Core/Parser/SupportedMethods/MethodFinder.cs @@ -3,29 +3,43 @@ using System.Linq.Expressions; using System.Reflection; -namespace System.Linq.Dynamic.Core.Parser.SupportedMethods +namespace System.Linq.Dynamic.Core.Parser.SupportedMethods; + +internal class MethodFinder { - internal class MethodFinder + private readonly ParsingConfig _parsingConfig; + private readonly IExpressionHelper _expressionHelper; + + public MethodFinder(ParsingConfig parsingConfig, IExpressionHelper expressionHelper) { - private readonly ParsingConfig _parsingConfig; - private readonly IExpressionHelper _expressionHelper; + _parsingConfig = Check.NotNull(parsingConfig); + _expressionHelper = Check.NotNull(expressionHelper); + } - public MethodFinder(ParsingConfig parsingConfig, IExpressionHelper expressionHelper) - { - _parsingConfig = Check.NotNull(parsingConfig); - _expressionHelper = Check.NotNull(expressionHelper); - } + public bool ContainsMethod(Type type, string methodName, bool staticAccess = true) + { + Check.NotNull(type); - public bool ContainsMethod(Type type, string methodName, bool staticAccess, Expression? instance, ref Expression[] args) - { - // NOTE: `instance` is not passed by ref in the method signature by design. The ContainsMethod should not change the instance. - // However, args by reference is required for backward compatibility (removing "ref" will break some tests) +#if !(NETFX_CORE || WINDOWS_APP || UAP10_0 || NETSTANDARD) + var flags = BindingFlags.Public | BindingFlags.DeclaredOnly | (staticAccess ? BindingFlags.Static : BindingFlags.Instance); + return type.FindMembers(MemberTypes.Method, flags, Type.FilterNameIgnoreCase, methodName).Any(); +#else + return type.GetTypeInfo().DeclaredMethods.Any(m => (m.IsStatic || !staticAccess) && m.Name.Equals(methodName, StringComparison.OrdinalIgnoreCase)); +#endif + } - return FindMethod(type, methodName, staticAccess, ref instance, ref args, out _) == 1; - } + public bool ContainsMethod(Type type, string methodName, bool staticAccess, Expression? instance, ref Expression[] args) + { + Check.NotNull(type); - public int FindMethod(Type? type, string methodName, bool staticAccess, ref Expression? instance, ref Expression[] args, out MethodBase? method) - { + // NOTE: `instance` is not passed by ref in the method signature by design. The ContainsMethod should not change the instance. + // However, args by reference is required for backward compatibility (removing "ref" will break some tests) + + return FindMethod(type, methodName, staticAccess, ref instance, ref args, out _) == 1; + } + + public int FindMethod(Type? type, string methodName, bool staticAccess, ref Expression? instance, ref Expression[] args, out MethodBase? method) + { #if !(NETFX_CORE || WINDOWS_APP || UAP10_0 || NETSTANDARD) BindingFlags flags = BindingFlags.Public | BindingFlags.DeclaredOnly | (staticAccess ? BindingFlags.Static : BindingFlags.Instance); foreach (Type t in SelfAndBaseTypes(type)) @@ -38,327 +52,327 @@ public int FindMethod(Type? type, string methodName, bool staticAccess, ref Expr } } #else - foreach (Type t in SelfAndBaseTypes(type)) + foreach (Type t in SelfAndBaseTypes(type)) + { + var methods = t.GetTypeInfo().DeclaredMethods.Where(m => (m.IsStatic || !staticAccess) && m.Name.Equals(methodName, StringComparison.OrdinalIgnoreCase)).ToArray(); + int count = FindBestMethodBasedOnArguments(methods, ref args, out method); + if (count != 0) { - var methods = t.GetTypeInfo().DeclaredMethods.Where(m => (m.IsStatic || !staticAccess) && m.Name.Equals(methodName, StringComparison.OrdinalIgnoreCase)).ToArray(); - int count = FindBestMethodBasedOnArguments(methods, ref args, out method); - if (count != 0) - { - return count; - } + return count; } + } #endif - if (instance != null) + if (instance != null) + { + // Try to solve with registered extension methods from this type and all base types + var methods = new List(); + foreach (var t in SelfAndBaseTypes(type)) { - // Try to solve with registered extension methods from this type and all base types - var methods = new List(); - foreach (var t in SelfAndBaseTypes(type)) + if (_parsingConfig.CustomTypeProvider.GetExtensionMethods().TryGetValue(t, out var extensionMethodsOfType)) { - if (_parsingConfig.CustomTypeProvider.GetExtensionMethods().TryGetValue(t, out var extensionMethodsOfType)) - { - methods.AddRange(extensionMethodsOfType.Where(m => m.Name.Equals(methodName, StringComparison.OrdinalIgnoreCase))); - } + methods.AddRange(extensionMethodsOfType.Where(m => m.Name.Equals(methodName, StringComparison.OrdinalIgnoreCase))); } + } - if (methods.Any()) - { - var argsList = args.ToList(); - argsList.Insert(0, instance); + if (methods.Any()) + { + var argsList = args.ToList(); + argsList.Insert(0, instance); - var extensionMethodArgs = argsList.ToArray(); + var extensionMethodArgs = argsList.ToArray(); - // ReSharper disable once RedundantEnumerableCastCall - int count = FindBestMethodBasedOnArguments(methods.Cast(), ref extensionMethodArgs, out method); - if (count != 0) - { - instance = null; - args = extensionMethodArgs; - return count; - } + // ReSharper disable once RedundantEnumerableCastCall + int count = FindBestMethodBasedOnArguments(methods.Cast(), ref extensionMethodArgs, out method); + if (count != 0) + { + instance = null; + args = extensionMethodArgs; + return count; } } - - method = null; - return 0; } - public int FindBestMethodBasedOnArguments(IEnumerable methods, ref Expression[] args, out MethodBase? method) + method = null; + return 0; + } + + public int FindBestMethodBasedOnArguments(IEnumerable methods, ref Expression[] args, out MethodBase? method) + { + // Passing args by reference is now required with the params array support. + var inlineArgs = args; + + MethodData[] applicable = methods + .Select(m => new MethodData { MethodBase = m, Parameters = m.GetParameters() }) + .Where(m => IsApplicable(m, inlineArgs)) + .ToArray(); + + if (applicable.Length > 1) { - // Passing args by reference is now required with the params array support. - var inlineArgs = args; + applicable = applicable.Where(m => applicable.All(n => m == n || FirstIsBetterThanSecond(inlineArgs, m, n))).ToArray(); + } - MethodData[] applicable = methods - .Select(m => new MethodData { MethodBase = m, Parameters = m.GetParameters() }) - .Where(m => IsApplicable(m, inlineArgs)) - .ToArray(); + if (args.Length == 2 && applicable.Length > 1 && (args[0].Type == typeof(Guid?) || args[1].Type == typeof(Guid?))) + { + applicable = applicable.Take(1).ToArray(); + } - if (applicable.Length > 1) + if (applicable.Length == 1) + { + var methodData = applicable[0]; + if (methodData.MethodBase is MethodInfo methodInfo) { - applicable = applicable.Where(m => applicable.All(n => m == n || FirstIsBetterThanSecond(inlineArgs, m, n))).ToArray(); + method = methodInfo.GetBaseDefinition(); } - - if (args.Length == 2 && applicable.Length > 1 && (args[0].Type == typeof(Guid?) || args[1].Type == typeof(Guid?))) + else { - applicable = applicable.Take(1).ToArray(); + method = methodData.MethodBase; } - if (applicable.Length == 1) + if (args.Length == 0 || args.Length != methodData.Args.Length) { - var methodData = applicable[0]; - if (methodData.MethodBase is MethodInfo methodInfo) - { - method = methodInfo.GetBaseDefinition(); - } - else - { - method = methodData.MethodBase; - } - - if (args.Length == 0 || args.Length != methodData.Args.Length) - { - args = methodData.Args; - } - else + args = methodData.Args; + } + else + { + for (var i = 0; i < args.Length; i++) { - for (var i = 0; i < args.Length; i++) + if (args[i].Type != methodData.Args[i].Type && + args[i].Type.IsArray && methodData.Args[i].Type.IsArray && + args[i].Type != typeof(string) && methodData.Args[i].Type == typeof(object[])) { - if (args[i].Type != methodData.Args[i].Type && - args[i].Type.IsArray && methodData.Args[i].Type.IsArray && - args[i].Type != typeof(string) && methodData.Args[i].Type == typeof(object[])) - { - args[i] = _expressionHelper.ConvertAnyArrayToObjectArray(args[i]); - } - else - { - args[i] = methodData.Args[i]; - } + args[i] = _expressionHelper.ConvertAnyArrayToObjectArray(args[i]); + } + else + { + args[i] = methodData.Args[i]; } } } - else - { - method = null; - } - - return applicable.Length; } + else + { + method = null; + } + + return applicable.Length; + } - public int FindIndexer(Type type, Expression[] args, out MethodBase? method) + public int FindIndexer(Type type, Expression[] args, out MethodBase? method) + { + foreach (Type t in SelfAndBaseTypes(type)) { - foreach (Type t in SelfAndBaseTypes(type)) + MemberInfo[] members = t.GetDefaultMembers(); + if (members.Length != 0) { - MemberInfo[] members = t.GetDefaultMembers(); - if (members.Length != 0) - { - IEnumerable methods = members.OfType(). + IEnumerable methods = members.OfType(). #if !(NETFX_CORE || WINDOWS_APP || UAP10_0 || NETSTANDARD) Select(p => (MethodBase)p.GetGetMethod()). Where(m => m != null); #else Select(p => (MethodBase)p.GetMethod); #endif - int count = FindBestMethodBasedOnArguments(methods, ref args, out method); - if (count != 0) - { - return count; - } + int count = FindBestMethodBasedOnArguments(methods, ref args, out method); + if (count != 0) + { + return count; } } - - method = null; - return 0; } - private bool IsApplicable(MethodData method, Expression[] args) - { - bool isParamArray = method.Parameters.Length > 0 && method.Parameters.Last().IsDefined(typeof(ParamArrayAttribute), false); + method = null; + return 0; + } - // if !paramArray, the number of parameter must be equal - // if paramArray, the last parameter is optional - if ((!isParamArray && method.Parameters.Length != args.Length) || - (isParamArray && method.Parameters.Length - 1 > args.Length)) - { - return false; - } + private bool IsApplicable(MethodData method, Expression[] args) + { + bool isParamArray = method.Parameters.Length > 0 && method.Parameters.Last().IsDefined(typeof(ParamArrayAttribute), false); + + // if !paramArray, the number of parameter must be equal + // if paramArray, the last parameter is optional + if ((!isParamArray && method.Parameters.Length != args.Length) || + (isParamArray && method.Parameters.Length - 1 > args.Length)) + { + return false; + } - Expression[] promotedArgs = new Expression[method.Parameters.Length]; - for (int i = 0; i < method.Parameters.Length; i++) + Expression[] promotedArgs = new Expression[method.Parameters.Length]; + for (int i = 0; i < method.Parameters.Length; i++) + { + if (isParamArray && i == method.Parameters.Length - 1) { - if (isParamArray && i == method.Parameters.Length - 1) + if (method.Parameters.Length == args.Length + 1 + || (method.Parameters.Length == args.Length && args[i] is ConstantExpression constantExpression && constantExpression.Value == null)) { - if (method.Parameters.Length == args.Length + 1 - || (method.Parameters.Length == args.Length && args[i] is ConstantExpression constantExpression && constantExpression.Value == null)) - { - promotedArgs[promotedArgs.Length - 1] = Expression.Constant(null, method.Parameters.Last().ParameterType); - } - else if (method.Parameters.Length == args.Length && method.Parameters.Last().ParameterType == args.Last().Type) - { - promotedArgs[promotedArgs.Length - 1] = args.Last(); - } - else - { - var paramType = method.Parameters.Last().ParameterType; - var paramElementType = paramType.GetElementType()!; - - var arrayInitializerExpressions = new List(); - - for (int j = method.Parameters.Length - 1; j < args.Length; j++) - { - var promotedExpression = _parsingConfig.ExpressionPromoter.Promote(args[j], paramElementType, false, method.MethodBase.DeclaringType != typeof(IEnumerableSignatures)); - if (promotedExpression == null) - { - return false; - } - - arrayInitializerExpressions.Add(promotedExpression); - } - - var paramExpression = Expression.NewArrayInit(paramElementType, arrayInitializerExpressions); - - promotedArgs[promotedArgs.Length - 1] = paramExpression; - } + promotedArgs[promotedArgs.Length - 1] = Expression.Constant(null, method.Parameters.Last().ParameterType); + } + else if (method.Parameters.Length == args.Length && method.Parameters.Last().ParameterType == args.Last().Type) + { + promotedArgs[promotedArgs.Length - 1] = args.Last(); } else { - var methodParameter = method.Parameters[i]; - if (methodParameter.IsOut && args[i] is ParameterExpression parameterExpression) - { -#if NET35 - return false; -#else - if (!parameterExpression.IsByRef) - { - return false; - } + var paramType = method.Parameters.Last().ParameterType; + var paramElementType = paramType.GetElementType()!; - promotedArgs[i] = Expression.Parameter(methodParameter.ParameterType, methodParameter.Name); -#endif - } - else + var arrayInitializerExpressions = new List(); + + for (int j = method.Parameters.Length - 1; j < args.Length; j++) { - var promotedExpression = _parsingConfig.ExpressionPromoter.Promote(args[i], methodParameter.ParameterType, false, method.MethodBase.DeclaringType != typeof(IEnumerableSignatures)); + var promotedExpression = _parsingConfig.ExpressionPromoter.Promote(args[j], paramElementType, false, true); if (promotedExpression == null) { return false; } - promotedArgs[i] = promotedExpression; + arrayInitializerExpressions.Add(promotedExpression); } - } - } - method.Args = promotedArgs; - return true; - } + var paramExpression = Expression.NewArrayInit(paramElementType, arrayInitializerExpressions); - bool FirstIsBetterThanSecond(Expression[] args, MethodData first, MethodData second) - { - // If args count is 0 -> parameterless method is better than method method with parameters - if (args.Length == 0) - { - return first.Parameters.Length == 0 && second.Parameters.Length != 0; + promotedArgs[promotedArgs.Length - 1] = paramExpression; + } } - - bool better = false; - for (int i = 0; i < args.Length; i++) + else { - CompareConversionType result = CompareConversions(args[i].Type, first.Parameters[i].ParameterType, second.Parameters[i].ParameterType); - - // If second is better, return false - if (result == CompareConversionType.Second) + var methodParameter = method.Parameters[i]; + if (methodParameter.IsOut && args[i] is ParameterExpression parameterExpression) { - return false; - } +#if NET35 + return false; +#else + if (!parameterExpression.IsByRef) + { + return false; + } - // If first is better, return true - if (result == CompareConversionType.First) - { - return true; + promotedArgs[i] = Expression.Parameter(methodParameter.ParameterType, methodParameter.Name); +#endif } - - // If both are same, just set better to true and continue - if (result == CompareConversionType.Both) + else { - better = true; + var promotedExpression = _parsingConfig.ExpressionPromoter.Promote(args[i], methodParameter.ParameterType, false, true); + if (promotedExpression == null) + { + return false; + } + + promotedArgs[i] = promotedExpression; } } - - return better; } - // Return "First" if s -> t1 is a better conversion than s -> t2 - // Return "Second" if s -> t2 is a better conversion than s -> t1 - // Return "Both" if neither conversion is better - CompareConversionType CompareConversions(Type source, Type first, Type second) + method.Args = promotedArgs; + return true; + } + + private static bool FirstIsBetterThanSecond(Expression[] args, MethodData first, MethodData second) + { + // If args count is 0 -> parameterless method is better than method method with parameters + if (args.Length == 0) { - if (first == second) - { - return CompareConversionType.Both; - } - if (source == first) - { - return CompareConversionType.First; - } - if (source == second) - { - return CompareConversionType.Second; - } + return first.Parameters.Length == 0 && second.Parameters.Length != 0; + } - bool firstIsCompatibleWithSecond = TypeHelper.IsCompatibleWith(first, second); - bool secondIsCompatibleWithFirst = TypeHelper.IsCompatibleWith(second, first); + bool better = false; + for (int i = 0; i < args.Length; i++) + { + CompareConversionType result = CompareConversions(args[i].Type, first.Parameters[i].ParameterType, second.Parameters[i].ParameterType); - if (firstIsCompatibleWithSecond && !secondIsCompatibleWithFirst) + // If second is better, return false + if (result == CompareConversionType.Second) { - return CompareConversionType.First; - } - if (secondIsCompatibleWithFirst && !firstIsCompatibleWithSecond) - { - return CompareConversionType.Second; + return false; } - if (TypeHelper.IsSignedIntegralType(first) && TypeHelper.IsUnsignedIntegralType(second)) + // If first is better, return true + if (result == CompareConversionType.First) { - return CompareConversionType.First; + return true; } - if (TypeHelper.IsSignedIntegralType(second) && TypeHelper.IsUnsignedIntegralType(first)) + + // If both are same, just set better to true and continue + if (result == CompareConversionType.Both) { - return CompareConversionType.Second; + better = true; } + } + + return better; + } + // Return "First" if s -> t1 is a better conversion than s -> t2 + // Return "Second" if s -> t2 is a better conversion than s -> t1 + // Return "Both" if neither conversion is better + private static CompareConversionType CompareConversions(Type source, Type first, Type second) + { + if (first == second) + { return CompareConversionType.Both; } + if (source == first) + { + return CompareConversionType.First; + } + if (source == second) + { + return CompareConversionType.Second; + } - IEnumerable SelfAndBaseTypes(Type? type) + bool firstIsCompatibleWithSecond = TypeHelper.IsCompatibleWith(first, second); + bool secondIsCompatibleWithFirst = TypeHelper.IsCompatibleWith(second, first); + + if (firstIsCompatibleWithSecond && !secondIsCompatibleWithFirst) { - if (type?.GetTypeInfo().IsInterface == true) - { - var types = new List(); - AddInterface(types, type); - return types; - } - return SelfAndBaseClasses(type); + return CompareConversionType.First; + } + if (secondIsCompatibleWithFirst && !firstIsCompatibleWithSecond) + { + return CompareConversionType.Second; } - IEnumerable SelfAndBaseClasses(Type? type) + if (TypeHelper.IsSignedIntegralType(first) && TypeHelper.IsUnsignedIntegralType(second)) { - while (type != null) - { - yield return type; - type = type.GetTypeInfo().BaseType; - } + return CompareConversionType.First; } + if (TypeHelper.IsSignedIntegralType(second) && TypeHelper.IsUnsignedIntegralType(first)) + { + return CompareConversionType.Second; + } + + return CompareConversionType.Both; + } - void AddInterface(List types, Type type) + private static IEnumerable SelfAndBaseTypes(Type? type) + { + if (type?.GetTypeInfo().IsInterface == true) { - if (!types.Contains(type)) + var types = new List(); + AddInterfaces(types, type); + return types; + } + + return SelfAndBaseClasses(type); + } + + private static IEnumerable SelfAndBaseClasses(Type? type) + { + while (type != null) + { + yield return type; + type = type.GetTypeInfo().BaseType; + } + } + + private static void AddInterfaces(ICollection types, Type type) + { + if (!types.Contains(type)) + { + types.Add(type); + foreach (var interfaceType in type.GetInterfaces()) { - types.Add(type); - foreach (Type t in type.GetInterfaces()) - { - AddInterface(types, t); - } + AddInterfaces(types, interfaceType); } } } -} +} \ No newline at end of file diff --git a/src/System.Linq.Dynamic.Core/Parser/TypeHelper.cs b/src/System.Linq.Dynamic.Core/Parser/TypeHelper.cs index 7b8dfaca..0b067e09 100644 --- a/src/System.Linq.Dynamic.Core/Parser/TypeHelper.cs +++ b/src/System.Linq.Dynamic.Core/Parser/TypeHelper.cs @@ -446,7 +446,7 @@ public static Type GetUnderlyingType(Type type) public static IList GetSelfAndBaseTypes(Type type, bool excludeObject = false) { - Check.NotNull(type, nameof(type)); + Check.NotNull(type); if (type.GetTypeInfo().IsInterface) { @@ -458,7 +458,7 @@ public static IList GetSelfAndBaseTypes(Type type, bool excludeObject = fa return GetSelfAndBaseClasses(type).Where(t => !excludeObject || t != typeof(object)).ToList(); } - private static IEnumerable GetSelfAndBaseClasses(Type type) + private static IEnumerable GetSelfAndBaseClasses(Type? type) { while (type != null) { diff --git a/src/System.Linq.Dynamic.Core/Res.cs b/src/System.Linq.Dynamic.Core/Res.cs index faaaa21a..4b074e08 100644 --- a/src/System.Linq.Dynamic.Core/Res.cs +++ b/src/System.Linq.Dynamic.Core/Res.cs @@ -55,7 +55,6 @@ internal static class Res public const string MissingAsClause = "Expression is missing an 'as' clause"; public const string NeitherTypeConvertsToOther = "Neither of the types '{0}' and '{1}' converts to the other"; public const string NewOperatorIsNotAllowed = "Using the new operator is not allowed via the ParsingConfig."; - public const string NoApplicableAggregate = "No applicable aggregate method '{0}({1})' exists"; public const string NoApplicableIndexer = "No applicable indexer exists in type '{0}'"; public const string NoApplicableMethod = "No applicable method '{0}' exists in type '{1}'"; public const string NoItInScope = "No 'it' is in scope"; diff --git a/test/System.Linq.Dynamic.Core.Tests/Issues/Issue775.cs b/test/System.Linq.Dynamic.Core.Tests/Issues/Issue775.cs new file mode 100644 index 00000000..14aa13b9 --- /dev/null +++ b/test/System.Linq.Dynamic.Core.Tests/Issues/Issue775.cs @@ -0,0 +1,52 @@ +using NFluent; +using System.Linq.Dynamic.Core.Tests.Helpers.Models; +using FluentAssertions; +using Xunit; + +namespace System.Linq.Dynamic.Core.Tests; + +public partial class QueryableTests +{ + [Fact] + public void Issue775a() + { + // Arrange + var users = User.GenerateSampleModels(10); + + // Act + var realResult = users.Where(x => x.Income == users.Select(p => p.Income).Min()).Select(x => x.Id).ToArray(); + var result = users.AsQueryable().Where("Income == @0.Select(Income).Min()", users).Select("Id"); + + // Assert + Check.That(result.ToDynamicArray().Cast()).ContainsExactly(realResult); + } + + [Fact] + public void Issue775b() + { + // Arrange + var users = User.GenerateSampleModels(10); + var pets = new[] { new Pet() }.AsQueryable(); + + // Act + var realResult = users.Where(x => x.Income == pets.Select(p => p.Id).FirstOrDefault()).Select(x => x.Id).ToArray(); + var result = users.AsQueryable().Where("Income == @0.Select(Id).FirstOrDefault()", pets).Select("Id"); + + // Assert + Check.That(result.ToDynamicArray().Cast()).ContainsExactly(realResult); + } + + [Fact] + public void Issue775_Exception() + { + // Arrange + var users = User.GenerateSampleModels(10); + var pets = new[] { new Pet() }.AsQueryable(); + + // Act + Action act = () => users.AsQueryable().Where("Income == @0.Select(Id).XXX()", pets); + + // Assert + act.Should().Throw(); + } +} \ No newline at end of file