Azure / azure-functions-dotnet-worker

Azure Functions out-of-process .NET language worker
MIT License
423 stars 184 forks source link

BindInputAsync caches without regard to type, which leads to invalid cast later #2332

Open stevendarby opened 7 months ago

stevendarby commented 7 months ago

Description

If I use BindInputAsync<T>(metadata) where T is a base type of the actual input type, it successfully deserializes, however, the function later fails to invoke, saying it cannot cast from the base class to the actual class. I think this is because BindInputAsync<T> caches the binding result regardless of what T is.

I want to enforce a common base class for ActivityTrigger inputs so that I can have middleware that can read the input for any ActivityTrigger.

See below... If you change the below to do BindInputAsync<MessageA>() it all works, but now the middleware isn't general enough.

Steps to reproduce

using Azure.Messaging;
using Microsoft.Azure.Functions.Worker;
using Microsoft.Azure.Functions.Worker.Middleware;
using Microsoft.DurableTask;
using Microsoft.DurableTask.Client;
using Microsoft.Extensions.Hosting;
using Microsoft.Extensions.Logging;

await new HostBuilder()
    .ConfigureFunctionsWorkerDefaults(x => x.UseMiddleware<TenantMiddleware>())
    .Build()
    .RunAsync();

public class MessageBase
{
    public int TenantId { get; set; }
}

public class MessageA : MessageBase
{
    public string Name { get; set; }
}

internal class TenantMiddleware : IFunctionsWorkerMiddleware
{
    public async Task Invoke(FunctionContext context, FunctionExecutionDelegate next)
    {
        int tenantId = 0;

        if (context.FunctionDefinition.InputBindings.Values.FirstOrDefault(x => x.Type == "activityTrigger") is { } activity)
        {
            var input = await context.BindInputAsync<MessageBase>(activity);
            tenantId = input.Value!.TenantId;
        }

        context.Items.Add("TenantId", tenantId);

        await next(context);
    }
}

public class CloudEventFunction
{
    private readonly ILogger<CloudEventFunction> _logger;

    public CloudEventFunction(ILogger<CloudEventFunction> logger)
    {
        _logger = logger;
    }

    [Function("CloudEventFunction")]
    public async Task RunCloudEventAsync(
        [EventGridTrigger] CloudEvent cloudEvent,
        [DurableClient] DurableTaskClient durableTaskClient)
    {
        await durableTaskClient.ScheduleNewOrchestrationInstanceAsync("OrchestratorFunction");
    }

    [Function("OrchestratorFunction")]
    public async Task RunOrchestratorAsync([OrchestrationTrigger] TaskOrchestrationContext taskOrchestrationContext)
    {
        var message = new MessageA { TenantId = 42, Name = "Test" };
        await taskOrchestrationContext.CallActivityAsync("ActivityFunction", message);
    }

    [Function("ActivityFunction")]
    public Task RunActivityAsync([ActivityTrigger] MessageA message, FunctionContext context)
    {
        _logger.LogInformation("TenantId: {TenantId}", context.Items["TenantId"]);
        return Task.CompletedTask;
    }
}
stevendarby commented 7 months ago

Workaround Change the middleware to create & cache a generic function specific to each concrete type:

internal class TenantMiddleware : IFunctionsWorkerMiddleware
{
    private readonly ConcurrentDictionary<Type, Func<FunctionContext, BindingMetadata, Task<MessageBase>>> _getMessageFuncs = new();

    public async Task Invoke(FunctionContext context, FunctionExecutionDelegate next)
    {
        int tenantId = 0;

        if (context.FunctionDefinition.InputBindings.Values.FirstOrDefault(x => x.Type == "activityTrigger") is { } activity &&
            context.FunctionDefinition.Parameters.FirstOrDefault(x => x.Name == activity.Name) is { } param)
        {
            var getMessageAsync = _getMessageFuncs.GetOrAdd(
                param.Type,
                type =>
                {
                    var method = typeof(TenantMiddleware)
                        .GetMethod(nameof(GetMessageAsync), BindingFlags.Static | BindingFlags.NonPublic)!
                        .MakeGenericMethod(type);

                    var param0 = Expression.Parameter(typeof(FunctionContext));
                    var param1 = Expression.Parameter(typeof(BindingMetadata));
                    var call = Expression.Call(method, param0, param1);

                    return Expression.Lambda<Func<FunctionContext, BindingMetadata, Task<MessageBase>>>(call, param0, param1).Compile();
                });

            var input = await getMessageAsync(context, activity);
            tenantId = input.TenantId;
        }

        context.Items.Add("TenantId", tenantId);

        await next(context);
    }

