dotnet / extensions

This repository contains a suite of libraries that provide facilities commonly needed when creating production-ready applications.
MIT License
2.63k stars 751 forks source link

[New Feature Proposal] Add support for HttpRequestMessage objects containing StreamContent to AddStandardHedgingHandler() #5105

Open adamhammond opened 6 months ago

adamhammond commented 6 months ago

Background and motivation

Many clients trying to use the AddStandardHedgingHandler() resilience API based on top of Polly v8 have requirements that force them to send HttpRequestMessage objects that contain StreamContent. Today, if a client built with an IHttpClientBuilder that was configured with a resilience handler via the AddStandardHedgingHandler() API attempts to send an HttpRequestMessage object that contains StreamContent to a downstream service, then an InvalidOperationException will be thrown. This exception is thrown by the Initialize() method in RequestMessageSnapshot.cs: https://github.com/dotnet/extensions/blob/10681a1cdb1e044b05341150203b94d5eec41557/src/Libraries/Microsoft.Extensions.Http.Resilience/Internal/RequestMessageSnapshot.cs#L73-L80

There is no reason to limit users by only supporting a subset of HttpContent types. Further, support can be added for HttpRequestMessage objects whose content is of type StreamContent while still defaulting to the existing shallow copy logic for HttpRequestMessage objects whose content is of any other HttpContent type. In this way, the change can be made fully backwards compatible with existing APIs and with no side effects for existing users of the AddStandardHedgingHandler() API.

Feature Proposal

Proposed Changes to RequestMessageSnapshot.cs:

internal sealed class RequestMessageSnapshot : IResettable, IDisposable
{
    private static readonly ObjectPool<RequestMessageSnapshot> _snapshots = PoolFactory.CreateResettingPool<RequestMessageSnapshot>();

    private readonly List<KeyValuePair<string, IEnumerable<string>>> _headers = [];
    private readonly List<KeyValuePair<string, object?>> _properties = [];

    private HttpMethod? _method;
    private Uri? _requestUri;
    private Version? _version;
    private HttpContent? _content;

    [System.Diagnostics.CodeAnalysis.SuppressMessage("Resilience", "EA0014:The async method doesn't support cancellation", Justification = "Past the point of no cancellation.")]
    public static async Task<RequestMessageSnapshot> CreateAsync(HttpRequestMessage request)
    {
        _ = Throw.IfNull(request);

        var snapshot = _snapshots.Get();
        await snapshot.InitializeAsync(request).ConfigureAwait(false);
        return snapshot;
    }

    [System.Diagnostics.CodeAnalysis.SuppressMessage("Resilience", "EA0014:The async method doesn't support cancellation", Justification = "Past the point of no cancellation.")]
    public async Task<HttpRequestMessage> CreateRequestMessageAsync()
    {
        if (IsReset())
        {
            throw new InvalidOperationException($"{nameof(CreateRequestMessageAsync)}() cannot be called on a snapshot object that has been reset and has not been initialized");
        }

        var clone = new HttpRequestMessage(_method!, _requestUri)
        {
            Version = _version!
        };

        if (_content is StreamContent)
        {
            (HttpContent? content, HttpContent? clonedContent) = await CloneContentAsync(_content).ConfigureAwait(false);
            _content = content;
            clone.Content = clonedContent;
        }
        else
        {
            clone.Content = _content;
        }

#if NET5_0_OR_GREATER
        foreach (var prop in _properties)
        {
            _ = clone.Options.TryAdd(prop.Key, prop.Value);
        }
#else
        foreach (var prop in _properties)
        {
            clone.Properties.Add(prop);
        }
#endif
        foreach (KeyValuePair<string, IEnumerable<string>> header in _headers)
        {
            _ = clone.Headers.TryAddWithoutValidation(header.Key, header.Value);
        }

        return clone;
    }

