dotnet / reactive

The Reactive Extensions for .NET
http://reactivex.io
MIT License
6.73k stars 751 forks source link

Fix of: The ToObservable operator can throw unhandled exception #2171

Closed fedeAlterio closed 1 month ago

fedeAlterio commented 1 month ago

This is a followup of #1677

I'd like to fix this bug, but i dont know what is the procedure to create a PR. This is the fixed code, with a test that fails the current Rx.Net code. How can I proceed?

Thank you!

// Licensed to the .NET Foundation under one or more agreements.
// The .NET Foundation licenses this file to you under the MIT License.
// See the LICENSE file in the project root for more information. 

using System.Collections.Generic;

namespace System.Linq
{
    public static partial class AsyncEnumerable
    {
        /// <summary>
        /// Converts an async-enumerable sequence to an observable sequence.
        /// </summary>
        /// <typeparam name="TSource">The type of the elements in the source sequence.</typeparam>
        /// <param name="source">Enumerable sequence to convert to an observable sequence.</param>
        /// <returns>The observable sequence whose elements are pulled from the given enumerable sequence.</returns>
        /// <exception cref="ArgumentNullException"><paramref name="source"/> is null.</exception>
        public static IObservable<TSource> ToObservable<TSource>(this IAsyncEnumerable<TSource> source)
        {
            if (source == null)
                throw Error.ArgumentNull(nameof(source));

            return new ToObservableObservable<TSource>(source);
        }

        private sealed class ToObservableObservable<T> : IObservable<T>
        {
            private readonly IAsyncEnumerable<T> _source;

            public ToObservableObservable(IAsyncEnumerable<T> source)
            {
                _source = source;
            }

            public IDisposable Subscribe(IObserver<T> observer)
            {
                var ctd = new CancellationTokenDisposable();

                async void Core()
                {
                    IAsyncEnumerator<T> e;

                    try
                    {
                        e = _source.GetAsyncEnumerator(ctd.Token);
                    }
                    catch (Exception ex)
                    {
                        if (!ctd.Token.IsCancellationRequested)
                        {
                            observer.OnError(ex);
                        }

                        return;
                    }

                    await using (e)
                    {
                        do
                        {
                            bool hasNext;
                            var value = default(T)!;

                            try
                            {
                                hasNext = await e.MoveNextAsync().ConfigureAwait(false);
                                if (hasNext)
                                {
                                    value = e.Current;
                                }
                            }
                            catch (Exception ex)
                            {
                                if (!ctd.Token.IsCancellationRequested)
                                {
                                    observer.OnError(ex);
                                }

                                return;
                            }

                            if (!hasNext)
                            {
                                observer.OnCompleted();
                                return;
                            }

                            observer.OnNext(value);
                        }
                        while (!ctd.Token.IsCancellationRequested);
                    }
                }

                // Fire and forget
                Core();

                return ctd;
            }
        }
    }
}
 [Fact]
 public void ToObservable_ShouldForwardExceptionOnGetEnumeratorAsync()
 {
     var exception = new Exception("Exception message");
     Exception? recievedException = null;
     var enumerable = AsyncEnumerable.Create<int>(_ => throw exception);
     using var evt = new ManualResetEvent(false);

     var observable = enumerable.ToObservable();
     observable.Subscribe(new MyObserver<int>(_ =>
                                              {
                                                  evt.Set();
                                              },
                                              e =>
                                              {
                                                  recievedException = e;
                                                  evt.Set();
                                              }, () =>
                                              {
                                                  evt.Set();
                                              }));

     evt.WaitOne();
     Assert.NotNull(recievedException);
     Assert.Equal(exception.Message, recievedException!.Message);
 }