    private static async Task<MessageBase?> GetMessageAsync<T>(FunctionContext context, BindingMetadata bindingMetadata) where T : MessageBase
    {
        var inputBindingData = await context.BindInputAsync<T>(bindingMetadata);
        return inputBindingData.Value;
    }
}
jviau commented 7 months ago

@stevendarby This is expected behavior. By calling BindInputAsync<BaseClass>(...) you have instantiated that input as the base class. When we later try to invoke the function the built input type and the expected type no longer match and it will fail. Now what you could do is the inverse. You could have your trigger accept the base class, and then call BindInputAsync with a more derived type, because then the cast will succeed. To put it another way, you are trying to do the following:

public class Base { }
public class Derived : Base { }
public void MyFunction(Derived input) { }

object input = new Base();
MyFunction((Derived)input);

The above will obviously fail because input is not castable to Derived. This is what is happening in the input converter. You explicitly asked the functions binder to instantiate that input as the derived type, when your function signature expects the derived type. This will always result in a runtime failure when we try to cast to the expected type.

stevendarby commented 7 months ago

@jviau I understand why the cast fails. Was that not clear from my issue? I'm saying that the current behaviour of BindInputAsync - caching and reusing the first result per input binding metadata, regardless of what T generic is given to it - is bad.

If you re-read my original post, it has a use-case for calling BindInputAsync with the base class in middleware, and why the suggestion to use the derived type is no good. The correct generic is not known by middleware that could be serving multiple functions with different input types. We have to resort to reflection (see second post above, and my note below*).

The first point to make is that it is not at all clear that using BindInputAsync in middleware is going to have this side effect, one that can break the function invocation later. Documentation ahead of time could be clearer, and a better exception could be raised could suggest why this might be happening and making it clear this isn't supported.

But I wouldn't just document this better and stop there. Why not improve the behaviour?

The caching could done per typeof(T) for example.

Or perhaps the FunctionContext could contain the input values, already bound to their actual declared types (it's going to do this later anyway), and accessible via an "object[]" property or some such, that developers can then try casting to any type, e.g. if (context.InputValues[0] is MessageBase msg) etc. etc.

I want to expand what is possible with function middleware, which is clearly in its early days :)

stevendarby commented 7 months ago

Here is the reusable, generic extension method that I'm using instead of BindInputAsync directly, for anyone who may be interested.

public static class FunctionContextExtensions
{
    private static readonly MethodInfo BindInputAsyncMethod
        = typeof(FunctionContextExtensions).GetMethod(nameof(BindInputAsync), BindingFlags.Static | BindingFlags.NonPublic)!;

    public static async Task<T?> GetInputValueAsync<T>(this FunctionContext context, BindingMetadata bindingMetadata)
    {
        var parameter = context.FunctionDefinition.Parameters
            .FirstOrDefault(parameter => parameter.Name == bindingMetadata.Name && parameter.Type.IsAssignableTo(typeof(T)));

        if (parameter is null)
        {
            return default;
        }

        // BindInputAsync serializes to T and caches the result against the binding metadata.
        // This causes issue if T is e.g. a base class, as the function will later call BindInputAsync
        // with T as the actual type of the function parameter, and won't be able to cast the cached result
        // (e.g. the base class object) to T. Therefore if T is not exactly the parameter type, we use
        // reflection to call BindInputAsync with the actual parameter type, then cast the result to T.
        if (parameter.Type == typeof(T))
        {
            var result = await context.BindInputAsync<T>(bindingMetadata);
            return result.Value;
        }
        else
        {
            var method = BindInputAsyncMethod.MakeGenericMethod(parameter.Type);
            var value = await (Task<object?>)method.Invoke(null, [context, bindingMetadata])!;
            return (T?)value;
        }
    }

    private static async Task<object?> BindInputAsync<T>(FunctionContext context, BindingMetadata bindingMetadata)
    {
        var result = await context.BindInputAsync<T>(bindingMetadata);
        return result.Value;
    }
}