    [System.Diagnostics.CodeAnalysis.SuppressMessage("Critical Bug", "S2952:Classes should \"Dispose\" of members from the classes' own \"Dispose\" methods", Justification = "Handled by ObjectPool")]
    bool IResettable.TryReset()
    {
        _properties.Clear();
        _headers.Clear();

        _method = null;
        _version = null;
        _requestUri = null;
        if (_content is StreamContent)
        {
            // a snapshot's StreamContent is always a unique copy (deep clone)
            // therefore, it is safe to dispose when snapshot is no longer needed
            _content.Dispose();
        }

        _content = null;

        return true;
    }

    void IDisposable.Dispose() => _snapshots.Return(this);

    [System.Diagnostics.CodeAnalysis.SuppressMessage("Resilience", "EA0014:The async method doesn't support cancellation", Justification = "Past the point of no cancellation.")]
    private static async Task<(HttpContent? content, HttpContent? clonedContent)> CloneContentAsync(HttpContent? content)
    {
        HttpContent? clonedContent = null;
        if (content != null)
        {
            HttpContent originalContent = content;
            Stream originalRequestBody = await content.ReadAsStreamAsync().ConfigureAwait(false);
            MemoryStream clonedRequestBody = new MemoryStream();
            await originalRequestBody.CopyToAsync(clonedRequestBody).ConfigureAwait(false);
            clonedRequestBody.Position = 0;
            if (originalRequestBody.CanSeek)
            {
                originalRequestBody.Position = 0;
            }
            else
            {
                originalRequestBody = new MemoryStream();
                await clonedRequestBody.CopyToAsync(originalRequestBody).ConfigureAwait(false);
                originalRequestBody.Position = 0;
                clonedRequestBody.Position = 0;
            }

            clonedContent = new StreamContent(clonedRequestBody);
            content = new StreamContent(originalRequestBody);
            foreach (KeyValuePair<string, IEnumerable<string>> header in originalContent.Headers)
            {
                _ = clonedContent.Headers.TryAddWithoutValidation(header.Key, header.Value);
                _ = content.Headers.TryAddWithoutValidation(header.Key, header.Value);
            }
        }

        return (content, clonedContent);
    }

    private bool IsReset()
    {
        return _method == null;
    }

    [System.Diagnostics.CodeAnalysis.SuppressMessage("Resilience", "EA0014:The async method doesn't support cancellation", Justification = "Past the point of no cancellation.")]
    private async Task InitializeAsync(HttpRequestMessage request)
    {
        _method = request.Method;
        _version = request.Version;
        _requestUri = request.RequestUri;
        if (request.Content is StreamContent)
        {
            (HttpContent? requestContent, HttpContent? clonedRequestContent) = await CloneContentAsync(request.Content).ConfigureAwait(false);
            _content = clonedRequestContent;
            request.Content = requestContent;
        }
        else
        {
            _content = request.Content;
        }

        // headers
        _headers.AddRange(request.Headers);

        // props
#if NET5_0_OR_GREATER
        _properties.AddRange(request.Options);
#else
        _properties.AddRange(request.Properties);
#endif
    }
}

Proposed Changes to ResilienceHttpClientBuilderExtensions.Hedging.cs:

public static IStandardHedgingHandlerBuilder AddStandardHedgingHandler(this IHttpClientBuilder builder)
{
    _ = Throw.IfNull(builder);

    var optionsName = builder.Name;
    var routingBuilder = new RoutingStrategyBuilder(builder.Name, builder.Services);

    builder.Services.TryAddSingleton<Randomizer>();

    _ = builder.Services.AddOptionsWithValidateOnStart<HttpStandardHedgingResilienceOptions, HttpStandardHedgingResilienceOptionsValidator>(optionsName);
    _ = builder.Services.AddOptionsWithValidateOnStart<HttpStandardHedgingResilienceOptions, HttpStandardHedgingResilienceOptionsCustomValidator>(optionsName);
    _ = builder.Services.PostConfigure<HttpStandardHedgingResilienceOptions>(optionsName, options =>
    {
        options.Hedging.ActionGenerator = args =>
        {
            if (!args.PrimaryContext.Properties.TryGetValue(ResilienceKeys.RequestSnapshot, out var snapshot))
            {
                Throw.InvalidOperationException("Request message snapshot is not attached to the resilience context.");
            }

            // if a routing strategy has been configured but it does not return the next route, then no more routes
            // are availabe, stop hedging
            Uri? route;
            if (args.PrimaryContext.Properties.TryGetValue(ResilienceKeys.RoutingStrategy, out var routingPipeline))
            {
                if (!routingPipeline.TryGetNextRoute(out route))
                {
                    return null;
                }
            }
            else
            {
                route = null;
            }

            return async () =>
            {
                Outcome<HttpResponseMessage>? actionResult = null;

                try
                {
                    var requestMessage = await snapshot.CreateRequestMessageAsync().ConfigureAwait(false);

                    // The secondary request message should use the action resilience context
                    requestMessage.SetResilienceContext(args.ActionContext);

                    // replace the request message
                    args.ActionContext.Properties.Set(ResilienceKeys.RequestMessage, requestMessage);

                    if (route != null)
                    {
                        // replace the RequestUri of the request per the routing strategy
                        requestMessage.RequestUri = requestMessage.RequestUri!.ReplaceHost(route);
                    }
                }
                catch (IOException e)
                {
                    actionResult = Outcome.FromException<HttpResponseMessage>(e);
                }

                return actionResult ?? await args.Callback(args.ActionContext).ConfigureAwait(args.ActionContext.ContinueOnCapturedContext);
            };
        };
    });

    // configure outer handler
    var outerHandler = builder.AddResilienceHandler(HedgingConstants.HandlerPostfix, (builder, context) =>
    {
        var options = context.GetOptions<HttpStandardHedgingResilienceOptions>(optionsName);
        context.EnableReloads<HttpStandardHedgingResilienceOptions>(optionsName);
        var routingOptions = context.GetOptions<RequestRoutingOptions>(routingBuilder.Name);

        _ = builder
            .AddStrategy(_ => new RoutingResilienceStrategy(routingOptions.RoutingStrategyProvider))
            .AddStrategy(_ => new RequestMessageSnapshotStrategy())
            .AddTimeout(options.TotalRequestTimeout)
            .AddHedging(options.Hedging);
    });

    // configure inner handler
    var innerBuilder = builder.AddResilienceHandler(
        HedgingConstants.InnerHandlerPostfix,
        (builder, context) =>
        {
            var options = context.GetOptions<HttpStandardHedgingResilienceOptions>(optionsName);
            context.EnableReloads<HttpStandardHedgingResilienceOptions>(optionsName);

            _ = builder
                .AddRateLimiter(options.Endpoint.RateLimiter)
                .AddCircuitBreaker(options.Endpoint.CircuitBreaker)
                .AddTimeout(options.Endpoint.Timeout);
        })
        .SelectPipelineByAuthority();

    return new StandardHedgingHandlerBuilder(builder.Name, builder.Services, routingBuilder);
}
...

Proposed Changes to RequestMessageSnapshotStrategy.cs:

protected override async ValueTask<Outcome<TResult>> ExecuteCore<TResult, TState>(
    Func<ResilienceContext, TState, ValueTask<Outcome<TResult>>> callback,
    ResilienceContext context,
    TState state)
{
    if (!context.Properties.TryGetValue(ResilienceKeys.RequestMessage, out var request) || request is null)
    {
        Throw.InvalidOperationException("The HTTP request message was not found in the resilience context.");
    }

    try
    {
        using var snapshot = await RequestMessageSnapshot.CreateAsync(request).ConfigureAwait(context.ContinueOnCapturedContext);
        context.Properties.Set(ResilienceKeys.RequestSnapshot, snapshot);
        return await callback(context, state).ConfigureAwait(context.ContinueOnCapturedContext);
    }
    catch (IOException e)
    {
        return Outcome.FromException<TResult>(e);
    }
}
...
joperezr commented 6 months ago

cc: @iliar-turdushev Can you please take a look at this?