JasperFx / lamar

Fast Inversion of Control Tool and Successor to StructureMap
https://jasperfx.github.io/lamar
MIT License
563 stars 118 forks source link

Scope validation like in .NET Core DI #383

Closed DrMueller closed 11 months ago

DrMueller commented 1 year ago

Hello there

I'm a big fan of Lamar, using it in any project. Yet, I've come recently into an issue. The default service builder performs a scope validation (https://learn.microsoft.com/en-us/dotnet/core/extensions/dependency-injection#scope-validation), therefore not allowing for example scoped services being injected in singleton services. As we had some issues with that topic, it would be great to have something like that in Lamar. Is there something, which I'm not aware of? If that's not the case, I would be happy to build that myself if you think the means to do so are around.

With best regards

Matthias

rizi commented 11 months ago

Hello there

I'm a big fan of Lamar, using it in any project. Yet, I've come recently into an issue. The default service builder performs a scope validation (https://learn.microsoft.com/en-us/dotnet/core/extensions/dependency-injection#scope-validation), therefore not allowing for example scoped services being injected in singleton services. As we had some issues with that topic, it would be great to have something like that in Lamar. Is there something, which I'm not aware of? If that's not the case, I would be happy to build that myself if you think the means to do so are around.

With best regards

Matthias

@DrMueller

Maybe this helps, just call this extension method after the container is built.

Note: IsTypeFromUpperAssembly is used to only check our own classes (they are all in assemblies that start with Upper.)

using System.Reflection;

using Lamar;
using Lamar.Diagnostics;
using Lamar.IoC;

using Microsoft.Extensions.DependencyInjection;

using Upper.Essentials.Registration.Lamar;

using static System.FormattableString;

namespace Upper.Essentials.Lamar.Extensions;

/// <summary>
/// Ensures recursively that a singleton does not inject a type that is registered with lifetime scoped.
/// </summary>
public static class CheckLifetimeExtension
{
     /// <summary>
    /// Ensures recursively that a singleton does not inject a type that is registered with lifetime scoped.
    /// </summary>
    public static void CheckLifetimeRegistrations(this IContainer container)
    {
        if (container == null)
            throw new ArgumentNullException(nameof(container));

        HashSet<Type> typesAlreadyChecked = new();

        IReadOnlyCollection<InstanceRef> registrations = container.Model.AllInstances.ToList();

        IReadOnlyCollection<Type> singletonTypes = GetSingletonTypes(registrations);
        IReadOnlyCollection<Type> lifetimeScopedTypes = GetScopedTypes(registrations);

        IReadOnlyCollection<Type> singletonTypesWithInjectedLifetimeScopeTypes = singletonTypes
            .Where(singletonType => IsScopedUsedForAtLeastOneTypeInAtLeastOneConstructor(registrations, singletonType, lifetimeScopedTypes, typesAlreadyChecked))
            .ToList();

        if (!singletonTypesWithInjectedLifetimeScopeTypes.Any())
            return;

        string types = string.Join(Environment.NewLine, singletonTypesWithInjectedLifetimeScopeTypes.Select(type => Invariant($"Type: {type.FullName}")));
        string errorMessage = Invariant($"Types with lifetime of Singleton must not inject a scoped instance:{Environment.NewLine}{types}");

        throw new LamarException(errorMessage);
    }

    private static IReadOnlyCollection<Type> GetScopedTypes(IEnumerable<InstanceRef> registrations)
    {
        return registrations
            .Where(registration => registration.Lifetime == ServiceLifetime.Scoped)
            .Select(registration => registration.ImplementationType)
            .ToList();
    }

    private static IReadOnlyCollection<Type> GetSingletonTypes(IEnumerable<InstanceRef> registrations)
    {
        return registrations
            .Where(registration => registration.Lifetime == ServiceLifetime.Singleton)
            .Where(IsTypeFromUpperAssembly)
            .Select(registration => registration.ImplementationType)
            .Where(type => type.GetConstructors(BindingFlags.Instance | BindingFlags.Public | BindingFlags.NonPublic).Any(constructorInfo => constructorInfo.GetParameters().Any()))
            .Distinct()
            .ToList();
    }

    private static bool IsTypeFromUpperAssembly(InstanceRef instance)
    {
        bool IsFromUpperAssembly(string name)
        {
            return name != null
                   && name.StartsWith("Upper.", StringComparison.OrdinalIgnoreCase);
        }

        string serviceTypeName = instance.ServiceType?.Assembly.GetName().Name;
        string implementatoinTypeName = instance.ImplementationType?.Assembly.GetName().Name;

        return IsFromUpperAssembly(serviceTypeName) || IsFromUpperAssembly(implementatoinTypeName);
    }

    private static bool IsScopedUsedForAtLeastOneTypeInAtLeastOneConstructor(IReadOnlyCollection<InstanceRef> registrations, Type concreteType, IReadOnlyCollection<Type> allScopedTypes, HashSet<Type> typesAlreadyChecked)
    {
        typesAlreadyChecked.Add(concreteType);

        IReadOnlyCollection<Type> typesFromConstructors = concreteType.GetConstructors()
            .SelectMany(constructor => constructor.GetParameters().Select(parameterInfo => parameterInfo.ParameterType))
            .Where(parameterType => !typesAlreadyChecked.Contains(parameterType))
            .Select(parameterType =>
                    {
                        typesAlreadyChecked.Add(parameterType);

                        return parameterType;
                    })
            .ToList();

        return typesFromConstructors.Any(type => IsLifetimeScopeUsed(registrations, type, allScopedTypes, typesAlreadyChecked)

                                                 //we must handle generics! Func<T>, IEnumerable<T>,...
                                                 || type.IsGenericType && type.GetGenericArguments().Any(genericArgument => IsLifetimeScopeUsed(registrations, genericArgument, allScopedTypes, typesAlreadyChecked))
        );
    }

    private static bool IsLifetimeScopeUsed(IReadOnlyCollection<InstanceRef> registrations, Type serviceTypeToCheck, IReadOnlyCollection<Type> allScopedTypes, HashSet<Type> typesAlreadyChecked)
    {
        IReadOnlyCollection<Type> serviceTypesToCheck = registrations
            .Where(registration => registration.ServiceType == serviceTypeToCheck || registration.ImplementationType == serviceTypeToCheck)
            .Select(registration => registration.ImplementationType)
            .ToList();

        return allScopedTypes.Any(scopedType => serviceTypesToCheck.Contains(scopedType))
               || serviceTypesToCheck.Any(currentServiceTypeToCheck => IsScopedUsedForAtLeastOneTypeInAtLeastOneConstructor(registrations, currentServiceTypeToCheck, allScopedTypes, typesAlreadyChecked));
    }
}

br

DrMueller commented 11 months ago

@rizi He that works like a charm thank you very much. I could even make it much simpler, as I can put some invariants regarding DI, but great stuff overall. Wouldn't that be a cool feature to have in the API directly? I think most people could use this, as finding scoping errors isn't too much fun.

One question: Do you know, why usually only singleton -> scoped is used? I'd think the other combinations

Do not make too much sense, but also Microsoft isn't verifying that.

rizi commented 11 months ago

@DrMueller you are absolutely right, other combinations could make sense as well, but I'm not sure that Singleton/Scoped --> Transient really make sense --> therefore I only implemented Singleton --> Scoped.

And yes it would be great if this would be built in, but also Autofac doesn't have this feature, as far as I remember because of some edge cases that could cause problems (at least if there is no way to remove certain types from this check (but that should be very easy to implement).

I would be curious how this could be simplified (maybe it's also possible on my side).

Br

DrMueller commented 11 months ago

@rizi Here is my current attempt:

namespace QualityTests.TestingAreas.CrossCutting.DependencyInjection
{
    public partial class DependencyInjectionTests : QualityTestBase
    {
        public DependencyInjectionTests(QualityTestFixture fixture) : base(fixture)
        {
        }

        private static IReadOnlyCollection<Type> GetAllConstructorTypes(Type type)
        {
            return type.GetConstructors()
                .SelectMany(constructor => constructor.GetParameters().Select(parameterInfo => parameterInfo.ParameterType))
                .Select(parameterType => parameterType)
                .ToList();
        }

        private static IReadOnlyCollection<Type> GetImplementationsOfLifetime(IEnumerable<InstanceRef> registrations, ServiceLifetime lifetime)
        {
            return registrations
                .Where(registration => registration.Lifetime == lifetime)
                .Select(registration => registration.ImplementationType)
                .Where(f => !f.FullName!.StartsWith("Microsoft"))
                .Distinct()
                .ToList();
        }

        private static IReadOnlyCollection<Type> GetInterfacesOfLifetime(IEnumerable<InstanceRef> registrations, params ServiceLifetime[] lifetimes)
        {
            return registrations
                .Where(registration => lifetimes.Contains(registration.Lifetime))
                .Select(registration => registration.ServiceType)
                .Distinct()
                .ToList();
        }

        private void AssertInjectionScoping(
            ServiceLifetime lifeTimeToTest,
            params ServiceLifetime[] notAllowedLifeTimes)
        {
            var registrations = LoadRegistrations();
            var typesToCheck = GetImplementationsOfLifetime(registrations, lifeTimeToTest).ToList();
            var forbiddenInterfaces = GetInterfacesOfLifetime(registrations, notAllowedLifeTimes).ToList();

            var sb = new StringBuilder();

            foreach (var typeToCheck in typesToCheck)
            {
                var ctorTypes = GetAllConstructorTypes(typeToCheck);
                var typesUsed = ctorTypes.Intersect(forbiddenInterfaces).ToList();

                foreach (var type in typesUsed)
                {
                    sb.AppendLine($"{typeToCheck.FullName} -> {type.FullName}");
                }
            }

            var str = sb.ToString();
            if (!string.IsNullOrEmpty(str))
            {
                Assert.Fail(str);
            }
        }

        private IReadOnlyCollection<InstanceRef> LoadRegistrations()
        {
            var serviceContainer = AppFactory.Services;
            serviceContainer.Should().BeOfType<Container>();
            var container = (IContainer)serviceContainer;
            return container.Model.AllInstances.ToList();
        }
    }
}
namespace QualityTests.TestingAreas.CrossCutting.DependencyInjection
{
    public partial class DependencyInjectionTests
    {
        [Fact]
        public void SingletonTypes_UseOnlyOtherSingletonTypes()
        {
            AssertInjectionScoping(
                ServiceLifetime.Singleton,
                ServiceLifetime.Scoped,
                ServiceLifetime.Transient);
        }

        [Fact]
        public void ScopedTypes_DoNotUseTransientTypes()
        {
            AssertInjectionScoping(
                ServiceLifetime.Scoped,
                ServiceLifetime.Transient);
        }
    }
}

I could remove some of the wheres, but I guess that's due to not having too many special cases (except Microsoft 😄 ).